-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathlinearize.py
139 lines (107 loc) · 4.62 KB
/
linearize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import abc
import os
import torch
import torch.nn as nn
from functorch import jvp, make_functional_with_buffers
from src.modeling import ImageEncoder
from src.utils import DotDict
class LinearizedModel(nn.Module):
"""Creates a linearized version of a nn.Module.
The linearized version of a model is a proper PyTorch model and can be
trained as any other nn.Module.
Args:
model (nn.Module): The model to linearize. The trainable parameters of
the linearized model will be initialized to the parameters of this
model.
init_model (nn.Module): A model of the same type as `model` containing
the parameters around which the model is initialized. If not
provided, `model` is used as the initialization model.
"""
def __init__(self, model: nn.Module, init_model: nn.Module = None) -> None:
"""Initializes the linearized model."""
super().__init__()
if init_model is None:
init_model = model
func0, params0, self.buffers0 = make_functional_with_buffers(
init_model.eval(), disable_autograd_tracking=True
)
self.func0 = lambda params, x: func0(params, self.buffers0, x)
_, params, _ = make_functional_with_buffers(
model, disable_autograd_tracking=True
)
self.params = nn.ParameterList(params)
self.params0 = nn.ParameterList(params0)
self._model_name = model.__class__.__name__
# The intial parameters are not trainable.
for p in self.params0:
p.requires_grad = False
# The params are.
for p in self.params:
p.requires_grad = True
def __call__(self, x) -> torch.Tensor:
"""Computes the linearized model output using a first-order Taylor decomposition."""
dparams = [p - p0 for p, p0 in zip(self.params, self.params0)]
out, dp = jvp(
lambda param: self.func0(param, x),
(tuple(self.params0),),
(tuple(dparams),),
)
return out + dp
class LinearizedImageEncoder(abc.ABC, nn.Module):
"""Creates a linearized version of an image encoder."""
def __init__(
self, args=None, keep_lang=False, image_encoder=None, init_encoder=None
):
super().__init__()
if image_encoder is None:
image_encoder = ImageEncoder(args, keep_lang)
if init_encoder is None:
init_encoder = image_encoder
# Copy the attributes from the image encoder.
self.train_preprocess = image_encoder.train_preprocess
self.val_preprocess = image_encoder.val_preprocess
self.cache_dir = image_encoder.cache_dir
self._model_name = self._get_name(args.model)
self.model = LinearizedModel(init_model=init_encoder, model=image_encoder)
def _get_name(self, model_name):
if "__pretrained__" in model_name:
model_name, _ = model_name.split("__pretrained__", "")
return model_name
def forward(self, x):
# use the taylorized version of the model.
return self.model(x)
def __call__(self, x):
return self.forward(x)
def save(self, filename):
"""Saves the linearized image encoder.
We save the model name in the state dict so that we can load the
correct model when loading the linearized image encoder. Directly using
torch.save would not work becuse func0 is not serializable.
Args:
filename (str): The path to save the taylorized image encoder.
"""
if os.path.dirname(filename) != "":
os.makedirs(os.path.dirname(filename), exist_ok=True)
state_dict = self.state_dict()
state_dict["model_name"] = self._model_name
torch.save(state_dict, filename)
@classmethod
def load(cls, filename):
"""Loads a linearized image encoder.
It first loads the state dict with the model name and then creates the
correct model and loads the state dict.
Args:
filename (str): The path to the taylorized image encoder.
Returns:
LinearizedImageEncoder: The loaded taylorized image encoder.
"""
print(f"Loading image encoder from {filename}")
state_dict = torch.load(filename, map_location="cpu")
# ImageEncoder expects a DotDict
args = DotDict({"model": state_dict["model_name"]})
taylorized_encoder = cls(args)
# Remove the model name from the state dict so that we can load the
# model.
state_dict.pop("model_name")
taylorized_encoder.load_state_dict(state_dict)
return taylorized_encoder