diff --git a/api/api_v1/endpoints/dashboard.py b/api/api_v1/endpoints/dashboard.py index ed0f40f..f27c8da 100644 --- a/api/api_v1/endpoints/dashboard.py +++ b/api/api_v1/endpoints/dashboard.py @@ -18,10 +18,10 @@ async def create( try: await crud.dashboard.create(db, data_in, user_id=current_user.id) except pymongo.errors.DuplicateKeyError: - return schemas.Msg(code=-1, msg='error', detail='看板已存在') + return schemas.Msg(code=-1, msg='error', data='看板已存在') # todo 建默认文件夹 - return schemas.Msg(code=0, msg='ok', detail='创建成功') + return schemas.Msg(code=0, msg='ok', data='创建成功') @router.post("/delete") @@ -34,8 +34,8 @@ async def delete( del_dashboard = await crud.dashboard.delete(db, _id=data_in.id, user_id=current_user.id) if del_dashboard.deleted_count == 0: - return schemas.Msg(code=-1, msg='error', detail='删除失败') - return schemas.Msg(code=0, msg='ok', detail='删除成功') + return schemas.Msg(code=-1, msg='error', data='删除失败') + return schemas.Msg(code=0, msg='ok', data='删除成功') @router.post("/move") @@ -49,8 +49,8 @@ async def delete( """ res = await crud.dashboard.update_one(db, id=data_in.source_id, cat=data_in.cat, pid=data_in.dest_id) if res.deleted_count == 0: - return schemas.Msg(code=-1, msg='error', detail='删除失败') - return schemas.Msg(code=0, msg='ok', detail='删除成功') + return schemas.Msg(code=-1, msg='error', data='删除失败') + return schemas.Msg(code=0, msg='ok', data='删除成功') @router.post("/add_report") @@ -59,7 +59,7 @@ async def add_report(data_in: schemas.AddReport, current_user: schemas.UserDB = Depends(deps.get_current_user) ): res = await crud.dashboard.update_one(db, id=data_in.id, **{'$push': {'reports': {'$each': data_in.report_ids}}}) - return schemas.Msg(code=0, msg='ok', detail='ok') + return schemas.Msg(code=0, msg='ok', data='ok') @router.post("/del_report") @@ -70,7 +70,7 @@ async def add_report(data_in: schemas.DelReport, """删除报表""" for item in data_in.report_ids: await crud.dashboard.update_one(db, id=data_in.id, **{'$pull': {'reports': item}}) - return schemas.Msg(code=0, msg='ok', detail='ok') + return schemas.Msg(code=0, msg='ok', data='ok') @router.get("/") @@ -81,4 +81,4 @@ async def add_report(_id: str, """获取一个看板""" res = await crud.dashboard.get(db, id=_id) res['reports'] = await crud.report.find_many(db, **{'$in': {'_id': res.get('reports')}}) - return schemas.Msg(code=0, msg='ok', detail=res['reports']) + return schemas.Msg(code=0, msg='ok', data=res['reports']) diff --git a/api/api_v1/endpoints/folder.py b/api/api_v1/endpoints/folder.py index dc57968..6cfe197 100644 --- a/api/api_v1/endpoints/folder.py +++ b/api/api_v1/endpoints/folder.py @@ -19,10 +19,10 @@ async def create( try: await crud.folder.create(db, data_in, user_id=current_user.id) except pymongo.errors.DuplicateKeyError: - return schemas.Msg(code=-1, msg='error', detail='文件夹已存在') + return schemas.Msg(code=-1, msg='error', data='文件夹已存在') # todo 建默认文件夹 - return schemas.Msg(code=0, msg='ok', detail='创建成功') + return schemas.Msg(code=0, msg='ok', data='创建成功') @router.post("/delete") @@ -37,5 +37,5 @@ async def delete( # 删除文件夹下的 dashboard del_dashboard = await crud.dashboard.delete(db, pid=data_in.id) if del_folder.deleted_count == 0: - return schemas.Msg(code=-1, msg='error', detail='删除失败') - return schemas.Msg(code=0, msg='ok', detail='删除成功') + return schemas.Msg(code=-1, msg='error', data='删除失败') + return schemas.Msg(code=0, msg='ok', data='删除成功') diff --git a/api/api_v1/endpoints/project.py b/api/api_v1/endpoints/project.py index 609a6c2..ffded54 100644 --- a/api/api_v1/endpoints/project.py +++ b/api/api_v1/endpoints/project.py @@ -1,5 +1,5 @@ import pymongo -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, Request from motor.motor_asyncio import AsyncIOMotorDatabase import crud, schemas from api import deps @@ -20,20 +20,22 @@ async def create( try: await crud.project.create(db, data_in, user_id=current_user.id) except pymongo.errors.DuplicateKeyError: - return schemas.Msg(code=-1, msg='error', detail='项目名已存在') + return schemas.Msg(code=-1, msg='项目名已存在', data='项目名已存在') # todo 建默认文件夹 + # schemas.FolderCreate + # await crud.folder.create(db, data_in, user_id=current_user.id) - return schemas.Msg(code=0, msg='ok', detail='创建成功') + return schemas.Msg(code=0, msg='创建成功') @router.get("/") -async def read_project( - db: AsyncIOMotorDatabase = Depends(get_database), - current_user: schemas.UserDB = Depends(deps.get_current_user) -): +async def read_project(request: Request, + db: AsyncIOMotorDatabase = Depends(get_database), + current_user: schemas.UserDB = Depends(deps.get_current_user) + ): """查看自己拥有的项目""" res = await crud.project.read_project(db, user_id=current_user.id) - return res + return schemas.Msg(code=0, msg='ok', data=res) @router.post("/kanban") @@ -43,19 +45,19 @@ async def read_kanban( current_user: schemas.UserDB = Depends(deps.get_current_user) ): """获取自己的看板""" - res = {'kanban': [], 'space': []} + res = {'kanban': [], 'spaces': []} # 我的看板 kanban = await crud.folder.read_folder(db, project_id=data_in.id, user_id=current_user.id, cat='kanban') for item in kanban: dashboards = await crud.dashboard.find_many(db, pid=item['_id']) res['kanban'].append({ - 'folder_name': item['name'], - 'dashboards': [], + 'name': item['name'], + 'children': [], '_id': item['_id'] }) for d in dashboards: - res['kanban'][-1]['dashboards'].append({ + res['kanban'][-1]['children'].append({ 'name': d['name'], '_id': item['_id'] }) @@ -68,32 +70,33 @@ async def read_kanban( spaces = await crud.space.find_many(db, **where) # 空间 文件夹 看板 for item in spaces: - res['space'].append({ - 'space_name': item['name'], - 'folders': [], - 'dashboards': [], + res['spaces'].append({ + 'name': item['name'], + 'children': [], '_id': item['_id'] }) - res['space'][-1]['authority'] = 'rw' if current_user.id in item['rw_members'] else 'r' + res['spaces'][-1]['authority'] = 'rw' if current_user.id in item['rw_members'] else 'r' for f in await crud.folder.find_many(db, pid=item['_id']): - res['space'][-1]['folders'].append({ + res['spaces'][-1]['children'].append({ 'name': f['name'], '_id': f['_id'], - 'dashboards': [], + 'children': [], + 'isFolder': True }) for d in await crud.dashboard.find_many(db, pid=f['_id']): - res['space'][-1]['folders'][-1]['dashboards'].append({ + res['spaces'][-1]['children'][-1]['children'].append({ 'name': d['name'], '_id': d['_id'] }) # 空间 看板 for d in await crud.dashboard.find_many(db, pid=item['_id']): - res['space'][-1]['dashboards'].append({ + res['spaces'][-1]['children'].append({ 'name': d['name'], - '_id': d['_id'] + '_id': d['_id'], + 'isFolder': False }) - return res + return schemas.Msg(code=0, msg='ok', data=res) diff --git a/api/api_v1/endpoints/report.py b/api/api_v1/endpoints/report.py index 342a530..6e9b706 100644 --- a/api/api_v1/endpoints/report.py +++ b/api/api_v1/endpoints/report.py @@ -21,19 +21,20 @@ async def create( try: await crud.report.create(db, data_in, user_id=current_user.id) except pymongo.errors.DuplicateKeyError: - return schemas.Msg(code=-1, msg='error', detail='报表已存在') + return schemas.Msg(code=-1, msg='error', data='报表已存在') - return schemas.Msg(code=0, msg='ok', detail='创建成功') + return schemas.Msg(code=0, msg='ok', data='创建成功') -@router.get("/") +@router.post("/read_report") async def read_report( + data_in: schemas.ReportRead, db: AsyncIOMotorDatabase = Depends(get_database), current_user: schemas.UserDB = Depends(deps.get_current_user) ) -> Any: """获取已建报表""" - res = await crud.report.read_report(db, user_id=current_user.id) + res = await crud.report.read_report(db, user_id=current_user.id, project_id=data_in.project_id) return res @@ -48,5 +49,5 @@ async def delete( del_report = await crud.report.delete(db, _id=data_in.id, user_id=current_user.id) if del_report.deleted_count == 0: - return schemas.Msg(code=-1, msg='error', detail='删除失败') - return schemas.Msg(code=0, msg='ok', detail='删除成功') + return schemas.Msg(code=-1, msg='error', data='删除失败') + return schemas.Msg(code=0, msg='ok', data='删除成功') diff --git a/api/api_v1/endpoints/space.py b/api/api_v1/endpoints/space.py index 20e78be..983b498 100644 --- a/api/api_v1/endpoints/space.py +++ b/api/api_v1/endpoints/space.py @@ -19,10 +19,10 @@ async def create( try: await crud.space.create(db, data_in, user_id=current_user.id) except pymongo.errors.DuplicateKeyError: - return schemas.Msg(code=-1, msg='error', detail='空间已存在') + return schemas.Msg(code=-1, msg='空间已存在', data='空间已存在') # todo 建默认文件夹 - return schemas.Msg(code=0, msg='ok', detail='创建成功') + return schemas.Msg(code=0, msg='创建成功', data='创建成功') @router.post("/delete") @@ -45,5 +45,5 @@ async def delete( await crud.dashboard.delete(db, pid=data_in.id) if del_space.deleted_count == 0: - return schemas.Msg(code=-1, msg='error', detail='删除失败') - return schemas.Msg(code=0, msg='ok', detail='删除成功') + return schemas.Msg(code=-1, msg='error', data='删除失败') + return schemas.Msg(code=0, msg='ok', dtta='删除成功') diff --git a/api/api_v1/endpoints/user.py b/api/api_v1/endpoints/user.py index 26951db..245cb7d 100644 --- a/api/api_v1/endpoints/user.py +++ b/api/api_v1/endpoints/user.py @@ -9,11 +9,8 @@ import crud, schemas from api import deps from core import security from core.config import settings -from core.security import get_password_hash + from db import get_database -from utils import ( - verify_password_reset_token, -) router = APIRouter() @@ -23,7 +20,7 @@ async def login( # data: schemas.UserLogin, data: OAuth2PasswordRequestForm = Depends(), db: AsyncIOMotorDatabase = Depends(get_database) -) -> dict: +) -> Any: """ OAuth2兼容令牌登录,获取将来令牌的访问令牌 """ @@ -31,18 +28,19 @@ async def login( name=data.username, password=data.password ) if not user: - raise HTTPException(status_code=400, detail="Incorrect name or password") + # raise HTTPException(status_code=400, detail="Incorrect name or password") + return schemas.Msg(code=-1, msg='密码或用户名错误') access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) return { 'data': { 'name': user.name, 'email': user.email, - # 'access_token': security.create_access_token( - # expires_delta=access_token_expires, id=user.id, email=user.email, is_active=user.is_active, - # is_superuser=user.is_superuser, name=user.name - # ), - # "token_type": "bearer", + 'token': security.create_access_token( + expires_delta=access_token_expires, _id=str(user.id), email=user.email, + is_superuser=user.is_superuser, name=user.name + ), + "token_type": "bearer", }, 'access_token': security.create_access_token( @@ -62,3 +60,18 @@ def me(current_user: schemas.User = Depends(deps.get_current_user)) -> Any: Test access token """ return current_user + + +@router.get("/all_user") +async def all_user(db: AsyncIOMotorDatabase = Depends(get_database)) -> Any: + """ + 获取所有用户 + """ + users = await crud.user.find_many(db) + data = [{ + "_id": "0", + "name": "所有用户" + }] + data += [schemas.UserDB(**user) for user in users] + + return schemas.Msg(code=0, msg='ok', data=data) diff --git a/api/deps.py b/api/deps.py index 168b3dc..5e5c37e 100644 --- a/api/deps.py +++ b/api/deps.py @@ -29,3 +29,19 @@ def get_current_user(token: str = Depends(reusable_oauth2) if not user: raise HTTPException(status_code=404, detail="User not found") return user + + +def get_current_user2(token: str) -> schemas.UserDB: + try: + payload = jwt.decode( + token, settings.SECRET_KEY, algorithms=[security.ALGORITHM] + ) + user = schemas.UserDB(**payload) + except (jwt.JWTError, ValidationError): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Could not validate credentials", + ) + if not user: + raise HTTPException(status_code=404, detail="User not found") + return user diff --git a/core/config.py b/core/config.py index 3c64c34..1228d2d 100644 --- a/core/config.py +++ b/core/config.py @@ -15,6 +15,8 @@ class Settings(BaseSettings): MDB_PASSWORD: str = 'iamciniao' MDB_DB: str = 'xdata' + CASBIN_COLL: str = 'casbin_rule' + DATABASE_URI = f'mongodb://{MDB_USER}:{MDB_PASSWORD}@{MDB_HOST}:{MDB_PORT}/admin' FIRST_EMAIL: str = '15392746632@qq.com' diff --git a/crud/crud_report.py b/crud/crud_report.py index 3eba83b..2cfe07b 100644 --- a/crud/crud_report.py +++ b/crud/crud_report.py @@ -12,7 +12,8 @@ class CRUDReport(CRUDBase): async def create(self, db: AsyncIOMotorDatabase, obj_in: ReportCreate, user_id: str): db_obj = ReportDB( **obj_in.dict(), user_id=user_id, - _id=uuid.uuid1().hex + _id=uuid.uuid1().hex, + members=[user_id] ) await db[self.coll_name].insert_one(db_obj.dict(by_alias=True)) diff --git a/main.py b/main.py index e92e08d..0afb1a4 100644 --- a/main.py +++ b/main.py @@ -1,8 +1,15 @@ +import base64 +import binascii + import uvicorn from fastapi import FastAPI +import casbin from core.config import settings from starlette.middleware.cors import CORSMiddleware +from starlette.authentication import AuthenticationBackend, AuthenticationError, SimpleUser, AuthCredentials +from starlette.middleware.authentication import AuthenticationMiddleware + from db import connect_to_mongo, close_mongo_connection @@ -19,6 +26,26 @@ if settings.BACKEND_CORS_ORIGINS: app.add_event_handler("startup", connect_to_mongo) app.add_event_handler("shutdown", close_mongo_connection) + +class BasicAuth(AuthenticationBackend): + async def authenticate(self, request): + if "Authorization" not in request.headers: + return None + + auth = request.headers["Authorization"] + try: + scheme, credentials = auth.split() + decoded = base64.b64decode(credentials).decode("ascii") + except (ValueError, UnicodeDecodeError, binascii.Error): + raise AuthenticationError("Invalid basic auth credentials") + + username, _, password = decoded.partition(":") + return AuthCredentials(["authenticated"]), SimpleUser(username) + +# enforcer = casbin.Enforcer('rbac_model.conf', 'rbac_policy.csv') +# app.add_middleware(CasbinMiddleware, enforcer=enforcer) +# app.add_middleware(AuthenticationMiddleware, backend=BasicAuth()) + from api.api_v1.api import api_router app.include_router(api_router, prefix=settings.API_V1_STR) diff --git a/rbac_policy.py b/rbac_policy.py new file mode 100644 index 0000000..6d0a186 --- /dev/null +++ b/rbac_policy.py @@ -0,0 +1,21 @@ +import casbin + +from core.config import settings +from pymongo import MongoClient + +from utils import Adapter + +client = MongoClient(settings.DATABASE_URI) +db = client[settings.MDB_DB] +collection = db[settings.CASBIN_COLL] + +adapter = Adapter(settings.DATABASE_URI, settings.MDB_DB) +enforcer = casbin.Enforcer('rbac_model.conf', adapter) + +model = enforcer.get_model() +model.add_policy('g', 'g', ['root', 'superAdmin', ]) +model.add_policy('g', 'g', ['legu', 'admin']) +adapter.save_policy(model) + +res = enforcer.enforce('alice', 'data1', 'read') +print(res) diff --git a/schemas/msg.py b/schemas/msg.py index 3475056..7721feb 100644 --- a/schemas/msg.py +++ b/schemas/msg.py @@ -6,4 +6,4 @@ from pydantic import BaseModel class Msg(BaseModel): code: int msg: str - detail: Any + data: Any diff --git a/schemas/report.py b/schemas/report.py index 92d0b2e..2e431e2 100644 --- a/schemas/report.py +++ b/schemas/report.py @@ -25,6 +25,10 @@ class ReportDelete(DBBase): pass +class ReportRead(DBBase): + project_id: str + + # -------------------------------------------------------------- # 数据库模型 class ReportDB(DBBase): @@ -32,5 +36,6 @@ class ReportDB(DBBase): user_id: str project_id: str # cat: Category + members: List[str] = [] pid: str create_date: datetime = datetime.now() diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..a7dde9b --- /dev/null +++ b/utils/__init__.py @@ -0,0 +1 @@ +from .adapter import Adapter diff --git a/utils/adapter.py b/utils/adapter.py new file mode 100644 index 0000000..5efbcb6 --- /dev/null +++ b/utils/adapter.py @@ -0,0 +1,112 @@ +import casbin +from casbin import persist +from pymongo import MongoClient + + +class CasbinRule: + ''' + CasbinRule model + ''' + + def __init__(self, ptype=None, v0=None, v1=None, v2=None, v3=None, v4=None, v5=None): + self.ptype = ptype + self.v0 = v0 + self.v1 = v1 + self.v2 = v2 + self.v3 = v3 + self.v4 = v4 + self.v5 = v5 + + def dict(self): + d = {'ptype': self.ptype} + + for i, v in enumerate([self.v0, self.v1, self.v2, self.v3, self.v4, self.v5]): + if v is None: + break + d['v' + str(i)] = v + + return d + + def __str__(self): + return ', '.join(self.dict().values()) + + def __repr__(self): + return ''.format(str(self)) + + +class Adapter(persist.Adapter): + """the interface for Casbin adapters.""" + + def __init__(self, uri, dbname, collection="casbin_rule"): + client = MongoClient(uri) + db = client[dbname] + self._collection = db[collection] + + def load_policy(self, model): + ''' + implementing add Interface for casbin \n + load all policy rules from mongodb \n + ''' + + for line in self._collection.find(): + if 'ptype' not in line: + continue + + rule = CasbinRule(line['ptype']) + if 'v0' in line: + rule.v0 = line['v0'] + if 'v1' in line: + rule.v1 = line['v1'] + if 'v2' in line: + rule.v2 = line['v2'] + if 'v3' in line: + rule.v3 = line['v3'] + if 'v4' in line: + rule.v4 = line['v4'] + if 'v5' in line: + rule.v5 = line['v5'] + + persist.load_policy_line(str(rule), model) + + def _save_policy_line(self, ptype, rule): + line = CasbinRule(ptype=ptype) + if len(rule) > 0: + line.v0 = rule[0] + if len(rule) > 1: + line.v1 = rule[1] + if len(rule) > 2: + line.v2 = rule[2] + if len(rule) > 3: + line.v3 = rule[3] + if len(rule) > 4: + line.v4 = rule[4] + if len(rule) > 5: + line.v5 = rule[5] + self._collection.update_one(line.dict(), {'$set':line.dict()}, upsert=True) + + def save_policy(self, model): + ''' + implementing add Interface for casbin \n + save the policy in mongodb \n + ''' + for sec in ["p", "g"]: + if sec not in model.model.keys(): + continue + for ptype, ast in model.model[sec].items(): + for rule in ast.policy: + self._save_policy_line(ptype, rule) + return True + + def add_policy(self, sec, ptype, rule): + """add policy rules to mongodb""" + self._save_policy_line(ptype, rule) + + def remove_policy(self, sec, ptype, rule): + """delete policy rules from mongodb""" + pass + + def remove_filtered_policy(self, sec, ptype, field_index, *field_values): + """ + delete policy rules for matching filters from mongodb + """ + pass diff --git a/utils/async_adapter.py b/utils/async_adapter.py new file mode 100644 index 0000000..a4c7d04 --- /dev/null +++ b/utils/async_adapter.py @@ -0,0 +1,109 @@ +from casbin import persist + + +class CasbinRule: + ''' + CasbinRule model + ''' + + def __init__(self, ptype = None, v0 = None, v1 = None, v2 = None, v3 = None, v4 = None, v5 = None): + self.ptype = ptype + self.v0 = v0 + self.v1 = v1 + self.v2 = v2 + self.v3 = v3 + self.v4 = v4 + self.v5 = v5 + + def dict(self): + d = {'ptype': self.ptype} + + for i, v in enumerate([self.v0, self.v1, self.v2, self.v3, self.v4, self.v5]): + if v is None: + break + d['v' + str(i)] = v + + return d + + def __str__(self): + return ', '.join(self.dict().values()) + + def __repr__(self): + return ''.format(str(self)) + +class Adapter(persist.Adapter): + """the interface for Casbin adapters.""" + + def __init__(self,db, collection="casbin_rule"): + self._collection = db[collection] + + async def load_policy(self, model): + ''' + implementing add Interface for casbin \n + load all policy rules from mongodb \n + ''' + + async for line in self._collection.find(): + if 'ptype' not in line: + continue + + rule = CasbinRule(line['ptype']) + if 'v0' in line: + rule.v0 = line['v0'] + if 'v1' in line: + rule.v1 = line['v1'] + if 'v2' in line: + rule.v2 = line['v2'] + if 'v3' in line: + rule.v3 = line['v3'] + if 'v4' in line: + rule.v4 = line['v4'] + if 'v5' in line: + rule.v5 = line['v5'] + + persist.load_policy_line(str(rule), model) + + async def _save_policy_line(self, ptype, rule): + line = CasbinRule(ptype=ptype) + if len(rule) > 0: + line.v0 = rule[0] + if len(rule) > 1: + line.v1 = rule[1] + if len(rule) > 2: + line.v2 = rule[2] + if len(rule) > 3: + line.v3 = rule[3] + if len(rule) > 4: + line.v4 = rule[4] + if len(rule) > 5: + line.v5 = rule[5] + await self._collection.insert_one(line.dict()) + + async def save_policy(self, model): + ''' + implementing add Interface for casbin \n + save the policy in mongodb \n + ''' + for sec in ["p", "g"]: + if sec not in model.model.keys(): + continue + for ptype, ast in model.model[sec].items(): + for rule in ast.policy: + await self._save_policy_line(ptype, rule) + return True + + async def add_policy(self, sec, ptype, rule): + """add policy rules to mongodb""" + await self._save_policy_line(ptype, rule) + + def remove_policy(self, sec, ptype, rule): + """delete policy rules from mongodb""" + pass + + def remove_filtered_policy(self, sec, ptype, field_index, *field_values): + """ + delete policy rules for matching filters from mongodb + """ + pass + +