{{py:

"""
Template file to easily generate loops over samples using Tempita
(https://github.com/cython/cython/blob/master/Cython/Tempita/_tempita.py).

Generated file: _loss.pyx

Each loss class is generated by a cdef functions on single samples.
The keywords between double braces are substituted during the build.
"""

doc_HalfSquaredError = (
    """Half Squared Error with identity link.

    Domain:
    y_true and y_pred all real numbers

    Link:
    y_pred = raw_prediction
    """
)

doc_AbsoluteError = (
    """Absolute Error with identity link.

    Domain:
    y_true and y_pred all real numbers

    Link:
    y_pred = raw_prediction
    """
)

doc_PinballLoss = (
    """Quantile Loss aka Pinball Loss with identity link.

    Domain:
    y_true and y_pred all real numbers
    quantile in (0, 1)

    Link:
    y_pred = raw_prediction

    Note: 2 * cPinballLoss(quantile=0.5) equals cAbsoluteError()
    """
)

doc_HuberLoss = (
    """Huber Loss with identity link.

    Domain:
    y_true and y_pred all real numbers
    delta in positive real numbers

    Link:
    y_pred = raw_prediction
    """
)

doc_HalfPoissonLoss = (
    """Half Poisson deviance loss with log-link.

    Domain:
    y_true in non-negative real numbers
    y_pred in positive real numbers

    Link:
    y_pred = exp(raw_prediction)

    Half Poisson deviance with log-link is
        y_true * log(y_true/y_pred) + y_pred - y_true
        = y_true * log(y_true) - y_true * raw_prediction
          + exp(raw_prediction) - y_true

    Dropping constant terms, this gives:
        exp(raw_prediction) - y_true * raw_prediction
    """
)

doc_HalfGammaLoss = (
    """Half Gamma deviance loss with log-link.

    Domain:
    y_true and y_pred in positive real numbers

    Link:
    y_pred = exp(raw_prediction)

    Half Gamma deviance with log-link is
        log(y_pred/y_true) + y_true/y_pred - 1
        = raw_prediction - log(y_true) + y_true * exp(-raw_prediction) - 1

    Dropping constant terms, this gives:
        raw_prediction + y_true * exp(-raw_prediction)
    """
)

doc_HalfTweedieLoss = (
    """Half Tweedie deviance loss with log-link.

    Domain:
    y_true in real numbers if p <= 0
    y_true in non-negative real numbers if 0 < p < 2
    y_true in positive real numbers if p >= 2
    y_pred and power in positive real numbers

    Link:
    y_pred = exp(raw_prediction)

    Half Tweedie deviance with log-link and p=power is
        max(y_true, 0)**(2-p) / (1-p) / (2-p)
        - y_true * y_pred**(1-p) / (1-p)
        + y_pred**(2-p) / (2-p)
        = max(y_true, 0)**(2-p) / (1-p) / (2-p)
        - y_true * exp((1-p) * raw_prediction) / (1-p)
        + exp((2-p) * raw_prediction) / (2-p)

    Dropping constant terms, this gives:
        exp((2-p) * raw_prediction) / (2-p)
        - y_true * exp((1-p) * raw_prediction) / (1-p)

    Notes:
    - Poisson with p=1 and Gamma with p=2 have different terms dropped such
      that cHalfTweedieLoss is not continuous in p=power at p=1 and p=2.
    - While the Tweedie distribution only exists for p<=0 or p>=1, the range
      0<p<1 still gives a strictly consistent scoring function for the
      expectation.
    """
)

doc_HalfTweedieLossIdentity = (
    """Half Tweedie deviance loss with identity link.

    Domain:
    y_true in real numbers if p <= 0
    y_true in non-negative real numbers if 0 < p < 2
    y_true in positive real numbers if p >= 2
    y_pred and power in positive real numbers, y_pred may be negative for p=0.

    Link:
    y_pred = raw_prediction

    Half Tweedie deviance with identity link and p=power is
        max(y_true, 0)**(2-p) / (1-p) / (2-p)
        - y_true * y_pred**(1-p) / (1-p)
        + y_pred**(2-p) / (2-p)

    Notes:
    - Here, we do not drop constant terms in contrast to the version with log-link.
    """
)

doc_HalfBinomialLoss = (
    """Half Binomial deviance loss with logit link.

    Domain:
    y_true in [0, 1]
    y_pred in (0, 1), i.e. boundaries excluded

    Link:
    y_pred = expit(raw_prediction)
    """
)

doc_ExponentialLoss = (
    """"Exponential loss with (half) logit link

    Domain:
    y_true in [0, 1]
    y_pred in (0, 1), i.e. boundaries excluded

    Link:
    y_pred = expit(2 * raw_prediction)
    """
)

# loss class name, docstring, param,
# cy_loss, cy_loss_grad,
# cy_grad, cy_grad_hess,
class_list = [
    ("CyHalfSquaredError", doc_HalfSquaredError, None,
     "closs_half_squared_error", None,
     "cgradient_half_squared_error", "cgrad_hess_half_squared_error"),
    ("CyAbsoluteError", doc_AbsoluteError, None,
     "closs_absolute_error", None,
     "cgradient_absolute_error", "cgrad_hess_absolute_error"),
    ("CyPinballLoss", doc_PinballLoss, "quantile",
     "closs_pinball_loss", None,
     "cgradient_pinball_loss", "cgrad_hess_pinball_loss"),
     ("CyHuberLoss", doc_HuberLoss, "delta",
     "closs_huber_loss", None,
     "cgradient_huber_loss", "cgrad_hess_huber_loss"),
    ("CyHalfPoissonLoss", doc_HalfPoissonLoss, None,
     "closs_half_poisson", "closs_grad_half_poisson",
     "cgradient_half_poisson", "cgrad_hess_half_poisson"),
    ("CyHalfGammaLoss", doc_HalfGammaLoss, None,
     "closs_half_gamma", "closs_grad_half_gamma",
     "cgradient_half_gamma", "cgrad_hess_half_gamma"),
    ("CyHalfTweedieLoss", doc_HalfTweedieLoss, "power",
     "closs_half_tweedie", "closs_grad_half_tweedie",
     "cgradient_half_tweedie", "cgrad_hess_half_tweedie"),
    ("CyHalfTweedieLossIdentity", doc_HalfTweedieLossIdentity, "power",
     "closs_half_tweedie_identity", "closs_grad_half_tweedie_identity",
     "cgradient_half_tweedie_identity", "cgrad_hess_half_tweedie_identity"),
    ("CyHalfBinomialLoss", doc_HalfBinomialLoss, None,
     "closs_half_binomial", "closs_grad_half_binomial",
     "cgradient_half_binomial", "cgrad_hess_half_binomial"),
     ("CyExponentialLoss", doc_ExponentialLoss, None,
     "closs_exponential", "closs_grad_exponential",
     "cgradient_exponential", "cgrad_hess_exponential"),
]
}}

# Design:
# See https://github.com/scikit-learn/scikit-learn/issues/15123 for reasons.
# a) Merge link functions into loss functions for speed and numerical
#    stability, i.e. use raw_prediction instead of y_pred in signature.
# b) Pure C functions (nogil) calculate single points (single sample)
# c) Wrap C functions in a loop to get Python functions operating on ndarrays.
#   - Write loops manually---use Tempita for this.
#     Reason: There is still some performance overhead when using a wrapper
#     function "wrap" that carries out the loop and gets as argument a function
#     pointer to one of the C functions from b), e.g.
#     wrap(closs_half_poisson, y_true, ...)
#   - Pass n_threads as argument to prange and propagate option to all callers.
# d) Provide classes (Cython extension types) per loss (names start with Cy) in
#    order to have semantical structured objects.
#    - Member functions for single points just call the C function from b).
#      These are used e.g. in SGD `_plain_sgd`.
#    - Member functions operating on ndarrays, see c), looping over calls to C
#      functions from b).
# e) Provide convenience Python classes that compose from these extension types
#    elsewhere (see loss.py)
#    - Example: loss.gradient calls CyLoss.gradient but does some input
#      checking like None -> np.empty().
#
# Note: We require 1-dim ndarrays to be contiguous.

from cython.parallel import parallel, prange
import numpy as np

from libc.math cimport exp, fabs, log, log1p, pow
from libc.stdlib cimport malloc, free


# -------------------------------------
# Helper functions
# -------------------------------------
# Numerically stable version of log(1 + exp(x)) for double precision, see Eq. (10) of
# https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf
# Note: The only important cutoff is at x = 18. All others are to save computation
# time. Compared to the reference, we add the additional case distinction x <= -2 in
# order to use log instead of log1p for improved performance. As with the other
# cutoffs, this is accurate within machine precision of double.
cdef inline double log1pexp(double x) noexcept nogil:
    if x <= -37:
        return exp(x)
    elif x <= -2:
        return log1p(exp(x))
    elif x <= 18:
        return log(1. + exp(x))
    elif x <= 33.3:
        return x + exp(-x)
    else:
        return x


cdef inline double_pair sum_exp_minus_max(
    const int i,
    const floating_in[:, :] raw_prediction,  # IN
    floating_out *p                           # OUT
) noexcept nogil:
    # Thread local buffers are used to store part of the results via p.
    # The results are stored as follows:
    #     p[k] = exp(raw_prediction_i_k - max_value) for k = 0 to n_classes-1
    #     return.val1 = max_value = max(raw_prediction_i_k, k = 0 to n_classes-1)
    #     return.val2 = sum_exps = sum(p[k], k = 0 to n_classes-1) = sum of exponentials
    # len(p) must be n_classes
    # Notes:
    # - We return the max value and sum of exps (stored in p) as a double_pair.
    # - i needs to be passed (and stays constant) because otherwise Cython does
    #   not generate optimal code, see
    #   https://github.com/scikit-learn/scikit-learn/issues/17299
    # - We do not normalize p by calculating p[k] = p[k] / sum_exps.
    #   This helps to save one loop over k.
    cdef:
        int k
        int n_classes = raw_prediction.shape[1]
        double_pair max_value_and_sum_exps  # val1 = max_value, val2 = sum_exps

    max_value_and_sum_exps.val1 = raw_prediction[i, 0]
    max_value_and_sum_exps.val2 = 0
    for k in range(1, n_classes):
        # Compute max value of array for numerical stability
        if max_value_and_sum_exps.val1 < raw_prediction[i, k]:
            max_value_and_sum_exps.val1 = raw_prediction[i, k]

    for k in range(n_classes):
        p[k] = exp(raw_prediction[i, k] - max_value_and_sum_exps.val1)
        max_value_and_sum_exps.val2 += p[k]

    return max_value_and_sum_exps


# -------------------------------------
# Single point inline C functions
# -------------------------------------
# Half Squared Error
cdef inline double closs_half_squared_error(
    double y_true,
    double raw_prediction
) noexcept nogil:
    return 0.5 * (raw_prediction - y_true) * (raw_prediction - y_true)


cdef inline double cgradient_half_squared_error(
    double y_true,
    double raw_prediction
) noexcept nogil:
    return raw_prediction - y_true


cdef inline double_pair cgrad_hess_half_squared_error(
    double y_true,
    double raw_prediction
) noexcept nogil:
    cdef double_pair gh
    gh.val1 = raw_prediction - y_true  # gradient
    gh.val2 = 1.                       # hessian
    return gh


# Absolute Error
cdef inline double closs_absolute_error(
    double y_true,
    double raw_prediction
) noexcept nogil:
    return fabs(raw_prediction - y_true)


cdef inline double cgradient_absolute_error(
    double y_true,
    double raw_prediction
) noexcept nogil:
    return 1. if raw_prediction > y_true else -1.


cdef inline double_pair cgrad_hess_absolute_error(
    double y_true,
    double raw_prediction
) noexcept nogil:
    cdef double_pair gh
    # Note that exact hessian = 0 almost everywhere. Optimization routines like
    # in HGBT, however, need a hessian > 0. Therefore, we assign 1.
    gh.val1 = 1. if raw_prediction > y_true else -1.  # gradient
    gh.val2 = 1.                                      # hessian
    return gh


# Quantile Loss / Pinball Loss
cdef inline double closs_pinball_loss(
    double y_true,
    double raw_prediction,
    double quantile
) noexcept nogil:
    return (quantile * (y_true - raw_prediction) if y_true >= raw_prediction
            else (1. - quantile) * (raw_prediction - y_true))


cdef inline double cgradient_pinball_loss(
    double y_true,
    double raw_prediction,
    double quantile
) noexcept nogil:
    return -quantile if y_true >=raw_prediction else 1. - quantile


cdef inline double_pair cgrad_hess_pinball_loss(
    double y_true,
    double raw_prediction,
    double quantile
) noexcept nogil:
    cdef double_pair gh
    # Note that exact hessian = 0 almost everywhere. Optimization routines like
    # in HGBT, however, need a hessian > 0. Therefore, we assign 1.
    gh.val1 = -quantile if y_true >=raw_prediction else 1. - quantile  # gradient
    gh.val2 = 1.                                                       # hessian
    return gh


# Huber Loss
cdef inline double closs_huber_loss(
    double y_true,
    double raw_prediction,
    double delta,
) noexcept nogil:
    cdef double abserr = fabs(y_true - raw_prediction)
    if abserr <= delta:
        return 0.5 * abserr**2
    else:
        return delta * (abserr - 0.5 * delta)


cdef inline double cgradient_huber_loss(
    double y_true,
    double raw_prediction,
    double delta,
) noexcept nogil:
    cdef double res = raw_prediction - y_true
    if fabs(res) <= delta:
        return res
    else:
        return delta if res >=0 else -delta


cdef inline double_pair cgrad_hess_huber_loss(
    double y_true,
    double raw_prediction,
    double delta,
) noexcept nogil:
    cdef double_pair gh
    gh.val2 = raw_prediction - y_true               # used as temporary
    if fabs(gh.val2) <= delta:
        gh.val1 = gh.val2                           # gradient
        gh.val2 = 1                                 # hessian
    else:
        gh.val1 = delta if gh.val2 >=0 else -delta  # gradient
        gh.val2 = 0                                 # hessian
    return gh


# Half Poisson Deviance with Log-Link, dropping constant terms
cdef inline double closs_half_poisson(
    double y_true,
    double raw_prediction
) noexcept nogil:
    return exp(raw_prediction) - y_true * raw_prediction


cdef inline double cgradient_half_poisson(
    double y_true,
    double raw_prediction
) noexcept nogil:
    # y_pred - y_true
    return exp(raw_prediction) - y_true


cdef inline double_pair closs_grad_half_poisson(
    double y_true,
    double raw_prediction
) noexcept nogil:
    cdef double_pair lg
    lg.val2 = exp(raw_prediction)                # used as temporary
    lg.val1 = lg.val2 - y_true * raw_prediction  # loss
    lg.val2 -= y_true                            # gradient
    return lg


cdef inline double_pair cgrad_hess_half_poisson(
    double y_true,
    double raw_prediction
) noexcept nogil:
    cdef double_pair gh
    gh.val2 = exp(raw_prediction)  # hessian
    gh.val1 = gh.val2 - y_true     # gradient
    return gh


# Half Gamma Deviance with Log-Link, dropping constant terms
cdef inline double closs_half_gamma(
    double y_true,
    double raw_prediction
) noexcept nogil:
    return raw_prediction + y_true * exp(-raw_prediction)


cdef inline double cgradient_half_gamma(
    double y_true,
    double raw_prediction
) noexcept nogil:
    return 1. - y_true * exp(-raw_prediction)


cdef inline double_pair closs_grad_half_gamma(
    double y_true,
    double raw_prediction
) noexcept nogil:
    cdef double_pair lg
    lg.val2 = exp(-raw_prediction)               # used as temporary
    lg.val1 = raw_prediction + y_true * lg.val2  # loss
    lg.val2 = 1. - y_true * lg.val2              # gradient
    return lg


cdef inline double_pair cgrad_hess_half_gamma(
    double y_true,
    double raw_prediction
) noexcept nogil:
    cdef double_pair gh
    gh.val2 = exp(-raw_prediction)   # used as temporary
    gh.val1 = 1. - y_true * gh.val2  # gradient
    gh.val2 *= y_true                # hessian
    return gh


# Half Tweedie Deviance with Log-Link, dropping constant terms
# Note that by dropping constants this is no longer continuous in parameter power.
cdef inline double closs_half_tweedie(
    double y_true,
    double raw_prediction,
    double power
) noexcept nogil:
    if power == 0.:
        return closs_half_squared_error(y_true, exp(raw_prediction))
    elif power == 1.:
        return closs_half_poisson(y_true, raw_prediction)
    elif power == 2.:
        return closs_half_gamma(y_true, raw_prediction)
    else:
        return (exp((2. - power) * raw_prediction) / (2. - power)
                - y_true * exp((1. - power) * raw_prediction) / (1. - power))


cdef inline double cgradient_half_tweedie(
    double y_true,
    double raw_prediction,
    double power
) noexcept nogil:
    cdef double exp1
    if power == 0.:
        exp1 = exp(raw_prediction)
        return exp1 * (exp1 - y_true)
    elif power == 1.:
        return cgradient_half_poisson(y_true, raw_prediction)
    elif power == 2.:
        return cgradient_half_gamma(y_true, raw_prediction)
    else:
        return (exp((2. - power) * raw_prediction)
                - y_true * exp((1. - power) * raw_prediction))


cdef inline double_pair closs_grad_half_tweedie(
    double y_true,
    double raw_prediction,
    double power
) noexcept nogil:
    cdef double_pair lg
    cdef double exp1, exp2
    if power == 0.:
        exp1 = exp(raw_prediction)
        lg.val1 = closs_half_squared_error(y_true, exp1)  # loss
        lg.val2 = exp1 * (exp1 - y_true)                  # gradient
    elif power == 1.:
        return closs_grad_half_poisson(y_true, raw_prediction)
    elif power == 2.:
        return closs_grad_half_gamma(y_true, raw_prediction)
    else:
        exp1 = exp((1. - power) * raw_prediction)
        exp2 = exp((2. - power) * raw_prediction)
        lg.val1 = exp2 / (2. - power) - y_true * exp1 / (1. - power)  # loss
        lg.val2 = exp2 - y_true * exp1                                # gradient
    return lg


cdef inline double_pair cgrad_hess_half_tweedie(
    double y_true,
    double raw_prediction,
    double power
) noexcept nogil:
    cdef double_pair gh
    cdef double exp1, exp2
    if power == 0.:
        exp1 = exp(raw_prediction)
        gh.val1 = exp1 * (exp1 - y_true)      # gradient
        gh.val2 = exp1 * (2 * exp1 - y_true)  # hessian
    elif power == 1.:
        return cgrad_hess_half_poisson(y_true, raw_prediction)
    elif power == 2.:
        return cgrad_hess_half_gamma(y_true, raw_prediction)
    else:
        exp1 = exp((1. - power) * raw_prediction)
        exp2 = exp((2. - power) * raw_prediction)
        gh.val1 = exp2 - y_true * exp1                                # gradient
        gh.val2 = (2. - power) * exp2 - (1. - power) * y_true * exp1  # hessian
    return gh


# Half Tweedie Deviance with identity link, without dropping constant terms!
# Therefore, best loss value is zero.
cdef inline double closs_half_tweedie_identity(
    double y_true,
    double raw_prediction,
    double power
) noexcept nogil:
    cdef double tmp
    if power == 0.:
        return closs_half_squared_error(y_true, raw_prediction)
    elif power == 1.:
        if y_true == 0:
            return raw_prediction
        else:
            return y_true * log(y_true/raw_prediction) + raw_prediction - y_true
    elif power == 2.:
        return log(raw_prediction/y_true) + y_true/raw_prediction - 1.
    else:
        tmp = pow(raw_prediction, 1. - power)
        tmp = raw_prediction * tmp / (2. - power) - y_true * tmp / (1. - power)
        if y_true > 0:
            tmp += pow(y_true, 2. - power) / ((1. - power) * (2. - power))
        return tmp


cdef inline double cgradient_half_tweedie_identity(
    double y_true,
    double raw_prediction,
    double power
) noexcept nogil:
    if power == 0.:
        return raw_prediction - y_true
    elif power == 1.:
        return 1. - y_true / raw_prediction
    elif power == 2.:
        return (raw_prediction - y_true) / (raw_prediction * raw_prediction)
    else:
        return pow(raw_prediction, -power) * (raw_prediction - y_true)


cdef inline double_pair closs_grad_half_tweedie_identity(
    double y_true,
    double raw_prediction,
    double power
) noexcept nogil:
    cdef double_pair lg
    cdef double tmp
    if power == 0.:
        lg.val2 = raw_prediction - y_true  # gradient
        lg.val1 = 0.5 * lg.val2 * lg.val2  # loss
    elif power == 1.:
        if y_true == 0:
            lg.val1 = raw_prediction
        else:
            lg.val1 = (y_true * log(y_true/raw_prediction)  # loss
                       + raw_prediction - y_true)
        lg.val2 = 1. - y_true / raw_prediction              # gradient
    elif power == 2.:
        lg.val1 = log(raw_prediction/y_true) + y_true/raw_prediction - 1.  # loss
        tmp = raw_prediction * raw_prediction
        lg.val2 = (raw_prediction - y_true) / tmp                          # gradient
    else:
        tmp = pow(raw_prediction, 1. - power)
        lg.val1 = (raw_prediction * tmp / (2. - power)  # loss
                   - y_true * tmp / (1. - power))
        if y_true > 0:
            lg.val1 += (pow(y_true, 2. - power)
                        / ((1. - power) * (2. - power)))
        lg.val2 = tmp * (1. - y_true / raw_prediction)    # gradient
    return lg


cdef inline double_pair cgrad_hess_half_tweedie_identity(
    double y_true,
    double raw_prediction,
    double power
) noexcept nogil:
    cdef double_pair gh
    cdef double tmp
    if power == 0.:
        gh.val1 = raw_prediction - y_true  # gradient
        gh.val2 = 1.                       # hessian
    elif power == 1.:
        gh.val1 = 1. - y_true / raw_prediction                # gradient
        gh.val2 = y_true / (raw_prediction * raw_prediction)  # hessian
    elif power == 2.:
        tmp = raw_prediction * raw_prediction
        gh.val1 = (raw_prediction - y_true) / tmp             # gradient
        gh.val2 = (-1. + 2. * y_true / raw_prediction) / tmp  # hessian
    else:
        tmp = pow(raw_prediction, -power)
        gh.val1 = tmp * (raw_prediction - y_true)                         # gradient
        gh.val2 = tmp * ((1. - power) + power * y_true / raw_prediction)  # hessian
    return gh


# Half Binomial deviance with logit-link, aka log-loss or binary cross entropy
cdef inline double closs_half_binomial(
    double y_true,
    double raw_prediction
) noexcept nogil:
    # log1p(exp(raw_prediction)) - y_true * raw_prediction
    return log1pexp(raw_prediction) - y_true * raw_prediction


cdef inline double cgradient_half_binomial(
    double y_true,
    double raw_prediction
) noexcept nogil:
    # gradient = y_pred - y_true = expit(raw_prediction) - y_true
    # Numerically more stable, see http://fa.bianp.net/blog/2019/evaluate_logistic/
    #     if raw_prediction < 0:
    #         exp_tmp = exp(raw_prediction)
    #         return ((1 - y_true) * exp_tmp - y_true) / (1 + exp_tmp)
    #     else:
    #         exp_tmp = exp(-raw_prediction)
    #         return ((1 - y_true) - y_true * exp_tmp) / (1 + exp_tmp)
    # Note that optimal speed would be achieved, at the cost of precision, by
    #     return expit(raw_prediction) - y_true
    # i.e. no "if else" and an own inline implementation of expit instead of
    #     from scipy.special.cython_special cimport expit
    # The case distinction raw_prediction < 0 in the stable implementation does not
    # provide significant better precision apart from protecting overflow of exp(..).
    # The branch (if else), however, can incur runtime costs of up to 30%.
    # Instead, we help branch prediction by almost always ending in the first if clause
    # and making the second branch (else) a bit simpler. This has the exact same
    # precision but is faster than the stable implementation.
    # As branching criteria, we use the same cutoff as in log1pexp. Note that the
    # maximal value to get gradient = -1 with y_true = 1 is -37.439198610162731
    # (based on mpmath), and scipy.special.logit(np.finfo(float).eps) ~ -36.04365.
    cdef double exp_tmp
    if raw_prediction > -37:
        exp_tmp = exp(-raw_prediction)
        return ((1 - y_true) - y_true * exp_tmp) / (1 + exp_tmp)
    else:
        # expit(raw_prediction) = exp(raw_prediction) for raw_prediction <= -37
        return exp(raw_prediction) - y_true


cdef inline double_pair closs_grad_half_binomial(
    double y_true,
    double raw_prediction
) noexcept nogil:
    cdef double_pair lg
    # Same if else conditions as in log1pexp.
    if raw_prediction <= -37:
        lg.val2 = exp(raw_prediction)  # used as temporary
        lg.val1 = lg.val2 - y_true * raw_prediction                  # loss
        lg.val2 -= y_true                                            # gradient
    elif raw_prediction <= -2:
        lg.val2 = exp(raw_prediction)  # used as temporary
        lg.val1 = log1p(lg.val2) - y_true * raw_prediction           # loss
        lg.val2 = ((1 - y_true) * lg.val2 - y_true) / (1 + lg.val2)  # gradient
    elif raw_prediction <= 18:
        lg.val2 = exp(-raw_prediction)  # used as temporary
        # log1p(exp(x)) = log(1 + exp(x)) = x + log1p(exp(-x))
        lg.val1 = log1p(lg.val2) + (1 - y_true) * raw_prediction     # loss
        lg.val2 = ((1 - y_true) - y_true * lg.val2) / (1 + lg.val2)  # gradient
    else:
        lg.val2 = exp(-raw_prediction)  # used as temporary
        lg.val1 = lg.val2 + (1 - y_true) * raw_prediction            # loss
        lg.val2 = ((1 - y_true) - y_true * lg.val2) / (1 + lg.val2)  # gradient
    return lg


cdef inline double_pair cgrad_hess_half_binomial(
    double y_true,
    double raw_prediction
) noexcept nogil:
    # with y_pred = expit(raw)
    # hessian = y_pred * (1 - y_pred) = exp( raw) / (1 + exp( raw))**2
    #                                 = exp(-raw) / (1 + exp(-raw))**2
    cdef double_pair gh
    # See comment in cgradient_half_binomial.
    if raw_prediction > -37:
        gh.val2 = exp(-raw_prediction)  # used as temporary
        gh.val1 = ((1 - y_true) - y_true * gh.val2) / (1 + gh.val2)  # gradient
        gh.val2 = gh.val2 / (1 + gh.val2)**2                         # hessian
    else:
        gh.val2 = exp(raw_prediction)  # = 1. order Taylor in exp(raw_prediction)
        gh.val1 = gh.val2 - y_true
    return gh


# Exponential loss with (half) logit-link, aka boosting loss
cdef inline double closs_exponential(
    double y_true,
    double raw_prediction
) noexcept nogil:
    cdef double tmp = exp(raw_prediction)
    return y_true / tmp + (1 - y_true) * tmp


cdef inline double cgradient_exponential(
    double y_true,
    double raw_prediction
) noexcept nogil:
    cdef double tmp = exp(raw_prediction)
    return -y_true / tmp + (1 - y_true) * tmp


cdef inline double_pair closs_grad_exponential(
    double y_true,
    double raw_prediction
) noexcept nogil:
    cdef double_pair lg
    lg.val2 = exp(raw_prediction)  # used as temporary

    lg.val1 =  y_true / lg.val2 + (1 - y_true) * lg.val2  # loss
    lg.val2 = -y_true / lg.val2 + (1 - y_true) * lg.val2  # gradient
    return lg


cdef inline double_pair cgrad_hess_exponential(
    double y_true,
    double raw_prediction
) noexcept nogil:
    # Note that hessian = loss
    cdef double_pair gh
    gh.val2 = exp(raw_prediction)  # used as temporary

    gh.val1 = -y_true / gh.val2 + (1 - y_true) * gh.val2  # gradient
    gh.val2 =  y_true / gh.val2 + (1 - y_true) * gh.val2  # hessian
    return gh


# ---------------------------------------------------
# Extension Types for Loss Functions of 1-dim targets
# ---------------------------------------------------
cdef class CyLossFunction:
    """Base class for convex loss functions."""

    def __reduce__(self):
        return (self.__class__, ())

    cdef double cy_loss(self, double y_true, double raw_prediction) noexcept nogil:
        """Compute the loss for a single sample.

        Parameters
        ----------
        y_true : double
            Observed, true target value.
        raw_prediction : double
            Raw prediction value (in link space).

        Returns
        -------
        double
            The loss evaluated at `y_true` and `raw_prediction`.
        """
        pass

    cdef double cy_gradient(self, double y_true, double raw_prediction) noexcept nogil:
        """Compute gradient of loss w.r.t. raw_prediction for a single sample.

        Parameters
        ----------
        y_true : double
            Observed, true target value.
        raw_prediction : double
            Raw prediction value (in link space).

        Returns
        -------
        double
            The derivative of the loss function w.r.t. `raw_prediction`.
        """
        pass

    cdef double_pair cy_grad_hess(
        self, double y_true, double raw_prediction
    ) noexcept nogil:
        """Compute gradient and hessian.

        Gradient and hessian of loss w.r.t. raw_prediction for a single sample.

        This is usually diagonal in raw_prediction_i and raw_prediction_j.
        Therefore, we return the diagonal element i=j.

        For a loss with a non-canonical link, this might implement the diagonal
        of the Fisher matrix (=expected hessian) instead of the hessian.

        Parameters
        ----------
        y_true : double
            Observed, true target value.
        raw_prediction : double
            Raw prediction value (in link space).

        Returns
        -------
        double_pair
            Gradient and hessian of the loss function w.r.t. `raw_prediction`.
        """
        pass

    def loss(
        self,
        const floating_in[::1] y_true,          # IN
        const floating_in[::1] raw_prediction,  # IN
        const floating_in[::1] sample_weight,   # IN
        floating_out[::1] loss_out,             # OUT
        int n_threads=1
    ):
        """Compute the point-wise loss value for each input.

        The point-wise loss is written to `loss_out` and no array is returned.

        Parameters
        ----------
        y_true : array of shape (n_samples,)
            Observed, true target values.
        raw_prediction : array of shape (n_samples,)
            Raw prediction values (in link space).
        sample_weight : array of shape (n_samples,) or None
            Sample weights.
        loss_out : array of shape (n_samples,)
            A location into which the result is stored.
        n_threads : int
            Number of threads used by OpenMP (if any).
        """
        pass

    def gradient(
        self,
        const floating_in[::1] y_true,          # IN
        const floating_in[::1] raw_prediction,  # IN
        const floating_in[::1] sample_weight,   # IN
        floating_out[::1] gradient_out,         # OUT
        int n_threads=1
    ):
        """Compute gradient of loss w.r.t raw_prediction for each input.

        The gradient is written to `gradient_out` and no array is returned.

        Parameters
        ----------
        y_true : array of shape (n_samples,)
            Observed, true target values.
        raw_prediction : array of shape (n_samples,)
            Raw prediction values (in link space).
        sample_weight : array of shape (n_samples,) or None
            Sample weights.
        gradient_out : array of shape (n_samples,)
            A location into which the result is stored.
        n_threads : int
            Number of threads used by OpenMP (if any).
        """
        pass

    def loss_gradient(
        self,
        const floating_in[::1] y_true,          # IN
        const floating_in[::1] raw_prediction,  # IN
        const floating_in[::1] sample_weight,   # IN
        floating_out[::1] loss_out,             # OUT
        floating_out[::1] gradient_out,         # OUT
        int n_threads=1
    ):
        """Compute loss and gradient of loss w.r.t raw_prediction.

        The loss and gradient are written to `loss_out` and `gradient_out` and no arrays
        are returned.

        Parameters
        ----------
        y_true : array of shape (n_samples,)
            Observed, true target values.
        raw_prediction : array of shape (n_samples,)
            Raw prediction values (in link space).
        sample_weight : array of shape (n_samples,) or None
            Sample weights.
        loss_out : array of shape (n_samples,) or None
            A location into which the element-wise loss is stored.
        gradient_out : array of shape (n_samples,)
            A location into which the gradient is stored.
        n_threads : int
            Number of threads used by OpenMP (if any).
        """
        self.loss(y_true, raw_prediction, sample_weight, loss_out, n_threads)
        self.gradient(y_true, raw_prediction, sample_weight, gradient_out, n_threads)

    def gradient_hessian(
        self,
        const floating_in[::1] y_true,          # IN
        const floating_in[::1] raw_prediction,  # IN
        const floating_in[::1] sample_weight,   # IN
        floating_out[::1] gradient_out,         # OUT
        floating_out[::1] hessian_out,          # OUT
        int n_threads=1
    ):
        """Compute gradient and hessian of loss w.r.t raw_prediction.

        The gradient and hessian are written to `gradient_out` and `hessian_out` and no
        arrays are returned.

        Parameters
        ----------
        y_true : array of shape (n_samples,)
            Observed, true target values.
        raw_prediction : array of shape (n_samples,)
            Raw prediction values (in link space).
        sample_weight : array of shape (n_samples,) or None
            Sample weights.
        gradient_out : array of shape (n_samples,)
            A location into which the gradient is stored.
        hessian_out : array of shape (n_samples,)
            A location into which the hessian is stored.
        n_threads : int
            Number of threads used by OpenMP (if any).
        """
        pass


{{for name, docstring, param, closs, closs_grad, cgrad, cgrad_hess, in class_list}}
{{py:
if param is None:
    with_param = ""
else:
    with_param = ", self." + param
}}

cdef class {{name}}(CyLossFunction):
    """{{docstring}}"""

    {{if param is not None}}
    def __init__(self, {{param}}):
        self.{{param}} = {{param}}
    {{endif}}

    {{if param is not None}}
    def __reduce__(self):
        return (self.__class__, (self.{{param}},))
    {{endif}}

    cdef inline double cy_loss(self, double y_true, double raw_prediction) noexcept nogil:
        return {{closs}}(y_true, raw_prediction{{with_param}})

    cdef inline double cy_gradient(self, double y_true, double raw_prediction) noexcept nogil:
        return {{cgrad}}(y_true, raw_prediction{{with_param}})

    cdef inline double_pair cy_grad_hess(self, double y_true, double raw_prediction) noexcept nogil:
        return {{cgrad_hess}}(y_true, raw_prediction{{with_param}})

    def loss(
        self,
        const floating_in[::1] y_true,          # IN
        const floating_in[::1] raw_prediction,  # IN
        const floating_in[::1] sample_weight,   # IN
        floating_out[::1] loss_out,             # OUT
        int n_threads=1
    ):
        cdef:
            int i
            int n_samples = y_true.shape[0]

        if sample_weight is None:
            for i in prange(
                n_samples, schedule='static', nogil=True, num_threads=n_threads
            ):
                loss_out[i] = {{closs}}(y_true[i], raw_prediction[i]{{with_param}})
        else:
            for i in prange(
                n_samples, schedule='static', nogil=True, num_threads=n_threads
            ):
                loss_out[i] = sample_weight[i] * {{closs}}(y_true[i], raw_prediction[i]{{with_param}})

    {{if closs_grad is not None}}
    def loss_gradient(
        self,
        const floating_in[::1] y_true,          # IN
        const floating_in[::1] raw_prediction,  # IN
        const floating_in[::1] sample_weight,   # IN
        floating_out[::1] loss_out,             # OUT
        floating_out[::1] gradient_out,         # OUT
        int n_threads=1
    ):
        cdef:
            int i
            int n_samples = y_true.shape[0]
            double_pair dbl2

        if sample_weight is None:
            for i in prange(
                n_samples, schedule='static', nogil=True, num_threads=n_threads
            ):
                dbl2 = {{closs_grad}}(y_true[i], raw_prediction[i]{{with_param}})
                loss_out[i] = dbl2.val1
                gradient_out[i] = dbl2.val2
        else:
            for i in prange(
                n_samples, schedule='static', nogil=True, num_threads=n_threads
            ):
                dbl2 = {{closs_grad}}(y_true[i], raw_prediction[i]{{with_param}})
                loss_out[i] = sample_weight[i] * dbl2.val1
                gradient_out[i] = sample_weight[i] * dbl2.val2

    {{endif}}

    def gradient(
        self,
        const floating_in[::1] y_true,          # IN
        const floating_in[::1] raw_prediction,  # IN
        const floating_in[::1] sample_weight,   # IN
        floating_out[::1] gradient_out,         # OUT
        int n_threads=1
    ):
        cdef:
            int i
            int n_samples = y_true.shape[0]

        if sample_weight is None:
            for i in prange(
                n_samples, schedule='static', nogil=True, num_threads=n_threads
            ):
                gradient_out[i] = {{cgrad}}(y_true[i], raw_prediction[i]{{with_param}})
        else:
            for i in prange(
                n_samples, schedule='static', nogil=True, num_threads=n_threads
            ):
                gradient_out[i] = sample_weight[i] * {{cgrad}}(y_true[i], raw_prediction[i]{{with_param}})

    def gradient_hessian(
        self,
        const floating_in[::1] y_true,          # IN
        const floating_in[::1] raw_prediction,  # IN
        const floating_in[::1] sample_weight,   # IN
        floating_out[::1] gradient_out,         # OUT
        floating_out[::1] hessian_out,          # OUT
        int n_threads=1
    ):
        cdef:
            int i
            int n_samples = y_true.shape[0]
            double_pair dbl2

        if sample_weight is None:
            for i in prange(
                n_samples, schedule='static', nogil=True, num_threads=n_threads
            ):
                dbl2 = {{cgrad_hess}}(y_true[i], raw_prediction[i]{{with_param}})
                gradient_out[i] = dbl2.val1
                hessian_out[i] = dbl2.val2
        else:
            for i in prange(
                n_samples, schedule='static', nogil=True, num_threads=n_threads
            ):
                dbl2 = {{cgrad_hess}}(y_true[i], raw_prediction[i]{{with_param}})
                gradient_out[i] = sample_weight[i] * dbl2.val1
                hessian_out[i] = sample_weight[i] * dbl2.val2

{{endfor}}


# The multinomial deviance loss is also known as categorical cross-entropy or
# multinomial log-likelihood.
# Here, we do not inherit from CyLossFunction as its cy_gradient method deviates
# from the API.
cdef class CyHalfMultinomialLoss():
    """Half Multinomial deviance loss with multinomial logit link.

    Domain:
    y_true in {0, 1, 2, 3, .., n_classes - 1}
    y_pred in (0, 1)**n_classes, i.e. interval with boundaries excluded

    Link:
    y_pred = softmax(raw_prediction)

    Note: Label encoding is built-in, i.e. {0, 1, 2, 3, .., n_classes - 1} is
    mapped to (y_true == k) for k = 0 .. n_classes - 1 which is either 0 or 1.
    """

    # Here we deviate from the CyLossFunction API. SAG/SAGA needs direct access to
    # sample-wise gradients which we provide here.
    cdef inline void cy_gradient(
        self,
        const floating_in y_true,
        const floating_in[::1] raw_prediction,  # IN
        const floating_in sample_weight,
        floating_out[::1] gradient_out,         # OUT
    ) noexcept nogil:
        """Compute gradient of loss w.r.t. `raw_prediction` for a single sample.

        The gradient of the multinomial logistic loss with respect to a class k,
        and for one sample is:
        grad_k = - sw * (p[k] - (y==k))

        where:
            p[k] = proba[k] = exp(raw_prediction[k] - logsumexp(raw_prediction))
            sw = sample_weight

        Parameters
        ----------
        y_true : double
            Observed, true target value.
        raw_prediction : array of shape (n_classes,)
            Raw prediction values (in link space).
        sample_weight : double
            Sample weight.
        gradient_out : array of shape (n_classs,)
            A location into which the gradient is stored.

        Returns
        -------
        gradient : double
            The derivative of the loss function w.r.t. `raw_prediction`.
        """
        cdef:
            int k
            int n_classes = raw_prediction.shape[0]
            double_pair max_value_and_sum_exps
            const floating_in[:, :] raw = raw_prediction[None, :]

        max_value_and_sum_exps = sum_exp_minus_max(0, raw, &gradient_out[0])
        for k in range(n_classes):
            # gradient_out[k] = p_k = y_pred_k = prob of class k
            gradient_out[k] /= max_value_and_sum_exps.val2
            # gradient_k = (p_k - (y_true == k)) * sw
            gradient_out[k] = (gradient_out[k] - (y_true == k)) * sample_weight

    def _test_cy_gradient(
        self,
        const floating_in[::1] y_true,             # IN
        const floating_in[:, ::1] raw_prediction,  # IN
        const floating_in[::1] sample_weight,      # IN
    ):
        """For testing only."""
        cdef:
            int i, k
            int n_samples = y_true.shape[0]
            int n_classes = raw_prediction.shape[1]
            floating_in [:, ::1] gradient_out
        gradient = np.empty((n_samples, n_classes), dtype=np.float64)
        gradient_out = gradient

        for i in range(n_samples):
            self.cy_gradient(
                y_true=y_true[i],
                raw_prediction=raw_prediction[i, :],
                sample_weight=1.0 if sample_weight is None else sample_weight[i],
                gradient_out=gradient_out[i, :],
            )
        return gradient

    # Note that we do not assume memory alignment/contiguity of 2d arrays.
    # There seems to be little benefit in doing so. Benchmarks proofing the
    # opposite are welcome.
    def loss(
        self,
        const floating_in[::1] y_true,           # IN
        const floating_in[:, :] raw_prediction,  # IN
        const floating_in[::1] sample_weight,    # IN
        floating_out[::1] loss_out,              # OUT
        int n_threads=1
    ):
        cdef:
            int i, k
            int n_samples = y_true.shape[0]
            int n_classes = raw_prediction.shape[1]
            floating_in max_value, sum_exps
            floating_in*  p  # temporary buffer
            double_pair max_value_and_sum_exps

        # We assume n_samples > n_classes. In this case having the inner loop
        # over n_classes is a good default.
        # TODO: If every memoryview is contiguous and raw_prediction is
        #       f-contiguous, can we write a better algo (loops) to improve
        #       performance?
        if sample_weight is None:
            # inner loop over n_classes
            with nogil, parallel(num_threads=n_threads):
                # Define private buffer variables as each thread might use its
                # own.
                p = <floating_in *> malloc(sizeof(floating_in) * (n_classes))

                for i in prange(n_samples, schedule='static'):
                    max_value_and_sum_exps = sum_exp_minus_max(i, raw_prediction, p)
                    max_value = max_value_and_sum_exps.val1
                    sum_exps = max_value_and_sum_exps.val2
                    loss_out[i] = log(sum_exps) + max_value

                    # label encoded y_true
                    k = int(y_true[i])
                    loss_out[i] -= raw_prediction[i, k]

                free(p)
        else:
            with nogil, parallel(num_threads=n_threads):
                p = <floating_in *> malloc(sizeof(floating_in) * (n_classes))

                for i in prange(n_samples, schedule='static'):
                    max_value_and_sum_exps = sum_exp_minus_max(i, raw_prediction, p)
                    max_value = max_value_and_sum_exps.val1
                    sum_exps = max_value_and_sum_exps.val2
                    loss_out[i] = log(sum_exps) + max_value

                    # label encoded y_true
                    k = int(y_true[i])
                    loss_out[i] -= raw_prediction[i, k]

                    loss_out[i] *= sample_weight[i]

                free(p)

    def loss_gradient(
        self,
        const floating_in[::1] y_true,           # IN
        const floating_in[:, :] raw_prediction,  # IN
        const floating_in[::1] sample_weight,    # IN
        floating_out[::1] loss_out,              # OUT
        floating_out[:, :] gradient_out,         # OUT
        int n_threads=1
    ):
        cdef:
            int i, k
            int n_samples = y_true.shape[0]
            int n_classes = raw_prediction.shape[1]
            floating_in max_value, sum_exps
            floating_in*  p  # temporary buffer
            double_pair max_value_and_sum_exps

        if sample_weight is None:
            # inner loop over n_classes
            with nogil, parallel(num_threads=n_threads):
                # Define private buffer variables as each thread might use its
                # own.
                p = <floating_in *> malloc(sizeof(floating_in) * (n_classes))

                for i in prange(n_samples, schedule='static'):
                    max_value_and_sum_exps = sum_exp_minus_max(i, raw_prediction, p)
                    max_value = max_value_and_sum_exps.val1
                    sum_exps = max_value_and_sum_exps.val2
                    loss_out[i] = log(sum_exps) + max_value

                    for k in range(n_classes):
                        # label decode y_true
                        if y_true[i] == k:
                            loss_out[i] -= raw_prediction[i, k]
                        p[k] /= sum_exps  # p_k = y_pred_k = prob of class k
                        # gradient_k = p_k - (y_true == k)
                        gradient_out[i, k] = p[k] - (y_true[i] == k)

                free(p)
        else:
            with nogil, parallel(num_threads=n_threads):
                p = <floating_in *> malloc(sizeof(floating_in) * (n_classes))

                for i in prange(n_samples, schedule='static'):
                    max_value_and_sum_exps = sum_exp_minus_max(i, raw_prediction, p)
                    max_value = max_value_and_sum_exps.val1
                    sum_exps = max_value_and_sum_exps.val2
                    loss_out[i] = log(sum_exps) + max_value

                    for k in range(n_classes):
                        # label decode y_true
                        if y_true[i] == k:
                            loss_out[i] -= raw_prediction[i, k]
                        p[k] /= sum_exps  # p_k = y_pred_k = prob of class k
                        # gradient_k = (p_k - (y_true == k)) * sw
                        gradient_out[i, k] = (p[k] - (y_true[i] == k)) * sample_weight[i]

                    loss_out[i] *= sample_weight[i]

                free(p)

    def gradient(
        self,
        const floating_in[::1] y_true,           # IN
        const floating_in[:, :] raw_prediction,  # IN
        const floating_in[::1] sample_weight,    # IN
        floating_out[:, :] gradient_out,         # OUT
        int n_threads=1
    ):
        cdef:
            int i, k
            int n_samples = y_true.shape[0]
            int n_classes = raw_prediction.shape[1]
            floating_in sum_exps
            floating_in*  p  # temporary buffer
            double_pair max_value_and_sum_exps

        if sample_weight is None:
            # inner loop over n_classes
            with nogil, parallel(num_threads=n_threads):
                # Define private buffer variables as each thread might use its
                # own.
                p = <floating_in *> malloc(sizeof(floating_in) * (n_classes))

                for i in prange(n_samples, schedule='static'):
                    max_value_and_sum_exps = sum_exp_minus_max(i, raw_prediction, p)
                    sum_exps = max_value_and_sum_exps.val2

                    for k in range(n_classes):
                        p[k] /= sum_exps  # p_k = y_pred_k = prob of class k
                        # gradient_k = y_pred_k - (y_true == k)
                        gradient_out[i, k] = p[k] - (y_true[i] == k)

                free(p)
        else:
            with nogil, parallel(num_threads=n_threads):
                p = <floating_in *> malloc(sizeof(floating_in) * (n_classes))

                for i in prange(n_samples, schedule='static'):
                    max_value_and_sum_exps = sum_exp_minus_max(i, raw_prediction, p)
                    sum_exps = max_value_and_sum_exps.val2

                    for k in range(n_classes):
                        p[k] /= sum_exps  # p_k = y_pred_k = prob of class k
                        # gradient_k = (p_k - (y_true == k)) * sw
                        gradient_out[i, k] = (p[k] - (y_true[i] == k)) * sample_weight[i]

                free(p)

    def gradient_hessian(
        self,
        const floating_in[::1] y_true,           # IN
        const floating_in[:, :] raw_prediction,  # IN
        const floating_in[::1] sample_weight,    # IN
        floating_out[:, :] gradient_out,         # OUT
        floating_out[:, :] hessian_out,          # OUT
        int n_threads=1
    ):
        cdef:
            int i, k
            int n_samples = y_true.shape[0]
            int n_classes = raw_prediction.shape[1]
            floating_in sum_exps
            floating_in* p  # temporary buffer
            double_pair max_value_and_sum_exps

        if sample_weight is None:
            # inner loop over n_classes
            with nogil, parallel(num_threads=n_threads):
                # Define private buffer variables as each thread might use its
                # own.
                p = <floating_in *> malloc(sizeof(floating_in) * (n_classes))

                for i in prange(n_samples, schedule='static'):
                    max_value_and_sum_exps = sum_exp_minus_max(i, raw_prediction, p)
                    sum_exps = max_value_and_sum_exps.val2

                    for k in range(n_classes):
                        p[k] /= sum_exps  # p_k = y_pred_k = prob of class k
                        # hessian_k = p_k * (1 - p_k)
                        # gradient_k = p_k - (y_true == k)
                        gradient_out[i, k] = p[k] - (y_true[i] == k)
                        hessian_out[i, k] = p[k] * (1. - p[k])

                free(p)
        else:
            with nogil, parallel(num_threads=n_threads):
                p = <floating_in *> malloc(sizeof(floating_in) * (n_classes))

                for i in prange(n_samples, schedule='static'):
                    max_value_and_sum_exps = sum_exp_minus_max(i, raw_prediction, p)
                    sum_exps = max_value_and_sum_exps.val2

                    for k in range(n_classes):
                        p[k] /= sum_exps  # p_k = y_pred_k = prob of class k
                        # gradient_k = (p_k - (y_true == k)) * sw
                        # hessian_k = p_k * (1 - p_k) * sw
                        gradient_out[i, k] = (p[k] - (y_true[i] == k)) * sample_weight[i]
                        hessian_out[i, k] = (p[k] * (1. - p[k])) * sample_weight[i]

                free(p)

    # This method simplifies the implementation of hessp in linear models,
    # i.e. the matrix-vector product of the full hessian, not only of the
    # diagonal (in the classes) approximation as implemented above.
    def gradient_proba(
        self,
        const floating_in[::1] y_true,           # IN
        const floating_in[:, :] raw_prediction,  # IN
        const floating_in[::1] sample_weight,    # IN
        floating_out[:, :] gradient_out,         # OUT
        floating_out[:, :] proba_out,            # OUT
        int n_threads=1
    ):
        cdef:
            int i, k
            int n_samples = y_true.shape[0]
            int n_classes = raw_prediction.shape[1]
            floating_in sum_exps
            floating_in*  p  # temporary buffer
            double_pair max_value_and_sum_exps

        if sample_weight is None:
            # inner loop over n_classes
            with nogil, parallel(num_threads=n_threads):
                # Define private buffer variables as each thread might use its
                # own.
                p = <floating_in *> malloc(sizeof(floating_in) * (n_classes))

                for i in prange(n_samples, schedule='static'):
                    max_value_and_sum_exps = sum_exp_minus_max(i, raw_prediction, p)
                    sum_exps = max_value_and_sum_exps.val2

                    for k in range(n_classes):
                        proba_out[i, k] = p[k] / sum_exps  # y_pred_k = prob of class k
                        # gradient_k = y_pred_k - (y_true == k)
                        gradient_out[i, k] = proba_out[i, k] - (y_true[i] == k)

                free(p)
        else:
            with nogil, parallel(num_threads=n_threads):
                p = <floating_in *> malloc(sizeof(floating_in) * (n_classes))

                for i in prange(n_samples, schedule='static'):
                    max_value_and_sum_exps = sum_exp_minus_max(i, raw_prediction, p)
                    sum_exps = max_value_and_sum_exps.val2

                    for k in range(n_classes):
                        proba_out[i, k] = p[k] / sum_exps  # y_pred_k = prob of class k
                        # gradient_k = (p_k - (y_true == k)) * sw
                        gradient_out[i, k] = (proba_out[i, k] - (y_true[i] == k)) * sample_weight[i]

                free(p)
