-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
version bump + fixes + hrank example
- Loading branch information
Showing
6 changed files
with
53 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
""" | ||
For licensing see accompanying LICENSE file. | ||
Copyright (C) 2023 Apple Inc. All Rights Reserved. | ||
How to customize heuristic used for pruning, in UPSCALE | ||
""" | ||
|
||
import torch, torchvision | ||
from upscale import MaskingManager, PruningManager | ||
from upscale.masking.importance import HRank, LAMP | ||
|
||
|
||
# for most heuristics, simply pass the heuristic to `.importance(...)` | ||
x = torch.rand((1, 3, 224, 224), device='cuda') | ||
model = torchvision.models.get_model('resnet18', pretrained=True).cuda() # get any pytorch model | ||
MaskingManager(model).importance(LAMP()).mask() | ||
PruningManager(model).compute([x]).prune() | ||
|
||
# for HRank, we need to run several forward passes to collect feature map | ||
# statistics | ||
|
||
heuristic = HRank() | ||
heuristic.register(model) | ||
for _ in range(10): | ||
model(torch.rand((1, 3, 224, 224), device='cuda')) | ||
heuristic.deregister(model) | ||
|
||
MaskingManager(model).importance(heuristic).mask() | ||
PruningManager(model).compute([x]).prune() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
""" | ||
For licensing see accompanying LICENSE file. | ||
Copyright (C) 2023 Apple Inc. All Rights Reserved. | ||
Tests for the UPSCALE library external-facing API | ||
""" | ||
|
||
import torch | ||
import torch.nn as nn | ||
from torchvision.models import resnet18 | ||
|
||
from upscale import MaskingManager, PruningManager | ||
|
||
|
||
def test_simple(): | ||
x = torch.rand((1, 3, 224, 224), device='cuda') | ||
model = resnet18().cuda() | ||
MaskingManager(model).importance().mask(amount=0.1) | ||
PruningManager(model).compute([x]).prune() |