Skip to content

Commit

Permalink
[Feature] Modify description and api for ascend quantization (#243)
Browse files Browse the repository at this point in the history
### What this PR does / why we need it?
1. It adds more description for classes in quant_config.py
2. It renames AscendQKVQuantAttentionMethod to AscendKVCacheMethod to
align with vLLM naming style.
3. It modifies the process when AscendLinearMethod or
AscendKVCacheMethod calls create_weights.


### Does this PR introduce _any_ user-facing change?
Yes. When creating weights, now AscendLinearMethod uses get_weight,
get_pertensor_param and get_perchannel_param api from linear quant
implementation, while AscendKVCacheMethod passes layer into linear quant
implementation.

### How was this patch tested?
By performing offline inference

---------

Signed-off-by: angazenn <[email protected]>
Co-authored-by: angazenn <[email protected]>
  • Loading branch information
Angazenn and angazenn authored Mar 6, 2025
1 parent cff08f9 commit 3217f0d
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 74 deletions.
124 changes: 52 additions & 72 deletions vllm_ascend/quantization/quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter,
ModelWeightParameter)
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter)

from .quantizer import AscendQuantizer

Expand All @@ -41,7 +41,11 @@

@register_quantization_config("ascend")
class AscendQuantConfig(QuantizationConfig):
"""Config class for Ascend"""
"""Config class for Ascend
This class is a general class that parse quantization configs
that are supported on ascend hardware.
"""

def __init__(self, quant_config: Dict[str, Any]):
self.quant_description = quant_config
Expand Down Expand Up @@ -84,10 +88,10 @@ def get_quant_method(self, layer: torch.nn.Module,
if self.is_layer_skipped_ascend(prefix,
self.packed_modules_mapping):
return UnquantizedLinearMethod()
return AscendLinearMethod(self)
return AscendLinearMethod(self, prefix)
if isinstance(layer, Attention) and \
'fa_quant_type' in self.quant_description.keys():
return AscendQKVQuantAttentionMethod(self)
return AscendKVCacheMethod(self, prefix)
return None

def is_layer_skipped_ascend(
Expand Down Expand Up @@ -127,13 +131,16 @@ def get_scaled_act_names(self) -> List[str]:
class AscendLinearMethod(LinearMethodBase):
"""Linear method for Ascend quantization.
This class calls AscendQuantizer to search a specific quantization
implementations supported on ascend hardware for linear methods.
Args:
quant_config: The Ascend quantization config.
"""

def __init__(self, quant_config: AscendQuantConfig) -> None:
def __init__(self, quant_config: AscendQuantConfig, prefix: str) -> None:
self.quantizer = AscendQuantizer.get_quantizer(
quant_config.quant_description)
quant_config.quant_description, prefix)
self.quant_method = self.quantizer.build_linear_method()

def create_weights(
Expand All @@ -146,57 +153,40 @@ def create_weights(
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> None:
del output_size
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")

weights = self.quant_method.create_weights(input_size_per_partition,
weight_dict = self.quant_method.get_weight(input_size_per_partition,
output_size_per_partition,
params_dtype)

weight_name = self.quant_method.get_weight()
if weight_name in weights.keys():
for weight_name, weight_param in weight_dict.items():
layer.register_parameter(
weight_name,
ModelWeightParameter(data=weights[weight_name].transpose(0, 1),
ModelWeightParameter(data=weight_param,
input_dim=1,
output_dim=0,
weight_loader=weight_loader))
else:
raise ValueError(
f"{weight_name} is nor registered. Please check your linear quant method implementation."
)

pertensor_names = self.quant_method.get_pertensor_param()
for pertensor_name in pertensor_names:
if pertensor_name in weights.keys():
param = BasevLLMParameter(data=weights[pertensor_name],
weight_loader=weight_loader)
# disable warning
param.ignore_warning = True
layer.register_parameter(pertensor_name, param)
else:
raise ValueError(
f"{pertensor_name} is nor registered. Please check your linear quant method implementation."
)

perchannel_names = self.quant_method.get_perchannel_param()
for perchannel_name in perchannel_names:
if perchannel_name in weights.keys():
layer.register_parameter(
perchannel_name,
ChannelQuantScaleParameter(data=weights[perchannel_name],
output_dim=0,
weight_loader=weight_loader))
else:
raise ValueError(
f"{perchannel_name} is nor registered. Please check your linear quant method implementation."
)

pertensor_dict = self.quant_method.get_pertensor_param(params_dtype)
for pertensor_name, pertensor_param in pertensor_dict.items():
param = PerTensorScaleParameter(data=pertensor_param,
weight_loader=weight_loader)
# disable warning
param.ignore_warning = True
layer.register_parameter(pertensor_name, param)

perchannel_dict = self.quant_method.get_perchannel_param(
output_size_per_partition, params_dtype)
for perchannel_name, perchannel_param in perchannel_dict.items():
layer.register_parameter(
perchannel_name,
ChannelQuantScaleParameter(data=perchannel_param,
output_dim=0,
weight_loader=weight_loader))

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if hasattr(self.quant_method,
'transpose_weight') and self.quant_method.transpose_weight:
layer.weight.data = layer.weight.data.transpose(1, 0)
if hasattr(self.quant_method, "process_weights_after_loading"):
self.quant_method.process_weights_after_loading(layer)

def apply(
self,
Expand All @@ -210,47 +200,37 @@ def apply(
return self.quant_method.apply(layer, x, bias)


class AscendQKVQuantAttentionMethod(BaseKVCacheMethod):
"""Linear method for Ascend quantization.
class AscendKVCacheMethod(BaseKVCacheMethod):
"""KVCache method for Ascend quantization.
This class calls AscendQuantizer to search a specific quantization
implementations supported on ascend hardware for kvcache methods.
Args:
quant_config: The Ascend quantization config.
"""

def __init__(self, quant_config: AscendQuantConfig) -> None:
def __init__(self, quant_config: AscendQuantConfig, prefix: str) -> None:
self.quantizer = AscendQuantizer.get_quantizer(
quant_config.quant_description)
quant_config.quant_description, prefix)
self.quant_method = self.quantizer.build_attention_method()

def create_weights(self, layer: torch.nn.Module) -> None:
# ascend attention quantization might include some extra weights
# and must be loaded by dummy modules
extra_module_names = self.quant_method.get_extra_module_names()
for name in extra_module_names:
setattr(layer, name, torch.nn.Module())

# During model initialization, the default dtype is set as the model
# weight and activation dtype.
dtype = torch.get_default_dtype()
weights = self.quant_method.create_weights(dtype, layer.num_heads,
layer.num_kv_heads)

for name, weight in weights.items():
module_name, weight_name = name.split('.')
module = getattr(layer, module_name)
module.register_parameter(
weight_name, torch.nn.Parameter(weight, requires_grad=False))
# Different from linear method, there are no weight processing/slicing
# steps for attention in vllm. So the whole process of create weights
# is hidden into the specific quant method.
self.quant_method.create_weights(layer)

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if hasattr(self.quant_method, "process_weights_after_loading"):
self.quant_method.process_weights_after_loading(layer)

def apply(self, layer: torch.nn.Module, query: torch.Tensor,
key: torch.Tensor, value: torch.Tensor, key_cache: torch.Tensor,
value_cache: torch.Tensor, scale: torch.Tensor,
key: torch.Tensor, value: torch.Tensor,
kv_cache: List[torch.Tensor], scale: torch.Tensor,
seq_lens_tensor_cpu: int, block_tables: torch.Tensor,
isPrefill: bool, attn_metadata, output) -> torch.Tensor:
return self.quant_method.apply(layer, query, key, value, key_cache,
value_cache, scale, seq_lens_tensor_cpu,
return self.quant_method.apply(layer, query, key, value, kv_cache,
scale, seq_lens_tensor_cpu,
block_tables, isPrefill, attn_metadata,
output)
4 changes: 2 additions & 2 deletions vllm_ascend/quantization/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class AscendQuantizer:
"""An interface to different quantization implementations for ascend hardwares."""

@classmethod
def get_quantizer(cls, quant_config: Dict[str, Any]):
def get_quantizer(cls, quant_config: Dict[str, Any], prefix: str):
# TODO: Need a param to choose quantization algorithms.
quantization_algorithm = ''

Expand All @@ -39,7 +39,7 @@ def get_quantizer(cls, quant_config: Dict[str, Any]):
raise NotImplementedError(
"There is no available ascend quantizer.")

return MindIETurboQuantizer.get_quantizer(quant_config)
return MindIETurboQuantizer.get_quantizer(quant_config, prefix)

def build_linear_method(self):
raise NotImplementedError
Expand Down

0 comments on commit 3217f0d

Please sign in to comment.