Skip to content

Image Regression Model Trainer. Built with PyTorch and πŸ€—.

Notifications You must be signed in to change notification settings

TonyAssi/ImageRegression

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

37 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

Image Regression

diagram2

by Tony Assi

Image Regression model training and inference. The trainer fine-tunes google/vit-base-patch16-224 with a custom PyTorch model that takes an image as input and outputs a number. You can upload the model to πŸ€— Hub and use a simple predict function for inference. Built with πŸ€— and PyTorch.

Download

git clone https://github.com/TonyAssi/ImageRegression.git
cd ImageRegression

Installation

pip install -r requirements.txt

Usage

Import

from ImageRegression import train_model, upload_model, predict

Train Model

  • dataset_id πŸ€— dataset id (see Dataset)
  • value_column_name column name of prediction values in dataset
  • test_split test split of the train/test split
  • output_dir the directory where the checkpoints will be saved
  • num_train_epochs training epochs
  • learning_rate learning rate
train_model(dataset_id='tonyassi/clothing-sales-ds',
            value_column_name='sales',
            test_split=0.2,
            output_dir='./results',
            num_train_epochs=10,
            learning_rate=1e-4)

The trainer will save the checkpoints in the output_dir location. The model.safetensors are the trained weights you'll use for inference (predicton).

Upload Model

This function will upload your model to the πŸ€— Hub, which will be useful for inference.

  • model_id the name of the model id
  • token go here to create a new πŸ€— token
  • checkpoint_dir checkpoint folder that will be uploaded
upload_model(model_id='sales-prediction',
             token='YOUR_HF_TOKEN',
             checkpoint_dir='./results/checkpoint-940')

Go to your πŸ€— profile to find your uploaded model, it should look similar to tonyassi/sales-prediction.

Inference (Prediction)

  • repo_id πŸ€— repo id of the model
  • image_path path to image
predict(repo_id='tonyassi/sales-prediction',
        image_path='image.jpg')

The first time this function is called it'll download the safetensor model. Subsequent function calls will run faster.

Dataset

The model trainer takes a πŸ€— dataset id as input so your dataset must be uploaded to πŸ€—. It should have a column of images and a column of values (floats or ints). Check out Create an image dataset if you need help creating a πŸ€— dataset. Your dataset should look like tonyassi/clothing-sales-ds (the values column can be named whatever you'd like).

Screenshot 2024-05-18 at 12 11 32 PM

About

Image Regression Model Trainer. Built with PyTorch and πŸ€—.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages