import os
import threading
import time
from dataclasses import dataclass
from enum import Enum

import httpx

from .. import constants
from . import hf_raise_for_status, http_backoff, validate_hf_hub_args


XET_CONNECTION_INFO_SAFETY_PERIOD = 60  # seconds
XET_CONNECTION_INFO_CACHE_SIZE = 1_000
XET_CONNECTION_INFO_CACHE: dict[str, "XetConnectionInfo"] = {}


class XetTokenType(str, Enum):
    READ = "read"
    WRITE = "write"


@dataclass(frozen=True)
class XetFileData:
    file_hash: str
    refresh_route: str


@dataclass(frozen=True)
class XetConnectionInfo:
    access_token: str
    expiration_unix_epoch: int
    endpoint: str


def parse_xet_file_data_from_response(response: httpx.Response, endpoint: str | None = None) -> XetFileData | None:
    """
    Parse XET file metadata from an HTTP response.

    This function extracts XET file metadata from the HTTP headers or HTTP links
    of a given response object. If the required metadata is not found, it returns `None`.

    Args:
        response (`httpx.Response`):
            The HTTP response object containing headers dict and links dict to extract the XET metadata from.
    Returns:
        `Optional[XetFileData]`:
            An instance of `XetFileData` containing the file hash and refresh route if the metadata
            is found. Returns `None` if the required metadata is missing.
    """
    if response is None:
        return None
    try:
        file_hash = response.headers[constants.HUGGINGFACE_HEADER_X_XET_HASH]

        if constants.HUGGINGFACE_HEADER_LINK_XET_AUTH_KEY in response.links:
            refresh_route = response.links[constants.HUGGINGFACE_HEADER_LINK_XET_AUTH_KEY]["url"]
        else:
            refresh_route = response.headers[constants.HUGGINGFACE_HEADER_X_XET_REFRESH_ROUTE]
    except KeyError:
        return None
    endpoint = endpoint if endpoint is not None else constants.ENDPOINT
    if refresh_route.startswith(constants.HUGGINGFACE_CO_URL_HOME):
        refresh_route = refresh_route.replace(constants.HUGGINGFACE_CO_URL_HOME.rstrip("/"), endpoint.rstrip("/"))
    return XetFileData(
        file_hash=file_hash,
        refresh_route=refresh_route,
    )


def parse_xet_connection_info_from_headers(headers: dict[str, str]) -> XetConnectionInfo | None:
    """
    Parse XET connection info from the HTTP headers or return None if not found.
    Args:
        headers (`dict`):
           HTTP headers to extract the XET metadata from.
    Returns:
        `XetConnectionInfo` or `None`:
            The information needed to connect to the XET storage service.
            Returns `None` if the headers do not contain the XET connection info.
    """
    try:
        endpoint = headers[constants.HUGGINGFACE_HEADER_X_XET_ENDPOINT]
        access_token = headers[constants.HUGGINGFACE_HEADER_X_XET_ACCESS_TOKEN]
        expiration_unix_epoch = int(headers[constants.HUGGINGFACE_HEADER_X_XET_EXPIRATION])
    except (KeyError, ValueError, TypeError):
        return None

    return XetConnectionInfo(
        endpoint=endpoint,
        access_token=access_token,
        expiration_unix_epoch=expiration_unix_epoch,
    )


@validate_hf_hub_args
def refresh_xet_connection_info(
    *,
    file_data: XetFileData,
    headers: dict[str, str],
) -> XetConnectionInfo:
    """
    Utilizes the information in the parsed metadata to request the Hub xet connection information.
    This includes the access token, expiration, and XET service URL.
    Args:
        file_data: (`XetFileData`):
            The file data needed to refresh the xet connection information.
        headers (`dict[str, str]`):
            Headers to use for the request, including authorization headers and user agent.
    Returns:
        `XetConnectionInfo`:
            The connection information needed to make the request to the xet storage service.
    Raises:
        [`~utils.HfHubHTTPError`]
            If the Hub API returned an error.
        [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
            If the Hub API response is improperly formatted.
    """
    if file_data.refresh_route is None:
        raise ValueError("The provided xet metadata does not contain a refresh endpoint.")
    return _fetch_xet_connection_info_with_url(file_data.refresh_route, headers)


@validate_hf_hub_args
def xet_connection_info_refresh_url(
    *,
    token_type: XetTokenType,
    repo_id: str,
    repo_type: str,
    revision: str | None = None,
    endpoint: str | None = None,
) -> str:
    """
    Build the URL used to fetch or refresh a Xet access token for a given repo.
    Args:
        token_type (`XetTokenType`):
            Type of the token to request: `"read"` or `"write"`.
        repo_id (`str`):
            A namespace (user or an organization) and a repo name separated by a `/`.
        repo_type (`str`):
            Type of the repo (e.g. `"model"`, `"dataset"`, `"space"`, `"bucket"`).
        revision (`str`, `optional`):
            The revision of the repo to get the token for.
        endpoint (`str`, `optional`):
            The endpoint to use for the request. Defaults to the Hub endpoint.
    Returns:
        `str`:
            The fully-qualified URL of the token refresh endpoint.
    """
    endpoint = endpoint if endpoint is not None else constants.ENDPOINT
    url = f"{endpoint}/api/{repo_type}s/{repo_id}/xet-{token_type.value}-token"
    if repo_type != "bucket" or revision is not None:
        # On "bucket" repo type, the revision never needed => don't use it
        # Otherwise, use the revision.
        # Note: when creating a PR on a git-based repo, user needs write access but they don't know the revision in advance.
        # => pass "/None" in URL and server will return a token for PR refs.
        url += f"/{revision}"
    return url


@validate_hf_hub_args
def fetch_xet_connection_info_from_repo_info(
    *,
    token_type: XetTokenType,
    repo_id: str,
    repo_type: str,
    revision: str | None = None,
    headers: dict[str, str],
    endpoint: str | None = None,
    params: dict[str, str] | None = None,
) -> XetConnectionInfo:
    """
    Uses the repo info to request a xet access token from Hub.
    Args:
        token_type (`XetTokenType`):
            Type of the token to request: `"read"` or `"write"`.
        repo_id (`str`):
            A namespace (user or an organization) and a repo name separated by a `/`.
        repo_type (`str`):
            Type of the repo to upload to: `"model"`, `"dataset"` or `"space"`.
        revision (`str`, `optional`):
            The revision of the repo to get the token for.
        headers (`dict[str, str]`):
            Headers to use for the request, including authorization headers and user agent.
        endpoint (`str`, `optional`):
            The endpoint to use for the request. Defaults to the Hub endpoint.
        params (`dict[str, str]`, `optional`):
            Additional parameters to pass with the request.
    Returns:
        `XetConnectionInfo`:
            The connection information needed to make the request to the xet storage service.
    Raises:
        [`~utils.HfHubHTTPError`]
            If the Hub API returned an error.
        [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
            If the Hub API response is improperly formatted.
    """
    url = xet_connection_info_refresh_url(
        token_type=token_type,
        repo_id=repo_id,
        repo_type=repo_type,
        revision=revision,
        endpoint=endpoint,
    )
    return _fetch_xet_connection_info_with_url(url, headers, params, cache_key_prefix=f"{repo_type}-{repo_id}")


@validate_hf_hub_args
def _fetch_xet_connection_info_with_url(
    url: str,
    headers: dict[str, str],
    params: dict[str, str] | None = None,
    cache_key_prefix: str | None = None,
) -> XetConnectionInfo:
    """
    Requests the xet connection info from the supplied URL. This includes the
    access token, expiration time, and endpoint to use for the xet storage service.

    Result is cached to avoid redundant requests.

    Args:
        url: (`str`):
            The access token endpoint URL.
        headers (`dict[str, str]`):
            Headers to use for the request, including authorization headers and user agent.
        params (`dict[str, str]`, `optional`):
            Additional parameters to pass with the request.
    Returns:
        `XetConnectionInfo`:
            The connection information needed to make the request to the xet storage service.
    Raises:
        [`~utils.HfHubHTTPError`]
            If the Hub API returned an error.
        [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
            If the Hub API response is improperly formatted.
    """
    # Check cache first
    cache_key = _cache_key(url, headers, params, prefix=cache_key_prefix)
    cached_info = XET_CONNECTION_INFO_CACHE.get(cache_key)
    if cached_info is not None:
        if not _is_expired(cached_info):
            return cached_info

    # Fetch from server
    resp = http_backoff("GET", url, headers=headers, params=params)
    hf_raise_for_status(resp)

    metadata = parse_xet_connection_info_from_headers(resp.headers)  # type: ignore
    if metadata is None:
        raise ValueError("Xet headers have not been correctly set by the server.")

    # Delete expired cache entries
    for k, v in list(XET_CONNECTION_INFO_CACHE.items()):
        if _is_expired(v):
            XET_CONNECTION_INFO_CACHE.pop(k, None)

    # Enforce cache size limit
    if len(XET_CONNECTION_INFO_CACHE) >= XET_CONNECTION_INFO_CACHE_SIZE:
        XET_CONNECTION_INFO_CACHE.pop(next(iter(XET_CONNECTION_INFO_CACHE)))

    # Update cache
    XET_CONNECTION_INFO_CACHE[cache_key] = metadata

    return metadata


def reset_xet_connection_info_cache_for_repo(repo_type: str | None, repo_id: str) -> None:
    """Reset the XET connection info cache for the given repo type and repo id.

    Used when a repo is deleted.
    """
    if repo_type is None:
        repo_type = constants.REPO_TYPE_MODEL
    prefix = f"{repo_type}-{repo_id}|"
    for k in list(XET_CONNECTION_INFO_CACHE.keys()):
        if k.startswith(prefix):
            XET_CONNECTION_INFO_CACHE.pop(k, None)


def _cache_key(url: str, headers: dict[str, str], params: dict[str, str] | None, prefix: str | None = None) -> str:
    """Return a unique cache key for the given request parameters."""
    lower_headers = {k.lower(): v for k, v in headers.items()}  # casing is not guaranteed here
    auth_header = lower_headers.get("authorization", "")
    params_str = "&".join(f"{k}={v}" for k, v in sorted((params or {}).items(), key=lambda x: x[0]))
    return f"{prefix}|{url}|{auth_header}|{params_str}"


def _is_expired(connection_info: XetConnectionInfo) -> bool:
    """Check if the given XET connection info is expired."""
    return connection_info.expiration_unix_epoch <= int(time.time()) + XET_CONNECTION_INFO_SAFETY_PERIOD


class XetSessionHolder:
    """Holds an optional XetSession; supports safe re-creation after sigint_abort or fork.

    Thread-safe: a ``threading.Lock`` guards all state mutations, which matters
    for free-threaded Python (3.14t) where multiple threads can race on ``get()``
    or ``sigint_abort()`` without the GIL serialising them.
    """

    def __init__(self):
        self._lock = threading.Lock()
        self._session = None
        self._session_pid: int | None = None

    def get(self):
        """Return the current session, creating one if needed.

        Fork-safe: if the current process PID differs from the PID that created
        the session (i.e. we are in a forked child), the old session is discarded
        and a fresh session is created for this process.
        """
        with self._lock:
            current_pid = os.getpid()

            if self._session is not None and self._session_pid != current_pid:
                # Fork detected. Discard the parent's session; the Rust Drop will
                # call discard_runtime() (std::mem::forget) rather than the normal
                # shutdown path, so this returns immediately without blocking.
                self._session = None

            if self._session is None:
                from hf_xet import XetSession

                self._session = XetSession()
                self._session_pid = current_pid

            return self._session

    def sigint_abort(self):
        """Abort the current session and clear it so the next get() creates a fresh one."""
        with self._lock:
            if self._session is not None:
                try:
                    self._session.sigint_abort()
                except Exception:
                    pass
                self._session = None
                self._session_pid = None


_GLOBAL_XET_HOLDER = XetSessionHolder()


def get_xet_session():
    """Return the global :class:`hf_xet.XetSession`, creating it on first call.

    The session is shared across all calls within a process, just as the HTTP
    client returned by :func:`~huggingface_hub.utils._http.get_session` is shared.
    It is created lazily and is fork-safe and thread-safe.
    """
    return _GLOBAL_XET_HOLDER.get()


def xet_headers_without_auth(headers: dict[str, str]) -> dict[str, str]:
    """Return a copy of headers with the authorization header removed.

    Xet storage requests use a short-lived xet access token for auth, so the
    Hub authorization header must not be forwarded to xet storage endpoints.
    """
    return {key: value for key, value in headers.items() if key.lower() != "authorization"}


def abort_xet_session():
    """Abort the global xet session after a KeyboardInterrupt.

    Cancels any in-flight Rust operation and clears the session so the next
    call to :func:`get_xet_session` starts fresh (notebook-friendly).
    """
    _GLOBAL_XET_HOLDER.sigint_abort()
