Skip to content

Commit

Permalink
Fixed model.py and train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
David Hovstadius committed Oct 28, 2024
1 parent 75465b2 commit f88d897
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 7 deletions.
6 changes: 2 additions & 4 deletions client/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,13 @@
helper = get_helper(HELPER_MODULE)

def compile_model():
yaml_file = glob.glob("yolov8*.yaml")

if torch.cuda.is_available():
device = 'cuda'
elif torch.backends.mps.is_available():
device = 'mps'
else:
device = 'cpu'
return YOLO(yaml_file[0]).to(device)
return YOLO('model.yaml').to(device)


def load_parameters(model_path):
Expand All @@ -33,7 +31,7 @@ def load_parameters(model_path):
model = compile_model()
params_dict = zip(model.state_dict().keys(), parameters_np)
state_dict = collections.OrderedDict({key: torch.tensor(x) for key, x in params_dict})
model.load_state_dict(state_dict, strict=True)
model.load_state_dict(state_dict, strict=False)
with tempfile.NamedTemporaryFile(suffix='.pt') as tmp_file:
torch.save(model,tmp_file.name)
model = YOLO(tmp_file.name)
Expand Down
6 changes: 3 additions & 3 deletions client/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from data import get_train_size
import yaml

def train(in_model_path, out_model_path, data_yaml_path='data.yaml', epochs=10,batch_size=16):
def train(in_model_path, out_model_path, epochs=10, data_yaml_path='data.yaml', batch_size=16):
"""Complete a model update using YOLOv8.
Load model parameters from in_model_path (managed by the FEDn client),
Expand Down Expand Up @@ -41,7 +41,7 @@ def train(in_model_path, out_model_path, data_yaml_path='data.yaml', epochs=10,b
epochs = config.get('local_epochs', epochs)
batch_size = config.get('batch_size', batch_size)
else:
print(f"Config file not found at {config_path}. Using default epochs ({epochs}) and batch size ({batch_size}).")
print(f"Client config file not found at {config_path}. Using default epochs ({epochs}) and batch size ({batch_size}).")

# Train the model and remove the unnecessary files
with tempfile.TemporaryDirectory() as tmp_dir:
Expand All @@ -60,7 +60,7 @@ def train(in_model_path, out_model_path, data_yaml_path='data.yaml', epochs=10,b

if __name__ == "__main__":
if len(sys.argv) < 3:
print("Usage: python train.py <in_model_path> <out_model_path> [data_yaml_path] [epochs]")
print("Usage: python train.py <in_model_path> <out_model_path> [epochs]")
sys.exit(1)

in_model_path = sys.argv[1]
Expand Down

0 comments on commit f88d897

Please sign in to comment.