diff --git a/README.md b/README.md index 6a1e6ef..8001e96 100644 --- a/README.md +++ b/README.md @@ -4,3 +4,13 @@ x后端 使用 mongo 数据库 ### casbin 接口权限控制 +修改 utils\casbin\model\assertion.py 20行 +为了能定义角色访问所有域 +```python +for rule in self.policy: + if len(rule) < count: + pass + # raise RuntimeError("grouping policy elements do not meet role definition") + if len(rule) > count: + rule = rule[:count] +``` diff --git a/api/api_v1/endpoints/authority.py b/api/api_v1/endpoints/authority.py index fce9f59..a1c5940 100644 --- a/api/api_v1/endpoints/authority.py +++ b/api/api_v1/endpoints/authority.py @@ -13,29 +13,41 @@ router = APIRouter() @router.get("/api_list") -async def api_list(request: Request, game: str, +async def api_list(request: Request, current_user: schemas.UserDB = Depends(deps.get_current_user)) -> schemas.Msg: """api 列表""" app = request.app - data = [] + data = {} for r in app.routes: + title = r.tags[0] if hasattr(r, 'description') else None + if not title: + continue + data.setdefault(title, {'list': []}) path = r.path name = r.description if hasattr(r, 'description') else r.name - data.append({'api': path, 'name': name}) - return schemas.Msg(code=0, msg='ok', data=data) + data[title]['list'].append({'api': path, 'title': name}) + + res = [{'title': k, 'list': v['list']} for k, v in data.items()] + + return schemas.Msg(code=0, msg='ok', data=res) @router.post("/add_role") -async def add_role(request: Request, game: str, data_in: schemas.CasbinRoleCreate, +async def add_role(request: Request, + data_in: schemas.CasbinRoleCreate, + game: str = Depends(deps.get_game_project), db: AsyncIOMotorDatabase = Depends(get_database), current_user: schemas.UserDB = Depends(deps.get_current_user) ) -> schemas.Msg: """创建角色""" role_dom = game + api_dict = dict() + for r in request.app.routes: + api_dict[r.path] = r.description if hasattr(r, 'description') else r.name # 角色有的接口权限 for obj in data_in.role_api: casbin_enforcer.add_policy(data_in.role_name, role_dom, obj, '*') - await crud.authority.create(db, 'p', data_in.role_name, role_dom, obj, '*') + await crud.authority.create(db, 'p', data_in.role_name, role_dom, obj, '*', api_name=api_dict.get(obj)) # 管理员默认拥有该角色 方便从db中读出 await crud.authority.create(db, 'g', settings.SUPERUSER_NAME, data_in.role_name, '*', '*', @@ -45,38 +57,93 @@ async def add_role(request: Request, game: str, data_in: schemas.CasbinRoleCreat return schemas.Msg(code=0, msg='ok') -@router.post("/add_account") -async def add_account(request: Request, - game: str, - data_in: schemas.AccountCreate, - db: AsyncIOMotorDatabase = Depends(get_database), - current_user: schemas.UserDB = Depends(deps.get_current_user) - ) -> schemas.Msg: - """创建账号 并设置角色""" +@router.post("/add_sys_role") +async def add_sys_role(request: Request, + data_in: schemas.CasbinRoleCreate, + game: str = Depends(deps.get_game_project), + db: AsyncIOMotorDatabase = Depends(get_database), + current_user: schemas.UserDB = Depends(deps.get_current_user) + ) -> schemas.Msg: + """创建系统角色""" + api_dict = dict() + for r in request.app.routes: + api_dict[r.path] = r.description if hasattr(r, 'description') else r.name + # 角色有的接口权限 + for obj in data_in.role_api: + casbin_enforcer.add_policy(data_in.role_name, '*', obj, '*') + await crud.authority.create(db, 'p', data_in.role_name, '*', obj, '*', api_name=api_dict.get(obj)) - account = schemas.UserCreate(name=data_in.username, nickname=data_in.nickname, password=settings.DEFAULT_PASSWORD) - try: - await crud.user.create(db, account) - except pymongo.errors.DuplicateKeyError: - return schemas.Msg(code=-1, msg='用户名已存在') - - casbin_enforcer.add_grouping_policy(data_in.username, data_in.role_name, game) - await crud.authority.create(db, 'g', data_in.username, data_in.role_name, game) + # 管理员默认拥有该角色 方便从db中读出 + await crud.authority.create(db, 'g', settings.SUPERUSER_NAME, data_in.role_name, + role_name=data_in.role_name, + game='*') return schemas.Msg(code=0, msg='ok') +@router.post("/add_account") +async def add_account(request: Request, + + data_in: schemas.AccountsCreate, + game: str = Depends(deps.get_game_project), + db: AsyncIOMotorDatabase = Depends(get_database), + current_user: schemas.UserDB = Depends(deps.get_current_user) + ) -> schemas.Msg: + """创建账号 并设置角色""" + for item in data_in.accounts: + account = schemas.UserCreate(name=item.username, password=settings.DEFAULT_PASSWORD) + try: + await crud.user.create(db, account) + except pymongo.errors.DuplicateKeyError: + return schemas.Msg(code=-1, msg='用户名已存在') + + casbin_enforcer.add_grouping_policy(item.username, item.role_name, game) + await crud.authority.create(db, 'g', item.username, item.role_name, game) + + return schemas.Msg(code=0, msg='ok') + + +@router.get("/data_authority") +async def data_authority(request: Request, + db: AsyncIOMotorDatabase = Depends(get_database), + game: str = Depends(deps.get_game_project), + current_user: schemas.UserDB = Depends(deps.get_current_user) + ) -> schemas.Msg: + """获取数据权限""" + + # todo 这是假数据 + data = [{'title': '全部事件', 'check_event_num': 100, 'total_event_num': 100, 'update_time': '2021-05-12 18:49:19'}] + + return schemas.Msg(code=0, msg='ok', data=data) + + @router.get("/all_role") async def all_role(request: Request, - game: str, db: AsyncIOMotorDatabase = Depends(get_database), + game: str = Depends(deps.get_game_project), current_user: schemas.UserDB = Depends(deps.get_current_user) ) -> schemas.Msg: """获取所有角色""" - roles = await crud.authority.find_many(db, role_name={'$exists': 1}, game=game) - data = [{'role': item['v1'], 'name': item['role_name']} for item in roles] - return schemas.Msg(code=0, msg='ok', data=data) + """获取域内所有角色""" + roles = await crud.authority.find_many(db, role_name={'$exists': 1}, game=game) + dom_data = [{'role': item['v1'], 'name': item['role_name']} for item in roles] + for item in dom_data: + q = await crud.authority.get_role_dom_authority(db, item['role'], game) + item['authority'] = q + + # 获取系统角色 + roles = await crud.authority.find_many(db, role_name={'$exists': 1}, game='*') + sys_data = [{'role': item['v1'], 'name': item['role_name']} for item in roles] + for item in sys_data: + q = await crud.authority.get_role_dom_authority(db, item['role'], dom='*') + item['authority'] = q + + data = { + 'dom_role': dom_data, + 'sys_role': sys_data + } + return schemas.Msg(code=0, msg='ok', data=data) # @router.get("/all_role") # async def all_role(request: Request, @@ -112,18 +179,18 @@ async def all_role(request: Request, # return schemas.Msg(code=0, msg='ok', data={'roles': roles, 'permissions': permissions}) -@router.post("/set_role") -async def set_role(request: Request, - data_id: schemas.AccountSetRole, - db: AsyncIOMotorDatabase = Depends(get_database), - current_user: schemas.UserDB = Depends(deps.get_current_user) - ) -> schemas.Msg: - """设置账号角色""" - casbin_enforcer.delete_user(data_id.name) - casbin_enforcer.add_role_for_user(data_id.name, data_id.role_name) - await crud.authority.update_one(db, {'ptype': 'g', 'v0': data_id.name}, dict(v1=data_id.role_name)) - - return schemas.Msg(code=0, msg='ok') +# @router.post("/set_role") +# async def set_role(request: Request, +# data_id: schemas.AccountSetRole, +# db: AsyncIOMotorDatabase = Depends(get_database), +# current_user: schemas.UserDB = Depends(deps.get_current_user) +# ) -> schemas.Msg: +# """设置账号角色""" +# casbin_enforcer.delete_user(data_id.name) +# casbin_enforcer.add_role_for_user(data_id.name, data_id.role_name) +# await crud.authority.update_one(db, {'ptype': 'g', 'v0': data_id.name}, dict(v1=data_id.role_name)) +# +# return schemas.Msg(code=0, msg='ok') # @router.get("/delete_user") # async def delete_user(request: Request, diff --git a/api/api_v1/endpoints/project.py b/api/api_v1/endpoints/project.py index 35d9e8d..a695f83 100644 --- a/api/api_v1/endpoints/project.py +++ b/api/api_v1/endpoints/project.py @@ -6,6 +6,7 @@ from api import deps from core.config import settings from db import get_database +from db.ckdb import CKDrive, get_ck_db from schemas.project import ProjectCreate from utils import casbin_enforcer @@ -46,6 +47,46 @@ async def read_project(request: Request, return schemas.Msg(code=0, msg='ok', data=res) +@router.post("/detail") +async def read_project(request: Request, + game: str, + data_in: schemas.ProjectDetail, + db: AsyncIOMotorDatabase = Depends(get_database), + ck: CKDrive = Depends(get_ck_db), + current_user: schemas.UserDB = Depends(deps.get_current_user) + ): + """查看项目信息""" + res = await crud.project.read_project(db, user_id=current_user.id, _id=data_in.project_id) + if res: + res = res[0] + event_count = await ck.count(game, 'event') + user_count = await ck.count(game, 'user_view') + event_type_count = await ck.distinct_count(game, 'event') + event_attr_count = await ck.field_count(db=game, tb='event') + user_attr_count = await ck.field_count(db=game, tb='user_view') + + res['event_count'] = event_count + res['user_count'] = user_count + res['event_type_count'] = event_type_count + res['event_attr_count'] = event_attr_count + res['user_attr_count'] = user_attr_count + return schemas.Msg(code=0, msg='ok', data=res) + + +@router.post("/rename") +async def rename_project(request: Request, + data_in: schemas.ProjectRename, + db: AsyncIOMotorDatabase = Depends(get_database), + current_user: schemas.UserDB = Depends(deps.get_current_user) + ): + """修改项目名""" + try: + res = await crud.project.rename(db, data_in) + except pymongo.errors.DuplicateKeyError: + return schemas.Msg(code=-1, msg='项目名已存在') + return schemas.Msg(code=0, msg='ok', data=res) + + @router.post("/add_members") async def add_members(request: Request, game: str, @@ -56,7 +97,7 @@ async def add_members(request: Request, """项目添加成员""" # - # await crud.project.add_members(db, data_in) + await crud.project.add_members(db, data_in) for item in data_in.members: casbin_enforcer.add_grouping_policy(item.username, item.role, game) await crud.authority.create(db, 'g', item.username, item.role, game) @@ -70,10 +111,18 @@ async def members(request: Request, current_user: schemas.UserDB = Depends(deps.get_current_user) ): """查看项目成员""" - res = await crud.authority.find_many(db, ptype='g', v2=game) - data = [{'name': 'root', 'role': '超级管理员'}] - data += [{'name': item['v0'], 'role': item['v1']} for item in res] - return schemas.Msg(code=0, msg='ok', data=data) + roles = await crud.authority.find_many(db, ptype='g', v2=game) + data = {item['v0']: {'name': item['v0'], 'role': item['v1']} for item in roles} + data['root'] = {'name': 'root', 'role': '超级管理员'} + users = await crud.user.get_by_users(db, name={'$in': list(data.keys())}) + res = [] + for user in users.data: + res.append({ + **user.dict(), + 'role': data[user.name]['role'] + }) + + return schemas.Msg(code=0, msg='ok', data=res) @router.post("/del_member") @@ -85,6 +134,7 @@ async def members(request: Request, ): """删除项目成员""" casbin_enforcer.delete_user(data_in.username) + await crud.project.del_members(data_in) await crud.authority.delete(db, ptype='g', v2=game, v0=data_in.username) return schemas.Msg(code=0, msg='ok') diff --git a/api/api_v1/endpoints/query.py b/api/api_v1/endpoints/query.py index 52e90d6..37df7c0 100644 --- a/api/api_v1/endpoints/query.py +++ b/api/api_v1/endpoints/query.py @@ -1,12 +1,9 @@ -import json - -import aioch import pandas as pd from fastapi import APIRouter, Depends, Request import crud, schemas from api import deps -from db.ckdb import get_ck_db +from db.ckdb import get_ck_db, CKDrive router = APIRouter() @@ -15,20 +12,19 @@ router = APIRouter() async def query_sql( request: Request, data_in: schemas.Sql, - ckdb: aioch.Client = Depends(get_ck_db), + ckdb: CKDrive = Depends(get_ck_db), current_user: schemas.UserDB = Depends(deps.get_current_user) ) -> schemas.Msg: """原 sql 查询 """ - data, columns = await ckdb.execute(data_in.sql, with_column_types=True, columnar=True) - df = pd.DataFrame({col[0]: d for d, col in zip(data, columns)}) - return schemas.Msg(code=0, msg='ok', data=df.to_dict()) + data = await ckdb.execute(data_in.sql) + return schemas.Msg(code=0, msg='ok', data=data) @router.post("/query") async def query( request: Request, data_in: schemas.CkQuery, - ckdb: aioch.Client = Depends(get_ck_db), + ckdb: CKDrive = Depends(get_ck_db), current_user: schemas.UserDB = Depends(deps.get_current_user) ) -> schemas.Msg: """ json解析 sql 查询""" diff --git a/api/api_v1/endpoints/user.py b/api/api_v1/endpoints/user.py index 3eda3cc..7f928d5 100644 --- a/api/api_v1/endpoints/user.py +++ b/api/api_v1/endpoints/user.py @@ -31,6 +31,7 @@ async def login( # raise HTTPException(status_code=400, detail="Incorrect name or password") return schemas.Msg(code=-1, msg='密码或用户名错误') access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) + await crud.user.update_login_time(db, data.username) return { 'data': { @@ -39,6 +40,7 @@ async def login( 'email': user.email, 'token': security.create_access_token( expires_delta=access_token_expires, _id=str(user.id), email=user.email, + nickname=user.nickname, is_superuser=user.is_superuser, name=user.name ), "token_type": "bearer", @@ -64,6 +66,20 @@ def me(current_user: schemas.User = Depends(deps.get_current_user)) -> Any: return current_user +@router.post("/reset_password") +async def reset_password(request: Request, + game: str, + data_in: schemas.UserRestPassword, + db: AsyncIOMotorDatabase = Depends(get_database), + current_user: schemas.User = Depends(deps.get_current_user) + ) -> Any: + """ + 修改密码 + """ + await crud.user.reset_password(db, data_in) + return schemas.Msg(code=0, msg='ok') + + @router.get("/all_account") async def all_account(page: int = 1, limit: int = 10, db: AsyncIOMotorDatabase = Depends(get_database), current_user: schemas.User = Depends(deps.get_current_user) diff --git a/api/deps.py b/api/deps.py index 5e5c37e..da1018c 100644 --- a/api/deps.py +++ b/api/deps.py @@ -1,11 +1,15 @@ -from fastapi import Depends, HTTPException, status +from fastapi import Depends, status, HTTPException from fastapi.security import OAuth2PasswordBearer from jose import jwt +from motor.motor_asyncio import AsyncIOMotorDatabase from pydantic import ValidationError +import crud import schemas +import utils from core import security from core.config import settings +from db import get_database reusable_oauth2 = OAuth2PasswordBearer( tokenUrl=f"{settings.API_V1_STR}/user/login" @@ -45,3 +49,13 @@ def get_current_user2(token: str) -> schemas.UserDB: if not user: raise HTTPException(status_code=404, detail="User not found") return user + + +async def get_game_project(game: str, db: AsyncIOMotorDatabase = Depends(get_database)) -> str: + is_exists = await crud.project.find_one(db, {'game': game}, {'_id': True}) + if not is_exists: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail='没有该项目' + ) + return game diff --git a/crud/base.py b/crud/base.py index 25a01fe..95e2001 100644 --- a/crud/base.py +++ b/crud/base.py @@ -10,6 +10,9 @@ class CRUDBase: async def get(self, db, id: Union[ObjectId, str]) -> dict: return (await db[self.coll_name].find_one({'_id': ObjectId(id)})) or dict() + async def find_one(self, db, filter=None, *args, **kwargs): + return (await db[self.coll_name].find_one(filter, *args, **kwargs)) or dict() + async def read_have(self, db, user_id: str, **kwargs): where = {'members': user_id} where.update(kwargs) diff --git a/crud/crud_authority.py b/crud/crud_authority.py index 6f41fc6..f14b33e 100644 --- a/crud/crud_authority.py +++ b/crud/crud_authority.py @@ -29,6 +29,19 @@ class CRUDAuthority(CRUDBase): data.update(kwargs) await self.update_one(db, data, {'$set': data}, upsert=True) + async def get_all_dom_role(self, db, dom): + pass + + async def get_role_dom_authority(self, db, role, dom): + data = await self.find_many(db, v0=role, v1=dom) + res = [] + for item in data: + res.append({ + 'api': item['v2'], + 'api_name': item.get('api_name', item['v2']) + }) + return res + async def create_index(self, db: AsyncIOMotorDatabase): await db[self.coll_name].create_index( [('ptype', pymongo.DESCENDING), ('v0', pymongo.DESCENDING), ('v1', pymongo.DESCENDING), diff --git a/crud/crud_project.py b/crud/crud_project.py index ed45db7..54f9f81 100644 --- a/crud/crud_project.py +++ b/crud/crud_project.py @@ -15,15 +15,24 @@ class CRUDProject(CRUDBase): ) return await db[self.coll_name].insert_one(db_obj.dict(by_alias=True)) - async def read_project(self, db: AsyncIOMotorDatabase, user_id: str): - return await self.read_have(db, user_id=user_id) + async def read_project(self, db: AsyncIOMotorDatabase, user_id: str, **kwargs): + return await self.read_have(db, user_id=user_id, **kwargs) - # async def add_members(self, db: AsyncIOMotorDatabase, obj_in: ProjectMember): - # p = await self.get(db, obj_in.project_id) - # members = list(set(p.get('members')) | set(obj_in.members)) - # await self.update_one(db, {'_id': obj_in.project_id}, {'$set': {'members': members}}) + async def add_members(self, db: AsyncIOMotorDatabase, obj_in: ProjectMember): + p = await self.get(db, obj_in.project_id) + members = list(set(p.get('members')) | set(obj_in.members)) + await self.update_one(db, {'_id': obj_in.project_id}, {'$set': {'members': members}}) + + async def del_members(self, db: AsyncIOMotorDatabase, obj_in: ProjectMember): + p = await self.get(db, obj_in.project_id) + members = list(set(p.get('members')) - set(obj_in.members)) + await self.update_one(db, {'_id': obj_in.project_id}, {'$set': {'members': members}}) + + async def rename(self, db: AsyncIOMotorDatabase, obj_in: ProjectRename): + await self.update_one(db, {'_id': obj_in.project_id}, {'$set': {'name': obj_in.rename}}) async def create_index(self, db: AsyncIOMotorDatabase): + await db[self.coll_name].create_index('game', unique=True) await db[self.coll_name].create_index('name', unique=True) diff --git a/crud/crud_user.py b/crud/crud_user.py index 36bf20d..3ff0714 100644 --- a/crud/crud_user.py +++ b/crud/crud_user.py @@ -1,7 +1,10 @@ +import datetime +import time import uuid from motor.motor_asyncio import AsyncIOMotorDatabase +import schemas from core.security import get_password_hash, verify_password from crud.base import CRUDBase from schemas import UserCreate, UserDBRW @@ -15,6 +18,11 @@ class CRUDUser(CRUDBase): res = await db[self.coll_name].find_one({'name': name}) return res + async def update_login_time(self, db, name): + await self.update_one(db, {'name': name}, + {'$set': {'last_login_ts': datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}}) + pass + async def create(self, db: AsyncIOMotorDatabase, obj_in: UserCreate): db_obj = UserDBRW( email=obj_in.email, @@ -26,6 +34,10 @@ class CRUDUser(CRUDBase): ) return await db[self.coll_name].insert_one(db_obj.dict(by_alias=True)) + async def reset_password(self, db: AsyncIOMotorDatabase, obj_in: schemas.UserRestPassword): + hashed_password = get_password_hash(obj_in.password) + await self.update_one(db, {'name': obj_in.username}, {'hashed_password': hashed_password}) + async def authenticate(self, db: AsyncIOMotorDatabase, name: str, password: str): user_obj = UserDBRW(**await self.get_by_user(db, name=name)) if not user_obj: @@ -34,6 +46,10 @@ class CRUDUser(CRUDBase): return None return user_obj + async def get_by_users(self, db, **kwargs) -> schemas.Users: + res = await self.find_many(db, **kwargs) + return schemas.Users(data=res) + async def create_index(self, db: AsyncIOMotorDatabase): await db[self.coll_name].create_index('name', unique=True) diff --git a/db/ckdb.py b/db/ckdb.py index e27429e..0481ffd 100644 --- a/db/ckdb.py +++ b/db/ckdb.py @@ -1,12 +1,38 @@ from aioch import Client +import pandas as pd -class CKBase: +class CKDrive: client: Client = None + async def execute(self, sql) -> dict: + data, columns = await self.client.execute(sql, with_column_types=True, columnar=True) + df = pd.DataFrame({col[0]: d for d, col in zip(data, columns)}) + return df.T.to_dict() -ckdb = CKBase() + async def query_dataframe(self, sql): + data, columns = await self.client.execute(sql, with_column_types=True, columnar=True) + df = pd.DataFrame({col[0]: d for d, col in zip(data, columns)}) + return df + + async def count(self, db: str, tb: str): + sql = f'select count() as `count` from {db}.{tb}' + res = await self.execute(sql) + return res[0]['count'] + + async def distinct_count(self, db: str, tb: str): + sql = f'select count(distinct `#event_name`) as `count` from {db}.{tb}' + res = await self.execute(sql) + return res[0]['count'] + + async def field_count(self, db: str, tb: str): + sql = f"select count(name) as `count` from system.columns where database='{db}' and table='{tb}'" + res = await self.execute(sql) + return res[0]['count'] -def get_ck_db() -> Client: - return ckdb.client +ckdb = CKDrive() + + +def get_ck_db() -> CKDrive: + return ckdb diff --git a/db/ckdb_utils.py b/db/ckdb_utils.py index 55bb263..9d1518f 100644 --- a/db/ckdb_utils.py +++ b/db/ckdb_utils.py @@ -1,12 +1,12 @@ from aioch import Client from core.config import settings -from .ckdb import ckdb +from .ckdb import CKDrive async def connect_to_ck(): - ckdb.client = Client(**settings.CK_CONFIG) + CKDrive.client = Client(**settings.CK_CONFIG) async def close_ck_connection(): - await ckdb.client.disconnect() + await CKDrive.client.disconnect() diff --git a/middleware/casbin.py b/middleware/casbin.py index cc1a0da..ed1436c 100644 --- a/middleware/casbin.py +++ b/middleware/casbin.py @@ -1,4 +1,4 @@ -from casbin.enforcer import Enforcer +from utils.casbin.enforcer import Enforcer from fastapi import HTTPException from starlette.authentication import BaseUser from starlette.requests import Request diff --git a/rbac_model.conf b/rbac_model.conf index 922af8d..396ebe4 100644 --- a/rbac_model.conf +++ b/rbac_model.conf @@ -11,4 +11,4 @@ g = _, _, _ e = some(where (p.eft == allow)) [matchers] -m = g(r.sub, p.sub, r.dom) && (p.dom=="*" || r.dom == p.dom) && ( p.obj=="*" || r.obj == p.obj) && (p.act=="*" || r.act == p.act) || r.sub=="root" \ No newline at end of file +m = (g(r.sub, p.sub) || g(r.sub, p.sub, r.dom)) && (p.dom=="*" || r.dom == p.dom) && ( p.obj=="*" || r.obj == p.obj) && (p.act=="*" || r.act == p.act) || r.sub=="root" \ No newline at end of file diff --git a/schemas/authotity.py b/schemas/authotity.py index 9fca684..6495fa3 100644 --- a/schemas/authotity.py +++ b/schemas/authotity.py @@ -11,7 +11,7 @@ class Ptype(str, Enum): class CasbinRoleCreate(BaseModel): role_name: str - role_api: List + role_api: List[str] class CasbinDB(BaseModel): @@ -24,7 +24,12 @@ class CasbinDB(BaseModel): class AccountCreate(BaseModel): username: str role_name: str - nickname: str + # nickname: str + data_authority: str + + +class AccountsCreate(BaseModel): + accounts: List[AccountCreate] class AccountDeleteUser(BaseModel): diff --git a/schemas/project.py b/schemas/project.py index 9d3eceb..730b891 100644 --- a/schemas/project.py +++ b/schemas/project.py @@ -18,7 +18,16 @@ class MemberRole(BaseModel): class ProjectMember(BaseModel): members: List[MemberRole] - # project_id: str + project_id: str + + +class ProjectDetail(BaseModel): + project_id: str + + +class ProjectRename(BaseModel): + project_id: str + rename: str class ProjectDelMember(BaseModel): diff --git a/schemas/user.py b/schemas/user.py index c0a4ea9..b4d95ab 100644 --- a/schemas/user.py +++ b/schemas/user.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, List from schemas.base import DBBase @@ -9,22 +9,31 @@ class UserBase(BaseModel): email: Optional[EmailStr] = None is_superuser: bool = False name: Optional[str] = None - nickname: str + nickname: str = '' + last_login_ts: str = '尚未登录' class User(UserBase): name: str +class Users(BaseModel): + data: List[User] = [] + + class UserLogin(BaseModel): username: str = ... password: str = ... +class UserRestPassword(BaseModel): + username: str = ... + password: str = ... + + class UserCreate(UserBase): password: str name: str - nickname: str # **************************************************************************** @@ -36,6 +45,7 @@ class UserDB(DBBase): is_superuser: bool = False name: str nickname: str + last_login_ts: str = '尚未登录' class UserDBRW(UserDB): diff --git a/utils/__init__.py b/utils/__init__.py index 38fc832..d61d95d 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -1 +1,3 @@ from .adapter import * + +from . import casbin diff --git a/utils/adapter.py b/utils/adapter.py index e14cb3d..cca7324 100644 --- a/utils/adapter.py +++ b/utils/adapter.py @@ -1,5 +1,5 @@ -import casbin -from casbin import persist +from utils import casbin +from utils.casbin import persist from pymongo import MongoClient from core.config import settings @@ -7,6 +7,7 @@ from core.config import settings __all__ = 'casbin_adapter', 'casbin_enforcer', 'casbin_model' + class CasbinRule: ''' CasbinRule model diff --git a/utils/async_adapter.py b/utils/async_adapter.py index a3ace43..241d2bf 100644 --- a/utils/async_adapter.py +++ b/utils/async_adapter.py @@ -1,5 +1,5 @@ -from casbin import persist +from .casbin import persist class CasbinRule: diff --git a/utils/casbin/__init__.py b/utils/casbin/__init__.py new file mode 100644 index 0000000..9991b3a --- /dev/null +++ b/utils/casbin/__init__.py @@ -0,0 +1,7 @@ +from .enforcer import * +from .synced_enforcer import SyncedEnforcer +from .distributed_enforcer import DistributedEnforcer +from . import util +from .persist import * +from .effect import * +from .model import * \ No newline at end of file diff --git a/utils/casbin/config/__init__.py b/utils/casbin/config/__init__.py new file mode 100644 index 0000000..3558f42 --- /dev/null +++ b/utils/casbin/config/__init__.py @@ -0,0 +1 @@ +from .config import Config \ No newline at end of file diff --git a/utils/casbin/config/config.py b/utils/casbin/config/config.py new file mode 100644 index 0000000..ac91a79 --- /dev/null +++ b/utils/casbin/config/config.py @@ -0,0 +1,151 @@ +from io import StringIO + + +class Config: + """represents an implementation of the ConfigInterface""" + + # DEFAULT_SECTION specifies the name of a section if no name provided + DEFAULT_SECTION = 'default' + # DEFAULT_COMMENT defines what character(s) indicate a comment `#` + DEFAULT_COMMENT = '#' + # DEFAULT_COMMENT_SEM defines what alternate character(s) indicate a comment `;` + DEFAULT_COMMENT_SEM = ';' + # DEFAULT_MULTI_LINE_SEPARATOR defines what character indicates a multi-line content + DEFAULT_MULTI_LINE_SEPARATOR = '\\' + + _data = dict() + + def __init__(self): + self._data = dict() + + @staticmethod + def new_config(conf_name): + c = Config() + c._parse(conf_name) + return c + + @staticmethod + def new_config_from_text(text): + c = Config() + f = StringIO(text) + c._parse_buffer(f) + return c + + def add_config(self, section, option, value): + if section == '': + section = self.DEFAULT_SECTION + + if section not in self._data.keys(): + self._data[section] = {} + + self._data[section][option] = value + + def _parse(self, fname): + with open(fname, 'r', encoding='utf-8') as f: + self._parse_buffer(f) + + def _parse_buffer(self, f): + section = '' + line_num = 0 + buf = [] + can_write = False + while True: + if can_write: + self._write(section, line_num, buf) + can_write = False + line_num = line_num + 1 + + line = f.readline() + + if not line: + if len(buf) > 0: + self._write(section, line_num, buf) + break + line = line.strip() + + if '' == line or self.DEFAULT_COMMENT == line[0:1] or self.DEFAULT_COMMENT_SEM == line[0:1]: + can_write = True + continue + elif '[' == line[0:1] and ']' == line[-1]: + if len(buf) > 0: + self._write(section, line_num, buf) + can_write = False + section = line[1:-1] + else: + p = '' + if self.DEFAULT_MULTI_LINE_SEPARATOR == line[-1]: + p = line[0:-1].strip() + p = p + ' ' + else: + p = line + can_write = True + buf.append(p) + + def _write(self, section, line_num, b): + + buf = "".join(b) + if len(buf) <= 0: + return + option_val = buf.split('=', 1) + + if len(option_val) != 2: + raise RuntimeError('parse the content error : line {} , {} = ?'.format(line_num, option_val[0])) + + option = option_val[0].strip() + value = option_val[1].strip() + + self.add_config(section, option, value) + + del b[:] + + def get_bool(self, key): + """lookups up the value using the provided key and converts the value to a bool.""" + return self.get(key).capitalize() == "True" + + def get_int(self, key): + """lookups up the value using the provided key and converts the value to a int""" + return int(self.get(key)) + + def get_float(self, key): + """lookups up the value using the provided key and converts the value to a float""" + return float(self.get(key)) + + def get_string(self, key): + """lookups up the value using the provided key and converts the value to a string""" + return self.get(key) + + def get_strings(self, key): + """lookups up the value using the provided key and converts the value to an array of string""" + value = self.get(key) + if value == "": + return None + return value.split(",") + + def set(self, key, value): + if len(key) == 0: + raise RuntimeError("key is empty") + + keys = key.lower().split('::') + if len(keys) >= 2: + section = keys[0] + option = keys[1] + else: + section = "" + option = keys[0] + self.add_config(section, option, value) + + def get(self, key): + """section.key or key""" + + keys = key.lower().split('::') + if len(keys) >= 2: + section = keys[0] + option = keys[1] + else: + section = self.DEFAULT_SECTION + option = keys[0] + + if section in self._data.keys(): + if option in self._data[section].keys(): + return self._data[section][option] + return '' diff --git a/utils/casbin/core_enforcer.py b/utils/casbin/core_enforcer.py new file mode 100644 index 0000000..f4d6cc5 --- /dev/null +++ b/utils/casbin/core_enforcer.py @@ -0,0 +1,373 @@ +import logging + +from utils.casbin.effect import Effector, get_effector, effect_to_bool +from utils.casbin.model import Model, FunctionMap +from utils.casbin.persist import Adapter +from utils.casbin.persist.adapters import FileAdapter +from utils.casbin.rbac import default_role_manager +from utils.casbin.util import generate_g_function, SimpleEval, util + + +class CoreEnforcer: + """CoreEnforcer defines the core functionality of an enforcer.""" + + model_path = "" + model = None + fm = None + eft = None + + adapter = None + watcher = None + rm_map = None + + enabled = False + auto_save = False + auto_build_role_links = False + + def __init__(self, model=None, adapter=None): + self.logger = logging.getLogger() + if isinstance(model, str): + if isinstance(adapter, str): + self.init_with_file(model, adapter) + else: + self.init_with_adapter(model, adapter) + pass + else: + if isinstance(adapter, str): + raise RuntimeError("Invalid parameters for enforcer.") + else: + self.init_with_model_and_adapter(model, adapter) + + def init_with_file(self, model_path, policy_path): + """initializes an enforcer with a model file and a policy file.""" + a = FileAdapter(policy_path) + self.init_with_adapter(model_path, a) + + def init_with_adapter(self, model_path, adapter=None): + """initializes an enforcer with a database adapter.""" + m = self.new_model(model_path) + self.init_with_model_and_adapter(m, adapter) + + self.model_path = model_path + + def init_with_model_and_adapter(self, m, adapter=None): + """initializes an enforcer with a model and a database adapter.""" + + if not isinstance(m, Model) or adapter is not None and not isinstance(adapter, Adapter): + raise RuntimeError("Invalid parameters for enforcer.") + + self.adapter = adapter + + self.model = m + self.model.print_model() + self.fm = FunctionMap.load_function_map() + + self._initialize() + + # Do not initialize the full policy when using a filtered adapter + if self.adapter and not self.is_filtered(): + self.load_policy() + + def _initialize(self): + self.rm_map = dict() + self.eft = get_effector(self.model.model["e"]["e"].value) + self.watcher = None + + self.enabled = True + self.auto_save = True + self.auto_build_role_links = True + + self.init_rm_map() + + @staticmethod + def new_model(path="", text=""): + """creates a model.""" + + m = Model() + if len(path) > 0: + m.load_model(path) + else: + m.load_model_from_text(text) + + return m + + def load_model(self): + """reloads the model from the model CONF file. + Because the policy is attached to a model, so the policy is invalidated and needs to be reloaded by calling LoadPolicy(). + """ + + self.model = self.new_model() + self.model.load_model(self.model_path) + self.model.print_model() + self.fm = FunctionMap.load_function_map() + + def get_model(self): + """gets the current model.""" + + return self.model + + def set_model(self, m): + """sets the current model.""" + + self.model = m + self.fm = FunctionMap.load_function_map() + + def get_adapter(self): + """gets the current adapter.""" + + return self.adapter + + def set_adapter(self, adapter): + """sets the current adapter.""" + + self.adapter = adapter + + def set_watcher(self, watcher): + """sets the current watcher.""" + + self.watcher = watcher + pass + + def get_role_manager(self): + """gets the current role manager.""" + return self.rm_map['g'] + + def set_role_manager(self, rm): + """sets the current role manager.""" + self.rm_map['g'] = rm + + def set_effector(self, eft): + """sets the current effector.""" + + self.eft = eft + + def clear_policy(self): + """ clears all policy.""" + + self.model.clear_policy() + + def init_rm_map(self): + if 'g' in self.model.model.keys(): + for ptype in self.model.model['g']: + self.rm_map[ptype] = default_role_manager.RoleManager(10) + + def load_policy(self): + """reloads the policy from file/database.""" + + self.model.clear_policy() + self.adapter.load_policy(self.model) + + self.init_rm_map() + self.model.print_policy() + if self.auto_build_role_links: + self.build_role_links() + + def load_filtered_policy(self, filter): + """reloads a filtered policy from file/database.""" + self.model.clear_policy() + + if not hasattr(self.adapter, "is_filtered"): + raise ValueError("filtered policies are not supported by this adapter") + + self.adapter.load_filtered_policy(self.model, filter) + self.init_rm_map() + self.model.print_policy() + if self.auto_build_role_links: + self.build_role_links() + + def load_increment_filtered_policy(self, filter): + """LoadIncrementalFilteredPolicy append a filtered policy from file/database.""" + if not hasattr(self.adapter, "is_filtered"): + raise ValueError("filtered policies are not supported by this adapter") + + self.adapter.load_filtered_policy(self.model, filter) + self.model.print_policy() + if self.auto_build_role_links: + self.build_role_links() + + def is_filtered(self): + """returns true if the loaded policy has been filtered.""" + + return hasattr(self.adapter, "is_filtered") and self.adapter.is_filtered() + + def save_policy(self): + if self.is_filtered(): + raise RuntimeError("cannot save a filtered policy") + + self.adapter.save_policy(self.model) + + if self.watcher: + self.watcher.update() + + def enable_enforce(self, enabled=True): + """changes the enforcing state of Casbin, + when Casbin is disabled, all access will be allowed by the Enforce() function. + """ + + self.enabled = enabled + + def enable_auto_save(self, auto_save): + """controls whether to save a policy rule automatically to the adapter when it is added or removed.""" + self.auto_save = auto_save + + def enable_auto_build_role_links(self, auto_build_role_links): + """controls whether to rebuild the role inheritance relations when a role is added or deleted.""" + self.auto_build_role_links = auto_build_role_links + + def build_role_links(self): + """manually rebuild the role inheritance relations.""" + + for rm in self.rm_map.values(): + rm.clear() + + self.model.build_role_links(self.rm_map) + + def add_named_matching_func(self, ptype, fn): + """add_named_matching_func add MatchingFunc by ptype RoleManager""" + try: + self.rm_map[ptype].add_matching_func(fn) + return True + except: + return False + + def add_named_domain_matching_func(self, ptype, fn): + """add_named_domain_matching_func add MatchingFunc by ptype to RoleManager""" + try: + self.rm_map[ptype].add_domain_matching_func(fn) + return True + except: + return False + + def enforce(self, *rvals): + """decides whether a "subject" can access a "object" with the operation "action", + input parameters are usually: (sub, obj, act). + """ + result, _ = self.enforce_ex(*rvals) + return result + + def enforce_ex(self, *rvals): + """decides whether a "subject" can access a "object" with the operation "action", + input parameters are usually: (sub, obj, act). + return judge result with reason + """ + + if not self.enabled: + return False + + functions = self.fm.get_functions() + + if "g" in self.model.model.keys(): + for key, ast in self.model.model["g"].items(): + rm = ast.rm + functions[key] = generate_g_function(rm) + + if "m" not in self.model.model.keys(): + raise RuntimeError("model is undefined") + + if "m" not in self.model.model["m"].keys(): + raise RuntimeError("model is undefined") + + r_tokens = self.model.model["r"]["r"].tokens + p_tokens = self.model.model["p"]["p"].tokens + + if len(r_tokens) != len(rvals): + raise RuntimeError("invalid request size") + + exp_string = self.model.model["m"]["m"].value + has_eval = util.has_eval(exp_string) + if not has_eval: + expression = self._get_expression(exp_string, functions) + + policy_effects = set() + + r_parameters = dict(zip(r_tokens, rvals)) + + policy_len = len(self.model.model["p"]["p"].policy) + + explain_index = -1 + if not 0 == policy_len: + for i, pvals in enumerate(self.model.model["p"]["p"].policy): + if len(p_tokens) != len(pvals): + raise RuntimeError("invalid policy size") + + p_parameters = dict(zip(p_tokens, pvals)) + parameters = dict(r_parameters, **p_parameters) + + if util.has_eval(exp_string): + rule_names = util.get_eval_value(exp_string) + rules = [util.escape_assertion(p_parameters[rule_name]) for rule_name in rule_names] + exp_with_rule = util.replace_eval(exp_string, rules) + expression = self._get_expression(exp_with_rule, functions) + + result = expression.eval(parameters) + + if isinstance(result, bool): + if not result: + policy_effects.add(Effector.INDETERMINATE) + continue + elif isinstance(result, float): + if 0 == result: + policy_effects.add(Effector.INDETERMINATE) + continue + else: + raise RuntimeError("matcher result should be bool, int or float") + + if "p_eft" in parameters.keys(): + eft = parameters["p_eft"] + if "allow" == eft: + policy_effects.add(Effector.ALLOW) + elif "deny" == eft: + policy_effects.add(Effector.DENY) + else: + policy_effects.add(Effector.INDETERMINATE) + else: + policy_effects.add(Effector.ALLOW) + + if self.eft.intermediate_effect(policy_effects) != Effector.INDETERMINATE: + explain_index = i + break + + else: + if has_eval: + raise RuntimeError("please make sure rule exists in policy when using eval() in matcher") + + parameters = r_parameters.copy() + + for token in self.model.model["p"]["p"].tokens: + parameters[token] = "" + + result = expression.eval(parameters) + + if result: + policy_effects.add(Effector.ALLOW) + else: + policy_effects.add(Effector.INDETERMINATE) + + final_effect = self.eft.final_effect(policy_effects) + result = effect_to_bool(final_effect) + + # Log request. + + req_str = "Request: " + req_str = req_str + ", ".join([str(v) for v in rvals]) + + req_str = req_str + " ---> %s" % result + if result: + self.logger.info(req_str) + else: + # leaving this in error for now, if it's very noise this can be changed to info or debug + self.logger.error(req_str) + + explain_rule = [] + if explain_index != -1 and explain_index < policy_len: + explain_rule = self.model.model["p"]["p"].policy[explain_index] + + return result, explain_rule + + @staticmethod + def _get_expression(expr, functions=None): + expr = expr.replace("&&", "and") + expr = expr.replace("||", "or") + expr = expr.replace("!", "not") + + return SimpleEval(expr, functions) diff --git a/utils/casbin/distributed_enforcer.py b/utils/casbin/distributed_enforcer.py new file mode 100644 index 0000000..d1e6406 --- /dev/null +++ b/utils/casbin/distributed_enforcer.py @@ -0,0 +1,132 @@ +from utils.casbin import SyncedEnforcer +import logging + +from utils.casbin.persist import batch_adapter +from utils.casbin.model.policy_op import PolicyOp +from utils.casbin.persist.adapters import update_adapter + + +class DistributedEnforcer(SyncedEnforcer): + """DistributedEnforcer wraps SyncedEnforcer for dispatcher.""" + + def __init__(self, model=None, adapter=None): + self.logger = logging.getLogger() + SyncedEnforcer.__init__(self, model, adapter) + + def add_policy_self(self, should_persist, sec, ptype, rules): + """ + AddPolicySelf provides a method for dispatcher to add authorization rules to the current policy. + The function returns the rules affected and error. + """ + + no_exists_policy = [] + for rule in rules: + if not self.get_model().has_policy(sec, ptype, rule): + no_exists_policy.append(rule) + + if should_persist: + try: + if isinstance(self.adapter, batch_adapter): + self.adapter.add_policies(sec, ptype, rules) + except Exception as e: + self.logger.log("An error occurred: " + e) + + self.get_model().add_policies(sec, ptype, no_exists_policy) + + if sec == "g": + try: + self.build_incremental_role_links(PolicyOp.Policy_add, ptype, no_exists_policy) + except Exception as e: + self.logger.log("An exception occurred: " + e) + return no_exists_policy + + return no_exists_policy + + def remove_policy_self(self, should_persist, sec, ptype, rules): + """ + remove_policy_self provides a method for dispatcher to remove policies from current policy. + The function returns the rules affected and error. + """ + if(should_persist): + try: + if(isinstance(self.adapter, batch_adapter)): + self.adapter.remove_policy(sec, ptype, rules) + except Exception as e: + self.logger.log("An exception occurred: " + e) + + effected = self.get_model().remove_policies_with_effected(sec, ptype, rules) + + if sec == "g": + try: + self.build_incremental_role_links(PolicyOp.Policy_remove, ptype, rules) + except Exception as e: + self.logger.log("An exception occurred: " + e) + return effected + + return effected + + def remove_filtered_policy_self(self, should_persist, sec, ptype, field_index, *field_values): + """ + remove_filtered_policy_self provides a method for dispatcher to remove an authorization + rule from the current policy,field filters can be specified. + The function returns the rules affected and error. + """ + if should_persist: + try: + self.adapter.remove_filtered_policy(sec, ptype, field_index, field_values) + except Exception as e: + self.logger.log("An exception occurred: " + e) + + effects = self.get_model().remove_filtered_policy_returns_effects(sec, ptype, field_index, field_values) + + if sec == "g": + try: + self.build_incremental_role_links(PolicyOp.Policy_remove, ptype, effects) + except Exception as e: + self.logger.log("An exception occurred: " + e) + return effects + + return effects + + def clear_policy_self(self, should_persist): + """ + clear_policy_self provides a method for dispatcher to clear all rules from the current policy. + """ + if should_persist: + try: + self.adapter.save_policy(None) + except Exception as e: + self.logger.log("An exception occurred: " + e) + + self.get_model().clear_policy() + + def update_policy_self(self, should_persist, sec, ptype, old_rule, new_rule): + """ + update_policy_self provides a method for dispatcher to update an authorization rule from the current policy. + """ + if should_persist: + try: + if isinstance(self.adapter, update_adapter): + self.adapter.update_policy(sec, ptype, old_rule, new_rule) + except Exception as e: + self.logger.log("An exception occurred: " + e) + return False + + rule_updated = self.get_model().update_policy(sec, ptype, old_rule, new_rule) + + if not rule_updated: + return False + + if sec == "g": + try: + self.build_incremental_role_links(PolicyOp.Policy_remove, ptype, [old_rule]) + except Exception as e: + return False + + try: + self.build_incremental_role_links(PolicyOp.Policy_add, ptype, [new_rule]) + except Exception as e: + return False + + + return True \ No newline at end of file diff --git a/utils/casbin/effect/__init__.py b/utils/casbin/effect/__init__.py new file mode 100644 index 0000000..5163437 --- /dev/null +++ b/utils/casbin/effect/__init__.py @@ -0,0 +1,24 @@ +from .default_effectors import AllowOverrideEffector, DenyOverrideEffector, AllowAndDenyEffector, PriorityEffector +from .effector import Effector + +def get_effector(expr): + ''' creates an effector based on the current policy effect expression ''' + + if expr == "some(where (p_eft == allow))": + return AllowOverrideEffector() + elif expr == "!some(where (p_eft == deny))": + return DenyOverrideEffector() + elif expr == "some(where (p_eft == allow)) && !some(where (p_eft == deny))": + return AllowAndDenyEffector() + elif expr == "priority(p_eft) || deny": + return PriorityEffector() + else: + raise RuntimeError("unsupported effect") + +def effect_to_bool(effect): + """ """ + if effect == Effector.ALLOW: + return True + if effect == Effector.DENY: + return False + raise RuntimeError("effect can't be converted to boolean") \ No newline at end of file diff --git a/utils/casbin/effect/default_effectors.py b/utils/casbin/effect/default_effectors.py new file mode 100644 index 0000000..77d65eb --- /dev/null +++ b/utils/casbin/effect/default_effectors.py @@ -0,0 +1,61 @@ +from .effector import Effector + +class AllowOverrideEffector(Effector): + + def intermediate_effect(self, effects): + """ returns a intermediate effect based on the matched effects of the enforcer """ + if Effector.ALLOW in effects: + return Effector.ALLOW + return Effector.INDETERMINATE + + def final_effect(self, effects): + """ returns the final effect based on the matched effects of the enforcer """ + if Effector.ALLOW in effects: + return Effector.ALLOW + return Effector.DENY + +class DenyOverrideEffector(Effector): + + def intermediate_effect(self, effects): + """ returns a intermediate effect based on the matched effects of the enforcer """ + if Effector.DENY in effects: + return Effector.DENY + return Effector.INDETERMINATE + + def final_effect(self, effects): + """ returns the final effect based on the matched effects of the enforcer """ + if Effector.DENY in effects: + return Effector.DENY + return Effector.ALLOW + +class AllowAndDenyEffector(Effector): + + def intermediate_effect(self, effects): + """ returns a intermediate effect based on the matched effects of the enforcer """ + if Effector.DENY in effects: + return Effector.DENY + return Effector.INDETERMINATE + + def final_effect(self, effects): + """ returns the final effect based on the matched effects of the enforcer """ + if Effector.DENY in effects or Effector.ALLOW not in effects: + return Effector.DENY + return Effector.ALLOW + +class PriorityEffector(Effector): + + def intermediate_effect(self, effects): + """ returns a intermediate effect based on the matched effects of the enforcer """ + if Effector.ALLOW in effects: + return Effector.ALLOW + if Effector.DENY in effects: + return Effector.DENY + return Effector.INDETERMINATE + + def final_effect(self, effects): + """ returns the final effect based on the matched effects of the enforcer """ + if Effector.ALLOW in effects: + return Effector.ALLOW + if Effector.DENY in effects: + return Effector.DENY + return Effector.DENY diff --git a/utils/casbin/effect/effector.py b/utils/casbin/effect/effector.py new file mode 100644 index 0000000..13b0888 --- /dev/null +++ b/utils/casbin/effect/effector.py @@ -0,0 +1,19 @@ +class Effector: + """Effector is the interface for Casbin effectors.""" + + ALLOW = 0 + + INDETERMINATE = 1 + + DENY = 2 + + def intermediate_effect(self, effects): + """ returns a intermediate effect based on the matched effects of the enforcer """ + pass + + def final_effect(self, effects): + """ returns the final effect based on the matched effects of the enforcer """ + pass + + + diff --git a/utils/casbin/enforcer.py b/utils/casbin/enforcer.py new file mode 100644 index 0000000..678f682 --- /dev/null +++ b/utils/casbin/enforcer.py @@ -0,0 +1,211 @@ +from utils.casbin.management_enforcer import ManagementEnforcer +from utils.casbin.util import join_slice, set_subtract + +class Enforcer(ManagementEnforcer): + """ + Enforcer = ManagementEnforcer + RBAC_API + RBAC_WITH_DOMAIN_API + """ + + """creates an enforcer via file or DB. + + File: + e = casbin.Enforcer("path/to/basic_model.conf", "path/to/basic_policy.csv") + MySQL DB: + a = mysqladapter.DBAdapter("mysql", "mysql_username:mysql_password@tcp(127.0.0.1:3306)/") + e = casbin.Enforcer("path/to/basic_model.conf", a) + """ + + def get_roles_for_user(self, name): + """ gets the roles that a user has. """ + return self.model.model["g"]["g"].rm.get_roles(name) + + def get_users_for_role(self, name): + """ gets the users that has a role. """ + return self.model.model["g"]["g"].rm.get_users(name) + + def has_role_for_user(self, name, role): + """ determines whether a user has a role. """ + roles = self.get_roles_for_user(name) + return any(r == role for r in roles) + + def add_role_for_user(self, user, role): + """ + adds a role for a user. + Returns false if the user already has the role (aka not affected). + """ + return self.add_grouping_policy(user, role) + + def delete_role_for_user(self, user, role): + """ + deletes a role for a user. + Returns false if the user does not have the role (aka not affected). + """ + return self.remove_grouping_policy(user, role) + + def delete_roles_for_user(self, user): + """ + deletes all roles for a user. + Returns false if the user does not have any roles (aka not affected). + """ + return self.remove_filtered_grouping_policy(0, user) + + def delete_user(self, user): + """ + deletes a user. + Returns false if the user does not exist (aka not affected). + """ + res1 = self.remove_filtered_grouping_policy(0, user) + + res2 = self.remove_filtered_policy(0, user) + return res1 or res2 + + def delete_role(self, role): + """ + deletes a role. + Returns false if the role does not exist (aka not affected). + """ + res1 = self.remove_filtered_grouping_policy(1, role) + + res2 = self.remove_filtered_policy(0, role) + return res1 or res2 + + def delete_permission(self, *permission): + """ + deletes a permission. + Returns false if the permission does not exist (aka not affected). + """ + return self.remove_filtered_policy(1, *permission) + + def add_permission_for_user(self, user, *permission): + """ + adds a permission for a user or role. + Returns false if the user or role already has the permission (aka not affected). + """ + return self.add_policy(join_slice(user, *permission)) + + def delete_permission_for_user(self, user, *permission): + """ + deletes a permission for a user or role. + Returns false if the user or role does not have the permission (aka not affected). + """ + return self.remove_policy(join_slice(user, *permission)) + + def delete_permissions_for_user(self, user): + """ + deletes permissions for a user or role. + Returns false if the user or role does not have any permissions (aka not affected). + """ + return self.remove_filtered_policy(0, user) + + def get_permissions_for_user(self, user): + """ + gets permissions for a user or role. + """ + return self.get_filtered_policy(0, user) + + def has_permission_for_user(self, user, *permission): + """ + determines whether a user has a permission. + """ + return self.has_policy(join_slice(user, *permission)) + + def get_implicit_roles_for_user(self, name, domain=None): + """ + gets implicit roles that a user has. + Compared to get_roles_for_user(), this function retrieves indirect roles besides direct roles. + For example: + g, alice, role:admin + g, role:admin, role:user + + get_roles_for_user("alice") can only get: ["role:admin"]. + But get_implicit_roles_for_user("alice") will get: ["role:admin", "role:user"]. + """ + res = [] + queue = [name] + + while queue: + name = queue.pop(0) + + for rm in self.rm_map.values(): + roles = rm.get_roles(name, domain) + for r in roles: + if r not in res: + res.append(r) + queue.append(r) + + return res + + def get_implicit_permissions_for_user(self, user, domain=None): + """ + gets implicit permissions for a user or role. + Compared to get_permissions_for_user(), this function retrieves permissions for inherited roles. + For example: + p, admin, data1, read + p, alice, data2, read + g, alice, admin + + get_permissions_for_user("alice") can only get: [["alice", "data2", "read"]]. + But get_implicit_permissions_for_user("alice") will get: [["admin", "data1", "read"], ["alice", "data2", "read"]]. + """ + roles = self.get_implicit_roles_for_user(user, domain) + + roles.insert(0, user) + + res = [] + for role in roles: + if domain: + permissions = self.get_permissions_for_user_in_domain(role, domain) + else: + permissions = self.get_permissions_for_user(role) + + res.extend(permissions) + + return res + + def get_implicit_users_for_permission(self, *permission): + """ + gets implicit users for a permission. + For example: + p, admin, data1, read + p, bob, data1, read + g, alice, admin + + get_implicit_users_for_permission("data1", "read") will get: ["alice", "bob"]. + Note: only users will be returned, roles (2nd arg in "g") will be excluded. + """ + subjects = self.get_all_subjects() + roles = self.get_all_roles() + + users = set_subtract(subjects, roles) + + res = list() + for user in users: + req = join_slice(user, *permission) + allowed = self.enforce(*req) + + if allowed: + res.append(user) + + return res + + def get_roles_for_user_in_domain(self, name, domain): + """gets the roles that a user has inside a domain.""" + return self.model.model['g']['g'].rm.get_roles(name, domain) + + def get_users_for_role_in_domain(self, name, domain): + """gets the users that has a role inside a domain.""" + return self.model.model['g']['g'].rm.get_users(name, domain) + + def add_role_for_user_in_domain(self, user, role, domain): + """adds a role for a user inside a domain.""" + """Returns false if the user already has the role (aka not affected).""" + return self.add_grouping_policy(user, role, domain) + + def delete_roles_for_user_in_domain(self, user, role, domain): + """deletes a role for a user inside a domain.""" + """Returns false if the user does not have any roles (aka not affected).""" + return self.remove_filtered_grouping_policy(0, user, role, domain) + + def get_permissions_for_user_in_domain(self, user, domain): + """gets permissions for a user or role inside domain.""" + return self.get_filtered_policy(0, user, domain) diff --git a/utils/casbin/internal_enforcer.py b/utils/casbin/internal_enforcer.py new file mode 100644 index 0000000..0b9a1ac --- /dev/null +++ b/utils/casbin/internal_enforcer.py @@ -0,0 +1,122 @@ +from utils.casbin.core_enforcer import CoreEnforcer +from utils.casbin.model.policy_op import PolicyOp + +class InternalEnforcer(CoreEnforcer): + """ + InternalEnforcer = CoreEnforcer + Internal API. + """ + + def _add_policy(self, sec, ptype, rule): + """adds a rule to the current policy.""" + rule_added = self.model.add_policy(sec, ptype, rule) + if not rule_added: + return rule_added + + if self.adapter and self.auto_save: + if self.adapter.add_policy(sec, ptype, rule) is False: + return False + + if self.watcher: + self.watcher.update() + + return rule_added + + def _add_policies(self,sec,ptype,rules): + """adds rules to the current policy.""" + rules_added = self.model.add_policies(sec, ptype, rules) + if not rules_added: + return rules_added + + if self.adapter and self.auto_save: + if hasattr(self.adapter,'add_policies') is False: + return False + + if self.adapter.add_policies(sec, ptype, rules) is False: + return False + + if self.watcher: + self.watcher.update() + + return rules_added + + def _update_policy(self, sec, ptype, old_rule, new_rule): + """updates a rule from the current policy.""" + rule_updated = self.model.update_policy(sec, ptype, old_rule, new_rule) + + if not rule_updated: + return rule_updated + + if self.adapter and self.auto_save: + + if self.adapter.update_policy(sec, ptype, old_rule, new_rule) is False: + return False + + if self.watcher: + self.watcher.update() + + return rule_updated + + def _update_policies(self, sec, ptype, old_rules, new_rules): + """updates rules from the current policy.""" + rules_updated = self.model.update_policies(sec, ptype, old_rules, new_rules) + + if not rules_updated: + return rules_updated + + if self.adapter and self.auto_save: + + if self.adapter.update_policies(sec, ptype, old_rules, new_rules) is False: + return False + + if self.watcher: + self.watcher.update() + + return rules_updated + + def _remove_policy(self, sec, ptype, rule): + """removes a rule from the current policy.""" + rule_removed = self.model.remove_policy(sec, ptype, rule) + if not rule_removed: + return rule_removed + + if self.adapter and self.auto_save: + if self.adapter.remove_policy(sec, ptype, rule) is False: + return False + + if self.watcher: + self.watcher.update() + + return rule_removed + + def _remove_policies(self, sec, ptype, rules): + """RemovePolicies removes policy rules from the model.""" + rules_removed = self.model.remove_policies(sec, ptype, rules) + if not rules_removed: + return rules_removed + + if self.adapter and self.auto_save: + if hasattr(self.adapter,'remove_policies') is False: + return False + + if self.adapter.remove_policies(sec, ptype, rules) is False: + return False + + if self.watcher: + self.watcher.update() + + return rules_removed + + def _remove_filtered_policy(self, sec, ptype, field_index, *field_values): + """removes rules based on field filters from the current policy.""" + rule_removed = self.model.remove_filtered_policy(sec, ptype, field_index, *field_values) + if not rule_removed: + return rule_removed + + if self.adapter and self.auto_save: + if self.adapter.remove_filtered_policy(sec, ptype, field_index, *field_values) is False: + return False + + if self.watcher: + self.watcher.update() + + return rule_removed \ No newline at end of file diff --git a/utils/casbin/management_enforcer.py b/utils/casbin/management_enforcer.py new file mode 100644 index 0000000..78851ed --- /dev/null +++ b/utils/casbin/management_enforcer.py @@ -0,0 +1,271 @@ +from utils.casbin.internal_enforcer import InternalEnforcer + +class ManagementEnforcer(InternalEnforcer): + """ + ManagementEnforcer = InternalEnforcer + Management API. + """ + + def get_all_subjects(self): + """gets the list of subjects that show up in the current policy.""" + return self.get_all_named_subjects('p') + + def get_all_named_subjects(self, ptype): + """gets the list of subjects that show up in the current named policy.""" + return self.model.get_values_for_field_in_policy('p', ptype, 0) + + def get_all_objects(self): + """gets the list of objects that show up in the current policy.""" + return self.get_all_named_objects('p') + + def get_all_named_objects(self, ptype): + """gets the list of objects that show up in the current named policy.""" + return self.model.get_values_for_field_in_policy('p', ptype, 1) + + def get_all_actions(self): + """gets the list of actions that show up in the current policy.""" + return self.get_all_named_actions('p') + + def get_all_named_actions(self, ptype): + """gets the list of actions that show up in the current named policy.""" + return self.model.get_values_for_field_in_policy('p', ptype, 2) + + def get_all_roles(self): + """gets the list of roles that show up in the current named policy.""" + return self.get_all_named_roles('g') + + def get_all_named_roles(self, ptype): + """gets all the authorization rules in the policy.""" + return self.model.get_values_for_field_in_policy('g', ptype, 1) + + def get_policy(self): + """gets all the authorization rules in the policy.""" + return self.get_named_policy('p') + + def get_filtered_policy(self, field_index, *field_values): + """gets all the authorization rules in the policy, field filters can be specified.""" + return self.get_filtered_named_policy('p', field_index, *field_values) + + def get_named_policy(self, ptype): + """gets all the authorization rules in the named policy.""" + return self.model.get_policy('p', ptype) + + def get_filtered_named_policy(self, ptype, field_index, *field_values): + """gets all the authorization rules in the named policy, field filters can be specified.""" + return self.model.get_filtered_policy('p', ptype, field_index, *field_values) + + def get_grouping_policy(self): + """gets all the role inheritance rules in the policy.""" + return self.get_named_grouping_policy('g') + + def get_filtered_grouping_policy(self, field_index, *field_values): + """gets all the role inheritance rules in the policy, field filters can be specified.""" + return self.get_filtered_named_grouping_policy("g", field_index, *field_values) + + def get_named_grouping_policy(self, ptype): + """gets all the role inheritance rules in the policy.""" + return self.model.get_policy('g', ptype) + + def get_filtered_named_grouping_policy(self, ptype, field_index, *field_values): + """gets all the role inheritance rules in the policy, field filters can be specified.""" + return self.model.get_filtered_policy('g', ptype, field_index, *field_values) + + def has_policy(self, *params): + """determines whether an authorization rule exists.""" + return self.has_named_policy('p', *params) + + def has_named_policy(self, ptype, *params): + """determines whether a named authorization rule exists.""" + if len(params) == 1 and isinstance(params[0], list): + str_slice = params[0] + return self.model.has_policy('p', ptype, str_slice) + + return self.model.has_policy('p', ptype, list(params)) + + def add_policy(self, *params): + """adds an authorization rule to the current policy. + + If the rule already exists, the function returns false and the rule will not be added. + Otherwise the function returns true by adding the new rule. + """ + return self.add_named_policy('p', *params) + + def add_policies(self,rules): + """adds authorization rules to the current policy. + + If the rule already exists, the function returns false for the corresponding rule and the rule will not be added. + Otherwise the function returns true for the corresponding rule by adding the new rule. + """ + return self.add_named_policies('p',rules) + + def add_named_policy(self, ptype, *params): + """adds an authorization rule to the current named policy. + + If the rule already exists, the function returns false and the rule will not be added. + Otherwise the function returns true by adding the new rule. + """ + + if len(params) == 1 and isinstance(params[0], list): + str_slice = params[0] + rule_added = self._add_policy('p', ptype, str_slice) + else: + rule_added = self._add_policy('p', ptype, list(params)) + + return rule_added + + def add_named_policies(self,ptype,rules): + """adds authorization rules to the current named policy. + + If the rule already exists, the function returns false for the corresponding rule and the rule will not be added. + Otherwise the function returns true for the corresponding by adding the new rule.""" + return self._add_policies('p',ptype,rules) + + def update_policy(self, old_rule, new_rule): + """updates an authorization rule from the current policy.""" + return self.update_named_policy('p', old_rule, new_rule) + + def update_policies(self, old_rules, new_rules): + """updates authorization rules from the current policy.""" + return self.update_named_policies('p', old_rules, new_rules) + + def update_named_policy(self, ptype, old_rule, new_rule): + """updates an authorization rule from the current named policy.""" + return self._update_policy('p', ptype, old_rule, new_rule) + + def update_named_policies(self, ptype, old_rules, new_rules): + """updates authorization rules from the current named policy.""" + return self._update_policies('p', ptype, old_rules, new_rules) + + def remove_policy(self, *params): + """removes an authorization rule from the current policy.""" + return self.remove_named_policy('p', *params) + + def remove_policies(self,rules): + """removes authorization rules from the current policy.""" + return self.remove_named_policies('p',rules) + + def remove_filtered_policy(self, field_index, *field_values): + """removes an authorization rule from the current policy, field filters can be specified.""" + return self.remove_filtered_named_policy('p', field_index, *field_values) + + def remove_named_policy(self, ptype, *params): + """removes an authorization rule from the current named policy.""" + + if len(params) == 1 and isinstance(params[0], list): + str_slice = params[0] + rule_removed = self._remove_policy('p', ptype, str_slice) + else: + rule_removed = self._remove_policy('p', ptype, list(params)) + + return rule_removed + + def remove_named_policies(self,ptype,rules): + """removes authorization rules from the current named policy.""" + return self._remove_policies('p',ptype,rules) + + def remove_filtered_named_policy(self, ptype, field_index, *field_values): + """removes an authorization rule from the current named policy, field filters can be specified.""" + return self._remove_filtered_policy('p', ptype, field_index, *field_values) + + def has_grouping_policy(self, *params): + """determines whether a role inheritance rule exists.""" + + return self.has_named_grouping_policy('g', *params) + + def has_named_grouping_policy(self, ptype, *params): + """determines whether a named role inheritance rule exists.""" + + if len(params) == 1 and isinstance(params[0], list): + str_slice = params[0] + return self.model.has_policy('g', ptype, str_slice) + + return self.model.has_policy('g', ptype, list(params)) + + def add_grouping_policy(self, *params): + """adds a role inheritance rule to the current policy. + + If the rule already exists, the function returns false and the rule will not be added. + Otherwise the function returns true by adding the new rule. + """ + return self.add_named_grouping_policy('g', *params) + + def add_grouping_policies(self,rules): + """adds role inheritance rulea to the current policy. + + If the rule already exists, the function returns false for the corresponding policy rule and the rule will not be added. + Otherwise the function returns true for the corresponding policy rule by adding the new rule. + """ + return self.add_named_grouping_policies('g',rules) + + def add_named_grouping_policy(self, ptype, *params): + """adds a named role inheritance rule to the current policy. + + If the rule already exists, the function returns false and the rule will not be added. + Otherwise the function returns true by adding the new rule. + """ + + if len(params) == 1 and isinstance(params[0], list): + str_slice = params[0] + rule_added = self._add_policy('g', ptype, str_slice) + else: + rule_added = self._add_policy('g', ptype, list(params)) + + if self.auto_build_role_links: + self.build_role_links() + return rule_added + + def add_named_grouping_policies(self,ptype,rules): + """"adds named role inheritance rules to the current policy. + + If the rule already exists, the function returns false for the corresponding policy rule and the rule will not be added. + Otherwise the function returns true for the corresponding policy rule by adding the new rule.""" + rules_added = self._add_policies('g',ptype,rules) + if self.auto_build_role_links: + self.build_role_links() + + return rules_added + + def remove_grouping_policy(self, *params): + """removes a role inheritance rule from the current policy.""" + return self.remove_named_grouping_policy('g', *params) + + def remove_grouping_policies(self,rules): + """removes role inheritance rulea from the current policy.""" + return self.remove_named_grouping_policies('g',rules) + + def remove_filtered_grouping_policy(self, field_index, *field_values): + """removes a role inheritance rule from the current policy, field filters can be specified.""" + return self.remove_filtered_named_grouping_policy('g', field_index, *field_values) + + def remove_named_grouping_policy(self, ptype, *params): + """removes a role inheritance rule from the current named policy.""" + + if len(params) == 1 and isinstance(params[0], list): + str_slice = params[0] + rule_removed = self._remove_policy('g', ptype, str_slice) + else: + rule_removed = self._remove_policy('g', ptype, list(params)) + + if self.auto_build_role_links: + self.build_role_links() + return rule_removed + + def remove_named_grouping_policies(self,ptype,rules): + """ removes role inheritance rules from the current named policy.""" + rules_removed = self._remove_policies('g',ptype,rules) + + if self.auto_build_role_links: + self.build_role_links() + + return rules_removed + + def remove_filtered_named_grouping_policy(self, ptype, field_index, *field_values): + """removes a role inheritance rule from the current named policy, field filters can be specified.""" + rule_removed = self._remove_filtered_policy('g', ptype, field_index, *field_values) + + if self.auto_build_role_links: + self.build_role_links() + return rule_removed + + def add_function(self, name, func): + """adds a customized function.""" + self.fm.add_function(name, func) \ No newline at end of file diff --git a/utils/casbin/model/__init__.py b/utils/casbin/model/__init__.py new file mode 100644 index 0000000..2bdda40 --- /dev/null +++ b/utils/casbin/model/__init__.py @@ -0,0 +1,4 @@ +from .assertion import Assertion +from .model import Model +from .policy import Policy +from .function import FunctionMap diff --git a/utils/casbin/model/assertion.py b/utils/casbin/model/assertion.py new file mode 100644 index 0000000..3585c65 --- /dev/null +++ b/utils/casbin/model/assertion.py @@ -0,0 +1,47 @@ +import logging +from utils.casbin.model.policy_op import PolicyOp + + +class Assertion: + def __init__(self): + self.logger = logging.getLogger() + self.key = "" + self.value = "" + self.tokens = [] + self.policy = [] + self.rm = None + + def build_role_links(self, rm): + self.rm = rm + count = self.value.count("_") + if count < 2: + raise RuntimeError('the number of "_" in role definition should be at least 2') + + for rule in self.policy: + if len(rule) < count: + pass + # raise RuntimeError("grouping policy elements do not meet role definition") + if len(rule) > count: + rule = rule[:count] + + self.rm.add_link(*rule[:count]) + + self.logger.info("Role links for: {}".format(self.key)) + self.rm.print_roles() + + def build_incremental_role_links(self, rm, op, rules): + self.rm = rm + count = self.value.count("_") + if count < 2: + raise RuntimeError('the number of "_" in role definition should be at least 2') + for rule in rules: + if len(rule) < count: + raise TypeError("grouping policy elements do not meet role definition") + if len(rule) > count: + rule = rule[:count] + if op == PolicyOp.Policy_add: + rm.add_link(rule[0], rule[1], *rule[2:]) + elif op == PolicyOp.Policy_remove: + rm.delete_link(rule[0], rule[1], *rule[2:]) + else: + raise TypeError("Invalid operation: " + str(op)) diff --git a/utils/casbin/model/function.py b/utils/casbin/model/function.py new file mode 100644 index 0000000..731a588 --- /dev/null +++ b/utils/casbin/model/function.py @@ -0,0 +1,22 @@ +from utils.casbin import util + + +class FunctionMap: + fm = dict() + + def add_function(self, name, func): + self.fm[name] = func + + @staticmethod + def load_function_map(): + fm = FunctionMap() + fm.add_function("keyMatch", util.key_match_func) + fm.add_function("keyMatch2", util.key_match2_func) + fm.add_function("regexMatch", util.regex_match_func) + fm.add_function("ipMatch", util.ip_match_func) + fm.add_function("globMatch", util.glob_match_func) + + return fm + + def get_functions(self): + return self.fm diff --git a/utils/casbin/model/model.py b/utils/casbin/model/model.py new file mode 100644 index 0000000..483e8d6 --- /dev/null +++ b/utils/casbin/model/model.py @@ -0,0 +1,80 @@ +from . import Assertion +from utils.casbin import util, config +from .policy import Policy + +class Model(Policy): + + section_name_map = { + 'r': 'request_definition', + 'p': 'policy_definition', + 'g': 'role_definition', + 'e': 'policy_effect', + 'm': 'matchers', + } + + def _load_assertion(self, cfg, sec, key): + value = cfg.get(self.section_name_map[sec] + "::" + key) + + return self.add_def(sec, key, value) + + def add_def(self, sec, key, value): + if value == "": + return + + ast = Assertion() + ast.key = key + ast.value = value + + if "r" == sec or "p" == sec: + ast.tokens = ast.value.split(",") + for i,token in enumerate(ast.tokens): + ast.tokens[i] = key + "_" + token.strip() + else: + ast.value = util.remove_comments(util.escape_assertion(ast.value)) + + if sec not in self.model.keys(): + self.model[sec] = {} + + self.model[sec][key] = ast + + return True + + def _get_key_suffix(self, i): + if i == 1: + return "" + + return str(i) + + def _load_section(self, cfg, sec): + i = 1 + while True: + if not self._load_assertion(cfg, sec, sec + self._get_key_suffix(i)): + break + else: + i = i + 1 + + def load_model(self, path): + cfg = config.Config.new_config(path) + + self._load_section(cfg, "r") + self._load_section(cfg, "p") + self._load_section(cfg, "e") + self._load_section(cfg, "m") + + self._load_section(cfg, "g") + + def load_model_from_text(self, text): + cfg = config.Config.new_config_from_text(text) + + self._load_section(cfg, "r") + self._load_section(cfg, "p") + self._load_section(cfg, "e") + self._load_section(cfg, "m") + + self._load_section(cfg, "g") + + def print_model(self): + self.logger.info("Model:") + for k, v in self.model.items(): + for i, j in v.items(): + self.logger.info("%s.%s: %s", k, i, j.value) diff --git a/utils/casbin/model/policy.py b/utils/casbin/model/policy.py new file mode 100644 index 0000000..d56301f --- /dev/null +++ b/utils/casbin/model/policy.py @@ -0,0 +1,190 @@ +import logging + +class Policy: + def __init__(self): + self.logger = logging.getLogger() + self.model = {} + + def build_role_links(self, rm_map): + """initializes the roles in RBAC.""" + + if "g" not in self.model.keys(): + return + + for ptype, ast in self.model["g"].items(): + rm = rm_map[ptype] + ast.build_role_links(rm) + + def build_incremental_role_links(self, rm, op, sec, ptype, rules): + if sec == "g": + self.model.get(sec).get(ptype).build_incremental_role_links(rm, op, rules) + + def print_policy(self): + """Log using info""" + + self.logger.info("Policy:") + for sec in ["p", "g"]: + if sec not in self.model.keys(): + continue + + for key, ast in self.model[sec].items(): + self.logger.info("{} : {} : {}".format(key, ast.value, ast.policy)) + + def clear_policy(self): + """clears all current policy.""" + + for sec in ["p", "g"]: + if sec not in self.model.keys(): + continue + + for key in self.model[sec].keys(): + self.model[sec][key].policy = [] + + def get_policy(self, sec, ptype): + """gets all rules in a policy.""" + + return self.model[sec][ptype].policy + + def get_filtered_policy(self, sec, ptype, field_index, *field_values): + """gets rules based on field filters from a policy.""" + return [ + rule for rule in self.model[sec][ptype].policy + if all(value == "" or rule[field_index + i] == value for i, value in enumerate(field_values)) + ] + + def has_policy(self, sec, ptype, rule): + """determines whether a model has the specified policy rule.""" + if sec not in self.model.keys(): + return False + if ptype not in self.model[sec]: + return False + + return rule in self.model[sec][ptype].policy + + def add_policy(self, sec, ptype, rule): + """adds a policy rule to the model.""" + + if not self.has_policy(sec, ptype, rule): + self.model[sec][ptype].policy.append(rule) + return True + + return False + + def add_policies(self,sec,ptype,rules): + """adds policy rules to the model.""" + + for rule in rules: + if self.has_policy(sec,ptype,rule): + return False + + for rule in rules: + self.model[sec][ptype].policy.append(rule) + + return True + + def update_policy(self, sec, ptype, old_rule, new_rule): + """update a policy rule from the model.""" + + if not self.has_policy(sec, ptype, old_rule): + return False + + return self.remove_policy(sec, ptype, old_rule) and self.add_policy(sec, ptype, new_rule) + + def update_policies(self, sec, ptype, old_rules, new_rules): + """update policy rules from the model.""" + + for rule in old_rules: + if not self.has_policy(sec, ptype, rule): + return False + + return self.remove_policies(sec, ptype, old_rules) and self.add_policies(sec, ptype, new_rules) + + def remove_policy(self, sec, ptype, rule): + """removes a policy rule from the model.""" + if not self.has_policy(sec, ptype, rule): + return False + + self.model[sec][ptype].policy.remove(rule) + + return rule not in self.model[sec][ptype].policy + + def remove_policies(self, sec, ptype, rules): + """RemovePolicies removes policy rules from the model.""" + + for rule in rules: + if not self.has_policy(sec,ptype,rule): + return False + self.model[sec][ptype].policy.remove(rule) + if rule in self.model[sec][ptype].policy: + return False + + return True + + def remove_policies_with_effected(self, sec, ptype, rules): + effected = [] + for rule in rules: + if self.has_policy(sec, ptype, rule): + effected.append(rule) + self.remove_policy(sec, ptype, rule) + + return effected + + def remove_filtered_policy_returns_effects(self, sec, ptype, field_index, *field_values): + """ + remove_filtered_policy_returns_effects removes policy rules based on field filters from the model. + """ + tmp = [] + effects = [] + + if(len(field_values) == 0): + return [] + if sec not in self.model.keys(): + return [] + if ptype not in self.model[sec]: + return [] + + for rule in self.model[sec][ptype].policy: + if all(value == "" or rule[field_index + i] == value for i, value in enumerate(field_values[0])): + effects.append(rule) + else: + tmp.append(rule) + + self.model[sec][ptype].policy = tmp + + return effects + + + def remove_filtered_policy(self, sec, ptype, field_index, *field_values): + """removes policy rules based on field filters from the model.""" + tmp = [] + res = False + + if sec not in self.model.keys(): + return res + if ptype not in self.model[sec]: + return res + + for rule in self.model[sec][ptype].policy: + if all(value == "" or rule[field_index + i] == value for i, value in enumerate(field_values)): + res = True + else: + tmp.append(rule) + + self.model[sec][ptype].policy = tmp + + return res + + def get_values_for_field_in_policy(self, sec, ptype, field_index): + """gets all values for a field for all rules in a policy, duplicated values are removed.""" + values = [] + if sec not in self.model.keys(): + return values + if ptype not in self.model[sec]: + return values + + for rule in self.model[sec][ptype].policy: + value = rule[field_index] + if value not in values: + values.append(value) + + return values diff --git a/utils/casbin/model/policy_op.py b/utils/casbin/model/policy_op.py new file mode 100644 index 0000000..de4b146 --- /dev/null +++ b/utils/casbin/model/policy_op.py @@ -0,0 +1,5 @@ +import enum + +class PolicyOp(enum.Enum): + Policy_add = 1 + Policy_remove = 2 \ No newline at end of file diff --git a/utils/casbin/persist/__init__.py b/utils/casbin/persist/__init__.py new file mode 100644 index 0000000..0c22ec8 --- /dev/null +++ b/utils/casbin/persist/__init__.py @@ -0,0 +1,4 @@ +from .adapter import * +from .adapter_filtered import * +from .batch_adapter import * +from .adapters import * \ No newline at end of file diff --git a/utils/casbin/persist/adapter.py b/utils/casbin/persist/adapter.py new file mode 100644 index 0000000..b290be2 --- /dev/null +++ b/utils/casbin/persist/adapter.py @@ -0,0 +1,46 @@ +def load_policy_line(line, model): + """loads a text line as a policy rule to model.""" + + if line == "": + return + + if line[:1] == "#": + return + + tokens = line.split(", ") + key = tokens[0] + sec = key[0] + + if sec not in model.model.keys(): + return + + if key not in model.model[sec].keys(): + return + + model.model[sec][key].policy.append(tokens[1:]) + + +class Adapter: + """the interface for Casbin adapters.""" + + def load_policy(self, model): + """loads all policy rules from the storage.""" + pass + + def save_policy(self, model): + """saves all policy rules to the storage.""" + pass + + def add_policy(self, sec, ptype, rule): + """adds a policy rule to the storage.""" + pass + + def remove_policy(self, sec, ptype, rule): + """removes a policy rule from the storage.""" + pass + + def remove_filtered_policy(self, sec, ptype, field_index, *field_values): + """removes policy rules that match the filter from the storage. + This is part of the Auto-Save feature. + """ + pass diff --git a/utils/casbin/persist/adapter_filtered.py b/utils/casbin/persist/adapter_filtered.py new file mode 100644 index 0000000..fbb4fef --- /dev/null +++ b/utils/casbin/persist/adapter_filtered.py @@ -0,0 +1,13 @@ +from .adapter import Adapter + +""" FilteredAdapter is the interface for Casbin adapters supporting filtered policies.""" +class FilteredAdapter(Adapter): + def is_filtered(self): + """IsFiltered returns true if the loaded policy has been filtered + Marks if the loaded policy is filtered or not + """ + pass + + def load_filtered_policy(self, model, filter): + """Loads policy rules that match the filter from the storage.""" + pass \ No newline at end of file diff --git a/utils/casbin/persist/adapters/__init__.py b/utils/casbin/persist/adapters/__init__.py new file mode 100644 index 0000000..f1b9da1 --- /dev/null +++ b/utils/casbin/persist/adapters/__init__.py @@ -0,0 +1,2 @@ +from .file_adapter import FileAdapter +from .adapter_filtered import FilteredAdapter \ No newline at end of file diff --git a/utils/casbin/persist/adapters/adapter_filtered.py b/utils/casbin/persist/adapters/adapter_filtered.py new file mode 100644 index 0000000..928ec1e --- /dev/null +++ b/utils/casbin/persist/adapters/adapter_filtered.py @@ -0,0 +1,89 @@ +from utils.casbin import persist +from utils.casbin import model +from .file_adapter import FileAdapter +import os + +class Filter: + #P,G are string [] + P = [] + G = [] + +class FilteredAdapter (FileAdapter,persist.FilteredAdapter): + filtered = False + _file_path = "" + filter = Filter() + #new_filtered_adapte is the constructor for FilteredAdapter. + def __init__(self,file_path): + self.filtered = True + self._file_path = file_path + + def load_policy(self,model): + if not os.path.isfile(self._file_path): + raise RuntimeError("invalid file path, file path cannot be empty") + self.filtered=False + self._load_policy_file(model) + + #load_filtered_policy loads only policy rules that match the filter. + def load_filtered_policy(self,model,filter): + if filter == None: + return self.load_policy(model) + + if not os.path.isfile(self._file_path): + raise RuntimeError("invalid file path, file path cannot be empty") + + try: + filter_value = [filter.__dict__['P']]+[filter.__dict__['G']] + except: + raise RuntimeError("invalid filter type") + + self.load_filtered_policy_file(model,filter_value,persist.load_policy_line) + self.filtered = True + + def load_filtered_policy_file(self,model,filter,hanlder): + with open(self._file_path, "rb") as file: + while True: + line = file.readline() + line = line.decode().strip() + if line == '\n': + continue + if not line : + break + if filter_line(line,filter): + continue + + hanlder(line,model) + + #is_filtered returns true if the loaded policy has been filtered. + def is_filtered(self): + return self.filtered + def save_policy(self,model): + if self.filtered: + raise RuntimeError("cannot save a filtered policy") + + self._save_policy_file(model) + +def filter_line(line,filter): + if filter == None: + return False + + p = line.split(',') + if len(p) == 0: + return True + filter_slice = [] + + if p[0].strip()== 'p': + filter_slice = filter[0] + elif p[0].strip() == 'g': + filter_slice = filter[1] + return filter_words(p,filter_slice) + +def filter_words(line,filter): + if len(line) < len(filter)+1: + return True + skip_line=False + for i,v in enumerate(filter): + if(len(v) >0 and ( v.strip() != line[i+1].strip() ) ): + skip_line = True + break + + return skip_line \ No newline at end of file diff --git a/utils/casbin/persist/adapters/file_adapter.py b/utils/casbin/persist/adapters/file_adapter.py new file mode 100644 index 0000000..f6a81ce --- /dev/null +++ b/utils/casbin/persist/adapters/file_adapter.py @@ -0,0 +1,62 @@ +from utils.casbin import persist +import os + +class FileAdapter(persist.Adapter): + """the file adapter for Casbin. + It can load policy from file or save policy to file. + """ + _file_path = "" + + def __init__(self, file_path): + self._file_path = file_path + + def load_policy(self, model): + if not os.path.isfile(self._file_path): + raise RuntimeError("invalid file path, file path cannot be empty") + + self._load_policy_file(model) + + def save_policy(self, model): + if not os.path.isfile(self._file_path): + raise RuntimeError("invalid file path, file path cannot be empty") + + self._save_policy_file(model) + + def _load_policy_file(self, model): + with open(self._file_path, "rb") as file: + line = file.readline() + while line: + persist.load_policy_line(line.decode().strip(), model) + line = file.readline() + + def _save_policy_file(self, model): + with open(self._file_path, "w") as file: + lines = [] + + if "p" in model.model.keys(): + for key, ast in model.model["p"].items(): + for pvals in ast.policy: + lines.append(key + ", " + ", ".join(pvals)) + + if "g" in model.model.keys(): + for key, ast in model.model["g"].items(): + for pvals in ast.policy: + lines.append(key + ", " + ", ".join(pvals)) + + for i, line in enumerate(lines): + if i != len(lines) - 1: + lines[i] += "\n" + + file.writelines(lines) + + def add_policy(self, sec, ptype, rule): + pass + + def add_policies(self,sec,ptype,rules): + pass + + def remove_policy(self, sec, ptype, rule): + pass + + def remove_policies(self,sec,ptype,rules): + pass \ No newline at end of file diff --git a/utils/casbin/persist/adapters/update_adapter.py b/utils/casbin/persist/adapters/update_adapter.py new file mode 100644 index 0000000..11d4f3f --- /dev/null +++ b/utils/casbin/persist/adapters/update_adapter.py @@ -0,0 +1,9 @@ +class UpdateAdapter: + """ UpdateAdapter is the interface for Casbin adapters with add update policy function. """ + + def update_policy(self, sec, ptype, old_rule, new_policy): + """ + update_policy updates a policy rule from storage. + This is part of the Auto-Save feature. + """ + pass \ No newline at end of file diff --git a/utils/casbin/persist/batch_adapter.py b/utils/casbin/persist/batch_adapter.py new file mode 100644 index 0000000..6dbf88d --- /dev/null +++ b/utils/casbin/persist/batch_adapter.py @@ -0,0 +1,11 @@ +from .adapter import Adapter + +"""BatchAdapter is the interface for Casbin adapters with multiple add and remove policy functions.""" +class BatchAdapter(Adapter): + def add_policies(self,sec,ptype,rules): + """AddPolicies adds policy rules to the storage.""" + pass + + def remove_policies(self,sec,ptype,rules): + """RemovePolicies removes policy rules from the storage.""" + pass \ No newline at end of file diff --git a/utils/casbin/persist/dispatcher.py b/utils/casbin/persist/dispatcher.py new file mode 100644 index 0000000..b499cde --- /dev/null +++ b/utils/casbin/persist/dispatcher.py @@ -0,0 +1,21 @@ +class Dispatcher: + """Dispatcher is the interface for pycasbin dispatcher""" + def add_policies(self, sec, ptype, rules): + """add_policies adds policies rule to all instance.""" + pass + + def remove_policies(self, sec, ptype, rules): + """remove_policies removes policies rule from all instance.""" + pass + + def remove_filtered_policy(self, sec, ptype, field_index, field_values): + """remove_filtered_policy removes policy rules that match the filter from all instance.""" + pass + + def clear_policy(self): + """clear_policy clears all current policy in all instances.""" + pass + + def update_policy(self, sec, ptype, old_rule, new_rule): + """update_policy updates policy rule from all instance.""" + pass diff --git a/utils/casbin/rbac/__init__.py b/utils/casbin/rbac/__init__.py new file mode 100644 index 0000000..cd3c152 --- /dev/null +++ b/utils/casbin/rbac/__init__.py @@ -0,0 +1 @@ +from .role_manager import RoleManager diff --git a/utils/casbin/rbac/default_role_manager/__init__.py b/utils/casbin/rbac/default_role_manager/__init__.py new file mode 100644 index 0000000..4ff0284 --- /dev/null +++ b/utils/casbin/rbac/default_role_manager/__init__.py @@ -0,0 +1 @@ +from .role_manager import RoleManager \ No newline at end of file diff --git a/utils/casbin/rbac/default_role_manager/role_manager.py b/utils/casbin/rbac/default_role_manager/role_manager.py new file mode 100644 index 0000000..551af64 --- /dev/null +++ b/utils/casbin/rbac/default_role_manager/role_manager.py @@ -0,0 +1,219 @@ +import logging + +from utils.casbin.rbac import RoleManager + + +class RoleManager(RoleManager): + """provides a default implementation for the RoleManager interface""" + + all_roles = dict() + max_hierarchy_level = 0 + + def __init__(self, max_hierarchy_level): + self.logger = logging.getLogger() + self.all_roles = dict() + self.max_hierarchy_level = max_hierarchy_level + self.matching_func = None + self.domain_matching_func = None + self.has_pattern = None + self.has_domain_pattern = None + + def add_matching_func(self, fn=None): + self.has_pattern = True + self.matching_func = fn + + def add_domain_matching_func(self, fn=None): + self.has_domain_pattern = True + self.domain_matching_func = fn + + def has_role(self, name): + if self.matching_func is None: + return name in self.all_roles.keys() + else: + for key in self.all_roles.keys(): + if self.matching_func(name, key): + return True + return False + + def create_role(self, name): + if name not in self.all_roles.keys(): + self.all_roles[name] = Role(name) + + return self.all_roles[name] + + def clear(self): + self.all_roles.clear() + + def add_link(self, name1, name2, *domain): + if len(domain) == 1: + name1 = domain[0] + "::" + name1 + name2 = domain[0] + "::" + name2 + elif len(domain) > 1: + raise RuntimeError("error: domain should be 1 parameter") + + role1 = self.create_role(name1) + role2 = self.create_role(name2) + role1.add_role(role2) + + if self.matching_func is not None: + for key, role in self.all_roles.items(): + if self.matching_func(key, name1) and name1 != key: + self.all_roles[key].add_role(role1) + if self.matching_func(key, name2) and name2 != key: + self.all_roles[name2].add_role(role) + if self.matching_func(name1, key) and name1 != key: + self.all_roles[key].add_role(role1) + if self.matching_func(name2, key) and name2 != key: + self.all_roles[name2].add_role(role) + + def delete_link(self, name1, name2, *domain): + if len(domain) == 1: + name1 = domain[0] + "::" + name1 + name2 = domain[0] + "::" + name2 + elif len(domain) > 1: + raise RuntimeError("error: domain should be 1 parameter") + + if not self.has_role(name1) or not self.has_role(name2): + raise RuntimeError("error: name1 or name2 does not exist") + + role1 = self.create_role(name1) + role2 = self.create_role(name2) + role1.delete_role(role2) + + def has_link(self, name1, name2, *domain): + if len(domain) == 1: + name1 = domain[0] + "::" + name1 + name2 = domain[0] + "::" + name2 + elif len(domain) > 1: + raise RuntimeError("error: domain should be 1 parameter") + + if name1 == name2: + return True + + if not self.has_role(name1) or not self.has_role(name2): + return False + + if self.matching_func is None: + role1 = self.create_role(name1) + return role1.has_role(name2, self.max_hierarchy_level) + else: + for key, role in self.all_roles.items(): + if self.matching_func(name1, key) and role.has_role(name2, self.max_hierarchy_level, + self.matching_func): + return True + return False + + def get_roles(self, name, domain=None): + """ + gets the roles that a subject inherits. + domain is a prefix to the roles. + """ + if domain: + name = domain + "::" + name + + if not self.has_role(name): + return [] + + roles = self.create_role(name).get_roles() + if domain: + for key, value in enumerate(roles): + roles[key] = value[len(domain) + 2:] + + return roles + + def get_users(self, name, *domain): + """ + gets the users that inherits a subject. + domain is an unreferenced parameter here, may be used in other implementations. + """ + if len(domain) == 1: + name = domain[0] + "::" + name + elif len(domain) > 1: + return RuntimeError("error: domain should be 1 parameter") + + if not self.has_role(name): + return [] + + names = [] + for role in self.all_roles.values(): + if role.has_direct_role(name): + if len(domain) == 1: + names.append(role.name[len(domain[0]) + 2:]) + else: + names.append(role.name) + + return names + + def print_roles(self): + line = [] + for role in self.all_roles.values(): + text = role.to_string() + if text: + line.append(text) + self.logger.info(", ".join(line)) + + +class Role: + """represents the data structure for a role in RBAC.""" + + name = "" + + roles = [] + + def __init__(self, name): + self.name = name + self.roles = [] + + def add_role(self, role): + for rr in self.roles: + if rr.name == role.name: + return + + self.roles.append(role) + + def delete_role(self, role): + for rr in self.roles: + if rr.name == role.name: + self.roles.remove(rr) + return + + def has_role(self, name, hierarchy_level, matching_func=None): + if self.has_direct_role(name, matching_func): + return True + if hierarchy_level <= 0: + return False + + for role in self.roles: + if role.has_role(name, hierarchy_level - 1, matching_func): + return True + + return False + + def has_direct_role(self, name, matching_func=None): + if matching_func is None: + for role in self.roles: + if role.name == name: + return True + else: + for role in self.roles: + if matching_func(name, role.name): + return True + return False + + def to_string(self): + if len(self.roles) == 0: + return "" + + names = ", ".join(self.get_roles()) + + if len(self.roles) == 1: + return self.name + " < " + names + else: + return self.name + " < (" + names + ")" + + def get_roles(self): + names = [] + for role in self.roles: + names.append(role.name) + + return names diff --git a/utils/casbin/rbac/role_manager.py b/utils/casbin/rbac/role_manager.py new file mode 100644 index 0000000..9c1cefc --- /dev/null +++ b/utils/casbin/rbac/role_manager.py @@ -0,0 +1,23 @@ +class RoleManager: + """provides interface to define the operations for managing roles.""" + + def clear(self): + pass + + def add_link(self, name1, name2, *domain): + pass + + def delete_link(self, name1, name2, *domain): + pass + + def has_link(self, name1, name2, *domain): + pass + + def get_roles(self, name, *domain): + pass + + def get_users(self, name, *domain): + pass + + def print_roles(self): + pass diff --git a/utils/casbin/synced_enforcer.py b/utils/casbin/synced_enforcer.py new file mode 100644 index 0000000..9061c35 --- /dev/null +++ b/utils/casbin/synced_enforcer.py @@ -0,0 +1,580 @@ +import threading +import time + +from utils.casbin.enforcer import Enforcer +from utils.casbin.util.rwlock import RWLockWrite + + +class AtomicBool(): + + def __init__(self, value): + self._lock = threading.Lock() + self._value = value + + @property + def value(self): + with self._lock: + return self._value + + @value.setter + def value(self, value): + with self._lock: + self._value = value + +class SyncedEnforcer(): + + """SyncedEnforcer wraps Enforcer and provides synchronized access. + It's also a drop-in replacement for Enforcer""" + + def __init__(self, model=None, adapter=None): + self._e = Enforcer(model, adapter) + self._rwlock = RWLockWrite() + self._rl = self._rwlock.gen_rlock() + self._wl = self._rwlock.gen_wlock() + self._auto_loading = AtomicBool(False) + self._auto_loading_thread = None + + def is_auto_loading_running(self): + """check if SyncedEnforcer is auto loading policies""" + return self._auto_loading.value + + def _auto_load_policy(self, interval): + while self.is_auto_loading_running(): + time.sleep(interval) + self.load_policy() + + def start_auto_load_policy(self, interval): + """starts a thread that will call load_policy every interval seconds""" + if self.is_auto_loading_running(): + return + self._auto_loading.value = True + self._auto_loading_thread = threading.Thread(target=self._auto_load_policy, args=[interval], daemon=True) + + def stop_auto_load_policy(self): + """stops the thread started by start_auto_load_policy""" + if self.is_auto_loading_running(): + self._auto_loading.value = False + + def get_model(self): + """gets the current model.""" + with self._rl: + return self._e.get_model() + + def set_model(self, m): + """sets the current model.""" + with self._wl: + return self._e.set_model(m) + + def load_model(self): + """reloads the model from the model CONF file. + Because the policy is attached to a model, so the policy is invalidated and needs to be reloaded by calling LoadPolicy(). + """ + with self._wl: + return self._e.load_model() + + def get_role_manager(self): + """gets the current role manager.""" + with self._rl: + return self._e.get_role_manager() + + def set_role_manager(self, rm): + with self._wl: + self._e.set_role_manager(rm) + + def get_adapter(self): + """gets the current adapter.""" + with self._rl: + self._e.get_adapter() + + def set_adapter(self, adapter): + """sets the current adapter.""" + with self._wl: + self._e.set_adapter(adapter) + + def set_watcher(self, watcher): + """sets the current watcher.""" + with self._wl: + self._e.set_watcher(watcher) + + def set_effector(self, eft): + """sets the current effector.""" + with self._wl: + self._e.set_effector(eft) + + def clear_policy(self): + """ clears all policy.""" + with self._wl: + return self._e.clear_policy() + + def load_policy(self): + """reloads the policy from file/database.""" + with self._wl: + return self._e.load_policy() + + def load_filtered_policy(self, filter): + """"reloads a filtered policy from file/database.""" + with self._wl: + return self._e.load_filtered_policy(filter) + + def save_policy(self): + with self._rl: + return self._e.save_policy() + + def build_role_links(self): + """manually rebuild the role inheritance relations.""" + with self._rl: + return self._e.build_role_links() + + def enforce(self, *rvals): + """decides whether a "subject" can access a "object" with the operation "action", + input parameters are usually: (sub, obj, act). + """ + with self._rl: + return self._e.enforce(*rvals) + + def enforce_ex(self, *rvals): + """decides whether a "subject" can access a "object" with the operation "action", + input parameters are usually: (sub, obj, act). + return judge result with reason + """ + with self._rl: + return self._e.enforce_ex(*rvals) + + def get_all_subjects(self): + """gets the list of subjects that show up in the current policy.""" + with self._rl: + return self._e.get_all_subjects() + + def get_all_named_subjects(self, ptype): + """gets the list of subjects that show up in the current named policy.""" + with self._rl: + return self._e.get_all_named_subjects(ptype) + + def get_all_objects(self): + """gets the list of objects that show up in the current policy.""" + with self._rl: + return self._e.get_all_objects() + + def get_all_named_objects(self, ptype): + """gets the list of objects that show up in the current named policy.""" + with self._rl: + return self._e.get_all_named_objects(ptype) + + def get_all_actions(self): + """gets the list of actions that show up in the current policy.""" + with self._rl: + return self._e.get_all_actions() + + def get_all_named_actions(self, ptype): + """gets the list of actions that show up in the current named policy.""" + with self._rl: + return self._e.get_all_named_actions(ptype) + + def get_all_roles(self): + """gets the list of roles that show up in the current named policy.""" + with self._rl: + return self._e.get_all_roles() + + def get_all_named_roles(self, ptype): + """gets all the authorization rules in the policy.""" + with self._rl: + return self._e.get_all_named_roles(ptype) + + def get_policy(self): + """gets all the authorization rules in the policy.""" + with self._rl: + return self._e.get_policy() + + def get_filtered_policy(self, field_index, *field_values): + """gets all the authorization rules in the policy, field filters can be specified.""" + with self._rl: + return self._e.get_filtered_policy(field_index, *field_values) + + def get_named_policy(self, ptype): + """gets all the authorization rules in the named policy.""" + with self._rl: + return self._e.get_named_policy(ptype) + + def get_filtered_named_policy(self, ptype, field_index, *field_values): + """gets all the authorization rules in the named policy, field filters can be specified.""" + with self._rl: + return self._e.get_filtered_named_policy(ptype, field_index, *field_values) + + def get_grouping_policy(self): + """gets all the role inheritance rules in the policy.""" + with self._rl: + return self._e.get_grouping_policy() + + def get_filtered_grouping_policy(self, field_index, *field_values): + """gets all the role inheritance rules in the policy, field filters can be specified.""" + with self._rl: + return self._e.get_filtered_grouping_policy(field_index, *field_values) + + def get_named_grouping_policy(self, ptype): + """gets all the role inheritance rules in the policy.""" + with self._rl: + return self._e.get_named_grouping_policy(ptype) + + def get_filtered_named_grouping_policy(self, ptype, field_index, *field_values): + """gets all the role inheritance rules in the policy, field filters can be specified.""" + with self._rl: + return self._e.get_filtered_named_grouping_policy(ptype, field_index, *field_values) + + def has_policy(self, *params): + """determines whether an authorization rule exists.""" + with self._rl: + return self._e.has_policy(*params) + + def has_named_policy(self, ptype, *params): + """determines whether a named authorization rule exists.""" + with self._rl: + return self._e.has_named_policy(ptype, *params) + + def add_policy(self, *params): + """adds an authorization rule to the current policy. + If the rule already exists, the function returns false and the rule will not be added. + Otherwise the function returns true by adding the new rule. + """ + with self._wl: + return self._e.add_policy(*params) + + def add_named_policy(self, ptype, *params): + """adds an authorization rule to the current named policy. + If the rule already exists, the function returns false and the rule will not be added. + Otherwise the function returns true by adding the new rule. + """ + with self._wl: + return self._e.add_named_policy(ptype, *params) + + def remove_policy(self, *params): + """removes an authorization rule from the current policy.""" + with self._wl: + return self._e.remove_policy(*params) + + def remove_filtered_policy(self, field_index, *field_values): + """removes an authorization rule from the current policy, field filters can be specified.""" + with self._wl: + return self._e.remove_filtered_policy(field_index, *field_values) + + def remove_named_policy(self, ptype, *params): + """removes an authorization rule from the current named policy.""" + with self._wl: + return self._e.remove_named_policy(ptype, *params) + + def remove_filtered_named_policy(self, ptype, field_index, *field_values): + """removes an authorization rule from the current named policy, field filters can be specified.""" + with self._wl: + return self._e.remove_filtered_named_policy(ptype, field_index, *field_values) + + def has_grouping_policy(self, *params): + """determines whether a role inheritance rule exists.""" + with self._rl: + return self._e.has_grouping_policy(*params) + + def has_named_grouping_policy(self, ptype, *params): + """determines whether a named role inheritance rule exists.""" + with self._rl: + return self._e.has_named_grouping_policy(ptype, *params) + + def add_grouping_policy(self, *params): + """adds a role inheritance rule to the current policy. + If the rule already exists, the function returns false and the rule will not be added. + Otherwise the function returns true by adding the new rule. + """ + with self._wl: + return self._e.add_grouping_policy(*params) + + def add_named_grouping_policy(self, ptype, *params): + """adds a named role inheritance rule to the current policy. + If the rule already exists, the function returns false and the rule will not be added. + Otherwise the function returns true by adding the new rule. + """ + with self._wl: + return self._e.add_named_grouping_policy(ptype, *params) + + def remove_grouping_policy(self, *params): + """removes a role inheritance rule from the current policy.""" + with self._wl: + return self._e.remove_grouping_policy(*params) + + def remove_filtered_grouping_policy(self, field_index, *field_values): + """removes a role inheritance rule from the current policy, field filters can be specified.""" + with self._wl: + return self._e.remove_filtered_grouping_policy(field_index, *field_values) + + def remove_named_grouping_policy(self, ptype, *params): + """removes a role inheritance rule from the current named policy.""" + with self._wl: + return self._e.remove_named_grouping_policy(ptype, *params) + + def remove_filtered_named_grouping_policy(self, ptype, field_index, *field_values): + """removes a role inheritance rule from the current named policy, field filters can be specified.""" + with self._wl: + return self._e.remove_filtered_named_grouping_policy(ptype, field_index, *field_values) + + def add_function(self, name, func): + """adds a customized function.""" + with self._wl: + return self._e.add_function(name, func) + + # enforcer.py + + def get_roles_for_user(self, name): + """ gets the roles that a user has. """ + with self._rl: + return self._e.get_roles_for_user(name) + + def get_users_for_role(self, name): + """ gets the users that has a role. """ + with self._rl: + return self._e.get_users_for_role(name) + + def has_role_for_user(self, name, role): + """ determines whether a user has a role. """ + with self._rl: + return self._e.has_role_for_user(name, role) + + def add_role_for_user(self, user, role): + """ + adds a role for a user. + Returns false if the user already has the role (aka not affected). + """ + with self._wl: + return self._e.add_role_for_user(user, role) + + def delete_role_for_user(self, user, role): + """ + deletes a role for a user. + Returns false if the user does not have the role (aka not affected). + """ + with self._wl: + return self._e.delete_role_for_user(user, role) + + def delete_roles_for_user(self, user): + """ + deletes all roles for a user. + Returns false if the user does not have any roles (aka not affected). + """ + with self._wl: + return self._e.delete_roles_for_user(user) + + def delete_user(self, user): + """ + deletes a user. + Returns false if the user does not exist (aka not affected). + """ + with self._wl: + return self._e.delete_user(user) + + def delete_role(self, role): + """ + deletes a role. + Returns false if the role does not exist (aka not affected). + """ + with self._wl: + return self._e.delete_role(role) + + def delete_permission(self, *permission): + """ + deletes a permission. + Returns false if the permission does not exist (aka not affected). + """ + with self._wl: + return self._e.delete_permission(*permission) + + def add_permission_for_user(self, user, *permission): + """ + adds a permission for a user or role. + Returns false if the user or role already has the permission (aka not affected). + """ + with self._wl: + return self._e.add_permission_for_user(user, *permission) + + def delete_permission_for_user(self, user, *permission): + """ + deletes a permission for a user or role. + Returns false if the user or role does not have the permission (aka not affected). + """ + with self._wl: + return self._e.delete_permission_for_user(user, *permission) + + def delete_permissions_for_user(self, user): + """ + deletes permissions for a user or role. + Returns false if the user or role does not have any permissions (aka not affected). + """ + with self._wl: + return self._e.delete_permissions_for_user(user) + + def get_permissions_for_user(self, user): + """ + gets permissions for a user or role. + """ + with self._rl: + return self._e.get_permissions_for_user(user) + + def has_permission_for_user(self, user, *permission): + """ + determines whether a user has a permission. + """ + with self._rl: + return self._e.has_permission_for_user(user, *permission) + + def get_implicit_roles_for_user(self, name, *domain): + """ + gets implicit roles that a user has. + Compared to get_roles_for_user(), this function retrieves indirect roles besides direct roles. + For example: + g, alice, role:admin + g, role:admin, role:user + + get_roles_for_user("alice") can only get: ["role:admin"]. + But get_implicit_roles_for_user("alice") will get: ["role:admin", "role:user"]. + """ + with self._rl: + return self._e.get_implicit_roles_for_user(name, *domain) + + def get_implicit_permissions_for_user(self, user, *domain): + """ + gets implicit permissions for a user or role. + Compared to get_permissions_for_user(), this function retrieves permissions for inherited roles. + For example: + p, admin, data1, read + p, alice, data2, read + g, alice, admin + + get_permissions_for_user("alice") can only get: [["alice", "data2", "read"]]. + But get_implicit_permissions_for_user("alice") will get: [["admin", "data1", "read"], ["alice", "data2", "read"]]. + """ + with self._rl: + return self._e.get_implicit_permissions_for_user(user, *domain) + + def get_implicit_users_for_permission(self, *permission): + """ + gets implicit users for a permission. + For example: + p, admin, data1, read + p, bob, data1, read + g, alice, admin + + get_implicit_users_for_permission("data1", "read") will get: ["alice", "bob"]. + Note: only users will be returned, roles (2nd arg in "g") will be excluded. + """ + with self._rl: + return self._e.get_implicit_users_for_permission(*permission) + + def get_roles_for_user_in_domain(self, name, domain): + """gets the roles that a user has inside a domain.""" + with self._rl: + return self._e.get_roles_for_user_in_domain(name, domain) + + def get_users_for_role_in_domain(self, name, domain): + """gets the users that has a role inside a domain.""" + with self._rl: + return self._e.get_users_for_role_in_domain(name, domain) + + def add_role_for_user_in_domain(self, user, role, domain): + """adds a role for a user inside a domain.""" + """Returns false if the user already has the role (aka not affected).""" + with self._wl: + return self._e.add_role_for_user_in_domain(user, role, domain) + + def delete_roles_for_user_in_domain(self, user, role, domain): + """deletes a role for a user inside a domain.""" + """Returns false if the user does not have any roles (aka not affected).""" + with self._wl: + return self._e.delete_roles_for_user_in_domain(user, role, domain) + + def get_permissions_for_user_in_domain(self, user, domain): + """gets permissions for a user or role inside domain.""" + with self._rl: + return self._e.get_permissions_for_user_in_domain(user, domain) + + def enable_auto_build_role_links(self, auto_build_role_links): + """controls whether to rebuild the role inheritance relations when a role is added or deleted.""" + with self._wl: + return self._e.enable_auto_build_role_links(auto_build_role_links) + + def enable_auto_save(self, auto_save): + """controls whether to save a policy rule automatically to the adapter when it is added or removed.""" + with self._wl: + return self._e.enable_auto_save(auto_save) + + def enable_enforce(self, enabled=True): + """changes the enforcing state of Casbin, + when Casbin is disabled, all access will be allowed by the Enforce() function. + """ + with self._wl: + return self._e.enable_enforce(enabled) + + def add_named_matching_func(self, ptype, fn): + """add_named_matching_func add MatchingFunc by ptype RoleManager""" + with self._wl: + self._e.add_named_matching_func(ptype, fn) + + def add_named_domain_matching_func(self, ptype, fn): + """add_named_domain_matching_func add MatchingFunc by ptype to RoleManager""" + with self._wl: + self._e.add_named_domain_matching_func(ptype, fn) + + def is_filtered(self): + """returns true if the loaded policy has been filtered.""" + with self._rl: + self._e.is_filtered() + + def add_policies(self,rules): + """adds authorization rules to the current policy. + + If the rule already exists, the function returns false for the corresponding rule and the rule will not be added. + Otherwise the function returns true for the corresponding rule by adding the new rule. + """ + with self._wl: + return self._e.add_policies(rules) + + def add_named_policies(self,ptype,rules): + """adds authorization rules to the current named policy. + + If the rule already exists, the function returns false for the corresponding rule and the rule will not be added. + Otherwise the function returns true for the corresponding by adding the new rule.""" + with self._wl: + return self._e.add_named_policies(ptype,rules) + + def remove_policies(self,rules): + """removes authorization rules from the current policy.""" + with self._wl: + return self._e.remove_policies(rules) + + def remove_named_policies(self,ptype,rules): + """removes authorization rules from the current named policy.""" + with self._wl: + return self._e.remove_named_policies(ptype,rules) + + def add_grouping_policies(self,rules): + """adds role inheritance rulea to the current policy. + + If the rule already exists, the function returns false for the corresponding policy rule and the rule will not be added. + Otherwise the function returns true for the corresponding policy rule by adding the new rule. + """ + with self._wl: + return self._e.add_grouping_policies(rules) + + def add_named_grouping_policies(self,ptype,rules): + """"adds named role inheritance rules to the current policy. + + If the rule already exists, the function returns false for the corresponding policy rule and the rule will not be added. + Otherwise the function returns true for the corresponding policy rule by adding the new rule.""" + with self._wl: + return self._e.add_named_grouping_policies(ptype,rules) + + def remove_grouping_policies(self,rules): + """removes role inheritance rulea from the current policy.""" + with self._wl: + return self._e.addremove_grouping_policies_policies(rules) + + def remove_named_grouping_policies(self,ptype,rules): + """ removes role inheritance rules from the current named policy.""" + with self._wl: + return self._e.remove_named_grouping_policies(ptype,rules) + + def build_incremental_role_links(self, op, ptype, rules): + self.get_model().build_incremental_role_links(self.get_role_manager(), op, "g", ptype, rules) \ No newline at end of file diff --git a/utils/casbin/util/__init__.py b/utils/casbin/util/__init__.py new file mode 100644 index 0000000..a40b64e --- /dev/null +++ b/utils/casbin/util/__init__.py @@ -0,0 +1,3 @@ +from .builtin_operators import * +from .expression import * +from .util import * \ No newline at end of file diff --git a/utils/casbin/util/builtin_operators.py b/utils/casbin/util/builtin_operators.py new file mode 100644 index 0000000..509eba4 --- /dev/null +++ b/utils/casbin/util/builtin_operators.py @@ -0,0 +1,137 @@ +import fnmatch +import re +import ipaddress + +KEY_MATCH2_PATTERN = re.compile(r'(.*?):[^\/]+(.*?)') +KEY_MATCH3_PATTERN = re.compile(r'(.*?){[^\/]+}(.*?)') + + +def key_match(key1, key2): + """determines whether key1 matches the pattern of key2 (similar to RESTful path), key2 can contain a *. + For example, "/foo/bar" matches "/foo/*" + """ + + i = key2.find("*") + if i == -1: + return key1 == key2 + + if len(key1) > i: + return key1[:i] == key2[:i] + return key1 == key2[:i] + + +def key_match_func(*args): + """The wrapper for key_match. + """ + name1 = args[0] + name2 = args[1] + + return key_match(name1, name2) + + +def key_match2(key1, key2): + """determines whether key1 matches the pattern of key2 (similar to RESTful path), key2 can contain a *. + For example, "/foo/bar" matches "/foo/*", "/resource1" matches "/:resource" + """ + + key2 = key2.replace("/*", "/.*") + key2 = KEY_MATCH2_PATTERN.sub(r'\g<1>[^\/]+\g<2>', key2, 0) + + return regex_match(key1, "^" + key2 + "$") + + +def key_match2_func(*args): + name1 = args[0] + name2 = args[1] + + return key_match2(name1, name2) + + +def key_match3(key1, key2): + """determines determines whether key1 matches the pattern of key2 (similar to RESTful path), key2 can contain a *. + For example, "/foo/bar" matches "/foo/*", "/resource1" matches "/{resource}" + """ + + key2 = key2.replace("/*", "/.*") + key2 = KEY_MATCH3_PATTERN.sub(r'\g<1>[^\/]+\g<2>', key2, 0) + + return regex_match(key1, "^" + key2 + "$") + + +def key_match3_func(*args): + name1 = args[0] + name2 = args[1] + + return key_match3(name1, name2) + + +def regex_match(key1, key2): + """determines whether key1 matches the pattern of key2 in regular expression.""" + + res = re.match(key2, key1) + if res: + return True + else: + return False + + +def regex_match_func(*args): + """the wrapper for RegexMatch.""" + + name1 = args[0] + name2 = args[1] + + return regex_match(name1, name2) + + +def glob_match(string, pattern): + """determines whether string matches the pattern in glob expression.""" + return fnmatch.fnmatch(string, pattern) + + +def glob_match_func(*args): + """the wrapper for globMatch.""" + + string = args[0] + pattern = args[1] + + return glob_match(string, pattern) + + +def ip_match(ip1, ip2): + """IPMatch determines whether IP address ip1 matches the pattern of IP address ip2, ip2 can be an IP address or a CIDR pattern. + For example, "192.168.2.123" matches "192.168.2.0/24" + """ + ip1 = ipaddress.ip_address(ip1) + try: + network = ipaddress.ip_network(ip2, strict=False) + return ip1 in network + except ValueError: + return ip1 == ip2 + + +def ip_match_func(*args): + """the wrapper for IPMatch.""" + + ip1 = args[0] + ip2 = args[1] + + return ip_match(ip1, ip2) + + +def generate_g_function(rm): + """the factory method of the g(_, _) function.""" + + def f(*args): + name1 = args[0] + name2 = args[1] + + if not rm: + return name1 == name2 + elif 2 == len(args): + return rm.has_link(name1, name2) + else: + domain = str(args[2]) + return rm.has_link(name1, name2, domain) + + return f diff --git a/utils/casbin/util/expression.py b/utils/casbin/util/expression.py new file mode 100644 index 0000000..cd9a20d --- /dev/null +++ b/utils/casbin/util/expression.py @@ -0,0 +1,29 @@ +from simpleeval import SimpleEval +import ast + + +class SimpleEval(SimpleEval): + """ Rewrite SimpleEval. + >>> s = SimpleEval("20 + 30 - ( 10 * 5)") + >>> s.eval() + 0 + """ + + ast_parsed_value = None + + def __init__(self, expr, functions=None): + """Create the evaluator instance. Set up valid operators (+,-, etc) + functions (add, random, get_val, whatever) and names. """ + super(SimpleEval, self).__init__(functions=functions) + if expr != "": + self.expr = expr + self.expr_parsed_value = ast.parse(expr.strip()).body[0].value + + def eval(self, names=None): + """ evaluate an expresssion, using the operators, functions and + names previously set up. """ + + if names: + self.names = names + + return self._eval(self.expr_parsed_value) diff --git a/utils/casbin/util/rwlock.py b/utils/casbin/util/rwlock.py new file mode 100644 index 0000000..2f54b90 --- /dev/null +++ b/utils/casbin/util/rwlock.py @@ -0,0 +1,68 @@ +from threading import RLock, Condition + +# This implementation was adapted from https://en.wikipedia.org/wiki/Readers%E2%80%93writer_lock + +class RWLockWrite(): + ''' write preferring readers-wirter lock ''' + + def __init__(self): + self._lock = RLock() + self._cond = Condition(self._lock) + self._active_readers = 0 + self._waiting_writers = 0 + self._writer_active = False + + def aquire_read(self): + with self._lock: + while self._waiting_writers > 0 or self._writer_active: + self._cond.wait() + self._active_readers += 1 + + def release_read(self): + with self._lock: + self._active_readers -= 1 + if self._active_readers == 0: + self._cond.notify_all() + + def aquire_write(self): + with self._lock: + self._waiting_writers += 1 + while self._active_readers > 0 or self._writer_active: + self._cond.wait() + self._waiting_writers -= 1 + self._writer_active = True + + def release_write(self): + with self._lock: + self._writer_active = False + self._cond.notify_all() + + def gen_rlock(self): + return ReadRWLock(self) + + def gen_wlock(self): + return WriteRWLock(self) + +class ReadRWLock(): + + def __init__(self, rwlock): + self.rwlock = rwlock + + def __enter__(self): + self.rwlock.aquire_read() + + def __exit__(self, exc_type, exc_value, traceback): + self.rwlock.release_read() + return False + +class WriteRWLock(): + + def __init__(self, rwlock): + self.rwlock = rwlock + + def __enter__(self): + self.rwlock.aquire_write() + + def __exit__(self, exc_type, exc_value, traceback): + self.rwlock.release_write() + return False diff --git a/utils/casbin/util/util.py b/utils/casbin/util/util.py new file mode 100644 index 0000000..baa0d56 --- /dev/null +++ b/utils/casbin/util/util.py @@ -0,0 +1,72 @@ +from collections import OrderedDict +import re + +eval_reg = re.compile(r'\beval\((?P[^)]*)\)') + +def escape_assertion(s): + """escapes the dots in the assertion, because the expression evaluation doesn't support such variable names.""" + + s = re.sub(r'\br\.', 'r_', s) + s = re.sub(r'\bp\.', 'p_', s) + + return s + + +def remove_comments(s): + """removes the comments starting with # in the text.""" + + pos = s.find("#") + if pos == -1: + return s + + return s[0:pos].strip() + + +def array_remove_duplicates(s): + """removes any duplicated elements in a string array.""" + return list(OrderedDict.fromkeys(s)) + + +def array_to_string(s): + """gets a printable string for a string array.""" + + return ", ".join(s) + + +def params_to_string(*s): + """gets a printable string for variable number of parameters.""" + + return ", ".join(s) + +def join_slice(a, *b): + ''' joins a string and a slice into a new slice.''' + res = [a] + + res.extend(b) + + return res + +def set_subtract(a, b): + ''' returns the elements in `a` that aren't in `b`. ''' + return [i for i in a if i not in b] + +def has_eval(s): + '''determine whether matcher contains function eval''' + return eval_reg.search(s) + +def replace_eval(expr, rules): + ''' replace all occurences of function eval with rules ''' + pos = 0 + match = eval_reg.search(expr, pos) + while match: + rule = "(" + rules.pop(0) + ")" + expr = expr[:match.start()] + rule + expr[match.end():] + pos = match.start() + len(rule) + match = eval_reg.search(expr, pos) + + return expr + +def get_eval_value(s): + '''returns the parameters of function eval''' + sub_match = eval_reg.findall(s) + return sub_match.copy()