by Tony Assi
Image Regression model training and inference. The trainer fine-tunes google/vit-base-patch16-224 with a custom PyTorch model that takes an image as input and outputs a number. You can upload the model to π€ Hub and use a simple predict function for inference. Built with π€ and PyTorch.
git clone https://github.com/TonyAssi/ImageRegression.git
cd ImageRegression
pip install -r requirements.txt
from ImageRegression import train_model, upload_model, predict
- dataset_id π€ dataset id (see Dataset)
- value_column_name column name of prediction values in dataset
- test_split test split of the train/test split
- output_dir the directory where the checkpoints will be saved
- num_train_epochs training epochs
- learning_rate learning rate
train_model(dataset_id='tonyassi/clothing-sales-ds',
value_column_name='sales',
test_split=0.2,
output_dir='./results',
num_train_epochs=10,
learning_rate=1e-4)
The trainer will save the checkpoints in the output_dir location. The model.safetensors are the trained weights you'll use for inference (predicton).
This function will upload your model to the π€ Hub, which will be useful for inference.
- model_id the name of the model id
- token go here to create a new π€ token
- checkpoint_dir checkpoint folder that will be uploaded
upload_model(model_id='sales-prediction',
token='YOUR_HF_TOKEN',
checkpoint_dir='./results/checkpoint-940')
Go to your π€ profile to find your uploaded model, it should look similar to tonyassi/sales-prediction.
- repo_id π€ repo id of the model
- image_path path to image
predict(repo_id='tonyassi/sales-prediction',
image_path='image.jpg')
The first time this function is called it'll download the safetensor model. Subsequent function calls will run faster.
The model trainer takes a π€ dataset id as input so your dataset must be uploaded to π€. It should have a column of images and a column of values (floats or ints). Check out Create an image dataset if you need help creating a π€ dataset. Your dataset should look like tonyassi/clothing-sales-ds (the values column can be named whatever you'd like).