Skip to content

Commit

Permalink
Make rxnfp work with more recent transformers / simpletransformers ve…
Browse files Browse the repository at this point in the history
…rsions, refactor tokenizer
  • Loading branch information
pschwllr committed Aug 11, 2021
1 parent 459c6bb commit 4e4bbbd
Show file tree
Hide file tree
Showing 24 changed files with 1,296 additions and 1,330 deletions.
5 changes: 2 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

## Install


For all installations, we recommend using `conda` to get the necessary `rdkit` and `tmap` dependencies:

### From pypi
Expand Down Expand Up @@ -35,7 +34,7 @@ Compute a fingerprint from a reaction SMILES
```python
```

```
```python
from rxnfp.transformer_fingerprints import (
RXNBERTFingerprintGenerator, get_default_model_and_tokenizer, generate_fingerprints
)
Expand All @@ -57,7 +56,7 @@ print(fp[:5])

Or for a list of reactions:

```
```python
rxns = [example_rxn, example_rxn]
fps = rxnfp_generator.convert_batch(rxns)
print(len(fps), len(fps[0]))
Expand Down
169 changes: 23 additions & 146 deletions docs/fine_tune_bert_on_uspto_1k_tpl.html

Large diffs are not rendered by default.

45 changes: 12 additions & 33 deletions docs/generate_fingerprints.html
Original file line number Diff line number Diff line change
Expand Up @@ -163,34 +163,6 @@ <h3 id="Load-data">Load data<a class="anchor-link" href="#Load-data"> </a></h3>
</div>
</div>

</div>
{% endraw %}

<div class="cell border-box-sizing text_cell rendered"><div class="inner_cell">
<div class="text_cell_render border-box-sizing rendered_html">
<h3 id="Initialize-fingerprint-generator-and-convert-reactions">Initialize fingerprint generator and convert reactions<a class="anchor-link" href="#Initialize-fingerprint-generator-and-convert-reactions"> </a></h3>
</div>
</div>
</div>
{% raw %}

<div class="cell border-box-sizing code_cell rendered">
<div class="input">

<div class="inner_cell">
<div class="input_area">
<div class=" highlight hl-ipython3"><pre><span></span><span class="n">model</span><span class="p">,</span> <span class="n">tokenizer</span> <span class="o">=</span> <span class="n">get_default_model_and_tokenizer</span><span class="p">(</span><span class="s1">&#39;bert_ft_10k_25s&#39;</span><span class="p">)</span>
<span class="n">ft_10k_rxnfp_generator</span> <span class="o">=</span> <span class="n">RXNBERTFingerprintGenerator</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">tokenizer</span><span class="p">)</span>
<span class="n">model</span><span class="p">,</span> <span class="n">tokenizer</span> <span class="o">=</span> <span class="n">get_default_model_and_tokenizer</span><span class="p">(</span><span class="s1">&#39;bert_ft&#39;</span><span class="p">)</span>
<span class="n">ft_rxnfp_generator</span> <span class="o">=</span> <span class="n">RXNBERTFingerprintGenerator</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">tokenizer</span><span class="p">)</span>
<span class="n">model</span><span class="p">,</span> <span class="n">tokenizer</span> <span class="o">=</span> <span class="n">get_default_model_and_tokenizer</span><span class="p">(</span><span class="s1">&#39;bert_pretrained&#39;</span><span class="p">)</span>
<span class="n">pretrained_rxnfp_generator</span> <span class="o">=</span> <span class="n">RXNBERTFingerprintGenerator</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">tokenizer</span><span class="p">)</span>
</pre></div>

</div>
</div>
</div>

</div>
{% endraw %}

Expand All @@ -209,7 +181,10 @@ <h3 id="ft_10k-model">ft_10k model<a class="anchor-link" href="#ft_10k-model"> <

<div class="inner_cell">
<div class="input_area">
<div class=" highlight hl-ipython3"><pre><span></span><span class="n">fps_ft_10k</span> <span class="o">=</span> <span class="n">generate_fingerprints</span><span class="p">(</span><span class="n">df</span><span class="o">.</span><span class="n">rxn</span><span class="o">.</span><span class="n">values</span><span class="o">.</span><span class="n">tolist</span><span class="p">(),</span> <span class="n">ft_10k_rxnfp_generator</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">8</span><span class="p">)</span>
<div class=" highlight hl-ipython3"><pre><span></span><span class="n">model</span><span class="p">,</span> <span class="n">tokenizer</span> <span class="o">=</span> <span class="n">get_default_model_and_tokenizer</span><span class="p">(</span><span class="s1">&#39;bert_ft_10k_25s&#39;</span><span class="p">)</span>
<span class="n">ft_10k_rxnfp_generator</span> <span class="o">=</span> <span class="n">RXNBERTFingerprintGenerator</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">tokenizer</span><span class="p">)</span>

<span class="n">fps_ft_10k</span> <span class="o">=</span> <span class="n">generate_fingerprints</span><span class="p">(</span><span class="n">df</span><span class="o">.</span><span class="n">rxn</span><span class="o">.</span><span class="n">values</span><span class="o">.</span><span class="n">tolist</span><span class="p">(),</span> <span class="n">ft_10k_rxnfp_generator</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">8</span><span class="p">)</span>
<span class="n">np</span><span class="o">.</span><span class="n">savez_compressed</span><span class="p">(</span><span class="s1">&#39;../data/fps_ft_10k&#39;</span><span class="p">,</span> <span class="n">fps</span><span class="o">=</span><span class="n">fps_ft_10k</span><span class="p">)</span>
<span class="n">fps_ft_10k</span><span class="o">.</span><span class="n">shape</span>
</pre></div>
Expand All @@ -224,7 +199,7 @@ <h3 id="ft_10k-model">ft_10k model<a class="anchor-link" href="#ft_10k-model"> <
<div class="output_area">

<div class="output_subarea output_stream output_stderr output_text">
<pre>100%|██████████| 6250/6250 [02:50&lt;00:00, 36.70it/s]
<pre>100%|██████████| 6250/6250 [02:52&lt;00:00, 36.31it/s]
</pre>
</div>
</div>
Expand Down Expand Up @@ -283,7 +258,9 @@ <h3 id="pretrained-model">pretrained model<a class="anchor-link" href="#pretrain

<div class="inner_cell">
<div class="input_area">
<div class=" highlight hl-ipython3"><pre><span></span><span class="n">fps_pretrained</span> <span class="o">=</span> <span class="n">generate_fingerprints</span><span class="p">(</span><span class="n">df</span><span class="o">.</span><span class="n">rxn</span><span class="o">.</span><span class="n">values</span><span class="o">.</span><span class="n">tolist</span><span class="p">(),</span> <span class="n">pretrained_rxnfp_generator</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">8</span><span class="p">)</span>
<div class=" highlight hl-ipython3"><pre><span></span><span class="n">model</span><span class="p">,</span> <span class="n">tokenizer</span> <span class="o">=</span> <span class="n">get_default_model_and_tokenizer</span><span class="p">(</span><span class="s1">&#39;bert_pretrained&#39;</span><span class="p">)</span>
<span class="n">pretrained_rxnfp_generator</span> <span class="o">=</span> <span class="n">RXNBERTFingerprintGenerator</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">tokenizer</span><span class="p">)</span>
<span class="n">fps_pretrained</span> <span class="o">=</span> <span class="n">generate_fingerprints</span><span class="p">(</span><span class="n">df</span><span class="o">.</span><span class="n">rxn</span><span class="o">.</span><span class="n">values</span><span class="o">.</span><span class="n">tolist</span><span class="p">(),</span> <span class="n">pretrained_rxnfp_generator</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">8</span><span class="p">)</span>
<span class="n">np</span><span class="o">.</span><span class="n">savez_compressed</span><span class="p">(</span><span class="s1">&#39;../data/fps_pretrained&#39;</span><span class="p">,</span> <span class="n">fps</span><span class="o">=</span><span class="n">fps_pretrained</span><span class="p">)</span>
<span class="n">fps_pretrained</span><span class="o">.</span><span class="n">shape</span>
</pre></div>
Expand Down Expand Up @@ -332,7 +309,9 @@ <h3 id="ft-model">ft model<a class="anchor-link" href="#ft-model"> </a></h3>

<div class="inner_cell">
<div class="input_area">
<div class=" highlight hl-ipython3"><pre><span></span><span class="n">fps_ft</span> <span class="o">=</span> <span class="n">generate_fingerprints</span><span class="p">(</span><span class="n">df</span><span class="o">.</span><span class="n">rxn</span><span class="o">.</span><span class="n">values</span><span class="o">.</span><span class="n">tolist</span><span class="p">(),</span> <span class="n">ft_rxnfp_generator</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">8</span><span class="p">)</span>
<div class=" highlight hl-ipython3"><pre><span></span><span class="n">model</span><span class="p">,</span> <span class="n">tokenizer</span> <span class="o">=</span> <span class="n">get_default_model_and_tokenizer</span><span class="p">(</span><span class="s1">&#39;bert_ft&#39;</span><span class="p">)</span>
<span class="n">ft_rxnfp_generator</span> <span class="o">=</span> <span class="n">RXNBERTFingerprintGenerator</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">tokenizer</span><span class="p">)</span>
<span class="n">fps_ft</span> <span class="o">=</span> <span class="n">generate_fingerprints</span><span class="p">(</span><span class="n">df</span><span class="o">.</span><span class="n">rxn</span><span class="o">.</span><span class="n">values</span><span class="o">.</span><span class="n">tolist</span><span class="p">(),</span> <span class="n">ft_rxnfp_generator</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">8</span><span class="p">)</span>
<span class="n">np</span><span class="o">.</span><span class="n">savez_compressed</span><span class="p">(</span><span class="s1">&#39;../data/fps_ft&#39;</span><span class="p">,</span> <span class="n">fps</span><span class="o">=</span><span class="n">fps_ft</span><span class="p">)</span>
<span class="n">fps_ft</span><span class="o">.</span><span class="n">shape</span>
</pre></div>
Expand All @@ -347,7 +326,7 @@ <h3 id="ft-model">ft model<a class="anchor-link" href="#ft-model"> </a></h3>
<div class="output_area">

<div class="output_subarea output_stream output_stderr output_text">
<pre>100%|██████████| 6250/6250 [02:54&lt;00:00, 35.82it/s]
<pre>100%|██████████| 6250/6250 [00:56&lt;00:00, 111.34it/s]
</pre>
</div>
</div>
Expand Down
Loading

0 comments on commit 4e4bbbd

Please sign in to comment.