Skip to content

Commit

Permalink
add predict step
Browse files Browse the repository at this point in the history
  • Loading branch information
tuanio committed Nov 13, 2022
1 parent b405eef commit d84cd8b
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 4 deletions.
9 changes: 9 additions & 0 deletions data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def __init__(
self.train_set = train_set
self.val_set = val_set
self.test_set = test_set
self.predict_set = predict_set
self.batch_size = batch_size
self.text = text
self.num_workers = num_workers
Expand Down Expand Up @@ -54,6 +55,14 @@ def test_dataloader(self):
persistent_workers=True,
)

def predict_dataloader(self):
return DataLoader(
self.predict_set,
shuffle=False,
batch_size=self.batch_size,
persistent_workers=True,
)

def collate_fn(self, batch):
size = len(batch)
formulas = [self.text.text2int(i[1]) for i in batch]
Expand Down
24 changes: 24 additions & 0 deletions data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torchvision
from torchvision import transforms as tvt
import math
import os


class LatexDataset(Dataset):
Expand Down Expand Up @@ -32,3 +33,26 @@ def __getitem__(self, idx):
image /= image.max()
image = self.transform(image) # transform image to [-1, 1]
return image, formula


class LatexPredictDataset(Dataset):
def __init__(self, predict_img_path: str):
super().__init__()
if predict_img_path:
assert os.path.exists(predict_img_path), "Image not found"
self.walker = [predict_img_path]
else:
self.walker = []

def __len__(self):
return len(self.walker)

def __getitem__(self, index):
img_path = self.walker[idx]

image = torchvision.io.read_image(img_path)
image = image.to(dtype=torch.float)
image /= image.max()
image = self.transform(image) # transform image to [-1, 1]

return image
9 changes: 8 additions & 1 deletion image2latex/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,4 +215,11 @@ def test_step(self, batch, batch_idx):
self.log("test_bleu4", bleu4, sync_dist=True)
self.log("test_exact_match", em, sync_dist=True)

return edit_dist, bleu4, em, loss
return edit_dist, bleu4, em, loss

def predict_step(self, batch, batch_idx):
image = batch

latex = self.model.decode(image, self.max_length)

return latex
21 changes: 18 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch.utils.checkpoint import checkpoint
from torch import nn, Tensor
from image2latex.model import Image2LatexModel
from data.dataset import LatexDataset
from data.dataset import LatexDataset, LatexPredictDataset
from data.datamodule import DataModule
from image2latex.text import Text100k, Text170k
import pytorch_lightning as pl
Expand All @@ -18,13 +18,17 @@
parser.add_argument("--batch-size", type=int, default=16)
parser.add_argument("--accumulate-batch", type=int, default=32)
parser.add_argument("--data-path", type=str, help="data path")
parser.add_argument("--img-path", type=str, help="data path")
parser.add_argument("--img-path", type=str, help="image folder path")
parser.add_argument(
"--predict-img-path", type=str, help="image for predict path", default=None
)
parser.add_argument(
"--dataset", type=str, help="choose dataset [100k, 170k]", default="100k"
)
parser.add_argument("--train", action="store_true")
parser.add_argument("--val", action="store_true")
parser.add_argument("--test", action="store_true")
parser.add_argument("--predict", action="store_true")
parser.add_argument("--log-text", action="store_true")
parser.add_argument("--train-sample", type=int, default=5000)
parser.add_argument("--val-sample", type=int, default=1000)
Expand Down Expand Up @@ -86,11 +90,18 @@
n_sample=args.test_sample,
dataset=args.dataset,
)
predict_set = LatexPredictDataset(predict_img_path=args.predict_img_path)

steps_per_epoch = round(len(train_set) / args.batch_size)
total_steps = steps_per_epoch * args.max_epochs
dm = DataModule(
train_set, val_set, test_set, args.num_workers, args.batch_size, text
train_set,
val_set,
test_set,
predict_set,
args.num_workers,
args.batch_size,
text,
)

model = Image2LatexModel(
Expand Down Expand Up @@ -146,3 +157,7 @@
if args.test:
print("=" * 10 + "[Test]" + "=" * 10)
trainer.test(datamodule=dm, model=model, ckpt_path=ckpt_path)

if args.predict:
print("=" * 10 + "[Predict]" + "=" * 10)
trainer.predict(datamodule=dm, model=model, ckpt_path=ckpt_path)
Binary file added report/Img2Latex.pdf
Binary file not shown.

0 comments on commit d84cd8b

Please sign in to comment.