# middleware/auth_middleware.py

from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
import re

from fastapi import Request, status
from sqlalchemy.orm import Session
from jose import JWTError

import os

from utils.logging import logger
from utils.JWTToken import decode_access_token
from utils.get_db import get_db
from models.user_model import User


PUBLIC_ROUTES = os.getenv("PUBLIC_ROUTES", "")

FINAL_PUBLIC_ROUTES = {
    route.strip()
    for route in PUBLIC_ROUTES.split(",")
    if route.strip()
}

# #logger.info(f"Public Routes: {FINAL_PUBLIC_ROUTES}")


class AuthorizationMiddleware(BaseHTTPMiddleware):

    async def dispatch(self, request: Request, call_next):
        if request.scope["type"] == "websocket":
            return await call_next(request)
            
        if request.method == "OPTIONS":
            return await call_next(request)

        request_path = request.url.path

        for route in FINAL_PUBLIC_ROUTES:

            # Convert {param} -> regex
            pattern = re.sub(r"\{[^/]+\}", r"[^/]+", route)

            # Exact path match
            pattern = f"^{pattern}/?$"

            if re.match(pattern, request_path):
                return await call_next(request)

        # Bypass websocket routes (auth is handled in the endpoint)
        if request_path.startswith("/api/v1/ws/"):
            return await call_next(request)

        # Get Authorization header
        auth_header = request.headers.get("Authorization")

        # Allow token via query parameters for media requests like <img> and <video> tags
        if not auth_header and request_path.startswith("/uploads/"):
            token_query = request.query_params.get("token")
            if token_query:
                auth_header = f"Bearer {token_query}"

        if not auth_header:
            #logger.warning("Authorization header missing")

            return JSONResponse(
                status_code=status.HTTP_401_UNAUTHORIZED,
                content={"detail": "Authorization header missing"}
            )

        try:
            #logger.info("Authorization header found")

            # Split Bearer token
            scheme, token = auth_header.split()

            #logger.info(f"Auth Scheme: {scheme}")

            if scheme.lower() != "bearer":

                #logger.warning("Invalid auth scheme")

                return JSONResponse(
                    status_code=status.HTTP_401_UNAUTHORIZED,
                    content={"detail": "Invalid auth scheme"}
                )

            # Decode JWT
            payload = decode_access_token(token)

            #logger.info(f"JWT Payload Decoded: {payload}")

            user_id = payload.get("user_id")

            if not user_id:

                #logger.warning("User ID missing in token")

                return JSONResponse(
                    status_code=status.HTTP_401_UNAUTHORIZED,
                    content={"detail": "Invalid token payload"}
                )

            #logger.info(f"Fetching user from DB: {user_id}")

            # Using get_db()
            db: Session = next(get_db())
            try:
                user = db.query(User).filter(User.id == user_id).first()

                if not user:

                    #logger.warning(f"User not found: {user_id}")

                    return JSONResponse(
                        status_code=status.HTTP_401_UNAUTHORIZED,
                        content={"detail": "User not found"}
                    )

                #logger.info(f"Authenticated User: {user.id}")

                # Store user in request
                request.state.user = user
            finally:
                db.close()

        except JWTError as e:

            #logger.error(f"JWT Error: {str(e)}")

            return JSONResponse(
                status_code=status.HTTP_401_UNAUTHORIZED,
                content={"detail": "Invalid or expired token"}
            )

        except ValueError as e:

            #logger.error(f"Authorization header format error: {str(e)}")

            return JSONResponse(
                status_code=status.HTTP_401_UNAUTHORIZED,
                content={"detail": "Invalid authorization header format"}
            )

        except Exception as e:

            #logger.exception(f"Middleware Error: {str(e)}")

            return JSONResponse(
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
                content={"detail": "Internal server error"}
            )

        response = await call_next(request)

        logger.info(
            f"Response Status: {response.status_code} "
            f"for {request.method} {request.url.path}"
        )

        return response