import json
import os
import platform
from pathlib import Path
from typing import Any, Union

import pytest

from pydantic_core import SchemaSerializer, core_schema

from ..conftest import plain_repr

on_pypy = platform.python_implementation() == 'PyPy'
# pypy doesn't seem to maintain order of `__dict__`
if on_pypy:
    IsStrictDict = dict
else:
    pass


class BaseModel:
    __slots__ = '__dict__', '__pydantic_fields_set__', '__pydantic_extra__', '__pydantic_private__'

    def __init__(self, **kwargs):
        for key, value in kwargs.items():
            setattr(self, key, value)


class RootModel:
    __slots__ = '__dict__', '__pydantic_fields_set__', '__pydantic_extra__', '__pydantic_private__'
    root: str

    def __init__(self, data):
        self.root = data


class RootSubModel(RootModel):
    pass


def test_model_root():
    s = SchemaSerializer(core_schema.model_schema(RootModel, core_schema.int_schema(), root_model=True))
    print(plain_repr(s))
    # TODO: assert 'mode:RootModel' in plain_repr(s)
    assert 'has_extra:false' in plain_repr(s)
    assert s.to_python(RootModel(1)) == 1
    assert s.to_python(RootSubModel(1)) == 1

    j = s.to_json(RootModel(1))
    if on_pypy:
        assert json.loads(j) == 1
    else:
        assert j == b'1'

    assert json.loads(s.to_json(RootSubModel(1))) == 1


def test_function_plain_field_serializer_to_python():
    class Model(RootModel):
        def ser_root(self, v: Any, _) -> str:
            assert self.root == 1_000
            return f'{v:_}'

    s = SchemaSerializer(
        core_schema.model_schema(
            Model,
            core_schema.int_schema(
                serialization=core_schema.plain_serializer_function_ser_schema(
                    Model.ser_root, is_field_serializer=True, info_arg=True
                )
            ),
            root_model=True,
        )
    )
    assert s.to_python(Model(1000)) == '1_000'


def test_function_wrap_field_serializer_to_python():
    class Model(RootModel):
        def ser_root(self, v: Any, serializer: core_schema.SerializerFunctionWrapHandler, _) -> str:
            root = serializer(v)
            assert self.root == 1_000
            return f'{root:_}'

    s = SchemaSerializer(
        core_schema.model_schema(
            Model,
            core_schema.int_schema(
                serialization=core_schema.wrap_serializer_function_ser_schema(
                    Model.ser_root, is_field_serializer=True, info_arg=True, schema=core_schema.any_schema()
                )
            ),
            root_model=True,
        )
    )
    assert s.to_python(Model(1000)) == '1_000'


def test_function_plain_field_serializer_to_json():
    class Model(RootModel):
        def ser_root(self, v: Any, _) -> str:
            assert self.root == 1_000
            return f'{v:_}'

    s = SchemaSerializer(
        core_schema.model_schema(
            Model,
            core_schema.int_schema(
                serialization=core_schema.plain_serializer_function_ser_schema(
                    Model.ser_root, is_field_serializer=True, info_arg=True
                )
            ),
            root_model=True,
        )
    )
    assert json.loads(s.to_json(Model(1000))) == '1_000'


def test_function_wrap_field_serializer_to_json():
    class Model(RootModel):
        def ser_root(self, v: Any, serializer: core_schema.SerializerFunctionWrapHandler, _) -> str:
            assert self.root == 1_000
            root = serializer(v)
            return f'{root:_}'

    s = SchemaSerializer(
        core_schema.model_schema(
            Model,
            core_schema.int_schema(
                serialization=core_schema.wrap_serializer_function_ser_schema(
                    Model.ser_root, is_field_serializer=True, info_arg=True, schema=core_schema.any_schema()
                )
            ),
            root_model=True,
        )
    )
    assert json.loads(s.to_json(Model(1000))) == '1_000'


@pytest.mark.parametrize('order', ['BR', 'RB'])
def test_root_model_dump_with_base_model(order):
    class BModel(BaseModel):
        value: str

    b_schema = core_schema.model_schema(
        BModel, core_schema.model_fields_schema({'value': core_schema.model_field(core_schema.str_schema())})
    )

    class RModel(RootModel):
        root: int

    r_schema = core_schema.model_schema(RModel, core_schema.int_schema(), root_model=True)

    if order == 'BR':

        class Model(RootModel):
            root: list[Union[BModel, RModel]]

        choices = [b_schema, r_schema]

    elif order == 'RB':

        class Model(RootModel):
            root: list[Union[RModel, BModel]]

        choices = [r_schema, b_schema]

    s = SchemaSerializer(
        core_schema.model_schema(
            Model, core_schema.list_schema(core_schema.union_schema(choices=choices)), root_model=True
        )
    )

    m = Model([RModel(1), RModel(2), BModel(value='abc')])

    assert s.to_python(m) == [1, 2, {'value': 'abc'}]
    assert s.to_json(m) == b'[1,2,{"value":"abc"}]'


def test_not_root_model():
    # https://github.com/pydantic/pydantic/issues/8963

    class RootModel:
        root: int

    v = RootModel()
    v.root = '123'

    s = SchemaSerializer(
        core_schema.model_schema(
            RootModel,
            core_schema.str_schema(),
            root_model=True,
        ),
    )

    assert s.to_python(v) == '123'
    assert s.to_json(v) == b'"123"'

    # Path is chosen because it has a .root property
    # which could look like a root model in bad implementations

    if os.name == 'nt':
        path_value = Path('C:\\a\\b')
        path_bytes = b'"C:\\\\a\\\\b"'  # fixme double escaping?
    else:
        path_value = Path('/a/b')
        path_bytes = b'"/a/b"'

    with pytest.warns(UserWarning, match=r'PydanticSerializationUnexpectedValue\(Expected `RootModel`'):
        assert s.to_python(path_value) == path_value

    with pytest.warns(UserWarning, match=r'PydanticSerializationUnexpectedValue\(Expected `RootModel`'):
        assert s.to_json(path_value) == path_bytes

    assert s.to_python(path_value, warnings=False) == path_value
    assert s.to_json(path_value, warnings=False) == path_bytes
