Skip to content

Commit

Permalink
fix ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
Seun-Ajayi committed Jul 11, 2024
1 parent f646a8e commit fe4915f
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 4 deletions.
3 changes: 2 additions & 1 deletion benchmark/run_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
XLMRobertaModel,
)

import os, sys
import os
import sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from src.mlx_transformers.models import BertForMaskedLM as MlxBertForMaskedLM
Expand Down
2 changes: 1 addition & 1 deletion tests/test_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@ def test_forward(self) -> None:
inputs = {key: mx.array(v) for key, v in inputs.items()}
outputs = self.model(**inputs, use_cache=True)

assert type(outputs.logits) == mx.array
assert type(outputs.logits) is mx.array
2 changes: 1 addition & 1 deletion tests/test_phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@ def test_forward(self) -> None:
inputs = {key: mx.array(v) for key, v in inputs.items()}
outputs = self.model(**inputs, use_cache=True)

assert type(outputs.logits) == mx.array
assert type(outputs.logits) is mx.array
2 changes: 1 addition & 1 deletion tests/test_phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,4 @@ def test_forward(self) -> None:
inputs = {key: mx.array(v) for key, v in inputs.items()}
outputs = self.model(**inputs, use_cache=True)

assert type(outputs.logits) == mx.array
assert type(outputs.logits) is mx.array

0 comments on commit fe4915f

Please sign in to comment.