# mypy: allow-untyped-defs
import contextlib

import torch


# Common testing utilities for use in public testing APIs.
# NB: these should all be importable without optional dependencies
# (like numpy and expecttest).


def wrapper_set_seed(op, *args, **kwargs):
    """Wrapper to set seed manually for some functions like dropout
    See: https://github.com/pytorch/pytorch/pull/62315#issuecomment-896143189 for more details.
    """
    with freeze_rng_state():
        torch.manual_seed(42)
        output = op(*args, **kwargs)

        if isinstance(output, torch.Tensor) and output.device.type == "lazy":
            # We need to call mark step inside freeze_rng_state so that numerics
            # match eager execution
            torch._lazy.mark_step()  # type: ignore[attr-defined]

        return output


@contextlib.contextmanager
def freeze_rng_state():
    # no_dispatch needed for test_composite_compliance
    # Some OpInfos use freeze_rng_state for rng determinism, but
    # test_composite_compliance overrides dispatch for all torch functions
    # which we need to disable to get and set rng state
    with torch.utils._mode_utils.no_dispatch(), torch._C._DisableFuncTorch():
        rng_state = torch.get_rng_state()
        if torch.accelerator.is_available():
            accelerator = torch.accelerator.current_accelerator(check_available=True)
            if accelerator is not None:
                accelerator_rng_state = torch.get_device_module(
                    accelerator.type
                ).get_rng_state()
    try:
        yield
    finally:
        # Modes are not happy with torch.cuda.set_rng_state
        # because it clones the state (which could produce a Tensor Subclass)
        # and then grabs the new tensor's data pointer in generator.set_state.
        #
        # In the long run torch.cuda.set_rng_state should probably be
        # an operator.
        #
        # NB: Mode disable is to avoid running cross-ref tests on this seeding
        with torch.utils._mode_utils.no_dispatch(), torch._C._DisableFuncTorch():
            if torch.accelerator.is_available():
                accelerator = torch.accelerator.current_accelerator(
                    check_available=True
                )
                if accelerator is not None:
                    torch.get_device_module(accelerator.type).set_rng_state(
                        accelerator_rng_state  # type: ignore[possibly-undefined]
                    )
            torch.set_rng_state(rng_state)
