{% macro assign_maybe_constexpr(name, value_expr) -%}
    {%- set value_str = value_expr | string -%}
    {%- set sentinel = "__NOT_A_NUMBER__" -%}
    {%- set as_int = value_str | int(default=sentinel) -%}
    {%- set as_float = value_str | float(default=sentinel) -%}
    {%- set is_constexpr = (as_int != sentinel) or (as_float != sentinel) -%}
    {{ name }}{{ ": tl.constexpr" if is_constexpr else "" }} = {{ value_expr }}
{%- endmacro %}

import triton
import triton.language as tl

@triton.jit
def do_tma_loads(
    g, a_desc, b_desc, m_offset, n_offset, k_offset,
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
{%- if A_IS_2D %}
{%- if A_IS_K_MAJOR %}
    a = a_desc.load([m_offset, k_offset])
{%- else %}
    a = a_desc.load([k_offset, m_offset])
{%- endif %}
{%- else %}
{%- if A_IS_K_MAJOR %}
    a = a_desc.load([g, m_offset, k_offset]).reshape(BLOCK_M, BLOCK_K)
{%- else %}
    a = a_desc.load([g, k_offset, m_offset]).reshape(BLOCK_K, BLOCK_M)
{%- endif %}
{%- endif %}
{%- if B_IS_2D %}
{%- if B_IS_K_MAJOR %}
    b = b_desc.load([n_offset, k_offset])
{%- else %}
    b = b_desc.load([k_offset, n_offset])
{%- endif %}
{%- else %}
{%- if B_IS_K_MAJOR %}
    b = b_desc.load([g, n_offset, k_offset]).reshape(BLOCK_N, BLOCK_K)
{%- else %}
    b = b_desc.load([g, k_offset, n_offset]).reshape(BLOCK_K, BLOCK_N)
{%- endif %}
{%- endif %}

    return (a, b)


@triton.jit
def do_mma(a, b, accumulator):
{%- if USE_FAST_ACCUM %}
{%- if A_IS_K_MAJOR and B_IS_K_MAJOR %}
    accumulator = tl.dot(a, b.T, accumulator)
{%- elif A_IS_K_MAJOR and not B_IS_K_MAJOR %}
    accumulator = tl.dot(a, b, accumulator)
{%- elif not A_IS_K_MAJOR and B_IS_K_MAJOR %}
    accumulator = tl.dot(a.T, b.T, accumulator)
{%- else %}
    accumulator = tl.dot(a.T, b, accumulator)
{%- endif %}
{%- else %}
{%- if A_IS_K_MAJOR and B_IS_K_MAJOR %}
    accumulator += tl.dot(a, b.T)
{%- elif A_IS_K_MAJOR and not B_IS_K_MAJOR %}
    accumulator += tl.dot(a, b)
{%- elif not A_IS_K_MAJOR and B_IS_K_MAJOR %}
    accumulator += tl.dot(a.T, b.T)
{%- else %}
    accumulator += tl.dot(a.T, b)
{%- endif %}
{%- endif %}

    return accumulator


{%- if SCALED %}
{%- if A_IS_2D or B_IS_2D %}
{{def_kernel("a_ptr", "b_ptr", "scale_a_ptr", "scale_b_ptr", "offsets_ptr")}}
{%- else %}
{{def_kernel("a_ptr", "b_ptr", "scale_a_ptr", "scale_b_ptr")}}
{%- endif %}
{%- else %}
{%- if A_IS_2D or B_IS_2D %}
{{def_kernel("a_ptr", "b_ptr", "offsets_ptr")}}
{%- else %}
{{def_kernel("a_ptr", "b_ptr")}}
{%- endif %}
{%- endif %}
    tidx = tl.program_id(0).to(INDEX_DTYPE)

{%- set M_IS_VARYING = A_IS_2D and not B_IS_2D %}
{%- set N_IS_VARYING = not A_IS_2D and B_IS_2D %}
{%- set K_IS_VARYING = A_IS_2D and B_IS_2D %}

{%- if A_IS_2D %}
{%- if B_IS_2D %}
    {{ assign_maybe_constexpr("G", size("offsets_ptr", 0)) }}
{%- else %}
    {{ assign_maybe_constexpr("G", size("b_ptr", 0)) }}
{%- endif %}
{%- else %}
{%- if B_IS_2D %}
    {{ assign_maybe_constexpr("G", size("a_ptr", 0)) }}
{%- else %}
    {{ assign_maybe_constexpr("G", size("a_ptr", 0)) }}
{%- endif %}
{%- endif %}

    # the b_ptr tensor is given with its last two dims transposed, revert here

    {{ assign_maybe_constexpr("M", size("a_ptr", -2)) }}
    {{ assign_maybe_constexpr("N", size("b_ptr", -1)) }}
    {{ assign_maybe_constexpr("K", size("a_ptr", -1)) }}

    {{ assign_maybe_constexpr("A_STRIDE_M", stride("a_ptr", -2)) }}
    {{ assign_maybe_constexpr("A_STRIDE_K", stride("a_ptr", -1)) }}
{%- if not A_IS_2D %}
    {{ assign_maybe_constexpr("A_STRIDE_G", stride("a_ptr", 0)) }}
{%- if SCALED %}
    {{ assign_maybe_constexpr("SCALE_A_STRIDE_G", stride("scale_a_ptr", 0)) }}
{%- endif %}
{%- endif %}
    {{ assign_maybe_constexpr("B_STRIDE_N", stride("b_ptr", -1)) }}
    {{ assign_maybe_constexpr("B_STRIDE_K", stride("b_ptr", -2)) }}
{%- if not B_IS_2D %}
    {{ assign_maybe_constexpr("B_STRIDE_G", stride("b_ptr", 0)) }}
    B_STRIDE_G = {{stride("b_ptr", 0)}}
{%- if SCALED %}
    {{ assign_maybe_constexpr("SCALE_B_STRIDE_G", stride("scale_b_ptr", 0)) }}
{%- endif %}
{%- endif %}

{%- if USE_TMA_LOAD %}
{%- if USE_EXPERIMENTAL_MAKE_TENSOR_DESCRIPTOR %}
    a_desc = tl._experimental_make_tensor_descriptor(
{%- else %}
    a_desc = tl.make_tensor_descriptor(
{%- endif %}
        a_ptr,
{%- if A_IS_2D %}
{%- if A_IS_K_MAJOR %}
        shape=[M, K],
        strides=[A_STRIDE_M, A_STRIDE_K],
        block_shape=[BLOCK_M, BLOCK_K],
{%- else %}
        shape=[K, M],
        strides=[A_STRIDE_K, A_STRIDE_M],
        block_shape=[BLOCK_K, BLOCK_M],
{%- endif %}
{%- else %}
{%- if A_IS_K_MAJOR %}
        shape=[G, M, K],
        strides=[A_STRIDE_G, A_STRIDE_M, A_STRIDE_K],
        block_shape=[1, BLOCK_M, BLOCK_K],
{%- else %}
        shape=[G, K, M],
        strides=[A_STRIDE_G, A_STRIDE_K, A_STRIDE_M],
        block_shape=[1, BLOCK_K, BLOCK_M],
{%- endif %}
{%- endif %}
    )

{%- if USE_EXPERIMENTAL_MAKE_TENSOR_DESCRIPTOR %}
    b_desc = tl._experimental_make_tensor_descriptor(
{%- else %}
    b_desc = tl.make_tensor_descriptor(
{%- endif %}
        b_ptr,
{%- if B_IS_2D %}
{%- if B_IS_K_MAJOR %}
        shape=[N, K],
        strides=[B_STRIDE_N, B_STRIDE_K],
        block_shape=[BLOCK_N, BLOCK_K],
{%- else %}
        shape=[K, N],
        strides=[B_STRIDE_K, B_STRIDE_N],
        block_shape=[BLOCK_K, BLOCK_N],
{%- endif %}
{%- else %}
{%- if B_IS_K_MAJOR %}
        shape=[G, N, K],
        strides=[B_STRIDE_G, B_STRIDE_N, B_STRIDE_K],
        block_shape=[1, BLOCK_N, BLOCK_K],
{%- else %}
        shape=[G, K, N],
        strides=[B_STRIDE_G, B_STRIDE_K, B_STRIDE_N],
        block_shape=[1, BLOCK_K, BLOCK_N],
{%- endif %}
{%- endif %}
    )
{%- endif %}

{%- if M_IS_VARYING %}
    m_end_offset = 0
{%- endif %}
{%- if N_IS_VARYING %}
    n_end_offset = 0
{%- endif %}
{%- if K_IS_VARYING %}
    k_end_offset = 0
{%- endif %}
    iterated_tiles = 0
    for g in tl.range(G):
{%- if M_IS_VARYING %}
        # Move across groups
        m_start_offset = m_end_offset
        m_end_offset = tl.load(offsets_ptr + g)
        m_size = m_end_offset - m_start_offset
{%- if SCALED %}
        m_scale_start_offset = m_start_offset
{%- endif %}
{%- else %}
        m_start_offset = 0
        m_size = M
{%- if SCALED %}
        m_scale_start_offset = g * M
{%- endif %}
{%- endif %}

{%- if N_IS_VARYING %}
        # Move across groups
        n_start_offset = n_end_offset
        n_end_offset = tl.load(offsets_ptr + g)
        n_size = n_end_offset - n_start_offset
{%- if SCALED %}
        n_scale_start_offset = n_start_offset
{%- endif %}
{%- else %}
        n_start_offset = 0
        n_size = N
{%- if SCALED %}
        n_scale_start_offset = g * N
{%- endif %}
{%- endif %}

        if m_size > 0 and n_size > 0:
{%- if K_IS_VARYING %}
            # Move across groups
            k_start_offset = k_end_offset
            k_end_offset = tl.load(offsets_ptr + g)
            k_size = k_end_offset - k_start_offset
{%- else %}
            k_start_offset = 0
            k_size = K
{%- endif %}

            num_m_tiles = tl.cdiv(m_size, BLOCK_M)
            num_n_tiles = tl.cdiv(n_size, BLOCK_N)
            num_tiles = num_m_tiles * num_n_tiles

            # Move across tiles
            while tidx >= iterated_tiles and tidx < iterated_tiles + num_tiles:
                gidx = tidx - iterated_tiles
                # Split M first and N second.
                tile_m_idx = gidx % num_m_tiles
                tile_n_idx = gidx // num_m_tiles

                accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)

{%- if USE_TMA_LOAD %}
                m_tile_offset = tile_m_idx * BLOCK_M
                n_tile_offset = tile_n_idx * BLOCK_N
                m_offset = (m_start_offset + m_tile_offset).to(tl.int32)
                n_offset = (n_start_offset + n_tile_offset).to(tl.int32)

                k_block_offset = 0
                for k in range(k_size // BLOCK_K):
                    k_offset = k_start_offset + k_block_offset
                    a, b = do_tma_loads(
                        g, a_desc, b_desc, m_offset, n_offset, k_offset,
                        BLOCK_M, BLOCK_N, BLOCK_K
                    )
                    accumulator = do_mma(a, b, accumulator)
                    k_block_offset += BLOCK_K

                if k_size % BLOCK_K != 0:
                    k_offset = k_start_offset + k_block_offset
                    a, b = do_tma_loads(
                        g, a_desc, b_desc, m_offset, n_offset, k_offset,
                        BLOCK_M, BLOCK_N, BLOCK_K
                    )
{%- if K_IS_VARYING %}
                    group_offs = k_block_offset + tl.arange(0, BLOCK_K)
                    k_mask = group_offs < k_size
{%- if A_IS_K_MAJOR %}
                    a = tl.where(k_mask[None, :], a, 0)
{%- else %}
                    a = tl.where(k_mask[:, None], a, 0)
{%- endif %}
{%- if B_IS_K_MAJOR %}
                    b = tl.where(k_mask[None, :], b, 0)
{%- else %}
                    b = tl.where(k_mask[:, None], b, 0)
{%- endif %}
{%- endif %}
                    accumulator = do_mma(a, b, accumulator)
{%- else %}
                offs_am = tile_m_idx * BLOCK_M + tl.arange(0, BLOCK_M)
                offs_bn = tile_n_idx * BLOCK_N + tl.arange(0, BLOCK_N)
                for k_block_offset in range(0, k_size, BLOCK_K):
                    block_offs_k = k_block_offset + tl.arange(0, BLOCK_K)
                    offs_k = block_offs_k + k_start_offset
                    a_ptrs = (
                        a_ptr
{%- if not A_IS_2D %}
                        + g * A_STRIDE_G
{%- endif %}
                        + (m_start_offset + offs_am[:, None]) * A_STRIDE_M
                        + offs_k[None, :] * A_STRIDE_K
                    )
                    b_ptrs = (
                        b_ptr
{%- if not B_IS_2D %}
                        + g * B_STRIDE_G
{%- endif %}
                        + (n_start_offset + offs_bn[:, None]) * B_STRIDE_N
                        + offs_k[None, :] * B_STRIDE_K
                    )
                    a_mask = (offs_am[:, None] < m_size) & (block_offs_k[None, :] < k_size)
                    b_mask = (offs_bn[:, None] < n_size) & (block_offs_k[None, :] < k_size)
                    a = tl.load(a_ptrs, mask=a_mask, other=tl.zeros((), dtype=a_ptrs.dtype.element_ty))
                    b = tl.load(b_ptrs, mask=b_mask, other=tl.zeros((), dtype=b_ptrs.dtype.element_ty))
{%- if USE_FAST_ACCUM %}
                    accumulator = tl.dot(a, b.T, accumulator)
{%- else %}
                    accumulator += tl.dot(a, b.T)
{%- endif %}
                    a_ptrs += BLOCK_K
                    b_ptrs += BLOCK_K
{%- endif %}

                offs_am = tile_m_idx * BLOCK_M + tl.arange(0, BLOCK_M)
                offs_bn = tile_n_idx * BLOCK_N + tl.arange(0, BLOCK_N)
{%- if SCALED %}
                scale_a = tl.load(
                    scale_a_ptr
{%- if A_IS_2D %}
                    + m_scale_start_offset
{%- else %}
                    + g * SCALE_A_STRIDE_G
{%- endif %}
                    + offs_am[:, None],
                    mask=offs_am[:, None] < m_size,
                    other=tl.zeros((), dtype=scale_a_ptr.dtype.element_ty),
                )
                scale_b = tl.load(
                    scale_b_ptr
{%- if B_IS_2D %}
                    + n_scale_start_offset
{%- else %}
                    + g * SCALE_B_STRIDE_G
{%- endif %}
                    + offs_bn[None, :],
                    mask=offs_bn[None, :] < n_size,
                    other=tl.zeros((), dtype=scale_b_ptr.dtype.element_ty),
                )
                c = accumulator.to(tl.float32) * scale_a * scale_b
{%- else %}
                c = accumulator.to(tl.float32)
{%- endif %}

{%- if M_IS_VARYING %}
                idx_m = (m_start_offset + offs_am[:, None])
{%- else %}
                idx_m = offs_am[:, None]
{%- endif %}
{%- if N_IS_VARYING %}
                idx_n = (n_start_offset + offs_bn[None, :])
{%- else %}
                idx_n = offs_bn[None, :]
{%- endif %}
                mask = (offs_am[:, None] < m_size) & (offs_bn[None, :] < n_size)
{%- if M_IS_VARYING or N_IS_VARYING %}
                {{store_output(("idx_m", "idx_n"), "c", "mask", indent_width=16, val_shape=("BLOCK_M", "BLOCK_N"))}}
{%- else %}
                {{store_output(("g", "idx_m", "idx_n"), "c", "mask", indent_width=16, val_shape=("BLOCK_M", "BLOCK_N"))}}
{%- endif %}
                tidx += NUM_SMS

            iterated_tiles += num_tiles
