import binascii import uvicorn from fastapi import FastAPI from starlette.middleware.cors import CORSMiddleware from starlette.authentication import AuthenticationBackend, AuthenticationError, AuthCredentials, BaseUser, SimpleUser from starlette.middleware.authentication import AuthenticationMiddleware from fastapi_authz import CasbinMiddleware from db import connect_to_mongo, close_mongo_connection, get_database from db.ckdb_utils import connect_to_ck, close_ck_connection from db.redisdb_utils import connect_to_redis, close_redis_connection from utils import * from api.api_v1.api import api_router from core.config import settings from api.deps import get_current_user2 app = FastAPI(title=settings.PROJECT_NAME) app.include_router(api_router, prefix=settings.API_V1_STR) app.add_event_handler("startup", connect_to_mongo) app.add_event_handler("startup", connect_to_redis) app.add_event_handler("startup", connect_to_ck) app.add_event_handler("shutdown", close_mongo_connection) app.add_event_handler("shutdown", close_redis_connection) app.add_event_handler("shutdown", close_ck_connection) class CurrentUser(BaseUser): def __init__(self, username: str, user_id: str) -> None: self.username = username self.id = user_id @property def is_authenticated(self) -> bool: return True @property def display_name(self) -> str: return self.username @property def identity(self) -> str: return '' class BasicAuth(AuthenticationBackend): async def authenticate(self, request): if "Authorization" not in request.headers: return None auth = request.headers["Authorization"] if len(auth) < 20: return None try: user = get_current_user2(auth.split(' ')[1]) except (ValueError, UnicodeDecodeError, binascii.Error): raise AuthenticationError("Invalid basic auth credentials") return AuthCredentials(["authenticated"]), CurrentUser(user.name, user.id) app.add_middleware(CasbinMiddleware, enforcer=casbin_enforcer) app.add_middleware(AuthenticationMiddleware, backend=BasicAuth()) app.add_middleware( CORSMiddleware, allow_origins=['*'], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) if __name__ == '__main__': uvicorn.run(app='main:app', host="0.0.0.0", port=8889, reload=True, debug=True)