Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ValueError: Pointer argument (at 9) cannot be accessed from Triton (cpu tensor?) #699

Open
wdykas opened this issue Mar 4, 2025 · 1 comment

Comments

@wdykas
Copy link

wdykas commented Mar 4, 2025

I am attempting to use cuda graphs with cudaMallocAsync and mamba. The code seems to work fine if I am using the regular allocater and cuda graphs but I am getting errors below with the async allocater. Other setups I have seen work with this but the ssm kernels seem to hit this weird error even though the pointer arguments are on gpu. I was wondering if the team had seen errors like this before?

Traceback (most recent call last):
  File "/lustre/fs1/portfolios/llmservice/users/wdykas/mamba-inference/megatron-lm/test_graphs.py", line 148, in <module>
    main()
  File "/lustre/fs1/portfolios/llmservice/users/wdykas/mamba-inference/megatron-lm/test_graphs.py", line 118, in main
    y = selective_state_update(
  File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/ops/triton/selective_state_update.py", line 181, in selective_state_update
    _selective_scan_update_kernel[grid](
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 345, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 338, in run
    return self.fn.run(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 338, in run
    return self.fn.run(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 338, in run
    return self.fn.run(*args, **kwargs)
  [Previous line repeated 1 more time]
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 691, in run
    kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
  File "/usr/local/lib/python3.10/dist-packages/triton/backends/nvidia/driver.py", line 365, in __call__
    self.launch(*args, **kwargs)
ValueError: Pointer argument (at 9) cannot be accessed from Triton (cpu tensor?)

#!/usr/bin/env python
import os
# Use cudaMallocAsync for performance.
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "backend:cudaMallocAsync"

import torch
import time
import math
from einops import rearrange, repeat
import logging

logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)

try:
    from mamba_ssm.ops.triton.selective_state_update import selective_state_update
except ImportError:
    raise ImportError("selective_state_update kernel is required for this test.")

def main():
    device = torch.device("cuda")
    # Dummy dimensions (example values)
    batch = 2
    nheads = 4         # number of heads
    headdim = 16       # head dimension (p)
    d_state = 32       # state dimension
    ngroups = 2        # groups for B and C

    # Create dummy tensors on GPU
    ssm_state = torch.randn(batch, nheads, headdim, d_state, device=device, dtype=torch.float32)
    x = torch.randn(batch, nheads * headdim, device=device, dtype=torch.float32)
    # Reshape x to shape [batch, nheads, headdim]
    x_reshaped = rearrange(x, "b (h p) -> b h p", p=headdim)
    dt = torch.randn(batch, nheads, device=device, dtype=torch.float32)
    dt_bias = torch.randn(nheads, device=device, dtype=torch.float32)
    A = torch.randn(nheads, device=device, dtype=torch.float32)
    B = torch.randn(batch, ngroups, d_state, device=device, dtype=torch.float32)
    C = torch.randn(batch, ngroups, d_state, device=device, dtype=torch.float32)
    D = torch.randn(nheads, device=device, dtype=torch.float32)
    z = torch.randn(batch, nheads, headdim, device=device, dtype=torch.float32)

    # Mimic the repeats and rearrangements used in the main code:
    A_rep = repeat(A, "h -> h p n", p=headdim, n=d_state)
    dt_rep = repeat(dt, "b h -> b h p", p=headdim)
    dt_bias_rep = repeat(dt_bias, "h -> h p", p=headdim)
    D_rep = repeat(D, "h -> h p", p=headdim)

    # Warm-up call (outside graph capture)
    _ = selective_state_update(
        ssm_state,
        x_reshaped,
        dt_rep,
        A_rep,
        B,
        C,
        D_rep,
        z=z,
        dt_bias=dt_bias_rep,
        dt_softplus=True,   # using original boolean
    )
    torch.cuda.synchronize()

    # Clone static buffers and explicitly ensure they are on the correct device
    static_ssm_state = ssm_state.clone().to(device)
    static_x_reshaped = x_reshaped.clone().to(device)
    static_dt_rep = dt_rep.clone().to(device)
    static_A_rep = A_rep.clone().to(device)
    static_B = B.clone().to(device)
    static_C = C.clone().to(device)
    static_D_rep = D_rep.clone().to(device)
    static_dt_bias_rep = dt_bias_rep.clone().to(device)
    static_z = z.clone().to(device)

    # Verify all tensors are on the correct device before proceeding
    tensor_names = ["static_ssm_state", "static_x_reshaped", "static_dt_rep", 
                   "static_A_rep", "static_B", "static_C", "static_D_rep", 
                   "static_z", "static_dt_bias_rep"]
    tensors = [static_ssm_state, static_x_reshaped, static_dt_rep, 
              static_A_rep, static_B, static_C, static_D_rep, 
              static_z, static_dt_bias_rep]
    
    for i, (name, tensor) in enumerate(zip(tensor_names, tensors)):
        if tensor.device.type != 'cuda':
            logger.error(f"Tensor '{name}' at position {i} is on {tensor.device} instead of CUDA!")
            tensor = tensor.to(device)  # Try to fix it
            logger.info(f"Moved '{name}' to {tensor.device}")

    # Log the device of all tensors before graph capture
    logger.debug("Before graph capture: ssm_state.device=%s, x_reshaped.device=%s, dt_rep.device=%s, A_rep.device=%s, B.device=%s, C.device=%s, D_rep.device=%s, dt_bias_rep.device=%s, z.device=%s",
                 static_ssm_state.device, static_x_reshaped.device, static_dt_rep.device,
                 static_A_rep.device, static_B.device, static_C.device,
                 static_D_rep.device, static_dt_bias_rep.device, static_z.device)

    # Warm-up the static buffers with a dummy kernel call
    _ = selective_state_update(
        static_ssm_state,
        static_x_reshaped,
        static_dt_rep,
        static_A_rep,
        static_B,
        static_C,
        static_D_rep,
        z=static_z,
        dt_bias=static_dt_bias_rep,
        dt_softplus=True,
    )
    torch.cuda.synchronize()

    # Create a non-default stream for graph capture.
    capture_stream = torch.cuda.Stream(device=device)
    graph = torch.cuda.CUDAGraph()
    
    # Make sure all operations are completed before starting graph capture
    torch.cuda.synchronize()
    
    with torch.cuda.stream(capture_stream):
        graph.capture_begin()
        y = selective_state_update(
            static_ssm_state,
            static_x_reshaped,
            static_dt_rep,
            static_A_rep,
            static_B,
            static_C,
            static_D_rep,
            z=static_z,
            dt_bias=static_dt_bias_rep,
            dt_softplus=True,
        )
        graph.capture_end()
    torch.cuda.synchronize()

    # Replay the captured graph 100 times.
    start = time.time()
    for i in range(100):
        graph.replay()
    torch.cuda.synchronize()
    end = time.time()

    logger.debug("After selective_state_update: ssm_state.device=%s, x_reshaped.device=%s, dt_rep.device=%s, A_rep.device=%s, B.device=%s, C.device=%s, D_rep.device=%s, dt_bias_rep.device=%s, z.device=%s",
                 static_ssm_state.device, static_x_reshaped.device, static_dt_rep.device,
                 static_A_rep.device, static_B.device, static_C.device,
                 static_D_rep.device, static_dt_bias_rep.device, static_z.device)
    print("Selective state update graph replay time over 100 iterations: {:.4f} sec".format(end - start))
    print("Output sample:\n", y)

if __name__ == "__main__":
    main()
@2020chenlin
Copy link

I have the same question. do you solve it ?please let me kown your solution

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants