Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new features: Safe LoRA #2098

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open

Conversation

chiayi-hsu
Copy link

Hello,

We have published a paper called Safe LoRA (https://arxiv.org/abs/2405.16833).
This work focuses on improving the safety of well-trained LoRA models.
Additionally, I have provided an example implementation in the model.py file to illustrate how to apply the Safe LoRA approach effectively.

@BenjaminBossan BenjaminBossan changed the title Add new features Add new features: Safe LoRA Sep 26, 2024
@BenjaminBossan
Copy link
Member

Thanks for proposing this PR to add the Safe LoRA method to PEFT. I have only skimmed the code and paper, so this is not an in-depth review yet.

Based on what I saw, this is my high level understanding. As a user, I start out by training a LoRA adapter on my dataset using an aligned model as the base model. Next, I want to restore safety which may be reduced after training. For this, I take my trained LoRA model, the base model, and the aligned model, Then Safe LoRA will create a new adapter that "injects" the alignment back into my LoRA adapter. Please correct me if my understanding is incorrect.

Implementation-wise, I wonder if we really need a SafeLoRA class. Here is a proposal for a different user API, LMK what you think:

peft_model = ...  # my trained LoRA model

# apply_safe_lora is a new function that does what SafeLoRA currently does
# additional arguments like select_layers_type are option args for apply_safe_lora
aligned_model = apply_safe_lora(peft_model, base_model_id, aligned_model_id)

# persist the new safe LoRA adapter
aligned_model.save_pretrained(...)

IMO, this would be a simpler API and it would achieve the same thing at the end.

In this example, the aligned_model could actually be the same peft_model as initially, just with a new LoRA adapter loaded which is the Safe LoRA adapter.

Another concern I saw when I read the code is that it looks like it will require a lot of memory. IIUC, we need to have the PEFT model, a copy of the PEFT model, the base model, and the aligned model in memory all at once. Even on CPU RAM, this could be difficult to achieve for many users. I wonder if it would be possible to load the weights from the base model and the aligned model one at a time, at least as an option.

I have already implemented something like this for another use case. Here is the code:

class _SafetensorLoader:
"""
Simple utility class that loads tensors with safetensors from a single file or sharded files.
Takes care of file name normalization etc.
"""
def __init__(self, peft_model, model_path):
if model_path is None:
try:
model_path = snapshot_download(peft_model.base_model.config._name_or_path, local_files_only=True)
except (AttributeError, HFValidationError) as exc:
raise ValueError(
"The provided model does not appear to be a transformers model or is a local model. In this case, "
"you must pass the model_path argument that points to the safetensors file."
) from exc
except LocalEntryNotFoundError as exc:
raise ValueError(
"The model.safetensors file must be present on disk, but it could not be found."
) from exc
suffix = "model.safetensors"
if not model_path.endswith(suffix):
model_path = os.path.join(model_path, suffix)
self.model_path = model_path
self.base_model_prefix = getattr(peft_model.get_base_model(), "base_model_prefix", None)
self.prefix = "base_model.model."
self.is_sharded = False
self.weight_map = None
if not os.path.exists(model_path):
# check if the file is sharded
par_dir = model_path.rpartition(os.path.sep)[0]
try:
resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
par_dir, cached_file(par_dir, "model.safetensors.index.json")
)
except OSError as exc:
raise FileNotFoundError(
f"Could not find file for {model_path}, ensure that there is a (sharded) safetensors file of the model."
) from exc
self.is_sharded = True
# maps from 'model-X-of-Y.safetensors' to full file path
file_map = {k.rpartition(os.path.sep)[-1]: k for k in resolved_archive_file}
self.weight_map = {k: file_map[v] for k, v in sharded_metadata["weight_map"].items()}
def get_tensor(self, name):
if not self.is_sharded:
file_path = self.model_path
else:
file_path = self.weight_map[name]
with safe_open(file_path, framework="pt", device="cpu") as f:
try:
tensor = f.get_tensor(name)
except SafetensorError as exc:
# no matching key found, we probably need to remove the base model prefix
if self.base_model_prefix:
# remove 1 extra character for "."
name = name[len(self.base_model_prefix) + 1 :]
tensor = f.get_tensor(name)
else:
raise exc
return tensor
@torch.no_grad()
def replace_lora_weights_loftq(
peft_model,
model_path: Optional[str] = None,
adapter_name: str = "default",
callback: Optional[Callable[[torch.nn.Module, str], bool]] = None,
):
"""
Replace the LoRA weights of a model quantized with bitsandbytes, using the LoftQ technique.
The replacement is done on the fly by loading in the non-quantized weights from a locally stored safetensors model
file and initializing the LoRA weights such that the quantization error between the original and quantized weights
is minimized.
As lazy loading is not possible with pickle, normal PyTorch checkpoint files cannot be supported.
Depending on the model size, calling this function may take some time to finish.
Args:
peft_model (`PeftModel`):
The model to replace the weights of. Must be a quantized PEFT model with LoRA layers.
model_path (`Optional[str]`):
The path to the model safetensors file. If the model is a Hugging Face model, this will be inferred from
the model's config. Otherwise, it must be provided.
adapter_name (`str`):
The name of the adapter to replace the weights of. The default adapter name is "default".
callback (`Optional[Callable[[PeftModel, str], bool]]`):
A callback function that will be called after each module is replaced. The callback function should take
the model and the name of the current module as input and return a boolean indicating whether the
replacement should be kept. If the callback returns False, the replacement will be rolled back. This can be
very useful to confirm that the LoftQ initialization actually decreases the quantization error of the
model. As an example, this callback could generate logits for given input and compare it with the logits
from the original, non-quanitzed model with the same input, and only return `True` if there is an
improvement. As this is a greedy optimization, it's possible that calling this function multiple times
yields incremental improvements.
"""
if not is_bnb_4bit_available():
raise ValueError("bitsandbytes must be installed and the model must be quantized in 4bits.")
from peft.tuners.lora import Linear4bit
# model_path = _check_model_path_loftq(model_path, peft_model)
prefix = "base_model.model."
any_match = False
safetensor_loader = _SafetensorLoader(peft_model, model_path)
# if too slow, consider adding tqdm as an option
for name, module in peft_model.named_modules():
if not isinstance(module, Linear4bit):
continue
if not name.startswith(prefix):
raise TypeError("The passed model does not appear to be a valid PeftModel")
any_match = True
name = name[len(prefix) :]
tensor = safetensor_loader.get_tensor(name + ".weight")
reduced_rank = module.r[adapter_name]
lora_A, lora_B = _loftq_init_new(module.weight, tensor, num_bits=4, reduced_rank=reduced_rank)
if not callback:
module.lora_A[adapter_name].weight.data = lora_A
module.lora_B[adapter_name].weight.data = lora_B
continue
lora_A_before = module.lora_A[adapter_name].weight.data
lora_B_before = module.lora_B[adapter_name].weight.data
module.lora_A[adapter_name].weight.data = lora_A
module.lora_B[adapter_name].weight.data = lora_B
should_replace = callback(peft_model, name)
if not should_replace:
# roll back
module.lora_A[adapter_name].weight.data = lora_A_before
module.lora_B[adapter_name].weight.data = lora_B_before
del lora_A_before, lora_B_before
if not any_match:
raise ValueError("No bnb LoRA module found on the model")

It only works for safetensors, because pickle does not allow to load the weights lazily, but I still think it could be a nice addition.

Also, do we really need a full copy of the PEFT model? Would it be sufficient to only copy the LoRA weights? Ideally, I would like to see a solution that works even if the user has less than twice the memory required for the base model. If this doesn't work, it's also okay, but it would greatly reduce the number of users who can use the method.

@chiayi-hsu
Copy link
Author

Thank you for taking the time to read the code and the paper.

Yes, your understanding is correct.

Since SafeLoRA is not a new training method for LoRA, but rather a process that enhances the safety of a well-trained LoRA model through certain operations, it may not be necessary to have a separate SafeLoRA class.

The most CPU memory-intensive operation here is likely when executing get_aligned_matrix(), as it requires loading both the base model and the aligned model. If it were possible to implement lazy loading for both models simultaneously, allowing for the subtraction of weights in the same position to obtain the aligned matrix, it could potentially be a solution.

Regarding your last question, my current code in projected_weighted() does load the complete PEFT model (aligned model + LoRA). However, it is indeed possible to operate only on the LoRA weights without needing to load the weights of the aligned model as well. Once the user has obtained the new LoRA weights, they can then add them to the original aligned model.

@BenjaminBossan
Copy link
Member

Thanks for confirming and further explaining. Would you be willing to make the suggested changes? I think it would help a lot with user adoption of the new method.

@chiayi-hsu
Copy link
Author

Yes, I will make the changes you suggested.

I would like to ask if SafeLoRA is modified into the form of a function like apply_safelora instead of being a safeLoRA class, should it still be placed under tuners/, or somewhere else?

@BenjaminBossan
Copy link
Member

Yes, I will make the changes you suggested.

Great, thanks.

I would like to ask if SafeLoRA is modified into the form of a function like apply_safelora instead of being a safeLoRA class, should it still be placed under tuners/, or somewhere else?

Good question, I think it should be placed elsewhere. I don't have a strong opinion, how about utils/safelora.py?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants