diff --git a/main2.py b/main2.py new file mode 100644 index 0000000..aa9ff3c --- /dev/null +++ b/main2.py @@ -0,0 +1,80 @@ +import base64 +import binascii + +import uvicorn +from fastapi import FastAPI +import casbin + +from api.deps import get_current_user2 +from core.config import settings +from starlette.middleware.cors import CORSMiddleware +from starlette.authentication import AuthenticationBackend, AuthenticationError, AuthCredentials, BaseUser, SimpleUser +from starlette.middleware.authentication import AuthenticationMiddleware +from fastapi_authz import CasbinMiddleware + +from db import connect_to_mongo, close_mongo_connection, get_database +from utils import Adapter + +app = FastAPI(title=settings.PROJECT_NAME) + +if settings.BACKEND_CORS_ORIGINS: + app.add_middleware( + CORSMiddleware, + allow_origins=['*'], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) +app.add_event_handler("startup", connect_to_mongo) + +app.add_event_handler("shutdown", close_mongo_connection) + + +class CurrentUser(BaseUser): + def __init__(self, username: str, user_id: str) -> None: + self.username = username + self.id = user_id + + @property + def is_authenticated(self) -> bool: + return True + + @property + def display_name(self) -> str: + return self.username + + @property + def identity(self) -> str: + return '' + + +class BasicAuth(AuthenticationBackend): + async def authenticate(self, request): + if "Authorization" not in request.headers: + return None + + auth = request.headers["Authorization"] + if len(auth) < 20: + return None + try: + user = get_current_user2(auth.split(' ')[1]) + except (ValueError, UnicodeDecodeError, binascii.Error): + raise AuthenticationError("Invalid basic auth credentials") + + return AuthCredentials(["authenticated"]), CurrentUser(user.name, user.id) + + + +enforcer = casbin.Enforcer('rbac_model.conf', Adapter(settings.DATABASE_URI,settings.MDB_DB)) +app.add_middleware(CasbinMiddleware, enforcer=enforcer) +app.add_middleware(AuthenticationMiddleware, backend=BasicAuth()) + + + + +from api.api_v1.api import api_router + +app.include_router(api_router, prefix=settings.API_V1_STR) + +if __name__ == '__main__': + uvicorn.run(app='main2:app', host="0.0.0.0", port=8899, reload=True, debug=True) diff --git a/rbac_model.conf b/rbac_model.conf new file mode 100644 index 0000000..0d4f321 --- /dev/null +++ b/rbac_model.conf @@ -0,0 +1,14 @@ +[request_definition] +r = sub, obj, act + +[policy_definition] +p = sub, obj, act + +[role_definition] +g = _, _ + +[policy_effect] +e = some(where (p.eft == allow)) + +[matchers] +m = (p.sub == "*" || g(r.sub, p.sub)) && (r.obj == p.obj || keyMatch(r.obj, p.obj)) && (p.act == "*" || r.act == p.act) \ No newline at end of file diff --git a/rbac_policy.csv b/rbac_policy.csv new file mode 100644 index 0000000..f5c797f --- /dev/null +++ b/rbac_policy.csv @@ -0,0 +1,7 @@ + +p, *, /api/v1/user/login, * +p, *, /api/v1/project/, * +p, *, /docs, * +p, *, /openapi.json, * + +g, cathy, dataset1_admin \ No newline at end of file