diff --git a/python/test/unit/test_debug.py b/python/test/unit/test_debug.py index 05bf1fe4940e..405f04bfa8f5 100644 --- a/python/test/unit/test_debug.py +++ b/python/test/unit/test_debug.py @@ -4,27 +4,33 @@ import triton.language as tl import triton -@pytest.mark.parametrize('cond, opt_flag, env_var', [ - (cond, opt_flag, env_var) for cond in [True, False] \ - for opt_flag in [True, False] \ - for env_var in [True, False]\ -]) + +@pytest.mark.parametrize('cond', [True, False]) +@pytest.mark.parametrize('opt_flag', [True, False, None]) +@pytest.mark.parametrize('env_var', [True, False]) +@pytest.mark.parametrize('jit_flag', [True, False]) @pytest.mark.forked -def test_device_assert(cond, opt_flag, env_var, device): +def test_device_assert(cond, opt_flag, env_var, jit_flag, device): os.environ['TRITON_DEBUG'] = str(int(env_var)) torch.zeros([1], dtype=torch.int32, device=device) - @triton.jit + @triton.jit(debug=jit_flag) def _kernel(COND: tl.constexpr): tl.device_assert(COND, 'test') - if not cond and (opt_flag or env_var): + is_debug = env_var or (opt_flag if opt_flag is not None else jit_flag) + + kwargs = {} + if opt_flag is not None: + kwargs["debug"] = opt_flag + + if not cond and is_debug: with pytest.raises(RuntimeError): - _kernel[(1, )](cond, debug=opt_flag) + _kernel[(1, )](cond, **kwargs) getattr(torch, device).synchronize() return - _kernel[(1, )](cond, debug=opt_flag) + _kernel[(1, )](cond, **kwargs) getattr(torch, device).synchronize() diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 3a7544c3c4ce..b04434fcf5fa 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -29,7 +29,6 @@ def builtin(fn: T) -> T: @wraps(fn) def wrapper(*args, **kwargs): if "_builder" not in kwargs or kwargs["_builder"] is None: - print(kwargs) raise ValueError("Did you forget to add @triton.jit ? " "(`_builder` argument must be provided outside of JIT functions.)") return fn(*args, **kwargs) diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 45178a40bb29..bc883d4fc84b 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -561,7 +561,7 @@ def create_binder(self, backend): ] def run(self, *args, grid, warmup, **kwargs): - kwargs["debug"] = kwargs.get("debug", False) or os.environ.get("TRITON_DEBUG", "0") == "1" + kwargs["debug"] = kwargs.get("debug", self.debug) or os.environ.get("TRITON_DEBUG", "0") == "1" # parse options from ..compiler import make_backend @@ -698,6 +698,7 @@ def __init__(self, fn, version=None, do_not_specialize=None, do_not_specialize_o # JITFunction can be instantiated as kernel # when called with a grid using __getitem__ self.kernel = None + self.debug = debug self.noinline = noinline # TODO(jlebar): Remove uses of these fields outside this file, then