diff --git a/optimum_benchmark/trackers/latency.py b/optimum_benchmark/trackers/latency.py index 0555578f..b25ef53d 100644 --- a/optimum_benchmark/trackers/latency.py +++ b/optimum_benchmark/trackers/latency.py @@ -6,10 +6,9 @@ import numpy as np import torch -from diffusers.callbacks import PipelineCallback from rich.console import Console from rich.markdown import Markdown -from transformers import LogitsProcessor, TrainerCallback +from transformers import TrainerCallback CONSOLE = Console() LOGGER = getLogger("latency") @@ -283,7 +282,7 @@ def get_latency(self) -> Latency: return Latency.from_values(latencies, unit=LATENCY_UNIT) -class PerTokenLatencySessionTrackerLogitsProcessor(LogitsProcessor): +class PerTokenLatencySessionTrackerLogitsProcessor: def __init__(self, device: str, backend: str): self.device = device self.backend = backend @@ -428,7 +427,7 @@ def get_per_token_latency(self) -> Latency: return Latency.from_values(latencies, unit=LATENCY_UNIT) -class PerStepLatencySessionTrackerPipelineCallback(PipelineCallback): +class PerStepLatencySessionTrackerPipelineCallback: tensor_inputs = [] def __init__(self, device: str, backend: str):