import binascii import time import uvicorn from fastapi import FastAPI, Request from starlette.middleware.cors import CORSMiddleware from starlette.authentication import AuthenticationBackend, AuthenticationError, AuthCredentials, BaseUser, SimpleUser from starlette.middleware.authentication import AuthenticationMiddleware from middleware import CasbinMiddleware from db import connect_to_mongo, close_mongo_connection, get_database from db.ckdb_utils import connect_to_ck, close_ck_connection from db.redisdb_utils import connect_to_redis, close_redis_connection from utils import * from api.api_v1.api import api_router from core.config import settings from api.deps import get_current_user2 app = FastAPI(title=settings.PROJECT_NAME) app.include_router(api_router, prefix=settings.API_V1_STR) app.add_event_handler("startup", connect_to_mongo) app.add_event_handler("startup", connect_to_redis) app.add_event_handler("startup", connect_to_ck) app.add_event_handler("shutdown", close_mongo_connection) app.add_event_handler("shutdown", close_redis_connection) app.add_event_handler("shutdown", close_ck_connection) class CurrentUser(BaseUser): def __init__(self, username: str, user_id: str) -> None: self.username = username self.id = user_id @property def is_authenticated(self) -> bool: return True @property def display_name(self) -> str: return self.username @property def identity(self) -> str: return '' class BasicAuth(AuthenticationBackend): async def authenticate(self, request): if "Authorization" not in request.headers: return None auth = request.headers["Authorization"] if len(auth) < 20: return None try: user = get_current_user2(auth.split(' ')[1]) except (ValueError, UnicodeDecodeError, binascii.Error): raise AuthenticationError("Invalid basic auth credentials") return AuthCredentials(["authenticated"]), CurrentUser(user.name, user.id) app.add_middleware(CasbinMiddleware, enforcer=casbin_enforcer) app.add_middleware(AuthenticationMiddleware, backend=BasicAuth()) app.add_middleware( CORSMiddleware, allow_origins=['*'], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.middleware("http") async def add_process_time_header(request: Request, call_next): start_time = int(time.time()*1000) response = await call_next(request) process_time = int(time.time()*1000) - start_time response.headers["X-Process-Time"] = str(process_time) return response if __name__ == '__main__': uvicorn.run(app='main:app', host="0.0.0.0", port=7889, reload=True, debug=True)