198 lines
7.0 KiB
Python
198 lines
7.0 KiB
Python
import binascii
|
||
import time
|
||
|
||
import uvicorn
|
||
from fastapi import FastAPI, Request
|
||
from fastapi.exceptions import RequestValidationError
|
||
from starlette.middleware.cors import CORSMiddleware
|
||
from starlette.authentication import AuthenticationBackend, AuthenticationError, AuthCredentials, BaseUser, SimpleUser
|
||
from starlette.middleware.authentication import AuthenticationMiddleware
|
||
from starlette.requests import HTTPConnection
|
||
from starlette.responses import Response, JSONResponse
|
||
|
||
import crud
|
||
import schemas
|
||
|
||
|
||
from db import connect_to_mongo, close_mongo_connection, get_database
|
||
from db.ckdb_utils import connect_to_ck, close_ck_connection
|
||
from db.redisdb_utils import connect_to_redis, close_redis_connection
|
||
from utils import *
|
||
from api.api_v1.api import api_router
|
||
from core.config import settings
|
||
from api.deps import get_current_user2
|
||
|
||
app = FastAPI(title=settings.PROJECT_NAME)
|
||
app.include_router(api_router, prefix=settings.API_V1_STR)
|
||
|
||
app.add_event_handler("startup", connect_to_mongo)
|
||
app.add_event_handler("startup", connect_to_redis)
|
||
app.add_event_handler("startup", connect_to_ck)
|
||
|
||
app.add_event_handler("shutdown", close_mongo_connection)
|
||
app.add_event_handler("shutdown", close_redis_connection)
|
||
app.add_event_handler("shutdown", close_ck_connection)
|
||
|
||
|
||
class CurrentUser(BaseUser):
|
||
def __init__(self, username: str, user_id: str) -> None:
|
||
self.username = username
|
||
self.id = user_id
|
||
|
||
@property
|
||
def is_authenticated(self) -> bool:
|
||
return True
|
||
|
||
@property
|
||
def display_name(self) -> str:
|
||
return self.username
|
||
|
||
@property
|
||
def identity(self) -> str:
|
||
return ''
|
||
|
||
|
||
class BasicAuth(AuthenticationBackend):
|
||
async def authenticate(self, request):
|
||
if "Authorization" not in request.headers or request.scope.get('path') == '/api/v1/user/login':
|
||
return None
|
||
|
||
auth = request.headers["Authorization"]
|
||
if len(auth) < 20:
|
||
return None
|
||
try:
|
||
user = get_current_user2(auth.split(' ')[1])
|
||
except (ValueError, UnicodeDecodeError, binascii.Error):
|
||
raise AuthenticationError("身份验证失败,请重新登录")
|
||
|
||
return AuthCredentials(["authenticated"]), CurrentUser(user.name, user.id)
|
||
|
||
|
||
def login_expired(conn: HTTPConnection, exc: Exception) -> Response:
|
||
return JSONResponse(schemas.Msg(code=-5, msg='请重新登录').dict(), status_code=200)
|
||
#处理路由权限问题
|
||
# @app.middleware("http")
|
||
# async def panduan_quanxian_url(request: Request, call_next):
|
||
# #user_id=request.user.id
|
||
# #user=request.user.username
|
||
# start_time = int(time.time() * 1000)
|
||
# response = await call_next(request)
|
||
# process_time = int(time.time() * 1000) - start_time
|
||
# response.headers["X-Process-Time"] = str(process_time)
|
||
# url=request.url.path
|
||
# if 'docs' in url or 'openapi.json' in url:
|
||
# return response
|
||
# if url == '/api/v1/user/login':
|
||
# return response
|
||
# game=request.url.query.split('=')[1]
|
||
# if 'undefined' in game:
|
||
# return response
|
||
# if '&' in game:
|
||
# game=game.split('&')[0]
|
||
# judge_url = await crud.user_url.get_quanxian(get_database(), schemas.Url_quanxian(user_id=request.user.id))
|
||
# if judge_url == {}:
|
||
# # data='没有匹配这个游戏'
|
||
# return Response(schemas.Msg(code=0, msg='没有操作权限',data='').json())
|
||
# if game not in judge_url['game']:
|
||
# #data='没有匹配这个游戏'
|
||
# return Response(schemas.Msg(code=0, msg='没有操作权限',data='' ).json())
|
||
# quanxian_dict={}
|
||
# for i in range(len(judge_url['game'])):
|
||
# quanxian_dict[judge_url['game'][i]]=judge_url['quanxian'][i]
|
||
# user_list=await crud.url_list.get_url(get_database(),schemas.Url_list(name=quanxian_dict[game]))
|
||
# api_list=[]
|
||
# state_list=[]
|
||
# api_dict={}
|
||
# for i in user_list:
|
||
# for api in i['api_list']:
|
||
# api_list.append(api)
|
||
# for quanxian in i['state']:
|
||
# state_list.append(quanxian)
|
||
# for i in range(len(api_list)):
|
||
# api_dict[api_list[i]]=state_list[i]
|
||
# if url not in api_list:
|
||
# # data='没有对应路由'
|
||
# return Response(schemas.Msg(code=0, msg='没有操作权限',data='').json())
|
||
# elif api_dict[url] != True:
|
||
# # data='路由为False'
|
||
# return Response(schemas.Msg(code=0, msg='没有操作权限',data='').json())
|
||
#
|
||
# return response
|
||
@app.middleware("http")#如反馈任务未启动,则启动任务
|
||
async def run_task(request: Request, call_next):
|
||
"""
|
||
重启反馈提醒任务
|
||
"""
|
||
start_time = int(time.time() * 1000)
|
||
response = await call_next(request)
|
||
process_time = int(time.time() * 1000) - start_time
|
||
response.headers["X-Process-Time"] = str(process_time)
|
||
with open('task.json', 'r', encoding='utf-8') as f:
|
||
ff = f.read()
|
||
if ff == '':
|
||
return response
|
||
else:
|
||
data = json.loads(ff)
|
||
res = scheduler.queue
|
||
uid = [i.kwargs['uid'] for i in res] # 取所有的任务uid
|
||
_del=[]
|
||
for k, v in data.items():
|
||
if k not in uid: # 在现有任务里面不在json文件里面,则启动json文件里面的一个任务
|
||
now = str(time.time()).split('.')[0]
|
||
end_time = v['times'] - int(now)
|
||
# 没有过时的才会重启任务
|
||
if end_time > 0:
|
||
scheduler.enter(end_time, 1, task, kwargs=v)
|
||
else:
|
||
_del.append(k)
|
||
if _del != []:
|
||
for i in _del:
|
||
data.pop(i)
|
||
jsontext = json.dumps(data)
|
||
write_task(jsontext)
|
||
t = threading.Thread(target=scheduler.run)
|
||
t.start()
|
||
return response
|
||
app.add_middleware(AuthenticationMiddleware, backend=BasicAuth(), on_error=login_expired)
|
||
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=['*'],
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
|
||
@app.exception_handler(RequestValidationError)
|
||
async def validation_exception_handler(request, exc):
|
||
return Response(schemas.Msg(code=-4, msg='请求错误', data=str(exc)).json(), status_code=400)
|
||
|
||
|
||
@app.exception_handler(Exception)
|
||
async def http_exception_handler(request, exc):
|
||
return Response(schemas.Msg(code=-3, msg='服务器错误').json(), status_code=500)
|
||
|
||
|
||
@app.middleware("http")
|
||
async def add_process_time_header(request: Request, call_next):
|
||
start_time = int(time.time() * 1000)
|
||
response = await call_next(request)
|
||
process_time = int(time.time() * 1000) - start_time
|
||
response.headers["X-Process-Time"] = str(process_time)
|
||
user_id = 'anonymous'
|
||
try:
|
||
user_id = request.user.id
|
||
except:
|
||
pass
|
||
await crud.api_log.insert_log(get_database(), schemas.ApiLogInsert(
|
||
api=str(request.url),
|
||
ms=process_time,
|
||
user_id=user_id
|
||
))
|
||
return response
|
||
|
||
|
||
if __name__ == '__main__':
|
||
#uvicorn.run(app='main:app', host="10.0.0.240", port=7800, reload=True, debug=True)
|
||
uvicorn.run(app='main:app', host="0.0.0.0", port=7800, reload=True, debug=True) |