Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add llama3 #30334

Merged
merged 12 commits into from
Apr 24, 2024
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,8 @@
title: LLaMA
- local: model_doc/llama2
title: Llama2
- local: model_doc/llama3
title: Llama3
- local: model_doc/longformer
title: Longformer
- local: model_doc/longt5
Expand Down
1 change: 1 addition & 0 deletions docs/source/en/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ Flax), PyTorch, and/or TensorFlow.
| [LiLT](model_doc/lilt) | ✅ | ❌ | ❌ |
| [LLaMA](model_doc/llama) | ✅ | ❌ | ✅ |
| [Llama2](model_doc/llama2) | ✅ | ❌ | ✅ |
| [Llama3](model_doc/llama3) | ✅ | ❌ | ✅ |
| [LLaVa](model_doc/llava) | ✅ | ❌ | ❌ |
| [LLaVA-NeXT](model_doc/llava_next) | ✅ | ❌ | ❌ |
| [Longformer](model_doc/longformer) | ✅ | ✅ | ❌ |
Expand Down
85 changes: 85 additions & 0 deletions docs/source/en/model_doc/llama3.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.

-->

# Llama3


## Overview

The Llama3 model was proposed in [Introducing Meta Llama 3: The most capable openly available LLM to date](https://ai.meta.com/blog/meta-llama-3/) by the meta AI team.

The abstract from the blogpost is the following:

*Today, we’re excited to share the first two models of the next generation of Llama, Meta Llama 3, available for broad use. This release features pretrained and instruction-fine-tuned language models with 8B and 70B parameters that can support a broad range of use cases. This next generation of Llama demonstrates state-of-the-art performance on a wide range of industry benchmarks and offers new capabilities, including improved reasoning. We believe these are the best open source models of their class, period. In support of our longstanding open approach, we’re putting Llama 3 in the hands of the community. We want to kickstart the next wave of innovation in AI across the stack—from applications to developer tools to evals to inference optimizations and more. We can’t wait to see what you build and look forward to your feedback.*

Checkout all Llama3 model checkpoints [here](https://huggingface.co/models?search=llama3).
The original code of the authors can be found [here](https://github.com/meta-llama/llama3).

## Usage tips

<Tip warning={true}>

The `Llama3` models were trained using `bfloat16`, but the original inference uses `float16`. The checkpoints uploaded on the Hub use `torch_dtype = 'float16'`, which will be
used by the `AutoModel` API to cast the checkpoints from `torch.float32` to `torch.float16`.

The `dtype` of the online weights is mostly irrelevant unless you are using `torch_dtype="auto"` when initializing a model using `model = AutoModelForCausalLM.from_pretrained("path", torch_dtype = "auto")`. The reason is that the model will first be downloaded ( using the `dtype` of the checkpoints online), then it will be casted to the default `dtype` of `torch` (becomes `torch.float32`), and finally, if there is a `torch_dtype` provided in the config, it will be used.

Training the model in `float16` is not recommended and is known to produce `nan`; as such, the model should be trained in `bfloat16`.

</Tip>
Comment on lines +33 to +42
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good!


Tips:

- Weights for the Llama3 models can be obtained by filling out [this form](https://ai.meta.com/resources/models-and-libraries/llama-downloads/)
- The architecture is exactly the same as Llama2.
- The tokenizer is a BPE model based on [tiktoken](https://github.com/openai/tiktoken) (vs the one based on sentencepiece implementation for Llama2). The main difference that it ignores BPE merge rules when an input token is part of the vocab. This means that if no merge exist to produce `"hugging"`, instead of having the smallest units, like `["hug","ging"] form 2 tokens, if `"hugging"` is part of the vocab, it will be automatically returned as a token.
- The original model uses `pad_id = -1` which means that there is no padding token. We can't have the same logic, make sure to add a padding token using `tokenizer.add_special_tokens({"pad_token":"<pad>"})` and resize the token embedding accordingly. You should also set the `model.config.pad_token_id`. The `embed_tokens` layer of the model is initialized with `self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.config.padding_idx)`, which makes sure that encoding the padding token will output zeros, so passing it when initializing is recommended.
- The original checkpoint can be converted using the [conversion script](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py). The script can be called with the following (example) command:

```bash
python src/transformers/models/llama/convert_llama_weights_to_hf.py \
--input_dir /path/to/downloaded/llama/weights --model_size 7B --output_dir /output/path --llama_version 3
```

- After conversion, the model and tokenizer can be loaded via:

```python
from transformers import AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("/output/path")
model = AutoModelForCausalLM.from_pretrained("/output/path")
```

Note that executing the script requires enough CPU RAM to host the whole model in float16 precision (even if the biggest versions
come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM). For the 75B model, it's thus 145GB of RAM needed.


- When using Flash Attention 2 via `attn_implementation="flash_attention_2"`, don't pass `torch_dtype` to the `from_pretrained` class method and use Automatic Mixed-Precision training. When using `Trainer`, it is simply specifying either `fp16` or `bf16` to `True`. Otherwise, make sure you are using `torch.autocast`. This is required because the Flash Attention only support `fp16` and `bf16` data type.

## Quick usage

```py3
import transformers
import torch

model_id = "meta-llama/Meta-Llama-3-8B"

pipeline = transformers.pipeline("text-generation", model=model_id, model_kwargs={"torch_dtype": torch.bfloat16}, device_map="auto")
pipeline("Hey how are you doing today?")
```

## Resources
A ton of cool resources are already available on the documentation page of [~llama2], inviting contributors to add new recourses curated for Llama3 here! 🤗
93 changes: 93 additions & 0 deletions src/transformers/convert_slow_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1450,6 +1450,99 @@ def converted(self) -> Tokenizer:
return tokenizer


# Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
characters the bpe code barfs on.

The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
tables between utf-8 bytes and unicode strings.
"""
bs = (
list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
)
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))


class TikTokenConverter:
"""
A general tiktoken converter.
"""

def __init__(
self,
vocab_file=None,
pattern=r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""",
add_prefix_space=False,
*args,
):
super().__init__(*args)
self.vocab_file = vocab_file
self.pattern = pattern
self.add_prefix_space = add_prefix_space

def extract_vocab_merges_from_model(self, tiktoken_url: str):
try:
from tiktoken.load import load_tiktoken_bpe
except Exception:
raise ValueError(
"`tiktoken` is required to read a `tiktoken` file. Install it with " "`pip install tiktoken`."
)

bpe_ranks = load_tiktoken_bpe(tiktoken_url)
byte_encoder = bytes_to_unicode()

def token_bytes_to_string(b):
return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")])

merges = []
vocab = {}
for token, rank in bpe_ranks.items():
vocab[token_bytes_to_string(token)] = rank
if len(token) == 1:
continue
local = []
for index in range(1, len(token)):
piece_l, piece_r = token[:index], token[index:]
if piece_l in bpe_ranks and piece_r in bpe_ranks and (piece_l + piece_r) in bpe_ranks:
local.append((piece_l, piece_r, rank))
local = sorted(local, key=lambda x: (bpe_ranks[x[0]], bpe_ranks[x[1]]), reverse=False)
merges.extend(local)
merges = sorted(merges, key=lambda val: val[2], reverse=False)
merges = [(token_bytes_to_string(val[0]), token_bytes_to_string(val[1])) for val in merges]
return vocab, merges

def tokenizer(self):
vocab_scores, merges = self.extract_vocab_merges_from_model(self.vocab_file)
tokenizer = Tokenizer(BPE(vocab_scores, merges, fuse_unk=False))
if hasattr(tokenizer.model, "ignore_merges"):
tokenizer.model.ignore_merges = True
return tokenizer

def converted(self) -> Tokenizer:
tokenizer = self.tokenizer()
tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
[
pre_tokenizers.Split(Regex(self.pattern), behavior="isolated", invert=False),
pre_tokenizers.ByteLevel(add_prefix_space=self.add_prefix_space, use_regex=False),
]
)
tokenizer.decoder = decoders.ByteLevel()
tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
return tokenizer


SLOW_TO_FAST_CONVERTERS = {
"AlbertTokenizer": AlbertConverter,
"BartTokenizer": RobertaConverter,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,7 @@
("lilt", "LiLT"),
("llama", "LLaMA"),
("llama2", "Llama2"),
("llama3", "Llama3"),
("llava", "LLaVa"),
("llava_next", "LLaVA-NeXT"),
("longformer", "Longformer"),
Expand Down
Loading
Loading