Skip to content

Commit

Permalink
models: Fix runtime tracking in multi-sample prediction.
Browse files Browse the repository at this point in the history
  • Loading branch information
l-laura committed Jul 13, 2023
1 parent 988bfd0 commit 3061eac
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 0 deletions.
2 changes: 2 additions & 0 deletions code/models/MLPAutoEncoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,10 @@ def predict(self, data: EncodedSampleGenerator, **kwargs) -> SampleGenerator:
prediction = self.model_instance.predict(encoded_sample)

if isinstance(sample, list):
sum_processing_time += time.process_time_ns() - start_time_ref
# Handle the prediction for multi-sample encoding.
for i, sample in enumerate(sample):
start_time_ref = time.process_time_ns()
sample[PredictionField.MODEL_NAME] = self.model_name
sample[PredictionField.OUTPUT_DISTANCE] = prediction[i]
sum_processing_time += time.process_time_ns() - start_time_ref
Expand Down
5 changes: 5 additions & 0 deletions code/models/RandomForest.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,18 @@ def predict(self, data: EncodedSampleGenerator, **kwargs) ->SampleGenerator:
for sample, encoded_sample in data:
start_time_ref = time.process_time_ns()
prediction = self.model_instance.predict(encoded_sample)

if isinstance(sample, list):
sum_processing_time += time.process_time_ns() - start_time_ref
# Handle the prediction for multi-sample encoding.
for i, sample in enumerate(sample):
start_time_ref = time.process_time_ns()
sample[PredictionField.MODEL_NAME] = self.model_name
sample[PredictionField.OUTPUT_BINARY] = prediction[i]
sum_processing_time += time.process_time_ns() - start_time_ref
sum_samples += 1
yield sample

else:
sample[PredictionField.MODEL_NAME] = self.model_name
sample[PredictionField.OUTPUT_BINARY] = prediction[0]
Expand Down

0 comments on commit 3061eac

Please sign in to comment.