Skip to content

Commit

Permalink
Fixed bugs
Browse files Browse the repository at this point in the history
1. Lack "model.model.eval()" before testing in train.py
2. Modify Readme
3. Add comments for cross attention module
  • Loading branch information
yyk-wew committed Apr 20, 2021
1 parent f3b4da2 commit 574667f
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 8 deletions.
20 changes: 12 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,18 @@ After preprocessing, the data should be organized as following:

There are four modes supported in F3-Net​.

| Mode(string) | |
| ------------ | ------------------------------------------------------- |
| 'FAD' | Use FAD branch only. |
| 'LFS' | Use LFS branch only. |
| 'Both' | Use both of branches and concate before classification. |
| 'Mix' | Use both of branches and MixBlock. |


| Mode(string) | |
| ------------------ | ------------------------------------------------------- |
| 'FAD' | Use FAD branch only. |
| 'LFS' | Use LFS branch only. |
| 'Both' | Use both of branches and concate before classification. |
| 'Mix'(unavailable) | Use both of branches and MixBlock. |

**Note**:

Mode 'Mix' is unavailable yet. If you're interested in this part, check 'class Mixblock' in models.py.



#### Run

Expand Down
2 changes: 2 additions & 0 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,8 @@ def fea_8_12(self, x):
return x

class MixBlock(nn.Module):
# An implementation of the cross attention module in F3-Net
# Haven't added into the whole network yet
def __init__(self, c_in, width, height):
super(MixBlock, self).__init__()
self.FAD_query = nn.Conv2d(c_in, c_in, (1,1))
Expand Down
1 change: 1 addition & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,5 +104,6 @@
model.model.train()
epoch = epoch + 1

model.model.eval()
auc, r_acc, f_acc = evaluate(model, dataset_path, mode='test')
logger.debug(f'(Test @ epoch {epoch}) auc: {auc}, r_acc: {r_acc}, f_acc:{f_acc}')

0 comments on commit 574667f

Please sign in to comment.