-
Notifications
You must be signed in to change notification settings - Fork 51
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 2d51d43
Showing
23 changed files
with
1,664 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
data/ | ||
runs/ | ||
checkpoint/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
MIT License | ||
|
||
Copyright (c) 2020 yusanshi | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
# News Recommendation System | ||
|
||
Currently included model: | ||
|
||
| Model | Full name | Paper | | ||
| ----- | ------------------------------------------------------------ | --------------------------------------------- | | ||
| NRMS | Neural News Recommendation with Multi-Head Self-Attention | https://www.aclweb.org/anthology/D19-1671/ | | ||
| NAML | Neural News Recommendation with Attentive Multi-View Learning | https://arxiv.org/abs/1907.05576 | | ||
| LSTUR | Neural News Recommendation with Long- and Short-term User Representations | https://www.aclweb.org/anthology/P19-1033.pdf | | ||
|
||
## Get started | ||
|
||
Basic setup. | ||
|
||
```bash | ||
git clone https://github.com/yusanshi/NewsRecommendation | ||
cd NewsRecommendation | ||
pip3 install -r requirements.txt | ||
``` | ||
|
||
Download GloVe pre-trained word embedding. | ||
``` | ||
mkdir data && cd data | ||
wget https://nlp.stanford.edu/data/glove.6B.zip | ||
sudo apt install unzip | ||
unzip glove.6B.zip -d glove | ||
rm glove.6B.zip | ||
``` | ||
|
||
Download the dataset. | ||
|
||
```bash | ||
# By downloading the dataset, you agree to the [Microsoft Research License Terms](https://go.microsoft.com/fwlink/?LinkID=206977). For more detail about the dataset, see https://msnews.github.io/. | ||
wget https://mind201910small.blob.core.windows.net/release/MINDsmall_train.zip https://mind201910small.blob.core.windows.net/release/MINDsmall_dev.zip | ||
unzip MINDsmall_train.zip -d train | ||
unzip MINDsmall_dev.zip -d test | ||
rm MINDsmall_*.zip | ||
|
||
# Preprocess data into appropriate format | ||
cd .. | ||
python3 src/data_preprocess.py | ||
# Remember you shoud modify `num_words` in `src/config.py` by the output of `src/data_preprocess.py` | ||
``` | ||
|
||
Modify `src/config.py` to select target model. The configuration file is organized into general part (which is applied to all models) and model-specific part (that some models not have). | ||
|
||
```bash | ||
vim src/config.py | ||
``` | ||
|
||
Run. | ||
|
||
```bash | ||
python3 src/train.py | ||
python3 src/inference.py | ||
python3 src/evaluate.py | ||
|
||
# or | ||
|
||
chmod +x run.sh | ||
./run.sh | ||
``` | ||
|
||
You can visualize the training loss and accuracy with TensorBoard. | ||
|
||
```bash | ||
tensorboard --logdir=runs | ||
``` | ||
|
||
Note the metrics in validation will differ greatly with final result on evaluation set. You should use it for reference only. | ||
|
||
## Credits | ||
|
||
- Dataset by **MI**crosoft **N**ews **D**ataset (MIND), see <https://msnews.github.io/>. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
torch | ||
numpy | ||
pandas | ||
tensorboard | ||
tqdm | ||
nltk |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
#!/bin/bash | ||
|
||
for i in {0..9}; do | ||
LOAD_CHECKPOINT=0 python3 src/train.py | ||
python3 src/inference.py | ||
python3 src/evaluate.py | ||
done |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import os | ||
|
||
# Currently included model: 'NRMS', 'NAML', 'LSTUR' | ||
model_name = 'NRMS' | ||
|
||
|
||
class BaseConfig(): | ||
""" | ||
General configurations appiled to all models | ||
""" | ||
num_batches = 8000 # Number of batches to train | ||
num_batches_batch_loss = 50 # Number of batchs to show loss | ||
# Number of batchs to check loss and accuracy on validation dataset | ||
num_batches_val_loss_and_acc = 500 | ||
num_batches_save_checkpoint = 100 | ||
batch_size = 64 | ||
learning_rate = 0.001 | ||
train_validation_split = (0.9, 0.1) | ||
num_workers = 4 # Number of workers for data loading | ||
num_clicked_news_a_user = 50 # Number of sampled click history for each user | ||
# Whether try to load checkpoint | ||
load_checkpoint = os.environ[ | ||
'LOAD_CHECKPOINT'] == '1' if 'LOAD_CHECKPOINT' in os.environ else True | ||
num_words_title = 20 | ||
word_freq_threshold = 3 | ||
negative_sampling_ratio = 4 | ||
inference_radio = 0.1 | ||
dropout_probability = 0.2 | ||
# Modify the following by the output of `src/dataprocess.py` | ||
num_words = 1 + 15352 | ||
word_embedding_dim = 300 | ||
# For additive attention | ||
query_vector_dim = 200 | ||
category_embedding_dim = 100 | ||
num_words_abstract = 50 | ||
|
||
|
||
class NRMSConfig(BaseConfig): | ||
# For multi-head self-attention | ||
num_attention_heads = 15 | ||
num_categories = 1 + 274 | ||
|
||
|
||
class NAMLConfig(BaseConfig): | ||
# For CNN | ||
num_filters = 400 | ||
window_size = 3 | ||
|
||
|
||
class LSTURConfig(BaseConfig): | ||
# For CNN | ||
num_filters = 300 | ||
window_size = 3 | ||
# 'ini' or 'con'. See paper for more detail | ||
long_short_term_method = 'ini' | ||
masking_probability = 0.5 | ||
# Modify the following by the output of `src/dataprocess.py` | ||
num_users = 1 + 50000 | ||
num_categories = 1 + 274 |
Oops, something went wrong.