casbin多租户模型
This commit is contained in:
parent
a802cb2d12
commit
d4b49cbb2b
@ -11,7 +11,7 @@ from .endpoints import query
|
||||
|
||||
api_router = APIRouter()
|
||||
|
||||
api_router.include_router(user.router, tags=["登录接口"], prefix='/user')
|
||||
api_router.include_router(user.router, tags=["用户接口"], prefix='/user')
|
||||
api_router.include_router(project.router, tags=["项目接口"], prefix='/project')
|
||||
api_router.include_router(folder.router, tags=["文件夹接口"], prefix='/folder')
|
||||
api_router.include_router(space.router, tags=["空间接口"], prefix='/space')
|
||||
|
@ -13,7 +13,8 @@ router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/api_list")
|
||||
async def api_list(request: Request, current_user: schemas.UserDB = Depends(deps.get_current_user)) -> schemas.Msg:
|
||||
async def api_list(request: Request, game: str,
|
||||
current_user: schemas.UserDB = Depends(deps.get_current_user)) -> schemas.Msg:
|
||||
"""api 列表"""
|
||||
app = request.app
|
||||
data = []
|
||||
@ -25,85 +26,90 @@ async def api_list(request: Request, current_user: schemas.UserDB = Depends(deps
|
||||
|
||||
|
||||
@router.post("/add_role")
|
||||
async def add_role(request: Request, data_in: schemas.CasbinRoleCreate,
|
||||
async def add_role(request: Request, game: str, data_in: schemas.CasbinRoleCreate,
|
||||
db: AsyncIOMotorDatabase = Depends(get_database),
|
||||
current_user: schemas.UserDB = Depends(deps.get_current_user)
|
||||
) -> schemas.Msg:
|
||||
"""创建角色"""
|
||||
role = (
|
||||
'g',
|
||||
'root',
|
||||
data_in.role_name,
|
||||
None
|
||||
)
|
||||
await crud.authority.create(db, role)
|
||||
role_dom = game
|
||||
# 角色有的接口权限
|
||||
for obj in data_in.role_api:
|
||||
casbin_enforcer.add_policy(data_in.role_name, role_dom, obj, '*')
|
||||
await crud.authority.create(db, 'p', data_in.role_name, role_dom, obj, '*')
|
||||
|
||||
# 管理员默认拥有该角色 方便从db中读出
|
||||
await crud.authority.create(db, 'g', settings.SUPERUSER_NAME, data_in.role_name, '*', '*',
|
||||
role_name=data_in.role_name,
|
||||
game=role_dom)
|
||||
|
||||
for item in data_in.role_api:
|
||||
await crud.authority.create(db, (
|
||||
'p',
|
||||
data_in.role_name,
|
||||
item,
|
||||
'*'
|
||||
))
|
||||
return schemas.Msg(code=0, msg='ok')
|
||||
|
||||
|
||||
@router.post("/add_account")
|
||||
async def add_account(request: Request,
|
||||
game: str,
|
||||
data_in: schemas.AccountCreate,
|
||||
db: AsyncIOMotorDatabase = Depends(get_database),
|
||||
current_user: schemas.UserDB = Depends(deps.get_current_user)
|
||||
) -> schemas.Msg:
|
||||
"""创建账号 并设置角色"""
|
||||
|
||||
account = schemas.UserCreate(name=data_in.name, password=settings.DEFAULT_PASSWORD)
|
||||
account = schemas.UserCreate(name=data_in.username, nickname=data_in.nickname, password=settings.DEFAULT_PASSWORD)
|
||||
try:
|
||||
await crud.user.create(db, account)
|
||||
except pymongo.errors.DuplicateKeyError:
|
||||
return schemas.Msg(code=-1, msg='用户名已存在')
|
||||
rule = (
|
||||
'g',
|
||||
data_in.name,
|
||||
data_in.role_name,
|
||||
None
|
||||
)
|
||||
await crud.authority.create(db, rule)
|
||||
|
||||
casbin_enforcer.add_grouping_policy(data_in.username, data_in.role_name, game)
|
||||
await crud.authority.create(db, 'g', data_in.username, data_in.role_name, game)
|
||||
|
||||
return schemas.Msg(code=0, msg='ok')
|
||||
|
||||
|
||||
@router.get("/all_role")
|
||||
async def all_role(request: Request,
|
||||
game: str,
|
||||
db: AsyncIOMotorDatabase = Depends(get_database),
|
||||
current_user: schemas.UserDB = Depends(deps.get_current_user)
|
||||
) -> schemas.Msg:
|
||||
"""获取所有角色 和 角色权限"""
|
||||
routes = {}
|
||||
for item in request.app.routes:
|
||||
routes[item.path] = item.description if hasattr(item, 'description') else item.name
|
||||
roles = casbin_enforcer.get_all_roles()
|
||||
permissions = {}
|
||||
for role in roles:
|
||||
for _, path, _ in casbin_enforcer.get_permissions_for_user(role):
|
||||
permissions.setdefault(role, [])
|
||||
if path == '*':
|
||||
permissions[role].clear()
|
||||
"""获取所有角色"""
|
||||
roles = await crud.authority.find_many(db, role_name={'$exists': 1}, game=game)
|
||||
data = [{'role': item['v1'], 'name': item['role_name']} for item in roles]
|
||||
return schemas.Msg(code=0, msg='ok', data=data)
|
||||
|
||||
permissions[role] = [{
|
||||
'path': k,
|
||||
'name': v
|
||||
} for k, v in routes.items()]
|
||||
break
|
||||
|
||||
if path in routes:
|
||||
permissions[role].append(
|
||||
{
|
||||
'path': path,
|
||||
'name': routes[path]
|
||||
}
|
||||
)
|
||||
|
||||
return schemas.Msg(code=0, msg='ok', data={'roles': roles, 'permissions': permissions})
|
||||
# @router.get("/all_role")
|
||||
# async def all_role(request: Request,
|
||||
# db: AsyncIOMotorDatabase = Depends(get_database),
|
||||
# current_user: schemas.UserDB = Depends(deps.get_current_user)
|
||||
# ) -> schemas.Msg:
|
||||
# """获取所有角色 和 角色权限"""
|
||||
# routes = {}
|
||||
# for item in request.app.routes:
|
||||
# routes[item.path] = item.description if hasattr(item, 'description') else item.name
|
||||
# roles = casbin_enforcer.get_all_roles()
|
||||
# permissions = {}
|
||||
# for role in roles:
|
||||
# for _, path, _ in casbin_enforcer.get_permissions_for_user(role):
|
||||
# permissions.setdefault(role, [])
|
||||
# if path == '*':
|
||||
# permissions[role].clear()
|
||||
#
|
||||
# permissions[role] = [{
|
||||
# 'path': k,
|
||||
# 'name': v
|
||||
# } for k, v in routes.items()]
|
||||
# break
|
||||
#
|
||||
# if path in routes:
|
||||
# permissions[role].append(
|
||||
# {
|
||||
# 'path': path,
|
||||
# 'name': routes[path]
|
||||
# }
|
||||
# )
|
||||
#
|
||||
# return schemas.Msg(code=0, msg='ok', data={'roles': roles, 'permissions': permissions})
|
||||
|
||||
|
||||
@router.post("/set_role")
|
||||
@ -115,13 +121,10 @@ async def set_role(request: Request,
|
||||
"""设置账号角色"""
|
||||
casbin_enforcer.delete_user(data_id.name)
|
||||
casbin_enforcer.add_role_for_user(data_id.name, data_id.role_name)
|
||||
crud.authority.update_upsert(db, {'prtype': 'g', 'v0': data_id.name}, v1=data_id.role_name)
|
||||
await crud.authority.update_one(db, {'ptype': 'g', 'v0': data_id.name}, dict(v1=data_id.role_name))
|
||||
|
||||
return schemas.Msg(code=0, msg='ok')
|
||||
|
||||
|
||||
|
||||
|
||||
# @router.get("/delete_user")
|
||||
# async def delete_user(request: Request,
|
||||
# data_id: schemas.AccountDeleteUser,
|
||||
|
@ -47,7 +47,8 @@ async def move(
|
||||
"""
|
||||
移动看板
|
||||
"""
|
||||
res = await crud.dashboard.update_one(db, id=data_in.source_id, cat=data_in.cat, pid=data_in.dest_id)
|
||||
res = await crud.dashboard.update_one(db, {'_id': data_in.source_id},
|
||||
{'$set': dict(cat=data_in.cat, pid=data_in.dest_id)})
|
||||
if res.deleted_count == 0:
|
||||
return schemas.Msg(code=-1, msg='error', data='删除失败')
|
||||
return schemas.Msg(code=0, msg='ok', data='删除成功')
|
||||
@ -59,7 +60,8 @@ async def add_report(data_in: schemas.AddReport,
|
||||
current_user: schemas.UserDB = Depends(deps.get_current_user)
|
||||
):
|
||||
"""添加报表"""
|
||||
res = await crud.dashboard.update_one(db, id=data_in.id, **{'$push': {'reports': {'$each': data_in.report_ids}}})
|
||||
res = await crud.dashboard.update_one(db, {'_id': data_in.id},
|
||||
{'$push': {'reports': {'$each': data_in.report_ids}}})
|
||||
return schemas.Msg(code=0, msg='ok', data='ok')
|
||||
|
||||
|
||||
@ -70,7 +72,7 @@ async def del_report(data_in: schemas.DelReport,
|
||||
):
|
||||
"""删除报表"""
|
||||
for item in data_in.report_ids:
|
||||
await crud.dashboard.update_one(db, id=data_in.id, **{'$pull': {'reports': item}})
|
||||
await crud.dashboard.update_one(db, {'_id': data_in.id}, {'$pull': {'reports': item}})
|
||||
return schemas.Msg(code=0, msg='ok', data='ok')
|
||||
|
||||
|
||||
|
@ -3,9 +3,11 @@ from fastapi import APIRouter, Depends, Request
|
||||
from motor.motor_asyncio import AsyncIOMotorDatabase
|
||||
import crud, schemas
|
||||
from api import deps
|
||||
from core.config import settings
|
||||
|
||||
from db import get_database
|
||||
from schemas.project import ProjectCreate
|
||||
from utils import casbin_enforcer
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@ -22,8 +24,14 @@ async def create(
|
||||
except pymongo.errors.DuplicateKeyError:
|
||||
return schemas.Msg(code=-1, msg='项目名已存在', data='项目名已存在')
|
||||
# todo 建默认文件夹
|
||||
# schemas.FolderCreate
|
||||
# await crud.folder.create(db, data_in, user_id=current_user.id)
|
||||
|
||||
# 新建项目管理员权限
|
||||
role_name = f'{data_in.game}_admin'
|
||||
role_dom = data_in.game
|
||||
casbin_enforcer.add_policy(role_name, role_dom, '*', '*')
|
||||
await crud.authority.create(db, 'p', role_name, role_dom, '*', '*')
|
||||
# 添加角色
|
||||
await crud.authority.create(db, 'g', settings.SUPERUSER_NAME, role_name, '*', '*', role_name='项目管理员', game=role_dom)
|
||||
|
||||
return schemas.Msg(code=0, msg='创建成功')
|
||||
|
||||
@ -38,6 +46,49 @@ async def read_project(request: Request,
|
||||
return schemas.Msg(code=0, msg='ok', data=res)
|
||||
|
||||
|
||||
@router.post("/add_members")
|
||||
async def add_members(request: Request,
|
||||
game: str,
|
||||
data_in: schemas.ProjectMember,
|
||||
db: AsyncIOMotorDatabase = Depends(get_database),
|
||||
current_user: schemas.UserDB = Depends(deps.get_current_user)
|
||||
):
|
||||
"""项目添加成员"""
|
||||
|
||||
#
|
||||
# await crud.project.add_members(db, data_in)
|
||||
for item in data_in.members:
|
||||
casbin_enforcer.add_grouping_policy(item.username, item.role, game)
|
||||
await crud.authority.create(db, 'g', item.username, item.role, game)
|
||||
return schemas.Msg(code=0, msg='ok', data=data_in)
|
||||
|
||||
|
||||
@router.get("/members")
|
||||
async def members(request: Request,
|
||||
game: str,
|
||||
db: AsyncIOMotorDatabase = Depends(get_database),
|
||||
current_user: schemas.UserDB = Depends(deps.get_current_user)
|
||||
):
|
||||
"""查看项目成员"""
|
||||
res = await crud.authority.find_many(db, ptype='g', v2=game)
|
||||
data = [{'name': 'root', 'role': '超级管理员'}]
|
||||
data += [{'name': item['v0'], 'role': item['v1']} for item in res]
|
||||
return schemas.Msg(code=0, msg='ok', data=data)
|
||||
|
||||
|
||||
@router.post("/del_member")
|
||||
async def members(request: Request,
|
||||
game: str,
|
||||
data_in: schemas.ProjectDelMember,
|
||||
db: AsyncIOMotorDatabase = Depends(get_database),
|
||||
current_user: schemas.UserDB = Depends(deps.get_current_user)
|
||||
):
|
||||
"""删除项目成员"""
|
||||
casbin_enforcer.delete_user(data_in.username)
|
||||
await crud.authority.delete(db, ptype='g', v2=game, v0=data_in.username)
|
||||
return schemas.Msg(code=0, msg='ok')
|
||||
|
||||
|
||||
@router.post("/kanban")
|
||||
async def read_kanban(
|
||||
data_in: schemas.ProjectKanban,
|
||||
|
@ -13,6 +13,7 @@ router = APIRouter()
|
||||
@router.post("/")
|
||||
async def read_table_struct(
|
||||
request: Request,
|
||||
game: str,
|
||||
data_in: schemas.GetTable,
|
||||
rdb: Redis = Depends(get_redis_pool),
|
||||
current_user: schemas.UserDB = Depends(deps.get_current_user)
|
||||
@ -20,4 +21,3 @@ async def read_table_struct(
|
||||
"""获取表结构"""
|
||||
data = await rdb.get(f'{data_in.game}_{data_in.name}')
|
||||
return schemas.Msg(code=0, msg='ok', data=json.loads(data))
|
||||
|
||||
|
@ -35,6 +35,7 @@ async def login(
|
||||
return {
|
||||
'data': {
|
||||
'name': user.name,
|
||||
'nickname': user.nickname,
|
||||
'email': user.email,
|
||||
'token': security.create_access_token(
|
||||
expires_delta=access_token_expires, _id=str(user.id), email=user.email,
|
||||
@ -45,6 +46,7 @@ async def login(
|
||||
},
|
||||
'access_token': security.create_access_token(
|
||||
expires_delta=access_token_expires, _id=str(user.id), email=user.email,
|
||||
nickname=user.nickname,
|
||||
is_superuser=user.is_superuser, name=user.name
|
||||
),
|
||||
"token_type": "bearer",
|
||||
@ -62,16 +64,18 @@ def me(current_user: schemas.User = Depends(deps.get_current_user)) -> Any:
|
||||
return current_user
|
||||
|
||||
|
||||
@router.get("/all_user")
|
||||
async def all_user(db: AsyncIOMotorDatabase = Depends(get_database)) -> Any:
|
||||
@router.get("/all_account")
|
||||
async def all_account(page: int = 1, limit: int = 10, db: AsyncIOMotorDatabase = Depends(get_database),
|
||||
current_user: schemas.User = Depends(deps.get_current_user)
|
||||
) -> Any:
|
||||
"""
|
||||
获取所有用户
|
||||
"""
|
||||
users = await crud.user.find_many(db)
|
||||
data = [{
|
||||
"_id": "0",
|
||||
"name": "所有用户"
|
||||
}]
|
||||
data += [schemas.UserDB(**user) for user in users]
|
||||
page -= 1
|
||||
if page < 0:
|
||||
page = 0
|
||||
cursor = crud.user.find(db).skip(page * limit).limit(limit)
|
||||
|
||||
data = [schemas.UserDB(**user) async for user in cursor]
|
||||
|
||||
return schemas.Msg(code=0, msg='ok', data=data)
|
||||
|
@ -19,9 +19,10 @@ class Settings(BaseSettings):
|
||||
|
||||
DATABASE_URI = f'mongodb://{MDB_USER}:{MDB_PASSWORD}@{MDB_HOST}:{MDB_PORT}/admin'
|
||||
|
||||
FIRST_EMAIL: str = '15392746632@qq.com'
|
||||
FIRST_SUPERUSER_PASSWORD: str = '123456'
|
||||
FIRST_NAME: str = 'root'
|
||||
SUPERUSER_EMAIL: str = '15392746632@qq.com'
|
||||
SUPERUSER_PASSWORD: str = '123456'
|
||||
SUPERUSER_NAME: str = 'root'
|
||||
SUPERUSER_NICKNAME: str = 'root'
|
||||
|
||||
DEFAULT_PASSWORD = '123456'
|
||||
|
||||
|
22
crud/base.py
22
crud/base.py
@ -7,27 +7,33 @@ class CRUDBase:
|
||||
def __init__(self, coll_name):
|
||||
self.coll_name = coll_name
|
||||
|
||||
async def get(self, db, id: Union[ObjectId, str]):
|
||||
async def get(self, db, id: Union[ObjectId, str]) -> dict:
|
||||
return (await db[self.coll_name].find_one({'_id': ObjectId(id)})) or dict()
|
||||
|
||||
async def read_have(self, db, user_id: str, **kwargs):
|
||||
where = {'members': user_id}
|
||||
where.update(kwargs)
|
||||
cursor = db[self.coll_name].find(where)
|
||||
return await cursor.to_list(length=999)
|
||||
return await cursor.to_list(length=9999)
|
||||
|
||||
async def find_many(self, db, **kwargs):
|
||||
cursor = db[self.coll_name].find(kwargs)
|
||||
return await cursor.to_list(length=999)
|
||||
return await cursor.to_list(length=9999)
|
||||
|
||||
def find(self, db, *args, **kwargs):
|
||||
cursor = db[self.coll_name].find(*args, **kwargs)
|
||||
return cursor
|
||||
|
||||
@staticmethod
|
||||
async def to_list(cursor):
|
||||
async for doc in cursor:
|
||||
yield doc
|
||||
|
||||
async def delete(self, db, **kwargs):
|
||||
return await db[self.coll_name].delete_many(kwargs)
|
||||
|
||||
async def update_one(self, db, id, **kwargs):
|
||||
return await db[self.coll_name].update_one({'_id': id}, kwargs)
|
||||
|
||||
async def update_upsert(self, db, where: dict, **kwargs):
|
||||
return await db[self.coll_name].update_one(where, {'$set': kwargs}, upsert=True)
|
||||
async def update_one(self, db, filter, update, upsert=False):
|
||||
return await db[self.coll_name].update_one(filter, update, upsert)
|
||||
|
||||
async def distinct(self, db, key, filter=None):
|
||||
return await db[self.coll_name].distinct(key, filter)
|
||||
|
@ -11,21 +11,29 @@ __all__ = 'authority',
|
||||
|
||||
class CRUDAuthority(CRUDBase):
|
||||
|
||||
async def create(self, db: AsyncIOMotorDatabase, *args):
|
||||
casbin_model.add_policy(args[0], args[0], args[1:])
|
||||
data = {'ptype': args[0],
|
||||
'v0': args[1],
|
||||
'v1': args[2],
|
||||
'v2': args[3],
|
||||
}
|
||||
await self.update_upsert(db, data, **data)
|
||||
async def create(self, db: AsyncIOMotorDatabase, *args, **kwargs):
|
||||
data = dict()
|
||||
if len(args) > 0:
|
||||
data['ptype'] = args[0]
|
||||
if len(args) > 1:
|
||||
data['v0'] = args[1]
|
||||
if len(args) > 2:
|
||||
data['v1'] = args[2]
|
||||
if len(args) > 3:
|
||||
data['v2'] = args[3]
|
||||
if len(args) > 4:
|
||||
data['v3'] = args[4]
|
||||
if len(args) > 5:
|
||||
data['v4'] = args[5]
|
||||
|
||||
data.update(kwargs)
|
||||
await self.update_one(db, data, {'$set': data}, upsert=True)
|
||||
|
||||
async def create_index(self, db: AsyncIOMotorDatabase):
|
||||
await db[self.coll_name].create_index(
|
||||
[('ptype', pymongo.DESCENDING), ('v0', pymongo.DESCENDING), ('v1', pymongo.DESCENDING),
|
||||
('v2', pymongo.DESCENDING)],
|
||||
('v2', pymongo.DESCENDING), ('v3', pymongo.DESCENDING)],
|
||||
unique=True)
|
||||
|
||||
|
||||
|
||||
authority = CRUDAuthority(settings.CASBIN_COLL)
|
||||
|
@ -13,11 +13,16 @@ class CRUDProject(CRUDBase):
|
||||
**obj_in.dict(), user_id=user_id, members=[user_id],
|
||||
_id=uuid.uuid1().hex
|
||||
)
|
||||
await db[self.coll_name].insert_one(db_obj.dict(by_alias=True))
|
||||
return await db[self.coll_name].insert_one(db_obj.dict(by_alias=True))
|
||||
|
||||
async def read_project(self, db: AsyncIOMotorDatabase, user_id: str):
|
||||
return await self.read_have(db, user_id=user_id)
|
||||
|
||||
# async def add_members(self, db: AsyncIOMotorDatabase, obj_in: ProjectMember):
|
||||
# p = await self.get(db, obj_in.project_id)
|
||||
# members = list(set(p.get('members')) | set(obj_in.members))
|
||||
# await self.update_one(db, {'_id': obj_in.project_id}, {'$set': {'members': members}})
|
||||
|
||||
async def create_index(self, db: AsyncIOMotorDatabase):
|
||||
await db[self.coll_name].create_index('name', unique=True)
|
||||
|
||||
|
@ -21,6 +21,7 @@ class CRUDUser(CRUDBase):
|
||||
hashed_password=get_password_hash(obj_in.password),
|
||||
name=obj_in.name,
|
||||
is_superuser=obj_in.is_superuser,
|
||||
nickname=obj_in.nickname,
|
||||
_id=uuid.uuid1().hex
|
||||
)
|
||||
return await db[self.coll_name].insert_one(db_obj.dict(by_alias=True))
|
||||
|
@ -11,12 +11,13 @@ db = get_database()
|
||||
|
||||
|
||||
async def create_superuser():
|
||||
user = await crud.user.get_by_user(db=db, name=settings.FIRST_NAME)
|
||||
user = await crud.user.get_by_user(db=db, name=settings.SUPERUSER_NAME)
|
||||
if not user:
|
||||
user_in = schemas.UserCreate(
|
||||
name=settings.FIRST_NAME,
|
||||
email=settings.FIRST_EMAIL,
|
||||
password=settings.FIRST_SUPERUSER_PASSWORD,
|
||||
name=settings.SUPERUSER_NAME,
|
||||
email=settings.SUPERUSER_EMAIL,
|
||||
password=settings.SUPERUSER_PASSWORD,
|
||||
nickname=settings.SUPERUSER_NICKNAME,
|
||||
is_superuser=True,
|
||||
)
|
||||
await crud.user.create(db, user_in)
|
||||
@ -45,15 +46,10 @@ async def report_index():
|
||||
|
||||
async def authority_init():
|
||||
await crud.authority.create_index(db)
|
||||
await crud.authority.create(db, 'p', 'admin', '*', '*')
|
||||
await crud.authority.create(db, 'g', 'root', 'admin', None)
|
||||
await crud.authority.create(db, 'p', '*', '/docs', '*')
|
||||
await crud.authority.create(db, 'p', '*', '/openapi.json', '*')
|
||||
|
||||
await crud.authority.create(db, 'p', '*', '/api/v1/user/login', '*')
|
||||
|
||||
await crud.authority.create(db, 'p', '*', '/api/v1/project/', '*')
|
||||
await crud.authority.create(db, 'p', '*', '/api/v1/project/kanban', '*')
|
||||
await crud.authority.create(db, 'p', 'anonymous', '*', '/docs', '*')
|
||||
await crud.authority.create(db, 'p', 'anonymous', '*', '/openapi.json', '*')
|
||||
await crud.authority.create(db, 'p', 'anonymous', '*', '/api/v1/user/login', '*')
|
||||
await crud.authority.create(db, 'p', 'anonymous', '*', '/docs', '*')
|
||||
|
||||
|
||||
async def main():
|
||||
|
2
main.py
2
main.py
@ -5,7 +5,7 @@ from fastapi import FastAPI
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
from starlette.authentication import AuthenticationBackend, AuthenticationError, AuthCredentials, BaseUser, SimpleUser
|
||||
from starlette.middleware.authentication import AuthenticationMiddleware
|
||||
from fastapi_authz import CasbinMiddleware
|
||||
from middleware import CasbinMiddleware
|
||||
|
||||
from db import connect_to_mongo, close_mongo_connection, get_database
|
||||
from db.ckdb_utils import connect_to_ck, close_ck_connection
|
||||
|
1
middleware/__init__.py
Normal file
1
middleware/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .casbin import CasbinMiddleware
|
70
middleware/casbin.py
Normal file
70
middleware/casbin.py
Normal file
@ -0,0 +1,70 @@
|
||||
from casbin.enforcer import Enforcer
|
||||
from fastapi import HTTPException
|
||||
from starlette.authentication import BaseUser
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.status import HTTP_403_FORBIDDEN
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
import schemas
|
||||
|
||||
|
||||
class CasbinMiddleware:
|
||||
"""
|
||||
Middleware for Casbin
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
enforcer: Enforcer,
|
||||
) -> None:
|
||||
"""
|
||||
Configure Casbin Middleware
|
||||
|
||||
:param app:Retain for ASGI.
|
||||
:param enforcer:Casbin Enforcer, must be initialized before FastAPI start.
|
||||
"""
|
||||
self.app = app
|
||||
self.enforcer = enforcer
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] not in ("http", "websocket"):
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
if self._enforce(scope, receive):
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
else:
|
||||
response = JSONResponse(
|
||||
status_code=HTTP_403_FORBIDDEN,
|
||||
content="没有操作权限"
|
||||
)
|
||||
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
|
||||
def _enforce(self, scope: Scope, receive: Receive) -> bool:
|
||||
"""
|
||||
Enforce a request
|
||||
|
||||
:param user: user will be sent to enforcer
|
||||
:param request: ASGI Request
|
||||
:return: Enforce Result
|
||||
"""
|
||||
|
||||
request = Request(scope, receive)
|
||||
|
||||
path = request.url.path
|
||||
method = request.method
|
||||
if 'user' not in scope:
|
||||
raise RuntimeError("Casbin Middleware must work with an Authentication Middleware")
|
||||
|
||||
assert isinstance(request.user, BaseUser)
|
||||
|
||||
user = request.user.display_name if request.user.is_authenticated else 'anonymous'
|
||||
dom = request.query_params.get('game', '0')
|
||||
print(user, dom, path, method)
|
||||
|
||||
return self.enforcer.enforce(user, dom, path, method)
|
@ -1,14 +1,14 @@
|
||||
[request_definition]
|
||||
r = sub, obj, act
|
||||
r = sub, dom, obj, act
|
||||
|
||||
[policy_definition]
|
||||
p = sub, obj, act
|
||||
p = sub, dom, obj, act
|
||||
|
||||
[role_definition]
|
||||
g = _, _
|
||||
g = _, _, _
|
||||
|
||||
[policy_effect]
|
||||
e = some(where (p.eft == allow))
|
||||
|
||||
[matchers]
|
||||
m = (p.sub == "*" || g(r.sub, p.sub)) && (r.obj == p.obj || keyMatch(r.obj, p.obj)) && (p.act == "*" || r.act == p.act)
|
||||
m = g(r.sub, p.sub, r.dom) && (p.dom=="*" || r.dom == p.dom) && ( p.obj=="*" || r.obj == p.obj) && (p.act=="*" || r.act == p.act) || r.sub=="root"
|
@ -22,8 +22,9 @@ class CasbinDB(BaseModel):
|
||||
|
||||
|
||||
class AccountCreate(BaseModel):
|
||||
name: str
|
||||
username: str
|
||||
role_name: str
|
||||
nickname: str
|
||||
|
||||
|
||||
class AccountDeleteUser(BaseModel):
|
||||
|
@ -11,6 +11,20 @@ class ProjectBase(BaseModel):
|
||||
name: str = None
|
||||
|
||||
|
||||
class MemberRole(BaseModel):
|
||||
username: str
|
||||
role: str
|
||||
|
||||
|
||||
class ProjectMember(BaseModel):
|
||||
members: List[MemberRole]
|
||||
# project_id: str
|
||||
|
||||
|
||||
class ProjectDelMember(BaseModel):
|
||||
username: str
|
||||
|
||||
|
||||
# 解析请求json 创建项目
|
||||
class ProjectCreate(ProjectBase):
|
||||
name: str = Field(..., title='项目名')
|
||||
|
@ -9,6 +9,7 @@ class UserBase(BaseModel):
|
||||
email: Optional[EmailStr] = None
|
||||
is_superuser: bool = False
|
||||
name: Optional[str] = None
|
||||
nickname: str
|
||||
|
||||
|
||||
class User(UserBase):
|
||||
@ -23,6 +24,7 @@ class UserLogin(BaseModel):
|
||||
class UserCreate(UserBase):
|
||||
password: str
|
||||
name: str
|
||||
nickname: str
|
||||
|
||||
|
||||
# ****************************************************************************
|
||||
@ -33,6 +35,7 @@ class UserDB(DBBase):
|
||||
email: EmailStr = None
|
||||
is_superuser: bool = False
|
||||
name: str
|
||||
nickname: str
|
||||
|
||||
|
||||
class UserDBRW(UserDB):
|
||||
|
Loading…
Reference in New Issue
Block a user