From 3061eacd066270a2b32ee0381a1e406c66f8698d Mon Sep 17 00:00:00 2001 From: Laura Lahesoo Date: Thu, 13 Jul 2023 15:43:39 +0200 Subject: [PATCH] models: Fix runtime tracking in multi-sample prediction. --- code/models/MLPAutoEncoder.py | 2 ++ code/models/RandomForest.py | 5 +++++ 2 files changed, 7 insertions(+) diff --git a/code/models/MLPAutoEncoder.py b/code/models/MLPAutoEncoder.py index 3babffc..27f7f6c 100644 --- a/code/models/MLPAutoEncoder.py +++ b/code/models/MLPAutoEncoder.py @@ -112,8 +112,10 @@ def predict(self, data: EncodedSampleGenerator, **kwargs) -> SampleGenerator: prediction = self.model_instance.predict(encoded_sample) if isinstance(sample, list): + sum_processing_time += time.process_time_ns() - start_time_ref # Handle the prediction for multi-sample encoding. for i, sample in enumerate(sample): + start_time_ref = time.process_time_ns() sample[PredictionField.MODEL_NAME] = self.model_name sample[PredictionField.OUTPUT_DISTANCE] = prediction[i] sum_processing_time += time.process_time_ns() - start_time_ref diff --git a/code/models/RandomForest.py b/code/models/RandomForest.py index 81c3827..4c241a8 100644 --- a/code/models/RandomForest.py +++ b/code/models/RandomForest.py @@ -94,13 +94,18 @@ def predict(self, data: EncodedSampleGenerator, **kwargs) ->SampleGenerator: for sample, encoded_sample in data: start_time_ref = time.process_time_ns() prediction = self.model_instance.predict(encoded_sample) + if isinstance(sample, list): + sum_processing_time += time.process_time_ns() - start_time_ref + # Handle the prediction for multi-sample encoding. for i, sample in enumerate(sample): + start_time_ref = time.process_time_ns() sample[PredictionField.MODEL_NAME] = self.model_name sample[PredictionField.OUTPUT_BINARY] = prediction[i] sum_processing_time += time.process_time_ns() - start_time_ref sum_samples += 1 yield sample + else: sample[PredictionField.MODEL_NAME] = self.model_name sample[PredictionField.OUTPUT_BINARY] = prediction[0]