import logging
from collections.abc import Callable
from contextlib import asynccontextmanager
from typing import Any
from urllib.parse import parse_qs, urljoin, urlparse

import anyio
import httpx
from anyio.abc import TaskStatus
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from httpx_sse import aconnect_sse
from httpx_sse._exceptions import SSEError

import mcp.types as types
from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client
from mcp.shared.message import SessionMessage

logger = logging.getLogger(__name__)


def remove_request_params(url: str) -> str:
    return urljoin(url, urlparse(url).path)


def _extract_session_id_from_endpoint(endpoint_url: str) -> str | None:
    query_params = parse_qs(urlparse(endpoint_url).query)
    return query_params.get("sessionId", [None])[0] or query_params.get("session_id", [None])[0]


@asynccontextmanager
async def sse_client(
    url: str,
    headers: dict[str, Any] | None = None,
    timeout: float = 5,
    sse_read_timeout: float = 60 * 5,
    httpx_client_factory: McpHttpClientFactory = create_mcp_http_client,
    auth: httpx.Auth | None = None,
    on_session_created: Callable[[str], None] | None = None,
):
    """
    Client transport for SSE.

    `sse_read_timeout` determines how long (in seconds) the client will wait for a new
    event before disconnecting. All other HTTP operations are controlled by `timeout`.

    Args:
        url: The SSE endpoint URL.
        headers: Optional headers to include in requests.
        timeout: HTTP timeout for regular operations.
        sse_read_timeout: Timeout for SSE read operations.
        auth: Optional HTTPX authentication handler.
        on_session_created: Optional callback invoked with the session ID when received.
    """
    read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
    read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]

    write_stream: MemoryObjectSendStream[SessionMessage]
    write_stream_reader: MemoryObjectReceiveStream[SessionMessage]

    read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
    write_stream, write_stream_reader = anyio.create_memory_object_stream(0)

    async with anyio.create_task_group() as tg:
        try:
            logger.debug(f"Connecting to SSE endpoint: {remove_request_params(url)}")
            async with httpx_client_factory(
                headers=headers, auth=auth, timeout=httpx.Timeout(timeout, read=sse_read_timeout)
            ) as client:
                async with aconnect_sse(
                    client,
                    "GET",
                    url,
                ) as event_source:
                    event_source.response.raise_for_status()
                    logger.debug("SSE connection established")

                    async def sse_reader(
                        task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED,
                    ):
                        try:
                            async for sse in event_source.aiter_sse():  # pragma: no branch
                                logger.debug(f"Received SSE event: {sse.event}")
                                match sse.event:
                                    case "endpoint":
                                        endpoint_url = urljoin(url, sse.data)
                                        logger.debug(f"Received endpoint URL: {endpoint_url}")

                                        url_parsed = urlparse(url)
                                        endpoint_parsed = urlparse(endpoint_url)
                                        if (  # pragma: no cover
                                            url_parsed.netloc != endpoint_parsed.netloc
                                            or url_parsed.scheme != endpoint_parsed.scheme
                                        ):
                                            error_msg = (  # pragma: no cover
                                                f"Endpoint origin does not match connection origin: {endpoint_url}"
                                            )
                                            logger.error(error_msg)  # pragma: no cover
                                            raise ValueError(error_msg)  # pragma: no cover

                                        if on_session_created:
                                            session_id = _extract_session_id_from_endpoint(endpoint_url)
                                            if session_id:
                                                on_session_created(session_id)

                                        task_status.started(endpoint_url)

                                    case "message":
                                        # Skip empty data (keep-alive pings)
                                        if not sse.data:
                                            continue
                                        try:
                                            message = types.JSONRPCMessage.model_validate_json(  # noqa: E501
                                                sse.data
                                            )
                                            logger.debug(f"Received server message: {message}")
                                        except Exception as exc:  # pragma: no cover
                                            logger.exception("Error parsing server message")  # pragma: no cover
                                            await read_stream_writer.send(exc)  # pragma: no cover
                                            continue  # pragma: no cover

                                        session_message = SessionMessage(message)
                                        await read_stream_writer.send(session_message)
                                    case _:  # pragma: no cover
                                        logger.warning(f"Unknown SSE event: {sse.event}")  # pragma: no cover
                        except SSEError as sse_exc:  # pragma: no cover
                            logger.exception("Encountered SSE exception")  # pragma: no cover
                            raise sse_exc  # pragma: no cover
                        except Exception as exc:  # pragma: no cover
                            logger.exception("Error in sse_reader")  # pragma: no cover
                            await read_stream_writer.send(exc)  # pragma: no cover
                        finally:
                            await read_stream_writer.aclose()

                    async def post_writer(endpoint_url: str):
                        try:
                            async with write_stream_reader:
                                async for session_message in write_stream_reader:
                                    logger.debug(f"Sending client message: {session_message}")
                                    response = await client.post(
                                        endpoint_url,
                                        json=session_message.message.model_dump(
                                            by_alias=True,
                                            mode="json",
                                            exclude_none=True,
                                        ),
                                    )
                                    response.raise_for_status()
                                    logger.debug(f"Client message sent successfully: {response.status_code}")
                        except Exception:  # pragma: no cover
                            logger.exception("Error in post_writer")  # pragma: no cover
                        finally:
                            await write_stream.aclose()

                    endpoint_url = await tg.start(sse_reader)
                    logger.debug(f"Starting post writer with endpoint URL: {endpoint_url}")
                    tg.start_soon(post_writer, endpoint_url)

                    try:
                        yield read_stream, write_stream
                    finally:
                        tg.cancel_scope.cancel()
        finally:
            await read_stream_writer.aclose()
            await write_stream.aclose()
