"""Helpers for parsing docstrings. Used for helptext generation."""

import builtins
import collections.abc
import dataclasses
import functools
import inspect
import io
import itertools
import sys
import tokenize
from typing import Callable, Dict, Hashable, List, Optional, Set, Tuple, Type, TypeVar

from typing_extensions import get_origin, is_typeddict

from tyro._typing_compat import is_typing_generic

from . import _resolver, _strings, _unsafe_cache
from .conf import _markers

T = TypeVar("T", bound=Callable)


@dataclasses.dataclass(frozen=True)
class _Token:
    token_type: int
    content: str
    logical_line: int
    actual_line: int


@dataclasses.dataclass(frozen=True)
class _FieldData:
    index: int
    logical_line: int
    actual_line: int
    prev_field_logical_line: int


@dataclasses.dataclass(frozen=True)
class _ClassTokenization:
    tokens: List[_Token]
    tokens_from_logical_line: Dict[int, List[_Token]]
    tokens_from_actual_line: Dict[int, List[_Token]]
    field_data_from_name: Dict[str, _FieldData]
    classdef_logical_line: int
    field_comments: Dict[str, str]  # Pre-computed comment for each field.

    @staticmethod
    @_unsafe_cache.unsafe_cache(64)
    def make(clz) -> "_ClassTokenization":
        """Parse the source code of a class, and cache some tokenization information."""
        readline = io.BytesIO(inspect.getsource(clz).encode("utf-8")).readline

        tokens: List[_Token] = []
        tokens_from_logical_line: Dict[int, List[_Token]] = {1: []}
        tokens_from_actual_line: Dict[int, List[_Token]] = {1: []}
        field_data_from_name: Dict[str, _FieldData] = {}
        classdef_logical_line: int = -1

        logical_line: int = 1
        actual_line: int = 1
        for toktype, tok, start, end, line in tokenize.tokenize(readline):
            # Note: we only track logical line numbers, which are delimited by
            # `tokenize.NEWLINE`. `tokenize.NL` tokens appear when logical lines are
            # broken into multiple lines of code; these are ignored.
            if toktype == tokenize.NEWLINE:
                logical_line += 1
                actual_line += 1
                tokens_from_logical_line[logical_line] = []
                tokens_from_actual_line[actual_line] = []
            elif toktype == tokenize.NL:
                actual_line += 1
                tokens_from_actual_line[actual_line] = []
            elif toktype is not tokenize.INDENT:
                token = _Token(
                    token_type=toktype,
                    content=tok,
                    logical_line=logical_line,
                    actual_line=actual_line,
                )
                tokens.append(token)
                tokens_from_logical_line[logical_line].append(token)
                tokens_from_actual_line[actual_line].append(token)

                # Track if we've seen the class definition.
                if (
                    toktype == tokenize.NAME
                    and tok == "class"
                    and classdef_logical_line == -1
                ):
                    classdef_logical_line = logical_line

        prev_field_logical_line: int = 1
        for i, token in enumerate(tokens[:-1]):
            if token.token_type == tokenize.NAME:
                # Naive heuristic for field names.
                is_first_token = True
                for t in tokens_from_logical_line[token.logical_line]:
                    if t == token:
                        break
                    if t.token_type is not tokenize.COMMENT:
                        is_first_token = False
                        break

                if not is_first_token:
                    continue

                if (
                    tokens[i + 1].content == ":"
                    and token.content not in field_data_from_name
                ):
                    field_data_from_name[token.content] = _FieldData(
                        index=i,
                        logical_line=token.logical_line,
                        actual_line=token.actual_line,
                        prev_field_logical_line=prev_field_logical_line,
                    )
                    prev_field_logical_line = token.logical_line

        # Pre-compute comments for all fields in a single forward pass.
        # This is O(n) instead of O(fields × lines).
        field_comments: Dict[str, str] = {}

        # Early return: if there are no comments in the source, skip all the work below.
        has_any_comments = any(token.token_type == tokenize.COMMENT for token in tokens)
        if not has_any_comments:
            return _ClassTokenization(
                tokens=tokens,
                tokens_from_logical_line=tokens_from_logical_line,
                tokens_from_actual_line=tokens_from_actual_line,
                field_data_from_name=field_data_from_name,
                classdef_logical_line=classdef_logical_line,
                field_comments=field_comments,
            )

        # Build reverse mapping: actual_line -> field_name.
        line_to_field: Dict[int, str] = {
            field_data.actual_line: field_name
            for field_name, field_data in field_data_from_name.items()
        }

        # Single forward pass through actual lines to associate comments with fields.
        comment_buffer: List[Tuple[str, bool]] = []  # (comment_text, is_sphinx)
        sorted_actual_lines = sorted(tokens_from_actual_line.keys())

        for line_idx, actual_line in enumerate(sorted_actual_lines):
            line_tokens = tokens_from_actual_line[actual_line]

            # Check if this line has a field.
            if actual_line in line_to_field:
                field_name = line_to_field[actual_line]

                # Check for inline comment on the same line as the field.
                if (
                    len(line_tokens) > 0
                    and line_tokens[-1].token_type == tokenize.COMMENT
                ):
                    comment_text = line_tokens[-1].content
                    assert comment_text.startswith("#")
                    if comment_text.startswith("#:"):
                        field_comments[field_name] = _strings.remove_single_line_breaks(
                            comment_text[2:].strip()
                        )
                    else:
                        field_comments[field_name] = _strings.remove_single_line_breaks(
                            comment_text[1:].strip()
                        )
                    # Inline comments always clear the buffer.
                    comment_buffer = []
                # Otherwise, assign buffered comments if any.
                elif len(comment_buffer) > 0:
                    # Sphinx-style comments only apply if directly above.
                    has_sphinx = any(is_sphinx for _, is_sphinx in comment_buffer)
                    field_comments[field_name] = _strings.remove_single_line_breaks(
                        "\n".join(text for text, _ in comment_buffer)
                    )

                    # After assigning comments, decide whether to keep buffer:
                    # - Sphinx comments: always clear (apply to one field only).
                    # - Non-Sphinx comments: keep buffer if next line is also a field
                    #   (for grouped comments like "Description of both y and z").
                    # - Otherwise: clear buffer (prevents comments from applying to non-adjacent fields).
                    next_line_is_field = (
                        line_idx + 1 < len(sorted_actual_lines)
                        and sorted_actual_lines[line_idx + 1] in line_to_field
                    )
                    if has_sphinx or not next_line_is_field:
                        comment_buffer = []

            # Track comments for the buffer.
            elif (
                len(line_tokens) == 1
                and line_tokens[0].token_type == tokenize.COMMENT
                and line_tokens[0].logical_line > classdef_logical_line
            ):
                comment_text = line_tokens[0].content
                assert comment_text.startswith("#")
                is_sphinx = comment_text.startswith("#:")
                if is_sphinx:
                    comment_buffer.append((comment_text[2:].strip(), True))
                else:
                    comment_buffer.append((comment_text[1:].strip(), False))

            # Empty line or non-comment, non-field line: clear buffer.
            # This prevents comments above methods, assignments, or other
            # non-field code from leaking to subsequent fields.
            else:
                comment_buffer = []

        return _ClassTokenization(
            tokens=tokens,
            tokens_from_logical_line=tokens_from_logical_line,
            tokens_from_actual_line=tokens_from_actual_line,
            field_data_from_name=field_data_from_name,
            classdef_logical_line=classdef_logical_line,
            field_comments=field_comments,
        )


@_unsafe_cache.unsafe_cache(1024)
def get_class_tokenization_with_field(
    cls: Type, field_name: str
) -> Optional[_ClassTokenization]:
    # Search for token in this class + all parents.
    found_field: bool = False
    classes_to_search = cls.__mro__
    tokenization = None
    for search_cls in classes_to_search:
        # Inherited generics seem challenging for now.
        # https://github.com/python/typing/issues/777
        assert is_typing_generic(search_cls) or get_origin(search_cls) is None

        try:
            tokenization = _ClassTokenization.make(search_cls)  # type: ignore
        except OSError:
            # OSError is raised when we can't read the source code. This is
            # fine, we just assume there's no docstring. We can uncomment the
            # assert below for debugging.
            #
            # assert (
            #     # Dynamic dataclasses.
            #     "could not find class definition" in e.args[0]
            #     # Pydantic.
            #     or "source code not available" in e.args[0]
            #     # Third error that can be raised by inspect.py.
            #     or "could not get source code" in e.args[0]
            # )
            return None
        except TypeError as e:  # pragma: no cover
            # Notebooks cause “___ is a built-in class” TypeError.
            assert "built-in class" in e.args[0]
            return None

        # Grab field-specific tokenization data.
        if field_name in tokenization.field_data_from_name:
            found_field = True
            break

    if dataclasses.is_dataclass(cls):
        assert found_field, (
            "Docstring parsing error -- this usually means that there are multiple"
            " dataclasses in the same file with the same name but different scopes."
        )

    return tokenization


@functools.lru_cache(maxsize=1024)
def parse_docstring_from_object(obj: object) -> Dict[str, str]:
    import docstring_parser

    return {
        doc.arg_name: doc.description
        for doc in docstring_parser.parse_from_object(obj).params
        if doc.description is not None
    }


@_unsafe_cache.unsafe_cache(1024)
def get_field_docstring(
    cls: Type, field_name: str, markers: Tuple[_markers.Marker, ...]
) -> Optional[str]:
    """Get docstring for a field in a class."""

    # NoneType will break docstring_parser.
    if cls is type(None):
        return None

    # Try to parse using docstring_parser.
    for cls_search in cls.__mro__:
        if cls_search.__module__ == "builtins":
            continue  # Skip `object`, `Callable`, `tuple`, etc.
        docstring = parse_docstring_from_object(cls_search).get(field_name, None)
        if docstring is not None:
            return _strings.dedent(
                _strings.remove_single_line_breaks(docstring)
            ).strip()

    if _markers.HelptextFromCommentsOff in markers:
        return None

    # If docstring_parser failed, let's try looking for comments.
    # Comments are pre-computed during tokenization for efficiency.
    tokenization = get_class_tokenization_with_field(cls, field_name)
    if tokenization is None:  # Currently only happens for dynamic dataclasses.
        return None

    # Return pre-computed comment if available.
    return tokenization.field_comments.get(field_name, None)


_callable_description_blocklist: Set[Hashable] = set(
    filter(
        lambda x: isinstance(x, Hashable),  # type: ignore
        itertools.chain(
            vars(builtins).values(),
            vars(collections.abc).values(),
        ),
    )
)


@_unsafe_cache.unsafe_cache(1024)
def get_callable_description(f: Callable) -> str:
    """Get description associated with a callable via docstring parsing.

    `dataclasses.dataclass` will automatically populate __doc__ based on the
    fields of the class if a docstring is not specified; this helper will
    ignore these docstrings."""

    f, _ = _resolver.resolve_generic_types(f)
    f = _resolver.unwrap_origin_strip_extras(f)
    if f in _callable_description_blocklist:
        return ""

    # Return original docstring when used with functools.partial, not
    # functools.partial's docstring.
    if isinstance(f, functools.partial):
        f = f.func

    if "pydantic" in sys.modules.keys():
        try:
            import pydantic
        except ImportError:
            # Needed for mock import test.
            pydantic = None  # type: ignore
    else:
        pydantic = None  # type: ignore

    # Note inspect.getdoc() causes some corner cases with TypedDicts.
    docstring = f.__doc__
    if (
        docstring is None
        and inspect.isclass(f)
        # Ignore TypedDict's __init__ docstring, because it will just be `dict`
        and not is_typeddict(f)
        # Ignore NamedTuple __init__ docstring.
        and not _resolver.is_namedtuple(f)
        # Ignore pydantic base model constructor docstring.
        and not (pydantic is not None and f.__init__ is pydantic.BaseModel.__init__)  # type: ignore
    ):
        docstring = f.__init__.__doc__  # type: ignore
    if docstring is None:
        return ""

    docstring = _strings.dedent(docstring)

    if dataclasses.is_dataclass(f):
        default_doc = f.__name__ + str(inspect.signature(f)).replace(" -> None", "")  # type: ignore
        if docstring == default_doc:
            return ""

    import docstring_parser

    parsed_docstring = docstring_parser.parse(docstring)

    parts: List[str] = []
    if parsed_docstring.short_description is not None:
        parts.append(parsed_docstring.short_description)
    if parsed_docstring.long_description is not None:
        parts.append(parsed_docstring.long_description)
    return "\n".join(parts)
