"""Code generated by Speakeasy (https://speakeasy.com). DO NOT EDIT."""
# @generated-id: 4000b05e3b8d

from __future__ import annotations
from mistralai.client.types import (
    BaseModel,
    Nullable,
    OptionalNullable,
    UNSET,
    UNSET_SENTINEL,
)
from pydantic import model_serializer
from typing import Optional
from typing_extensions import NotRequired, TypedDict


class ClassifierTrainingParametersTypedDict(TypedDict):
    training_steps: NotRequired[Nullable[int]]
    r"""The number of training steps to perform. A training step refers to a single update of the model weights during the fine-tuning process. This update is typically calculated using a batch of samples from the training dataset."""
    learning_rate: NotRequired[float]
    r"""A parameter describing how much to adjust the pre-trained model's weights in response to the estimated error each time the weights are updated during the fine-tuning process."""
    weight_decay: NotRequired[Nullable[float]]
    r"""(Advanced Usage) Weight decay adds a term to the loss function that is proportional to the sum of the squared weights. This term reduces the magnitude of the weights and prevents them from growing too large."""
    warmup_fraction: NotRequired[Nullable[float]]
    r"""(Advanced Usage) A parameter that specifies the percentage of the total training steps at which the learning rate warm-up phase ends. During this phase, the learning rate gradually increases from a small value to the initial learning rate, helping to stabilize the training process and improve convergence. Similar to `pct_start` in [mistral-finetune](https://github.com/mistralai/mistral-finetune)"""
    epochs: NotRequired[Nullable[float]]
    seq_len: NotRequired[Nullable[int]]


class ClassifierTrainingParameters(BaseModel):
    training_steps: OptionalNullable[int] = UNSET
    r"""The number of training steps to perform. A training step refers to a single update of the model weights during the fine-tuning process. This update is typically calculated using a batch of samples from the training dataset."""

    learning_rate: Optional[float] = 0.0001
    r"""A parameter describing how much to adjust the pre-trained model's weights in response to the estimated error each time the weights are updated during the fine-tuning process."""

    weight_decay: OptionalNullable[float] = UNSET
    r"""(Advanced Usage) Weight decay adds a term to the loss function that is proportional to the sum of the squared weights. This term reduces the magnitude of the weights and prevents them from growing too large."""

    warmup_fraction: OptionalNullable[float] = UNSET
    r"""(Advanced Usage) A parameter that specifies the percentage of the total training steps at which the learning rate warm-up phase ends. During this phase, the learning rate gradually increases from a small value to the initial learning rate, helping to stabilize the training process and improve convergence. Similar to `pct_start` in [mistral-finetune](https://github.com/mistralai/mistral-finetune)"""

    epochs: OptionalNullable[float] = UNSET

    seq_len: OptionalNullable[int] = UNSET

    @model_serializer(mode="wrap")
    def serialize_model(self, handler):
        optional_fields = set(
            [
                "training_steps",
                "learning_rate",
                "weight_decay",
                "warmup_fraction",
                "epochs",
                "seq_len",
            ]
        )
        nullable_fields = set(
            ["training_steps", "weight_decay", "warmup_fraction", "epochs", "seq_len"]
        )
        serialized = handler(self)
        m = {}

        for n, f in type(self).model_fields.items():
            k = f.alias or n
            val = serialized.get(k, serialized.get(n))
            is_nullable_and_explicitly_set = (
                k in nullable_fields
                and (self.__pydantic_fields_set__.intersection({n}))  # pylint: disable=no-member
            )

            if val != UNSET_SENTINEL:
                if (
                    val is not None
                    or k not in optional_fields
                    or is_nullable_and_explicitly_set
                ):
                    m[k] = val

        return m
