From c60218e8c7ddb9d15f792f600b75447002897713 Mon Sep 17 00:00:00 2001 From: mobicham Date: Thu, 20 Feb 2025 11:02:33 +0000 Subject: [PATCH] check W_q in state_dict --- hqq/core/quantize.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/hqq/core/quantize.py b/hqq/core/quantize.py index e0a4da0..0870da6 100755 --- a/hqq/core/quantize.py +++ b/hqq/core/quantize.py @@ -680,11 +680,13 @@ def _load_from_state_dict( layer_state_dict[key] = state_dict.pop(prefix + key) else: if(key not in ['bias']): - missing_keys.append(prefix + key) - - layer_state_dict['W_q'] = nn.Parameter(layer_state_dict['W_q'], requires_grad=False) + missing_keys.append(prefix + key) - self.load_state_dict(layer_state_dict, strict=strict) + if 'W_q' in layer_state_dict: + layer_state_dict['W_q'] = nn.Parameter(layer_state_dict['W_q'], requires_grad=False) + self.load_state_dict(layer_state_dict, strict=strict) + else: + missing_keys.append(prefix + "W_q") def load_state_dict(self, state_dict, strict=True, assign=False): if "encoded_state_dict" in state_dict: