Skip to content

Commit

Permalink
add TTA support from geo-inference
Browse files Browse the repository at this point in the history
  • Loading branch information
Turgeon-Pelchat committed Oct 4, 2024
1 parent e8d411c commit ccba099
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 1 deletion.
4 changes: 3 additions & 1 deletion config/inference/default_binary.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ inference:
workers: 0
prep_data_only: False
heatmap_threshold: 0.3

flip: False
rotate: True

# GPU parameters
gpu: ${training.num_gpus}
max_used_perc: ${training.max_used_perc} # If GPU's usage exceeds this percentage, it will be ignored
Expand Down
2 changes: 2 additions & 0 deletions config/inference/default_multiclass.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ inference:
workers: 0
prep_data_only: False
heatmap_threshold: 0.3
flip: False
rotate: True

# GPU parameters
gpu: ${training.num_gpus}
Expand Down
5 changes: 5 additions & 0 deletions docs/source/mode.rst
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,11 @@ will be found in :ref:`configurationdefaultparam` under ``inference`` and this c
enhancement is applied and equalized images save to disk.
- ``heatmap_threshold`` (float)
Prediction probability Threshold (fraction of 1) to use. Default is ``0.3``.
- ``flip`` (bool)
If True, perform horizontal and vertical flips during inference.
- ``rotate`` (bool)
If True, perform 90 degree rotation at inference.

- ``gpu`` (int)
Number of gpus to use at inference.
- ``max_used_perc`` (int)
Expand Down
5 changes: 5 additions & 0 deletions inference_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def main(params:Union[DictConfig, Dict]):
input_stac_item = get_key_def('input_stac_item', params['inference'], expected_type=str, to_path=True,
validate_path_exists=True)
vectorize = get_key_def('ras2vec', params['inference'], expected_type=bool, default=False)
transform_flip = get_key_def('flip', params['inference'], expected_type=bool, default=False)
transform_rotate = get_key_def('rotate', params['inference'], expected_type=bool, default=False)

if raw_data_csv and input_stac_item:
raise ValueError(f"Input imagery should be either a csv of stac item. Got inputs from both \"raw_data_csv\" "
Expand Down Expand Up @@ -106,6 +108,9 @@ def main(params:Union[DictConfig, Dict]):
device=device_str,
gpu_id=gpu_index,
prediction_threshold=prediction_threshold,
transformer=True,
transformer_flip=transform_flip,
transformer_rotate=transform_rotate,
)

# LOOP THROUGH LIST OF INPUT IMAGES
Expand Down

0 comments on commit ccba099

Please sign in to comment.