-
Notifications
You must be signed in to change notification settings - Fork 99
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
batching over model parameters #1094
Comments
is this a legal way to solve this? it doesn't give me an error but I am very unsure why this now works. def test_resnet_2(new_params):
def interpolate(alpha):
with torch.no_grad():
for i, (name, old_p) in enumerate(named_params_data):
new_p = new_params[i]
parame_names = name.split(".")
current = model_resnet
for p in parame_names[:-1]:
current = getattr(current, p)
setattr(current, parame_names[-1], torch.nn.Parameter(old_p + alpha*new_p))
out = model_resnet(sample_data)
for i, (name, old_p) in enumerate(named_params_data):
parame_names = name.split(".")
current = model_resnet
for p in parame_names[:-1]:
current = getattr(current, p)
setattr(current, parame_names[-1], torch.nn.Parameter(old_p))
return out
return interpolate
model_resnet.eval()
to_vamp_resnet = test_thing2(rand_tensor)
test_out2 = vmap(to_vamp_resnet)(alphas) EDIT: found an even simple solution. This is the correct approach, right? def test_resnet_4(new_params, sample_data, model_resnet):
func_model, params, buff = make_functional_with_buffers(model_resnet, disable_autograd_tracking=True)
def interpolate(alpha):
with torch.no_grad():
interpol_params = [torch.nn.Parameter(old_p + alpha*new_params[i]) for i, old_p in enumerate(params)]
out = func_model(interpol_params, buff, sample_data)
return out
return interpolate
model_resnet.eval()
to_vamp_resnet = test_resnet_4(rand_tensor, sample_data, model_resnet)
test_out2 = vmap(to_vamp_resnet)(alphas) |
Hi @LeanderK! Thanks for the interesting issue! Since it sounds like this works, that's a totally fine way of doing it! One thing that might come up is if you do For this use case, since it looks like you want to have very specific initializations, it this might be better to riff on the idea of the ensemble API def test_resnet_4(func_model, buff, sample_data):
def interpolate(interpol_params):
with torch.no_grad():
out = func_model(interpol_params, buff, sample_data)
return out
return interpolate
model_resnet.eval()
func_model, params, buff = make_functional_with_buffers(model_resnet, disable_autograd_tracking=True)
interpol_params = [[torch.nn.Parameter(old_p + alpha*rand_tensor[i]) for i, old_p in enumerate(params)] for alpha in alphas]
interpol_params = [torch.stack(i) for i in zip(*interpol_params)] # this is basically what the ensemble API is doing
to_vmap_resnet = test_resnet_4(func_model, buff, sample_data)
test_out2 = vmap(to_vmap_resnet)(interpol_params) Then, if you want to train, you can also expand the buffers and vmap across them along with interpol_params so that batch norm works Hope that helps! We are also looking at changing the module API to help rationalize some of the functorch API with the PyTorch API soon. If you're using the nightly build, I can point you to the new API if you're curious |
I have a use-case for
functorch
. I would like to check possible iterations of model parameters in a very efficient way (I want to eliminate the loop). Here's an example code for a simplified case I got it working:now I could do
for alpha in np.np.linspace(0.0, 1.0, 100)
but I want to vectorise this loop since my code is prohibitively slow. Is functorch here applicable? Executing:works, but how to do something similar for a simple resnet does not work. I've tried using
load_state_dict
but that's not working:results in:
While copying the parameter named "fc.bias", whose dimensions in the model are torch.Size([1000]) and whose dimensions in the checkpoint are torch.Size([1000]), an exception occurred : ('vmap: inplace arithmetic(self, *extra_args) is not possible because there exists a Tensor
otherin extra_args that has more elements than
self. This happened due to
otherbeing vmapped over but
selfnot being vmapped over in a vmap. Please try to use out-of-place operators instead of inplace arithmetic. If said operator is being called inside the PyTorch framework, please file a bug report instead.',).
The text was updated successfully, but these errors were encountered: