xbackend/utils/adapter.py
2021-10-11 17:36:15 +08:00

149 lines
4.2 KiB
Python

from utils import casbin
from utils.casbin import persist
from pymongo import MongoClient
from core.config import settings
__all__ = 'casbin_adapter', 'casbin_enforcer', 'casbin_model'
class CasbinRule:
'''
CasbinRule model
'''
def __init__(self, ptype=None, v0=None, v1=None, v2=None, v3=None, v4=None, v5=None):
self.ptype = ptype
self.v0 = v0
self.v1 = v1
self.v2 = v2
self.v3 = v3
self.v4 = v4
self.v5 = v5
def dict(self):
d = {'ptype': self.ptype}
for i, v in enumerate([self.v0, self.v1, self.v2, self.v3, self.v4, self.v5]):
if v is None:
continue
d['v' + str(i)] = v
return d
def __str__(self):
return ', '.join(self.dict().values())
def __repr__(self):
return '<CasbinRule :"{}">'.format(str(self))
class Adapter(persist.Adapter):
"""the interface for Casbin adapters."""
def __init__(self, uri, dbname, collection="casbin_rule"):
client = MongoClient(uri)
db = client[dbname]
self._collection = db[collection]
@staticmethod
def format_policy(ptype, field_values, field_index=0):
line = CasbinRule(ptype=ptype)
for i in range(field_index, field_index+len(field_values)):
line.__setattr__(f'v{i}', field_values[i - field_index])
# if len(args) > 0:
# line.v0 = args[0]
# if len(args) > 1:
# line.v1 = args[1]
# if len(args) > 2:
# line.v2 = args[2]
# if len(args) > 3:
# line.v3 = args[3]
# if len(args) > 4:
# line.v4 = args[4]
# if len(args) > 5:
# line.v5 = args[5]
return line
def load_policy(self, model):
'''
implementing add Interface for casbin \n
load all policy rules from mongodb \n
'''
for line in self._collection.find():
if 'ptype' not in line:
continue
rule = CasbinRule(line['ptype'])
if 'v0' in line:
rule.v0 = line['v0']
if 'v1' in line:
rule.v1 = line['v1']
if 'v2' in line:
rule.v2 = line['v2']
if 'v3' in line:
rule.v3 = line['v3']
if 'v4' in line:
rule.v4 = line['v4']
if 'v5' in line:
rule.v5 = line['v5']
persist.load_policy_line(str(rule), model)
def _save_policy_line(self, ptype, rule):
line = CasbinRule(ptype=ptype)
if len(rule) > 0:
line.v0 = rule[0]
if len(rule) > 1:
line.v1 = rule[1]
if len(rule) > 2:
line.v2 = rule[2]
if len(rule) > 3:
line.v3 = rule[3]
if len(rule) > 4:
line.v4 = rule[4]
if len(rule) > 5:
line.v5 = rule[5]
self._collection.update_one(line.dict(), {'$set': line.dict()}, upsert=True)
def save_policy(self, model):
'''
implementing add Interface for casbin \n
save the policy in mongodb \n
'''
for sec in ["p", "g"]:
if sec not in model.model.keys():
continue
for ptype, ast in model.model[sec].items():
for rule in ast.policy:
self._save_policy_line(ptype, rule)
return True
def add_policy(self, sec, ptype, rule):
"""add policy rules to mongodb"""
self._save_policy_line(ptype, rule)
def remove_policy(self, sec, ptype, rule):
"""delete policy rules from mongodb"""
line = self.format_policy(ptype, rule)
self._collection.delete_one(line.dict())
def remove_filtered_policy(self, sec, ptype, field_index, *field_values):
"""
delete policy rules for matching filters from mongodb
"""
line = self.format_policy(ptype, field_values, field_index)
self._collection.delete_one(line.dict())
casbin_adapter = Adapter(settings.DATABASE_URI, settings.MDB_DB)
casbin_enforcer = casbin.SyncedEnforcer('rbac_model.conf', casbin_adapter)
casbin_model = casbin_enforcer.get_model()
casbin_enforcer.start_auto_load_policy(30)