Skip to content

Commit

Permalink
version bump + fixes + hrank example
Browse files Browse the repository at this point in the history
  • Loading branch information
alvinwan committed Jul 14, 2023
1 parent 10583d8 commit 33e2fc5
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 4 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Mask and prune channels, using the default magnitude pruner.
import torch, torchvision
from upscale import MaskingManager, PruningManager

x = torch.rand((1, 3, 224, 224)).cuda()
x = torch.rand((1, 3, 224, 224), device='cuda')
model = torchvision.models.get_model('resnet18', pretrained=True).cuda() # get any pytorch model
MaskingManager(model).importance().mask()
PruningManager(model).compute([x]).prune()
Expand Down
29 changes: 29 additions & 0 deletions examples/heuristic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""
For licensing see accompanying LICENSE file.
Copyright (C) 2023 Apple Inc. All Rights Reserved.
How to customize heuristic used for pruning, in UPSCALE
"""

import torch, torchvision
from upscale import MaskingManager, PruningManager
from upscale.masking.importance import HRank, LAMP


# for most heuristics, simply pass the heuristic to `.importance(...)`
x = torch.rand((1, 3, 224, 224), device='cuda')
model = torchvision.models.get_model('resnet18', pretrained=True).cuda() # get any pytorch model
MaskingManager(model).importance(LAMP()).mask()
PruningManager(model).compute([x]).prune()

# for HRank, we need to run several forward passes to collect feature map
# statistics

heuristic = HRank()
heuristic.register(model)
for _ in range(10):
model(torch.rand((1, 3, 224, 224), device='cuda'))
heuristic.deregister(model)

MaskingManager(model).importance(heuristic).mask()
PruningManager(model).compute([x]).prune()
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

[project]
name = "apple-upscale" # Required
version = "0.1.0" # Required
version = "0.1.1" # Required
description = "Export utility for unconstrained channel pruned models" # Optional
readme = "README.md" # Optional
requires-python = ">=3.7"
Expand Down
1 change: 1 addition & 0 deletions src/upscale/masking/importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def hook(module, input, output):
logging.debug(f"Loaded precomputed rank for {module._name})")

model._handles = []
model._name = ''
is_hook_registered = False
for name, module in model.named_modules():
if isinstance(module, nn.Conv2d):
Expand Down
4 changes: 2 additions & 2 deletions src/upscale/pruning/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class PruningManager(nn.Module):
>>> manager = PruningManager(net)
>>> pruned_inputs = [('layer1.1.conv1.weight', 5)]
>>> manager.compute([x], pruned_inputs=pruned_inputs)
>>> _ = manager.compute([x], pruned_inputs=pruned_inputs)
Then, during training, you can use the mock-pruned model. This mock-pruned
model applies masks instead of modifying the model itself.
Expand All @@ -43,7 +43,7 @@ class PruningManager(nn.Module):
Finally, actually prune the model. Then run inference using the *original
(now modified, in place) model.
>>> manager.prune()
>>> _ = manager.prune()
>>> y_pruned = net(x)
Check that both the mocked and pruned outputs match.
Expand Down
19 changes: 19 additions & 0 deletions tests/test_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""
For licensing see accompanying LICENSE file.
Copyright (C) 2023 Apple Inc. All Rights Reserved.
Tests for the UPSCALE library external-facing API
"""

import torch
import torch.nn as nn
from torchvision.models import resnet18

from upscale import MaskingManager, PruningManager


def test_simple():
x = torch.rand((1, 3, 224, 224), device='cuda')
model = resnet18().cuda()
MaskingManager(model).importance().mask(amount=0.1)
PruningManager(model).compute([x]).prune()

0 comments on commit 33e2fc5

Please sign in to comment.