"""
This module is responsible for transforming functions to be traced into a form
that is easier for the downstream infra (e.g. Autograd, FX, AOTAutograd analysis)
to handle.

It does so by:
1. functionalization (including RNG functionalzation)
2. creating a joint graph when required
3. transforming mutations into extra outputs
4. dispatching subclasses
"""

import warnings
from collections.abc import Callable, Generator
from contextlib import AbstractContextManager, contextmanager, ExitStack, nullcontext
from dataclasses import dataclass
from typing import Any, Optional, TypeVar, Union
from unittest.mock import patch

import torch
import torch.fx.traceback as fx_traceback
import torch.utils._pytree as pytree
from torch import Tensor
from torch._decomp.decompositions_for_rng import PhiloxStateTracker
from torch._guards import detect_fake_mode
from torch._prims_common import CUDARngStateHelper
from torch.fx.experimental.proxy_tensor import (
    _proxy_tensor_disable_update_tensor_tracker,
    get_proxy_mode,
    maybe_disable_thunkify,
    maybe_enable_thunkify,
)
from torch.fx.experimental.symbolic_shapes import (
    guard_or_true,
    PropagateUnbackedSymInts,
    sym_eq,
)
from torch.nn.utils import stateless
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
from torch.utils._pytree import TreeSpec

from .. import config
from .collect_metadata_analysis import run_functionalized_fw_and_collect_metadata
from .descriptors import (
    AOTInput,
    AOTOutput,
    BackwardTokenAOTOutput,
    ForwardTokenAOTInput,
    ForwardTokenAOTOutput,
    GradAOTOutput,
    InputMutationAOTOutput,
    IntermediateBaseAOTOutput,
    PhiloxBackwardBaseOffsetAOTInput,
    PhiloxBackwardSeedAOTInput,
    PhiloxForwardBaseOffsetAOTInput,
    PhiloxForwardSeedAOTInput,
    PhiloxUpdatedBackwardOffsetAOTOutput,
    PhiloxUpdatedForwardOffsetAOTOutput,
)
from .functional_utils import (
    _check_if_mutation_can_be_in_graph,
    are_all_mutations_hidden_from_autograd,
    are_all_mutations_under_no_grad_or_inference_mode,
    from_fun,
    has_data_mutation,
    has_metadata_mutation,
    is_fun,
    sync_functional_tensor,
    to_fun,
    was_inductor_storage_resized,
)
from .logging_utils import setup_stacktrace_preservation_hooks
from .schemas import (
    AOTConfig,
    FxValue,
    InputAliasInfo,
    JointTraceFn,
    MutationType,
    OutputType,
    PreppedForAutogradTraceFn,
    SubclassMeta,
    SubclassTracingInfo,
    TraceFn,
    ViewAndMutationMeta,
)
from .subclass_utils import (
    create_subclass_meta,
    remap_unwrapped_subclass_arg_indices,
    requires_subclass_dispatch,
    unwrap_tensor_subclasses,
    wrap_tensor_subclasses_maybe_joint,
)
from .utils import (
    _is_tangent,
    call_and_expect_output_descs,
    maybe_to_fresh_input,
    simple_wraps,
    without_output_descs,
)


# This function returns a new function that returns mutated inputs as outputs.
# if keep_data_input_mutations is set, then we assume that data-only mutations
# will be left in the graph, and we only return metadata-mutated inputs as outputs.
def fn_input_mutations_to_outputs(
    fn: Callable[..., Any],
    args_descs: list[AOTInput],
    meta: ViewAndMutationMeta,
    keep_data_input_mutations: bool,
) -> Any:
    @simple_wraps(fn)
    def inner_fn(*args: FxValue) -> tuple[tuple[Any, ...], tuple[Any, ...]]:
        outs, outs_descs = call_and_expect_output_descs(fn, args)
        if len(meta.output_info) != len(outs):
            raise AssertionError(
                f"output_info length ({len(meta.output_info)}) != outs length ({len(outs)})"
            )
        # The compiled fw will return mutated input tensors, *including* metadata-only mutation.
        # However, if keep_data_input_mutations is set, the compiled fw only needs to return metadata-mutated inputs.
        # (because data-only input mutations are handled directly in the compiled graph)
        mutated_input_pairs = [
            (x, InputMutationAOTOutput(src))
            for (i, (x, src)) in enumerate(zip(args, args_descs))
            if i in meta.mutated_inp_runtime_indices
        ]
        if mutated_input_pairs:
            mutated_inputs_to_return, mutated_inputs_to_return_descs = zip(
                *mutated_input_pairs
            )
        else:
            mutated_inputs_to_return, mutated_inputs_to_return_descs = (), ()
        return (
            (*mutated_inputs_to_return, *outs),
            (*mutated_inputs_to_return_descs, *outs_descs),
        )

    return inner_fn


@contextmanager
def disable_autocast() -> Generator[None, None, None]:
    with ExitStack() as stack:
        autocast_enabled_devices = torch._C._autocast_supported_devices()
        for device_type in autocast_enabled_devices:
            if hasattr(torch, device_type):
                stack.enter_context(torch.amp.autocast(device_type, enabled=False))
        yield


# This function takes in a fn with external aliasing and mutation,
# and returns a new fn with no external aliasing and mutation,
# as needed for autograd.
# The main transformations are:
# - Return mutated inputs as extra outputs
# - Clone mutated inputs that require gradients,
#   because autograd will require us to pass the pre-mutated inputs into autograd.grad
# - Return intermediate bases of outputs as additional outputs,
#   needed to appease autograd.Function
# The new function returns:
# (1) The updated outputs
# (2) A boolean mask of len(new_fn_outputs),
#     that can be used to tell autograd.grad which outputs should get tangents
#     if we trace the backward.
def fn_prepped_for_autograd(
    fn: TraceFn,
    args_descs: list[AOTInput],
    meta: ViewAndMutationMeta,
    aot_config: AOTConfig,
) -> PreppedForAutogradTraceFn:
    @simple_wraps(fn)
    def inner_fn(
        *args: FxValue,
    ) -> tuple[tuple[list[FxValue], list[bool]], list[AOTOutput]]:
        args_maybe_cloned = [
            maybe_to_fresh_input(i, t, meta) for i, t in enumerate(args)
        ]

        outs, outs_descs = call_and_expect_output_descs(fn, args_maybe_cloned)  # type: ignore[arg-type]
        if not isinstance(outs, (tuple, list)):
            raise AssertionError(f"expected outs to be tuple or list, got {type(outs)}")
        outs = list(outs)
        if len(meta.output_info) != len(outs):
            raise AssertionError(
                f"output_info length ({len(meta.output_info)}) != outs length ({len(outs)})"
            )

        mutated_input_pairs = [
            (x, InputMutationAOTOutput(src))
            for (i, (x, src)) in enumerate(zip(args_maybe_cloned, args_descs))
            if i in meta.mutated_inp_runtime_indices
        ]
        if mutated_input_pairs:
            mutated_inputs_to_return, mutated_inputs_to_return_descs = zip(
                *mutated_input_pairs
            )
        else:
            mutated_inputs_to_return, mutated_inputs_to_return_descs = (), ()

        intermediate_bases = []
        intermediate_bases_descs = []
        for o, info, o_desc in zip(outs, meta.output_info, outs_descs):
            if info.output_type == OutputType.alias_of_intermediate_save_as_output:
                if not isinstance(o, torch.Tensor):
                    raise AssertionError(
                        f"Expected tensor for intermediate base, got {type(o)}"
                    )
                intermediate_bases.append(o._base)
                intermediate_bases_descs.append(IntermediateBaseAOTOutput(o_desc))

        if meta.num_intermediate_bases != len(intermediate_bases):
            raise AssertionError(
                f"num_intermediate_bases ({meta.num_intermediate_bases}) != len(intermediate_bases) ({len(intermediate_bases)})"
            )

        # the compiled forward should return (mutated_inputs, user_outs, intermediate_bases)
        fw_outs_to_return = *mutated_inputs_to_return, *outs, *intermediate_bases
        fw_outs_to_return_descs = (
            *mutated_inputs_to_return_descs,
            *outs_descs,
            *intermediate_bases_descs,
        )

        # Also return a boolean mask specifying which outputs to this function will be used as tangents
        mutated_inputs_grad_mask = [
            meta.input_info[meta.mutated_inp_runtime_indices[i]].mutates_data
            and meta.input_info[meta.mutated_inp_runtime_indices[i]].requires_grad
            for (i, x) in enumerate(mutated_inputs_to_return)
        ]

        # Pass any (non-aliased) outputs in as tangents, since they'll be returned as outputs in the fw
        # For outputs that are aliases of intermediates, we will have returned the output's _base as an output in the graph instead,
        # which we *should* send to grad()
        output_grad_mask = [
            meta.output_info[i].output_type
            in [
                OutputType.non_alias,
                OutputType.unsafe_view_alias,
                OutputType.custom_function_view,
            ]
            # Also, only tensor outputs should participate in the backward
            # (in particular, Symint outputs in the forward graph shouldn't get tangents)
            and issubclass(meta.output_info[i].raw_type, Tensor)
            and meta.output_info[i].requires_grad
            for (i, x) in enumerate(outs)
        ]

        intermediate_base_grad_mask = [True for _ in range(len(intermediate_bases))]

        out_grad_mask = (
            mutated_inputs_grad_mask + output_grad_mask + intermediate_base_grad_mask
        )
        if len(out_grad_mask) != len(fw_outs_to_return):
            raise AssertionError(
                f"out_grad_mask length ({len(out_grad_mask)}) != fw_outs_to_return length ({len(fw_outs_to_return)})"
            )

        # Take care to grab and sync the updated inputs from primals_after_cloning (the inputs we actually mutate!)
        # and not primals (the preserved inputs, pre-mutation, that we pass to grad())
        # This is annoying: our joint function needs to be aware of functionalization
        # (syncing mutated inputs before calling autograd.grad())
        # In theory, we could make the autograd engine do this automatically, although that probably isn't any cleaner.
        if not aot_config.disable_functionalization:
            for arg in args_maybe_cloned:
                if not isinstance(arg, Tensor):
                    continue
                sync_functional_tensor(arg)

        # pyrefly: ignore[bad-return]
        return (fw_outs_to_return, out_grad_mask), (
            fw_outs_to_return_descs,
            out_grad_mask,
        )

    return inner_fn


@dataclass
class JointFnHandle:
    post_forward: Callable[..., Any] | None = None


# Given a fn, computes the joint.
# NOTE: fn is expects the following behavior:
# (1) fn() needs to return a tuple of (outs, mask),
#     where `mask` tells us which outputs are meant to have tangents.
#     we don't know this info automatically, because we don't actually want to blindly
#     compute tangents for every output that requires grad.
#     Specifically, outputs that alias inputs won't participate in the backward and get tangents.
# (2) fn() cannot mutate any inputs that require gradient.
#     otherwise, when we compute autograd.grad(), we will not take those input mutations into account
#     (the way this is handled is that we ensure any inputs that normally get mutated are cloned first)
def create_joint(
    fn: Callable[..., Any],
    primals_descs: Optional[list[AOTInput]] = None,
    *,
    aot_config: AOTConfig,
) -> Callable[..., Any]:
    joint_fn_handle = JointFnHandle()

    # post_forward
    # NB: this type is inaccurate when primals_descs is None
    @simple_wraps(fn)
    def inner_fn(
        primals: list[FxValue], tangents: list[FxValue]
    ) -> tuple[
        tuple[list[FxValue], list[Optional[Tensor]]],
        tuple[list[AOTOutput], list[Optional[AOTOutput]]],
    ]:
        outs_descs = None
        if primals_descs is None:
            outs, tangent_mask = fn(*primals)
            if pytree.tree_any(lambda x: isinstance(x, AOTOutput), tangent_mask):
                raise AssertionError(
                    "tangent_mask should not contain AOTOutput instances"
                )
        else:
            (outs, tangent_mask), (outs_descs, _) = call_and_expect_output_descs(
                fn,
                primals,  # type: ignore[arg-type]
            )
        mode = get_proxy_mode()
        if mode is None:
            raise AssertionError("Expected non-None proxy mode")
        for node in mode.tracer.graph.nodes:
            if _is_tangent(node):
                node.meta["partitioner_tag"] = "is_backward"
            else:
                node.meta["partitioner_tag"] = "is_forward"

        # TODO: I think this hook can also be eliminated now
        if joint_fn_handle and joint_fn_handle.post_forward:
            joint_fn_handle.post_forward(primals)

        if len(tangent_mask) != len(outs):
            raise AssertionError(
                f"tangent_mask length ({len(tangent_mask)}) != outs length ({len(outs)})"
            )
        outs_to_grad = [
            o for needs_tangent, o in zip(tangent_mask, outs) if needs_tangent
        ]
        if len(outs_to_grad) != len(tangents):
            raise AssertionError(
                f"outs_to_grad length ({len(outs_to_grad)}) != tangents length ({len(tangents)})"
            )

        # Get the inputs that need gradients
        grad_primals: list[torch.Tensor] = []
        inputs_needs_grads = []
        # Note that we're not using primals here,
        # being carefully not to pass any mutated inputs into autograd.grad()
        for p in primals:
            if isinstance(p, Tensor) and p.requires_grad:
                inputs_needs_grads.append(True)
                if not isinstance(p, torch.Tensor):  # Help mypy understand the type
                    raise AssertionError(f"expected Tensor, got {type(p)}")
                grad_primals.append(p)
            else:
                inputs_needs_grads.append(False)

        # Get the outputs that need gradients
        needed_outs: list[Tensor] = []
        needed_tangents: list[Tensor] = []
        for out, tangent in zip(outs_to_grad, tangents):
            if isinstance(out, Tensor) and out.requires_grad:
                # A bit sketchy, but fixes e.g. test_aot_autograd_exhaustive_matmul_cpu_float32
                # The issue is that we are sensitive to decomps that don't accurately maintain
                # their output's _base.shape compared to eager mode, and this helps mitigate a bit.
                # The guard_or_true also sketchy; if unbacked
                # symints are involved, we're just going to assume that the
                # decomps setup the base shape correctly

                # Return out if the result of out.shape==tangent.shape is unknown or known to be true.
                # otherwise if its a known false return out.view(tangent.shape).
                # tangent should also be a tensor since it corresponds to a tensor output
                if not isinstance(tangent, torch.Tensor):
                    raise AssertionError(
                        f"Expected tensor tangent, got {type(tangent)}"
                    )
                needed_outs.append(
                    out
                    if guard_or_true(sym_eq(out.shape, tangent.shape))
                    else out.view(tangent.shape)
                )
                needed_tangents.append(tangent)

        setup_stacktrace_preservation_hooks(
            [out.grad_fn for out in needed_outs if out.grad_fn is not None]
        )

        if config.functionalize_rng_ops:
            PhiloxStateTracker.mark_beginning_of_backward()
        backward_out: tuple[Tensor, ...] = ()
        # Call the backwards pass
        if grad_primals:
            functional_tensor_mode = torch.utils._python_dispatch._detect_infra_mode(
                torch._C._TorchDispatchModeKey.FUNCTIONAL
            )
            if functional_tensor_mode is not None:
                # Side-Effect Tokens:
                # We want to have independent chains of tokens for forward and backward.
                # functional_tensor_mode._tokens is used by both.
                # We memoize the result tokens of forward in functional_tensor_mode._tokens_forward_output,
                # to return them as joint graph outputs.
                # We clean functional_tensor_mode._tokens before backward, to prevent reuse of forward tokens in backward.
                # Joint graph tracing allows tokens discovery,
                # So all the tokens in backward will be created and added as a graph inputs during tracing.
                functional_tensor_mode._tokens_forward_output = (
                    functional_tensor_mode._tokens
                )
                functional_tensor_mode._tokens = {}  # pyrefly: ignore[implicit-any]

            with (
                set_partitioner_tag_is_backward(),
                fx_traceback.preserve_node_meta(),
                ExitStack() as stack,
            ):
                backward_pass_autocast = torch._functorch.config.backward_pass_autocast
                if backward_pass_autocast == "same_as_forward":
                    # Use the ambient autocast mode(s)
                    pass
                elif backward_pass_autocast == "off":
                    stack.enter_context(disable_autocast())
                else:
                    # Disable autocast, then enable anything in `backward_pass_autocast`.
                    stack.enter_context(disable_autocast())
                    if not isinstance(backward_pass_autocast, list):
                        raise AssertionError(
                            f"expected backward_pass_autocast to be a list, got {type(backward_pass_autocast)}"
                        )
                    for kwargs in backward_pass_autocast:
                        if not isinstance(kwargs, dict):
                            raise AssertionError(
                                f"expected kwargs to be a dict, got {type(kwargs)}"
                            )
                        stack.enter_context(torch.amp.autocast(**kwargs))

                # for full graph export, we always export a joint graph where we assume no tangents are needed.
                if aot_config.no_tangents:
                    if not (
                        len(needed_tangents) == 1 and needed_tangents[0].numel() == 1
                    ):
                        raise AssertionError(
                            f"expected single scalar tangent for no_tangents mode, got {len(needed_tangents)} tangents"
                        )
                    backward_out = torch.autograd.grad(
                        needed_outs,
                        grad_primals,
                        allow_unused=True,
                    )
                else:
                    backward_out = torch.autograd.grad(
                        needed_outs,
                        grad_primals,
                        grad_outputs=needed_tangents,
                        allow_unused=True,
                    )
        backward_out_iter = iter(backward_out)
        final_outs = (
            outs,
            [next(backward_out_iter) if i else None for i in inputs_needs_grads],
        )
        if primals_descs is None:
            return final_outs  # type: ignore[return-value]
        if outs_descs is None:
            raise AssertionError("outs_descs must not be None")
        # pyrefly: ignore[bad-return]
        return final_outs, (
            outs_descs,
            [
                # TODO: ideally we do know this is DifferentiableAOTInput
                # but this is quite an involved refactor
                GradAOTOutput(desc) if i else None  # type: ignore[arg-type]
                for i, desc in zip(inputs_needs_grads, primals_descs)
            ],
        )

    @simple_wraps(inner_fn)
    def inner_fn_with_anomaly(
        primals: list[FxValue], tangents: list[FxValue]
    ) -> tuple[
        tuple[list[FxValue], list[Optional[Tensor]]],
        tuple[list[AOTOutput], list[Optional[AOTOutput]]],
    ]:
        with fx_traceback.preserve_node_meta(), warnings.catch_warnings():
            warnings.filterwarnings("ignore", "Anomaly Detection has been enabled.")
            with torch.autograd.detect_anomaly(check_nan=False):
                return inner_fn(primals, tangents)

    def joint_helper(
        primals: list[FxValue], tangents: list[FxValue]
    ) -> tuple[
        tuple[list[FxValue], list[Optional[Tensor]]],
        tuple[list[AOTOutput], list[Optional[AOTOutput]]],
    ]:
        return inner_fn_with_anomaly(primals, tangents)

    joint_helper.handle = joint_fn_handle  # type: ignore[attr-defined]

    # pyrefly: ignore[bad-return]
    return joint_helper


def create_functionalized_rng_ops_wrapper(
    func: Callable[..., Any],
    args: Any,
    args_descs: list[AOTInput],
    trace_joint: bool = True,
) -> Any:
    # Functionalization of rng ops changes the calling convention of the joint graph.
    # It goes from (primals, tangents) to (seed, offset, primals, tangents)
    # At runtime, we pass on the current seed and offset. This is hidden from
    # the user.
    fake_mode_det = detect_fake_mode()
    fake_mode: AbstractContextManager[Any] = nullcontext()
    if fake_mode_det is not None:
        fake_mode = fake_mode_det

    def override_get_rng_state(
        device: Union[int, str, torch.device] = "cuda",
    ) -> Tensor:
        out = PhiloxStateTracker.get_state_as_tensor()
        return out

    def override_set_rng_state(
        x: Tensor, device: Union[int, str, torch.device] = "cuda"
    ) -> None:
        PhiloxStateTracker.set_state_from_tensor(x)

    def append_rng_offsets(outs: Any, outs_descs: Any) -> Any:
        if trace_joint:
            # outs signature before: Tuple(fwd_outputs), Tuple(bwd_outputs)
            # outs signature after: Tuple(fwd_outputs, new_fwd_rng_offset), Tuple(bwd_offset, new_bwd_rng_offset)
            return (
                (
                    (*outs[0], PhiloxStateTracker.get_updated_fwd_offset()),
                    (*outs[1], PhiloxStateTracker.get_updated_bwd_offset()),
                ),
                (
                    (*outs_descs[0], PhiloxUpdatedForwardOffsetAOTOutput()),
                    (*outs_descs[1], PhiloxUpdatedBackwardOffsetAOTOutput()),
                ),
            )
        else:
            # outs signature before: Tuple(fwd_outputs)
            # outs signature after: Tuple(fwd_outputs, new_fwd_rng_offset)
            return (
                (*outs, PhiloxStateTracker.get_updated_fwd_offset()),
                (*outs_descs, PhiloxUpdatedForwardOffsetAOTOutput()),
            )

    def traced_joint(
        primals: list[FxValue],
        tangents: list[FxValue],
        fwd_seed: Tensor,
        fwd_base_offset: Tensor,
        bwd_seed: Tensor,
        bwd_base_offset: Tensor,
    ) -> tuple[
        tuple[tuple[FxValue, ...], tuple[FxValue, ...]],
        tuple[tuple[AOTOutput, ...], tuple[AOTOutput, ...]],
    ]:
        with (
            patch("torch.cuda.get_rng_state", override_get_rng_state),
            patch("torch.cuda.set_rng_state", override_set_rng_state),
        ):
            return append_rng_offsets(*func(primals, tangents))

    def traced_forward(*primals_fwd_seed_fwd_base_offset: Any) -> Any:
        # The signature is (*primals, seed, offset)
        with (
            patch("torch.cuda.get_rng_state", override_get_rng_state),
            patch("torch.cuda.set_rng_state", override_set_rng_state),
        ):
            return append_rng_offsets(*func(*primals_fwd_seed_fwd_base_offset[:-2]))

    if trace_joint:
        # Get the current seed and offset to setup tracing.
        fwd_seed, fwd_base_offset = CUDARngStateHelper.get_torch_state_as_tuple(
            fake_mode
        )
        bwd_seed, bwd_base_offset = CUDARngStateHelper.get_torch_state_as_tuple(
            fake_mode
        )
        PhiloxStateTracker.record_state(fwd_seed, fwd_base_offset, "forward")
        PhiloxStateTracker.record_state(bwd_seed, bwd_base_offset, "backward")
        return (
            traced_joint,
            (
                *args,
                fwd_seed,
                fwd_base_offset,
                bwd_seed,
                bwd_base_offset,
            ),
            (
                *args_descs,
                PhiloxForwardSeedAOTInput(),
                PhiloxForwardBaseOffsetAOTInput(),
                PhiloxBackwardSeedAOTInput(),
                PhiloxBackwardBaseOffsetAOTInput(),
            ),
        )
    else:
        # Get the current seed and offset to setup tracing.
        fwd_seed, fwd_base_offset = CUDARngStateHelper.get_torch_state_as_tuple(
            fake_mode
        )
        PhiloxStateTracker.record_state(fwd_seed, fwd_base_offset, "forward")
        return (
            traced_forward,
            (*args, fwd_seed, fwd_base_offset),
            (
                *args_descs,
                PhiloxForwardSeedAOTInput(),
                PhiloxForwardBaseOffsetAOTInput(),
            ),
        )


@contextmanager
def set_partitioner_tag(tag: str) -> Generator[None, None, None]:
    meta_key = "partitioner_tag"
    if not fx_traceback.has_preserved_node_meta():
        raise AssertionError("expected preserved node meta")

    original_val = fx_traceback.current_meta.get(meta_key, None)
    fx_traceback.current_meta[meta_key] = tag
    try:
        yield
    finally:
        fx_traceback.current_meta[meta_key] = original_val


def set_partitioner_tag_is_backward() -> AbstractContextManager[None]:
    return set_partitioner_tag("is_backward")


def set_partitioner_tag_must_be_in_backward() -> AbstractContextManager[None]:
    return set_partitioner_tag("must_be_in_backward")


def set_partitioner_tag_must_be_in_forward() -> AbstractContextManager[None]:
    return set_partitioner_tag("must_be_in_forward")


@dataclass
class MutationCounters:
    mc_data: int
    mc_storage: int
    mc_inductor_storage_resized: int


T = TypeVar("T")


def sc_visit(
    t: torch.Tensor,
    fn: Callable[[Tensor], T],
    reduce_fn: Callable[[T, T], T],
    accum_init: T,
) -> T:
    if not is_traceable_wrapper_subclass(t):
        return fn(t)

    accum = accum_init

    def visit(e: Any) -> None:
        if not is_traceable_wrapper_subclass(e):
            nonlocal accum
            accum = reduce_fn(accum, fn(e))
            return

        for a in e.__tensor_flatten__()[0]:
            visit(getattr(e, a))

    visit(t)
    return accum


def _get_mutation_counter(t: torch.Tensor) -> int:
    return sc_visit(
        t,
        lambda t: torch._functionalize_mutation_counter(t.elem),  # type: ignore[attr-defined]
        lambda l, r: max(l, r),
        -1,
    )


def _get_storage_changed_counter(t: torch.Tensor) -> int:
    return sc_visit(
        t,
        lambda t: torch._functionalize_storage_changed_counter(t.elem),  # type: ignore[attr-defined]
        lambda l, r: max(l, r),
        -1,
    )


def _get_inductor_storage_resized_counter(t: torch.Tensor) -> int:
    return sc_visit(
        t,
        lambda t: torch._functionalize_inductor_storage_resized_counter(t.elem),  # type: ignore[attr-defined]
        lambda l, r: max(l, r),
        -1,
    )


def _get_mutation_counters(t: torch.Tensor) -> MutationCounters:
    return MutationCounters(
        _get_mutation_counter(t),
        _get_storage_changed_counter(t),
        _get_inductor_storage_resized_counter(t),
    )


def apply_in_graph_mutations(
    input_info: InputAliasInfo,
    inpt_old: Tensor,
    inpt_new: Tensor,
    f_inpt: Tensor,
    input_idx: int,
    mcs: Optional[MutationCounters] = None,
    applied_mcs: Optional[MutationCounters] = None,
) -> None:
    if input_info.mutation_type != MutationType.MUTATED_IN_GRAPH:
        raise AssertionError(
            f"expected mutation_type MUTATED_IN_GRAPH, got {input_info.mutation_type}"
        )
    # See Note [set_() Input Mutations in AOTAutograd]
    # all mutations on the input must be under no_grad, so it is safe to put in the graph
    # Here, we're saying that if an input experienced a set call, inp.set_(other),
    # then we can effectively not have to worry about whether its data was mutated.
    # There are 3 cases:
    # (1) We mutate inp *after* the set_() call. other is a graph intermediate.
    #     In this case, we're not really mutating the input storage of "inp";
    #     we're mutating the storage of an intermdiate value (other),
    #     and slamming that storage into the input tensor. So no data mutation is necessary.
    # (2) We mutate inp *after* the set_() call. other is a graph *input*.
    #     In this case, the data mutation will be properly handled in the runtime
    #     epilogue during the processing of "other"
    # (3) We mutate inp *before* the set_() call.
    #     This case is *not* currently handled.
    if input_info.mutates_storage_metadata:
        if mcs is None or mcs.mc_storage > applied_mcs.mc_storage:  # type: ignore[union-attr]
            with torch.no_grad():
                # pyrefly: ignore[no-matching-overload]
                inpt_old.set_(inpt_new)

    # Note [Ordering of resize_() and set_()]
    # Importantly: the common usage in FSDP is that we have a dummy parameter
    # that sees a set_() and **Then** a resize_().
    # We must put those mutations into the graph in the same order,
    # Since running them in the opposite order will have different behavior.
    # We fully ban resize_() followed by set_() for now, although in principal
    # we could support this
    if input_info.mutation_inductor_storage_resize:
        if (
            mcs is None
            or mcs.mc_inductor_storage_resized > applied_mcs.mc_inductor_storage_resized  # type: ignore[union-attr]
        ):
            # resizing is not supported on subclasses (we error earlier if this happens)
            from torch._subclasses.functional_tensor import FunctionalTensor

            if not isinstance(f_inpt, FunctionalTensor):
                raise AssertionError(f"expected FunctionalTensor, got {type(f_inpt)}")
            old_storage_size = torch._functionalize_get_storage_size(  # type: ignore[attr-defined]
                f_inpt.elem, before=True
            )
            new_storage_size = torch._functionalize_get_storage_size(  # type: ignore[attr-defined]
                f_inpt.elem, before=False
            )
            if old_storage_size != new_storage_size:
                if not (old_storage_size == 0 or new_storage_size == 0):
                    raise AssertionError(f"""\
        Encosize during tracing on input {input_idx}. Old nbytes={old_storage_size}, new nbytes={new_storage_size}
        We oresizing on graph inputs as long as the input either starts or ends with a storage size of 0
        (thee for FSDP)""")
                torch.ops.inductor.resize_storage_bytes_(inpt_old, new_storage_size)
            if new_storage_size == 0:
                # Even if we marked the input as having a data mutation (thus needing a copy_()),
                # We should **ignore** it if our input has no storage
                # (this can happen if, e.g. we temporarily resize our input, copy data into it,
                #  and resize it back down to zero)
                return

    # Optimization: if the copy_() is a no-op then don't include it in the graph.
    # In theory inductor could optimize this away, however in fsdp, we end up with
    # param.copy_(param), where param is a zero-storage-size tensor,
    # and running this op in eager mode (using the aot_eager backend) will result in a segfault.
    # So we may as well optimize it away here.
    if inpt_old is inpt_new:
        # (This check needs to be done after putting resize_() in the graph,
        # since a resize_(0) doesn't actually change the FunctionalTensor's inner tensor)
        return
    # We found an input that had a (data-only) mutation.
    # Since keep_input_mutations is set, we need to faithfully apply a copy_()
    # so the compiler will see the input mutation in the graph.

    if not input_info.mutates_data:
        return

    if mcs is not None and mcs.mc_data <= applied_mcs.mc_data:  # type: ignore[union-attr]
        return

    if input_info.mutations_hidden_from_autograd:
        # Hidden from autograd = run under no_grad, **and** don't bump VC
        # (although if the tensor was created in inference mode, it has no VC)
        if inpt_old.is_inference():
            maybe_preserve_vc = nullcontext()
        else:
            maybe_preserve_vc = torch.autograd._unsafe_preserve_version_counter(
                inpt_old  # type: ignore[assignment]
            )
        with torch.no_grad(), maybe_preserve_vc:
            inpt_old.copy_(inpt_new)
    elif input_info.mutations_under_no_grad_or_inference_mode:
        # Under no_grad = run under no_grad (we still bump the VC though)
        # (inference_mode will also bump the VC, as long as the tensor in question
        # was created outside of inference_mode)

        with torch.no_grad():
            inpt_old.copy_(inpt_new)
    else:
        inpt_old.copy_(inpt_new)


# This creates the final function that we want to trace using make_fx(),
# in both aot_dispatch_autograd and aot_dispatch_base.
# Preconditions:
# - fn corresponds to the user's fw function
# - fn arguments have been flattened, duplicate arguments have been handled
# - In the returned function, the "primals" arguments *includes* synthetic bases.
# This function does the work of functionalizing the input function,
# and performing copy_() calls at the end of the function if `keep_input_mutations` is set.
# The function returned has signature that is either:
# (1) "traced_fn(primals: List[Any])" if trace_joint is False
# (2) "traced_fn(primals: List[Any], tangents: List[Any])" if trace_joint is True
# Returns a new (functionalized) function, and updated arguments to call it with.
def create_functionalized_fn(
    fn: Callable[..., Any],
    args: Any,
    args_descs: Any,
    *,
    meta: ViewAndMutationMeta,
    aot_config: AOTConfig,
    trace_joint: bool,
    joint_fn_handle: Optional[JointFnHandle] = None,
) -> Any:
    primals_after_forward = None
    f_args_after_forward = None
    f_args_mutation_counters_after_forward: Optional[list[MutationCounters]] = None
    inputs_mutated_in_graph = [
        info.mutation_type == MutationType.MUTATED_IN_GRAPH for info in meta.input_info
    ]
    has_input_mutated_in_graph = any(inputs_mutated_in_graph)

    @simple_wraps(fn)
    def _functionalized_f_helper(
        *args: list[FxValue],
    ) -> tuple[tuple[list[FxValue], list[Tensor]], list[Optional[AOTOutput]]]:
        with maybe_enable_thunkify():
            # See Note [Disabling Functionalize TLS Above Python Functionalization]
            disable_above = torch._C._ExcludeDispatchKeyGuard(
                torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize)
            )

            with disable_above:
                # The functionalization code here can potentially trigger traces
                # into the graph, but we'd prefer to NOT do this, because if we
                # trace them now, we will end up with FX nodes that don't have
                # module stack annotations, which makes unflattener unhappy.
                # Wrap inputs into functional wrappers
                f_args = pytree.tree_map(to_fun, args)

                if trace_joint and has_input_mutated_in_graph and joint_fn_handle:
                    # TODO(ivankobzarev): Support fw and bw mutations for subclasses
                    def _post_forward(primals: Any) -> None:
                        nonlocal primals_after_forward
                        primals_after_forward = pytree.tree_map(from_fun, primals)
                        nonlocal f_args_after_forward
                        f_args_after_forward = f_args[0]
                        nonlocal f_args_mutation_counters_after_forward
                        f_args_mutation_counters_after_forward = [
                            MutationCounters(-1, -1, -1)
                            if not inputs_mutated_in_graph[i]
                            else _get_mutation_counters(f_arg)
                            for i, f_arg in enumerate(f_args_after_forward)
                        ]

                    joint_fn_handle.post_forward = _post_forward

                # Run the joint
                f_outs, f_outs_descs = call_and_expect_output_descs(fn, f_args)

            if trace_joint:
                # We support a limited amount of mutation of graph inputs during the backward pass.
                # (This is used e.g. by Float8, which needs to update buffers during the backward pass)
                # Here, we perform extra checks for primals that were mutated in the **backward**
                # We're doing the checks here instead of doing them with the rest of the input mutation handling because:
                # - We need to detect inputs that were mutated in the backward **separately** from mutations that happened
                #   during the forward, because the handling is different: some input mutations from the forward
                #   can be only handled in a fw-only runtime epilogue, and in theory if we wanted to handle those same
                #   types of mutations in the backward we would need a bw-only runtime epilogue.
                # - We could in theory have our analysis pass differentiate mutations in the fw from mutations in
                #   the bw by running our analysis first on the fw-only graph, and then on the joint graph. This would
                #   require an extra round of tracing though, so it's more efficient to do in-line here.
                if not (
                    isinstance(args, tuple)
                    and len(args) == 2
                    and isinstance(args[0], (list, tuple))
                ):
                    raise AssertionError(
                        f"expected args to be tuple of (list/tuple, ...), got {type(args)}"
                    )
                # Only look at mutations that happened to forward inputs (e.g. fw buffers that were saved for bw)
                primals_before = args[0]
                primals_after = pytree.tree_map(from_fun, f_args[0])
                for idx, (f_inpt, before, after, inpt_info) in enumerate(
                    zip(f_args[0], primals_before, primals_after, meta.input_info)
                ):
                    # Store information about mutations in joint(for backward analysis)
                    joint_mutates_data = has_data_mutation(f_inpt)

                    joint_mutates_metadata = has_metadata_mutation(
                        f_inpt, before, check_only_storage_mutation=False
                    )

                    # Ban metadata mutations on fw inputs during the bw
                    if not inpt_info.mutates_metadata:
                        if joint_mutates_metadata:
                            raise AssertionError(
                                "Found a graph input that had its metadata mutated in the backward. This is not supported"
                            )

                    # Ban storage resizing on fw inputs during the bw
                    if not inpt_info.mutation_inductor_storage_resize:
                        if was_inductor_storage_resized(f_inpt):
                            raise AssertionError(
                                "Found a graph input that had storage resizing in the backward. This is not supported"
                            )

                    # Allow data mutations on fw inputs during the bw, but only if they do not require grad
                    # So we can guarantee that we can keep the mutations in the graph
                    if (
                        joint_mutates_data
                        and not inpt_info.mutates_data
                        and not inpt_info.mutates_storage_metadata
                    ):
                        # Not banning here mutations on inpt_info.requires_grad -
                        # we'll check at runtime and fail only when backward is under torch.is_grad_enabled (create_graph)
                        # Add node meta for copy_ for partitioner that this node should be in backward graph.
                        with (
                            torch.fx.traceback.preserve_node_meta(),
                            set_partitioner_tag_must_be_in_backward(),
                        ):
                            # before and after should be tensors if we're calling copy_ on them
                            if not (
                                isinstance(before, torch.Tensor)
                                and isinstance(after, torch.Tensor)
                            ):
                                raise AssertionError(
                                    f"expected both before and after to be Tensors, got {type(before)} and {type(after)}"
                                )
                            before.copy_(after)
                        meta.indices_of_inputs_that_requires_grad_with_mutations_in_bw.append(
                            idx
                        )
                # Now that we covered mutations to *forward* inputs during the backward,
                # we also need to cover mutations to *backward-only* inputs during the backward (e.g. mutation to a grad_out).
                # Today, we will just error in all cases of this happening unless someone needs us to support it.
                tangents_before = args[1]
                tangents_after = pytree.tree_map(from_fun, f_args[1])
                for f_inpt, before, after in zip(
                    f_args[1], tangents_before, tangents_after
                ):
                    if has_metadata_mutation(
                        f_inpt, before, check_only_storage_mutation=False
                    ):
                        raise AssertionError(
                            "Found an input to the backward that had metadata mutated "
                            "during the backward pass. This is not supported"
                        )
                    if has_data_mutation(f_inpt):
                        can_be_in_graph = _check_if_mutation_can_be_in_graph(
                            keep_input_mutations=True,
                            mutates_data=True,
                            mutates_metadata=False,
                            mutations_hidden_from_autograd=are_all_mutations_hidden_from_autograd(
                                f_inpt
                            ),
                            mutations_under_no_grad_or_inference_mode=are_all_mutations_under_no_grad_or_inference_mode(
                                f_inpt
                            ),
                            mutates_storage_metadata=False,
                            mutation_inductor_storage_resize=was_inductor_storage_resized(
                                f_inpt
                            ),
                            requires_grad=f_inpt.requires_grad,
                        )
                        if not can_be_in_graph:
                            raise AssertionError(
                                "a backward input that had data mutated in an autograd-aware way. This is not supported"
                            )
                        # Perform the input mutation
                        with torch.fx.traceback.preserve_node_meta():
                            # before and after should be tensors if we're calling copy_ on them
                            if not (
                                isinstance(before, torch.Tensor)
                                and isinstance(after, torch.Tensor)
                            ):
                                raise AssertionError(
                                    f"expected both before and after to be Tensors, got {type(before)} and {type(after)}"
                                )
                            before.copy_(after)

            if aot_config.keep_inference_input_mutations:
                # Note: This is a bit annoying. There's a layering issue here, where:
                # (1) functionalization needs to operate on **synthetic base** inputs, before unpacking them into the "real" inputs.
                # (2) For keep_input_mutations, we support tracing a call to copy_() directly on mutated inputs.
                #     However, we **only** want to support this for inputs that have data-only (and no metadata) mutations,
                #     because inductor (and backends in generally) would prefer not to see these (e.g. as_strided_(), resize_()).
                #     This makes it pretty difficult for this logic to operate on synthetic bases.
                # (3) In addition, there are cases where it's significantly cheaper to perform the copy on the individual
                #     (unpacked) input aliases, instead of the synthetic base.
                # Example case where (3) could be important:
                #
                #     def f(x, y):
                #         x.mul_(2)
                #         y.mul_(3)
                #         return x, y
                #    a = torch.ones(1'000'000)
                #    x, y = out(a[0:9], a[1:10])
                #
                # It would be much better to add copy_() calls into the graph for the two tiny slices, instead of materializing
                # a giant "updated synthetic base" and copying into a's entire storage.
                #
                # For now, we are pessimistically not performing the optimization from (3);
                # we will materialize an "updated" synthetic base, and copy it back to the synthetic input base.
                # This allows us to factor aot autograd much more nicely, since only one area of the code needs to worry
                # about synthetic bases.

                # Apply in graph forward mutations only in joint case.
                # Note: Mutations of primals in forward AND backward.
                # If we have mutations of the same input in forward and in backward,
                # we can not fuse them into one copy_ node. As in this case partitioner will put it
                # either in forward or in backward. This will lead to incorrect state
                # after forward and before backward.
                # We have to emit two copy_ nodes, marking with additional meta each node,
                # if it must be in forward or backward.
                # We memorize mutation counter of the inputs after forward.
                # Based on this after joint graph we check if backward also mutated input or not.
                # We emit copy_ only in the end of joint tracing, to provide invariant for joint
                # graph passes, that our graph is functional, except only some number of copy_ nodes
                # in the end.
                mcs_applied: list[MutationCounters] = [MutationCounters(0, 0, 0)] * len(
                    meta.input_info
                )
                if f_args_mutation_counters_after_forward is not None:
                    primals_before = args[0]
                    for idx, (f_inpt, before, after, inpt_info) in enumerate(
                        # pyrefly: ignore [no-matching-overload]
                        zip(
                            f_args_after_forward,  # type: ignore[arg-type]
                            primals_before,  # type: ignore[arg-type]
                            primals_after_forward,  # type: ignore[arg-type]
                            meta.input_info,
                        )
                    ):
                        if inpt_info.mutation_type != MutationType.MUTATED_IN_GRAPH:
                            continue

                        mcs_after_forward = f_args_mutation_counters_after_forward[idx]
                        with (
                            torch.fx.traceback.preserve_node_meta(),
                            set_partitioner_tag_must_be_in_forward(),
                            _proxy_tensor_disable_update_tensor_tracker(),
                        ):
                            apply_in_graph_mutations(
                                inpt_info,
                                before,
                                after,
                                f_inpt,
                                idx,
                                mcs_after_forward,
                                mcs_applied[idx],
                            )
                            mcs_applied[idx] = mcs_after_forward

                for idx, (inpt_old, f_inpt) in enumerate(
                    zip(args, f_args) if not trace_joint else zip(args[0], f_args[0])  # type: ignore[arg-type]
                ):
                    if not isinstance(f_inpt, torch.Tensor):
                        continue
                    if not is_fun(f_inpt):
                        raise AssertionError(
                            f"expected functional tensor, got {type(f_inpt)}"
                        )
                    inpt_new = from_fun(f_inpt)
                    if (
                        meta.input_info[idx].mutation_type
                        != MutationType.MUTATED_IN_GRAPH
                    ):
                        continue
                    mcs: Optional[MutationCounters] = None
                    if f_args_mutation_counters_after_forward is not None:
                        # This could happen for subclasses tracing
                        # Subclasses support for mutations in fw and bw is TBD.
                        mcs = _get_mutation_counters(f_inpt)
                        if mcs == mcs_applied[idx]:
                            # No mutation in backward; mutation was already applied.
                            continue

                    with (
                        torch.fx.traceback.preserve_node_meta(),
                        set_partitioner_tag_must_be_in_backward(),
                    ):
                        apply_in_graph_mutations(
                            meta.input_info[idx],
                            # pyrefly: ignore[bad-argument-type]
                            inpt_old,
                            # pyrefly: ignore[bad-argument-type]
                            inpt_new,
                            f_inpt,
                            idx,
                            mcs,
                            mcs_applied[idx],
                        )

                # When an output tensor is a functionalized mutated input, and we
                # were able to move the mutation in to the graph then we can return
                # the mutated input directly. This prevents duplicating the
                # tensors contents.
                flat_outs, outs_spec = pytree.tree_flatten(f_outs)
                flat_outs = [from_fun(o) for o in flat_outs]
                num_outs = len(meta.output_info)

                for i in range(num_outs):
                    info = meta.output_info[i]
                    if info.output_type != OutputType.is_input:
                        continue

                    if info.base_idx is None:
                        raise AssertionError("info.base_idx must not be None")
                    if (
                        meta.input_info[info.base_idx].mutation_type
                        == MutationType.MUTATED_IN_GRAPH
                    ):
                        fw_args = args[0] if trace_joint else args
                        flat_outs[i] = fw_args[info.base_idx]
                return pytree.tree_unflatten(flat_outs, outs_spec), f_outs_descs

            return pytree.tree_map(from_fun, f_outs), f_outs_descs

    # Kinda annoying, but needed to make sure that the fx graph we trace out has "primals"
    # and "tangents" as its input names (which are special-cased by the partitioner)
    # TODO (tmanlaibaatar) revisit this if we ever need to turn on non-strict joint graph export
    def joint_helper(primals: list[FxValue], tangents: list[FxValue]) -> Any:
        return _functionalized_f_helper(primals, tangents)

    helper = joint_helper if trace_joint else _functionalized_f_helper
    if config.functionalize_rng_ops:
        # Setup the wrapper for functionalization of rng ops
        helper, args, args_descs = create_functionalized_rng_ops_wrapper(
            helper, args, args_descs, trace_joint
        )

    return helper, args, args_descs


def handle_effect_tokens_fn(
    fn: Callable[..., Any],
    args: Any,
    args_descs: list[AOTInput],
    *,
    meta: ViewAndMutationMeta,
    trace_joint: bool,
) -> Any:
    num_tokens = len(meta.tokens)

    @simple_wraps(fn)
    def inner_fn(*args: Any) -> Any:
        # See Note [Disabling Functionalize TLS Above Python Functionalization]
        disable_above = torch._C._ExcludeDispatchKeyGuard(
            torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize)
        )

        with disable_above:
            # See Note [Side-Effectful Tokens in AOTAutograd]
            if trace_joint:
                if not (isinstance(args, tuple) and isinstance(args[0], (list, tuple))):
                    raise AssertionError(
                        f"expected args to be tuple with first element as list/tuple, got {type(args)}"
                    )
                tokens = args[0][:num_tokens]
                if not all(token.numel() == 0 for token in tokens):
                    raise AssertionError("all tokens must have numel() == 0")
                args = (args[0][num_tokens:], *args[1:])
            else:
                tokens = args[:num_tokens]
                if not all(token.numel() == 0 for token in tokens):
                    raise AssertionError("all tokens must have numel() == 0")
                args = args[num_tokens:]

            # Populate the current FunctionalTensorMode with the tokens per
            # operator. See Note [FunctionalTensorMode is Stateful]
            functional_tensor_mode = torch.utils._python_dispatch._detect_infra_mode(
                torch._C._TorchDispatchModeKey.FUNCTIONAL
            )
            if functional_tensor_mode is None:
                raise AssertionError("functional_tensor_mode must not be None")
            f_tokens = pytree.tree_map(to_fun, tokens)
            for i, k in enumerate(meta.tokens.keys()):
                functional_tensor_mode._tokens[k] = f_tokens[i]

            # Run the joint
            outs, outs_descs = call_and_expect_output_descs(fn, args)

        # Return both the tokens and the outputs
        # See Note [Side-Effectful Tokens in AOTAutograd]
        if trace_joint:
            if len(outs) != 2:
                raise AssertionError(
                    f"expected len(outs) == 2 for joint trace, got {len(outs)}"
                )
            if len(functional_tensor_mode._tokens_forward_output) != num_tokens:
                raise AssertionError(
                    f"expected {num_tokens} forward output tokens, got {len(functional_tensor_mode._tokens_forward_output)}"
                )
            fwd_out_tokens = functional_tensor_mode._tokens_forward_output.values()

            bwd_out_tokens = functional_tensor_mode._tokens.values()

            f_fwd_out_tokens = [from_fun(t) for t in fwd_out_tokens]
            f_bwd_out_tokens = [from_fun(t) for t in bwd_out_tokens]
            f_fwd_out_tokens_descs = [
                ForwardTokenAOTOutput(i) for i in range(len(fwd_out_tokens))
            ]
            f_bwd_out_tokens_descs = [
                BackwardTokenAOTOutput(i) for i in range(len(bwd_out_tokens))
            ]

            meta.num_backward_tokens = len(bwd_out_tokens)
            return (
                ((*f_fwd_out_tokens, *outs[0]), (*outs[1], *f_bwd_out_tokens)),
                (
                    (*f_fwd_out_tokens_descs, *outs_descs[0]),
                    (*outs_descs[1], *f_bwd_out_tokens_descs),
                ),
            )

        out_tokens = [from_fun(t) for t in functional_tensor_mode._tokens.values()]
        # TODO: can probably do a little more resolution here
        out_tokens_descs = [
            ForwardTokenAOTOutput(i)
            for i in range(len(functional_tensor_mode._tokens.values()))
        ]
        return ((*out_tokens, *outs), (*out_tokens_descs, *outs_descs))

    # Additionally pass in tokens as inputs
    # See Note [Side-Effectful Tokens in AOTAutograd]
    additional_fwd_token_inputs = [torch.tensor([])] * num_tokens
    additional_fwd_token_inputs_descs = [
        ForwardTokenAOTInput(i) for i in range(num_tokens)
    ]

    if trace_joint:
        args = ([*additional_fwd_token_inputs, *args[0]], *args[1:])
        args_descs = (  # type: ignore[assignment]
            [*additional_fwd_token_inputs_descs, *args_descs[0]],  # type: ignore[misc]
            *args_descs[1:],
        )
    else:
        args = [*additional_fwd_token_inputs, *args]
        args_descs = [*additional_fwd_token_inputs_descs, *args_descs]
    return inner_fn, args, args_descs


# Given a function operating on Subclass -> Subclass, returns an function that operates on Tensor -> Tensor
# Also returns:
# - the new set of arguments to pass into this function (now that tensor subclasses have been eliminated)
# - the updated ViewAndMutationMeta for this dense -> dense function.
# The other important arguments are:
# - flat_fn_maybe_joint: when is_joint_structure=True, this is the joint fw-bw function.
#                        when is_joint_structure=False, this is just the forward function.
# - fw_only: this is *always* the forward-only function.
#   Why do we need this? We need to collect updated ViewAndMutationMeta on our new dense -> dense functions.
#   In particular, we need this to tell the partitioner how many dense forward outputs there are.
def aot_dispatch_subclass(
    flat_fn_maybe_joint: Union[JointTraceFn, TraceFn],
    args: Union[list[FxValue], tuple[list[FxValue], list[FxValue]]],
    args_descs: Union[list[AOTInput], tuple[list[AOTInput], list[AOTInput]]],
    *,
    is_joint_structure: bool,
    meta: ViewAndMutationMeta,
    fw_only: Callable[..., Any],
) -> SubclassTracingInfo:
    # Skip logic if we don't need to trace through any subclasses
    req_subclass_dispatch = requires_subclass_dispatch(args, meta)  # type: ignore[arg-type]
    if not req_subclass_dispatch:
        return SubclassTracingInfo(
            plain_tensor_trace_fn=flat_fn_maybe_joint,
            plain_tensor_args=args,
            plain_tensor_args_descs=args_descs,
            maybe_subclass_meta=None,
        )

    # TODO: add subclass guards (later PR).

    # What's going on here? We need to compute subclass metadata about the outputs of the joint (grad_inputs).
    # Annoying: we don't know the grad input metas until we're in the middle of tracing the joint,
    # so we set it later, while we're tracing the joint (see inner_fn() below).
    # Another option would be to run our run_functionalized_fw_and_collect_metadata() function
    # directly on the joint, but this would hurt compile time (adding yet another pass through the joint).
    subclass_meta = SubclassMeta()

    # NB: doesn't take descs, this is going from the NEW flat_args to the
    # subclasses, we don't need to do bookkeeping here
    def inner_fn(fn: Callable[..., Any], args: Any, *, use_trace_joint: bool) -> Any:
        # Step 1: wrap tensor inputs into subclasses if necessary
        all_args = wrap_tensor_subclasses_maybe_joint(
            args, is_joint_structure=use_trace_joint, meta=meta
        )

        # Step 2: call the inner function, with our (maybe subclass) inputs
        wrapped_outs, wrapped_outs_descs = call_and_expect_output_descs(fn, all_args)  # type: ignore[arg-type]

        if use_trace_joint:
            # See Note: [Computing Subclass Metadata about grad_inputs]
            # We also stash subclass info on our grad_inputs, if we're tracing the joint.
            nonlocal subclass_meta
            if not (isinstance(wrapped_outs, tuple) and len(wrapped_outs) == 2):
                raise AssertionError(
                    f"expected wrapped_outs to be tuple of length 2, got {type(wrapped_outs)}, {wrapped_outs_descs}"
                )
            # Don't need fw outs since we already have subclass metadata on them
            grad_inputs = wrapped_outs[1]
            subclass_meta.grad_input_metas = create_subclass_meta(grad_inputs)

            # Add extra symints as outputs to the forward/backward graphs
            # ignore nested ints here
            forward_outs, forward_outs_descs = unwrap_tensor_subclasses(
                wrapped_outs[0], wrapped_outs_descs[0], append_symints=True
            )
            # ignore nested ints here
            backward_outs, backward_outs_descs = unwrap_tensor_subclasses(
                wrapped_outs[1], wrapped_outs_descs[1], append_symints=True
            )
            return (
                (forward_outs, backward_outs),
                (forward_outs_descs, backward_outs_descs),
            )

        # Step 3: Unwrap any subclass outputs back into dense tensors
        return unwrap_tensor_subclasses(
            wrapped_outs, wrapped_outs_descs, append_symints=True
        )

    def joint_fn(
        primals: list[FxValue], tangents: list[FxValue]
    ) -> tuple[
        tuple[list[FxValue], list[FxValue]], tuple[list[AOTOutput], list[AOTOutput]]
    ]:
        with maybe_enable_thunkify():
            return inner_fn(
                flat_fn_maybe_joint, (primals, tangents), use_trace_joint=True
            )

    def fw_fn(*primals: FxValue) -> tuple[list[FxValue], list[AOTOutput]]:
        with maybe_enable_thunkify():
            return inner_fn(flat_fn_maybe_joint, primals, use_trace_joint=False)

    def metadata_fn(*primals: FxValue) -> tuple[list[FxValue], list[AOTOutput]]:
        @simple_wraps(fw_only)
        def inner_fw_only(*args: Any) -> Any:
            return call_and_expect_output_descs(fw_only, args)

        return inner_fn(inner_fw_only, primals, use_trace_joint=False)

    if is_joint_structure:
        # Add extra symints (size/strides) as input to the forward graph
        primals_unwrapped_pair = unwrap_tensor_subclasses(
            args[0],  # type: ignore[arg-type]
            args_descs[0],  # type: ignore[arg-type]
            append_symints=True,
        )
        # We pass append_symints=False here because the partitioner will
        # capture and add any extra argument
        tangents_unwrapped_pair = unwrap_tensor_subclasses(
            args[1],  # type: ignore[arg-type]
            args_descs[1],  # type: ignore[arg-type]
            append_symints=False,
        )

        args_unwrapped = (primals_unwrapped_pair[0], tangents_unwrapped_pair[0])
        args_descs_unwrapped = (primals_unwrapped_pair[1], tangents_unwrapped_pair[1])
        remapped_static_indices = remap_unwrapped_subclass_arg_indices(
            args[0],  # type: ignore[arg-type]
            meta.static_input_indices,  # type: ignore[arg-type]
        )
    else:
        args_unwrapped, args_descs_unwrapped = unwrap_tensor_subclasses(  # type: ignore[assignment]
            args,  # type: ignore[arg-type]
            args_descs,  # type: ignore[arg-type]
            append_symints=True,
        )
        remapped_static_indices = remap_unwrapped_subclass_arg_indices(
            args,  # type: ignore[arg-type]
            meta.static_input_indices,  # type: ignore[arg-type]
        )

    if is_joint_structure:
        primals_unwrapped = args_unwrapped[0]  # type: ignore[assignment]
        primals_unwrapped_descs = args_descs_unwrapped[0]  # type: ignore[assignment]
        fn_to_trace = joint_fn  # type: ignore[assignment]
    else:
        primals_unwrapped = args_unwrapped  # type: ignore[assignment]
        primals_unwrapped_descs = args_descs_unwrapped  # type: ignore[assignment]
        fn_to_trace = fw_fn  # type: ignore[assignment]

    # Note: [Partitioner handling for Subclasses, Part 1]
    # The way the partitioner works is that:
    # (1) we pass is a single graph containing the joint fw/bw,
    #     where the # of graph outputs corresponds to # fw_outputs + # grad_inputs
    # (2) The partitioner accepts an arguments, num_fwd_outputs,
    #     and assumes that the first "num_fwd_outputs" graph outputs correspond
    #     to outputs of the forward graph.
    # How do tensor subclasses enter the picture?
    # the num_fwd_outputs in the final graph is actually non-trivial to compute,
    # because it can be influenced by input mutations and intermediate bases.
    # So we compute it by inspecting the current ViewAndMutationMeta object.
    # However, the original ViewAndMutationMeta that we computed was created
    # on the subclass -> subclass graph,
    # which can have a different number of outputs than the dense -> dense graph.
    # That's why we created a fresh metadata object on the dense -> dense function here,
    # and plumb it back up to the partitioner.
    # See Note: [Partitioner handling for Subclasses, Part 2] for more info.
    meta_updated = run_functionalized_fw_and_collect_metadata(
        without_output_descs(metadata_fn),
        # pyrefly: ignore [bad-argument-type]
        flat_args_descs=primals_unwrapped_descs,
        static_input_indices=remapped_static_indices,
        keep_input_mutations=meta.keep_input_mutations,
        is_train=meta.is_train,
        # pyrefly: ignore [not-iterable]
    )(*primals_unwrapped)

    subclass_meta.fw_metadata = meta_updated

    return SubclassTracingInfo(
        plain_tensor_trace_fn=fn_to_trace,
        plain_tensor_args=args_unwrapped,
        plain_tensor_args_descs=args_descs_unwrapped,
        maybe_subclass_meta=subclass_meta,
    )


def create_functional_call(
    mod: Any,
    params_spec: Any,
    params_len: int,
    store_orig_mod: bool = False,
    strict_out_tuple: bool = True,
) -> Callable[..., Any]:
    # Redundant with dynamo, but worth having in case this gets invoked elsewhere.
    # https://github.com/pytorch/pytorch/issues/103569

    @simple_wraps(mod)
    def functional_call(*args: Any, **kwargs: Any) -> Any:
        flat_params = args[:params_len]
        if isinstance(params_spec, TreeSpec):
            params = pytree.tree_unflatten(flat_params, params_spec)
        else:
            if not isinstance(params_spec, list):
                raise AssertionError(
                    f"expected params_spec to be a list, got {type(params_spec)}"
                )
            params = dict(zip(params_spec, flat_params))
        with (
            stateless._reparametrize_module(mod, params),
            maybe_disable_thunkify(),
        ):
            if isinstance(mod, torch.fx.GraphModule):
                if kwargs:
                    # Handle **kwargs. FX only natively supports positional
                    # arguments (through placeholders).
                    arg_list = list(args[params_len:])
                    arg_list.extend(list(kwargs.values()))
                    args = tuple(arg_list)
                else:
                    args = args[params_len:]

                with fx_traceback.preserve_node_meta(), warnings.catch_warnings():
                    warnings.filterwarnings(
                        "ignore", "Anomaly Detection has been enabled."
                    )
                    with torch.autograd.detect_anomaly(check_nan=False):
                        fake_mode = detect_fake_mode()
                        if fake_mode is None:
                            raise AssertionError("fake_mode must not be None")
                        fake_mode.epoch += 1
                        out = PropagateUnbackedSymInts(mod).run(*args)
            else:
                out = mod(*args[params_len:], **kwargs)

        if strict_out_tuple and not isinstance(out, (tuple, list)):
            raise RuntimeError(
                "Graph output must be a (). This is so that we can avoid "
                "pytree processing of the outputs. Please change the module to "
                "have tuple outputs or use aot_module instead."
            )
        return out

    # Note [Preserving the nn module stack metadata during export non-strict mode]
    # This path is currently only used by the non-strict export flow,
    # where we cannot rely on dynamo to preserve nn stack metadata in our captured graph.
    # Instead, we stash the original user nn module here, and rely on `make_fx` to grab
    # this stashed module and use it to track nn module stack metadata
    if store_orig_mod and not hasattr(functional_call, "_orig_mod"):
        functional_call._orig_mod = mod  # type: ignore[attr-defined]

    return functional_call
