diff --git a/xiaomusic/httpserver.py b/xiaomusic/httpserver.py index 76aa11c..2a55fd6 100644 --- a/xiaomusic/httpserver.py +++ b/xiaomusic/httpserver.py @@ -24,6 +24,8 @@ from fastapi import ( status, ) from fastapi.middleware.cors import CORSMiddleware +from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html +from fastapi.openapi.utils import get_openapi from fastapi.responses import RedirectResponse, StreamingResponse from fastapi.security import HTTPBasic, HTTPBasicCredentials from fastapi.staticfiles import StaticFiles @@ -92,6 +94,9 @@ def no_verification(): app = FastAPI( lifespan=app_lifespan, version=__version__, + docs_url=None, + redoc_url=None, + openapi_url=None, ) app.add_middleware( @@ -111,6 +116,17 @@ def reset_http_server(): app.dependency_overrides = {} +class AuthStaticFiles(StaticFiles): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + async def __call__(self, scope, receive, send) -> None: + request = Request(scope, receive) + if not config.disable_httpauth: + assert verification(await security(request)) + await super().__call__(scope, receive, send) + + def HttpInit(_xiaomusic): global xiaomusic, config, log xiaomusic = _xiaomusic @@ -118,7 +134,7 @@ def HttpInit(_xiaomusic): log = xiaomusic.log folder = os.path.dirname(__file__) - app.mount("/static", StaticFiles(directory=f"{folder}/static"), name="static") + app.mount("/static", AuthStaticFiles(directory=f"{folder}/static"), name="static") reset_http_server() @@ -598,3 +614,18 @@ async def get_picture(request: Request, file_path: str, key: str = "", code: str if mime_type is None: mime_type = "image/jpeg" return FileResponse(absolute_file_path, media_type=mime_type) + + +@app.get("/docs", include_in_schema=False) +async def get_swagger_documentation(Verifcation=Depends(verification)): + return get_swagger_ui_html(openapi_url="/openapi.json", title="docs") + + +@app.get("/redoc", include_in_schema=False) +async def get_redoc_documentation(Verifcation=Depends(verification)): + return get_redoc_html(openapi_url="/openapi.json", title="docs") + + +@app.get("/openapi.json", include_in_schema=False) +async def openapi(Verifcation=Depends(verification)): + return get_openapi(title=app.title, version=app.version, routes=app.routes)