-
Notifications
You must be signed in to change notification settings - Fork 93
/
Copy pathquantization_format.py
115 lines (101 loc) · 5 KB
/
quantization_format.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from typing import Optional
from compressed_tensors import CompressionFormat
from compressed_tensors.config import SparsityStructure
from compressed_tensors.quantization import QuantizationStrategy, QuantizationType
from compressed_tensors.quantization.utils import (
is_model_quantized,
is_module_quantized,
iter_named_leaf_modules,
)
__all__ = ["infer_quantization_format"]
def infer_quantization_format(
model,
quantization_format: Optional[str] = None,
save_compressed: bool = False,
sparsity_structure: Optional[str] = None,
) -> str:
"""
Infers the quantization format for a model based on its state and provided
compression arguments.
The following table outlines the possible quantization and sparsity formats
along with their corresponding compressor formats:
+---------------+----------+----------------------+---------------------+
| Quantization | Sparsity | Quant Compressor | Sparsity Compressor |
| | | Format | Format |
+---------------+----------+----------------------+---------------------+
| W8A8 - int | None | int_quantized | Dense |
| W8A8 - float | None | float_quantized | Dense |
| W4A16 - int | None | pack_quantized | Dense |
| W8A16 - int | None | pack_quantized | Dense |
| W8A16 - float | None | naive_quantized | Dense |
| W8A8 - int | 2:4 | int_quantized | Sparse24 |
| W8A8 - float | 2:4 | float_quantized | Sparse24 |
| W4A16 - int | 2:4 | marlin_24 | Dense |
| W8A16 - int | 2:4 | marlin_24 | Dense |
| W8A16 - float | 2:4 | naive_quantized | Dense |
+---------------+----------+----------------------+---------------------+
:param model: model to check for quantization, if the model is not quantized no
quantization format is returned
:param quantization_format: user provided quantization format, supercedes any
inferred quantization format
:param save_compressed: used to infer a quantization format if None is provided
:return compression format appropriate for model
"""
if not is_model_quantized(model):
return None
if quantization_format is not None:
return quantization_format
if save_compressed:
weight_args, input_args = _get_unique_quant_args(model)
is_24_structure = (
SparsityStructure(sparsity_structure) == SparsityStructure.TWO_FOUR
)
is_weight_only = len(input_args) == 0 and len(weight_args) > 0
if is_weight_only: # w4a16 and w8a16
is_valid_pack = all(
weight_arg.num_bits in [4, 8]
and weight_arg.type == QuantizationType.INT.value
for weight_arg in weight_args
)
if not is_valid_pack: # packing only valid for int4 and int 8
return CompressionFormat.naive_quantized
if is_24_structure:
for arg in weight_args:
if (
arg.strategy is not QuantizationStrategy.CHANNEL.value
and arg.strategy is not QuantizationStrategy.GROUP.value
):
# marlin24 kernel only applicable for channel/group quantization
return CompressionFormat.pack_quantized
return CompressionFormat.marlin_24
return CompressionFormat.pack_quantized
else: # w8a8 float and int
if len(weight_args) == 1:
if (
weight_args[0].type == QuantizationType.FLOAT.value
and weight_args[0].num_bits == 8
):
return CompressionFormat.float_quantized
if weight_args[0].type == QuantizationType.INT.value:
return CompressionFormat.int_quantized
return CompressionFormat.naive_quantized
else:
# format will be inferred from config
return None
def _get_unique_quant_args(model):
"""
Gets a list of all the unique quantization settings present in model
"""
quant_info_weight = []
quant_info_inputs = []
for _, submodule in iter_named_leaf_modules(model):
if is_module_quantized(submodule):
weight_scheme = submodule.quantization_scheme.weights
input_scheme = submodule.quantization_scheme.input_activations
if weight_scheme is not None:
if weight_scheme not in quant_info_weight:
quant_info_weight.append(weight_scheme)
if input_scheme is not None:
if input_scheme not in quant_info_inputs:
quant_info_inputs.append(input_scheme)
return quant_info_weight, quant_info_inputs