Skip to content

Commit

Permalink
add load_peft_adapter helper function
Browse files Browse the repository at this point in the history
  • Loading branch information
mobicham committed Feb 19, 2025
1 parent bab90a6 commit 0bd8d2f
Showing 1 changed file with 32 additions and 0 deletions.
32 changes: 32 additions & 0 deletions hqq/utils/patching.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,3 +243,35 @@ def forward_updated(self, x: Tensor) -> Tensor:
layer.forward = lambda x: forward_updated(layer, x)

return layer


#Loads Peft-HF adapter into an HQQ model
def load_peft_adapter(model, adapter_dir):
from safetensors import safe_open
from peft import get_peft_model, PeftConfig

#Load config
model = get_peft_model(model, PeftConfig.from_pretrained(adapter_dir))

#Load weights
tensors = {}
with safe_open(adapter_dir + "/adapter_model.safetensors", framework="pt", device=model.device.type) as f:
for key in f.keys():
base, param = '.'.join(key.split('.')[:-1]) + '.default', key.split('.')[-1]
if(base not in tensors):
tensors[base] = {}
tensors[base][param] = torch.nn.Parameter(f.get_tensor(key))

#Full module name
for name, module in model.named_modules():
module.name = name

#Assign weights
def _patch(model):
for name, layer in model.named_children():
if(layer.name in tensors):
for p,v in tensors[layer.name].items():
setattr(layer, p, v)
_patch(layer)

_patch(model)

0 comments on commit 0bd8d2f

Please sign in to comment.