diff --git a/README.md b/README.md index e4ac1d8..0b0af61 100644 --- a/README.md +++ b/README.md @@ -90,14 +90,18 @@ After preprocessing, the data should be organized as following: There are four modes supported in F3-Net​. -| Mode(string) | | -| ------------ | ------------------------------------------------------- | -| 'FAD' | Use FAD branch only. | -| 'LFS' | Use LFS branch only. | -| 'Both' | Use both of branches and concate before classification. | -| 'Mix' | Use both of branches and MixBlock. | - - +| Mode(string) | | +| ------------------ | ------------------------------------------------------- | +| 'FAD' | Use FAD branch only. | +| 'LFS' | Use LFS branch only. | +| 'Both' | Use both of branches and concate before classification. | +| 'Mix'(unavailable) | Use both of branches and MixBlock. | + + **Note**: + +Mode 'Mix' is unavailable yet. If you're interested in this part, check 'class Mixblock' in models.py. + + #### Run diff --git a/models.py b/models.py index 67f40c3..42f1975 100644 --- a/models.py +++ b/models.py @@ -314,6 +314,8 @@ def fea_8_12(self, x): return x class MixBlock(nn.Module): + # An implementation of the cross attention module in F3-Net + # Haven't added into the whole network yet def __init__(self, c_in, width, height): super(MixBlock, self).__init__() self.FAD_query = nn.Conv2d(c_in, c_in, (1,1)) diff --git a/train.py b/train.py index 1573e24..ea2c4ff 100644 --- a/train.py +++ b/train.py @@ -104,5 +104,6 @@ model.model.train() epoch = epoch + 1 + model.model.eval() auc, r_acc, f_acc = evaluate(model, dataset_path, mode='test') logger.debug(f'(Test @ epoch {epoch}) auc: {auc}, r_acc: {r_acc}, f_acc:{f_acc}')