diff --git a/docs/supported_models.md b/docs/supported_models.md index ebc3b5d30..115bc9955 100644 --- a/docs/supported_models.md +++ b/docs/supported_models.md @@ -289,9 +289,9 @@ Neural Speed supports the following models: 128k - StableLM-3B, - StableLM2-1_6B - StableLM2-Zephyr-1_6B + StableLM-2-1_6B, + StableLM-3B, + StableLM-2-12B ✅ @@ -301,7 +301,7 @@ Neural Speed supports the following models: Latest - 2048 + 4096 gemma-2b-it , @@ -372,7 +372,7 @@ Neural Speed supports the following models: ✅ Latest - + Magicoder-6.7B ✅ ✅ @@ -398,6 +398,18 @@ Neural Speed supports the following models: Latest + + Stable-Code-3B + ✅ + + + + ✅ + + + + Latest + diff --git a/neural_speed/convert/convert_stablelm.py b/neural_speed/convert/convert_stablelm.py index f5f1d43fd..8340aecee 100644 --- a/neural_speed/convert/convert_stablelm.py +++ b/neural_speed/convert/convert_stablelm.py @@ -20,6 +20,7 @@ # This script is similar to "convert-pt-to-ne.py" # import os +import sys import struct import numpy as np from pathlib import Path @@ -27,7 +28,7 @@ from typing import (IO, TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Literal, Optional, Sequence, Tuple, TypeVar, Union) from transformers import AutoModelForCausalLM, AutoTokenizer -import gguf +import torch # ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py def bytes_to_unicode(): @@ -51,123 +52,45 @@ def bytes_to_unicode(): cs = [chr(n) for n in cs] return dict(zip(bs, cs)) -def stablelm_convert_gguf(model, tokenizer, dir_model, fname_out, ftype, hparams): - print("stablelm.gguf converting: ") - list_vars = model.state_dict() - n_rot = int(hparams["partial_rotary_factor"] * hparams["hidden_size"] / hparams["num_attention_heads"]) - for name in list_vars.keys(): - print(name, list_vars[name].shape, list_vars[name].dtype) - - print(hparams) - - gguf_file = fname_out + '.gguf' - gguf_writer = gguf.GGUFWriter(gguf_file, "stablelm") - - gguf_writer.add_uint32('magic', 0x67676d66) - gguf_writer.add_uint32('version', 1) - gguf_writer.add_uint32('n_vocab', hparams["vocab_size"]) - gguf_writer.add_embedding_length(hparams["hidden_size"]) - gguf_writer.add_head_count(hparams["num_attention_heads"]) - gguf_writer.add_head_count_kv(hparams["num_key_value_heads"]) - - gguf_writer.add_block_count(hparams["num_hidden_layers"]) - gguf_writer.add_rope_dimension_count(n_rot) - gguf_writer.add_uint32('ftype', ftype) - gguf_writer.add_context_length(hparams["max_position_embeddings"]) - gguf_writer.add_feed_forward_length(hparams["intermediate_size"]) - - gguf_writer.add_bos_token_id(hparams["bos_token_id"]) - gguf_writer.add_eos_token_id(hparams["eos_token_id"]) - gguf_writer.add_pad_token_id(hparams["pad_token_id"] if hparams["pad_token_id"] else 0) - gguf_writer.add_sep_token_id(hparams["sep_token_id"] if hparams["sep_token_id"] else 0) - - def write_vocab_gguf(dir_model, hparams, gguf_writer): - tokens: list[bytearray] = [] - toktypes: list[int] = [] - - tokenizer = AutoTokenizer.from_pretrained(dir_model) - vocab_size = hparams.get("vocab_size", len(tokenizer.vocab)) - assert max(tokenizer.vocab.values()) < vocab_size - reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()} - added_vocab = tokenizer.get_added_vocab() - - for i in range(vocab_size): - if i not in reverse_vocab: - pad_token = f"[PAD{i}]".encode('utf-8') - tokens.append(bytearray(pad_token)) - toktypes.append(gguf.TokenType.USER_DEFINED) - elif reverse_vocab[i] in added_vocab: - tokens.append(reverse_vocab[i]) - if tokenizer.added_tokens_decoder[i].special: - toktypes.append(gguf.TokenType.CONTROL) - else: - toktypes.append(gguf.TokenType.USER_DEFINED) - else: - tokens.append(reverse_vocab[i]) - toktypes.append(gguf.TokenType.NORMAL) - - gguf_writer.add_tokenizer_model("gpt2") - gguf_writer.add_token_list(tokens) - gguf_writer.add_token_types(toktypes) - - special_vocab = gguf.SpecialVocab(dir_model, load_merges=True) - special_vocab.add_to_gguf(gguf_writer) - - write_vocab_gguf(dir_model, hparams, gguf_writer) - - # tensor info - print("gguf: get tensor metadata") - for name in list_vars.keys(): - data = list_vars[name].squeeze().numpy() - - print("Processing variable: " + name + " with shape: ", data.shape) - if 'inv_freq' in name: - continue - - n_dims = len(data.shape) - - # ftype == 0 -> float32, ftype == 1 -> float16 - ftype_cur = 0 - if ftype != 0: - if name[-7:] == ".weight" and n_dims == 2: - print(" Converting to float16") - data = data.astype(np.float16) - ftype_cur = 1 - else: - print(" Converting to float32") - data = data.astype(np.float32) - ftype_cur = 0 +def stack_qk_norm(block, name, n_head, norms, n_dims, ftype, layer_name="q_layernorm"): + datas = [] + for i in range(n_head): + ename = f"model.layers.{block}.self_attn.{layer_name}.norms.{i}.weight" + print(f"-----> Merging Tensor {ename} with shape {norms[ename].shape}") + datas.append(norms[ename]) + del norms[ename] + data = np.stack(datas, axis=0) + merged_name = f"model.layers.{block}.self_attn.{layer_name}.weight" + + # ftype == 0 -> float32, ftype == 1 -> float16 + if ftype != 0: + if name.endswith(".weight") and not name.endswith("_norm.weight") and n_dims == 2: + print(" Converting to float16") + data = data.astype(np.float16) else: - if data.dtype != np.float32: - print(" Converting to float32") - data = data.astype(np.float32) - ftype_cur = 0 - - gguf_writer.add_tensor(name, data) + print(" Converting to float32") + data = data.astype(np.float32) + else: + if data.dtype != np.float32: + print(" Converting to float32") + data = data.astype(np.float32) - print("gguf: write header") - gguf_writer.write_header_to_file() - print("gguf: write metadata") - gguf_writer.write_kv_data_to_file() - print("gguf: write tensors") - gguf_writer.write_tensors_to_file() + return merged_name, data - gguf_writer.close() - - print("Done. Output file: " + gguf_file) - print("") def stablelm_convert(model, tokenizer, dir_model, fname_out, ftype, hparams): - n_rot = int(hparams["partial_rotary_factor"] * hparams["hidden_size"] / hparams["num_attention_heads"]) - model.eval() - for p in model.parameters(): - p.requires_grad = False - hparams = model.config.to_dict() + print("stablelm ne converting: ") + list_vars = model.state_dict() + n_head = hparams["num_attention_heads"] + n_head_kv = hparams["num_key_value_heads"] + block_count = hparams["num_hidden_layers"] vocab_size = hparams["vocab_size"] + n_rot = int(hparams["partial_rotary_factor"] * hparams["hidden_size"] / hparams["num_attention_heads"]) print("Model loaded: ", dir_model) - fout = open(fname_out, "wb") + ne_file = fname_out + '.bin' if not fname_out.endswith(".bin") else fname_out + fout = open(ne_file, "wb") # 0x67676d6c is unversioned ne # 0x67676d66 is versioned ggmf (requires token scores) @@ -176,11 +99,11 @@ def stablelm_convert(model, tokenizer, dir_model, fname_out, ftype, hparams): fout.write(struct.pack("i", ne_file_magic)) # magic: ne in hex fout.write(struct.pack("i", 1)) - fout.write(struct.pack("i", hparams["vocab_size"])) + fout.write(struct.pack("i", vocab_size)) fout.write(struct.pack("i", hparams["hidden_size"])) fout.write(struct.pack("i", 0)) - fout.write(struct.pack("i", hparams["num_attention_heads"])) - fout.write(struct.pack("i", hparams["num_key_value_heads"])) # multi-query attention + fout.write(struct.pack("i", n_head)) + fout.write(struct.pack("i", n_head_kv)) # multi-query attention fout.write(struct.pack("i", hparams["num_hidden_layers"])) fout.write(struct.pack("i", n_rot)) fout.write(struct.pack("i", ftype)) @@ -197,7 +120,7 @@ def stablelm_convert(model, tokenizer, dir_model, fname_out, ftype, hparams): fout.write(struct.pack("i", 0)) # n_experts fout.write(struct.pack("i", 0)) # n_expert_used - fout.write(struct.pack("i", 0)) # n_embd_head_k for gemma + fout.write(struct.pack("i", hparams["hidden_size"] // n_head)) # n_embd_head_k for gemma fout.write(struct.pack("f", hparams.get("layer_norm_eps", 1e-5))) # rms_norm_eps or layer_norm_eps fout.write(struct.pack("f", hparams["rope_theta"])) # freq_base fout.write(struct.pack("f", 1.0)) # freq_scale, was removed in config.json (by default=1.0) @@ -223,45 +146,72 @@ def stablelm_convert(model, tokenizer, dir_model, fname_out, ftype, hparams): fout.write(text) fout.write(struct.pack("f", -10000)) - list_vars = model.state_dict() + def write_header(name, data, ftype=0): + str = name.encode('utf-8') + n_dims = len(data.shape) + fout.write(struct.pack("iii", n_dims, len(str), ftype)) + for i in range(n_dims): + fout.write(struct.pack("i", data.shape[n_dims - 1 - i])) + print(str) + fout.write(str) print(hparams) + q_norms, k_norms = dict(), dict() + for name, data_torch in list_vars.items(): + # Convert any unsupported data types to float32 + if data_torch.dtype not in (torch.float16, torch.float32): + data_torch = data_torch.to(torch.float32) + + # Skip some tensors + if name.endswith((".attention.rotary_emb.inv_freq")): + continue - for name in list_vars.keys(): - # No gradients for these - list_vars[name].requires_grad = False - src = name - print(src, ' -> ', name) - data = list_vars[src].squeeze().numpy() - data = data.astype(np.float32) - + data = data_torch.squeeze().numpy() + old_dtype = data.dtype n_dims = len(data.shape) - print(name, n_dims, data.shape) + if name.find("q_layernorm.norms") != -1: + q_norms[name] = data + if len(q_norms) >= (block_count * n_head): + for block in range(block_count): + name, data = stack_qk_norm(block, name, n_head, q_norms, n_dims, ftype, layer_name="q_layernorm") + print(f"Processing variable {name} with shape {data.shape}, {old_dtype} --> {data.dtype}") + write_header(name, data) + data.tofile(fout) + continue + if name.find("k_layernorm.norms") != -1: + k_norms[name] = data + if len(k_norms) >= (block_count * n_head_kv): + for block in range(block_count): + name, data = stack_qk_norm(block, name, n_head_kv, k_norms, n_dims, ftype, layer_name="k_layernorm") + print(f"Processing variable {name} with shape {data.shape}, {old_dtype} --> {data.dtype}") + write_header(name, data) + data.tofile(fout) + continue - # default type is fp32 + # ftype == 0 -> float32, ftype == 1 -> float16 ftype_cur = 0 - if ftype == 1 and n_dims > 1: - print(" Converting to float16", data.shape, data[:3, :3].tolist()) - data = data.astype(np.float16) - ftype_cur = 1 + if ftype != 0: + if name.endswith(".weight") and not name.endswith("_norm.weight") and n_dims == 2: + print(" Converting to float16") + data = data.astype(np.float16) + ftype_cur = 1 + else: + print(" Converting to float32") + data = data.astype(np.float32) else: - print(" Converting to float32", data.shape, data[:3, :3].tolist() if n_dims > 1 else data[:3].tolist()) - data = data.astype(np.float32) + if data.dtype != np.float32: + print(" Converting to float32") + data = data.astype(np.float32) # header - str = name.encode('utf-8') - fout.write(struct.pack("iii", n_dims, len(str), ftype_cur)) - for i in range(n_dims): - fout.write(struct.pack("i", data.shape[n_dims - 1 - i])) - print(str) - fout.write(str) + write_header(name, data, ftype_cur) # data data.tofile(fout) fout.close() - print("Done. Output file: " + fname_out) + print("Done. Output file: " + ne_file) print("") def main(args_in: Optional[List[str]] = None) -> None: @@ -282,13 +232,6 @@ def main(args_in: Optional[List[str]] = None) -> None: default="huggingface", help="hub to load model" ) - parser.add_argument( - "--format", - type=str, - default="NE", - choices=["NE", "GGUF"], - help="convert to the GGUF or NE format" - ) parser.add_argument( "model", type=Path, @@ -315,11 +258,7 @@ def main(args_in: Optional[List[str]] = None) -> None: print("Loading model: ", dir_model) model = AutoModelForCausalLM.from_pretrained(dir_model, trust_remote_code=True) hparams = model.config.to_dict() - if args.format == "GGUF": - stablelm_convert_gguf(model, tokenizer, dir_model, fname_out, ftype, hparams) - else: - stablelm_convert(model, tokenizer, dir_model, fname_out, ftype, hparams) - + stablelm_convert(model, tokenizer, dir_model, fname_out, ftype, hparams) if __name__ == '__main__': diff --git a/neural_speed/models/stablelm/stablelm.cpp b/neural_speed/models/stablelm/stablelm.cpp index 4b0dc9935..20c843b37 100644 --- a/neural_speed/models/stablelm/stablelm.cpp +++ b/neural_speed/models/stablelm/stablelm.cpp @@ -74,6 +74,7 @@ static bool stablelm_model_eval_internal(model_context* ctx, const model_input* const int n_ctx = lctx.n_ctx; const int n_keep = lctx.n_keep; const int n_head = hparams.n_head; + const int n_head_kv = hparams.n_head_kv; const int n_vocab = hparams.n_vocab; const int n_rot = hparams.n_rot; const int head_dim = n_embd / n_head; @@ -101,7 +102,7 @@ static bool stablelm_model_eval_internal(model_context* ctx, const model_input* attn_shape_t attn_shape = { /* .batch_size = */ 1, /* .head_num = */ n_head, - /* .heads_kv = */ n_head, + /* .heads_kv = */ n_head_kv, /* .head_size = */ head_dim, /* .sl_q = */ N, // Note: make sure that bestla reordered attn supports next token inference /* .sl_kv = */ n_past + N, @@ -110,7 +111,7 @@ static bool stablelm_model_eval_internal(model_context* ctx, const model_input* NE_ASSERT(("bestla managed kv-cache not supported; use `--memory-f16 / --memory-f32` instead", bestla_reordered_attn_fp32_support(&attn_shape))); kv_shape_t kv_shape{ - /* .heads_kv = */ static_cast(n_head), + /* .heads_kv = */ static_cast(n_head_kv), /* .head_size = */ static_cast(head_dim), /* .sl_kv_max = */ static_cast(n_ctx), }; @@ -126,6 +127,7 @@ static bool stablelm_model_eval_internal(model_context* ctx, const model_input* for (int il = 0; il < n_layer; ++il) { struct ne_tensor* cur; + struct ne_tensor* inpPA; lctx.use_buf(ctx0, 0); @@ -133,30 +135,40 @@ static bool stablelm_model_eval_internal(model_context* ctx, const model_input* // layer_norm { cur = ne_norm(ctx0, inpL, hparams.norm_eps); - // cur = cur*attention_norm(broadcasted) cur = ne_mul(ctx0, cur, model.layers[il].norm[0]); cur = ne_add(ctx0, cur, model.layers[il].norm[1]); } + // Store for parallel MLP layer + inpPA = cur; + // Compute QKV struct ne_tensor* Qcur; struct ne_tensor* Kcur; struct ne_tensor* Vcur; - if (n_layer == 24) { // Stablelm2 1.6B & Stablelm2 Zephyr 1.6B + if (n_layer == 24) { // StableLM-2-1.6B & StableLM-2-Zephyr-1.6B Qcur = - ne_reshape_4d(ctx0, ne_add(ctx0, ne_mul_mat(ctx0, model.layers[il].attn[0], cur), model.layers[il].attn[1]), + ne_reshape_4d(ctx0, ne_add(ctx0, ne_mul_mat(ctx0, model.layers[il].attn[0], cur), model.layers[il].attn[4]), head_dim, n_head, N, 1); Kcur = - ne_reshape_4d(ctx0, ne_add(ctx0, ne_mul_mat(ctx0, model.layers[il].attn[2], cur), model.layers[il].attn[3]), - head_dim, n_head, N, 1); + ne_reshape_4d(ctx0, ne_add(ctx0, ne_mul_mat(ctx0, model.layers[il].attn[1], cur), model.layers[il].attn[5]), + head_dim, n_head_kv, N, 1); Vcur = - ne_reshape_4d(ctx0, ne_add(ctx0, ne_mul_mat(ctx0, model.layers[il].attn[4], cur), model.layers[il].attn[5]), - head_dim, n_head, N, 1); - } else { // Stablelm 3B + ne_reshape_4d(ctx0, ne_add(ctx0, ne_mul_mat(ctx0, model.layers[il].attn[2], cur), model.layers[il].attn[6]), + head_dim, n_head_kv, N, 1); + } else if (n_layer == 32) { // StableLM-3B & Stable-Code-3B + Qcur = ne_reshape_4d(ctx0, ne_mul_mat(ctx0, model.layers[il].attn[0], cur), head_dim, n_head, N, 1); + Kcur = ne_reshape_4d(ctx0, ne_mul_mat(ctx0, model.layers[il].attn[1], cur), head_dim, n_head_kv, N, 1); + Vcur = ne_reshape_4d(ctx0, ne_mul_mat(ctx0, model.layers[il].attn[2], cur), head_dim, n_head_kv, N, 1); + } else if (n_layer == 40) { // StableLM-2-12B Qcur = ne_reshape_4d(ctx0, ne_mul_mat(ctx0, model.layers[il].attn[0], cur), head_dim, n_head, N, 1); - Kcur = ne_reshape_4d(ctx0, ne_mul_mat(ctx0, model.layers[il].attn[1], cur), head_dim, n_head, N, 1); - Vcur = ne_reshape_4d(ctx0, ne_mul_mat(ctx0, model.layers[il].attn[2], cur), head_dim, n_head, N, 1); + Qcur = ne_norm(ctx0, Qcur, hparams.norm_eps); + Qcur = ne_mul(ctx0, Qcur, model.layers[il].attn[4]); + Kcur = ne_reshape_4d(ctx0, ne_mul_mat(ctx0, model.layers[il].attn[1], cur), head_dim, n_head_kv, N, 1); + Kcur = ne_norm(ctx0, Kcur, hparams.norm_eps); + Kcur = ne_mul(ctx0, Kcur, model.layers[il].attn[5]); + Vcur = ne_reshape_4d(ctx0, ne_mul_mat(ctx0, model.layers[il].attn[2], cur), head_dim, n_head_kv, N, 1); } // using mode = 2 for GPT-NeoX mode @@ -166,7 +178,7 @@ static bool stablelm_model_eval_internal(model_context* ctx, const model_input* ne_build_forward_expand(&gf, Qcur_Part); ne_set_name(Qcur, "Qcur"); - struct ne_tensor* Kcur_Part = ne_view_4d(ctx0, ne_permute(ctx0, Kcur, 0, 2, 1, 3), n_rot, n_head, N, 1, + struct ne_tensor* Kcur_Part = ne_view_4d(ctx0, ne_permute(ctx0, Kcur, 0, 2, 1, 3), n_rot, n_head_kv, N, 1, Kcur->nb[1], Kcur->nb[2], Kcur->nb[3], 0); Kcur_Part = ne_rope_inplace(ctx0, Kcur_Part, n_past, n_rot, 2, 0, hparams.freq_base, hparams.freq_scale); ne_build_forward_expand(&gf, Kcur_Part); @@ -184,14 +196,14 @@ static bool stablelm_model_eval_internal(model_context* ctx, const model_input* for (int i = 0; i < batch_size; ++i) { // batch K Kcur_bs[i] = ne_permute(ctx0, - ne_view_4d(ctx0, Kcur, head_dim, n_head, N, 1, ne_element_size(Kcur) * head_dim, + ne_view_4d(ctx0, Kcur, head_dim, n_head_kv, N, 1, ne_element_size(Kcur) * head_dim, ne_element_size(Kcur) * n_embd, ne_element_size(Kcur) * n_embd * N, i * ne_element_size(Kcur) * n_embd * N), 0, 2, 1, 3); Kcur_temp = Kcur_bs[i]; ne_set_name(Kcur_bs[i], "kcur_bs"); k_bs[i] = ne_view_4d( - ctx0, kv_self.k, head_dim, N, n_head, 1, ne_element_size(kv_self.k) * head_dim, + ctx0, kv_self.k, head_dim, N, n_head_kv, 1, ne_element_size(kv_self.k) * head_dim, ne_element_size(kv_self.k) * head_dim * n_ctx, ne_element_size(kv_self.k) * n_embd * n_ctx, ((il * n_ctx) * ne_element_size(kv_self.k) * n_embd * kv_n_ctx_block + i * n_ctx * n_embd * ne_element_size(kv_self.k) + head_dim * n_past * ne_element_size(kv_self.k))); @@ -201,10 +213,10 @@ static bool stablelm_model_eval_internal(model_context* ctx, const model_input* ne_reshape_4d(ctx0, ne_view_2d(ctx0, Vcur, n_embd, N, ne_element_size(Vcur) * n_embd, i * ne_element_size(Vcur) * n_embd * N), - head_dim, n_head, N, 1), + head_dim, n_head_kv, N, 1), 1, 2, 0, 3); v_bs[i] = - ne_view_4d(ctx0, kv_self.v, N, head_dim, n_head, 1, n_ctx * ne_element_size(kv_self.v), + ne_view_4d(ctx0, kv_self.v, N, head_dim, n_head_kv, 1, n_ctx * ne_element_size(kv_self.v), n_ctx * ne_element_size(kv_self.v) * head_dim, n_ctx * ne_element_size(kv_self.v) * n_embd, ((il * n_ctx) * ne_element_size(kv_self.v) * n_embd * kv_n_ctx_block + i * n_ctx * n_embd * ne_element_size(kv_self.v) + n_past * ne_element_size(kv_self.v))); @@ -216,17 +228,17 @@ static bool stablelm_model_eval_internal(model_context* ctx, const model_input* struct ne_tensor* Q = ne_permute(ctx0, ne_reshape_4d(ctx0, Qcur, head_dim, n_head, N, batch_size), 0, 2, 1, 3); ne_set_name(Q, "Q"); // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3) - struct ne_tensor* K = - ne_view_4d(ctx0, kv_self.k, head_dim, n_past + N, n_head, batch_size, ne_element_size(kv_self.k) * head_dim, - ne_element_size(kv_self.k) * head_dim * n_ctx, ne_element_size(kv_self.k) * n_embd * n_ctx, - il * n_ctx * ne_element_size(kv_self.k) * n_embd * kv_n_ctx_block); + struct ne_tensor* K = ne_view_4d( + ctx0, kv_self.k, head_dim, n_past + N, n_head_kv, batch_size, ne_element_size(kv_self.k) * head_dim, + ne_element_size(kv_self.k) * head_dim * n_ctx, ne_element_size(kv_self.k) * n_embd * n_ctx, + il * n_ctx * ne_element_size(kv_self.k) * n_embd * kv_n_ctx_block); ne_set_name(K, "K"); // K * Q struct ne_tensor* KQ = ne_mul_mat(ctx0, K, Q); // KQ_scaled = KQ / sqrt(n_embd/n_head) struct ne_tensor* KQ_scaled = - ne_scale_inplace(ctx0, KQ, ne_new_f32(ctx0, 1.0f / sqrt(static_cast((n_embd) / n_head)))); + ne_scale_inplace(ctx0, KQ, ne_new_f32(ctx0, 1.0f / sqrt(static_cast(head_dim)))); // KQ_masked = mask_past(KQ_scaled) struct ne_tensor* KQ_masked = ne_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); @@ -236,7 +248,7 @@ static bool stablelm_model_eval_internal(model_context* ctx, const model_input* // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous() struct ne_tensor* V = - ne_view_4d(ctx0, kv_self.v, n_past + N, head_dim, n_head, batch_size, n_ctx * ne_element_size(kv_self.v), + ne_view_4d(ctx0, kv_self.v, n_past + N, head_dim, n_head_kv, batch_size, n_ctx * ne_element_size(kv_self.v), n_ctx * ne_element_size(kv_self.v) * head_dim, n_ctx * ne_element_size(kv_self.v) * n_embd, il * n_ctx * ne_element_size(kv_self.v) * n_embd * kv_n_ctx_block); @@ -255,15 +267,15 @@ static bool stablelm_model_eval_internal(model_context* ctx, const model_input* // store key and value to memory { - const auto k_cache = ne_view_3d(ctx0, kv_self.k, // tensor - head_dim, n_ctx, n_head, // ne - 0, 0, // nb (bestla managed) - il * k_size); // offset + const auto k_cache = ne_view_3d(ctx0, kv_self.k, // tensor + head_dim, n_ctx, n_head_kv, // ne + 0, 0, // nb (bestla managed) + il * k_size); // offset ne_build_forward_expand(&gf, ne_flash_attn_update_k(ctx0, k_cache, Kcur, n_past, false)); - const auto v_cache = ne_view_3d(ctx0, kv_self.v, // tensor - head_dim, n_ctx, n_head, // ne - 0, 0, // nb (bestla managed) - il * v_size); // offset + const auto v_cache = ne_view_3d(ctx0, kv_self.v, // tensor + head_dim, n_ctx, n_head_kv, // ne + 0, 0, // nb (bestla managed) + il * v_size); // offset ne_build_forward_expand(&gf, ne_flash_attn_update_v(ctx0, v_cache, Vcur, n_past, false)); } @@ -272,14 +284,14 @@ static bool stablelm_model_eval_internal(model_context* ctx, const model_input* struct ne_tensor* K = ne_view_3d(ctx0, kv_self.k, // tensor - head_dim, seq_kv, n_head, // ne + head_dim, seq_kv, n_head_kv, // ne kv_cache_info.stride_k_sl, kv_cache_info.stride_k_head_num, // nb (bestla managed) il * k_size); // offset *reinterpret_cast(&K->nb[0]) = kv_cache_info.k_layout; // us nb0 for layout ne_set_name(K, "K"); struct ne_tensor* V = ne_view_3d(ctx0, kv_self.v, // tensor - seq_kv, head_dim, n_head, // ne + seq_kv, head_dim, n_head_kv, // ne kv_cache_info.stride_v_head_size, kv_cache_info.stride_v_head_num, // nb (bestla managed) il * v_size); // offset *reinterpret_cast(&V->nb[0]) = kv_cache_info.v_layout; // us nb0 for layout @@ -291,14 +303,8 @@ static bool stablelm_model_eval_internal(model_context* ctx, const model_input* cur = ne_view_2d(ctx0, KQV_Out, n_embd, N, n_embd * ne_element_size(KQV_Out), 0); } - // projection - { - if (n_layer == 24) { // Stablelm2 1.6B & Stablelm2 Zephyr 1.6B - cur = ne_mul_mat(ctx0, model.layers[il].attn[6], cur); - } else { // Stablelm 3B - cur = ne_mul_mat(ctx0, model.layers[il].attn[3], cur); - } - } + // out projection gemm + { cur = ne_mul_mat(ctx0, model.layers[il].attn[3], cur); } } lctx.use_buf(ctx0, 1); @@ -309,9 +315,13 @@ static bool stablelm_model_eval_internal(model_context* ctx, const model_input* { // Post Attention norm { - cur = ne_norm(ctx0, cur, hparams.norm_eps); - cur = ne_mul(ctx0, cur, model.layers[il].norm[2]); - cur = ne_add(ctx0, cur, model.layers[il].norm[3]); + if (n_layer < 40) { + cur = ne_norm(ctx0, cur, hparams.norm_eps); + cur = ne_mul(ctx0, cur, model.layers[il].norm[2]); + cur = ne_add(ctx0, cur, model.layers[il].norm[3]); + } else { + cur = inpPA; // Parallel FFN + } } if (bestla_fusion_FFN_SiLu_f32f32_support(model.layers[il].ffn[0]->data, model.layers[il].ffn[1]->data, diff --git a/neural_speed/models/stablelm/stablelm.h b/neural_speed/models/stablelm/stablelm.h index 3df5b75cb..8309a6ff4 100644 --- a/neural_speed/models/stablelm/stablelm.h +++ b/neural_speed/models/stablelm/stablelm.h @@ -20,16 +20,31 @@ enum stablelm_model { STABLELM_UNKNOWN, - STABLELM_1_6B, + STABLELM_2_1_6B, + STABLELM_2_12B, STABLELM_3B, }; -static const model_scratch stablelm_mem_req(int n_layers) { +static const model_scratch stablelm_mem_req(int n_layers, float scratch_size_ratio = 1.0f) { switch (n_layers) { - case 24: - return {512ull * MB, 512ull * MB, 1026ull * MB}; // StableLM2-1.6B & StableLM2-Zephyr-1.6B - case 32: - return {1024ull * MB, 1024ull * MB, 1026ull * MB}; // StableLM-3B + case 24: // StableLM-2-1.6B & StableLM-2-Zephyr-1.6B + return { + static_cast(scratch_size_ratio * 512) * MB, + static_cast(scratch_size_ratio * 512) * MB, + static_cast(scratch_size_ratio * 1024) * MB, + }; + case 32: // StableLM-3B & Stable-Code-3B + return { + static_cast(scratch_size_ratio * 1024) * MB, + static_cast(scratch_size_ratio * 1024) * MB, + static_cast(scratch_size_ratio * 1024) * MB, + }; + case 40: // StableLM-2-12B + return { + static_cast(scratch_size_ratio * 2560) * MB, + static_cast(scratch_size_ratio * 2560) * MB, + static_cast(scratch_size_ratio * 5120) * MB, + }; default: MODEL_ASSERT(false); } @@ -39,7 +54,7 @@ class stablelm : public IModel { private: model_archs name = MODEL_STABLELM; std::unique_ptr ml; - uint32_t n_layer, n_embd, n_ff, n_vocab; + uint32_t n_layer, n_embd, n_ff, n_vocab, n_head, n_head_kv, n_embd_head_k; int n_ctx, n_gpu_layer; bool use_mmap, use_mlock, vocab_only; model_scratch scratch; diff --git a/neural_speed/models/stablelm/stablelm_utils.cpp b/neural_speed/models/stablelm/stablelm_utils.cpp index e12877995..a7714cdb9 100644 --- a/neural_speed/models/stablelm/stablelm_utils.cpp +++ b/neural_speed/models/stablelm/stablelm_utils.cpp @@ -73,6 +73,9 @@ void stablelm::init(const char* path_model, model_context* ctx, int n_gpu_layer_ n_embd = hparams.n_embd; n_vocab = hparams.n_vocab; n_layer = hparams.n_layer; + n_head = hparams.n_head; + n_head_kv = hparams.n_head_kv; + n_embd_head_k = hparams.n_embd_head_k; n_embd = hparams.n_embd; scratch = stablelm_mem_req(n_layer); model.scratchs = scratch; @@ -130,24 +133,25 @@ void stablelm::load(model_context* ctx, model_progress_callback progress_callbac layer.norm[1] = ml->get_tensor(layers_i + ".input_layernorm.bias", {n_embd}, backend); // qkv GEMM + out proj GEMM - if (ml->verify_tensor(layers_i + ".self_attn.q_proj.bias")) { // Stablelm2 1.6B & Stablelm2 Zephyr 1.6B - layer.attn[0] = ml->get_tensor(layers_i + ".self_attn.q_proj.weight", {n_embd, n_embd}, backend); - layer.attn[1] = ml->get_tensor(layers_i + ".self_attn.q_proj.bias", {n_embd}, backend); - layer.attn[2] = ml->get_tensor(layers_i + ".self_attn.k_proj.weight", {n_embd, n_embd}, backend); - layer.attn[3] = ml->get_tensor(layers_i + ".self_attn.k_proj.bias", {n_embd}, backend); - layer.attn[4] = ml->get_tensor(layers_i + ".self_attn.v_proj.weight", {n_embd, n_embd}, backend); - layer.attn[5] = ml->get_tensor(layers_i + ".self_attn.v_proj.bias", {n_embd}, backend); - layer.attn[6] = ml->get_tensor(layers_i + ".self_attn.o_proj.weight", {n_embd, n_embd}, backend); - } else { // Stablelm 3B - layer.attn[0] = ml->get_tensor(layers_i + ".self_attn.q_proj.weight", {n_embd, n_embd}, backend); - layer.attn[1] = ml->get_tensor(layers_i + ".self_attn.k_proj.weight", {n_embd, n_embd}, backend); - layer.attn[2] = ml->get_tensor(layers_i + ".self_attn.v_proj.weight", {n_embd, n_embd}, backend); - layer.attn[3] = ml->get_tensor(layers_i + ".self_attn.o_proj.weight", {n_embd, n_embd}, backend); + layer.attn[0] = ml->get_tensor(layers_i + ".self_attn.q_proj.weight", {n_embd, n_embd_head_k * n_head}, backend); + layer.attn[1] = ml->get_tensor(layers_i + ".self_attn.k_proj.weight", {n_embd, n_embd_head_k * n_head_kv}, backend); + layer.attn[2] = ml->get_tensor(layers_i + ".self_attn.v_proj.weight", {n_embd, n_embd_head_k * n_head_kv}, backend); + layer.attn[3] = ml->get_tensor(layers_i + ".self_attn.o_proj.weight", {n_embd_head_k * n_head, n_embd}, backend); + + if (n_layer == 24) { // StableLM-2-1.6B & StableLM-2-Zephyr-1.6B + layer.attn[4] = ml->get_tensor(layers_i + ".self_attn.q_proj.bias", {n_embd}, backend); + layer.attn[5] = ml->get_tensor(layers_i + ".self_attn.k_proj.bias", {n_embd}, backend); + layer.attn[6] = ml->get_tensor(layers_i + ".self_attn.v_proj.bias", {n_embd}, backend); + } else if (n_layer == 40) { // StableLM-2-12B + layer.attn[4] = ml->get_tensor(layers_i + ".self_attn.q_layernorm.weight", {n_embd_head_k, n_head}, backend); + layer.attn[5] = ml->get_tensor(layers_i + ".self_attn.k_layernorm.weight", {n_embd_head_k, n_head_kv}, backend); } - // Post Attention norm - layer.norm[2] = ml->get_tensor(layers_i + ".post_attention_layernorm.weight", {n_embd}, backend); - layer.norm[3] = ml->get_tensor(layers_i + ".post_attention_layernorm.bias", {n_embd}, backend); + // Post Attention norm - Only present in 1.6B & 3B + if (n_layer < 40) { + layer.norm[2] = ml->get_tensor(layers_i + ".post_attention_layernorm.weight", {n_embd}, backend); + layer.norm[3] = ml->get_tensor(layers_i + ".post_attention_layernorm.bias", {n_embd}, backend); + } // ffn GEMM layer.ffn[0] = ml->get_tensor(layers_i + ".mlp.gate_proj.weight", {n_embd, n_ff}, backend); @@ -155,17 +159,22 @@ void stablelm::load(model_context* ctx, model_progress_callback progress_callbac layer.ffn[2] = ml->get_tensor(layers_i + ".mlp.up_proj.weight", {n_embd, n_ff}, backend); if (backend != NE_BACKEND_CPU) { - if (ml->verify_tensor(layers_i + ".self_attn.q_proj.bias")) { + if (n_layer == 24) { vram_total += ne_nbytes(layer.norm[0]) + ne_nbytes(layer.norm[1]) + ne_nbytes(layer.norm[2]) + ne_nbytes(layer.norm[3]) + ne_nbytes(layer.attn[0]) + ne_nbytes(layer.attn[1]) + ne_nbytes(layer.attn[2]) + ne_nbytes(layer.attn[3]) + ne_nbytes(layer.attn[4]) + ne_nbytes(layer.attn[5]) + ne_nbytes(layer.attn[6]) + ne_nbytes(layer.ffn[0]) + ne_nbytes(layer.ffn[1]) + ne_nbytes(layer.ffn[2]); - } else { + } else if (n_layer == 32) { vram_total += ne_nbytes(layer.norm[0]) + ne_nbytes(layer.norm[1]) + ne_nbytes(layer.norm[2]) + ne_nbytes(layer.norm[3]) + ne_nbytes(layer.attn[0]) + ne_nbytes(layer.attn[1]) + ne_nbytes(layer.attn[2]) + ne_nbytes(layer.attn[3]) + ne_nbytes(layer.ffn[0]) + ne_nbytes(layer.ffn[1]) + ne_nbytes(layer.ffn[2]); + } else if (n_layer == 40) { + vram_total += ne_nbytes(layer.norm[0]) + ne_nbytes(layer.norm[1]) + ne_nbytes(layer.attn[0]) + + ne_nbytes(layer.attn[1]) + ne_nbytes(layer.attn[2]) + ne_nbytes(layer.attn[3]) + + ne_nbytes(layer.attn[4]) + ne_nbytes(layer.attn[5]) + ne_nbytes(layer.ffn[0]) + + ne_nbytes(layer.ffn[1]) + ne_nbytes(layer.ffn[2]); } } } @@ -196,7 +205,9 @@ void stablelm::load(model_context* ctx, model_progress_callback progress_callbac class stablelm_quant_layer : public quant_layer_base { public: quant_params_internal get_layer_config(std::string layername, std::vector ne, ne_type type) override { - bool quantize = layername.rfind("weight") == layername.size() - 6; // ends with 'weight'? + bool quantize = + (layername.rfind("weight") == layername.size() - 6) && + (layername.find("layernorm") == std::string::npos); // quantize if ending with 'weight' && not a layernorm if (layername == "model.embed_tokens.weight") { // special layer process, can be loaded by config file return quant_params_internal(); // return q4_0 to cover the usage of getrow