-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
38e6995
commit 6209ed7
Showing
4 changed files
with
67 additions
and
103 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,10 +7,6 @@ | |
box pairs are computed and fed into a simple | ||
MLP to compute class logits for the 600 interactions. | ||
This method has abysmal performance (2% mAP) | ||
and only serves as an example to demonstrate | ||
the usage of pocket.core.MultiLabelClassificationEngine | ||
Fred Zhang <[email protected]> | ||
The Australian National University | ||
|
@@ -81,7 +77,7 @@ def custom_collate(batch): | |
|
||
HICO_ROOT = "./data/hicodet" | ||
if not os.path.exists(HICO_ROOT): | ||
raise ValueError("Cannot find the dataset" | ||
raise ValueError("Cannot find the dataset. " | ||
"Make sure a symbolic link is created at {}".format(HICO_ROOT)) | ||
|
||
net = Net() | ||
|
@@ -91,15 +87,15 @@ def custom_collate(batch): | |
train_loader = DataLoader( | ||
HICODet( | ||
root=os.path.join(HICO_ROOT, "hico_20160224_det/images/train2015"), | ||
annoFile=os.path.join(HICO_ROOT, "instances_train2015.json"), | ||
anno_file=os.path.join(HICO_ROOT, "instances_train2015.json"), | ||
transforms=transforms | ||
), batch_size=4, shuffle=True, num_workers=4, | ||
collate_fn=custom_collate, drop_last=True | ||
) | ||
test_loader = DataLoader( | ||
HICODet( | ||
root=os.path.join(HICO_ROOT, "hico_20160224_det/images/test2015"), | ||
annoFile=os.path.join(HICO_ROOT, "instances_test2015.json"), | ||
anno_file=os.path.join(HICO_ROOT, "instances_test2015.json"), | ||
transforms=transforms | ||
), batch_size=4, num_workers=4, | ||
collate_fn=custom_collate | ||
|
@@ -108,42 +104,25 @@ def custom_collate(batch): | |
engine = MultiLabelClassificationEngine(net, criterion, train_loader, | ||
val_loader=test_loader, ap_algorithm='11P', print_interval=500) | ||
|
||
engine(5) | ||
|
||
engine(1) | ||
|
||
# Sample output | ||
""" | ||
Sample output | ||
=> Validation (+1754.71s) | ||
Epoch: 0 | mAP: 0.0067 | Loss: 0.6940 | Time: 1751.42s | ||
=> Validation (+935.19s) | ||
Epoch: 0 | mAP: 0.0065 | Loss: 0.6958 | Time: 933.54s | ||
Epoch [1/1], Iter. [0500/9408], Loss: 0.6478, Time[Data/Iter.]: [2.75s/197.86s] | ||
Epoch [1/1], Iter. [1000/9408], Loss: 0.5503, Time[Data/Iter.]: [2.58s/198.63s] | ||
Epoch [1/1], Iter. [1500/9408], Loss: 0.4748, Time[Data/Iter.]: [2.57s/198.32s] | ||
Epoch [1/1], Iter. [2000/9408], Loss: 0.4136, Time[Data/Iter.]: [2.57s/197.08s] | ||
Epoch [1/1], Iter. [2500/9408], Loss: 0.3648, Time[Data/Iter.]: [2.60s/198.07s] | ||
... | ||
... | ||
... | ||
[Ep.][Iter.]: [5][38000] | Loss: 0.0388 | Time[Data/Iter.]: [0.4588s/203.3094s] | ||
[Ep.][Iter.]: [5][38500] | Loss: 0.0383 | Time[Data/Iter.]: [0.0629s/206.3993s] | ||
[Ep.][Iter.]: [5][39000] | Loss: 0.0387 | Time[Data/Iter.]: [0.0600s/196.8182s] | ||
[Ep.][Iter.]: [5][39500] | Loss: 0.0387 | Time[Data/Iter.]: [0.0607s/208.0784s] | ||
[Ep.][Iter.]: [5][40000] | Loss: 0.0383 | Time[Data/Iter.]: [0.0621s/200.6665s] | ||
[Ep.][Iter.]: [5][40500] | Loss: 0.0382 | Time[Data/Iter.]: [0.0621s/198.3799s] | ||
[Ep.][Iter.]: [5][41000] | Loss: 0.0376 | Time[Data/Iter.]: [0.0643s/206.2081s] | ||
[Ep.][Iter.]: [5][41500] | Loss: 0.0378 | Time[Data/Iter.]: [0.0619s/197.4531s] | ||
[Ep.][Iter.]: [5][42000] | Loss: 0.0370 | Time[Data/Iter.]: [0.0660s/199.5019s] | ||
[Ep.][Iter.]: [5][42500] | Loss: 0.0364 | Time[Data/Iter.]: [0.0620s/202.1312s] | ||
[Ep.][Iter.]: [5][43000] | Loss: 0.0370 | Time[Data/Iter.]: [0.0622s/199.2598s] | ||
[Ep.][Iter.]: [5][43500] | Loss: 0.0369 | Time[Data/Iter.]: [0.0640s/203.9033s] | ||
[Ep.][Iter.]: [5][44000] | Loss: 0.0367 | Time[Data/Iter.]: [0.0625s/200.8627s] | ||
[Ep.][Iter.]: [5][44500] | Loss: 0.0358 | Time[Data/Iter.]: [0.0637s/198.3094s] | ||
[Ep.][Iter.]: [5][45000] | Loss: 0.0358 | Time[Data/Iter.]: [0.0645s/198.1471s] | ||
[Ep.][Iter.]: [5][45500] | Loss: 0.0365 | Time[Data/Iter.]: [0.0633s/206.5252s] | ||
[Ep.][Iter.]: [5][46000] | Loss: 0.0362 | Time[Data/Iter.]: [0.0661s/205.7809s] | ||
[Ep.][Iter.]: [5][46500] | Loss: 0.0354 | Time[Data/Iter.]: [0.0639s/200.6037s] | ||
[Ep.][Iter.]: [5][47000] | Loss: 0.0355 | Time[Data/Iter.]: [0.0596s/201.4554s] | ||
=> Training (+25678.19s) | ||
Epoch: 5 | mAP: 0.0114 | Time(eval): 100.55s | ||
=> Validation (+26641.95s) | ||
Epoch: 5 | mAP: 0.0255 | Loss: 0.0284 | Time: 963.73s | ||
Epoch [1/1], Iter. [8000/9408], Loss: 0.1425, Time[Data/Iter.]: [2.52s/196.84s] | ||
Epoch [1/1], Iter. [8500/9408], Loss: 0.1364, Time[Data/Iter.]: [2.54s/198.61s] | ||
Epoch [1/1], Iter. [9000/9408], Loss: 0.1276, Time[Data/Iter.]: [2.61s/193.84s] | ||
=> Training (+4701.03s) | ||
Epoch: 1 | mAP: 0.0047 | Time(eval): 61.62s | ||
=> Validation (+5635.69s) | ||
Epoch: 1 | mAP: 0.0075 | Loss: 0.0872 | Time: 934.64s | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters