diff --git a/torch_geometric/data/dataset.py b/torch_geometric/data/dataset.py index 18bb65a5525e..899a878bef89 100644 --- a/torch_geometric/data/dataset.py +++ b/torch_geometric/data/dataset.py @@ -1,4 +1,5 @@ import copy +import os import os.path as osp import re import sys @@ -256,16 +257,28 @@ def __getitem__( In case :obj:`idx` is a slicing object, *e.g.*, :obj:`[2:5]`, a list, a tuple, or a :obj:`torch.Tensor` or :obj:`np.ndarray` of type long or bool, will return a subset of the dataset at the specified indices.""" + + # We can't hook __getitem__ item as it is a special method + # https://docs.python.org/3/reference/datamodel.html#special-lookup + if os.environ.get('NVIDIA_NVTX_RANGES', + "0") == "1" and torch.cuda.is_available(): + nvtx_handle = torch.cuda.nvtx.range_start( + f"[Dataset] __getitem__ for {self}") + if (isinstance(idx, (int, np.integer)) or (isinstance(idx, Tensor) and idx.dim() == 0) or (isinstance(idx, np.ndarray) and np.isscalar(idx))): data = self.get(self.indices()[idx]) data = data if self.transform is None else self.transform(data) - return data - else: - return self.index_select(idx) + data = self.index_select(idx) + + if os.environ.get('NVIDIA_NVTX_RANGES', + "0") == "1" and torch.cuda.is_available(): + torch.cuda.nvtx.range_end(nvtx_handle) + + return data def index_select(self, idx: IndexType) -> 'Dataset': r"""Creates a subset of the dataset from specified indices :obj:`idx`. diff --git a/torch_geometric/loader/__init__.py b/torch_geometric/loader/__init__.py index 494a380023e2..682c24989fb7 100644 --- a/torch_geometric/loader/__init__.py +++ b/torch_geometric/loader/__init__.py @@ -21,6 +21,46 @@ from .prefetch import PrefetchLoader from .mixin import AffinityMixin +import os +import torch +import inspect + + +def hook_nvtx_collate_fn(class_to_hook): + original_init = class_to_hook.__init__ + + def post_hooked_init(self, *args, **kwargs): + original_init(self, *args, **kwargs) + + if not hasattr(self, "collate_fn"): + return + + # Checking if a subclass was already hooked: no need to hook again + if hasattr(self, "collate_fn_hooked"): + return + original_collate_fn = self.collate_fn + + def hooked_collate_fn(*args, **kwargs): + nvtx_handle = torch.cuda.nvtx.range_start( + f"[{class_to_hook.__name__}]] collate_fn for {self}") + ret = original_collate_fn(*args, **kwargs) + torch.cuda.nvtx.range_end(nvtx_handle) + return ret + + self.collate_fn = hooked_collate_fn + self.collate_fn_hooked = True + + class_to_hook.__init__ = post_hooked_init + + +if os.environ.get('NVIDIA_NVTX_RANGES', + "0") == "1" and torch.cuda.is_available(): + syms = list(locals().keys()) + for sym in syms: + cl = locals()[sym] + if inspect.isclass(cl) and issubclass(cl, torch.utils.data.DataLoader): + hook_nvtx_collate_fn(cl) + __all__ = classes = [ 'DataLoader', 'NodeLoader', diff --git a/torch_geometric/nn/conv/message_passing.py b/torch_geometric/nn/conv/message_passing.py index 1daf66490eac..7aaad9e044e8 100644 --- a/torch_geometric/nn/conv/message_passing.py +++ b/torch_geometric/nn/conv/message_passing.py @@ -184,6 +184,34 @@ def __init__( self._edge_update_forward_pre_hooks = OrderedDict() self._edge_update_forward_hooks = OrderedDict() + # Init NVTX Ranges for this op + if os.environ.get('NVIDIA_NVTX_RANGES', + "0") == "1" and torch.cuda.is_available(): + self._nvtx_handles = dict() + + def get_hooks_for(func_name): + def nvtx_pre_hook(module, inputs): + self._nvtx_handles[ + func_name] = torch.cuda.nvtx.range_start( + f"[MessagePassing] {func_name} for {self}") + return inputs + + def nvtx_hook(module, inputs, output): + torch.cuda.nvtx.range_end(self._nvtx_handles[func_name]) + return output + + return nvtx_pre_hook, nvtx_hook + + for func_name in [ + "propagate", "message", "aggregate", + "message_and_aggregate", "edge_update" + ]: + nvtx_pre_hook, nvtx_hook = get_hooks_for(func_name) + getattr( + self, + f"register_{func_name}_forward_pre_hook")(nvtx_pre_hook) + getattr(self, f"register_{func_name}_forward_hook")(nvtx_hook) + def reset_parameters(self): r"""Resets all learnable parameters of the module.""" if self.aggr_module is not None: diff --git a/torch_geometric/sampler/neighbor_sampler.py b/torch_geometric/sampler/neighbor_sampler.py index 24c1e2834c19..e20ca52da3e2 100644 --- a/torch_geometric/sampler/neighbor_sampler.py +++ b/torch_geometric/sampler/neighbor_sampler.py @@ -1,5 +1,6 @@ import copy import math +import os import sys import warnings from typing import Callable, Dict, List, Optional, Tuple, Union @@ -151,6 +152,25 @@ def __init__( self.disjoint = disjoint self.temporal_strategy = temporal_strategy + # Init NVTX Ranges for this class + if os.environ.get('NVIDIA_NVTX_RANGES', + "0") == "1" and torch.cuda.is_available(): + + def hook_func_with_nvtx(func_name): + func = getattr(self, func_name) + + def hooked_func(*args, **kwargs): + nvtx_handle = torch.cuda.nvtx.range_start( + f"[Sampler] {func_name} for {self}") + ret = func(*args, **kwargs) + torch.cuda.nvtx.range_end(nvtx_handle) + return ret + + setattr(self, func_name, hooked_func) + + for func_name in ["sample_from_nodes", "sample_from_edges"]: + hook_func_with_nvtx(func_name) + @property def num_neighbors(self) -> NumNeighbors: return self._num_neighbors