diff --git a/.gitignore b/.gitignore index 1777e86..53d473c 100644 --- a/.gitignore +++ b/.gitignore @@ -14,7 +14,7 @@ dist/ downloads/ eggs/ .eggs/ -lib/ +# lib/ lib64/ parts/ sdist/ diff --git a/README.md b/README.md index 3bee2e0..f6e0292 100644 --- a/README.md +++ b/README.md @@ -60,9 +60,9 @@ Don't forget to update the path to the tokenizer in the config file and set `num The model consist of a ViT [[1](#References)] encoder with a ResNet backbone and a Transformer [[2](#References)] decoder. ### Performance -| BLEU score | normed edit distance | -| ---------- | -------------------- | -| 0.88 | 0.10 | +| BLEU score | normed edit distance | token accuracy | +| ---------- | -------------------- | -------------- | +| 0.88 | 0.10 | 0.60 | ## Data We need paired data for the network to learn. Luckily there is a lot of LaTeX code on the internet, e.g. [wikipedia](https://www.wikipedia.org), [arXiv](https://www.arxiv.org). We also use the formulae from the [im2latex-100k](https://zenodo.org/record/56198#.V2px0jXT6eA) [[3](#References)] dataset. diff --git a/pix2tex/__init__.py b/pix2tex/__init__.py index e69de29..7c990f2 100644 --- a/pix2tex/__init__.py +++ b/pix2tex/__init__.py @@ -0,0 +1,2 @@ +import os +os.environ['FOR_DISABLE_CONSOLE_CTRL_HANDLER'] = '1' diff --git a/pix2tex/dataset/__init__.py b/pix2tex/dataset/__init__.py index a7ec24a..e69de29 100644 --- a/pix2tex/dataset/__init__.py +++ b/pix2tex/dataset/__init__.py @@ -1,6 +0,0 @@ -import pix2tex.dataset.arxiv -import pix2tex.dataset.extract_latex -import pix2tex.dataset.latex2png -import pix2tex.dataset.render -import pix2tex.dataset.scraping -import pix2tex.dataset.dataset diff --git a/pix2tex/dataset/arxiv.py b/pix2tex/dataset/arxiv.py index 01c98af..5351042 100644 --- a/pix2tex/dataset/arxiv.py +++ b/pix2tex/dataset/arxiv.py @@ -1,7 +1,7 @@ # modified from https://github.com/soskek/arxiv_leaks import argparse -import json +import subprocess import os import glob import re @@ -10,7 +10,6 @@ import logging import tarfile import tempfile -import chardet import logging import requests import urllib.request @@ -22,7 +21,7 @@ # logging.getLogger().setLevel(logging.INFO) arxiv_id = re.compile(r'(? str: ''' replaces all layered brackets with special symbols @@ -66,7 +62,9 @@ def sweep(t, cmds): nargs = int(c[1][1]) if c[1] != r'' else 0 optional = c[2] != r'' if nargs == 0: - t = re.sub(r'\\%s([\W_^\d])' % c[0], r'%s\1' % c[-1].replace('\\', r'\\'), t) + num_matches += len(re.findall(r'\\%s([\W_^\dĊ])' % c[0], t)) + if num_matches > 0: + t = re.sub(r'\\%s([\W_^\dĊ])' % c[0], r'%s\1' % c[-1].replace('\\', r'\\'), t) else: matches = re.findall(r'(\\%s(?:\[(.+?)\])?' % c[0]+r'{(.+?)}'*(nargs-(1 if optional else 0))+r')', t) num_matches += len(matches) @@ -81,18 +79,49 @@ def sweep(t, cmds): def unfold(t): - t = remove_labels(t).replace('\n', 'Ċ') - - cmds = re.findall(r'\\(?:re)?newcommand\*?{\\(.+?)}\s*(\[\d\])?(\[.+?\])?{(.+?)}Ċ', t) + #t = queue.get() + t = t.replace('\n', 'Ċ') + t = bracket_replace(t) + commands_pattern = r'\\(?:re)?newcommand\*?{\\(.+?)}[\sĊ]*(\[\d\])?[\sĊ]*(\[.+?\])?[\sĊ]*{(.*?)}\s*(?:Ċ|\\)' + cmds = re.findall(commands_pattern, t) + t = re.sub(r'(? 1: + # something went wrong here. No multiple definitions allowed + del cmds[i] + elif '\\newcommand' in cmds[i][-1]: + logging.debug("Command recognition pattern didn't work properly. %s" % (undo_bracket_replace(cmds[i][-1]))) + del cmds[i] + start = time.time() + try: + for i in range(10): + # check for up to 10 nested commands + if i > 0: + t = bracket_replace(t) + t, N = sweep(t, cmds) + if time.time()-start > 5: # not optimal. more sophisticated methods didnt work or are slow + raise TimeoutError + t = undo_bracket_replace(t) + if N == 0 or i == 9: + #print("Needed %i iterations to demacro" % (i+1)) + break + elif N > 4000: + raise ValueError("Too many matches. Processing would take too long.") + except ValueError: + pass + except TimeoutError: + pass + except re.error as e: + raise DemacroError(e) + t = remove_labels(t.replace('Ċ', '\n')) + # queue.put(t) + return t + + +def pydemacro(t): + return unfold(convert(re.sub('\n+', '\n', re.sub(r'(? 0: @@ -72,9 +72,12 @@ def recursive_wiki(seeds, depth=4, skip=[]): url = [sys.argv[1]] else: url = ['https://en.wikipedia.org/wiki/Mathematics', 'https://en.wikipedia.org/wiki/Physics'] - visited, math = recursive_wiki(url) + try: + visited, math = recursive_wiki(url) + except KeyboardInterrupt: + pass for l, name in zip([visited, math], ['visited_wiki.txt', 'math_wiki.txt']): - f = open(os.path.join(sys.path[0], 'dataset', 'data', name), 'a', encoding='utf-8') + f = open(os.path.join(sys.path[0], 'data', name), 'a', encoding='utf-8') for element in l: f.write(element) f.write('\n') diff --git a/pix2tex/eval.py b/pix2tex/eval.py index b0735ca..b81c0d6 100644 --- a/pix2tex/eval.py +++ b/pix2tex/eval.py @@ -44,13 +44,12 @@ def evaluate(model: Model, dataset: Im2LatexDataset, args: Munch, num_batches: i assert len(dataset) > 0 device = args.device log = {} - bleus, edit_dists = [], [] - bleu_score, edit_distance = 0, 1 + bleus, edit_dists, token_acc = [], [], [] + bleu_score, edit_distance, token_accuracy = 0, 1, 0 pbar = tqdm(enumerate(iter(dataset)), total=len(dataset)) for i, (seq, im) in pbar: if seq is None or im is None: continue - tgt_seq, tgt_mask = seq['input_ids'].to(device), seq['attention_mask'].bool().to(device) encoded = model.encoder(im.to(device)) #loss = decoder(tgt_seq, mask=tgt_mask, context=encoded) dec = model.decoder.generate(torch.LongTensor([args.bos_token]*len(encoded))[:, None].to(device), args.max_seq_len, @@ -62,7 +61,17 @@ def evaluate(model: Model, dataset: Im2LatexDataset, args: Munch, num_batches: i ts = post_process(truthi) if len(ts) > 0: edit_dists.append(distance(post_process(predi), ts)/len(ts)) - pbar.set_description('BLEU: %.3f, ED: %.2e' % (np.mean(bleus), np.mean(edit_dists))) + dec = dec.cpu() + tgt_seq = seq['input_ids'][:, 1:] + shape_diff = dec.shape[1]-tgt_seq.shape[1] + if shape_diff < 0: + dec = torch.nn.functional.pad(dec, (0, -shape_diff), "constant", args.pad_token) + elif shape_diff > 0: + tgt_seq = torch.nn.functional.pad(tgt_seq, (0, shape_diff), "constant", args.pad_token) + mask = torch.logical_or(tgt_seq != args.pad_token, dec != args.pad_token) + tok_acc = (dec == tgt_seq)[mask].float().mean().item() + token_acc.append(tok_acc) + pbar.set_description('BLEU: %.3f, ED: %.2e, ACC: %.3f' % (np.mean(bleus), np.mean(edit_dists), np.mean(token_acc))) if num_batches is not None and i >= num_batches: break if len(bleus) > 0: @@ -71,6 +80,9 @@ def evaluate(model: Model, dataset: Im2LatexDataset, args: Munch, num_batches: i if len(edit_dists) > 0: edit_distance = np.mean(edit_dists) log[name+'/edit_distance'] = edit_distance + if len(token_acc) > 0: + token_accuracy = np.mean(token_acc) + log[name+'/token_acc'] = token_accuracy if args.wandb: # samples pred = token2str(dec, dataset.tokenizer) @@ -83,7 +95,7 @@ def evaluate(model: Model, dataset: Im2LatexDataset, args: Munch, num_batches: i else: print('\n%s\n%s' % (truth, pred)) print('BLEU: %.2f' % bleu_score) - return bleu_score, edit_distance + return bleu_score, edit_distance, token_accuracy if __name__ == '__main__': diff --git a/setup.py b/setup.py index 41a5d79..ee8a93e 100644 --- a/setup.py +++ b/setup.py @@ -22,7 +22,7 @@ setuptools.setup( name='pix2tex', - version='0.0.15', + version='0.0.20', description='pix2tex: Using a ViT to convert images of equations into LaTeX code.', long_description=long_description, long_description_content_type='text/markdown', @@ -58,7 +58,6 @@ 'PyYAML>=5.4.1', 'pandas>=1.0.0', 'timm', - 'chardet>=3.0.4', 'python-Levenshtein>=0.12.2', 'torchtext>=0.6.0', 'albumentations>=0.5.2',