from cython cimport final

from sklearn.utils._typedefs cimport intp_t, float64_t

{{for name_suffix in ['64', '32']}}

from sklearn.metrics._pairwise_distances_reduction._datasets_pair cimport DatasetsPair{{name_suffix}}


cpdef float64_t[::1] _sqeuclidean_row_norms{{name_suffix}}(
    X,
    intp_t num_threads,
)

cdef class BaseDistancesReduction{{name_suffix}}:
    """
    Base float{{name_suffix}} implementation template of the pairwise-distances
    reduction backends.

    Implementations inherit from this template and may override the several
    defined hooks as needed in order to easily extend functionality with
    minimal redundant code.
    """

    cdef:
        readonly DatasetsPair{{name_suffix}} datasets_pair

        # The number of threads that can be used is stored in effective_n_threads.
        #
        # The number of threads to use in the parallelization strategy
        # (i.e. parallel_on_X or parallel_on_Y) can be smaller than effective_n_threads:
        # for small datasets, fewer threads might be needed to loop over pair of chunks.
        #
        # Hence, the number of threads that _will_ be used for looping over chunks
        # is stored in chunks_n_threads, allowing solely using what we need.
        #
        # Thus, an invariant is:
        #
        #                 chunks_n_threads <= effective_n_threads
        #
        intp_t effective_n_threads
        intp_t chunks_n_threads

        intp_t n_samples_chunk, chunk_size

        intp_t n_samples_X, X_n_samples_chunk, X_n_chunks, X_n_samples_last_chunk
        intp_t n_samples_Y, Y_n_samples_chunk, Y_n_chunks, Y_n_samples_last_chunk

        bint execute_in_parallel_on_Y

    @final
    cdef void _parallel_on_X(self) noexcept nogil

    @final
    cdef void _parallel_on_Y(self) noexcept nogil

    # Placeholder methods which have to be implemented

    cdef void _compute_and_reduce_distances_on_chunks(
        self,
        intp_t X_start,
        intp_t X_end,
        intp_t Y_start,
        intp_t Y_end,
        intp_t thread_num,
    ) noexcept nogil


    # Placeholder methods which can be implemented

    cdef void compute_exact_distances(self) noexcept nogil

    cdef void _parallel_on_X_parallel_init(
        self,
        intp_t thread_num,
    ) noexcept nogil

    cdef void _parallel_on_X_init_chunk(
        self,
        intp_t thread_num,
        intp_t X_start,
        intp_t X_end,
    ) noexcept nogil

    cdef void _parallel_on_X_pre_compute_and_reduce_distances_on_chunks(
        self,
        intp_t X_start,
        intp_t X_end,
        intp_t Y_start,
        intp_t Y_end,
        intp_t thread_num,
    ) noexcept nogil

    cdef void _parallel_on_X_prange_iter_finalize(
        self,
        intp_t thread_num,
        intp_t X_start,
        intp_t X_end,
    ) noexcept nogil

    cdef void _parallel_on_X_parallel_finalize(
        self,
        intp_t thread_num
    ) noexcept nogil

    cdef void _parallel_on_Y_init(
        self,
    ) noexcept nogil

    cdef void _parallel_on_Y_parallel_init(
        self,
        intp_t thread_num,
        intp_t X_start,
        intp_t X_end,
    ) noexcept nogil

    cdef void _parallel_on_Y_pre_compute_and_reduce_distances_on_chunks(
        self,
        intp_t X_start,
        intp_t X_end,
        intp_t Y_start,
        intp_t Y_end,
        intp_t thread_num,
    ) noexcept nogil

    cdef void _parallel_on_Y_synchronize(
        self,
        intp_t X_start,
        intp_t X_end,
    ) noexcept nogil

    cdef void _parallel_on_Y_finalize(
        self,
    ) noexcept nogil
{{endfor}}
