From cc5facdf4f359aab93a885477e4d00a480d61c86 Mon Sep 17 00:00:00 2001 From: "Gao, Ruiyuan" <905370712@qq.com> Date: Mon, 14 Oct 2024 14:00:55 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20static=E5=92=8Cdoc=E6=B7=BB=E5=8A=A0basi?= =?UTF-8?q?c=20auth=20(#231)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * bug: static和doc添加basic auth * Auto-format code 🧹🌟🤖 --------- Co-authored-by: Formatter [BOT] --- xiaomusic/httpserver.py | 33 ++++++++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/xiaomusic/httpserver.py b/xiaomusic/httpserver.py index 76aa11c..2a55fd6 100644 --- a/xiaomusic/httpserver.py +++ b/xiaomusic/httpserver.py @@ -24,6 +24,8 @@ from fastapi import ( status, ) from fastapi.middleware.cors import CORSMiddleware +from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html +from fastapi.openapi.utils import get_openapi from fastapi.responses import RedirectResponse, StreamingResponse from fastapi.security import HTTPBasic, HTTPBasicCredentials from fastapi.staticfiles import StaticFiles @@ -92,6 +94,9 @@ def no_verification(): app = FastAPI( lifespan=app_lifespan, version=__version__, + docs_url=None, + redoc_url=None, + openapi_url=None, ) app.add_middleware( @@ -111,6 +116,17 @@ def reset_http_server(): app.dependency_overrides = {} +class AuthStaticFiles(StaticFiles): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + async def __call__(self, scope, receive, send) -> None: + request = Request(scope, receive) + if not config.disable_httpauth: + assert verification(await security(request)) + await super().__call__(scope, receive, send) + + def HttpInit(_xiaomusic): global xiaomusic, config, log xiaomusic = _xiaomusic @@ -118,7 +134,7 @@ def HttpInit(_xiaomusic): log = xiaomusic.log folder = os.path.dirname(__file__) - app.mount("/static", StaticFiles(directory=f"{folder}/static"), name="static") + app.mount("/static", AuthStaticFiles(directory=f"{folder}/static"), name="static") reset_http_server() @@ -598,3 +614,18 @@ async def get_picture(request: Request, file_path: str, key: str = "", code: str if mime_type is None: mime_type = "image/jpeg" return FileResponse(absolute_file_path, media_type=mime_type) + + +@app.get("/docs", include_in_schema=False) +async def get_swagger_documentation(Verifcation=Depends(verification)): + return get_swagger_ui_html(openapi_url="/openapi.json", title="docs") + + +@app.get("/redoc", include_in_schema=False) +async def get_redoc_documentation(Verifcation=Depends(verification)): + return get_redoc_html(openapi_url="/openapi.json", title="docs") + + +@app.get("/openapi.json", include_in_schema=False) +async def openapi(Verifcation=Depends(verification)): + return get_openapi(title=app.title, version=app.version, routes=app.routes)