+
+
+
+
+
+
\ No newline at end of file
diff --git a/Likunlin_final/analyse_text/tests.py b/Likunlin_final/analyse_text/tests.py
new file mode 100644
index 00000000000000..7ce503c2dd97ba
--- /dev/null
+++ b/Likunlin_final/analyse_text/tests.py
@@ -0,0 +1,3 @@
+from django.test import TestCase
+
+# Create your tests here.
diff --git a/Likunlin_final/analyse_text/views.py b/Likunlin_final/analyse_text/views.py
new file mode 100644
index 00000000000000..e40c6ca8d467f3
--- /dev/null
+++ b/Likunlin_final/analyse_text/views.py
@@ -0,0 +1,27 @@
+from django.shortcuts import render
+# -*- coding: utf-8 -*-
+from django.shortcuts import render
+from django.http import HttpResponse
+import json
+import sys
+sys.path =['/home/xd/projects/pytorch-pretrained-BERT'] + sys.path
+from likunlin_final import analyze_text,modify_text
+
+text = []
+def home(request):
+ return render(request, 'home.html')
+
+
+def analyse(request):
+ global text
+ text = request.GET['text']
+ text = [text]
+ print("xiaofang")
+ suggestions,tokens,avg_gap = analyze_text(text)
+ return HttpResponse(json.dumps({"tokens":tokens,"suggestions":suggestions,"avg_gap":avg_gap}))
+
+def modify(request):
+ global text
+ index = request.GET['index']
+ text,new_tokens,suggestions = modify_text(int(index),text)
+ return HttpResponse(json.dumps({"tokens":new_tokens,"suggestions":suggestions}))
diff --git a/Likunlin_final/manage.py b/Likunlin_final/manage.py
new file mode 100755
index 00000000000000..30c456de702310
--- /dev/null
+++ b/Likunlin_final/manage.py
@@ -0,0 +1,21 @@
+#!/usr/bin/env python
+"""Django's command-line utility for administrative tasks."""
+import os
+import sys
+
+
+def main():
+ os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'Likunlin_final.settings')
+ try:
+ from django.core.management import execute_from_command_line
+ except ImportError as exc:
+ raise ImportError(
+ "Couldn't import Django. Are you sure it's installed and "
+ "available on your PYTHONPATH environment variable? Did you "
+ "forget to activate a virtual environment?"
+ ) from exc
+ execute_from_command_line(sys.argv)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/MANIFEST.in b/MANIFEST.in
new file mode 100644
index 00000000000000..1aba38f67a2211
--- /dev/null
+++ b/MANIFEST.in
@@ -0,0 +1 @@
+include LICENSE
diff --git a/README.md b/README.md
index eb337d8253f465..4e7d3bb1090bb4 100644
--- a/README.md
+++ b/README.md
@@ -1,5 +1,7 @@
# PyTorch Pretrained Bert
+[![CircleCI](https://circleci.com/gh/huggingface/pytorch-pretrained-BERT.svg?style=svg)](https://circleci.com/gh/huggingface/pytorch-pretrained-BERT)
+
This repository contains an op-for-op PyTorch reimplementation of [Google's TensorFlow repository for the BERT model](https://github.com/google-research/bert) that was released together with the paper [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova.
This implementation is provided with [Google's pre-trained models](https://github.com/google-research/bert), examples, notebooks and a command-line interface to load any pre-trained TensorFlow checkpoint for BERT is also provided.
@@ -14,12 +16,12 @@ This implementation is provided with [Google's pre-trained models](https://githu
| [Doc](#doc) | Detailed documentation |
| [Examples](#examples) | Detailed examples on how to fine-tune Bert |
| [Notebooks](#notebooks) | Introduction on the provided Jupyter Notebooks |
-| [TPU](#tup) | Notes on TPU support and pretraining scripts |
+| [TPU](#tpu) | Notes on TPU support and pretraining scripts |
| [Command-line interface](#Command-line-interface) | Convert a TensorFlow checkpoint in a PyTorch dump |
## Installation
-This repo was tested on Python 3.5+ and PyTorch 0.4.1
+This repo was tested on Python 3.5+ and PyTorch 0.4.1/1.0.0
### With pip
@@ -46,13 +48,15 @@ python -m pytest -sv tests/
This package comprises the following classes that can be imported in Python and are detailed in the [Doc](#doc) section of this readme:
-- Six PyTorch models (`torch.nn.Module`) for Bert with pre-trained weights (in the [`modeling.py`](./pytorch_pretrained_bert/modeling.py) file):
- - [`BertModel`](./pytorch_pretrained_bert/modeling.py#L535) - raw BERT Transformer model (**fully pre-trained**),
- - [`BertForMaskedLM`](./pytorch_pretrained_bert/modeling.py#L689) - BERT Transformer with the pre-trained masked language modeling head on top (**fully pre-trained**),
- - [`BertForNextSentencePrediction`](./pytorch_pretrained_bert/modeling.py#L750) - BERT Transformer with the pre-trained next sentence prediction classifier on top (**fully pre-trained**),
- - [`BertForPreTraining`](./pytorch_pretrained_bert/modeling.py#L618) - BERT Transformer with masked language modeling head and next sentence prediction classifier on top (**fully pre-trained**),
- - [`BertForSequenceClassification`](./pytorch_pretrained_bert/modeling.py#L812) - BERT Transformer with a sequence classification head on top (BERT Transformer is **pre-trained**, the sequence classification head **is only initialized and has to be trained**),
- - [`BertForQuestionAnswering`](./pytorch_pretrained_bert/modeling.py#L877) - BERT Transformer with a token classification head on top (BERT Transformer is **pre-trained**, the token classification head **is only initialized and has to be trained**).
+- Eight PyTorch models (`torch.nn.Module`) for Bert with pre-trained weights (in the [`modeling.py`](./pytorch_pretrained_bert/modeling.py) file):
+ - [`BertModel`](./pytorch_pretrained_bert/modeling.py#L537) - raw BERT Transformer model (**fully pre-trained**),
+ - [`BertForMaskedLM`](./pytorch_pretrained_bert/modeling.py#L691) - BERT Transformer with the pre-trained masked language modeling head on top (**fully pre-trained**),
+ - [`BertForNextSentencePrediction`](./pytorch_pretrained_bert/modeling.py#L752) - BERT Transformer with the pre-trained next sentence prediction classifier on top (**fully pre-trained**),
+ - [`BertForPreTraining`](./pytorch_pretrained_bert/modeling.py#L620) - BERT Transformer with masked language modeling head and next sentence prediction classifier on top (**fully pre-trained**),
+ - [`BertForSequenceClassification`](./pytorch_pretrained_bert/modeling.py#L814) - BERT Transformer with a sequence classification head on top (BERT Transformer is **pre-trained**, the sequence classification head **is only initialized and has to be trained**),
+ - [`BertForMultipleChoice`](./pytorch_pretrained_bert/modeling.py#L880) - BERT Transformer with a multiple choice head on top (used for task like Swag) (BERT Transformer is **pre-trained**, the multiple choice classification head **is only initialized and has to be trained**),
+ - [`BertForTokenClassification`](./pytorch_pretrained_bert/modeling.py#L949) - BERT Transformer with a token classification head on top (BERT Transformer is **pre-trained**, the token classification head **is only initialized and has to be trained**),
+ - [`BertForQuestionAnswering`](./pytorch_pretrained_bert/modeling.py#L1015) - BERT Transformer with a token classification head on top (BERT Transformer is **pre-trained**, the token classification head **is only initialized and has to be trained**).
- Three tokenizers (in the [`tokenization.py`](./pytorch_pretrained_bert/tokenization.py) file):
- `BasicTokenizer` - basic tokenization (punctuation splitting, lower casing, etc.),
@@ -63,15 +67,17 @@ This package comprises the following classes that can be imported in Python and
- `BertAdam` - Bert version of Adam algorithm with weight decay fix, warmup and linear decay of the learning rate.
- A configuration class (in the [`modeling.py`](./pytorch_pretrained_bert/modeling.py) file):
- - `BertConfig` - Configuration class to store the configuration of a `BertModel` with utilisities to read and write from JSON configuration files.
+ - `BertConfig` - Configuration class to store the configuration of a `BertModel` with utilities to read and write from JSON configuration files.
The repository further comprises:
-- Three examples on how to use Bert (in the [`examples` folder](./examples)):
+- Five examples on how to use Bert (in the [`examples` folder](./examples)):
- [`extract_features.py`](./examples/extract_features.py) - Show how to extract hidden states from an instance of `BertModel`,
- [`run_classifier.py`](./examples/run_classifier.py) - Show how to fine-tune an instance of `BertForSequenceClassification` on GLUE's MRPC task,
- [`run_squad.py`](./examples/run_squad.py) - Show how to fine-tune an instance of `BertForQuestionAnswering` on SQuAD v1.0 task.
-
+ - [`run_swag.py`](./examples/run_swag.py) - Show how to fine-tune an instance of `BertForMultipleChoice` on Swag task.
+ - [`run_lm_finetuning.py`](./examples/run_lm_finetuning.py) - Show how to fine-tune an instance of `BertForPretraining' on a target text corpus.
+
These examples are detailed in the [Examples](#examples) section of this readme.
- Three notebooks that were used to check that the TensorFlow and PyTorch models behave identically (in the [`notebooks` folder](./notebooks)):
@@ -153,7 +159,7 @@ Here is a detailed documentation of the classes in the package and how to use th
| Sub-section | Description |
|-|-|
| [Loading Google AI's pre-trained weigths](#Loading-Google-AIs-pre-trained-weigths-and-PyTorch-dump) | How to load Google AI's pre-trained weight or a PyTorch saved instance |
-| [PyTorch models](#PyTorch-models) | API of the six PyTorch model classes: `BertModel`, `BertForMaskedLM`, `BertForNextSentencePrediction`, `BertForPreTraining`, `BertForSequenceClassification` or `BertForQuestionAnswering` |
+| [PyTorch models](#PyTorch-models) | API of the eight PyTorch model classes: `BertModel`, `BertForMaskedLM`, `BertForNextSentencePrediction`, `BertForPreTraining`, `BertForSequenceClassification`, `BertForMultipleChoice` or `BertForQuestionAnswering` |
| [Tokenizer: `BertTokenizer`](#Tokenizer-BertTokenizer) | API of the `BertTokenizer` class|
| [Optimizer: `BertAdam`](#Optimizer-BertAdam) | API of the `BertAdam` class |
@@ -162,12 +168,12 @@ Here is a detailed documentation of the classes in the package and how to use th
To load one of Google AI's pre-trained models or a PyTorch saved model (an instance of `BertForPreTraining` saved with `torch.save()`), the PyTorch model classes and the tokenizer can be instantiated as
```python
-model = BERT_CLASS.from_pretrain(PRE_TRAINED_MODEL_NAME_OR_PATH, cache_dir=None)
+model = BERT_CLASS.from_pretrained(PRE_TRAINED_MODEL_NAME_OR_PATH, cache_dir=None)
```
where
-- `BERT_CLASS` is either the `BertTokenizer` class (to load the vocabulary) or one of the six PyTorch model classes (to load the pre-trained weights): `BertModel`, `BertForMaskedLM`, `BertForNextSentencePrediction`, `BertForPreTraining`, `BertForSequenceClassification` or `BertForQuestionAnswering`, and
+- `BERT_CLASS` is either the `BertTokenizer` class (to load the vocabulary) or one of the eight PyTorch model classes (to load the pre-trained weights): `BertModel`, `BertForMaskedLM`, `BertForNextSentencePrediction`, `BertForPreTraining`, `BertForSequenceClassification`, `BertForTokenClassification`, `BertForMultipleChoice` or `BertForQuestionAnswering`, and
- `PRE_TRAINED_MODEL_NAME_OR_PATH` is either:
- the shortcut name of a Google AI's pre-trained model selected in the list:
@@ -175,19 +181,26 @@ where
- `bert-base-uncased`: 12-layer, 768-hidden, 12-heads, 110M parameters
- `bert-large-uncased`: 24-layer, 1024-hidden, 16-heads, 340M parameters
- `bert-base-cased`: 12-layer, 768-hidden, 12-heads , 110M parameters
- - `bert-base-multilingual`: 102 languages, 12-layer, 768-hidden, 12-heads, 110M parameters
+ - `bert-large-cased`: 24-layer, 1024-hidden, 16-heads, 340M parameters
+ - `bert-base-multilingual-uncased`: (Orig, not recommended) 102 languages, 12-layer, 768-hidden, 12-heads, 110M parameters
+ - `bert-base-multilingual-cased`: **(New, recommended)** 104 languages, 12-layer, 768-hidden, 12-heads, 110M parameters
- `bert-base-chinese`: Chinese Simplified and Traditional, 12-layer, 768-hidden, 12-heads, 110M parameters
- a path or url to a pretrained model archive containing:
-
- - `bert_config.json` a configuration file for the model, and
- - `pytorch_model.bin` a PyTorch dump of a pre-trained instance `BertForPreTraining` (saved with the usual `torch.save()`)
+
+ - `bert_config.json` a configuration file for the model, and
+ - `pytorch_model.bin` a PyTorch dump of a pre-trained instance `BertForPreTraining` (saved with the usual `torch.save()`)
If `PRE_TRAINED_MODEL_NAME_OR_PATH` is a shortcut name, the pre-trained weights will be downloaded from AWS S3 (see the links [here](pytorch_pretrained_bert/modeling.py)) and stored in a cache folder to avoid future download (the cache folder can be found at `~/.pytorch_pretrained_bert/`).
-- `cache_dir` can be an optional path to a specific directory to download and cache the pre-trained model weights. This option is useful in particular when you are using distributed training: to avoid concurrent access to the same weights you can set for example `cache_dir='./pretrained_model_{}'.format(args.local_rank)` (see the section on distributed training for more information)
+- `cache_dir` can be an optional path to a specific directory to download and cache the pre-trained model weights. This option is useful in particular when you are using distributed training: to avoid concurrent access to the same weights you can set for example `cache_dir='./pretrained_model_{}'.format(args.local_rank)` (see the section on distributed training for more information).
+
+`Uncased` means that the text has been lowercased before WordPiece tokenization, e.g., `John Smith` becomes `john smith`. The Uncased model also strips out any accent markers. `Cased` means that the true case and accent markers are preserved. Typically, the Uncased model is better unless you know that case information is important for your task (e.g., Named Entity Recognition or Part-of-Speech tagging). For information about the Multilingual and Chinese model, see the [Multilingual README](https://github.com/google-research/bert/blob/master/multilingual.md) or the original TensorFlow repository.
+
+**When using an `uncased model`, make sure to pass `--do_lower_case` to the example training scripts (or pass `do_lower_case=True` to FullTokenizer if you're using your own script and loading the tokenizer your-self.).**
Example:
```python
+tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
```
@@ -200,8 +213,8 @@ model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
The inputs and output are **identical to the TensorFlow model inputs and outputs**.
We detail them here. This model takes as *inputs*:
-
-- `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] with the word token indices in the vocabulary (see the tokens preprocessing logic in the scripts `extract_features.py`, `run_classifier.py` and `run_squad.py`), and
+[`modeling.py`](./pytorch_pretrained_bert/modeling.py)
+- `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] with the word token indices in the vocabulary (see the tokens preprocessing logic in the scripts [`extract_features.py`](./examples/extract_features.py), [`run_classifier.py`](./examples/run_classifier.py) and [`run_squad.py`](./examples/run_squad.py)), and
- `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to a `sentence B` token (see BERT paper for more details).
- `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [0, 1]. It's a mask to be used if some input sequence lengths are smaller than the max input sequence length of the current batch. It's the mask that we typically use for attention when a batch has varying length sentences.
- `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
@@ -215,7 +228,7 @@ This model *outputs* a tuple composed of:
- `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a classifier pretrained on top of the hidden state associated to the first character of the input (`CLF`) to train on the Next-Sentence task (see BERT's paper).
-An example on how to use this class is given in the `extract_features.py` script which can be used to extract the hidden states of the model for a given input.
+An example on how to use this class is given in the [`extract_features.py`](./examples/extract_features.py) script which can be used to extract the hidden states of the model for a given input.
#### 2. `BertForPreTraining`
@@ -236,6 +249,9 @@ An example on how to use this class is given in the `extract_features.py` script
- the masked language modeling logits, and
- the next sentence classification logits.
+
+An example on how to use this class is given in the [`run_lm_finetuning.py`](./examples/run_lm_finetuning.py) script which can be used to fine-tune the BERT language model on your specific different text corpus. This should improve model performance, if the language style is different from the original BERT training corpus (Wiki + BookCorpus).
+
#### 3. `BertForMaskedLM`
@@ -269,15 +285,31 @@ An example on how to use this class is given in the `extract_features.py` script
The sequence-level classifier is a linear layer that takes as input the last hidden state of the first character in the input sequence (see Figures 3a and 3b in the BERT paper).
-An example on how to use this class is given in the `run_classifier.py` script which can be used to fine-tune a single sequence (or pair of sequence) classifier using BERT, for example for the MRPC task.
+An example on how to use this class is given in the [`run_classifier.py`](./examples/run_classifier.py) script which can be used to fine-tune a single sequence (or pair of sequence) classifier using BERT, for example for the MRPC task.
+
+#### 6. `BertForMultipleChoice`
+
+`BertForMultipleChoice` is a fine-tuning model that includes `BertModel` and a linear layer on top of the `BertModel`.
+
+The linear layer outputs a single value for each choice of a multiple choice problem, then all the outputs corresponding to an instance are passed through a softmax to get the model choice.
+
+This implementation is largely inspired by the work of OpenAI in [Improving Language Understanding by Generative Pre-Training](https://blog.openai.com/language-unsupervised/) and the answer of Jacob Devlin in the following [issue](https://github.com/google-research/bert/issues/38).
+
+An example on how to use this class is given in the [`run_swag.py`](./examples/run_swag.py) script which can be used to fine-tune a multiple choice classifier using BERT, for example for the Swag task.
+
+#### 7. `BertForTokenClassification`
+
+`BertForTokenClassification` is a fine-tuning model that includes `BertModel` and a token-level classifier on top of the `BertModel`.
-#### 6. `BertForQuestionAnswering`
+The token-level classifier is a linear layer that takes as input the last hidden state of the sequence.
+
+#### 8. `BertForQuestionAnswering`
`BertForQuestionAnswering` is a fine-tuning model that includes `BertModel` with a token-level classifiers on top of the full sequence of last hidden states.
The token-level classifier takes as input the full sequence of the last hidden state and compute several (e.g. two) scores for each tokens that can for example respectively be the score that a given token is a `start_span` and a `end_span` token (see Figures 3c and 3d in the BERT paper).
-An example on how to use this class is given in the `run_squad.py` script which can be used to fine-tune a token classifier using BERT, for example for the SQuAD task.
+An example on how to use this class is given in the [`run_squad.py`](./examples/run_squad.py) script which can be used to fine-tune a token classifier using BERT, for example for the SQuAD task.
### Tokenizer: `BertTokenizer`
@@ -313,7 +345,7 @@ The optimizer accepts the following arguments:
- `b1` : Adams b1. Default : `0.9`
- `b2` : Adams b2. Default : `0.999`
- `e` : Adams epsilon. Default : `1e-6`
-- `weight_decay_rate:` Weight decay. Default : `0.01`
+- `weight_decay:` Weight decay. Default : `0.01`
- `max_grad_norm` : Maximum norm for the gradients (`-1` means no clipping). Default : `1.0`
## Examples
@@ -321,22 +353,23 @@ The optimizer accepts the following arguments:
| Sub-section | Description |
|-|-|
| [Training large models: introduction, tools and examples](#Training-large-models-introduction,-tools-and-examples) | How to use gradient-accumulation, multi-gpu training, distributed training, optimize on CPU and 16-bits training to train Bert models |
-| [Fine-tuning with BERT: running the examples](#Fine-tuning-with-BERT-running-the-examples) | Running the examples in [`./examples`](./examples/): `extract_classif.py`, `run_classifier.py` and `run_squad.py` |
+| [Fine-tuning with BERT: running the examples](#Fine-tuning-with-BERT-running-the-examples) | Running the examples in [`./examples`](./examples/): `extract_classif.py`, `run_classifier.py`, `run_squad.py` and `run_lm_finetuning.py` |
| [Fine-tuning BERT-large on GPUs](#Fine-tuning-BERT-large-on-GPUs) | How to fine tune `BERT large`|
### Training large models: introduction, tools and examples
BERT-base and BERT-large are respectively 110M and 340M parameters models and it can be difficult to fine-tune them on a single GPU with the recommended batch size for good performance (in most case a batch size of 32).
-To help with fine-tuning these models, we have included five techniques that you can activate in the fine-tuning scripts `run_classifier.py` and `run_squad.py`: gradient-accumulation, multi-gpu training, distributed training, optimize on CPU and 16-bits training . For more details on how to use these techniques you can read [the tips on training large batches in PyTorch](https://medium.com/huggingface/training-larger-batches-practical-tips-on-1-gpu-multi-gpu-distributed-setups-ec88c3e51255) that I published earlier this month.
+To help with fine-tuning these models, we have included several techniques that you can activate in the fine-tuning scripts [`run_classifier.py`](./examples/run_classifier.py) and [`run_squad.py`](./examples/run_squad.py): gradient-accumulation, multi-gpu training, distributed training and 16-bits training . For more details on how to use these techniques you can read [the tips on training large batches in PyTorch](https://medium.com/huggingface/training-larger-batches-practical-tips-on-1-gpu-multi-gpu-distributed-setups-ec88c3e51255) that I published earlier this month.
Here is how to use these techniques in our scripts:
- **Gradient Accumulation**: Gradient accumulation can be used by supplying a integer greater than 1 to the `--gradient_accumulation_steps` argument. The batch at each step will be divided by this integer and gradient will be accumulated over `gradient_accumulation_steps` steps.
- **Multi-GPU**: Multi-GPU is automatically activated when several GPUs are detected and the batches are splitted over the GPUs.
- **Distributed training**: Distributed training can be activated by supplying an integer greater or equal to 0 to the `--local_rank` argument (see below).
-- **Optimize on CPU**: The Adam optimizer stores 2 moving average of the weights of the model. If you keep them on GPU 1 (typical behavior), your first GPU will have to store 3-times the size of the model. This is not optimal for large models like `BERT-large` and means your batch size is a lot lower than it could be. This option will perform the optimization and store the averages on the CPU/RAM to free more room on the GPU(s). As the most computational intensive operation is usually the backward pass, this doesn't have a significant impact on the training time. Activate this option with `--optimize_on_cpu` on the `run_squad.py` script.
-- **16-bits training**: 16-bits training, also called mixed-precision training, can reduce the memory requirement of your model on the GPU by using half-precision training, basically allowing to double the batch size. If you have a recent GPU (starting from NVIDIA Volta architecture) you should see no decrease in speed. A good introduction to Mixed precision training can be found [here](https://devblogs.nvidia.com/mixed-precision-training-deep-neural-networks/) and a full documentation is [here](https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html). In our scripts, this option can be activated by setting the `--fp16` flag and you can play with loss scaling using the `--loss_scaling` flag (see the previously linked documentation for details on loss scaling). If the loss scaling is too high (`Nan` in the gradients) it will be automatically scaled down until the value is acceptable. The default loss scaling is 128 which behaved nicely in our tests.
+- **16-bits training**: 16-bits training, also called mixed-precision training, can reduce the memory requirement of your model on the GPU by using half-precision training, basically allowing to double the batch size. If you have a recent GPU (starting from NVIDIA Volta architecture) you should see no decrease in speed. A good introduction to Mixed precision training can be found [here](https://devblogs.nvidia.com/mixed-precision-training-deep-neural-networks/) and a full documentation is [here](https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html). In our scripts, this option can be activated by setting the `--fp16` flag and you can play with loss scaling using the `--loss_scale` flag (see the previously linked documentation for details on loss scaling). The loss scale can be zero in which case the scale is dynamically adjusted or a positive power of two in which case the scaling is static.
+
+To use 16-bits training and distributed training, you need to install NVIDIA's apex extension [as detailed here](https://github.com/nvidia/apex). You will find more information regarding the internals of `apex` and how to use `apex` in [the doc and the associated repository](https://github.com/nvidia/apex). The results of the tests performed on pytorch-BERT by the NVIDIA team (and my trials at reproducing them) can be consulted in [the relevant PR of the present repository](https://github.com/huggingface/pytorch-pretrained-BERT/pull/116).
Note: To use *Distributed Training*, you will need to run one training script on each of your machines. This can be done for example by running the following command on each server (see [the above mentioned blog post]((https://medium.com/huggingface/training-larger-batches-practical-tips-on-1-gpu-multi-gpu-distributed-setups-ec88c3e51255)) for more details):
```bash
@@ -346,16 +379,22 @@ Where `$THIS_MACHINE_INDEX` is an sequential index assigned to each of your mach
### Fine-tuning with BERT: running the examples
-We showcase the same examples as [the original implementation](https://github.com/google-research/bert/): fine-tuning a sequence-level classifier on the MRPC classification corpus and a token-level classifier on the question answering dataset SQuAD.
+We showcase several fine-tuning examples based on (and extended from) [the original implementation](https://github.com/google-research/bert/):
+
+- a *sequence-level classifier* on the MRPC classification corpus,
+- a *token-level classifier* on the question answering dataset SQuAD, and
+- a *sequence-level multiple-choice classifier* on the SWAG classification corpus.
+- a *BERT language model* on another target corpus
+
+#### MRPC
+
+This example code fine-tunes BERT on the Microsoft Research Paraphrase
+Corpus (MRPC) corpus and runs in less than 10 minutes on a single K-80 and in 27 seconds (!) on single tesla V100 16GB with apex installed.
-Before running these examples you should download the
+Before running this example you should download the
[GLUE data](https://gluebenchmark.com/tasks) by running
[this script](https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e)
-and unpack it to some directory `$GLUE_DIR`. Please also download the `BERT-Base`
-checkpoint, unzip it to some directory `$BERT_BASE_DIR`, and convert it to its PyTorch version as explained in the previous section.
-
-This example code fine-tunes `BERT-Base` on the Microsoft Research Paraphrase
-Corpus (MRPC) corpus and runs in less than 10 minutes on a single K-80.
+and unpack it to some directory `$GLUE_DIR`.
```shell
export GLUE_DIR=/path/to/glue
@@ -364,6 +403,7 @@ python run_classifier.py \
--task_name MRPC \
--do_train \
--do_eval \
+ --do_lower_case \
--data_dir $GLUE_DIR/MRPC/ \
--bert_model bert-base-uncased \
--max_seq_length 128 \
@@ -375,7 +415,29 @@ python run_classifier.py \
Our test ran on a few seeds with [the original implementation hyper-parameters](https://github.com/google-research/bert#sentence-and-sentence-pair-classification-tasks) gave evaluation results between 84% and 88%.
-The second example fine-tunes `BERT-Base` on the SQuAD question answering task.
+**Fast run with apex and 16 bit precision: fine-tuning on MRPC in 27 seconds!**
+First install apex as indicated [here](https://github.com/NVIDIA/apex).
+Then run
+```shell
+export GLUE_DIR=/path/to/glue
+
+python run_classifier.py \
+ --task_name MRPC \
+ --do_train \
+ --do_eval \
+ --do_lower_case \
+ --data_dir $GLUE_DIR/MRPC/ \
+ --bert_model bert-base-uncased \
+ --max_seq_length 128 \
+ --train_batch_size 32 \
+ --learning_rate 2e-5 \
+ --num_train_epochs 3.0 \
+ --output_dir /tmp/mrpc_output/
+```
+
+#### SQuAD
+
+This example code fine-tunes BERT on the SQuAD dataset. It runs in 24 min (with BERT-base) or 68 min (with BERT-large) on a single tesla V100 16GB.
The data for SQuAD can be downloaded with the following links and should be saved in a `$SQUAD_DIR` directory.
@@ -390,6 +452,7 @@ python run_squad.py \
--bert_model bert-base-uncased \
--do_train \
--do_predict \
+ --do_lower_case \
--train_file $SQUAD_DIR/train-v1.1.json \
--predict_file $SQUAD_DIR/dev-v1.1.json \
--train_batch_size 12 \
@@ -405,6 +468,54 @@ Training with the previous hyper-parameters gave us the following results:
{"f1": 88.52381567990474, "exact_match": 81.22043519394512}
```
+#### SWAG
+
+The data for SWAG can be downloaded by cloning the following [repository](https://github.com/rowanz/swagaf)
+
+```shell
+export SWAG_DIR=/path/to/SWAG
+
+python run_swag.py \
+ --bert_model bert-base-uncased \
+ --do_train \
+ --do_lower_case \
+ --do_eval \
+ --data_dir $SWAG_DIR/data \
+ --train_batch_size 16 \
+ --learning_rate 2e-5 \
+ --num_train_epochs 3.0 \
+ --max_seq_length 80 \
+ --output_dir /tmp/swag_output/ \
+ --gradient_accumulation_steps 4
+```
+
+Training with the previous hyper-parameters on a single GPU gave us the following results:
+```
+eval_accuracy = 0.8062081375587323
+eval_loss = 0.5966546792367169
+global_step = 13788
+loss = 0.06423990014260186
+```
+
+#### LM Fine-tuning
+
+The data should be a text file in the same format as [sample_text.txt](./samples/sample_text.txt) (one sentence per line, docs separated by empty line).
+You can download an [exemplary training corpus](https://ext-bert-sample.obs.eu-de.otc.t-systems.com/small_wiki_sentence_corpus.txt) generated from wikipedia articles and splitted into ~500k sentences with spaCy.
+Training one epoch on this corpus takes about 1:20h on 4 x NVIDIA Tesla P100 with `train_batch_size=200` and `max_seq_length=128`:
+
+
+```shell
+python run_lm_finetuning.py \
+ --bert_model bert-base-cased \
+ --do_train \
+ --train_file samples/sample_text.txt \
+ --output_dir models \
+ --num_train_epochs 5.0 \
+ --learning_rate 3e-5 \
+ --train_batch_size 32 \
+ --max_seq_length 128
+```
+
## Fine-tuning BERT-large on GPUs
The options we list above allow to fine-tune BERT-large rather easily on GPU(s) instead of the TPU used by the original implementation.
@@ -424,6 +535,7 @@ python ./run_squad.py \
--bert_model bert-large-uncased \
--do_train \
--do_predict \
+ --do_lower_case \
--train_file $SQUAD_TRAIN \
--predict_file $SQUAD_EVAL \
--learning_rate 3e-5 \
@@ -432,8 +544,7 @@ python ./run_squad.py \
--doc_stride 128 \
--output_dir $OUTPUT_DIR \
--train_batch_size 24 \
- --gradient_accumulation_steps 2 \
- --optimize_on_cpu
+ --gradient_accumulation_steps 2
```
If you have a recent GPU (starting from NVIDIA Volta series), you should try **16-bit fine-tuning** (FP16).
@@ -444,6 +555,7 @@ python ./run_squad.py \
--bert_model bert-large-uncased \
--do_train \
--do_predict \
+ --do_lower_case \
--train_file $SQUAD_TRAIN \
--predict_file $SQUAD_EVAL \
--learning_rate 3e-5 \
@@ -479,7 +591,7 @@ A command-line interface is provided to convert a TensorFlow checkpoint in a PyT
You can convert any TensorFlow checkpoint for BERT (in particular [the pre-trained models released by Google](https://github.com/google-research/bert#pre-trained-models)) in a PyTorch save file by using the [`./pytorch_pretrained_bert/convert_tf_checkpoint_to_pytorch.py`](convert_tf_checkpoint_to_pytorch.py) script.
-This CLI takes as input a TensorFlow checkpoint (three files starting with `bert_model.ckpt`) and the associated configuration file (`bert_config.json`), and creates a PyTorch model for this configuration, loads the weights from the TensorFlow checkpoint in the PyTorch model and saves the resulting model in a standard PyTorch save file that can be imported using `torch.load()` (see examples in `extract_features.py`, `run_classifier.py` and `run_squad.py`).
+This CLI takes as input a TensorFlow checkpoint (three files starting with `bert_model.ckpt`) and the associated configuration file (`bert_config.json`), and creates a PyTorch model for this configuration, loads the weights from the TensorFlow checkpoint in the PyTorch model and saves the resulting model in a standard PyTorch save file that can be imported using `torch.load()` (see examples in [`extract_features.py`](./examples/extract_features.py), [`run_classifier.py`](./examples/run_classifier.py) and [`run_squad.py`]((./examples/run_squad.py))).
You only need to run this conversion script **once** to get a PyTorch model. You can then disregard the TensorFlow checkpoint (the three files starting with `bert_model.ckpt`) but be sure to keep the configuration file (`bert_config.json`) and the vocabulary file (`vocab.txt`) as these are needed for the PyTorch model too.
diff --git a/Untitled.ipynb b/Untitled.ipynb
new file mode 100644
index 00000000000000..6701ee5f62e8e7
--- /dev/null
+++ b/Untitled.ipynb
@@ -0,0 +1,1003 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%load_ext autoreload\n",
+ "%autoreload 2\n",
+ "\n",
+ "from IPython.core.interactiveshell import InteractiveShell\n",
+ "InteractiveShell.ast_node_interactivity = 'all'"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Warning: apex was installed without --cpp_ext. Falling back to Python flatten and unflatten.\n",
+ "Warning: apex was installed without --cuda_ext. Fused syncbn kernels will be unavailable. Python fallbacks will be used instead.\n",
+ "Warning: apex was installed without --cuda_ext. FusedAdam will be unavailable.\n",
+ "Warning: apex was installed without --cuda_ext. FusedLayerNorm will be unavailable.\n",
+ "Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.\n"
+ ]
+ }
+ ],
+ "source": [
+ "# import seaborn as sns\n",
+ "import os\n",
+ "import json\n",
+ "\n",
+ "import numpy as np\n",
+ "import math\n",
+ "import matplotlib\n",
+ "import matplotlib.pyplot as plt\n",
+ "from pylab import rcParams\n",
+ "\n",
+ "import torch\n",
+ "import torch.nn.functional as F\n",
+ "from pytorch_pretrained_bert import tokenization, BertTokenizer, BertModel, BertForMaskedLM, BertForPreTraining, BertConfig\n",
+ "from examples.extract_features import *"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "06/10/2019 08:14:45 - INFO - pytorch_pretrained_bert.tokenization - loading vocabulary file /nas/pretrain-bert/pretrain-pytorch/bert-base-uncased-vocab.txt\n",
+ "06/10/2019 08:14:45 - INFO - pytorch_pretrained_bert.modeling - loading archive file /nas/pretrain-bert/pretrain-pytorch/bert-base-uncased/\n",
+ "06/10/2019 08:14:45 - INFO - pytorch_pretrained_bert.modeling - Model config {\n",
+ " \"attention_probs_dropout_prob\": 0.1,\n",
+ " \"hidden_act\": \"gelu\",\n",
+ " \"hidden_dropout_prob\": 0.1,\n",
+ " \"hidden_size\": 768,\n",
+ " \"initializer_range\": 0.02,\n",
+ " \"intermediate_size\": 3072,\n",
+ " \"max_position_embeddings\": 512,\n",
+ " \"num_attention_heads\": 12,\n",
+ " \"num_hidden_layers\": 12,\n",
+ " \"type_vocab_size\": 2,\n",
+ " \"vocab_size\": 30522\n",
+ "}\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "class Args:\n",
+ " def __init__(self):\n",
+ " pass\n",
+ " \n",
+ "args = Args()\n",
+ "args.no_cuda = True\n",
+ "\n",
+ "CONFIG_NAME = 'bert_config.json'\n",
+ "# BERT_DIR = '/nas/pretrain-bert/pretrain-tensorflow/uncased_L-12_H-768_A-12/'\n",
+ "BERT_DIR = '/nas/pretrain-bert/pretrain-pytorch/bert-base-uncased/'\n",
+ "config_file = os.path.join(BERT_DIR, CONFIG_NAME)\n",
+ "config = BertConfig.from_json_file(config_file)\n",
+ "\n",
+ "# tokenizer = BertTokenizer.from_pretrained(os.path.join(BERT_DIR, 'vocab.txt'))\n",
+ "tokenizer = BertTokenizer.from_pretrained('/nas/pretrain-bert/pretrain-pytorch/bert-base-uncased-vocab.txt')\n",
+ "model = BertForPreTraining.from_pretrained(BERT_DIR)\n",
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() and not args.no_cuda else \"cpu\")\n",
+ "_ = model.to(device)\n",
+ "_ = model.eval()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import re\n",
+ "def convert_text_to_examples(text):\n",
+ " examples = []\n",
+ " unique_id = 0\n",
+ " if True:\n",
+ " for line in text:\n",
+ " line = line.strip()\n",
+ " text_a = None\n",
+ " text_b = None\n",
+ " m = re.match(r\"^(.*) \\|\\|\\| (.*)$\", line)\n",
+ " if m is None:\n",
+ " text_a = line\n",
+ " else:\n",
+ " text_a = m.group(1)\n",
+ " text_b = m.group(2)\n",
+ " examples.append(\n",
+ " InputExample(unique_id=unique_id, text_a=text_a, text_b=text_b))\n",
+ " unique_id += 1\n",
+ " return examples\n",
+ "\n",
+ "def convert_examples_to_features(examples, tokenizer, append_special_tokens=True, replace_mask=True, print_info=False):\n",
+ " features = []\n",
+ " for (ex_index, example) in enumerate(examples):\n",
+ " tokens_a = tokenizer.tokenize(example.text_a)\n",
+ " tokens_b = None\n",
+ " if example.text_b:\n",
+ " tokens_b = tokenizer.tokenize(example.text_b)\n",
+ "\n",
+ " tokens = []\n",
+ " input_type_ids = []\n",
+ " if append_special_tokens:\n",
+ " tokens.append(\"[CLS]\")\n",
+ " input_type_ids.append(0)\n",
+ " for token in tokens_a:\n",
+ " if replace_mask and token == '_': # XD\n",
+ " token = \"[MASK]\"\n",
+ " tokens.append(token)\n",
+ " input_type_ids.append(0)\n",
+ " if append_special_tokens:\n",
+ " tokens.append(\"[SEP]\")\n",
+ " input_type_ids.append(0)\n",
+ "\n",
+ " if tokens_b:\n",
+ " for token in tokens_b:\n",
+ " if replace_mask and token == '_': # XD\n",
+ " token = \"[MASK]\"\n",
+ " tokens.append(token)\n",
+ " input_type_ids.append(1)\n",
+ " if append_special_tokens:\n",
+ " tokens.append(\"[SEP]\")\n",
+ " input_type_ids.append(1)\n",
+ "\n",
+ " input_ids = tokenizer.convert_tokens_to_ids(tokens)\n",
+ " input_mask = [1] * len(input_ids)\n",
+ "\n",
+ " if ex_index < 5:\n",
+ "# logger.info(\"*** Example ***\")\n",
+ "# logger.info(\"unique_id: %s\" % (example.unique_id))\n",
+ " logger.info(\"tokens: %s\" % \" \".join([str(x) for x in tokens]))\n",
+ "# logger.info(\"input_ids: %s\" % \" \".join([str(x) for x in input_ids]))\n",
+ "# logger.info(\"input_mask: %s\" % \" \".join([str(x) for x in input_mask]))\n",
+ "# logger.info(\n",
+ "# \"input_type_ids: %s\" % \" \".join([str(x) for x in input_type_ids]))\n",
+ " \n",
+ " features.append(\n",
+ " InputFeatures(\n",
+ " unique_id=example.unique_id,\n",
+ " tokens=tokens,\n",
+ " input_ids=input_ids,\n",
+ " input_mask=input_mask,\n",
+ " input_type_ids=input_type_ids))\n",
+ " return features\n",
+ "\n",
+ "def copy_and_mask_feature(feature, masked_tokens=None):\n",
+ " import copy\n",
+ " tokens = feature.tokens\n",
+ " masked_positions = [tokens.index(t) for t in masked_tokens if t in tokens] \\\n",
+ " if masked_tokens is not None else range(len(tokens))\n",
+ " assert len(masked_positions) > 0\n",
+ " masked_feature_copies = []\n",
+ " for masked_pos in masked_positions:\n",
+ " feature_copy = copy.deepcopy(feature)\n",
+ " feature_copy.input_ids[masked_pos] = tokenizer.vocab[\"[MASK]\"]\n",
+ " masked_feature_copies.append(feature_copy)\n",
+ " return masked_feature_copies, masked_positions\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def show_lm_probs(tokens, input_ids, probs, topk=5, firstk=20):\n",
+ " def print_pair(token, prob, end_str='', hit_mark=' '):\n",
+ " if i < firstk:\n",
+ " # token = token.replace('', '').replace('\\n', '/n')\n",
+ " print('{}{: >3} | {: <12}'.format(hit_mark, int(round(prob*100)), token), end=end_str)\n",
+ " \n",
+ " ret = None\n",
+ " for i in range(len(tokens)):\n",
+ " ind_ = input_ids[i].item() if input_ids is not None else tokenizer.vocab[tokens[i]]\n",
+ " prob_ = probs[i][ind_].item()\n",
+ " print_pair(tokens[i], prob_, end_str='\\t')\n",
+ " values, indices = probs[i].topk(topk)\n",
+ " top_pairs = []\n",
+ " for j in range(topk):\n",
+ " ind, prob = indices[j].item(), values[j].item()\n",
+ " hit_mark = '*' if ind == ind_ else ' '\n",
+ " token = tokenizer.ids_to_tokens[ind]\n",
+ " print_pair(token, prob, hit_mark=hit_mark, end_str='' if j < topk - 1 else '\\n')\n",
+ " top_pairs.append((token, prob))\n",
+ " if tokens[i] == \"[MASK]\":\n",
+ " ret = top_pairs\n",
+ " return ret"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import colored\n",
+ "from colored import stylize\n",
+ "\n",
+ "def show_abnormals(tokens, probs, show_suggestions=False):\n",
+ " def gap2color(gap):\n",
+ " if gap <= 5:\n",
+ " return 'yellow_1'\n",
+ " elif gap <= 10:\n",
+ " return 'orange_1'\n",
+ " else:\n",
+ " return 'red_1'\n",
+ " \n",
+ " def print_token(token, suggestion, gap):\n",
+ " if gap == 0:\n",
+ " print(stylize(token + ' ', colored.fg('white') + colored.bg('black')), end='')\n",
+ " else:\n",
+ " print(stylize(token, colored.fg(gap2color(gap)) + colored.bg('black')), end='')\n",
+ " if show_suggestions and gap > 5:\n",
+ " print(stylize('/' + suggestion + ' ', colored.fg('green' if gap > 10 else 'cyan') + colored.bg('black')), end='')\n",
+ " else:\n",
+ " print(stylize(' ', colored.fg(gap2color(gap)) + colored.bg('black')), end='')\n",
+ " # print('/' + suggestion, end=' ')\n",
+ " # print('%.2f' % gap, end=' ')\n",
+ " \n",
+ " avg_gap = 0.\n",
+ " for i in range(1, len(tokens) - 1): # skip first [CLS] and last [SEP]\n",
+ " ind_ = tokenizer.vocab[tokens[i]]\n",
+ " prob_ = probs[i][ind_].item()\n",
+ " top_prob = probs[i].max().item()\n",
+ " top_ind = probs[i].argmax().item()\n",
+ " gap = math.log(top_prob) - math.log(prob_)\n",
+ " suggestion = tokenizer.ids_to_tokens[top_ind]\n",
+ " print_token(tokens[i], suggestion, gap)\n",
+ " avg_gap += gap\n",
+ " avg_gap /= (len(tokens) - 2)\n",
+ " print()\n",
+ " print(avg_gap)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "analyzed_cache = {}\n",
+ "\n",
+ "def analyze_text(text, masked_tokens=None, show_suggestions=False, show_firstk_probs=20):\n",
+ " if text[0] in analyzed_cache:\n",
+ " features, mlm_probs = analyzed_cache[text[0]]\n",
+ " given_mask = \"[MASK]\" in features[0].tokens\n",
+ " tokens = features[0].tokens\n",
+ " else:\n",
+ " examples = convert_text_to_examples(text)\n",
+ " features = convert_examples_to_features(examples, tokenizer, print_info=False)\n",
+ " given_mask = \"[MASK]\" in features[0].tokens\n",
+ " if not given_mask or masked_tokens is not None:\n",
+ " assert len(features) == 1\n",
+ " features, masked_positions = copy_and_mask_feature(features[0], masked_tokens=masked_tokens)\n",
+ "\n",
+ " input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)\n",
+ " input_type_ids = torch.tensor([f.input_type_ids for f in features], dtype=torch.long)\n",
+ " input_ids = input_ids.to(device)\n",
+ " input_type_ids = input_type_ids.to(device)\n",
+ "\n",
+ " mlm_logits, _ = model(input_ids, input_type_ids)\n",
+ " mlm_probs = F.softmax(mlm_logits, dim=-1)\n",
+ "\n",
+ " tokens = features[0].tokens\n",
+ " if not given_mask or masked_tokens is not None:\n",
+ " bsz, seq_len, vocab_size = mlm_probs.size()\n",
+ " assert bsz == len(masked_positions)\n",
+ " # reduced_mlm_probs = torch.Tensor(1, seq_len, vocab_size)\n",
+ " # for i in range(seq_len):\n",
+ " # reduced_mlm_probs[0, i] = mlm_probs[i, i]\n",
+ " reduced_mlm_probs = torch.Tensor(1, len(masked_positions), vocab_size)\n",
+ " for i, pos in enumerate(masked_positions):\n",
+ " reduced_mlm_probs[0, i] = mlm_probs[i, pos]\n",
+ " mlm_probs = reduced_mlm_probs\n",
+ " tokens = [tokens[i] for i in masked_positions]\n",
+ " \n",
+ " analyzed_cache[text[0]] = (features, mlm_probs)\n",
+ " \n",
+ " top_pairs = show_lm_probs(tokens, None, mlm_probs[0], firstk=show_firstk_probs)\n",
+ " if not given_mask:\n",
+ " show_abnormals(tokens, mlm_probs[0], show_suggestions=show_suggestions)\n",
+ " return top_pairs"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 38,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " 0 | [CLS] \t 3 | . 1 | the 1 | , 1 | ) 1 | \" \n",
+ " 100 | \" \t*100 | \" 0 | ' 0 | and 0 | so 0 | did \n",
+ " 100 | is \t*100 | is 0 | was 0 | does 0 | isn 0 | has \n",
+ " 97 | tom \t* 97 | tom 2 | he 0 | thomas 0 | you 0 | she \n",
+ " 100 | taller \t*100 | taller 0 | tall 0 | shorter 0 | height 0 | tallest \n",
+ " 100 | than \t*100 | than 0 | then 0 | as 0 | that 0 | to \n",
+ " 100 | mary \t*100 | mary 0 | tom 0 | you 0 | barbara 0 | maria \n",
+ " 100 | ? \t*100 | ? 0 | . 0 | ! 0 | ... 0 | - \n",
+ " 100 | \" \t*100 | \" 0 | ' 0 | ! 0 | * 0 | ) \n",
+ " 100 | \" \t*100 | \" 0 | no 0 | ' 0 | oh 0 | that \n",
+ " 100 | no \t*100 | no 0 | yes 0 | nope 0 | yeah 0 | oh \n",
+ " 100 | , \t*100 | , 0 | . 0 | ; 0 | - 0 | no \n",
+ " 0 | [MASK] \t 80 | tom 10 | he 4 | mary 2 | she 1 | thomas \n",
+ " 100 | is \t*100 | is 0 | was 0 | does 0 | has 0 | no \n",
+ " 100 | taller \t*100 | taller 0 | shorter 0 | tall 0 | larger 0 | smaller \n",
+ " 100 | . \t*100 | . 0 | ; 0 | , 0 | ! 0 | ) \n",
+ " 100 | \" \t*100 | \" 0 | ' 0 | . 0 | ! 0 | ; \n",
+ " 0 | [SEP] \t 86 | . 4 | , 3 | he 2 | \" 1 | she \n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "[('tom', 0.7961671352386475),\n",
+ " ('he', 0.09765198826789856),\n",
+ " ('mary', 0.04068772494792938),\n",
+ " ('she', 0.022535543888807297),\n",
+ " ('thomas', 0.0058586327359080315)]"
+ ]
+ },
+ "execution_count": 38,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "text = [\"_ was the greatest physicist who developed theory of relativity.\"]\n",
+ "text = [\"The trophy doesn't fit into the brown suitcase because the _ is too large.\"] # relational adj\n",
+ "text = ['\"Is Tom taller than Mary?\" \"No, _ is taller.\"'] # yes/no\n",
+ "text = [ \"Tom has black hair. Mary has black hair. John has yellow hair. _ and Mary have the same hair color.\"] # compare \n",
+ "text = ['John is taller/shorter than Mary because/although _ is older/younger.'] # causality\n",
+ "text = [\"Jennifer is older than James . Jennifer younger than Robert . _ is the oldest.\"] # transitive inference\n",
+ "\n",
+ "analyze_text(text, show_firstk_probs=100)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def words2heads(attns, tokens, words):\n",
+ " positions = [tokens.index(word) for word in words]\n",
+ "\n",
+ " for layer in range(config.num_hidden_layers):\n",
+ " for head in range(config.num_attention_heads):\n",
+ " for pos_indices in [(0, 1), (1, 0)]:\n",
+ " from_pos, to_pos = positions[pos_indices[0]], positions[pos_indices[1]]\n",
+ " if attns[layer][head][from_pos].max(0)[1].item() == to_pos:\n",
+ " print('Layer %d, head %d: %s -> %s' % (layer, head, tokens[from_pos], tokens[to_pos]), end='\\t')\n",
+ " print(attns[layer][head][from_pos].topk(5)[0].data)\n",
+ "\n",
+ "def head2words(attns, tokens, layer, head):\n",
+ " for from_pos in range(len(tokens)):\n",
+ " to_pos = attns[layer][head][from_pos].max(0)[1].item()\n",
+ " from_word, to_word = tokens[from_pos], tokens[to_pos]\n",
+ " if from_word.isalpha() and to_word.isalpha():\n",
+ " print('%s @ %d -> %s @ %d' % (from_word, from_pos, to_word, to_pos), end='\\t')\n",
+ " print(attns[layer][head][from_pos].topk(5)[0].data)\n",
+ " \n",
+ "special_tokens = ['[CLS]', '[SEP]']\n",
+ "\n",
+ "def get_salient_heads(attns, tokens, attn_thld=0.5):\n",
+ " for layer in range(config.num_hidden_layers):\n",
+ " for head in range(config.num_attention_heads):\n",
+ " pos_pairs = []\n",
+ " for from_pos in range(1, len(tokens) - 1): # skip [CLS] and [SEP]\n",
+ " top_attn, to_pos = attns[layer][head][from_pos].max(0)\n",
+ " top_attn, to_pos = top_attn.item(), to_pos.item()\n",
+ " from_word, to_word = tokens[from_pos], tokens[to_pos]\n",
+ "# if from_word.isalpha() and to_word.isalpha() and top_attn >= attn_thld:\n",
+ " if abs(from_pos - to_pos) <= 1:\n",
+ "# print('Layer %d, head %d: %s @ %d -> %s @ %d' % (layer, head, from_word, from_pos, to_word, to_pos), end='\\t')\n",
+ "# print(attns[layer][head][from_pos].topk(5)[0].data)\n",
+ " pos_pairs.append((from_pos, to_pos))\n",
+ " \n",
+ " ratio = len(pos_pairs) / (len(tokens) - 2)\n",
+ " if ratio > 0.5:\n",
+ " print(ratio)\n",
+ " for from_pos, to_pos in pos_pairs:\n",
+ " print('Layer %d, head %d: %s @ %d -> %s @ %d' % (layer, head, tokens[from_pos], from_pos, tokens[to_pos], to_pos), end='\\t')\n",
+ " print(attns[layer][head][from_pos].topk(5)[0].data)\n",
+ " "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "01/10/2019 21:46:20 - INFO - examples.extract_features - tokens: [CLS] jim laughed because he was so happy . [SEP]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "jim @ 1 -> jim @ 1\ttensor([0.7248, 0.0842, 0.0656, 0.0407, 0.0319], device='cuda:0')\n"
+ ]
+ }
+ ],
+ "source": [
+ "# text, words = [\"The trophy doesn't fit into the brown suitcase because the it is too large.\"], ['fit', 'large']\n",
+ "# text, words = [\"Mary couldn't beat John in the match because he was too strong.\"], ['beat', 'strong']\n",
+ "text, words = [\"John is taller than Mary because he is older.\"], ['taller', 'older']\n",
+ "# text, words = [\"The red ball is heavier than the blue ball because the red ball is bigger.\"], ['heavier', 'bigger']\n",
+ "text, words = [\"Jim laughed because he was so happy.\"], ['cried', 'sad']\n",
+ "# text, words = [\"Jim ate the cake quickly because he was so hungry.\"], ['ate', 'hungry']\n",
+ "# text, words = [\"Jim drank the juice quickly because he was so thirsty.\"], ['drank', 'thirsty']\n",
+ "# text, words = [\"Tom's drawing hangs high. It is above Susan's drawing\"], ['high', 'above']\n",
+ "# text, words = [\"Tom's drawing hangs low. It is below Susan's drawing\"], ['low', 'below']\n",
+ "# text, words = [\"John is taller than Mary . Mary is shorter than John.\"], ['taller', 'shorter']\n",
+ "# text, words = [\"The drawing is above the cabinet. The cabinet is below the drawing\"], ['above', 'below']\n",
+ "# text, words = [\"Jim is very thin . He is not fat.\"], ['thin', 'fat']\n",
+ "\n",
+ "features = convert_examples_to_features(convert_text_to_examples(text), tokenizer, print_info=False)\n",
+ "input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long).to(device)\n",
+ "input_type_ids = torch.tensor([f.input_type_ids for f in features], dtype=torch.long).to(device)\n",
+ "mlm_logits, _ = model(input_ids, input_type_ids)\n",
+ "mlm_probs = F.softmax(mlm_logits, dim=-1)\n",
+ "tokens = features[0].tokens\n",
+ "# top_pairs = show_lm_probs(tokens, None, mlm_probs[0], firstk=100)\n",
+ "\n",
+ "attn_name = 'enc_self_attns'\n",
+ "hypo = {attn_name: [model.bert.encoder.layer[i].attention.self.attention_probs[0] for i in range(config.num_hidden_layers)]}\n",
+ "key_labels = query_labels = tokens\n",
+ "labels_dict = {attn_name: (key_labels, query_labels)}\n",
+ "result_tuple = (hypo, config.num_attention_heads, labels_dict)\n",
+ "# plot_layer_attn(result_tuple, attn_name=attn_name, layer=10, heads=None)\n",
+ "\n",
+ "attns = hypo[attn_name]\n",
+ " \n",
+ "# words2heads(attns, tokens, words)\n",
+ "head2words(attns, tokens, 2, 10)\n",
+ "# get_salient_heads(attns, tokens, attn_thld=0.0)"
+ ]
+ },
+ {
+ "cell_type": "raw",
+ "metadata": {},
+ "source": [
+ "0,2\t-1\n",
+ "0,3\t-1\n",
+ "0,10\t+1 动宾\n",
+ "1,1\t+1 动介\n",
+ "1,4\t-1\n",
+ "1,11\t0\n",
+ "2,0\t+1**\n",
+ "2,6\t0**\n",
+ "2,9\t+1**\n",
+ "3,5\t-1\n",
+ "7,4\t-1\n",
+ "11,8\t0\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "head_size = config.hidden_size // config.num_attention_heads\n",
+ "layer = 1\n",
+ "head = 1 # 2, 3, 10\n",
+ "wq = model.bert.encoder.layer[layer].attention.self.query.weight.data.view(-1, config.num_attention_heads, head_size).permute(1, 0, 2)\n",
+ "wk = model.bert.encoder.layer[layer].attention.self.key.weight.data.view(-1, config.num_attention_heads, head_size).permute(1, 0, 2)\n",
+ "\n",
+ "wqk = torch.bmm(wq, wk.transpose(-1, -2))\n",
+ "# (wqk * wqk.transpose(-1, -2)).sum((1, 2)) / (wqk * wqk).sum((1, 2))\n",
+ "# plt.imshow(wqk[head]*wqk[head])\n",
+ "# plt.show()\n",
+ "\n",
+ "# q = torch.matmul(pos_emb, wq)\n",
+ "# k = torch.matmul(pos_emb_prev, wk)\n",
+ "# (q * k).sum((-2, -1))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pos_emb = model.bert.embeddings.position_embeddings.weight.data\n",
+ "pos_emb_prev = torch.zeros_like(pos_emb)\n",
+ "pos_emb_next = torch.zeros_like(pos_emb)\n",
+ "pos_emb_prev[1:] = pos_emb[:-1]\n",
+ "pos_emb_next[:-1] = pos_emb[1:]\n",
+ "pos_emb, pos_emb_prev, pos_emb_next = pos_emb[1:-1], pos_emb_prev[1:-1], pos_emb_next[1:-1]\n",
+ "\n",
+ "# pos_q = torch.matmul(pos_emb, wk[head])\n",
+ "# plt.imshow(pos_q[:32])\n",
+ "# plt.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "['Tom has black hair. Mary has black hair. John has yellow hair. _ and Mary have the same hair color.',\n",
+ " 'Tom has black hair. Mary has black hair. John has yellow hair. _ and Mary have different hair colors.']"
+ ]
+ },
+ "execution_count": 19,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "text = [\n",
+ " # same / different\n",
+ " \"Tom has black hair. Mary has black hair. John has yellow hair. _ and Mary have the same hair color.\",\n",
+ " \"Tom has black hair. Mary has black hair. John has yellow hair. _ and Mary have different hair colors.\",\n",
+ " \"Tom has yellow hair. Mary has black hair. John has black hair. Mary and _ have the same hair color.\",\n",
+ " # because / although\n",
+ " \"John is taller/shorter than Mary because/although _ is older/younger.\",\n",
+ " \"The red ball is heavier/lighter than the blue ball because/although the _ ball is bigger/smaller.\",\n",
+ " \"Charles did a lot better/worse than his good friend Nancy on the test because/although _ had/hadn't studied so hard.\",\n",
+ " \"The trophy doesn't fit into the brown suitcase because/although the _ is too small/large.\",\n",
+ " \"John thought that he would arrive earlier than Susan, but/and indeed _ was the first to arrive.\",\n",
+ " # reverse\n",
+ " \"John came then Mary came. They left in reverse order. _ left then _ left.\",\n",
+ " \"John came after Mary. They left in reverse order. _ left after _ .\",\n",
+ " \"John came first, then came Mary. They left in reverse order: _ left first, then left _ .\",\n",
+ " # compare sentences with same / opposite meaning, 2nd order\n",
+ " \"Though John is tall, Tom is taller than John. So John is _ than Tom.\",\n",
+ " \"Tom is taller than John. So _ is shorter than _.\",\n",
+ " # WSC-style: before /after\n",
+ " # \"Mary came before/after John. _ was late/early .\",\n",
+ " # yes / no, 2nd order\n",
+ " \"Was Tom taller than Susan? Yes, _ was taller.\",\n",
+ " # right / wrong, epistemic modality, 2nd order\n",
+ " \"John said/thought that the red ball was heavier than the blue ball. He was wrong. The _ ball was heavier\",\n",
+ " \"John was wrong in saying/thinking that the red ball was heavier than the blue ball. The _ ball was heavier\",\n",
+ " \"John said the rain was about to stop. Mary said the rain would continue. Later the rain stopped. _ was wrong/right.\",\n",
+ " \n",
+ " \"The trophy doesn't fit into the brown suitcase because/although the _ is too small/large.\",\n",
+ " \"John thanked Mary because _ had given help to _ . \",\n",
+ " \"John felt vindicated/crushed when his longtime rival Mary revealed that _ was the winner of the competition.\",\n",
+ " \"John couldn't see the stage with Mary in front of him because _ is so short/tall.\",\n",
+ " \"Although they ran at about the same speed, John beat Sally because _ had such a bad start.\",\n",
+ " \"The fish ate the worm. The _ was hungry/tasty.\",\n",
+ " \n",
+ " \"John beat Mary. _ won the game/e winner.\",\n",
+ "]\n",
+ "text"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "with open('WSC_switched_label.json') as f:\n",
+ " examples = json.load(f)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "with open('WSC_child_problem.json') as f:\n",
+ " cexamples = json.load(f)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "for ce in cexamples:\n",
+ " for s in ce['sentences']:\n",
+ " for a in s['answer0'] + s['answer1']:\n",
+ " a = a.lower()\n",
+ "# if a not in tokenizer.vocab:\n",
+ "# ce\n",
+ "# print(a, 'not in vocab!!!')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "for ce in cexamples:\n",
+ " if len(ce['sentences']) > 0:\n",
+ " e = examples[ce['index']]\n",
+ " assert ce['index'] == e['index']\n",
+ " e['score'] = all([s['score'] for s in ce['sentences']])\n",
+ " assert len(set([s['adjacent_ref'] for s in ce['sentences']])) == 1, 'adjcent_refs are different!'\n",
+ " e['adjacent_ref'] = ce['sentences'][0]['adjacent_ref']"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from collections import defaultdict\n",
+ "\n",
+ "groups = defaultdict(list)\n",
+ "for e in examples:\n",
+ " if 'score' in e:\n",
+ " index = e['index']\n",
+ " if index < 252:\n",
+ " if index % 2 == 1:\n",
+ " index -= 1\n",
+ " elif index in [252, 253, 254]:\n",
+ " index = 252\n",
+ " else:\n",
+ " if index % 2 == 0:\n",
+ " index -= 1\n",
+ " groups[index].append(e)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 41,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[(2,\n",
+ " \"The trophy doesn't fit into the brown suitcase because [it] is too large.\",\n",
+ " 'fit into:large/small'),\n",
+ " (4,\n",
+ " 'Joan made sure to thank Susan for all the help [she] had recieved.',\n",
+ " 'thank:receive/give'),\n",
+ " (10,\n",
+ " 'The delivery truck zoomed by the school bus because [it] was going so fast.',\n",
+ " 'zoom by:fast/slow'),\n",
+ " (12,\n",
+ " 'Frank felt vindicated when his longtime rival Bill revealed that [he] was the winner of the competition.',\n",
+ " 'vindicated/crushed:be the winner'),\n",
+ " (16,\n",
+ " 'The large ball crashed right through the table because [it] was made of steel.',\n",
+ " 'crash through:[hard]/[soft]'),\n",
+ " (18,\n",
+ " \"John couldn't see the stage with Billy in front of him because [he] is so short.\",\n",
+ " '[block]:short/tall'),\n",
+ " (20,\n",
+ " 'Tom threw his schoolbag down to Ray after [he] reached the top of the stairs.',\n",
+ " 'down to:top/bottom'),\n",
+ " (22,\n",
+ " 'Although they ran at about the same speed, Sue beat Sally because [she] had such a good start.',\n",
+ " 'beat:good/bad'),\n",
+ " (26,\n",
+ " \"Sam's drawing was hung just above Tina's and [it] did look much better with another one below it.\",\n",
+ " 'above/below'),\n",
+ " (28,\n",
+ " 'Anna did a lot better than her good friend Lucy on the test because [she] had studied so hard.',\n",
+ " 'better/worse:study hard'),\n",
+ " (30,\n",
+ " 'The firemen arrived after the police because [they] were coming from so far away.',\n",
+ " 'after/before:far away'),\n",
+ " (32,\n",
+ " \"Frank was upset with Tom because the toaster [he] had bought from him didn't work.\",\n",
+ " 'be upset with:buy from not work/sell not work'),\n",
+ " (36,\n",
+ " 'The sack of potatoes had been placed above the bag of flour, so [it] had to be moved first.',\n",
+ " 'above/below:moved first'),\n",
+ " (38,\n",
+ " 'Pete envies Martin although [he] is very successful.',\n",
+ " 'although/because'),\n",
+ " (42,\n",
+ " 'I poured water from the bottle into the cup until [it] was empty.',\n",
+ " 'pour:empty/full'),\n",
+ " (46,\n",
+ " \"Sid explained his theory to Mark but [he] couldn't convince him.\",\n",
+ " 'explain:convince/understand'),\n",
+ " (48,\n",
+ " \"Susan knew that Ann's son had been in a car accident, so [she] told her about it.\",\n",
+ " '?know tell:so/because'),\n",
+ " (50,\n",
+ " \"Joe's uncle can still beat him at tennis, even though [he] is 30 years younger.\",\n",
+ " 'beat:younger/older'),\n",
+ " (64,\n",
+ " 'In the middle of the outdoor concert, the rain started falling, but [it] continued until 10.',\n",
+ " 'but/and'),\n",
+ " (68,\n",
+ " 'Ann asked Mary what time the library closes, because [she] had forgotten.',\n",
+ " 'because/but'),\n",
+ " (84,\n",
+ " 'If the con artist has succeeded in fooling Sam, [he] would have gotten a lot of money.',\n",
+ " 'fool:get/lose'),\n",
+ " (92,\n",
+ " 'Alice tried frantically to stop her daughter from chatting at the party, leaving us to wonder why [she] was behaving so strangely.',\n",
+ " '?stop normal/stop abnormal:strange'),\n",
+ " (98,\n",
+ " \"I was trying to open the lock with the key, but someone had filled the keyhole with chewing gum, and I couldn't get [it] in.\",\n",
+ " 'put ... into filled with ... :get in/get out'),\n",
+ " (100,\n",
+ " 'The dog chased the cat, which ran up a tree. [It] waited at the bottom.',\n",
+ " 'up:at the bottom/at the top'),\n",
+ " (106,\n",
+ " 'John was doing research in the library when he heard a man humming and whistling. [He] was very annoyed.',\n",
+ " 'hear ... humming and whistling:annoyed/annoying'),\n",
+ " (108,\n",
+ " 'John was jogging through the park when he saw a man juggling watermelons. [He] was very impressed.',\n",
+ " 'see ... juggling watermelons:impressed/impressive'),\n",
+ " (132,\n",
+ " 'Jane knocked on the door, and Susan answered it. [She] invited her to come out.',\n",
+ " 'visit:invite come out/invite come in'),\n",
+ " (150,\n",
+ " 'Jackson was greatly influenced by Arnold, though [he] lived two centuries later.',\n",
+ " 'influence:later/earlier'),\n",
+ " (160,\n",
+ " 'The actress used to be named Terpsichore, but she changed it to Tina a few years ago, because she figured [it] was too hard to pronounce.',\n",
+ " 'change:hard/easy'),\n",
+ " (166,\n",
+ " 'Fred is the only man still alive who remembers my great-grandfather. [He] is a remarkable man.',\n",
+ " 'alive:is/was'),\n",
+ " (170,\n",
+ " \"In July, Kamtchatka declared war on Yakutsk. Since Yakutsk's army was much better equipped and ten times larger, [they] were defeated within weeks.\",\n",
+ " 'better equipped and large:defeated/victorious'),\n",
+ " (186,\n",
+ " 'When the sponsors of the bill got to the town hall, they were surprised to find that the room was full of opponents. [They] were very much in the minority.',\n",
+ " 'be full of:minority/majority'),\n",
+ " (188,\n",
+ " 'Everyone really loved the oatmeal cookies; only a few people liked the chocolate chip cookies. Next time, we should make more of [them] .',\n",
+ " 'like over:more/fewer'),\n",
+ " (190,\n",
+ " 'We had hoped to place copies of our newsletter on all the chairs in the auditorium, but there were simply not enough of [them] .',\n",
+ " 'place on all:not enough/too many'),\n",
+ " (196,\n",
+ " \"Steve follows Fred's example in everything. [He] admires him hugely.\",\n",
+ " 'follow:admire/influence'),\n",
+ " (198,\n",
+ " \"The table won't fit through the doorway because [it] is too wide.\",\n",
+ " 'fit through:wide/narrow'),\n",
+ " (200,\n",
+ " 'Grace was happy to trade me her sweater for my jacket. She thinks [it] looks dowdy on her.',\n",
+ " 'trade:dowdy/great'),\n",
+ " (202,\n",
+ " 'John hired Bill to take care of [him] .',\n",
+ " 'hire/hire oneself to:take care of'),\n",
+ " (204,\n",
+ " 'John promised Bill to leave, so an hour later [he] left.',\n",
+ " 'promise/order'),\n",
+ " (210,\n",
+ " \"Jane knocked on Susan's door but [she] did not get an answer.\",\n",
+ " 'knock:get an answer/answer'),\n",
+ " (212,\n",
+ " 'Joe paid the detective after [he] received the final report on the case.',\n",
+ " 'pay:receive/deliver'),\n",
+ " (226,\n",
+ " 'Bill passed the half-empty plate to John because [he] was full.',\n",
+ " 'pass the plate:full/hungry'),\n",
+ " (252,\n",
+ " 'George got free tickets to the play, but he gave them to Eric, even though [he] was particularly eager to see it.',\n",
+ " 'even though/because/not'),\n",
+ " (255,\n",
+ " \"Jane gave Joan candy because [she] wasn't hungry.\",\n",
+ " 'give:not hungry/hungry'),\n",
+ " (259,\n",
+ " 'James asked Robert for a favor but [he] was refused.',\n",
+ " 'ask for a favor:refuse/be refused`'),\n",
+ " (261,\n",
+ " 'Kirilov ceded the presidency to Shatov because [he] was less popular.',\n",
+ " 'cede:less popular/more popular'),\n",
+ " (263,\n",
+ " 'Emma did not pass the ball to Janie although [she] saw that she was open.',\n",
+ " 'not pass although:see open/open')]"
+ ]
+ },
+ "execution_count": 41,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "def filter_dict(d, keys=['index', 'sentence', 'correct_answer', 'relational_word', 'is_associative', 'score']):\n",
+ " return {k: d[k] for k in d if k in keys}\n",
+ "\n",
+ "# ([[filter_dict(e) for e in eg] for eg in groups.values() if eg[0]['relational_word'] != 'none' and all([e['score'] for e in eg])])# / len([eg for eg in groups.values() if eg[0]['relational_word'] != 'none'])\n",
+ "# [(index, eg[0]['relational_word'], all([e['score'] for e in eg])) for index, eg in groups.items() if eg[0]['relational_word'] != 'none']\n",
+ "# len([filter_dict(e) for e in examples if 'score' in e and not e['score'] and e['adjacent_ref']])\n",
+ "# for e in examples:\n",
+ "# if e['index'] % 2 == 0:\n",
+ "# print(e['sentence'])\n",
+ "[(eg[0]['index'], eg[0]['sentence'], eg[0]['relational_word']) for index, eg in groups.items() if '/' in eg[0]['relational_word']]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 51,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "179"
+ ]
+ },
+ "execution_count": 51,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "sum(['because' in e['sentence'] for e in examples]) + \\\n",
+ "sum(['so ' in e['sentence'] for e in examples]) + \\\n",
+ "sum(['but ' in e['sentence'] for e in examples]) + \\\n",
+ "sum(['though' in e['sentence'] for e in examples])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 73,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# with open('WSC_switched_label.json', 'w') as f:\n",
+ "# json.dump(examples, f)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "vis_attn_topk = 3\n",
+ "\n",
+ "def has_chinese_label(labels):\n",
+ " labels = [label.split('->')[0].strip() for label in labels]\n",
+ " r = sum([len(label) > 1 for label in labels if label not in ['BOS', 'EOS']]) * 1. / (len(labels) - 1)\n",
+ " return 0 < r < 0.5 # r == 0 means empty query labels used in self attention\n",
+ "\n",
+ "def _plot_attn(ax1, attn_name, attn, key_labels, query_labels, col, color='b'):\n",
+ " assert len(query_labels) == attn.size(0)\n",
+ " assert len(key_labels) == attn.size(1)\n",
+ "\n",
+ " ax1.set_xlim([-1, 1])\n",
+ " ax1.set_xticks([])\n",
+ " ax2 = ax1.twinx()\n",
+ " nlabels = max(len(key_labels), len(query_labels))\n",
+ " pos = range(nlabels)\n",
+ " \n",
+ " if 'self' in attn_name and col < ncols - 1:\n",
+ " query_labels = ['' for _ in query_labels]\n",
+ "\n",
+ " for ax, labels in [(ax1, key_labels), (ax2, query_labels)]:\n",
+ " ax.set_yticks(pos)\n",
+ " if has_chinese_label(labels):\n",
+ " ax.set_yticklabels(labels, fontproperties=zhfont)\n",
+ " else:\n",
+ " ax.set_yticklabels(labels)\n",
+ " ax.set_ylim([nlabels - 1, 0])\n",
+ " ax.tick_params(width=0, labelsize='xx-large')\n",
+ "\n",
+ " for spine in ax.spines.values():\n",
+ " spine.set_visible(False)\n",
+ "\n",
+ "# mask, attn = filter_attn(attn)\n",
+ " for qi in range(attn.size(0)):\n",
+ "# if not mask[qi]:\n",
+ "# continue\n",
+ "# for ki in range(attn.size(1)):\n",
+ " for ki in attn[qi].topk(vis_attn_topk)[1]:\n",
+ " a = attn[qi, ki]\n",
+ " ax1.plot((-1, 1), (ki, qi), color, alpha=a)\n",
+ "# print(attn.mean(dim=0).topk(5)[0])\n",
+ "# ax1.barh(pos, attn.mean(dim=0).data.cpu().numpy())\n",
+ "\n",
+ "def plot_layer_attn(result_tuple, attn_name='dec_self_attns', layer=0, heads=None):\n",
+ " hypo, nheads, labels_dict = result_tuple\n",
+ " key_labels, query_labels = labels_dict[attn_name]\n",
+ " if heads is None:\n",
+ " heads = range(nheads)\n",
+ " else:\n",
+ " nheads = len(heads)\n",
+ " \n",
+ " stride = 2 if attn_name == 'dec_enc_attns' else 1\n",
+ " nlabels = max(len(key_labels), len(query_labels))\n",
+ " rcParams['figure.figsize'] = 20, int(round(nlabels * stride * nheads / 8 * 1.0))\n",
+ " \n",
+ " rows = nheads // ncols * stride\n",
+ " fig, axes = plt.subplots(rows, ncols)\n",
+ " \n",
+ " # for head in range(nheads):\n",
+ " for head_i, head in enumerate(heads):\n",
+ " row, col = head_i * stride // ncols, head_i * stride % ncols\n",
+ " ax1 = axes[row, col]\n",
+ " attn = hypo[attn_name][layer][head]\n",
+ " _plot_attn(ax1, attn_name, attn, key_labels, query_labels, col)\n",
+ " if attn_name == 'dec_enc_attns':\n",
+ " col = col + 1\n",
+ " axes[row, col].axis('off') # next subfig acts as blank place holder\n",
+ " # plt.suptitle('%s with %d heads, Layer %d' % (attn_name, nheads, layer), fontsize=20)\n",
+ " plt.show() \n",
+ " \n",
+ "ncols = 4"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 40,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{\n",
+ " \"attention_probs_dropout_prob\": 0.1,\n",
+ " \"hidden_act\": \"gelu\",\n",
+ " \"hidden_dropout_prob\": 0.1,\n",
+ " \"hidden_size\": 768,\n",
+ " \"initializer_range\": 0.02,\n",
+ " \"intermediate_size\": 3072,\n",
+ " \"max_position_embeddings\": 512,\n",
+ " \"num_attention_heads\": 12,\n",
+ " \"num_hidden_layers\": 12,\n",
+ " \"type_vocab_size\": 2,\n",
+ " \"vocab_size\": 30522\n",
+ "}"
+ ]
+ },
+ "execution_count": 40,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "config.num"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.5"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/Untitled1.ipynb b/Untitled1.ipynb
new file mode 100644
index 00000000000000..0a6ceec8cab0b2
--- /dev/null
+++ b/Untitled1.ipynb
@@ -0,0 +1,2971 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%load_ext autoreload\n",
+ "%autoreload 2\n",
+ "\n",
+ "from IPython.core.interactiveshell import InteractiveShell\n",
+ "InteractiveShell.ast_node_interactivity = 'all'"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "import json\n",
+ "import itertools\n",
+ "from itertools import product, permutations\n",
+ "from random import sample"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "from torch.utils.data import DataLoader, RandomSampler, SequentialSampler"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Warning: apex was installed without --cpp_ext. Falling back to Python flatten and unflatten.\n",
+ "Warning: apex was installed without --cuda_ext. Fused syncbn kernels will be unavailable. Python fallbacks will be used instead.\n",
+ "Warning: apex was installed without --cuda_ext. FusedAdam will be unavailable.\n",
+ "Warning: apex was installed without --cuda_ext. FusedLayerNorm will be unavailable.\n",
+ "Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.\n"
+ ]
+ }
+ ],
+ "source": [
+ "from pytorch_pretrained_bert.tokenization import BertTokenizer\n",
+ "from pytorch_pretrained_bert.modeling import BertForPreTraining, BertForMaskedLM, BertConfig\n",
+ "from pytorch_pretrained_bert.optimization import BertAdam\n",
+ "from run_child_finetuning import *"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "06/09/2019 14:55:34 - INFO - pytorch_pretrained_bert.tokenization - loading vocabulary file /nas/pretrain-bert/pretrain-pytorch/bert-base-uncased-vocab.txt\n"
+ ]
+ }
+ ],
+ "source": [
+ "BERT_DIR = '/nas/pretrain-bert/pretrain-pytorch/bert-base-uncased'\n",
+ "tokenizer = BertTokenizer.from_pretrained('/nas/pretrain-bert/pretrain-pytorch/bert-base-uncased-vocab.txt')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def assert_in_bert_vocab(tokens):\n",
+ " for token in tokens:\n",
+ " if isinstance(token, str): # entities\n",
+ " assert token.lower() in tokenizer.vocab, token + '->' + str(tokenizer.tokenize(token))\n",
+ " elif isinstance(token, tuple): # relations\n",
+ " assert len(token) == 2, str(token)\n",
+ " for rel in token:\n",
+ " rel = rel.split('..')[0]\n",
+ " assert rel in tokenizer.vocab, rel + '->' + str(tokenizer.tokenize(rel))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "19"
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "fruits = ['apple', 'banana', 'pear', 'orange', 'peach', 'berry', 'plum', 'pinapple', 'melon', 'cherry', 'grape', 'lemon',\n",
+ " 'papaya', 'durian', 'kiwi', 'mongo', 'date', 'jujube', 'watermelon']\n",
+ "len(fruits)\n",
+ "# http://www.manythings.org/vocabulary/lists/e/words.php?f=fruit"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "16"
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "animals = ['dog', 'cat', 'pig', 'chicken', 'hen', 'cock', 'duck', 'goose', 'monkey', 'tiger', 'bird', 'bear', 'lion', 'bee', 'ant', 'elephant']\n",
+ "len(animals)\n",
+ "# see more at http://www.manythings.org/vocabulary/lists/a/words.php?f=animals_1\n",
+ "# http://www.manythings.org/vocabulary/lists/a/\n",
+ "# especially http://www.manythings.org/vocabulary/lists/a/words.php?f=classroom_1 things in classroom"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "3"
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ },
+ {
+ "data": {
+ "text/plain": [
+ "3"
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "male_names = ['James', 'John', 'Robert', ]#'Michael', 'David', 'Paul', 'Jeff', 'Daniel', 'Charles', 'Thomas']\n",
+ "female_names = ['Mary', 'Linda', 'Jennifer', ]#'Maria', 'Susan', 'Lisa', 'Sandra', 'Barbara', 'Patricia', 'Elizabeth']\n",
+ "len(male_names)\n",
+ "len(female_names)\n",
+ "people_names = (male_names, female_names)\n",
+ "assert_in_bert_vocab(male_names)\n",
+ "assert_in_bert_vocab(female_names)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "spatial_relations = (\n",
+ " ('above', 'below'), \n",
+ " ('in front of/in the front', 'behind/in the back'), \n",
+ " ('on the left..side of', 'on the right..side of')\n",
+ ")\n",
+ "people_adj_relations = (\n",
+ " ('taller..than', 'shorter..than'), \n",
+ "# ('thinner..than', 'fatter..than'), # fatter not in BERT vocab\n",
+ " ('younger..than', 'older..than'), \n",
+ "# ('stronger..than', 'weaker..than'), \n",
+ "# ('faster..than', 'slower..than'),\n",
+ "# ('richer..than', 'poorer..than')\n",
+ ")\n",
+ "animal_adj_relations = (\n",
+ " ('thinner..than', 'fatter..than'), \n",
+ " ('younger..than', 'older..than'), \n",
+ " ('stronger..than', 'weaker..than'), \n",
+ " ('faster..than', 'slower..than')\n",
+ ")\n",
+ "object_adj_relations = (\n",
+ " ('bigger..than', 'smaller..than'), \n",
+ " ('heavier..than', 'lighter..than'), \n",
+ " ('better..than', 'worse..than')\n",
+ ")\n",
+ "assert_in_bert_vocab(people_adj_relations)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "rel2entypes = {\n",
+ "# spatial_relations: [fruits, animals, people_names],\n",
+ " people_adj_relations: [people_names],\n",
+ "# animal_adj_relations: [animals],\n",
+ "# object_adj_relations: [fruits, animals]\n",
+ "}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "twoent_A_template = 'is {dt} {ent0} {rel} {dt} {ent1}'\n",
+ "twoent_B_template = '{dt} {ent} is {pred}'\n",
+ "twoent_template = '\"{A}?\" \"{conj} {B}.\"'"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def reverse(l):\n",
+ " return list(reversed(l)) if isinstance(l, list) else tuple(reversed(l))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def mask(ent_str):\n",
+ " tokens = ent_str.strip().split()\n",
+ " if len(tokens) == 1:\n",
+ " return '[%s]' % tokens[0]\n",
+ " elif len(tokens) == 2:\n",
+ " assert tokens[0] == 'the', ent_str\n",
+ " return '%s [%s]' % (tokens[0], tokens[1])\n",
+ " else:\n",
+ " assert False, ent_str"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def get_conj(join_type, A, B):\n",
+ " if join_type == 'no':\n",
+ " return 'no,'\n",
+ " return 'yes,'\n",
+ " assert join_type == 'yes'\n",
+ " subB = B.split('is')[0].split()[-1]\n",
+ " w0, w1, w2 = A.split()[: 3]\n",
+ " assert w0 == 'Is'\n",
+ " subA = w1 if w1 != 'the' else w2\n",
+ " if subA == subB and 'not' not in B: # B is repeating A\n",
+ " return 'Yes,'\n",
+ " else:\n",
+ " return 'Yes, in other words,'"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 134,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def make_sentences(A_template, B_template, join_template,\n",
+ " index=-1, orig_sentence='', entities=[\"John\", \"Mary\"], entity_substitutes=None, determiner=\"\", \n",
+ " relations=[],\n",
+ " packed_relations=[\"rel/~rel\", \"rev_rel/~rev_rel\"], packed_relation_substitutes=None, relation_suffix=\"\",\n",
+ " packed_predicates=[\"pred0/~pred0\", \"pred1/~pred1\"], predicate_substitutes=None,\n",
+ " predicate_dichotomy=True, reverse_causal=False):\n",
+ "# assert entities[0].lower() in tokenizer.vocab , entities[0]\n",
+ "# assert entities[1].lower() in tokenizer.vocab , entities[1]\n",
+ " determiner = 'the' if entities[0].islower() else ''\n",
+ " relations, predicates = ([r.replace('..', ' ') for r in relations], [r.split('..')[0] for r in relations]) \\\n",
+ " if '..' in relations[0] else ([r.split('/')[0] for r in relations], [r.split('/')[-1] for r in relations])\n",
+ " neg_predicates = ['not ' + p for p in predicates]\n",
+ " As = [A_template.format(dt=determiner, ent0=ent0, ent1=ent1, rel=rel, rel_suffix=relation_suffix) \n",
+ " for ent0, ent1, rel in [entities + relations[:1], reverse(entities) + reverse(relations)[:1]]]\n",
+ " negAs = [A_template.format(dt=determiner, ent0=ent0, ent1=ent1, rel=rel, rel_suffix=relation_suffix) \n",
+ " for ent0, ent1, rel in [entities + reverse(relations)[:1], reverse(entities) + relations[:1]]]\n",
+ " \n",
+ " Bs = [B_template.format(dt=determiner, ent=mask(ent), pred=pred) for ent, pred in zip(entities, predicates)]\n",
+ " negBs = [B_template.format(dt=determiner, ent=mask(ent), pred=pred) for ent, pred in zip(entities, neg_predicates)]\n",
+ " if predicate_dichotomy:\n",
+ " Bs += [B_template.format(dt=determiner, ent=mask(ent), pred=pred) for ent, pred in zip(entities, reversed(neg_predicates))]\n",
+ " negBs += [B_template.format(dt=determiner, ent=mask(ent), pred=pred) for ent, pred in zip(entities, reversed(predicates))]\n",
+ " \n",
+ " def form_sentences(sentence_template, join_type, As, Bs):\n",
+ " return [\" \".join(sentence_template.format(A=A, B=B, conj=get_conj(join_type, A, B)).split()) for A, B in itertools.product(As, Bs)]\n",
+ " \n",
+ " yes_sentences = []\n",
+ " for A, B in [(As, Bs), (negAs, negBs)]:\n",
+ " yes_sentences += form_sentences(join_template, 'yes', A, B)\n",
+ "# yes_sentences = list(itertools.chain.from_iterable([form_sentences(join_template, 'yes', A, B) for A, B in [(As, Bs), (negAs, negBs)]]))\n",
+ "\n",
+ " no_sentences = []\n",
+ " for A, B in [(As, negBs), (negAs, Bs)]:\n",
+ " no_sentences += form_sentences(join_template, 'no', A, B)\n",
+ " \n",
+ " return yes_sentences + no_sentences\n",
+ " \n",
+ "# make_sentences(\n",
+ "# twoent_A_template, twoent_B_template, twoent_template, entities=['apple', 'banana'], determiner='', relations=['taller..than', 'shorter..than'])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 180,
+ "metadata": {},
+ "outputs": [
+ {
+ "ename": "NameError",
+ "evalue": "name 'make_sentences' is not defined",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
+ "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0;31m# yes_sent, no_sent = make_sentences(twoent_A_template, twoent_B_template, twoent_template, entities=list(ent_pair), relations=rel)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0;31m# sentences += (yes_sent + no_sent)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 15\u001b[0;31m \u001b[0msentences\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mmake_sentences\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtwoent_A_template\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtwoent_B_template\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtwoent_template\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mentities\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ment_pair\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrelations\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mrel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 16\u001b[0m \u001b[0msample\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msentences\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m20\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[0msentence_groups\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msentences\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;31mNameError\u001b[0m: name 'make_sentences' is not defined"
+ ]
+ }
+ ],
+ "source": [
+ "sentence_groups = []\n",
+ "for relations, entity_types in rel2entypes.items():\n",
+ " sentences = []\n",
+ " ent_pairs = []\n",
+ " for entities in entity_types:\n",
+ " if isinstance(entities, list):\n",
+ " ent_pairs += permutations(entities, 2)\n",
+ " else:\n",
+ " assert isinstance(entities, tuple) and len(entities) == 2 # people_names\n",
+ " ent_pairs += product(entities[0], entities[1])\n",
+ " ent_pairs += product(entities[1], entities[0])\n",
+ " for (rel, ent_pair) in product(relations, ent_pairs):\n",
+ "# yes_sent, no_sent = make_sentences(twoent_A_template, twoent_B_template, twoent_template, entities=list(ent_pair), relations=rel)\n",
+ "# sentences += (yes_sent + no_sent)\n",
+ " sentences += make_sentences(twoent_A_template, twoent_B_template, twoent_template, entities=list(ent_pair), relations=rel)\n",
+ " sample(sentences, 20)\n",
+ " sentence_groups.append(sentences)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 115,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "4"
+ ]
+ },
+ "execution_count": 115,
+ "metadata": {},
+ "output_type": "execute_result"
+ },
+ {
+ "data": {
+ "text/plain": [
+ "[78432, 38400, 32768, 59232]"
+ ]
+ },
+ "execution_count": 115,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "len(sentence_groups)\n",
+ "[len(sg) for sg in sentence_groups]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def comparative2superlative(comparative_form, structured=False):\n",
+ " assert comparative_form.endswith('er'), comparative_form\n",
+ " superlative_form = 'the ' + comparative_form[:-2] + 'est' \\\n",
+ " if not structured else 'the ' + comparative_form + ' st'\n",
+ " return superlative_form"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def make_relational_atoms(relational_template, entities, relations):\n",
+ " neg_relations = [\"isn't \" + r for r in relations]\n",
+ " relations = [\"is \" + r for r in relations]\n",
+ " atoms = [relational_template.format(ent0=ent0, ent1=ent1, rel=rel) \n",
+ " for ent0, ent1, rel in [entities + relations[:1], reverse(entities) + reverse(relations)[:1]]]\n",
+ " atoms += [relational_template.format(ent0=ent0, ent1=ent1, rel=rel) \n",
+ " for ent0, ent1, rel in [entities + reverse(neg_relations)[:1], reverse(entities) + neg_relations[:1]]]\n",
+ " return atoms"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 44,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "['John is taller than Mary . Mary is taller than Susan . ||| is Susan shorter than John ? [yes] .',\n",
+ " 'John is taller than Mary . Susan is shorter than Mary . ||| is Susan shorter than John ? [yes] .',\n",
+ " \"Mary isn't taller than John . Mary is taller than Susan . ||| is Susan shorter than John ? [yes] .\",\n",
+ " \"Susan isn't taller than Mary . John is taller than Mary . ||| who is the tallest ? [John] .\",\n",
+ " \"John is taller than Mary . Susan isn't taller than Mary . ||| is John shorter than Susan ? [no] .\",\n",
+ " \"Mary is shorter than John . Mary isn't shorter than Susan . ||| is Susan shorter than John ? [yes] .\",\n",
+ " \"John isn't shorter than Mary . Mary is taller than Susan . ||| is Susan taller than John ? [no] .\",\n",
+ " \"Mary is taller than Susan . John isn't shorter than Mary . ||| is Susan taller than John ? [no] .\",\n",
+ " \"Mary is shorter than John . Susan isn't taller than Mary . ||| is Susan taller than John ? [no] .\",\n",
+ " \"Mary is shorter than John . Susan isn't taller than Mary . ||| who is the shortest ? [Susan] .\",\n",
+ " 'John is taller than Mary . Susan is shorter than Mary . ||| is John taller than Susan ? [yes] .',\n",
+ " 'Mary is taller than Susan . John is taller than Mary . ||| is Susan shorter than John ? [yes] .',\n",
+ " \"Mary isn't shorter than Susan . John isn't shorter than Mary . ||| who is the tallest ? [John] .\",\n",
+ " \"Susan is shorter than Mary . John isn't shorter than Mary . ||| is John taller than Susan ? [yes] .\",\n",
+ " \"Mary isn't taller than John . Mary is taller than Susan . ||| is John shorter than Susan ? [no] .\",\n",
+ " \"John isn't shorter than Mary . Susan is shorter than Mary . ||| is Susan taller than John ? [no] .\",\n",
+ " \"Mary isn't shorter than Susan . Mary is shorter than John . ||| is Susan taller than John ? [no] .\",\n",
+ " \"Mary is shorter than John . Susan isn't taller than Mary . ||| who is the tallest ? [John] .\",\n",
+ " \"Susan isn't taller than Mary . Mary isn't taller than John . ||| is John taller than Susan ? [yes] .\",\n",
+ " 'John is taller than Mary . Susan is shorter than Mary . ||| is Susan taller than John ? [no] .']"
+ ]
+ },
+ "execution_count": 44,
+ "metadata": {},
+ "output_type": "execute_result"
+ },
+ {
+ "data": {
+ "text/plain": [
+ "['John is taller than Mary . Mary is taller than Susan . ||| who is the tallest ? [John] .',\n",
+ " 'John is taller than Mary . Mary is taller than Susan . ||| who is the shortest ? [Susan] .',\n",
+ " 'John is taller than Mary . Mary is taller than Susan . ||| is John taller than Susan ? [yes] .',\n",
+ " 'John is taller than Mary . Mary is taller than Susan . ||| is John shorter than Susan ? [no] .',\n",
+ " 'John is taller than Mary . Mary is taller than Susan . ||| is Susan shorter than John ? [yes] .',\n",
+ " 'John is taller than Mary . Mary is taller than Susan . ||| is Susan taller than John ? [no] .',\n",
+ " 'John is taller than Mary . Susan is shorter than Mary . ||| who is the tallest ? [John] .',\n",
+ " 'John is taller than Mary . Susan is shorter than Mary . ||| who is the shortest ? [Susan] .',\n",
+ " 'John is taller than Mary . Susan is shorter than Mary . ||| is John taller than Susan ? [yes] .',\n",
+ " 'John is taller than Mary . Susan is shorter than Mary . ||| is John shorter than Susan ? [no] .',\n",
+ " 'John is taller than Mary . Susan is shorter than Mary . ||| is Susan shorter than John ? [yes] .',\n",
+ " 'John is taller than Mary . Susan is shorter than Mary . ||| is Susan taller than John ? [no] .',\n",
+ " \"John is taller than Mary . Mary isn't shorter than Susan . ||| who is the tallest ? [John] .\",\n",
+ " \"John is taller than Mary . Mary isn't shorter than Susan . ||| who is the shortest ? [Susan] .\",\n",
+ " \"John is taller than Mary . Mary isn't shorter than Susan . ||| is John taller than Susan ? [yes] .\",\n",
+ " \"John is taller than Mary . Mary isn't shorter than Susan . ||| is John shorter than Susan ? [no] .\",\n",
+ " \"John is taller than Mary . Mary isn't shorter than Susan . ||| is Susan shorter than John ? [yes] .\",\n",
+ " \"John is taller than Mary . Mary isn't shorter than Susan . ||| is Susan taller than John ? [no] .\",\n",
+ " \"John is taller than Mary . Susan isn't taller than Mary . ||| who is the tallest ? [John] .\",\n",
+ " \"John is taller than Mary . Susan isn't taller than Mary . ||| who is the shortest ? [Susan] .\",\n",
+ " \"John is taller than Mary . Susan isn't taller than Mary . ||| is John taller than Susan ? [yes] .\",\n",
+ " \"John is taller than Mary . Susan isn't taller than Mary . ||| is John shorter than Susan ? [no] .\",\n",
+ " \"John is taller than Mary . Susan isn't taller than Mary . ||| is Susan shorter than John ? [yes] .\",\n",
+ " \"John is taller than Mary . Susan isn't taller than Mary . ||| is Susan taller than John ? [no] .\",\n",
+ " 'Mary is shorter than John . Mary is taller than Susan . ||| who is the tallest ? [John] .',\n",
+ " 'Mary is shorter than John . Mary is taller than Susan . ||| who is the shortest ? [Susan] .',\n",
+ " 'Mary is shorter than John . Mary is taller than Susan . ||| is John taller than Susan ? [yes] .',\n",
+ " 'Mary is shorter than John . Mary is taller than Susan . ||| is John shorter than Susan ? [no] .',\n",
+ " 'Mary is shorter than John . Mary is taller than Susan . ||| is Susan shorter than John ? [yes] .',\n",
+ " 'Mary is shorter than John . Mary is taller than Susan . ||| is Susan taller than John ? [no] .',\n",
+ " 'Mary is shorter than John . Susan is shorter than Mary . ||| who is the tallest ? [John] .',\n",
+ " 'Mary is shorter than John . Susan is shorter than Mary . ||| who is the shortest ? [Susan] .',\n",
+ " 'Mary is shorter than John . Susan is shorter than Mary . ||| is John taller than Susan ? [yes] .',\n",
+ " 'Mary is shorter than John . Susan is shorter than Mary . ||| is John shorter than Susan ? [no] .',\n",
+ " 'Mary is shorter than John . Susan is shorter than Mary . ||| is Susan shorter than John ? [yes] .',\n",
+ " 'Mary is shorter than John . Susan is shorter than Mary . ||| is Susan taller than John ? [no] .',\n",
+ " \"Mary is shorter than John . Mary isn't shorter than Susan . ||| who is the tallest ? [John] .\",\n",
+ " \"Mary is shorter than John . Mary isn't shorter than Susan . ||| who is the shortest ? [Susan] .\",\n",
+ " \"Mary is shorter than John . Mary isn't shorter than Susan . ||| is John taller than Susan ? [yes] .\",\n",
+ " \"Mary is shorter than John . Mary isn't shorter than Susan . ||| is John shorter than Susan ? [no] .\",\n",
+ " \"Mary is shorter than John . Mary isn't shorter than Susan . ||| is Susan shorter than John ? [yes] .\",\n",
+ " \"Mary is shorter than John . Mary isn't shorter than Susan . ||| is Susan taller than John ? [no] .\",\n",
+ " \"Mary is shorter than John . Susan isn't taller than Mary . ||| who is the tallest ? [John] .\",\n",
+ " \"Mary is shorter than John . Susan isn't taller than Mary . ||| who is the shortest ? [Susan] .\",\n",
+ " \"Mary is shorter than John . Susan isn't taller than Mary . ||| is John taller than Susan ? [yes] .\",\n",
+ " \"Mary is shorter than John . Susan isn't taller than Mary . ||| is John shorter than Susan ? [no] .\",\n",
+ " \"Mary is shorter than John . Susan isn't taller than Mary . ||| is Susan shorter than John ? [yes] .\",\n",
+ " \"Mary is shorter than John . Susan isn't taller than Mary . ||| is Susan taller than John ? [no] .\",\n",
+ " \"John isn't shorter than Mary . Mary is taller than Susan . ||| who is the tallest ? [John] .\",\n",
+ " \"John isn't shorter than Mary . Mary is taller than Susan . ||| who is the shortest ? [Susan] .\",\n",
+ " \"John isn't shorter than Mary . Mary is taller than Susan . ||| is John taller than Susan ? [yes] .\",\n",
+ " \"John isn't shorter than Mary . Mary is taller than Susan . ||| is John shorter than Susan ? [no] .\",\n",
+ " \"John isn't shorter than Mary . Mary is taller than Susan . ||| is Susan shorter than John ? [yes] .\",\n",
+ " \"John isn't shorter than Mary . Mary is taller than Susan . ||| is Susan taller than John ? [no] .\",\n",
+ " \"John isn't shorter than Mary . Susan is shorter than Mary . ||| who is the tallest ? [John] .\",\n",
+ " \"John isn't shorter than Mary . Susan is shorter than Mary . ||| who is the shortest ? [Susan] .\",\n",
+ " \"John isn't shorter than Mary . Susan is shorter than Mary . ||| is John taller than Susan ? [yes] .\",\n",
+ " \"John isn't shorter than Mary . Susan is shorter than Mary . ||| is John shorter than Susan ? [no] .\",\n",
+ " \"John isn't shorter than Mary . Susan is shorter than Mary . ||| is Susan shorter than John ? [yes] .\",\n",
+ " \"John isn't shorter than Mary . Susan is shorter than Mary . ||| is Susan taller than John ? [no] .\",\n",
+ " \"John isn't shorter than Mary . Mary isn't shorter than Susan . ||| who is the tallest ? [John] .\",\n",
+ " \"John isn't shorter than Mary . Mary isn't shorter than Susan . ||| who is the shortest ? [Susan] .\",\n",
+ " \"John isn't shorter than Mary . Mary isn't shorter than Susan . ||| is John taller than Susan ? [yes] .\",\n",
+ " \"John isn't shorter than Mary . Mary isn't shorter than Susan . ||| is John shorter than Susan ? [no] .\",\n",
+ " \"John isn't shorter than Mary . Mary isn't shorter than Susan . ||| is Susan shorter than John ? [yes] .\",\n",
+ " \"John isn't shorter than Mary . Mary isn't shorter than Susan . ||| is Susan taller than John ? [no] .\",\n",
+ " \"John isn't shorter than Mary . Susan isn't taller than Mary . ||| who is the tallest ? [John] .\",\n",
+ " \"John isn't shorter than Mary . Susan isn't taller than Mary . ||| who is the shortest ? [Susan] .\",\n",
+ " \"John isn't shorter than Mary . Susan isn't taller than Mary . ||| is John taller than Susan ? [yes] .\",\n",
+ " \"John isn't shorter than Mary . Susan isn't taller than Mary . ||| is John shorter than Susan ? [no] .\",\n",
+ " \"John isn't shorter than Mary . Susan isn't taller than Mary . ||| is Susan shorter than John ? [yes] .\",\n",
+ " \"John isn't shorter than Mary . Susan isn't taller than Mary . ||| is Susan taller than John ? [no] .\",\n",
+ " \"Mary isn't taller than John . Mary is taller than Susan . ||| who is the tallest ? [John] .\",\n",
+ " \"Mary isn't taller than John . Mary is taller than Susan . ||| who is the shortest ? [Susan] .\",\n",
+ " \"Mary isn't taller than John . Mary is taller than Susan . ||| is John taller than Susan ? [yes] .\",\n",
+ " \"Mary isn't taller than John . Mary is taller than Susan . ||| is John shorter than Susan ? [no] .\",\n",
+ " \"Mary isn't taller than John . Mary is taller than Susan . ||| is Susan shorter than John ? [yes] .\",\n",
+ " \"Mary isn't taller than John . Mary is taller than Susan . ||| is Susan taller than John ? [no] .\",\n",
+ " \"Mary isn't taller than John . Susan is shorter than Mary . ||| who is the tallest ? [John] .\",\n",
+ " \"Mary isn't taller than John . Susan is shorter than Mary . ||| who is the shortest ? [Susan] .\",\n",
+ " \"Mary isn't taller than John . Susan is shorter than Mary . ||| is John taller than Susan ? [yes] .\",\n",
+ " \"Mary isn't taller than John . Susan is shorter than Mary . ||| is John shorter than Susan ? [no] .\",\n",
+ " \"Mary isn't taller than John . Susan is shorter than Mary . ||| is Susan shorter than John ? [yes] .\",\n",
+ " \"Mary isn't taller than John . Susan is shorter than Mary . ||| is Susan taller than John ? [no] .\",\n",
+ " \"Mary isn't taller than John . Mary isn't shorter than Susan . ||| who is the tallest ? [John] .\",\n",
+ " \"Mary isn't taller than John . Mary isn't shorter than Susan . ||| who is the shortest ? [Susan] .\",\n",
+ " \"Mary isn't taller than John . Mary isn't shorter than Susan . ||| is John taller than Susan ? [yes] .\",\n",
+ " \"Mary isn't taller than John . Mary isn't shorter than Susan . ||| is John shorter than Susan ? [no] .\",\n",
+ " \"Mary isn't taller than John . Mary isn't shorter than Susan . ||| is Susan shorter than John ? [yes] .\",\n",
+ " \"Mary isn't taller than John . Mary isn't shorter than Susan . ||| is Susan taller than John ? [no] .\",\n",
+ " \"Mary isn't taller than John . Susan isn't taller than Mary . ||| who is the tallest ? [John] .\",\n",
+ " \"Mary isn't taller than John . Susan isn't taller than Mary . ||| who is the shortest ? [Susan] .\",\n",
+ " \"Mary isn't taller than John . Susan isn't taller than Mary . ||| is John taller than Susan ? [yes] .\",\n",
+ " \"Mary isn't taller than John . Susan isn't taller than Mary . ||| is John shorter than Susan ? [no] .\",\n",
+ " \"Mary isn't taller than John . Susan isn't taller than Mary . ||| is Susan shorter than John ? [yes] .\",\n",
+ " \"Mary isn't taller than John . Susan isn't taller than Mary . ||| is Susan taller than John ? [no] .\",\n",
+ " 'Mary is taller than Susan . John is taller than Mary . ||| who is the tallest ? [John] .',\n",
+ " 'Mary is taller than Susan . John is taller than Mary . ||| who is the shortest ? [Susan] .',\n",
+ " 'Mary is taller than Susan . John is taller than Mary . ||| is John taller than Susan ? [yes] .',\n",
+ " 'Mary is taller than Susan . John is taller than Mary . ||| is John shorter than Susan ? [no] .',\n",
+ " 'Mary is taller than Susan . John is taller than Mary . ||| is Susan shorter than John ? [yes] .',\n",
+ " 'Mary is taller than Susan . John is taller than Mary . ||| is Susan taller than John ? [no] .',\n",
+ " 'Mary is taller than Susan . Mary is shorter than John . ||| who is the tallest ? [John] .',\n",
+ " 'Mary is taller than Susan . Mary is shorter than John . ||| who is the shortest ? [Susan] .',\n",
+ " 'Mary is taller than Susan . Mary is shorter than John . ||| is John taller than Susan ? [yes] .',\n",
+ " 'Mary is taller than Susan . Mary is shorter than John . ||| is John shorter than Susan ? [no] .',\n",
+ " 'Mary is taller than Susan . Mary is shorter than John . ||| is Susan shorter than John ? [yes] .',\n",
+ " 'Mary is taller than Susan . Mary is shorter than John . ||| is Susan taller than John ? [no] .',\n",
+ " \"Mary is taller than Susan . John isn't shorter than Mary . ||| who is the tallest ? [John] .\",\n",
+ " \"Mary is taller than Susan . John isn't shorter than Mary . ||| who is the shortest ? [Susan] .\",\n",
+ " \"Mary is taller than Susan . John isn't shorter than Mary . ||| is John taller than Susan ? [yes] .\",\n",
+ " \"Mary is taller than Susan . John isn't shorter than Mary . ||| is John shorter than Susan ? [no] .\",\n",
+ " \"Mary is taller than Susan . John isn't shorter than Mary . ||| is Susan shorter than John ? [yes] .\",\n",
+ " \"Mary is taller than Susan . John isn't shorter than Mary . ||| is Susan taller than John ? [no] .\",\n",
+ " \"Mary is taller than Susan . Mary isn't taller than John . ||| who is the tallest ? [John] .\",\n",
+ " \"Mary is taller than Susan . Mary isn't taller than John . ||| who is the shortest ? [Susan] .\",\n",
+ " \"Mary is taller than Susan . Mary isn't taller than John . ||| is John taller than Susan ? [yes] .\",\n",
+ " \"Mary is taller than Susan . Mary isn't taller than John . ||| is John shorter than Susan ? [no] .\",\n",
+ " \"Mary is taller than Susan . Mary isn't taller than John . ||| is Susan shorter than John ? [yes] .\",\n",
+ " \"Mary is taller than Susan . Mary isn't taller than John . ||| is Susan taller than John ? [no] .\",\n",
+ " 'Susan is shorter than Mary . John is taller than Mary . ||| who is the tallest ? [John] .',\n",
+ " 'Susan is shorter than Mary . John is taller than Mary . ||| who is the shortest ? [Susan] .',\n",
+ " 'Susan is shorter than Mary . John is taller than Mary . ||| is John taller than Susan ? [yes] .',\n",
+ " 'Susan is shorter than Mary . John is taller than Mary . ||| is John shorter than Susan ? [no] .',\n",
+ " 'Susan is shorter than Mary . John is taller than Mary . ||| is Susan shorter than John ? [yes] .',\n",
+ " 'Susan is shorter than Mary . John is taller than Mary . ||| is Susan taller than John ? [no] .',\n",
+ " 'Susan is shorter than Mary . Mary is shorter than John . ||| who is the tallest ? [John] .',\n",
+ " 'Susan is shorter than Mary . Mary is shorter than John . ||| who is the shortest ? [Susan] .',\n",
+ " 'Susan is shorter than Mary . Mary is shorter than John . ||| is John taller than Susan ? [yes] .',\n",
+ " 'Susan is shorter than Mary . Mary is shorter than John . ||| is John shorter than Susan ? [no] .',\n",
+ " 'Susan is shorter than Mary . Mary is shorter than John . ||| is Susan shorter than John ? [yes] .',\n",
+ " 'Susan is shorter than Mary . Mary is shorter than John . ||| is Susan taller than John ? [no] .',\n",
+ " \"Susan is shorter than Mary . John isn't shorter than Mary . ||| who is the tallest ? [John] .\",\n",
+ " \"Susan is shorter than Mary . John isn't shorter than Mary . ||| who is the shortest ? [Susan] .\",\n",
+ " \"Susan is shorter than Mary . John isn't shorter than Mary . ||| is John taller than Susan ? [yes] .\",\n",
+ " \"Susan is shorter than Mary . John isn't shorter than Mary . ||| is John shorter than Susan ? [no] .\",\n",
+ " \"Susan is shorter than Mary . John isn't shorter than Mary . ||| is Susan shorter than John ? [yes] .\",\n",
+ " \"Susan is shorter than Mary . John isn't shorter than Mary . ||| is Susan taller than John ? [no] .\",\n",
+ " \"Susan is shorter than Mary . Mary isn't taller than John . ||| who is the tallest ? [John] .\",\n",
+ " \"Susan is shorter than Mary . Mary isn't taller than John . ||| who is the shortest ? [Susan] .\",\n",
+ " \"Susan is shorter than Mary . Mary isn't taller than John . ||| is John taller than Susan ? [yes] .\",\n",
+ " \"Susan is shorter than Mary . Mary isn't taller than John . ||| is John shorter than Susan ? [no] .\",\n",
+ " \"Susan is shorter than Mary . Mary isn't taller than John . ||| is Susan shorter than John ? [yes] .\",\n",
+ " \"Susan is shorter than Mary . Mary isn't taller than John . ||| is Susan taller than John ? [no] .\",\n",
+ " \"Mary isn't shorter than Susan . John is taller than Mary . ||| who is the tallest ? [John] .\",\n",
+ " \"Mary isn't shorter than Susan . John is taller than Mary . ||| who is the shortest ? [Susan] .\",\n",
+ " \"Mary isn't shorter than Susan . John is taller than Mary . ||| is John taller than Susan ? [yes] .\",\n",
+ " \"Mary isn't shorter than Susan . John is taller than Mary . ||| is John shorter than Susan ? [no] .\",\n",
+ " \"Mary isn't shorter than Susan . John is taller than Mary . ||| is Susan shorter than John ? [yes] .\",\n",
+ " \"Mary isn't shorter than Susan . John is taller than Mary . ||| is Susan taller than John ? [no] .\",\n",
+ " \"Mary isn't shorter than Susan . Mary is shorter than John . ||| who is the tallest ? [John] .\",\n",
+ " \"Mary isn't shorter than Susan . Mary is shorter than John . ||| who is the shortest ? [Susan] .\",\n",
+ " \"Mary isn't shorter than Susan . Mary is shorter than John . ||| is John taller than Susan ? [yes] .\",\n",
+ " \"Mary isn't shorter than Susan . Mary is shorter than John . ||| is John shorter than Susan ? [no] .\",\n",
+ " \"Mary isn't shorter than Susan . Mary is shorter than John . ||| is Susan shorter than John ? [yes] .\",\n",
+ " \"Mary isn't shorter than Susan . Mary is shorter than John . ||| is Susan taller than John ? [no] .\",\n",
+ " \"Mary isn't shorter than Susan . John isn't shorter than Mary . ||| who is the tallest ? [John] .\",\n",
+ " \"Mary isn't shorter than Susan . John isn't shorter than Mary . ||| who is the shortest ? [Susan] .\",\n",
+ " \"Mary isn't shorter than Susan . John isn't shorter than Mary . ||| is John taller than Susan ? [yes] .\",\n",
+ " \"Mary isn't shorter than Susan . John isn't shorter than Mary . ||| is John shorter than Susan ? [no] .\",\n",
+ " \"Mary isn't shorter than Susan . John isn't shorter than Mary . ||| is Susan shorter than John ? [yes] .\",\n",
+ " \"Mary isn't shorter than Susan . John isn't shorter than Mary . ||| is Susan taller than John ? [no] .\",\n",
+ " \"Mary isn't shorter than Susan . Mary isn't taller than John . ||| who is the tallest ? [John] .\",\n",
+ " \"Mary isn't shorter than Susan . Mary isn't taller than John . ||| who is the shortest ? [Susan] .\",\n",
+ " \"Mary isn't shorter than Susan . Mary isn't taller than John . ||| is John taller than Susan ? [yes] .\",\n",
+ " \"Mary isn't shorter than Susan . Mary isn't taller than John . ||| is John shorter than Susan ? [no] .\",\n",
+ " \"Mary isn't shorter than Susan . Mary isn't taller than John . ||| is Susan shorter than John ? [yes] .\",\n",
+ " \"Mary isn't shorter than Susan . Mary isn't taller than John . ||| is Susan taller than John ? [no] .\",\n",
+ " \"Susan isn't taller than Mary . John is taller than Mary . ||| who is the tallest ? [John] .\",\n",
+ " \"Susan isn't taller than Mary . John is taller than Mary . ||| who is the shortest ? [Susan] .\",\n",
+ " \"Susan isn't taller than Mary . John is taller than Mary . ||| is John taller than Susan ? [yes] .\",\n",
+ " \"Susan isn't taller than Mary . John is taller than Mary . ||| is John shorter than Susan ? [no] .\",\n",
+ " \"Susan isn't taller than Mary . John is taller than Mary . ||| is Susan shorter than John ? [yes] .\",\n",
+ " \"Susan isn't taller than Mary . John is taller than Mary . ||| is Susan taller than John ? [no] .\",\n",
+ " \"Susan isn't taller than Mary . Mary is shorter than John . ||| who is the tallest ? [John] .\",\n",
+ " \"Susan isn't taller than Mary . Mary is shorter than John . ||| who is the shortest ? [Susan] .\",\n",
+ " \"Susan isn't taller than Mary . Mary is shorter than John . ||| is John taller than Susan ? [yes] .\",\n",
+ " \"Susan isn't taller than Mary . Mary is shorter than John . ||| is John shorter than Susan ? [no] .\",\n",
+ " \"Susan isn't taller than Mary . Mary is shorter than John . ||| is Susan shorter than John ? [yes] .\",\n",
+ " \"Susan isn't taller than Mary . Mary is shorter than John . ||| is Susan taller than John ? [no] .\",\n",
+ " \"Susan isn't taller than Mary . John isn't shorter than Mary . ||| who is the tallest ? [John] .\",\n",
+ " \"Susan isn't taller than Mary . John isn't shorter than Mary . ||| who is the shortest ? [Susan] .\",\n",
+ " \"Susan isn't taller than Mary . John isn't shorter than Mary . ||| is John taller than Susan ? [yes] .\",\n",
+ " \"Susan isn't taller than Mary . John isn't shorter than Mary . ||| is John shorter than Susan ? [no] .\",\n",
+ " \"Susan isn't taller than Mary . John isn't shorter than Mary . ||| is Susan shorter than John ? [yes] .\",\n",
+ " \"Susan isn't taller than Mary . John isn't shorter than Mary . ||| is Susan taller than John ? [no] .\",\n",
+ " \"Susan isn't taller than Mary . Mary isn't taller than John . ||| who is the tallest ? [John] .\",\n",
+ " \"Susan isn't taller than Mary . Mary isn't taller than John . ||| who is the shortest ? [Susan] .\",\n",
+ " \"Susan isn't taller than Mary . Mary isn't taller than John . ||| is John taller than Susan ? [yes] .\",\n",
+ " \"Susan isn't taller than Mary . Mary isn't taller than John . ||| is John shorter than Susan ? [no] .\",\n",
+ " \"Susan isn't taller than Mary . Mary isn't taller than John . ||| is Susan shorter than John ? [yes] .\",\n",
+ " \"Susan isn't taller than Mary . Mary isn't taller than John . ||| is Susan taller than John ? [no] .\"]"
+ ]
+ },
+ "execution_count": 44,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "transitive_P_template = '{ent0} {rel} {ent1} .'\n",
+ "transitive_wh_QA_template = '{which} is {pred} ? {ent} .'\n",
+ "transitive_yesno_QA_template = 'is {ent0} {rel} {ent1} ? {ans} .'\n",
+ "\n",
+ "def make_transitive(P_template, wh_QA_template, yesno_QA_template, join_template,\n",
+ " index=-1, orig_sentence='', entities=[\"John\", \"Mary\", \"Susan\"], entity_substitutes=None, determiner=\"\", \n",
+ " relations=('taller..than', 'shorter..than'), maybe=True, structured=False,\n",
+ " packed_predicates=[\"pred0/~pred0\", \"pred1/~pred1\"], predicate_substitutes=None,\n",
+ " predicate_dichotomy=True, reverse_causal=False):\n",
+ " if entities[0].islower():\n",
+ " entities = ['the ' + e for e in entities]\n",
+ "# print('relations =', relations)\n",
+ " relations, predicates = ([r.replace('..', ' ') for r in relations], [r.split('..')[0] for r in relations]) \\\n",
+ " if '..' in relations[0] else ([r.split('/')[0] for r in relations], [r.split('/')[-1] for r in relations])\n",
+ "# print('relations =', relations, 'predicates =', predicates)\n",
+ " predicates = [comparative2superlative(p, structured=structured) for p in predicates]\n",
+ " \n",
+ " P0_entities, P1_entities = ([entities[0], entities[1]], [entities[1], entities[2]]) \\\n",
+ " if not maybe else ([entities[0], entities[1]], [entities[0], entities[2]])\n",
+ " P0 = make_relational_atoms(P_template, P0_entities, relations)\n",
+ " P1 = make_relational_atoms(P_template, P1_entities, relations)\n",
+ " \n",
+ " wh_pronoun = 'which' if entities[0].startswith('the') else 'who'\n",
+ " wh_QA = [wh_QA_template.format(which=wh_pronoun, pred=pred, ent=ent) \n",
+ " for pred, ent in [(predicates[0], mask(entities[0])), (predicates[-1], mask(entities[-1] if not maybe else 'unknown'))]]\n",
+ " \n",
+ " def _maybe(s):\n",
+ " return s if not maybe else 'maybe'\n",
+ " yesno_entities = (entities[0], entities[-1]) if not maybe else (entities[1], entities[-1])\n",
+ " yesno_QA = [yesno_QA_template.format(ent0=ent0, ent1=ent1, rel=rel, ans=ans) \n",
+ " for ent0, ent1, rel, ans in [\n",
+ " (yesno_entities[0], yesno_entities[-1], relations[0], mask(_maybe('yes'))), \n",
+ " (yesno_entities[0], yesno_entities[-1], relations[-1], mask(_maybe('no'))),\n",
+ " (yesno_entities[-1], yesno_entities[0], relations[-1], mask(_maybe('yes'))),\n",
+ " (yesno_entities[-1], yesno_entities[0], relations[0], mask(_maybe('no')))]]\n",
+ " \n",
+ " Ps = [(p0, p1) for p0, p1 in list(product(P0, P1)) + list(product(P1, P0))]\n",
+ " QAs = wh_QA + yesno_QA\n",
+ " \n",
+ " def get_rel(atom):\n",
+ " for rel in relations:\n",
+ "# assert rel.startswith('is')\n",
+ " rel = rel.split()[0] # \"taller than\" -> \"taller\"\n",
+ " if rel in atom:\n",
+ " return rel\n",
+ " assert False\n",
+ " sentences = [p0 + ' ' + p1 + ' ||| ' + qas for (p0, p1), qas in product(Ps, QAs)\n",
+ " if not structured or get_rel(p0) == get_rel(p1) == get_rel(qas)]\n",
+ "# sentences = [s.replace('er st ', 'est ') for s in sentences]\n",
+ " return sentences\n",
+ "\n",
+ "sentences = make_transitive(transitive_P_template, transitive_wh_QA_template, transitive_yesno_QA_template, None, maybe=False, structured=False)\n",
+ "# len(sentences)\n",
+ "sample(sentences, 20)\n",
+ "sentences"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 41,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "'a . . . b . . . c'"
+ ]
+ },
+ "execution_count": 41,
+ "metadata": {},
+ "output_type": "execute_result"
+ },
+ {
+ "ename": "TypeError",
+ "evalue": "object of type 'NoneType' has no len()",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
+ "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;34m'a'\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m' .'\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mrandom\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m10\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m' '\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m'b'\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m' .'\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mrandom\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m10\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m' '\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m'c'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
+ "\u001b[0;31mTypeError\u001b[0m: object of type 'NoneType' has no len()"
+ ]
+ }
+ ],
+ "source": [
+ "'a' + ' .'*random.randint(0, 10) + ' ' + 'b' + ' .'*random.randint(0, 10) + ' ' + 'c'\n",
+ "len(None)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "['James is older than Jennifer . Jennifer is older than John . ||| is James older than John ? [yes] .',\n",
+ " \"James is younger than Jennifer . James isn't younger than Linda . ||| who is the younger st ? [Linda] .\",\n",
+ " \"Linda is shorter than Mary . Linda isn't shorter than Robert . ||| is Mary shorter than Robert ? [no] .\",\n",
+ " 'Linda is shorter than Robert . John is shorter than Linda . ||| who is the shorter st ? [John] .',\n",
+ " 'Mary is older than Robert . John is older than Mary . ||| is Robert older than John ? [no] .',\n",
+ " \"Jennifer isn't younger than Robert . James is younger than Robert . ||| is Jennifer younger than James ? [no] .\",\n",
+ " \"Mary is shorter than Jennifer . Mary isn't shorter than John . ||| who is the shorter st ? [John] .\",\n",
+ " \"Linda isn't taller than Robert . Linda is taller than John . ||| who is the taller st ? [Robert] .\",\n",
+ " \"Robert isn't younger than Mary . Mary isn't younger than Linda . ||| is Robert younger than Linda ? [no] .\",\n",
+ " \"Jennifer isn't taller than Linda . Mary isn't taller than Jennifer . ||| who is the taller st ? [Linda] .\",\n",
+ " \"Mary isn't older than Linda . John isn't older than Mary . ||| is John older than Linda ? [no] .\",\n",
+ " \"Linda is taller than Robert . John isn't taller than Robert . ||| is John taller than Linda ? [no] .\",\n",
+ " \"Robert isn't older than Jennifer . James is older than Jennifer . ||| is Robert older than James ? [no] .\",\n",
+ " \"Linda isn't older than Jennifer . Jennifer isn't older than James . ||| is Linda older than James ? [no] .\",\n",
+ " \"Jennifer is shorter than Robert . John isn't shorter than Robert . ||| is Jennifer shorter than John ? [yes] .\",\n",
+ " 'James is older than Mary . Jennifer is older than James . ||| is Mary older than Jennifer ? [no] .',\n",
+ " 'Jennifer is taller than John . John is taller than Robert . ||| is Jennifer taller than Robert ? [yes] .',\n",
+ " \"John is younger than Linda . Mary isn't younger than Linda . ||| who is the younger st ? [John] .\",\n",
+ " \"Jennifer is younger than Mary . Jennifer isn't younger than John . ||| who is the younger st ? [John] .\",\n",
+ " \"Robert is younger than John . Linda isn't younger than John . ||| is Linda younger than Robert ? [no] .\"]"
+ ]
+ },
+ "execution_count": 23,
+ "metadata": {},
+ "output_type": "execute_result"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "num_sent = 11520 -> 11520\n"
+ ]
+ }
+ ],
+ "source": [
+ "sentence_groups = []\n",
+ "maybe = False\n",
+ "for relations, entity_types in rel2entypes.items():\n",
+ " sentences = []\n",
+ " ent_tuples = []\n",
+ " for entities in entity_types:\n",
+ " if isinstance(entities, list):\n",
+ " ent_tuples += permutations(entities, 3)\n",
+ " else:\n",
+ " assert isinstance(entities, tuple) and len(entities) == 2 # people_names\n",
+ " ent_tuples += permutations(entities[0] + entities[1], 3)\n",
+ " for (rel, ent_tuple) in product(relations, ent_tuples):\n",
+ " sentences += make_transitive(transitive_P_template, transitive_wh_QA_template, transitive_yesno_QA_template, None, \n",
+ " entities=list(ent_tuple), relations=rel, maybe=False, structured=True)\n",
+ " if maybe:\n",
+ " sentences += make_transitive(transitive_P_template, transitive_wh_QA_template, transitive_yesno_QA_template, None, \n",
+ " entities=list(ent_tuple), relations=rel, maybe=True, structured=True)\n",
+ " sample(sentences, 20)\n",
+ " print('num_sent =', len(sentences), '->', len(set(sentences)))\n",
+ " sentence_groups.append(sentences)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 247,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "_StoreAction(option_strings=['--max_seq_length'], dest='max_seq_length', nargs=None, const=None, default=128, type=, choices=None, help='The maximum total input sequence length after WordPiece tokenization. \\nSequences longer than this will be truncated, and sequences shorter \\nthan this will be padded.', metavar=None)"
+ ]
+ },
+ "execution_count": 247,
+ "metadata": {},
+ "output_type": "execute_result"
+ },
+ {
+ "data": {
+ "text/plain": [
+ "_StoreTrueAction(option_strings=['--do_train'], dest='do_train', nargs=0, const=True, default=False, type=None, choices=None, help='Whether to run training.', metavar=None)"
+ ]
+ },
+ "execution_count": 247,
+ "metadata": {},
+ "output_type": "execute_result"
+ },
+ {
+ "data": {
+ "text/plain": [
+ "_StoreTrueAction(option_strings=['--do_eval'], dest='do_eval', nargs=0, const=True, default=False, type=None, choices=None, help='Whether to run eval on the dev set.', metavar=None)"
+ ]
+ },
+ "execution_count": 247,
+ "metadata": {},
+ "output_type": "execute_result"
+ },
+ {
+ "data": {
+ "text/plain": [
+ "_StoreAction(option_strings=['--train_batch_size'], dest='train_batch_size', nargs=None, const=None, default=32, type=, choices=None, help='Total batch size for training.', metavar=None)"
+ ]
+ },
+ "execution_count": 247,
+ "metadata": {},
+ "output_type": "execute_result"
+ },
+ {
+ "data": {
+ "text/plain": [
+ "_StoreAction(option_strings=['--eval_batch_size'], dest='eval_batch_size', nargs=None, const=None, default=32, type=, choices=None, help='Total batch size for eval.', metavar=None)"
+ ]
+ },
+ "execution_count": 247,
+ "metadata": {},
+ "output_type": "execute_result"
+ },
+ {
+ "data": {
+ "text/plain": [
+ "_StoreAction(option_strings=['--learning_rate'], dest='learning_rate', nargs=None, const=None, default=3e-05, type=, choices=None, help='The initial learning rate for Adam.', metavar=None)"
+ ]
+ },
+ "execution_count": 247,
+ "metadata": {},
+ "output_type": "execute_result"
+ },
+ {
+ "data": {
+ "text/plain": [
+ "_StoreAction(option_strings=['--num_train_epochs'], dest='num_train_epochs', nargs=None, const=None, default=3.0, type=, choices=None, help='Total number of training epochs to perform.', metavar=None)"
+ ]
+ },
+ "execution_count": 247,
+ "metadata": {},
+ "output_type": "execute_result"
+ },
+ {
+ "data": {
+ "text/plain": [
+ "_StoreAction(option_strings=['--warmup_proportion'], dest='warmup_proportion', nargs=None, const=None, default=0.1, type=, choices=None, help='Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% of training.', metavar=None)"
+ ]
+ },
+ "execution_count": 247,
+ "metadata": {},
+ "output_type": "execute_result"
+ },
+ {
+ "data": {
+ "text/plain": [
+ "_StoreTrueAction(option_strings=['--no_cuda'], dest='no_cuda', nargs=0, const=True, default=False, type=None, choices=None, help='Whether not to use CUDA when available', metavar=None)"
+ ]
+ },
+ "execution_count": 247,
+ "metadata": {},
+ "output_type": "execute_result"
+ },
+ {
+ "data": {
+ "text/plain": [
+ "_StoreTrueAction(option_strings=['--do_lower_case'], dest='do_lower_case', nargs=0, const=True, default=False, type=None, choices=None, help='Whether to lower case the input text. True for uncased models, False for cased models.', metavar=None)"
+ ]
+ },
+ "execution_count": 247,
+ "metadata": {},
+ "output_type": "execute_result"
+ },
+ {
+ "data": {
+ "text/plain": [
+ "_StoreAction(option_strings=['--seed'], dest='seed', nargs=None, const=None, default=42, type=, choices=None, help='random seed for initialization', metavar=None)"
+ ]
+ },
+ "execution_count": 247,
+ "metadata": {},
+ "output_type": "execute_result"
+ },
+ {
+ "data": {
+ "text/plain": [
+ "_StoreAction(option_strings=['--gradient_accumulation_steps'], dest='gradient_accumulation_steps', nargs=None, const=None, default=1, type=, choices=None, help='Number of updates steps to accumualte before performing a backward/update pass.', metavar=None)"
+ ]
+ },
+ "execution_count": 247,
+ "metadata": {},
+ "output_type": "execute_result"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Namespace(do_eval=True, do_lower_case=True, do_train=True, eval_batch_size=128, gradient_accumulation_steps=1, learning_rate=0.0001, max_seq_length=128, no_cuda=False, num_train_epochs=100, seed=42, train_batch_size=32, warmup_proportion=0.1)\n"
+ ]
+ }
+ ],
+ "source": [
+ "import argparse\n",
+ "parser = argparse.ArgumentParser()\n",
+ "\n",
+ "parser.add_argument(\"--max_seq_length\",\n",
+ " default=128,\n",
+ " type=int,\n",
+ " help=\"The maximum total input sequence length after WordPiece tokenization. \\n\"\n",
+ " \"Sequences longer than this will be truncated, and sequences shorter \\n\"\n",
+ " \"than this will be padded.\")\n",
+ "parser.add_argument(\"--do_train\",\n",
+ " action='store_true',\n",
+ " help=\"Whether to run training.\")\n",
+ "parser.add_argument(\"--do_eval\",\n",
+ " action='store_true',\n",
+ " help=\"Whether to run eval on the dev set.\")\n",
+ "parser.add_argument(\"--train_batch_size\",\n",
+ " default=32,\n",
+ " type=int,\n",
+ " help=\"Total batch size for training.\")\n",
+ "parser.add_argument(\"--eval_batch_size\",\n",
+ " default=32,\n",
+ " type=int,\n",
+ " help=\"Total batch size for eval.\")\n",
+ "parser.add_argument(\"--learning_rate\",\n",
+ " default=3e-5,\n",
+ " type=float,\n",
+ " help=\"The initial learning rate for Adam.\")\n",
+ "parser.add_argument(\"--num_train_epochs\",\n",
+ " default=3.0,\n",
+ " type=float,\n",
+ " help=\"Total number of training epochs to perform.\")\n",
+ "parser.add_argument(\"--warmup_proportion\",\n",
+ " default=0.1,\n",
+ " type=float,\n",
+ " help=\"Proportion of training to perform linear learning rate warmup for. \"\n",
+ " \"E.g., 0.1 = 10%% of training.\")\n",
+ "parser.add_argument(\"--no_cuda\",\n",
+ " action='store_true',\n",
+ " help=\"Whether not to use CUDA when available\")\n",
+ "parser.add_argument(\"--do_lower_case\",\n",
+ " action='store_true',\n",
+ " help=\"Whether to lower case the input text. True for uncased models, False for cased models.\")\n",
+ "parser.add_argument('--seed',\n",
+ " type=int,\n",
+ " default=42,\n",
+ " help=\"random seed for initialization\")\n",
+ "parser.add_argument('--gradient_accumulation_steps',\n",
+ " type=int,\n",
+ " default=1,\n",
+ " help=\"Number of updates steps to accumualte before performing a backward/update pass.\")\n",
+ "parser.add_argument(\"--dev_percent\",\n",
+ " default=0.5,\n",
+ " type=float)\n",
+ "# args = parser.parse_args(['--output_dir', '/home'])\n",
+ "args = parser.parse_args([])\n",
+ "args.do_lower_case = True\n",
+ "args.do_train = True\n",
+ "args.do_eval = True\n",
+ "args.eval_batch_size = 128\n",
+ "args.learning_rate = 1e-4\n",
+ "args.num_train_epochs = 100\n",
+ "print(args)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 243,
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "num_train_steps = 10800\n"
+ ]
+ }
+ ],
+ "source": [
+ "child_dataset = CHILDDataset(tokenizer, sentence_groups[0], dev_percent=0.5)\n",
+ "train_features = child_dataset.get_train_features()\n",
+ "num_train_steps = int(\n",
+ " len(train_features) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs)\n",
+ "print('num_train_steps =', num_train_steps)\n",
+ "eval_features = child_dataset.get_dev_features()\n",
+ "\n",
+ "train_dataset = child_dataset.build_dataset(train_features)\n",
+ "eval_dataset = child_dataset.build_dataset(eval_features)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 250,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "06/09/2019 10:05:44 - INFO - run_child_finetuning - device: cuda n_gpu: 1\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 250,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() and not args.no_cuda else \"cpu\")\n",
+ "n_gpu = torch.cuda.device_count()\n",
+ "logger.info(\"device: {} n_gpu: {}\".format(\n",
+ " device, n_gpu))\n",
+ "\n",
+ "args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps)\n",
+ "\n",
+ "random.seed(args.seed)\n",
+ "np.random.seed(args.seed)\n",
+ "torch.manual_seed(args.seed)\n",
+ "if n_gpu > 0:\n",
+ " torch.cuda.manual_seed_all(args.seed)\n",
+ "\n",
+ "# Prepare model\n",
+ "# model = BertForMaskedLM.from_pretrained(BERT_DIR)\n",
+ "CONFIG_NAME = 'bert_config_small.json'\n",
+ "config = BertConfig(os.path.join(BERT_DIR, CONFIG_NAME))\n",
+ "model = BertForMaskedLM(config)\n",
+ "_ = model.to(device)\n",
+ "if n_gpu > 1:\n",
+ " model = torch.nn.DataParallel(model)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 252,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Prepare optimizer\n",
+ "param_optimizer = list(model.named_parameters())\n",
+ "no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']\n",
+ "optimizer_grouped_parameters = [\n",
+ " {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},\n",
+ " {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}\n",
+ " ]\n",
+ "optimizer = BertAdam(optimizer_grouped_parameters,\n",
+ " lr=args.learning_rate,\n",
+ " warmup=args.warmup_proportion,\n",
+ " t_total=num_train_steps)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 253,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 0%| | 0/100 [00:00, ?it/s]\u001b[A\u001b[A\u001b[A\u001b[A"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "global_step 0, lr = 0.000000\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "06/09/2019 10:10:32 - INFO - run_child_finetuning - Epoch 1\n",
+ "06/09/2019 10:10:32 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 10:11:00 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:11:00 - INFO - run_child_finetuning - eval_accuracy = 0.3390625\n",
+ "06/09/2019 10:11:00 - INFO - run_child_finetuning - eval_loss = 9.694811651441785\n",
+ "06/09/2019 10:11:00 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 10:11:28 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:11:28 - INFO - run_child_finetuning - eval_accuracy = 0.32760416666666664\n",
+ "06/09/2019 10:11:28 - INFO - run_child_finetuning - eval_loss = 9.699780379401313\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 1%| | 1/100 [01:06<1:49:31, 66.37s/it]\u001b[A\u001b[A\u001b[A\u001b[A06/09/2019 10:11:39 - INFO - run_child_finetuning - Epoch 2\n",
+ "06/09/2019 10:11:39 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 10:12:07 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:12:07 - INFO - run_child_finetuning - eval_accuracy = 0.3390625\n",
+ "06/09/2019 10:12:07 - INFO - run_child_finetuning - eval_loss = 7.738626289367676\n",
+ "06/09/2019 10:12:07 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 10:12:34 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:12:34 - INFO - run_child_finetuning - eval_accuracy = 0.32760416666666664\n",
+ "06/09/2019 10:12:35 - INFO - run_child_finetuning - eval_loss = 7.746824651294284\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 2%|▏ | 2/100 [02:13<1:48:35, 66.49s/it]\u001b[A\u001b[A\u001b[A\u001b[A"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "global_step 1000, lr = 0.000093\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "06/09/2019 10:12:45 - INFO - run_child_finetuning - Epoch 3\n",
+ "06/09/2019 10:12:45 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 10:13:13 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:13:13 - INFO - run_child_finetuning - eval_accuracy = 0.3390625\n",
+ "06/09/2019 10:13:13 - INFO - run_child_finetuning - eval_loss = 3.257909724447462\n",
+ "06/09/2019 10:13:13 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 10:13:40 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:13:40 - INFO - run_child_finetuning - eval_accuracy = 0.32760416666666664\n",
+ "06/09/2019 10:13:40 - INFO - run_child_finetuning - eval_loss = 3.2719171391593087\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 3%|▎ | 3/100 [03:19<1:47:13, 66.32s/it]\u001b[A\u001b[A\u001b[A\u001b[A06/09/2019 10:13:51 - INFO - run_child_finetuning - Epoch 4\n",
+ "06/09/2019 10:13:51 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 10:14:19 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:14:19 - INFO - run_child_finetuning - eval_accuracy = 0.3390625\n",
+ "06/09/2019 10:14:19 - INFO - run_child_finetuning - eval_loss = 2.0499441080623204\n",
+ "06/09/2019 10:14:19 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 10:14:46 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:14:46 - INFO - run_child_finetuning - eval_accuracy = 0.32760416666666664\n",
+ "06/09/2019 10:14:46 - INFO - run_child_finetuning - eval_loss = 2.066389168633355\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 4%|▍ | 4/100 [04:24<1:45:55, 66.20s/it]\u001b[A\u001b[A\u001b[A\u001b[A06/09/2019 10:14:57 - INFO - run_child_finetuning - Epoch 5\n",
+ "06/09/2019 10:14:57 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 10:15:25 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:15:25 - INFO - run_child_finetuning - eval_accuracy = 0.3390625\n",
+ "06/09/2019 10:15:25 - INFO - run_child_finetuning - eval_loss = 1.707436407936944\n",
+ "06/09/2019 10:15:25 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 10:15:52 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:15:52 - INFO - run_child_finetuning - eval_accuracy = 0.32760416666666664\n",
+ "06/09/2019 10:15:52 - INFO - run_child_finetuning - eval_loss = 1.7236953417460124\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 5%|▌ | 5/100 [05:30<1:44:41, 66.12s/it]\u001b[A\u001b[A\u001b[A\u001b[A"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "global_step 2000, lr = 0.000081\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "06/09/2019 10:16:03 - INFO - run_child_finetuning - Epoch 6\n",
+ "06/09/2019 10:16:03 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 10:16:31 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:16:31 - INFO - run_child_finetuning - eval_accuracy = 0.4223090277777778\n",
+ "06/09/2019 10:16:31 - INFO - run_child_finetuning - eval_loss = 1.4861090461413065\n",
+ "06/09/2019 10:16:31 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 10:16:59 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:16:59 - INFO - run_child_finetuning - eval_accuracy = 0.4110243055555556\n",
+ "06/09/2019 10:16:59 - INFO - run_child_finetuning - eval_loss = 1.500312285953098\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 6%|▌ | 6/100 [06:37<1:44:00, 66.38s/it]\u001b[A\u001b[A\u001b[A\u001b[A06/09/2019 10:17:10 - INFO - run_child_finetuning - Epoch 7\n",
+ "06/09/2019 10:17:10 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 10:17:38 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:17:38 - INFO - run_child_finetuning - eval_accuracy = 0.4223090277777778\n",
+ "06/09/2019 10:17:38 - INFO - run_child_finetuning - eval_loss = 1.414702398247189\n",
+ "06/09/2019 10:17:38 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 10:18:05 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:18:05 - INFO - run_child_finetuning - eval_accuracy = 0.4110243055555556\n",
+ "06/09/2019 10:18:05 - INFO - run_child_finetuning - eval_loss = 1.4278037812974718\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 7%|▋ | 7/100 [07:44<1:42:47, 66.31s/it]\u001b[A\u001b[A\u001b[A\u001b[A06/09/2019 10:18:16 - INFO - run_child_finetuning - Epoch 8\n",
+ "06/09/2019 10:18:16 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 10:18:44 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:18:44 - INFO - run_child_finetuning - eval_accuracy = 0.4223090277777778\n",
+ "06/09/2019 10:18:44 - INFO - run_child_finetuning - eval_loss = 1.3849829329384697\n",
+ "06/09/2019 10:18:44 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 10:19:11 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:19:11 - INFO - run_child_finetuning - eval_accuracy = 0.4110243055555556\n",
+ "06/09/2019 10:19:11 - INFO - run_child_finetuning - eval_loss = 1.3974607268969217\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 8%|▊ | 8/100 [08:50<1:41:30, 66.20s/it]\u001b[A\u001b[A\u001b[A\u001b[A"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "global_step 3000, lr = 0.000072\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "06/09/2019 10:19:22 - INFO - run_child_finetuning - Epoch 9\n",
+ "06/09/2019 10:19:22 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 10:19:50 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:19:50 - INFO - run_child_finetuning - eval_accuracy = 0.4223090277777778\n",
+ "06/09/2019 10:19:50 - INFO - run_child_finetuning - eval_loss = 1.369037503666348\n",
+ "06/09/2019 10:19:50 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 10:20:17 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:20:17 - INFO - run_child_finetuning - eval_accuracy = 0.4110243055555556\n",
+ "06/09/2019 10:20:17 - INFO - run_child_finetuning - eval_loss = 1.3817288875579834\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 9%|▉ | 9/100 [09:56<1:40:18, 66.14s/it]\u001b[A\u001b[A\u001b[A\u001b[A06/09/2019 10:20:28 - INFO - run_child_finetuning - Epoch 10\n",
+ "06/09/2019 10:20:28 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 10:20:56 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:20:56 - INFO - run_child_finetuning - eval_accuracy = 0.4223090277777778\n",
+ "06/09/2019 10:20:56 - INFO - run_child_finetuning - eval_loss = 1.3590228782759772\n",
+ "06/09/2019 10:20:56 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 10:21:24 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:21:24 - INFO - run_child_finetuning - eval_accuracy = 0.4110243055555556\n",
+ "06/09/2019 10:21:24 - INFO - run_child_finetuning - eval_loss = 1.3718345721562704\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 10%|█ | 10/100 [11:02<1:39:22, 66.25s/it]\u001b[A\u001b[A\u001b[A\u001b[A06/09/2019 10:21:34 - INFO - run_child_finetuning - Epoch 11\n",
+ "06/09/2019 10:21:34 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 10:22:02 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:22:02 - INFO - run_child_finetuning - eval_accuracy = 0.4223090277777778\n",
+ "06/09/2019 10:22:02 - INFO - run_child_finetuning - eval_loss = 1.352443257967631\n",
+ "06/09/2019 10:22:02 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 10:22:30 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:22:30 - INFO - run_child_finetuning - eval_accuracy = 0.4110243055555556\n",
+ "06/09/2019 10:22:30 - INFO - run_child_finetuning - eval_loss = 1.3645663738250733\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 11%|█ | 11/100 [12:08<1:38:12, 66.21s/it]\u001b[A\u001b[A\u001b[A\u001b[A"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "global_step 4000, lr = 0.000063\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "06/09/2019 10:22:40 - INFO - run_child_finetuning - Epoch 12\n",
+ "06/09/2019 10:22:40 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 10:23:08 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:23:08 - INFO - run_child_finetuning - eval_accuracy = 0.4223090277777778\n",
+ "06/09/2019 10:23:08 - INFO - run_child_finetuning - eval_loss = 1.3473159684075249\n",
+ "06/09/2019 10:23:08 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 10:23:36 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:23:36 - INFO - run_child_finetuning - eval_accuracy = 0.4110243055555556\n",
+ "06/09/2019 10:23:36 - INFO - run_child_finetuning - eval_loss = 1.3603505068355137\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 12%|█▏ | 12/100 [13:14<1:37:00, 66.14s/it]\u001b[A\u001b[A\u001b[A\u001b[A06/09/2019 10:23:46 - INFO - run_child_finetuning - Epoch 13\n",
+ "06/09/2019 10:23:46 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 10:24:14 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:24:14 - INFO - run_child_finetuning - eval_accuracy = 0.4223090277777778\n",
+ "06/09/2019 10:24:14 - INFO - run_child_finetuning - eval_loss = 1.3420454674296909\n",
+ "06/09/2019 10:24:14 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 10:24:42 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:24:42 - INFO - run_child_finetuning - eval_accuracy = 0.4110243055555556\n",
+ "06/09/2019 10:24:42 - INFO - run_child_finetuning - eval_loss = 1.3549408475557962\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 13%|█▎ | 13/100 [14:20<1:35:45, 66.04s/it]\u001b[A\u001b[A\u001b[A\u001b[A"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "global_step 5000, lr = 0.000054\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "06/09/2019 10:24:52 - INFO - run_child_finetuning - Epoch 14\n",
+ "06/09/2019 10:24:52 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 10:25:20 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:25:20 - INFO - run_child_finetuning - eval_accuracy = 0.4223090277777778\n",
+ "06/09/2019 10:25:20 - INFO - run_child_finetuning - eval_loss = 1.3357309381167093\n",
+ "06/09/2019 10:25:20 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 10:25:48 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:25:48 - INFO - run_child_finetuning - eval_accuracy = 0.4110243055555556\n",
+ "06/09/2019 10:25:48 - INFO - run_child_finetuning - eval_loss = 1.3490168280071682\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 14%|█▍ | 14/100 [15:26<1:34:36, 66.01s/it]\u001b[A\u001b[A\u001b[A\u001b[A06/09/2019 10:25:58 - INFO - run_child_finetuning - Epoch 15\n",
+ "06/09/2019 10:25:58 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 10:26:26 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:26:26 - INFO - run_child_finetuning - eval_accuracy = 0.4223090277777778\n",
+ "06/09/2019 10:26:26 - INFO - run_child_finetuning - eval_loss = 1.3257557378874885\n",
+ "06/09/2019 10:26:26 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 10:26:54 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:26:54 - INFO - run_child_finetuning - eval_accuracy = 0.4110243055555556\n",
+ "06/09/2019 10:26:54 - INFO - run_child_finetuning - eval_loss = 1.3387107451756795\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 15%|█▌ | 15/100 [16:32<1:33:27, 65.98s/it]\u001b[A\u001b[A\u001b[A\u001b[A06/09/2019 10:27:04 - INFO - run_child_finetuning - Epoch 16\n",
+ "06/09/2019 10:27:04 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 10:27:32 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:27:32 - INFO - run_child_finetuning - eval_accuracy = 0.4435763888888889\n",
+ "06/09/2019 10:27:32 - INFO - run_child_finetuning - eval_loss = 1.3156095531251695\n",
+ "06/09/2019 10:27:32 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 10:28:00 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:28:00 - INFO - run_child_finetuning - eval_accuracy = 0.4318576388888889\n",
+ "06/09/2019 10:28:00 - INFO - run_child_finetuning - eval_loss = 1.328736596637302\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 16%|█▌ | 16/100 [17:38<1:32:21, 65.97s/it]\u001b[A\u001b[A\u001b[A\u001b[A"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "global_step 6000, lr = 0.000044\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "06/09/2019 10:28:10 - INFO - run_child_finetuning - Epoch 17\n",
+ "06/09/2019 10:28:10 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 10:28:38 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:28:38 - INFO - run_child_finetuning - eval_accuracy = 0.44973958333333336\n",
+ "06/09/2019 10:28:38 - INFO - run_child_finetuning - eval_loss = 1.3007791850301955\n",
+ "06/09/2019 10:28:38 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 10:29:06 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:29:06 - INFO - run_child_finetuning - eval_accuracy = 0.43914930555555554\n",
+ "06/09/2019 10:29:06 - INFO - run_child_finetuning - eval_loss = 1.314843883779314\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 17%|█▋ | 17/100 [18:44<1:31:17, 65.99s/it]\u001b[A\u001b[A\u001b[A\u001b[A06/09/2019 10:29:17 - INFO - run_child_finetuning - Epoch 18\n",
+ "06/09/2019 10:29:17 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 10:29:45 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:29:45 - INFO - run_child_finetuning - eval_accuracy = 0.44973958333333336\n",
+ "06/09/2019 10:29:45 - INFO - run_child_finetuning - eval_loss = 1.2931998319096036\n",
+ "06/09/2019 10:29:45 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 10:30:12 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:30:12 - INFO - run_child_finetuning - eval_accuracy = 0.43914930555555554\n",
+ "06/09/2019 10:30:12 - INFO - run_child_finetuning - eval_loss = 1.3075149999724494\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 18%|█▊ | 18/100 [19:51<1:30:30, 66.23s/it]\u001b[A\u001b[A\u001b[A\u001b[A06/09/2019 10:30:23 - INFO - run_child_finetuning - Epoch 19\n",
+ "06/09/2019 10:30:23 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 10:30:51 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:30:51 - INFO - run_child_finetuning - eval_accuracy = 0.44973958333333336\n",
+ "06/09/2019 10:30:51 - INFO - run_child_finetuning - eval_loss = 1.2879919780625237\n",
+ "06/09/2019 10:30:51 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 10:31:18 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:31:18 - INFO - run_child_finetuning - eval_accuracy = 0.43914930555555554\n",
+ "06/09/2019 10:31:18 - INFO - run_child_finetuning - eval_loss = 1.30219263765547\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 19%|█▉ | 19/100 [20:57<1:29:20, 66.18s/it]\u001b[A\u001b[A\u001b[A\u001b[A"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "global_step 7000, lr = 0.000035\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "06/09/2019 10:31:27 - INFO - run_child_finetuning - Epoch 20\n",
+ "06/09/2019 10:31:27 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 10:31:55 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:31:55 - INFO - run_child_finetuning - eval_accuracy = 0.44973958333333336\n",
+ "06/09/2019 10:31:55 - INFO - run_child_finetuning - eval_loss = 1.2851258847448561\n",
+ "06/09/2019 10:31:55 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 10:32:24 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:32:24 - INFO - run_child_finetuning - eval_accuracy = 0.43914930555555554\n",
+ "06/09/2019 10:32:24 - INFO - run_child_finetuning - eval_loss = 1.2990937895245023\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 20%|██ | 20/100 [22:02<1:27:51, 65.89s/it]\u001b[A\u001b[A\u001b[A\u001b[A06/09/2019 10:32:32 - INFO - run_child_finetuning - Epoch 21\n",
+ "06/09/2019 10:32:32 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 10:33:00 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:33:00 - INFO - run_child_finetuning - eval_accuracy = 0.44973958333333336\n",
+ "06/09/2019 10:33:00 - INFO - run_child_finetuning - eval_loss = 1.282910199960073\n",
+ "06/09/2019 10:33:00 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 10:33:28 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:33:28 - INFO - run_child_finetuning - eval_accuracy = 0.43914930555555554\n",
+ "06/09/2019 10:33:28 - INFO - run_child_finetuning - eval_loss = 1.296793986691369\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 21%|██ | 21/100 [23:06<1:26:04, 65.38s/it]\u001b[A\u001b[A\u001b[A\u001b[A06/09/2019 10:33:36 - INFO - run_child_finetuning - Epoch 22\n",
+ "06/09/2019 10:33:36 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 10:34:04 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:34:04 - INFO - run_child_finetuning - eval_accuracy = 0.44973958333333336\n",
+ "06/09/2019 10:34:04 - INFO - run_child_finetuning - eval_loss = 1.281203603744507\n",
+ "06/09/2019 10:34:04 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 10:34:32 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:34:32 - INFO - run_child_finetuning - eval_accuracy = 0.43914930555555554\n",
+ "06/09/2019 10:34:32 - INFO - run_child_finetuning - eval_loss = 1.2950837108823987\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 22%|██▏ | 22/100 [24:10<1:24:31, 65.02s/it]\u001b[A\u001b[A\u001b[A\u001b[A"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "global_step 8000, lr = 0.000026\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "06/09/2019 10:34:41 - INFO - run_child_finetuning - Epoch 23\n",
+ "06/09/2019 10:34:41 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 10:35:08 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:35:08 - INFO - run_child_finetuning - eval_accuracy = 0.44973958333333336\n",
+ "06/09/2019 10:35:08 - INFO - run_child_finetuning - eval_loss = 1.2800932976934645\n",
+ "06/09/2019 10:35:08 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 10:35:36 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:35:36 - INFO - run_child_finetuning - eval_accuracy = 0.43914930555555554\n",
+ "06/09/2019 10:35:36 - INFO - run_child_finetuning - eval_loss = 1.2938868072297838\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 23%|██▎ | 23/100 [25:14<1:23:07, 64.77s/it]\u001b[A\u001b[A\u001b[A\u001b[A06/09/2019 10:35:45 - INFO - run_child_finetuning - Epoch 24\n",
+ "06/09/2019 10:35:45 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 10:36:13 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:36:13 - INFO - run_child_finetuning - eval_accuracy = 0.44973958333333336\n",
+ "06/09/2019 10:36:13 - INFO - run_child_finetuning - eval_loss = 1.2789997299512228\n",
+ "06/09/2019 10:36:13 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 10:36:41 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:36:41 - INFO - run_child_finetuning - eval_accuracy = 0.43914930555555554\n",
+ "06/09/2019 10:36:41 - INFO - run_child_finetuning - eval_loss = 1.2929178635279337\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 24%|██▍ | 24/100 [26:19<1:22:02, 64.77s/it]\u001b[A\u001b[A\u001b[A\u001b[A06/09/2019 10:36:51 - INFO - run_child_finetuning - Epoch 25\n",
+ "06/09/2019 10:36:51 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 10:37:19 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:37:19 - INFO - run_child_finetuning - eval_accuracy = 0.44973958333333336\n",
+ "06/09/2019 10:37:19 - INFO - run_child_finetuning - eval_loss = 1.2782557209332783\n",
+ "06/09/2019 10:37:19 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 10:37:47 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:37:47 - INFO - run_child_finetuning - eval_accuracy = 0.43914930555555554\n",
+ "06/09/2019 10:37:47 - INFO - run_child_finetuning - eval_loss = 1.2921040852864583\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 25%|██▌ | 25/100 [27:25<1:21:25, 65.14s/it]\u001b[A\u001b[A\u001b[A\u001b[A"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "global_step 9000, lr = 0.000017\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "06/09/2019 10:37:57 - INFO - run_child_finetuning - Epoch 26\n",
+ "06/09/2019 10:37:57 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 10:38:25 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:38:25 - INFO - run_child_finetuning - eval_accuracy = 0.44973958333333336\n",
+ "06/09/2019 10:38:25 - INFO - run_child_finetuning - eval_loss = 1.2780342843797472\n",
+ "06/09/2019 10:38:25 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 10:38:53 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:38:53 - INFO - run_child_finetuning - eval_accuracy = 0.43914930555555554\n",
+ "06/09/2019 10:38:53 - INFO - run_child_finetuning - eval_loss = 1.2919086231125725\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 26%|██▌ | 26/100 [28:31<1:20:44, 65.46s/it]\u001b[A\u001b[A\u001b[A\u001b[A06/09/2019 10:39:04 - INFO - run_child_finetuning - Epoch 27\n",
+ "06/09/2019 10:39:04 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 10:39:31 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:39:31 - INFO - run_child_finetuning - eval_accuracy = 0.4519097222222222\n",
+ "06/09/2019 10:39:31 - INFO - run_child_finetuning - eval_loss = 1.2777638885709974\n",
+ "06/09/2019 10:39:31 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 10:39:59 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:39:59 - INFO - run_child_finetuning - eval_accuracy = 0.43697916666666664\n",
+ "06/09/2019 10:39:59 - INFO - run_child_finetuning - eval_loss = 1.291655433177948\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 27%|██▋ | 27/100 [29:37<1:19:49, 65.62s/it]\u001b[A\u001b[A\u001b[A\u001b[A"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "global_step 10000, lr = 0.000007\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "06/09/2019 10:40:10 - INFO - run_child_finetuning - Epoch 28\n",
+ "06/09/2019 10:40:10 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 10:40:38 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:40:38 - INFO - run_child_finetuning - eval_accuracy = 0.4519097222222222\n",
+ "06/09/2019 10:40:38 - INFO - run_child_finetuning - eval_loss = 1.277630979484982\n",
+ "06/09/2019 10:40:38 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 10:41:06 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:41:06 - INFO - run_child_finetuning - eval_accuracy = 0.43697916666666664\n",
+ "06/09/2019 10:41:06 - INFO - run_child_finetuning - eval_loss = 1.2915024439493814\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 28%|██▊ | 28/100 [30:44<1:19:02, 65.87s/it]\u001b[A\u001b[A\u001b[A\u001b[A06/09/2019 10:41:16 - INFO - run_child_finetuning - Epoch 29\n",
+ "06/09/2019 10:41:16 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 10:41:44 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:41:44 - INFO - run_child_finetuning - eval_accuracy = 0.4519097222222222\n",
+ "06/09/2019 10:41:44 - INFO - run_child_finetuning - eval_loss = 1.2776025176048278\n",
+ "06/09/2019 10:41:44 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 10:42:12 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:42:12 - INFO - run_child_finetuning - eval_accuracy = 0.43697916666666664\n",
+ "06/09/2019 10:42:12 - INFO - run_child_finetuning - eval_loss = 1.291484334733751\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 29%|██▉ | 29/100 [31:50<1:18:11, 66.08s/it]\u001b[A\u001b[A\u001b[A\u001b[A06/09/2019 10:42:23 - INFO - run_child_finetuning - Epoch 30\n",
+ "06/09/2019 10:42:23 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 10:42:51 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:42:51 - INFO - run_child_finetuning - eval_accuracy = 0.4519097222222222\n",
+ "06/09/2019 10:42:51 - INFO - run_child_finetuning - eval_loss = 1.2775944696532355\n",
+ "06/09/2019 10:42:51 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 10:43:19 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:43:19 - INFO - run_child_finetuning - eval_accuracy = 0.43697916666666664\n",
+ "06/09/2019 10:43:19 - INFO - run_child_finetuning - eval_loss = 1.291474199295044\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 30%|███ | 30/100 [32:57<1:17:20, 66.29s/it]\u001b[A\u001b[A\u001b[A\u001b[A"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "global_step 11000, lr = -0.000002\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "06/09/2019 10:43:30 - INFO - run_child_finetuning - Epoch 31\n",
+ "06/09/2019 10:43:30 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 10:43:58 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:43:58 - INFO - run_child_finetuning - eval_accuracy = 0.4519097222222222\n",
+ "06/09/2019 10:43:58 - INFO - run_child_finetuning - eval_loss = 1.2775861925548977\n",
+ "06/09/2019 10:43:58 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 10:44:25 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:44:25 - INFO - run_child_finetuning - eval_accuracy = 0.43697916666666664\n",
+ "06/09/2019 10:44:25 - INFO - run_child_finetuning - eval_loss = 1.2914685328801474\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 31%|███ | 31/100 [34:03<1:16:15, 66.31s/it]\u001b[A\u001b[A\u001b[A\u001b[A06/09/2019 10:44:36 - INFO - run_child_finetuning - Epoch 32\n",
+ "06/09/2019 10:44:36 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 10:45:04 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:45:04 - INFO - run_child_finetuning - eval_accuracy = 0.4519097222222222\n",
+ "06/09/2019 10:45:04 - INFO - run_child_finetuning - eval_loss = 1.2775551716486613\n",
+ "06/09/2019 10:45:04 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 10:45:32 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:45:32 - INFO - run_child_finetuning - eval_accuracy = 0.43697916666666664\n",
+ "06/09/2019 10:45:32 - INFO - run_child_finetuning - eval_loss = 1.291435596677992\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 32%|███▏ | 32/100 [35:10<1:15:10, 66.34s/it]\u001b[A\u001b[A\u001b[A\u001b[A06/09/2019 10:45:42 - INFO - run_child_finetuning - Epoch 33\n",
+ "06/09/2019 10:45:42 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 10:46:10 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:46:10 - INFO - run_child_finetuning - eval_accuracy = 0.4519097222222222\n",
+ "06/09/2019 10:46:10 - INFO - run_child_finetuning - eval_loss = 1.2774020512898763\n",
+ "06/09/2019 10:46:10 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 10:46:37 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:46:37 - INFO - run_child_finetuning - eval_accuracy = 0.43697916666666664\n",
+ "06/09/2019 10:46:37 - INFO - run_child_finetuning - eval_loss = 1.2912689606348673\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 33%|███▎ | 33/100 [36:15<1:13:49, 66.11s/it]\u001b[A\u001b[A\u001b[A\u001b[A"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "global_step 12000, lr = -0.000011\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "06/09/2019 10:46:47 - INFO - run_child_finetuning - Epoch 34\n",
+ "06/09/2019 10:46:47 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 10:47:14 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:47:14 - INFO - run_child_finetuning - eval_accuracy = 0.4519097222222222\n",
+ "06/09/2019 10:47:14 - INFO - run_child_finetuning - eval_loss = 1.2771676964230008\n",
+ "06/09/2019 10:47:14 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 10:47:42 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:47:42 - INFO - run_child_finetuning - eval_accuracy = 0.43697916666666664\n",
+ "06/09/2019 10:47:42 - INFO - run_child_finetuning - eval_loss = 1.2910632411638896\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 34%|███▍ | 34/100 [37:20<1:12:16, 65.70s/it]\u001b[A\u001b[A\u001b[A\u001b[A06/09/2019 10:47:54 - INFO - run_child_finetuning - Epoch 35\n",
+ "06/09/2019 10:47:54 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 10:48:21 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:48:21 - INFO - run_child_finetuning - eval_accuracy = 0.4519097222222222\n",
+ "06/09/2019 10:48:21 - INFO - run_child_finetuning - eval_loss = 1.2766649497879876\n",
+ "06/09/2019 10:48:22 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 10:48:49 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:48:49 - INFO - run_child_finetuning - eval_accuracy = 0.43697916666666664\n",
+ "06/09/2019 10:48:49 - INFO - run_child_finetuning - eval_loss = 1.290544174777137\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 35%|███▌ | 35/100 [38:27<1:11:39, 66.14s/it]\u001b[A\u001b[A\u001b[A\u001b[A06/09/2019 10:48:58 - INFO - run_child_finetuning - Epoch 36\n",
+ "06/09/2019 10:48:58 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 10:49:26 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:49:26 - INFO - run_child_finetuning - eval_accuracy = 0.4519097222222222\n",
+ "06/09/2019 10:49:26 - INFO - run_child_finetuning - eval_loss = 1.276163759496477\n",
+ "06/09/2019 10:49:26 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 10:49:53 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:49:53 - INFO - run_child_finetuning - eval_accuracy = 0.43697916666666664\n",
+ "06/09/2019 10:49:53 - INFO - run_child_finetuning - eval_loss = 1.2901054302851358\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 36%|███▌ | 36/100 [39:32<1:09:55, 65.56s/it]\u001b[A\u001b[A\u001b[A\u001b[A"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "global_step 13000, lr = -0.000020\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "06/09/2019 10:50:02 - INFO - run_child_finetuning - Epoch 37\n",
+ "06/09/2019 10:50:02 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 10:50:30 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:50:30 - INFO - run_child_finetuning - eval_accuracy = 0.4519097222222222\n",
+ "06/09/2019 10:50:30 - INFO - run_child_finetuning - eval_loss = 1.275437773598565\n",
+ "06/09/2019 10:50:30 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 10:50:58 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:50:58 - INFO - run_child_finetuning - eval_accuracy = 0.43697916666666664\n",
+ "06/09/2019 10:50:58 - INFO - run_child_finetuning - eval_loss = 1.2893267750740052\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 37%|███▋ | 37/100 [40:36<1:08:23, 65.14s/it]\u001b[A\u001b[A\u001b[A\u001b[A06/09/2019 10:51:06 - INFO - run_child_finetuning - Epoch 38\n",
+ "06/09/2019 10:51:06 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 10:51:34 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:51:34 - INFO - run_child_finetuning - eval_accuracy = 0.4519097222222222\n",
+ "06/09/2019 10:51:34 - INFO - run_child_finetuning - eval_loss = 1.2746005985471938\n",
+ "06/09/2019 10:51:34 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 10:52:02 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:52:02 - INFO - run_child_finetuning - eval_accuracy = 0.43697916666666664\n",
+ "06/09/2019 10:52:02 - INFO - run_child_finetuning - eval_loss = 1.2885264688067966\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 38%|███▊ | 38/100 [41:40<1:07:01, 64.87s/it]\u001b[A\u001b[A\u001b[A\u001b[A"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "global_step 14000, lr = -0.000030\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "06/09/2019 10:52:10 - INFO - run_child_finetuning - Epoch 39\n",
+ "06/09/2019 10:52:10 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 10:52:38 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:52:38 - INFO - run_child_finetuning - eval_accuracy = 0.4519097222222222\n",
+ "06/09/2019 10:52:38 - INFO - run_child_finetuning - eval_loss = 1.2734754257731968\n",
+ "06/09/2019 10:52:38 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 10:53:06 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:53:06 - INFO - run_child_finetuning - eval_accuracy = 0.43697916666666664\n",
+ "06/09/2019 10:53:06 - INFO - run_child_finetuning - eval_loss = 1.28737782769733\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 39%|███▉ | 39/100 [42:44<1:05:48, 64.74s/it]\u001b[A\u001b[A\u001b[A\u001b[A06/09/2019 10:53:18 - INFO - run_child_finetuning - Epoch 40\n",
+ "06/09/2019 10:53:18 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 10:53:46 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:53:46 - INFO - run_child_finetuning - eval_accuracy = 0.44973958333333336\n",
+ "06/09/2019 10:53:46 - INFO - run_child_finetuning - eval_loss = 1.2718001670307584\n",
+ "06/09/2019 10:53:46 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 10:54:14 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:54:14 - INFO - run_child_finetuning - eval_accuracy = 0.43914930555555554\n",
+ "06/09/2019 10:54:14 - INFO - run_child_finetuning - eval_loss = 1.2856513102849325\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 40%|████ | 40/100 [43:52<1:05:42, 65.71s/it]\u001b[A\u001b[A\u001b[A\u001b[A06/09/2019 10:54:25 - INFO - run_child_finetuning - Epoch 41\n",
+ "06/09/2019 10:54:25 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 10:54:53 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:54:53 - INFO - run_child_finetuning - eval_accuracy = 0.4519097222222222\n",
+ "06/09/2019 10:54:53 - INFO - run_child_finetuning - eval_loss = 1.2708097603585986\n",
+ "06/09/2019 10:54:53 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 10:55:21 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:55:21 - INFO - run_child_finetuning - eval_accuracy = 0.43697916666666664\n",
+ "06/09/2019 10:55:21 - INFO - run_child_finetuning - eval_loss = 1.284542813565996\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 41%|████ | 41/100 [44:59<1:04:51, 65.96s/it]\u001b[A\u001b[A\u001b[A\u001b[A"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "global_step 15000, lr = -0.000039\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "06/09/2019 10:55:31 - INFO - run_child_finetuning - Epoch 42\n",
+ "06/09/2019 10:55:31 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 10:55:59 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:55:59 - INFO - run_child_finetuning - eval_accuracy = 0.4519097222222222\n",
+ "06/09/2019 10:55:59 - INFO - run_child_finetuning - eval_loss = 1.2693544705708821\n",
+ "06/09/2019 10:55:59 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 10:56:27 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:56:27 - INFO - run_child_finetuning - eval_accuracy = 0.43697916666666664\n",
+ "06/09/2019 10:56:27 - INFO - run_child_finetuning - eval_loss = 1.2828620976871914\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 42%|████▏ | 42/100 [46:05<1:03:44, 65.93s/it]\u001b[A\u001b[A\u001b[A\u001b[A06/09/2019 10:56:36 - INFO - run_child_finetuning - Epoch 43\n",
+ "06/09/2019 10:56:36 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 10:57:03 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:57:03 - INFO - run_child_finetuning - eval_accuracy = 0.4509548611111111\n",
+ "06/09/2019 10:57:03 - INFO - run_child_finetuning - eval_loss = 1.2677685194545323\n",
+ "06/09/2019 10:57:03 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 10:57:31 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:57:31 - INFO - run_child_finetuning - eval_accuracy = 0.4379340277777778\n",
+ "06/09/2019 10:57:31 - INFO - run_child_finetuning - eval_loss = 1.2817622078789606\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 43%|████▎ | 43/100 [47:09<1:02:12, 65.49s/it]\u001b[A\u001b[A\u001b[A\u001b[A06/09/2019 10:57:41 - INFO - run_child_finetuning - Epoch 44\n",
+ "06/09/2019 10:57:41 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 10:58:09 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:58:09 - INFO - run_child_finetuning - eval_accuracy = 0.4519097222222222\n",
+ "06/09/2019 10:58:09 - INFO - run_child_finetuning - eval_loss = 1.2655644430054558\n",
+ "06/09/2019 10:58:09 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 10:58:37 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:58:37 - INFO - run_child_finetuning - eval_accuracy = 0.43697916666666664\n",
+ "06/09/2019 10:58:37 - INFO - run_child_finetuning - eval_loss = 1.2792559107144674\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 44%|████▍ | 44/100 [48:15<1:01:08, 65.51s/it]\u001b[A\u001b[A\u001b[A\u001b[A"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "global_step 16000, lr = -0.000048\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "06/09/2019 10:58:47 - INFO - run_child_finetuning - Epoch 45\n",
+ "06/09/2019 10:58:47 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 10:59:15 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:59:15 - INFO - run_child_finetuning - eval_accuracy = 0.4509548611111111\n",
+ "06/09/2019 10:59:15 - INFO - run_child_finetuning - eval_loss = 1.2655728247430589\n",
+ "06/09/2019 10:59:15 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 10:59:43 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 10:59:43 - INFO - run_child_finetuning - eval_accuracy = 0.4379340277777778\n",
+ "06/09/2019 10:59:43 - INFO - run_child_finetuning - eval_loss = 1.2796800454457602\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 45%|████▌ | 45/100 [49:21<1:00:09, 65.63s/it]\u001b[A\u001b[A\u001b[A\u001b[A06/09/2019 10:59:53 - INFO - run_child_finetuning - Epoch 46\n",
+ "06/09/2019 10:59:53 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 11:00:21 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:00:21 - INFO - run_child_finetuning - eval_accuracy = 0.4519097222222222\n",
+ "06/09/2019 11:00:21 - INFO - run_child_finetuning - eval_loss = 1.2633930643399556\n",
+ "06/09/2019 11:00:21 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 11:00:49 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:00:49 - INFO - run_child_finetuning - eval_accuracy = 0.43697916666666664\n",
+ "06/09/2019 11:00:49 - INFO - run_child_finetuning - eval_loss = 1.2771886467933655\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 46%|████▌ | 46/100 [50:27<59:10, 65.74s/it] \u001b[A\u001b[A\u001b[A\u001b[A06/09/2019 11:00:59 - INFO - run_child_finetuning - Epoch 47\n",
+ "06/09/2019 11:00:59 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 11:01:27 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:01:27 - INFO - run_child_finetuning - eval_accuracy = 0.45078125\n",
+ "06/09/2019 11:01:27 - INFO - run_child_finetuning - eval_loss = 1.2638664881388346\n",
+ "06/09/2019 11:01:27 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 11:01:55 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:01:55 - INFO - run_child_finetuning - eval_accuracy = 0.4381076388888889\n",
+ "06/09/2019 11:01:55 - INFO - run_child_finetuning - eval_loss = 1.2772795226838853\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 47%|████▋ | 47/100 [51:33<58:08, 65.82s/it]\u001b[A\u001b[A\u001b[A\u001b[A"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "global_step 17000, lr = -0.000057\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "06/09/2019 11:02:05 - INFO - run_child_finetuning - Epoch 48\n",
+ "06/09/2019 11:02:05 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 11:02:33 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:02:33 - INFO - run_child_finetuning - eval_accuracy = 0.45199652777777777\n",
+ "06/09/2019 11:02:33 - INFO - run_child_finetuning - eval_loss = 1.2638536400265163\n",
+ "06/09/2019 11:02:33 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 11:03:00 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:03:00 - INFO - run_child_finetuning - eval_accuracy = 0.4368923611111111\n",
+ "06/09/2019 11:03:00 - INFO - run_child_finetuning - eval_loss = 1.2784057392014399\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 48%|████▊ | 48/100 [52:39<57:02, 65.81s/it]\u001b[A\u001b[A\u001b[A\u001b[A06/09/2019 11:03:11 - INFO - run_child_finetuning - Epoch 49\n",
+ "06/09/2019 11:03:11 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 11:03:38 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:03:38 - INFO - run_child_finetuning - eval_accuracy = 0.44973958333333336\n",
+ "06/09/2019 11:03:38 - INFO - run_child_finetuning - eval_loss = 1.2610097911622788\n",
+ "06/09/2019 11:03:38 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 11:04:06 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:04:06 - INFO - run_child_finetuning - eval_accuracy = 0.43914930555555554\n",
+ "06/09/2019 11:04:06 - INFO - run_child_finetuning - eval_loss = 1.2743374122513664\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 49%|████▉ | 49/100 [53:44<55:55, 65.80s/it]\u001b[A\u001b[A\u001b[A\u001b[A06/09/2019 11:04:16 - INFO - run_child_finetuning - Epoch 50\n",
+ "06/09/2019 11:04:16 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 11:04:44 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:04:44 - INFO - run_child_finetuning - eval_accuracy = 0.4509548611111111\n",
+ "06/09/2019 11:04:44 - INFO - run_child_finetuning - eval_loss = 1.2613953351974487\n",
+ "06/09/2019 11:04:44 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 11:05:11 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:05:11 - INFO - run_child_finetuning - eval_accuracy = 0.4379340277777778\n",
+ "06/09/2019 11:05:11 - INFO - run_child_finetuning - eval_loss = 1.2759553485446506\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 50%|█████ | 50/100 [54:49<54:39, 65.59s/it]\u001b[A\u001b[A\u001b[A\u001b[A"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "global_step 18000, lr = -0.000067\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "06/09/2019 11:05:22 - INFO - run_child_finetuning - Epoch 51\n",
+ "06/09/2019 11:05:22 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 11:05:49 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:05:49 - INFO - run_child_finetuning - eval_accuracy = 0.4509548611111111\n",
+ "06/09/2019 11:05:49 - INFO - run_child_finetuning - eval_loss = 1.2582731948958503\n",
+ "06/09/2019 11:05:49 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 11:06:17 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:06:17 - INFO - run_child_finetuning - eval_accuracy = 0.4379340277777778\n",
+ "06/09/2019 11:06:17 - INFO - run_child_finetuning - eval_loss = 1.271828391816881\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 51%|█████ | 51/100 [55:55<53:39, 65.69s/it]\u001b[A\u001b[A\u001b[A\u001b[A06/09/2019 11:06:28 - INFO - run_child_finetuning - Epoch 52\n",
+ "06/09/2019 11:06:28 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 11:06:56 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:06:56 - INFO - run_child_finetuning - eval_accuracy = 0.4513888888888889\n",
+ "06/09/2019 11:06:56 - INFO - run_child_finetuning - eval_loss = 1.2580342027876112\n",
+ "06/09/2019 11:06:56 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 11:07:24 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:07:24 - INFO - run_child_finetuning - eval_accuracy = 0.4375\n",
+ "06/09/2019 11:07:24 - INFO - run_child_finetuning - eval_loss = 1.2725122107399836\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 52%|█████▏ | 52/100 [57:03<52:55, 66.15s/it]\u001b[A\u001b[A\u001b[A\u001b[A"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "global_step 19000, lr = -0.000076\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "06/09/2019 11:07:35 - INFO - run_child_finetuning - Epoch 53\n",
+ "06/09/2019 11:07:35 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 11:08:03 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:08:03 - INFO - run_child_finetuning - eval_accuracy = 0.45078125\n",
+ "06/09/2019 11:08:03 - INFO - run_child_finetuning - eval_loss = 1.2580892933739556\n",
+ "06/09/2019 11:08:03 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 11:08:31 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:08:31 - INFO - run_child_finetuning - eval_accuracy = 0.4381076388888889\n",
+ "06/09/2019 11:08:31 - INFO - run_child_finetuning - eval_loss = 1.2724553810225592\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 53%|█████▎ | 53/100 [58:09<51:53, 66.24s/it]\u001b[A\u001b[A\u001b[A\u001b[A06/09/2019 11:08:42 - INFO - run_child_finetuning - Epoch 54\n",
+ "06/09/2019 11:08:42 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 11:09:10 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:09:10 - INFO - run_child_finetuning - eval_accuracy = 0.4519097222222222\n",
+ "06/09/2019 11:09:10 - INFO - run_child_finetuning - eval_loss = 1.2563159465789795\n",
+ "06/09/2019 11:09:10 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 11:09:37 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:09:37 - INFO - run_child_finetuning - eval_accuracy = 0.43697916666666664\n",
+ "06/09/2019 11:09:37 - INFO - run_child_finetuning - eval_loss = 1.2693641344706217\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 54%|█████▍ | 54/100 [59:16<50:51, 66.34s/it]\u001b[A\u001b[A\u001b[A\u001b[A06/09/2019 11:09:48 - INFO - run_child_finetuning - Epoch 55\n",
+ "06/09/2019 11:09:48 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 11:10:16 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:10:16 - INFO - run_child_finetuning - eval_accuracy = 0.45069444444444445\n",
+ "06/09/2019 11:10:16 - INFO - run_child_finetuning - eval_loss = 1.2571652372678122\n",
+ "06/09/2019 11:10:16 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 11:10:43 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:10:43 - INFO - run_child_finetuning - eval_accuracy = 0.43819444444444444\n",
+ "06/09/2019 11:10:43 - INFO - run_child_finetuning - eval_loss = 1.2704107642173768\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 55%|█████▌ | 55/100 [1:00:22<49:40, 66.22s/it]\u001b[A\u001b[A\u001b[A\u001b[A"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "global_step 20000, lr = -0.000085\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "06/09/2019 11:10:54 - INFO - run_child_finetuning - Epoch 56\n",
+ "06/09/2019 11:10:54 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 11:11:22 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:11:22 - INFO - run_child_finetuning - eval_accuracy = 0.45078125\n",
+ "06/09/2019 11:11:22 - INFO - run_child_finetuning - eval_loss = 1.2557978232701619\n",
+ "06/09/2019 11:11:22 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 11:11:49 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:11:49 - INFO - run_child_finetuning - eval_accuracy = 0.4381076388888889\n",
+ "06/09/2019 11:11:49 - INFO - run_child_finetuning - eval_loss = 1.2688145054711235\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 56%|█████▌ | 56/100 [1:01:27<48:30, 66.15s/it]\u001b[A\u001b[A\u001b[A\u001b[A06/09/2019 11:12:00 - INFO - run_child_finetuning - Epoch 57\n",
+ "06/09/2019 11:12:00 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 11:12:28 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:12:28 - INFO - run_child_finetuning - eval_accuracy = 0.4509548611111111\n",
+ "06/09/2019 11:12:28 - INFO - run_child_finetuning - eval_loss = 1.2567673405011495\n",
+ "06/09/2019 11:12:28 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 11:12:56 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:12:56 - INFO - run_child_finetuning - eval_accuracy = 0.4379340277777778\n",
+ "06/09/2019 11:12:56 - INFO - run_child_finetuning - eval_loss = 1.2707869211832683\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 57%|█████▋ | 57/100 [1:02:34<47:29, 66.27s/it]\u001b[A\u001b[A\u001b[A\u001b[A06/09/2019 11:13:06 - INFO - run_child_finetuning - Epoch 58\n",
+ "06/09/2019 11:13:06 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 11:13:34 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:13:34 - INFO - run_child_finetuning - eval_accuracy = 0.4519097222222222\n",
+ "06/09/2019 11:13:34 - INFO - run_child_finetuning - eval_loss = 1.2552655418713887\n",
+ "06/09/2019 11:13:34 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 11:14:02 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:14:02 - INFO - run_child_finetuning - eval_accuracy = 0.43697916666666664\n",
+ "06/09/2019 11:14:02 - INFO - run_child_finetuning - eval_loss = 1.2685243950949774\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 58%|█████▊ | 58/100 [1:03:40<46:18, 66.16s/it]\u001b[A\u001b[A\u001b[A\u001b[A"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "global_step 21000, lr = -0.000094\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "06/09/2019 11:14:10 - INFO - run_child_finetuning - Epoch 59\n",
+ "06/09/2019 11:14:10 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 11:14:38 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:14:38 - INFO - run_child_finetuning - eval_accuracy = 0.45078125\n",
+ "06/09/2019 11:14:38 - INFO - run_child_finetuning - eval_loss = 1.2543529285324944\n",
+ "06/09/2019 11:14:38 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 11:15:06 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:15:06 - INFO - run_child_finetuning - eval_accuracy = 0.4381076388888889\n",
+ "06/09/2019 11:15:06 - INFO - run_child_finetuning - eval_loss = 1.2681733555263943\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 59%|█████▉ | 59/100 [1:04:44<44:49, 65.59s/it]\u001b[A\u001b[A\u001b[A\u001b[A06/09/2019 11:15:16 - INFO - run_child_finetuning - Epoch 60\n",
+ "06/09/2019 11:15:16 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 11:15:44 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:15:44 - INFO - run_child_finetuning - eval_accuracy = 0.4513888888888889\n",
+ "06/09/2019 11:15:44 - INFO - run_child_finetuning - eval_loss = 1.2541927046246\n",
+ "06/09/2019 11:15:44 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 11:16:12 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:16:12 - INFO - run_child_finetuning - eval_accuracy = 0.4375\n",
+ "06/09/2019 11:16:12 - INFO - run_child_finetuning - eval_loss = 1.267860644393497\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 60%|██████ | 60/100 [1:05:50<43:49, 65.73s/it]\u001b[A\u001b[A\u001b[A\u001b[A06/09/2019 11:16:23 - INFO - run_child_finetuning - Epoch 61\n",
+ "06/09/2019 11:16:23 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 11:16:51 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:16:51 - INFO - run_child_finetuning - eval_accuracy = 0.4513020833333333\n",
+ "06/09/2019 11:16:51 - INFO - run_child_finetuning - eval_loss = 1.2536385284529792\n",
+ "06/09/2019 11:16:51 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 11:17:19 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:17:19 - INFO - run_child_finetuning - eval_accuracy = 0.43758680555555557\n",
+ "06/09/2019 11:17:19 - INFO - run_child_finetuning - eval_loss = 1.2662296599811977\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 61%|██████ | 61/100 [1:06:57<42:56, 66.08s/it]\u001b[A\u001b[A\u001b[A\u001b[A"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "global_step 22000, lr = -0.000104\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "06/09/2019 11:17:31 - INFO - run_child_finetuning - Epoch 62\n",
+ "06/09/2019 11:17:31 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 11:17:59 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:17:59 - INFO - run_child_finetuning - eval_accuracy = 0.4519097222222222\n",
+ "06/09/2019 11:17:59 - INFO - run_child_finetuning - eval_loss = 1.2535286770926581\n",
+ "06/09/2019 11:17:59 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 11:18:27 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:18:27 - INFO - run_child_finetuning - eval_accuracy = 0.43697916666666664\n",
+ "06/09/2019 11:18:27 - INFO - run_child_finetuning - eval_loss = 1.267613332801395\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 62%|██████▏ | 62/100 [1:08:05<42:13, 66.66s/it]\u001b[A\u001b[A\u001b[A\u001b[A06/09/2019 11:18:37 - INFO - run_child_finetuning - Epoch 63\n",
+ "06/09/2019 11:18:37 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 11:19:05 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:19:05 - INFO - run_child_finetuning - eval_accuracy = 0.4515625\n",
+ "06/09/2019 11:19:05 - INFO - run_child_finetuning - eval_loss = 1.2538655002911885\n",
+ "06/09/2019 11:19:05 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 11:19:33 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:19:33 - INFO - run_child_finetuning - eval_accuracy = 0.43732638888888886\n",
+ "06/09/2019 11:19:33 - INFO - run_child_finetuning - eval_loss = 1.2682664235432943\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 63%|██████▎ | 63/100 [1:09:11<40:59, 66.47s/it]\u001b[A\u001b[A\u001b[A\u001b[A"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "global_step 23000, lr = -0.000113\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "06/09/2019 11:19:42 - INFO - run_child_finetuning - Epoch 64\n",
+ "06/09/2019 11:19:42 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 11:20:09 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:20:09 - INFO - run_child_finetuning - eval_accuracy = 0.44991319444444444\n",
+ "06/09/2019 11:20:09 - INFO - run_child_finetuning - eval_loss = 1.2541112303733826\n",
+ "06/09/2019 11:20:09 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 11:20:37 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:20:37 - INFO - run_child_finetuning - eval_accuracy = 0.43897569444444445\n",
+ "06/09/2019 11:20:37 - INFO - run_child_finetuning - eval_loss = 1.267833666006724\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 64%|██████▍ | 64/100 [1:10:15<39:27, 65.76s/it]\u001b[A\u001b[A\u001b[A\u001b[A06/09/2019 11:20:46 - INFO - run_child_finetuning - Epoch 65\n",
+ "06/09/2019 11:20:46 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 11:21:14 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:21:14 - INFO - run_child_finetuning - eval_accuracy = 0.4509548611111111\n",
+ "06/09/2019 11:21:14 - INFO - run_child_finetuning - eval_loss = 1.2567384481430053\n",
+ "06/09/2019 11:21:14 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 11:21:42 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:21:42 - INFO - run_child_finetuning - eval_accuracy = 0.4379340277777778\n",
+ "06/09/2019 11:21:42 - INFO - run_child_finetuning - eval_loss = 1.2706474079026115\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 65%|██████▌ | 65/100 [1:11:20<38:10, 65.44s/it]\u001b[A\u001b[A\u001b[A\u001b[A06/09/2019 11:21:52 - INFO - run_child_finetuning - Epoch 66\n",
+ "06/09/2019 11:21:52 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 11:22:20 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:22:20 - INFO - run_child_finetuning - eval_accuracy = 0.45078125\n",
+ "06/09/2019 11:22:20 - INFO - run_child_finetuning - eval_loss = 1.253372961945004\n",
+ "06/09/2019 11:22:20 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 11:22:48 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:22:48 - INFO - run_child_finetuning - eval_accuracy = 0.4381076388888889\n",
+ "06/09/2019 11:22:48 - INFO - run_child_finetuning - eval_loss = 1.267328475581275\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 66%|██████▌ | 66/100 [1:12:26<37:10, 65.61s/it]\u001b[A\u001b[A\u001b[A\u001b[A"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "global_step 24000, lr = -0.000122\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "06/09/2019 11:22:58 - INFO - run_child_finetuning - Epoch 67\n",
+ "06/09/2019 11:22:58 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 11:23:27 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:23:27 - INFO - run_child_finetuning - eval_accuracy = 0.45078125\n",
+ "06/09/2019 11:23:27 - INFO - run_child_finetuning - eval_loss = 1.2558889269828797\n",
+ "06/09/2019 11:23:27 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 11:23:55 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:23:55 - INFO - run_child_finetuning - eval_accuracy = 0.4381076388888889\n",
+ "06/09/2019 11:23:55 - INFO - run_child_finetuning - eval_loss = 1.2680187635951572\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 67%|██████▋ | 67/100 [1:13:33<36:17, 65.97s/it]\u001b[A\u001b[A\u001b[A\u001b[A06/09/2019 11:24:05 - INFO - run_child_finetuning - Epoch 68\n",
+ "06/09/2019 11:24:05 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 11:24:33 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:24:33 - INFO - run_child_finetuning - eval_accuracy = 0.4509548611111111\n",
+ "06/09/2019 11:24:33 - INFO - run_child_finetuning - eval_loss = 1.2534217052989536\n",
+ "06/09/2019 11:24:33 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 11:25:01 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:25:01 - INFO - run_child_finetuning - eval_accuracy = 0.4379340277777778\n",
+ "06/09/2019 11:25:01 - INFO - run_child_finetuning - eval_loss = 1.267832002374861\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 68%|██████▊ | 68/100 [1:14:39<35:12, 66.03s/it]\u001b[A\u001b[A\u001b[A\u001b[A06/09/2019 11:25:12 - INFO - run_child_finetuning - Epoch 69\n",
+ "06/09/2019 11:25:12 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 11:25:40 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:25:40 - INFO - run_child_finetuning - eval_accuracy = 0.4515625\n",
+ "06/09/2019 11:25:40 - INFO - run_child_finetuning - eval_loss = 1.2543478224012587\n",
+ "06/09/2019 11:25:40 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 11:26:08 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:26:08 - INFO - run_child_finetuning - eval_accuracy = 0.43732638888888886\n",
+ "06/09/2019 11:26:08 - INFO - run_child_finetuning - eval_loss = 1.2676605025927226\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 69%|██████▉ | 69/100 [1:15:46<34:13, 66.25s/it]\u001b[A\u001b[A\u001b[A\u001b[A"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "global_step 25000, lr = -0.000131\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "06/09/2019 11:26:20 - INFO - run_child_finetuning - Epoch 70\n",
+ "06/09/2019 11:26:20 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 11:26:48 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:26:48 - INFO - run_child_finetuning - eval_accuracy = 0.44973958333333336\n",
+ "06/09/2019 11:26:48 - INFO - run_child_finetuning - eval_loss = 1.256232378217909\n",
+ "06/09/2019 11:26:48 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 11:27:16 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:27:16 - INFO - run_child_finetuning - eval_accuracy = 0.43914930555555554\n",
+ "06/09/2019 11:27:16 - INFO - run_child_finetuning - eval_loss = 1.2710346884197659\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 70%|███████ | 70/100 [1:16:54<33:23, 66.77s/it]\u001b[A\u001b[A\u001b[A\u001b[A06/09/2019 11:27:28 - INFO - run_child_finetuning - Epoch 71\n",
+ "06/09/2019 11:27:28 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 11:27:56 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:27:56 - INFO - run_child_finetuning - eval_accuracy = 0.4509548611111111\n",
+ "06/09/2019 11:27:56 - INFO - run_child_finetuning - eval_loss = 1.254373996787601\n",
+ "06/09/2019 11:27:56 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 11:28:24 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:28:24 - INFO - run_child_finetuning - eval_accuracy = 0.4379340277777778\n",
+ "06/09/2019 11:28:24 - INFO - run_child_finetuning - eval_loss = 1.268411025736067\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 71%|███████ | 71/100 [1:18:02<32:27, 67.14s/it]\u001b[A\u001b[A\u001b[A\u001b[A06/09/2019 11:28:34 - INFO - run_child_finetuning - Epoch 72\n",
+ "06/09/2019 11:28:34 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 11:29:02 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:29:02 - INFO - run_child_finetuning - eval_accuracy = 0.44973958333333336\n",
+ "06/09/2019 11:29:02 - INFO - run_child_finetuning - eval_loss = 1.2569880644480387\n",
+ "06/09/2019 11:29:02 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 11:29:30 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:29:30 - INFO - run_child_finetuning - eval_accuracy = 0.43914930555555554\n",
+ "06/09/2019 11:29:30 - INFO - run_child_finetuning - eval_loss = 1.2723949763509963\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 72%|███████▏ | 72/100 [1:19:08<31:12, 66.86s/it]\u001b[A\u001b[A\u001b[A\u001b[A"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "global_step 26000, lr = -0.000141\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "06/09/2019 11:29:40 - INFO - run_child_finetuning - Epoch 73\n",
+ "06/09/2019 11:29:40 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 11:30:08 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:30:08 - INFO - run_child_finetuning - eval_accuracy = 0.4509548611111111\n",
+ "06/09/2019 11:30:08 - INFO - run_child_finetuning - eval_loss = 1.2545752935939365\n",
+ "06/09/2019 11:30:08 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 11:30:36 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:30:36 - INFO - run_child_finetuning - eval_accuracy = 0.4379340277777778\n",
+ "06/09/2019 11:30:36 - INFO - run_child_finetuning - eval_loss = 1.269049670961168\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 73%|███████▎ | 73/100 [1:20:14<29:58, 66.62s/it]\u001b[A\u001b[A\u001b[A\u001b[A06/09/2019 11:30:46 - INFO - run_child_finetuning - Epoch 74\n",
+ "06/09/2019 11:30:46 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 11:31:14 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:31:14 - INFO - run_child_finetuning - eval_accuracy = 0.4509548611111111\n",
+ "06/09/2019 11:31:14 - INFO - run_child_finetuning - eval_loss = 1.2558313740624323\n",
+ "06/09/2019 11:31:14 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 11:31:42 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:31:42 - INFO - run_child_finetuning - eval_accuracy = 0.4379340277777778\n",
+ "06/09/2019 11:31:42 - INFO - run_child_finetuning - eval_loss = 1.2702598147922093\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 74%|███████▍ | 74/100 [1:21:20<28:47, 66.45s/it]\u001b[A\u001b[A\u001b[A\u001b[A06/09/2019 11:31:52 - INFO - run_child_finetuning - Epoch 75\n",
+ "06/09/2019 11:31:52 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 11:32:20 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:32:20 - INFO - run_child_finetuning - eval_accuracy = 0.45069444444444445\n",
+ "06/09/2019 11:32:20 - INFO - run_child_finetuning - eval_loss = 1.25285046365526\n",
+ "06/09/2019 11:32:20 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 11:32:48 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:32:48 - INFO - run_child_finetuning - eval_accuracy = 0.43819444444444444\n",
+ "06/09/2019 11:32:48 - INFO - run_child_finetuning - eval_loss = 1.2665512853198582\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 75%|███████▌ | 75/100 [1:22:26<27:37, 66.32s/it]\u001b[A\u001b[A\u001b[A\u001b[A"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "global_step 27000, lr = -0.000150\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "06/09/2019 11:32:58 - INFO - run_child_finetuning - Epoch 76\n",
+ "06/09/2019 11:32:58 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 11:33:26 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:33:26 - INFO - run_child_finetuning - eval_accuracy = 0.4599826388888889\n",
+ "06/09/2019 11:33:26 - INFO - run_child_finetuning - eval_loss = 1.2468693004714118\n",
+ "06/09/2019 11:33:26 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 11:35:11 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 11:35:39 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:35:39 - INFO - run_child_finetuning - eval_accuracy = 0.4732638888888889\n",
+ "06/09/2019 11:35:39 - INFO - run_child_finetuning - eval_loss = 1.2132148491011725\n",
+ "06/09/2019 11:35:39 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 11:36:06 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:36:06 - INFO - run_child_finetuning - eval_accuracy = 0.46206597222222223\n",
+ "06/09/2019 11:36:06 - INFO - run_child_finetuning - eval_loss = 1.2266984952820672\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 78%|███████▊ | 78/100 [1:25:44<24:15, 66.17s/it]\u001b[A\u001b[A\u001b[A\u001b[A06/09/2019 11:36:17 - INFO - run_child_finetuning - Epoch 79\n",
+ "06/09/2019 11:36:17 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 11:36:44 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:36:44 - INFO - run_child_finetuning - eval_accuracy = 0.4810763888888889\n",
+ "06/09/2019 11:36:44 - INFO - run_child_finetuning - eval_loss = 1.1860122005144755\n",
+ "06/09/2019 11:36:44 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 11:37:12 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:37:12 - INFO - run_child_finetuning - eval_accuracy = 0.46848958333333335\n",
+ "06/09/2019 11:37:12 - INFO - run_child_finetuning - eval_loss = 1.1959241694874234\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 79%|███████▉ | 79/100 [1:26:50<23:08, 66.12s/it]\u001b[A\u001b[A\u001b[A\u001b[A06/09/2019 11:37:22 - INFO - run_child_finetuning - Epoch 80\n",
+ "06/09/2019 11:37:22 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 11:37:50 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:37:50 - INFO - run_child_finetuning - eval_accuracy = 0.49288194444444444\n",
+ "06/09/2019 11:37:50 - INFO - run_child_finetuning - eval_loss = 1.1808326456281875\n",
+ "06/09/2019 11:37:50 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 11:38:18 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:38:18 - INFO - run_child_finetuning - eval_accuracy = 0.4830729166666667\n",
+ "06/09/2019 11:38:18 - INFO - run_child_finetuning - eval_loss = 1.1921248899565802\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 80%|████████ | 80/100 [1:27:56<21:59, 65.98s/it]\u001b[A\u001b[A\u001b[A\u001b[A"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "global_step 29000, lr = -0.000169\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "06/09/2019 11:38:28 - INFO - run_child_finetuning - Epoch 81\n",
+ "06/09/2019 11:38:28 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 11:38:56 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:38:56 - INFO - run_child_finetuning - eval_accuracy = 0.5042534722222223\n",
+ "06/09/2019 11:38:56 - INFO - run_child_finetuning - eval_loss = 1.1616202606095207\n",
+ "06/09/2019 11:38:56 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 11:39:24 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:39:24 - INFO - run_child_finetuning - eval_accuracy = 0.4915798611111111\n",
+ "06/09/2019 11:39:24 - INFO - run_child_finetuning - eval_loss = 1.1748484041955736\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 81%|████████ | 81/100 [1:29:02<20:53, 65.97s/it]\u001b[A\u001b[A\u001b[A\u001b[A06/09/2019 11:39:34 - INFO - run_child_finetuning - Epoch 82\n",
+ "06/09/2019 11:39:34 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 11:40:02 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:40:02 - INFO - run_child_finetuning - eval_accuracy = 0.5051215277777777\n",
+ "06/09/2019 11:40:02 - INFO - run_child_finetuning - eval_loss = 1.1368214580747815\n",
+ "06/09/2019 11:40:02 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 11:40:30 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:40:30 - INFO - run_child_finetuning - eval_accuracy = 0.49105902777777777\n",
+ "06/09/2019 11:40:30 - INFO - run_child_finetuning - eval_loss = 1.152006494998932\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 82%|████████▏ | 82/100 [1:30:08<19:48, 66.01s/it]\u001b[A\u001b[A\u001b[A\u001b[A06/09/2019 11:40:40 - INFO - run_child_finetuning - Epoch 83\n",
+ "06/09/2019 11:40:40 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 11:41:08 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:41:08 - INFO - run_child_finetuning - eval_accuracy = 0.5098090277777778\n",
+ "06/09/2019 11:41:08 - INFO - run_child_finetuning - eval_loss = 1.0961747805277506\n",
+ "06/09/2019 11:41:08 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 11:41:36 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:41:36 - INFO - run_child_finetuning - eval_accuracy = 0.5002604166666667\n",
+ "06/09/2019 11:41:36 - INFO - run_child_finetuning - eval_loss = 1.1076407035191853\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 83%|████████▎ | 83/100 [1:31:14<18:42, 66.02s/it]\u001b[A\u001b[A\u001b[A\u001b[A"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "global_step 30000, lr = -0.000178\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "06/09/2019 11:41:47 - INFO - run_child_finetuning - Epoch 84\n",
+ "06/09/2019 11:41:47 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 11:42:14 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:42:14 - INFO - run_child_finetuning - eval_accuracy = 0.5082465277777778\n",
+ "06/09/2019 11:42:14 - INFO - run_child_finetuning - eval_loss = 1.076478154791726\n",
+ "06/09/2019 11:42:14 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 11:42:42 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:42:42 - INFO - run_child_finetuning - eval_accuracy = 0.4957465277777778\n",
+ "06/09/2019 11:42:42 - INFO - run_child_finetuning - eval_loss = 1.0900266740057203\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 84%|████████▍ | 84/100 [1:32:20<17:36, 66.05s/it]\u001b[A\u001b[A\u001b[A\u001b[A06/09/2019 11:42:53 - INFO - run_child_finetuning - Epoch 85\n",
+ "06/09/2019 11:42:53 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 11:43:20 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:43:20 - INFO - run_child_finetuning - eval_accuracy = 0.5261284722222223\n",
+ "06/09/2019 11:43:20 - INFO - run_child_finetuning - eval_loss = 1.04820695983039\n",
+ "06/09/2019 11:43:21 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 11:43:48 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:43:48 - INFO - run_child_finetuning - eval_accuracy = 0.5151041666666667\n",
+ "06/09/2019 11:43:48 - INFO - run_child_finetuning - eval_loss = 1.0613130741649204\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 85%|████████▌ | 85/100 [1:33:26<16:31, 66.07s/it]\u001b[A\u001b[A\u001b[A\u001b[A06/09/2019 11:43:58 - INFO - run_child_finetuning - Epoch 86\n",
+ "06/09/2019 11:43:58 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 11:44:26 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:44:26 - INFO - run_child_finetuning - eval_accuracy = 0.5272569444444445\n",
+ "06/09/2019 11:44:26 - INFO - run_child_finetuning - eval_loss = 1.0183378650082482\n",
+ "06/09/2019 11:44:26 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 11:44:54 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:44:54 - INFO - run_child_finetuning - eval_accuracy = 0.5131076388888889\n",
+ "06/09/2019 11:44:54 - INFO - run_child_finetuning - eval_loss = 1.0315876146157583\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 86%|████████▌ | 86/100 [1:34:32<15:23, 65.96s/it]\u001b[A\u001b[A\u001b[A\u001b[A"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "global_step 31000, lr = -0.000187\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "06/09/2019 11:45:05 - INFO - run_child_finetuning - Epoch 87\n",
+ "06/09/2019 11:45:05 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 11:45:33 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:45:33 - INFO - run_child_finetuning - eval_accuracy = 0.5419270833333333\n",
+ "06/09/2019 11:45:33 - INFO - run_child_finetuning - eval_loss = 0.9918538702858819\n",
+ "06/09/2019 11:45:33 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 11:46:01 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:46:01 - INFO - run_child_finetuning - eval_accuracy = 0.5289930555555555\n",
+ "06/09/2019 11:46:01 - INFO - run_child_finetuning - eval_loss = 1.0008921066919962\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 87%|████████▋ | 87/100 [1:35:39<14:19, 66.13s/it]\u001b[A\u001b[A\u001b[A\u001b[A06/09/2019 11:46:11 - INFO - run_child_finetuning - Epoch 88\n",
+ "06/09/2019 11:46:11 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 11:46:39 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:46:39 - INFO - run_child_finetuning - eval_accuracy = 0.5427083333333333\n",
+ "06/09/2019 11:46:39 - INFO - run_child_finetuning - eval_loss = 0.9888957500457763\n",
+ "06/09/2019 11:46:39 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 11:47:07 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:47:07 - INFO - run_child_finetuning - eval_accuracy = 0.5352430555555555\n",
+ "06/09/2019 11:47:07 - INFO - run_child_finetuning - eval_loss = 1.0018013291888768\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 88%|████████▊ | 88/100 [1:36:45<13:14, 66.23s/it]\u001b[A\u001b[A\u001b[A\u001b[A"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "global_step 32000, lr = -0.000196\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "06/09/2019 11:47:17 - INFO - run_child_finetuning - Epoch 89\n",
+ "06/09/2019 11:47:17 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 11:47:45 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:47:45 - INFO - run_child_finetuning - eval_accuracy = 0.56171875\n",
+ "06/09/2019 11:47:45 - INFO - run_child_finetuning - eval_loss = 0.9563338147269355\n",
+ "06/09/2019 11:47:45 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 11:48:13 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:48:13 - INFO - run_child_finetuning - eval_accuracy = 0.5494791666666666\n",
+ "06/09/2019 11:48:13 - INFO - run_child_finetuning - eval_loss = 0.9669986453321245\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 89%|████████▉ | 89/100 [1:37:51<12:08, 66.23s/it]\u001b[A\u001b[A\u001b[A\u001b[A06/09/2019 11:48:24 - INFO - run_child_finetuning - Epoch 90\n",
+ "06/09/2019 11:48:24 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 11:48:51 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:48:51 - INFO - run_child_finetuning - eval_accuracy = 0.5645833333333333\n",
+ "06/09/2019 11:48:51 - INFO - run_child_finetuning - eval_loss = 0.9435959888829125\n",
+ "06/09/2019 11:48:52 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 11:49:19 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:49:19 - INFO - run_child_finetuning - eval_accuracy = 0.5533854166666666\n",
+ "06/09/2019 11:49:19 - INFO - run_child_finetuning - eval_loss = 0.9578184756967757\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 90%|█████████ | 90/100 [1:38:57<11:01, 66.16s/it]\u001b[A\u001b[A\u001b[A\u001b[A06/09/2019 11:49:30 - INFO - run_child_finetuning - Epoch 91\n",
+ "06/09/2019 11:49:30 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 11:49:57 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:49:57 - INFO - run_child_finetuning - eval_accuracy = 0.5693576388888889\n",
+ "06/09/2019 11:49:57 - INFO - run_child_finetuning - eval_loss = 0.928356761402554\n",
+ "06/09/2019 11:49:57 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 11:50:25 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:50:25 - INFO - run_child_finetuning - eval_accuracy = 0.5584201388888889\n",
+ "06/09/2019 11:50:25 - INFO - run_child_finetuning - eval_loss = 0.9383055541250441\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 91%|█████████ | 91/100 [1:40:04<09:55, 66.18s/it]\u001b[A\u001b[A\u001b[A\u001b[A"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "global_step 33000, lr = -0.000206\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "06/09/2019 11:50:36 - INFO - run_child_finetuning - Epoch 92\n",
+ "06/09/2019 11:50:36 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 11:51:04 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:51:04 - INFO - run_child_finetuning - eval_accuracy = 0.5755208333333334\n",
+ "06/09/2019 11:51:04 - INFO - run_child_finetuning - eval_loss = 0.9096341941091749\n",
+ "06/09/2019 11:51:04 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 11:51:32 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:51:32 - INFO - run_child_finetuning - eval_accuracy = 0.56328125\n",
+ "06/09/2019 11:51:32 - INFO - run_child_finetuning - eval_loss = 0.9205585459868113\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 92%|█████████▏| 92/100 [1:41:10<08:49, 66.14s/it]\u001b[A\u001b[A\u001b[A\u001b[A06/09/2019 11:51:43 - INFO - run_child_finetuning - Epoch 93\n",
+ "06/09/2019 11:51:43 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 11:52:11 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:52:11 - INFO - run_child_finetuning - eval_accuracy = 0.58046875\n",
+ "06/09/2019 11:52:11 - INFO - run_child_finetuning - eval_loss = 0.900900975200865\n",
+ "06/09/2019 11:52:11 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 11:52:38 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:52:38 - INFO - run_child_finetuning - eval_accuracy = 0.5693576388888889\n",
+ "06/09/2019 11:52:38 - INFO - run_child_finetuning - eval_loss = 0.9142036868466271\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 93%|█████████▎| 93/100 [1:42:16<07:44, 66.33s/it]\u001b[A\u001b[A\u001b[A\u001b[A06/09/2019 11:52:47 - INFO - run_child_finetuning - Epoch 94\n",
+ "06/09/2019 11:52:47 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 11:53:15 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:53:15 - INFO - run_child_finetuning - eval_accuracy = 0.5886284722222223\n",
+ "06/09/2019 11:53:15 - INFO - run_child_finetuning - eval_loss = 0.8851869417561425\n",
+ "06/09/2019 11:53:15 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 11:53:43 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:53:43 - INFO - run_child_finetuning - eval_accuracy = 0.5755208333333334\n",
+ "06/09/2019 11:53:43 - INFO - run_child_finetuning - eval_loss = 0.8926095426082611\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 94%|█████████▍| 94/100 [1:43:21<06:34, 65.73s/it]\u001b[A\u001b[A\u001b[A\u001b[A"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "global_step 34000, lr = -0.000215\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "06/09/2019 11:53:54 - INFO - run_child_finetuning - Epoch 95\n",
+ "06/09/2019 11:53:54 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 11:54:21 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:54:21 - INFO - run_child_finetuning - eval_accuracy = 0.5769965277777778\n",
+ "06/09/2019 11:54:21 - INFO - run_child_finetuning - eval_loss = 0.8847115721967486\n",
+ "06/09/2019 11:54:21 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 11:54:49 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:54:49 - INFO - run_child_finetuning - eval_accuracy = 0.5650173611111111\n",
+ "06/09/2019 11:54:49 - INFO - run_child_finetuning - eval_loss = 0.8932965106434292\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 95%|█████████▌| 95/100 [1:44:27<05:29, 65.95s/it]\u001b[A\u001b[A\u001b[A\u001b[A06/09/2019 11:54:59 - INFO - run_child_finetuning - Epoch 96\n",
+ "06/09/2019 11:54:59 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 11:55:27 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:55:27 - INFO - run_child_finetuning - eval_accuracy = 0.5789930555555556\n",
+ "06/09/2019 11:55:27 - INFO - run_child_finetuning - eval_loss = 0.8607717480924394\n",
+ "06/09/2019 11:55:27 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 11:55:54 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:55:54 - INFO - run_child_finetuning - eval_accuracy = 0.5701388888888889\n",
+ "06/09/2019 11:55:54 - INFO - run_child_finetuning - eval_loss = 0.8688309980763329\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 96%|█████████▌| 96/100 [1:45:32<04:22, 65.74s/it]\u001b[A\u001b[A\u001b[A\u001b[A06/09/2019 11:56:05 - INFO - run_child_finetuning - Epoch 97\n",
+ "06/09/2019 11:56:05 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 11:56:33 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:56:33 - INFO - run_child_finetuning - eval_accuracy = 0.6397569444444444\n",
+ "06/09/2019 11:56:33 - INFO - run_child_finetuning - eval_loss = 0.7840271459685432\n",
+ "06/09/2019 11:56:33 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 11:57:00 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:57:00 - INFO - run_child_finetuning - eval_accuracy = 0.6365451388888889\n",
+ "06/09/2019 11:57:00 - INFO - run_child_finetuning - eval_loss = 0.7873119976785448\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 97%|█████████▋| 97/100 [1:46:38<03:17, 65.81s/it]\u001b[A\u001b[A\u001b[A\u001b[A"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "global_step 35000, lr = -0.000224\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "06/09/2019 11:57:11 - INFO - run_child_finetuning - Epoch 98\n",
+ "06/09/2019 11:57:11 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 11:57:38 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:57:38 - INFO - run_child_finetuning - eval_accuracy = 0.6711805555555556\n",
+ "06/09/2019 11:57:38 - INFO - run_child_finetuning - eval_loss = 0.7093107594384087\n",
+ "06/09/2019 11:57:38 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 11:58:06 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:58:06 - INFO - run_child_finetuning - eval_accuracy = 0.6711805555555556\n",
+ "06/09/2019 11:58:06 - INFO - run_child_finetuning - eval_loss = 0.7124807761775123\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 98%|█████████▊| 98/100 [1:47:44<02:11, 65.84s/it]\u001b[A\u001b[A\u001b[A\u001b[A06/09/2019 11:58:17 - INFO - run_child_finetuning - Epoch 99\n",
+ "06/09/2019 11:58:17 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 11:58:44 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:58:44 - INFO - run_child_finetuning - eval_accuracy = 0.7196180555555556\n",
+ "06/09/2019 11:58:44 - INFO - run_child_finetuning - eval_loss = 0.6273805638154347\n",
+ "06/09/2019 11:58:44 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 11:59:12 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:59:12 - INFO - run_child_finetuning - eval_accuracy = 0.7190972222222223\n",
+ "06/09/2019 11:59:12 - INFO - run_child_finetuning - eval_loss = 0.630940580368042\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 99%|█████████▉| 99/100 [1:48:50<01:05, 65.88s/it]\u001b[A\u001b[A\u001b[A\u001b[A06/09/2019 11:59:23 - INFO - run_child_finetuning - Epoch 100\n",
+ "06/09/2019 11:59:23 - INFO - run_child_finetuning - Evaluating on train set...\n",
+ "06/09/2019 11:59:50 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 11:59:50 - INFO - run_child_finetuning - eval_accuracy = 0.7444444444444445\n",
+ "06/09/2019 11:59:50 - INFO - run_child_finetuning - eval_loss = 0.545815435383055\n",
+ "06/09/2019 11:59:50 - INFO - run_child_finetuning - Evaluating on valid set...\n",
+ "06/09/2019 12:00:18 - INFO - run_child_finetuning - ***** Eval results *****\n",
+ "06/09/2019 12:00:18 - INFO - run_child_finetuning - eval_accuracy = 0.7447916666666666\n",
+ "06/09/2019 12:00:18 - INFO - run_child_finetuning - eval_loss = 0.5425901783837213\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Epoch: 100%|██████████| 100/100 [1:49:57<00:00, 65.99s/it]\u001b[A\u001b[A\u001b[A\u001b[A"
+ ]
+ }
+ ],
+ "source": [
+ "# train_sampler = RandomSampler(train_dataset)\n",
+ "# train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)\n",
+ "# eval_sampler = SequentialSampler(eval_dataset)\n",
+ "# eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)\n",
+ "\n",
+ "logger.info(\"Epoch 0\")\n",
+ "logger.info(\"Evaluating on train set...\")\n",
+ "validate(model, train_dataset, device)\n",
+ "logger.info(\"Evaluating on valid set...\")\n",
+ "validate(model, eval_dataset, device)\n",
+ "\n",
+ "global_step = 0\n",
+ "for epoch in trange(int(args.num_train_epochs), desc=\"Epoch\"):\n",
+ " _ = model.train()\n",
+ " tr_loss = 0\n",
+ " nb_tr_examples, nb_tr_steps = 0, 0\n",
+ "# for step, batch in enumerate(tqdm(train_dataloader, desc=\"Iteration\")):\n",
+ " for step, batch_idx in enumerate(get_batch_index(len(train_dataset), args.train_batch_size, randomized=True)):\n",
+ " batch = tuple(t[batch_idx] for t in train_dataset.tensors)\n",
+ " batch = tuple(t.to(device) for t in batch)\n",
+ " input_ids, input_mask, segment_ids, lm_label_ids, is_next = batch\n",
+ " loss = model(input_ids, segment_ids, input_mask, lm_label_ids)\n",
+ " if n_gpu > 1:\n",
+ " loss = loss.mean() # mean() to average on multi-gpu.\n",
+ " if args.gradient_accumulation_steps > 1:\n",
+ " loss = loss / args.gradient_accumulation_steps\n",
+ " loss.backward()\n",
+ " tr_loss += loss.item()\n",
+ " nb_tr_examples += input_ids.size(0)\n",
+ " nb_tr_steps += 1\n",
+ " if (step + 1) % args.gradient_accumulation_steps == 0:\n",
+ " # modify learning rate with special warm up BERT uses\n",
+ " lr_this_step = args.learning_rate * warmup_linear(global_step/num_train_steps, args.warmup_proportion)\n",
+ " if global_step % 1000 == 0:\n",
+ " print('global_step %d, lr = %f' % (global_step, lr_this_step))\n",
+ " for param_group in optimizer.param_groups:\n",
+ " param_group['lr'] = lr_this_step\n",
+ " optimizer.step()\n",
+ " optimizer.zero_grad()\n",
+ " global_step += 1\n",
+ "\n",
+ " if args.do_eval:\n",
+ " logger.info(\"Epoch %d\" % (epoch + 1))\n",
+ " logger.info(\"Evaluating on train set...\")\n",
+ " validate(model, train_dataset, device)\n",
+ " logger.info(\"Evaluating on valid set...\")\n",
+ " validate(model, eval_dataset, device)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.5"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/Untitled3.ipynb b/Untitled3.ipynb
new file mode 100644
index 00000000000000..eee4c4c8357630
--- /dev/null
+++ b/Untitled3.ipynb
@@ -0,0 +1,804 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%load_ext autoreload\n",
+ "%autoreload 2\n",
+ "\n",
+ "from IPython.core.interactiveshell import InteractiveShell\n",
+ "InteractiveShell.ast_node_interactivity = 'all'"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "import json\n",
+ "import itertools\n",
+ "from itertools import product, chain\n",
+ "\n",
+ "from pytorch_pretrained_bert import tokenization, BertTokenizer, BertModel, BertForMaskedLM, BertForPreTraining, BertConfig"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "01/24/2019 22:16:56 - INFO - pytorch_pretrained_bert.tokenization - loading vocabulary file /nas/pretrain-bert/pretrain-tensorflow/uncased_L-12_H-768_A-12/vocab.txt\n"
+ ]
+ }
+ ],
+ "source": [
+ "CONFIG_NAME = 'bert_config.json'\n",
+ "BERT_DIR = '/nas/pretrain-bert/pretrain-tensorflow/uncased_L-12_H-768_A-12/'\n",
+ "tokenizer = BertTokenizer.from_pretrained(os.path.join(BERT_DIR, 'vocab.txt'))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def reverse(l):\n",
+ " return list(reversed(l))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def mask(ent_str):\n",
+ " tokens = ent_str.strip().split()\n",
+ " if len(tokens) == 1:\n",
+ " return '[%s]' % tokens[0]\n",
+ " elif len(tokens) == 2:\n",
+ " assert tokens[0] == 'the', ent_str\n",
+ " return '%s [%s]' % (tokens[0], tokens[1])\n",
+ " else:\n",
+ " assert False, ent_str"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "A_template = \"{dt} {ent0} {rel} {dt} {ent1} {rel_suffix}\"\n",
+ "B_template = \"{dt} {ent} {pred}\"\n",
+ "\n",
+ "causal_templates = [[\"{A} because {B}.\"],# \"{B} so {A}.\"], \n",
+ " [\"{A} so {B}.\"],# \"{B} because {A}.\"]\n",
+ " ]\n",
+ "turning_templates = [[\"{A} although {B}.\"],# \"{B} but {A}.\"], \n",
+ " [\"{A} but {B}.\"],# \"{B} although {A}.\"]\n",
+ " ]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 79,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def make_sentences(A_template, B_template, causal_templates, turning_templates,\n",
+ " index=-1, orig_sentence='', entities=[\"John\", \"Mary\"], entity_substitutes=None, determiner=\"\", \n",
+ " packed_relations=[\"rel/~rel\", \"rev_rel/~rev_rel\"], packed_relation_substitutes=None, relation_suffix=\"\",\n",
+ " packed_predicates=[\"pred0/~pred0\", \"pred1/~pred1\"], predicate_substitutes=None,\n",
+ " predicate_dichotomy=True, reverse_causal=False):\n",
+ " assert entities[0].lower() in tokenizer.vocab , entities[0]\n",
+ " assert entities[1].lower() in tokenizer.vocab , entities[1]\n",
+ " \n",
+ " relations, neg_relations = zip(*[rel.split(\"/\") for rel in packed_relations])\n",
+ " relations, neg_relations = list(relations), list(neg_relations)\n",
+ " predicates, neg_predicates = zip(*[pred.split(\"/\") for pred in packed_predicates])\n",
+ " predicates, neg_predicates = list(predicates), list(neg_predicates)\n",
+ " \n",
+ " As = [A_template.format(dt=determiner, ent0=ent0, ent1=ent1, rel=rel, rel_suffix=relation_suffix) \n",
+ " for ent0, ent1, rel in [entities + relations[:1], reverse(entities) + reverse(relations)[:1]]]\n",
+ " negAs = [A_template.format(dt=determiner, ent0=ent0, ent1=ent1, rel=rel, rel_suffix=relation_suffix) \n",
+ " for ent0, ent1, rel in [entities + neg_relations[:1], reverse(entities) + reverse(neg_relations)[:1]]]\n",
+ "\n",
+ " Bs = [B_template.format(dt=determiner, ent=mask(ent), pred=pred) for ent, pred in zip(entities, predicates)]\n",
+ " negBs = [B_template.format(dt=determiner, ent=mask(ent), pred=pred) for ent, pred in zip(entities, neg_predicates)]\n",
+ " if predicate_dichotomy:\n",
+ " Bs += [B_template.format(dt=determiner, ent=mask(ent), pred=pred) for ent, pred in zip(entities, reversed(neg_predicates))]\n",
+ " negBs += [B_template.format(dt=determiner, ent=mask(ent), pred=pred) for ent, pred in zip(entities, reversed(predicates))]\n",
+ "\n",
+ " def form_sentences(sentence_template, As, Bs):\n",
+ " return [\" \".join(sentence_template.format(A=A, B=B).split()) for A, B in product(As, Bs)]\n",
+ "\n",
+ " causal_sentences = []\n",
+ " for causal_template in causal_templates[int(reverse_causal)]:\n",
+ " for A, B in [(As, Bs), (negAs, negBs)]:\n",
+ " causal_sentences.extend(form_sentences(causal_template, A, B))\n",
+ "\n",
+ " turning_sentences = []\n",
+ " for turning_template in turning_templates[int(reverse_causal)]:\n",
+ " for A, B in [(As, negBs), (negAs, Bs)]:\n",
+ " turning_sentences.extend(form_sentences(turning_template, A, B))\n",
+ " \n",
+ " sentences = causal_sentences + turning_sentences\n",
+ " substituted_sentences = sentences\n",
+ " \n",
+ " if packed_relation_substitutes is not None:\n",
+ " packed_relation_substitutes = list(itertools.product(packed_relations[:1] + packed_relation_substitutes[0], \n",
+ " packed_relations[1:] + packed_relation_substitutes[1]))\n",
+ " substituted_sentences = []\n",
+ " for packed_sub_relations in packed_relation_substitutes:\n",
+ " sub_relations, sub_neg_relations = zip(*[rel.split(\"/\") for rel in packed_sub_relations])\n",
+ " substituted_sentences += [sent.replace(relations[0], sub_relations[0]).replace(relations[1], sub_relations[1])\n",
+ " .replace(neg_relations[0], sub_neg_relations[0]).replace(neg_relations[1], sub_neg_relations[1]) \n",
+ " for sent in sentences]\n",
+ " substituted_sentences = list(set(substituted_sentences))\n",
+ " \n",
+ " if entity_substitutes is not None:\n",
+ " for sub in entity_substitutes:\n",
+ " for ent in sub:\n",
+ " assert ent.lower() in tokenizer.vocab , ent + \" not in BERT vocab\"\n",
+ " assert len(set(chain.from_iterable(entity_substitutes))) == 4, entity_substitutes\n",
+ " assert len(set(chain.from_iterable(entity_substitutes)).union(set(entities))) == 6 \n",
+ " \n",
+ " entity_substitutes = list(itertools.product(entities[:1] + entity_substitutes[0], entities[1:] + entity_substitutes[1]))\n",
+ " substituted_sentences = [sent.replace(entities[0], sub[0]).replace(entities[1], sub[1]) \n",
+ " for sent in substituted_sentences for sub in entity_substitutes]\n",
+ " return causal_sentences, turning_sentences, substituted_sentences"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 80,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "frames = \\\n",
+ "[\n",
+ " {\n",
+ " \"index\": 2,\n",
+ " \"orig_sentence\": \"The trophy doesn't fit into the brown suitcase because [it] is too large/small.\",\n",
+ " \"entities\": [\"trophy\", \"suitcase\"],\n",
+ " \"entitity_substitutes\": [[\"ball\", \"toy\"], [\"bag\", \"box\"]],\n",
+ " \"determiner\": \"the\",\n",
+ " \"packed_relations\": [\"doesn't fit into/can fit into\", \"doesn't hold/can hold\"],\n",
+ " \"packed_relation_substitutes\": [[\"can't be put into/can be put into\"], [\"doesn't have enough room for/has enough room for\"]],\n",
+ " \"relation_suffix\": \"\",\n",
+ " \"packed_predicates\": [\"is large/isn't large\", \"is small/isn't small\"],\n",
+ " \"predicate_dichotomy\": True,\n",
+ " \"reverse_causal\": False\n",
+ " },\n",
+ " {\n",
+ " \"index\": 4,\n",
+ " \"orig_sentence\": \"Joan made sure to thank Susan for all the help [she] had recieved/given.\",\n",
+ " \"entities\": [\"John\", \"Susan\"],\n",
+ " \"entity_substitutes\": [[\"David\", \"Michael\"], [\"Mary\", \"Linda\"]],\n",
+ " \"determiner\": \"\",\n",
+ " \"packed_relations\": [\"thanked/didn't thank\", \"took good care of/didn't good care of\"],\n",
+ " \"packed_relation_substitutes\": [[\"felt grateful to/didn't feel grateful to\"], [\"was appreciated by/wasn't appreciated by\"]],\n",
+ " \"relation_suffix\": \"\",\n",
+ " \"packed_predicates\": [\"had received a lot of help/hadn't received a lot of help\", \"had given a lot of help/hadn't given a lot of help\"],\n",
+ " \"predicate_dichotomy\": False,\n",
+ " \"reverse_causal\": False\n",
+ " },\n",
+ " {\n",
+ " \"index\": 4000,\n",
+ " \"orig_sentence\": \"John gave a lot of money to Susan because [he] was very rich/poor.\",\n",
+ " \"entities\": [\"John\", \"Susan\"],\n",
+ " \"entity_substitutes\": [[\"David\", \"Michael\"], [\"Mary\", \"Linda\"]],\n",
+ " \"determiner\": \"\",\n",
+ " \"packed_relations\": [\"gave a lot of money to/didn't give a lot of money to\", \"received a lot of money from/didn't receive a lot of money from\"],\n",
+ " \"packed_relation_substitutes\": [[\"subsidized/didn't subsidize\"], [\"borrowed a lot of money from/didn't borrow any money from\"]],\n",
+ " \"relation_suffix\": \"\",\n",
+ " \"packed_predicates\": [\"was rich/wasn't rich\", \"was poor/wasn't poor\"],\n",
+ " \"predicate_dichotomy\": True,\n",
+ " \"reverse_causal\": False\n",
+ " },\n",
+ " {\n",
+ " \"index\": 10,\n",
+ " \"orig_sentence\": \"The delivery truck zoomed by the school bus because [it] was going so fast/slow.\",\n",
+ " \"entities\": [\"truck\", \"bus\"],\n",
+ " \"entity_substitutes\": [[\"car\", \"ambulance\"], [\"bicycle\", \"tram\"]],\n",
+ " \"determiner\": \"the\",\n",
+ " \"packed_relations\": [\"overtook/couldn't overtake\", \"fell far behind/didn't fall far behind\"],\n",
+ " \"packed_relation_substitutes\": [[\"zoomed by/didn't pass\"], [\"was left behind/wasn't left far behind\"]],\n",
+ " \"relation_suffix\": \"\",\n",
+ " \"packed_predicates\": [\"was going fast/wasn't going fast\", \"was going slow/wasn't going slow\"],\n",
+ " \"predicate_dichotomy\": True,\n",
+ " \"reverse_causal\": False\n",
+ " },\n",
+ " {\n",
+ " \"index\": 12,\n",
+ " \"orig_sentence\": \"Frank felt vindicated/crushed when his longtime rival Bill revealed that [he] was the winner of the competition.\",\n",
+ " \"entities\": [\"John\", \"Susan\"],\n",
+ " \"entity_substitutes\": [[\"David\", \"Michael\"], [\"Mary\", \"Linda\"]],\n",
+ " \"determiner\": \"\",\n",
+ " \"packed_relations\": [\"beat/didn't beat\", \"lost to/didn't lose to\"],\n",
+ " \"relation_suffix\": \"in the game\",\n",
+ " \"packed_predicates\": [\"was happy/wasn't happy\", \"was sad/wasn't sad\"],\n",
+ " \"packed_relation_substitutes\": None,\n",
+ " \"predicate_dichotomy\": True,\n",
+ " \"reverse_causal\": True\n",
+ " },\n",
+ " {\n",
+ " \"index\": 16,\n",
+ " \"orig_sentence\": \"The large ball crashed right through the table because [it] was made of steel/styrofoam.\",\n",
+ " \"entities\": [\"ball\", \"board\"],\n",
+ " \"substitutes\": [[\"bullet\", \"arrow\"], [\"shield\", \"disk\"]],\n",
+ " \"determiner\": \"the\",\n",
+ " \"relations\": [\"crashed right through\", \"failed to block\"],\n",
+ " \"neg_relations\": [\"didn't crash through\", \"blocked\"],\n",
+ " \"relation_suffix\": \"\",\n",
+ " \"predicates\": [\"was hard\", \"was soft\"],\n",
+ " \"neg_predicates\": [\"wasn't hard\", \"wasn't soft\"],\n",
+ " \"predicate_dichotomy\": True,\n",
+ " \"reverse_causal\": False\n",
+ " },\n",
+ "]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 75,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "causal_sentences, turning_sentences, substituted_sentences = \\\n",
+ " make_sentences(A_template, B_template, causal_templates, turning_templates, **frames[-1])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 76,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "['John beat Susan in the game so [John] was happy.',\n",
+ " 'John beat Susan in the game so [Susan] was sad.',\n",
+ " \"John beat Susan in the game so [John] wasn't sad.\",\n",
+ " \"John beat Susan in the game so [Susan] wasn't happy.\",\n",
+ " 'Susan lost to John in the game so [John] was happy.',\n",
+ " 'Susan lost to John in the game so [Susan] was sad.',\n",
+ " \"Susan lost to John in the game so [John] wasn't sad.\",\n",
+ " \"Susan lost to John in the game so [Susan] wasn't happy.\",\n",
+ " \"John didn't beat Susan in the game so [John] wasn't happy.\",\n",
+ " \"John didn't beat Susan in the game so [Susan] wasn't sad.\",\n",
+ " \"John didn't beat Susan in the game so [John] was sad.\",\n",
+ " \"John didn't beat Susan in the game so [Susan] was happy.\",\n",
+ " \"Susan didn't lose to John in the game so [John] wasn't happy.\",\n",
+ " \"Susan didn't lose to John in the game so [Susan] wasn't sad.\",\n",
+ " \"Susan didn't lose to John in the game so [John] was sad.\",\n",
+ " \"Susan didn't lose to John in the game so [Susan] was happy.\",\n",
+ " \"John beat Susan in the game but [John] wasn't happy.\",\n",
+ " \"John beat Susan in the game but [Susan] wasn't sad.\",\n",
+ " 'John beat Susan in the game but [John] was sad.',\n",
+ " 'John beat Susan in the game but [Susan] was happy.',\n",
+ " \"Susan lost to John in the game but [John] wasn't happy.\",\n",
+ " \"Susan lost to John in the game but [Susan] wasn't sad.\",\n",
+ " 'Susan lost to John in the game but [John] was sad.',\n",
+ " 'Susan lost to John in the game but [Susan] was happy.',\n",
+ " \"John didn't beat Susan in the game but [John] was happy.\",\n",
+ " \"John didn't beat Susan in the game but [Susan] was sad.\",\n",
+ " \"John didn't beat Susan in the game but [John] wasn't sad.\",\n",
+ " \"John didn't beat Susan in the game but [Susan] wasn't happy.\",\n",
+ " \"Susan didn't lose to John in the game but [John] was happy.\",\n",
+ " \"Susan didn't lose to John in the game but [Susan] was sad.\",\n",
+ " \"Susan didn't lose to John in the game but [John] wasn't sad.\",\n",
+ " \"Susan didn't lose to John in the game but [Susan] wasn't happy.\"]"
+ ]
+ },
+ "execution_count": 76,
+ "metadata": {},
+ "output_type": "execute_result"
+ },
+ {
+ "data": {
+ "text/plain": [
+ "['John beat Susan in the game so [John] was happy.',\n",
+ " 'John beat Mary in the game so [John] was happy.',\n",
+ " 'John beat Linda in the game so [John] was happy.',\n",
+ " 'David beat Susan in the game so [David] was happy.',\n",
+ " 'David beat Mary in the game so [David] was happy.',\n",
+ " 'David beat Linda in the game so [David] was happy.',\n",
+ " 'Michael beat Susan in the game so [Michael] was happy.',\n",
+ " 'Michael beat Mary in the game so [Michael] was happy.',\n",
+ " 'Michael beat Linda in the game so [Michael] was happy.',\n",
+ " 'John beat Susan in the game so [Susan] was sad.',\n",
+ " 'John beat Mary in the game so [Mary] was sad.',\n",
+ " 'John beat Linda in the game so [Linda] was sad.',\n",
+ " 'David beat Susan in the game so [Susan] was sad.',\n",
+ " 'David beat Mary in the game so [Mary] was sad.',\n",
+ " 'David beat Linda in the game so [Linda] was sad.',\n",
+ " 'Michael beat Susan in the game so [Susan] was sad.',\n",
+ " 'Michael beat Mary in the game so [Mary] was sad.',\n",
+ " 'Michael beat Linda in the game so [Linda] was sad.',\n",
+ " \"John beat Susan in the game so [John] wasn't sad.\",\n",
+ " \"John beat Mary in the game so [John] wasn't sad.\",\n",
+ " \"John beat Linda in the game so [John] wasn't sad.\",\n",
+ " \"David beat Susan in the game so [David] wasn't sad.\",\n",
+ " \"David beat Mary in the game so [David] wasn't sad.\",\n",
+ " \"David beat Linda in the game so [David] wasn't sad.\",\n",
+ " \"Michael beat Susan in the game so [Michael] wasn't sad.\",\n",
+ " \"Michael beat Mary in the game so [Michael] wasn't sad.\",\n",
+ " \"Michael beat Linda in the game so [Michael] wasn't sad.\",\n",
+ " \"John beat Susan in the game so [Susan] wasn't happy.\",\n",
+ " \"John beat Mary in the game so [Mary] wasn't happy.\",\n",
+ " \"John beat Linda in the game so [Linda] wasn't happy.\",\n",
+ " \"David beat Susan in the game so [Susan] wasn't happy.\",\n",
+ " \"David beat Mary in the game so [Mary] wasn't happy.\",\n",
+ " \"David beat Linda in the game so [Linda] wasn't happy.\",\n",
+ " \"Michael beat Susan in the game so [Susan] wasn't happy.\",\n",
+ " \"Michael beat Mary in the game so [Mary] wasn't happy.\",\n",
+ " \"Michael beat Linda in the game so [Linda] wasn't happy.\",\n",
+ " 'Susan lost to John in the game so [John] was happy.',\n",
+ " 'Mary lost to John in the game so [John] was happy.',\n",
+ " 'Linda lost to John in the game so [John] was happy.',\n",
+ " 'Susan lost to David in the game so [David] was happy.',\n",
+ " 'Mary lost to David in the game so [David] was happy.',\n",
+ " 'Linda lost to David in the game so [David] was happy.',\n",
+ " 'Susan lost to Michael in the game so [Michael] was happy.',\n",
+ " 'Mary lost to Michael in the game so [Michael] was happy.',\n",
+ " 'Linda lost to Michael in the game so [Michael] was happy.',\n",
+ " 'Susan lost to John in the game so [Susan] was sad.',\n",
+ " 'Mary lost to John in the game so [Mary] was sad.',\n",
+ " 'Linda lost to John in the game so [Linda] was sad.',\n",
+ " 'Susan lost to David in the game so [Susan] was sad.',\n",
+ " 'Mary lost to David in the game so [Mary] was sad.',\n",
+ " 'Linda lost to David in the game so [Linda] was sad.',\n",
+ " 'Susan lost to Michael in the game so [Susan] was sad.',\n",
+ " 'Mary lost to Michael in the game so [Mary] was sad.',\n",
+ " 'Linda lost to Michael in the game so [Linda] was sad.',\n",
+ " \"Susan lost to John in the game so [John] wasn't sad.\",\n",
+ " \"Mary lost to John in the game so [John] wasn't sad.\",\n",
+ " \"Linda lost to John in the game so [John] wasn't sad.\",\n",
+ " \"Susan lost to David in the game so [David] wasn't sad.\",\n",
+ " \"Mary lost to David in the game so [David] wasn't sad.\",\n",
+ " \"Linda lost to David in the game so [David] wasn't sad.\",\n",
+ " \"Susan lost to Michael in the game so [Michael] wasn't sad.\",\n",
+ " \"Mary lost to Michael in the game so [Michael] wasn't sad.\",\n",
+ " \"Linda lost to Michael in the game so [Michael] wasn't sad.\",\n",
+ " \"Susan lost to John in the game so [Susan] wasn't happy.\",\n",
+ " \"Mary lost to John in the game so [Mary] wasn't happy.\",\n",
+ " \"Linda lost to John in the game so [Linda] wasn't happy.\",\n",
+ " \"Susan lost to David in the game so [Susan] wasn't happy.\",\n",
+ " \"Mary lost to David in the game so [Mary] wasn't happy.\",\n",
+ " \"Linda lost to David in the game so [Linda] wasn't happy.\",\n",
+ " \"Susan lost to Michael in the game so [Susan] wasn't happy.\",\n",
+ " \"Mary lost to Michael in the game so [Mary] wasn't happy.\",\n",
+ " \"Linda lost to Michael in the game so [Linda] wasn't happy.\",\n",
+ " \"John didn't beat Susan in the game so [John] wasn't happy.\",\n",
+ " \"John didn't beat Mary in the game so [John] wasn't happy.\",\n",
+ " \"John didn't beat Linda in the game so [John] wasn't happy.\",\n",
+ " \"David didn't beat Susan in the game so [David] wasn't happy.\",\n",
+ " \"David didn't beat Mary in the game so [David] wasn't happy.\",\n",
+ " \"David didn't beat Linda in the game so [David] wasn't happy.\",\n",
+ " \"Michael didn't beat Susan in the game so [Michael] wasn't happy.\",\n",
+ " \"Michael didn't beat Mary in the game so [Michael] wasn't happy.\",\n",
+ " \"Michael didn't beat Linda in the game so [Michael] wasn't happy.\",\n",
+ " \"John didn't beat Susan in the game so [Susan] wasn't sad.\",\n",
+ " \"John didn't beat Mary in the game so [Mary] wasn't sad.\",\n",
+ " \"John didn't beat Linda in the game so [Linda] wasn't sad.\",\n",
+ " \"David didn't beat Susan in the game so [Susan] wasn't sad.\",\n",
+ " \"David didn't beat Mary in the game so [Mary] wasn't sad.\",\n",
+ " \"David didn't beat Linda in the game so [Linda] wasn't sad.\",\n",
+ " \"Michael didn't beat Susan in the game so [Susan] wasn't sad.\",\n",
+ " \"Michael didn't beat Mary in the game so [Mary] wasn't sad.\",\n",
+ " \"Michael didn't beat Linda in the game so [Linda] wasn't sad.\",\n",
+ " \"John didn't beat Susan in the game so [John] was sad.\",\n",
+ " \"John didn't beat Mary in the game so [John] was sad.\",\n",
+ " \"John didn't beat Linda in the game so [John] was sad.\",\n",
+ " \"David didn't beat Susan in the game so [David] was sad.\",\n",
+ " \"David didn't beat Mary in the game so [David] was sad.\",\n",
+ " \"David didn't beat Linda in the game so [David] was sad.\",\n",
+ " \"Michael didn't beat Susan in the game so [Michael] was sad.\",\n",
+ " \"Michael didn't beat Mary in the game so [Michael] was sad.\",\n",
+ " \"Michael didn't beat Linda in the game so [Michael] was sad.\",\n",
+ " \"John didn't beat Susan in the game so [Susan] was happy.\",\n",
+ " \"John didn't beat Mary in the game so [Mary] was happy.\",\n",
+ " \"John didn't beat Linda in the game so [Linda] was happy.\",\n",
+ " \"David didn't beat Susan in the game so [Susan] was happy.\",\n",
+ " \"David didn't beat Mary in the game so [Mary] was happy.\",\n",
+ " \"David didn't beat Linda in the game so [Linda] was happy.\",\n",
+ " \"Michael didn't beat Susan in the game so [Susan] was happy.\",\n",
+ " \"Michael didn't beat Mary in the game so [Mary] was happy.\",\n",
+ " \"Michael didn't beat Linda in the game so [Linda] was happy.\",\n",
+ " \"Susan didn't lose to John in the game so [John] wasn't happy.\",\n",
+ " \"Mary didn't lose to John in the game so [John] wasn't happy.\",\n",
+ " \"Linda didn't lose to John in the game so [John] wasn't happy.\",\n",
+ " \"Susan didn't lose to David in the game so [David] wasn't happy.\",\n",
+ " \"Mary didn't lose to David in the game so [David] wasn't happy.\",\n",
+ " \"Linda didn't lose to David in the game so [David] wasn't happy.\",\n",
+ " \"Susan didn't lose to Michael in the game so [Michael] wasn't happy.\",\n",
+ " \"Mary didn't lose to Michael in the game so [Michael] wasn't happy.\",\n",
+ " \"Linda didn't lose to Michael in the game so [Michael] wasn't happy.\",\n",
+ " \"Susan didn't lose to John in the game so [Susan] wasn't sad.\",\n",
+ " \"Mary didn't lose to John in the game so [Mary] wasn't sad.\",\n",
+ " \"Linda didn't lose to John in the game so [Linda] wasn't sad.\",\n",
+ " \"Susan didn't lose to David in the game so [Susan] wasn't sad.\",\n",
+ " \"Mary didn't lose to David in the game so [Mary] wasn't sad.\",\n",
+ " \"Linda didn't lose to David in the game so [Linda] wasn't sad.\",\n",
+ " \"Susan didn't lose to Michael in the game so [Susan] wasn't sad.\",\n",
+ " \"Mary didn't lose to Michael in the game so [Mary] wasn't sad.\",\n",
+ " \"Linda didn't lose to Michael in the game so [Linda] wasn't sad.\",\n",
+ " \"Susan didn't lose to John in the game so [John] was sad.\",\n",
+ " \"Mary didn't lose to John in the game so [John] was sad.\",\n",
+ " \"Linda didn't lose to John in the game so [John] was sad.\",\n",
+ " \"Susan didn't lose to David in the game so [David] was sad.\",\n",
+ " \"Mary didn't lose to David in the game so [David] was sad.\",\n",
+ " \"Linda didn't lose to David in the game so [David] was sad.\",\n",
+ " \"Susan didn't lose to Michael in the game so [Michael] was sad.\",\n",
+ " \"Mary didn't lose to Michael in the game so [Michael] was sad.\",\n",
+ " \"Linda didn't lose to Michael in the game so [Michael] was sad.\",\n",
+ " \"Susan didn't lose to John in the game so [Susan] was happy.\",\n",
+ " \"Mary didn't lose to John in the game so [Mary] was happy.\",\n",
+ " \"Linda didn't lose to John in the game so [Linda] was happy.\",\n",
+ " \"Susan didn't lose to David in the game so [Susan] was happy.\",\n",
+ " \"Mary didn't lose to David in the game so [Mary] was happy.\",\n",
+ " \"Linda didn't lose to David in the game so [Linda] was happy.\",\n",
+ " \"Susan didn't lose to Michael in the game so [Susan] was happy.\",\n",
+ " \"Mary didn't lose to Michael in the game so [Mary] was happy.\",\n",
+ " \"Linda didn't lose to Michael in the game so [Linda] was happy.\",\n",
+ " \"John beat Susan in the game but [John] wasn't happy.\",\n",
+ " \"John beat Mary in the game but [John] wasn't happy.\",\n",
+ " \"John beat Linda in the game but [John] wasn't happy.\",\n",
+ " \"David beat Susan in the game but [David] wasn't happy.\",\n",
+ " \"David beat Mary in the game but [David] wasn't happy.\",\n",
+ " \"David beat Linda in the game but [David] wasn't happy.\",\n",
+ " \"Michael beat Susan in the game but [Michael] wasn't happy.\",\n",
+ " \"Michael beat Mary in the game but [Michael] wasn't happy.\",\n",
+ " \"Michael beat Linda in the game but [Michael] wasn't happy.\",\n",
+ " \"John beat Susan in the game but [Susan] wasn't sad.\",\n",
+ " \"John beat Mary in the game but [Mary] wasn't sad.\",\n",
+ " \"John beat Linda in the game but [Linda] wasn't sad.\",\n",
+ " \"David beat Susan in the game but [Susan] wasn't sad.\",\n",
+ " \"David beat Mary in the game but [Mary] wasn't sad.\",\n",
+ " \"David beat Linda in the game but [Linda] wasn't sad.\",\n",
+ " \"Michael beat Susan in the game but [Susan] wasn't sad.\",\n",
+ " \"Michael beat Mary in the game but [Mary] wasn't sad.\",\n",
+ " \"Michael beat Linda in the game but [Linda] wasn't sad.\",\n",
+ " 'John beat Susan in the game but [John] was sad.',\n",
+ " 'John beat Mary in the game but [John] was sad.',\n",
+ " 'John beat Linda in the game but [John] was sad.',\n",
+ " 'David beat Susan in the game but [David] was sad.',\n",
+ " 'David beat Mary in the game but [David] was sad.',\n",
+ " 'David beat Linda in the game but [David] was sad.',\n",
+ " 'Michael beat Susan in the game but [Michael] was sad.',\n",
+ " 'Michael beat Mary in the game but [Michael] was sad.',\n",
+ " 'Michael beat Linda in the game but [Michael] was sad.',\n",
+ " 'John beat Susan in the game but [Susan] was happy.',\n",
+ " 'John beat Mary in the game but [Mary] was happy.',\n",
+ " 'John beat Linda in the game but [Linda] was happy.',\n",
+ " 'David beat Susan in the game but [Susan] was happy.',\n",
+ " 'David beat Mary in the game but [Mary] was happy.',\n",
+ " 'David beat Linda in the game but [Linda] was happy.',\n",
+ " 'Michael beat Susan in the game but [Susan] was happy.',\n",
+ " 'Michael beat Mary in the game but [Mary] was happy.',\n",
+ " 'Michael beat Linda in the game but [Linda] was happy.',\n",
+ " \"Susan lost to John in the game but [John] wasn't happy.\",\n",
+ " \"Mary lost to John in the game but [John] wasn't happy.\",\n",
+ " \"Linda lost to John in the game but [John] wasn't happy.\",\n",
+ " \"Susan lost to David in the game but [David] wasn't happy.\",\n",
+ " \"Mary lost to David in the game but [David] wasn't happy.\",\n",
+ " \"Linda lost to David in the game but [David] wasn't happy.\",\n",
+ " \"Susan lost to Michael in the game but [Michael] wasn't happy.\",\n",
+ " \"Mary lost to Michael in the game but [Michael] wasn't happy.\",\n",
+ " \"Linda lost to Michael in the game but [Michael] wasn't happy.\",\n",
+ " \"Susan lost to John in the game but [Susan] wasn't sad.\",\n",
+ " \"Mary lost to John in the game but [Mary] wasn't sad.\",\n",
+ " \"Linda lost to John in the game but [Linda] wasn't sad.\",\n",
+ " \"Susan lost to David in the game but [Susan] wasn't sad.\",\n",
+ " \"Mary lost to David in the game but [Mary] wasn't sad.\",\n",
+ " \"Linda lost to David in the game but [Linda] wasn't sad.\",\n",
+ " \"Susan lost to Michael in the game but [Susan] wasn't sad.\",\n",
+ " \"Mary lost to Michael in the game but [Mary] wasn't sad.\",\n",
+ " \"Linda lost to Michael in the game but [Linda] wasn't sad.\",\n",
+ " 'Susan lost to John in the game but [John] was sad.',\n",
+ " 'Mary lost to John in the game but [John] was sad.',\n",
+ " 'Linda lost to John in the game but [John] was sad.',\n",
+ " 'Susan lost to David in the game but [David] was sad.',\n",
+ " 'Mary lost to David in the game but [David] was sad.',\n",
+ " 'Linda lost to David in the game but [David] was sad.',\n",
+ " 'Susan lost to Michael in the game but [Michael] was sad.',\n",
+ " 'Mary lost to Michael in the game but [Michael] was sad.',\n",
+ " 'Linda lost to Michael in the game but [Michael] was sad.',\n",
+ " 'Susan lost to John in the game but [Susan] was happy.',\n",
+ " 'Mary lost to John in the game but [Mary] was happy.',\n",
+ " 'Linda lost to John in the game but [Linda] was happy.',\n",
+ " 'Susan lost to David in the game but [Susan] was happy.',\n",
+ " 'Mary lost to David in the game but [Mary] was happy.',\n",
+ " 'Linda lost to David in the game but [Linda] was happy.',\n",
+ " 'Susan lost to Michael in the game but [Susan] was happy.',\n",
+ " 'Mary lost to Michael in the game but [Mary] was happy.',\n",
+ " 'Linda lost to Michael in the game but [Linda] was happy.',\n",
+ " \"John didn't beat Susan in the game but [John] was happy.\",\n",
+ " \"John didn't beat Mary in the game but [John] was happy.\",\n",
+ " \"John didn't beat Linda in the game but [John] was happy.\",\n",
+ " \"David didn't beat Susan in the game but [David] was happy.\",\n",
+ " \"David didn't beat Mary in the game but [David] was happy.\",\n",
+ " \"David didn't beat Linda in the game but [David] was happy.\",\n",
+ " \"Michael didn't beat Susan in the game but [Michael] was happy.\",\n",
+ " \"Michael didn't beat Mary in the game but [Michael] was happy.\",\n",
+ " \"Michael didn't beat Linda in the game but [Michael] was happy.\",\n",
+ " \"John didn't beat Susan in the game but [Susan] was sad.\",\n",
+ " \"John didn't beat Mary in the game but [Mary] was sad.\",\n",
+ " \"John didn't beat Linda in the game but [Linda] was sad.\",\n",
+ " \"David didn't beat Susan in the game but [Susan] was sad.\",\n",
+ " \"David didn't beat Mary in the game but [Mary] was sad.\",\n",
+ " \"David didn't beat Linda in the game but [Linda] was sad.\",\n",
+ " \"Michael didn't beat Susan in the game but [Susan] was sad.\",\n",
+ " \"Michael didn't beat Mary in the game but [Mary] was sad.\",\n",
+ " \"Michael didn't beat Linda in the game but [Linda] was sad.\",\n",
+ " \"John didn't beat Susan in the game but [John] wasn't sad.\",\n",
+ " \"John didn't beat Mary in the game but [John] wasn't sad.\",\n",
+ " \"John didn't beat Linda in the game but [John] wasn't sad.\",\n",
+ " \"David didn't beat Susan in the game but [David] wasn't sad.\",\n",
+ " \"David didn't beat Mary in the game but [David] wasn't sad.\",\n",
+ " \"David didn't beat Linda in the game but [David] wasn't sad.\",\n",
+ " \"Michael didn't beat Susan in the game but [Michael] wasn't sad.\",\n",
+ " \"Michael didn't beat Mary in the game but [Michael] wasn't sad.\",\n",
+ " \"Michael didn't beat Linda in the game but [Michael] wasn't sad.\",\n",
+ " \"John didn't beat Susan in the game but [Susan] wasn't happy.\",\n",
+ " \"John didn't beat Mary in the game but [Mary] wasn't happy.\",\n",
+ " \"John didn't beat Linda in the game but [Linda] wasn't happy.\",\n",
+ " \"David didn't beat Susan in the game but [Susan] wasn't happy.\",\n",
+ " \"David didn't beat Mary in the game but [Mary] wasn't happy.\",\n",
+ " \"David didn't beat Linda in the game but [Linda] wasn't happy.\",\n",
+ " \"Michael didn't beat Susan in the game but [Susan] wasn't happy.\",\n",
+ " \"Michael didn't beat Mary in the game but [Mary] wasn't happy.\",\n",
+ " \"Michael didn't beat Linda in the game but [Linda] wasn't happy.\",\n",
+ " \"Susan didn't lose to John in the game but [John] was happy.\",\n",
+ " \"Mary didn't lose to John in the game but [John] was happy.\",\n",
+ " \"Linda didn't lose to John in the game but [John] was happy.\",\n",
+ " \"Susan didn't lose to David in the game but [David] was happy.\",\n",
+ " \"Mary didn't lose to David in the game but [David] was happy.\",\n",
+ " \"Linda didn't lose to David in the game but [David] was happy.\",\n",
+ " \"Susan didn't lose to Michael in the game but [Michael] was happy.\",\n",
+ " \"Mary didn't lose to Michael in the game but [Michael] was happy.\",\n",
+ " \"Linda didn't lose to Michael in the game but [Michael] was happy.\",\n",
+ " \"Susan didn't lose to John in the game but [Susan] was sad.\",\n",
+ " \"Mary didn't lose to John in the game but [Mary] was sad.\",\n",
+ " \"Linda didn't lose to John in the game but [Linda] was sad.\",\n",
+ " \"Susan didn't lose to David in the game but [Susan] was sad.\",\n",
+ " \"Mary didn't lose to David in the game but [Mary] was sad.\",\n",
+ " \"Linda didn't lose to David in the game but [Linda] was sad.\",\n",
+ " \"Susan didn't lose to Michael in the game but [Susan] was sad.\",\n",
+ " \"Mary didn't lose to Michael in the game but [Mary] was sad.\",\n",
+ " \"Linda didn't lose to Michael in the game but [Linda] was sad.\",\n",
+ " \"Susan didn't lose to John in the game but [John] wasn't sad.\",\n",
+ " \"Mary didn't lose to John in the game but [John] wasn't sad.\",\n",
+ " \"Linda didn't lose to John in the game but [John] wasn't sad.\",\n",
+ " \"Susan didn't lose to David in the game but [David] wasn't sad.\",\n",
+ " \"Mary didn't lose to David in the game but [David] wasn't sad.\",\n",
+ " \"Linda didn't lose to David in the game but [David] wasn't sad.\",\n",
+ " \"Susan didn't lose to Michael in the game but [Michael] wasn't sad.\",\n",
+ " \"Mary didn't lose to Michael in the game but [Michael] wasn't sad.\",\n",
+ " \"Linda didn't lose to Michael in the game but [Michael] wasn't sad.\",\n",
+ " \"Susan didn't lose to John in the game but [Susan] wasn't happy.\",\n",
+ " \"Mary didn't lose to John in the game but [Mary] wasn't happy.\",\n",
+ " \"Linda didn't lose to John in the game but [Linda] wasn't happy.\",\n",
+ " \"Susan didn't lose to David in the game but [Susan] wasn't happy.\",\n",
+ " \"Mary didn't lose to David in the game but [Mary] wasn't happy.\",\n",
+ " \"Linda didn't lose to David in the game but [Linda] wasn't happy.\",\n",
+ " \"Susan didn't lose to Michael in the game but [Susan] wasn't happy.\",\n",
+ " \"Mary didn't lose to Michael in the game but [Mary] wasn't happy.\",\n",
+ " \"Linda didn't lose to Michael in the game but [Linda] wasn't happy.\"]"
+ ]
+ },
+ "execution_count": 76,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "causal_sentences\n",
+ "turning_sentences\n",
+ "# substituted_sentences"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 40,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "examples = [(2,\n",
+ " \"The trophy doesn't fit into the brown suitcase because [it] is too large.\",\n",
+ " 'fit into:large/small'),\n",
+ " (4,\n",
+ " 'Joan made sure to thank Susan for all the help [she] had recieved.',\n",
+ " 'thank:receive/give'),\n",
+ " (10,\n",
+ " 'The delivery truck zoomed by the school bus because [it] was going so fast.',\n",
+ " 'zoom by:fast/slow'),\n",
+ " (12,\n",
+ " 'Frank felt vindicated when his longtime rival Bill revealed that [he] was the winner of the competition.',\n",
+ " 'vindicated/crushed:be the winner'),\n",
+ " (16,\n",
+ " 'The large ball crashed right through the table because [it] was made of steel.',\n",
+ " 'crash through:[hard]/[soft]'),\n",
+ " (18,\n",
+ " \"John couldn't see the stage with Billy in front of him because [he] is so short.\",\n",
+ " '[block]:short/tall'),\n",
+ " (20,\n",
+ " 'Tom threw his schoolbag down to Ray after [he] reached the top of the stairs.',\n",
+ " 'down to:top/bottom'),\n",
+ " (22,\n",
+ " 'Although they ran at about the same speed, Sue beat Sally because [she] had such a good start.',\n",
+ " 'beat:good/bad'),\n",
+ " (26,\n",
+ " \"Sam's drawing was hung just above Tina's and [it] did look much better with another one below it.\",\n",
+ " 'above/below'),\n",
+ " (28,\n",
+ " 'Anna did a lot better than her good friend Lucy on the test because [she] had studied so hard.',\n",
+ " 'better/worse:study hard'),\n",
+ " (30,\n",
+ " 'The firemen arrived after the police because [they] were coming from so far away.',\n",
+ " 'after/before:far away'),\n",
+ " (32,\n",
+ " \"Frank was upset with Tom because the toaster [he] had bought from him didn't work.\",\n",
+ " 'be upset with:buy from not work/sell not work'),\n",
+ " (36,\n",
+ " 'The sack of potatoes had been placed above the bag of flour, so [it] had to be moved first.',\n",
+ " 'above/below:moved first'),\n",
+ " (38,\n",
+ " 'Pete envies Martin although [he] is very successful.',\n",
+ " 'although/because'),\n",
+ " (42,\n",
+ " 'I poured water from the bottle into the cup until [it] was empty.',\n",
+ " 'pour:empty/full'),\n",
+ " (46,\n",
+ " \"Sid explained his theory to Mark but [he] couldn't convince him.\",\n",
+ " 'explain:convince/understand'),\n",
+ " (48,\n",
+ " \"Susan knew that Ann's son had been in a car accident, so [she] told her about it.\",\n",
+ " '?know tell:so/because'),\n",
+ " (50,\n",
+ " \"Joe's uncle can still beat him at tennis, even though [he] is 30 years younger.\",\n",
+ " 'beat:younger/older'),\n",
+ " (64,\n",
+ " 'In the middle of the outdoor concert, the rain started falling, but [it] continued until 10.',\n",
+ " 'but/and'),\n",
+ " (68,\n",
+ " 'Ann asked Mary what time the library closes, because [she] had forgotten.',\n",
+ " 'because/but'),\n",
+ " (84,\n",
+ " 'If the con artist has succeeded in fooling Sam, [he] would have gotten a lot of money.',\n",
+ " 'fool:get/lose'),\n",
+ " (92,\n",
+ " 'Alice tried frantically to stop her daughter from chatting at the party, leaving us to wonder why [she] was behaving so strangely.',\n",
+ " '?stop normal/stop abnormal:strange'),\n",
+ " (98,\n",
+ " \"I was trying to open the lock with the key, but someone had filled the keyhole with chewing gum, and I couldn't get [it] in.\",\n",
+ " 'put ... into filled with ... :get in/get out'),\n",
+ " (100,\n",
+ " 'The dog chased the cat, which ran up a tree. [It] waited at the bottom.',\n",
+ " 'up:at the bottom/at the top'),\n",
+ " (106,\n",
+ " 'John was doing research in the library when he heard a man humming and whistling. [He] was very annoyed.',\n",
+ " 'hear ... humming and whistling:annoyed/annoying'),\n",
+ " (108,\n",
+ " 'John was jogging through the park when he saw a man juggling watermelons. [He] was very impressed.',\n",
+ " 'see ... juggling watermelons:impressed/impressive'),\n",
+ " (132,\n",
+ " 'Jane knocked on the door, and Susan answered it. [She] invited her to come out.',\n",
+ " 'visit:invite come out/invite come in'),\n",
+ " (150,\n",
+ " 'Jackson was greatly influenced by Arnold, though [he] lived two centuries later.',\n",
+ " 'influence:later/earlier'),\n",
+ " (160,\n",
+ " 'The actress used to be named Terpsichore, but she changed it to Tina a few years ago, because she figured [it] was too hard to pronounce.',\n",
+ " 'change:hard/easy'),\n",
+ " (166,\n",
+ " 'Fred is the only man still alive who remembers my great-grandfather. [He] is a remarkable man.',\n",
+ " 'alive:is/was'),\n",
+ " (170,\n",
+ " \"In July, Kamtchatka declared war on Yakutsk. Since Yakutsk's army was much better equipped and ten times larger, [they] were defeated within weeks.\",\n",
+ " 'better equipped and large:defeated/victorious'),\n",
+ " (186,\n",
+ " 'When the sponsors of the bill got to the town hall, they were surprised to find that the room was full of opponents. [They] were very much in the minority.',\n",
+ " 'be full of:minority/majority'),\n",
+ " (188,\n",
+ " 'Everyone really loved the oatmeal cookies; only a few people liked the chocolate chip cookies. Next time, we should make more of [them] .',\n",
+ " 'like over:more/fewer'),\n",
+ " (190,\n",
+ " 'We had hoped to place copies of our newsletter on all the chairs in the auditorium, but there were simply not enough of [them] .',\n",
+ " 'place on all:not enough/too many'),\n",
+ " (196,\n",
+ " \"Steve follows Fred's example in everything. [He] admires him hugely.\",\n",
+ " 'follow:admire/influence'),\n",
+ " (198,\n",
+ " \"The table won't fit through the doorway because [it] is too wide.\",\n",
+ " 'fit through:wide/narrow'),\n",
+ " (200,\n",
+ " 'Grace was happy to trade me her sweater for my jacket. She thinks [it] looks dowdy on her.',\n",
+ " 'trade:dowdy/great'),\n",
+ " (202,\n",
+ " 'John hired Bill to take care of [him] .',\n",
+ " 'hire/hire oneself to:take care of'),\n",
+ " (204,\n",
+ " 'John promised Bill to leave, so an hour later [he] left.',\n",
+ " 'promise/order'),\n",
+ " (210,\n",
+ " \"Jane knocked on Susan's door but [she] did not get an answer.\",\n",
+ " 'knock:get an answer/answer'),\n",
+ " (212,\n",
+ " 'Joe paid the detective after [he] received the final report on the case.',\n",
+ " 'pay:receive/deliver'),\n",
+ " (226,\n",
+ " 'Bill passed the half-empty plate to John because [he] was full.',\n",
+ " 'pass the plate:full/hungry'),\n",
+ " (252,\n",
+ " 'George got free tickets to the play, but he gave them to Eric, even though [he] was particularly eager to see it.',\n",
+ " 'even though/because/not'),\n",
+ " (255,\n",
+ " \"Jane gave Joan candy because [she] wasn't hungry.\",\n",
+ " 'give:not hungry/hungry'),\n",
+ " (259,\n",
+ " 'James asked Robert for a favor but [he] was refused.',\n",
+ " 'ask for a favor:refuse/be refused`'),\n",
+ " (261,\n",
+ " 'Kirilov ceded the presidency to Shatov because [he] was less popular.',\n",
+ " 'cede:less popular/more popular'),\n",
+ " (263,\n",
+ " 'Emma did not pass the ball to Janie although [she] saw that she was open.',\n",
+ " 'not pass although:see open/open')]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 77,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "47"
+ ]
+ },
+ "execution_count": 77,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "len(examples)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.5"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/Untitled_likunlin-Copy1.ipynb b/Untitled_likunlin-Copy1.ipynb
new file mode 100644
index 00000000000000..a48277551d3723
--- /dev/null
+++ b/Untitled_likunlin-Copy1.ipynb
@@ -0,0 +1,827 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%load_ext autoreload\n",
+ "%autoreload 2\n",
+ "\n",
+ "from IPython.core.interactiveshell import InteractiveShell\n",
+ "InteractiveShell.ast_node_interactivity = 'all'"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "/home/xd/projects/pytorch-pretrained-BERT/pytorch_pretrained_bert/__init__.py\n"
+ ]
+ }
+ ],
+ "source": [
+ "import os\n",
+ "import json\n",
+ "\n",
+ "import numpy as np\n",
+ "import math\n",
+ "import matplotlib\n",
+ "import matplotlib.pyplot as plt\n",
+ "from pylab import rcParams\n",
+ "\n",
+ "import torch\n",
+ "import torch.nn.functional as F\n",
+ "from pytorch_pretrained_bert import tokenization, BertTokenizer, BertModel, BertForMaskedLM, BertForPreTraining, BertConfig\n",
+ "from examples.extract_features import *\n",
+ "\n",
+ "import pytorch_pretrained_bert\n",
+ "print(pytorch_pretrained_bert.__file__)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "01/03/2019 16:37:32 - INFO - pytorch_pretrained_bert.tokenization - loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /home/xd/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084\n",
+ "01/03/2019 16:37:32 - INFO - pytorch_pretrained_bert.modeling - loading archive file /nas/pretrain-bert/pretrain-tensorflow/uncased_L-12_H-768_A-12/\n",
+ "01/03/2019 16:37:32 - INFO - pytorch_pretrained_bert.modeling - Model config {\n",
+ " \"attention_probs_dropout_prob\": 0.1,\n",
+ " \"hidden_act\": \"gelu\",\n",
+ " \"hidden_dropout_prob\": 0.1,\n",
+ " \"hidden_size\": 768,\n",
+ " \"initializer_range\": 0.02,\n",
+ " \"intermediate_size\": 3072,\n",
+ " \"max_position_embeddings\": 512,\n",
+ " \"num_attention_heads\": 12,\n",
+ " \"num_hidden_layers\": 12,\n",
+ " \"type_vocab_size\": 2,\n",
+ " \"vocab_size\": 30522\n",
+ "}\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "class Args:\n",
+ " def __init__(self):\n",
+ " pass\n",
+ " \n",
+ "args = Args()\n",
+ "args.no_cuda = False\n",
+ "\n",
+ "CONFIG_NAME = 'bert_config.json'\n",
+ "BERT_DIR = '/nas/pretrain-bert/pretrain-tensorflow/uncased_L-12_H-768_A-12/'\n",
+ "config_file = os.path.join(BERT_DIR, CONFIG_NAME)\n",
+ "config = BertConfig.from_json_file(config_file)\n",
+ "\n",
+ "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n",
+ "model = BertForPreTraining.from_pretrained(BERT_DIR)\n",
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() and not args.no_cuda else \"cpu\")\n",
+ "_ = model.to(device)\n",
+ "_ = model.eval()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import re\n",
+ "def convert_text_to_examples(text):\n",
+ " examples = []\n",
+ " unique_id = 0\n",
+ " if True:\n",
+ " for line in text:\n",
+ " line = line.strip()\n",
+ " text_a = None\n",
+ " text_b = None\n",
+ " m = re.match(r\"^(.*) \\|\\|\\| (.*)$\", line)\n",
+ " if m is None:\n",
+ " text_a = line\n",
+ " else:\n",
+ " text_a = m.group(1)\n",
+ " text_b = m.group(2)\n",
+ " examples.append(\n",
+ " InputExample(unique_id=unique_id, text_a=text_a, text_b=text_b))\n",
+ " unique_id += 1\n",
+ " return examples\n",
+ "\n",
+ "def convert_examples_to_features(examples, tokenizer, append_special_tokens=True, replace_mask=True, print_info=False):\n",
+ " features = []\n",
+ " for (ex_index, example) in enumerate(examples):\n",
+ " tokens_a = tokenizer.tokenize(example.text_a)\n",
+ " tokens_b = None\n",
+ " if example.text_b:\n",
+ " tokens_b = tokenizer.tokenize(example.text_b)\n",
+ "\n",
+ " tokens = []\n",
+ " input_type_ids = []\n",
+ " if append_special_tokens:\n",
+ " tokens.append(\"[CLS]\")\n",
+ " input_type_ids.append(0)\n",
+ " for token in tokens_a:\n",
+ " if replace_mask and token == '_': # XD\n",
+ " token = \"[MASK]\"\n",
+ " tokens.append(token)\n",
+ " input_type_ids.append(0)\n",
+ " if append_special_tokens:\n",
+ " tokens.append(\"[SEP]\")\n",
+ " input_type_ids.append(0)\n",
+ "\n",
+ " if tokens_b:\n",
+ " for token in tokens_b:\n",
+ " if replace_mask and token == '_': # XD\n",
+ " token = \"[MASK]\"\n",
+ " tokens.append(token)\n",
+ " input_type_ids.append(1)\n",
+ " if append_special_tokens:\n",
+ " tokens.append(\"[SEP]\")\n",
+ " input_type_ids.append(1)\n",
+ "\n",
+ " input_ids = tokenizer.convert_tokens_to_ids(tokens)\n",
+ " input_mask = [1] * len(input_ids)\n",
+ "\n",
+ " if ex_index < 5:\n",
+ "# logger.info(\"*** Example ***\")\n",
+ "# logger.info(\"unique_id: %s\" % (example.unique_id))\n",
+ " logger.info(\"tokens: %s\" % \" \".join([str(x) for x in tokens]))\n",
+ "# logger.info(\"input_ids: %s\" % \" \".join([str(x) for x in input_ids]))\n",
+ "# logger.info(\"input_mask: %s\" % \" \".join([str(x) for x in input_mask]))\n",
+ "# logger.info(\n",
+ "# \"input_type_ids: %s\" % \" \".join([str(x) for x in input_type_ids]))\n",
+ " \n",
+ " features.append(\n",
+ " InputFeatures(\n",
+ " unique_id=example.unique_id,\n",
+ " tokens=tokens,\n",
+ " input_ids=input_ids,\n",
+ " input_mask=input_mask,\n",
+ " input_type_ids=input_type_ids))\n",
+ " return features\n",
+ "\n",
+ "def copy_and_mask_feature(feature, masked_tokens=None):\n",
+ " import copy\n",
+ " tokens = feature.tokens\n",
+ " masked_positions = [tokens.index(t) for t in masked_tokens if t in tokens] \\\n",
+ " if masked_tokens is not None else range(len(tokens))\n",
+ " assert len(masked_positions) > 0\n",
+ " masked_feature_copies = []\n",
+ " for masked_pos in masked_positions:\n",
+ " feature_copy = copy.deepcopy(feature)\n",
+ " feature_copy.input_ids[masked_pos] = tokenizer.vocab[\"[MASK]\"]\n",
+ " masked_feature_copies.append(feature_copy)\n",
+ " return masked_feature_copies, masked_positions"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def show_lm_probs(tokens, input_ids, probs, topk=5, firstk=20):\n",
+ " def print_pair(token, prob, end_str='', hit_mark=' '):\n",
+ " if i < firstk:\n",
+ " # token = token.replace('', '').replace('\\n', '/n')\n",
+ " print('{}{: >3} | {: <12}'.format(hit_mark, int(round(prob*100)), token), end=end_str)\n",
+ " \n",
+ " ret = None\n",
+ " for i in range(len(tokens)):\n",
+ " ind_ = input_ids[i].item() if input_ids is not None else tokenizer.vocab[tokens[i]]\n",
+ " prob_ = probs[i][ind_].item()\n",
+ " print_pair(tokens[i], prob_, end_str='\\t')\n",
+ " values, indices = probs[i].topk(topk)\n",
+ " top_pairs = []\n",
+ " for j in range(topk):\n",
+ " ind, prob = indices[j].item(), values[j].item()\n",
+ " hit_mark = '*' if ind == ind_ else ' '\n",
+ " token = tokenizer.ids_to_tokens[ind]\n",
+ " print_pair(token, prob, hit_mark=hit_mark, end_str='' if j < topk - 1 else '\\n')\n",
+ " top_pairs.append((token, prob))\n",
+ " if tokens[i] == \"[MASK]\":\n",
+ " ret = top_pairs\n",
+ " return ret"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import colored\n",
+ "from colored import stylize\n",
+ "\n",
+ "def show_abnormals(tokens, probs, show_suggestions=False):\n",
+ " def gap2color(gap):\n",
+ " if gap <= 5:\n",
+ " return 'yellow_1'\n",
+ " elif gap <= 10:\n",
+ " return 'orange_1'\n",
+ " else:\n",
+ " return 'red_1'\n",
+ " \n",
+ " def print_token(token, suggestion, gap):\n",
+ " if gap == 0:\n",
+ " print(stylize(token + ' ', colored.fg('white') + colored.bg('black')), end='')\n",
+ " else:\n",
+ " print(stylize(token, colored.fg(gap2color(gap)) + colored.bg('black')), end='')\n",
+ " if show_suggestions and gap > 5:\n",
+ " print(stylize('/' + suggestion + ' ', colored.fg('green' if gap > 10 else 'cyan') + colored.bg('black')), end='')\n",
+ " else:\n",
+ " print(stylize(' ', colored.fg(gap2color(gap)) + colored.bg('black')), end='')\n",
+ " # print('/' + suggestion, end=' ')\n",
+ " # print('%.2f' % gap, end=' ')\n",
+ " \n",
+ " avg_gap = 0.\n",
+ " for i in range(1, len(tokens) - 1): # skip first [CLS] and last [SEP]\n",
+ " ind_ = tokenizer.vocab[tokens[i]]\n",
+ " prob_ = probs[i][ind_].item()\n",
+ " top_prob = probs[i].max().item()\n",
+ " top_ind = probs[i].argmax().item()\n",
+ " gap = math.log(top_prob) - math.log(prob_)\n",
+ " suggestion = tokenizer.ids_to_tokens[top_ind]\n",
+ " print_token(tokens[i], suggestion, gap)\n",
+ " avg_gap += gap\n",
+ " avg_gap /= (len(tokens) - 2)\n",
+ " print()\n",
+ " print(avg_gap)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "analyzed_cache = {}\n",
+ "\n",
+ "def analyze_text(text, masked_tokens=None, show_suggestions=False, show_firstk_probs=20):\n",
+ " if text[0] in analyzed_cache:\n",
+ " features, mlm_probs = analyzed_cache[text[0]]\n",
+ " given_mask = \"[MASK]\" in features[0].tokens\n",
+ " tokens = features[0].tokens\n",
+ " else:\n",
+ " examples = convert_text_to_examples(text)\n",
+ " features = convert_examples_to_features(examples, tokenizer, print_info=False)\n",
+ " given_mask = \"[MASK]\" in features[0].tokens\n",
+ " if not given_mask or masked_tokens is not None:\n",
+ " assert len(features) == 1\n",
+ " features, masked_positions = copy_and_mask_feature(features[0], masked_tokens=masked_tokens)\n",
+ "\n",
+ " input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)\n",
+ " input_type_ids = torch.tensor([f.input_type_ids for f in features], dtype=torch.long)\n",
+ " input_ids = input_ids.to(device)\n",
+ " input_type_ids = input_type_ids.to(device)\n",
+ "\n",
+ " mlm_logits, _ = model(input_ids, input_type_ids)\n",
+ " mlm_probs = F.softmax(mlm_logits, dim=-1)\n",
+ "\n",
+ " tokens = features[0].tokens\n",
+ " if not given_mask or masked_tokens is not None:\n",
+ " bsz, seq_len, vocab_size = mlm_probs.size()\n",
+ " assert bsz == len(masked_positions)\n",
+ " # reduced_mlm_probs = torch.Tensor(1, seq_len, vocab_size)\n",
+ " # for i in range(seq_len):\n",
+ " # reduced_mlm_probs[0, i] = mlm_probs[i, i]\n",
+ " reduced_mlm_probs = torch.Tensor(1, len(masked_positions), vocab_size)\n",
+ " for i, pos in enumerate(masked_positions):\n",
+ " reduced_mlm_probs[0, i] = mlm_probs[i, pos]\n",
+ " mlm_probs = reduced_mlm_probs\n",
+ " tokens = [tokens[i] for i in masked_positions]\n",
+ " \n",
+ " analyzed_cache[text[0]] = (features, mlm_probs)\n",
+ " \n",
+ " top_pairs = show_lm_probs(tokens, None, mlm_probs[0], firstk=show_firstk_probs)\n",
+ " if not given_mask:\n",
+ " show_abnormals(tokens, mlm_probs[0], show_suggestions=show_suggestions)\n",
+ " return top_pairs"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "01/03/2019 17:13:21 - INFO - examples.extract_features - tokens: [CLS] what ingredients account for the marvelous function of a dream ? [SEP]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " 0 | [CLS] \t 3 | . 1 | the 1 | , 1 | ) 1 | \" \n",
+ " 35 | what \t* 35 | what 25 | do 9 | can 7 | could 5 | would \n",
+ " 0 | ingredients \t 51 | could 23 | would 13 | can 8 | might 2 | may \n",
+ " 0 | account \t 32 | were 26 | are 7 | remained 6 | existed 6 | exist \n",
+ " 100 | for \t*100 | for 0 | to 0 | of 0 | up 0 | all \n",
+ " 98 | the \t* 98 | the 2 | this 0 | a 0 | that 0 | such \n",
+ " 0 | marvelous \t 5 | biological 5 | normal 4 | cognitive 2 | specific 2 | physiological\n",
+ " 0 | function \t 21 | ##ness 8 | beauty 5 | quality 5 | nature 4 | power \n",
+ " 91 | of \t* 91 | of 8 | in 0 | within 0 | as 0 | during \n",
+ " 14 | a \t 55 | the 16 | this * 14 | a 4 | my 3 | his \n",
+ " 0 | dream \t 3 | heart 3 | plant 3 | soul 2 | brain 2 | body \n",
+ " 98 | ? \t* 98 | ? 2 | . 0 | ; 0 | ! 0 | | \n",
+ " 0 | [SEP] \t 13 | what 12 | \" 7 | they 4 | and 4 | ' \n",
+ "\u001b[38;5;15m\u001b[48;5;0mwhat \u001b[0m\u001b[38;5;196m\u001b[48;5;0mingredients\u001b[0m\u001b[38;5;196m\u001b[48;5;0m \u001b[0m\u001b[38;5;226m\u001b[48;5;0maccount\u001b[0m\u001b[38;5;226m\u001b[48;5;0m \u001b[0m\u001b[38;5;15m\u001b[48;5;0mfor \u001b[0m\u001b[38;5;15m\u001b[48;5;0mthe \u001b[0m\u001b[38;5;214m\u001b[48;5;0mmarvelous\u001b[0m\u001b[38;5;214m\u001b[48;5;0m \u001b[0m\u001b[38;5;214m\u001b[48;5;0mfunction\u001b[0m\u001b[38;5;214m\u001b[48;5;0m \u001b[0m\u001b[38;5;15m\u001b[48;5;0mof \u001b[0m\u001b[38;5;226m\u001b[48;5;0ma\u001b[0m\u001b[38;5;226m\u001b[48;5;0m \u001b[0m\u001b[38;5;226m\u001b[48;5;0mdream\u001b[0m\u001b[38;5;226m\u001b[48;5;0m \u001b[0m\u001b[38;5;15m\u001b[48;5;0m? \u001b[0m\n",
+ "3.421217077676471\n"
+ ]
+ }
+ ],
+ "source": [
+ "# text = [\"Who was Jim Henson? Jim Henson _ a puppeteer.\"]\n",
+ "text = [\"What ingredients account for the marvelous function of a dream?\"]\n",
+ "# text = [\"Last week I went to the theatre. I had a very good seat. The play was very interesting. But I didn't enjoy it. A young man and a young woman were sitting behind me. They were talking loudly. I got very angry. I couldn't hear a word. I turned round. I looked at the man angrily. They didn't pay any attention.In the end, I couldn't bear it. I turned round again. 'I can't hear a word!' I said angrily. 'It's none of your business,' the young man said rudely. 'This is a private conversation!'\"]\n",
+ "# text = [\"After the outbreak of the disease, the Ministry of Agriculture and rural areas immediately sent a supervision team to the local. Local Emergency Response Mechanism has been activated in accordance with the requirements, to take blockade, culling, harmless treatment, disinfection and other treatment measures to all disease and culling of pigs for harmless treatment. At the same time, all live pigs and their products are prohibited from transferring out of the blockade area, and live pigs are not allowed to be transported into the blockade area. At present, all the above measures have been implemented.\"]\n",
+ "# text = [\"Early critics of Emily Dickinson's poetry mistook for simplemindedness the surface of artlessness that in fact she constructed with such innocence.\"]\n",
+ "analyze_text(text, show_firstk_probs=100)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "01/03/2019 17:10:45 - INFO - examples.extract_features - tokens: [CLS] the trophy doesn ' t fit into the brown suitcase because the [MASK] is too large . [SEP]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " 0 | [CLS] \t 2 | . 1 | ) 1 | the 1 | , 1 | \" \n",
+ " 100 | the \t*100 | the 0 | his 0 | a 0 | its 0 | her \n",
+ " 97 | trophy \t* 97 | trophy 0 | cup 0 | prize 0 | trophies 0 | competition \n",
+ " 100 | doesn \t*100 | doesn 0 | can 0 | does 0 | won 0 | didn \n",
+ " 100 | ' \t*100 | ' 0 | t 0 | \" 0 | = 0 | ` \n",
+ " 100 | t \t*100 | t 0 | not 0 | s 0 | n 0 | to \n",
+ " 100 | fit \t*100 | fit 0 | fits 0 | sit 0 | get 0 | fitting \n",
+ " 100 | into \t*100 | into 0 | in 0 | inside 0 | onto 0 | within \n",
+ " 100 | the \t*100 | the 0 | her 0 | his 0 | a 0 | my \n",
+ " 100 | brown \t*100 | brown 0 | black 0 | green 0 | blue 0 | plastic \n",
+ " 95 | suitcase \t* 95 | suitcase 3 | bag 1 | luggage 0 | backpack 0 | trunk \n",
+ " 100 | because \t*100 | because 0 | as 0 | since 0 | due 0 | . \n",
+ " 100 | the \t*100 | the 0 | its 0 | his 0 | it 0 | her \n",
+ " 0 | [MASK] \t 21 | suitcase 19 | bag 6 | box 2 | luggage 2 | case \n",
+ " 99 | is \t* 99 | is 1 | was 0 | being 0 | has 0 | it \n",
+ " 100 | too \t*100 | too 0 | very 0 | extra 0 | overly 0 | more \n",
+ " 87 | large \t* 87 | large 11 | big 1 | small 1 | huge 0 | larger \n",
+ " 100 | . \t*100 | . 0 | ; 0 | , 0 | ! 0 | ' \n",
+ " 0 | [SEP] \t 35 | . 8 | ) 5 | , 4 | ( 3 | it \n"
+ ]
+ }
+ ],
+ "source": [
+ "text = [\"The trophy doesn't fit into the brown suitcase because the _ is too large.\"]\n",
+ "# text = [\"Mary beat John in the match because _ was very strong.\"]\n",
+ "features = convert_examples_to_features(convert_text_to_examples(text), tokenizer, print_info=False)\n",
+ "input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long).to(device)\n",
+ "input_type_ids = torch.tensor([f.input_type_ids for f in features], dtype=torch.long).to(device)\n",
+ "mlm_logits, _ = model(input_ids, input_type_ids)\n",
+ "mlm_probs = F.softmax(mlm_logits, dim=-1)\n",
+ "tokens = features[0].tokens\n",
+ "top_pairs = show_lm_probs(tokens, None, mlm_probs[0], firstk=100)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "['Tom has black hair. Mary has black hair. John has yellow hair. _ and Mary have the same hair color.',\n",
+ " 'Tom has black hair. Mary has black hair. John has yellow hair. _ and Mary have different hair colors.']"
+ ]
+ },
+ "execution_count": 19,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "text = [\n",
+ " # same / different\n",
+ " \"Tom has black hair. Mary has black hair. John has yellow hair. _ and Mary have the same hair color.\",\n",
+ " \"Tom has black hair. Mary has black hair. John has yellow hair. _ and Mary have different hair colors.\",\n",
+ " \"Tom has yellow hair. Mary has black hair. John has black hair. Mary and _ have the same hair color.\",\n",
+ " # because / although\n",
+ " \"John is taller/shorter than Mary because/although _ is older/younger.\",\n",
+ " \"The red ball is heavier/lighter than the blue ball because/although the _ ball is bigger/smaller.\",\n",
+ " \"Charles did a lot better/worse than his good friend Nancy on the test because/although _ had/hadn't studied so hard.\",\n",
+ " \"The trophy doesn't fit into the brown suitcase because/although the _ is too small/large.\",\n",
+ " \"John thought that he would arrive earlier than Susan, but/and indeed _ was the first to arrive.\",\n",
+ " # reverse\n",
+ " \"John came then Mary came. They left in reverse order. _ left then _ left.\",\n",
+ " \"John came after Mary. They left in reverse order. _ left after _ .\",\n",
+ " \"John came first, then came Mary. They left in reverse order: _ left first, then left _ .\",\n",
+ " # compare\n",
+ " \"Though John is tall, Tom is taller than John. So John is _ than Tom.\",\n",
+ " \"Tom is taller than John. So _ is shorter than _.\",\n",
+ " # WSC-style: before /after\n",
+ " \"Mary came before/after John. _ was late/early .\",\n",
+ " # yes / no\n",
+ " \"Was Tom taller than Susan? Yes, _ was taller.\",\n",
+ " # right / wrong, epistemic modality\n",
+ " \"John said the rain was about to stop. Mary said the rain would continue. Later the rain stopped. _ was wrong.\",\n",
+ " \n",
+ " \"The trophy doesn't fit into the brown suitcase because/although the _ is too small/large.\",\n",
+ " \"John thanked Mary because _ had given help to _ . \",\n",
+ " \"John felt vindicated/crushed when his longtime rival Mary revealed that _ was the winner of the competition.\",\n",
+ " \"John couldn't see the stage with Mary in front of him because _ is so short/tall.\",\n",
+ " \"Although they ran at about the same speed, John beat Sally because _ had such a bad start.\",\n",
+ " \"The fish ate the worm. The _ was hungry/tasty.\",\n",
+ " \n",
+ " \"John beat Mary. _ won the game/e winner.\",\n",
+ "]\n",
+ "text"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "config"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "with open('WSC_switched_label.json') as f:\n",
+ " examples = json.load(f)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "with open('WSC_child_problem.json') as f:\n",
+ " cexamples = json.load(f)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 89,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "for ce in cexamples:\n",
+ " for s in ce['sentences']:\n",
+ " for a in s['answer0'] + s['answer1']:\n",
+ " a = a.lower()\n",
+ " if a not in tokenizer.vocab:\n",
+ " ce\n",
+ " print(a, 'not in vocab!!!')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "for ce in cexamples:\n",
+ " if len(ce['sentences']) > 0:\n",
+ " e = examples[ce['index']]\n",
+ " assert ce['index'] == e['index']\n",
+ " e['score'] = all([s['score'] for s in ce['sentences']])\n",
+ " assert len(set([s['adjacent_ref'] for s in ce['sentences']])) == 1, 'adjcent_refs are different!'\n",
+ " e['adjacent_ref'] = ce['sentences'][0]['adjacent_ref']"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from collections import defaultdict\n",
+ "\n",
+ "groups = defaultdict(list)\n",
+ "for e in examples:\n",
+ " if 'score' in e:\n",
+ " index = e['index']\n",
+ " if index < 252:\n",
+ " if index % 2 == 1:\n",
+ " index -= 1\n",
+ " elif index in [252, 253, 254]:\n",
+ " index = 252\n",
+ " else:\n",
+ " if index % 2 == 0:\n",
+ " index -= 1\n",
+ " groups[index].append(e)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 62,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[(2, 'fit into:large/small', False),\n",
+ " (4, 'thank:receive/give', False),\n",
+ " (6, 'call:successful available', True),\n",
+ " (8, 'ask:repeat answer', False),\n",
+ " (10, 'zoom by:fast/slow', False),\n",
+ " (12, 'vindicated/crushed:be the winner', False),\n",
+ " (14, 'lift:weak heavy', False),\n",
+ " (16, 'crash through:[hard]/[soft]', False),\n",
+ " (18, '[block]:short/tall', False),\n",
+ " (20, 'down to:top/bottom', False),\n",
+ " (22, 'beat:good/bad', False),\n",
+ " (24, 'roll off:anchored level', False),\n",
+ " (26, 'above/below', False),\n",
+ " (28, 'better/worse:study hard', False),\n",
+ " (30, 'after/before:far away', False),\n",
+ " (32, 'be upset with:buy from not work/sell not work', True),\n",
+ " (34, '?yell at comfort:upset', False),\n",
+ " (36, 'above/below:moved first', False),\n",
+ " (38, 'although/because', False),\n",
+ " (40, 'bully:punish rescue', False),\n",
+ " (42, 'pour:empty/full', False),\n",
+ " (44, 'know:nosy indiscreet', False),\n",
+ " (46, 'explain:convince/understand', True),\n",
+ " (48, '?know tell:so/because', True),\n",
+ " (50, 'beat:younger/older', False),\n",
+ " (56, 'clog:cleaned removed', True),\n",
+ " (58, '?immediately follow:short delayed', False),\n",
+ " (60, '?between:see see around', True),\n",
+ " (64, 'but/and', False),\n",
+ " (66, 'clean:put in the trash put in the drawer', False),\n",
+ " (68, 'because/but', False),\n",
+ " (70, 'out of:handy lighter', False),\n",
+ " (72, 'put:tall high', False),\n",
+ " (74, 'show:good famous', True),\n",
+ " (76, 'pay for:generous grateful', False),\n",
+ " (78, 'but', False),\n",
+ " (80, 'if', False),\n",
+ " (82, 'if', False),\n",
+ " (84, 'fool:get/lose', False),\n",
+ " (88, 'wait:impatient cautious', False),\n",
+ " (90, 'give birth:woman baby', True),\n",
+ " (92, '?stop normal/stop abnormal:strange', False),\n",
+ " (96, 'eat:hungry tasty', False),\n",
+ " (98, 'put ... into filled with ... :get in/get out', False),\n",
+ " (100, 'up:at the bottom/at the top', False),\n",
+ " (102, 'crash through:removed repaired', False),\n",
+ " (104, 'stab:taken to the police station taken to the hospital', False),\n",
+ " (106, 'hear ... humming and whistling:annoyed/annoying', True),\n",
+ " (108, 'see ... juggling watermelons:impressed/impressive', True),\n",
+ " (114, 'tell lies: truthful skeptical', True),\n",
+ " (130, 'but:disappointed', True),\n",
+ " (132, 'visit:invite come out/invite come in', True),\n",
+ " (134, 'take classes from:eager known to speak it fluently', False),\n",
+ " (138, 'cover:out gone', True),\n",
+ " (144, 'tuck:work sleep', True),\n",
+ " (150, 'influence:later/earlier', False),\n",
+ " (152, 'can not cut:thick small', False),\n",
+ " (154, 'attack:kill guard', False),\n",
+ " (156, 'attack:bold nervous', False),\n",
+ " (160, 'change:hard:easy', False),\n",
+ " (166, 'alive:is/was', False),\n",
+ " (168, 'infant:twelve years old twelve months old', False),\n",
+ " (170, 'better equipped and large:defeated/victorious', False),\n",
+ " (178, 'interview:persistent cooperative', False),\n",
+ " (186, 'be full of:minority/majority', False),\n",
+ " (188, 'like over:more/fewer', False),\n",
+ " (190, 'place on all:not enough/too many', True),\n",
+ " (192, 'stick:leave have', True),\n",
+ " (196, 'follow:admire/influence', True),\n",
+ " (198, 'fit through:wide/narrow', False),\n",
+ " (200, 'trade:dowdy/great', False),\n",
+ " (202, 'hire/hire oneself to:take care of', True),\n",
+ " (204, 'promise/order', False),\n",
+ " (208, 'mother:education place', True),\n",
+ " (210, 'knock:get an answer/answer', True),\n",
+ " (212, 'pay:receive/deliver', False),\n",
+ " (218, '?', False),\n",
+ " (220, 'say check:move take', False),\n",
+ " (222, '?', False),\n",
+ " (224, 'give a life:drive alone walk', False),\n",
+ " (226, 'pass the plate:full/hungry', False),\n",
+ " (228, 'pass:turn over turn next', False),\n",
+ " (232, 'stretch pat', True),\n",
+ " (234, 'accept share', False),\n",
+ " (236, 'speak:break silence break concentration', False),\n",
+ " (240, 'carry:leg ache leg dangle', True),\n",
+ " (242, 'carry:in arms in bassinet', False),\n",
+ " (244, 'hold:against chest against will', True),\n",
+ " (250, 'stop', False),\n",
+ " (252, 'even though/because/not', False),\n",
+ " (255, 'give:not hungry/hungry', False),\n",
+ " (259, 'ask for a favor:refuse/be refused`', False),\n",
+ " (261, 'cede:less popular/more popular', False),\n",
+ " (263, 'not pass although:see open/open', True),\n",
+ " (271, 'suspect regret', True)]"
+ ]
+ },
+ "execution_count": 62,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "def filter_dict(d, keys=['index', 'sentence', 'correct_answer', 'relational_word', 'is_associative', 'score']):\n",
+ " return {k: d[k] for k in d if k in keys}\n",
+ "\n",
+ "# ([[filter_dict(e) for e in eg] for eg in groups.values() if eg[0]['relational_word'] != 'none' and all([e['score'] for e in eg])])# / len([eg for eg in groups.values() if eg[0]['relational_word'] != 'none'])\n",
+ "[(index, eg[0]['relational_word'], all([e['score'] for e in eg])) for index, eg in groups.items() if eg[0]['relational_word'] != 'none']\n",
+ "# len([filter_dict(e) for e in examples if 'score' in e and not e['score'] and e['adjacent_ref']])\n",
+ "# for e in examples:\n",
+ "# if e['index'] % 2 == 0:\n",
+ "# print(e['sentence'])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 51,
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "179"
+ ]
+ },
+ "execution_count": 51,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "sum(['because' in e['sentence'] for e in examples]) + \\\n",
+ "sum(['so ' in e['sentence'] for e in examples]) + \\\n",
+ "sum(['but ' in e['sentence'] for e in examples]) + \\\n",
+ "sum(['though' in e['sentence'] for e in examples])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 73,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# with open('WSC_switched_label.json', 'w') as f:\n",
+ "# json.dump(examples, f)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "vis_attn_topk = 3\n",
+ "\n",
+ "def has_chinese_label(labels):\n",
+ " labels = [label.split('->')[0].strip() for label in labels]\n",
+ " r = sum([len(label) > 1 for label in labels if label not in ['BOS', 'EOS']]) * 1. / (len(labels) - 1)\n",
+ " return 0 < r < 0.5 # r == 0 means empty query labels used in self attention\n",
+ "\n",
+ "def _plot_attn(ax1, attn_name, attn, key_labels, query_labels, col, color='b'):\n",
+ " assert len(query_labels) == attn.size(0)\n",
+ " assert len(key_labels) == attn.size(1)\n",
+ "\n",
+ " ax1.set_xlim([-1, 1])\n",
+ " ax1.set_xticks([])\n",
+ " ax2 = ax1.twinx()\n",
+ " nlabels = max(len(key_labels), len(query_labels))\n",
+ " pos = range(nlabels)\n",
+ " \n",
+ " if 'self' in attn_name and col < ncols - 1:\n",
+ " query_labels = ['' for _ in query_labels]\n",
+ "\n",
+ " for ax, labels in [(ax1, key_labels), (ax2, query_labels)]:\n",
+ " ax.set_yticks(pos)\n",
+ " if has_chinese_label(labels):\n",
+ " ax.set_yticklabels(labels, fontproperties=zhfont)\n",
+ " else:\n",
+ " ax.set_yticklabels(labels)\n",
+ " ax.set_ylim([nlabels - 1, 0])\n",
+ " ax.tick_params(width=0, labelsize='xx-large')\n",
+ "\n",
+ " for spine in ax.spines.values():\n",
+ " spine.set_visible(False)\n",
+ "\n",
+ "# mask, attn = filter_attn(attn)\n",
+ " for qi in range(attn.size(0)):\n",
+ "# if not mask[qi]:\n",
+ "# continue\n",
+ "# for ki in range(attn.size(1)):\n",
+ " for ki in attn[qi].topk(vis_attn_topk)[1]:\n",
+ " a = attn[qi, ki]\n",
+ " ax1.plot((-1, 1), (ki, qi), color, alpha=a)\n",
+ "# print(attn.mean(dim=0).topk(5)[0])\n",
+ "# ax1.barh(pos, attn.mean(dim=0).data.cpu().numpy())\n",
+ "\n",
+ "def plot_layer_attn(result_tuple, attn_name='dec_self_attns', layer=0, heads=None):\n",
+ " hypo, nheads, labels_dict = result_tuple\n",
+ " key_labels, query_labels = labels_dict[attn_name]\n",
+ " if heads is None:\n",
+ " heads = range(nheads)\n",
+ " else:\n",
+ " nheads = len(heads)\n",
+ " \n",
+ " stride = 2 if attn_name == 'dec_enc_attns' else 1\n",
+ " nlabels = max(len(key_labels), len(query_labels))\n",
+ " rcParams['figure.figsize'] = 20, int(round(nlabels * stride * nheads / 8 * 1.0))\n",
+ " \n",
+ " rows = nheads // ncols * stride\n",
+ " fig, axes = plt.subplots(rows, ncols)\n",
+ " \n",
+ " # for head in range(nheads):\n",
+ " for head_i, head in enumerate(heads):\n",
+ " row, col = head_i * stride // ncols, head_i * stride % ncols\n",
+ " ax1 = axes[row, col]\n",
+ " attn = hypo[attn_name][layer][head]\n",
+ " _plot_attn(ax1, attn_name, attn, key_labels, query_labels, col)\n",
+ " if attn_name == 'dec_enc_attns':\n",
+ " col = col + 1\n",
+ " axes[row, col].axis('off') # next subfig acts as blank place holder\n",
+ " # plt.suptitle('%s with %d heads, Layer %d' % (attn_name, nheads, layer), fontsize=20)\n",
+ " plt.show() \n",
+ " \n",
+ "ncols = 4"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/png": "\n",
+ "text/plain": [
+ "