Skip to content

Commit

Permalink
Merge pull request #53 from jmisilo/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
jmisilo authored Nov 6, 2022
2 parents 6980a52 + ef9922b commit 9a80cec
Show file tree
Hide file tree
Showing 14 changed files with 53 additions and 22 deletions.
Binary file modified examples/23012796.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/36979.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed examples/3787801.jpg
Binary file not shown.
Binary file removed examples/7757242158.jpg
Binary file not shown.
Binary file added examples/89407459.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/loss_lr.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
17 changes: 10 additions & 7 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,17 @@ The Model uses prefixes as in the [ClipCap](https://arxiv.org/abs/2111.09734) pa

The Model was trained with a frozen CLIP, a fully trained Mapping Module (6x Transformer Encoder Layers) and with partially frozen GPT-2 (the first and last 14 layers were trained).

The training process was carried out using the [Kaggle](https://www.kaggle.com/) P100 GPU. Training time is about 2 x 11h (106 epochs) with a linearly changing learning rate (from 0 to 0.0001908) and batch size 64. Originally, the Model was supposed to be trained longer - which results in a non-standard LR. *I also tried a longer training session (150 epochs), but overtraining was noticeable.*
The training process was carried out using the [Kaggle](https://www.kaggle.com/) P100 GPU. Training time - about 3 x 11h (150 epochs) with a linear learning rate warmup (max LR `3e-3`) and batch size 64.

### Example results

![Example1](./examples/23012796.jpg)
#### Loss and Learning Rate during training

![Example2](./examples/3787801.jpg)
![LOSSxLR](./examples/loss_lr.jpg)

![Example3](./examples/7757242158.jpg)
### Example results

As I said, the goal was to test the Model's ability to recognize the situation. In the next phase of the experiments, I will try to improve the Model process and parameters to achieve better captions with the same dataset.
![Example1](./examples/23012796.jpg)
![Example2](./examples/36979.jpg)
![Example3](./examples/89407459.jpg)

### Usage

Expand All @@ -36,7 +36,10 @@ Create environment and install requirements:

```bash
python -m venv venv
# for windows
.\venv\Scripts\activate
# for linux/mac
source venv/bin/activate

pip install -r requirements.txt
```
Expand Down
10 changes: 9 additions & 1 deletion src/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,14 @@
help='Path to the results folder'
)

parser.add_argument(
'-T',
'--temperature',
type=float,
default=1.0,
help='Temperature for sampling'
)

args = parser.parse_args()

ckp_path = os.path.join(config.weights_dir, args.checkpoint_name)
Expand Down Expand Up @@ -93,4 +101,4 @@
if not os.path.exists(save_path):
os.mkdir(save_path)

evaluate_dataset(model, test_dataset, args.img_path, save_path)
evaluate_dataset(model, test_dataset, args.img_path, save_path, args.temperature)
4 changes: 2 additions & 2 deletions src/model/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def test_step(model, dataset, img_path, num_examples=4):

return Image.open(buf)

def evaluate_dataset(model, dataset, img_path, save_path):
def evaluate_dataset(model, dataset, img_path, save_path, temperature=1.0):
'''
Evaluate model on dataset.
Expand All @@ -147,7 +147,7 @@ def evaluate_dataset(model, dataset, img_path, save_path):
img = Image.open(os.path.join(img_path, img_name))

with torch.no_grad():
caption, _ = model(img)
caption, _ = model(img, temperature)

plt.imshow(img)
plt.title(caption)
Expand Down
9 changes: 7 additions & 2 deletions src/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def freeze_layers(self):
for p in [*list(self.ie.parameters()), *list(self.td.parameters())[14:-14]]: # freeze everything, except 1st and last transformer layer in Decoder
p.requires_grad = False

def forward(self, img):
def forward(self, img, temperature=1.0):
'''
Caption generation for a single image.
Expand All @@ -157,7 +157,10 @@ def forward(self, img):
caption: generated caption [str]
tokens: generated tokens [torch.Tensor]
'''
# only one image at a time

if temperature <= 0.0:
temperature = 1.0
print('Temperature must be positive. Setting it to 1.0')

with torch.no_grad():
img_embedded = self.ie(img)
Expand Down Expand Up @@ -188,6 +191,8 @@ def forward(self, img):
emb += pos_emb
pred = self.td(emb)

pred = torch.softmax(pred / temperature, dim=-1)

_, pred = torch.max(pred, dim=1)

last_token = pred[-1].item()
Expand Down
10 changes: 9 additions & 1 deletion src/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@
help='Path to the results folder'
)

parser.add_argument(
'-T',
'--temperature',
type=float,
default=1.0,
help='Temperature for sampling'
)

args = parser.parse_args()

# set seed
Expand Down Expand Up @@ -87,7 +95,7 @@
model.eval()

with torch.no_grad():
caption, _ = model(img)
caption, _ = model(img, args.temperature)

plt.imshow(img)
plt.title(caption)
Expand Down
15 changes: 11 additions & 4 deletions src/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,11 @@
scaler = torch.cuda.amp.GradScaler()

ckp_path = os.path.join(config.weights_dir, args.checkpoint_name)
start_epoch = load_ckp(ckp_path, model, optimizer, scheduler, scaler, device) if os.path.isfile(ckp_path) else 0
start_epoch, total_train_loss, total_valid_loss = (
load_ckp(ckp_path, model, optimizer, scheduler, scaler, device)
if os.path.isfile(ckp_path) else
(0, [], [])
)

# build train model process with experiment tracking from wandb
wandb.init(project='clipXgpt2 captioner', config=config.__dict__)
Expand All @@ -107,19 +111,22 @@
'examples': wandb.Image(test_results)
})

total_train_loss.append(train_loss)
total_valid_loss.append(valid_loss)

if not os.path.exists(config.weights_dir):
os.makedirs(config.weights_dir)

if (epoch + 1) % 10 == 0:
torch.save(
{
'epoch': epoch,
'train_loss': train_loss,
'valid_loss': valid_loss,
'model1_state_dict': model.state_dict(),
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'scaler_state_dict': scaler.state_dict(),
'tloss': total_train_loss,
'vloss': total_valid_loss
},
os.path.join(config.weights_dir, f'epoch_{epoch}.pt')
)
4 changes: 2 additions & 2 deletions src/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ class Config:
num_workers: int = 2
train_size: int = 0.84
val_size: int = 0.13
epochs: int = 200
lr: int = 6e-3
epochs: int = 150
lr: int = 3e-3
k: float = 0.33
batch_size_exp: int = 6
ep_len: int = 4
Expand Down
6 changes: 3 additions & 3 deletions src/utils/load_ckp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def load_ckp(checkpoint_fpath, model, optimizer=None, scheduler=None, scaler=Non

checkpoint = torch.load(checkpoint_fpath, map_location=device)

model.load_state_dict(checkpoint['model1_state_dict'])
model.load_state_dict(checkpoint['model_state_dict'])
if optimizer is not None:
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

Expand All @@ -22,11 +22,11 @@ def load_ckp(checkpoint_fpath, model, optimizer=None, scheduler=None, scaler=Non
if scaler is not None:
scaler.load_state_dict(checkpoint['scaler_state_dict'])

return checkpoint['epoch']
return checkpoint['epoch'], checkpoint['tloss'], checkpoint['vloss']

def download_weights(checkpoint_fpath):
'''
Downloads weights from Google Drive.
'''

gdown.download('https://drive.google.com/uc?id=1lEufQVOETFEIhPdFDYaez31uroq_5Lby', checkpoint_fpath, quiet=False)
gdown.download('https://drive.google.com/uc?id=10ieSMMJzE9EeiPIF3CMzeT4timiQTjHV', checkpoint_fpath, quiet=False)

0 comments on commit 9a80cec

Please sign in to comment.