casbin多租户模型

This commit is contained in:
wuaho 2021-05-15 15:38:02 +08:00
parent d4b49cbb2b
commit 182799e6f7
56 changed files with 3434 additions and 77 deletions

View File

@ -4,3 +4,13 @@ x后端 使用 mongo 数据库
### casbin 接口权限控制
修改 utils\casbin\model\assertion.py 20行
为了能定义角色访问所有域
```python
for rule in self.policy:
if len(rule) < count:
pass
# raise RuntimeError("grouping policy elements do not meet role definition")
if len(rule) > count:
rule = rule[:count]
```

View File

@ -13,29 +13,41 @@ router = APIRouter()
@router.get("/api_list")
async def api_list(request: Request, game: str,
async def api_list(request: Request,
current_user: schemas.UserDB = Depends(deps.get_current_user)) -> schemas.Msg:
"""api 列表"""
app = request.app
data = []
data = {}
for r in app.routes:
title = r.tags[0] if hasattr(r, 'description') else None
if not title:
continue
data.setdefault(title, {'list': []})
path = r.path
name = r.description if hasattr(r, 'description') else r.name
data.append({'api': path, 'name': name})
return schemas.Msg(code=0, msg='ok', data=data)
data[title]['list'].append({'api': path, 'title': name})
res = [{'title': k, 'list': v['list']} for k, v in data.items()]
return schemas.Msg(code=0, msg='ok', data=res)
@router.post("/add_role")
async def add_role(request: Request, game: str, data_in: schemas.CasbinRoleCreate,
async def add_role(request: Request,
data_in: schemas.CasbinRoleCreate,
game: str = Depends(deps.get_game_project),
db: AsyncIOMotorDatabase = Depends(get_database),
current_user: schemas.UserDB = Depends(deps.get_current_user)
) -> schemas.Msg:
"""创建角色"""
role_dom = game
api_dict = dict()
for r in request.app.routes:
api_dict[r.path] = r.description if hasattr(r, 'description') else r.name
# 角色有的接口权限
for obj in data_in.role_api:
casbin_enforcer.add_policy(data_in.role_name, role_dom, obj, '*')
await crud.authority.create(db, 'p', data_in.role_name, role_dom, obj, '*')
await crud.authority.create(db, 'p', data_in.role_name, role_dom, obj, '*', api_name=api_dict.get(obj))
# 管理员默认拥有该角色 方便从db中读出
await crud.authority.create(db, 'g', settings.SUPERUSER_NAME, data_in.role_name, '*', '*',
@ -45,38 +57,93 @@ async def add_role(request: Request, game: str, data_in: schemas.CasbinRoleCreat
return schemas.Msg(code=0, msg='ok')
@router.post("/add_sys_role")
async def add_sys_role(request: Request,
data_in: schemas.CasbinRoleCreate,
game: str = Depends(deps.get_game_project),
db: AsyncIOMotorDatabase = Depends(get_database),
current_user: schemas.UserDB = Depends(deps.get_current_user)
) -> schemas.Msg:
"""创建系统角色"""
api_dict = dict()
for r in request.app.routes:
api_dict[r.path] = r.description if hasattr(r, 'description') else r.name
# 角色有的接口权限
for obj in data_in.role_api:
casbin_enforcer.add_policy(data_in.role_name, '*', obj, '*')
await crud.authority.create(db, 'p', data_in.role_name, '*', obj, '*', api_name=api_dict.get(obj))
# 管理员默认拥有该角色 方便从db中读出
await crud.authority.create(db, 'g', settings.SUPERUSER_NAME, data_in.role_name,
role_name=data_in.role_name,
game='*')
return schemas.Msg(code=0, msg='ok')
@router.post("/add_account")
async def add_account(request: Request,
game: str,
data_in: schemas.AccountCreate,
data_in: schemas.AccountsCreate,
game: str = Depends(deps.get_game_project),
db: AsyncIOMotorDatabase = Depends(get_database),
current_user: schemas.UserDB = Depends(deps.get_current_user)
) -> schemas.Msg:
"""创建账号 并设置角色"""
account = schemas.UserCreate(name=data_in.username, nickname=data_in.nickname, password=settings.DEFAULT_PASSWORD)
for item in data_in.accounts:
account = schemas.UserCreate(name=item.username, password=settings.DEFAULT_PASSWORD)
try:
await crud.user.create(db, account)
except pymongo.errors.DuplicateKeyError:
return schemas.Msg(code=-1, msg='用户名已存在')
casbin_enforcer.add_grouping_policy(data_in.username, data_in.role_name, game)
await crud.authority.create(db, 'g', data_in.username, data_in.role_name, game)
casbin_enforcer.add_grouping_policy(item.username, item.role_name, game)
await crud.authority.create(db, 'g', item.username, item.role_name, game)
return schemas.Msg(code=0, msg='ok')
@router.get("/data_authority")
async def data_authority(request: Request,
db: AsyncIOMotorDatabase = Depends(get_database),
game: str = Depends(deps.get_game_project),
current_user: schemas.UserDB = Depends(deps.get_current_user)
) -> schemas.Msg:
"""获取数据权限"""
# todo 这是假数据
data = [{'title': '全部事件', 'check_event_num': 100, 'total_event_num': 100, 'update_time': '2021-05-12 18:49:19'}]
return schemas.Msg(code=0, msg='ok', data=data)
@router.get("/all_role")
async def all_role(request: Request,
game: str,
db: AsyncIOMotorDatabase = Depends(get_database),
game: str = Depends(deps.get_game_project),
current_user: schemas.UserDB = Depends(deps.get_current_user)
) -> schemas.Msg:
"""获取所有角色"""
roles = await crud.authority.find_many(db, role_name={'$exists': 1}, game=game)
data = [{'role': item['v1'], 'name': item['role_name']} for item in roles]
return schemas.Msg(code=0, msg='ok', data=data)
"""获取域内所有角色"""
roles = await crud.authority.find_many(db, role_name={'$exists': 1}, game=game)
dom_data = [{'role': item['v1'], 'name': item['role_name']} for item in roles]
for item in dom_data:
q = await crud.authority.get_role_dom_authority(db, item['role'], game)
item['authority'] = q
# 获取系统角色
roles = await crud.authority.find_many(db, role_name={'$exists': 1}, game='*')
sys_data = [{'role': item['v1'], 'name': item['role_name']} for item in roles]
for item in sys_data:
q = await crud.authority.get_role_dom_authority(db, item['role'], dom='*')
item['authority'] = q
data = {
'dom_role': dom_data,
'sys_role': sys_data
}
return schemas.Msg(code=0, msg='ok', data=data)
# @router.get("/all_role")
# async def all_role(request: Request,
@ -112,18 +179,18 @@ async def all_role(request: Request,
# return schemas.Msg(code=0, msg='ok', data={'roles': roles, 'permissions': permissions})
@router.post("/set_role")
async def set_role(request: Request,
data_id: schemas.AccountSetRole,
db: AsyncIOMotorDatabase = Depends(get_database),
current_user: schemas.UserDB = Depends(deps.get_current_user)
) -> schemas.Msg:
"""设置账号角色"""
casbin_enforcer.delete_user(data_id.name)
casbin_enforcer.add_role_for_user(data_id.name, data_id.role_name)
await crud.authority.update_one(db, {'ptype': 'g', 'v0': data_id.name}, dict(v1=data_id.role_name))
return schemas.Msg(code=0, msg='ok')
# @router.post("/set_role")
# async def set_role(request: Request,
# data_id: schemas.AccountSetRole,
# db: AsyncIOMotorDatabase = Depends(get_database),
# current_user: schemas.UserDB = Depends(deps.get_current_user)
# ) -> schemas.Msg:
# """设置账号角色"""
# casbin_enforcer.delete_user(data_id.name)
# casbin_enforcer.add_role_for_user(data_id.name, data_id.role_name)
# await crud.authority.update_one(db, {'ptype': 'g', 'v0': data_id.name}, dict(v1=data_id.role_name))
#
# return schemas.Msg(code=0, msg='ok')
# @router.get("/delete_user")
# async def delete_user(request: Request,

View File

@ -6,6 +6,7 @@ from api import deps
from core.config import settings
from db import get_database
from db.ckdb import CKDrive, get_ck_db
from schemas.project import ProjectCreate
from utils import casbin_enforcer
@ -46,6 +47,46 @@ async def read_project(request: Request,
return schemas.Msg(code=0, msg='ok', data=res)
@router.post("/detail")
async def read_project(request: Request,
game: str,
data_in: schemas.ProjectDetail,
db: AsyncIOMotorDatabase = Depends(get_database),
ck: CKDrive = Depends(get_ck_db),
current_user: schemas.UserDB = Depends(deps.get_current_user)
):
"""查看项目信息"""
res = await crud.project.read_project(db, user_id=current_user.id, _id=data_in.project_id)
if res:
res = res[0]
event_count = await ck.count(game, 'event')
user_count = await ck.count(game, 'user_view')
event_type_count = await ck.distinct_count(game, 'event')
event_attr_count = await ck.field_count(db=game, tb='event')
user_attr_count = await ck.field_count(db=game, tb='user_view')
res['event_count'] = event_count
res['user_count'] = user_count
res['event_type_count'] = event_type_count
res['event_attr_count'] = event_attr_count
res['user_attr_count'] = user_attr_count
return schemas.Msg(code=0, msg='ok', data=res)
@router.post("/rename")
async def rename_project(request: Request,
data_in: schemas.ProjectRename,
db: AsyncIOMotorDatabase = Depends(get_database),
current_user: schemas.UserDB = Depends(deps.get_current_user)
):
"""修改项目名"""
try:
res = await crud.project.rename(db, data_in)
except pymongo.errors.DuplicateKeyError:
return schemas.Msg(code=-1, msg='项目名已存在')
return schemas.Msg(code=0, msg='ok', data=res)
@router.post("/add_members")
async def add_members(request: Request,
game: str,
@ -56,7 +97,7 @@ async def add_members(request: Request,
"""项目添加成员"""
#
# await crud.project.add_members(db, data_in)
await crud.project.add_members(db, data_in)
for item in data_in.members:
casbin_enforcer.add_grouping_policy(item.username, item.role, game)
await crud.authority.create(db, 'g', item.username, item.role, game)
@ -70,10 +111,18 @@ async def members(request: Request,
current_user: schemas.UserDB = Depends(deps.get_current_user)
):
"""查看项目成员"""
res = await crud.authority.find_many(db, ptype='g', v2=game)
data = [{'name': 'root', 'role': '超级管理员'}]
data += [{'name': item['v0'], 'role': item['v1']} for item in res]
return schemas.Msg(code=0, msg='ok', data=data)
roles = await crud.authority.find_many(db, ptype='g', v2=game)
data = {item['v0']: {'name': item['v0'], 'role': item['v1']} for item in roles}
data['root'] = {'name': 'root', 'role': '超级管理员'}
users = await crud.user.get_by_users(db, name={'$in': list(data.keys())})
res = []
for user in users.data:
res.append({
**user.dict(),
'role': data[user.name]['role']
})
return schemas.Msg(code=0, msg='ok', data=res)
@router.post("/del_member")
@ -85,6 +134,7 @@ async def members(request: Request,
):
"""删除项目成员"""
casbin_enforcer.delete_user(data_in.username)
await crud.project.del_members(data_in)
await crud.authority.delete(db, ptype='g', v2=game, v0=data_in.username)
return schemas.Msg(code=0, msg='ok')

View File

@ -1,12 +1,9 @@
import json
import aioch
import pandas as pd
from fastapi import APIRouter, Depends, Request
import crud, schemas
from api import deps
from db.ckdb import get_ck_db
from db.ckdb import get_ck_db, CKDrive
router = APIRouter()
@ -15,20 +12,19 @@ router = APIRouter()
async def query_sql(
request: Request,
data_in: schemas.Sql,
ckdb: aioch.Client = Depends(get_ck_db),
ckdb: CKDrive = Depends(get_ck_db),
current_user: schemas.UserDB = Depends(deps.get_current_user)
) -> schemas.Msg:
"""原 sql 查询 """
data, columns = await ckdb.execute(data_in.sql, with_column_types=True, columnar=True)
df = pd.DataFrame({col[0]: d for d, col in zip(data, columns)})
return schemas.Msg(code=0, msg='ok', data=df.to_dict())
data = await ckdb.execute(data_in.sql)
return schemas.Msg(code=0, msg='ok', data=data)
@router.post("/query")
async def query(
request: Request,
data_in: schemas.CkQuery,
ckdb: aioch.Client = Depends(get_ck_db),
ckdb: CKDrive = Depends(get_ck_db),
current_user: schemas.UserDB = Depends(deps.get_current_user)
) -> schemas.Msg:
""" json解析 sql 查询"""

View File

@ -31,6 +31,7 @@ async def login(
# raise HTTPException(status_code=400, detail="Incorrect name or password")
return schemas.Msg(code=-1, msg='密码或用户名错误')
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
await crud.user.update_login_time(db, data.username)
return {
'data': {
@ -39,6 +40,7 @@ async def login(
'email': user.email,
'token': security.create_access_token(
expires_delta=access_token_expires, _id=str(user.id), email=user.email,
nickname=user.nickname,
is_superuser=user.is_superuser, name=user.name
),
"token_type": "bearer",
@ -64,6 +66,20 @@ def me(current_user: schemas.User = Depends(deps.get_current_user)) -> Any:
return current_user
@router.post("/reset_password")
async def reset_password(request: Request,
game: str,
data_in: schemas.UserRestPassword,
db: AsyncIOMotorDatabase = Depends(get_database),
current_user: schemas.User = Depends(deps.get_current_user)
) -> Any:
"""
修改密码
"""
await crud.user.reset_password(db, data_in)
return schemas.Msg(code=0, msg='ok')
@router.get("/all_account")
async def all_account(page: int = 1, limit: int = 10, db: AsyncIOMotorDatabase = Depends(get_database),
current_user: schemas.User = Depends(deps.get_current_user)

View File

@ -1,11 +1,15 @@
from fastapi import Depends, HTTPException, status
from fastapi import Depends, status, HTTPException
from fastapi.security import OAuth2PasswordBearer
from jose import jwt
from motor.motor_asyncio import AsyncIOMotorDatabase
from pydantic import ValidationError
import crud
import schemas
import utils
from core import security
from core.config import settings
from db import get_database
reusable_oauth2 = OAuth2PasswordBearer(
tokenUrl=f"{settings.API_V1_STR}/user/login"
@ -45,3 +49,13 @@ def get_current_user2(token: str) -> schemas.UserDB:
if not user:
raise HTTPException(status_code=404, detail="User not found")
return user
async def get_game_project(game: str, db: AsyncIOMotorDatabase = Depends(get_database)) -> str:
is_exists = await crud.project.find_one(db, {'game': game}, {'_id': True})
if not is_exists:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail='没有该项目'
)
return game

View File

@ -10,6 +10,9 @@ class CRUDBase:
async def get(self, db, id: Union[ObjectId, str]) -> dict:
return (await db[self.coll_name].find_one({'_id': ObjectId(id)})) or dict()
async def find_one(self, db, filter=None, *args, **kwargs):
return (await db[self.coll_name].find_one(filter, *args, **kwargs)) or dict()
async def read_have(self, db, user_id: str, **kwargs):
where = {'members': user_id}
where.update(kwargs)

View File

@ -29,6 +29,19 @@ class CRUDAuthority(CRUDBase):
data.update(kwargs)
await self.update_one(db, data, {'$set': data}, upsert=True)
async def get_all_dom_role(self, db, dom):
pass
async def get_role_dom_authority(self, db, role, dom):
data = await self.find_many(db, v0=role, v1=dom)
res = []
for item in data:
res.append({
'api': item['v2'],
'api_name': item.get('api_name', item['v2'])
})
return res
async def create_index(self, db: AsyncIOMotorDatabase):
await db[self.coll_name].create_index(
[('ptype', pymongo.DESCENDING), ('v0', pymongo.DESCENDING), ('v1', pymongo.DESCENDING),

View File

@ -15,15 +15,24 @@ class CRUDProject(CRUDBase):
)
return await db[self.coll_name].insert_one(db_obj.dict(by_alias=True))
async def read_project(self, db: AsyncIOMotorDatabase, user_id: str):
return await self.read_have(db, user_id=user_id)
async def read_project(self, db: AsyncIOMotorDatabase, user_id: str, **kwargs):
return await self.read_have(db, user_id=user_id, **kwargs)
# async def add_members(self, db: AsyncIOMotorDatabase, obj_in: ProjectMember):
# p = await self.get(db, obj_in.project_id)
# members = list(set(p.get('members')) | set(obj_in.members))
# await self.update_one(db, {'_id': obj_in.project_id}, {'$set': {'members': members}})
async def add_members(self, db: AsyncIOMotorDatabase, obj_in: ProjectMember):
p = await self.get(db, obj_in.project_id)
members = list(set(p.get('members')) | set(obj_in.members))
await self.update_one(db, {'_id': obj_in.project_id}, {'$set': {'members': members}})
async def del_members(self, db: AsyncIOMotorDatabase, obj_in: ProjectMember):
p = await self.get(db, obj_in.project_id)
members = list(set(p.get('members')) - set(obj_in.members))
await self.update_one(db, {'_id': obj_in.project_id}, {'$set': {'members': members}})
async def rename(self, db: AsyncIOMotorDatabase, obj_in: ProjectRename):
await self.update_one(db, {'_id': obj_in.project_id}, {'$set': {'name': obj_in.rename}})
async def create_index(self, db: AsyncIOMotorDatabase):
await db[self.coll_name].create_index('game', unique=True)
await db[self.coll_name].create_index('name', unique=True)

View File

@ -1,7 +1,10 @@
import datetime
import time
import uuid
from motor.motor_asyncio import AsyncIOMotorDatabase
import schemas
from core.security import get_password_hash, verify_password
from crud.base import CRUDBase
from schemas import UserCreate, UserDBRW
@ -15,6 +18,11 @@ class CRUDUser(CRUDBase):
res = await db[self.coll_name].find_one({'name': name})
return res
async def update_login_time(self, db, name):
await self.update_one(db, {'name': name},
{'$set': {'last_login_ts': datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}})
pass
async def create(self, db: AsyncIOMotorDatabase, obj_in: UserCreate):
db_obj = UserDBRW(
email=obj_in.email,
@ -26,6 +34,10 @@ class CRUDUser(CRUDBase):
)
return await db[self.coll_name].insert_one(db_obj.dict(by_alias=True))
async def reset_password(self, db: AsyncIOMotorDatabase, obj_in: schemas.UserRestPassword):
hashed_password = get_password_hash(obj_in.password)
await self.update_one(db, {'name': obj_in.username}, {'hashed_password': hashed_password})
async def authenticate(self, db: AsyncIOMotorDatabase, name: str, password: str):
user_obj = UserDBRW(**await self.get_by_user(db, name=name))
if not user_obj:
@ -34,6 +46,10 @@ class CRUDUser(CRUDBase):
return None
return user_obj
async def get_by_users(self, db, **kwargs) -> schemas.Users:
res = await self.find_many(db, **kwargs)
return schemas.Users(data=res)
async def create_index(self, db: AsyncIOMotorDatabase):
await db[self.coll_name].create_index('name', unique=True)

View File

@ -1,12 +1,38 @@
from aioch import Client
import pandas as pd
class CKBase:
class CKDrive:
client: Client = None
async def execute(self, sql) -> dict:
data, columns = await self.client.execute(sql, with_column_types=True, columnar=True)
df = pd.DataFrame({col[0]: d for d, col in zip(data, columns)})
return df.T.to_dict()
ckdb = CKBase()
async def query_dataframe(self, sql):
data, columns = await self.client.execute(sql, with_column_types=True, columnar=True)
df = pd.DataFrame({col[0]: d for d, col in zip(data, columns)})
return df
async def count(self, db: str, tb: str):
sql = f'select count() as `count` from {db}.{tb}'
res = await self.execute(sql)
return res[0]['count']
async def distinct_count(self, db: str, tb: str):
sql = f'select count(distinct `#event_name`) as `count` from {db}.{tb}'
res = await self.execute(sql)
return res[0]['count']
async def field_count(self, db: str, tb: str):
sql = f"select count(name) as `count` from system.columns where database='{db}' and table='{tb}'"
res = await self.execute(sql)
return res[0]['count']
def get_ck_db() -> Client:
return ckdb.client
ckdb = CKDrive()
def get_ck_db() -> CKDrive:
return ckdb

View File

@ -1,12 +1,12 @@
from aioch import Client
from core.config import settings
from .ckdb import ckdb
from .ckdb import CKDrive
async def connect_to_ck():
ckdb.client = Client(**settings.CK_CONFIG)
CKDrive.client = Client(**settings.CK_CONFIG)
async def close_ck_connection():
await ckdb.client.disconnect()
await CKDrive.client.disconnect()

View File

@ -1,4 +1,4 @@
from casbin.enforcer import Enforcer
from utils.casbin.enforcer import Enforcer
from fastapi import HTTPException
from starlette.authentication import BaseUser
from starlette.requests import Request

View File

@ -11,4 +11,4 @@ g = _, _, _
e = some(where (p.eft == allow))
[matchers]
m = g(r.sub, p.sub, r.dom) && (p.dom=="*" || r.dom == p.dom) && ( p.obj=="*" || r.obj == p.obj) && (p.act=="*" || r.act == p.act) || r.sub=="root"
m = (g(r.sub, p.sub) || g(r.sub, p.sub, r.dom)) && (p.dom=="*" || r.dom == p.dom) && ( p.obj=="*" || r.obj == p.obj) && (p.act=="*" || r.act == p.act) || r.sub=="root"

View File

@ -11,7 +11,7 @@ class Ptype(str, Enum):
class CasbinRoleCreate(BaseModel):
role_name: str
role_api: List
role_api: List[str]
class CasbinDB(BaseModel):
@ -24,7 +24,12 @@ class CasbinDB(BaseModel):
class AccountCreate(BaseModel):
username: str
role_name: str
nickname: str
# nickname: str
data_authority: str
class AccountsCreate(BaseModel):
accounts: List[AccountCreate]
class AccountDeleteUser(BaseModel):

View File

@ -18,7 +18,16 @@ class MemberRole(BaseModel):
class ProjectMember(BaseModel):
members: List[MemberRole]
# project_id: str
project_id: str
class ProjectDetail(BaseModel):
project_id: str
class ProjectRename(BaseModel):
project_id: str
rename: str
class ProjectDelMember(BaseModel):

View File

@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, List
from schemas.base import DBBase
@ -9,22 +9,31 @@ class UserBase(BaseModel):
email: Optional[EmailStr] = None
is_superuser: bool = False
name: Optional[str] = None
nickname: str
nickname: str = ''
last_login_ts: str = '尚未登录'
class User(UserBase):
name: str
class Users(BaseModel):
data: List[User] = []
class UserLogin(BaseModel):
username: str = ...
password: str = ...
class UserRestPassword(BaseModel):
username: str = ...
password: str = ...
class UserCreate(UserBase):
password: str
name: str
nickname: str
# ****************************************************************************
@ -36,6 +45,7 @@ class UserDB(DBBase):
is_superuser: bool = False
name: str
nickname: str
last_login_ts: str = '尚未登录'
class UserDBRW(UserDB):

View File

@ -1 +1,3 @@
from .adapter import *
from . import casbin

View File

@ -1,5 +1,5 @@
import casbin
from casbin import persist
from utils import casbin
from utils.casbin import persist
from pymongo import MongoClient
from core.config import settings
@ -7,6 +7,7 @@ from core.config import settings
__all__ = 'casbin_adapter', 'casbin_enforcer', 'casbin_model'
class CasbinRule:
'''
CasbinRule model

View File

@ -1,5 +1,5 @@
from casbin import persist
from .casbin import persist
class CasbinRule:

7
utils/casbin/__init__.py Normal file
View File

@ -0,0 +1,7 @@
from .enforcer import *
from .synced_enforcer import SyncedEnforcer
from .distributed_enforcer import DistributedEnforcer
from . import util
from .persist import *
from .effect import *
from .model import *

View File

@ -0,0 +1 @@
from .config import Config

View File

@ -0,0 +1,151 @@
from io import StringIO
class Config:
"""represents an implementation of the ConfigInterface"""
# DEFAULT_SECTION specifies the name of a section if no name provided
DEFAULT_SECTION = 'default'
# DEFAULT_COMMENT defines what character(s) indicate a comment `#`
DEFAULT_COMMENT = '#'
# DEFAULT_COMMENT_SEM defines what alternate character(s) indicate a comment `;`
DEFAULT_COMMENT_SEM = ';'
# DEFAULT_MULTI_LINE_SEPARATOR defines what character indicates a multi-line content
DEFAULT_MULTI_LINE_SEPARATOR = '\\'
_data = dict()
def __init__(self):
self._data = dict()
@staticmethod
def new_config(conf_name):
c = Config()
c._parse(conf_name)
return c
@staticmethod
def new_config_from_text(text):
c = Config()
f = StringIO(text)
c._parse_buffer(f)
return c
def add_config(self, section, option, value):
if section == '':
section = self.DEFAULT_SECTION
if section not in self._data.keys():
self._data[section] = {}
self._data[section][option] = value
def _parse(self, fname):
with open(fname, 'r', encoding='utf-8') as f:
self._parse_buffer(f)
def _parse_buffer(self, f):
section = ''
line_num = 0
buf = []
can_write = False
while True:
if can_write:
self._write(section, line_num, buf)
can_write = False
line_num = line_num + 1
line = f.readline()
if not line:
if len(buf) > 0:
self._write(section, line_num, buf)
break
line = line.strip()
if '' == line or self.DEFAULT_COMMENT == line[0:1] or self.DEFAULT_COMMENT_SEM == line[0:1]:
can_write = True
continue
elif '[' == line[0:1] and ']' == line[-1]:
if len(buf) > 0:
self._write(section, line_num, buf)
can_write = False
section = line[1:-1]
else:
p = ''
if self.DEFAULT_MULTI_LINE_SEPARATOR == line[-1]:
p = line[0:-1].strip()
p = p + ' '
else:
p = line
can_write = True
buf.append(p)
def _write(self, section, line_num, b):
buf = "".join(b)
if len(buf) <= 0:
return
option_val = buf.split('=', 1)
if len(option_val) != 2:
raise RuntimeError('parse the content error : line {} , {} = ?'.format(line_num, option_val[0]))
option = option_val[0].strip()
value = option_val[1].strip()
self.add_config(section, option, value)
del b[:]
def get_bool(self, key):
"""lookups up the value using the provided key and converts the value to a bool."""
return self.get(key).capitalize() == "True"
def get_int(self, key):
"""lookups up the value using the provided key and converts the value to a int"""
return int(self.get(key))
def get_float(self, key):
"""lookups up the value using the provided key and converts the value to a float"""
return float(self.get(key))
def get_string(self, key):
"""lookups up the value using the provided key and converts the value to a string"""
return self.get(key)
def get_strings(self, key):
"""lookups up the value using the provided key and converts the value to an array of string"""
value = self.get(key)
if value == "":
return None
return value.split(",")
def set(self, key, value):
if len(key) == 0:
raise RuntimeError("key is empty")
keys = key.lower().split('::')
if len(keys) >= 2:
section = keys[0]
option = keys[1]
else:
section = ""
option = keys[0]
self.add_config(section, option, value)
def get(self, key):
"""section.key or key"""
keys = key.lower().split('::')
if len(keys) >= 2:
section = keys[0]
option = keys[1]
else:
section = self.DEFAULT_SECTION
option = keys[0]
if section in self._data.keys():
if option in self._data[section].keys():
return self._data[section][option]
return ''

View File

@ -0,0 +1,373 @@
import logging
from utils.casbin.effect import Effector, get_effector, effect_to_bool
from utils.casbin.model import Model, FunctionMap
from utils.casbin.persist import Adapter
from utils.casbin.persist.adapters import FileAdapter
from utils.casbin.rbac import default_role_manager
from utils.casbin.util import generate_g_function, SimpleEval, util
class CoreEnforcer:
"""CoreEnforcer defines the core functionality of an enforcer."""
model_path = ""
model = None
fm = None
eft = None
adapter = None
watcher = None
rm_map = None
enabled = False
auto_save = False
auto_build_role_links = False
def __init__(self, model=None, adapter=None):
self.logger = logging.getLogger()
if isinstance(model, str):
if isinstance(adapter, str):
self.init_with_file(model, adapter)
else:
self.init_with_adapter(model, adapter)
pass
else:
if isinstance(adapter, str):
raise RuntimeError("Invalid parameters for enforcer.")
else:
self.init_with_model_and_adapter(model, adapter)
def init_with_file(self, model_path, policy_path):
"""initializes an enforcer with a model file and a policy file."""
a = FileAdapter(policy_path)
self.init_with_adapter(model_path, a)
def init_with_adapter(self, model_path, adapter=None):
"""initializes an enforcer with a database adapter."""
m = self.new_model(model_path)
self.init_with_model_and_adapter(m, adapter)
self.model_path = model_path
def init_with_model_and_adapter(self, m, adapter=None):
"""initializes an enforcer with a model and a database adapter."""
if not isinstance(m, Model) or adapter is not None and not isinstance(adapter, Adapter):
raise RuntimeError("Invalid parameters for enforcer.")
self.adapter = adapter
self.model = m
self.model.print_model()
self.fm = FunctionMap.load_function_map()
self._initialize()
# Do not initialize the full policy when using a filtered adapter
if self.adapter and not self.is_filtered():
self.load_policy()
def _initialize(self):
self.rm_map = dict()
self.eft = get_effector(self.model.model["e"]["e"].value)
self.watcher = None
self.enabled = True
self.auto_save = True
self.auto_build_role_links = True
self.init_rm_map()
@staticmethod
def new_model(path="", text=""):
"""creates a model."""
m = Model()
if len(path) > 0:
m.load_model(path)
else:
m.load_model_from_text(text)
return m
def load_model(self):
"""reloads the model from the model CONF file.
Because the policy is attached to a model, so the policy is invalidated and needs to be reloaded by calling LoadPolicy().
"""
self.model = self.new_model()
self.model.load_model(self.model_path)
self.model.print_model()
self.fm = FunctionMap.load_function_map()
def get_model(self):
"""gets the current model."""
return self.model
def set_model(self, m):
"""sets the current model."""
self.model = m
self.fm = FunctionMap.load_function_map()
def get_adapter(self):
"""gets the current adapter."""
return self.adapter
def set_adapter(self, adapter):
"""sets the current adapter."""
self.adapter = adapter
def set_watcher(self, watcher):
"""sets the current watcher."""
self.watcher = watcher
pass
def get_role_manager(self):
"""gets the current role manager."""
return self.rm_map['g']
def set_role_manager(self, rm):
"""sets the current role manager."""
self.rm_map['g'] = rm
def set_effector(self, eft):
"""sets the current effector."""
self.eft = eft
def clear_policy(self):
""" clears all policy."""
self.model.clear_policy()
def init_rm_map(self):
if 'g' in self.model.model.keys():
for ptype in self.model.model['g']:
self.rm_map[ptype] = default_role_manager.RoleManager(10)
def load_policy(self):
"""reloads the policy from file/database."""
self.model.clear_policy()
self.adapter.load_policy(self.model)
self.init_rm_map()
self.model.print_policy()
if self.auto_build_role_links:
self.build_role_links()
def load_filtered_policy(self, filter):
"""reloads a filtered policy from file/database."""
self.model.clear_policy()
if not hasattr(self.adapter, "is_filtered"):
raise ValueError("filtered policies are not supported by this adapter")
self.adapter.load_filtered_policy(self.model, filter)
self.init_rm_map()
self.model.print_policy()
if self.auto_build_role_links:
self.build_role_links()
def load_increment_filtered_policy(self, filter):
"""LoadIncrementalFilteredPolicy append a filtered policy from file/database."""
if not hasattr(self.adapter, "is_filtered"):
raise ValueError("filtered policies are not supported by this adapter")
self.adapter.load_filtered_policy(self.model, filter)
self.model.print_policy()
if self.auto_build_role_links:
self.build_role_links()
def is_filtered(self):
"""returns true if the loaded policy has been filtered."""
return hasattr(self.adapter, "is_filtered") and self.adapter.is_filtered()
def save_policy(self):
if self.is_filtered():
raise RuntimeError("cannot save a filtered policy")
self.adapter.save_policy(self.model)
if self.watcher:
self.watcher.update()
def enable_enforce(self, enabled=True):
"""changes the enforcing state of Casbin,
when Casbin is disabled, all access will be allowed by the Enforce() function.
"""
self.enabled = enabled
def enable_auto_save(self, auto_save):
"""controls whether to save a policy rule automatically to the adapter when it is added or removed."""
self.auto_save = auto_save
def enable_auto_build_role_links(self, auto_build_role_links):
"""controls whether to rebuild the role inheritance relations when a role is added or deleted."""
self.auto_build_role_links = auto_build_role_links
def build_role_links(self):
"""manually rebuild the role inheritance relations."""
for rm in self.rm_map.values():
rm.clear()
self.model.build_role_links(self.rm_map)
def add_named_matching_func(self, ptype, fn):
"""add_named_matching_func add MatchingFunc by ptype RoleManager"""
try:
self.rm_map[ptype].add_matching_func(fn)
return True
except:
return False
def add_named_domain_matching_func(self, ptype, fn):
"""add_named_domain_matching_func add MatchingFunc by ptype to RoleManager"""
try:
self.rm_map[ptype].add_domain_matching_func(fn)
return True
except:
return False
def enforce(self, *rvals):
"""decides whether a "subject" can access a "object" with the operation "action",
input parameters are usually: (sub, obj, act).
"""
result, _ = self.enforce_ex(*rvals)
return result
def enforce_ex(self, *rvals):
"""decides whether a "subject" can access a "object" with the operation "action",
input parameters are usually: (sub, obj, act).
return judge result with reason
"""
if not self.enabled:
return False
functions = self.fm.get_functions()
if "g" in self.model.model.keys():
for key, ast in self.model.model["g"].items():
rm = ast.rm
functions[key] = generate_g_function(rm)
if "m" not in self.model.model.keys():
raise RuntimeError("model is undefined")
if "m" not in self.model.model["m"].keys():
raise RuntimeError("model is undefined")
r_tokens = self.model.model["r"]["r"].tokens
p_tokens = self.model.model["p"]["p"].tokens
if len(r_tokens) != len(rvals):
raise RuntimeError("invalid request size")
exp_string = self.model.model["m"]["m"].value
has_eval = util.has_eval(exp_string)
if not has_eval:
expression = self._get_expression(exp_string, functions)
policy_effects = set()
r_parameters = dict(zip(r_tokens, rvals))
policy_len = len(self.model.model["p"]["p"].policy)
explain_index = -1
if not 0 == policy_len:
for i, pvals in enumerate(self.model.model["p"]["p"].policy):
if len(p_tokens) != len(pvals):
raise RuntimeError("invalid policy size")
p_parameters = dict(zip(p_tokens, pvals))
parameters = dict(r_parameters, **p_parameters)
if util.has_eval(exp_string):
rule_names = util.get_eval_value(exp_string)
rules = [util.escape_assertion(p_parameters[rule_name]) for rule_name in rule_names]
exp_with_rule = util.replace_eval(exp_string, rules)
expression = self._get_expression(exp_with_rule, functions)
result = expression.eval(parameters)
if isinstance(result, bool):
if not result:
policy_effects.add(Effector.INDETERMINATE)
continue
elif isinstance(result, float):
if 0 == result:
policy_effects.add(Effector.INDETERMINATE)
continue
else:
raise RuntimeError("matcher result should be bool, int or float")
if "p_eft" in parameters.keys():
eft = parameters["p_eft"]
if "allow" == eft:
policy_effects.add(Effector.ALLOW)
elif "deny" == eft:
policy_effects.add(Effector.DENY)
else:
policy_effects.add(Effector.INDETERMINATE)
else:
policy_effects.add(Effector.ALLOW)
if self.eft.intermediate_effect(policy_effects) != Effector.INDETERMINATE:
explain_index = i
break
else:
if has_eval:
raise RuntimeError("please make sure rule exists in policy when using eval() in matcher")
parameters = r_parameters.copy()
for token in self.model.model["p"]["p"].tokens:
parameters[token] = ""
result = expression.eval(parameters)
if result:
policy_effects.add(Effector.ALLOW)
else:
policy_effects.add(Effector.INDETERMINATE)
final_effect = self.eft.final_effect(policy_effects)
result = effect_to_bool(final_effect)
# Log request.
req_str = "Request: "
req_str = req_str + ", ".join([str(v) for v in rvals])
req_str = req_str + " ---> %s" % result
if result:
self.logger.info(req_str)
else:
# leaving this in error for now, if it's very noise this can be changed to info or debug
self.logger.error(req_str)
explain_rule = []
if explain_index != -1 and explain_index < policy_len:
explain_rule = self.model.model["p"]["p"].policy[explain_index]
return result, explain_rule
@staticmethod
def _get_expression(expr, functions=None):
expr = expr.replace("&&", "and")
expr = expr.replace("||", "or")
expr = expr.replace("!", "not")
return SimpleEval(expr, functions)

View File

@ -0,0 +1,132 @@
from utils.casbin import SyncedEnforcer
import logging
from utils.casbin.persist import batch_adapter
from utils.casbin.model.policy_op import PolicyOp
from utils.casbin.persist.adapters import update_adapter
class DistributedEnforcer(SyncedEnforcer):
"""DistributedEnforcer wraps SyncedEnforcer for dispatcher."""
def __init__(self, model=None, adapter=None):
self.logger = logging.getLogger()
SyncedEnforcer.__init__(self, model, adapter)
def add_policy_self(self, should_persist, sec, ptype, rules):
"""
AddPolicySelf provides a method for dispatcher to add authorization rules to the current policy.
The function returns the rules affected and error.
"""
no_exists_policy = []
for rule in rules:
if not self.get_model().has_policy(sec, ptype, rule):
no_exists_policy.append(rule)
if should_persist:
try:
if isinstance(self.adapter, batch_adapter):
self.adapter.add_policies(sec, ptype, rules)
except Exception as e:
self.logger.log("An error occurred: " + e)
self.get_model().add_policies(sec, ptype, no_exists_policy)
if sec == "g":
try:
self.build_incremental_role_links(PolicyOp.Policy_add, ptype, no_exists_policy)
except Exception as e:
self.logger.log("An exception occurred: " + e)
return no_exists_policy
return no_exists_policy
def remove_policy_self(self, should_persist, sec, ptype, rules):
"""
remove_policy_self provides a method for dispatcher to remove policies from current policy.
The function returns the rules affected and error.
"""
if(should_persist):
try:
if(isinstance(self.adapter, batch_adapter)):
self.adapter.remove_policy(sec, ptype, rules)
except Exception as e:
self.logger.log("An exception occurred: " + e)
effected = self.get_model().remove_policies_with_effected(sec, ptype, rules)
if sec == "g":
try:
self.build_incremental_role_links(PolicyOp.Policy_remove, ptype, rules)
except Exception as e:
self.logger.log("An exception occurred: " + e)
return effected
return effected
def remove_filtered_policy_self(self, should_persist, sec, ptype, field_index, *field_values):
"""
remove_filtered_policy_self provides a method for dispatcher to remove an authorization
rule from the current policy,field filters can be specified.
The function returns the rules affected and error.
"""
if should_persist:
try:
self.adapter.remove_filtered_policy(sec, ptype, field_index, field_values)
except Exception as e:
self.logger.log("An exception occurred: " + e)
effects = self.get_model().remove_filtered_policy_returns_effects(sec, ptype, field_index, field_values)
if sec == "g":
try:
self.build_incremental_role_links(PolicyOp.Policy_remove, ptype, effects)
except Exception as e:
self.logger.log("An exception occurred: " + e)
return effects
return effects
def clear_policy_self(self, should_persist):
"""
clear_policy_self provides a method for dispatcher to clear all rules from the current policy.
"""
if should_persist:
try:
self.adapter.save_policy(None)
except Exception as e:
self.logger.log("An exception occurred: " + e)
self.get_model().clear_policy()
def update_policy_self(self, should_persist, sec, ptype, old_rule, new_rule):
"""
update_policy_self provides a method for dispatcher to update an authorization rule from the current policy.
"""
if should_persist:
try:
if isinstance(self.adapter, update_adapter):
self.adapter.update_policy(sec, ptype, old_rule, new_rule)
except Exception as e:
self.logger.log("An exception occurred: " + e)
return False
rule_updated = self.get_model().update_policy(sec, ptype, old_rule, new_rule)
if not rule_updated:
return False
if sec == "g":
try:
self.build_incremental_role_links(PolicyOp.Policy_remove, ptype, [old_rule])
except Exception as e:
return False
try:
self.build_incremental_role_links(PolicyOp.Policy_add, ptype, [new_rule])
except Exception as e:
return False
return True

View File

@ -0,0 +1,24 @@
from .default_effectors import AllowOverrideEffector, DenyOverrideEffector, AllowAndDenyEffector, PriorityEffector
from .effector import Effector
def get_effector(expr):
''' creates an effector based on the current policy effect expression '''
if expr == "some(where (p_eft == allow))":
return AllowOverrideEffector()
elif expr == "!some(where (p_eft == deny))":
return DenyOverrideEffector()
elif expr == "some(where (p_eft == allow)) && !some(where (p_eft == deny))":
return AllowAndDenyEffector()
elif expr == "priority(p_eft) || deny":
return PriorityEffector()
else:
raise RuntimeError("unsupported effect")
def effect_to_bool(effect):
""" """
if effect == Effector.ALLOW:
return True
if effect == Effector.DENY:
return False
raise RuntimeError("effect can't be converted to boolean")

View File

@ -0,0 +1,61 @@
from .effector import Effector
class AllowOverrideEffector(Effector):
def intermediate_effect(self, effects):
""" returns a intermediate effect based on the matched effects of the enforcer """
if Effector.ALLOW in effects:
return Effector.ALLOW
return Effector.INDETERMINATE
def final_effect(self, effects):
""" returns the final effect based on the matched effects of the enforcer """
if Effector.ALLOW in effects:
return Effector.ALLOW
return Effector.DENY
class DenyOverrideEffector(Effector):
def intermediate_effect(self, effects):
""" returns a intermediate effect based on the matched effects of the enforcer """
if Effector.DENY in effects:
return Effector.DENY
return Effector.INDETERMINATE
def final_effect(self, effects):
""" returns the final effect based on the matched effects of the enforcer """
if Effector.DENY in effects:
return Effector.DENY
return Effector.ALLOW
class AllowAndDenyEffector(Effector):
def intermediate_effect(self, effects):
""" returns a intermediate effect based on the matched effects of the enforcer """
if Effector.DENY in effects:
return Effector.DENY
return Effector.INDETERMINATE
def final_effect(self, effects):
""" returns the final effect based on the matched effects of the enforcer """
if Effector.DENY in effects or Effector.ALLOW not in effects:
return Effector.DENY
return Effector.ALLOW
class PriorityEffector(Effector):
def intermediate_effect(self, effects):
""" returns a intermediate effect based on the matched effects of the enforcer """
if Effector.ALLOW in effects:
return Effector.ALLOW
if Effector.DENY in effects:
return Effector.DENY
return Effector.INDETERMINATE
def final_effect(self, effects):
""" returns the final effect based on the matched effects of the enforcer """
if Effector.ALLOW in effects:
return Effector.ALLOW
if Effector.DENY in effects:
return Effector.DENY
return Effector.DENY

View File

@ -0,0 +1,19 @@
class Effector:
"""Effector is the interface for Casbin effectors."""
ALLOW = 0
INDETERMINATE = 1
DENY = 2
def intermediate_effect(self, effects):
""" returns a intermediate effect based on the matched effects of the enforcer """
pass
def final_effect(self, effects):
""" returns the final effect based on the matched effects of the enforcer """
pass

211
utils/casbin/enforcer.py Normal file
View File

@ -0,0 +1,211 @@
from utils.casbin.management_enforcer import ManagementEnforcer
from utils.casbin.util import join_slice, set_subtract
class Enforcer(ManagementEnforcer):
"""
Enforcer = ManagementEnforcer + RBAC_API + RBAC_WITH_DOMAIN_API
"""
"""creates an enforcer via file or DB.
File:
e = casbin.Enforcer("path/to/basic_model.conf", "path/to/basic_policy.csv")
MySQL DB:
a = mysqladapter.DBAdapter("mysql", "mysql_username:mysql_password@tcp(127.0.0.1:3306)/")
e = casbin.Enforcer("path/to/basic_model.conf", a)
"""
def get_roles_for_user(self, name):
""" gets the roles that a user has. """
return self.model.model["g"]["g"].rm.get_roles(name)
def get_users_for_role(self, name):
""" gets the users that has a role. """
return self.model.model["g"]["g"].rm.get_users(name)
def has_role_for_user(self, name, role):
""" determines whether a user has a role. """
roles = self.get_roles_for_user(name)
return any(r == role for r in roles)
def add_role_for_user(self, user, role):
"""
adds a role for a user.
Returns false if the user already has the role (aka not affected).
"""
return self.add_grouping_policy(user, role)
def delete_role_for_user(self, user, role):
"""
deletes a role for a user.
Returns false if the user does not have the role (aka not affected).
"""
return self.remove_grouping_policy(user, role)
def delete_roles_for_user(self, user):
"""
deletes all roles for a user.
Returns false if the user does not have any roles (aka not affected).
"""
return self.remove_filtered_grouping_policy(0, user)
def delete_user(self, user):
"""
deletes a user.
Returns false if the user does not exist (aka not affected).
"""
res1 = self.remove_filtered_grouping_policy(0, user)
res2 = self.remove_filtered_policy(0, user)
return res1 or res2
def delete_role(self, role):
"""
deletes a role.
Returns false if the role does not exist (aka not affected).
"""
res1 = self.remove_filtered_grouping_policy(1, role)
res2 = self.remove_filtered_policy(0, role)
return res1 or res2
def delete_permission(self, *permission):
"""
deletes a permission.
Returns false if the permission does not exist (aka not affected).
"""
return self.remove_filtered_policy(1, *permission)
def add_permission_for_user(self, user, *permission):
"""
adds a permission for a user or role.
Returns false if the user or role already has the permission (aka not affected).
"""
return self.add_policy(join_slice(user, *permission))
def delete_permission_for_user(self, user, *permission):
"""
deletes a permission for a user or role.
Returns false if the user or role does not have the permission (aka not affected).
"""
return self.remove_policy(join_slice(user, *permission))
def delete_permissions_for_user(self, user):
"""
deletes permissions for a user or role.
Returns false if the user or role does not have any permissions (aka not affected).
"""
return self.remove_filtered_policy(0, user)
def get_permissions_for_user(self, user):
"""
gets permissions for a user or role.
"""
return self.get_filtered_policy(0, user)
def has_permission_for_user(self, user, *permission):
"""
determines whether a user has a permission.
"""
return self.has_policy(join_slice(user, *permission))
def get_implicit_roles_for_user(self, name, domain=None):
"""
gets implicit roles that a user has.
Compared to get_roles_for_user(), this function retrieves indirect roles besides direct roles.
For example:
g, alice, role:admin
g, role:admin, role:user
get_roles_for_user("alice") can only get: ["role:admin"].
But get_implicit_roles_for_user("alice") will get: ["role:admin", "role:user"].
"""
res = []
queue = [name]
while queue:
name = queue.pop(0)
for rm in self.rm_map.values():
roles = rm.get_roles(name, domain)
for r in roles:
if r not in res:
res.append(r)
queue.append(r)
return res
def get_implicit_permissions_for_user(self, user, domain=None):
"""
gets implicit permissions for a user or role.
Compared to get_permissions_for_user(), this function retrieves permissions for inherited roles.
For example:
p, admin, data1, read
p, alice, data2, read
g, alice, admin
get_permissions_for_user("alice") can only get: [["alice", "data2", "read"]].
But get_implicit_permissions_for_user("alice") will get: [["admin", "data1", "read"], ["alice", "data2", "read"]].
"""
roles = self.get_implicit_roles_for_user(user, domain)
roles.insert(0, user)
res = []
for role in roles:
if domain:
permissions = self.get_permissions_for_user_in_domain(role, domain)
else:
permissions = self.get_permissions_for_user(role)
res.extend(permissions)
return res
def get_implicit_users_for_permission(self, *permission):
"""
gets implicit users for a permission.
For example:
p, admin, data1, read
p, bob, data1, read
g, alice, admin
get_implicit_users_for_permission("data1", "read") will get: ["alice", "bob"].
Note: only users will be returned, roles (2nd arg in "g") will be excluded.
"""
subjects = self.get_all_subjects()
roles = self.get_all_roles()
users = set_subtract(subjects, roles)
res = list()
for user in users:
req = join_slice(user, *permission)
allowed = self.enforce(*req)
if allowed:
res.append(user)
return res
def get_roles_for_user_in_domain(self, name, domain):
"""gets the roles that a user has inside a domain."""
return self.model.model['g']['g'].rm.get_roles(name, domain)
def get_users_for_role_in_domain(self, name, domain):
"""gets the users that has a role inside a domain."""
return self.model.model['g']['g'].rm.get_users(name, domain)
def add_role_for_user_in_domain(self, user, role, domain):
"""adds a role for a user inside a domain."""
"""Returns false if the user already has the role (aka not affected)."""
return self.add_grouping_policy(user, role, domain)
def delete_roles_for_user_in_domain(self, user, role, domain):
"""deletes a role for a user inside a domain."""
"""Returns false if the user does not have any roles (aka not affected)."""
return self.remove_filtered_grouping_policy(0, user, role, domain)
def get_permissions_for_user_in_domain(self, user, domain):
"""gets permissions for a user or role inside domain."""
return self.get_filtered_policy(0, user, domain)

View File

@ -0,0 +1,122 @@
from utils.casbin.core_enforcer import CoreEnforcer
from utils.casbin.model.policy_op import PolicyOp
class InternalEnforcer(CoreEnforcer):
"""
InternalEnforcer = CoreEnforcer + Internal API.
"""
def _add_policy(self, sec, ptype, rule):
"""adds a rule to the current policy."""
rule_added = self.model.add_policy(sec, ptype, rule)
if not rule_added:
return rule_added
if self.adapter and self.auto_save:
if self.adapter.add_policy(sec, ptype, rule) is False:
return False
if self.watcher:
self.watcher.update()
return rule_added
def _add_policies(self,sec,ptype,rules):
"""adds rules to the current policy."""
rules_added = self.model.add_policies(sec, ptype, rules)
if not rules_added:
return rules_added
if self.adapter and self.auto_save:
if hasattr(self.adapter,'add_policies') is False:
return False
if self.adapter.add_policies(sec, ptype, rules) is False:
return False
if self.watcher:
self.watcher.update()
return rules_added
def _update_policy(self, sec, ptype, old_rule, new_rule):
"""updates a rule from the current policy."""
rule_updated = self.model.update_policy(sec, ptype, old_rule, new_rule)
if not rule_updated:
return rule_updated
if self.adapter and self.auto_save:
if self.adapter.update_policy(sec, ptype, old_rule, new_rule) is False:
return False
if self.watcher:
self.watcher.update()
return rule_updated
def _update_policies(self, sec, ptype, old_rules, new_rules):
"""updates rules from the current policy."""
rules_updated = self.model.update_policies(sec, ptype, old_rules, new_rules)
if not rules_updated:
return rules_updated
if self.adapter and self.auto_save:
if self.adapter.update_policies(sec, ptype, old_rules, new_rules) is False:
return False
if self.watcher:
self.watcher.update()
return rules_updated
def _remove_policy(self, sec, ptype, rule):
"""removes a rule from the current policy."""
rule_removed = self.model.remove_policy(sec, ptype, rule)
if not rule_removed:
return rule_removed
if self.adapter and self.auto_save:
if self.adapter.remove_policy(sec, ptype, rule) is False:
return False
if self.watcher:
self.watcher.update()
return rule_removed
def _remove_policies(self, sec, ptype, rules):
"""RemovePolicies removes policy rules from the model."""
rules_removed = self.model.remove_policies(sec, ptype, rules)
if not rules_removed:
return rules_removed
if self.adapter and self.auto_save:
if hasattr(self.adapter,'remove_policies') is False:
return False
if self.adapter.remove_policies(sec, ptype, rules) is False:
return False
if self.watcher:
self.watcher.update()
return rules_removed
def _remove_filtered_policy(self, sec, ptype, field_index, *field_values):
"""removes rules based on field filters from the current policy."""
rule_removed = self.model.remove_filtered_policy(sec, ptype, field_index, *field_values)
if not rule_removed:
return rule_removed
if self.adapter and self.auto_save:
if self.adapter.remove_filtered_policy(sec, ptype, field_index, *field_values) is False:
return False
if self.watcher:
self.watcher.update()
return rule_removed

View File

@ -0,0 +1,271 @@
from utils.casbin.internal_enforcer import InternalEnforcer
class ManagementEnforcer(InternalEnforcer):
"""
ManagementEnforcer = InternalEnforcer + Management API.
"""
def get_all_subjects(self):
"""gets the list of subjects that show up in the current policy."""
return self.get_all_named_subjects('p')
def get_all_named_subjects(self, ptype):
"""gets the list of subjects that show up in the current named policy."""
return self.model.get_values_for_field_in_policy('p', ptype, 0)
def get_all_objects(self):
"""gets the list of objects that show up in the current policy."""
return self.get_all_named_objects('p')
def get_all_named_objects(self, ptype):
"""gets the list of objects that show up in the current named policy."""
return self.model.get_values_for_field_in_policy('p', ptype, 1)
def get_all_actions(self):
"""gets the list of actions that show up in the current policy."""
return self.get_all_named_actions('p')
def get_all_named_actions(self, ptype):
"""gets the list of actions that show up in the current named policy."""
return self.model.get_values_for_field_in_policy('p', ptype, 2)
def get_all_roles(self):
"""gets the list of roles that show up in the current named policy."""
return self.get_all_named_roles('g')
def get_all_named_roles(self, ptype):
"""gets all the authorization rules in the policy."""
return self.model.get_values_for_field_in_policy('g', ptype, 1)
def get_policy(self):
"""gets all the authorization rules in the policy."""
return self.get_named_policy('p')
def get_filtered_policy(self, field_index, *field_values):
"""gets all the authorization rules in the policy, field filters can be specified."""
return self.get_filtered_named_policy('p', field_index, *field_values)
def get_named_policy(self, ptype):
"""gets all the authorization rules in the named policy."""
return self.model.get_policy('p', ptype)
def get_filtered_named_policy(self, ptype, field_index, *field_values):
"""gets all the authorization rules in the named policy, field filters can be specified."""
return self.model.get_filtered_policy('p', ptype, field_index, *field_values)
def get_grouping_policy(self):
"""gets all the role inheritance rules in the policy."""
return self.get_named_grouping_policy('g')
def get_filtered_grouping_policy(self, field_index, *field_values):
"""gets all the role inheritance rules in the policy, field filters can be specified."""
return self.get_filtered_named_grouping_policy("g", field_index, *field_values)
def get_named_grouping_policy(self, ptype):
"""gets all the role inheritance rules in the policy."""
return self.model.get_policy('g', ptype)
def get_filtered_named_grouping_policy(self, ptype, field_index, *field_values):
"""gets all the role inheritance rules in the policy, field filters can be specified."""
return self.model.get_filtered_policy('g', ptype, field_index, *field_values)
def has_policy(self, *params):
"""determines whether an authorization rule exists."""
return self.has_named_policy('p', *params)
def has_named_policy(self, ptype, *params):
"""determines whether a named authorization rule exists."""
if len(params) == 1 and isinstance(params[0], list):
str_slice = params[0]
return self.model.has_policy('p', ptype, str_slice)
return self.model.has_policy('p', ptype, list(params))
def add_policy(self, *params):
"""adds an authorization rule to the current policy.
If the rule already exists, the function returns false and the rule will not be added.
Otherwise the function returns true by adding the new rule.
"""
return self.add_named_policy('p', *params)
def add_policies(self,rules):
"""adds authorization rules to the current policy.
If the rule already exists, the function returns false for the corresponding rule and the rule will not be added.
Otherwise the function returns true for the corresponding rule by adding the new rule.
"""
return self.add_named_policies('p',rules)
def add_named_policy(self, ptype, *params):
"""adds an authorization rule to the current named policy.
If the rule already exists, the function returns false and the rule will not be added.
Otherwise the function returns true by adding the new rule.
"""
if len(params) == 1 and isinstance(params[0], list):
str_slice = params[0]
rule_added = self._add_policy('p', ptype, str_slice)
else:
rule_added = self._add_policy('p', ptype, list(params))
return rule_added
def add_named_policies(self,ptype,rules):
"""adds authorization rules to the current named policy.
If the rule already exists, the function returns false for the corresponding rule and the rule will not be added.
Otherwise the function returns true for the corresponding by adding the new rule."""
return self._add_policies('p',ptype,rules)
def update_policy(self, old_rule, new_rule):
"""updates an authorization rule from the current policy."""
return self.update_named_policy('p', old_rule, new_rule)
def update_policies(self, old_rules, new_rules):
"""updates authorization rules from the current policy."""
return self.update_named_policies('p', old_rules, new_rules)
def update_named_policy(self, ptype, old_rule, new_rule):
"""updates an authorization rule from the current named policy."""
return self._update_policy('p', ptype, old_rule, new_rule)
def update_named_policies(self, ptype, old_rules, new_rules):
"""updates authorization rules from the current named policy."""
return self._update_policies('p', ptype, old_rules, new_rules)
def remove_policy(self, *params):
"""removes an authorization rule from the current policy."""
return self.remove_named_policy('p', *params)
def remove_policies(self,rules):
"""removes authorization rules from the current policy."""
return self.remove_named_policies('p',rules)
def remove_filtered_policy(self, field_index, *field_values):
"""removes an authorization rule from the current policy, field filters can be specified."""
return self.remove_filtered_named_policy('p', field_index, *field_values)
def remove_named_policy(self, ptype, *params):
"""removes an authorization rule from the current named policy."""
if len(params) == 1 and isinstance(params[0], list):
str_slice = params[0]
rule_removed = self._remove_policy('p', ptype, str_slice)
else:
rule_removed = self._remove_policy('p', ptype, list(params))
return rule_removed
def remove_named_policies(self,ptype,rules):
"""removes authorization rules from the current named policy."""
return self._remove_policies('p',ptype,rules)
def remove_filtered_named_policy(self, ptype, field_index, *field_values):
"""removes an authorization rule from the current named policy, field filters can be specified."""
return self._remove_filtered_policy('p', ptype, field_index, *field_values)
def has_grouping_policy(self, *params):
"""determines whether a role inheritance rule exists."""
return self.has_named_grouping_policy('g', *params)
def has_named_grouping_policy(self, ptype, *params):
"""determines whether a named role inheritance rule exists."""
if len(params) == 1 and isinstance(params[0], list):
str_slice = params[0]
return self.model.has_policy('g', ptype, str_slice)
return self.model.has_policy('g', ptype, list(params))
def add_grouping_policy(self, *params):
"""adds a role inheritance rule to the current policy.
If the rule already exists, the function returns false and the rule will not be added.
Otherwise the function returns true by adding the new rule.
"""
return self.add_named_grouping_policy('g', *params)
def add_grouping_policies(self,rules):
"""adds role inheritance rulea to the current policy.
If the rule already exists, the function returns false for the corresponding policy rule and the rule will not be added.
Otherwise the function returns true for the corresponding policy rule by adding the new rule.
"""
return self.add_named_grouping_policies('g',rules)
def add_named_grouping_policy(self, ptype, *params):
"""adds a named role inheritance rule to the current policy.
If the rule already exists, the function returns false and the rule will not be added.
Otherwise the function returns true by adding the new rule.
"""
if len(params) == 1 and isinstance(params[0], list):
str_slice = params[0]
rule_added = self._add_policy('g', ptype, str_slice)
else:
rule_added = self._add_policy('g', ptype, list(params))
if self.auto_build_role_links:
self.build_role_links()
return rule_added
def add_named_grouping_policies(self,ptype,rules):
""""adds named role inheritance rules to the current policy.
If the rule already exists, the function returns false for the corresponding policy rule and the rule will not be added.
Otherwise the function returns true for the corresponding policy rule by adding the new rule."""
rules_added = self._add_policies('g',ptype,rules)
if self.auto_build_role_links:
self.build_role_links()
return rules_added
def remove_grouping_policy(self, *params):
"""removes a role inheritance rule from the current policy."""
return self.remove_named_grouping_policy('g', *params)
def remove_grouping_policies(self,rules):
"""removes role inheritance rulea from the current policy."""
return self.remove_named_grouping_policies('g',rules)
def remove_filtered_grouping_policy(self, field_index, *field_values):
"""removes a role inheritance rule from the current policy, field filters can be specified."""
return self.remove_filtered_named_grouping_policy('g', field_index, *field_values)
def remove_named_grouping_policy(self, ptype, *params):
"""removes a role inheritance rule from the current named policy."""
if len(params) == 1 and isinstance(params[0], list):
str_slice = params[0]
rule_removed = self._remove_policy('g', ptype, str_slice)
else:
rule_removed = self._remove_policy('g', ptype, list(params))
if self.auto_build_role_links:
self.build_role_links()
return rule_removed
def remove_named_grouping_policies(self,ptype,rules):
""" removes role inheritance rules from the current named policy."""
rules_removed = self._remove_policies('g',ptype,rules)
if self.auto_build_role_links:
self.build_role_links()
return rules_removed
def remove_filtered_named_grouping_policy(self, ptype, field_index, *field_values):
"""removes a role inheritance rule from the current named policy, field filters can be specified."""
rule_removed = self._remove_filtered_policy('g', ptype, field_index, *field_values)
if self.auto_build_role_links:
self.build_role_links()
return rule_removed
def add_function(self, name, func):
"""adds a customized function."""
self.fm.add_function(name, func)

View File

@ -0,0 +1,4 @@
from .assertion import Assertion
from .model import Model
from .policy import Policy
from .function import FunctionMap

View File

@ -0,0 +1,47 @@
import logging
from utils.casbin.model.policy_op import PolicyOp
class Assertion:
def __init__(self):
self.logger = logging.getLogger()
self.key = ""
self.value = ""
self.tokens = []
self.policy = []
self.rm = None
def build_role_links(self, rm):
self.rm = rm
count = self.value.count("_")
if count < 2:
raise RuntimeError('the number of "_" in role definition should be at least 2')
for rule in self.policy:
if len(rule) < count:
pass
# raise RuntimeError("grouping policy elements do not meet role definition")
if len(rule) > count:
rule = rule[:count]
self.rm.add_link(*rule[:count])
self.logger.info("Role links for: {}".format(self.key))
self.rm.print_roles()
def build_incremental_role_links(self, rm, op, rules):
self.rm = rm
count = self.value.count("_")
if count < 2:
raise RuntimeError('the number of "_" in role definition should be at least 2')
for rule in rules:
if len(rule) < count:
raise TypeError("grouping policy elements do not meet role definition")
if len(rule) > count:
rule = rule[:count]
if op == PolicyOp.Policy_add:
rm.add_link(rule[0], rule[1], *rule[2:])
elif op == PolicyOp.Policy_remove:
rm.delete_link(rule[0], rule[1], *rule[2:])
else:
raise TypeError("Invalid operation: " + str(op))

View File

@ -0,0 +1,22 @@
from utils.casbin import util
class FunctionMap:
fm = dict()
def add_function(self, name, func):
self.fm[name] = func
@staticmethod
def load_function_map():
fm = FunctionMap()
fm.add_function("keyMatch", util.key_match_func)
fm.add_function("keyMatch2", util.key_match2_func)
fm.add_function("regexMatch", util.regex_match_func)
fm.add_function("ipMatch", util.ip_match_func)
fm.add_function("globMatch", util.glob_match_func)
return fm
def get_functions(self):
return self.fm

View File

@ -0,0 +1,80 @@
from . import Assertion
from utils.casbin import util, config
from .policy import Policy
class Model(Policy):
section_name_map = {
'r': 'request_definition',
'p': 'policy_definition',
'g': 'role_definition',
'e': 'policy_effect',
'm': 'matchers',
}
def _load_assertion(self, cfg, sec, key):
value = cfg.get(self.section_name_map[sec] + "::" + key)
return self.add_def(sec, key, value)
def add_def(self, sec, key, value):
if value == "":
return
ast = Assertion()
ast.key = key
ast.value = value
if "r" == sec or "p" == sec:
ast.tokens = ast.value.split(",")
for i,token in enumerate(ast.tokens):
ast.tokens[i] = key + "_" + token.strip()
else:
ast.value = util.remove_comments(util.escape_assertion(ast.value))
if sec not in self.model.keys():
self.model[sec] = {}
self.model[sec][key] = ast
return True
def _get_key_suffix(self, i):
if i == 1:
return ""
return str(i)
def _load_section(self, cfg, sec):
i = 1
while True:
if not self._load_assertion(cfg, sec, sec + self._get_key_suffix(i)):
break
else:
i = i + 1
def load_model(self, path):
cfg = config.Config.new_config(path)
self._load_section(cfg, "r")
self._load_section(cfg, "p")
self._load_section(cfg, "e")
self._load_section(cfg, "m")
self._load_section(cfg, "g")
def load_model_from_text(self, text):
cfg = config.Config.new_config_from_text(text)
self._load_section(cfg, "r")
self._load_section(cfg, "p")
self._load_section(cfg, "e")
self._load_section(cfg, "m")
self._load_section(cfg, "g")
def print_model(self):
self.logger.info("Model:")
for k, v in self.model.items():
for i, j in v.items():
self.logger.info("%s.%s: %s", k, i, j.value)

View File

@ -0,0 +1,190 @@
import logging
class Policy:
def __init__(self):
self.logger = logging.getLogger()
self.model = {}
def build_role_links(self, rm_map):
"""initializes the roles in RBAC."""
if "g" not in self.model.keys():
return
for ptype, ast in self.model["g"].items():
rm = rm_map[ptype]
ast.build_role_links(rm)
def build_incremental_role_links(self, rm, op, sec, ptype, rules):
if sec == "g":
self.model.get(sec).get(ptype).build_incremental_role_links(rm, op, rules)
def print_policy(self):
"""Log using info"""
self.logger.info("Policy:")
for sec in ["p", "g"]:
if sec not in self.model.keys():
continue
for key, ast in self.model[sec].items():
self.logger.info("{} : {} : {}".format(key, ast.value, ast.policy))
def clear_policy(self):
"""clears all current policy."""
for sec in ["p", "g"]:
if sec not in self.model.keys():
continue
for key in self.model[sec].keys():
self.model[sec][key].policy = []
def get_policy(self, sec, ptype):
"""gets all rules in a policy."""
return self.model[sec][ptype].policy
def get_filtered_policy(self, sec, ptype, field_index, *field_values):
"""gets rules based on field filters from a policy."""
return [
rule for rule in self.model[sec][ptype].policy
if all(value == "" or rule[field_index + i] == value for i, value in enumerate(field_values))
]
def has_policy(self, sec, ptype, rule):
"""determines whether a model has the specified policy rule."""
if sec not in self.model.keys():
return False
if ptype not in self.model[sec]:
return False
return rule in self.model[sec][ptype].policy
def add_policy(self, sec, ptype, rule):
"""adds a policy rule to the model."""
if not self.has_policy(sec, ptype, rule):
self.model[sec][ptype].policy.append(rule)
return True
return False
def add_policies(self,sec,ptype,rules):
"""adds policy rules to the model."""
for rule in rules:
if self.has_policy(sec,ptype,rule):
return False
for rule in rules:
self.model[sec][ptype].policy.append(rule)
return True
def update_policy(self, sec, ptype, old_rule, new_rule):
"""update a policy rule from the model."""
if not self.has_policy(sec, ptype, old_rule):
return False
return self.remove_policy(sec, ptype, old_rule) and self.add_policy(sec, ptype, new_rule)
def update_policies(self, sec, ptype, old_rules, new_rules):
"""update policy rules from the model."""
for rule in old_rules:
if not self.has_policy(sec, ptype, rule):
return False
return self.remove_policies(sec, ptype, old_rules) and self.add_policies(sec, ptype, new_rules)
def remove_policy(self, sec, ptype, rule):
"""removes a policy rule from the model."""
if not self.has_policy(sec, ptype, rule):
return False
self.model[sec][ptype].policy.remove(rule)
return rule not in self.model[sec][ptype].policy
def remove_policies(self, sec, ptype, rules):
"""RemovePolicies removes policy rules from the model."""
for rule in rules:
if not self.has_policy(sec,ptype,rule):
return False
self.model[sec][ptype].policy.remove(rule)
if rule in self.model[sec][ptype].policy:
return False
return True
def remove_policies_with_effected(self, sec, ptype, rules):
effected = []
for rule in rules:
if self.has_policy(sec, ptype, rule):
effected.append(rule)
self.remove_policy(sec, ptype, rule)
return effected
def remove_filtered_policy_returns_effects(self, sec, ptype, field_index, *field_values):
"""
remove_filtered_policy_returns_effects removes policy rules based on field filters from the model.
"""
tmp = []
effects = []
if(len(field_values) == 0):
return []
if sec not in self.model.keys():
return []
if ptype not in self.model[sec]:
return []
for rule in self.model[sec][ptype].policy:
if all(value == "" or rule[field_index + i] == value for i, value in enumerate(field_values[0])):
effects.append(rule)
else:
tmp.append(rule)
self.model[sec][ptype].policy = tmp
return effects
def remove_filtered_policy(self, sec, ptype, field_index, *field_values):
"""removes policy rules based on field filters from the model."""
tmp = []
res = False
if sec not in self.model.keys():
return res
if ptype not in self.model[sec]:
return res
for rule in self.model[sec][ptype].policy:
if all(value == "" or rule[field_index + i] == value for i, value in enumerate(field_values)):
res = True
else:
tmp.append(rule)
self.model[sec][ptype].policy = tmp
return res
def get_values_for_field_in_policy(self, sec, ptype, field_index):
"""gets all values for a field for all rules in a policy, duplicated values are removed."""
values = []
if sec not in self.model.keys():
return values
if ptype not in self.model[sec]:
return values
for rule in self.model[sec][ptype].policy:
value = rule[field_index]
if value not in values:
values.append(value)
return values

View File

@ -0,0 +1,5 @@
import enum
class PolicyOp(enum.Enum):
Policy_add = 1
Policy_remove = 2

View File

@ -0,0 +1,4 @@
from .adapter import *
from .adapter_filtered import *
from .batch_adapter import *
from .adapters import *

View File

@ -0,0 +1,46 @@
def load_policy_line(line, model):
"""loads a text line as a policy rule to model."""
if line == "":
return
if line[:1] == "#":
return
tokens = line.split(", ")
key = tokens[0]
sec = key[0]
if sec not in model.model.keys():
return
if key not in model.model[sec].keys():
return
model.model[sec][key].policy.append(tokens[1:])
class Adapter:
"""the interface for Casbin adapters."""
def load_policy(self, model):
"""loads all policy rules from the storage."""
pass
def save_policy(self, model):
"""saves all policy rules to the storage."""
pass
def add_policy(self, sec, ptype, rule):
"""adds a policy rule to the storage."""
pass
def remove_policy(self, sec, ptype, rule):
"""removes a policy rule from the storage."""
pass
def remove_filtered_policy(self, sec, ptype, field_index, *field_values):
"""removes policy rules that match the filter from the storage.
This is part of the Auto-Save feature.
"""
pass

View File

@ -0,0 +1,13 @@
from .adapter import Adapter
""" FilteredAdapter is the interface for Casbin adapters supporting filtered policies."""
class FilteredAdapter(Adapter):
def is_filtered(self):
"""IsFiltered returns true if the loaded policy has been filtered
Marks if the loaded policy is filtered or not
"""
pass
def load_filtered_policy(self, model, filter):
"""Loads policy rules that match the filter from the storage."""
pass

View File

@ -0,0 +1,2 @@
from .file_adapter import FileAdapter
from .adapter_filtered import FilteredAdapter

View File

@ -0,0 +1,89 @@
from utils.casbin import persist
from utils.casbin import model
from .file_adapter import FileAdapter
import os
class Filter:
#P,G are string []
P = []
G = []
class FilteredAdapter (FileAdapter,persist.FilteredAdapter):
filtered = False
_file_path = ""
filter = Filter()
#new_filtered_adapte is the constructor for FilteredAdapter.
def __init__(self,file_path):
self.filtered = True
self._file_path = file_path
def load_policy(self,model):
if not os.path.isfile(self._file_path):
raise RuntimeError("invalid file path, file path cannot be empty")
self.filtered=False
self._load_policy_file(model)
#load_filtered_policy loads only policy rules that match the filter.
def load_filtered_policy(self,model,filter):
if filter == None:
return self.load_policy(model)
if not os.path.isfile(self._file_path):
raise RuntimeError("invalid file path, file path cannot be empty")
try:
filter_value = [filter.__dict__['P']]+[filter.__dict__['G']]
except:
raise RuntimeError("invalid filter type")
self.load_filtered_policy_file(model,filter_value,persist.load_policy_line)
self.filtered = True
def load_filtered_policy_file(self,model,filter,hanlder):
with open(self._file_path, "rb") as file:
while True:
line = file.readline()
line = line.decode().strip()
if line == '\n':
continue
if not line :
break
if filter_line(line,filter):
continue
hanlder(line,model)
#is_filtered returns true if the loaded policy has been filtered.
def is_filtered(self):
return self.filtered
def save_policy(self,model):
if self.filtered:
raise RuntimeError("cannot save a filtered policy")
self._save_policy_file(model)
def filter_line(line,filter):
if filter == None:
return False
p = line.split(',')
if len(p) == 0:
return True
filter_slice = []
if p[0].strip()== 'p':
filter_slice = filter[0]
elif p[0].strip() == 'g':
filter_slice = filter[1]
return filter_words(p,filter_slice)
def filter_words(line,filter):
if len(line) < len(filter)+1:
return True
skip_line=False
for i,v in enumerate(filter):
if(len(v) >0 and ( v.strip() != line[i+1].strip() ) ):
skip_line = True
break
return skip_line

View File

@ -0,0 +1,62 @@
from utils.casbin import persist
import os
class FileAdapter(persist.Adapter):
"""the file adapter for Casbin.
It can load policy from file or save policy to file.
"""
_file_path = ""
def __init__(self, file_path):
self._file_path = file_path
def load_policy(self, model):
if not os.path.isfile(self._file_path):
raise RuntimeError("invalid file path, file path cannot be empty")
self._load_policy_file(model)
def save_policy(self, model):
if not os.path.isfile(self._file_path):
raise RuntimeError("invalid file path, file path cannot be empty")
self._save_policy_file(model)
def _load_policy_file(self, model):
with open(self._file_path, "rb") as file:
line = file.readline()
while line:
persist.load_policy_line(line.decode().strip(), model)
line = file.readline()
def _save_policy_file(self, model):
with open(self._file_path, "w") as file:
lines = []
if "p" in model.model.keys():
for key, ast in model.model["p"].items():
for pvals in ast.policy:
lines.append(key + ", " + ", ".join(pvals))
if "g" in model.model.keys():
for key, ast in model.model["g"].items():
for pvals in ast.policy:
lines.append(key + ", " + ", ".join(pvals))
for i, line in enumerate(lines):
if i != len(lines) - 1:
lines[i] += "\n"
file.writelines(lines)
def add_policy(self, sec, ptype, rule):
pass
def add_policies(self,sec,ptype,rules):
pass
def remove_policy(self, sec, ptype, rule):
pass
def remove_policies(self,sec,ptype,rules):
pass

View File

@ -0,0 +1,9 @@
class UpdateAdapter:
""" UpdateAdapter is the interface for Casbin adapters with add update policy function. """
def update_policy(self, sec, ptype, old_rule, new_policy):
"""
update_policy updates a policy rule from storage.
This is part of the Auto-Save feature.
"""
pass

View File

@ -0,0 +1,11 @@
from .adapter import Adapter
"""BatchAdapter is the interface for Casbin adapters with multiple add and remove policy functions."""
class BatchAdapter(Adapter):
def add_policies(self,sec,ptype,rules):
"""AddPolicies adds policy rules to the storage."""
pass
def remove_policies(self,sec,ptype,rules):
"""RemovePolicies removes policy rules from the storage."""
pass

View File

@ -0,0 +1,21 @@
class Dispatcher:
"""Dispatcher is the interface for pycasbin dispatcher"""
def add_policies(self, sec, ptype, rules):
"""add_policies adds policies rule to all instance."""
pass
def remove_policies(self, sec, ptype, rules):
"""remove_policies removes policies rule from all instance."""
pass
def remove_filtered_policy(self, sec, ptype, field_index, field_values):
"""remove_filtered_policy removes policy rules that match the filter from all instance."""
pass
def clear_policy(self):
"""clear_policy clears all current policy in all instances."""
pass
def update_policy(self, sec, ptype, old_rule, new_rule):
"""update_policy updates policy rule from all instance."""
pass

View File

@ -0,0 +1 @@
from .role_manager import RoleManager

View File

@ -0,0 +1 @@
from .role_manager import RoleManager

View File

@ -0,0 +1,219 @@
import logging
from utils.casbin.rbac import RoleManager
class RoleManager(RoleManager):
"""provides a default implementation for the RoleManager interface"""
all_roles = dict()
max_hierarchy_level = 0
def __init__(self, max_hierarchy_level):
self.logger = logging.getLogger()
self.all_roles = dict()
self.max_hierarchy_level = max_hierarchy_level
self.matching_func = None
self.domain_matching_func = None
self.has_pattern = None
self.has_domain_pattern = None
def add_matching_func(self, fn=None):
self.has_pattern = True
self.matching_func = fn
def add_domain_matching_func(self, fn=None):
self.has_domain_pattern = True
self.domain_matching_func = fn
def has_role(self, name):
if self.matching_func is None:
return name in self.all_roles.keys()
else:
for key in self.all_roles.keys():
if self.matching_func(name, key):
return True
return False
def create_role(self, name):
if name not in self.all_roles.keys():
self.all_roles[name] = Role(name)
return self.all_roles[name]
def clear(self):
self.all_roles.clear()
def add_link(self, name1, name2, *domain):
if len(domain) == 1:
name1 = domain[0] + "::" + name1
name2 = domain[0] + "::" + name2
elif len(domain) > 1:
raise RuntimeError("error: domain should be 1 parameter")
role1 = self.create_role(name1)
role2 = self.create_role(name2)
role1.add_role(role2)
if self.matching_func is not None:
for key, role in self.all_roles.items():
if self.matching_func(key, name1) and name1 != key:
self.all_roles[key].add_role(role1)
if self.matching_func(key, name2) and name2 != key:
self.all_roles[name2].add_role(role)
if self.matching_func(name1, key) and name1 != key:
self.all_roles[key].add_role(role1)
if self.matching_func(name2, key) and name2 != key:
self.all_roles[name2].add_role(role)
def delete_link(self, name1, name2, *domain):
if len(domain) == 1:
name1 = domain[0] + "::" + name1
name2 = domain[0] + "::" + name2
elif len(domain) > 1:
raise RuntimeError("error: domain should be 1 parameter")
if not self.has_role(name1) or not self.has_role(name2):
raise RuntimeError("error: name1 or name2 does not exist")
role1 = self.create_role(name1)
role2 = self.create_role(name2)
role1.delete_role(role2)
def has_link(self, name1, name2, *domain):
if len(domain) == 1:
name1 = domain[0] + "::" + name1
name2 = domain[0] + "::" + name2
elif len(domain) > 1:
raise RuntimeError("error: domain should be 1 parameter")
if name1 == name2:
return True
if not self.has_role(name1) or not self.has_role(name2):
return False
if self.matching_func is None:
role1 = self.create_role(name1)
return role1.has_role(name2, self.max_hierarchy_level)
else:
for key, role in self.all_roles.items():
if self.matching_func(name1, key) and role.has_role(name2, self.max_hierarchy_level,
self.matching_func):
return True
return False
def get_roles(self, name, domain=None):
"""
gets the roles that a subject inherits.
domain is a prefix to the roles.
"""
if domain:
name = domain + "::" + name
if not self.has_role(name):
return []
roles = self.create_role(name).get_roles()
if domain:
for key, value in enumerate(roles):
roles[key] = value[len(domain) + 2:]
return roles
def get_users(self, name, *domain):
"""
gets the users that inherits a subject.
domain is an unreferenced parameter here, may be used in other implementations.
"""
if len(domain) == 1:
name = domain[0] + "::" + name
elif len(domain) > 1:
return RuntimeError("error: domain should be 1 parameter")
if not self.has_role(name):
return []
names = []
for role in self.all_roles.values():
if role.has_direct_role(name):
if len(domain) == 1:
names.append(role.name[len(domain[0]) + 2:])
else:
names.append(role.name)
return names
def print_roles(self):
line = []
for role in self.all_roles.values():
text = role.to_string()
if text:
line.append(text)
self.logger.info(", ".join(line))
class Role:
"""represents the data structure for a role in RBAC."""
name = ""
roles = []
def __init__(self, name):
self.name = name
self.roles = []
def add_role(self, role):
for rr in self.roles:
if rr.name == role.name:
return
self.roles.append(role)
def delete_role(self, role):
for rr in self.roles:
if rr.name == role.name:
self.roles.remove(rr)
return
def has_role(self, name, hierarchy_level, matching_func=None):
if self.has_direct_role(name, matching_func):
return True
if hierarchy_level <= 0:
return False
for role in self.roles:
if role.has_role(name, hierarchy_level - 1, matching_func):
return True
return False
def has_direct_role(self, name, matching_func=None):
if matching_func is None:
for role in self.roles:
if role.name == name:
return True
else:
for role in self.roles:
if matching_func(name, role.name):
return True
return False
def to_string(self):
if len(self.roles) == 0:
return ""
names = ", ".join(self.get_roles())
if len(self.roles) == 1:
return self.name + " < " + names
else:
return self.name + " < (" + names + ")"
def get_roles(self):
names = []
for role in self.roles:
names.append(role.name)
return names

View File

@ -0,0 +1,23 @@
class RoleManager:
"""provides interface to define the operations for managing roles."""
def clear(self):
pass
def add_link(self, name1, name2, *domain):
pass
def delete_link(self, name1, name2, *domain):
pass
def has_link(self, name1, name2, *domain):
pass
def get_roles(self, name, *domain):
pass
def get_users(self, name, *domain):
pass
def print_roles(self):
pass

View File

@ -0,0 +1,580 @@
import threading
import time
from utils.casbin.enforcer import Enforcer
from utils.casbin.util.rwlock import RWLockWrite
class AtomicBool():
def __init__(self, value):
self._lock = threading.Lock()
self._value = value
@property
def value(self):
with self._lock:
return self._value
@value.setter
def value(self, value):
with self._lock:
self._value = value
class SyncedEnforcer():
"""SyncedEnforcer wraps Enforcer and provides synchronized access.
It's also a drop-in replacement for Enforcer"""
def __init__(self, model=None, adapter=None):
self._e = Enforcer(model, adapter)
self._rwlock = RWLockWrite()
self._rl = self._rwlock.gen_rlock()
self._wl = self._rwlock.gen_wlock()
self._auto_loading = AtomicBool(False)
self._auto_loading_thread = None
def is_auto_loading_running(self):
"""check if SyncedEnforcer is auto loading policies"""
return self._auto_loading.value
def _auto_load_policy(self, interval):
while self.is_auto_loading_running():
time.sleep(interval)
self.load_policy()
def start_auto_load_policy(self, interval):
"""starts a thread that will call load_policy every interval seconds"""
if self.is_auto_loading_running():
return
self._auto_loading.value = True
self._auto_loading_thread = threading.Thread(target=self._auto_load_policy, args=[interval], daemon=True)
def stop_auto_load_policy(self):
"""stops the thread started by start_auto_load_policy"""
if self.is_auto_loading_running():
self._auto_loading.value = False
def get_model(self):
"""gets the current model."""
with self._rl:
return self._e.get_model()
def set_model(self, m):
"""sets the current model."""
with self._wl:
return self._e.set_model(m)
def load_model(self):
"""reloads the model from the model CONF file.
Because the policy is attached to a model, so the policy is invalidated and needs to be reloaded by calling LoadPolicy().
"""
with self._wl:
return self._e.load_model()
def get_role_manager(self):
"""gets the current role manager."""
with self._rl:
return self._e.get_role_manager()
def set_role_manager(self, rm):
with self._wl:
self._e.set_role_manager(rm)
def get_adapter(self):
"""gets the current adapter."""
with self._rl:
self._e.get_adapter()
def set_adapter(self, adapter):
"""sets the current adapter."""
with self._wl:
self._e.set_adapter(adapter)
def set_watcher(self, watcher):
"""sets the current watcher."""
with self._wl:
self._e.set_watcher(watcher)
def set_effector(self, eft):
"""sets the current effector."""
with self._wl:
self._e.set_effector(eft)
def clear_policy(self):
""" clears all policy."""
with self._wl:
return self._e.clear_policy()
def load_policy(self):
"""reloads the policy from file/database."""
with self._wl:
return self._e.load_policy()
def load_filtered_policy(self, filter):
""""reloads a filtered policy from file/database."""
with self._wl:
return self._e.load_filtered_policy(filter)
def save_policy(self):
with self._rl:
return self._e.save_policy()
def build_role_links(self):
"""manually rebuild the role inheritance relations."""
with self._rl:
return self._e.build_role_links()
def enforce(self, *rvals):
"""decides whether a "subject" can access a "object" with the operation "action",
input parameters are usually: (sub, obj, act).
"""
with self._rl:
return self._e.enforce(*rvals)
def enforce_ex(self, *rvals):
"""decides whether a "subject" can access a "object" with the operation "action",
input parameters are usually: (sub, obj, act).
return judge result with reason
"""
with self._rl:
return self._e.enforce_ex(*rvals)
def get_all_subjects(self):
"""gets the list of subjects that show up in the current policy."""
with self._rl:
return self._e.get_all_subjects()
def get_all_named_subjects(self, ptype):
"""gets the list of subjects that show up in the current named policy."""
with self._rl:
return self._e.get_all_named_subjects(ptype)
def get_all_objects(self):
"""gets the list of objects that show up in the current policy."""
with self._rl:
return self._e.get_all_objects()
def get_all_named_objects(self, ptype):
"""gets the list of objects that show up in the current named policy."""
with self._rl:
return self._e.get_all_named_objects(ptype)
def get_all_actions(self):
"""gets the list of actions that show up in the current policy."""
with self._rl:
return self._e.get_all_actions()
def get_all_named_actions(self, ptype):
"""gets the list of actions that show up in the current named policy."""
with self._rl:
return self._e.get_all_named_actions(ptype)
def get_all_roles(self):
"""gets the list of roles that show up in the current named policy."""
with self._rl:
return self._e.get_all_roles()
def get_all_named_roles(self, ptype):
"""gets all the authorization rules in the policy."""
with self._rl:
return self._e.get_all_named_roles(ptype)
def get_policy(self):
"""gets all the authorization rules in the policy."""
with self._rl:
return self._e.get_policy()
def get_filtered_policy(self, field_index, *field_values):
"""gets all the authorization rules in the policy, field filters can be specified."""
with self._rl:
return self._e.get_filtered_policy(field_index, *field_values)
def get_named_policy(self, ptype):
"""gets all the authorization rules in the named policy."""
with self._rl:
return self._e.get_named_policy(ptype)
def get_filtered_named_policy(self, ptype, field_index, *field_values):
"""gets all the authorization rules in the named policy, field filters can be specified."""
with self._rl:
return self._e.get_filtered_named_policy(ptype, field_index, *field_values)
def get_grouping_policy(self):
"""gets all the role inheritance rules in the policy."""
with self._rl:
return self._e.get_grouping_policy()
def get_filtered_grouping_policy(self, field_index, *field_values):
"""gets all the role inheritance rules in the policy, field filters can be specified."""
with self._rl:
return self._e.get_filtered_grouping_policy(field_index, *field_values)
def get_named_grouping_policy(self, ptype):
"""gets all the role inheritance rules in the policy."""
with self._rl:
return self._e.get_named_grouping_policy(ptype)
def get_filtered_named_grouping_policy(self, ptype, field_index, *field_values):
"""gets all the role inheritance rules in the policy, field filters can be specified."""
with self._rl:
return self._e.get_filtered_named_grouping_policy(ptype, field_index, *field_values)
def has_policy(self, *params):
"""determines whether an authorization rule exists."""
with self._rl:
return self._e.has_policy(*params)
def has_named_policy(self, ptype, *params):
"""determines whether a named authorization rule exists."""
with self._rl:
return self._e.has_named_policy(ptype, *params)
def add_policy(self, *params):
"""adds an authorization rule to the current policy.
If the rule already exists, the function returns false and the rule will not be added.
Otherwise the function returns true by adding the new rule.
"""
with self._wl:
return self._e.add_policy(*params)
def add_named_policy(self, ptype, *params):
"""adds an authorization rule to the current named policy.
If the rule already exists, the function returns false and the rule will not be added.
Otherwise the function returns true by adding the new rule.
"""
with self._wl:
return self._e.add_named_policy(ptype, *params)
def remove_policy(self, *params):
"""removes an authorization rule from the current policy."""
with self._wl:
return self._e.remove_policy(*params)
def remove_filtered_policy(self, field_index, *field_values):
"""removes an authorization rule from the current policy, field filters can be specified."""
with self._wl:
return self._e.remove_filtered_policy(field_index, *field_values)
def remove_named_policy(self, ptype, *params):
"""removes an authorization rule from the current named policy."""
with self._wl:
return self._e.remove_named_policy(ptype, *params)
def remove_filtered_named_policy(self, ptype, field_index, *field_values):
"""removes an authorization rule from the current named policy, field filters can be specified."""
with self._wl:
return self._e.remove_filtered_named_policy(ptype, field_index, *field_values)
def has_grouping_policy(self, *params):
"""determines whether a role inheritance rule exists."""
with self._rl:
return self._e.has_grouping_policy(*params)
def has_named_grouping_policy(self, ptype, *params):
"""determines whether a named role inheritance rule exists."""
with self._rl:
return self._e.has_named_grouping_policy(ptype, *params)
def add_grouping_policy(self, *params):
"""adds a role inheritance rule to the current policy.
If the rule already exists, the function returns false and the rule will not be added.
Otherwise the function returns true by adding the new rule.
"""
with self._wl:
return self._e.add_grouping_policy(*params)
def add_named_grouping_policy(self, ptype, *params):
"""adds a named role inheritance rule to the current policy.
If the rule already exists, the function returns false and the rule will not be added.
Otherwise the function returns true by adding the new rule.
"""
with self._wl:
return self._e.add_named_grouping_policy(ptype, *params)
def remove_grouping_policy(self, *params):
"""removes a role inheritance rule from the current policy."""
with self._wl:
return self._e.remove_grouping_policy(*params)
def remove_filtered_grouping_policy(self, field_index, *field_values):
"""removes a role inheritance rule from the current policy, field filters can be specified."""
with self._wl:
return self._e.remove_filtered_grouping_policy(field_index, *field_values)
def remove_named_grouping_policy(self, ptype, *params):
"""removes a role inheritance rule from the current named policy."""
with self._wl:
return self._e.remove_named_grouping_policy(ptype, *params)
def remove_filtered_named_grouping_policy(self, ptype, field_index, *field_values):
"""removes a role inheritance rule from the current named policy, field filters can be specified."""
with self._wl:
return self._e.remove_filtered_named_grouping_policy(ptype, field_index, *field_values)
def add_function(self, name, func):
"""adds a customized function."""
with self._wl:
return self._e.add_function(name, func)
# enforcer.py
def get_roles_for_user(self, name):
""" gets the roles that a user has. """
with self._rl:
return self._e.get_roles_for_user(name)
def get_users_for_role(self, name):
""" gets the users that has a role. """
with self._rl:
return self._e.get_users_for_role(name)
def has_role_for_user(self, name, role):
""" determines whether a user has a role. """
with self._rl:
return self._e.has_role_for_user(name, role)
def add_role_for_user(self, user, role):
"""
adds a role for a user.
Returns false if the user already has the role (aka not affected).
"""
with self._wl:
return self._e.add_role_for_user(user, role)
def delete_role_for_user(self, user, role):
"""
deletes a role for a user.
Returns false if the user does not have the role (aka not affected).
"""
with self._wl:
return self._e.delete_role_for_user(user, role)
def delete_roles_for_user(self, user):
"""
deletes all roles for a user.
Returns false if the user does not have any roles (aka not affected).
"""
with self._wl:
return self._e.delete_roles_for_user(user)
def delete_user(self, user):
"""
deletes a user.
Returns false if the user does not exist (aka not affected).
"""
with self._wl:
return self._e.delete_user(user)
def delete_role(self, role):
"""
deletes a role.
Returns false if the role does not exist (aka not affected).
"""
with self._wl:
return self._e.delete_role(role)
def delete_permission(self, *permission):
"""
deletes a permission.
Returns false if the permission does not exist (aka not affected).
"""
with self._wl:
return self._e.delete_permission(*permission)
def add_permission_for_user(self, user, *permission):
"""
adds a permission for a user or role.
Returns false if the user or role already has the permission (aka not affected).
"""
with self._wl:
return self._e.add_permission_for_user(user, *permission)
def delete_permission_for_user(self, user, *permission):
"""
deletes a permission for a user or role.
Returns false if the user or role does not have the permission (aka not affected).
"""
with self._wl:
return self._e.delete_permission_for_user(user, *permission)
def delete_permissions_for_user(self, user):
"""
deletes permissions for a user or role.
Returns false if the user or role does not have any permissions (aka not affected).
"""
with self._wl:
return self._e.delete_permissions_for_user(user)
def get_permissions_for_user(self, user):
"""
gets permissions for a user or role.
"""
with self._rl:
return self._e.get_permissions_for_user(user)
def has_permission_for_user(self, user, *permission):
"""
determines whether a user has a permission.
"""
with self._rl:
return self._e.has_permission_for_user(user, *permission)
def get_implicit_roles_for_user(self, name, *domain):
"""
gets implicit roles that a user has.
Compared to get_roles_for_user(), this function retrieves indirect roles besides direct roles.
For example:
g, alice, role:admin
g, role:admin, role:user
get_roles_for_user("alice") can only get: ["role:admin"].
But get_implicit_roles_for_user("alice") will get: ["role:admin", "role:user"].
"""
with self._rl:
return self._e.get_implicit_roles_for_user(name, *domain)
def get_implicit_permissions_for_user(self, user, *domain):
"""
gets implicit permissions for a user or role.
Compared to get_permissions_for_user(), this function retrieves permissions for inherited roles.
For example:
p, admin, data1, read
p, alice, data2, read
g, alice, admin
get_permissions_for_user("alice") can only get: [["alice", "data2", "read"]].
But get_implicit_permissions_for_user("alice") will get: [["admin", "data1", "read"], ["alice", "data2", "read"]].
"""
with self._rl:
return self._e.get_implicit_permissions_for_user(user, *domain)
def get_implicit_users_for_permission(self, *permission):
"""
gets implicit users for a permission.
For example:
p, admin, data1, read
p, bob, data1, read
g, alice, admin
get_implicit_users_for_permission("data1", "read") will get: ["alice", "bob"].
Note: only users will be returned, roles (2nd arg in "g") will be excluded.
"""
with self._rl:
return self._e.get_implicit_users_for_permission(*permission)
def get_roles_for_user_in_domain(self, name, domain):
"""gets the roles that a user has inside a domain."""
with self._rl:
return self._e.get_roles_for_user_in_domain(name, domain)
def get_users_for_role_in_domain(self, name, domain):
"""gets the users that has a role inside a domain."""
with self._rl:
return self._e.get_users_for_role_in_domain(name, domain)
def add_role_for_user_in_domain(self, user, role, domain):
"""adds a role for a user inside a domain."""
"""Returns false if the user already has the role (aka not affected)."""
with self._wl:
return self._e.add_role_for_user_in_domain(user, role, domain)
def delete_roles_for_user_in_domain(self, user, role, domain):
"""deletes a role for a user inside a domain."""
"""Returns false if the user does not have any roles (aka not affected)."""
with self._wl:
return self._e.delete_roles_for_user_in_domain(user, role, domain)
def get_permissions_for_user_in_domain(self, user, domain):
"""gets permissions for a user or role inside domain."""
with self._rl:
return self._e.get_permissions_for_user_in_domain(user, domain)
def enable_auto_build_role_links(self, auto_build_role_links):
"""controls whether to rebuild the role inheritance relations when a role is added or deleted."""
with self._wl:
return self._e.enable_auto_build_role_links(auto_build_role_links)
def enable_auto_save(self, auto_save):
"""controls whether to save a policy rule automatically to the adapter when it is added or removed."""
with self._wl:
return self._e.enable_auto_save(auto_save)
def enable_enforce(self, enabled=True):
"""changes the enforcing state of Casbin,
when Casbin is disabled, all access will be allowed by the Enforce() function.
"""
with self._wl:
return self._e.enable_enforce(enabled)
def add_named_matching_func(self, ptype, fn):
"""add_named_matching_func add MatchingFunc by ptype RoleManager"""
with self._wl:
self._e.add_named_matching_func(ptype, fn)
def add_named_domain_matching_func(self, ptype, fn):
"""add_named_domain_matching_func add MatchingFunc by ptype to RoleManager"""
with self._wl:
self._e.add_named_domain_matching_func(ptype, fn)
def is_filtered(self):
"""returns true if the loaded policy has been filtered."""
with self._rl:
self._e.is_filtered()
def add_policies(self,rules):
"""adds authorization rules to the current policy.
If the rule already exists, the function returns false for the corresponding rule and the rule will not be added.
Otherwise the function returns true for the corresponding rule by adding the new rule.
"""
with self._wl:
return self._e.add_policies(rules)
def add_named_policies(self,ptype,rules):
"""adds authorization rules to the current named policy.
If the rule already exists, the function returns false for the corresponding rule and the rule will not be added.
Otherwise the function returns true for the corresponding by adding the new rule."""
with self._wl:
return self._e.add_named_policies(ptype,rules)
def remove_policies(self,rules):
"""removes authorization rules from the current policy."""
with self._wl:
return self._e.remove_policies(rules)
def remove_named_policies(self,ptype,rules):
"""removes authorization rules from the current named policy."""
with self._wl:
return self._e.remove_named_policies(ptype,rules)
def add_grouping_policies(self,rules):
"""adds role inheritance rulea to the current policy.
If the rule already exists, the function returns false for the corresponding policy rule and the rule will not be added.
Otherwise the function returns true for the corresponding policy rule by adding the new rule.
"""
with self._wl:
return self._e.add_grouping_policies(rules)
def add_named_grouping_policies(self,ptype,rules):
""""adds named role inheritance rules to the current policy.
If the rule already exists, the function returns false for the corresponding policy rule and the rule will not be added.
Otherwise the function returns true for the corresponding policy rule by adding the new rule."""
with self._wl:
return self._e.add_named_grouping_policies(ptype,rules)
def remove_grouping_policies(self,rules):
"""removes role inheritance rulea from the current policy."""
with self._wl:
return self._e.addremove_grouping_policies_policies(rules)
def remove_named_grouping_policies(self,ptype,rules):
""" removes role inheritance rules from the current named policy."""
with self._wl:
return self._e.remove_named_grouping_policies(ptype,rules)
def build_incremental_role_links(self, op, ptype, rules):
self.get_model().build_incremental_role_links(self.get_role_manager(), op, "g", ptype, rules)

View File

@ -0,0 +1,3 @@
from .builtin_operators import *
from .expression import *
from .util import *

View File

@ -0,0 +1,137 @@
import fnmatch
import re
import ipaddress
KEY_MATCH2_PATTERN = re.compile(r'(.*?):[^\/]+(.*?)')
KEY_MATCH3_PATTERN = re.compile(r'(.*?){[^\/]+}(.*?)')
def key_match(key1, key2):
"""determines whether key1 matches the pattern of key2 (similar to RESTful path), key2 can contain a *.
For example, "/foo/bar" matches "/foo/*"
"""
i = key2.find("*")
if i == -1:
return key1 == key2
if len(key1) > i:
return key1[:i] == key2[:i]
return key1 == key2[:i]
def key_match_func(*args):
"""The wrapper for key_match.
"""
name1 = args[0]
name2 = args[1]
return key_match(name1, name2)
def key_match2(key1, key2):
"""determines whether key1 matches the pattern of key2 (similar to RESTful path), key2 can contain a *.
For example, "/foo/bar" matches "/foo/*", "/resource1" matches "/:resource"
"""
key2 = key2.replace("/*", "/.*")
key2 = KEY_MATCH2_PATTERN.sub(r'\g<1>[^\/]+\g<2>', key2, 0)
return regex_match(key1, "^" + key2 + "$")
def key_match2_func(*args):
name1 = args[0]
name2 = args[1]
return key_match2(name1, name2)
def key_match3(key1, key2):
"""determines determines whether key1 matches the pattern of key2 (similar to RESTful path), key2 can contain a *.
For example, "/foo/bar" matches "/foo/*", "/resource1" matches "/{resource}"
"""
key2 = key2.replace("/*", "/.*")
key2 = KEY_MATCH3_PATTERN.sub(r'\g<1>[^\/]+\g<2>', key2, 0)
return regex_match(key1, "^" + key2 + "$")
def key_match3_func(*args):
name1 = args[0]
name2 = args[1]
return key_match3(name1, name2)
def regex_match(key1, key2):
"""determines whether key1 matches the pattern of key2 in regular expression."""
res = re.match(key2, key1)
if res:
return True
else:
return False
def regex_match_func(*args):
"""the wrapper for RegexMatch."""
name1 = args[0]
name2 = args[1]
return regex_match(name1, name2)
def glob_match(string, pattern):
"""determines whether string matches the pattern in glob expression."""
return fnmatch.fnmatch(string, pattern)
def glob_match_func(*args):
"""the wrapper for globMatch."""
string = args[0]
pattern = args[1]
return glob_match(string, pattern)
def ip_match(ip1, ip2):
"""IPMatch determines whether IP address ip1 matches the pattern of IP address ip2, ip2 can be an IP address or a CIDR pattern.
For example, "192.168.2.123" matches "192.168.2.0/24"
"""
ip1 = ipaddress.ip_address(ip1)
try:
network = ipaddress.ip_network(ip2, strict=False)
return ip1 in network
except ValueError:
return ip1 == ip2
def ip_match_func(*args):
"""the wrapper for IPMatch."""
ip1 = args[0]
ip2 = args[1]
return ip_match(ip1, ip2)
def generate_g_function(rm):
"""the factory method of the g(_, _) function."""
def f(*args):
name1 = args[0]
name2 = args[1]
if not rm:
return name1 == name2
elif 2 == len(args):
return rm.has_link(name1, name2)
else:
domain = str(args[2])
return rm.has_link(name1, name2, domain)
return f

View File

@ -0,0 +1,29 @@
from simpleeval import SimpleEval
import ast
class SimpleEval(SimpleEval):
""" Rewrite SimpleEval.
>>> s = SimpleEval("20 + 30 - ( 10 * 5)")
>>> s.eval()
0
"""
ast_parsed_value = None
def __init__(self, expr, functions=None):
"""Create the evaluator instance. Set up valid operators (+,-, etc)
functions (add, random, get_val, whatever) and names. """
super(SimpleEval, self).__init__(functions=functions)
if expr != "":
self.expr = expr
self.expr_parsed_value = ast.parse(expr.strip()).body[0].value
def eval(self, names=None):
""" evaluate an expresssion, using the operators, functions and
names previously set up. """
if names:
self.names = names
return self._eval(self.expr_parsed_value)

View File

@ -0,0 +1,68 @@
from threading import RLock, Condition
# This implementation was adapted from https://en.wikipedia.org/wiki/Readers%E2%80%93writer_lock
class RWLockWrite():
''' write preferring readers-wirter lock '''
def __init__(self):
self._lock = RLock()
self._cond = Condition(self._lock)
self._active_readers = 0
self._waiting_writers = 0
self._writer_active = False
def aquire_read(self):
with self._lock:
while self._waiting_writers > 0 or self._writer_active:
self._cond.wait()
self._active_readers += 1
def release_read(self):
with self._lock:
self._active_readers -= 1
if self._active_readers == 0:
self._cond.notify_all()
def aquire_write(self):
with self._lock:
self._waiting_writers += 1
while self._active_readers > 0 or self._writer_active:
self._cond.wait()
self._waiting_writers -= 1
self._writer_active = True
def release_write(self):
with self._lock:
self._writer_active = False
self._cond.notify_all()
def gen_rlock(self):
return ReadRWLock(self)
def gen_wlock(self):
return WriteRWLock(self)
class ReadRWLock():
def __init__(self, rwlock):
self.rwlock = rwlock
def __enter__(self):
self.rwlock.aquire_read()
def __exit__(self, exc_type, exc_value, traceback):
self.rwlock.release_read()
return False
class WriteRWLock():
def __init__(self, rwlock):
self.rwlock = rwlock
def __enter__(self):
self.rwlock.aquire_write()
def __exit__(self, exc_type, exc_value, traceback):
self.rwlock.release_write()
return False

72
utils/casbin/util/util.py Normal file
View File

@ -0,0 +1,72 @@
from collections import OrderedDict
import re
eval_reg = re.compile(r'\beval\((?P<rule>[^)]*)\)')
def escape_assertion(s):
"""escapes the dots in the assertion, because the expression evaluation doesn't support such variable names."""
s = re.sub(r'\br\.', 'r_', s)
s = re.sub(r'\bp\.', 'p_', s)
return s
def remove_comments(s):
"""removes the comments starting with # in the text."""
pos = s.find("#")
if pos == -1:
return s
return s[0:pos].strip()
def array_remove_duplicates(s):
"""removes any duplicated elements in a string array."""
return list(OrderedDict.fromkeys(s))
def array_to_string(s):
"""gets a printable string for a string array."""
return ", ".join(s)
def params_to_string(*s):
"""gets a printable string for variable number of parameters."""
return ", ".join(s)
def join_slice(a, *b):
''' joins a string and a slice into a new slice.'''
res = [a]
res.extend(b)
return res
def set_subtract(a, b):
''' returns the elements in `a` that aren't in `b`. '''
return [i for i in a if i not in b]
def has_eval(s):
'''determine whether matcher contains function eval'''
return eval_reg.search(s)
def replace_eval(expr, rules):
''' replace all occurences of function eval with rules '''
pos = 0
match = eval_reg.search(expr, pos)
while match:
rule = "(" + rules.pop(0) + ")"
expr = expr[:match.start()] + rule + expr[match.end():]
pos = match.start() + len(rule)
match = eval_reg.search(expr, pos)
return expr
def get_eval_value(s):
'''returns the parameters of function eval'''
sub_match = eval_reg.findall(s)
return sub_match.copy()