from utils import casbin from utils.casbin import persist from pymongo import MongoClient from core.config import settings __all__ = 'casbin_adapter', 'casbin_enforcer', 'casbin_model' class CasbinRule: ''' CasbinRule model ''' def __init__(self, ptype=None, v0=None, v1=None, v2=None, v3=None, v4=None, v5=None): self.ptype = ptype self.v0 = v0 self.v1 = v1 self.v2 = v2 self.v3 = v3 self.v4 = v4 self.v5 = v5 def dict(self): d = {'ptype': self.ptype} for i, v in enumerate([self.v0, self.v1, self.v2, self.v3, self.v4, self.v5]): if v is None: continue d['v' + str(i)] = v return d def __str__(self): return ', '.join(self.dict().values()) def __repr__(self): return ''.format(str(self)) class Adapter(persist.Adapter): """the interface for Casbin adapters.""" def __init__(self, uri, dbname, collection="casbin_rule"): client = MongoClient(uri) db = client[dbname] self._collection = db[collection] @staticmethod def format_policy(ptype, field_values, field_index=0): line = CasbinRule(ptype=ptype) for i in range(field_index, field_index+len(field_values)): line.__setattr__(f'v{i}', field_values[i - field_index]) # if len(args) > 0: # line.v0 = args[0] # if len(args) > 1: # line.v1 = args[1] # if len(args) > 2: # line.v2 = args[2] # if len(args) > 3: # line.v3 = args[3] # if len(args) > 4: # line.v4 = args[4] # if len(args) > 5: # line.v5 = args[5] return line def load_policy(self, model): ''' implementing add Interface for casbin \n load all policy rules from mongodb \n ''' for line in self._collection.find(): if 'ptype' not in line: continue rule = CasbinRule(line['ptype']) if 'v0' in line: rule.v0 = line['v0'] if 'v1' in line: rule.v1 = line['v1'] if 'v2' in line: rule.v2 = line['v2'] if 'v3' in line: rule.v3 = line['v3'] if 'v4' in line: rule.v4 = line['v4'] if 'v5' in line: rule.v5 = line['v5'] persist.load_policy_line(str(rule), model) def _save_policy_line(self, ptype, rule): line = CasbinRule(ptype=ptype) if len(rule) > 0: line.v0 = rule[0] if len(rule) > 1: line.v1 = rule[1] if len(rule) > 2: line.v2 = rule[2] if len(rule) > 3: line.v3 = rule[3] if len(rule) > 4: line.v4 = rule[4] if len(rule) > 5: line.v5 = rule[5] self._collection.update_one(line.dict(), {'$set': line.dict()}, upsert=True) def save_policy(self, model): ''' implementing add Interface for casbin \n save the policy in mongodb \n ''' for sec in ["p", "g"]: if sec not in model.model.keys(): continue for ptype, ast in model.model[sec].items(): for rule in ast.policy: self._save_policy_line(ptype, rule) return True def add_policy(self, sec, ptype, rule): """add policy rules to mongodb""" self._save_policy_line(ptype, rule) def remove_policy(self, sec, ptype, rule): """delete policy rules from mongodb""" line = self.format_policy(ptype, rule) self._collection.delete_one(line.dict()) def remove_filtered_policy(self, sec, ptype, field_index, *field_values): """ delete policy rules for matching filters from mongodb """ line = self.format_policy(ptype, field_values, field_index) self._collection.delete_one(line.dict()) casbin_adapter = Adapter(settings.DATABASE_URI, settings.MDB_DB) casbin_enforcer = casbin.SyncedEnforcer('rbac_model.conf', casbin_adapter) casbin_model = casbin_enforcer.get_model() casbin_enforcer.start_auto_load_policy(30)