Skip to content

Commit

Permalink
check W_q in state_dict
Browse files Browse the repository at this point in the history
  • Loading branch information
mobicham committed Feb 20, 2025
1 parent 0bd8d2f commit c60218e
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions hqq/core/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,11 +680,13 @@ def _load_from_state_dict(
layer_state_dict[key] = state_dict.pop(prefix + key)
else:
if(key not in ['bias']):
missing_keys.append(prefix + key)

layer_state_dict['W_q'] = nn.Parameter(layer_state_dict['W_q'], requires_grad=False)
missing_keys.append(prefix + key)

self.load_state_dict(layer_state_dict, strict=strict)
if 'W_q' in layer_state_dict:
layer_state_dict['W_q'] = nn.Parameter(layer_state_dict['W_q'], requires_grad=False)
self.load_state_dict(layer_state_dict, strict=strict)
else:
missing_keys.append(prefix + "W_q")

def load_state_dict(self, state_dict, strict=True, assign=False):
if "encoded_state_dict" in state_dict:
Expand Down

0 comments on commit c60218e

Please sign in to comment.