Skip to content

Commit

Permalink
fix: pyright type issues
Browse files Browse the repository at this point in the history
  • Loading branch information
nfrasser committed Dec 7, 2023
1 parent 27bf96f commit ae4eb91
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 20 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ repos:
- id: black
- id: black-jupyter
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.4
rev: v0.1.7
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.335
rev: v1.1.339
hooks:
- id: pyright
additional_dependencies: [cython, httpretty, numpy, pytest]
4 changes: 2 additions & 2 deletions cryosparc/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@ def __new__(cls, field: Field, data: Data):

return obj

def __array_wrap__(self, obj, **kwargs):
def __array_wrap__(self, obj, context=None, /):
# This prevents wrapping single results such as aggregations from n.sum
# or n.median
return obj[()] if obj.shape == () else super().__array_wrap__(obj, **kwargs)
return obj[()] if obj.shape == () else super().__array_wrap__(obj, context)

def to_fixed(self) -> "Column":
"""
Expand Down
8 changes: 4 additions & 4 deletions cryosparc/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
Callable,
Collection,
Dict,
Generator,
Generic,
Iterable,
List,
Mapping,
MutableMapping,
Expand Down Expand Up @@ -645,7 +645,7 @@ def save(self, file: Union[str, PurePath, IO[bytes]], format: int = DEFAULT_FORM
else:
raise TypeError(f"Invalid dataset save format for {file}: {format}")

def stream(self, compression: Literal["lz4", None] = None) -> Iterable[bytes]:
def stream(self, compression: Literal["lz4", None] = None) -> Generator[bytes, None, None]:
"""
Generate a binary representation for this dataset. Results may be
written to a file or buffer to be sent over the network.
Expand Down Expand Up @@ -824,7 +824,7 @@ def __delitem__(self, key: str):
"""
self.drop_fields([key])

def __contains__(self, key: str) -> bool:
def __contains__(self, key: object) -> bool:
"""
Use the ``in`` operator to check if the given field exists in dataset.
Expand All @@ -834,7 +834,7 @@ def __contains__(self, key: str) -> bool:
Returns:
bool: True if exists, False otherwise.
"""
return self._data.has(key)
return self._data.has(key) if isinstance(key, str) else False

def __eq__(self, other: object):
"""
Expand Down
20 changes: 12 additions & 8 deletions cryosparc/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def kill(self):
)
self.refresh()

def wait_for_status(self, status: Union[JobStatus, Iterable[JobStatus]], timeout: Optional[int] = None) -> str:
def wait_for_status(self, status: Union[JobStatus, Iterable[JobStatus]], *, timeout: Optional[int] = None) -> str:
"""
Wait for a job's status to reach the specified value. Must be one of
the following:
Expand Down Expand Up @@ -207,7 +207,7 @@ def wait_for_status(self, status: Union[JobStatus, Iterable[JobStatus]], timeout
self.refresh()
return self.status

def wait_for_done(self, error_on_incomplete: bool = False, timeout: Optional[int] = None) -> str:
def wait_for_done(self, *, error_on_incomplete: bool = False, timeout: Optional[int] = None) -> str:
"""
Wait until a job reaches status "completed", "killed" or "failed".
Expand All @@ -225,7 +225,7 @@ def wait_for_done(self, error_on_incomplete: bool = False, timeout: Optional[int
), f"Job {self.project_uid}-{self.uid} did not complete (status {status})"
return status

def interact(self, action: str, body: Any = {}, timeout: int = 10, refresh: bool = False) -> Any:
def interact(self, action: str, body: Any = {}, *, timeout: int = 10, refresh: bool = False) -> Any:
"""
Call an interactive action on a waiting interactive job. The possible
actions and expected body depends on the job type.
Expand Down Expand Up @@ -253,7 +253,7 @@ def clear(self):
self.cs.cli.clear_job(self.project_uid, self.uid) # type: ignore
self.refresh()

def set_param(self, name: str, value: Any, refresh: bool = True) -> bool:
def set_param(self, name: str, value: Any, *, refresh: bool = True) -> bool:
"""
Set the given param name on the current job to the given value. Only
works if the job is in "building" status.
Expand Down Expand Up @@ -283,7 +283,7 @@ def set_param(self, name: str, value: Any, refresh: bool = True) -> bool:
self.refresh()
return result

def connect(self, target_input: str, source_job_uid: str, source_output: str, refresh: bool = True) -> bool:
def connect(self, target_input: str, source_job_uid: str, source_output: str, *, refresh: bool = True) -> bool:
"""
Connect the given input for this job to an output with given job UID and
name.
Expand Down Expand Up @@ -321,7 +321,7 @@ def connect(self, target_input: str, source_job_uid: str, source_output: str, re
self.refresh()
return result

def disconnect(self, target_input: str, connection_idx: Optional[int] = None, refresh: bool = True):
def disconnect(self, target_input: str, connection_idx: Optional[int] = None, *, refresh: bool = True):
"""
Clear the given job input group.
Expand Down Expand Up @@ -471,6 +471,7 @@ def log_plot(
figure: Union[str, PurePath, IO[bytes], Any],
text: str,
formats: Iterable[ImageFormat] = ["png", "pdf"],
*,
raw_data: Union[str, bytes, Literal[None]] = None,
raw_data_file: Union[str, PurePath, IO[bytes], Literal[None]] = None,
raw_data_format: Optional[TextFormat] = None,
Expand Down Expand Up @@ -1125,6 +1126,7 @@ def add_output(
slots: List[SlotSpec] = ...,
passthrough: Optional[str] = ...,
title: Optional[str] = None,
alloc: Literal[None] = None,
) -> str:
...

Expand Down Expand Up @@ -1243,11 +1245,12 @@ def connect(
target_input: str,
source_job_uid: str,
source_output: str,
*,
slots: List[SlotSpec] = [],
title: str = "",
desc: str = "",
refresh: bool = True,
):
) -> bool:
"""
Connect the given input for this job to an output with given job UID and
name. If this input does not exist, it will be added with the given
Expand Down Expand Up @@ -1302,6 +1305,7 @@ def connect(
raise
if refresh:
self.refresh()
return True

def alloc_output(self, name: str, alloc: Union[int, "ArrayLike", Dataset] = 0) -> Dataset:
"""
Expand Down Expand Up @@ -1362,7 +1366,7 @@ def alloc_output(self, name: str, alloc: Union[int, "ArrayLike", Dataset] = 0) -
else:
return Dataset({"uid": alloc}).add_fields(expected_fields)

def save_output(self, name: str, dataset: Dataset, refresh: bool = True):
def save_output(self, name: str, dataset: Dataset, *, refresh: bool = True):
"""
Save output dataset to external job.
Expand Down
7 changes: 5 additions & 2 deletions cryosparc/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@
]
"""
from abc import ABC, abstractmethod
from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar, Union
from typing import Any, Dict, Generic, List, Optional, TYPE_CHECKING, Tuple, TypeVar, Union
from typing_extensions import Literal, TypedDict

if TYPE_CHECKING:
from typing_extensions import Self # not present in typing-extensions=3.7

# Database document
D = TypeVar("D")

Expand Down Expand Up @@ -881,6 +884,6 @@ def doc(self) -> D:
return self._doc

@abstractmethod
def refresh(self):
def refresh(self) -> Self:
# Must be implemented in subclasses
return self
3 changes: 1 addition & 2 deletions cryosparc/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

if TYPE_CHECKING:
from typing_extensions import Self # not present in typing-extensions=3.7
from .core import MemoryView


class AsyncBinaryIO(Protocol):
Expand Down Expand Up @@ -154,7 +153,7 @@ async def from_async_iterator(cls, iterator: Union[AsyncIterator[bytes], AsyncGe
return await cls.from_async_stream(AsyncBinaryIteratorIO(iterator))

@abstractmethod
def stream(self) -> Generator[Union[bytes, memoryview, "MemoryView"], None, None]:
def stream(self) -> Generator[bytes, None, None]:
...

async def astream(self):
Expand Down

0 comments on commit ae4eb91

Please sign in to comment.