{% if HAS_BLOCK_MASK %}
{{def_kernel("Q", "K", "V", "OUT", "D_OUT", "LSE", "DK", "DV", "Q_NUM_BLKS", "Q_IDX", "FULL_Q_NUM_BLKS", "FULL_Q_IDX")}}
{% else %}
{{def_kernel("Q", "K", "V", "OUT", "D_OUT", "LSE", "DK", "DV")}}
{% endif %}
    from flash_attn.cute.interface import _flash_attn_bwd

    q_transposed = Q.transpose(1, 2)
    k_transposed = K.transpose(1, 2)
    v_transposed = V.transpose(1, 2)
    out_transposed = OUT.transpose(1, 2)
    d_out_transposed = D_OUT.transpose(1, 2)

    dq_out = {{get_output()}}
    dq_out_transposed = dq_out.transpose(1, 2)
    dk_out_transposed = DK.transpose(1, 2)
    dv_out_transposed = DV.transpose(1, 2)

    {% if HAS_SCORE_MOD %}
    @cute.jit
    def score_mod(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors):
        {{unpack_buffers("aux_tensors", indent_width=8)}}
        {{ modification(
            subgraph_number=0,
            output_name="tSrS_ssa",
            score="tSrS_ssa",
            b="b_idx",
            h="h_idx",
            m="q_idx",
            n="kv_idx",
            out="tSrS_ssa"
        ) | indent_except_first(2) }}
        return tSrS_ssa
    {{ set_cute_hash("score_mod", "score") }}

    @cute.jit
    def score_mod_bwd(grad_score_mod_ssa, tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors):
        {{unpack_buffers("aux_tensors", indent_width=8)}}
        {{ modification(
            subgraph_number=1,
            output_name="grad_score_mod_ssa_out",
            score="tSrS_ssa",
            grad_score_mod="grad_score_mod_ssa",
            b="b_idx",
            h="h_idx",
            m="q_idx",
            n="kv_idx"
        ) | indent_except_first(2) }}
        return grad_score_mod_ssa_out
    {{ set_cute_hash("score_mod_bwd", "score_bwd") }}
    {% else %}
    score_mod = None
    score_mod_bwd = None
    {% endif %}

    {% if HAS_BLOCK_MASK %}
    from flash_attn.cute.block_sparsity import BlockSparseTensorsTorch
    {# mask_mod is subgraph 2 when HAS_SCORE_MOD, else subgraph 0 #}
    {% set mask_subgraph_idx = 2 if HAS_SCORE_MOD else 0 %}
    @cute.jit
    def mask_mod(b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors):
        {{unpack_buffers("aux_tensors", indent_width=8)}}
        {{ modification(
            subgraph_number=mask_subgraph_idx,
            output_name="mask_mod_output",
            b="b_idx",
            h="h_idx",
            m="q_idx",
            n="kv_idx"
        ) | indent_except_first(2) }}
        return mask_mod_output
    {{ set_cute_hash("mask_mod", "mask") }}
    block_sparse_tensors = BlockSparseTensorsTorch(Q_NUM_BLKS, Q_IDX, FULL_Q_NUM_BLKS, FULL_Q_IDX, block_size=({{SPARSE_Q_BLOCK_SIZE}}, {{SPARSE_KV_BLOCK_SIZE}}))
    {% else %}
    mask_mod = None
    block_sparse_tensors = None
    {% endif %}

    {% set tensor_buffers = get_tensor_buffers() -%}
    {% if tensor_buffers -%}
    buffers = [{% for buffer in tensor_buffers %}{{buffer}}{% if not loop.last %}, {% endif %}{% endfor %}]
    buffers = list(buffers)
    {% else -%}
    buffers = None
    {% endif -%}

    _flash_attn_bwd(
        q_transposed,
        k_transposed,
        v_transposed,
        out_transposed,
        d_out_transposed,
        LSE,
        softmax_scale={{SM_SCALE}},
        score_mod=score_mod,
        score_mod_bwd=score_mod_bwd,
        mask_mod=mask_mod,
        aux_tensors=buffers,
        dq=dq_out_transposed,
        dk=dk_out_transposed,
        dv=dv_out_transposed,
        block_sparse_tensors=block_sparse_tensors,
        deterministic=torch.are_deterministic_algorithms_enabled() and not torch.is_deterministic_algorithms_warn_only_enabled()
    )
