-
Notifications
You must be signed in to change notification settings - Fork 11k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
llama : add option to override model tensor buffers #11397
base: master
Are you sure you want to change the base?
Conversation
Is there a chance that the direction you're taking these changes might allow for scheduling specific threads to work on specific tensors? With R1 coming out, I'm very interested in reviving my work on trying to improve memory locality to increase CPU inference speeds. |
No, that's something that would need to handled at a lower level in the CPU backend. |
Thanks for the reply @slaren. I figured it wouldn't directly help, but that maybe you'd be adding useful metadata to tensor objects that could help coordinate affinity in the future. I'll start a fresh branch and see how far I get.
I'll also try to pull this branch and test it to see what the speedup and sysmem savings look like. |
Quick, non-scientific initial test with Deepseek R1 at q6 on llama-server with -ot exps=CPU: -ngl 0 = 4.65t/s So there is definitely a major speedup potential for this patch. I can't offload all 62 layers for this model because I only have 24GB VRAM, but I expect the trend would be continue in the same general direction. This is without dropping caches, so its inefficient, but I didn't have the time to do a proper drop/reload cycle since it takes so long to be read back into memory on each test run. |
@bmtwl |
What are the shared expert tensors called in |
I believe the pattern |
Thanks - I'll give this a try later in the week. This PR together with Reddit post opens up the interesting possibility: https://old.reddit.com/r/LocalLLaMA/comments/1ibbloy/158bit_deepseek_r1_131gb_dynamic_gguf/ of quantising up/gate projections to q2_k and down projections to q4_k (or something similar), then keeping everything else as Sadly I need to move some stuff about to get space to upscale the fp8 download to bf16 before I can try it, but will report back when I do. |
It might be worth trying |
Just being able to split the experts between NUMA nodes would make a big difference, but not sure how easy that would be as IIRC the experts' tensors are all in one huge tensor now? |
During normal operation, When I fit a model between ram and vram, Does the offloading follow a set layer sequence? (layer 0 is chosen first to be offloaded to GPU, then layer 1, etc) Between GPU offloading and ram, which takes priority?
Do you remember how much of a speedup? No need for extensive benchmarks, just the rough % estimate. |
I can't seem to offload more than 29 layers of R1 (unsloth's UD-IQ2_XXS) via RPC. 29 layers and below work fine, but 30 just crashes my rpc_server, with no error output. It is not an issue of VRAM as even setting context very low so that it takes up nowhere near my GPU's limits and it still crashes. |
I had a similar problem where if I used a single GPU (via If I didn't use either of these it tried to allocate this 1.4TB monster buffer:
After some searching I found this issue: and recompiled using (It's likely nothing to do with this PR, but thought it might help!) |
I figured it out: you have to reorder the devices so the local and mainly these:
Means this works: --device "RPC[IP1:PORT1],RPC[IP1:PORT2],RPC[IP1:PORT1],RPC[IP2:PORT2],CUDA0,CUDA1" But if I don't do this I get OOM errors with plenty of VRAM left like you had. |
I'm testing this with and without #11446 and without on unsloth's UD-IQ2_XXS I was only able to offload 29 layers, and with I was able to allocate only 28 (on a Q4_K_S quant). This is not a VRAM issue, it would have plenty of spare VRAM, it would even get past allocation, and get to warmup, where the rpc-server would then just crash. The other issue is performance the more layers I allocate the worse performance gets while bmtwl shows performance increase with more layers offloaded with non-RPC based offloading. |
I am able to load the model with
But as soon as I send the prompt I receive:
Without the Testing with 4x RTX 3090 and 320GiB RAM. Built with |
Maybe try |
No luck, still the same issue. Oddly enough, the issue only happens when sending more than 450 tokens. |
It's trying to allocate a tensor of size 2^64, which suggest there is an integer overflow somewhere. If you set the environment variable |
It is the Is it possible to try to force this particular one to be allocated into the GPU buffer? |
This is most likely a bug, we need to understand why it is happening and fix it. Since you mentioned that it only happens with large prompts, I suspect that this is caused by a zero-sized tensors. When evaluating a batch where no logits are required (which happens when evaluating a prompt that needs to be split into multiple ubatches), zero-size tensors are created to skip the calculation of the logits. diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c
index 9a3bf9f29..470ef13e6 100644
--- a/ggml/src/ggml-alloc.c
+++ b/ggml/src/ggml-alloc.c
@@ -179,6 +179,9 @@ static size_t ggml_dyn_tallocr_alloc(struct ggml_dyn_tallocr * alloc, size_t siz
// this should never happen
GGML_LOG_ERROR("%s: not enough space in the buffer to allocate %zu bytes, largest block available %zu bytes\n",
__func__, size, max_avail);
+ GGML_LOG_ERROR("%s: tensor: %s, shape: %ld %ld %ld %ld, size: %zu",
+ __func__, tensor->name, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3],
+ ggml_nbytes(tensor));
GGML_ABORT("not enough space in the buffer");
}
} |
Ok nvm, I think I see the problem. I will push a possible fix soon. |
The names of these tensors do not match the names of the tensors in llama.cpp. I suggest running with
You would need to increase the value of |
|
@slaren : thanks for pointing out that I was using the incorrect tensor names (infact the ktransformers were using the model names from safetensor format files and not gguf). So now I have rerun some tests and can see improved GPU usage, increasing to 50%:
However, using the -ot option it's seems impossible to utilise the full memory on the GPUs, the ffn_gate/ffn_up/ffn_down layers are simply too large to be loaded into 48gb vram. But this results in ~3.2 tok/s. The best combination appears to be --ngl 36 --tensor-split 19,20, where I can get over 4.2 tok/s. It seems that the bottleneck would be the CPU memory. With -ot we get more GPU utilisation, but this doesn't seem to make up for the time lost to having some of the layers on slower CPU memory. @jukofyork : I'm using the --ctk q4_0 as per the unsloth: (https://unsloth.ai/blog/deepseekr1-dynamic) If I remove this and use the default. I get CUDA OOM. I have tried the different ctk values but it doesn't appear to be any noticeable performance improvements. |
Have you tried rrcompiling with the right cuda architecture flag set? #4215 Looking at the nvidia docs sm_100 is what you need: |
I looked into that and it seems to have done the trick. cmake -B build -DGGML_CUDA=ON -DCMAKE_CUDA_ARCHITECTURES="100" This changed it to sm=100 while it compiled. I still need to mess with settings to get the best speed, but here is the very first run. llama-server --model DeepSeek-R1-Q4_K_M-00001-of-00011.gguf --flash-attn --threads 36 --temp 0.6 --min-p 0.05 --ctx-size 2048 --no-mmap -ngl 36 -ot exps=CPU I am getting about 28% higher t/s for eval_time. For prompt eval_time, around a 50% improvement. (6.2 t/s / 14.1 t/s). This one leaves a lot of room for context as it only uses 17 GB of GPU memory. llama-server --model DeepSeek-R1-Q4_K_M-00001-of-00011.gguf --flash-attn --threads 36 --temp 0.6 --min-p 0.05 --ctx-size 2048 --no-mmap -ngl 62 -ot exps=CPU This command uses 26 GB of GPU memory, so still 6 GB for extra context over 2k context (I tested this and it uses 31.1 GB at 4096 context). This gets me around eval_time / prompt eval_time (7.8 t/s / 20.5 t/s). Overall, the changes you made lead to a 66% performance increase on eval_time and around 100% performance increase on promp eval_time vs CPU only on a threadripper 7965WX, 512 GB memory, 5090. You are an absolute genius. If you have some proper benches you want me to run, let me know. |
Another update. llama-server --model DeepSeek-R1-Q4_K_M-00001-of-00011.gguf --flash-attn --threads 40 --temp 0.6 --min-p 0.05 --ctx-size 4096 --no-mmap -ngl 62 -ot exps=CPU This uses up all my threads completely and I get a small performance bump. 82% performance increase now on eval time. Make me really want a 64 core threadripper now. Also, a second 5090 for more context. Using 31 GB of GPU memory right now at 4k. I am also curious if getting double the system memory bandwidth will make a difference after the 64 core threadripper upgrade. Maybe I can get up to 10-15 t/s. Another thing I noticed is that it no longer drops off a cliff in inference speed as I continue a story. After 1k context generated, then another new 2k context, the new t/s was still 8.01 t/s. If this was CPU, it would have dropped by 25% by then. The only real limiting factor is that 3.5k context seems like the absolute upper limit. I was having trouble with 4k context. I really need more context. Another issue is that promp eval time is actually all over the place. Sometimes it is fast, sometimes it does this: Another update: I found that --flash-attn makes no difference. Also, I changed --no-mmap to --mlock and I get consistent promp eval now around 12 t/s. Still pretty amazing for running Q4 of R1 on CPU with one consumer grade GPU. Yet another update. This time using Unsloth DeepSeek-R1-UD-IQ2_XXS-00001-of-00004.gguf. This model is still really good and uses only ~200 GB system memory and 27.5 GB GPU memory at 3k context. Was able to get 3600 context max with this unsloth model. The only real limited factor with this setup is context. Any chance KV cache allocation will resolve this issue? |
Thanks for all the testing, I will try to get this ready for merging over the next days. |
Yeah flash-attn is not supported yet in llama.cpp for DeepSeek-R1 psure, check out #11557
This is likely because without those args, llama.cpp defaults to normal thanks for the benchmarks, i'm revisiting this exciting branch after playing with ktransformers and trying to figure out how they get almost 2x inference speeds on R1. i noticed when i disabled CUDA Graphs on ktransformers, it performs almost same as llama.cpp again... however cuda graphs only work when not offloading any experts into VRAM hrmm... anyway enjoying the quest for more tok/sec! cheers! |
You can try it with the PR the comment is from and the modification shown at the bottom of the comment: #11446 (comment) . This further comment showed it worked, #11446 (comment) |
@Reactantvr Thanks for sharing your test results. Just curious, what is the ratings of your DIMM memory you are using on your setup? if you run nvtop do you see your GPU running at max compute? For me it seems that in my testing CPU memory is the limiting factor/bottleneck. |
My memory is 8x64 V-Color DDR5 6000 running at 4800. I didn't bother overclocking it yet because I am on 4 CCDs, which should limit me to around 230 GB/s. I assume I would not get more bandwidth until I upgrade to a 64 core Threadripper. Waiting on Shimada Peak for that. I'll probably run it at 6400 once I get that CPU. I've never used nvtop. Plus, I am doing everything in Windows 10, so not sure if I can use it. I can give you stats from GPU-Z. Looks like GPU load is around 18-19%. This was using DeepSeek-R1-UD-IQ2_XXS. |
Works perfect for me, with dual E5 v2 + 2080Ti by running DeepSeek-R1-UD-Q2_K_XL. boost the token generation speed from 1.8tps to 3.3tps. While disable one node of numa, it can increase to 3.8tps. |
Not sure if this is related, but I get slightly worse t/s generation speed (0.4 t/s slower) offloading q2_k_xl any layers to vram (24gb over 2x 3060s) than using -ngl 0 and using only quad channel 2133 ddr4 system ram. This is using the main branch and an old version. |
Is there a way to disable to kv cache and just recompute values as required? The freed cache memory could be used for loading additional model layers. In my tests I've not seen my gpu max out, so maybe there is a sweet spot between caching vs calculating? |
No, but you can keep it in system memory with |
would it be possible to have llama.cpp only load some experts from disk to ram or vram, or from ram to vram, on demand? but it would come at the cost of latency after the prompt is sent to the model |
i am not sure if this is similar, but would it also be possible to implement keeping several instances of experts or most used tensors on each compute device to increase inference speed for common queries, and also perform separation of each expert into commonly used and rarely used neurons aka hot and cold neurons respectively like powerinfer and powerinfer 2 do? Would it also be possible to perform sharding of the model to achieve tensor parallelization between different types of devices like CPUs with GPUs using the hot and cold neurons approach, on any kind of AI model? |
The good newsUsing It is counter-intuitive to me that offloading less layers onto GPU makes it go faster, and I presume this has something to do with CUDA graphs not working as well with even a single expert also in VRAM, but I'm really just speculating wildly. This method is still not quite as fast as The technically unrelated newsI had hoped to use this fine-grained offload method to distribute experts across 6 NUMA nodes on a big dual socket Intel Xeon 6980P. While it does technically work and runs, it is much slower than just running normally with no NUMA optimizations at all. I even tried making a patch to
ExampleI tried a few configurations including 5x I'll leave the commands and some info for anyone interested inside the fold below. Also a whole discussion on the challenges of running llama.cpp in more than a single NUMA node over here. Cheers! EDIT Tried one last time with Example selective RPC backend offloading experimentsSystem Info# $ numactl -H --cpu-compress
available: 6 nodes (0-5)
node 0 cpus: 0-42, 256-298 (86)
node 0 size: 257688 MB
node 1 cpus: 43-85, 299-341 (86)
node 1 size: 258018 MB
node 2 cpus: 86-127, 342-383 (84)
node 2 size: 258019 MB
node 3 cpus: 128-170, 384-426 (86)
node 3 size: 258018 MB
node 4 cpus: 171-213, 427-469 (86)
node 4 size: 258018 MB
node 5 cpus: 214-255, 470-511 (84)
node 5 size: 257949 MB Backend RPC server(s)Bash script to distribute rpc-servers across NUMA nodes.
Frontend ClientI noticed llama-server starts like 555 threads for some reason. I tested llama-cli which starts correct requested number of threads. Both seem to generate at same very poor speeds. ## Start frontend in node 0
CMD="numactl -N 0 -m 0 \
./build_amx/bin/llama-server \
--model /mnt/ai/models/unsloth/DeepSeek-R1-GGUF/DeepSeek-R1-UD-Q2_K_XL/DeepSeek-R1-UD-Q2_K_XL-00001-of-00005.gguf \
--threads 42 \
--numa numactl \
--ctx-size 2048 \
--rpc 127.0.0.1:50053 \
--device RPC[127.0.0.1:50053] \
--n-gpu-layers 62 \
--override-tensor exps=RPC[127.0.0.1:50053] \
--override-tensor \.*=CPU \
--host 127.0.0.1 \
--port 8080 -v" numastat confirming allocations in correct nodes$ watch numastat -p $(pidof llama-server)
Per-node process memory usage (in MBs) for PID 3493834 (llama-server)
Node 0 Node 1 Node 2 Node 3 Node 4 Node 5 Total
--------------- --------------- --------------- --------------- --------------- --------------- ---------------
Huge 0.00 0.00 0.00 0.00 0.00 0.00 0.00
Heap 39.93 0.00 0.00 0.00 0.00 0.00 39.93
Stack 0.07 0.00 0.00 0.00 0.00 0.00 0.07
Private 210980.80 0.04 0.00 1.49 0.00 0.00 210982.33
---------------- --------------- --------------- --------------- --------------- --------------- --------------- ---------------
Total 211020.81 0.04 0.00 1.49 0.00 0.00 211022.34
$ watch numastat -m -v z
Per-node system memory usage (in MBs):
Node 0 Node 1 Node 2 Node 3 Node 4 Node 5 Total
--------------- --------------- --------------- --------------- --------------- --------------- ---------------
MemTotal 257688.18 258018.79 258019.46 258018.79 258018.79 257949.26 1547713.25
MemFree 38192.29 256419.21 256928.47 40355.96 255835.55 256676.09 1104407.58
MemUsed 219495.88 1599.57 1090.99 217662.82 2183.24 1273.16 443305.67
SwapCached 1.16 0.05 0.00 0.76 0.46 0.02 2.45
Active 136.95 50.34 3.38 215533.59 120.05 30.82 215875.14
Inactive 215620.14 3.86 0.36 5.04 19.48 0.00 215648.88
Active(anon) 125.72 7.79 0.78 215517.16 85.89 28.62 215765.98
Inactive(anon) 6.27 0.47 0.00 0.00 0.28 0.00 7.02
Active(file) 11.23 42.55 2.59 16.43 34.16 2.20 109.16
Inactive(file) 215613.86 3.39 0.36 5.04 19.21 0.00 215641.86
Unevictable 33.43 1.52 0.00 0.00 0.41 0.00 35.36
Mlocked 24.64 1.52 0.00 0.00 0.41 0.00 26.57
Dirty 0.01 0.25 0.00 0.00 0.00 0.00 0.27
FilePages 215641.79 47.52 3.71 22.75 57.19 2.84 215775.79
Mapped 210910.81 36.04 2.95 17.79 31.80 2.21 211001.59
AnonPages 148.90 8.21 0.02 215515.99 82.81 15.12 215771.06
Shmem 8.86 0.01 0.75 0.52 2.95 0.62 13.72
KernelStack 22.39 12.57 12.44 12.93 13.10 12.08 85.50
PageTables 419.59 0.23 0.02 422.52 0.97 0.05 843.37
Slab 2064.86 488.38 210.09 332.59 920.96 355.46 4372.35
SReclaimable 561.28 32.97 21.71 40.41 59.71 34.47 750.55
SUnreclaim 1503.58 455.41 188.39 292.18 861.25 320.99 3621.80
AnonHugePages 68.00 4.00 0.00 215502.00 66.00 0.00 215640.00
KReclaimable 561.28 32.97 21.71 40.41 59.71 34.47 750.55 btopConfirm CPU cores are running on the NUMA nodes with memory allocation. Layer MappingsI explicitly wildcard all non
Again but with
|
Thank you for this feature, it allows me to be more efficient in hybrid CPU/GPU inference for the DeepSeek R1 671B model, achieving approximately 13.x t/s through it. load_tensors: loading model tensors, this can take a while... (mmap = true) <|User|>Hello<|Assistant|>Hi there<|end▁of▁sentence|><|User|>How are you?<|Assistant|> system_info: n_threads = 48 (n_threads_batch = 48) / 192 | CUDA : ARCHS = 860,890 | USE_GRAPHS = 1 | PEER_MAX_BATCH_SIZE = 128 | CPU : SSE3 = 1 | SSSE3 = 1 | AVX = 1 | AVX2 = 1 | F16C = 1 | FMA = 1 | AVX512 = 1 | AVX512_VBMI = 1 | AVX512_VNNI = 1 | AVX512_BF16 = 1 | LLAMAFILE = 1 | AARCH64_REPACK = 1 | main: interactive mode on. == Running in interactive mode. ==
You are a helpful assistant. Okay, the user said "Hello". That's a friendly greeting. I should respond in a warm and welcoming manner. Let me make sure to keep it open-ended so they feel comfortable to ask anything. Maybe add a smiley to keep it friendly. Alright, something like, "Hello! How can I assist you today? 😊" That should work. Hello! How can I assist you today? 😊 llama_perf_sampler_print: sampling time = 5.85 ms / 96 runs ( 0.06 ms per token, 16413.06 tokens per second) |
Could you provide more information on you machine and commands which get 13.x t/s? |
Certainly. My setup is: Amd epyc 9654 * 2, 64G DDR5 4800Mhz * 24, 4070 Ti Super 16G Gpu * 1, Debian 12. The model used is: DeepSeek R1 671B Q4_K_M. Cmake command: cmake -B build -DGGML_CUDA=ON -DGGML_BUILD_NUMBER=3 -DGGML_OPENMP=OFF -DGGML_SCHED_MAX_COPIES=1 Run command: CUDA_VISIBLE_DEVICES=0 ./build/bin/llama-cli -m /data/deepseekR1/DeepSeek-R1-Q4_K_M-000000.gguf -cnv -p "You are a helpful assistant." -fa -c 65536 --temp 0.6 --top-p 0.95 -s 3047 -if -mli -t 48 -ngl 160 -nkvo -c 163840 -ctk q8_0 -ot exps=CPU |
To ubergarm: Therefore, I used the -nkvo option, which allows for large-scale text processing. The performance optimization approach of ktransformers has two key points: However, from an engineering perspective, using multiple NUMA nodes with separate data copies can release more memory bandwidth, but synchronizing across multiple NUMA nodes is a major issue. By analyzing the inference program with the perf tool, it can be seen that approximately 50% of CPU usage is spent on thread synchronization, which is a significant area for optimization. While KT achieves some acceleration in inference across different NUMA nodes, it still faces the problem of high synchronization overhead when merging data back to the main thread. The second optimization point is exactly what this PR implements: placing sparse expert weights in CPU memory instead of GPU memory. This saves GPU memory. CUDA GRAPH can effectively reduce the cost of interaction between CUDA and CPU computations. Placing kv in GPU memory, as mentioned earlier, only has an advantage in benchmark scores but is practically useless in real-world applications. For example, with a context length of 163,840, the quantized kv cache requires 540 GB of storage space. I do not have that much GPU memory, and the cost is too high, making it inefficient. Of course, MLA can alleviate the performance degradation when dealing with long contexts, and combining it with -ot provides the best single-machine deployment experience for R1 inference. VLLM and SGLang have advantages in enterprise-scale deployments, but for single-machine deployments, I believe we may still need to rely on llama.cpp (my personal opinion). |
Yeah, I think if you care about the quality of the generation then I've got a custom Just changing the |
(sorry for spamming this PR thread) Thanks for the discussion and I agree with many of your points. One more detail you left out:
My guess is you have BIOS set to
Could you share the |
(sorry for spamming this PR thread too) @ubergarm NPS0 allows for the maximum memory bandwidth with minimal overhead, but it introduces additional latency by involving the remote NUMA node's memory controller. In this mode, the CPU's L3 cache becomes available only after data is fetched from both local (assumed latency of 10) and remote (assumed latency of 50) memory controllers. NPS1 provides the highest memory bandwidth while maintaining low latency. However, it requires additional memory access handling from the program. This mode offers the maximum bandwidth, which aligns with the optimization approach of KT. By storing the full model data across two NUMA nodes, local threads can access local model data, thereby unlocking the system's maximum memory bandwidth. I adopted a similar approach by modifying the struct ggml_tensor to include data storage for each NUMA node (doubling memory consumption). Additionally, I bound NUMA nodes (CPU nodes) within the thread pool. During the ggml_compute_forward_mul_mat computation, threads can access local data backups based on their NUMA node ID, achieving the same NUMA optimization result as KT. (This is why I require the GGML_OPENMP=OFF option: I need to control thread CPU binding manually, distributing threads evenly across AMD CPU CCDs.) Through this method, in an NPS1 system configuration, faster inference performance can be achieved. @jukofyork Thank you for sharing your quantization experience. I tested scenarios with 4, 6, 8, and 12 experts and ultimately settled on the 6-expert configuration, which strikes the best balance between inference quality and speed. |
Adds command line parameter
--override-tensor
(-ot
) that allows changing the buffer type where a model tensor is allocated. This gives user fine grained control over what tensors are to offloaded to each device.How is this useful: for example, to force the experts in MoE models to stay on the CPU, while offloading the rest to the GPU, you could use
-ngl 99 -ot exps=CPU
. This may allow more efficient offloading schemes.The syntax is
<tensor name pattern>=<buffer type>
. Currently the pattern is just a string search (edit: this is no longer the case, it is a C++ regex search), ie. any tensors that contains the characters in<tensor name pattern>
will be matched and loaded into the given buffer type. Multiple overrides can be given by separating them with commas, or passing the-ot
option multiple times. To see what tensors are being matched, enable debugging output with-v
.At this point it is just a demo, feel free to experiment and report if you find any interesting uses.
Edit: added regex support, for example to keep experts of layers 20-99 in the CPU you could use
-ot "[2-9][0-9]\.ffn_.*_exps\.=CPU"
TODO: