from typing import Any
import json
import logging
import time
from threading import Thread

from posthog.request import (
    AI_EVENTS_ENDPOINT,
    EVENTS_ENDPOINT,
    APIError,
    DatetimeSerializer,
    batch_post,
    is_ai_event,
)

try:
    from queue import Empty
except ImportError:
    from Queue import Empty


MAX_MSG_SIZE = 900 * 1024  # 900KiB per event

# `$ai_*` events carry LLM inputs/outputs and, when routed to the dedicated AI
# endpoint, hit a pipeline that accepts larger messages than analytics ingestion,
# so they get a higher per-event ceiling when that routing is enabled.
AI_MAX_MSG_SIZE = 8 * 1024 * 1024  # 8MiB per `$ai_*` event

# The maximum request body size is currently 20MiB, let's be conservative
# in case we want to lower it in the future.
BATCH_SIZE_LIMIT = 5 * 1024 * 1024


class Consumer(Thread):
    """Consumes the messages from the client's queue."""

    log = logging.getLogger("posthog")

    def __init__(
        self,
        queue,
        api_key,
        flush_at=100,
        host=None,
        on_error=None,
        flush_interval=0.5,
        gzip=False,
        retries=10,
        timeout=15,
        historical_migration=False,
        dedicated_ai_endpoint=False,
    ):
        """Create a consumer thread."""
        Thread.__init__(self)
        # Make consumer a daemon thread so that it doesn't block program exit
        self.daemon = True
        self.flush_at = flush_at
        self.flush_interval = flush_interval
        self.api_key = api_key
        self.host = host
        self.on_error = on_error
        self.queue = queue
        self.gzip = gzip
        self.dedicated_ai_endpoint = dedicated_ai_endpoint
        # It's important to set running in the constructor: if we are asked to
        # pause immediately after construction, we might set running to True in
        # run() *after* we set it to False in pause... and keep running
        # forever.
        self.running = True
        self.retries = retries
        self.timeout = timeout
        self.historical_migration = historical_migration

    def run(self):
        """Runs the consumer."""
        self.log.debug("consumer is running...")
        while self.running:
            self.upload()

        self.log.debug("consumer exited.")

    def pause(self):
        """Pause the consumer."""
        self.running = False

    def upload(self):
        """Upload the next batch of items, return whether successful."""
        success = False
        batch = self.next()
        if len(batch) == 0:
            return False

        try:
            self.request(batch)
            success = True
        except Exception as e:
            self.log.error("error uploading: %s", e)
            success = False
            if self.on_error:
                try:
                    self.on_error(e, batch)
                except Exception as e:
                    self.log.error("on_error handler failed: %s", e)
        finally:
            # mark items as acknowledged from queue
            for item in batch:
                self.queue.task_done()

        return success

    def next(self):
        """Return the next batch of items to upload."""
        queue = self.queue
        items: list[Any] = []

        start_time = time.monotonic()
        total_size = 0

        while len(items) < self.flush_at:
            elapsed = time.monotonic() - start_time
            if elapsed >= self.flush_interval:
                break
            try:
                item = queue.get(block=True, timeout=self.flush_interval - elapsed)
                item_size = len(json.dumps(item, cls=DatetimeSerializer).encode())
                max_msg_size = self._max_msg_size(item)
                if item_size > max_msg_size:
                    self.log.error(
                        "Item exceeds %dKiB limit, dropping. (%s)",
                        max_msg_size // 1024,
                        str(item),
                    )
                    queue.task_done()
                    continue
                items.append(item)
                total_size += item_size
                if total_size >= BATCH_SIZE_LIMIT:
                    self.log.debug("hit batch size limit (size: %d)", total_size)
                    break
            except Empty:
                break

        return items

    def request(self, batch):
        """Upload the batch, routing `$ai_*` events to their own endpoint when enabled.

        Each destination is attempted independently so a failure on one does not
        skip the other. The first failure is raised (so `upload()` logs it and
        invokes `on_error`); a second is logged here so it isn't silently lost.
        The batch was already dequeued in `upload()`, so unsent events are dropped
        after retries, same as the single-endpoint path.
        """
        if not self.dedicated_ai_endpoint:
            self._send(batch, EVENTS_ENDPOINT)
            return

        ai_events: list[Any] = []
        analytics_events: list[Any] = []
        for item in batch:
            target = ai_events if is_ai_event(item.get("event")) else analytics_events
            target.append(item)

        first_exc = None
        for events, path in (
            (analytics_events, EVENTS_ENDPOINT),
            (ai_events, AI_EVENTS_ENDPOINT),
        ):
            if not events:
                continue
            try:
                self._send(events, path)
            except Exception as e:
                if first_exc is None:
                    first_exc = e
                else:
                    self.log.error("error uploading to %s: %s", path, e)

        if first_exc is not None:
            raise first_exc

    def _send(self, batch, path):
        """Attempt to upload a single batch to `path`, retrying before raising an error"""

        def is_retryable(exc):
            if isinstance(exc, APIError):
                # retry on server errors and client errors
                # with 408 (request timeout) or 429 (rate limited),
                # don't retry on other client errors
                if exc.status == "N/A":
                    return False
                return not ((400 <= exc.status < 500) and exc.status not in (408, 429))
            else:
                # retry on all other errors (eg. network)
                return True

        last_exc = None
        for attempt in range(self.retries + 1):
            try:
                batch_post(
                    self.api_key,
                    self.host,
                    gzip=self.gzip,
                    timeout=self.timeout,
                    batch=batch,
                    historical_migration=self.historical_migration,
                    path=path,
                )
                return
            except Exception as e:
                last_exc = e
                if not is_retryable(e):
                    raise
                if attempt < self.retries:
                    # Respect Retry-After header if present, otherwise use exponential backoff
                    retry_after = getattr(e, "retry_after", None)
                    if retry_after and retry_after > 0:
                        time.sleep(retry_after)
                    else:
                        time.sleep(min(2**attempt, 30))

        if last_exc:
            raise last_exc

    def _max_msg_size(self, item):
        if self.dedicated_ai_endpoint and is_ai_event(item.get("event")):
            return AI_MAX_MSG_SIZE
        return MAX_MSG_SIZE
