1
This commit is contained in:
parent
b345c22932
commit
8a3d34ce00
@ -18,10 +18,10 @@ async def create(
|
||||
try:
|
||||
await crud.dashboard.create(db, data_in, user_id=current_user.id)
|
||||
except pymongo.errors.DuplicateKeyError:
|
||||
return schemas.Msg(code=-1, msg='error', detail='看板已存在')
|
||||
return schemas.Msg(code=-1, msg='error', data='看板已存在')
|
||||
# todo 建默认文件夹
|
||||
|
||||
return schemas.Msg(code=0, msg='ok', detail='创建成功')
|
||||
return schemas.Msg(code=0, msg='ok', data='创建成功')
|
||||
|
||||
|
||||
@router.post("/delete")
|
||||
@ -34,8 +34,8 @@ async def delete(
|
||||
del_dashboard = await crud.dashboard.delete(db, _id=data_in.id, user_id=current_user.id)
|
||||
|
||||
if del_dashboard.deleted_count == 0:
|
||||
return schemas.Msg(code=-1, msg='error', detail='删除失败')
|
||||
return schemas.Msg(code=0, msg='ok', detail='删除成功')
|
||||
return schemas.Msg(code=-1, msg='error', data='删除失败')
|
||||
return schemas.Msg(code=0, msg='ok', data='删除成功')
|
||||
|
||||
|
||||
@router.post("/move")
|
||||
@ -49,8 +49,8 @@ async def delete(
|
||||
"""
|
||||
res = await crud.dashboard.update_one(db, id=data_in.source_id, cat=data_in.cat, pid=data_in.dest_id)
|
||||
if res.deleted_count == 0:
|
||||
return schemas.Msg(code=-1, msg='error', detail='删除失败')
|
||||
return schemas.Msg(code=0, msg='ok', detail='删除成功')
|
||||
return schemas.Msg(code=-1, msg='error', data='删除失败')
|
||||
return schemas.Msg(code=0, msg='ok', data='删除成功')
|
||||
|
||||
|
||||
@router.post("/add_report")
|
||||
@ -59,7 +59,7 @@ async def add_report(data_in: schemas.AddReport,
|
||||
current_user: schemas.UserDB = Depends(deps.get_current_user)
|
||||
):
|
||||
res = await crud.dashboard.update_one(db, id=data_in.id, **{'$push': {'reports': {'$each': data_in.report_ids}}})
|
||||
return schemas.Msg(code=0, msg='ok', detail='ok')
|
||||
return schemas.Msg(code=0, msg='ok', data='ok')
|
||||
|
||||
|
||||
@router.post("/del_report")
|
||||
@ -70,7 +70,7 @@ async def add_report(data_in: schemas.DelReport,
|
||||
"""删除报表"""
|
||||
for item in data_in.report_ids:
|
||||
await crud.dashboard.update_one(db, id=data_in.id, **{'$pull': {'reports': item}})
|
||||
return schemas.Msg(code=0, msg='ok', detail='ok')
|
||||
return schemas.Msg(code=0, msg='ok', data='ok')
|
||||
|
||||
|
||||
@router.get("/")
|
||||
@ -81,4 +81,4 @@ async def add_report(_id: str,
|
||||
"""获取一个看板"""
|
||||
res = await crud.dashboard.get(db, id=_id)
|
||||
res['reports'] = await crud.report.find_many(db, **{'$in': {'_id': res.get('reports')}})
|
||||
return schemas.Msg(code=0, msg='ok', detail=res['reports'])
|
||||
return schemas.Msg(code=0, msg='ok', data=res['reports'])
|
||||
|
@ -19,10 +19,10 @@ async def create(
|
||||
try:
|
||||
await crud.folder.create(db, data_in, user_id=current_user.id)
|
||||
except pymongo.errors.DuplicateKeyError:
|
||||
return schemas.Msg(code=-1, msg='error', detail='文件夹已存在')
|
||||
return schemas.Msg(code=-1, msg='error', data='文件夹已存在')
|
||||
# todo 建默认文件夹
|
||||
|
||||
return schemas.Msg(code=0, msg='ok', detail='创建成功')
|
||||
return schemas.Msg(code=0, msg='ok', data='创建成功')
|
||||
|
||||
|
||||
@router.post("/delete")
|
||||
@ -37,5 +37,5 @@ async def delete(
|
||||
# 删除文件夹下的 dashboard
|
||||
del_dashboard = await crud.dashboard.delete(db, pid=data_in.id)
|
||||
if del_folder.deleted_count == 0:
|
||||
return schemas.Msg(code=-1, msg='error', detail='删除失败')
|
||||
return schemas.Msg(code=0, msg='ok', detail='删除成功')
|
||||
return schemas.Msg(code=-1, msg='error', data='删除失败')
|
||||
return schemas.Msg(code=0, msg='ok', data='删除成功')
|
||||
|
@ -1,5 +1,5 @@
|
||||
import pymongo
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from motor.motor_asyncio import AsyncIOMotorDatabase
|
||||
import crud, schemas
|
||||
from api import deps
|
||||
@ -20,20 +20,22 @@ async def create(
|
||||
try:
|
||||
await crud.project.create(db, data_in, user_id=current_user.id)
|
||||
except pymongo.errors.DuplicateKeyError:
|
||||
return schemas.Msg(code=-1, msg='error', detail='项目名已存在')
|
||||
return schemas.Msg(code=-1, msg='项目名已存在', data='项目名已存在')
|
||||
# todo 建默认文件夹
|
||||
# schemas.FolderCreate
|
||||
# await crud.folder.create(db, data_in, user_id=current_user.id)
|
||||
|
||||
return schemas.Msg(code=0, msg='ok', detail='创建成功')
|
||||
return schemas.Msg(code=0, msg='创建成功')
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def read_project(
|
||||
async def read_project(request: Request,
|
||||
db: AsyncIOMotorDatabase = Depends(get_database),
|
||||
current_user: schemas.UserDB = Depends(deps.get_current_user)
|
||||
):
|
||||
):
|
||||
"""查看自己拥有的项目"""
|
||||
res = await crud.project.read_project(db, user_id=current_user.id)
|
||||
return res
|
||||
return schemas.Msg(code=0, msg='ok', data=res)
|
||||
|
||||
|
||||
@router.post("/kanban")
|
||||
@ -43,19 +45,19 @@ async def read_kanban(
|
||||
current_user: schemas.UserDB = Depends(deps.get_current_user)
|
||||
):
|
||||
"""获取自己的看板"""
|
||||
res = {'kanban': [], 'space': []}
|
||||
res = {'kanban': [], 'spaces': []}
|
||||
# 我的看板
|
||||
kanban = await crud.folder.read_folder(db, project_id=data_in.id, user_id=current_user.id, cat='kanban')
|
||||
|
||||
for item in kanban:
|
||||
dashboards = await crud.dashboard.find_many(db, pid=item['_id'])
|
||||
res['kanban'].append({
|
||||
'folder_name': item['name'],
|
||||
'dashboards': [],
|
||||
'name': item['name'],
|
||||
'children': [],
|
||||
'_id': item['_id']
|
||||
})
|
||||
for d in dashboards:
|
||||
res['kanban'][-1]['dashboards'].append({
|
||||
res['kanban'][-1]['children'].append({
|
||||
'name': d['name'],
|
||||
'_id': item['_id']
|
||||
})
|
||||
@ -68,32 +70,33 @@ async def read_kanban(
|
||||
spaces = await crud.space.find_many(db, **where)
|
||||
# 空间 文件夹 看板
|
||||
for item in spaces:
|
||||
res['space'].append({
|
||||
'space_name': item['name'],
|
||||
'folders': [],
|
||||
'dashboards': [],
|
||||
res['spaces'].append({
|
||||
'name': item['name'],
|
||||
'children': [],
|
||||
'_id': item['_id']
|
||||
})
|
||||
res['space'][-1]['authority'] = 'rw' if current_user.id in item['rw_members'] else 'r'
|
||||
res['spaces'][-1]['authority'] = 'rw' if current_user.id in item['rw_members'] else 'r'
|
||||
|
||||
for f in await crud.folder.find_many(db, pid=item['_id']):
|
||||
res['space'][-1]['folders'].append({
|
||||
res['spaces'][-1]['children'].append({
|
||||
'name': f['name'],
|
||||
'_id': f['_id'],
|
||||
'dashboards': [],
|
||||
'children': [],
|
||||
'isFolder': True
|
||||
})
|
||||
|
||||
for d in await crud.dashboard.find_many(db, pid=f['_id']):
|
||||
res['space'][-1]['folders'][-1]['dashboards'].append({
|
||||
res['spaces'][-1]['children'][-1]['children'].append({
|
||||
'name': d['name'],
|
||||
'_id': d['_id']
|
||||
})
|
||||
|
||||
# 空间 看板
|
||||
for d in await crud.dashboard.find_many(db, pid=item['_id']):
|
||||
res['space'][-1]['dashboards'].append({
|
||||
res['spaces'][-1]['children'].append({
|
||||
'name': d['name'],
|
||||
'_id': d['_id']
|
||||
'_id': d['_id'],
|
||||
'isFolder': False
|
||||
})
|
||||
|
||||
return res
|
||||
return schemas.Msg(code=0, msg='ok', data=res)
|
||||
|
@ -21,19 +21,20 @@ async def create(
|
||||
try:
|
||||
await crud.report.create(db, data_in, user_id=current_user.id)
|
||||
except pymongo.errors.DuplicateKeyError:
|
||||
return schemas.Msg(code=-1, msg='error', detail='报表已存在')
|
||||
return schemas.Msg(code=-1, msg='error', data='报表已存在')
|
||||
|
||||
return schemas.Msg(code=0, msg='ok', detail='创建成功')
|
||||
return schemas.Msg(code=0, msg='ok', data='创建成功')
|
||||
|
||||
|
||||
@router.get("/")
|
||||
@router.post("/read_report")
|
||||
async def read_report(
|
||||
data_in: schemas.ReportRead,
|
||||
db: AsyncIOMotorDatabase = Depends(get_database),
|
||||
current_user: schemas.UserDB = Depends(deps.get_current_user)
|
||||
) -> Any:
|
||||
"""获取已建报表"""
|
||||
|
||||
res = await crud.report.read_report(db, user_id=current_user.id)
|
||||
res = await crud.report.read_report(db, user_id=current_user.id, project_id=data_in.project_id)
|
||||
return res
|
||||
|
||||
|
||||
@ -48,5 +49,5 @@ async def delete(
|
||||
del_report = await crud.report.delete(db, _id=data_in.id, user_id=current_user.id)
|
||||
|
||||
if del_report.deleted_count == 0:
|
||||
return schemas.Msg(code=-1, msg='error', detail='删除失败')
|
||||
return schemas.Msg(code=0, msg='ok', detail='删除成功')
|
||||
return schemas.Msg(code=-1, msg='error', data='删除失败')
|
||||
return schemas.Msg(code=0, msg='ok', data='删除成功')
|
||||
|
@ -19,10 +19,10 @@ async def create(
|
||||
try:
|
||||
await crud.space.create(db, data_in, user_id=current_user.id)
|
||||
except pymongo.errors.DuplicateKeyError:
|
||||
return schemas.Msg(code=-1, msg='error', detail='空间已存在')
|
||||
return schemas.Msg(code=-1, msg='空间已存在', data='空间已存在')
|
||||
# todo 建默认文件夹
|
||||
|
||||
return schemas.Msg(code=0, msg='ok', detail='创建成功')
|
||||
return schemas.Msg(code=0, msg='创建成功', data='创建成功')
|
||||
|
||||
|
||||
@router.post("/delete")
|
||||
@ -45,5 +45,5 @@ async def delete(
|
||||
await crud.dashboard.delete(db, pid=data_in.id)
|
||||
|
||||
if del_space.deleted_count == 0:
|
||||
return schemas.Msg(code=-1, msg='error', detail='删除失败')
|
||||
return schemas.Msg(code=0, msg='ok', detail='删除成功')
|
||||
return schemas.Msg(code=-1, msg='error', data='删除失败')
|
||||
return schemas.Msg(code=0, msg='ok', dtta='删除成功')
|
||||
|
@ -9,11 +9,8 @@ import crud, schemas
|
||||
from api import deps
|
||||
from core import security
|
||||
from core.config import settings
|
||||
from core.security import get_password_hash
|
||||
|
||||
from db import get_database
|
||||
from utils import (
|
||||
verify_password_reset_token,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@ -23,7 +20,7 @@ async def login(
|
||||
# data: schemas.UserLogin,
|
||||
data: OAuth2PasswordRequestForm = Depends(),
|
||||
db: AsyncIOMotorDatabase = Depends(get_database)
|
||||
) -> dict:
|
||||
) -> Any:
|
||||
"""
|
||||
OAuth2兼容令牌登录,获取将来令牌的访问令牌
|
||||
"""
|
||||
@ -31,18 +28,19 @@ async def login(
|
||||
name=data.username, password=data.password
|
||||
)
|
||||
if not user:
|
||||
raise HTTPException(status_code=400, detail="Incorrect name or password")
|
||||
# raise HTTPException(status_code=400, detail="Incorrect name or password")
|
||||
return schemas.Msg(code=-1, msg='密码或用户名错误')
|
||||
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
|
||||
return {
|
||||
'data': {
|
||||
'name': user.name,
|
||||
'email': user.email,
|
||||
# 'access_token': security.create_access_token(
|
||||
# expires_delta=access_token_expires, id=user.id, email=user.email, is_active=user.is_active,
|
||||
# is_superuser=user.is_superuser, name=user.name
|
||||
# ),
|
||||
# "token_type": "bearer",
|
||||
'token': security.create_access_token(
|
||||
expires_delta=access_token_expires, _id=str(user.id), email=user.email,
|
||||
is_superuser=user.is_superuser, name=user.name
|
||||
),
|
||||
"token_type": "bearer",
|
||||
|
||||
},
|
||||
'access_token': security.create_access_token(
|
||||
@ -62,3 +60,18 @@ def me(current_user: schemas.User = Depends(deps.get_current_user)) -> Any:
|
||||
Test access token
|
||||
"""
|
||||
return current_user
|
||||
|
||||
|
||||
@router.get("/all_user")
|
||||
async def all_user(db: AsyncIOMotorDatabase = Depends(get_database)) -> Any:
|
||||
"""
|
||||
获取所有用户
|
||||
"""
|
||||
users = await crud.user.find_many(db)
|
||||
data = [{
|
||||
"_id": "0",
|
||||
"name": "所有用户"
|
||||
}]
|
||||
data += [schemas.UserDB(**user) for user in users]
|
||||
|
||||
return schemas.Msg(code=0, msg='ok', data=data)
|
||||
|
16
api/deps.py
16
api/deps.py
@ -29,3 +29,19 @@ def get_current_user(token: str = Depends(reusable_oauth2)
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
return user
|
||||
|
||||
|
||||
def get_current_user2(token: str) -> schemas.UserDB:
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, settings.SECRET_KEY, algorithms=[security.ALGORITHM]
|
||||
)
|
||||
user = schemas.UserDB(**payload)
|
||||
except (jwt.JWTError, ValidationError):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Could not validate credentials",
|
||||
)
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
return user
|
||||
|
@ -15,6 +15,8 @@ class Settings(BaseSettings):
|
||||
MDB_PASSWORD: str = 'iamciniao'
|
||||
MDB_DB: str = 'xdata'
|
||||
|
||||
CASBIN_COLL: str = 'casbin_rule'
|
||||
|
||||
DATABASE_URI = f'mongodb://{MDB_USER}:{MDB_PASSWORD}@{MDB_HOST}:{MDB_PORT}/admin'
|
||||
|
||||
FIRST_EMAIL: str = '15392746632@qq.com'
|
||||
|
@ -12,7 +12,8 @@ class CRUDReport(CRUDBase):
|
||||
async def create(self, db: AsyncIOMotorDatabase, obj_in: ReportCreate, user_id: str):
|
||||
db_obj = ReportDB(
|
||||
**obj_in.dict(), user_id=user_id,
|
||||
_id=uuid.uuid1().hex
|
||||
_id=uuid.uuid1().hex,
|
||||
members=[user_id]
|
||||
|
||||
)
|
||||
await db[self.coll_name].insert_one(db_obj.dict(by_alias=True))
|
||||
|
27
main.py
27
main.py
@ -1,8 +1,15 @@
|
||||
import base64
|
||||
import binascii
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
import casbin
|
||||
|
||||
from core.config import settings
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
from starlette.authentication import AuthenticationBackend, AuthenticationError, SimpleUser, AuthCredentials
|
||||
from starlette.middleware.authentication import AuthenticationMiddleware
|
||||
|
||||
|
||||
from db import connect_to_mongo, close_mongo_connection
|
||||
|
||||
@ -19,6 +26,26 @@ if settings.BACKEND_CORS_ORIGINS:
|
||||
app.add_event_handler("startup", connect_to_mongo)
|
||||
app.add_event_handler("shutdown", close_mongo_connection)
|
||||
|
||||
|
||||
class BasicAuth(AuthenticationBackend):
|
||||
async def authenticate(self, request):
|
||||
if "Authorization" not in request.headers:
|
||||
return None
|
||||
|
||||
auth = request.headers["Authorization"]
|
||||
try:
|
||||
scheme, credentials = auth.split()
|
||||
decoded = base64.b64decode(credentials).decode("ascii")
|
||||
except (ValueError, UnicodeDecodeError, binascii.Error):
|
||||
raise AuthenticationError("Invalid basic auth credentials")
|
||||
|
||||
username, _, password = decoded.partition(":")
|
||||
return AuthCredentials(["authenticated"]), SimpleUser(username)
|
||||
|
||||
# enforcer = casbin.Enforcer('rbac_model.conf', 'rbac_policy.csv')
|
||||
# app.add_middleware(CasbinMiddleware, enforcer=enforcer)
|
||||
# app.add_middleware(AuthenticationMiddleware, backend=BasicAuth())
|
||||
|
||||
from api.api_v1.api import api_router
|
||||
|
||||
app.include_router(api_router, prefix=settings.API_V1_STR)
|
||||
|
21
rbac_policy.py
Normal file
21
rbac_policy.py
Normal file
@ -0,0 +1,21 @@
|
||||
import casbin
|
||||
|
||||
from core.config import settings
|
||||
from pymongo import MongoClient
|
||||
|
||||
from utils import Adapter
|
||||
|
||||
client = MongoClient(settings.DATABASE_URI)
|
||||
db = client[settings.MDB_DB]
|
||||
collection = db[settings.CASBIN_COLL]
|
||||
|
||||
adapter = Adapter(settings.DATABASE_URI, settings.MDB_DB)
|
||||
enforcer = casbin.Enforcer('rbac_model.conf', adapter)
|
||||
|
||||
model = enforcer.get_model()
|
||||
model.add_policy('g', 'g', ['root', 'superAdmin', ])
|
||||
model.add_policy('g', 'g', ['legu', 'admin'])
|
||||
adapter.save_policy(model)
|
||||
|
||||
res = enforcer.enforce('alice', 'data1', 'read')
|
||||
print(res)
|
@ -6,4 +6,4 @@ from pydantic import BaseModel
|
||||
class Msg(BaseModel):
|
||||
code: int
|
||||
msg: str
|
||||
detail: Any
|
||||
data: Any
|
||||
|
@ -25,6 +25,10 @@ class ReportDelete(DBBase):
|
||||
pass
|
||||
|
||||
|
||||
class ReportRead(DBBase):
|
||||
project_id: str
|
||||
|
||||
|
||||
# --------------------------------------------------------------
|
||||
# 数据库模型
|
||||
class ReportDB(DBBase):
|
||||
@ -32,5 +36,6 @@ class ReportDB(DBBase):
|
||||
user_id: str
|
||||
project_id: str
|
||||
# cat: Category
|
||||
members: List[str] = []
|
||||
pid: str
|
||||
create_date: datetime = datetime.now()
|
||||
|
1
utils/__init__.py
Normal file
1
utils/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .adapter import Adapter
|
112
utils/adapter.py
Normal file
112
utils/adapter.py
Normal file
@ -0,0 +1,112 @@
|
||||
import casbin
|
||||
from casbin import persist
|
||||
from pymongo import MongoClient
|
||||
|
||||
|
||||
class CasbinRule:
|
||||
'''
|
||||
CasbinRule model
|
||||
'''
|
||||
|
||||
def __init__(self, ptype=None, v0=None, v1=None, v2=None, v3=None, v4=None, v5=None):
|
||||
self.ptype = ptype
|
||||
self.v0 = v0
|
||||
self.v1 = v1
|
||||
self.v2 = v2
|
||||
self.v3 = v3
|
||||
self.v4 = v4
|
||||
self.v5 = v5
|
||||
|
||||
def dict(self):
|
||||
d = {'ptype': self.ptype}
|
||||
|
||||
for i, v in enumerate([self.v0, self.v1, self.v2, self.v3, self.v4, self.v5]):
|
||||
if v is None:
|
||||
break
|
||||
d['v' + str(i)] = v
|
||||
|
||||
return d
|
||||
|
||||
def __str__(self):
|
||||
return ', '.join(self.dict().values())
|
||||
|
||||
def __repr__(self):
|
||||
return '<CasbinRule :"{}">'.format(str(self))
|
||||
|
||||
|
||||
class Adapter(persist.Adapter):
|
||||
"""the interface for Casbin adapters."""
|
||||
|
||||
def __init__(self, uri, dbname, collection="casbin_rule"):
|
||||
client = MongoClient(uri)
|
||||
db = client[dbname]
|
||||
self._collection = db[collection]
|
||||
|
||||
def load_policy(self, model):
|
||||
'''
|
||||
implementing add Interface for casbin \n
|
||||
load all policy rules from mongodb \n
|
||||
'''
|
||||
|
||||
for line in self._collection.find():
|
||||
if 'ptype' not in line:
|
||||
continue
|
||||
|
||||
rule = CasbinRule(line['ptype'])
|
||||
if 'v0' in line:
|
||||
rule.v0 = line['v0']
|
||||
if 'v1' in line:
|
||||
rule.v1 = line['v1']
|
||||
if 'v2' in line:
|
||||
rule.v2 = line['v2']
|
||||
if 'v3' in line:
|
||||
rule.v3 = line['v3']
|
||||
if 'v4' in line:
|
||||
rule.v4 = line['v4']
|
||||
if 'v5' in line:
|
||||
rule.v5 = line['v5']
|
||||
|
||||
persist.load_policy_line(str(rule), model)
|
||||
|
||||
def _save_policy_line(self, ptype, rule):
|
||||
line = CasbinRule(ptype=ptype)
|
||||
if len(rule) > 0:
|
||||
line.v0 = rule[0]
|
||||
if len(rule) > 1:
|
||||
line.v1 = rule[1]
|
||||
if len(rule) > 2:
|
||||
line.v2 = rule[2]
|
||||
if len(rule) > 3:
|
||||
line.v3 = rule[3]
|
||||
if len(rule) > 4:
|
||||
line.v4 = rule[4]
|
||||
if len(rule) > 5:
|
||||
line.v5 = rule[5]
|
||||
self._collection.update_one(line.dict(), {'$set':line.dict()}, upsert=True)
|
||||
|
||||
def save_policy(self, model):
|
||||
'''
|
||||
implementing add Interface for casbin \n
|
||||
save the policy in mongodb \n
|
||||
'''
|
||||
for sec in ["p", "g"]:
|
||||
if sec not in model.model.keys():
|
||||
continue
|
||||
for ptype, ast in model.model[sec].items():
|
||||
for rule in ast.policy:
|
||||
self._save_policy_line(ptype, rule)
|
||||
return True
|
||||
|
||||
def add_policy(self, sec, ptype, rule):
|
||||
"""add policy rules to mongodb"""
|
||||
self._save_policy_line(ptype, rule)
|
||||
|
||||
def remove_policy(self, sec, ptype, rule):
|
||||
"""delete policy rules from mongodb"""
|
||||
pass
|
||||
|
||||
def remove_filtered_policy(self, sec, ptype, field_index, *field_values):
|
||||
"""
|
||||
delete policy rules for matching filters from mongodb
|
||||
"""
|
||||
pass
|
109
utils/async_adapter.py
Normal file
109
utils/async_adapter.py
Normal file
@ -0,0 +1,109 @@
|
||||
from casbin import persist
|
||||
|
||||
|
||||
class CasbinRule:
|
||||
'''
|
||||
CasbinRule model
|
||||
'''
|
||||
|
||||
def __init__(self, ptype = None, v0 = None, v1 = None, v2 = None, v3 = None, v4 = None, v5 = None):
|
||||
self.ptype = ptype
|
||||
self.v0 = v0
|
||||
self.v1 = v1
|
||||
self.v2 = v2
|
||||
self.v3 = v3
|
||||
self.v4 = v4
|
||||
self.v5 = v5
|
||||
|
||||
def dict(self):
|
||||
d = {'ptype': self.ptype}
|
||||
|
||||
for i, v in enumerate([self.v0, self.v1, self.v2, self.v3, self.v4, self.v5]):
|
||||
if v is None:
|
||||
break
|
||||
d['v' + str(i)] = v
|
||||
|
||||
return d
|
||||
|
||||
def __str__(self):
|
||||
return ', '.join(self.dict().values())
|
||||
|
||||
def __repr__(self):
|
||||
return '<CasbinRule :"{}">'.format(str(self))
|
||||
|
||||
class Adapter(persist.Adapter):
|
||||
"""the interface for Casbin adapters."""
|
||||
|
||||
def __init__(self,db, collection="casbin_rule"):
|
||||
self._collection = db[collection]
|
||||
|
||||
async def load_policy(self, model):
|
||||
'''
|
||||
implementing add Interface for casbin \n
|
||||
load all policy rules from mongodb \n
|
||||
'''
|
||||
|
||||
async for line in self._collection.find():
|
||||
if 'ptype' not in line:
|
||||
continue
|
||||
|
||||
rule = CasbinRule(line['ptype'])
|
||||
if 'v0' in line:
|
||||
rule.v0 = line['v0']
|
||||
if 'v1' in line:
|
||||
rule.v1 = line['v1']
|
||||
if 'v2' in line:
|
||||
rule.v2 = line['v2']
|
||||
if 'v3' in line:
|
||||
rule.v3 = line['v3']
|
||||
if 'v4' in line:
|
||||
rule.v4 = line['v4']
|
||||
if 'v5' in line:
|
||||
rule.v5 = line['v5']
|
||||
|
||||
persist.load_policy_line(str(rule), model)
|
||||
|
||||
async def _save_policy_line(self, ptype, rule):
|
||||
line = CasbinRule(ptype=ptype)
|
||||
if len(rule) > 0:
|
||||
line.v0 = rule[0]
|
||||
if len(rule) > 1:
|
||||
line.v1 = rule[1]
|
||||
if len(rule) > 2:
|
||||
line.v2 = rule[2]
|
||||
if len(rule) > 3:
|
||||
line.v3 = rule[3]
|
||||
if len(rule) > 4:
|
||||
line.v4 = rule[4]
|
||||
if len(rule) > 5:
|
||||
line.v5 = rule[5]
|
||||
await self._collection.insert_one(line.dict())
|
||||
|
||||
async def save_policy(self, model):
|
||||
'''
|
||||
implementing add Interface for casbin \n
|
||||
save the policy in mongodb \n
|
||||
'''
|
||||
for sec in ["p", "g"]:
|
||||
if sec not in model.model.keys():
|
||||
continue
|
||||
for ptype, ast in model.model[sec].items():
|
||||
for rule in ast.policy:
|
||||
await self._save_policy_line(ptype, rule)
|
||||
return True
|
||||
|
||||
async def add_policy(self, sec, ptype, rule):
|
||||
"""add policy rules to mongodb"""
|
||||
await self._save_policy_line(ptype, rule)
|
||||
|
||||
def remove_policy(self, sec, ptype, rule):
|
||||
"""delete policy rules from mongodb"""
|
||||
pass
|
||||
|
||||
def remove_filtered_policy(self, sec, ptype, field_index, *field_values):
|
||||
"""
|
||||
delete policy rules for matching filters from mongodb
|
||||
"""
|
||||
pass
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user