Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support StableLM2 12B #6635

Merged
merged 25 commits into from
Apr 16, 2024
Merged

Support StableLM2 12B #6635

merged 25 commits into from
Apr 16, 2024

Conversation

ashishdatta
Copy link
Contributor

Copy link
Collaborator

@Galunid Galunid left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this working, or work in progress?

@Galunid
Copy link
Collaborator

Galunid commented Apr 13, 2024

Since you require specific branch to convert, perhaps it'd be a good idea to warn user if they are using main from huggingface instead of stack-per-head-qk-norm. It would help avoid "StableLM2 12B doesn't convert" issues after this gets merged ;)

@compilade
Copy link
Collaborator

perhaps it'd be a good idea to warn user

By then the user would already have downloaded more than 20GB of model files. Ideally, the q and k layernorms should be stacked during conversion (similarly to how mixtral's expert tensors are concatenated) if they aren't already and if they're present.

@ashishdatta
Copy link
Contributor Author

perhaps it'd be a good idea to warn user

By then the user would already have downloaded more than 20GB of model files. Ideally, the q and k layernorms should be stacked during conversion (similarly to how mixtral's expert tensors are concatenated) if they aren't already and if they're present.

Done, thanks for pointing this out. Makes the conversion a lot simpler.

@ashishdatta ashishdatta requested a review from ggerganov April 14, 2024 06:58
Copy link
Contributor

github-actions bot commented Apr 14, 2024

📈 llama.cpp server for bench-server-baseline on Standard_NC4as_T4_v3 for phi-2-q4_0: 413 iterations 🚀

Expand details for performance related PR only
  • Concurrent users: 8, duration: 10m
  • HTTP request : avg=11408.45ms p(95)=29598.35ms fails=, finish reason: stop=354 truncated=59
  • Prompt processing (pp): avg=128.03tk/s p(95)=572.74tk/s
  • Token generation (tg): avg=25.78tk/s p(95)=35.8tk/s
  • ggml-org/models/phi-2/ggml-model-q4_0.gguf parallel=8 ctx-size=16384 ngl=33 batch-size=2048 ubatch-size=256 pp=1024 pp+tg=2048 branch=stablelm-12b commit=94e8c490fe210212bc314f01167881a74575f04a

prompt_tokens_seconds

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 413 iterations"
    y-axis "llamacpp:prompt_tokens_seconds"
    x-axis "llamacpp:prompt_tokens_seconds" 1713280523 --> 1713281153
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 614.55, 614.55, 614.55, 614.55, 614.55, 534.69, 534.69, 534.69, 534.69, 534.69, 543.23, 543.23, 543.23, 543.23, 543.23, 553.74, 553.74, 553.74, 553.74, 553.74, 612.08, 612.08, 612.08, 612.08, 612.08, 626.31, 626.31, 626.31, 626.31, 626.31, 626.42, 626.42, 626.42, 626.42, 626.42, 625.17, 625.17, 625.17, 625.17, 625.17, 656.19, 656.19, 656.19, 656.19, 656.19, 654.69, 654.69, 654.69, 654.69, 654.69, 655.27, 655.27, 655.27, 655.27, 655.27, 661.59, 661.59, 661.59, 661.59, 661.59, 659.99, 659.99, 659.99, 659.99, 659.99, 674.11, 674.11, 674.11, 674.11, 674.11, 658.96, 658.96, 658.96, 658.96, 658.96, 634.13, 634.13, 634.13, 634.13, 634.13, 643.81, 643.81, 643.81, 643.81, 643.81, 648.18, 648.18, 648.18, 648.18, 648.18, 650.16, 650.16, 650.16, 650.16, 650.16, 653.51, 653.51, 653.51, 653.51, 653.51, 652.6, 652.6, 652.6, 652.6, 652.6, 668.81, 668.81, 668.81, 668.81, 668.81, 666.91, 666.91, 666.91, 666.91, 666.91, 667.03, 667.03, 667.03, 667.03, 667.03, 666.09, 666.09, 666.09, 666.09, 666.09, 670.49, 670.49, 670.49, 670.49, 670.49, 671.48, 671.48, 671.48, 671.48, 671.48, 674.92, 674.92, 674.92, 674.92, 674.92, 686.04, 686.04, 686.04, 686.04, 686.04, 683.1, 683.1, 683.1, 683.1, 683.1, 683.14, 683.14, 683.14, 683.14, 683.14, 684.02, 684.02, 684.02, 684.02, 684.02, 690.95, 690.95, 690.95, 690.95, 690.95, 688.97, 688.97, 688.97, 688.97, 688.97, 688.73, 688.73, 688.73, 688.73, 688.73, 687.42, 687.42, 687.42, 687.42, 687.42, 687.98, 687.98, 687.98, 687.98, 687.98, 693.31, 693.31, 693.31, 693.31, 693.31, 694.04, 694.04, 694.04, 694.04, 694.04, 693.27, 693.27, 693.27, 693.27, 693.27, 693.99, 693.99, 693.99, 693.99, 693.99, 698.52, 698.52, 698.52, 698.52, 698.52, 695.43, 695.43, 695.43, 695.43, 695.43, 700.16, 700.16, 700.16, 700.16, 700.16, 692.94, 692.94, 692.94, 692.94, 692.94, 692.83, 692.83, 692.83, 692.83, 692.83, 691.93, 691.93, 691.93, 691.93, 691.93, 691.27, 691.27, 691.27, 691.27, 691.27, 693.38, 693.38, 693.38, 693.38, 693.38, 692.64, 692.64, 692.64, 692.64, 692.64, 692.87, 692.87, 692.87, 692.87, 692.87, 696.75, 696.75, 696.75, 696.75, 696.75, 697.64, 697.64, 697.64, 697.64, 697.64, 693.61, 693.61, 693.61, 693.61, 693.61, 691.43, 691.43, 691.43, 691.43, 691.43, 689.54, 689.54, 689.54, 689.54, 689.54, 689.54, 689.54, 689.54, 689.54, 689.54, 686.63, 686.63, 686.63, 686.63, 686.63, 685.9, 685.9, 685.9, 685.9, 685.9, 685.89, 685.89, 685.89, 685.89, 685.89, 685.89]
                    
Loading
predicted_tokens_seconds
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 413 iterations"
    y-axis "llamacpp:predicted_tokens_seconds"
    x-axis "llamacpp:predicted_tokens_seconds" 1713280523 --> 1713281153
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 29.78, 29.78, 29.78, 29.78, 29.78, 29.77, 29.77, 29.77, 29.77, 29.77, 23.57, 23.57, 23.57, 23.57, 23.57, 22.73, 22.73, 22.73, 22.73, 22.73, 22.97, 22.97, 22.97, 22.97, 22.97, 23.33, 23.33, 23.33, 23.33, 23.33, 23.84, 23.84, 23.84, 23.84, 23.84, 24.66, 24.66, 24.66, 24.66, 24.66, 25.01, 25.01, 25.01, 25.01, 25.01, 25.2, 25.2, 25.2, 25.2, 25.2, 25.16, 25.16, 25.16, 25.16, 25.16, 24.62, 24.62, 24.62, 24.62, 24.62, 24.46, 24.46, 24.46, 24.46, 24.46, 24.36, 24.36, 24.36, 24.36, 24.36, 24.43, 24.43, 24.43, 24.43, 24.43, 23.9, 23.9, 23.9, 23.9, 23.9, 23.71, 23.71, 23.71, 23.71, 23.71, 23.35, 23.35, 23.35, 23.35, 23.35, 22.49, 22.49, 22.49, 22.49, 22.49, 22.46, 22.46, 22.46, 22.46, 22.46, 22.62, 22.62, 22.62, 22.62, 22.62, 22.62, 22.62, 22.62, 22.62, 22.62, 22.3, 22.3, 22.3, 22.3, 22.3, 22.3, 22.3, 22.3, 22.3, 22.3, 21.89, 21.89, 21.89, 21.89, 21.89, 21.82, 21.82, 21.82, 21.82, 21.82, 21.67, 21.67, 21.67, 21.67, 21.67, 21.77, 21.77, 21.77, 21.77, 21.77, 21.83, 21.83, 21.83, 21.83, 21.83, 21.87, 21.87, 21.87, 21.87, 21.87, 21.88, 21.88, 21.88, 21.88, 21.88, 22.05, 22.05, 22.05, 22.05, 22.05, 22.07, 22.07, 22.07, 22.07, 22.07, 21.9, 21.9, 21.9, 21.9, 21.9, 21.76, 21.76, 21.76, 21.76, 21.76, 21.69, 21.69, 21.69, 21.69, 21.69, 21.77, 21.77, 21.77, 21.77, 21.77, 21.92, 21.92, 21.92, 21.92, 21.92, 22.0, 22.0, 22.0, 22.0, 22.0, 22.11, 22.11, 22.11, 22.11, 22.11, 22.32, 22.32, 22.32, 22.32, 22.32, 22.32, 22.32, 22.32, 22.32, 22.32, 22.37, 22.37, 22.37, 22.37, 22.37, 22.3, 22.3, 22.3, 22.3, 22.3, 22.29, 22.29, 22.29, 22.29, 22.29, 22.09, 22.09, 22.09, 22.09, 22.09, 22.05, 22.05, 22.05, 22.05, 22.05, 22.05, 22.05, 22.05, 22.05, 22.05, 22.13, 22.13, 22.13, 22.13, 22.13, 22.3, 22.3, 22.3, 22.3, 22.3, 22.31, 22.31, 22.31, 22.31, 22.31, 22.34, 22.34, 22.34, 22.34, 22.34, 22.26, 22.26, 22.26, 22.26, 22.26, 22.2, 22.2, 22.2, 22.2, 22.2, 21.94, 21.94, 21.94, 21.94, 21.94, 21.91, 21.91, 21.91, 21.91, 21.91, 21.91, 21.91, 21.91, 21.91, 21.91, 21.41, 21.41, 21.41, 21.41, 21.41, 20.64, 20.64, 20.64, 20.64, 20.64, 20.37, 20.37, 20.37, 20.37, 20.37, 20.43]
                    
Loading

Details

kv_cache_usage_ratio

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 413 iterations"
    y-axis "llamacpp:kv_cache_usage_ratio"
    x-axis "llamacpp:kv_cache_usage_ratio" 1713280523 --> 1713281153
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3, 0.3, 0.3, 0.3, 0.3, 0.23, 0.23, 0.23, 0.23, 0.23, 0.19, 0.19, 0.19, 0.19, 0.19, 0.14, 0.14, 0.14, 0.14, 0.14, 0.21, 0.21, 0.21, 0.21, 0.21, 0.14, 0.14, 0.14, 0.14, 0.14, 0.17, 0.17, 0.17, 0.17, 0.17, 0.11, 0.11, 0.11, 0.11, 0.11, 0.16, 0.16, 0.16, 0.16, 0.16, 0.18, 0.18, 0.18, 0.18, 0.18, 0.24, 0.24, 0.24, 0.24, 0.24, 0.17, 0.17, 0.17, 0.17, 0.17, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.29, 0.29, 0.29, 0.29, 0.29, 0.13, 0.13, 0.13, 0.13, 0.13, 0.33, 0.33, 0.33, 0.33, 0.33, 0.26, 0.26, 0.26, 0.26, 0.26, 0.14, 0.14, 0.14, 0.14, 0.14, 0.19, 0.19, 0.19, 0.19, 0.19, 0.12, 0.12, 0.12, 0.12, 0.12, 0.31, 0.31, 0.31, 0.31, 0.31, 0.27, 0.27, 0.27, 0.27, 0.27, 0.26, 0.26, 0.26, 0.26, 0.26, 0.13, 0.13, 0.13, 0.13, 0.13, 0.22, 0.22, 0.22, 0.22, 0.22, 0.16, 0.16, 0.16, 0.16, 0.16, 0.15, 0.15, 0.15, 0.15, 0.15, 0.34, 0.34, 0.34, 0.34, 0.34, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.17, 0.17, 0.17, 0.17, 0.17, 0.32, 0.32, 0.32, 0.32, 0.32, 0.24, 0.24, 0.24, 0.24, 0.24, 0.27, 0.27, 0.27, 0.27, 0.27, 0.13, 0.13, 0.13, 0.13, 0.13, 0.17, 0.17, 0.17, 0.17, 0.17, 0.13, 0.13, 0.13, 0.13, 0.13, 0.12, 0.12, 0.12, 0.12, 0.12, 0.13, 0.13, 0.13, 0.13, 0.13, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.15, 0.15, 0.15, 0.15, 0.15, 0.2, 0.2, 0.2, 0.2, 0.2, 0.34, 0.34, 0.34, 0.34, 0.34, 0.14, 0.14, 0.14, 0.14, 0.14, 0.19, 0.19, 0.19, 0.19, 0.19, 0.19, 0.19, 0.19, 0.19, 0.19, 0.13, 0.13, 0.13, 0.13, 0.13, 0.12, 0.12, 0.12, 0.12, 0.12, 0.13, 0.13, 0.13, 0.13, 0.13, 0.31, 0.31, 0.31, 0.31, 0.31, 0.39, 0.39, 0.39, 0.39, 0.39, 0.45, 0.45, 0.45, 0.45, 0.45, 0.51, 0.51, 0.51, 0.51, 0.51, 0.55, 0.55, 0.55, 0.55, 0.55, 0.6, 0.6, 0.6, 0.6, 0.6, 0.51, 0.51, 0.51, 0.51, 0.51, 0.41, 0.41, 0.41, 0.41, 0.41, 0.16, 0.16, 0.16, 0.16, 0.16, 0.2]
                    
Loading
requests_processing
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 413 iterations"
    y-axis "llamacpp:requests_processing"
    x-axis "llamacpp:requests_processing" 1713280523 --> 1713281153
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 4.0, 4.0, 4.0, 4.0, 4.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 3.0, 3.0, 3.0, 3.0, 3.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 4.0, 4.0, 4.0, 4.0, 4.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 3.0, 3.0, 3.0, 3.0, 3.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 0.0, 0.0, 0.0, 0.0, 0.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 2.0, 2.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0, 3.0, 3.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 2.0, 2.0, 2.0, 2.0, 2.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 3.0, 3.0, 3.0, 3.0, 3.0, 5.0, 5.0, 5.0, 5.0, 5.0, 3.0, 3.0, 3.0, 3.0, 3.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 1.0]
                    
Loading

Copy link
Collaborator

@compilade compilade left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hopefully this helps with correcting the flake8 linter errors

@ashishdatta
Copy link
Contributor Author

@ggerganov does this look good to merge now?

llama.cpp Outdated
Comment on lines 8188 to 8195
if (model.layers[il].ffn_norm) {
// non-parallel residual
cur = ggml_add(ctx0, cur, ffn_inp);
} else {
// add together residual + FFN + self-attention
cur = ggml_add(ctx0, cur, inpL);
cur = ggml_add(ctx0, cur, attn_out);
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't these 2 branches equivalent?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't believe so. One is doing parallel residual (eg 12b) and the other when the ffn norm is present (eg stablelm 1.6 and 3b) is not doing parallel residual. If I am missing something please let me know thanks !

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since ffn_inp = attn_out + inpL I think these branches do the same and can be replaced by simply with:

cur = ggml_add(ctx0, cur, ffn_inp);

I am looking for ways to avoid the unused ffn_inp = ggml_add(...) in the parallel-residual case

Copy link
Collaborator

@compilade compilade Apr 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since ffn_inp = attn_out + inpL I think these branches do the same

Reasoning from the relevant modeling code in transformers, even though they separate them for clarity, I think you're right, theses branches do the same thing.

Copy link
Contributor Author

@ashishdatta ashishdatta Apr 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ggerganov Thanks! I removed the branches. And re-ran on 1.6B, 3B and 12B no problems. Please let me know if there is anything else !

Copy link
Collaborator

@Galunid Galunid left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you want to do anything else, or can we merge?

@ashishdatta
Copy link
Contributor Author

All done from my side

@ggerganov ggerganov merged commit dbceec8 into ggml-org:master Apr 16, 2024
57 of 62 checks passed
tybalex pushed a commit to rubra-ai/tools.cpp that referenced this pull request Apr 17, 2024
* StableLM2 12B support for huggingface -> GGUF

* StableLM12 tensormapping and constants

* StableLM-2-12b model support

* fix

* Added 12B support

* Removed autoformatting; resolved bug where model_arch was not selecting StableLM2

* Formatting

* Do QK norm stacking in model conversion step

* Converge StableLM and StableLM2 code to simplify graph construction

* Fix accidental removal

* Removed warnings

* Revert formatter

* Move QK norm stack to private function so it's easier to read

* refactor stablelm graph builder to support 1.6, 3b and 12b more efficiently

* Proper check for None type for new_name to avoid crash; formatting; revert change to base class `write_tensors()`

* Format

* Formatting

* format

Co-authored-by: compilade <[email protected]>

* Fix incorrect check for K norm

* space after commas; Keep indentation multiple of 4 spaces

* Flake8 format

* Removed unnecessary conditional branches

* Removed unused comment

* Fixed incorrect tensor passing

* Format

---------

Co-authored-by: compilade <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants