forked from facebookresearch/pytorchvideo
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest_models_memory_bank.py
37 lines (30 loc) · 1.01 KB
/
test_models_memory_bank.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import unittest
import torch
from pytorchvideo.models.memory_bank import MemoryBank
from torch import nn
class TestMemoryBank(unittest.TestCase):
def setUp(self):
super().setUp()
torch.set_rng_state(torch.manual_seed(42).get_state())
def test_memory_bank(self):
simclr = MemoryBank(
backbone=nn.Linear(8, 4),
mlp=nn.Linear(4, 2),
temperature=0.07,
bank_size=8,
dim=2,
)
for crop, ind in TestMemoryBank._get_inputs():
simclr(crop, ind)
@staticmethod
def _get_inputs(bank_size: int = 8) -> torch.tensor:
"""
Provide different tensors as test cases.
Yield:
(torch.tensor): tensor as test case input.
"""
# Prepare random inputs as test cases.
shapes = ((2, 8),)
for shape in shapes:
yield torch.rand(shape), torch.randint(0, bank_size, size=(shape[0],))