-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
2 changed files
with
72 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# ruff: noqa: E402 | ||
""" | ||
This is a tool for dumping the state dict of a dummy model. | ||
Purpose: | ||
When adding/testing model detection or model parameter detection code, | ||
it is useful to see the effects a single parameter has on the state dict of a | ||
model. Since there aren't pretrained models for every possible parameter | ||
configuration, this script can be used to generate a dummy model with the given | ||
parameters. | ||
Usage: | ||
To use this script, you need to edit the `create_dummy` function below. Edit | ||
the function to make it return a model with your desired parameters. As always, | ||
VSCode is the recommended IDE for this task. | ||
After you edited the function, run this script, and it will dump the state dict | ||
of the dummy model to `dump.yml`. | ||
python scripts/dump_dummy.py | ||
For more detail on the dump itself, see the docs of `dump_state_dict.py`. | ||
""" | ||
|
||
|
||
import inspect | ||
import os | ||
import sys | ||
from textwrap import dedent | ||
|
||
import torch | ||
|
||
# This hack is necessary to make our module import | ||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src"))) | ||
|
||
from dump_state_dict import dump | ||
|
||
from spandrel.architectures import SCUNet | ||
|
||
|
||
def create_dummy() -> torch.nn.Module: | ||
"""Edit this function""" | ||
return SCUNet.SCUNet() | ||
|
||
|
||
if __name__ == "__main__": | ||
net = create_dummy() | ||
state = net.state_dict() | ||
|
||
# get source code expression of network | ||
source = inspect.getsource(create_dummy) | ||
source = "\n".join(source.split("\n")[1:]) # remove "def create_dummy(): | ||
source = dedent(source) | ||
if source.startswith("return "): | ||
source = source[7:] | ||
|
||
dump(state, source) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters