import io
import os
from dataclasses import dataclass
from typing import Literal, Optional

import pandas as pd
import pyarrow as pa
import pyarrow.json as paj

import datasets
import datasets.config
from datasets.builder import Key
from datasets.table import table_cast
from datasets.utils.file_utils import readline
from datasets.utils.json import (
    find_mixed_struct_types_field_paths,
    get_json_field_path_from_pyarrow_json_error,
    get_json_field_paths_from_feature,
    insert_json_field_path,
    json_encode_field,
    json_encode_fields_in_json_lines,
    set_json_types_in_feature,
    ujson_dumps,
    ujson_loads,
)


logger = datasets.utils.logging.get_logger(__name__)


def pandas_read_json(path_or_buf, **kwargs):
    if datasets.config.PANDAS_VERSION.major >= 2:
        kwargs["dtype_backend"] = "pyarrow"
    return pd.read_json(path_or_buf, **kwargs)


class FullReadDisallowed(Exception):
    pass


@dataclass
class JsonConfig(datasets.BuilderConfig):
    """BuilderConfig for JSON."""

    features: Optional[datasets.Features] = None
    encoding: str = "utf-8"
    encoding_errors: Optional[str] = None
    field: Optional[str] = None
    use_threads: bool = True  # deprecated
    block_size: Optional[int] = None  # deprecated
    chunksize: int = 10 << 20  # 10MB
    newlines_in_values: Optional[bool] = None
    on_mixed_types: Optional[Literal["use_json"]] = "use_json"
    parse_agent_traces: bool = True

    def __post_init__(self):
        super().__post_init__()


class Json(datasets.ArrowBasedBuilder):
    BUILDER_CONFIG_CLASS = JsonConfig

    def _info(self):
        if self.config.block_size is not None:
            logger.warning("The JSON loader parameter `block_size` is deprecated. Please use `chunksize` instead")
            self.config.chunksize = self.config.block_size
        if self.config.use_threads is not True:
            logger.warning(
                "The JSON loader parameter `use_threads` is deprecated and doesn't have any effect anymore."
            )
        if self.config.newlines_in_values is not None:
            raise ValueError("The JSON loader parameter `newlines_in_values` is no longer supported")
        return datasets.DatasetInfo(features=self.config.features)

    def _split_generators(self, dl_manager):
        """We handle string, list and dicts in datafiles"""
        if not self.config.data_files:
            raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}")
        dl_manager.download_config.extract_on_the_fly = True
        base_data_files = dl_manager.download(self.config.data_files)
        extracted_data_files = dl_manager.extract(base_data_files)
        splits = []
        for split_name, extracted_files in extracted_data_files.items():
            files_iterables = [dl_manager.iter_files(extracted_file) for extracted_file in extracted_files]
            splits.append(
                datasets.SplitGenerator(
                    name=split_name,
                    gen_kwargs={
                        "files_iterables": files_iterables,
                        "base_files": base_data_files[split_name],
                        "original_files": self.config.data_files[split_name],
                    },
                )
            )
        if self.info.features is None:
            try:
                pa_table = next(iter(self._generate_tables(**splits[0].gen_kwargs, allow_full_read=False)))[1]
                self.info.features = datasets.Features.from_arrow_schema(pa_table.schema)
                if self.config.parse_agent_traces and has_agent_traces_markers(self.info.features):
                    self.info.features = AGENT_TRACES_FEATURES
            except FullReadDisallowed:
                pass
        return splits

    def _cast_table(self, pa_table: pa.Table, json_field_paths=()) -> pa.Table:
        if self.info.features is not None:
            # adding missing columns
            for column_name in set(self.info.features) - set(pa_table.column_names):
                type = self.info.features.arrow_schema.field(column_name).type
                pa_table = pa_table.append_column(column_name, pa.array([None] * len(pa_table), type=type))
            # convert to string when needed
            for i, column_name in enumerate(pa_table.column_names):
                if pa.types.is_struct(pa_table[column_name].type) and self.info.features.get(
                    column_name, None
                ) == datasets.Value("string"):
                    jsonl = (
                        pa_table[column_name]
                        .to_pandas(types_mapper=pd.ArrowDtype)
                        .to_json(orient="records", lines=True)
                    )
                    string_array = pa.array(
                        (None if x.strip() == "null" else x.strip() for x in jsonl.split("\n") if x.strip()),
                        type=pa.string(),
                    )
                    pa_table = pa_table.set_column(i, column_name, string_array)
            # more expensive cast to support nested structures with keys in a different order
            # allows str <-> int/float or str to Audio for example
            pa_table = table_cast(pa_table, self.info.features.arrow_schema)
        elif json_field_paths:
            features = datasets.Features.from_arrow_schema(pa_table.schema)
            features = set_json_types_in_feature(features, json_field_paths)
            pa_table = table_cast(pa_table, features.arrow_schema)
        return pa_table

    def _generate_shards(self, base_files, files_iterables, original_files):
        yield from base_files

    def _generate_tables(self, base_files, files_iterables, original_files, allow_full_read=True):
        json_field_paths = []
        is_agent_traces = False

        if self.info.features is not None:
            if self.info.features == AGENT_TRACES_FEATURES:
                is_agent_traces = True
            else:
                json_field_paths = get_json_field_paths_from_feature(self.info.features)

        for shard_idx, files_iterable in enumerate(files_iterables):
            for file in files_iterable:
                # If the file is one json object and if we need to look at the items in one specific field
                if self.config.field is not None:
                    if not allow_full_read:
                        raise FullReadDisallowed()
                    with open(file, encoding=self.config.encoding, errors=self.config.encoding_errors) as f:
                        dataset = ujson_loads(f.read())
                    # We keep only the field we are interested in
                    dataset = dataset[self.config.field]
                    df = pandas_read_json(io.StringIO(ujson_dumps(dataset)))
                    if df.columns.tolist() == [0]:
                        df.columns = list(self.config.features) if self.config.features else ["text"]
                    pa_table = pa.Table.from_pandas(df, preserve_index=False)
                    yield Key(shard_idx, 0), self._cast_table(pa_table)

                # If the files are agent traces (one row = one file)
                elif is_agent_traces:
                    with open(file, "r", encoding="utf-8") as f:
                        traces = f.readlines()
                    harness, session_id = parse_traces_info(traces)
                    file_path = original_files[shard_idx]
                    if file_path.startswith(self.base_path):
                        file_path = os.path.relpath(file_path, self.base_path)
                    pa_table = pa.Table.from_pydict(
                        {
                            "harness": [harness],
                            "session_id": [session_id],
                            "traces": [traces],
                            "file_path": [file_path],
                        }
                    )
                    yield Key(shard_idx, 0), self._cast_table(pa_table)

                # If the file has one json object per line
                else:
                    with open(file, "rb") as f:
                        batch_idx = 0
                        # Use block_size equal to the chunk size divided by 32 to leverage multithreading
                        # Set a default minimum value of 16kB if the chunk size is really small
                        block_size = max(self.config.chunksize // 32, 16 << 10)
                        encoding_errors = (
                            self.config.encoding_errors if self.config.encoding_errors is not None else "strict"
                        )
                        while True:
                            batch = f.read(self.config.chunksize)
                            if not batch:
                                break
                            if batch.startswith(b"["):
                                if not allow_full_read:
                                    raise FullReadDisallowed()
                                else:
                                    # convert to JSON Lines
                                    full_data = batch + f.read()
                                    if b"{" in batch[:100].split(b'"', 1)[0]:  # list of objects
                                        batch = "\n".join(ujson_dumps(x) for x in ujson_loads(full_data)).encode()
                                    else:  # list of strings
                                        batch = "\n".join(
                                            ujson_dumps({"text": x}) for x in ujson_loads(full_data)
                                        ).encode()
                            # Finish current line
                            try:
                                batch += f.readline()
                            except (AttributeError, io.UnsupportedOperation):
                                batch += readline(f)
                            # PyArrow only accepts utf-8 encoded bytes
                            if self.config.encoding != "utf-8":
                                batch = batch.decode(self.config.encoding, errors=encoding_errors).encode("utf-8")
                            # On first batch we check for lists of objects with arbitrary fields
                            if (
                                shard_idx == 0
                                and batch_idx == 0
                                and self.info.features is None
                                and self.config.on_mixed_types == "use_json"
                            ):
                                try:
                                    examples = [ujson_loads(line) for line in batch.splitlines()]
                                except ValueError:
                                    # the file is likely not JSON Lines and may contain one single multi-line JSON object
                                    pass
                                else:
                                    json_field_paths += find_mixed_struct_types_field_paths(examples)
                            # Re-encode JSON fields
                            original_batch = batch
                            if json_field_paths:
                                examples = [ujson_loads(line) for line in batch.splitlines()]
                                for json_field_path in json_field_paths:
                                    examples = [json_encode_field(examples, json_field_path) for examples in examples]
                                batch = "\n".join(ujson_dumps(example) for example in examples).encode()
                            # Disable parallelism if block size is ~ len(batch) to avoid segfault
                            block_size = len(batch) if len(batch) // 8 > block_size else block_size
                            try:
                                while True:
                                    try:
                                        pa_table = paj.read_json(
                                            io.BytesIO(batch), read_options=paj.ReadOptions(block_size=block_size)
                                        )
                                        break
                                    except (pa.ArrowInvalid, pa.ArrowNotImplementedError) as e:
                                        if batch.startswith(b"["):  # paj.read_json only supports json lines
                                            raise
                                        elif self.config.on_mixed_types == "use_json" and (
                                            isinstance(e, pa.ArrowInvalid)
                                            and "JSON parse error: Column(" in str(e)
                                            and ") changed from" in str(e)
                                        ):
                                            json_field_path = get_json_field_path_from_pyarrow_json_error(str(e))
                                            insert_json_field_path(json_field_paths, json_field_path)
                                            batch = json_encode_fields_in_json_lines(original_batch, json_field_paths)
                                        elif (
                                            "straddling" in str(e) or "JSON conversion to" in str(e)
                                        ) and block_size < len(batch):
                                            # Increase the block size in case it was too small.
                                            # The block size will be reset for the next file.
                                            # this is needed in case of "stradding" or for some JSON conversions (see https://github.com/huggingface/datasets/issues/2799)
                                            logger.debug(
                                                f"Batch of {len(batch)} bytes couldn't be parsed with block_size={block_size}. Retrying with block_size={block_size * 2}."
                                            )
                                            block_size *= 2
                                        else:
                                            raise
                            except pa.ArrowInvalid as e:
                                if not allow_full_read:
                                    raise FullReadDisallowed()
                                try:
                                    with open(
                                        file, encoding=self.config.encoding, errors=self.config.encoding_errors
                                    ) as f:
                                        df = pandas_read_json(f)
                                except ValueError:
                                    logger.error(f"Failed to load JSON from file '{file}' with error {type(e)}: {e}")
                                    raise e
                                if df.columns.tolist() == [0]:
                                    df.columns = list(self.config.features) if self.config.features else ["text"]
                                try:
                                    pa_table = pa.Table.from_pandas(df, preserve_index=False)
                                except pa.ArrowInvalid as e:
                                    logger.error(
                                        f"Failed to convert pandas DataFrame to Arrow Table from file '{file}' with error {type(e)}: {e}"
                                    )
                                    raise ValueError(
                                        f"Failed to convert pandas DataFrame to Arrow Table from file {file}."
                                    ) from None
                                yield Key(shard_idx, 0), self._cast_table(pa_table)
                                break
                            yield (
                                Key(shard_idx, batch_idx),
                                self._cast_table(pa_table, json_field_paths=json_field_paths),
                            )
                            batch_idx += 1


AGENT_TRACES_TYPES_VALUES = {
    "claude_code": ["user", "assistant", "system"],
    "pi": ["session", "message"],
    "codex": ["session_meta", "turn_context", "response_item", "event_msg"],
}
AGENT_TRACES_TYPE_TO_HARNESS = {}
for _harness, _trace_types in AGENT_TRACES_TYPES_VALUES.items():
    for _trace_type in _trace_types:
        AGENT_TRACES_TYPE_TO_HARNESS[_trace_type] = _harness


AGENT_TRACES_FEATURES_MARKERS = {
    "claude_code": datasets.Features(
        {
            "type": datasets.Value("string"),
            "message": datasets.Json(),
        }
    ),
    "pi": datasets.Features(
        {
            "type": datasets.Value("string"),
            "message": datasets.Json(),
        }
    ),
    "codex": datasets.Features(
        {
            "type": datasets.Value("string"),
            "payload": datasets.Json(),
        }
    ),
}

AGENT_TRACES_FEATURES = datasets.Features(
    {
        "harness": datasets.Value("string"),
        "session_id": datasets.Value("string"),
        "traces": datasets.List(datasets.Json()),
        "file_path": datasets.Value("string"),
    }
)


def has_agent_traces_markers(features: datasets.Features) -> bool:
    for agent_traces_features_marker in AGENT_TRACES_FEATURES_MARKERS.values():
        if all(features.get(key) == feature for key, feature in agent_traces_features_marker.items()):
            return True
    return False


def parse_traces_info(traces: list[str]) -> tuple[Optional[str], Optional[str]]:
    harness, session_id = None, None
    for trace in traces:
        decoded_trace = ujson_loads(trace)
        if harness is None:
            if "type" in decoded_trace and isinstance(decoded_trace["type"], str):
                harness = AGENT_TRACES_TYPE_TO_HARNESS.get(decoded_trace["type"])
        if session_id is None:
            # claude
            if "sessionId" in decoded_trace and isinstance(decoded_trace["sessionId"], str):
                session_id = decoded_trace["sessionId"]
            # claude (not sure but this format does exist online)
            elif "session_id" in decoded_trace and isinstance(decoded_trace["session_id"], str):
                session_id = decoded_trace["session_id"]
            # codex
            elif (
                "payload" in decoded_trace
                and isinstance(decoded_trace["payload"], dict)
                and "id" in decoded_trace["payload"]
                and isinstance(decoded_trace["payload"]["id"], str)
            ):
                session_id = decoded_trace["payload"]["id"]
            # pi / openclaw (openclaw embeds pi-agent; distinguish via cwd)
            elif (
                "type" in decoded_trace
                and decoded_trace["type"] == "session"
                and "id" in decoded_trace
                and isinstance(decoded_trace["id"], str)
            ):
                session_id = decoded_trace["id"]
                if isinstance(decoded_trace.get("cwd"), str) and "/.openclaw/" in decoded_trace["cwd"]:
                    harness = "openclaw"
        if harness and session_id:
            break
    return harness, session_id
