Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to get only the last few layers' gradident? #1101

Open
pmzzs opened this issue Jan 13, 2023 · 2 comments
Open

How to get only the last few layers' gradident? #1101

pmzzs opened this issue Jan 13, 2023 · 2 comments

Comments

@pmzzs
Copy link

pmzzs commented Jan 13, 2023

from functorch import make_functional_with_buffers, vmap, grad
fmodel, params, buffers = make_functional_with_buffers(net,disable_autograd_tracking=True)

def compute_loss_stateless_model (params, buffers, sample, target):
    batch = sample.unsqueeze(0)
    targets = target.unsqueeze(0)

    predictions = fmodel(params, buffers, batch) 
    loss = criterion(predictions, targets)
    return loss

ft_compute_grad = grad(compute_loss_stateless_model)
gradinet = ft_compute_grad(params, buffers, train_poi_set[0][0].cuda(), torch.tensor(train_poi_set[0][1]).cuda())

This will return the gradient of the whole model. However, I only want the second last layers' gradient, like:

gradinet = ft_compute_grad(params, buffers, train_poi_set[0][0].cuda(), torch.tensor(train_poi_set[0][1]).cuda())[-2]

Although this method can also obtain the required gradient, it will cause a lot of unnecessary overhead. Is there any way to close the 'require_grad' of all previous layers? Thanks for your answer!

@zou3519
Copy link
Contributor

zou3519 commented Jan 17, 2023

functorch.grad computes gradients w.r.t. to the first argument you pass it. This is currently params (all parameters in the model), but the solution is to pass it only the parameters that you want gradients of.

Some pseudocode.

from functorch import make_functional_with_buffers, vmap, grad
fmodel, params, buffers = make_functional_with_buffers(net,disable_autograd_tracking=True)

def compute_loss_stateless_model (last_layers_params, first_layers_params, buffers, sample, target):
    batch = sample.unsqueeze(0)
    targets = target.unsqueeze(0)

    # pseudocode: we need to put the params together back into a single params list
    # that fmodel can understand
    params = (*first_layers_params, *last_layers_params)

    predictions = fmodel(params, buffers, batch) 
    loss = criterion(predictions, targets)
    return loss

ft_compute_grad = grad(compute_loss_stateless_model)

# pseudocode: we need to split the params we want to compute gradients of from the params we don't
# want to compute gradients of.
first_layers_params, last_layers_params = partition(params)  

gradinet = ft_compute_grad(last_layers_params, first_layers_params, buffers, train_poi_set[0][0].cuda(), torch.tensor(train_poi_set[0][1]).cuda())

@skxgogo
Copy link

skxgogo commented Apr 5, 2024

@zou3519 I have the similar question. But it's about jacrev. For example, I only want to compute the jacobi respect to the last layers. Can this work?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants