Skip to content

Commit

Permalink
Merge branch 'release-0.5'
Browse files Browse the repository at this point in the history
  • Loading branch information
xju committed Aug 12, 2020
2 parents bc7c470 + e498b01 commit 5097d69
Show file tree
Hide file tree
Showing 35 changed files with 1,555 additions and 860 deletions.
89 changes: 89 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,92 @@ git clone https://github.com/xju2/root_gnn.git
cd root_gnn
pip install -e .
```

## Input data
Save data in the folder `data`.

## Commands for the edge classifier
* Graph Construction
```bash
create_tfrecord data/wboson_big.txt tfRec/fully_connected --max-evts 200 --evts-per-record 200 --type WTaggerDataset
```
The graphs are saved into the folder `tfRec`. Then create another folder `tfRec_val` and move some TFRecord files to the folder `tfRec_val` for validation purpose.
Then modify `tfrec_dir_train` and `tfrec_dir_val` in the `configs/train_wtaggers_edges.yaml` so that they point to the training and validation data.
Modify `output_dir` so it points to a output directory.
* Graph Training
```bash
train_classifier configs/train_wtaggers_edges.yaml
```

create a `data/wboson_small.txt` file from the `wboson.txt` using the events that are not used in training.
* Evaluation
```bash
evaluate_wtagger data/wboson_small.txt configs/train_wtaggers_edges.yaml test --nevts 10
```

* Metrics calculation
```bash
calculate_wtagger_metrics test.npz test
```

## Commands for event classifier
Traing two event classifiers with different inputs, one from the edge classifier and the other from the anti-$k_t$ algorithm.

### event classifier using outputs from the edge classifier
* Graph Construction for W boson events
```bash
create_tfrecord "tfRec/fully_connected*.tfrec" tfRec_filtered/wboson \
--type WTaggerFilteredDataset \
--signal --model-config configs/train_wtaggers_edges.yaml \
--max-evts 100 --evts-per-record 100
```

* Graph Construction for q* events
```bash
create_tfrecord data/qstar.txt tfrec_qcd/qcd --type WTaggerDataset --max-evts 100 --evts-per-record 100
```
```bash
create_tfrecord "tfrec_qcd/qcd*.tfrec" \
tfRec_filtered/qcd \
--type WTaggerFilteredDataset \
--model-config configs/train_wtaggers_edges.yaml \
--max-evts 100 --evts-per-record 100
```
Create `tfRec_filtered_val` and put events there for validation.
Change the `tfrec_dir_train`, `tfrec_dir_val` and `output_dir` in `configs/train_w_qcd.yaml` accordingly.
* Graph training
```bash
train_classifier configs/train_w_qcd.yaml
```

### event classifier using outputs from the anti-$k_t$ algorithm
* Graph construction for W boson events
```bash
create_tfrecord data/wboson.txt \
tfRec_ljet/wboson \
--type WTaggerLeadingJetDataset \
--signal \
--max-evts 95000 --evts-per-record 1000
```

* Graphc osntruction for q* events
```bash
create_tfrecord data/qstar.txt \
tfRec_ljet/qcd \
--type WTaggerLeadingJetDataset \
--max-evts 95000 --evts-per-record 1000
```
Create `tfRec_ljet_val` and put events there for validation.
* Graph training
Again, change the `tfrec_dir_train`, `tfrec_dir_val` and `output_dir` in `configs/train_w_qcd_ljet.yaml` accordingly.
```bash
train_classifier configs/train_w_qcd_ljet.yaml
```

### Evaluate both GNNs
```bash
evaluate_w_qcd_classifier "tfRec_filtered_val/*.tfrec" configs/train_w_qcd.yaml classifier_gnn
```
```bash
evaluate_w_qcd_classifier "tfRec_ljet_val/*.tfrec" configs/train_w_qcd_ljet.yaml classifier_ljet
```
16 changes: 16 additions & 0 deletions configs/test_summary.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
tfrec_dir_train: /global/cscratch1/sd/xju/WbosonTagger/tfrec_bigger/*_0.tfrec
tfrec_dir_val: /global/cscratch1/sd/xju/WbosonTagger/tfrec_val_bigger/*_95.tfrec
output_dir: /global/cscratch1/sd/xju/WbosonTagger/trained
prod_name: TESTSummary
model_name: EdgeClassifier
loss_name: EdgeLoss, 2, 1
parameters:
batch_size: 1
n_iters: 5
learning_rate: 0.0001
epochs: 2
earlystop_metric: "auc_te" #auc_te, acc_te, pre_te, rec_te
acceptable_failure: 5
do_profiling: true
profiling_steps: 1000
do_profiling_only: true
14 changes: 14 additions & 0 deletions configs/train_w_qcd.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
tfrec_dir_train: /global/homes/x/xju/work/WbosonTagger/tfrec_filtered/*.tfrec
tfrec_dir_val: /global/homes/x/xju/work/WbosonTagger/tfrec_filtered_val/*.tfrec
output_dir: /global/homes/x/xju/work/WbosonTagger/trained/w_qcd
prod_name: Test
model_name: GlobalClassifierNoEdgeInfo
loss_name: GlobalLoss, 1, 1
parameters:
batch_size: 1
n_iters: 8
learning_rate: 0.0005
epochs: 10
earlystop_metric: "auc_te" #auc_te, acc_te, pre_te, rec_te
acceptable_failure: 10
shuffle_buffer_size: -1
14 changes: 14 additions & 0 deletions configs/train_w_qcd_ljet.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
tfrec_dir_train: /global/homes/x/xju/work/WbosonTagger/tfrec_ljet/*.tfrec
tfrec_dir_val: /global/homes/x/xju/work/WbosonTagger/tfrec_ljet_val/*.tfrec
output_dir: /global/homes/x/xju/work/WbosonTagger/trained/w_qcd_ljet
prod_name: Test
model_name: GlobalClassifierNoEdgeInfo
loss_name: GlobalLoss, 1, 1
parameters:
batch_size: 1
n_iters: 8
learning_rate: 0.0005
epochs: 10
earlystop_metric: "auc_te" #auc_te, acc_te, pre_te, rec_te
acceptable_failure: 10
shuffle_buffer_size: -1
4 changes: 2 additions & 2 deletions root_gnn/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __call__(self, target_op, output_ops):
for output_op in output_ops
]
loss_ops += [
tf.compact.v1.losses.log_loss(target_op.edges, output_op.edges, weights=edge_weights)
tf.compat.v1.losses.log_loss(target_op.edges, output_op.edges, weights=edge_weights)
for output_op in output_ops
]
return tf.stack(loss_ops)
Expand All @@ -75,5 +75,5 @@ def __call__(self, target_op, output_ops):
return tf.stack(loss_ops)

if __name__ == "__main__":
node_edge_loss = NodeEdgeLoss(2, 1)
node_edge_loss = NodeEdgeLoss(2, 1, 2, 1)
node_edge_loss(1, 1)
Loading

0 comments on commit 5097d69

Please sign in to comment.