import contextlib
import functools
import itertools
import logging
from collections.abc import Callable
from typing import Any

import sympy

import torch
import torch._inductor.config as config
from torch._dynamo.utils import counters
from torch._inductor import ir
from torch._inductor.autotune_process import (
    SubgraphCPUBenchmarkRequest,
    SubgraphGPUBenchmarkRequest,
    TensorMeta,
)
from torch._inductor.codegen.common import KernelTemplate
from torch._inductor.ir import (
    Buffer,
    FixedLayout,
    get_free_symbols,
    get_symbolic_inputs,
    gm_original_output_strides,
    ir_node_to_tensor,
    Layout,
)
from torch._inductor.runtime.benchmarking import benchmarker
from torch._inductor.utils import do_bench_using_profiling
from torch._inductor.virtualized import V
from torch.utils._ordered_set import OrderedSet


log = logging.getLogger(__name__)


def inline_subgraph_to_ir_nodes(
    gm: torch.fx.GraphModule, inputs: list[Any], name: str
) -> Any:
    """Inline a subgraph by converting its FX operations to individual IR nodes.

    This converts a subgraph to multiple ComputedBuffer nodes (fusable),
    enabling epilogue fusion with subsequent operations.

    Returns:
        TensorBox containing the final operation result as individual IR nodes
    """
    from torch._inductor.lowering import process_subgraph_nodes

    # Temporarily switch V.graph.module to subgraph during processing; restore to prevent IR nodes added to wrong graph
    original_module = V.graph.module
    try:
        V.graph.module = gm
        return process_subgraph_nodes(gm, inputs)
    finally:
        V.graph.module = original_module


class SubgraphChoiceCaller(ir.ChoiceCaller):
    """
    Represents a Subgraph Autotuning choice, and the subgraph can be any arbitrary
    GraphModule. Compiles the Subgraph down to a module for benchmarking.
    """

    def __init__(
        self,
        name: str,
        input_nodes: list[Buffer],
        layout: Layout,
        description: str,
        make_fx_graph: Callable[..., Any],
        input_gen_fns: dict[int, Callable[[Any], torch.Tensor]] | None = None,
    ) -> None:
        super().__init__(name, input_nodes, layout, description)

        # Create inputs for tracing and benchmarking:
        # - trace_inputs: Always use symbolic inputs for tracing so the graph works
        #   for any size at runtime. This correctly captures shape-dependent operations
        #   like x * x.shape[0].
        # - benchmark_inputs: Use input_gen_fns if provided (for range-specific benchmarks)
        #   or concrete inputs from ir_node_to_tensor with guard_shape=False.
        trace_inputs = []
        self.benchmark_inputs = []
        with V.fake_mode:
            for i, inp in enumerate(self.input_nodes):
                # Here there will be no unbacked symbols, as SubgraphBuffer does not support them
                assert len(get_free_symbols(inp.get_size(), unbacked_only=True)) == 0
                assert len(get_free_symbols(inp.get_stride(), unbacked_only=True)) == 0

                inp.data.freeze_layout()  # type: ignore[attr-defined]

                # Always use symbolic inputs for tracing
                trace_inputs.append(ir_node_to_tensor(inp))

                # Use input_gen_fn for benchmarking if provided, otherwise concrete sizes
                if input_gen_fns is not None and i in input_gen_fns:
                    self.benchmark_inputs.append(input_gen_fns[i](inp))
                else:
                    self.benchmark_inputs.append(
                        ir_node_to_tensor(inp, replace_symbols_with_hints=True)
                    )

        # Trace with symbolic inputs for guard detection
        self.gm = make_fx_graph(*trace_inputs)
        gm_original_output_strides(self.gm)
        # Store symbolic inputs for sym_input computation
        self.example_inputs = trace_inputs

        self.sym_inputs = get_symbolic_inputs(self.input_nodes)
        self.sym_input_values = self._compute_sym_input_values()

        # Cached decomposition info for range-based dispatch (set via cache_decomposition)
        self.decomposition: Callable[..., Any] | None = None
        self.decomposition_kwargs: dict[str, Any] = {}
        # Config patches to apply during kernel codegen (e.g., coordinate_descent_tuning)
        self.config_patches: dict[str, Any] = {}
        # Cache compiled module to avoid recompiling on every benchmark call
        self._compiled_module: Any = None
        # Cache benchmark request for async autotuning
        self._bmreq: (
            SubgraphGPUBenchmarkRequest | SubgraphCPUBenchmarkRequest | None
        ) = None

        # Pre-compile only if using async pipelined autotuning
        # Must happen in __init__ because compilation requires virtualized context (V.graph, V.debug)
        if config.pipeline_max_autotune_gemm:
            with V.fake_mode:
                self._compiled_module = self._compile_for_benchmarking()
                self._bmreq = self._create_benchmark_request()

    def _compute_sym_input_values(self) -> list[int]:
        """Extract concrete dimension values for sym_inputs from benchmark_inputs.

        The compiled module expects symbolic dimension values as runtime arguments.
        This maps each symbolic variable to its concrete value from the benchmark tensors.
        Used for range based autotuning.
        """
        sym_input_names = OrderedSet(
            [s.name for s in self.sym_inputs if hasattr(s, "name")]
        )

        # Build mapping: symbolic dimension name -> concrete value
        sym_name_to_value: dict[str, int] = {}
        for inp_node, benchmark_inp in zip(self.input_nodes, self.benchmark_inputs):
            if isinstance(benchmark_inp, torch.Tensor):
                for sym_dim, actual_dim in zip(
                    inp_node.get_size(), benchmark_inp.shape
                ):
                    if isinstance(sym_dim, sympy.Symbol):
                        sym_name_to_value[sym_dim.name] = int(actual_dim)
                    elif str(sym_dim) in sym_input_names:
                        sym_name_to_value[str(sym_dim)] = int(actual_dim)

        result = []
        for sym_var in self.sym_inputs:
            if isinstance(sym_var, sympy.Symbol) and sym_var.name in sym_name_to_value:
                result.append(sym_name_to_value[sym_var.name])
            else:
                hint = V.graph.sizevars.shape_env.optimization_hint(sym_var, fallback=1)
                result.append(int(hint))
        return result

    def cache_decomposition(
        self, decomposition: Callable[..., Any], kwargs: dict[str, Any]
    ) -> None:
        """Cache decomposition function and kwargs for range-based dispatch lookup."""
        self.decomposition = decomposition
        self.decomposition_kwargs = kwargs

    def __str__(self) -> str:
        return f"SubgraphCaller({self.name})"

    def _compile_for_benchmarking(self) -> Any:
        """Compile the subgraph for benchmarking, returns the compiled module."""
        from torch._inductor.graph import GraphLowering

        safe_name = self.name.replace("::", "_").replace(".", "_")

        # Symbolic inputs produce code with symbolic sizes resolved at runtime
        # via sym_input_values. Without symbols (static shapes), compile with
        # benchmark_inputs whose sizes match the actual benchmark tensors.
        compile_inputs = (
            self.example_inputs if self.sym_inputs else self.benchmark_inputs
        )
        log.debug("Benchmark compile %s: sym_inputs=%s", self.name, self.sym_inputs)

        assert self.gm is not None
        bm_graph_lowering = GraphLowering(
            gm=self.gm,
            example_inputs=compile_inputs,
            shape_env=V.graph._shape_env,
            cpp_wrapper=V.graph.cpp_wrapper,
            aot_mode=V.graph.aot_mode,
            extern_node_serializer=V.graph.extern_node_serializer,
            is_inference=V.graph.is_inference,
            is_backward=V.graph.is_backward,
            name=f"benchmark_{safe_name}",
        )

        for sym_inp in self.sym_inputs:
            bm_graph_lowering.graph_inputs[sym_inp.name] = sym_inp
            bm_graph_lowering.graph_input_names.append(sym_inp.name)

        with V.set_graph_handler(bm_graph_lowering):
            # Apply config_patches during benchmarking (e.g., coordinate_descent_tuning)
            # Also disable max_autotune to avoid nested autotuning
            benchmark_config: dict[str, Any] = {
                "max_autotune": False,
                "max_autotune_gemm": False,
                "max_autotune_gemm_backends": "ATEN",
                "benchmark_fusion": False,
                "pipeline_max_autotune_gemm": False,
                **self.config_patches,
            }
            with config.patch(benchmark_config):
                bm_graph_lowering.run(*compile_inputs)
                return bm_graph_lowering.compile_to_module()

    def _create_benchmark_request(
        self,
    ) -> SubgraphGPUBenchmarkRequest | SubgraphCPUBenchmarkRequest:
        """Create a benchmark request for async autotuning."""
        assert self._compiled_module is not None, (
            "Module must be compiled before creating benchmark request"
        )
        input_tensor_meta = TensorMeta.from_irnodes(self.input_nodes)
        output_tensor_meta = TensorMeta.from_irnodes(self.layout)

        if self.layout.device.type == "cpu":
            bmreq_cls = SubgraphCPUBenchmarkRequest
        else:
            bmreq_cls = SubgraphGPUBenchmarkRequest

        return bmreq_cls(
            kernel_name=self.name,
            input_tensor_meta=input_tensor_meta,
            output_tensor_meta=output_tensor_meta,
            extra_args=tuple(),
            module_path=self._compiled_module.__file__,
            module_cache_key=self._compiled_module.key,
            sym_input_values=self.sym_input_values,
        )

    @property
    def bmreq(
        self,
    ) -> SubgraphGPUBenchmarkRequest | SubgraphCPUBenchmarkRequest:
        """Benchmark request for async autotuning. Pre-compiled when pipeline_max_autotune_gemm is enabled."""
        assert self._bmreq is not None, (
            "bmreq accessed but pipeline_max_autotune_gemm was not enabled during __init__"
        )
        return self._bmreq

    def _ensure_compiled(self) -> None:
        """Ensure the module is compiled. Used for lazy compilation in non-async path."""
        if self._compiled_module is None:
            self._compiled_module = self._compile_for_benchmarking()

    def benchmark(self, *args: list[Any], out: torch.Tensor) -> float:
        """Regular benchmarking: compile if needed, then use benchmarker."""
        self._ensure_compiled()
        bm_func = self._compiled_module.call
        sym_inputs = self.sym_input_values

        def fn() -> Any:
            return bm_func([*sym_inputs, *args])

        if self._benchmark_with_cudagraphs:
            return benchmarker.benchmark_gpu_with_cuda_graph(fn)

        if config.profile_bandwidth_with_do_bench_using_profiling:
            return do_bench_using_profiling(fn)
        return benchmarker.benchmark(
            fn,
            device=benchmarker.infer_device(*sym_inputs, *args),
        )

    def benchmark_collective(self, *args: list[Any], out: torch.Tensor) -> None:
        """Run once for collective benchmarking (barrier sync handled by caller)."""
        self._ensure_compiled()
        self._compiled_module.call([*self.sym_input_values, *args])

    def hash_key(self) -> str:
        assert self.gm is not None
        return "-".join(
            [
                self.name.rsplit("_", 1)[0],
                *[str(inp.get_size()) for inp in self.input_nodes],
                *[str(inp.get_stride()) for inp in self.input_nodes],
                str(self.gm.graph),
            ]
        )

    def output_node(self) -> ir.TensorBox:
        assert self.gm is not None
        return ir.TensorBox.create(
            ir.SubgraphBuffer(
                layout=self.layout,
                input_nodes=self.input_nodes,
                gm=self.gm,
                example_inputs=self.example_inputs,
                subgraph_name=self.name,
                config_patches=self.config_patches if self.config_patches else None,
            )
        )

    def info_dict(self) -> dict[str, Any]:
        """Information returned here is logged to the autotune log file when that is enabled."""
        return {
            "backend": "subgraph",
            "kernel_name": self.name,
        }

    def autoheuristic_id(self) -> str:
        return f"subgraph_{self.name}"


class SubgraphTemplate(KernelTemplate):
    """
    A template for subgraph evaluation to be used in autotuning.

    This class allows creating customized subgraphs that can be appended
    as choices during the autotuning process, enabling the selection of
    optimal implementations for complex operations.
    """

    index_counter = itertools.count()

    def __init__(
        self,
        name: str,
    ):
        """
        Initialize a subgraph template.

        Args:
            name: The name of this template
            graph: The FX graph
        """
        super().__init__(name=name)

    def generate(  # type: ignore[override]
        self,
        name: str,
        input_nodes: list[Buffer],
        layout: Layout,
        make_fx_graph: Callable[..., Any],
        description: str = "",
        input_gen_fns: dict[int, Callable[[Any], torch.Tensor]] | None = None,
        **kwargs: Any,
    ) -> SubgraphChoiceCaller:
        """
        Generate a SubgraphChoiceCaller instance for autotuning.

        Args:
            name: The name for this subgraph choice
            input_nodes: List of input nodes to the subgraph
            layout: Memory layout information for the output
            make_fx_graph: Callable that creates the FX graph for this subgraph
            description: Optional description of this choice
            input_gen_fns: Optional dict mapping input indices to tensor generators
            **kwargs: Additional keyword arguments

        Returns:
            SubgraphChoiceCaller: A callable object that can be used for autotuning
        """

        return SubgraphChoiceCaller(
            name=f"{name}_{next(SubgraphTemplate.index_counter)}",
            input_nodes=input_nodes,
            layout=layout,
            description=description,
            make_fx_graph=make_fx_graph,
            input_gen_fns=input_gen_fns,
        )

    def generate_custom_op_choices(
        self,
        name: str,
        decompositions: list[Callable[..., Any]],
        input_nodes: list[Buffer],
        non_tensor_args: list[dict[str, Any]],
        default_impl: Callable[..., Any] | None = None,
        input_gen_fns: dict[int, Callable[[Any], torch.Tensor]] | None = None,
        config_patches_list: list[dict[str, Any]] | None = None,
    ) -> list[SubgraphChoiceCaller]:
        """
        Generate multiple SubgraphChoiceCaller instances for custom op autotuning.

        This method extends SubgraphTemplate to support custom op decompositions,
        allowing multiple implementations to compete in autotuning.

        Args:
            name: Base name for the choices
            decompositions: List of decomposition functions to compete in autotuning
            input_nodes: List of tensor inputs. All tensor arguments must be passed here.
            non_tensor_args: List of non-tensor kwargs only, one dict per corresponding decomposition.
            default_impl: Default implementation for layout inference
            input_gen_fns: Optional dict mapping input indices to tensor generators
            config_patches_list: Optional list of config patches per decomposition

        Returns:
            List of SubgraphChoiceCaller instances for autotuning
        """
        if not decompositions:
            return []

        assert len(decompositions) == len(non_tensor_args), (
            f"decompositions and non_tensor_args must have same length, "
            f"got {len(decompositions)} decompositions and {len(non_tensor_args)} kwargs"
        )

        # Default to empty config_patches if not provided
        if config_patches_list is None:
            config_patches_list = [{} for _ in decompositions]

        # Infer layouts and ensure layout consistency for fair autotuning comparison
        layouts = [
            self._infer_custom_op_layout(
                input_nodes, decomp, kwargs, default_impl, input_gen_fns
            )
            for decomp, kwargs in zip(decompositions, non_tensor_args)
        ]

        # Validate all decompositions produce equivalent layouts for fair comparison
        self._validate_layout_equivalence(name, decompositions, layouts)
        layout = layouts[0]  # All layouts are now validated to be equivalent

        choices: list[SubgraphChoiceCaller] = []
        for decomp, decomp_kwargs, config_patches in zip(
            decompositions, non_tensor_args, config_patches_list
        ):
            # Create make_fx_graph function for this decomposition
            # Uses error_on_new_guards to detect impls that add guards
            from torch.fx.experimental.symbolic_shapes import _ShapeEnvGuardError

            def make_fx_graph(
                *args: Any,
                decomp: Callable[..., Any] = decomp,
                decomp_kwargs: dict[str, Any] = decomp_kwargs,
            ) -> Any:
                # decomp_kwargs contains all merged parameters: CustomOpConfig params + runtime kwargs

                from torch.fx.experimental.proxy_tensor import make_fx

                from ..decomposition import select_decomp_table

                decomposition_table = select_decomp_table()
                shape_env = V.fake_mode.shape_env

                # Use error_on_new_guards to detect impls that add guards during tracing
                guard_ctx = (
                    shape_env.error_on_new_guards()
                    if shape_env is not None
                    else contextlib.nullcontext()
                )
                with guard_ctx:
                    return make_fx(
                        functools.partial(decomp, **decomp_kwargs),
                        decomposition_table=decomposition_table,
                        tracing_mode="symbolic",
                    )(*args)

            # Generate descriptive name for this variant
            variant_name = self._generate_variant_name(decomp, decomp_kwargs)

            # Try to create choice; skip if it adds guards
            try:
                choice = self.generate(
                    name=f"{name}_{variant_name}",
                    input_nodes=input_nodes,
                    layout=layout,
                    make_fx_graph=make_fx_graph,
                    description=f"CustomOp {decomp.__name__}",
                    input_gen_fns=input_gen_fns,
                )
            except _ShapeEnvGuardError:
                log.info(
                    "Skipping decomposition %s: adds guards during tracing",
                    decomp.__name__,
                )
                counters["inductor"]["custom_op_decomp_guard_skips"] += 1
                continue

            # Cache decomposition info for range-based dispatch
            choice.cache_decomposition(decomp, decomp_kwargs)
            # Store config_patches for this choice
            choice.config_patches = config_patches
            choices.append(choice)

        return choices

    def _generate_variant_name(
        self, decomp: Callable[..., Any], kwargs: dict[str, Any]
    ) -> str:
        """Generate a descriptive name for a decomposition variant with its parameters."""
        import re

        base_name = decomp.__name__
        if not kwargs:
            return base_name

        def sanitize_value(v: Any) -> str:
            """Convert a value to a valid Python identifier component."""
            s = str(v)
            # Replace invalid characters with underscores
            s = re.sub(r"[^a-zA-Z0-9_]", "_", s)
            # Ensure it doesn't start with a digit
            if s and s[0].isdigit():
                s = "_" + s
            return s

        param_suffix = "_".join(
            f"{k}_{sanitize_value(v)}" for k, v in sorted(kwargs.items())
        )
        return f"{base_name}_{param_suffix}"

    def _validate_non_tensor_kwargs(self, kwargs: dict[str, Any]) -> None:
        """Validate that kwargs contains only non-tensor arguments."""
        for key, value in kwargs.items():
            assert not isinstance(value, (torch.Tensor, Buffer)), (
                f"kwargs['{key}'] contains tensor {type(value)}. "
                f"Tensor arguments should be in input_nodes, not kwargs. "
                f"Only scalar/non-tensor parameters should be in kwargs."
            )

    def _validate_layout_equivalence(
        self,
        op_name: str,
        decompositions: list[Callable[..., Any]],
        layouts: list[Layout],
    ) -> None:
        """Ensure all layouts have consistent stride, device, dtype, and sizes for fair autotuning."""
        if not layouts:
            return

        reference = layouts[0]
        for i, layout in enumerate(layouts[1:], start=1):
            if (layout.device, layout.dtype, layout.size, layout.stride) != (
                reference.device,
                reference.dtype,
                reference.size,
                reference.stride,
            ):
                raise AssertionError(
                    f"Layout mismatch in custom op '{op_name}': "
                    f"decomposition '{decompositions[i].__name__}' produces "
                    f"({layout.device}, {layout.dtype}, {layout.size}, {layout.stride}) "
                    f"but '{decompositions[0].__name__}' produces "
                    f"({reference.device}, {reference.dtype}, {reference.size}, {reference.stride})"
                )

    def _infer_custom_op_layout(
        self,
        input_nodes: list[Buffer],
        function_decomposition: Callable[..., Any],
        kwargs: dict[str, Any],
        default_impl: Callable[..., Any] | None = None,
        input_gen_fns: dict[int, Callable[[Any], torch.Tensor]] | None = None,
    ) -> Layout:
        """Infer output layout for custom ops using the default implementation when available.
        Note that the Subgraph assumes custom ops return exactly one tensor output.
        TODO: Add support for multiple output custom ops.
        """
        import functools

        from torch._inductor.virtualized import V

        # Assert kwargs contain only non-tensor arguments
        self._validate_non_tensor_kwargs(kwargs)

        with V.fake_mode:
            example_inputs = []
            for i, inp in enumerate(input_nodes):
                if input_gen_fns and i in input_gen_fns:
                    fake_tensor = input_gen_fns[i](inp)
                else:
                    raw_shape = inp.get_size()
                    concrete_shape = V.graph.sizevars.optimization_hints(raw_shape)
                    raw_stride = inp.get_stride()
                    concrete_stride = V.graph.sizevars.optimization_hints(raw_stride)
                    fake_tensor = torch.empty_strided(
                        concrete_shape,
                        concrete_stride,
                        dtype=inp.get_dtype(),
                        device=inp.get_device(),
                    )
                example_inputs.append(fake_tensor)

            fn = functools.partial(function_decomposition, **kwargs)
            output = fn(*example_inputs)

            # Assert single output
            assert isinstance(output, torch.Tensor), (
                f"Expected single tensor output, got {type(output)}. "
                f"Multi-output custom ops not yet supported in autotuning."
            )

            return FixedLayout(
                device=output.device,
                dtype=output.dtype,
                size=output.shape,
                stride=output.stride(),
            )
