You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am attempting to use cuda graphs with cudaMallocAsync and mamba. The code seems to work fine if I am using the regular allocater and cuda graphs but I am getting errors below with the async allocater. Other setups I have seen work with this but the ssm kernels seem to hit this weird error even though the pointer arguments are on gpu. I was wondering if the team had seen errors like this before?
Traceback (most recent call last):
File "/lustre/fs1/portfolios/llmservice/users/wdykas/mamba-inference/megatron-lm/test_graphs.py", line 148, in <module>
main()
File "/lustre/fs1/portfolios/llmservice/users/wdykas/mamba-inference/megatron-lm/test_graphs.py", line 118, in main
y = selective_state_update(
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/ops/triton/selective_state_update.py", line 181, in selective_state_update
_selective_scan_update_kernel[grid](
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 345, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 338, in run
return self.fn.run(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 338, in run
return self.fn.run(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 338, in run
return self.fn.run(*args, **kwargs)
[Previous line repeated 1 more time]
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 691, in run
kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
File "/usr/local/lib/python3.10/dist-packages/triton/backends/nvidia/driver.py", line 365, in __call__
self.launch(*args, **kwargs)
ValueError: Pointer argument (at 9) cannot be accessed from Triton (cpu tensor?)
#!/usr/bin/env python
import os
# Use cudaMallocAsync for performance.
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "backend:cudaMallocAsync"
import torch
import time
import math
from einops import rearrange, repeat
import logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
try:
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
except ImportError:
raise ImportError("selective_state_update kernel is required for this test.")
def main():
device = torch.device("cuda")
# Dummy dimensions (example values)
batch = 2
nheads = 4 # number of heads
headdim = 16 # head dimension (p)
d_state = 32 # state dimension
ngroups = 2 # groups for B and C
# Create dummy tensors on GPU
ssm_state = torch.randn(batch, nheads, headdim, d_state, device=device, dtype=torch.float32)
x = torch.randn(batch, nheads * headdim, device=device, dtype=torch.float32)
# Reshape x to shape [batch, nheads, headdim]
x_reshaped = rearrange(x, "b (h p) -> b h p", p=headdim)
dt = torch.randn(batch, nheads, device=device, dtype=torch.float32)
dt_bias = torch.randn(nheads, device=device, dtype=torch.float32)
A = torch.randn(nheads, device=device, dtype=torch.float32)
B = torch.randn(batch, ngroups, d_state, device=device, dtype=torch.float32)
C = torch.randn(batch, ngroups, d_state, device=device, dtype=torch.float32)
D = torch.randn(nheads, device=device, dtype=torch.float32)
z = torch.randn(batch, nheads, headdim, device=device, dtype=torch.float32)
# Mimic the repeats and rearrangements used in the main code:
A_rep = repeat(A, "h -> h p n", p=headdim, n=d_state)
dt_rep = repeat(dt, "b h -> b h p", p=headdim)
dt_bias_rep = repeat(dt_bias, "h -> h p", p=headdim)
D_rep = repeat(D, "h -> h p", p=headdim)
# Warm-up call (outside graph capture)
_ = selective_state_update(
ssm_state,
x_reshaped,
dt_rep,
A_rep,
B,
C,
D_rep,
z=z,
dt_bias=dt_bias_rep,
dt_softplus=True, # using original boolean
)
torch.cuda.synchronize()
# Clone static buffers and explicitly ensure they are on the correct device
static_ssm_state = ssm_state.clone().to(device)
static_x_reshaped = x_reshaped.clone().to(device)
static_dt_rep = dt_rep.clone().to(device)
static_A_rep = A_rep.clone().to(device)
static_B = B.clone().to(device)
static_C = C.clone().to(device)
static_D_rep = D_rep.clone().to(device)
static_dt_bias_rep = dt_bias_rep.clone().to(device)
static_z = z.clone().to(device)
# Verify all tensors are on the correct device before proceeding
tensor_names = ["static_ssm_state", "static_x_reshaped", "static_dt_rep",
"static_A_rep", "static_B", "static_C", "static_D_rep",
"static_z", "static_dt_bias_rep"]
tensors = [static_ssm_state, static_x_reshaped, static_dt_rep,
static_A_rep, static_B, static_C, static_D_rep,
static_z, static_dt_bias_rep]
for i, (name, tensor) in enumerate(zip(tensor_names, tensors)):
if tensor.device.type != 'cuda':
logger.error(f"Tensor '{name}' at position {i} is on {tensor.device} instead of CUDA!")
tensor = tensor.to(device) # Try to fix it
logger.info(f"Moved '{name}' to {tensor.device}")
# Log the device of all tensors before graph capture
logger.debug("Before graph capture: ssm_state.device=%s, x_reshaped.device=%s, dt_rep.device=%s, A_rep.device=%s, B.device=%s, C.device=%s, D_rep.device=%s, dt_bias_rep.device=%s, z.device=%s",
static_ssm_state.device, static_x_reshaped.device, static_dt_rep.device,
static_A_rep.device, static_B.device, static_C.device,
static_D_rep.device, static_dt_bias_rep.device, static_z.device)
# Warm-up the static buffers with a dummy kernel call
_ = selective_state_update(
static_ssm_state,
static_x_reshaped,
static_dt_rep,
static_A_rep,
static_B,
static_C,
static_D_rep,
z=static_z,
dt_bias=static_dt_bias_rep,
dt_softplus=True,
)
torch.cuda.synchronize()
# Create a non-default stream for graph capture.
capture_stream = torch.cuda.Stream(device=device)
graph = torch.cuda.CUDAGraph()
# Make sure all operations are completed before starting graph capture
torch.cuda.synchronize()
with torch.cuda.stream(capture_stream):
graph.capture_begin()
y = selective_state_update(
static_ssm_state,
static_x_reshaped,
static_dt_rep,
static_A_rep,
static_B,
static_C,
static_D_rep,
z=static_z,
dt_bias=static_dt_bias_rep,
dt_softplus=True,
)
graph.capture_end()
torch.cuda.synchronize()
# Replay the captured graph 100 times.
start = time.time()
for i in range(100):
graph.replay()
torch.cuda.synchronize()
end = time.time()
logger.debug("After selective_state_update: ssm_state.device=%s, x_reshaped.device=%s, dt_rep.device=%s, A_rep.device=%s, B.device=%s, C.device=%s, D_rep.device=%s, dt_bias_rep.device=%s, z.device=%s",
static_ssm_state.device, static_x_reshaped.device, static_dt_rep.device,
static_A_rep.device, static_B.device, static_C.device,
static_D_rep.device, static_dt_bias_rep.device, static_z.device)
print("Selective state update graph replay time over 100 iterations: {:.4f} sec".format(end - start))
print("Output sample:\n", y)
if __name__ == "__main__":
main()
The text was updated successfully, but these errors were encountered:
I am attempting to use cuda graphs with
cudaMallocAsync
and mamba. The code seems to work fine if I am using the regular allocater and cuda graphs but I am getting errors below with the async allocater. Other setups I have seen work with this but the ssm kernels seem to hit this weird error even though the pointer arguments are on gpu. I was wondering if the team had seen errors like this before?The text was updated successfully, but these errors were encountered: