from __future__ import annotations

import operator
from io import BytesIO
from typing import TYPE_CHECKING, Any, cast

import ibis
import ibis.expr.types as ir

from narwhals._ibis.expr import IbisExpr
from narwhals._ibis.utils import evaluate_exprs, lit, native_to_narwhals_dtype
from narwhals._sql.dataframe import SQLLazyFrame
from narwhals._utils import (
    Implementation,
    ValidateBackendVersion,
    Version,
    generate_temporary_column_name,
    not_implemented,
    parse_columns_to_drop,
    to_pyarrow_table,
)
from narwhals.exceptions import InvalidOperationError

if TYPE_CHECKING:
    from collections.abc import Iterable, Iterator, Mapping, Sequence
    from pathlib import Path
    from types import ModuleType
    from typing import TypeAlias

    import pandas as pd
    import pyarrow as pa
    from ibis.expr.operations import Binary
    from typing_extensions import Self, TypeIs

    from narwhals._compliant.typing import CompliantDataFrameAny
    from narwhals._ibis.group_by import IbisGroupBy
    from narwhals._ibis.namespace import IbisNamespace
    from narwhals._ibis.series import IbisInterchangeSeries
    from narwhals._typing import _EagerAllowedImpl
    from narwhals._utils import _LimitedContext
    from narwhals.dataframe import LazyFrame
    from narwhals.dtypes import DType
    from narwhals.stable.v1 import DataFrame as DataFrameV1
    from narwhals.typing import AsofJoinStrategy, JoinStrategy, UniqueKeepStrategy

    JoinPredicates: TypeAlias = "Sequence[ir.BooleanColumn] | Sequence[str]"


class IbisLazyFrame(
    SQLLazyFrame["IbisExpr", "ir.Table", "LazyFrame[ir.Table] | DataFrameV1[ir.Table]"],
    ValidateBackendVersion,
):
    _implementation = Implementation.IBIS

    def __init__(
        self, df: ir.Table, *, version: Version, validate_backend_version: bool = False
    ) -> None:
        self._native_frame: ir.Table = df
        self._version = version
        self._cached_schema: dict[str, DType] | None = None
        self._cached_columns: list[str] | None = None
        if validate_backend_version:
            self._validate_backend_version()

    @staticmethod
    def _is_native(obj: ir.Table | Any) -> TypeIs[ir.Table]:
        return isinstance(obj, ir.Table)

    @classmethod
    def from_native(cls, data: ir.Table, /, *, context: _LimitedContext) -> Self:
        return cls(data, version=context._version)

    def to_narwhals(self) -> LazyFrame[ir.Table] | DataFrameV1[ir.Table]:
        if self._version is Version.V1:
            from narwhals.stable.v1 import DataFrame

            return DataFrame(self, level="interchange")
        return self._version.lazyframe(self, level="lazy")

    def __narwhals_dataframe__(self) -> Self:  # pragma: no cover
        # Keep around for backcompat.
        if self._version is not Version.V1:
            msg = "__narwhals_dataframe__ is not implemented for IbisLazyFrame"
            raise AttributeError(msg)
        return self

    def __narwhals_lazyframe__(self) -> Self:
        return self

    def __native_namespace__(self) -> ModuleType:
        return ibis

    def __narwhals_namespace__(self) -> IbisNamespace:
        from narwhals._ibis.namespace import IbisNamespace

        return IbisNamespace(version=self._version)

    def get_column(self, name: str) -> IbisInterchangeSeries:
        from narwhals._ibis.series import IbisInterchangeSeries

        return IbisInterchangeSeries(self.native.select(name), version=self._version)

    def _iter_columns(self) -> Iterator[ir.Expr]:
        for name in self.columns:
            yield self.native[name]

    def collect(
        self, backend: _EagerAllowedImpl | None, **kwargs: Any
    ) -> CompliantDataFrameAny:
        if backend is None or backend is Implementation.PYARROW:
            from narwhals._arrow.dataframe import ArrowDataFrame

            return ArrowDataFrame(
                to_pyarrow_table(self.native.to_pyarrow()),
                validate_backend_version=True,
                version=self._version,
                validate_column_names=True,
            )

        if backend is Implementation.PANDAS:
            from narwhals._pandas_like.dataframe import PandasLikeDataFrame

            return PandasLikeDataFrame(
                self.native.to_pandas(),
                implementation=Implementation.PANDAS,
                validate_backend_version=True,
                version=self._version,
                validate_column_names=True,
            )

        if backend is Implementation.POLARS:
            from narwhals._polars.dataframe import PolarsDataFrame

            return PolarsDataFrame(
                self.native.to_polars(),
                validate_backend_version=True,
                version=self._version,
            )

        msg = f"Unsupported `backend` value: {backend}"  # pragma: no cover
        raise ValueError(msg)  # pragma: no cover

    def head(self, n: int) -> Self:
        return self._with_native(self.native.head(n))

    def simple_select(self, *column_names: str) -> Self:
        return self._with_native(self.native.select(*column_names))

    def aggregate(self, *exprs: IbisExpr) -> Self:
        selection = [
            cast("ir.Scalar", val.name(name))
            for name, val in evaluate_exprs(self, *exprs)
        ]
        return self._with_native(self.native.aggregate(selection))

    def select(self, *exprs: IbisExpr) -> Self:
        selection = [val.name(name) for name, val in evaluate_exprs(self, *exprs)]
        if not selection:
            msg = "At least one expression must be provided to `select` with the Ibis backend."
            raise ValueError(msg)

        t = self.native.select(*selection)
        return self._with_native(t)

    def drop(self, columns: Sequence[str], *, strict: bool) -> Self:
        columns_to_drop = parse_columns_to_drop(self, columns, strict=strict)
        selection = (col for col in self.columns if col not in columns_to_drop)
        return self._with_native(self.native.select(*selection))

    def lazy(self, backend: None = None, **_: None) -> Self:
        # The `backend`` argument has no effect but we keep it here for
        # backwards compatibility because in `narwhals.stable.v1`
        # function `.from_native()` will return a DataFrame for Ibis.

        if backend is not None:  # pragma: no cover
            msg = "`backend` argument is not supported for Ibis"
            raise ValueError(msg)
        return self

    def with_columns(self, *exprs: IbisExpr) -> Self:
        new_columns_map = dict(evaluate_exprs(self, *exprs))
        return self._with_native(self.native.mutate(**new_columns_map))

    def filter(self, predicate: IbisExpr) -> Self:
        # `[0]` is safe as the predicate's expression only returns a single column
        mask = cast("ir.BooleanValue", predicate(self)[0])
        return self._with_native(self.native.filter(mask))

    @property
    def schema(self) -> dict[str, DType]:
        if self._cached_schema is None:
            # Note: prefer `self._cached_schema` over `functools.cached_property`
            # due to Python3.13 failures.
            self._cached_schema = {
                name: native_to_narwhals_dtype(dtype, self._version)
                for name, dtype in self.native.schema().fields.items()
            }
        return self._cached_schema

    @property
    def columns(self) -> list[str]:
        if self._cached_columns is None:
            self._cached_columns = (
                list(self.schema)
                if self._cached_schema is not None
                else list(self.native.columns)
            )
        return self._cached_columns

    def to_pandas(self) -> pd.DataFrame:
        # only if version is v1, keep around for backcompat
        return self.native.to_pandas()

    def to_arrow(self) -> pa.Table:
        # only if version is v1, keep around for backcompat
        return self.native.to_pyarrow()

    def _with_version(self, version: Version) -> Self:
        return self.__class__(self.native, version=version)

    def _with_native(self, df: ir.Table) -> Self:
        return self.__class__(df, version=self._version)

    def group_by(
        self, keys: Sequence[str] | Sequence[IbisExpr], *, drop_null_keys: bool
    ) -> IbisGroupBy:
        from narwhals._ibis.group_by import IbisGroupBy

        return IbisGroupBy(self, keys, drop_null_keys=drop_null_keys)

    def rename(self, mapping: Mapping[str, str]) -> Self:
        def _rename(col: str) -> str:
            return mapping.get(col, col)

        return self._with_native(self.native.rename(_rename))

    @staticmethod
    def _join_drop_duplicate_columns(df: ir.Table, columns: Iterable[str], /) -> ir.Table:
        """Ibis adds a suffix to the right table col, even when it matches the left during a join."""
        duplicates = set(df.columns).intersection(columns)
        return df.drop(*duplicates) if duplicates else df

    def join(
        self,
        other: Self,
        *,
        how: JoinStrategy,
        left_on: Sequence[str] | None,
        right_on: Sequence[str] | None,
        suffix: str,
    ) -> Self:
        how_native = "outer" if how == "full" else how
        rname = "{name}" + suffix
        if other == self:
            # Ibis does not support self-references unless created as a view
            other = self._with_native(other.native.view())
        if how_native == "cross":
            joined = self.native.join(other.native, how=how_native, rname=rname)
            return self._with_native(joined)
        # help mypy
        assert left_on is not None  # noqa: S101
        assert right_on is not None  # noqa: S101
        predicates = self._convert_predicates(other, left_on, right_on)
        joined = self.native.join(other.native, predicates, how=how_native, rname=rname)
        if how_native == "left":
            right_names = (n + suffix for n in right_on)
            joined = self._join_drop_duplicate_columns(joined, right_names)
            it = (cast("Binary", p.op()) for p in predicates if not isinstance(p, str))
            to_drop = []
            for pred in it:
                right = pred.right.name
                # Mirrors how polars works.
                if right not in self.columns and pred.left.name != right:
                    to_drop.append(right)
            if to_drop:
                joined = joined.drop(*to_drop)
        return self._with_native(joined)

    def join_asof(
        self,
        other: Self,
        *,
        left_on: str,
        right_on: str,
        by_left: Sequence[str] | None,
        by_right: Sequence[str] | None,
        strategy: AsofJoinStrategy,
        suffix: str,
    ) -> Self:
        rname = "{name}" + suffix
        strategy_op = {"backward": operator.ge, "forward": operator.le}
        predicates: JoinPredicates = []
        if op := strategy_op.get(strategy):
            on: ir.BooleanColumn = op(self.native[left_on], other.native[right_on])
        else:
            msg = "Only `backward` and `forward` strategies are currently supported for Ibis"
            raise NotImplementedError(msg)
        if by_left is not None and by_right is not None:
            predicates = self._convert_predicates(other, by_left, by_right)
        joined = self.native.asof_join(other.native, on, predicates, rname=rname)
        joined = self._join_drop_duplicate_columns(joined, [right_on + suffix])
        if by_right is not None:
            right_names = (n + suffix for n in by_right)
            joined = self._join_drop_duplicate_columns(joined, right_names)
        return self._with_native(joined)

    def _convert_predicates(
        self, other: Self, left_on: Sequence[str], right_on: Sequence[str]
    ) -> JoinPredicates:
        if left_on == right_on:
            return left_on
        return [
            cast("ir.BooleanColumn", (self.native[left] == other.native[right]))
            for left, right in zip(left_on, right_on, strict=True)
        ]

    def collect_schema(self) -> dict[str, DType]:
        return {
            name: native_to_narwhals_dtype(dtype, self._version)
            for name, dtype in self.native.schema().fields.items()
        }

    def unique(
        self,
        subset: Sequence[str] | None,
        *,
        keep: UniqueKeepStrategy,
        order_by: Sequence[str] | None,
    ) -> Self:
        subset_ = subset or self.columns
        if error := self._check_columns_exist(subset_):
            raise error
        tmp_name = generate_temporary_column_name(8, self.columns, prefix="row_index_")
        if order_by and keep == "last":
            order_by_ = IbisExpr._sort(*order_by, descending=True, nulls_last=True)
        elif order_by:
            order_by_ = IbisExpr._sort(*order_by, descending=False, nulls_last=False)
        else:
            order_by_ = lit(1)
        window = ibis.window(group_by=subset_, order_by=order_by_)
        if keep == "none":
            expr = self.native.count().over(window)
        else:
            expr = ibis.row_number().over(window) + lit(1)
        df = (
            self.native.mutate(**{tmp_name: expr})
            .filter(ibis._[tmp_name] == lit(1))
            .drop(tmp_name)
        )
        return self._with_native(df)

    def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) -> Self:
        from narwhals._ibis.expr import IbisExpr

        cols = IbisExpr._sort(*by, descending=descending, nulls_last=nulls_last)
        return self._with_native(self.native.order_by(*cols))

    def top_k(self, k: int, *, by: Iterable[str], reverse: bool | Sequence[bool]) -> Self:
        from narwhals._ibis.expr import IbisExpr

        desc = not reverse if isinstance(reverse, bool) else [not el for el in reverse]
        cols = IbisExpr._sort(*by, descending=desc, nulls_last=True)
        return self._with_native(self.native.order_by(*cols).head(k))

    def drop_nulls(self, subset: Sequence[str] | None) -> Self:
        subset_ = subset if subset is not None else self.columns
        return self._with_native(self.native.drop_null(subset_))

    def explode(self, columns: Sequence[str]) -> Self:
        dtypes = self._version.dtypes
        schema = self.collect_schema()
        for col in columns:
            dtype = schema[col]

            if dtype != dtypes.List:
                msg = (
                    f"`explode` operation not supported for dtype `{dtype}`, "
                    "expected List type"
                )
                raise InvalidOperationError(msg)

        if len(columns) != 1:
            msg = (
                "Exploding on multiple columns is not supported with Ibis backend since "
                "we cannot guarantee that the exploded columns have matching element counts."
            )
            raise NotImplementedError(msg)

        return self._with_native(self.native.unnest(columns[0], keep_empty=True))

    def unpivot(
        self,
        on: Sequence[str] | None,
        index: Sequence[str] | None,
        variable_name: str,
        value_name: str,
    ) -> Self:
        import ibis.selectors as s

        index_: Sequence[str] = [] if index is None else index
        on_: Sequence[str] = (
            [c for c in self.columns if c not in index_] if on is None else on
        )

        # Discard columns not in the index
        final_columns = list(dict.fromkeys([*index_, variable_name, value_name]))

        unpivoted = self.native.pivot_longer(
            s.cols(*on_), names_to=variable_name, values_to=value_name
        )
        return self._with_native(unpivoted.select(*final_columns))

    def with_row_index(self, name: str, order_by: Sequence[str]) -> Self:
        to_select = [
            ibis.row_number().over(ibis.window(order_by=order_by)).name(name),
            ibis.selectors.all(),
        ]
        return self._with_native(self.native.select(*to_select))

    def sink_parquet(self, file: str | Path | BytesIO) -> None:
        if isinstance(file, BytesIO):  # pragma: no cover
            msg = "Writing to BytesIO is not supported for Ibis backend."
            raise NotImplementedError(msg)
        self.native.to_parquet(file)

    # Intentionally not implemented, as Ibis does its own expression rewriting.
    _evaluate_window_expr = not_implemented()
    _filter = not_implemented()
