Skip to content

Commit

Permalink
update config args and add to docs
Browse files Browse the repository at this point in the history
  • Loading branch information
bw4sz committed Dec 14, 2023
1 parent ff63abb commit 14f895f
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 7 deletions.
27 changes: 21 additions & 6 deletions deepforest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,19 +55,34 @@ def __init__(self,

print("Reading config file: {}".format(config_path))
self.config = utilities.read_config(config_path)
self.config["num_classes"] = num_classes
# If num classes is specified, overwrite config
if not num_classes == 1:
warnings.warn(
"Directly specifying the num_classes arg in deepforest.main will be deprecated in 2.0 in favor of config_args. Use deepforest.main(config_args={'num_classes':value})"
)

# Update config with user supplied arguments
if config_args:
for key, value in config_args.items():
if key not in self.config.keys():
raise ValueError(
"Config argument {} not found in config file".format(key))
if type(value) == dict:
for subkey, subvalue in value.items():
print("setting config {} to {}").format(subkey, subvalue)
self.config[key][subkey] = subvalue
else:
print("setting config {} to {}".format(key, value))
self.config[key] = value

self.model = model

# release version id to flag if release is being used
self.__release_version__ = None

self.config["num_classes"] = num_classes
self.create_model()

# If num classes is specified, overwrite config
if not num_classes == 1:
warnings.warn(
"Directly specifying the num_classes arg in deepforest.main will be deprecated in 2.0 in favor of config_args. Use deepforest.main(config_args={'num_classes':value})"
)

# Metrics
self.iou_metric = IntersectionOverUnion(
Expand Down
18 changes: 18 additions & 0 deletions docs/ConfigurationFile.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,24 @@ validation:
iou_threshold: 0.4
val_accuracy_interval: 20
```
## Passing config arguments at runtime using a dict

It can often be useful to pass config args directly to a model instead of editing the config file. By using a dict with that matches with the config keys, main.deepforest will update the config after reading from file.

```
m = main.deepforest()
assert not m.config["num_classes"] == 2
m = main.deepforest(config_args={"num_classes":2}, label_dict={"Alive":0,"Dead":1})
assert m.config["num_classes"] == 2
# These call also be nested for train and val arguments
m = main.deepforest()
assert not m.config["train"]["epochs"] == 7
m = main.deepforest(config_args={"train":{"epochs":7}})
assert m.config["train"]["epochs"] == 7
```

## Dataloaders

Expand Down
16 changes: 15 additions & 1 deletion tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,4 +431,18 @@ def test_iou_metric(m):
results = m.trainer.validate(m)
keys = ['val_classification', 'val_bbox_regression', 'iou', 'iou/cl_0']
for x in keys:
assert x in list(results[0].keys())
assert x in list(results[0].keys())

def test_config_args():
m = main.deepforest()
assert not m.config["num_classes"] == 2

m = main.deepforest(config_args={"num_classes":2}, label_dict={"Alive":0,"Dead":1})
assert m.config["num_classes"] == 2

# These call also be nested for train and val arguments
m = main.deepforest()
assert not m.config["train"]["epochs"] == 7

m = main.deepforest(config_args={"train":{"epochs":7}})
assert m.config["train"]["epochs"] == 7

0 comments on commit 14f895f

Please sign in to comment.