-
Notifications
You must be signed in to change notification settings - Fork 93
/
Copy pathbase.py
412 lines (363 loc) · 16.6 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
import contextlib
import warnings
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from compressed_tensors.quantization import QuantizationScheme
from compressed_tensors.utils import (
align_module_device,
get_execution_device,
getattr_chain,
update_offload_parameter,
)
from loguru import logger
from pydantic import Field, PrivateAttr, field_validator
from llmcompressor.core import State
from llmcompressor.modifiers import Modifier, ModifierFactory
from llmcompressor.modifiers.quantization.calibration import freeze_module_quantization
from llmcompressor.modifiers.quantization.gptq.gptq_quantize import (
accumulate_hessian,
make_empty_hessian,
quantize_weight,
)
from llmcompressor.modifiers.quantization.quantization.base import QuantizationModifier
from llmcompressor.modifiers.utils.hooks import HooksMixin
from llmcompressor.pipelines.basic import run_pipeline as run_basic
from llmcompressor.pipelines.layer_sequential import (
run_pipeline as run_layer_sequential,
)
from llmcompressor.pipelines.sequential import run_pipeline as run_sequential
from llmcompressor.utils.metric_logging import CompressionLogger
from llmcompressor.utils.pytorch.module import get_no_split_params, qat_active
__all__ = ["GPTQModifier"]
class GPTQModifier(Modifier, HooksMixin):
"""
Implements the GPTQ algorithm from https://arxiv.org/abs/2210.17323. This modifier
uses activations to calibrate a hessian matrix, which is then used to determine
optimal quantizion values and orderings for the model weights.
| Sample yaml:
| test_stage:
| obcq_modifiers:
| GPTQModifier:
| block_size: 128
| dampening_frac: 0.001
| offload_hessians: False
| config_groups:
| group_0:
| targets:
| - "Linear"
| input_activations: null
| output_activations: null
| weights:
| num_bits: 8
| type: "int"
| symmetric: true
| strategy: "tensor"
| group_size: 128
| actorder: False
Lifecycle:
- on_initialize
- _build_quant_modifier
- register_hook(module, compress_module, "forward")
- run_sequential / run_layer_sequential / run_basic
- make_empty_hessian
- accumulate_hessian
- on_sequential_batch_end
- quantize_weight
- on_finalize
- remove_hooks()
- model.apply(freeze_module_quantization)
:param sequential_targets: list of layer names to compress during GPTQ, or
'__ALL__' to compress every layer in the model
:param block_size: Used to determine number of columns to compress in one pass
:param dampening_frac: Amount of dampening to apply to H, as a fraction of the
diagonal norm
:param quantize: Set to True to quantize using an existing quantization modifier,
or pass in the configuration for a quantization modifier if one does not
already exist in the recipe
:param offload_hessians: Set to True for decreased memory usage but increased
runtime.
:param config_groups: [Used, if a quantization modifier is not specified],
dictionary specifying quantization schemes to apply to target
modules. Modules not matching a scheme target will NOT be quantized.
:param scheme: [Used, if a quantization modifier is not specified], the quantization
scheme to apply to the model, this is a dictionary that supports all keys from
QuantizationScheme except targets, which will be set to the targets parameter
set at the modifier level. Can also be set to a dictionary of the format
`preset_scheme_name: targets` for example: `W8A8: ['Linear']` for weight 8 bit
or a string of a preset scheme if targets is provided
and activation 8 bit quantization on the Linear layers.
:param targets: list of layer names to quantize if a scheme is provided. Defaults
to Linear layers
:param ignore: [Used, if a quantization modifier is not specified]
optional list of module class names or submodule names to not
quantize even if they match a target in config_groups. Defaults to empty list.
:param num_calibration_steps: Number of steps to run post training calibration for.
When None, the entire calibration_dataloader is used
:param disable_quantization_observer_epoch: [Used, if a quantization modifier is
not specified] Epoch to disable updates to the module
quantization observers. At this point, quantized weights and zero points will
not be updated. Leave None to not disable observers during QAT. Default is None
"""
# gptq modifier arguments
sequential_update: bool = True # DEPRECIATED
sequential_targets: Union[str, List[str], None] = None
block_size: int = 128
dampening_frac: Optional[float] = 0.01
quantize: Union[bool, Dict] = True
offload_hessians: bool = False
# arguments used for attached quant modifier
config_groups: Optional[Dict[str, QuantizationScheme]] = None
scheme: Optional[Union[str, Dict[str, Any]]] = None
targets: Union[str, List[str], None] = None
ignore: List[str] = Field(default_factory=list)
num_calibration_steps: Optional[int] = None
disable_quantization_observer_epoch: Optional[float] = None
# private variables
_quantization_modifier: Optional[QuantizationModifier] = PrivateAttr(default=None)
_module_names: Dict[torch.nn.Module, str] = PrivateAttr(default_factory=dict)
_hessians: Dict[torch.nn.Module, torch.Tensor] = PrivateAttr(default_factory=dict)
_num_samples: Dict[torch.nn.Module, int] = PrivateAttr(default_factory=dict)
@field_validator("sequential_update", mode="before")
def validate_sequential_update(cls, value: bool) -> bool:
if not value:
warnings.warn(
"`sequential_update=False` is no longer supported, setting "
"sequential_update=True",
DeprecationWarning,
)
return True
def _check_build_quant_modifier(self, model: torch.nn.Module):
"""
Check the model's quantization state matches that expected by this modifier,
adding a default quantization scheme if needed
# TODO: build modifier during recipe validation
:param state: session state storing input model and calibration data
"""
quantization_already_active = qat_active(model)
if isinstance(self.quantize, bool):
if not self.quantize and quantization_already_active:
logger.warning(
"GPTQ quantization is set to False, but a "
"quantization modifier is already active on the model "
"resetting quantize to True"
)
self.quantize = True
elif self.quantize and not quantization_already_active:
logger.warning(
"GPTQ quantization is set to True without an "
"active quantization modifier."
)
self._build_quant_modifier()
return # use existing quantization modifier if there is one
else:
if not isinstance(self.quantize, Dict):
raise ValueError(
"GPTQModifier.quantize accepts only a single "
"quantization modifier or a boolean. Found "
f"type {type(self.quantize)}"
)
if len(self.quantize) != 1:
raise ValueError(
"GPTQModifier.quantize accepts only a single "
"quantization modifier or a boolean. Found "
f"{len(self.quantize)} modifiers"
)
if quantization_already_active:
logger.warning(
"Attempting to initialize quantization for GPTQ "
"but a quantization modifier has already been applied. "
"The quantization configuration defined under the "
"GPTQ modifier will be ignored."
)
self.quantize = True
return
self._build_quant_modifier_from_dict(self.quantize)
self.quantize = True
def on_initialize(self, state: State, **kwargs) -> bool:
"""
Initialize and run the GPTQ algorithm on the current state
:param state: session state storing input model and calibration data
"""
# build quantization modifier
self._check_build_quant_modifier(state.model)
if self._quantization_modifier:
self._quantization_modifier.initialize(state, **kwargs)
if not self.quantize:
raise ValueError("To use the GPTQModifier, quantization must be enabled.")
# prepare module names
self._module_names = {m: name for name, m in state.model.named_modules()}
# register hooks
for module in state.model.modules():
if getattr_chain(module, "quantization_scheme.weights", None) is not None:
# HACK: previously, embeddings were not quantized because they were not
# accessible by the layer compressor. For now, we manually ignore it,
# but in the FUTURE this should be ignored by the user
if not isinstance(module, torch.nn.Embedding):
self.register_hook(module, self.calibrate_module, "forward")
# infer sequential targets
if self.sequential_targets is None:
self.sequential_targets = get_no_split_params(state.model)
if isinstance(self.sequential_targets, str):
self.sequential_targets = [self.sequential_targets]
# infer pipeline
model_name = state.model.__class__.__name__
input_names = state.data.calib.dataset.column_names
unfixable_errors = (
torch.OutOfMemoryError,
torch._C._LinAlgError,
KeyboardInterrupt,
)
try:
run_sequential(
state.model,
state.data.calib,
self.sequential_targets,
self.ignore,
self,
)
return True
except Exception as exception:
if isinstance(exception, torch.fx.proxy.TraceError):
warnings.warn(
f"Failed to trace {model_name} with inputs {input_names}. For more "
"information on tracing with the sequential pipeline, see "
"https://github.com/vllm-project/llm-compressor/blob/main/"
"src/llmcompressor/transformers/tracing/GUIDE.md"
)
if isinstance(exception, unfixable_errors):
raise exception
warnings.warn("Falling back to layer_sequential pipeline")
try:
run_layer_sequential(
state.model,
state.data.calib,
self.sequential_targets,
self,
)
return True
except Exception as exception:
if isinstance(exception, TypeError):
warnings.warn(f"{model_name} fails layer-wise assumptions")
if isinstance(exception, unfixable_errors):
raise exception
warnings.warn(
"Falling back to basic pipeline, which requires extra memory and "
"may result in decreased accuracy. Consider using "
"`offload_hessians=True`"
)
run_basic(state.model, state.data.calib, self)
return True
def on_finalize(self, state: State, **kwargs) -> bool:
"""
disable the quantization observers used by the OBCQ algorithm
:param state: session state storing input model and calibration data
"""
if self._quantization_modifier:
self._quantization_modifier.finalize(state, **kwargs)
self.remove_hooks()
self._hessians = dict()
self._num_samples = dict()
state.model.apply(freeze_module_quantization)
return True
def calibrate_module(
self,
module: torch.nn.Module,
args: Tuple[torch.Tensor, ...],
_output: torch.Tensor,
):
"""
Quantize a module's weight according to the GPTQ algorithm
:param name: name of module being quantized
:param module: module being quantized
:param args: input arguments for module forward pass
:return: total loss from applying weight quantization to this module
"""
# Assume that first argument is the input
inp = args[0]
# Initialize hessian if not present
if module not in self._num_samples:
init_device = (
"cpu" if self.offload_hessians else get_execution_device(module)
)
self._hessians[module] = make_empty_hessian(module, device=init_device)
self._num_samples[module] = 0
# Accumulate hessian with input with optional offloading
with self._maybe_onload_hessian(module):
self._hessians[module], self._num_samples[module] = accumulate_hessian(
inp,
module,
self._hessians[module],
self._num_samples[module],
)
def on_sequential_batch_end(self):
"""
Quantize modules.
TODO: implement with event callback
"""
for module in list(self._num_samples.keys()):
name = self._module_names[module]
num_samples = self._num_samples[module]
quant_args = getattr_chain(module, "quantization_scheme.weights")
logger.info(f"Quantizing {name} using {num_samples} samples")
with (
torch.no_grad(),
align_module_device(module),
self._maybe_onload_hessian(module),
CompressionLogger(module) as comp_logger,
):
loss, quantized_weight, scale, zero_point, g_idx = quantize_weight(
module=module,
quant_args=quant_args,
hessians_dict=self._hessians,
blocksize=self.block_size,
percdamp=self.dampening_frac,
)
comp_logger.set_loss(loss)
update_offload_parameter(module, "weight", quantized_weight)
update_offload_parameter(module, "weight_scale", scale)
update_offload_parameter(module, "weight_zero_point", zero_point)
if g_idx is not None:
update_offload_parameter(module, "weight_g_idx", g_idx)
# self._hessians[module] already deleted by quantize_weight
del self._num_samples[module]
@contextlib.contextmanager
def _maybe_onload_hessian(self, module: torch.nn.Module):
if self.offload_hessians:
device = get_execution_device(module)
self._hessians[module] = self._hessians[module].to(device=device)
yield
if self.offload_hessians:
if module in self._hessians: # may have been deleted in context
self._hessians[module] = self._hessians[module].to(device="cpu")
def _build_quant_modifier(self):
"""
Build a quantization modifier based on the specified config_groups,
ignore list, and num_calibration_steps.
:postcondition: self._quantization_modifier is set to the built
quantization modifier
"""
quantization_args_names = [
"config_groups",
"targets",
"scheme",
"num_calibration_steps",
"ignore",
"disable_quantization_observer_epoch",
]
quant_args = {
key: getattr(self, key)
for key in quantization_args_names
if getattr(self, key, False)
}
logger.info(f"Building quantization modifier with args: {quant_args}")
vllm_quant_config = {"QuantizationModifier": quant_args}
self._build_quant_modifier_from_dict(vllm_quant_config)
def _build_quant_modifier_from_dict(self, quant_config):
modifier_type = list(quant_config.keys())[0]
modifier_args = quant_config[modifier_type]
self._quantization_modifier = ModifierFactory.create(
modifier_type,
allow_registered=True,
allow_experimental=True,
**modifier_args,
)