diff --git a/data/datamodule.py b/data/datamodule.py index 3051602..a16d68f 100644 --- a/data/datamodule.py +++ b/data/datamodule.py @@ -19,6 +19,7 @@ def __init__( self.train_set = train_set self.val_set = val_set self.test_set = test_set + self.predict_set = predict_set self.batch_size = batch_size self.text = text self.num_workers = num_workers @@ -54,6 +55,14 @@ def test_dataloader(self): persistent_workers=True, ) + def predict_dataloader(self): + return DataLoader( + self.predict_set, + shuffle=False, + batch_size=self.batch_size, + persistent_workers=True, + ) + def collate_fn(self, batch): size = len(batch) formulas = [self.text.text2int(i[1]) for i in batch] diff --git a/data/dataset.py b/data/dataset.py index d78602a..e6f8175 100644 --- a/data/dataset.py +++ b/data/dataset.py @@ -4,6 +4,7 @@ import torchvision from torchvision import transforms as tvt import math +import os class LatexDataset(Dataset): @@ -32,3 +33,26 @@ def __getitem__(self, idx): image /= image.max() image = self.transform(image) # transform image to [-1, 1] return image, formula + + +class LatexPredictDataset(Dataset): + def __init__(self, predict_img_path: str): + super().__init__() + if predict_img_path: + assert os.path.exists(predict_img_path), "Image not found" + self.walker = [predict_img_path] + else: + self.walker = [] + + def __len__(self): + return len(self.walker) + + def __getitem__(self, index): + img_path = self.walker[idx] + + image = torchvision.io.read_image(img_path) + image = image.to(dtype=torch.float) + image /= image.max() + image = self.transform(image) # transform image to [-1, 1] + + return image diff --git a/image2latex/model.py b/image2latex/model.py index 4e859e4..26fe42d 100644 --- a/image2latex/model.py +++ b/image2latex/model.py @@ -215,4 +215,11 @@ def test_step(self, batch, batch_idx): self.log("test_bleu4", bleu4, sync_dist=True) self.log("test_exact_match", em, sync_dist=True) - return edit_dist, bleu4, em, loss \ No newline at end of file + return edit_dist, bleu4, em, loss + + def predict_step(self, batch, batch_idx): + image = batch + + latex = self.model.decode(image, self.max_length) + + return latex diff --git a/main.py b/main.py index 66e3759..03ec11c 100644 --- a/main.py +++ b/main.py @@ -6,7 +6,7 @@ from torch.utils.checkpoint import checkpoint from torch import nn, Tensor from image2latex.model import Image2LatexModel -from data.dataset import LatexDataset +from data.dataset import LatexDataset, LatexPredictDataset from data.datamodule import DataModule from image2latex.text import Text100k, Text170k import pytorch_lightning as pl @@ -18,13 +18,17 @@ parser.add_argument("--batch-size", type=int, default=16) parser.add_argument("--accumulate-batch", type=int, default=32) parser.add_argument("--data-path", type=str, help="data path") - parser.add_argument("--img-path", type=str, help="data path") + parser.add_argument("--img-path", type=str, help="image folder path") + parser.add_argument( + "--predict-img-path", type=str, help="image for predict path", default=None + ) parser.add_argument( "--dataset", type=str, help="choose dataset [100k, 170k]", default="100k" ) parser.add_argument("--train", action="store_true") parser.add_argument("--val", action="store_true") parser.add_argument("--test", action="store_true") + parser.add_argument("--predict", action="store_true") parser.add_argument("--log-text", action="store_true") parser.add_argument("--train-sample", type=int, default=5000) parser.add_argument("--val-sample", type=int, default=1000) @@ -86,11 +90,18 @@ n_sample=args.test_sample, dataset=args.dataset, ) + predict_set = LatexPredictDataset(predict_img_path=args.predict_img_path) steps_per_epoch = round(len(train_set) / args.batch_size) total_steps = steps_per_epoch * args.max_epochs dm = DataModule( - train_set, val_set, test_set, args.num_workers, args.batch_size, text + train_set, + val_set, + test_set, + predict_set, + args.num_workers, + args.batch_size, + text, ) model = Image2LatexModel( @@ -146,3 +157,7 @@ if args.test: print("=" * 10 + "[Test]" + "=" * 10) trainer.test(datamodule=dm, model=model, ckpt_path=ckpt_path) + + if args.predict: + print("=" * 10 + "[Predict]" + "=" * 10) + trainer.predict(datamodule=dm, model=model, ckpt_path=ckpt_path) diff --git a/report/Img2Latex.pdf b/report/Img2Latex.pdf new file mode 100644 index 0000000..37635d5 Binary files /dev/null and b/report/Img2Latex.pdf differ