import functools
from torch._inductor.runtime.runtime_utils import ceildiv
from cutlass.utils import TensorMapUpdateMode
{{gen_defines()}}
# ---- Import GroupedGemm implementation, copied on PyTorch build from Cutlass repository: cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py ----
from torch._inductor.kernel.vendored_templates.cutedsl_grouped_gemm import (
    GroupedGemmKernel,
)


# Note about caching:
# Each instantiated CuTeDSL grouped GEMM kernel file generated by Inductor
# maintains its own local caching system. At this stage, all compile-time
# constexprs (e.g., TILE_M, TILE_N, CLUSTER_M/N, USE_2_CTA) and the kernel
# name itself ({{kernel_name}}) are permanently baked into the file, so they
# do not need to be included in any cache key.
#
# The caching mechanism is split into two levels:
#
#   1. prep_cache
#      Caches the compiled executor for build_group_ptrs_from_bases(). This
#      kernel depends only on the tensor shapes, strides, and dtypes of A/B/C,
#      and can therefore be safely reused across runs with different group
#      partitioning (`offs`).
#
#   2. gemm_cache
#      Caches the compiled Grouped GEMM executor. Its key extends the prep
#      cache key with hardware- and grid-specific parameters:
#      (prep_cache_key, max_active_clusters, total_num_clusters).
#      This is necessary because different `offs` tensors can change the
#      per-group problem sizes and thus alter `total_num_clusters`, which in
#      turn changes the grid shape and persistent scheduler configuration.
#      Kernels compiled for one grid cannot be safely reused for another.
#
#
# Additionally, note the @lru_cache decorator on get_hardware_info(). Empirically,
# hw.get_max_active_clusters() triggers significant MLIR recompilation overhead,
# despite depending only on the GPU type. We cache this function to mitigate
# redundant recompiles even when shape/stride/dtype cache misses force kernel
# regeneration. A follow-up study will investigate the root cause.

prep_cache = {}
gemm_cache = {}


@functools.lru_cache
def get_hardware_info():
    hw = cutlass.utils.HardwareInfo()
    sm_count = hw.get_max_active_clusters(1)
    max_active_clusters = hw.get_max_active_clusters(CLUSTER_M * CLUSTER_N)

    return (sm_count, max_active_clusters)


def get_prep_cache_key(input_a, input_b, output):
    """
    Returns a tuple key for caching the preprocessing kernel executor based on kernel name,
    shapes, strides, and dtypes of input/output tensors.
    """
    return (
        tuple(input_a.shape),
        tuple(input_a.stride()),
        input_a.dtype,
        tuple(input_b.shape),
        tuple(input_b.stride()),
        input_b.dtype,
        tuple(output.shape),
        tuple(output.stride()),
        output.dtype,
    )


def get_gemm_cache_key(prep_cache_key, max_active_clusters, total_num_clusters):
    """
    Returns a tuple key for caching the gemm kernel executor by extending the
    prep cache key with hardware- and grid-specific parameters.
    """
    return (
        prep_cache_key,
        max_active_clusters,
        total_num_clusters,
    )


@cute.kernel
def build_group_ptrs_from_bases_kernel(
    base_A_u64: cutlass.Int64,  # device addr of input_a (bytes)
    base_B_u64: cutlass.Int64,  # device addr of input_b (bytes)
    base_C_u64: cutlass.Int64,  # device addr of Output (bytes)
    offs: cute.Tensor,  # [G], cutlass.Int32/64 cumulative
    K: cutlass.Constexpr,
    N: cutlass.Constexpr,
    sizeof_element: cutlass.Int32,  # bytes
    # -------- STRIDES (in ELEMENTS) --------
    stride_A_m_elems: cutlass.Constexpr,  # A.stride(0)
    stride_A_k_elems: cutlass.Constexpr,  # A.stride(1)
    stride_B0_elems: cutlass.Constexpr,  # B.stride(0)
    stride_Bk_elems: cutlass.Constexpr,  # B.stride(1)
    stride_Bn_elems: cutlass.Constexpr,  # B.stride(2)
    stride_C_m_elems: cutlass.Constexpr,  # C.stride(0)
    stride_C_n_elems: cutlass.Constexpr,  # C.stride(1)
    # -------- OUTPUTS --------
    out_ptrs: cute.Tensor,  # [G,3] cutlass.Int64: (A_ptr, B_ptr, C_ptr)
    out_problem: cute.Tensor,  # [G,4] cutlass.Int32: (m_g, n, k, 1)
    out_strides_abc: cute.Tensor,  # [G,3,2] cutlass.Int32 [[A_m,A_k],[B_n,B_k],[C_m,C_n]]
):
    tidx, _, _ = cute.arch.thread_idx()
    g = tidx

    m_beg_i32 = 0
    if g > 0:
        m_beg_i32 = offs[g - 1]
    m_end_i32 = offs[g]
    m_g_i32 = m_end_i32 - m_beg_i32

    a_byte_off = (
        cutlass.Int64(m_beg_i32) * stride_A_m_elems * cutlass.Int64(sizeof_element)
    )
    c_byte_off = (
        cutlass.Int64(m_beg_i32) * stride_C_m_elems * cutlass.Int64(sizeof_element)
    )
    b_byte_off = cutlass.Int64(g) * stride_B0_elems * cutlass.Int64(sizeof_element)

    # ---- pointers ----
    out_ptrs[g, 0] = base_A_u64 + a_byte_off
    out_ptrs[g, 1] = base_B_u64 + b_byte_off
    out_ptrs[g, 2] = base_C_u64 + c_byte_off

    # ---- (m, n, k, 1) ----
    out_problem[g, 0] = m_g_i32
    out_problem[g, 1] = N
    out_problem[g, 2] = K
    out_problem[g, 3] = cutlass.Int32(1)

    # ---- strides ----
    out_strides_abc[g, 0, 0] = cutlass.Int32(stride_A_m_elems)
    out_strides_abc[g, 0, 1] = cutlass.Int32(stride_A_k_elems)
    out_strides_abc[g, 1, 0] = cutlass.Int32(stride_Bn_elems)
    out_strides_abc[g, 1, 1] = cutlass.Int32(stride_Bk_elems)
    out_strides_abc[g, 2, 0] = cutlass.Int32(stride_C_m_elems)
    out_strides_abc[g, 2, 1] = cutlass.Int32(stride_C_n_elems)


@cute.jit
def launch_build_group_ptrs_from_bases(
    base_A_u64: cutlass.Int64,
    base_B_u64: cutlass.Int64,
    base_C_u64: cutlass.Int64,
    offs: cute.Tensor,
    G: cutlass.Constexpr,
    K: cutlass.Constexpr,
    N: cutlass.Constexpr,
    sizeof_element: cutlass.Constexpr,
    stride_A_m_elems: cutlass.Constexpr,
    stride_A_k_elems: cutlass.Constexpr,
    stride_B0_elems: cutlass.Constexpr,
    stride_Bk_elems: cutlass.Constexpr,
    stride_Bn_elems: cutlass.Constexpr,
    stride_C_m_elems: cutlass.Constexpr,
    stride_C_n_elems: cutlass.Constexpr,
    out_ptrs: cute.Tensor,  # [G,3] cutlass.Int64
    out_problem: cute.Tensor,  # [G,4] cutlass.Int32
    out_strides_abc: cute.Tensor,  # [3,2] cutlass.Int32
    stream: cuda.CUstream,
):
    build_group_ptrs_from_bases_kernel(
        base_A_u64,
        base_B_u64,
        base_C_u64,
        offs,
        K,
        N,
        sizeof_element,
        stride_A_m_elems,
        stride_A_k_elems,
        stride_B0_elems,
        stride_Bk_elems,
        stride_Bn_elems,
        stride_C_m_elems,
        stride_C_n_elems,
        out_ptrs,
        out_problem,
        out_strides_abc,
    ).launch(grid=(1, 1, 1), block=(G, 1, 1), stream=stream)


{{def_kernel("input_a", "input_b", "input_a_offs")}}
    stream = cuda.CUstream(stream)

    input_b = input_b.transpose(1, 2)

    sumM, K = input_a.shape
    G, N, Kb = input_b.shape

    dev = input_a.device

    base_A_u64 = int(input_a.data_ptr())
    base_B_u64 = int(input_b.data_ptr())
    base_C_u64 = int({{get_output()}}.data_ptr())

    ptrs_t = torch.empty((G, 3), device=dev, dtype=torch.int64)
    probs_t = torch.empty((G, 4), device=dev, dtype=torch.int32)
    strides_t = torch.empty((G, 3, 2), device=dev, dtype=torch.int32)
    ptrs = from_dlpack(ptrs_t)
    probs = from_dlpack(probs_t)
    strides = from_dlpack(strides_t)

    prep_cache_key = get_prep_cache_key(input_a, input_b, {{get_output()}})
    prep_executor = prep_cache.get(prep_cache_key)

    if prep_executor is None:
        sizeof_element = int(input_a.element_size())
        sA_m, sA_k = map(int, input_a.stride())
        sB_0, sB_n, sB_k = map(int, input_b.stride())
        sC_m, sC_n = map(int, {{get_output()}}.stride())

        prep_executor = cute.compile(
            launch_build_group_ptrs_from_bases,
            base_A_u64=base_A_u64,
            base_B_u64=base_B_u64,
            base_C_u64=base_C_u64,
            offs=from_dlpack(input_a_offs),
            G=int(G),
            K=int(K),
            N=int(N),
            sizeof_element=sizeof_element,
            stride_A_m_elems=sA_m,
            stride_A_k_elems=sA_k,
            stride_B0_elems=sB_0,
            stride_Bk_elems=sB_k,
            stride_Bn_elems=sB_n,
            stride_C_m_elems=sC_m,
            stride_C_n_elems=sC_n,
            out_ptrs=ptrs,
            out_problem=probs,
            out_strides_abc=strides,
            stream=stream,
        )

        prep_cache[prep_cache_key] = prep_executor

    prep_executor(
        base_A_u64=base_A_u64,
        base_B_u64=base_B_u64,
        base_C_u64=base_C_u64,
        offs=from_dlpack(input_a_offs),
        out_ptrs=ptrs,
        out_problem=probs,
        out_strides_abc=strides,
        stream=stream,
    )

    # --- Tensormap workspace per SM ---
    num_tensormap_buffers, max_active_clusters = get_hardware_info()
    tensormap_shape = (
        num_tensormap_buffers,
        GroupedGemmKernel.num_tensormaps,
        GroupedGemmKernel.bytes_per_tensormap // 8,
    )
    tensormap_workspace_t = torch.empty(tensormap_shape, device=dev, dtype=torch.int64)
    tensormap_workspace = from_dlpack(tensormap_workspace_t)

    # --- Total clusters ---
    def compute_total_num_clusters(
        problem_sizes_mnkl,
        cluster_tile_shape_mn,
    ):
        total_num_clusters = 0
        for m, n, _, _ in problem_sizes_mnkl:
            num_clusters_mn = tuple(
                ceildiv(x, y) for x, y in zip((m, n), cluster_tile_shape_mn)
            )
            total_num_clusters += functools.reduce(lambda x, y: x * y, num_clusters_mn)
        return total_num_clusters

    # Compute cluster tile shape
    def compute_cluster_tile_shape(
        mma_tiler_mn,
        cluster_shape_mn,
        use_2cta_instrs,
    ):
        cta_tile_shape_mn = list(mma_tiler_mn)
        if use_2cta_instrs:
            cta_tile_shape_mn[0] = cta_tile_shape_mn[0] // 2
        return tuple(x * y for x, y in zip(cta_tile_shape_mn, cluster_shape_mn))

    cluster_tile_shape_mn = compute_cluster_tile_shape(
        (TILE_M, TILE_N), (CLUSTER_M, CLUSTER_N), bool(USE_2_CTA)
    )

    total_num_clusters = int(compute_total_num_clusters(probs_t, cluster_tile_shape_mn))

    gemm_cache_key = get_gemm_cache_key(
        prep_cache_key, max_active_clusters, total_num_clusters
    )
    gemm_executor = gemm_cache.get(gemm_cache_key)

    if gemm_executor is None:
        grouped_gemm = GroupedGemmKernel(
            acc_dtype=ACC_DTYPE,
            use_2cta_instrs=USE_2_CTA,
            mma_tiler_mn=(TILE_M, TILE_N),
            cluster_shape_mn=(CLUSTER_M, CLUSTER_N),
            tensormap_update_mode=TENSORMAP_UPDATE_MODE,
        )

        gemm_executor = cute.compile(
            grouped_gemm,
            from_dlpack(input_a.unsqueeze(-1), assumed_align=16),
            from_dlpack(input_b[0].unsqueeze(-1), assumed_align=16),
            from_dlpack({{get_output()}}.unsqueeze(-1), assumed_align=16),
            G,
            probs,
            strides,
            ptrs,
            total_num_clusters,
            tensormap_workspace,
            max_active_clusters,
            stream,
        )

        gemm_cache[gemm_cache_key] = gemm_executor

    gemm_executor(
        from_dlpack(input_a.unsqueeze(-1), assumed_align=16),
        from_dlpack(input_b[0].unsqueeze(-1), assumed_align=16),
        from_dlpack({{get_output()}}.unsqueeze(-1), assumed_align=16),
        probs,
        strides,
        ptrs,
        tensormap_workspace,
        stream,
    )
