Skip to content

Commit

Permalink
Merge pull request #160 from ArgoHA/master
Browse files Browse the repository at this point in the history
ruff 100 formatting
  • Loading branch information
Peterande authored Feb 12, 2025
2 parents 8944d61 + 2c0318b commit 36ca0d3
Show file tree
Hide file tree
Showing 86 changed files with 3,495 additions and 2,478 deletions.
17 changes: 10 additions & 7 deletions reference/convert_weight.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import torch
import os
import argparse
import os

import torch


def save_only_ema_weights(checkpoint_file):
"""Extract and save only the EMA weights."""
checkpoint = torch.load(checkpoint_file, map_location='cpu')
checkpoint = torch.load(checkpoint_file, map_location="cpu")

weights = {}
if 'ema' in checkpoint:
weights['model'] = checkpoint['ema']['module']
if "ema" in checkpoint:
weights["model"] = checkpoint["ema"]["module"]
else:
raise ValueError("The checkpoint does not contain 'ema'.")

Expand All @@ -19,9 +21,10 @@ def save_only_ema_weights(checkpoint_file):
torch.save(weights, output_file)
print(f"EMA weights saved to {output_file}")

if __name__ == '__main__':

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Extract and save only EMA weights.")
parser.add_argument('checkpoint_file', type=str, help="Path to the input checkpoint file.")
parser.add_argument("checkpoint_file", type=str, help="Path to the input checkpoint file.")

args = parser.parse_args()
save_only_ema_weights(args.checkpoint_file)
5 changes: 1 addition & 4 deletions src/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,4 @@
"""

# for register purpose
from . import optim
from . import data
from . import nn
from . import zoo
from . import data, nn, optim, zoo
4 changes: 2 additions & 2 deletions src/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
"""

from .workspace import GLOBAL_CONFIG, register, create
from .yaml_utils import *
from ._config import BaseConfig
from .workspace import GLOBAL_CONFIG, create, register
from .yaml_config import YAMLConfig
from .yaml_utils import *
Loading

0 comments on commit 36ca0d3

Please sign in to comment.