forked from facebookresearch/pytorchvideo
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest_models_masked_multistream.py
130 lines (111 loc) · 4.92 KB
/
test_models_masked_multistream.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import copy
import unittest
import torch
import torch.nn
from pytorchvideo.layers import PositionalEncoding, make_multilayer_perceptron
from pytorchvideo.models.masked_multistream import (
LSTM,
LearnMaskedDefault,
MaskedSequential,
MaskedTemporalPooling,
TransposeMultiheadAttention,
TransposeTransformerEncoder,
)
class TestMaskedMultiStream(unittest.TestCase):
def setUp(self):
super().setUp()
torch.set_rng_state(torch.manual_seed(42).get_state())
def test_masked_multistream_model(self):
feature_dim = 8
mlp, out_dim = make_multilayer_perceptron([feature_dim, 2])
input_stream = MaskedSequential(
PositionalEncoding(feature_dim),
TransposeMultiheadAttention(feature_dim),
MaskedTemporalPooling(method="avg"),
torch.nn.LayerNorm(feature_dim),
mlp,
LearnMaskedDefault(out_dim),
)
seq_len = 10
input_tensor = torch.rand([4, seq_len, feature_dim])
mask = _lengths2mask(
torch.tensor([seq_len, seq_len, seq_len, seq_len]), input_tensor.shape[1]
)
output = input_stream(input=input_tensor, mask=mask)
self.assertEqual(output.shape, torch.Size([4, out_dim]))
def test_masked_temporal_pooling(self):
fake_input = torch.Tensor(
[[[4, -2], [3, 0]], [[0, 2], [4, 3]], [[3, 1], [5, 2]]]
).float()
valid_lengths = torch.Tensor([2, 1, 0]).int()
valid_mask = _lengths2mask(valid_lengths, fake_input.shape[1])
expected_output_for_method = {
"max": torch.Tensor([[4, 0], [0, 2], [0, 0]]).float(),
"avg": torch.Tensor([[3.5, -1], [0, 2], [0, 0]]).float(),
"sum": torch.Tensor([[7, -2], [0, 2], [0, 0]]).float(),
}
for method, expected_output in expected_output_for_method.items():
model = MaskedTemporalPooling(method)
output = model(copy.deepcopy(fake_input), mask=valid_mask)
self.assertTrue(torch.equal(output, expected_output))
def test_transpose_attention(self):
feature_dim = 8
seq_len = 10
fake_input = torch.rand([4, seq_len, feature_dim])
mask = _lengths2mask(
torch.tensor([seq_len, seq_len, seq_len, seq_len]), fake_input.shape[1]
)
model = TransposeMultiheadAttention(feature_dim, num_heads=2)
output = model(fake_input, mask=mask)
self.assertTrue(output.shape, fake_input.shape)
def test_masked_lstm(self):
feature_dim = 8
seq_len = 10
fake_input = torch.rand([4, seq_len, feature_dim])
mask = _lengths2mask(
torch.tensor([seq_len, seq_len, seq_len, seq_len]), fake_input.shape[1]
)
hidden_dim = 128
model = LSTM(feature_dim, hidden_dim=hidden_dim, bidirectional=False)
output = model(fake_input, mask=mask)
self.assertTrue(output.shape, (fake_input.shape[0], hidden_dim))
model = LSTM(feature_dim, hidden_dim=hidden_dim, bidirectional=True)
output = model(fake_input, mask=mask)
self.assertTrue(output.shape, (fake_input.shape[0], hidden_dim * 2))
def test_masked_transpose_transformer_encoder(self):
feature_dim = 8
seq_len = 10
fake_input = torch.rand([4, seq_len, feature_dim])
mask = _lengths2mask(
torch.tensor([seq_len, seq_len, seq_len, seq_len]), fake_input.shape[1]
)
model = TransposeTransformerEncoder(feature_dim)
output = model(fake_input, mask=mask)
self.assertEqual(output.shape, (fake_input.shape[0], feature_dim))
def test_learn_masked_default(self):
feature_dim = 8
seq_len = 10
fake_input = torch.rand([4, feature_dim])
# All valid mask
all_valid_mask = _lengths2mask(
torch.tensor([seq_len, seq_len, seq_len, seq_len]), fake_input.shape[1]
)
model = LearnMaskedDefault(feature_dim)
output = model(fake_input, mask=all_valid_mask)
self.assertTrue(output.equal(fake_input))
# No valid mask
no_valid_mask = _lengths2mask(torch.tensor([0, 0, 0, 0]), fake_input.shape[1])
model = LearnMaskedDefault(feature_dim)
output = model(fake_input, mask=no_valid_mask)
self.assertTrue(output.equal(model._learned_defaults.repeat(4, 1)))
# Half valid mask
half_valid_mask = _lengths2mask(torch.tensor([1, 1, 0, 0]), fake_input.shape[1])
model = LearnMaskedDefault(feature_dim)
output = model(fake_input, mask=half_valid_mask)
self.assertTrue(output[:2].equal(fake_input[:2]))
self.assertTrue(output[2:].equal(model._learned_defaults.repeat(2, 1)))
def _lengths2mask(lengths: torch.Tensor, seq_len: int) -> torch.Tensor:
return torch.lt(
torch.arange(seq_len, device=lengths.device)[None, :], lengths[:, None].long()
)