Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Dec 9, 2024
1 parent 501c42d commit bb59538
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions optimum_benchmark/trackers/latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@

import numpy as np
import torch
from diffusers.callbacks import PipelineCallback
from rich.console import Console
from rich.markdown import Markdown
from transformers import LogitsProcessor, TrainerCallback
from transformers import TrainerCallback

CONSOLE = Console()
LOGGER = getLogger("latency")
Expand Down Expand Up @@ -283,7 +282,7 @@ def get_latency(self) -> Latency:
return Latency.from_values(latencies, unit=LATENCY_UNIT)


class PerTokenLatencySessionTrackerLogitsProcessor(LogitsProcessor):
class PerTokenLatencySessionTrackerLogitsProcessor:
def __init__(self, device: str, backend: str):
self.device = device
self.backend = backend
Expand Down Expand Up @@ -428,7 +427,7 @@ def get_per_token_latency(self) -> Latency:
return Latency.from_values(latencies, unit=LATENCY_UNIT)


class PerStepLatencySessionTrackerPipelineCallback(PipelineCallback):
class PerStepLatencySessionTrackerPipelineCallback:
tensor_inputs = []

def __init__(self, device: str, backend: str):
Expand Down

0 comments on commit bb59538

Please sign in to comment.