Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Florence-2 | Default LoRA config produces suboptimal results | Found a better config #162

Open
2 tasks done
patel-zeel opened this issue Feb 16, 2025 · 9 comments
Open
2 tasks done
Labels
bug Something isn't working

Comments

@patel-zeel
Copy link

patel-zeel commented Feb 16, 2025

Search before asking

  • I have searched the Multimodal Maestro issues and found no similar bug report.

Bug

The default LoRA config used in maestro is

config = LoraConfig(
r=8,
lora_alpha=16,
lora_dropout=0.05,
bias="none",
target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "linear", "Conv2d", "lm_head", "fc2"],
task_type="CAUSAL_LM",
)

LoRA config used in the Florence-2 fine-tuning on custom dataset Roboflow notebook is the following:

config = LoraConfig(
    r=8,
    lora_alpha=8,
    target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "linear", "Conv2d", "lm_head", "fc2"],
    task_type="CAUSAL_LM",
    lora_dropout=0.05,
    bias="none",
    inference_mode=False,
    use_rslora=True,
    init_lora_weights="gaussian",
    revision=REVISION
)

For the poker-cards-fmjio dataset, the default LoRA config of maestro results in a mAP50 value of 0.20, but the Roboflow notebook config results in a mAP50 value of 0.52. I experimentally found a config that results in a mAP50 value of 0.71. Please see Minimal Reproducible Example for more.

Environment

  • multimodel-maestro = 1.0.0
  • OS: Ubuntu 20.04
  • Python: 3.10.15

Minimal Reproducible Example

I used 3 variants of LoRA config and results are as described below:

Configs

Maestro default

config = LoraConfig(
  r=8,
  lora_alpha=16,
  target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "linear", "Conv2d", "lm_head", "fc2"],
  task_type="CAUSAL_LM",
  lora_dropout=0.05,
  bias="none",
)

Maestro default + Gaussian init

config = LoraConfig(
  r=8,
  lora_alpha=16,
  target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "linear", "Conv2d", "lm_head", "fc2"],
  task_type="CAUSAL_LM",
  lora_dropout=0.05,
  bias="none",
  init_lora_weights="gaussian",
)

Roboflow notebook default

config = LoraConfig(
  r=8,
  lora_alpha=8,
  target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "linear", "Conv2d", "lm_head", "fc2"],
  task_type="CAUSAL_LM",
  lora_dropout=0.05,
  bias="none",
  inference_mode=False,
  use_rslora=True,
  init_lora_weights="gaussian",
)

Roboflow notebook default except lora_alpha=16

config = LoraConfig(
  r=8,
  lora_alpha=16,
  target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "linear", "Conv2d", "lm_head", "fc2"],
  task_type="CAUSAL_LM",
  lora_dropout=0.05,
  bias="none",
  inference_mode=False,
  use_rslora=True,
  init_lora_weights="gaussian",
)

Metrics

I used the Roboflow notebook to run the pipeline for 10 epochs and compute the metrics. I have used the new evaluation API as follows:

mean_average_precision = sv.metrics.MeanAveragePrecision().update(predictions, targets).compute()
map50_95 = mean_average_precision.map50_95
map50 = mean_average_precision.map50

p = sv.metrics.Precision().update(predictions, targets).compute()
precision_at_50 = p.precision_at_50

r = sv.metrics.Recall().update(predictions, targets).compute()
recall_at_50 = r.recall_at_50

Results

Config mAP50 mAP50-95 Precision50 Recall50
Maestro default 0.20 0.18 0.21 0.14
Maestro default + Gaussian init 0.32 0.30 0.54 0.35
Roboflow notebook default 0.52 0.47 0.66 0.58
Roboflow notebook default except lora_alpha=16 0.71 0.65 0.78 0.75

Conclusion

Using lora_alpha=16 in Roboflow notebook default LoRA config results in much better performance with same number of epochs.

Questions

  1. Should maestro give LoRA config control to users? Probably users can then play with the values and find the best config that works for them. It might be defined as toml or json or any other format file and then users provide the path of the config to maestro CLI.

Additional

No response

Are you willing to submit a PR?

  • Yes I'd like to help by submitting a PR!
@patel-zeel patel-zeel added the bug Something isn't working label Feb 16, 2025
@SkalskiP
Copy link
Collaborator

Hi @patel-zeel great analysis!

Should maestro give LoRA config control to users?

Absolutely. We just need to develop a consistent system where users can do this both through the SDK and the CLI.

@patel-zeel
Copy link
Author

Thank you, @SkalskiP. Yes, that'd be great. Having a library that allows fine-tuning VLMs with a single line of CLI is fantastic. I'd happily contribute now or later when the library is relatively stable.

@SkalskiP
Copy link
Collaborator

One of our main goals right now is to design the solution in a way that remains flexible enough to handle similar fine-tuning scenarios in the future. Your input and code would be tremendously appreciated—whether it’s refining the design, writing new features, or improving documentation. Let’s work together to make it as robust and reusable as possible!

I see 2 main ways we can implement this:

1: Extend Florence2Configuration with LoRA settings.

SDK Example

from maestro.trainer.models.florence_2.core import train, Florence2Configuration

config = Florence2Configuration(
    dataset="dataset/location",
    epochs=10,
    batch_size=4,
    optimization_strategy="lora",
    metrics=["edit_distance"],
    lora_r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    lora_bias="none",
    lora_init_lora_weights="gaussian",
    lora_use_rslora=True,
    lora_inference_mode=False
)
train(config)

CLI Example

maestro florence_2 train \
  --dataset "dataset/location" \
  --epochs 10 \
  --batch-size 4 \
  --optimization_strategy "lora" \
  --metrics "edit_distance" \
  --lora-r 8 \
  --lora-alpha 16 \
  --lora-dropout 0.05 \
  --lora-bias "none" \
  --lora-init-lora-weights "gaussian" \
  --lora-use-rslora True \
  --lora-inference-mode False

Approach 2: Single "advanced" parameter argument with inlined JSON/YAML

Add a single parameter (e.g., --peft-advanced-params) where users provide JSON or YAML. You parse this into a dictionary and pass it to LoraConfig. This supports all LoRA parameters without adding many new flags.

SDK Example

from maestro.trainer.models.florence_2.core import train, Florence2Configuration

advanced_params = {
    "r": 8,
    "lora_alpha": 16,
    "lora_dropout": 0.05,
    "bias": "none",
    "init_lora_weights": "gaussian",
    "use_rslora": True
}

config = Florence2Configuration(
    dataset="dataset/location",
    epochs=10,
    batch_size=4,
    optimization_strategy="lora",
    metrics=["edit_distance"],
    peft_advanced_params=advanced_params
)
train(config)

CLI Example

maestro florence_2 train \
  --dataset "dataset/location" \
  --epochs 10 \
  --batch-size 4 \
  --optimization_strategy "lora" \
  --metrics "edit_distance" \
  --peft-advanced-params '{"r":8,"lora_alpha":16,"lora_dropout":0.05,"bias":"none","use_rslora":true}'

@PawelPeczek-Roboflow @Matvezy @probicheaux what do you think about this?

@patel-zeel
Copy link
Author

patel-zeel commented Feb 18, 2025

I am leaning toward the second approach so that the API doesn't need to be changed to accommodate every minor change.

@PawelPeczek-Roboflow
Copy link
Collaborator

👍 on schema-less configs plus it would be great to have ability to load configs from files and override configs with params explicitly given in CLI command - this way you can have a base config that you quickly modify at will while running a training

@patel-zeel
Copy link
Author

patel-zeel commented Feb 18, 2025

@PawelPeczek-Roboflow @SkalskiP Considering the base config idea, it seems that hierarchical config could be unintuitive to users if they just want to change a single parameter, e.g., lora_alpha or a few parameters (which would mostly be the case). In that case, the first option suggested by @SkalskiP (also similar to Ultralytics API) might work better. Something like:

Default values hardcoded in the library

"epochs": 10,
"batch_size": 4,
"optimization_strategy": "lora",
"metrics": ["edit_distance"],
"lora_r": 8,
"lora_alpha": 8,
"lora_dropout": 0.05,
"lora_bias": "none",
"lora_init_lora_weights": "gaussian",
"lora_use_rslora": true,
"lora_inference_mode": false

SDK Example for one parameter change

from maestro.trainer.models.florence_2.core import train, Florence2Configuration

config = Florence2Configuration(
    dataset="dataset/location",
    lora_alpha = 16,  # overrides the default value
)
train(config)

CLI Example for one parameter change

maestro florence_2 train dataset "dataset/location" lora_alpha 16

@SkalskiP
Copy link
Collaborator

Hi @patel-zeel 👋🏻 sorry for the lack of contact over the past few days. I had to focus on other projects than maestro for a while, but I'm coming back.

I talked to @Matvezy @probicheaux in private messages and I think I'm leaning towards solution number 2, potentially adding support for config files that @PawelPeczek-Roboflow suggested in the future (not as part of this task). @patel-zeel would you like to work on the implementation?

@patel-zeel
Copy link
Author

Hi @patel-zeel 👋🏻 sorry for the lack of contact over the past few days. I had to focus on other projects than maestro for a while, but I'm coming back.

No worries, @SkalskiP!

I talked to @Matvezy @probicheaux in private messages and I think I'm leaning towards solution number 2, potentially adding support for config files that @PawelPeczek-Roboflow suggested in the future (not as part of this task). @patel-zeel would you like to work on the implementation?

Sure. Do you mean we don't want to add the config file support just yet but first enable dictionary like support suggested in solution 2?

@SkalskiP
Copy link
Collaborator

Exactly!👍🏻

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants