Skip to content

Commit

Permalink
replace uses of np.ndarray with npt.NDArray
Browse files Browse the repository at this point in the history
Summary:
X-link: pytorch/opacus#681

X-link: pytorch/captum#1389

X-link: pytorch/botorch#2586

X-link: pytorch/audio#3846

This replaces uses of `numpy.ndarray` in type annotations with `numpy.typing.NDArray`. In Numpy-1.24.0+ `numpy.ndarray` is annotated as generic type. Without template parameters it triggers static analysis errors:
```counterexample
Generic type `ndarray` expects 2 type parameters.
```
`numpy.typing.NDArray` is an alias that provides default template parameters.

Reviewed By: ryanthomasjohnson

Differential Revision: D64619891

fbshipit-source-id: dffc096b1ce90d11e73d475f0bbcb8867ed9ef01
  • Loading branch information
igorsugak authored and facebook-github-bot committed Oct 19, 2024
1 parent e737b8f commit 06e35fc
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 12 deletions.
3 changes: 2 additions & 1 deletion torchbenchmark/models/pytorch_unet/pytorch_unet/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os

import numpy as np
import numpy.typing as npt
import torch
import torch.nn.functional as F
from PIL import Image
Expand Down Expand Up @@ -102,7 +103,7 @@ def _generate_name(fn):
return args.output or list(map(_generate_name, args.input))


def mask_to_image(mask: np.ndarray):
def mask_to_image(mask: npt.NDArray):
if mask.ndim == 2:
return Image.fromarray((mask * 255).astype(np.uint8))
elif mask.ndim == 3:
Expand Down
13 changes: 7 additions & 6 deletions torchbenchmark/models/sam/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Optional, Tuple

import numpy as np
import numpy.typing as npt
import torch

from .sam import Sam
Expand All @@ -32,7 +33,7 @@ def __init__(

def set_image(
self,
image: np.ndarray,
image: npt.NDArray,
image_format: str = "RGB",
) -> None:
"""
Expand Down Expand Up @@ -92,13 +93,13 @@ def set_torch_image(

def predict(
self,
point_coords: Optional[np.ndarray] = None,
point_labels: Optional[np.ndarray] = None,
box: Optional[np.ndarray] = None,
mask_input: Optional[np.ndarray] = None,
point_coords: Optional[npt.NDArray] = None,
point_labels: Optional[npt.NDArray] = None,
box: Optional[npt.NDArray] = None,
mask_input: Optional[npt.NDArray] = None,
multimask_output: bool = True,
return_logits: bool = False,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
) -> Tuple[npt.NDArray, npt.NDArray, npt.NDArray]:
"""
Predict masks for the given input prompts, using the currently set image.
Expand Down
11 changes: 6 additions & 5 deletions torchbenchmark/models/sam/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Tuple

import numpy as np
import numpy.typing as npt
import torch
from torch.nn import functional as F
from torchvision.transforms.functional import resize, to_pil_image # type: ignore
Expand All @@ -23,7 +24,7 @@ class ResizeLongestSide:
def __init__(self, target_length: int) -> None:
self.target_length = target_length

def apply_image(self, image: np.ndarray) -> np.ndarray:
def apply_image(self, image: npt.NDArray) -> npt.NDArray:
"""
Expects a numpy array with shape HxWxC in uint8 format.
"""
Expand All @@ -33,8 +34,8 @@ def apply_image(self, image: np.ndarray) -> np.ndarray:
return np.array(resize(to_pil_image(image), target_size))

def apply_coords(
self, coords: np.ndarray, original_size: Tuple[int, ...]
) -> np.ndarray:
self, coords: npt.NDArray, original_size: Tuple[int, ...]
) -> npt.NDArray:
"""
Expects a numpy array of length 2 in the final dimension. Requires the
original image size in (H, W) format.
Expand All @@ -49,8 +50,8 @@ def apply_coords(
return coords

def apply_boxes(
self, boxes: np.ndarray, original_size: Tuple[int, ...]
) -> np.ndarray:
self, boxes: npt.NDArray, original_size: Tuple[int, ...]
) -> npt.NDArray:
"""
Expects a numpy array shape Bx4. Requires the original image size
in (H, W) format.
Expand Down

0 comments on commit 06e35fc

Please sign in to comment.