From 5501e28b62502ecef81097ac7c9c145b02dd4cb5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E4=BC=9F?= <250213850@qq.com> Date: Mon, 29 Aug 2022 15:42:13 +0800 Subject: [PATCH] =?UTF-8?q?1.=E4=BC=98=E5=8C=96=E7=99=BB=E5=BD=95=E6=96=B9?= =?UTF-8?q?=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/api_v1/endpoints/interview.py | 11 +++++++---- schemas/interview_plan.py | 15 +++++++++++++++ 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/api/api_v1/endpoints/interview.py b/api/api_v1/endpoints/interview.py index 0b626f0..5ed650d 100644 --- a/api/api_v1/endpoints/interview.py +++ b/api/api_v1/endpoints/interview.py @@ -15,6 +15,7 @@ from api import deps from core import security from core.config import settings from core.security import get_password_hash +from schemas import ExtendendOAuth from utils.dingding import * from utils.jianli import get_resume from utils.func import get_uid @@ -814,6 +815,7 @@ async def get_operate_log( @router.get("/get_dding_user") async def get_dding_user( request: Request, + current_user: schemas.User = Depends(deps.get_current_user) ) -> schemas.Msg: """获取钉钉的用户id""" data = get_redis_alluid() @@ -921,6 +923,7 @@ async def update_mode( async def name( request: Request, db: AsyncIOMotorDatabase = Depends(get_database), + current_user: schemas.User = Depends(deps.get_current_user) ) -> schemas.Msg: """获取所有用户角色""" res = await crud.user.get_all_users(db, {}) @@ -1173,14 +1176,14 @@ async def up_hint( @router.post("/login") async def login( - data_in: schemas.Login, - data: OAuth2PasswordRequestForm = Depends(), + #data_in: schemas.Login, + data: ExtendendOAuth = Depends(), db: AsyncIOMotorDatabase = Depends(get_database) ) -> Any: """ OAuth2兼容令牌登录,获取将来令牌的访问令牌 """ - if data_in.unionid == '': + if data.unionid == None: # 账号密码登录 user = await crud.user.authenticate(db, name=data.username, password=data.password @@ -1193,7 +1196,7 @@ async def login( await crud.user.update_login_time(db, data.username) else: # 钉钉扫码登录 - user_id = Unionid(data_in.unionid) + user_id = Unionid(data.unionid) user_list = get_alluid_list() if user_id not in user_list: return schemas.Msg(code=-1, msg='密码或用户名错误') diff --git a/schemas/interview_plan.py b/schemas/interview_plan.py index c10717b..a83b406 100644 --- a/schemas/interview_plan.py +++ b/schemas/interview_plan.py @@ -1,6 +1,8 @@ import time from datetime import datetime from typing import List, Union, Dict +from fastapi.param_functions import Form +from fastapi.security import OAuth2PasswordRequestForm from pydantic import BaseModel from typing import Optional @@ -122,3 +124,16 @@ class Ins_section(BaseModel): class Filenames(BaseModel): filenames: str + + +class ExtendendOAuth(OAuth2PasswordRequestForm): + def __init__(self, grant_type: str = Form(None, regex="password"), + username: str = Form(...), + password: str = Form(...), + scope: str = Form(""), + client_id: Optional[str] = Form(None), + client_secret: Optional[str] = Form(None), + unionid: str = Form(None)): + super().__init__(grant_type, username, password, scope, client_id, client_secret) + self.unionid = unionid # 通过钉钉扫码获取的unionid + # unionid :str # 通过钉钉扫码获取的unionid