diff --git a/.gitignore b/.gitignore index 13d1490..de0e8b2 100644 --- a/.gitignore +++ b/.gitignore @@ -129,3 +129,4 @@ dmypy.json # Pyre type checker .pyre/ +.idea \ No newline at end of file diff --git a/api/__init__.py b/api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/api/api_v1/__init__.py b/api/api_v1/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/api/api_v1/api.py b/api/api_v1/api.py new file mode 100644 index 0000000..e3e6c7a --- /dev/null +++ b/api/api_v1/api.py @@ -0,0 +1,6 @@ +from fastapi import APIRouter +from api.api_v1.endpoints import login + +api_router = APIRouter() + +api_router.include_router(login.router, tags=["登录接口"]) diff --git a/api/api_v1/endpoints/__init__.py b/api/api_v1/endpoints/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/api/api_v1/endpoints/dashboard.py b/api/api_v1/endpoints/dashboard.py new file mode 100644 index 0000000..0b534df --- /dev/null +++ b/api/api_v1/endpoints/dashboard.py @@ -0,0 +1,31 @@ +from datetime import timedelta +from typing import Any + +from fastapi import APIRouter, Body, Depends, HTTPException, Request +from sqlalchemy.orm import Session + +import crud, models, schemas +from api import deps +from core import security +from core.config import settings +from core.security import get_password_hash +from utils import ( + verify_password_reset_token, +) + +router = APIRouter() + + +@router.post("/create-space") +def create_space() -> Any: + pass + + +@router.post("/create-folder") +def create_folder() -> Any: + pass + + +@router.post("/create-folder") +def create_folder() -> Any: + pass diff --git a/api/api_v1/endpoints/login.py b/api/api_v1/endpoints/login.py new file mode 100644 index 0000000..9504369 --- /dev/null +++ b/api/api_v1/endpoints/login.py @@ -0,0 +1,86 @@ +import json +from datetime import timedelta +from typing import Any + +from fastapi import APIRouter, Body, Depends, HTTPException, Request +from fastapi.security import OAuth2PasswordRequestForm +from sqlalchemy.orm import Session + +import crud, models, schemas +from api import deps +from core import security +from core.config import settings +from core.security import get_password_hash +from utils import ( + verify_password_reset_token, +) + +router = APIRouter() + + +@router.post("/login") +def login( + data: schemas.UserLogin, + # data: OAuth2PasswordRequestForm = Depends(), + db: Session = Depends(deps.get_db), +) -> Any: + """ + OAuth2兼容令牌登录,获取将来令牌的访问令牌 + """ + user = crud.user.authenticate( + db, name=data.username, password=data.password + ) + if not user: + raise HTTPException(status_code=400, detail="Incorrect name or password") + elif not crud.user.is_active(user): + raise HTTPException(status_code=400, detail="Inactive user") + access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) + return { + 'data': { + 'name': user.name, + 'email': user.email, + 'token': security.create_access_token( + expires_delta=access_token_expires, user_id=user.id, email=user.email, is_active=user.is_active, + is_superuser=user.is_superuser, name=user.name + ), + }, + + 'code': 0, + 'msg': 'success', + } + + +@router.post("/me", response_model=schemas.UserBase) +# @router.post("/me/") +def me(request: Request) -> Any: + """ + Test access token + """ + return request.state.user + + +@router.post("/reset-password", response_model=schemas.Msg) +def reset_password( + token: str = Body(...), + new_password: str = Body(...), + db: Session = Depends(deps.get_db), +) -> Any: + """ + 重设密码 + """ + user_id = verify_password_reset_token(token) + if not user_id: + raise HTTPException(status_code=400, detail="Invalid token") + user = crud.user.get(db, user_id) + if not user: + raise HTTPException( + status_code=404, + detail="The user with this username does not exist in the system.", + ) + elif not crud.user.is_active(user): + raise HTTPException(status_code=400, detail="Inactive user") + hashed_password = get_password_hash(new_password) + user.hashed_password = hashed_password + db.add(user) + db.commit() + return {"msg": "Password updated successfully"} diff --git a/api/api_v1/endpoints/manage.py b/api/api_v1/endpoints/manage.py new file mode 100644 index 0000000..a78adfb --- /dev/null +++ b/api/api_v1/endpoints/manage.py @@ -0,0 +1,22 @@ +from datetime import timedelta +from typing import Any + +from fastapi import APIRouter, Body, Depends, HTTPException, Request +from sqlalchemy.orm import Session + +import crud, models, schemas +from api import deps +from core import security +from core.config import settings +from core.security import get_password_hash +from utils import ( + verify_password_reset_token, +) + +router = APIRouter() + + +@router.post("/create-project") +def create_project() -> Any: + pass + diff --git a/api/deps.py b/api/deps.py new file mode 100644 index 0000000..ca8e8fa --- /dev/null +++ b/api/deps.py @@ -0,0 +1,60 @@ +import json +from typing import Generator + +from fastapi import Depends, HTTPException, status +from fastapi.security import OAuth2PasswordBearer +from jose import jwt +from pydantic import ValidationError +from sqlalchemy.orm import Session + +import crud, models, schemas +from core import security +from core.config import settings +from db.session import SessionLocal + +reusable_oauth2 = OAuth2PasswordBearer( + tokenUrl=f"{settings.API_V1_STR}/login/" +) + + +def get_db() -> Generator: + try: + db = SessionLocal() + yield db + finally: + db.close() + + +def get_current_user(token:str + ) -> schemas.UserBase: + try: + payload = jwt.decode( + token, settings.SECRET_KEY, algorithms=[security.ALGORITHM] + ) + user = schemas.UserBase(**payload) + except (jwt.JWTError, ValidationError): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Could not validate credentials", + ) + if not user: + raise HTTPException(status_code=404, detail="User not found") + return user + + +def get_current_active_user( + current_user: models.User = Depends(get_current_user), +) -> models.User: + if not crud.user.is_active(current_user): + raise HTTPException(status_code=400, detail="Inactive user") + return current_user + + +def get_current_active_superuser( + current_user: models.User = Depends(get_current_user), +) -> models.User: + if not crud.user.is_superuser(current_user): + raise HTTPException( + status_code=400, detail="The user doesn't have enough privileges" + ) + return current_user diff --git a/core/__init__.py b/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/core/config.py b/core/config.py new file mode 100644 index 0000000..26627b3 --- /dev/null +++ b/core/config.py @@ -0,0 +1,33 @@ +from typing import Any, Dict, List, Optional, Union + +from pydantic import AnyHttpUrl, BaseSettings, EmailStr, HttpUrl, validator + + +class Settings(BaseSettings): + PROJECT_NAME: str = 'X数据分析后台' + API_V1_STR: str = '/api/v1' + + ALLOW_ANONYMOUS: tuple = ('login','openapi.json') + + BACKEND_CORS_ORIGINS: List[str] = ['*'] + + MYSQL_HOST: str = '127.0.0.1' + MYSQL_PORT: int = 3306 + MYSQL_USER: str = 'root' + MYSQL_PASSWORD: str = 'root' + MYSQL_DB: str = 'xdata' + + SQLALCHEMY_DATABASE_URI = f'mysql+pymysql://{MYSQL_USER}:{MYSQL_PASSWORD}@{MYSQL_HOST}:{MYSQL_PORT}/{MYSQL_DB}' + + FIRST_EMAIL: str = '15392746632@qq.com' + FIRST_SUPERUSER_PASSWORD: str = '123456' + FIRST_NAME: str = 'root' + + ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 8 + SECRET_KEY: str = 'ZaFX6EypK6PtuhGv11q4DLRvAb0csiLx4dbKUwLwCe8' + + class Config: + case_sensitive = True + + +settings = Settings() diff --git a/core/security.py b/core/security.py new file mode 100644 index 0000000..de9d365 --- /dev/null +++ b/core/security.py @@ -0,0 +1,33 @@ +from datetime import datetime, timedelta +from typing import Any, Union + +from jose import jwt +from passlib.context import CryptContext + +from core.config import settings + +pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") + +ALGORITHM = "HS256" + + +def create_access_token( + expires_delta: timedelta = None, **payload +) -> str: + if expires_delta: + expire = datetime.utcnow() + expires_delta + else: + expire = datetime.utcnow() + timedelta( + minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES + ) + payload["exp"] = expire + encoded_jwt = jwt.encode(payload, settings.SECRET_KEY, algorithm=ALGORITHM) + return encoded_jwt + + +def verify_password(plain_password: str, hashed_password: str) -> bool: + return pwd_context.verify(plain_password, hashed_password) + + +def get_password_hash(password: str) -> str: + return pwd_context.hash(password) diff --git a/crud/__init__.py b/crud/__init__.py new file mode 100644 index 0000000..fd75a0e --- /dev/null +++ b/crud/__init__.py @@ -0,0 +1,3 @@ +from .crud_user import user + + diff --git a/crud/base.py b/crud/base.py new file mode 100644 index 0000000..f7ce6ee --- /dev/null +++ b/crud/base.py @@ -0,0 +1,66 @@ +from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union + +from fastapi.encoders import jsonable_encoder +from pydantic import BaseModel +from sqlalchemy.orm import Session + +from db.base_class import Base + +ModelType = TypeVar("ModelType", bound=Base) +CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel) +UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel) + + +class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): + def __init__(self, model: Type[ModelType]): + """ + CRUD object with default methods to Create, Read, Update, Delete (CRUD). + + **Parameters** + + * `model`: A SQLAlchemy model class + * `schema`: A Pydantic model (schema) class + """ + self.model = model + + def get(self, db: Session, id: Any) -> Optional[ModelType]: + return db.query(self.model).filter(self.model.id == id).first() + + def get_multi( + self, db: Session, *, skip: int = 0, limit: int = 100 + ) -> List[ModelType]: + return db.query(self.model).offset(skip).limit(limit).all() + + def create(self, db: Session, *, obj_in: CreateSchemaType) -> ModelType: + obj_in_data = jsonable_encoder(obj_in) + db_obj = self.model(**obj_in_data) # type: ignore + db.add(db_obj) + db.commit() + db.refresh(db_obj) + return db_obj + + def update( + self, + db: Session, + *, + db_obj: ModelType, + obj_in: Union[UpdateSchemaType, Dict[str, Any]] + ) -> ModelType: + obj_data = jsonable_encoder(db_obj) + if isinstance(obj_in, dict): + update_data = obj_in + else: + update_data = obj_in.dict(exclude_unset=True) + for field in obj_data: + if field in update_data: + setattr(db_obj, field, update_data[field]) + db.add(db_obj) + db.commit() + db.refresh(db_obj) + return db_obj + + def remove(self, db: Session, *, id: int) -> ModelType: + obj = db.query(self.model).get(id) + db.delete(obj) + db.commit() + return obj diff --git a/crud/crud_item.py b/crud/crud_item.py new file mode 100644 index 0000000..dcb87cd --- /dev/null +++ b/crud/crud_item.py @@ -0,0 +1,34 @@ +from typing import List + +from fastapi.encoders import jsonable_encoder +from sqlalchemy.orm import Session + +from app.crud.base import CRUDBase +from app.models.item import Item +from app.schemas.item import ItemCreate, ItemUpdate + + +class CRUDItem(CRUDBase[Item, ItemCreate, ItemUpdate]): + def create_with_owner( + self, db: Session, *, obj_in: ItemCreate, owner_id: int + ) -> Item: + obj_in_data = jsonable_encoder(obj_in) + db_obj = self.model(**obj_in_data, owner_id=owner_id) + db.add(db_obj) + db.commit() + db.refresh(db_obj) + return db_obj + + def get_multi_by_owner( + self, db: Session, *, owner_id: int, skip: int = 0, limit: int = 100 + ) -> List[Item]: + return ( + db.query(self.model) + .filter(Item.owner_id == owner_id) + .offset(skip) + .limit(limit) + .all() + ) + + +item = CRUDItem(Item) diff --git a/crud/crud_user.py b/crud/crud_user.py new file mode 100644 index 0000000..90a9b46 --- /dev/null +++ b/crud/crud_user.py @@ -0,0 +1,55 @@ +from typing import Any, Dict, Optional, Union + +from sqlalchemy.orm import Session + +from core.security import get_password_hash, verify_password +from crud.base import CRUDBase +from models.user import User +from schemas.user import UserCreate, UserUpdate + + +class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]): + def get_by_user(self, db: Session, *, name: str) -> Optional[User]: + return db.query(User).filter(User.name == name).first() + + def create(self, db: Session, *, obj_in: UserCreate) -> User: + db_obj = User( + email=obj_in.email, + hashed_password=get_password_hash(obj_in.password), + name=obj_in.name, + is_superuser=obj_in.is_superuser, + ) + db.add(db_obj) + db.commit() + db.refresh(db_obj) + return db_obj + + def update( + self, db: Session, *, db_obj: User, obj_in: Union[UserUpdate, Dict[str, Any]] + ) -> User: + if isinstance(obj_in, dict): + update_data = obj_in + else: + update_data = obj_in.dict(exclude_unset=True) + if update_data["password"]: + hashed_password = get_password_hash(update_data["password"]) + del update_data["password"] + update_data["hashed_password"] = hashed_password + return super().update(db, db_obj=db_obj, obj_in=update_data) + + def authenticate(self, db: Session, *, name: str, password: str) -> Optional[User]: + user = self.get_by_user(db, name=name) + if not user: + return None + if not verify_password(password, user.hashed_password): + return None + return user + + def is_active(self, user: User) -> bool: + return user.is_active + + def is_superuser(self, user: User) -> bool: + return user.is_superuser + + +user = CRUDUser(User) diff --git a/db/__init__.py b/db/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/db/base.py b/db/base.py new file mode 100644 index 0000000..e69de29 diff --git a/db/base_class.py b/db/base_class.py new file mode 100644 index 0000000..0b94d5c --- /dev/null +++ b/db/base_class.py @@ -0,0 +1,14 @@ +from typing import Any + +from sqlalchemy.ext.declarative import as_declarative, declared_attr + + +@as_declarative() +class Base: + id: Any + __name__: str + + # 自动生成表名 + @declared_attr + def __tablename__(cls) -> str: + return cls.__name__.lower() diff --git a/db/init_db.py b/db/init_db.py new file mode 100644 index 0000000..5097a85 --- /dev/null +++ b/db/init_db.py @@ -0,0 +1,16 @@ +from sqlalchemy.orm import Session +import crud, schemas +from core.config import settings +from db import base # noqa: F401 + + +def init_db(db: Session) -> None: + user = crud.user.get_by_user(db, name=settings.FIRST_NAME) + if not user: + user_in = schemas.UserCreate( + name=settings.FIRST_NAME, + email=settings.FIRST_EMAIL, + password=settings.FIRST_SUPERUSER_PASSWORD, + is_superuser=True, + ) + user = crud.user.create(db, obj_in=user_in) # noqa: F841 diff --git a/db/session.py b/db/session.py new file mode 100644 index 0000000..4e2ceb6 --- /dev/null +++ b/db/session.py @@ -0,0 +1,7 @@ +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +from core.config import settings + +engine = create_engine(settings.SQLALCHEMY_DATABASE_URI, pool_pre_ping=True) +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) diff --git a/initial_data.py b/initial_data.py new file mode 100644 index 0000000..7a02390 --- /dev/null +++ b/initial_data.py @@ -0,0 +1,22 @@ +import logging + +from db.init_db import init_db +from db.session import SessionLocal + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def init() -> None: + db = SessionLocal() + init_db(db) + + +def main() -> None: + logger.info("Creating initial data") + init() + logger.info("Initial data created") + + +if __name__ == "__main__": + main() diff --git a/main.py b/main.py new file mode 100644 index 0000000..a208b01 --- /dev/null +++ b/main.py @@ -0,0 +1,45 @@ +import time + +import uvicorn +from fastapi import FastAPI, Request, Depends +from starlette.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse +from api.api_v1.api import api_router +from api.deps import get_current_user +from core.config import settings + +app = FastAPI(title=settings.PROJECT_NAME, openapi_url=f"{settings.API_V1_STR}/openapi.json") + +allow_anonymous = [f'{settings.API_V1_STR}/{page}' for page in settings.ALLOW_ANONYMOUS] +allow_anonymous.extend(['/docs']) + + +@app.middleware("http") +async def add_jwt_auth(request: Request, call_next): + fail = {'code': -1, 'msg': '身份验证失败'} + if request.scope.get('path') not in allow_anonymous: + token = request.headers.get("Authorization") + try: + user = get_current_user(token) + except: + return JSONResponse(fail) + # + # request.state.user = user + + response = await call_next(request) + return response + + +if settings.BACKEND_CORS_ORIGINS: + app.add_middleware( + CORSMiddleware, + allow_origins=['*'], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + +app.include_router(api_router, prefix=settings.API_V1_STR) + +if __name__ == '__main__': + uvicorn.run(app='main:app', host="0.0.0.0", port=8888, reload=True, debug=True) diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..b7bb9be --- /dev/null +++ b/models/__init__.py @@ -0,0 +1 @@ +from .user import User \ No newline at end of file diff --git a/models/dashboard.py b/models/dashboard.py new file mode 100644 index 0000000..c9079d2 --- /dev/null +++ b/models/dashboard.py @@ -0,0 +1,15 @@ +from sqlalchemy import Boolean, Column, Integer, String, ForeignKey +from sqlalchemy.orm import relationship + +from db.base_class import Base + + +class Dashboard(Base): + id = Column(Integer, primary_key=True, index=True) + folder_type = Column(String) + pid = Column(Integer, ForeignKey('dashboard.id')) + parent = relationship('Dashboard', remote_side=[id]) + children = relationship('Dashboard', remote_side=[pid]) + name = Column(String, nullable=False) + user_id = Column(Integer, ForeignKey('user.id')) + project_id = Column(Integer, ForeignKey('project.id')) diff --git a/models/project.py b/models/project.py new file mode 100644 index 0000000..1714f0d --- /dev/null +++ b/models/project.py @@ -0,0 +1,8 @@ +from sqlalchemy import Boolean, Column, Integer, String, ForeignKey +from sqlalchemy.orm import relationship + +from db.base_class import Base + + +class Dashboard(Base): + id = Column(Integer, primary_key=True, index=True) diff --git a/models/user.py b/models/user.py new file mode 100644 index 0000000..209f44d --- /dev/null +++ b/models/user.py @@ -0,0 +1,13 @@ +from sqlalchemy import Boolean, Column, Integer, String +from sqlalchemy.orm import relationship +from db.base_class import Base + + +class User(Base): + id = Column(Integer, primary_key=True, index=True) + name = Column(String, index=True) + email = Column(String, unique=True, index=True, nullable=False) + hashed_password = Column(String, nullable=False) + is_active = Column(Boolean(), default=True) + is_superuser = Column(Boolean(), default=False) + dashboard = relationship('Dashboard', back_populates='user') diff --git a/schemas/__init__.py b/schemas/__init__.py new file mode 100644 index 0000000..d8b73d6 --- /dev/null +++ b/schemas/__init__.py @@ -0,0 +1,3 @@ +from .user import User, UserCreate, UserInDB, UserUpdate,UserLogin,UserBase +from .token import Token, TokenPayload +from .msg import Msg \ No newline at end of file diff --git a/schemas/msg.py b/schemas/msg.py new file mode 100644 index 0000000..945e0c6 --- /dev/null +++ b/schemas/msg.py @@ -0,0 +1,5 @@ +from pydantic import BaseModel + + +class Msg(BaseModel): + msg: str diff --git a/schemas/token.py b/schemas/token.py new file mode 100644 index 0000000..03a65de --- /dev/null +++ b/schemas/token.py @@ -0,0 +1,15 @@ +from typing import Optional + +from pydantic import BaseModel + + +class Token(BaseModel): + token: str + code: int + name: str + email: str + msg: str + + +class TokenPayload(BaseModel): + user_id: int = None diff --git a/schemas/user.py b/schemas/user.py new file mode 100644 index 0000000..7a28596 --- /dev/null +++ b/schemas/user.py @@ -0,0 +1,44 @@ +from typing import Optional + +from pydantic import BaseModel, EmailStr + + +# Shared properties +class UserBase(BaseModel): + email: Optional[EmailStr] = None + is_active: Optional[bool] = True + is_superuser: bool = False + name: Optional[str] = None + + +class UserLogin(BaseModel): + username: str = ... + password: str = ... + + +# Properties to receive via API on creation +class UserCreate(UserBase): + email: EmailStr + password: str + + +# Properties to receive via API on update +class UserUpdate(UserBase): + password: Optional[str] = None + + +class UserInDBBase(UserBase): + id: Optional[int] = None + + class Config: + orm_mode = True + + +# Additional properties to return via API +class User(UserInDBBase): + pass + + +# Additional properties stored in DB +class UserInDB(UserInDBBase): + hashed_password: str diff --git a/test/__init__.py b/test/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..88ac30c --- /dev/null +++ b/utils.py @@ -0,0 +1,107 @@ +import json +import logging +from datetime import datetime, timedelta +from pathlib import Path +from typing import Any, Dict, Optional + +import emails +from emails.template import JinjaTemplate +from jose import jwt + +from core.config import settings + + +def send_email( + email_to: str, + subject_template: str = "", + html_template: str = "", + environment: Dict[str, Any] = {}, +) -> None: + assert settings.EMAILS_ENABLED, "no provided configuration for email variables" + message = emails.Message( + subject=JinjaTemplate(subject_template), + html=JinjaTemplate(html_template), + mail_from=(settings.EMAILS_FROM_NAME, settings.EMAILS_FROM_EMAIL), + ) + smtp_options = {"host": settings.SMTP_HOST, "port": settings.SMTP_PORT} + if settings.SMTP_TLS: + smtp_options["tls"] = True + if settings.SMTP_USER: + smtp_options["user"] = settings.SMTP_USER + if settings.SMTP_PASSWORD: + smtp_options["password"] = settings.SMTP_PASSWORD + response = message.send(to=email_to, render=environment, smtp=smtp_options) + logging.info(f"send email result: {response}") + + +def send_test_email(email_to: str) -> None: + project_name = settings.PROJECT_NAME + subject = f"{project_name} - Test email" + with open(Path(settings.EMAIL_TEMPLATES_DIR) / "test_email.html") as f: + template_str = f.read() + send_email( + email_to=email_to, + subject_template=subject, + html_template=template_str, + environment={"project_name": settings.PROJECT_NAME, "email": email_to}, + ) + + +def send_reset_password_email(email_to: str, email: str, token: str) -> None: + project_name = settings.PROJECT_NAME + subject = f"{project_name} - Password recovery for user {email}" + with open(Path(settings.EMAIL_TEMPLATES_DIR) / "reset_password.html") as f: + template_str = f.read() + server_host = settings.SERVER_HOST + link = f"{server_host}/reset-password?token={token}" + send_email( + email_to=email_to, + subject_template=subject, + html_template=template_str, + environment={ + "project_name": settings.PROJECT_NAME, + "username": email, + "email": email_to, + "valid_hours": settings.EMAIL_RESET_TOKEN_EXPIRE_HOURS, + "link": link, + }, + ) + + +def send_new_account_email(email_to: str, username: str, password: str) -> None: + project_name = settings.PROJECT_NAME + subject = f"{project_name} - New account for user {username}" + with open(Path(settings.EMAIL_TEMPLATES_DIR) / "new_account.html") as f: + template_str = f.read() + link = settings.SERVER_HOST + send_email( + email_to=email_to, + subject_template=subject, + html_template=template_str, + environment={ + "project_name": settings.PROJECT_NAME, + "username": username, + "password": password, + "email": email_to, + "link": link, + }, + ) + + +def generate_password_reset_token(email: str) -> str: + delta = timedelta(hours=settings.EMAIL_RESET_TOKEN_EXPIRE_HOURS) + now = datetime.utcnow() + expires = now + delta + exp = expires.timestamp() + encoded_jwt = jwt.encode( + {"exp": exp, "nbf": now, "sub": email}, settings.SECRET_KEY, algorithm="HS256", + ) + return encoded_jwt + + +def verify_password_reset_token(token: str) -> Optional[str]: + try: + decoded_token = jwt.decode(token, settings.SECRET_KEY, algorithms=["HS256"]) + return decoded_token.get('user_id') + except jwt.JWTError: + return None