-
Notifications
You must be signed in to change notification settings - Fork 93
/
Copy pathwhisper_example.py
46 lines (39 loc) · 1.74 KB
/
whisper_example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from datasets import load_dataset
from transformers import AutoProcessor, WhisperForConditionalGeneration
from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier
MODEL_ID = "openai/whisper-large-v2"
# Load model.
model = WhisperForConditionalGeneration.from_pretrained(
MODEL_ID, device_map="auto", torch_dtype="auto"
)
model.config.forced_decoder_ids = None
processor = AutoProcessor.from_pretrained(MODEL_ID)
processor.tokenizer.set_prefix_tokens(language="en", task="transcribe")
# Configure the quantization algorithm and scheme.
# In this case, we:
# * quantize the weights to fp8 with per channel via ptq
# * quantize the activations to fp8 with dynamic per token
recipe = QuantizationModifier(
targets="Linear", scheme="FP8_DYNAMIC", ignore=["lm_head"]
)
# Apply quantization.
oneshot(model=model, recipe=recipe)
# Confirm generations of the quantized model look sane.
print("========== SAMPLE GENERATION ==============")
ds = load_dataset(
"hf-internal-testing/librispeech_asr_dummy", "clean", split="validation[:1]"
)
sample = ds[0]["audio"]
input_features = processor(
sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt"
).input_features
input_features = input_features.to(model.device)
output_ids = model.generate(input_features, language="en", forced_decoder_ids=None)
print(processor.batch_decode(output_ids, skip_special_tokens=False)[0])
# Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel
print("==========================================")
# Save to disk in compressed-tensors format.
SAVE_DIR = MODEL_ID.split("/")[1] + "-FP8-Dynamic"
model.save_pretrained(SAVE_DIR, save_compressed=True)
processor.save_pretrained(SAVE_DIR)