Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : fix tokenizer #2315

Closed
wants to merge 27 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
ac793a2
Fix for #2023
goerch Jul 21, 2023
8c9d1e7
Fix typo
goerch Jul 21, 2023
9f055e3
Add missing include
goerch Jul 22, 2023
bf665cc
Replace VLA with std::vector
goerch Jul 22, 2023
c8ae817
Add possibly missing typename
goerch Jul 22, 2023
94a0ee1
More testing of the tokenizer
goerch Jul 22, 2023
0e74a72
Added whitespace escaping and unescaping
goerch Jul 22, 2023
e6b1a50
Fix for #2310
goerch Jul 23, 2023
dba8369
One more test case...
goerch Jul 23, 2023
b97a505
Fix C linkage for llama_token_to_str
goerch Jul 24, 2023
81fae1d
Fixing llama_token_to_str for the different sentence_piece token types
goerch Jul 24, 2023
281a4b4
Fixing tests
goerch Jul 24, 2023
a0d28b2
Remove comment
goerch Jul 24, 2023
39c9a3b
Added test cases
goerch Jul 24, 2023
fe7508c
Fix review remarks.
goerch Jul 24, 2023
8253a53
Fix test
goerch Jul 24, 2023
e68580f
Remove llama.cpp.h
goerch Jul 25, 2023
3bdf106
Merge branch 'master' into fix-2023
goerch Jul 25, 2023
b4a5461
Resolve merge conflict with grammar stuff.
goerch Jul 25, 2023
de41d5e
Fix static declarations
goerch Jul 26, 2023
30a0e4c
Fixing function ordering issue
goerch Aug 6, 2023
1b54429
Fix tokenizer regression in convert.py and improve CPP interface for …
goerch Aug 6, 2023
19e950f
Adding support for Aquila (GPT2?) tokenizer.
goerch Aug 6, 2023
bb6a58d
Simplifying an expression.
goerch Aug 6, 2023
5d52192
Remove inactive code.
goerch Aug 6, 2023
38fbb74
Merge branch 'master' into fix-2023
goerch Aug 7, 2023
f1f85de
Split BPE and SentencePiece vocabularies
goerch Aug 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 61 additions & 42 deletions convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,22 +238,58 @@ def load(model_plus: 'ModelPlus') -> 'Params':
return params


class SentencePieceVocab:
def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path], vocabtype: Optional[str]) -> None:
self.vocabtype = vocabtype
if self.vocabtype == "bpe":
self.sentencepiece_tokenizer = json.loads(open(str(fname_tokenizer)).read())
else:
self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer))
class BpeVocab:
def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path]) -> None:
self.bpe_tokenizer = json.loads(open(str(fname_tokenizer), encoding="utf-8").read())
added_tokens: Dict[str, int]
if fname_added_tokens is not None:
added_tokens = json.load(open(fname_added_tokens))
added_tokens = json.load(open(fname_added_tokens, encoding="utf-8"))
else:
added_tokens = {}
if self.vocabtype == "bpe":
vocab_size: int = len(self.sentencepiece_tokenizer)
vocab_size: int = len(self.bpe_tokenizer)
expected_ids = list(range(vocab_size, vocab_size + len(added_tokens)))
actual_ids = sorted(added_tokens.values())
if expected_ids != actual_ids:
raise Exception(f"Expected added token IDs to be sequential and start at {len(added_tokens)}; got {actual_ids}")
items = sorted(added_tokens.items(), key=lambda text_idx: text_idx[1])
self.added_tokens_list = [text for (text, idx) in items]
self.vocab_size_base: int = vocab_size
self.vocab_size: int = self.vocab_size_base + len(self.added_tokens_list)
self.fname_tokenizer = fname_tokenizer
self.fname_added_tokens = fname_added_tokens

def bpe_tokens(self) -> Iterable[Tuple[bytes, float]]:
tokenizer = self.bpe_tokenizer
from transformers.models.gpt2 import tokenization_gpt2
byte_encoder = tokenization_gpt2.bytes_to_unicode()
byte_decoder = {v: k for k, v in byte_encoder.items()}
for i, item in enumerate(tokenizer):
text: bytes = item.encode("utf-8")
score: float = -i
yield text, score

def added_tokens(self) -> Iterable[Tuple[bytes, float]]:
for text in self.added_tokens_list:
score = -1000.0
yield text.encode("utf-8"), score

def all_tokens(self) -> Iterable[Tuple[bytes, float]]:
yield from self.bpe_tokens()
yield from self.added_tokens()

def __repr__(self) -> str:
return f"BpeVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"


class SentencePieceVocab:
def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path]) -> None:
self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer))
added_tokens: Dict[str, int]
if fname_added_tokens is not None:
added_tokens = json.load(open(fname_added_tokens, encoding="utf-8"))
else:
vocab_size: int = self.sentencepiece_tokenizer.vocab_size()
added_tokens = {}
vocab_size: int = self.sentencepiece_tokenizer.vocab_size()
expected_ids = list(range(vocab_size, vocab_size + len(added_tokens)))
actual_ids = sorted(added_tokens.values())
if expected_ids != actual_ids:
Expand All @@ -267,32 +303,11 @@ def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path], vo

def sentencepiece_tokens(self) -> Iterable[Tuple[bytes, float]]:
tokenizer = self.sentencepiece_tokenizer
if self.vocabtype == "bpe":
from transformers.models.gpt2 import tokenization_gpt2
byte_encoder = tokenization_gpt2.bytes_to_unicode()
byte_decoder = {v: k for k, v in byte_encoder.items()}
for i, item in enumerate(tokenizer):
text: bytes
text = b''.join([x.to_bytes(1, byteorder='big') for x in [byte_decoder[y] for y in item]])
score: float = -i
for i in range(tokenizer.vocab_size()):
piece = tokenizer.id_to_piece(i)
text: bytes = piece.encode("utf-8")
score: float = tokenizer.get_score(i)
yield text, score
else:
for i in range(tokenizer.vocab_size()):
text: bytes
if tokenizer.is_unknown(i):
text = " \u2047 ".encode("utf-8")
elif tokenizer.is_control(i):
text = b""
elif tokenizer.is_byte(i):
piece = tokenizer.id_to_piece(i)
if len(piece) != 6:
raise Exception(f"Invalid token: {piece}")
byte_value = int(piece[3:-1], 16)
text = struct.pack("B", byte_value)
else:
text = tokenizer.id_to_piece(i).replace("\u2581", " ").encode("utf-8")
score: float = tokenizer.get_score(i)
yield text, score

def added_tokens(self) -> Iterable[Tuple[bytes, float]]:
for text in self.added_tokens_list:
Expand All @@ -319,7 +334,7 @@ def __repr__(self) -> str:
return f"<GGMLVocab with {self.vocab_size} tokens>"


Vocab = Union[SentencePieceVocab, GGMLVocab]
Vocab = Union[BpeVocab, SentencePieceVocab, GGMLVocab]


def permute(weights: NDArray, n_head: int, n_kv_head: Optional[int] = None) -> NDArray:
Expand Down Expand Up @@ -1044,7 +1059,7 @@ def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], conc
def check_vocab_size(params: Params, vocab: Vocab) -> None:
if params.n_vocab != vocab.vocab_size:
# GGMLVocab comes from the same file as the model so shouldn't mismatch:
assert isinstance(vocab, SentencePieceVocab)
assert isinstance(vocab, BpeVocab) or isinstance(vocab, SentencePieceVocab)
if params.n_vocab == vocab.vocab_size_base:
print("Ignoring added_tokens.json since model matches vocab size without it.")
vocab.added_tokens_list = []
Expand Down Expand Up @@ -1093,7 +1108,7 @@ def write_vocab(self, vocab: Vocab) -> None:
@staticmethod
def write_vocab_only(fname_out: Path, vocab: Vocab) -> None:
of = OutputFile(fname_out)
params = Params(n_vocab=vocab.vocab_size, n_embd=0, n_mult=0, n_head=1, n_layer=0)
params = Params(n_vocab=vocab.vocab_size, n_embd=0, n_mult=0, n_head=1, n_layer=0, n_kv_head=None)
of = OutputFile(fname_out)
of.write_file_header(params, file_type=GGMLFileType.AllF32)
of.write_vocab(vocab)
Expand Down Expand Up @@ -1228,7 +1243,7 @@ def filter_and_sort_tensors(model: LazyModel) -> LazyModel:
return {name: model[name] for name in TENSORS_LIST if name in model}


def load_vocab(path: Path, vocabtype: Optional[str]) -> SentencePieceVocab:
def load_vocab(path: Path, vocabtype: Optional[str]) -> Union[BpeVocab, SentencePieceVocab]:
print(f"vocabtype: {vocabtype}")
# Be extra-friendly and accept either a file or a directory. Also, if it's
# a directory, it might be the model directory, and tokenizer.model might
Expand All @@ -1250,8 +1265,12 @@ def load_vocab(path: Path, vocabtype: Optional[str]) -> SentencePieceVocab:
"if it's in another directory, pass the directory as --vocab-dir")
added_tokens_path = path.parent / "added_tokens.json"
print(f"Loading vocab file {path}")
return SentencePieceVocab(path, added_tokens_path if added_tokens_path.exists() else None,
vocabtype)
if vocabtype == "bpe":
return BpeVocab(path, added_tokens_path if added_tokens_path.exists() else None)
elif vocabtype == "spm":
return SentencePieceVocab(path, added_tokens_path if added_tokens_path.exists() else None)
else:
raise ValueError(f"Unsupported vocabulary type {vocabtype}")


def default_outfile(model_paths: List[Path], file_type: GGMLFileType) -> Path:
Expand Down
11 changes: 0 additions & 11 deletions examples/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -633,17 +633,6 @@ std::string gpt_random_prompt(std::mt19937 & rng) {
return "The";
}

// TODO: not great allocating this every time
std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::string & text, bool add_bos) {
// initialize to prompt numer of chars, since n_tokens <= n_prompt_chars
std::vector<llama_token> res(text.size() + (int) add_bos);
const int n = llama_tokenize(ctx, text.c_str(), res.data(), res.size(), add_bos);
assert(n >= 0);
res.resize(n);

return res;
}

struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) {
auto lparams = llama_context_default_params();

Expand Down
7 changes: 1 addition & 6 deletions examples/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#pragma once

#define LLAMA_API_CPP // TODO: eliminate me
#include "llama.h"

#include <string>
Expand Down Expand Up @@ -100,12 +101,6 @@ void gpt_print_usage(int argc, char ** argv, const gpt_params & params);

std::string gpt_random_prompt(std::mt19937 & rng);

//
// Vocab utils
//

std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::string & text, bool add_bos);

//
// Model utils
//
Expand Down
2 changes: 1 addition & 1 deletion examples/embedding/embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ int main(int argc, char ** argv) {
fprintf(stderr, "%s: prompt: '%s'\n", __func__, params.prompt.c_str());
fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
for (int i = 0; i < (int) embd_inp.size(); i++) {
fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_str(ctx, embd_inp[i]));
fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_str(ctx, embd_inp[i]).c_str());
}
fprintf(stderr, "\n");
}
Expand Down
12 changes: 4 additions & 8 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,6 @@ int main(int argc, char ** argv) {

// tokenize the prompt
std::vector<llama_token> embd_inp;

// Add a space in front of the first character to match OG llama tokenizer behavior
params.prompt.insert(0, 1, ' ');

if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) {
embd_inp = ::llama_tokenize(ctx, params.prompt, true);
} else {
Expand Down Expand Up @@ -278,22 +274,22 @@ int main(int argc, char ** argv) {
fprintf(stderr, "%s: prompt: '%s'\n", __func__, params.prompt.c_str());
fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
for (int i = 0; i < (int) embd_inp.size(); i++) {
fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_str(ctx, embd_inp[i]));
fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_str(ctx, embd_inp[i]).c_str());
}

if (ctx_guidance) {
fprintf(stderr, "\n");
fprintf(stderr, "%s: negative prompt: '%s'\n", __func__, params.cfg_negative_prompt.c_str());
fprintf(stderr, "%s: number of tokens in negative prompt = %zu\n", __func__, guidance_inp.size());
for (int i = 0; i < (int) guidance_inp.size(); i++) {
fprintf(stderr, "%6d -> '%s'\n", guidance_inp[i], llama_token_to_str(ctx, guidance_inp[i]));
fprintf(stderr, "%6d -> '%s'\n", guidance_inp[i], llama_token_to_str(ctx, guidance_inp[i]).c_str());
}
}

if (params.n_keep > 0) {
fprintf(stderr, "%s: static prompt based on n_keep: '", __func__);
for (int i = 0; i < params.n_keep; i++) {
fprintf(stderr, "%s", llama_token_to_str(ctx, embd_inp[i]));
fprintf(stderr, "%s", llama_token_to_str(ctx, embd_inp[i]).c_str());
}
fprintf(stderr, "'\n");
}
Expand Down Expand Up @@ -658,7 +654,7 @@ int main(int argc, char ** argv) {
// display text
if (input_echo) {
for (auto id : embd) {
printf("%s", llama_token_to_str(ctx, id));
printf("%s", llama_token_to_str(ctx, id).c_str());
}
fflush(stdout);
}
Expand Down
1 change: 1 addition & 0 deletions examples/quantize-stats/quantize-stats.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "ggml.h"
#include "build-info.h"

#define LLAMA_API_CPP // TODO: eliminate me
#define LLAMA_API_INTERNAL
#include "llama.h"

Expand Down
9 changes: 4 additions & 5 deletions examples/save-load-state/save-load-state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,8 @@ int main(int argc, char ** argv) {
llama_free_model(model);
return 1;
}
auto tokens = std::vector<llama_token>(params.n_ctx);
auto n_prompt_tokens = llama_tokenize(ctx, params.prompt.c_str(), tokens.data(), int(tokens.size()), true);

auto tokens = llama_tokenize(ctx, params.prompt.c_str(), true);
auto n_prompt_tokens = tokens.size();
if (n_prompt_tokens < 1) {
fprintf(stderr, "%s : failed to tokenize prompt\n", __func__);
llama_free(ctx);
Expand Down Expand Up @@ -92,7 +91,7 @@ int main(int argc, char ** argv) {
auto next_token_str = llama_token_to_str(ctx, next_token);
last_n_tokens_data.push_back(next_token);

printf("%s", next_token_str);
printf("%s", next_token_str.c_str());
if (llama_eval(ctx, &next_token, 1, n_past, params.n_threads)) {
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
llama_free(ctx);
Expand Down Expand Up @@ -152,7 +151,7 @@ int main(int argc, char ** argv) {
auto next_token_str = llama_token_to_str(ctx2, next_token);
last_n_tokens_data.push_back(next_token);

printf("%s", next_token_str);
printf("%s", next_token_str.c_str());
if (llama_eval(ctx2, &next_token, 1, n_past, params.n_threads)) {
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
llama_free(ctx2);
Expand Down
4 changes: 2 additions & 2 deletions examples/simple/simple.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ int main(int argc, char ** argv)

for( auto id : tokens_list )
{
printf( "%s" , llama_token_to_str( ctx , id ) );
printf( "%s" , llama_token_to_str( ctx , id ).c_str() );
}

fflush(stdout);
Expand Down Expand Up @@ -162,7 +162,7 @@ int main(int argc, char ** argv)
}

// Print the new token :
printf( "%s" , llama_token_to_str( ctx , new_token_id ) );
printf( "%s" , llama_token_to_str( ctx , new_token_id ).c_str() );
fflush( stdout );

// Push this new token for next evaluation :
Expand Down
20 changes: 10 additions & 10 deletions examples/train-text-from-scratch/train-text-from-scratch.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "ggml.h"
#include "common.h"
#include "llama.h"
#include <unordered_map>
#include <vector>
Expand Down Expand Up @@ -1961,7 +1962,7 @@ void print_matrix(struct ggml_tensor * probs) {


void print_token(struct llama_context * ctx, llama_token token) {
printf("%s", llama_token_to_str(ctx, token));
printf("%s", llama_token_to_str(ctx, token).c_str());
}

void print_tokens(struct llama_context* ctx, struct ggml_tensor * tokens) {
Expand Down Expand Up @@ -2188,29 +2189,28 @@ int tokenize_file(struct llama_context * lctx, const char * filename, std::vecto
f.read_raw(buf.data(), f.size);
buf[f.size] = '\0';

out.resize(buf.size());

int n_tokens = llama_tokenize(lctx, buf.data(), out.data(), buf.size(), false);
if (n_tokens >= 0) {
out.resize(n_tokens);
int n_tokens = llama_tokenize(lctx, buf.data(), out.data(), out.size(), false);
if (n_tokens < 0) {
out.resize(-n_tokens);
llama_tokenize(lctx, buf.data(), out.data(), out.size(), false);
}

bool verify = false;
if (verify) {
const char * in = buf.data();
const char * end = buf.data() + buf.size();
for (int i = 0; i < (int) out.size(); ++i) {
const char * s = llama_token_to_str(lctx, out[i]);
int len = strlen(s);
std::string s = llama_token_to_str(lctx, out[i]);
int len = s.length();
if (in >= end) {
printf("%s: unexpected end of original text.\n", __func__);
break;
}
const bool matches = (strncmp(in, s, len) == 0);
const bool matches = (strncmp(in, s.c_str(), len) == 0);
if (matches) {
in += len;
} else {
printf("%s: mismatch: expected '%s', but got '%s'\n", __func__, std::string(in, len).c_str(), s);
printf("%s: mismatch: expected '%s', but got '%s'\n", __func__, std::string(in, len).c_str(), s.c_str());
}
}
}
Expand Down
Loading