diff --git a/optimum_benchmark/benchmarks/inference/benchmark.py b/optimum_benchmark/benchmarks/inference/benchmark.py index 86e48782..d9420b0b 100644 --- a/optimum_benchmark/benchmarks/inference/benchmark.py +++ b/optimum_benchmark/benchmarks/inference/benchmark.py @@ -11,6 +11,7 @@ from ..base import Benchmark from ..report import BenchmarkMeasurements, BenchmarkReport from .config import InferenceConfig +from .inputs_utils import extract_text_generation_inputs if is_torch_distributed_available(): import torch.distributed @@ -80,9 +81,9 @@ def run(self, backend: Backend[BackendConfigT]) -> None: if backend.config.task in TEXT_GENERATION_TASKS: LOGGER.info("\t+ Generating and preparing Text Generation inputs") - self.text_generation_inputs = self.input_generator() - self.text_generation_inputs = backend.prepare_inputs(self.text_generation_inputs) - self.text_generation_inputs = {"input_ids": self.text_generation_inputs["input_ids"]} + self.forward_inputs = self.input_generator() + self.forward_inputs = backend.prepare_inputs(self.forward_inputs) + self.generate_inputs = extract_text_generation_inputs(self.forward_inputs) LOGGER.info("\t+ Updating Text Generation kwargs with default values") self.config.generate_kwargs = {**TEXT_GENERATION_KWARGS, **self.config.generate_kwargs} LOGGER.info("\t+ Initializing Text Generation report") @@ -90,7 +91,9 @@ def run(self, backend: Backend[BackendConfigT]) -> None: elif backend.config.task in IMAGE_DIFFUSION_TASKS: LOGGER.info("\t+ Generating Image Diffusion inputs") - self.image_diffusion_inputs = self.input_generator() + self.call_inputs = self.input_generator() + self.call_inputs = backend.prepare_inputs(self.call_inputs) + self.call_inputs = {"prompt": self.call_inputs["prompt"]} LOGGER.info("\t+ Updating Image Diffusion kwargs with default values") self.config.call_kwargs = {**IMAGE_DIFFUSION_KWARGS, **self.config.call_kwargs} LOGGER.info("\t+ Initializing Image Diffusion report") @@ -98,8 +101,8 @@ def run(self, backend: Backend[BackendConfigT]) -> None: else: LOGGER.info("\t+ Generating and preparing Inference inputs") - self.inference_inputs = self.input_generator() - self.inference_inputs = backend.prepare_inputs(self.inference_inputs) + self.forward_inputs = self.input_generator() + self.forward_inputs = backend.prepare_inputs(self.forward_inputs) LOGGER.info("\t+ Initializing Inference report") self.report = InferenceReport(forward=BenchmarkMeasurements()) @@ -115,11 +118,11 @@ def run(self, backend: Backend[BackendConfigT]) -> None: LOGGER.info("\t+ Warming up backend for Inference") for _ in range(self.config.warmup_runs): if backend.config.task in TEXT_GENERATION_TASKS: - _ = backend.generate(self.text_generation_inputs, {"max_new_tokens": 2, "min_new_tokens": 2}) + _ = backend.generate(self.generate_inputs, {"max_new_tokens": 2, "min_new_tokens": 2}) elif backend.config.task in IMAGE_DIFFUSION_TASKS: - _ = backend.call(self.image_diffusion_inputs, {"num_inference_steps": 2}) + _ = backend.call(self.call_inputs, {"num_inference_steps": 2}) else: - _ = backend.forward(self.inference_inputs, self.config.forward_kwargs) + _ = backend.forward(self.forward_inputs, self.config.forward_kwargs) if self.config.memory: LOGGER.info("\t+ Creating inference memory tracker") @@ -166,13 +169,13 @@ def run_text_generation_memory_tracking(self, backend: Backend): LOGGER.info("\t+ Running memory tracking") self.memory_tracker.reset() with self.memory_tracker.track(): - _ = backend.forward(self.text_generation_inputs, self.config.forward_kwargs) + _ = backend.forward(self.forward_inputs, self.config.forward_kwargs) self.report.prefill.memory = self.memory_tracker.get_max_memory() self.memory_tracker.reset() with self.memory_tracker.track(): - _ = backend.generate(self.text_generation_inputs, self.config.generate_kwargs) + _ = backend.generate(self.generate_inputs, self.config.generate_kwargs) self.report.decode.memory = self.memory_tracker.get_max_memory() @@ -180,7 +183,7 @@ def run_image_diffusion_memory_tracking(self, backend: Backend): LOGGER.info("\t+ Running memory tracking") self.memory_tracker.reset() with self.memory_tracker.track(): - _ = backend.call(self.image_diffusion_inputs, self.config.call_kwargs) + _ = backend.call(self.call_inputs, self.config.call_kwargs) self.report.call.memory = self.memory_tracker.get_max_memory() @@ -188,7 +191,7 @@ def run_inference_memory_tracking(self, backend: Backend): LOGGER.info("\t+ Running memory tracking") self.memory_tracker.reset() with self.memory_tracker.track(): - _ = backend.forward(self.inference_inputs, self.config.forward_kwargs) + _ = backend.forward(self.forward_inputs, self.config.forward_kwargs) self.report.forward.memory = self.memory_tracker.get_max_memory() @@ -198,7 +201,7 @@ def run_text_generation_latency_tracking(self, backend: Backend): self.latency_tracker.reset() while self.latency_tracker.get_elapsed_time() < self.config.duration: with self.latency_tracker.track(): - _ = backend.forward(self.text_generation_inputs, self.config.forward_kwargs) + _ = backend.forward(self.forward_inputs, self.config.forward_kwargs) forward_latency = self.latency_tracker.get_latency() forward_latency.log(prefix="forward") @@ -210,7 +213,7 @@ def run_text_generation_latency_tracking(self, backend: Backend): self.latency_tracker.reset() while self.latency_tracker.get_elapsed_time() < self.config.duration: with self.latency_tracker.track(): - _ = backend.generate(self.text_generation_inputs, self.config.generate_kwargs) + _ = backend.generate(self.generate_inputs, self.config.generate_kwargs) generate_latency = self.latency_tracker.get_latency() generate_latency.log(prefix="generate") @@ -224,7 +227,7 @@ def run_image_diffusion_latency_tracking(self, backend: Backend): self.latency_tracker.reset() while self.latency_tracker.get_elapsed_time() < self.config.duration: with self.latency_tracker.track(): - _ = backend.call(self.image_diffusion_inputs, self.config.call_kwargs) + _ = backend.call(self.call_inputs, self.config.call_kwargs) self.report.call.latency = self.latency_tracker.get_latency() self.report.call.throughput = Throughput.from_latency( @@ -236,7 +239,7 @@ def run_latency_inference_tracking(self, backend: Backend): self.latency_tracker.reset() while self.latency_tracker.get_elapsed_time() < self.config.duration: with self.latency_tracker.track(): - _ = backend.forward(self.inference_inputs, self.config.forward_kwargs) + _ = backend.forward(self.forward_inputs, self.config.forward_kwargs) self.report.forward.latency = self.latency_tracker.get_latency() self.report.forward.throughput = Throughput.from_latency( @@ -248,7 +251,7 @@ def run_text_generation_energy_tracking(self, backend: Backend): LOGGER.info("\t+ Running energy tracking") self.energy_tracker.reset() with self.energy_tracker.track(): - _ = backend.forward(self.text_generation_inputs, self.config.forward_kwargs) + _ = backend.forward(self.forward_inputs, self.config.forward_kwargs) self.report.prefill.energy = self.energy_tracker.get_energy() self.report.prefill.efficiency = Efficiency.from_energy( @@ -257,7 +260,7 @@ def run_text_generation_energy_tracking(self, backend: Backend): self.energy_tracker.reset() with self.energy_tracker.track(): - _ = backend.generate(self.text_generation_inputs, self.config.generate_kwargs) + _ = backend.generate(self.generate_inputs, self.config.generate_kwargs) self.report.decode.energy = self.energy_tracker.get_energy() - self.report.prefill.energy self.report.decode.efficiency = Efficiency.from_energy( @@ -268,7 +271,7 @@ def run_image_diffusion_energy_tracking(self, backend: Backend): LOGGER.info("\t+ Running energy tracking") self.energy_tracker.reset() with self.energy_tracker.track(): - _ = backend.call(self.image_diffusion_inputs, self.config.call_kwargs) + _ = backend.call(self.call_inputs, self.config.call_kwargs) self.report.call.energy = self.energy_tracker.get_energy() self.report.call.efficiency = Efficiency.from_energy( @@ -279,7 +282,7 @@ def run_inference_energy_tracking(self, backend: Backend): LOGGER.info("\t+ Running energy tracking") self.energy_tracker.reset() with self.energy_tracker.track(): - _ = backend.forward(self.inference_inputs, self.config.forward_kwargs) + _ = backend.forward(self.forward_inputs, self.config.forward_kwargs) self.report.forward.energy = self.energy_tracker.get_energy() self.report.forward.efficiency = Efficiency.from_energy( diff --git a/optimum_benchmark/benchmarks/inference/inputs_utils.py b/optimum_benchmark/benchmarks/inference/inputs_utils.py new file mode 100644 index 00000000..f4dc5bd1 --- /dev/null +++ b/optimum_benchmark/benchmarks/inference/inputs_utils.py @@ -0,0 +1,17 @@ +def extract_text_generation_inputs(inputs): + if "pixel_values" in inputs: + # image input + text_generation_inputs = {"inputs": inputs["pixel_values"]} + elif "input_values" in inputs: + # speech input + text_generation_inputs = {"inputs": inputs["input_values"]} + elif "input_features" in inputs: + # waveform input + text_generation_inputs = {"inputs": inputs["input_features"]} + elif "input_ids" in inputs: + # text input + text_generation_inputs = {"inputs": inputs["input_ids"]} + else: + raise ValueError("Could not find any valid text generation inputs.") + + return text_generation_inputs