diff --git a/mmengine/visualization/vis_backend.py b/mmengine/visualization/vis_backend.py index d5fa1fc068..54fdcea7cd 100644 --- a/mmengine/visualization/vis_backend.py +++ b/mmengine/visualization/vis_backend.py @@ -8,7 +8,7 @@ import warnings from abc import ABCMeta, abstractmethod from collections.abc import MutableMapping -from typing import Any, Callable, List, Optional, Sequence, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Union import cv2 import numpy as np @@ -456,6 +456,8 @@ def add_graph(self, model: torch.nn.Module, data_batch: Sequence[dict], def add_image(self, name: str, image: np.ndarray, + boxes: Optional[Dict] = None, + masks: Optional[Dict] = None, step: int = 0, **kwargs) -> None: """Record the image to wandb. @@ -467,7 +469,7 @@ def add_image(self, step (int): Useless parameter. Wandb does not need this parameter. Defaults to 0. """ - image = self._wandb.Image(image) + image = self._wandb.Image(image, boxes=boxes, masks=masks) self._wandb.log({name: image}, commit=self._commit) @force_init_env @@ -507,7 +509,7 @@ def add_scalars(self, def close(self) -> None: """close an opened wandb object.""" if hasattr(self, '_wandb'): - self._wandb.join() + self._wandb.finish() @VISBACKENDS.register_module()