from __future__ import annotations

import dataclasses
import hashlib
import math
import re
import typing_extensions
from typing import Any, cast, TYPE_CHECKING

import sympy  # noqa: TC002

import torch  # noqa: TC001
from torch.utils._ordered_set import OrderedSet
from torch.utils._sympy.functions import ModularIndexing

from .. import config
from ..runtime.runtime_utils import torch_dtype_to_jax
from ..utils import get_fused_kernel_name, get_kernel_metadata
from ..virtualized import V
from .block_analysis import BlockPatternMatcher
from .common import (
    BackendFeature,
    CSEVariable,
    IndentedBuffer,
    OpOverrides,
    PythonPrinter,
)
from .simd import IterationRangesEntry, SIMDKernel, SIMDScheduling


class PallasPrinter(PythonPrinter):
    """
    Custom sympy printer for Pallas that handles JAX-specific constructs.
    """

    def _print_Where(self, expr: sympy.Expr) -> str:
        """Convert sympy Where to jnp.where."""
        c = self.doprint(expr.args[0])
        p = self.doprint(expr.args[1])
        q = self.doprint(expr.args[2])
        return f"jnp.where({c}, {p}, {q})"

    def _print_Min(self, expr: sympy.Expr) -> str:
        """Convert sympy Min to jnp.minimum for JAX compatibility."""
        args = [self.doprint(arg) for arg in expr.args]
        result = args[0]
        for arg in args[1:]:
            result = f"jnp.minimum({result}, {arg})"
        return result

    def _print_Max(self, expr: sympy.Expr) -> str:
        """Convert sympy Max to jnp.maximum for JAX compatibility."""
        args = [self.doprint(arg) for arg in expr.args]
        result = args[0]
        for arg in args[1:]:
            result = f"jnp.maximum({result}, {arg})"
        return result


# Use Pallas-specific printer for expression generation
pallas_pexpr = PallasPrinter().doprint


if TYPE_CHECKING:
    from collections.abc import Callable, Sequence

    from ..ir import IRNode
    from ..ops_handler import ReductionType
    from ..scheduler import BaseSchedulerNode


# Main function suffix used in generated Pallas code
MAIN_SUFFIX = "main"

# Logger for Pallas kernel code
kernel_code_log = torch._logging.getArtifactLogger(__name__, "kernel_code")


class PallasKernelWrapper:
    """Wrapper to provide .run() interface for Pallas kernels"""

    def __init__(self, kernel_fn: Callable[..., Any], kernel_path: str | None = None):
        self.kernel_fn = kernel_fn
        self.kernel_path = kernel_path
        kernel_code_log.info("Pallas kernel path: %s", kernel_path)

    def run(self, *args, stream=None, **kwargs):
        """
        Execute the Pallas kernel.

        Args:
            *args: Arguments to pass to the kernel function
            stream: CUDA stream to pass to the kernel function
            **kwargs: Additional keyword arguments for the kernel

        Returns:
            Result of the kernel execution
        """
        return self.kernel_fn(*args, stream=stream, **kwargs)


class Unsupported(RuntimeError):
    """Exception raised when an operation is not supported by the Pallas backend."""


class PallasKernelOverrides(OpOverrides):
    """
    Map element-wise ops to JAX/Pallas operations.

    For now, we use the default Python operators which are compatible
    with JAX numpy broadcasting semantics.
    """

    @staticmethod
    # pyrefly: ignore [bad-override]
    def sin(x: str) -> str:
        return f"jnp.sin({x})"

    @staticmethod
    # pyrefly: ignore [bad-override]
    def cos(x: str) -> str:
        return f"jnp.cos({x})"

    @staticmethod
    # pyrefly: ignore [bad-override]
    def tan(x: str) -> str:
        return f"jnp.tan({x})"

    @staticmethod
    # pyrefly: ignore [bad-override]
    def sinh(x: str) -> str:
        return f"jnp.sinh({x})"

    @staticmethod
    # pyrefly: ignore [bad-override]
    def cosh(x: str) -> str:
        return f"jnp.cosh({x})"

    @staticmethod
    # pyrefly: ignore [bad-override]
    def tanh(x: str) -> str:
        return f"jnp.tanh({x})"

    @staticmethod
    # pyrefly: ignore [bad-override]
    def asin(x: str) -> str:
        return f"jnp.arcsin({x})"

    @staticmethod
    # pyrefly: ignore [bad-override]
    def acos(x: str) -> str:
        return f"jnp.arccos({x})"

    @staticmethod
    # pyrefly: ignore [bad-override]
    def atan(x: str) -> str:
        return f"jnp.arctan({x})"

    @staticmethod
    # pyrefly: ignore [bad-override]
    def exp(x: str) -> str:
        return f"jnp.exp({x})"

    @staticmethod
    def exp2(x: str) -> str:
        return f"jnp.exp2({x})"

    @staticmethod
    def expm1(x: str) -> str:
        return f"jnp.expm1({x})"

    @staticmethod
    # pyrefly: ignore [bad-override]
    def log(x: str) -> str:
        return f"jnp.log({x})"

    @staticmethod
    def log10(x: str) -> str:
        return f"jnp.log10({x})"

    @staticmethod
    def log2(x: str) -> str:
        return f"jnp.log2({x})"

    @staticmethod
    def log1p(x: str) -> str:
        return f"jnp.log1p({x})"

    @staticmethod
    # pyrefly: ignore [bad-override]
    def sqrt(x: str) -> str:
        return f"jnp.sqrt({x})"

    @staticmethod
    # pyrefly: ignore [bad-override]
    def rsqrt(x: str) -> str:
        return f"jax.lax.rsqrt({x})"

    @staticmethod
    # pyrefly: ignore [bad-override]
    def abs(x: str) -> str:
        return f"jnp.abs({x})"

    @staticmethod
    def neg(x: str) -> str:
        return f"(-{x})"

    @staticmethod
    # pyrefly: ignore [bad-override]
    def floor(x: str) -> str:
        return f"jnp.floor({x})"

    @staticmethod
    # pyrefly: ignore [bad-override]
    def ceil(x: str) -> str:
        return f"jnp.ceil({x})"

    @staticmethod
    # pyrefly: ignore [bad-override]
    def trunc(x: str) -> str:
        return f"jnp.trunc({x})"

    @staticmethod
    # pyrefly: ignore [bad-override]
    def round(x: str) -> str:
        return f"jnp.round({x})"

    @staticmethod
    def sigmoid(x: str) -> str:
        return f"jax.nn.sigmoid({x})"

    @staticmethod
    def relu(x: str) -> str:
        return f"jnp.maximum({x}, 0)"

    @staticmethod
    def pow(a: str, b: str) -> str:
        return f"jnp.power({a}, {b})"

    @staticmethod
    # pyrefly: ignore [bad-override]
    def maximum(a: str, b: str) -> str:
        return f"jnp.maximum({a}, {b})"

    @staticmethod
    # pyrefly: ignore [bad-override]
    def minimum(a: str, b: str) -> str:
        return f"jnp.minimum({a}, {b})"

    @staticmethod
    # pyrefly: ignore [bad-override]
    def where(cond: str, a: str, b: str) -> str:
        return f"jnp.where({cond}, {a}, {b})"

    @staticmethod
    def masked(mask: str, body: Callable[[], str], other: float) -> str:
        """
        Computes body, but only uses the result where mask is true.
        Where mask is false, uses the 'other' value instead.
        """
        result = body()
        # Format the 'other' value properly for JAX
        if isinstance(other, float):
            if math.isnan(other):
                other_str = "jnp.nan"
            elif math.isinf(other):
                other_str = "jnp.inf" if other > 0 else "-jnp.inf"
            else:
                other_str = repr(other)
        else:
            other_str = repr(other)
        # Use jnp.where to select between result and other based on mask
        return f"jnp.where({mask}, {result}, {other_str})"

    @staticmethod
    def to_dtype(
        x: str,
        dtype: torch.dtype,
        src_dtype: torch.dtype | None = None,
        use_compute_types: bool = True,
    ) -> str:
        # TPU doesn't support 64-bit types
        if dtype == torch.int64 and V.graph.get_current_device_or_throw().type == "tpu":
            dtype = torch.int32
        jax_dtype = torch_dtype_to_jax(dtype)
        # Wrap in jnp.asarray to handle scalars from integer indexing
        return f"jnp.asarray({x}).astype({jax_dtype})"

    @staticmethod
    def to_dtype_bitcast(x: str, dtype: torch.dtype, src_dtype: torch.dtype) -> str:
        """Bitcast a value from one dtype to another with the same size."""
        jax_dtype = torch_dtype_to_jax(dtype)
        jax_src_dtype = torch_dtype_to_jax(src_dtype)
        # First ensure the value is the correct source dtype, then bitcast
        return f"jax.lax.bitcast_convert_type(jnp.asarray({x}).astype({jax_src_dtype}), {jax_dtype})"

    @staticmethod
    def index_expr(expr: sympy.Expr, dtype: torch.dtype) -> str:
        """Convert a sympy expression to a JAX array indexing expression."""
        from ..utils import get_bounds_index_expr

        # Track which iteration variables are used
        V.kernel.used_iter_vars.update(V.kernel._get_used_iter_vars(expr))

        # Prepare and rename indexing to register size symbols as kernel args
        prepared = V.kernel.prepare_indexing(expr)
        renamed = V.kernel.rename_indexing(prepared)
        idx_str = V.kernel.kexpr(renamed)
        var = V.kernel.cse.generate(
            V.kernel.compute, idx_str, bounds=get_bounds_index_expr(expr)
        )
        return PallasKernelOverrides.to_dtype(var, dtype)

    @staticmethod
    def constant(val, dtype: torch.dtype) -> str:
        """Convert a constant value to JAX representation."""
        jax_dtype = torch_dtype_to_jax(dtype)
        if dtype == torch.bool:
            return "True" if val else "False"
        # Handle special float values
        if isinstance(val, float):
            if math.isnan(val):
                return "jnp.nan"
            if math.isinf(val):
                return "jnp.inf" if val > 0 else "-jnp.inf"
        return f"jnp.array({val}, dtype={jax_dtype})"

    @staticmethod
    def real(x: str) -> str:
        return f"jnp.real({x})"

    @staticmethod
    def imag(x: str) -> str:
        return f"jnp.imag({x})"

    @staticmethod
    def conj(x: str) -> str:
        return f"jnp.conj({x})"

    @staticmethod
    def angle(x: str) -> str:
        return f"jnp.angle({x})"

    @staticmethod
    def view_as_real(x: str) -> str:
        """View complex tensor as real tensor with extra dimension."""
        return f"jnp.stack([jnp.real({x}), jnp.imag({x})], axis=-1)"

    @staticmethod
    def view_as_complex(x: str) -> str:
        """View real tensor as complex tensor."""
        return f"({x}[..., 0] + 1j * {x}[..., 1])"

    # Comparison operations
    @staticmethod
    def eq(a: str, b: str) -> str:
        return f"({a} == {b})"

    @staticmethod
    def ne(a: str, b: str) -> str:
        return f"({a} != {b})"

    @staticmethod
    def lt(a: str, b: str) -> str:
        return f"({a} < {b})"

    @staticmethod
    def le(a: str, b: str) -> str:
        return f"({a} <= {b})"

    @staticmethod
    def gt(a: str, b: str) -> str:
        return f"({a} > {b})"

    @staticmethod
    # pyrefly: ignore [bad-override]
    def isnan(x: str) -> str:
        return f"jnp.isnan({x})"

    @staticmethod
    # pyrefly: ignore [bad-override]
    def isinf(x: str) -> str:
        return f"jnp.isinf({x})"

    @staticmethod
    def isfinite(x: str) -> str:
        return f"jnp.isfinite({x})"

    @staticmethod
    def ge(a: str, b: str) -> str:
        return f"({a} >= {b})"

    # Logical operations
    @staticmethod
    # pyrefly: ignore [bad-override]
    def logical_and(a: str, b: str) -> str:
        return f"jnp.logical_and({a}, {b})"

    @staticmethod
    # pyrefly: ignore [bad-override]
    def logical_or(a: str, b: str) -> str:
        return f"jnp.logical_or({a}, {b})"

    @staticmethod
    def logical_not(x: str) -> str:
        return f"jnp.logical_not({x})"

    @staticmethod
    # pyrefly: ignore [bad-override]
    def logical_xor(a: str, b: str) -> str:
        return f"jnp.logical_xor({a}, {b})"

    # Math operations
    @staticmethod
    # pyrefly: ignore [bad-override]
    def atan2(a: str, b: str) -> str:
        return f"jnp.arctan2({a}, {b})"

    @staticmethod
    # pyrefly: ignore [bad-override]
    def hypot(a: str, b: str) -> str:
        return f"jnp.hypot({a}, {b})"

    @staticmethod
    # pyrefly: ignore [bad-override]
    def fmod(a: str, b: str) -> str:
        return f"jnp.fmod({a}, {b})"

    @staticmethod
    def remainder(a: str, b: str) -> str:
        return f"jnp.remainder({a}, {b})"

    @staticmethod
    # pyrefly: ignore [bad-override]
    def truncdiv(a: str, b: str) -> str:
        # Truncated division (rounds toward zero)
        # For integers: sign(a)*sign(b) * (abs(a) // abs(b))
        return f"(jnp.sign({a}) * jnp.sign({b}) * (jnp.abs({a}) // jnp.abs({b}))).astype({a}.dtype)"

    @staticmethod
    def floordiv(a: str, b: str) -> str:
        return f"({a} // {b})"

    @staticmethod
    def clamp(x: str, min_val: str, max_val: str) -> str:
        return f"jnp.clip({x}, {min_val}, {max_val})"

    clip = clamp

    # Sign operations
    @staticmethod
    # pyrefly: ignore [bad-override]
    def sign(x: str) -> str:
        # PyTorch returns 0 for NaN, JAX returns NaN
        return f"jnp.where(jnp.isnan({x}), 0.0, jnp.sign({x}))"

    @staticmethod
    # pyrefly: ignore [bad-override]
    def signbit(x: str) -> str:
        return f"jnp.signbit({x})"

    # Special math functions
    @staticmethod
    # pyrefly: ignore [bad-override]
    def erf(x: str) -> str:
        return f"jax.scipy.special.erf({x})"

    @staticmethod
    def erfc(x: str) -> str:
        return f"jax.scipy.special.erfc({x})"

    @staticmethod
    # pyrefly: ignore [bad-override]
    def erfinv(x: str) -> str:
        return f"jax.scipy.special.erfinv({x})"

    @staticmethod
    # pyrefly: ignore [bad-override]
    def lgamma(x: str) -> str:
        return f"jax.scipy.special.gammaln({x})"

    @staticmethod
    def digamma(x: str) -> str:
        return f"jax.scipy.special.digamma({x})"

    @staticmethod
    def bessel_j0(x: str) -> str:
        # bessel_jn requires float64 and has numerical issues at x=0 (returns NaN)
        # bessel_jn(x, v=n) returns array of shape (n+1, ...) with J_0 to J_n
        # Handle by: convert to float64, compute, handle x=0, convert back
        # J0(0) = 1.0
        return (
            f"jnp.where({x}.astype(jnp.float64) == 0.0, 1.0, "
            f"jax.scipy.special.bessel_jn({x}.astype(jnp.float64), v=0)[0])"
            f".astype({x}.dtype)"
        )

    @staticmethod
    def bessel_j1(x: str) -> str:
        # bessel_jn requires float64 and has numerical issues at x=0 (returns NaN)
        # bessel_jn(x, v=n) returns array of shape (n+1, ...) with J_0 to J_n
        # Handle by: convert to float64, compute, handle x=0, convert back
        # J1(0) = 0.0
        return (
            f"jnp.where({x}.astype(jnp.float64) == 0.0, 0.0, "
            f"jax.scipy.special.bessel_jn({x}.astype(jnp.float64), v=1)[1])"
            f".astype({x}.dtype)"
        )

    @staticmethod
    def modified_bessel_i0(x: str) -> str:
        # Modified Bessel function of the first kind I_0(x)
        # I_0(x) = bessel_i0e(x) * exp(|x|) where bessel_i0e is the scaled version
        return f"jax.lax.bessel_i0e({x}) * jnp.exp(jnp.abs({x}))"

    @staticmethod
    def modified_bessel_i1(x: str) -> str:
        # Modified Bessel function of the first kind I_1(x)
        # I_1(x) = bessel_i1e(x) * exp(|x|) where bessel_i1e is the scaled version
        return f"jax.lax.bessel_i1e({x}) * jnp.exp(jnp.abs({x}))"

    @staticmethod
    def spherical_bessel_j0(x: str) -> str:
        # Spherical Bessel function of the first kind j_0(x) = sin(x) / x
        # Handle x=0: j_0(0) = 1
        return f"jnp.where({x} == 0.0, 1.0, jnp.sin({x}) / {x})"

    i0 = modified_bessel_i0

    @staticmethod
    def i0e(x: str) -> str:
        # Exponentially scaled modified Bessel function I_0
        return f"jax.lax.bessel_i0e({x})"

    i1 = modified_bessel_i1

    @staticmethod
    def i1e(x: str) -> str:
        # Exponentially scaled modified Bessel function I_1
        return f"jax.lax.bessel_i1e({x})"

    @staticmethod
    def gammainc(x: str, y: str) -> str:
        # Regularized lower incomplete gamma function P(a, x)
        # Note: PyTorch uses gammainc(input, other) where input is a (shape param)
        return f"jax.scipy.special.gammainc({x}, {y})"

    @staticmethod
    def gammaincc(x: str, y: str) -> str:
        # Regularized upper incomplete gamma function Q(a, x)
        return f"jax.scipy.special.gammaincc({x}, {y})"

    igamma = gammainc

    igammac = gammaincc

    @staticmethod
    def polygamma(x: str, y: str) -> str:
        # Polygamma function psi^(n)(x), x is order n, y is the value
        # Note: JAX uses polygamma(n, x) where n is integer order
        return f"jax.scipy.special.polygamma({x}.astype(jnp.int32), {y})"

    @staticmethod
    def ndtri(x: str) -> str:
        # Inverse of the standard normal CDF
        return f"jax.scipy.special.ndtri({x})"

    @staticmethod
    def zeta(x: str, y: str) -> str:
        # Hurwitz zeta function zeta(x, q) = sum_{k=0}^inf 1/(k+q)^x
        return f"jax.scipy.special.zeta({x}, {y})"

    @staticmethod
    def xlogy(x: str, y: str) -> str:
        # x * log(y), with proper handling of x=0
        return f"jax.scipy.special.xlogy({x}, {y})"

    @staticmethod
    def xlog1py(x: str, y: str) -> str:
        # x * log1p(y), with proper handling of x=0
        return f"jax.scipy.special.xlog1py({x}, {y})"

    @staticmethod
    # pyrefly: ignore [bad-override]
    def chebyshev_polynomial_t(x: str, n: str) -> str:
        # Chebyshev polynomial of the first kind T_n(x)
        # For |x| <= 1: T_n(x) = cos(n * arccos(x))
        # For x > 1: T_n(x) = cosh(n * arccosh(x))
        # For x < -1: T_n(x) = (-1)^n * cosh(n * arccosh(-x))
        return (
            f"jnp.where(jnp.abs({x}) <= 1, "
            f"jnp.cos({n} * jnp.arccos(jnp.clip({x}, -1, 1))), "
            f"jnp.where({x} > 1, "
            f"jnp.cosh({n} * jnp.arccosh(jnp.maximum({x}, 1.0))), "
            f"((-1.0) ** {n}) * jnp.cosh({n} * jnp.arccosh(jnp.maximum(-{x}, 1.0)))))"
        )

    @staticmethod
    # pyrefly: ignore [bad-override]
    def chebyshev_polynomial_u(x: str, n: str) -> str:
        # Chebyshev polynomial of the second kind U_n(x)
        # For |x| < 1: U_n(x) = sin((n+1) * arccos(x)) / sqrt(1 - x^2)
        # For x = 1: U_n(1) = n+1
        # For x = -1: U_n(-1) = (-1)^n * (n+1)
        # For x > 1: U_n(x) = sinh((n+1) * arccosh(x)) / sqrt(x^2 - 1)
        # For x < -1: U_n(x) = (-1)^n * U_n(-x) (symmetry)
        return (
            f"jnp.where(jnp.abs({x}) < 1, "
            f"jnp.sin(({n} + 1) * jnp.arccos(jnp.clip({x}, -1, 1))) / "
            f"jnp.sqrt(jnp.maximum(1 - {x}**2, 1e-10)), "
            f"jnp.where({x} >= 1, "
            f"jnp.where({x} == 1, {n} + 1.0, "
            f"jnp.sinh(({n} + 1) * jnp.arccosh(jnp.maximum({x}, 1.0))) / "
            f"jnp.sqrt(jnp.maximum({x}**2 - 1, 1e-10))), "
            f"jnp.where({x} == -1, ((-1.0) ** {n}) * ({n} + 1.0), "
            f"((-1.0) ** {n}) * jnp.sinh(({n} + 1) * jnp.arccosh(jnp.maximum(-{x}, 1.0))) / "
            f"jnp.sqrt(jnp.maximum({x}**2 - 1, 1e-10)))))"
        )

    @staticmethod
    # pyrefly: ignore [bad-override]
    def chebyshev_polynomial_v(x: str, n: str) -> str:
        # Chebyshev polynomial of the third kind V_n(x)
        # V_n(x) = (T_n(x) - T_{n+1}(x)) / (1 - x) for x != 1
        # V_n(1) = 1, recurrence: V_0 = 1, V_1 = 2x - 1, V_n = 2x*V_{n-1} - V_{n-2}
        # Explicit: V_0 = 1, V_1 = 2x-1, V_2 = 4x^2-2x-1, V_3 = 8x^3-4x^2-4x+1
        return (
            f"jnp.where({n} == 0, jnp.ones_like({x}), "
            f"jnp.where({n} == 1, 2*{x} - 1, "
            f"jnp.where({n} == 2, 4*{x}**2 - 2*{x} - 1, "
            f"jnp.where({n} == 3, 8*{x}**3 - 4*{x}**2 - 4*{x} + 1, "
            f"jnp.where({n} == 4, 16*{x}**4 - 8*{x}**3 - 12*{x}**2 + 4*{x} + 1, "
            f"jnp.where({n} == 5, 32*{x}**5 - 16*{x}**4 - 32*{x}**3 + 12*{x}**2 + 6*{x} - 1, "
            f"jnp.zeros_like({x})))))))"
        )

    @staticmethod
    # pyrefly: ignore [bad-override]
    def chebyshev_polynomial_w(x: str, n: str) -> str:
        # Chebyshev polynomial of the fourth kind W_n(x)
        # W_n(x) = (T_n(x) + T_{n+1}(x)) / (1 + x) for x != -1
        # W_n(-1) = (-1)^n, recurrence: W_0 = 1, W_1 = 2x + 1, W_n = 2x*W_{n-1} - W_{n-2}
        # Explicit: W_0 = 1, W_1 = 2x+1, W_2 = 4x^2+2x-1, W_3 = 8x^3+4x^2-4x-1
        return (
            f"jnp.where({n} == 0, jnp.ones_like({x}), "
            f"jnp.where({n} == 1, 2*{x} + 1, "
            f"jnp.where({n} == 2, 4*{x}**2 + 2*{x} - 1, "
            f"jnp.where({n} == 3, 8*{x}**3 + 4*{x}**2 - 4*{x} - 1, "
            f"jnp.where({n} == 4, 16*{x}**4 + 8*{x}**3 - 12*{x}**2 - 4*{x} + 1, "
            f"jnp.where({n} == 5, 32*{x}**5 + 16*{x}**4 - 32*{x}**3 - 12*{x}**2 + 6*{x} + 1, "
            f"jnp.zeros_like({x})))))))"
        )

    @staticmethod
    # pyrefly: ignore [bad-override]
    def shifted_chebyshev_polynomial_t(x: str, n: str) -> str:
        return PallasKernelOverrides.chebyshev_polynomial_t(f"(2 * {x} - 1)", n)

    @staticmethod
    # pyrefly: ignore [bad-override]
    def shifted_chebyshev_polynomial_u(x: str, n: str) -> str:
        return PallasKernelOverrides.chebyshev_polynomial_u(f"(2 * {x} - 1)", n)

    @staticmethod
    # pyrefly: ignore [bad-override]
    def shifted_chebyshev_polynomial_v(x: str, n: str) -> str:
        return PallasKernelOverrides.chebyshev_polynomial_v(f"(2 * {x} - 1)", n)

    @staticmethod
    # pyrefly: ignore [bad-override]
    def shifted_chebyshev_polynomial_w(x: str, n: str) -> str:
        return PallasKernelOverrides.chebyshev_polynomial_w(f"(2 * {x} - 1)", n)

    @staticmethod
    # pyrefly: ignore [bad-override]
    def hermite_polynomial_h(x: str, n: str) -> str:
        # Physicist's Hermite polynomial H_n(x)
        # H_n(x) = 2^n * x^n - n*(n-1)/2 * 2^(n-2) * x^(n-2) + ...
        # Use explicit formula: H_n(x) = n! * sum_{m=0}^{n//2} (-1)^m / (m! * (n-2m)!) * (2x)^(n-2m)
        # For simplicity, use the relation: H_n(x) = 2^(n/2) * He_n(x * sqrt(2)) where He is probabilist's
        # Actually simpler: use recurrence or closed form
        # H_0 = 1, H_1 = 2x, H_2 = 4x^2 - 2, H_3 = 8x^3 - 12x
        return (
            f"jnp.where({n} == 0, jnp.ones_like({x}), "
            f"jnp.where({n} == 1, 2 * {x}, "
            f"jnp.where({n} == 2, 4 * {x}**2 - 2, "
            f"jnp.where({n} == 3, 8 * {x}**3 - 12 * {x}, "
            f"jnp.where({n} == 4, 16 * {x}**4 - 48 * {x}**2 + 12, "
            f"jnp.where({n} == 5, 32 * {x}**5 - 160 * {x}**3 + 120 * {x}, "
            f"jnp.zeros_like({x})))))))"  # Fallback for higher n
        )

    @staticmethod
    # pyrefly: ignore [bad-override]
    def hermite_polynomial_he(x: str, n: str) -> str:
        # Probabilist's Hermite polynomial He_n(x)
        # He_0 = 1, He_1 = x, He_2 = x^2 - 1, He_3 = x^3 - 3x
        return (
            f"jnp.where({n} == 0, jnp.ones_like({x}), "
            f"jnp.where({n} == 1, {x}, "
            f"jnp.where({n} == 2, {x}**2 - 1, "
            f"jnp.where({n} == 3, {x}**3 - 3 * {x}, "
            f"jnp.where({n} == 4, {x}**4 - 6 * {x}**2 + 3, "
            f"jnp.where({n} == 5, {x}**5 - 10 * {x}**3 + 15 * {x}, "
            f"jnp.zeros_like({x})))))))"  # Fallback for higher n
        )

    @staticmethod
    # pyrefly: ignore [bad-override]
    def laguerre_polynomial_l(x: str, n: str) -> str:
        # Laguerre polynomial L_n(x)
        # L_0 = 1, L_1 = 1 - x, L_2 = (x^2 - 4x + 2)/2, L_3 = (-x^3 + 9x^2 - 18x + 6)/6
        return (
            f"jnp.where({n} == 0, jnp.ones_like({x}), "
            f"jnp.where({n} == 1, 1 - {x}, "
            f"jnp.where({n} == 2, ({x}**2 - 4*{x} + 2) / 2, "
            f"jnp.where({n} == 3, (-{x}**3 + 9*{x}**2 - 18*{x} + 6) / 6, "
            f"jnp.where({n} == 4, ({x}**4 - 16*{x}**3 + 72*{x}**2 - 96*{x} + 24) / 24, "
            f"jnp.where({n} == 5, (-{x}**5 + 25*{x}**4 - 200*{x}**3 + 600*{x}**2 - 600*{x} + 120) / 120, "
            f"jnp.zeros_like({x})))))))"  # Fallback for higher n
        )

    @staticmethod
    # pyrefly: ignore [bad-override]
    def legendre_polynomial_p(x: str, n: str) -> str:
        # Legendre polynomial P_n(x)
        # P_0 = 1, P_1 = x, P_2 = (3x^2 - 1)/2, P_3 = (5x^3 - 3x)/2
        return (
            f"jnp.where({n} == 0, jnp.ones_like({x}), "
            f"jnp.where({n} == 1, {x}, "
            f"jnp.where({n} == 2, (3 * {x}**2 - 1) / 2, "
            f"jnp.where({n} == 3, (5 * {x}**3 - 3 * {x}) / 2, "
            f"jnp.where({n} == 4, (35 * {x}**4 - 30 * {x}**2 + 3) / 8, "
            f"jnp.where({n} == 5, (63 * {x}**5 - 70 * {x}**3 + 15 * {x}) / 8, "
            f"jnp.zeros_like({x})))))))"  # Fallback for higher n
        )

    # Reciprocal and square
    @staticmethod
    def reciprocal(x: str) -> str:
        return f"jnp.reciprocal({x})"

    @staticmethod
    def square(x: str) -> str:
        return f"jnp.square({x})"

    # Additional operations
    @staticmethod
    def fma(a: str, b: str, c: str) -> str:
        """Fused multiply-add: a * b + c

        JAX doesn't have jnp.fma, so we use the unfused version.
        The compiler may still fuse this on supported hardware.
        """
        return f"(({a}) * ({b}) + ({c}))"

    @staticmethod
    # pyrefly: ignore [bad-override]
    def copysign(a: str, b: str) -> str:
        return f"jnp.copysign({a}, {b})"

    @staticmethod
    # pyrefly: ignore [bad-override]
    def nextafter(a: str, b: str) -> str:
        return f"jnp.nextafter({a}, {b})"

    @staticmethod
    # pyrefly: ignore [bad-override]
    def ldexp(a: str, b: str) -> str:
        return f"jnp.ldexp({a}, {b})"

    @staticmethod
    # pyrefly: ignore [bad-override]
    def frexp(x: str) -> str:
        return f"jnp.frexp({x})"

    @staticmethod
    def modf(x: str) -> str:
        return f"jnp.modf({x})"

    # Bitwise operations
    @staticmethod
    def bitwise_and(a: str, b: str) -> str:
        return f"jnp.bitwise_and({a}, {b})"

    @staticmethod
    def bitwise_or(a: str, b: str) -> str:
        return f"jnp.bitwise_or({a}, {b})"

    @staticmethod
    def bitwise_xor(a: str, b: str) -> str:
        return f"jnp.bitwise_xor({a}, {b})"

    @staticmethod
    def bitwise_not(x: str) -> str:
        return f"jnp.bitwise_not({x})"

    @staticmethod
    def left_shift(a: str, b: str) -> str:
        return f"jnp.left_shift({a}, {b})"

    @staticmethod
    def right_shift(a: str, b: str) -> str:
        return f"jnp.right_shift({a}, {b})"

    # Random number generation operations
    @staticmethod
    def load_seed(name: str, offset: str) -> str:
        """Load the random seed value from a buffer."""
        # Load the seed from the buffer and add offset for uniqueness
        seed_offset = V.kernel.args.seed_offset("load_seed_offset", offset)
        return f"({V.kernel.args.input(name)}[0] + {seed_offset})"

    @staticmethod
    def rand(seed: str, offset: str) -> str:
        """Generate uniform random numbers in [0, 1).

        Uses JAX's threefry2x32 PRNG directly for vectorized random generation.
        The seed provides the base key, offset provides per-element uniqueness.
        """
        # For vectorized random, we use jax.random.uniform with shape from offset
        # Create a base key from seed, then use fold_in with vmap for per-element keys
        # Use float32 dtype to match PyTorch's default
        return (
            f"jax.vmap(lambda o: jax.random.uniform("
            f"jax.random.fold_in(jax.random.PRNGKey(jnp.uint32({seed})), jnp.uint32(o)), (), dtype=jnp.float32))"
            f"(jnp.asarray({offset}).flatten()).reshape(jnp.asarray({offset}).shape)"
        )

    @staticmethod
    def randn(seed: str, offset: str) -> str:
        """Generate standard normal random numbers.

        Uses JAX's threefry2x32 PRNG directly for vectorized random generation.
        The seed provides the base key, offset provides per-element uniqueness.
        """
        # For vectorized random, use vmap to fold in each offset value
        # Use float32 dtype to match PyTorch's default
        return (
            f"jax.vmap(lambda o: jax.random.normal("
            f"jax.random.fold_in(jax.random.PRNGKey(jnp.uint32({seed})), jnp.uint32(o)), (), dtype=jnp.float32))"
            f"(jnp.asarray({offset}).flatten()).reshape(jnp.asarray({offset}).shape)"
        )

    @staticmethod
    def randint64(seed: str, offset: str, low: str, high: str) -> str:
        """Generate random int64 values in [low, high)."""
        # For vectorized random, use vmap to fold in each offset value
        return (
            f"jax.vmap(lambda o: jax.random.randint("
            f"jax.random.fold_in(jax.random.PRNGKey(jnp.uint32({seed})), jnp.uint32(o)), (), {low}, {high}, dtype=jnp.int64))"
            f"(jnp.asarray({offset}).flatten()).reshape(jnp.asarray({offset}).shape)"
        )


@dataclasses.dataclass
class _IndirectAccessInfo:
    """Describes a detected indirect (data-dependent) buffer access."""

    table_param: str
    table_buf_name: str
    table_shape: tuple
    indirect_dim: int
    indirect_var: str
    indices_param: str


@dataclasses.dataclass
class _BufferIndexing:
    """Encapsulates index string and flattening requirements for buffer access."""

    index_str: str
    needs_flatten: bool


@dataclasses.dataclass
class _BroadcastedIterVar:
    """Encapsulates information needed to codegen a broadcasted iteration var"""

    # index of this var in `self.range_tree_nodes.items()``
    idx: int
    var_sym: sympy.Symbol
    entry: IterationRangesEntry
    length_val: int | None


@dataclasses.dataclass
class _CodegenContext:
    """Bundles local state shared across codegen_kernel helper methods."""

    code: IndentedBuffer
    kernel_name: str
    is_tpu: bool
    interpret_is_cpu: bool
    interpret_literal: str
    kernel_params: list[str]
    pure_out_params: list[str]
    output_params: list[str]
    size_var_params: list[str]
    output_buffer_lookup: dict[str, str]
    aliasable_flags: dict[str, bool]
    alias_params: list[str]
    pointer_tail: list[str]
    kernel_input_params: list[str]
    full_kernel_params: list[str]
    non_alias_out_set: OrderedSet[str]
    copy_output_indices: list[int]
    alias_pairs: list[tuple[int, int]]


class PallasKernel(SIMDKernel):
    """Pallas kernel codegen for TPU and GPU (Mosaic backend).

    Generates Python code that defines a Pallas kernel and a host entrypoint,
    compiled and loaded via async_compile.pallas.
    """

    overrides = PallasKernelOverrides  # type: ignore[assignment]
    kexpr: Callable[[sympy.Expr], str] = pallas_pexpr  # Use Pallas expression printer

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # Determine device type once at initialization
        device = V.graph.get_current_device_or_throw()
        self.is_gpu = device.type == "cuda"
        self.is_tpu = device.type == "tpu"
        # Use TMA (Tensor Memory Accelerator) for GPU to handle non-aligned tensor sizes
        # TMA automatically masks OOB accesses, eliminating the need for explicit
        # padding to multiples of 128. Uses lax.fori_loop with direct TMA primitives.
        self.use_emit_pipeline = self.is_gpu  # Enable TMA approach for GPU
        # Track which output param each store uses: list of (out_ptr_name, store_line)
        self.store_with_output: list[tuple[str, str]] = []
        # Track load index expressions for reduction axis detection
        self.load_index_exprs: dict[str, sympy.Expr] = {}
        # Track outputs that need to be readable (for scatter operations)
        self.outputs_need_read: OrderedSet[str] = OrderedSet()
        # Map input buffer names to their detected permutation tuples.
        self.permuted_input_buffers: dict[str, tuple[int, ...]] = {}
        self.collapsed_reshape_inputs: dict[str, tuple[int, ...]] = {}
        self.collapsed_output_shape: tuple[int, ...] | None = None
        self._cpu_max_grid_product: int | None = None
        # Precompute output buffer names from scheduler nodes so that the
        # load path can check output shapes before stores are processed.
        self._output_buffer_names: list[str] = []
        for snode in self.features.scheduler_nodes():
            for dep in snode.read_writes.writes:
                self._output_buffer_names.append(dep.name)
        # Track which iteration variables are actually used in the kernel
        self.used_iter_vars: OrderedSet[sympy.Symbol] = OrderedSet()
        # Iteration vars that have been emitted in tile-relative form
        # (safe for tiling even when they appear in the compute body)
        self.tile_relative_iter_vars: OrderedSet[sympy.Symbol] = OrderedSet()
        # Track if any load/store uses flatten-based indexing (buf[...].flatten()[idx])
        self.has_flatten_indexing = False
        # Strided input buffers: map graph buffer name -> per-dim
        # (stride, offset, skip) triples.  Used to reshape inputs outside
        # the kernel and generate static indexing inside
        # (e.g. in_ref[:, :, offset] instead of in_ref[...].flatten()[idx]).
        self.strided_input_buffers: dict[str, list[tuple[int, int, int]]] = {}
        # Buffers that already use flatten+gather indexing; strided
        # decomposition must not reshape these (it would break flat offsets).
        self.flatten_indexed_buffers: OrderedSet[str] = OrderedSet()
        # Indirect (data-dependent) access info for scalar prefetch
        self.indirect_access: _IndirectAccessInfo | None = None
        self._cse_to_param: dict[str, str] = {}
        self._param_to_graph_name: dict[str, str] = {}
        self.is_tpu = device.type == "tpu"

    def check_bounds(
        self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool
    ) -> None:
        """Check array bounds for indirect indexing."""
        # For now, skip explicit bounds checking as JAX/Pallas handles this internally
        # TODO: Implement explicit bounds checking with assertions if needed

    def _get_index_str(self, index: sympy.Expr) -> str:
        """
        Convert an index expression to a string suitable for Pallas indexing.

        Pallas operates on full arrays, so we need to convert index expressions
        to JAX array slicing. For example:
        - x0 -> "..." (contiguous access, full array)
        - 2*x0 -> "::2" (strided access with stride 2)
        - 2*x0 + 1 -> "1::2" (strided access with offset 1, stride 2)

        Args:
            index: The indexing expression to convert

        Returns:
            The indexing string to use in generated code
        """
        # Prepare and simplify the index
        prepared_index = self.prepare_indexing(index)

        # Note: Block variable detection (im2col patterns) is handled in load()/store()
        # where we have access to buffer dimensions. We check the buffer size
        # against iteration variables there to detect gather patterns.

        # For simple single-symbol access (contiguous case), we can use [...]
        # which is more efficient as it operates on the entire array at once
        if isinstance(prepared_index, sympy.Symbol):
            return "..."
        elif prepared_index.is_Integer:
            # Scalar index
            return str(prepared_index)
        else:
            # Complex expression (strided/scatter access)
            # Try to extract stride and offset for common patterns
            return self._convert_to_jax_slice(prepared_index)

    def _convert_to_jax_slice(self, index: sympy.Expr) -> str:
        """
        Convert a sympy index expression to JAX slice notation.

        Handles common patterns like:
        - stride*var -> ::stride
        - stride*var + offset -> offset::stride

        For more complex patterns, falls back to explicit indexing.
        Uses BlockPatternMatcher for robust pattern matching.
        """
        # Get the iteration variables for this kernel
        if not self.range_trees:
            return "..."

        # Rename symbolic sizes to kernel parameter names upfront
        index = self.rename_indexing(index)

        # Check for ModularIndexing - this is NOT contiguous access
        # ModularIndexing is used for roll/wrap-around operations
        if index.has(ModularIndexing):
            # Track which iteration variables are used before returning
            self.used_iter_vars.update(self._get_used_iter_vars(index))
            # Generate actual index expression - iteration variables are already
            # defined as jnp.arange arrays, so we just convert to JAX code
            return self.kexpr(index)

        # Simplify the index
        index = V.graph.sizevars.simplify(index)
        # Find which iteration variable(s) are used
        used_vars = self._get_used_iter_vars(index)

        # Track which iteration variables are used
        self.used_iter_vars.update(used_vars)

        if len(used_vars) == 0:
            # No iteration variables, this is a constant index
            return str(index)
        elif len(used_vars) == 1:
            # Single iteration variable - try to extract stride and offset using BlockPatternMatcher
            var = next(iter(used_vars))

            # Get the subexpression involving this variable
            var_expr = BlockPatternMatcher.get_subexpr_involving_symbol(index, var)

            # Try to match affine pattern: stride * var
            stride = BlockPatternMatcher.match_affine_block_expr(var_expr, var)

            if stride is not None:
                offset = index - var_expr
                offset = V.graph.sizevars.simplify(offset)

                if stride < 0:
                    return self.kexpr(index)

                if offset == 0:
                    return "..."

                # Non-zero offset: check if we can use slice notation
                if stride != 1:
                    return self.kexpr(index)

                try:
                    offset_val = int(offset)
                    if offset_val < 0:
                        return self.kexpr(index)
                except (TypeError, ValueError):
                    return self.kexpr(index)

                return f"{self.kexpr(offset)}::1"
            else:
                # Couldn't match affine pattern, fall back to original logic
                offset = index - var_expr
                offset = V.graph.sizevars.simplify(offset)
                if offset == 0 and var_expr == var:
                    # Just the variable itself, unit stride
                    return "..."
        elif len(used_vars) > 1:
            # Multi-dimensional indexing
            # For contiguous multi-dim access, all terms should have unit stride
            all_unit_stride = True
            for var in used_vars:
                var_expr = BlockPatternMatcher.get_subexpr_involving_symbol(index, var)
                stride = BlockPatternMatcher.match_affine_block_expr(var_expr, var)
                if stride != 1:
                    all_unit_stride = False
                    break
            if all_unit_stride:
                # Contiguous multi-dimensional access
                return "..."
            else:
                # Strided multi-dimensional access
                # For most cases, inputs are made contiguous before passing to JAX,
                # so strided tensors become contiguous and we can use [...]
                # The buffer size check in load() handles im2col-like patterns
                return "..."

        # For complex cases, use [...] since inputs are made contiguous
        return "..."

    def _generate_strided_index(self, index: sympy.Expr) -> str:
        """
        Generate JAX code to compute an index array for strided/complex indexing patterns.

        For expressions like `2 * x3 + 32 * x2 + 256 * x1 + 1024 * x0`, we generate
        code that computes the flattened index array using broadcasting.

        The iteration variables (x0, x1, x2, x3) are already defined as jnp.arange arrays
        in the kernel. We just need to convert the sympy expression to JAX code.
        """
        free_symbols = index.free_symbols
        iter_vars = self._get_iter_vars()

        # Check that all free symbols are iteration variables (no indirect vars)
        used_vars = free_symbols & iter_vars
        if used_vars != free_symbols:
            raise Unsupported(
                f"Pallas backend does not yet support mixed index pattern: {index}"
            )

        # Track which iteration variables are used
        self.used_iter_vars.update(used_vars)

        # Convert sympy expression to Python/JAX code string
        # The iteration variables are already defined as jnp.arange arrays
        index_str = self.kexpr(index)

        # Mark this as requiring flatten access
        return index_str

    def _get_iter_vars(self) -> OrderedSet:
        """Get the set of iteration variable symbols."""
        return OrderedSet(self.range_tree_nodes.keys())

    def _get_used_iter_vars(self, index: sympy.Expr) -> OrderedSet:
        """Get iteration variables used in an index expression."""
        return index.free_symbols & self._get_iter_vars()

    def _has_iteration_vars(self, index: sympy.Expr) -> bool:
        """Check if index expression contains iteration variables."""
        return bool(self._get_used_iter_vars(index))

    def _get_indirect_vars(self, index: sympy.Expr) -> list[sympy.Symbol]:
        """Get list of indirect variable symbols (tmp*) in an index expression."""
        return [s for s in index.free_symbols if str(s).startswith("tmp")]

    def _has_indirect_vars(self, index: sympy.Expr) -> bool:
        """Check if index expression contains indirect variables."""
        return len(self._get_indirect_vars(index)) > 0

    def _decompose_strided_access(
        self, index: sympy.Expr, name: str
    ) -> list[tuple[int, int, int]] | None:
        """Decompose a flat index into per-dimension (stride, offset, skip) triples.

        Given flat index like ``64*x0 + 2*x1 + 5`` and buffer shape ``(32, 64)``
        with C-contiguous strides ``[64, 1]``:
          - x0 coefficient 64 / buffer_stride[0]=64 -> dim 0: stride=1
          - x1 coefficient 2 / buffer_stride[1]=1  -> dim 1: stride=2
          - constant 5: dim 0 gets 5//64=0, dim 1 gets 5//1=5
          - dim 1 offset 5 with stride 2: skip=5//2=2, offset=5%2=1

        Returns per-dim ``[(stride, offset, skip), ...]`` where:
          - stride: access stride on this dim (1 = contiguous)
          - offset: static index into the stride dim (0 <= offset < stride)
          - skip: number of stride-blocks to skip at the start of this dim
        Returns None if decomposition fails.
        """
        if self._has_indirect_vars(index) or index.has(ModularIndexing):
            return None

        # Don't reshape a buffer that already has flatten+gather loads;
        # the reshape would change the flat layout and break those loads.
        if name in self.flatten_indexed_buffers:
            return None

        info = self._get_buffer_info(name)
        if info is None:
            return None
        _, buf_size, _, _, _ = info

        buf_shape_or_none = [self._safe_int(s) for s in buf_size]
        if any(s is None or s <= 0 for s in buf_shape_or_none):
            return None
        buf_shape: list[int] = cast(list[int], buf_shape_or_none)
        ndim = len(buf_shape)
        if ndim == 0:
            return None

        c_strides = self._c_contiguous_strides(buf_shape)

        # Extract per-variable coefficients
        used_vars = self._get_used_iter_vars(index)
        if not used_vars:
            return None

        # [stride, raw_offset] per dim — raw_offset may be >= stride
        result: list[list[int]] = [[1, 0] for _ in range(ndim)]
        # Track which variable maps to which dimension
        var_to_dim: dict[sympy.Symbol, int] = {}

        remaining = V.graph.sizevars.simplify(index)
        for var in used_vars:
            var_expr = BlockPatternMatcher.get_subexpr_involving_symbol(remaining, var)
            coeff = BlockPatternMatcher.match_affine_block_expr(var_expr, var)
            if coeff is None:
                return None
            coeff_int = self._safe_int(coeff)
            if coeff_int is None or coeff_int <= 0:
                return None

            # Find which buffer dim this variable maps to
            dim = None
            for d in range(ndim):
                if c_strides[d] == 0:
                    continue
                if coeff_int % c_strides[d] == 0:
                    per_dim_stride = coeff_int // c_strides[d]
                    if per_dim_stride >= 1:
                        dim = d
                        break
            if dim is None:
                return None

            per_dim_stride = coeff_int // c_strides[dim]
            if per_dim_stride < 1 or buf_shape[dim] % per_dim_stride != 0:
                return None

            result[dim][0] = per_dim_stride
            var_to_dim[var] = dim
            remaining = V.graph.sizevars.simplify(remaining - var_expr)

        # Remaining is the constant offset; distribute across dims using
        # divmod with C-contiguous strides (largest stride first).
        offset_val = self._safe_int(remaining)
        if offset_val is None:
            return None
        if offset_val < 0:
            return None
        for d in range(ndim):
            if c_strides[d] > 0:
                result[d][1] = offset_val // c_strides[d]
                offset_val = offset_val % c_strides[d]
        if offset_val != 0:
            return None

        # Only return if there's at least one dim with stride > 1.
        # Contiguous accesses with just an offset (stride=1 everywhere)
        # are handled by the normal tiling path.
        if all(s == 1 for s, _ in result):
            return None

        # Decompose each dim's raw offset into (skip, offset) where
        # offset < stride: raw = skip * stride + offset.
        # Then validate the output numel matches.
        decomposed: list[tuple[int, int, int]] = []
        output_numel_expected = 1
        for d in range(ndim):
            stride, raw_offset = result[d]
            offset = raw_offset % stride
            skip = raw_offset // stride
            n_blocks = buf_shape[d] // stride
            if skip >= n_blocks:
                return None
            output_numel_expected *= n_blocks - skip
            decomposed.append((stride, offset, skip))

        output_numel, _ = self._compute_output_numel_from_index(index)
        if output_numel != output_numel_expected:
            return None

        # Verify each variable's range matches its assigned dimension's
        # effective size.  When the kernel collapses multiple buffer dims
        # into one iteration variable (e.g. batch*channels), the variable
        # range won't match any single buffer dimension and we must bail
        # out to avoid shape mismatches in the generated code.
        for var, dim in var_to_dim.items():
            if var not in self.range_tree_nodes:
                return None
            var_range = self._safe_int(self.range_tree_nodes[var].length)
            if var_range is None:
                return None
            stride_d, _offset_d, skip_d = decomposed[dim]
            effective_size = buf_shape[dim] // stride_d - skip_d
            if var_range != effective_size:
                return None

        return decomposed

    @staticmethod
    def _strided_load_expr(buf: str, decomp: list[tuple[int, int, int]]) -> str:
        """Build ``buf[:, :, offset]`` for strided dims, ``:`` for others."""
        parts: list[str] = []
        for stride, offset, _skip in decomp:
            if stride == 1:
                parts.append(":")
            else:
                parts.append(":")  # the halved dim
                parts.append(str(offset))  # static index into stride dim
        return f"{buf}[{', '.join(parts)}]"

    def _codegen_strided_reshapes(
        self, code: IndentedBuffer, params: list[str]
    ) -> None:
        """Emit reshape + optional slice for strided input parameters.

        For each strided param, reshapes ``(M, N)`` to ``(M, N/stride, stride)``
        and, when ``skip > 0``, slices off leading blocks so the remaining
        elements align with the output.
        """
        for param in params:
            buf_name = self._param_to_buf_name(param)
            if buf_name is None or buf_name not in self.strided_input_buffers:
                continue
            strides = self.strided_input_buffers[buf_name]
            info = self._get_buffer_info(buf_name)
            if info is None:
                continue
            _, buf_size, _, _, _ = info
            new_shape_parts: list[str] = []
            for d, (stride, _offset, _skip) in enumerate(strides):
                dim = self._safe_int(buf_size[d])
                if dim is None:
                    break
                if stride > 1:
                    new_shape_parts.append(str(dim // stride))
                    new_shape_parts.append(str(stride))
                else:
                    new_shape_parts.append(str(dim))
            else:
                code.writeline(
                    f"{param} = {param}.reshape({', '.join(new_shape_parts)})"
                )
                if any(skip > 0 for _, _, skip in strides):
                    slice_parts: list[str] = []
                    for stride, _offset, skip in strides:
                        if stride == 1:
                            slice_parts.append(":")
                        else:
                            slice_parts.append(f"{skip}:" if skip > 0 else ":")
                            slice_parts.append(":")
                    code.writeline(f"{param} = {param}[{', '.join(slice_parts)}]")

    @staticmethod
    def _get_actual_out_strides(out_buf, n: int) -> list[int] | None:
        """Extract actual output buffer strides from its layout."""
        layout = getattr(out_buf, "get_layout", lambda: None)()
        if layout is None:
            return None
        stride_raw = getattr(layout, "stride", None)
        if stride_raw is None or len(stride_raw) != n:
            return None
        strides: list[int] = []
        for s in stride_raw:
            v = int(s) if isinstance(s, (int, sympy.Integer)) else None
            if v is None:
                return None
            strides.append(v)
        return strides

    def _compute_store_coeffs(self, ordered: list) -> dict | None:
        """Compute store-side linearization coefficients from range tree nesting.

        The tree structure encodes the output iteration order: later
        trees (prefix ``x``) are innermost, earlier trees (``y``, ``z``)
        are outer.  Within a tree, dict order goes inner-to-outer.
        The innermost variable gets coefficient 1; each successive
        variable (moving outward) multiplies by the previous range.

        Returns ``{sympy.Symbol: int}`` mapping each RT var to its store
        coefficient, or ``None`` on failure.
        """
        prefix_groups: dict[str, list] = {}
        prefix_order: list[str] = []
        for v in ordered:
            node = self.range_tree_nodes[v]
            p = node.prefix
            if p not in prefix_groups:
                prefix_groups[p] = []
                prefix_order.append(p)
            prefix_groups[p].append(v)
        inner_to_outer: list = []
        for p in reversed(prefix_order):
            inner_to_outer.extend(prefix_groups[p])
        coeffs: dict = {}
        coeff = 1
        for v in inner_to_outer:
            sz = self._safe_int(self.range_tree_nodes[v].length)
            if sz is None:
                return None
            coeffs[v] = coeff
            coeff *= sz
        return coeffs

    def _get_full_load_permutation(
        self, name: str, index: sympy.Expr
    ) -> tuple[int, ...] | None:
        """Return permutation for a full-array load, or None.

        Computes the permutation by mapping each range-tree variable to
        both an output dimension (via store coefficients + actual output
        strides) and an input dimension (via load coefficients + input
        C-contiguous strides).  The permutation is then:

            perm[out_dim] = in_dim   for each RT variable

        Using actual output strides (not C-contiguous) is critical: the
        scheduler may choose a non-standard output layout (e.g. column-
        major) to optimise for transposed inputs.

        When all dimensions collapse to a single flat RT variable (e.g.
        (2,2,2,2,2) with all dims size 2), infers the permutation
        directly from output strides vs input C-contiguous strides.
        """
        info = self._get_buffer_info(name)
        if not info:
            return None
        _, buf_size, _, _, is_contiguous = info
        in_shape_raw = [self._safe_int(s) for s in buf_size]
        if len(in_shape_raw) < 2 or None in in_shape_raw:
            return None
        in_shape: list[int] = cast(list[int], in_shape_raw)
        if not is_contiguous:
            return None  # .contiguous() at JAX boundary handles this

        # Extract index coefficients for each non-reduction RT variable.
        iter_used = self._get_used_iter_vars(index)
        ordered = [
            s
            for s, e in self.range_tree_nodes.items()
            if s in iter_used and not e.is_reduction
        ]
        if len(ordered) != len(in_shape):
            # All dims may have collapsed to a single flat RT variable
            # (e.g. (2,2,2,2,2) → single x0 of length 32).  In this
            # case, infer the permutation directly from output strides
            # vs input C-contiguous strides.
            n = len(in_shape)
            if len(ordered) == 1 and self._safe_int(
                self.range_tree_nodes[ordered[0]].length
            ) == math.prod(in_shape):
                in_strides = self._c_contiguous_strides(in_shape)
                for out_name in self._output_buffer_names:
                    out_buf = V.graph.get_buffer(out_name)
                    if out_buf is None:
                        continue
                    out_shape = [self._safe_int(s) for s in out_buf.get_size()]
                    if any(s is None for s in out_shape) or len(out_shape) != n:
                        continue
                    actual = self._get_actual_out_strides(out_buf, n)
                    if actual is None:
                        break
                    # Map each output dim to the input dim with the
                    # same stride.
                    perm = self._map_coeffs_to_dims(actual, in_strides)
                    if perm is None:
                        break
                    if list(perm) == list(range(n)):
                        return None
                    return tuple(perm)
            return None
        coeffs_raw = [
            self._get_index_coefficient(V.graph.sizevars.simplify(index), v)
            for v in ordered
        ]
        if not all(isinstance(c, int) and c > 0 for c in coeffs_raw):
            return None
        coeffs: list[int] = cast(list[int], coeffs_raw)

        n = len(ordered)
        in_strides = self._c_contiguous_strides(in_shape)
        store_coeffs = self._compute_store_coeffs(ordered)

        # --- Primary path: dimension-mapping with actual output strides ---
        if store_coeffs is not None:
            for out_name in self._output_buffer_names:
                out_buf = V.graph.get_buffer(out_name)
                if out_buf is None:
                    continue
                out_shape = [self._safe_int(s) for s in out_buf.get_size()]
                if any(s is None for s in out_shape) or len(out_shape) != n:
                    continue

                actual = self._get_actual_out_strides(out_buf, n)
                if actual is not None:
                    rt_to_out = self._map_coeffs_to_dims(
                        [store_coeffs[v] for v in ordered], actual
                    )
                    rt_to_in = self._map_coeffs_to_dims(list(coeffs), in_strides)
                    if rt_to_out is not None and rt_to_in is not None:
                        perm = [0] * n
                        for k in range(n):
                            perm[rt_to_out[k]] = rt_to_in[k]
                        if list(perm) == list(range(n)):
                            return None
                        return tuple(perm)
                break

        return None

    def _get_collapsed_load_permutation(
        self, name: str, index: sympy.Expr
    ) -> tuple[tuple[int, ...], tuple[int, ...]] | None:
        """Handle permutation when range tree has collapsed dimensions.

        When simplify_and_reorder merges contiguous dims, the range tree
        has fewer variables than the buffer's rank.  This method detects
        the permutation in the collapsed space and returns
        (collapsed_input_shape, perm) so the caller can generate:
            jnp.permute_dims(load.reshape(collapsed_shape), perm)

        Uses index coefficients on both sides: load-index coefficients
        map vars to collapsed input dims, and store-side coefficients
        (derived from the range tree nesting) map vars to collapsed
        output dims.  Both sets of strides are always unique, so
        matching is unambiguous even with duplicate group sizes.
        """
        info = self._get_buffer_info(name)
        if not info:
            return None
        _, buf_size, _, _, is_contiguous = info
        in_shape_raw = [self._safe_int(s) for s in buf_size]
        if len(in_shape_raw) < 2 or None in in_shape_raw:
            return None
        in_shape: list[int] = cast(list[int], in_shape_raw)
        if not is_contiguous:
            return None

        iter_used = self._get_used_iter_vars(index)
        ordered = [
            s
            for s, e in self.range_tree_nodes.items()
            if s in iter_used and not e.is_reduction
        ]
        n = len(ordered)
        if n < 2 or n >= len(in_shape):
            return None
        ranges_raw = [self._safe_int(self.range_tree_nodes[v].length) for v in ordered]
        if None in ranges_raw:
            return None
        ranges: list[int] = cast(list[int], ranges_raw)
        if math.prod(ranges) != math.prod(in_shape):
            return None

        # Group consecutive input dims (right-to-left) to match ranges
        in_groups = self._group_dims_to_ranges(in_shape, ranges)
        if in_groups is None:
            return None

        # Compute collapsed input strides (row-major) and use load-index
        # coefficients to map each range tree var to a collapsed input dim.
        # Strides are always unique, so this is unambiguous even when
        # group sizes are duplicated.
        collapsed_in_strides = [0] * n
        stride = 1
        for i in range(n - 1, -1, -1):
            collapsed_in_strides[i] = stride
            stride *= in_groups[i]

        simplified = V.graph.sizevars.simplify(index)
        in_coeffs_raw = [self._get_index_coefficient(simplified, v) for v in ordered]
        if not all(isinstance(c, int) and c > 0 for c in in_coeffs_raw):
            return None
        in_coeffs: list[int] = cast(list[int], in_coeffs_raw)

        in_stride_to_dim = {s: i for i, s in enumerate(collapsed_in_strides)}
        var_to_in_dim = []
        for coeff in in_coeffs:
            dim = in_stride_to_dim.get(coeff)
            if dim is None:
                return None
            var_to_in_dim.append(dim)

        store_coeffs = self._compute_store_coeffs(ordered)
        if store_coeffs is None:
            return None

        # Find the output-side mapping using store coefficients.
        for out_name in self._output_buffer_names:
            out_buf = V.graph.get_buffer(out_name)
            if out_buf is None:
                continue
            out_shape_raw = [self._safe_int(s) for s in out_buf.get_size()]
            if any(s is None for s in out_shape_raw) or len(out_shape_raw) < 2:
                continue
            out_shape: list[int] = cast(list[int], out_shape_raw)
            if math.prod(out_shape) != math.prod(in_shape):
                continue
            out_groups = self._group_dims_to_ranges(out_shape, list(in_groups))
            if out_groups is None:
                continue

            # Compute collapsed output strides and match store coefficients.
            collapsed_out_strides = [0] * n
            stride = 1
            for i in range(n - 1, -1, -1):
                collapsed_out_strides[i] = stride
                stride *= out_groups[i]

            out_stride_to_dim = {s: j for j, s in enumerate(collapsed_out_strides)}
            var_to_out_dim = []
            for v in ordered:
                j = out_stride_to_dim.get(store_coeffs[v])
                if j is None:
                    return None
                var_to_out_dim.append(j)

            # Build perm: perm[out_dim] = in_dim
            perm = [0] * n
            for k in range(n):
                perm[var_to_out_dim[k]] = var_to_in_dim[k]
            if perm == list(range(n)):
                return None
            return (tuple(in_groups), tuple(perm))
        return None

    @staticmethod
    def _group_dims_to_ranges(dims: list[int], ranges: list[int]) -> list[int] | None:
        """Group consecutive dims (right-to-left) to match range values.

        Returns collapsed shape (left-to-right) or None if no valid grouping.
        """
        available = list(ranges)
        groups: list[int] = []
        product = 1
        for i in range(len(dims) - 1, -1, -1):
            product *= dims[i]
            try:
                idx = available.index(product)
            except ValueError:
                continue
            groups.append(product)
            available.pop(idx)
            product = 1
        if product != 1 or available:
            return None
        groups.reverse()
        return groups

    def _get_index_expr(self, index: sympy.Expr) -> _BufferIndexing:
        """Get the index expression string and whether it needs flattening."""
        has_indirect = self._has_indirect_vars(index)
        has_iter_vars = self._has_iteration_vars(index)

        if has_indirect and has_iter_vars:
            return _BufferIndexing(
                index_str=self._handle_mixed_indexing(index), needs_flatten=True
            )
        elif has_indirect:
            return _BufferIndexing(index_str=self.kexpr(index), needs_flatten=False)
        else:
            index_str = self._get_index_str(index)
            # Check if index contains ModularIndexing - this requires flattened access
            # ModularIndexing is used for roll/wrap-around operations
            needs_flatten = index.has(ModularIndexing) and index_str != "..."
            # If index_str is an actual expression (not "..." or a slice pattern),
            # we need flattened access because it uses block variables
            if not needs_flatten and index_str != "...":
                # Check if it's a simple slice pattern (::N or M::N)
                if not ("::" in index_str or index_str.lstrip("-").isdigit()):
                    needs_flatten = True
            return _BufferIndexing(index_str=index_str, needs_flatten=needs_flatten)

    @staticmethod
    def _safe_int(val: Any) -> int | None:
        """Convert value to int, returning None on failure."""
        try:
            return int(val)
        except (TypeError, ValueError):
            return None

    @staticmethod
    def _c_contiguous_strides(shape: list[int]) -> list[int]:
        """Return C-contiguous strides for the given shape."""
        n = len(shape)
        strides = [1] * n
        for i in range(n - 2, -1, -1):
            strides[i] = strides[i + 1] * shape[i + 1]
        return strides

    @staticmethod
    def _map_coeffs_to_dims(coeffs: list[int], strides: list[int]) -> list[int] | None:
        """Map coefficient values to dimension indices via stride matching.

        Returns a list where entry k is the dimension whose stride equals
        coeffs[k], or None if the mapping is ambiguous or incomplete.
        """
        stride_to_dim: dict[int, int] = {}
        for d, s in enumerate(strides):
            if s in stride_to_dim:
                return None  # duplicate strides
            stride_to_dim[s] = d
        mapping: list[int] = []
        for c in coeffs:
            d = stride_to_dim.get(c)
            if d is None:
                return None
            mapping.append(d)
        if len(OrderedSet(mapping)) != len(coeffs):
            return None
        return mapping

    def _zero_dim_output_flags(self, ctx: _CodegenContext) -> tuple[bool, bool]:
        """Return whether an output has a zero or unknown dimension."""
        has_unknown_dim = False
        for buf_name in ctx.output_buffer_lookup.values():
            buf = V.graph.try_get_buffer(buf_name)
            if buf is None:
                has_unknown_dim = True
                continue
            for dim in buf.get_size():
                dim_int = self._safe_int(dim)
                if dim_int == 0:
                    return True, has_unknown_dim
                if dim_int is None:
                    has_unknown_dim = True
        return False, has_unknown_dim

    def _get_reduction_axes(self) -> tuple[int, ...]:
        """Determine which axes of the loaded array are reduction axes.

        Finds the innermost reduction stride from the load index
        expression, then walks outward through the buffer's dims
        using stride ratios until the accumulated product reaches
        red_numel.  Falls back to stride-direction analysis for
        gather/flatten loads.
        """
        if not self.load_index_exprs:
            return (-1,)

        r_vars = [v for v, e in self.range_tree_nodes.items() if e.is_reduction]
        pw_vars = [v for v, e in self.range_tree_nodes.items() if not e.is_reduction]
        if not r_vars or not pw_vars:
            return (-1,)

        red_numel = self._compute_reduction_numel()
        if not red_numel or red_numel <= 1:
            return (-1,)

        for buf_name, load_index in self.load_index_exprs.items():
            info = self._get_buffer_info(buf_name)
            if info is None:
                continue
            _, buf_size, _, actual_strides, _ = info
            nd = len(buf_size)
            if nd < 2:
                continue
            strides_or_none = [self._safe_int(s) for s in actual_strides]
            if any(s is None for s in strides_or_none):
                continue
            strides: list[int] = cast(list[int], strides_or_none)

            # Get reduction stride coefficients by zeroing pw_vars.
            r_only = load_index
            for pv in pw_vars:
                r_only = r_only.subs(pv, 0)
            r_coeffs: OrderedSet[int] = OrderedSet()
            for term in sympy.Add.make_args(r_only):
                if term.is_number:
                    continue
                coeff, _ = term.as_coeff_Mul()
                c = self._safe_int(coeff)
                if c is not None and c > 0:
                    r_coeffs.add(c)
            if not r_coeffs:
                continue

            # Match all coefficients against buffer strides
            matched = sorted(
                (i for i in range(nd) if strides[i] in r_coeffs),
            )
            if not matched:
                continue

            # Multiple r_vars each map to a distinct dim — return directly.
            # Single r_var with multiple coefficients (transposed access)
            # → skip to fallback.
            if len(r_coeffs) > 1:
                if len(r_coeffs) == len(matched) and len(r_vars) > 1:
                    return tuple(i - nd for i in matched)
                continue

            # Single coefficient: walk outward from the matched dim
            # using span to find flattened contiguous dims.
            r_stride = next(iter(r_coeffs))
            span = (red_numel - 1) * r_stride
            is_contiguous = all(strides[i] > strides[i + 1] for i in range(nd - 1))
            if is_contiguous:
                # Walk by dim index (strides are in descending order)
                inner = matched[-1]
                start = inner
                while start > 0 and span > strides[start - 1]:
                    start -= 1
                axes = list(range(start, inner + 1))
            else:
                # Non-contiguous layout: collect dims whose strides
                # fall within the r_var's traversal range
                axes = sorted(
                    i
                    for i in range(nd)
                    if r_stride <= strides[i] and strides[i] < span + r_stride
                )
                if not axes:
                    axes = list(matched)
            return tuple(i - nd for i in axes)

        # Fallback: stride-direction for gather/flatten loads
        load_index = next(iter(self.load_index_exprs.values()))
        r_coeff = load_index.coeff(r_vars[0])
        r_stride = self._safe_int(r_coeff) if r_coeff != 0 else 1
        if r_stride is None:
            r_stride = 1
        pw_coeff = load_index.coeff(pw_vars[0])
        pw_stride = self._safe_int(pw_coeff) if pw_coeff != 0 else 1
        if pw_stride is None:
            pw_stride = 1
        if r_stride > pw_stride:
            return (0,)
        return (-1,)

    def _compute_prefix_numel(self, prefixes: OrderedSet) -> int | None:
        """Compute total numel for given prefixes (e.g., pointwise prefixes)."""
        result = 1
        for p in prefixes:
            if p in self.numels:
                numel = self._safe_int(self.numels[p])
                if numel is None:
                    return None
                result *= numel
        return result

    def _compute_reduction_numel(self) -> int | None:
        """Compute total reduction numel."""
        result = 1
        for tree in self.range_trees:
            if tree.is_reduction:
                numel = self._safe_int(tree.numel)
                if numel is None:
                    return None
                result *= numel
        return result

    def _can_use_tma_approach(self) -> bool:
        """
        Check if TMA (Tensor Memory Accelerator) approach can be used.
        TMA works for simple element-wise ops but not for:
        - Reductions (need different accumulation patterns)
          TODO: TMA supports float64 for loading but not for reductions
        - Broadcasting (inputs have different shapes or output differs)
        - Non-contiguous tensors (strided, transposed)
        """
        # TMA flattens to 1D tiles, incompatible with permutation detection
        # which emits jnp.permute_dims expecting N-D input.
        if self.permuted_input_buffers:
            return False

        # Check for reductions
        reduction_numel = self._compute_reduction_numel()
        if reduction_numel is not None and reduction_numel > 1:
            return False

        # Check all input buffers for contiguity, dtype, and shape consistency
        input_shapes: list[tuple] = []
        for name in self.args.input_buffers:
            info = self._get_buffer_info(name)
            if info is None:
                return False
            buf_obj, buf_size, buf_numel, actual_strides, is_contiguous = info
            if not is_contiguous:
                return False

            # Check for unsupported dtypes
            # TODO: TMA supports float64 for loading but current JAX Mosaic GPU
            # implementation doesn't support it yet. Re-enable when JAX adds support.
            buf_dtype = getattr(buf_obj, "get_dtype", lambda: None)()
            if buf_dtype is not None:
                import torch

                if buf_dtype == torch.float64:
                    return False

            # Collect shape as tuple for comparison
            shape_tuple = tuple(self._safe_int(s) for s in buf_size)
            if None in shape_tuple:
                return False  # Dynamic shapes not supported
            input_shapes.append(shape_tuple)

        # Check if all input shapes are identical (no broadcasting)
        if input_shapes and len(OrderedSet(input_shapes)) > 1:
            return False

        # Check that output numel matches input numel (no broadcasting expansion)
        if input_shapes:
            input_numel = 1
            for s in input_shapes[0]:
                input_numel *= s

            # Compute output numel from pointwise range trees (non-reduction)
            output_numel = 1
            for tree in self.range_trees:
                if not tree.is_reduction:
                    numel = self._safe_int(tree.numel)
                    if numel is None:
                        return False  # Dynamic shapes not supported
                    output_numel *= numel

            if output_numel != input_numel:
                return False

        return True

    def _get_buffer_info(self, name: str) -> tuple[Any, Any, Any, list, bool] | None:
        """Get buffer metadata (buf_obj, buf_size, buf_numel, actual_strides, is_contiguous).

        Returns None if the buffer doesn't exist.
        """
        buf_obj = V.graph.get_buffer(name)
        if buf_obj is None:
            return None
        buf_size = buf_obj.get_size()
        buf_numel = 1
        for s in buf_size:
            sval = self._safe_int(s)
            buf_numel *= sval if sval is not None else s

        # Get buffer strides and check contiguity
        actual_strides: list = []
        is_contiguous = True

        layout = getattr(buf_obj, "get_layout", lambda: None)()
        buf_stride = getattr(layout, "stride", None) if layout else None

        if buf_stride is not None:
            for i in range(len(buf_size)):
                actual_stride = self._safe_int(buf_stride[i])
                actual_strides.append(actual_stride)

            # Check contiguity
            if len(buf_size) == 1:
                if actual_strides[0] is not None and actual_strides[0] != 1:
                    is_contiguous = False
            elif len(buf_size) > 1:
                expected_stride = 1
                for i in range(len(buf_size) - 1, -1, -1):
                    actual_stride = actual_strides[i]
                    if actual_stride is None or actual_stride != expected_stride:
                        is_contiguous = False
                    dim_size = self._safe_int(buf_size[i])
                    if dim_size is not None:
                        expected_stride *= dim_size

        return buf_obj, buf_size, buf_numel, actual_strides, is_contiguous

    def _compute_output_numel_from_index(
        self, index: sympy.Expr
    ) -> tuple[int, OrderedSet]:
        """Compute expected output numel and used vars from iteration variables in index."""
        used_vars = self._get_used_iter_vars(index)

        used_range_lengths = []
        for var in used_vars:
            if var in self.range_tree_nodes:
                entry = self.range_tree_nodes[var]
                length_val = self._safe_int(entry.length)
                if length_val is not None:
                    used_range_lengths.append(length_val)

        output_numel = 1
        for l in used_range_lengths:
            output_numel *= l

        return output_numel, used_vars

    def _get_index_coefficients(
        self, index: sympy.Expr, used_vars: OrderedSet
    ) -> OrderedSet:
        """
        Extract coefficients of iteration variables from index expression.
        """
        coefficients: OrderedSet = OrderedSet()
        for var in used_vars:
            var_expr = BlockPatternMatcher.get_subexpr_involving_symbol(index, var)
            stride = BlockPatternMatcher.match_affine_block_expr(var_expr, var)
            if stride is None:
                stride = 1  # Variable without explicit coefficient has stride 1
            coef = self._safe_int(stride)
            coefficients.add(coef if coef is not None else stride)
        return coefficients

    def _check_gather_pattern(
        self,
        buf_size: list,
        actual_strides: list,
        is_contiguous: bool,
        coefficients: OrderedSet,
    ) -> bool:
        """
        Check if access pattern requires gather (non-standard striding).
        """
        expected_strides = [1]  # 1D buffers have stride 1

        if len(buf_size) > 1:
            expected_stride = 1
            expected_strides = []
            for i in range(len(buf_size) - 1, -1, -1):
                expected_strides.insert(0, expected_stride)
                dim_size = self._safe_int(buf_size[i])
                if dim_size is not None:
                    expected_stride *= dim_size

        if is_contiguous:
            # Buffer is contiguous - check if access coefficients match expected strides
            expected_stride_set = OrderedSet(expected_strides)
            for coef in coefficients:
                if coef not in expected_stride_set:
                    return True
        else:
            # Buffer is NOT contiguous (strided input)
            # Check if coefficients match actual buffer strides
            actual_stride_set = OrderedSet(s for s in actual_strides if s is not None)
            for coef in coefficients:
                if coef not in actual_stride_set:
                    return True

        return False

    def _needs_strided_indexing(
        self,
        name: str,
        index: sympy.Expr,
        indexing: _BufferIndexing,
    ) -> _BufferIndexing:
        """Check if buffer access needs strided indexing due to size mismatch or gather patterns."""
        # Only applies when full array access is indicated
        if indexing.index_str != "..." or indexing.needs_flatten:
            return indexing

        info = self._get_buffer_info(name)
        if info is None:
            return indexing

        buf_obj, buf_size, buf_numel, actual_strides, is_contiguous = info
        output_numel, used_vars = self._compute_output_numel_from_index(index)
        all_iter_vars = self._get_iter_vars()
        coefficients = self._get_index_coefficients(index, used_vars)

        # Check for gather pattern
        has_non_unit_strides = self._check_gather_pattern(
            buf_size, actual_strides, is_contiguous, coefficients
        )

        # Check for im2col-like pattern (more iter vars used than buffer dims)
        buf_effective_dims = sum(1 for s in buf_size if self._safe_int(s) != 1)
        not_all_vars_used = (
            len(used_vars) < len(all_iter_vars)
            and len(used_vars) > 0
            and buf_effective_dims > 1
            and len(used_vars) > len(buf_size)
        )

        # Check various conditions for skipping strided indexing
        is_tpu = V.graph.get_current_device_or_throw().type == "tpu"
        is_known_non_contiguous = not is_contiguous and all(
            s is not None for s in actual_strides
        )
        has_symbolic_coef = any(not isinstance(c, int | float) for c in coefficients)
        skip_for_non_contiguous = (
            is_known_non_contiguous and not is_tpu and buf_numel == output_numel
        )

        # Determine if strided indexing is needed
        if (
            output_numel > 0
            and (buf_numel != output_numel or not_all_vars_used or has_non_unit_strides)
            and len(used_vars) > 0
            and not skip_for_non_contiguous
            and not has_symbolic_coef
        ):
            return _BufferIndexing(
                index_str=self._generate_strided_index(index), needs_flatten=True
            )

        return indexing

    def _adjust_index_for_buffer_shape(
        self,
        name: str,
        index: sympy.Expr,
        indexing: _BufferIndexing,
    ) -> _BufferIndexing:
        """
        Adjust index expression based on buffer shape (0-dim scalar, multi-dim, etc.).
        """
        if indexing.needs_flatten or indexing.index_str == "...":
            return indexing

        buf_obj = V.graph.get_buffer(name)
        if buf_obj is None:
            return indexing

        buf_size = buf_obj.get_size()

        # 0-dimensional (scalar) buffer - use [...] to access it
        if len(buf_size) == 0:
            return _BufferIndexing(
                index_str="...", needs_flatten=indexing.needs_flatten
            )

        # Multi-dimensional buffer with constant/scalar index
        if len(buf_size) > 1:
            has_iter_vars = self._has_iteration_vars(index)
            if not has_iter_vars:
                return _BufferIndexing(
                    index_str=indexing.index_str, needs_flatten=True
                )  # Use flattened access
            elif "::" in indexing.index_str:
                # Strided slice patterns need flattened indexing for multi-dim
                return _BufferIndexing(
                    index_str=self._generate_strided_index(index), needs_flatten=True
                )

        # GPU doesn't support gather from slice patterns on 1D buffers
        if self.is_gpu and "::" in indexing.index_str:
            return _BufferIndexing(
                index_str=self._generate_strided_index(index), needs_flatten=True
            )

        return indexing

    def _try_multidim_slice(
        self,
        name: str,
        index: sympy.Expr,
        indexing: _BufferIndexing,
    ) -> _BufferIndexing:
        """
        Try to emit multi-dim slice notation instead of flatten + gather.

        For a buffer with shape (d0, ..., dk) and index `stride * var + offset`,
        emit `buf[:, ..., :, offset::stride]` when stride divides dk.
        """
        if not indexing.needs_flatten:
            return indexing

        buf_obj = V.graph.get_buffer(name)
        if buf_obj is None:
            return indexing

        buf_size = buf_obj.get_size()
        ndim = len(buf_size)
        if ndim < 2:
            return indexing

        # Need a single iteration variable with an affine index
        used_vars = self._get_used_iter_vars(index)
        if len(used_vars) != 1:
            return indexing

        var = next(iter(used_vars))
        var_expr = BlockPatternMatcher.get_subexpr_involving_symbol(index, var)
        stride = self._safe_int(
            BlockPatternMatcher.match_affine_block_expr(var_expr, var)
        )
        if stride is None or stride <= 1:
            return indexing

        offset = V.graph.sizevars.simplify(index - var_expr)
        try:
            offset_val = int(offset)
        except (TypeError, ValueError):
            return indexing

        if offset_val < 0 or offset_val >= stride:
            return indexing

        last_dim = self._safe_int(buf_size[-1])
        if last_dim is None or last_dim % stride != 0:
            return indexing

        # Verify the iteration variable covers all buffer elements at the
        # given stride: var_length * stride == buf_numel. This ensures
        # the flattened stride-access 0, stride, 2*stride, ... maps exactly
        # to buf[:, ..., :, offset::stride].
        entry = self.range_tree_nodes.get(var)
        if entry is None:
            return indexing
        var_length = self._safe_int(entry.length)
        buf_numel = 1
        for s in buf_size:
            d = self._safe_int(s)
            if d is None:
                return indexing
            buf_numel *= d
        if var_length is None or var_length * stride != buf_numel:
            return indexing

        prefix = ":, " * (ndim - 1)
        if offset_val == 0:
            slice_str = f"{prefix}::{stride}"
        else:
            slice_str = f"{prefix}{offset_val}::{stride}"
        return _BufferIndexing(index_str=slice_str, needs_flatten=False)

    @staticmethod
    def _gather_permute_expr(load_expr: str, perm: tuple[int, ...]) -> str:
        """Generate gather-based permutation instead of jnp.permute_dims.

        Avoids a Mosaic compiler bug where jnp.permute_dims produces
        corrupted output tensors on TPU for 3D+ arrays.  Uses
        pallas_permute which flattens to 1D and does a 1D gather.
        """
        return f"pallas_permute({load_expr}, {perm})"

    def _trace_to_load_source(self, var_name: str) -> str | None:
        """Trace a tmp variable back to its source buffer's kernel param.

        Follows CSE assignments backward through bounds-checking (where/clamp)
        until it finds a variable that was directly loaded from a buffer.
        """
        if var_name in self._cse_to_param:
            return self._cse_to_param[var_name]
        for line in self.compute._lines:
            line_str = str(line).lstrip()
            if not line_str.startswith(f"{var_name} = "):
                continue
            for ref in re.findall(r"\btmp\d+\b", line_str.split(" = ", 1)[1]):
                result = self._trace_to_load_source(ref)
                if result is not None:
                    return result
        return None

    def _detect_indirect_access(
        self, buf: str, name: str, index: sympy.Expr
    ) -> _IndirectAccessInfo | None:
        """Detect a load with data-dependent indexing suitable for scalar prefetch.

        Matches exactly one indirect variable whose coefficient corresponds to
        a C-contiguous stride dimension.  Rejects 1-to-1 gather patterns where
        the indices buffer covers the full iteration space.
        """
        buf_info = self._get_buffer_info(name)
        if buf_info is None:
            return None
        _, buf_size, _, _, _ = buf_info
        buf_size_raw = [self._safe_int(s) for s in buf_size]
        if len(buf_size_raw) < 2 or any(s is None for s in buf_size_raw):
            return None
        buf_size_ints: list[int] = cast(list[int], buf_size_raw)

        indirect_vars = self._get_indirect_vars(index)
        if len(indirect_vars) != 1:
            return None
        indirect_var = indirect_vars[0]

        coeff = self._get_index_coefficient(index, indirect_var)
        if coeff == 0 or not isinstance(coeff, int):
            return None

        # Use existing stride mapping to find which dimension is indirected
        strides = self._c_contiguous_strides(buf_size_ints)
        mapping = self._map_coeffs_to_dims([coeff], strides)
        if mapping is None:
            return None
        indirect_dim = mapping[0]

        ndim = len(buf_size_ints)
        if indirect_dim >= max(1, ndim - 2):
            return None

        indirect_var_name = str(indirect_var)
        indices_param = self._trace_to_load_source(indirect_var_name)
        if indices_param is None:
            return None

        # Reject gather patterns: only 1-D static index tensors supported
        indices_graph_name = self._param_to_graph_name.get(indices_param)
        if indices_graph_name is not None:
            indices_info = self._get_buffer_info(indices_graph_name)
            if indices_info is not None:
                _, indices_size, _, _, _ = indices_info
                if len(indices_size) != 1:
                    return None
                if self._safe_int(indices_size[0]) is None:
                    return None
                indices_numel = math.prod(
                    v for s in indices_size if (v := self._safe_int(s)) is not None
                )
                iter_product = math.prod(
                    length
                    for var in self._get_used_iter_vars(index)
                    if var in self.range_tree_nodes
                    if (length := self._safe_int(self.range_tree_nodes[var].length))
                    is not None
                )
                if indices_numel >= iter_product:
                    return None

        return _IndirectAccessInfo(
            table_param=buf,
            table_buf_name=name,
            table_shape=tuple(buf_size_ints),
            indirect_dim=indirect_dim,
            indirect_var=indirect_var_name,
            indices_param=indices_param,
        )

    def _eliminate_dead_indirect_code(self) -> None:
        """Remove dead compute lines after scalar prefetch replaces indirect load.

        When the table load is simplified to buf[0] (scalar prefetch handles
        indexing), the indices load and all derived bounds-checking code become
        dead.  This performs backward liveness analysis from the store variables
        to identify and remove dead lines.
        """
        # Collect variables used by stores (live roots)
        live_vars: OrderedSet[str] = OrderedSet()
        for _, store_line in self.store_with_output:
            for m in re.finditer(r"\btmp\d+\b", store_line):
                live_vars.add(m.group())

        # Parse assignments from compute lines
        assignments: list[tuple[str | None, str, Any]] = []
        for line in self.compute._lines:
            line_str = str(line).lstrip()
            m = re.match(r"^(tmp\d+)\s*=\s*(.*)", line_str, re.DOTALL)
            if m:
                assignments.append((m.group(1), m.group(2), line))
            else:
                assignments.append((None, line_str, line))

        # Propagate liveness backward
        changed = True
        while changed:
            changed = False
            for var_name, rhs, _ in reversed(assignments):
                if var_name and var_name in live_vars:
                    for m in re.finditer(r"\btmp\d+\b", rhs):
                        if m.group() not in live_vars:
                            live_vars.add(m.group())
                            changed = True

        # Keep only live assignments (and non-assignment lines)
        self.compute._lines = [
            line
            for var_name, _, line in assignments
            if var_name is None or var_name in live_vars
        ]

    def _build_load_expr(
        self,
        buf: str,
        name: str,
        index: sympy.Expr,
        indexing: _BufferIndexing,
    ) -> str:
        """
        Build the load expression based on indexing mode.
        """
        if indexing.needs_flatten:
            # Detect indirect (data-dependent) access for scalar prefetch
            indirect = self._detect_indirect_access(buf, name, index)
            if indirect is not None:
                if self.indirect_access is not None:
                    # Fused nodes may re-visit the same indirect load (e.g.
                    # a reduction + pointwise over the same embedding).
                    # Allow that, but reject truly different indirect accesses.
                    assert indirect == self.indirect_access, (
                        "only one indirect access per kernel supported"
                    )
                self.indirect_access = indirect
                return f"{buf}[0]"

            self.has_flatten_indexing = True
            self.flatten_indexed_buffers.add(name)
            # Flatten then index for non-contiguous access (gather operation)
            has_minmax = index.has(sympy.Min) or index.has(sympy.Max)
            idx_dtype = "jnp.int32" if self.is_tpu else "jnp.int64"
            idx = (
                f"({indexing.index_str}).astype({idx_dtype})"
                if has_minmax
                else indexing.index_str
            )
            return f"{buf}[...].flatten()[{idx}]"
        else:
            # Direct indexing for contiguous access
            load_expr = f"{buf}[{indexing.index_str}]"

            if indexing.index_str == "..." and not self.is_gpu:
                perm = self._get_full_load_permutation(name, index)
                if perm is not None:
                    load_expr = self._gather_permute_expr(load_expr, perm)
                    self.permuted_input_buffers[name] = perm
                else:
                    collapsed = self._get_collapsed_load_permutation(name, index)
                    if collapsed is not None:
                        collapsed_shape, cperm = collapsed
                        load_expr = f"jnp.permute_dims({load_expr}, {cperm})"
                        # Don't store cperm in permuted_input_buffers as it's for the reshaped tensor
                        # not the original shape, which causes issues later when used for tiling
                        self.collapsed_reshape_inputs[name] = collapsed_shape
                        self.collapsed_output_shape = tuple(
                            collapsed_shape[p] for p in cperm
                        )

            return load_expr

    def _maybe_squeeze_intermediate_buffer(self, name: str, load_expr: str) -> str:
        """
        Squeeze (N,1) intermediate buffers when kernel has 1D graph inputs.

        This avoids wrong broadcasting: (N,) op (N,1) -> (N,N) instead of (N,)
        """
        if not name.startswith("buf"):
            return load_expr

        # Check if any input buffer is a 1D graph input
        has_1d_input = any(
            not buf_name.startswith("buf")
            and (buf_obj := V.graph.get_buffer(buf_name)) is not None
            and len(buf_obj.get_size()) == 1
            for buf_name in self.args.input_buffers
        )

        if has_1d_input:
            buf_obj = V.graph.get_buffer(name)
            if buf_obj is not None:
                buf_size = buf_obj.get_size()
                if len(buf_size) == 2 and buf_size[-1] == 1:
                    return f"jnp.squeeze({load_expr}, axis=-1)"

        return load_expr

    def _maybe_broadcast_1d_buffer(
        self, name: str, index: sympy.Expr, load_expr: str
    ) -> str:
        """Reshape 1D buffers for higher-dim broadcasting in reduction kernels.

        When a 1D buffer (e.g. a reduction result from a prior kernel, or a
        batch-norm parameter) is loaded into a kernel with 2+ iteration dims,
        JAX right-aligns it for broadcasting: (N,) becomes (1, N).  This is
        wrong when the buffer corresponds to a non-trailing axis; we reshape
        to (N, 1, ...) so broadcasting matches the correct axis.
        """
        buf_obj = V.graph.get_buffer(name)
        if buf_obj is None or len(buf_obj.get_size()) != 1:
            return load_expr

        # Only graph inputs, not intermediate buffers — intermediates are
        # already shaped by the IR and their dim order may not match the
        # reference buffer used below for axis inference.
        if name.startswith("buf"):
            return load_expr

        buf_length = self._safe_int(buf_obj.get_size()[0])
        if buf_length is None:
            return load_expr

        dtype = V.graph.get_dtype(name)
        if dtype is not None and not dtype.is_floating_point:
            return load_expr

        # Find a higher-dimensional reference buffer
        ref_buf_size = None
        for buf_name in self.args.input_buffers:
            other_buf = V.graph.get_buffer(buf_name)
            if other_buf is not None and len(other_buf.get_size()) > 1:
                ref_buf_size = [self._safe_int(s) for s in other_buf.get_size()]
                if all(s is not None for s in ref_buf_size):
                    break
                ref_buf_size = None
        if ref_buf_size is None or len(ref_buf_size) <= 1:
            return load_expr

        # Must use exactly one iteration variable
        used_vars = self._get_used_iter_vars(index)
        if len(used_vars) != 1:
            return load_expr
        used_var = next(iter(used_vars))
        if used_var not in self.range_tree_nodes:
            return load_expr

        # Verify buffer length matches variable length
        entry = self.range_tree_nodes[used_var]
        if self._safe_int(entry.length) != buf_length:
            return load_expr

        # Buffer length must uniquely match one non-reduction iteration variable.
        # If multiple pointwise vars share the same length (e.g. 2D pointwise
        # kernel with both dims equal), the axis is ambiguous and we bail out.
        matching_vars = [
            v
            for v, e in self.range_tree_nodes.items()
            if self._safe_int(e.length) == buf_length and not e.is_reduction
        ]
        if len(matching_vars) != 1:
            return load_expr

        # Determine axis position from the iteration variable's position
        # in the range tree (pointwise vars first, then reduction vars).
        axis_pos = None
        matching_dims = [i for i, s in enumerate(ref_buf_size) if s == buf_length]
        if len(matching_dims) == 1:
            axis_pos = matching_dims[0]
        else:
            # Ambiguous by size (e.g. square tensor with reduction).
            # Use the variable's position in the range tree.
            pw_idx = 0
            for sym, e in self.range_tree_nodes.items():
                if sym == used_var:
                    axis_pos = pw_idx
                    break
                if not e.is_reduction:
                    pw_idx += 1

        if axis_pos is None:
            return load_expr
        if axis_pos == len(ref_buf_size) - 1:
            return load_expr  # Last dim uses default broadcasting

        reshape_dims = [1] * len(ref_buf_size)
        reshape_dims[axis_pos] = -1
        return f"{load_expr}.reshape({', '.join(map(str, reshape_dims))})"

    def _check_im2col_pattern(
        self, index: sympy.Expr, indexing: _BufferIndexing
    ) -> _BufferIndexing:
        """
        Check for im2col-like patterns where store uses block variables but load doesn't.

        For cat/expand patterns, both load and store prepared indices share block vars.
        For im2col patterns, store compresses to block vars but load doesn't.
        """
        if indexing.index_str != "..." or indexing.needs_flatten:
            return indexing

        prepared_index = self.prepare_indexing(index)
        iter_vars = self._get_iter_vars()
        store_orig_vars = self._get_used_iter_vars(index)
        store_prep_vars = (
            prepared_index.free_symbols
            if hasattr(prepared_index, "free_symbols")
            else OrderedSet()
        ) & iter_vars
        new_vars = store_prep_vars - store_orig_vars

        # Only trigger if store introduces new block vars
        if not new_vars or len(store_orig_vars) <= 1:
            return indexing

        # Check if loads are compatible with broadcast or cat pattern
        has_im2col_pattern = False
        for buf_name, load_index in self.load_index_exprs.items():
            load_orig_vars = self._get_used_iter_vars(load_index)
            if not load_orig_vars:
                continue

            # Load has iteration variables
            if load_orig_vars != store_orig_vars:
                continue

            # Same vars - check if load gets compressed too
            prep_load = self.prepare_indexing(load_index)
            load_prep_vars = (
                prep_load.free_symbols
                if hasattr(prep_load, "free_symbols")
                else OrderedSet()
            ) & iter_vars

            # If store compresses but load doesn't, check for strided input vs im2col
            if load_orig_vars != load_prep_vars or store_prep_vars == store_orig_vars:
                continue

            # Check if load coefficients match buffer strides
            if not self._check_load_is_strided_input(
                buf_name, load_index, load_orig_vars
            ):
                has_im2col_pattern = True
                break

        if has_im2col_pattern:
            return _BufferIndexing(
                index_str=self._generate_strided_index(prepared_index),
                needs_flatten=True,
            )

        return indexing

    def _check_load_is_strided_input(
        self, buf_name: str, load_index: sympy.Expr, load_orig_vars: OrderedSet
    ) -> bool:
        """
        Check if load coefficients match buffer strides (strided input vs im2col).
        """
        buf = V.graph.get_buffer(buf_name)
        if buf is None:
            return False

        layout = getattr(buf, "get_layout", lambda: None)()
        if layout is None:
            return False

        buf_strides = getattr(layout, "stride", None)
        if buf_strides is None:
            return False

        buf_sizes = buf.get_size()

        # Get load coefficients
        load_coeffs = []
        for var in load_orig_vars:
            var_expr = BlockPatternMatcher.get_subexpr_involving_symbol(load_index, var)
            coef = BlockPatternMatcher.match_affine_block_expr(var_expr, var)
            if coef is not None:
                int_coef = self._safe_int(coef)
                load_coeffs.append(int_coef if int_coef is not None else coef)

        # Check if coefficients match buffer strides
        # Only include strides for non-trivial dimensions (size > 1)
        buf_stride_set = OrderedSet()
        for i, s in enumerate(buf_strides):
            dim_size = self._safe_int(buf_sizes[i])
            if dim_size is None or dim_size > 1:
                int_s = self._safe_int(s)
                buf_stride_set.add(int_s if int_s is not None else s)

        return OrderedSet(load_coeffs) == buf_stride_set

    def _check_store_needs_transpose(self, name: str) -> bool:
        """
        Check if output needs transpose for column-major storage.

        Transpose on store is needed when:
        - Output has column-major stride (s0 < s1)
        - But input(s) have row-major stride
        - And we haven't already transposed on load
        """
        if self.permuted_input_buffers:
            return False

        info = self._get_buffer_info(name)
        if info is None:
            return False

        _, buf_size, _, actual_strides, _ = info
        if len(actual_strides) != 2 or len(buf_size) != 2:
            return False

        size0 = self._safe_int(buf_size[0])
        size1 = self._safe_int(buf_size[1])
        s0 = actual_strides[0]
        s1 = actual_strides[1]

        # Check if output is column-major with valid dimensions
        if not (
            s0 is not None
            and s1 is not None
            and s0 < s1
            and size0 is not None
            and size1 is not None
            and size0 > 1
            and size1 > 1
        ):
            return False

        # Check if any input is column-major (if so, no transpose needed)
        for inp_name in self.args.input_buffers:
            inp_info = self._get_buffer_info(inp_name)
            if inp_info is None:
                continue
            _, _, _, inp_strides, _ = inp_info
            if len(inp_strides) != 2:
                continue
            inp_s0 = inp_strides[0]
            inp_s1 = inp_strides[1]
            if inp_s0 is not None and inp_s1 is not None and inp_s0 < inp_s1:
                return False  # Input is also column-major

        return True

    def _build_full_array_store_expr(
        self, out: str, value: CSEVariable, needs_transpose: bool
    ) -> list[str]:
        """
        Build store expression for full array assignment.

        Handles scalar broadcast, shape matching, and optional transpose.
        Returns a list of lines to emit (variable assignment + store).
        """
        lines = [f"_val = jnp.asarray({value})"]
        if needs_transpose:
            lines.append(
                f"{out}[...] = "
                f"jnp.full({out}.shape, _val) if _val.ndim == 0 "
                f"else jnp.transpose(_val)"
            )
        else:
            lines.append(
                f"{out}[...] = "
                f"jnp.full({out}.shape, _val) if _val.ndim == 0 "
                f"else (_val.reshape({out}.shape) if _val.size == {out}.size "
                f"else jnp.broadcast_to(_val, {out}.shape))"
            )
        return lines

    def _build_store_expr(
        self,
        out: str,
        name: str,
        index: sympy.Expr,
        value: CSEVariable,
        indexing: _BufferIndexing,
        mode: Any = None,
    ) -> list[str]:
        """
        Build the store expression based on indexing mode.
        mode can be None (set) or "atomic_add" (accumulate).
        Returns a list of lines to emit.
        """
        if indexing.index_str == "...":
            # Full array store with shape matching
            needs_transpose = self._check_store_needs_transpose(name)
            return self._build_full_array_store_expr(out, value, needs_transpose)

        if indexing.needs_flatten:
            self.has_flatten_indexing = True
            # Block variable indexing (e.g., im2col) - use flattened scatter
            scatter_op = "add" if mode == "atomic_add" else "set"
            return [
                f"{out}[...] = {out}[...].flatten().at[({indexing.index_str}).flatten()].{scatter_op}("
                f"jnp.asarray({value}).flatten()).reshape({out}.shape)"
            ]

        # Direct indexed assignment
        has_indirect = self._has_indirect_vars(index)
        buf = V.graph.get_buffer(name)

        if buf is not None:
            buf_size = buf.get_size()
            if len(buf_size) > 1 and not self._has_iteration_vars(index):
                # Multi-dim output with constant index - use [...] for full assignment
                return self._build_full_array_store_expr(out, value, False)

        if has_indirect:
            # Indirect indexed store (scatter): use .add() for atomic_add, .set() otherwise
            scatter_op = "add" if mode == "atomic_add" else "set"
            lines = [f"_val = jnp.asarray({value})"]
            value_expr = f"(jnp.full({indexing.index_str}.shape, _val) if _val.ndim == 0 else {value})"
            if mode == "atomic_add":
                # For atomic_add, mark output as needing to be readable (for aliasing)
                self.outputs_need_read.add(out)
                alias_param = f"{out}_alias"
                lines.append(
                    f"{out}[...] = {alias_param}[...].flatten().at[({indexing.index_str}).flatten()].{scatter_op}("
                    f"{value_expr}.flatten()).reshape({out}.shape)"
                )
            else:
                lines.append(f"{out}[{indexing.index_str}] = {value_expr}")
            return lines

        return [f"{out}[{indexing.index_str}] = {value}"]

    def _build_scatter_store_expr(
        self,
        out: str,
        value: CSEVariable,
        scatter_info: dict[str, Any],
        name: str,
        mode: Any,
    ) -> str:
        """Build store expression for scatter operations (indirect indexing)."""
        is_point_scatter = scatter_info.get("is_point_scatter", False)

        # Mark this output parameter as needing to be readable (for aliasing)
        self.outputs_need_read.add(out)
        alias_param = f"{out}_alias"

        # Use .add() for atomic_add mode, .set() otherwise
        scatter_op = "add" if mode == "atomic_add" else "set"

        if is_point_scatter:
            # Single-element scatter
            indirect_var = scatter_info["indirect_var"]
            indirect_dim = scatter_info["indirect_dim"]
            output_shape = scatter_info["output_shape"]

            # Build index tuple with 0s for other dimensions
            index_parts = []
            for dim in range(len(output_shape)):
                if dim == indirect_dim:
                    index_parts.append(indirect_var)
                else:
                    index_parts.append("0")

            index_tuple = ", ".join(index_parts)
            return f"{out}[...] = {alias_param}[...].at[{index_tuple}].{scatter_op}({value})"

        # Scatter with iteration variables
        indirect_var = scatter_info["indirect_var"]
        dims_before = scatter_info["dims_before"]
        dims_after = scatter_info["dims_after"]

        # Determine if element-wise or slice-based scatter
        buf = V.graph.get_buffer(name)
        output_ndim = len(buf.get_size()) if buf is not None else 0

        num_iter_vars_in_store = len(dims_before) + len(dims_after)
        total_kernel_iter_vars = len(self.range_tree_nodes)
        remaining_dims = output_ndim - 1  # dims other than indirect

        is_element_wise = (
            num_iter_vars_in_store == remaining_dims
            and num_iter_vars_in_store == total_kernel_iter_vars
        )

        if is_element_wise:
            # Element-wise scatter: use iteration variable names
            index_parts = [var_name for var_name, size in dims_before]

            # Reshape indirect var for broadcasting if needed
            n_leading = len(dims_before)
            n_trailing = len(dims_after)
            if n_leading > 0 and n_trailing > 0:
                leading_ones = "None, " * n_leading
                trailing_nones = ", None" * n_trailing
                indirect_reshaped = f"{indirect_var}[{leading_ones}...{trailing_nones}]"
            else:
                indirect_reshaped = indirect_var
            index_parts.append(indirect_reshaped)

            index_parts.extend(var_name for var_name, size in dims_after)
        else:
            # Slice-based scatter: use : for iteration dimensions
            index_parts = [":" for _ in dims_before]
            index_parts.append(indirect_var)
            index_parts.extend(":" for _ in dims_after)

        index_tuple = ", ".join(index_parts)
        return (
            f"{out}[...] = {alias_param}[...].at[{index_tuple}].{scatter_op}({value})"
        )

    @typing_extensions.override
    def load(self, name: str, index: sympy.Expr) -> CSEVariable:
        buf = self.args.input(name)
        dtype = V.graph.get_dtype(name)

        # Track the load index expression for argmax/argmin axis detection
        self.load_index_exprs[name] = index

        # Get base index expression
        indexing = self._get_index_expr(index)

        # Check for buffer size mismatch requiring strided indexing
        indexing = self._needs_strided_indexing(name, index, indexing)

        # Try strided decomposition before multidim slice or flatten.
        # This generates reshape + static indexing which works on both
        # CPU and TPU (unlike slice notation which fails on Mosaic).
        decomp = self._decompose_strided_access(index, name)
        if decomp is not None:
            self.strided_input_buffers[name] = decomp
            load_expr = self._strided_load_expr(buf, decomp)
        else:
            # Adjust index for buffer shape (scalar, multi-dim, etc.)
            indexing = self._adjust_index_for_buffer_shape(name, index, indexing)

            # Try to emit multi-dim slice instead of flatten + gather
            indexing = self._try_multidim_slice(name, index, indexing)

            # Build the load expression
            load_expr = self._build_load_expr(buf, name, index, indexing)

        # Handle intermediate buffer squeezing for correct broadcasting
        if not indexing.needs_flatten and indexing.index_str == "...":
            load_expr = self._maybe_squeeze_intermediate_buffer(name, load_expr)
            # Handle 1D buffer broadcasting for higher-dimensional kernels
            load_expr = self._maybe_broadcast_1d_buffer(name, index, load_expr)

        cse_var = self.cse.generate(
            self.compute,
            load_expr,
            dtype=dtype,
        )
        # Track CSE var -> param -> graph name for indirect access detection
        buf_param = self.args.input(name)
        self._cse_to_param[str(cse_var)] = buf_param
        self._param_to_graph_name[buf_param] = name
        return cse_var

    def _handle_mixed_indexing(self, index: sympy.Expr) -> str:
        """
        Handle indexing with both indirect variables and iteration variables.

        For example, x[indices, :] generates index = i0 + stride * tmp0
        where tmp0 is loaded from indices and i0 is the iteration variable.

        We need to convert this to JAX advanced indexing with proper broadcasting.
        When there are multiple iteration variables, they need different shapes
        to form an outer product (grid) rather than broadcasting together.

        Special case: For gather operations where a single iteration variable
        and single indirect variable have the same extent, they should be
        element-wise aligned, not broadcast into an outer product.

        PyTorch advanced indexing semantics: When multiple indirect indices have
        the same shape, they are paired element-wise (not outer product), and
        the combined result dimension appears at the FRONT of the output.
        """
        used_iter_vars_set = self._get_used_iter_vars(index)

        # Track which iteration variables are used
        self.used_iter_vars.update(used_iter_vars_set)

        if len(used_iter_vars_set) == 0:
            return self.kexpr(index)

        # Sort iteration variables by their coefficient (stride) in the index expression.
        # Variables with larger strides correspond to earlier output dimensions.
        # Use inf default so symbolic coefficients sort as outermost dimensions.
        def _coeff(var):
            return self._get_index_coefficient(index, var, default=float("inf"))

        used_iter_vars = sorted(used_iter_vars_set, key=_coeff, reverse=True)
        iter_coeffs = [_coeff(var) for var in used_iter_vars]

        # Rename symbolic sizes to kernel parameter names
        index_str = self.kexpr(self.rename_indexing(index))
        indirect_var_syms = self._get_indirect_vars(index)
        indirect_vars = [str(sym) for sym in indirect_var_syms]

        # Get coefficients for indirect vars to determine output ordering
        indirect_coeffs = {str(s): _coeff(s) for s in indirect_var_syms}

        # Special case: reduction var + single indirect var = element-wise gather
        # Reduction vars (r prefix) iterate over the reduction dimension, and when paired
        # with an indirect var, both are aligned to that dimension (element-wise).
        # Pointwise vars form output dimensions and need the complex reshape code.
        if len(used_iter_vars) == 1 and len(indirect_vars) == 1:
            var = used_iter_vars[0]
            var_name = str(var)
            is_reduction_var = (
                var in self.range_tree_nodes and self.range_tree_nodes[var].is_reduction
            )

            if is_reduction_var:
                # Reduction var: simple element-wise gather
                if var in self.range_tree_nodes:
                    range_entry = self.range_tree_nodes[var]
                    range_size = range_entry.length
                    # Rename to use kernel parameter names for symbolic sizes
                    renamed_size = self.rename_indexing(range_size)
                    arange_expr = f"jnp.arange({self.kexpr(renamed_size)})"
                    index_str = index_str.replace(var_name, arange_expr)
                return index_str
            # For pointwise vars, fall through to the complex reshape code

        # Check if multiple indirect vars should be paired element-wise.
        # In PyTorch, when multiple advanced indices have the same shape, they pair up.
        # The paired dimension goes to the FRONT of the output.
        # However, if indirect vars have different shapes (e.g., (1,4) and (4,1)),
        # they form an outer product instead.
        # We detect element-wise pairing when:
        # 1. Multiple indirect vars exist
        # 2. There's exactly ONE unused iteration variable (for the shared paired dim)
        # For outer product, there are MULTIPLE unused iter vars (one per indirect dim)
        paired_indirect = False
        if len(indirect_vars) > 1:
            # Count unused iteration variables (defined but not in index expression)
            unused_iter_vars = self._get_iter_vars() - used_iter_vars_set
            # Element-wise pairing: one unused iter var for the shared paired dimension
            # Outer product: multiple unused iter vars (one for each indirect var dimension)
            paired_indirect = len(unused_iter_vars) == 1

        if paired_indirect:
            # Multiple indirect vars with element-wise pairing
            # Output order: (paired_indirect_dim, iter_var_dims...)
            # All indirect vars get the same shape: (N, 1, 1, ...) for first dim
            # Iter vars come after: second dim onwards

            # Count total output dims: 1 (paired) + len(iter_vars) for non-newaxis
            # But some iter vars may be for newaxis dimensions (size 1)
            n_output_dims = 1 + len(used_iter_vars)

            # Reshape indirect vars to occupy the first dimension
            for indirect_var in indirect_vars:
                trailing_ones = ", 1" * len(used_iter_vars)
                reshape_expr = f"{indirect_var}.reshape(-1{trailing_ones})"
                index_str = index_str.replace(indirect_var, reshape_expr)

            # Reshape iteration variables to occupy subsequent dimensions
            # Sort by coefficient (descending) to determine order
            for i, var in enumerate(used_iter_vars):
                var_name = str(var)
                if var in self.range_tree_nodes:
                    range_entry = self.range_tree_nodes[var]
                    range_size = range_entry.length
                    # Rename to use kernel parameter names for symbolic sizes
                    renamed_size = self.rename_indexing(range_size)

                    # Shape: (1, ..., N, ..., 1) where N is at position i+1
                    # Position 0 is for paired indirect vars
                    shape_parts = ["1"] * n_output_dims
                    shape_parts[i + 1] = self.kexpr(renamed_size)
                    shape_str = ", ".join(shape_parts)
                    arange_expr = (
                        f"jnp.arange({self.kexpr(renamed_size)}).reshape({shape_str})"
                    )

                    index_str = index_str.replace(var_name, arange_expr)

            return index_str

        # Single indirect var case (or no indirect vars handled above)
        # Build a sorted list of all components by coefficient (descending)
        # Each component is (coeff, type, var) where type is 'iter' or 'indirect'
        all_components = []
        for var in used_iter_vars:
            all_components.append((_coeff(var), "iter", var))
        for sym in indirect_var_syms:
            all_components.append((_coeff(sym), "indirect", sym))
        all_components.sort(key=lambda x: x[0], reverse=True)

        # Calculate trailing dims needed for each component
        # Each component needs trailing dims for all subsequent iter vars
        # plus trailing dims for all dimensions of subsequent indirect vars
        # For simplicity, assume each indirect var contributes some dimensions
        # that will be handled by the reshape at store time

        # For iter vars, we need to count how many dimensions come after in the output
        for i, var in enumerate(used_iter_vars):
            var_name = str(var)
            if var in self.range_tree_nodes:
                range_entry = self.range_tree_nodes[var]
                range_size = range_entry.length
                # Rename to use kernel parameter names for symbolic sizes
                renamed_size = self.rename_indexing(range_size)
                var_coeff = _coeff(var)

                arange_expr = f"jnp.arange({self.kexpr(renamed_size)})"

                # Count trailing dims needed:
                # - One for each subsequent iter var (with smaller coeff)
                # - One for each dimension of indirect vars with smaller coeff
                # For indirect vars, assume each contributes 2 dims (common case)
                # The actual reshape at store time will fix any shape mismatches
                n_trailing_iter = sum(1 for c in iter_coeffs if c < var_coeff)
                n_trailing_indirect = sum(
                    2 for c in indirect_coeffs.values() if c < var_coeff
                )
                n_trailing = n_trailing_iter + n_trailing_indirect

                if n_trailing > 0:
                    trailing_dims = ", None" * n_trailing
                    arange_expr = f"{arange_expr}[:{trailing_dims}]"

                index_str = index_str.replace(var_name, arange_expr)

        # Reshape indirect variables for proper broadcasting.
        for indirect_var in indirect_vars:
            indirect_coeff = indirect_coeffs[indirect_var]

            # Count dims needed before and after this indirect var
            n_leading = sum(1 for c in iter_coeffs if c > indirect_coeff)
            n_trailing = sum(1 for c in iter_coeffs if c < indirect_coeff)

            # Build the indexing expression with leading Nones, ellipsis, trailing Nones
            if n_leading > 0 and n_trailing > 0:
                leading_nones = "None, " * n_leading
                trailing_nones = ", None" * n_trailing
                reshape_expr = f"{indirect_var}[{leading_nones}...{trailing_nones}]"
            elif n_leading > 0:
                leading_nones = "None, " * n_leading
                reshape_expr = f"{indirect_var}[{leading_nones}...]"
            elif n_trailing > 0:
                trailing_nones = ", None" * n_trailing
                reshape_expr = f"{indirect_var}[...{trailing_nones}]"
            else:
                reshape_expr = indirect_var

            index_str = index_str.replace(indirect_var, reshape_expr)

        return index_str

    @typing_extensions.override
    def store(
        self, name: str, index: sympy.Expr, value: CSEVariable, mode: Any = None
    ) -> None:
        # mode can be None (set), "atomic_add" (accumulate), etc.
        if mode is not None and mode != "atomic_add":
            raise Unsupported(f"pallas store mode '{mode}' not supported")
        out = self.args.output(name)
        self.store_buffer_names.add(name)

        # Check if this is a scalar output (reduction to scalar)
        buf = V.graph.get_buffer(name)
        is_scalar = buf is not None and len(buf.get_size()) == 0

        if is_scalar:
            store_lines = [
                f"_val = jnp.asarray({value})",
                f"{out}[...] = jnp.full({out}.shape, _val) if _val.ndim == 0 else _val.reshape({out}.shape)",
            ]
        else:
            # When collapsed_output_shape is set, the load-side permutation
            # already produces data in the correct layout for the collapsed
            # output.  Force a full-array store ("...") so the scatter index
            # (which was computed for the original output layout) does not
            # rearrange the permuted data.
            if self.collapsed_output_shape is not None:
                store_lines = self._build_full_array_store_expr(out, value, False)
            else:
                # Check for scatter pattern (indirect indexing for stores)
                scatter_info = self._detect_scatter_pattern(index, name)

                if scatter_info is not None:
                    # Track iteration variables used in scatter index
                    self.used_iter_vars.update(self._get_used_iter_vars(index))
                    store_lines = [
                        self._build_scatter_store_expr(
                            out, value, scatter_info, name, mode
                        )
                    ]
                else:
                    # Get base index expression
                    indexing = self._get_index_expr(index)

                    # Check for im2col-like patterns
                    indexing = self._check_im2col_pattern(index, indexing)

                    # Build the store expression
                    store_lines = self._build_store_expr(
                        out, name, index, value, indexing, mode
                    )

        for line in store_lines:
            self.stores.writeline(line)
            # Track which output param this store uses for filtering in codegen_kernel
            self.store_with_output.append((out, line))

    @staticmethod
    def _get_index_coefficient(
        index: sympy.Expr, var: sympy.Symbol, default: int | float = 0
    ) -> int | float:
        """Get integer coefficient of a variable in an index expression."""
        coeff = index.coeff(var)
        if coeff == 0:
            coeff = sympy.diff(index, var)
        try:
            return int(coeff)
        except (TypeError, ValueError):
            return default

    def _detect_scatter_pattern(
        self, index: sympy.Expr, output_name: str = ""
    ) -> dict[str, Any] | None:
        """Detect scatter operation pattern. Returns scatter info dict or None."""
        indirect_syms = self._get_indirect_vars(index)
        if len(indirect_syms) != 1:
            return None

        indirect_sym = indirect_syms[0]
        indirect_var = str(indirect_sym)
        indirect_coeff: int = int(self._get_index_coefficient(index, indirect_sym))
        if indirect_coeff == 0:
            return None

        # Point scatter: no iteration variables, just indirect indexing
        if not self._has_iteration_vars(index):
            return self._detect_point_scatter(output_name, indirect_var, indirect_coeff)

        # Regular scatter: has both indirect and iteration variables
        return self._detect_iter_scatter(index, indirect_var, indirect_coeff)

    def _detect_point_scatter(
        self, output_name: str, indirect_var: str, indirect_coeff: int
    ) -> dict[str, Any] | None:
        """Detect single-element scatter pattern."""
        if not output_name:
            return None
        try:
            buf = V.graph.get_buffer(output_name)
            output_shape = [int(s) for s in buf.get_size()]
        except Exception:
            return None

        if len(output_shape) < 2:
            return None

        # Find which dimension indirect var indexes based on coefficient
        cumulative = 1
        indirect_dim = len(output_shape) - 1
        for dim in range(len(output_shape) - 1, -1, -1):
            if indirect_coeff == cumulative:
                indirect_dim = dim
                break
            cumulative *= output_shape[dim]

        return {
            "indirect_var": indirect_var,
            "indirect_dim": indirect_dim,
            "dims_before": [],
            "dims_after": [],
            "is_point_scatter": True,
            "output_shape": output_shape,
        }

    def _detect_iter_scatter(
        self, index: sympy.Expr, indirect_var: str, indirect_coeff: int
    ) -> dict[str, Any] | None:
        """Detect scatter pattern with iteration variables."""
        used_iter_vars = self._get_used_iter_vars(index)

        # Collect (var_name, coefficient, length) for each variable
        all_vars: list[tuple[str, int, int]] = []
        for var in used_iter_vars:
            coeff = int(self._get_index_coefficient(index, var))
            if coeff > 0 and var in self.range_tree_nodes:
                length = self._safe_int(self.range_tree_nodes[var].length)
                if length is None:
                    return None
                all_vars.append((str(var), coeff, length))

        all_vars.append((indirect_var, indirect_coeff, -1))
        all_vars.sort(key=lambda x: x[1], reverse=True)

        # Find indirect variable position
        indirect_pos = next(
            (i for i, (name, _, _) in enumerate(all_vars) if name == indirect_var),
            None,
        )
        if indirect_pos is None:
            return None

        # Verify coefficients form valid stride pattern
        expected = 1
        for _, coeff, length in reversed(all_vars[indirect_pos + 1 :]):
            if coeff != expected:
                return None
            expected *= length
        if indirect_coeff != expected:
            return None

        return {
            "indirect_var": indirect_var,
            "indirect_dim": indirect_pos,
            "dims_before": [(n, l) for n, _, l in all_vars[:indirect_pos]],
            "dims_after": [(n, l) for n, _, l in all_vars[indirect_pos + 1 :]],
            "is_point_scatter": False,
            "output_shape": None,
        }

    def reduction(
        self,
        dtype: torch.dtype,
        src_dtype: torch.dtype,
        reduction_type: ReductionType,
        value: CSEVariable | tuple[CSEVariable, ...],
    ) -> CSEVariable | tuple[CSEVariable, ...]:  # type: ignore[override]
        """
        Generate code for reduction operations in JAX/Pallas.

        Reductions in Pallas work by:
        1. Loading the input data into the kernel
        2. Applying JAX reduction operations (jnp.sum, jnp.max, etc.)
        3. Storing the reduced result

        The reduction happens over the loaded block of data.
        """
        assert self.inside_reduction

        # Handle welford_reduce using the fallback (computes via sum reductions)
        if reduction_type == "welford_reduce":
            return self.welford_reduce_fallback(dtype, value)

        if isinstance(value, tuple):
            raise Unsupported(
                "Tuple reductions (e.g., welford_combine) not supported in Pallas backend"
            )

        # Check if this reduction is already cached.
        cache_key = (src_dtype, reduction_type, value)
        if cache_key in self.cse.reduction_cache:
            return self.cse.reduction_cache[cache_key]

        # Map reduction types to JAX functions
        reduction_ops = {
            "sum": "jnp.sum",
            "prod": "jnp.prod",  # CPU only - not supported in Pallas GPU (Mosaic) backend
            "max": "jnp.max",
            "min": "jnp.min",
            "any": "jnp.any",
            "argmax": "jnp.argmax",
            "argmin": "jnp.argmin",
        }

        # Determine if this is a partial reduction (has pointwise dimensions)
        # or a full reduction to scalar
        pointwise_prefixes = OrderedSet(["x", "y", "z"])
        has_pointwise = any(p in self.numels for p in pointwise_prefixes)
        pointwise_numel: int | None = self._compute_prefix_numel(pointwise_prefixes)
        reduction_numel: int | None = self._compute_reduction_numel()
        n_reduction_dims = sum(
            1 for var, entry in self.range_tree_nodes.items() if entry.is_reduction
        )

        is_partial_reduction = (
            has_pointwise
            and pointwise_numel is not None
            and pointwise_numel > 1
            and reduction_numel
            and n_reduction_dims > 0
        )
        is_symbolic_partial = (
            has_pointwise and n_reduction_dims > 0 and pointwise_numel is None
        )

        if reduction_type == "xor_sum":
            if is_partial_reduction:
                axes = self._get_reduction_axes()
                axis_expr = axes[0] if len(axes) == 1 else axes
                reduction_expr = f"jnp.bitwise_xor.reduce({value}, axis={axis_expr})"
            else:
                reduction_expr = f"jnp.bitwise_xor.reduce({value})"
        elif reduction_type in ("argmax", "argmin"):
            reduction_op = reduction_ops[reduction_type]
            if is_partial_reduction:
                # argmax/argmin only accept a single axis
                axes = self._get_reduction_axes()
                reduction_expr = f"{reduction_op}({value}, axis={axes[-1]})"
            else:
                reduction_expr = f"{reduction_op}({value})"
        elif reduction_type in reduction_ops:
            reduction_op = reduction_ops[reduction_type]
            if is_partial_reduction:
                axes = self._get_reduction_axes()
                axis_expr = axes[0] if len(axes) == 1 else axes
                reduction_expr = (
                    f"{reduction_op}({value}, axis={axis_expr}, keepdims=True)"
                )
            elif is_symbolic_partial:
                # With symbolic shapes, strided loads produce a degenerate
                # batch dim at axis=0 that just needs squeezing.
                reduction_expr = f"{reduction_op}({value}, axis=0)"
            else:
                reduction_expr = f"{reduction_op}({value})"
        else:
            raise Unsupported(
                f"Reduction type '{reduction_type}' not yet supported in Pallas backend. "
                f"Supported types: {list(reduction_ops.keys())}, xor_sum"
            )

        # Generate CSE variable for the reduction result
        result = self.cse.generate(
            self.compute,
            reduction_expr,
            dtype=dtype,
        )

        # Cache the result
        self.cse.reduction_cache[cache_key] = result
        return result

    @staticmethod
    def _buffer_is_contiguous(buffer_name: str) -> bool:
        buf = V.graph.get_buffer(buffer_name)
        layout = buf.get_layout()
        return layout.is_contiguous()

    def _can_tile_cpu_tpu(self) -> bool:
        """Check if this kernel can use tiling on CPU/TPU.

        Tiling is compatible with reductions, transpositions, and multi-range-tree
        kernels as long as no flatten-based indexing is used (buf[...].flatten()[idx]).
        Flatten indexing requires global flat indices which don't decompose into
        per-tile local indices.

        Reject:
        - GPU (has its own TMA / padding path)
        - Flatten-based indexing
        - Scatter outputs (indirect indexing complicates tile boundaries)
        """
        if self.is_gpu:
            return False
        if self.has_flatten_indexing:
            return False
        if self.outputs_need_read:
            return False

        # If iteration variables appear in the compute body (not just in
        # load/store index resolution that collapses to [...]), tiling is
        # unsafe because the arange-based vars have full-tensor shapes.
        # Exception: vars emitted in tile-relative form are safe.
        if self.used_iter_vars:
            compute_text = "\n".join(str(line) for line in self.compute._lines)
            for var_sym in self.used_iter_vars:
                if var_sym in self.tile_relative_iter_vars:
                    continue
                if str(var_sym) in compute_text:
                    return False

        # Determine the reference output shape (highest-ndim output).
        out_bufs = list(self.args.output_buffers.keys())

        # Only check the current kernel's actual output buffers for transpose,
        # not _has_column_major_output() which scans all graph buffers and can
        # be triggered by unrelated intermediates (e.g., (N,1) reductions with
        # degenerate column-major strides).
        has_col_major_out = False
        for buf_name in out_bufs:
            info = self._get_buffer_info(buf_name)
            if info is None:
                continue
            _, buf_size, _, actual_strides, _ = info
            if len(actual_strides) >= 2 and len(buf_size) >= 2:
                s0 = actual_strides[0]
                s1 = actual_strides[1]
                d0 = self._safe_int(buf_size[0])
                d1 = self._safe_int(buf_size[1])
                if (
                    s0 is not None
                    and s1 is not None
                    and s0 < s1
                    and d0 is not None
                    and d1 is not None
                    and d0 > 1
                    and d1 > 1
                ):
                    has_col_major_out = True
                    break
        self.tile_has_transpose = bool(self.permuted_input_buffers) or has_col_major_out

        # Count trailing reduction dimensions in the output shape that must
        # not be tiled (the kernel body needs the full reduction range).
        # Only count when the kernel actually performs reduction (numel > 1).
        reduction_numel = self._compute_reduction_numel()
        has_reduction = reduction_numel is not None and reduction_numel > 1
        self.tile_skip_last_n = (
            sum(1 for tree in self.range_trees if tree.is_reduction)
            if has_reduction
            else 0
        )

        ref_shape: list[int] = []
        for buf_name in out_bufs:
            info = self._get_buffer_info(buf_name)
            if info is None:
                return False
            _, buf_size, _, _, _ = info
            int_size = [self._safe_int(s) for s in buf_size]
            if any(s is None for s in int_size):
                return False
            if len(int_size) > len(ref_shape):
                ref_shape = int_size  # type: ignore[assignment]

        if not ref_shape:
            return False

        # For collapsed permutation kernels, override ref_shape with the
        # collapsed output shape so all compatibility checks operate in
        # collapsed-shape space.
        if self.collapsed_output_shape is not None:
            ref_shape = list(self.collapsed_output_shape)

        ref_nd = len(ref_shape)

        all_bufs = list(self.args.input_buffers) + out_bufs
        has_tileable = False
        for buf_name in all_bufs:
            info = self._get_buffer_info(buf_name)
            if info is None:
                return False
            _, buf_size, _, _, _ = info
            if len(buf_size) == 0:
                continue  # scalar

            # Use collapsed shapes when available so dimension checks
            # operate in the same space as the kernel.
            if buf_name in self.collapsed_reshape_inputs:
                int_size = list(self.collapsed_reshape_inputs[buf_name])
            elif self.collapsed_output_shape is not None and buf_name in out_bufs:
                int_size = list(self.collapsed_output_shape)
            else:
                int_size = [self._safe_int(s) for s in buf_size]
                if any(s is None for s in int_size):
                    return False
            buf_nd = len(int_size)

            if buf_nd == ref_nd:
                # Same ndim: check dimensions match or are broadcast (1).
                # Allow strided buffers (dims may differ after reshape).
                is_strided = buf_name in self.strided_input_buffers
                mismatch = False
                for i in range(ref_nd):
                    if (
                        int_size[i] == ref_shape[i]
                        or int_size[i] == 1
                        or ref_shape[i] == 1
                        or is_strided
                    ):
                        continue
                    mismatch = True
                    break

                if mismatch and buf_name in self.permuted_input_buffers:
                    perm = self.permuted_input_buffers[buf_name]
                    if not (
                        len(perm) == ref_nd
                        and all(
                            int_size[perm[i]] == ref_shape[i]
                            or int_size[perm[i]] == 1
                            or ref_shape[i] == 1
                            for i in range(ref_nd)
                        )
                    ):
                        return False
                elif mismatch:
                    return False

                # At least one buffer with a tileable dim
                if is_strided or any(
                    int_size[i] == ref_shape[i] and ref_shape[i] > 1
                    for i in range(ref_nd)
                ):
                    has_tileable = True

            elif buf_nd > ref_nd:
                # Reduction input with extra dims. Find an alignment offset k
                # such that buf_shape[k+i] == ref_shape[i] for all i (skipping
                # broadcast dims where ref_shape[i] == 1).
                found = False
                for k in range(buf_nd - ref_nd + 1):
                    ok = True
                    for i in range(ref_nd):
                        if ref_shape[i] == 1:
                            continue
                        if int_size[k + i] != ref_shape[i]:
                            ok = False
                            break
                    if ok:
                        found = True
                        break
                if not found:
                    return False
                has_tileable = True

            else:
                # Fewer dims: verify numpy-style broadcastability
                for a, b in zip(reversed(int_size), reversed(ref_shape)):
                    if a != b and a != 1 and b != 1:
                        return False

        if not has_tileable:
            return False

        # On CPU (interpret mode) each tile iteration has significant
        # Python/JAX overhead, so cap the grid size.  Store the cap
        # so _codegen_tiled_specs can pass it to pallas_compute_tiling,
        # which will scale up tiles to stay within the limit.
        is_tpu = V.graph.get_current_device_or_throw().type == "tpu"
        if not is_tpu:
            self._cpu_max_grid_product = 64
        else:
            self._cpu_max_grid_product = None

        return True

    def codegen_kernel(self, name: str | None = None) -> str:  # type: ignore[override]
        """
        Generate the complete Pallas kernel code as a Python string.

        This includes:
        - Import statements for JAX/Pallas
        - The kernel function that operates on refs
        - The main wrapper function that handles PyTorch<->JAX conversions via DLPack

        Args:
            name: Optional kernel name (will use placeholder if not provided)

        Returns:
            str: Complete Python source code for the Pallas kernel
        """
        code = IndentedBuffer()

        # Define the Pallas kernel: accepts refs, uses broadcasted expressions
        arg_defs, call_args, _, _ = self.args.python_argdefs()
        kernel_params = [a.name for a in arg_defs]
        pure_out_params = [p for p in kernel_params if p.startswith("out_ptr")]
        output_params = [
            p for p in kernel_params if p.startswith(("out_ptr", "in_out_ptr"))
        ]
        # Identify size variable parameters (scalars like load_seed_offset)
        size_var_names = OrderedSet(self.args.sizevars.values())
        size_var_params = [p for p in kernel_params if p in size_var_names]
        if not output_params:
            raise RuntimeError("Pallas backend requires at least one output buffer")

        output_buffer_lookup = {
            inner: outer
            for outer, inner in self.args.output_buffers.items()
            if isinstance(inner, str)
        }

        kernel_name = name or "<KERNEL_NAME>"
        interpret_is_cpu = V.graph.get_current_device_or_throw().type == "cpu"
        interpret_literal = "True" if interpret_is_cpu else "False"

        aliasable_flags: dict[str, bool] = {}
        for param in pure_out_params:
            aliasable_flags[param] = True
        alias_params = [
            f"{param}_alias" for param in pure_out_params if aliasable_flags[param]
        ]
        pointer_tail = [
            p for p in kernel_params if p.startswith(("in_out_ptr", "in_ptr"))
        ]
        kernel_input_params = alias_params + pointer_tail
        full_kernel_params = alias_params + kernel_params
        non_alias_out_set = OrderedSet(
            [name for name, flag in aliasable_flags.items() if not flag]
        )
        # On CPU (interpret=True), pallas_call returns new arrays so we must
        # copy back every output.  On TPU, call_custom_kernel with
        # input_output_aliases handles donation (zero-copy), so no copy is
        # needed.  On CUDA, aliased outputs are mutated in-place by the
        # donated-buffer mechanism so only non-aliased outputs need a copy.
        if interpret_is_cpu:
            copy_output_indices = list(range(len(output_params)))
        elif self.is_tpu:
            copy_output_indices = []
        else:
            copy_output_indices = [
                idx
                for idx, name in enumerate(output_params)
                if name in non_alias_out_set
            ]

        ctx = _CodegenContext(
            code=code,
            kernel_name=kernel_name,
            is_tpu=self.is_tpu,
            interpret_is_cpu=interpret_is_cpu,
            interpret_literal=interpret_literal,
            kernel_params=kernel_params,
            pure_out_params=pure_out_params,
            output_params=output_params,
            size_var_params=size_var_params,
            output_buffer_lookup=output_buffer_lookup,
            aliasable_flags=aliasable_flags,
            alias_params=alias_params,
            pointer_tail=pointer_tail,
            kernel_input_params=kernel_input_params,
            full_kernel_params=full_kernel_params,
            non_alias_out_set=non_alias_out_set,
            copy_output_indices=copy_output_indices,
            alias_pairs=[],
        )
        self.aliasable_out_ptrs = aliasable_flags

        self._codegen_imports(ctx)

        kernel_body = IndentedBuffer()
        with kernel_body.indent():
            self._codegen_iteration_vars(kernel_body, ctx)

            for line in self.compute._lines:
                kernel_body.writeline(str(line))

        # Recompute kernel parameters after kernel body generation.
        # Size variables may have been registered during kernel body generation
        # (e.g., via rename_indexing for symbolic sizes), so we need to re-fetch
        # the arg defs to capture all parameters including newly-registered size vars.
        arg_defs, call_args, _, _ = self.args.python_argdefs()
        kernel_params = [a.name for a in arg_defs]
        size_var_names = OrderedSet(self.args.sizevars.values())
        ctx.size_var_params = [p for p in kernel_params if p in size_var_names]
        ctx.pointer_tail = [
            p for p in kernel_params if p.startswith(("in_out_ptr", "in_ptr"))
        ]
        ctx.kernel_input_params = alias_params + ctx.pointer_tail
        ctx.full_kernel_params = alias_params + kernel_params

        # Decide whether to use tiling for CPU/TPU after kernel body is fully
        # generated (used_iter_vars is populated during load/store codegen).
        self.tile_cpu_tpu = self._can_tile_cpu_tpu()

        extra_kernel_params = ""
        if self.tile_relative_iter_vars:
            extra_kernel_params = ", _pallas_tile=None, _pallas_ax2g=None"

        ctx.alias_pairs = self._compute_alias_pairs(ctx, aliasable_flags)

        use_scalar_prefetch = bool(self.indirect_access)

        if use_scalar_prefetch:
            self._eliminate_dead_indirect_code()
            kernel_body_sp = IndentedBuffer()
            with kernel_body_sp.indent():
                for line in self.compute._lines:
                    kernel_body_sp.writeline(str(line))
            self._codegen_scalar_prefetch_wrapper(
                ctx,
                kernel_name,
                kernel_body_sp,
            )
            return code.getvalue()

        # Emit the kernel function with the correct signature
        kernel_signature = f"def {kernel_name}_kernel({', '.join(ctx.full_kernel_params)}{extra_kernel_params}):"
        code.writeline(kernel_signature)

        with code.indent():
            self._emit_kernel_body(code, kernel_body, ctx)

        code.writeline("")
        jit_wrapper_name = f"{kernel_name}_jit_wrapper"
        donate_indices = []
        base_offset = 2 + len(ctx.size_var_params)
        for idx, name in enumerate(ctx.kernel_input_params):
            if (name in alias_params) or name.startswith("in_out_ptr"):
                donate_indices.append(idx + base_offset)
        if donate_indices:
            donate_literal = "(" + ", ".join(str(x) for x in donate_indices) + ",)"
        else:
            donate_literal = "()"
        static_argnums = list(range(2 + len(ctx.size_var_params)))
        static_argnums_literal = "(" + ", ".join(str(x) for x in static_argnums) + ",)"
        code.writeline(
            "@functools.partial("
            f"jax.jit, static_argnums={static_argnums_literal}, donate_argnums="
            f"{donate_literal})"
        )
        wrapper_params = (
            ["out_shapes", "out_dtypes"] + ctx.size_var_params + ctx.kernel_input_params
        )
        code.writeline(f"def {jit_wrapper_name}({', '.join(wrapper_params)}):")

        alias_map_literal = ", ".join(f"{i}: {o}" for (i, o) in ctx.alias_pairs)

        has_zero_dim, has_unknown_dim = self._zero_dim_output_flags(ctx)

        zero_dim_return = (
            "results = tuple(jnp.empty(s, dtype=dt) "
            "for s, dt in zip(out_shapes, out_dtypes))",
            "return results if len(results) > 1 else results[0]",
        )

        with code.indent():
            if has_zero_dim:
                code.writelines(zero_dim_return)
            else:
                if has_unknown_dim:
                    code.writeline("if any(0 in shape for shape in out_shapes):")
                    with code.indent():
                        code.writelines(zero_dim_return)
                # Pallas requires >= 1-d tensors; promote 0-d to (1,)
                code.writeline(
                    "_pallas_out_shapes = tuple("
                    "s if len(s) > 0 else (1,) for s in out_shapes)"
                )
                if self.collapsed_output_shape is not None:
                    code.writeline(
                        f"_pallas_out_shapes = ({self.collapsed_output_shape},)"
                    )
                # Reshape aliased inputs to match promoted output shapes
                for input_idx, out_idx in ctx.alias_pairs:
                    param = ctx.kernel_input_params[input_idx]
                    code.writeline(
                        f"{param} = {param}.reshape(_pallas_out_shapes[{out_idx}])"
                    )
                code.writeline("out_shapes_pallas = tuple(")
                code.writeline("    jax.ShapeDtypeStruct(shape, dtype)")
                code.writeline(
                    "    for shape, dtype in zip(_pallas_out_shapes, out_dtypes)"
                )
                code.writeline(")")
                if self.tile_cpu_tpu:
                    self._codegen_tiled_specs(ctx)
                else:
                    self._codegen_strided_reshapes(code, ctx.kernel_input_params)
                    for param in ctx.kernel_input_params:
                        buf_name = self._param_to_buf_name(param)
                        cshape = (
                            self.collapsed_reshape_inputs.get(buf_name)
                            if buf_name
                            else None
                        )
                        if cshape is not None:
                            code.writeline(f"{param} = {param}.reshape({cshape})")

                    code.writeline("out_specs_pallas = tuple(")
                    code.writeline("    pallas_make_block_spec_non_tiled(shape)")
                    code.writeline(
                        "    for shape, dtype in zip(_pallas_out_shapes, out_dtypes)"
                    )
                    code.writeline(")")
                    code.writeline("in_specs_pallas = tuple(")
                    code.writeline("    pallas_make_block_spec_non_tiled(i.shape)")
                    code.writeline(
                        "    for i in [" + ", ".join(ctx.kernel_input_params) + "]"
                    )
                    code.writeline(")")

                if self.tile_relative_iter_vars:
                    if self.tile_cpu_tpu:
                        code.writeline("_pallas_tile = _tile")
                        code.writeline("_pallas_ax2g = _ax2g")
                    else:
                        code.writeline("_pallas_tile = _pallas_out_shapes[0]")
                        code.writeline("_pallas_ax2g = {}")

                # Wrap kernel with functools.partial to pass scalar arguments (size variables)
                partial_args = []
                for sv_param in ctx.size_var_params:
                    partial_args.append(f"{sv_param}={sv_param}")

                if self.tile_relative_iter_vars:
                    partial_args.append("_pallas_tile=_pallas_tile")
                    partial_args.append("_pallas_ax2g=_pallas_ax2g")

                if partial_args:
                    kernel_arg = f"functools.partial({kernel_name}_kernel, {', '.join(partial_args)}),"
                else:
                    kernel_arg = f"{kernel_name}_kernel,"

                use_tma = (
                    self.is_gpu
                    and self.use_emit_pipeline
                    and self._can_use_tma_approach()
                )
                if use_tma:
                    self._codegen_jit_wrapper_tma(ctx, kernel_arg)
                elif self.is_gpu:
                    self._codegen_jit_wrapper_legacy_gpu(ctx, kernel_arg)
                else:
                    self._codegen_jit_wrapper_cpu_tpu(
                        ctx, kernel_arg, ctx.alias_pairs, alias_map_literal
                    )

        self._codegen_main_entry(ctx, jit_wrapper_name)
        return code.getvalue()

    @staticmethod
    def _compute_alias_pairs(
        ctx: _CodegenContext, aliasable_flags: dict[str, bool]
    ) -> list[tuple[int, int]]:
        alias_pairs: list[tuple[int, int]] = []
        for out_idx, name in enumerate(ctx.output_params):
            if name.startswith("out_ptr"):
                if aliasable_flags.get(name, False):
                    alias_name = f"{name}_alias"
                    input_idx = ctx.kernel_input_params.index(alias_name)
                    alias_pairs.append((input_idx, out_idx))
            else:
                input_idx = ctx.kernel_input_params.index(name)
                alias_pairs.append((input_idx, out_idx))
        return alias_pairs

    def _emit_kernel_body(
        self,
        code: IndentedBuffer,
        kernel_body: IndentedBuffer,
        ctx: _CodegenContext,
    ) -> None:
        """Emit the kernel body lines and store operations into code."""
        for line in kernel_body._lines:
            if isinstance(line, str):
                code.writeline(line.lstrip())
            else:
                code._lines.append(line)
        for out_ptr, store_line in self.store_with_output:
            if out_ptr in ctx.full_kernel_params:
                code.writeline(store_line)

    def _codegen_scalar_prefetch_wrapper(
        self,
        ctx: _CodegenContext,
        kernel_name: str,
        kernel_body: IndentedBuffer,
    ) -> None:
        """Emit kernel, JIT wrapper, and main entry for scalar prefetch."""
        assert self.indirect_access is not None
        indirect = self.indirect_access
        code = ctx.code

        alias_set = OrderedSet(ctx.alias_params)
        other_input_params = [
            p
            for p in ctx.kernel_input_params
            if p != indirect.indices_param
            and p != indirect.table_param
            and p not in alias_set
        ]

        # Emit kernel function with params reordered for PrefetchScalarGridSpec:
        # [scalar_prefetch] + [in_specs refs] + [out_specs refs]
        prefetch_kernel_params = (
            [indirect.indices_param]
            + [indirect.table_param]
            + other_input_params
            + list(ctx.alias_params)
            + ctx.output_params
        )
        code.writeline(
            f"def {kernel_name}_kernel({', '.join(prefetch_kernel_params)}):"
        )
        with code.indent():
            self._emit_kernel_body(code, kernel_body, ctx)

        # Emit JIT wrapper
        code.writeline("")
        jit_wrapper_name = f"{kernel_name}_jit_wrapper"
        wrapper_params = (
            ["out_shapes", "out_dtypes"] + ctx.size_var_params + ctx.kernel_input_params
        )
        static_argnums = list(range(2 + len(ctx.size_var_params)))
        static_argnums_literal = "(" + ", ".join(str(x) for x in static_argnums) + ",)"
        code.writeline(
            f"@functools.partial(jax.jit, static_argnums={static_argnums_literal})"
        )
        code.writeline(f"def {jit_wrapper_name}({', '.join(wrapper_params)}):")

        with code.indent():
            table = indirect.table_param
            indices = indirect.indices_param

            ind_dim = indirect.indirect_dim
            ndim = len(indirect.table_shape)
            code.writeline("_D = 1")
            for i in range(ndim):
                if i != ind_dim:
                    code.writeline(f"_D = _D * {table}.shape[{i}]")
            code.writeline(f"_seq = {indices}.shape[0]")

            if ind_dim == 0:
                code.writeline(f"_table_3d = {table}.reshape({table}.shape[0], 1, _D)")
            else:
                perm = (ind_dim, *[d for d in range(ndim) if d != ind_dim])
                code.writeline(
                    f"_table_3d = {table}.transpose{perm}.reshape("
                    f"{table}.shape[{ind_dim}], 1, _D)"
                )

            # Reshape other (non-table, non-indices) inputs to 3D to match the
            # table's (seq, 1, D) layout.  Currently handles:
            #   - 2D with leading dim == seq: row-aligned, reshape to (seq, 1, D)
            #   - 1D: broadcast scalar/vector, reshape to (1, 1, numel)
            #   - else: flatten to (1, 1, -1) — assumes broadcastable with
            #     (seq, 1, D).  This may not work correctly for 3D+ inputs.
            pallas_call_other_args = []
            for p in other_input_params:
                p3d = f"_{p}_3d"
                code.writeline(f"if {p}.ndim == 2 and {p}.shape[0] == _seq:")
                code.writeline(f"    {p3d} = {p}.reshape(_seq, 1, _D)")
                code.writeline(f"elif {p}.ndim == 1:")
                code.writeline(f"    {p3d} = {p}.reshape(1, 1, {p}.shape[0])")
                code.writeline("else:")
                code.writeline(f"    {p3d} = {p}.reshape(1, 1, -1)")
                pallas_call_other_args.append(p3d)

            pallas_call_alias_args = []
            for p in ctx.alias_params:
                p3d = f"_{p}_3d"
                code.writeline(f"{p3d} = {p}.reshape(_seq, 1, _D)")
                pallas_call_alias_args.append(p3d)

            partial_args = [f"{sv}={sv}" for sv in ctx.size_var_params]
            if partial_args:
                kernel_ref = (
                    f"functools.partial({kernel_name}_kernel,"
                    f" {', '.join(partial_args)})"
                )
            else:
                kernel_ref = f"{kernel_name}_kernel"

            # Reusable row-tiled BlockSpec (all i32 index_map for Mosaic compat)
            code.writeline(
                "_ROW_SPEC = pl.BlockSpec((1, 1, _D),"
                " lambda i, _: (i, jnp.int32(0), jnp.int32(0)))"
            )

            num_non_alias_in_specs = 1 + len(pallas_call_other_args)
            code.writeline("_in_specs = [")
            with code.indent():
                code.writeline(
                    "pl.BlockSpec((1, 1, _D),"
                    " lambda gi, idx: (idx[gi], jnp.int32(0), jnp.int32(0))),"
                )
                for p3d in pallas_call_other_args:
                    code.writeline(
                        f"_ROW_SPEC"
                        f" if {p3d}.shape[0] == _seq else"
                        f" pl.BlockSpec({p3d}.shape,"
                        f" lambda i, _: (jnp.int32(0), jnp.int32(0), jnp.int32(0))),"
                    )
                for _ in ctx.alias_params:
                    code.writeline("_ROW_SPEC,")
            code.writeline("]")

            num_outputs = len(ctx.output_params)
            code.writeline(
                "_out_specs = [" + ", ".join(["_ROW_SPEC"] * num_outputs) + "]"
            )

            # input_output_aliases: pallas_call arg index -> output index
            # (offset by 1 for scalar prefetch arg)
            alias_map_parts = []
            for out_idx, _ in enumerate(ctx.alias_params):
                arg_idx = 1 + num_non_alias_in_specs + out_idx
                alias_map_parts.append(f"{arg_idx}: {out_idx}")
            alias_map_literal = ", ".join(alias_map_parts)

            out_shape_parts = [
                f"jax.ShapeDtypeStruct((_seq, 1, _D), out_dtypes[{i}])"
                for i in range(num_outputs)
            ]
            out_shape_expr = "[" + ", ".join(out_shape_parts) + "]"

            code.writeline("_result = pl.pallas_call(")
            with code.indent():
                code.writeline(f"{kernel_ref},")
                code.writeline(f"out_shape={out_shape_expr},")
                code.writeline("grid_spec=pltpu.PrefetchScalarGridSpec(")
                with code.indent():
                    code.writeline("num_scalar_prefetch=1,")
                    code.writeline("grid=(_seq,),")
                    code.writeline("in_specs=_in_specs,")
                    code.writeline("out_specs=_out_specs,")
                code.writeline("),")
                if alias_map_parts:
                    code.writeline(f"input_output_aliases={{ {alias_map_literal} }},")
                if not self.is_tpu:
                    code.writeline(f"interpret={ctx.interpret_literal},")

            all_pallas_args = (
                [indices]
                + ["_table_3d"]
                + pallas_call_other_args
                + pallas_call_alias_args
            )
            code.writeline(f")({', '.join(all_pallas_args)})")

            code.writeline(
                "return tuple(r.reshape(s) for r, s in zip(_result, out_shapes))"
            )

        self._codegen_main_entry(ctx, jit_wrapper_name)

    def _codegen_imports(self, ctx: _CodegenContext) -> None:
        imports = """
import functools
import math
import torch
import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl
from torch.utils._ordered_set import OrderedSet
from torch._inductor.runtime.runtime_utils import (
    pallas_compute_tiling, pallas_make_block_spec, pallas_permute,
    pallas_gpu_align_output_specs, pallas_gpu_pad_inputs,
    pallas_gpu_unpad_results,
    pallas_ensure_nonzero_rank,
    pallas_make_block_spec_non_tiled,
    torch_dtype_to_jax_runtime,
)
"""
        if ctx.is_tpu:
            imports += "\nimport jax.export"
            imports += "\nfrom jax.experimental.pallas import tpu as pltpu"
            imports += "\nfrom torch_tpu._internal.pallas import tpu_torch_pallas"
        elif not ctx.interpret_is_cpu:
            imports += "\nfrom jax.experimental.pallas import mosaic_gpu as plgpu"
        if self.indirect_access and not ctx.is_tpu:
            imports += (
                "\nimport os as _os; _os.environ.setdefault('JAX_PLATFORMS', 'cpu')"
            )
            imports += "\nfrom jax.experimental.pallas import tpu as pltpu"
        ctx.code.splice(imports, strip=True)

    def _get_iter_var_axis(self, var_sym: sympy.Symbol) -> int | None:
        """Map an iteration variable to its output tensor axis index.

        Non-reduction variables map to axes 0, 1, 2, ... in order.
        Reduction variables map to axes after all pointwise axes.
        Returns None if the mapping cannot be determined.
        """
        pw_idx = 0
        r_idx = 0
        n_pw = sum(1 for _, e in self.range_tree_nodes.items() if not e.is_reduction)
        for sym, entry in self.range_tree_nodes.items():
            if sym == var_sym:
                return pw_idx if not entry.is_reduction else n_pw + r_idx
            if entry.is_reduction:
                r_idx += 1
            else:
                pw_idx += 1
        return None

    def _get_reshape_target_shape_and_numel(
        self,
    ) -> tuple[tuple[int, ...] | None, int | None]:
        # Find reshape target: N-D shape whose numel matches an iteration
        # var. Try output first (repeat/upsample), then inputs (reductions).
        iter_lengths = OrderedSet(
            [
                int(e.length)
                for e in self.range_tree_nodes.values()
                if isinstance(e.length, (int, sympy.Integer))
            ]
        )

        def _get_nd_shape_if_matches(buf_name):
            buf = V.graph.try_get_buffer(buf_name)
            if buf is None or len(buf.get_size()) <= 1:
                return None, None
            shape = tuple(
                int(s) if isinstance(s, (int, sympy.Integer)) else s
                for s in buf.get_size()
            )
            numel = math.prod(shape)
            return (shape, numel) if numel in iter_lengths else (None, None)

        candidate_buf_names = self._output_buffer_names.copy()
        candidate_buf_names.extend(self.args.input_buffers)

        reshape_target_shape, reshape_target_numel = None, None
        for buf_name in candidate_buf_names:
            result = _get_nd_shape_if_matches(buf_name)
            if result[0]:
                reshape_target_shape, reshape_target_numel = result
                break

        return reshape_target_shape, reshape_target_numel

    def _make_broadcasted_iteration_var_expr(
        self, broadcast_vars: list[_BroadcastedIterVar], broadcast_idx: int
    ) -> str:
        bv = broadcast_vars[broadcast_idx]
        length = bv.entry.length
        renamed_length = self.rename_indexing(length)
        length_str = self.kexpr(renamed_length)

        num_broadcast_dims = len(broadcast_vars)
        axis_idx = self._broadcast_axis_idx(
            broadcast_vars, broadcast_idx, num_broadcast_dims
        )
        shape_parts = ["1"] * num_broadcast_dims
        shape_parts[axis_idx] = length_str
        shape_str = ", ".join(shape_parts)
        arange = f"jnp.arange({length_str})"
        reshaped = f"{arange}.reshape({shape_str})"
        return reshaped

    def _codegen_iteration_vars(
        self, kernel_body: IndentedBuffer, ctx: _CodegenContext
    ) -> None:
        # Generate iteration variables as jnp.arange arrays
        # Skip on GPU - jnp.arange is not supported by Pallas Mosaic backend
        if not (self.range_tree_nodes and not self.is_gpu and self.used_iter_vars):
            return

        kernel_body.writeline("# Define iteration variables as JAX arrays")

        reshape_target_shape, reshape_target_numel = (
            self._get_reshape_target_shape_and_numel()
        )

        var_items = list(self.range_tree_nodes.items())

        broadcast_vars = []
        total_var_idx = None
        for idx, (var_sym, entry) in enumerate(var_items):
            length_val = self._safe_int(entry.length)
            if length_val is not None and length_val == reshape_target_numel:
                total_var_idx = idx
            else:
                broadcast_vars.append(
                    _BroadcastedIterVar(idx, var_sym, entry, length_val)
                )

        num_broadcast_dims = len(broadcast_vars)

        for idx, (var_sym, entry) in enumerate(var_items):
            if var_sym not in self.used_iter_vars:
                continue
            var_name = str(var_sym)
            length = entry.length
            renamed_length = self.rename_indexing(length)
            length_str = self.kexpr(renamed_length)
            length_val = self._safe_int(length)

            if length_val is None:
                if (
                    reshape_target_shape
                    and num_broadcast_dims > 1
                    and idx != total_var_idx
                ):
                    broadcast_idx = next(
                        (i for i, v in enumerate(broadcast_vars) if v.idx == idx),
                        None,
                    )
                    if broadcast_idx is not None:
                        expr = self._make_broadcasted_iteration_var_expr(
                            broadcast_vars, broadcast_idx
                        )
                        kernel_body.writeline(f"{var_name} = {expr}")
                        continue
                kernel_body.writeline(f"{var_name} = jnp.arange({length_str})")
                continue

            if (
                reshape_target_shape
                and len(reshape_target_shape) > 1
                and length_val == reshape_target_numel
            ):
                shape_str = ", ".join(str(s) for s in reshape_target_shape)
                arange = f"jnp.arange({length_str})"
                kernel_body.writeline(f"{var_name} = {arange}.reshape({shape_str})")
            elif num_broadcast_dims > 1 and idx != total_var_idx:
                broadcast_idx = next(
                    i for i, v in enumerate(broadcast_vars) if v.idx == idx
                )
                expr = self._make_broadcasted_iteration_var_expr(
                    broadcast_vars, broadcast_idx
                )
                kernel_body.writeline(f"{var_name} = {expr}")
            else:
                # Simple 1D arange — emit tile-relative form so tiling is safe.
                # When grid=(1,), _pallas_tile[ax] == full length and
                # pl.program_id(0) == 0, so this degenerates to jnp.arange(N).
                # Only do this when the var actually appears in compute body
                # (otherwise tiling is not blocked and the full arange is fine).
                # Skip for scatter/index kernels where the iter var is used
                # as a global index, not a data value.
                compute_text = "\n".join(str(line) for line in self.compute._lines)
                var_in_compute = var_name in compute_text
                can_tile_relative = (
                    var_in_compute
                    and not self.is_gpu
                    and not self.outputs_need_read
                    and not self.has_flatten_indexing
                    and not entry.is_reduction
                )
                axis_idx = (
                    self._get_iter_var_axis(var_sym) if can_tile_relative else None
                )
                if axis_idx is not None:
                    kernel_body.writeline(
                        f"{var_name} = jnp.arange(_pallas_tile[{axis_idx}])"
                        f" + pl.program_id(_pallas_ax2g.get({axis_idx}, 0))"
                        f" * _pallas_tile[{axis_idx}]"
                    )
                    self.tile_relative_iter_vars.add(var_sym)
                else:
                    kernel_body.writeline(f"{var_name} = jnp.arange({length_str})")

    @staticmethod
    def _broadcast_axis_idx(
        broadcast_vars: list[_BroadcastedIterVar],
        broadcast_idx: int,
        num_broadcast_dims: int,
    ) -> int:
        # Axis placement depends on var types (reduction r* vs x*):
        # - Mixed: pointwise first, reduction last for output reshape
        # - Same-type: reverse order, first var innermost
        has_reduction_vars = any(
            str(bv.var_sym).startswith("r") for bv in broadcast_vars
        )
        has_pointwise_vars = any(
            not str(bv.var_sym).startswith("r") for bv in broadcast_vars
        )
        is_mixed = has_reduction_vars and has_pointwise_vars
        if is_mixed:
            return broadcast_idx
        return num_broadcast_dims - 1 - broadcast_idx

    def _codegen_jit_wrapper_tma(self, ctx: _CodegenContext, kernel_arg: str) -> None:
        code = ctx.code
        kernel_input_params = ctx.kernel_input_params
        output_params = ctx.output_params

        # TMA automatically handles out-of-bounds accesses
        code.writeline("# Use lax.fori_loop with TMA for automatic OOB masking")
        code.writeline("from jax import lax")
        code.writeline("_tile_size = 128  # Warpgroup size")
        code.writeline("_orig_out_shapes = out_shapes")

        code.writeline("_max_numel = 0")
        for param in kernel_input_params:
            code.writeline(f"_max_numel = max(_max_numel, {param}.size)")
        code.writeline("for shape in out_shapes:")
        code.writeline("    _max_numel = max(_max_numel, math.prod(shape))")

        code.writeline("_num_tiles = (_max_numel + _tile_size - 1) // _tile_size")

        gmem_input_params = [f"{p}_gmem" for p in kernel_input_params]
        gmem_output_params = [f"{p}_gmem" for p in output_params]
        smem_input_params = [f"{p}_smem" for p in kernel_input_params]
        smem_output_params = [f"{p}_smem" for p in output_params]

        code.writeline("")
        code.writeline("# Wrapper kernel using lax.fori_loop with direct TMA")

        wrapper_kernel_params = gmem_input_params + gmem_output_params
        all_smem_params = smem_input_params + smem_output_params
        barrier_params = [f"_barrier_{i}" for i in range(len(kernel_input_params))]
        scratch_params = ", ".join(all_smem_params + barrier_params)

        code.writeline(
            f"def _tma_kernel({', '.join(wrapper_kernel_params)}, *, {scratch_params}):"
        )
        with code.indent():
            code.writeline("")
            code.writeline("def _tile_body(_tile_idx, _):")
            with code.indent():
                code.writeline("_tile_start = _tile_idx * _tile_size")
                code.writeline("")

                code.writeline("# TMA load inputs from GMEM to SMEM (OOB auto-masked)")
                for i, (gmem_in, smem_in) in enumerate(
                    zip(gmem_input_params, smem_input_params)
                ):
                    code.writeline(
                        f"plgpu.copy_gmem_to_smem({gmem_in}.at[pl.ds(_tile_start, _tile_size)], {smem_in}, _barrier_{i})"
                    )

                code.writeline("")
                code.writeline("# Wait for TMA loads to complete")
                for i, _ in enumerate(gmem_input_params):
                    code.writeline(f"plgpu.barrier_wait(_barrier_{i})")

                code.writeline("")
                code.writeline("# Compute on SMEM tiles")
                kernel_call_args = smem_input_params + smem_output_params
                kernel_fn = kernel_arg.rstrip(",").strip()
                code.writeline(f"{kernel_fn}({', '.join(kernel_call_args)})")

                code.writeline("")
                code.writeline(
                    "# TMA store outputs from SMEM to GMEM (OOB auto-masked)"
                )
                code.writeline("plgpu.commit_smem()")
                for gmem_out, smem_out in zip(gmem_output_params, smem_output_params):
                    code.writeline(
                        f"plgpu.copy_smem_to_gmem({smem_out}, {gmem_out}.at[pl.ds(_tile_start, _tile_size)])"
                    )
                code.writeline("plgpu.wait_smem_to_gmem(0)")
                code.writeline("")
                code.writeline("return None")

            code.writeline("")
            code.writeline("# Iterate over all tiles")
            code.writeline("lax.fori_loop(0, _num_tiles, _tile_body, None)")

        # Build scratch_shapes dict
        code.writeline("")
        code.writeline(
            "# Build SMEM scratch shapes for inputs, outputs, and TMA barriers"
        )
        code.writeline("_scratch_shapes = {}")
        for i, smem_param in enumerate(smem_input_params):
            orig_param = kernel_input_params[i]
            code.writeline(
                f"_scratch_shapes['{smem_param}'] = plgpu.SMEM((_tile_size,), {orig_param}.dtype)"
            )
        for i, smem_param in enumerate(smem_output_params):
            code.writeline(
                f"_scratch_shapes['{smem_param}'] = plgpu.SMEM((_tile_size,), out_dtypes[{i}])"
            )
        for barrier_param in barrier_params:
            code.writeline(
                f"_scratch_shapes['{barrier_param}'] = plgpu.Barrier(num_arrivals=1)"
            )

        code.writeline("")
        code.writeline("# Create flattened output specs aligned to tile size")
        code.writeline(
            "_flat_out_specs, _ = pallas_gpu_align_output_specs(out_shapes, out_dtypes, _tile_size)"
        )

        code.writeline("")
        code.writeline("# Call plgpu.kernel with TMA kernel")
        code.writeline("_result = plgpu.kernel(")
        with code.indent():
            code.writeline("_tma_kernel,")
            code.writeline("out_shape=_flat_out_specs,")
            code.writeline("scratch_shapes=_scratch_shapes,")
        code.writeline(")(")
        for param in kernel_input_params:
            code.writeline(f"    {param}.flatten(),")
        code.writeline(")")

        code.writeline("")
        code.writeline("# Reshape results to original shapes")
        code.writeline("return pallas_gpu_unpad_results(_result, _orig_out_shapes)")

    def _codegen_jit_wrapper_legacy_gpu(
        self, ctx: _CodegenContext, kernel_arg: str
    ) -> None:
        code = ctx.code
        kernel_input_params = ctx.kernel_input_params
        input_list = f"[{', '.join(kernel_input_params)}]"

        # Legacy GPU path with explicit padding (use_emit_pipeline=False)
        # Mosaic GPU requires tensor sizes to be multiples of 128.
        # Only apply padding when all tensors have the same size (no broadcasting).
        code.writeline("# Check if all tensors have same size (no broadcasting)")
        code.writeline("_all_sizes = []")
        for param in kernel_input_params:
            code.writeline(f"_all_sizes.append({param}.size)")
        code.writeline("for shape in out_shapes:")
        code.writeline("    _all_sizes.append(math.prod(shape))")
        code.writeline("_unique_sizes = OrderedSet(_all_sizes)")
        code.writeline(
            "_can_pad = len(_unique_sizes) == 1 and all(s > 1 for s in _unique_sizes)"
        )

        code.writeline("")
        code.writeline("if _can_pad:")
        code.writeline("    # All tensors same size - safe to flatten and pad")
        code.writeline(f"    _padded_inputs = pallas_gpu_pad_inputs({input_list})")
        code.writeline(
            "    _aligned_out_specs, _is_scalar = pallas_gpu_align_output_specs(out_shapes, out_dtypes)"
        )
        code.writeline("    _result = plgpu.kernel(")
        code.writeline("        " + kernel_arg)
        code.writeline("        out_shape=_aligned_out_specs,")
        code.writeline("    )(*_padded_inputs)")
        code.writeline(
            "    return pallas_gpu_unpad_results(_result, out_shapes, _is_scalar)"
        )

        code.writeline("else:")
        code.writeline(
            "    # Different sizes - check if it's a reduction (scalar output)"
        )
        code.writeline("    _out_numel = math.prod(out_shapes[0])")
        code.writeline("    ")
        code.writeline("    if _out_numel <= 1:")
        code.writeline(
            "        # Scalar output (reduction) - pad inputs but keep scalar output"
        )
        code.writeline(f"        _padded_inputs = pallas_gpu_pad_inputs({input_list})")
        code.writeline("        _aligned_out_specs = tuple(")
        code.writeline("            jax.ShapeDtypeStruct(shape, dtype)")
        code.writeline("            for shape, dtype in zip(out_shapes, out_dtypes)")
        code.writeline("        )")
        code.writeline("        _result = plgpu.kernel(")
        code.writeline("            " + kernel_arg)
        code.writeline("            out_shape=_aligned_out_specs,")
        code.writeline("        )(*_padded_inputs)")
        code.writeline("        return _result")
        code.writeline("    else:")
        code.writeline(
            "        # Non-scalar output with broadcasting - broadcast inputs to output shape"
        )
        code.writeline("        _target_shape = out_shapes[0]")
        code.writeline("        _broadcasted = [")
        code.writeline(
            f"            jnp.broadcast_to(_inp, _target_shape) for _inp in {input_list}"
        )
        code.writeline("        ]")
        code.writeline("        _padded_inputs = pallas_gpu_pad_inputs(_broadcasted)")
        code.writeline(
            "        _aligned_out_specs, _is_scalar = pallas_gpu_align_output_specs(out_shapes, out_dtypes)"
        )
        code.writeline("        _result = plgpu.kernel(")
        code.writeline("            " + kernel_arg)
        code.writeline("            out_shape=_aligned_out_specs,")
        code.writeline("        )(*_padded_inputs)")
        code.writeline(
            "        return pallas_gpu_unpad_results(_result, out_shapes, _is_scalar)"
        )

    def _param_to_buf_name(self, param: str) -> str | None:
        """Map a kernel parameter name back to its graph buffer name."""
        for graph_name, inner_name in self.args.input_buffers.items():
            if inner_name == param:
                return graph_name
        return None

    def _codegen_tiled_specs(self, ctx: _CodegenContext) -> None:
        """Generate tiled BlockSpec and grid variables for CPU/TPU.

        Tiles the last 1–2 dimensions of each tensor, respecting TPU
        alignment constraints (last dim multiple of 128, second-to-last
        multiple of 8).  Lower-ndim inputs are right-aligned with the
        reference output shape per numpy broadcast rules.
        """
        code = ctx.code
        skip_n = self.tile_skip_last_n
        has_transpose = "True" if self.tile_has_transpose else "False"
        is_tpu_literal = "True" if ctx.is_tpu else "False"

        # Collect per-input permutations for tiling alignment.
        all_perms: list[tuple[int, ...] | None] = []
        for p in ctx.kernel_input_params:
            buf_name = self._param_to_buf_name(p)
            all_perms.append(
                self.permuted_input_buffers.get(buf_name) if buf_name else None
            )

        mgp = self._cpu_max_grid_product
        mgp_arg = f", max_grid_product={mgp}" if mgp else ""
        perms_arg = (
            f", permutations={repr(all_perms)}"
            if any(p is not None for p in all_perms)
            else ""
        )
        code.writeline(
            f"_tile, _grid, _ax2g = pallas_compute_tiling("
            f"_pallas_out_shapes[0], "
            f"transpose={has_transpose}, "
            f"skip_last_n={skip_n}, exact_only=len(_pallas_out_shapes[0]) < 2, "
            f"is_tpu={is_tpu_literal}"
            f"{perms_arg}{mgp_arg})"
        )
        code.writeline("_ng = len(_grid)")
        code.writeline("_ref = _pallas_out_shapes[0]")

        code.writeline("out_specs_pallas = tuple(")
        code.writeline(
            "    pallas_make_block_spec(s, _ref, _tile, _ax2g, _ng, is_output=True)"
        )
        code.writeline("    for s in _pallas_out_shapes")
        code.writeline(")")

        self._codegen_strided_reshapes(code, ctx.kernel_input_params)

        # Reshape collapsed inputs before building specs.
        for param in ctx.kernel_input_params:
            buf_name = self._param_to_buf_name(param)
            cshape = self.collapsed_reshape_inputs.get(buf_name) if buf_name else None
            if cshape is not None:
                code.writeline(f"{param} = {param}.reshape({cshape})")

        # Build input BlockSpecs (with per-input permutation when needed).
        input_list = ", ".join(ctx.kernel_input_params)
        if any(p is not None for p in all_perms):
            perm_list = ", ".join(repr(p) for p in all_perms)
            code.writeline(f"_perm_flags = [{perm_list}]")
            code.writeline("in_specs_pallas = tuple(")
            code.writeline(
                f"    pallas_make_block_spec(i.shape, _ref, _tile, _ax2g, _ng, permutation=p)"
                f" for i, p in zip([{input_list}], _perm_flags)"
            )
            code.writeline(")")
        else:
            code.writeline("in_specs_pallas = tuple(")
            code.writeline(
                f"    pallas_make_block_spec(i.shape, _ref, _tile, _ax2g, _ng) for i in [{input_list}]"
            )
            code.writeline(")")

    def _codegen_jit_wrapper_cpu_tpu(
        self,
        ctx: _CodegenContext,
        kernel_arg: str,
        alias_pairs: list[tuple[int, int]],
        alias_map_literal: str,
    ) -> None:
        code = ctx.code
        grid_expr = "_grid" if self.tile_cpu_tpu else "(1,)"
        code.writeline("_result = pl.pallas_call(")
        code.writeline("    " + kernel_arg)
        code.writeline("    out_shape=out_shapes_pallas,")
        code.writeline("    out_specs=out_specs_pallas,")
        code.writeline("    in_specs=in_specs_pallas,")
        code.writeline(f"    interpret={ctx.interpret_literal},")
        code.writeline(f"    grid={grid_expr},")
        code.writeline(
            f"    input_output_aliases={{ {alias_map_literal} }},"
            if alias_pairs
            else "    input_output_aliases={},"
        )
        code.writeline(")(")
        if ctx.kernel_input_params:
            kernel_input_params_nonzero_rank = [
                f"pallas_ensure_nonzero_rank({p})" for p in ctx.kernel_input_params
            ]
            code.writeline(f"    {', '.join(kernel_input_params_nonzero_rank)},")
        code.writeline(")")
        # Reshape results back to original shapes (restores 0-d from promoted (1,))
        code.writeline("if isinstance(_result, tuple):")
        code.writeline(
            "    _result = tuple(r.reshape(s) for r, s in zip(_result, out_shapes))"
        )
        code.writeline("else:")
        code.writeline("    _result = _result.reshape(out_shapes[0])")
        code.writeline("return _result")

    def _codegen_main_entry(self, ctx: _CodegenContext, jit_wrapper_name: str) -> None:
        if ctx.is_tpu:
            self._codegen_main_entry_tpu(ctx, jit_wrapper_name)
        else:
            self._codegen_main_entry_default(ctx, jit_wrapper_name)

    def _codegen_main_entry_tpu(
        self, ctx: _CodegenContext, jit_wrapper_name: str
    ) -> None:
        code = ctx.code
        code.writeline("")
        main_name = f"{ctx.kernel_name}_main"
        kernel_name_str = ctx.kernel_name
        code.writeline(
            f"def {main_name}({', '.join(ctx.full_kernel_params)}, stream=None):"
        )
        with code.indent():
            # `jax_enable_x64` is per-process. The CPU path sets it to True,
            # so running both CPU and TPU tests in one process can cause
            # x64-related TPU crashes if we do not explicitly set it to
            # False here.
            code.writeline("jax.config.update('jax_enable_x64', False)")
            code.writeline("jax.clear_caches()")

            # Convert int64 inputs to int32 (TPU doesn't support int64)
            all_input_params = list(ctx.alias_params) + list(ctx.pointer_tail)
            for param_name in all_input_params:
                code.writeline(
                    f"{param_name} = {param_name}.to(torch.int32) "
                    f"if {param_name}.dtype == torch.int64 else {param_name}"
                )

            # Build JAX placeholders for all inputs
            code.writeline("# Build JAX placeholders for export tracing")
            all_jax_input_names = []
            for alias_name in ctx.alias_params:
                code.writeline(
                    f"{alias_name}_placeholder = jax.ShapeDtypeStruct("
                    f"{alias_name}.shape, torch_dtype_to_jax_runtime({alias_name}.dtype))"
                )
                all_jax_input_names.append(f"{alias_name}_placeholder")
            for ptr in ctx.pointer_tail:
                code.writeline(
                    f"{ptr}_placeholder = jax.ShapeDtypeStruct("
                    f"{ptr}.shape, torch_dtype_to_jax_runtime({ptr}.dtype))"
                )
                all_jax_input_names.append(f"{ptr}_placeholder")

            # Prepare output metadata
            code.writeline(
                "out_shapes = ("
                + ", ".join([f"tuple({name}.shape)" for name in ctx.output_params])
                + ",)"
            )
            dtype_exprs: list[str] = []
            for name in ctx.output_params:
                buf_name = ctx.output_buffer_lookup.get(name)
                if buf_name is not None:
                    dtype = V.graph.get_dtype(buf_name)
                    if dtype is not None:
                        dtype_exprs.append(torch_dtype_to_jax(dtype))
                        continue
                dtype_exprs.append(f"torch_dtype_to_jax_runtime({name}.dtype)")
            code.writeline("out_dtypes = (" + ", ".join(dtype_exprs) + ",)")

            # Export the jit_wrapper
            wrapper_placeholder_args = ["out_shapes", "out_dtypes"]
            wrapper_placeholder_args.extend(ctx.size_var_params)
            wrapper_placeholder_args.extend(all_jax_input_names)
            code.writeline(
                f"exported = jax.export.export("
                f"{jit_wrapper_name}, platforms=['tpu'])"
                f"({', '.join(wrapper_placeholder_args)})"
            )

            # Register and call via tpu_torch_pallas
            # Include all output and input shapes in the key to avoid stale
            # cache hits when the same kernel name is compiled with different
            # input/output ranks (e.g. broadcasting vs non-broadcasting calls).
            shape_key_parts = []
            for p in ctx.output_params:
                shape_key_parts.append(f"'_'.join(str(s) for s in {p}.shape)")
            output_key_expr = (
                " + 'x' + ".join(shape_key_parts) if shape_key_parts else "''"
            )
            input_key_parts = []
            for p in ctx.kernel_input_params:
                input_key_parts.append(f"'_'.join(str(s) for s in {p}.shape)")
            input_key_expr = (
                " + 'x' + ".join(input_key_parts) if input_key_parts else "''"
            )
            code.writeline(
                f"kernel_key = '{kernel_name_str}_out_' + "
                f"{output_key_expr}"
                f" + '_in_' + {input_key_expr}"
            )

            code.writeline(
                f"if not tpu_torch_pallas.lookup_custom_kernel('{kernel_name_str}', kernel_key):"
            )
            with code.indent():
                code.writeline("try:")
                with code.indent():
                    code.writeline(
                        f"tpu_torch_pallas.register_custom_kernel("
                        f"'{kernel_name_str}', kernel_key, "
                        f"serialized_mlir_module=exported.mlir_module_serialized)"
                    )
                code.writeline("except TypeError:")
                with code.indent():
                    code.writeline(
                        f"tpu_torch_pallas.register_custom_kernel("
                        f"'{kernel_name_str}', kernel_key, "
                        f"exported.mlir_module_serialized)"
                    )

            # Build input tensor list (all non-size-var inputs)
            input_tensor_names = list(ctx.alias_params) + list(ctx.pointer_tail)
            code.writeline(f"input_tensors = [{', '.join(input_tensor_names)}]")

            # Build output shapes list
            code.writeline("output_shape_tensors = [")
            with code.indent():
                for name in ctx.output_params:
                    buf_name = ctx.output_buffer_lookup.get(name)
                    if buf_name is not None:
                        dtype = V.graph.get_dtype(buf_name)
                        if dtype is not None:
                            code.writeline(
                                f"torch.empty({name}.shape, dtype={dtype!r}, device='tpu'),"
                            )
                            continue
                    code.writeline(
                        f"torch.empty({name}.shape, dtype={name}.dtype, device='tpu'),"
                    )
            code.writeline("]")

            # Build input_output_aliases for zero-copy donation
            if ctx.alias_pairs:
                alias_map_str = ", ".join(f"{i}: {o}" for (i, o) in ctx.alias_pairs)
                code.writeline(f"_input_output_aliases = {{ {alias_map_str} }}")
            else:
                code.writeline("_input_output_aliases = {}")

            code.writeline("try:")
            with code.indent():
                code.writeline(
                    f"tpu_torch_pallas.call_custom_kernel("
                    f"'{kernel_name_str}', kernel_key, "
                    f"inputs=input_tensors, "
                    f"output_shapes=output_shape_tensors, "
                    f"input_output_aliases=_input_output_aliases)"
                )
            code.writeline("except TypeError:")
            with code.indent():
                code.writeline(
                    f"tpu_torch_pallas.call_custom_kernel("
                    f"input_tensors, output_shape_tensors, "
                    f"'{kernel_name_str}', kernel_key, "
                    f"_input_output_aliases)"
                )

    def _codegen_main_entry_default(
        self, ctx: _CodegenContext, jit_wrapper_name: str
    ) -> None:
        code = ctx.code
        code.writeline("")
        main_name = f"{ctx.kernel_name}_main"
        code.writeline(
            f"def {main_name}({', '.join(ctx.full_kernel_params)}, stream=None):"
        )
        with code.indent():
            code.writeline("jax.config.update('jax_enable_x64', True)")
            if ctx.interpret_is_cpu:
                code.writeline(
                    "jax.config.update('jax_default_device', jax.devices('cpu')[0])"
                )
            code.writeline("jax.clear_caches()")
            if ctx.alias_params:
                code.writeline("# Convert Torch -> JAX for donated outputs")
                for alias_name in ctx.alias_params:
                    # On CPU/TPU, alias outputs may be non-contiguous (e.g.
                    # torch.cat slices) and JAX's from_dlpack rejects
                    # non-trivially strided tensors.  Making them contiguous
                    # is safe because CPU/TPU already copies all results back
                    # via copy_output_indices.  On CUDA, the donated-buffer
                    # mechanism requires the original buffer for in-place
                    # mutation, so we cannot make a contiguous copy.
                    self._emit_torch_to_jax(
                        code,
                        alias_name,
                        ctx.is_tpu,
                        contiguous=ctx.interpret_is_cpu,
                    )
            code.writeline("# Convert Torch -> JAX for in-place tensors")
            for ptr in ctx.pointer_tail:
                if ptr.startswith("in_out_ptr"):
                    self._emit_torch_to_jax(code, ptr, ctx.is_tpu, contiguous=False)
            code.writeline("# Convert Torch -> JAX for inputs")
            for ptr in ctx.pointer_tail:
                if ptr.startswith("in_ptr"):
                    self._emit_torch_to_jax(code, ptr, ctx.is_tpu, contiguous=True)

            code.writeline("# Prepare output metadata from PyTorch tensor")
            code.writeline(
                "out_shapes = ("
                + ", ".join([f"tuple({name}.shape)" for name in ctx.output_params])
                + ",)"
            )
            dtype_exprs: list[str] = []
            for name in ctx.output_params:
                buf_name = ctx.output_buffer_lookup.get(name)
                if buf_name is not None:
                    dtype = V.graph.get_dtype(buf_name)
                    if dtype is not None:
                        dtype_exprs.append(torch_dtype_to_jax(dtype))
                        continue
                dtype_exprs.append(f"torch_dtype_to_jax_runtime({name}.dtype)")
            code.writeline("out_dtypes = (" + ", ".join(dtype_exprs) + ",)")
            arg_name_map: dict[str, str] = {}
            for alias_name in ctx.alias_params:
                arg_name_map[alias_name] = f"{alias_name}_jax"
            for ptr in ctx.pointer_tail:
                arg_name_map[ptr] = f"{ptr}_jax"

            wrapper_call_args = ["out_shapes", "out_dtypes"]
            wrapper_call_args.extend(ctx.size_var_params)
            wrapper_call_args.extend(
                arg_name_map[name] for name in ctx.kernel_input_params
            )
            code.writeline(f"res = {jit_wrapper_name}({', '.join(wrapper_call_args)})")
            code.writeline("jax.block_until_ready(res)")
            if ctx.copy_output_indices:
                code.writeline(
                    "result_values = res if isinstance(res, tuple) else (res,)"
                )
                for idx in ctx.copy_output_indices:
                    out_name = ctx.output_params[idx]
                    code.writeline(
                        f"{out_name}.copy_(torch.from_dlpack(result_values[{idx}]))"
                    )

    @staticmethod
    def _emit_torch_to_jax(
        code: IndentedBuffer, var_name: str, is_tpu: bool, *, contiguous: bool
    ) -> None:
        suffix = ".detach().contiguous()" if contiguous else ".detach()"
        code.writeline(f"{var_name}_jax = jax.dlpack.from_dlpack({var_name}{suffix})")

    def call_kernel(self, name: str, node: IRNode | None = None) -> None:  # type: ignore[override]
        """Generate the Python code that calls this Pallas kernel."""
        wrapper = V.graph.wrapper_code
        arg_defs, call_args, _, _ = self.args.python_argdefs()
        kernel_param_names = [a.name for a in arg_defs]
        pure_out_params = [p for p in kernel_param_names if p.startswith("out_ptr")]
        call_arg_strs = list(map(str, call_args))
        aliasable = getattr(self, "aliasable_out_ptrs", {})
        alias_call_args = [
            call_arg_strs[kernel_param_names.index(p)]
            for p in pure_out_params
            if aliasable.get(p, False)
        ]

        # Generate kernel call: kernel_name.run(arg1, arg2, ...)
        # Note: async_compile.pallas loads {name}_main function and wraps it in PallasKernelWrapper
        # which exposes a run() method
        kernel_call = f"{name}.run({', '.join(alias_call_args + call_arg_strs)})"
        wrapper.writeline(kernel_call)


class PallasScheduling(SIMDScheduling):
    kernel_type = PallasKernel  # type: ignore[assignment]

    @classmethod
    def get_backend_features(cls, device: torch.device) -> OrderedSet[BackendFeature]:
        # Pallas/JAX can handle reductions to single elements efficiently
        # without requiring split reductions
        return OrderedSet([BackendFeature.REDUCE_TO_SINGLE_ELEMENT])

    def can_fuse(self, node1, node2):  # type: ignore[override]
        if not super().can_fuse(node1, node2):
            return False
        # Pallas partial reductions use keepdims, so fusing two reductions
        # that read the same buffer with different index patterns produces
        # intermediates with incompatible shapes (e.g. (1,8) + (8,1) = (8,8)
        # instead of (8,)).  Prevent this by rejecting fusion when the read
        # indices differ.
        if node1.is_reduction() and node2.is_reduction():
            from torch._inductor.dependencies import MemoryDep

            reads1 = {}
            for dep in node1.read_writes.reads:
                if isinstance(dep, MemoryDep):
                    reads1[dep.name] = dep.index
            for dep in node2.read_writes.reads:
                if isinstance(dep, MemoryDep) and dep.name in reads1:
                    if reads1[dep.name] != dep.index:
                        return False
        return True

    can_fuse_vertical = can_fuse  # type: ignore[assignment]
    can_fuse_horizontal = can_fuse  # type: ignore[assignment]

    def define_kernel(
        self,
        src_code: str,
        node_schedule: Sequence[BaseSchedulerNode],
        kernel: PallasKernel,
    ) -> str:  # type: ignore[override]
        wrapper = V.graph.wrapper_code
        if src_code in wrapper.src_to_kernel:
            return wrapper.src_to_kernel[src_code]

        fused_name = (
            get_fused_kernel_name(node_schedule, config.triton.descriptive_names)
            if config.triton.descriptive_names
            else ""
        )
        kernel_hash = hashlib.sha256(src_code.encode("utf-8")).hexdigest()[:8]
        if fused_name == "fused":
            kernel_name = f"pallas_{kernel_hash}"
        else:
            kernel_name = f"pallas_{fused_name}_{kernel_hash}"
        wrapper.src_to_kernel[src_code] = kernel_name

        # Replace placeholder if any
        src_code = src_code.replace("<KERNEL_NAME>", kernel_name)

        compile_wrapper = IndentedBuffer()
        compile_wrapper.writeline(f"async_compile.pallas({kernel_name!r}, r'''")
        compile_wrapper.splice(src_code, strip=True)
        compile_wrapper.writeline("''')")

        origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper)
        metadata_comment = f"{origins}\n{detailed_origins}"
        wrapper.define_kernel(kernel_name, compile_wrapper.getvalue(), metadata_comment)

        return kernel_name
