{{py:

"""
Efficient (dense) parameter vector implementation for linear models.

Template file for easily generate fused types consistent code using Tempita
(https://github.com/cython/cython/blob/master/Cython/Tempita/_tempita.py).

Generated file: weight_vector.pxd

Each class is duplicated for all dtypes (float and double). The keywords
between double braces are substituted during the build.
"""

# name_suffix, c_type, reset_wscale_threshold
dtypes = [('64', 'double', 1e-9),
          ('32', 'float', 1e-6)]

}}

# cython: binding=False
#
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause

cimport cython
from libc.limits cimport INT_MAX
from libc.math cimport sqrt, fabs

from sklearn.utils._cython_blas cimport _dot, _scal, _axpy, _asum

{{for name_suffix, c_type, reset_wscale_threshold in dtypes}}

cdef class WeightVector{{name_suffix}}(object):
    """Dense vector represented by a scalar and a numpy array.

    The class provides methods to ``add`` a sparse vector
    and scale the vector.
    Representing a vector explicitly as a scalar times a
    vector allows for efficient scaling operations.

    Attributes
    ----------
    w : ndarray, dtype={{c_type}}, order='C'
        The numpy array which backs the weight vector.
    aw : ndarray, dtype={{c_type}}, order='C'
        The numpy array which backs the average_weight vector.
    w_data_ptr : {{c_type}}*
        A pointer to the data of the numpy array.
    wscale : {{c_type}}
        The scale of the vector.
    n_features : int
        The number of features (= dimensionality of ``w``).
    sq_norm : {{c_type}}
        The squared norm of ``w``.
    l1_norm : {{c_type}}
        The L1 norm of ``w``.
    """

    def __cinit__(self,
                  {{c_type}}[::1] w,
                  {{c_type}}[::1] aw):

        if w.shape[0] > INT_MAX:
            raise ValueError("More than %d features not supported; got %d."
                             % (INT_MAX, w.shape[0]))
        self.w = w
        self.w_data_ptr = &w[0]
        self.wscale = 1.0
        self.n_features = w.shape[0]
        self.sq_norm = _dot(self.n_features, self.w_data_ptr, 1, self.w_data_ptr, 1)
        self.l1_norm = _asum(self.n_features, self.w_data_ptr, 1)

        self.aw = aw
        if self.aw is not None:
            self.aw_data_ptr = &aw[0]
            self.average_a = 0.0
            self.average_b = 1.0

    cdef void add(self, {{c_type}} *x_data_ptr, int *x_ind_ptr, int xnnz,
                  {{c_type}} c) noexcept nogil:
        """Scales sample x by constant c and adds it to the weight vector.

        This operation updates ``sq_norm`` and ``l1_norm``.

        Parameters
        ----------
        x_data_ptr : {{c_type}}*
            The array which holds the feature values of ``x``.
        x_ind_ptr : np.intc*
            The array which holds the feature indices of ``x``.
        xnnz : int
            The number of non-zero features of ``x``.
        c : {{c_type}}
            The scaling constant for the example.
        """
        cdef int j
        cdef int idx
        cdef double val
        cdef double l2norm_accumulator = 0.0
        cdef double l1norm_accumulator = 0.0

        # the next two lines save a factor of 2!
        cdef {{c_type}} wscale = self.wscale
        cdef {{c_type}}* w_data_ptr = self.w_data_ptr

        for j in range(xnnz):
            idx = x_ind_ptr[j]
            val = x_data_ptr[j]
            w_data_ptr[idx] += val * (c / wscale)

            l2norm_accumulator += w_data_ptr[idx] * w_data_ptr[idx]
            l1norm_accumulator += fabs(w_data_ptr[idx])

        self.sq_norm = l2norm_accumulator * (wscale * wscale)
        self.l1_norm = l1norm_accumulator * wscale

    # Update the average weights according to the sparse trick defined
    # here: https://research.microsoft.com/pubs/192769/tricks-2012.pdf
    # by Leon Bottou
    cdef void add_average(self, {{c_type}} *x_data_ptr, int *x_ind_ptr, int xnnz,
                          {{c_type}} c, {{c_type}} num_iter) noexcept nogil:
        """Updates the average weight vector.

        Parameters
        ----------
        x_data_ptr : {{c_type}}*
            The array which holds the feature values of ``x``.
        x_ind_ptr : np.intc*
            The array which holds the feature indices of ``x``.
        xnnz : int
            The number of non-zero features of ``x``.
        c : {{c_type}}
            The scaling constant for the example.
        num_iter : {{c_type}}
            The total number of iterations.
        """
        cdef int j
        cdef int idx
        cdef double val
        cdef double mu = 1.0 / num_iter
        cdef double average_a = self.average_a
        cdef double wscale = self.wscale
        cdef {{c_type}}* aw_data_ptr = self.aw_data_ptr

        for j in range(xnnz):
            idx = x_ind_ptr[j]
            val = x_data_ptr[j]
            aw_data_ptr[idx] += (self.average_a * val * (-c / wscale))

        # Once the sample has been processed
        # update the average_a and average_b
        if num_iter > 1:
            self.average_b /= (1.0 - mu)
        self.average_a += mu * self.average_b * wscale

    cdef {{c_type}} dot(self, {{c_type}} *x_data_ptr, int *x_ind_ptr,
                    int xnnz) noexcept nogil:
        """Computes the dot product of a sample x and the weight vector.

        Parameters
        ----------
        x_data_ptr : {{c_type}}*
            The array which holds the feature values of ``x``.
        x_ind_ptr : np.intc*
            The array which holds the feature indices of ``x``.
        xnnz : int
            The number of non-zero features of ``x`` (length of x_ind_ptr).

        Returns
        -------
        innerprod : {{c_type}}
            The inner product of ``x`` and ``w``.
        """
        cdef int j
        cdef int idx
        cdef double innerprod = 0.0
        cdef {{c_type}}* w_data_ptr = self.w_data_ptr
        for j in range(xnnz):
            idx = x_ind_ptr[j]
            innerprod += w_data_ptr[idx] * x_data_ptr[j]
        innerprod *= self.wscale
        return innerprod

    cdef void scale(self, {{c_type}} c) noexcept nogil:
        """Scales the weight vector by a constant ``c``.

        It updates ``wscale``, ``sq_norm``, and ``l1_norm``. If ``wscale`` gets too
        small we call ``reset_wscale``."""
        self.wscale *= c
        self.sq_norm *= (c * c)
        self.l1_norm *= fabs(c)

        if self.wscale < {{reset_wscale_threshold}}:
            self.reset_wscale()

    cdef void reset_wscale(self) noexcept nogil:
        """Scales each coef of ``w`` by ``wscale`` and resets it to 1. """
        if self.aw_data_ptr != NULL:
            _axpy(self.n_features, self.average_a,
                  self.w_data_ptr, 1, self.aw_data_ptr, 1)
            _scal(self.n_features, 1.0 / self.average_b, self.aw_data_ptr, 1)
            self.average_a = 0.0
            self.average_b = 1.0

        _scal(self.n_features, self.wscale, self.w_data_ptr, 1)
        self.wscale = 1.0

    cdef {{c_type}} norm(self) noexcept nogil:
        """The L2 norm of the weight vector. """
        return sqrt(self.sq_norm)

    cdef {{c_type}} l1norm(self) noexcept nogil:
        """The L1 norm of the weight vector. """
        return self.l1_norm

{{endfor}}
