{% if NEEDS_BLOCK_MASK %}
{{def_kernel("Q", "K", "V", "LOGSUMEXP", "KV_NUM_BLKS", "KV_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX")}}
{% else %}
{{def_kernel("Q", "K", "V", "LOGSUMEXP")}}
{% endif %}
    from flash_attn.cute.interface import _flash_attn_fwd
    from flash_attn.cute.block_sparsity import BlockSparseTensorsTorch

    # Transpose tensors for _flash_attn_fwd compatibility (B,H,M,D) -> (B,M,H,D)
    q_transposed = Q.transpose(1, 2)
    k_transposed = K.transpose(1, 2)
    v_transposed = V.transpose(1, 2)

    {% if HAS_SCORE_MOD %}
    @cute.jit
    def score_mod(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors):
        {{unpack_buffers("aux_tensors", indent_width=8)}}
        {{ modification(
            subgraph_number=0,
            output_name="tSrS_ssa",
            score="tSrS_ssa",
            b="b_idx",
            h="h_idx",
            m="q_idx",
            n="kv_idx",
            out="tSrS_ssa"
        ) | indent_except_first(2) }}
        return tSrS_ssa
    {{ set_cute_hash("score_mod", "score") }}
    {% else %}
    score_mod = None
    {% endif %}

    # (B,M,H,D) -> (B,H,M,D)
    output = {{get_output()}}
    output_transposed = output.transpose(1, 2)

    {% if NEEDS_BLOCK_MASK %}
    {# mask_mod is subgraph 1 when HAS_SCORE_MOD, else subgraph 0 #}
    {% set mask_subgraph_idx = 1 if HAS_SCORE_MOD else 0 %}
    @cute.jit
    def mask_mod(b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors):
        {{unpack_buffers("aux_tensors", indent_width=8)}}
        {{ modification(
        subgraph_number=mask_subgraph_idx,
        output_name="mask_mod_output",
        b="b_idx",
        h="h_idx",
        m="q_idx",
        n="kv_idx",
        ) | indent_except_first(2) }}
        return mask_mod_output
    {{ set_cute_hash("mask_mod", "mask") }}
    block_sparse_tensors = BlockSparseTensorsTorch(KV_NUM_BLKS, KV_IDX, FULL_KV_NUM_BLKS, FULL_KV_IDX, block_size=({{SPARSE_Q_BLOCK_SIZE}}, {{SPARSE_KV_BLOCK_SIZE}}))
    {% else %}
    block_sparse_tensors = None
    mask_mod = None
    {% endif %}

    # Collect any additional tensor buffers that were added during modifications
    {% set tensor_buffers = get_tensor_buffers() -%}
    {% if tensor_buffers -%}
    buffers = [{% for buffer in tensor_buffers %}{{buffer}}{% if not loop.last %}, {% endif %}{% endfor %}]
    buffers = list(buffers)
    {% else -%}
    buffers = None
    {% endif -%}

    # Out and LSE filled inplace
    _flash_attn_fwd(
        q_transposed,
        k_transposed,
        v_transposed,
        softmax_scale={{SM_SCALE}},
        return_lse=True,
        score_mod=score_mod,
        mask_mod=mask_mod,
        out=output_transposed,
        lse=LOGSUMEXP,
        block_sparse_tensors=block_sparse_tensors,
        aux_tensors=buffers
    )
