Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: deepseek-r1 mutlti-node crash #13136

Open
1 task done
fan-niu opened this issue Feb 12, 2025 · 15 comments
Open
1 task done

[Bug]: deepseek-r1 mutlti-node crash #13136

fan-niu opened this issue Feb 12, 2025 · 15 comments
Labels
bug Something isn't working

Comments

@fan-niu
Copy link

fan-niu commented Feb 12, 2025

Your current environment

The output of `python collect_env.py`
PyTorch version: 2.5.1+cu124
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.3 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.31.4
Libc version: glibc-2.35

Python version: 3.12.9 (main, Feb  5 2025, 08:49:00) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-122-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.1.105
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA H100 80GB HBM3
GPU 1: NVIDIA H100 80GB HBM3
GPU 2: NVIDIA H100 80GB HBM3
GPU 3: NVIDIA H100 80GB HBM3
GPU 4: NVIDIA H100 80GB HBM3
GPU 5: NVIDIA H100 80GB HBM3
GPU 6: NVIDIA H100 80GB HBM3
GPU 7: NVIDIA H100 80GB HBM3

Nvidia driver version: 550.90.07
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.4.5.8
[pip3] nvidia-cuda-cupti-cu12==12.4.127
[pip3] nvidia-cuda-nvrtc-cu12==12.4.127
[pip3] nvidia-cuda-runtime-cu12==12.4.127
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.2.1.3
[pip3] nvidia-curand-cu12==10.3.5.147
[pip3] nvidia-cusolver-cu12==11.6.1.9
[pip3] nvidia-cusparse-cu12==12.3.1.170
[pip3] nvidia-ml-py==12.570.86
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.4.127
[pip3] pyzmq==26.2.1
[pip3] torch==2.5.1
[pip3] torchaudio==2.5.1
[pip3] torchvision==0.20.1
[pip3] transformers==4.48.2
[pip3] triton==3.1.0
[conda] Could not collect
ROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: 0.7.2


🐛 Describe the bug

I use 2x8H100 to deploy the deepseek-r1 model in kubenet, but when I test the bbh test set with a concurrency count of 3, the service will run for a while and then crash. The nvidia-smi check will show that the gpu memory of a process has dropped from more than 60G to 2G, and the server log has the following error.

^[[36m(RayWorkerWrapper pid=4823, ip=10.233.68.253)^[[0m *** SIGSEGV received at time=1739339449 on cpu 57 ***
^[[36m(RayWorkerWrapper pid=4823, ip=10.233.68.253)^[[0m PC: @     0x7ee184049b8a  (unknown)  addProxyOpIfNeeded()
^[[36m(RayWorkerWrapper pid=4823, ip=10.233.68.253)^[[0m     @     0x7f11ab22d520  (unknown)  (unknown)
^[[36m(RayWorkerWrapper pid=4823, ip=10.233.68.253)^[[0m     @            0x80000  (unknown)  (unknown)
^[[36m(RayWorkerWrapper pid=4823, ip=10.233.68.253)^[[0m [2025-02-11 21:50:49,831 E 4823 4823] logging.cc:460: *** SIGSEGV received at time=1739339449 on cpu 57 ***
^[[36m(RayWorkerWrapper pid=4823, ip=10.233.68.253)^[[0m [2025-02-11 21:50:49,831 E 4823 4823] logging.cc:460: PC: @     0x7ee184049b8a  (unknown)  addProxyOpIfNeeded()
^[[36m(RayWorkerWrapper pid=4823, ip=10.233.68.253)^[[0m [2025-02-11 21:50:49,832 E 4823 4823] logging.cc:460:     @     0x7f11ab22d520  (unknown)  (unknown)
^[[36m(RayWorkerWrapper pid=4823, ip=10.233.68.253)^[[0m [2025-02-11 21:50:49,832 E 4823 4823] logging.cc:460:     @            0x80000  (unknown)  (unknown)
^[[36m(RayWorkerWrapper pid=4823, ip=10.233.68.253)^[[0m Fatal Python error: Segmentation fault
^[[36m(RayWorkerWrapper pid=4823, ip=10.233.68.253)^[[0m
^[[36m(RayWorkerWrapper pid=4823, ip=10.233.68.253)^[[0m Stack (most recent call first):
^[[36m(RayWorkerWrapper pid=4823, ip=10.233.68.253)^[[0m   File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/device_communicators/pynccl_wrapper.py", line 290 in ncclAllReduce
^[[36m(RayWorkerWrapper pid=4823, ip=10.233.68.253)^[[0m   File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/device_communicators/pynccl.py", line 126 in all_reduce
^[[36m(RayWorkerWrapper pid=4823, ip=10.233.68.253)^[[0m   File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/parallel_state.py", line 370 in _all_reduce_out_place
^[[36m(RayWorkerWrapper pid=4823, ip=10.233.68.253)^[[0m   File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/parallel_state.py", line 111 in all_reduce
^[[36m(RayWorkerWrapper pid=4823, ip=10.233.68.253)^[[0m   File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 1116 in __call__
^[[36m(RayWorkerWrapper pid=4823, ip=10.233.68.253)^[[0m   File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/parallel_state.py", line 357 in all_reduce
^[[36m(RayWorkerWrapper pid=4823, ip=10.233.68.253)^[[0m   File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/communication_op.py", line 13 in tensor_model_parallel_all_reduce
^[[36m(RayWorkerWrapper pid=4823, ip=10.233.68.253)^[[0m   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/linear.py", line 1145 in forward
^[[36m(RayWorkerWrapper pid=4823, ip=10.233.68.253)^[[0m   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1747 in _call_impl
^[[36m(RayWorkerWrapper pid=4823, ip=10.233.68.253)^[[0m   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1736 in _wrapped_call_impl
^[[36m(RayWorkerWrapper pid=4823, ip=10.233.68.253)^[[0m   File "/usr/local/lib/python3.12/dist-packages/vllm/attention/backends/mla/utils.py", line 541 in _forward_prefill_flash
^[[36m(RayWorkerWrapper pid=4823, ip=10.233.68.253)^[[0m   File "/usr/local/lib/python3.12/dist-packages/vllm/attention/backends/triton_mla.py", line 694 in _forward_prefill
^[[36m(RayWorkerWrapper pid=4823, ip=10.233.68.253)^[[0m   File "/usr/local/lib/python3.12/dist-packages/vllm/attention/backends/mla/utils.py", line 499 in forward
^[[36m(RayWorkerWrapper pid=4823, ip=10.233.68.253)^[[0m   File "/usr/local/lib/python3.12/dist-packages/vllm/attention/layer.py", line 307 in unified_attention
^[[36m(RayWorkerWrapper pid=4823, ip=10.233.68.253)^[[0m   File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 1116 in __call__
^[[36m(RayWorkerWrapper pid=4823, ip=10.233.68.253)^[[0m   File "/usr/local/lib/python3.12/dist-packages/vllm/attention/layer.py", line 201 in forward
^[[36m(RayWorkerWrapper pid=4823, ip=10.233.68.253)^[[0m   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1747 in _call_impl
^[[36m(RayWorkerWrapper pid=4823, ip=10.233.68.253)^[[0m   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1736 in _wrapped_call_impl
^[[36m(RayWorkerWrapper pid=4823, ip=10.233.68.253)^[[0m   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 480 in forward
^[[36m(RayWorkerWrapper pid=4823, ip=10.233.68.253)^[[0m   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1747 in _call_impl
^[[36m(RayWorkerWrapper pid=4823, ip=10.233.68.253)^[[0m   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1736 in _wrapped_call_impl
^[[36m(RayWorkerWrapper pid=4823, ip=10.233.68.253)^[[0m   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 561 in forward
^[[36m(RayWorkerWrapper pid=4823, ip=10.233.68.253)^[[0m   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1747 in _call_impl
^[[36m(RayWorkerWrapper pid=4823, ip=10.233.68.253)^[[0m   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1736 in _wrapped_call_impl
^[[36m(RayWorkerWrapper pid=4823, ip=10.233.68.253)^[[0m   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 643 in forward
^[[36m(RayWorkerWrapper pid=4823, ip=10.233.68.253)^[[0m   File "/usr/local/lib/python3.12/dist-packages/vllm/compilation/decorators.py", line 172 in __call__
^[[36m(RayWorkerWrapper pid=4823, ip=10.233.68.253)^[[0m   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 687 in forward
^[[36m(RayWorkerWrapper pid=4823, ip=10.233.68.253)^[[0m   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1747 in _call_impl
^[[36m(RayWorkerWrapper pid=4823, ip=10.233.68.253)^[[0m   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1736 in _wrapped_call_impl
^[[36m(RayWorkerWrapper pid=4823, ip=10.233.68.253)^[[0m   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1719 in execute_model
^[[36m(RayWorkerWrapper pid=4823, ip=10.233.68.253)^[[0m   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116 in decorate_context
^[[36m(RayWorkerWrapper pid=4823, ip=10.233.68.253)^[[0m   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/worker_base.py", line 413 in execute_model
^[[36m(RayWorkerWrapper pid=4823, ip=10.233.68.253)^[[0m   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/worker_base.py", line 93 in start_worker_execution_loop
^[[36m(RayWorkerWrapper pid=4823, ip=10.233.68.253)^[[0m   File "/usr/local/lib/python3.12/dist-packages/vllm/utils.py", line 2220 in run_method
^[[36m(RayWorkerWrapper pid=4823, ip=10.233.68.253)^[[0m   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/worker_base.py", line 566 in execute_method
^[[36m(RayWorkerWrapper pid=4823, ip=10.233.68.253)^[[0m   File "/usr/local/lib/python3.12/dist-packages/ray/util/tracing/tracing_helper.py", line 463 in _resume_span
^[[36m(RayWorkerWrapper pid=4823, ip=10.233.68.253)^[[0m   File "/usr/local/lib/python3.12/dist-packages/ray/_private/function_manager.py", line 696 in actor_method_executor
^[[36m(RayWorkerWrapper pid=4823, ip=10.233.68.253)^[[0m   File "/usr/local/lib/python3.12/dist-packages/ray/_private/worker.py", line 935 in main_loop
^[[36m(RayWorkerWrapper pid=4823, ip=10.233.68.253)^[[0m   File "/usr/local/lib/python3.12/dist-packages/ray/_private/workers/default_worker.py", line 297 in <module>
^[[36m(RayWorkerWrapper pid=4823, ip=10.233.68.253)^[[0m
^[[36m(RayWorkerWrapper pid=4823, ip=10.233.68.253)^[[0m Extension modules: msgpack._cmsgpack, google._upb._message, psutil._psutil_linux, psutil._psutil_posix, setproctitle, yaml._yaml, charset_normalizer.md, requests.packages.charset_normalizer.md, requests.packages.chardet.md, uvloop.loop, ray._raylet, numpy.core._multiarray_umath, numpy.core._multiarray_tests, numpy.linalg._umath_linalg, numpy.fft._pocketfft_internal, numpy.random._common, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.random._mt19937, numpy.random.mtrand, numpy.random._philox, numpy.random._pcg64, numpy.random._sfc64, numpy.random._generator, torch._C, torch._C._dynamo.autograd_compiler, torch._C._dynamo.eval_frame, torch._C._dynamo.guards, torch._C._dynamo.utils, torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._nn, torch._C._sparse, torch._C._special, markupsafe._speedups, PIL._imaging, msgspec._core, PIL._imagingft, zmq.backend.cython._zmq, multidict._multidict, yarl._quoting_c, propcache._helpers_c, aiohttp._http_writer, aiohttp._http_parser, aiohttp._websocket.mask, aiohttp._websocket.reader_c, frozenlist._frozenlist, vllm.cumem_allocator, sentencepiece._sentencepiece, cuda_utils, __triton_launcher (total: 52)
^[[36m(RayWorkerWrapper pid=4820, ip=10.233.68.253)^[[0m loc("/usr/local/lib/python3.12/dist-packages/vllm/attention/ops/triton_decode_attention.py":308:16): error: operation scheduled before its operands^[[32m [repeated 14x across cluster]^[[0m

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
@fan-niu fan-niu added the bug Something isn't working label Feb 12, 2025
@hiyforever
Copy link

+1

1 similar comment
@qingwu11
Copy link

+1

@YejinHwang909
Copy link

+1

@jeasonli0912
Copy link

+1,在线等,挺急的

@ayrnb
Copy link

ayrnb commented Feb 18, 2025

+1

@simon-mo
Copy link
Collaborator

cc @youkaichao if you have any suggestions

but given this is a harder to reproduce nccl segfault, i recommend setting up fault tolerance for the service for now.

@fan-niu
Copy link
Author

fan-niu commented Feb 18, 2025

cc @youkaichao if you have any suggestions

but given this is a harder to reproduce nccl segfault, i recommend setting up fault tolerance for the service for now.

@simon-mo thanks for your reply, How do I make the settings you recommend?

Another question, if deepseek-r1 is started on 2 nodes, could you tell me how to create a profiler? I used torch profiler but it failed, thanks a lot

@youkaichao
Copy link
Member

seems to be a nccl error.

can you please run the sanity check script at https://docs.vllm.ai/en/latest/getting_started/troubleshooting.html#incorrect-hardware-driver to see if NCCL works as expected first?

@Sisphyus
Copy link

Upgrade nccl version to 2.25.1 solved this problem

@YejinHwang909
Copy link

I upgraded NCCL version to 2.25.1-1+cuda12.8 and fixed it, thank you

@fan-niu
Copy link
Author

fan-niu commented Feb 19, 2025

I upgraded NCCL version to 2.25.1-1+cuda12.8 and fixed it, thank you

@Sisphyus @YejinHwang909 Thanks a lot, will try it again.

@ayrnb
Copy link

ayrnb commented Feb 21, 2025

When I set NCCL_MIN_NCHANNELS=24 and NCCL_IB_QPS_PER_CONNECTION=8, the same error has occurred again. 😩 2*8 H20

@hmellor hmellor moved this to Backlog in DeepSeek V3/R1 Feb 25, 2025
@youkaichao
Copy link
Member

I upgraded NCCL version to 2.25.1-1+cuda12.8 and fixed it, thank you

multiple users reported that pip install -U nvidia-nccl-cu12 to install the latest NCCL can solve the problem. if you met this issue, it's worthwhile to give it a try.

@forearrow
Copy link

@youkaichao after upgradding nvidia-nccl to v2.25.1, it will report torch v2.5.1 reqiures nvidia-nccl v2.21, seems v2.25.1is incompatible with torch v2.5.1 which is using by vllm?

@youkaichao
Copy link
Member

@youkaichao after upgradding nvidia-nccl to v2.25.1, it will report torch v2.5.1 reqiures nvidia-nccl v2.21, seems v2.25.1is incompatible with torch v2.5.1 which is using by vllm?

you can install pytorch first, and only upgrade the nccl by pip install -U nvidia-nccl-cu12 . torch v2.5.1 reqiures nvidia-nccl v2.21, but v2.25.1 should also work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
Status: Backlog
Development

No branches or pull requests

10 participants