import json
from typing import Dict, List
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from utils.JWTToken import decode_access_token
from utils.logging import logger

websocket_router = APIRouter()

class ConnectionManager:
    def __init__(self):
        # Maps project_id to a list of active WebSocket connections
        self.active_connections: Dict[str, List[WebSocket]] = {}

    async def connect(self, websocket: WebSocket, project_id: str):
        await websocket.accept()
        if project_id not in self.active_connections:
            self.active_connections[project_id] = []
        self.active_connections[project_id].append(websocket)
        logger.info(f"WebSocket connected for project {project_id}. Total connections: {len(self.active_connections[project_id])}")

    def disconnect(self, websocket: WebSocket, project_id: str):
        if project_id in self.active_connections:
            try:
                self.active_connections[project_id].remove(websocket)
                logger.info(f"WebSocket disconnected for project {project_id}")
                if not self.active_connections[project_id]:
                    del self.active_connections[project_id]
            except ValueError:
                pass

    async def broadcast_to_project(self, project_id: str, message: dict):
        if project_id in self.active_connections:
            text_data = json.dumps(message)
            connections = self.active_connections[project_id]
            disconnected = []
            for connection in connections:
                try:
                    await connection.send_text(text_data)
                except Exception as e:
                    logger.error(f"Error broadcasting to connection: {e}")
                    disconnected.append(connection)
            
            for conn in disconnected:
                self.disconnect(conn, project_id)

manager = ConnectionManager()

@websocket_router.websocket("/ws/{org_slug}/{project_id}")
async def websocket_endpoint(websocket: WebSocket, org_slug: str, project_id: str, token: str = None):
    if not token:
        await websocket.close(code=1008, reason="Token missing")
        return

    try:
        payload = decode_access_token(token)
        user_id = payload.get("user_id")
        if not user_id:
            await websocket.close(code=1008, reason="Invalid token")
            return
    except Exception as e:
        logger.error(f"WebSocket JWT Decode Error: {e}")
        await websocket.close(code=1008, reason="Invalid token")
        return

    await manager.connect(websocket, project_id)
    try:
        while True:
            data = await websocket.receive_text()
            if data == "ping":
                await websocket.send_text("pong")
    except WebSocketDisconnect:
        manager.disconnect(websocket, project_id)
    except Exception as e:
        logger.error(f"WebSocket unexpected error: {e}")
        manager.disconnect(websocket, project_id)
