Skip to content

Commit

Permalink
Merge branch 'master' into aki-opc
Browse files Browse the repository at this point in the history
akihironitta committed Oct 8, 2024
2 parents a14818a + 086c2bd commit 08a5582
Showing 18 changed files with 227 additions and 176 deletions.
4 changes: 2 additions & 2 deletions .github/actions/setup/action.yml
Original file line number Diff line number Diff line change
@@ -6,7 +6,7 @@ inputs:
default: '3.8'
torch-version:
required: false
default: '2.3.0'
default: '2.4.0'
cuda-version:
required: false
default: cpu
@@ -43,7 +43,7 @@ runs:
shell: bash

- name: Disable CUDNN
if: ${{ inputs.cuda-version != 'cpu' }}
if: ${{ (inputs.cuda-version != 'cpu') && ((inputs.torch-version == '1.12.0') || (inputs.torch-version == '1.13.0')) }}
run: |
Torch_DIR=`python -c 'import torch;print(torch.utils.cmake_prefix_path)'`
sed -i '95,100d' ${Torch_DIR}/Caffe2/Caffe2Config.cmake
5 changes: 5 additions & 0 deletions .github/workflows/aws/upload_nightly_index.py
Original file line number Diff line number Diff line change
@@ -30,6 +30,11 @@
if '2.1.0' in torch_version:
wheels_dict[torch_version.replace('2.1.0', '2.1.1')].append(wheel)
wheels_dict[torch_version.replace('2.1.0', '2.1.2')].append(wheel)
if '2.2.0' in torch_version:
wheels_dict[torch_version.replace('2.2.0', '2.2.1')].append(wheel)
wheels_dict[torch_version.replace('2.2.0', '2.2.2')].append(wheel)
if '2.3.0' in torch_version:
wheels_dict[torch_version.replace('2.3.0', '2.3.1')].append(wheel)

index_html = html.format('\n'.join([
href.format(f'{version}.html'.replace('+', '%2B'), version)
30 changes: 22 additions & 8 deletions .github/workflows/building.yml
Original file line number Diff line number Diff line change
@@ -12,9 +12,9 @@ jobs:
matrix:
os: [ubuntu-20.04, macos-14, windows-2019]
python-version: ['3.8', '3.9', '3.10', '3.11', '3.12']
# torch-version: [1.12.0, 1.13.0, 2.0.0, 2.1.0, 2.2.0, 2.3.0]
torch-version: [2.3.0]
cuda-version: ['cpu', 'cu113', 'cu116', 'cu117', 'cu118', 'cu121']
# torch-version: [1.12.0, 1.13.0, 2.0.0, 2.1.0, 2.2.0, 2.3.0, 2.4.0]
torch-version: [2.4.0]
cuda-version: ['cpu', 'cu113', 'cu116', 'cu117', 'cu118', 'cu121', 'cu124']
exclude:
- torch-version: 1.12.0
python-version: '3.12'
@@ -32,6 +32,8 @@ jobs:
cuda-version: 'cu118'
- torch-version: 1.12.0
cuda-version: 'cu121'
- torch-version: 1.12.0
cuda-version: 'cu124'
- torch-version: 1.13.0
python-version: '3.11'
- torch-version: 1.13.0
@@ -40,30 +42,44 @@ jobs:
cuda-version: 'cu118'
- torch-version: 1.13.0
cuda-version: 'cu121'
- torch-version: 1.13.0
cuda-version: 'cu124'
- torch-version: 2.0.0
cuda-version: 'cu113'
- torch-version: 2.0.0
cuda-version: 'cu116'
- torch-version: 1.13.0
cuda-version: 'cu121'
- torch-version: 2.0.0
cuda-version: 'cu124'
- torch-version: 2.1.0
cuda-version: 'cu113'
- torch-version: 2.1.0
cuda-version: 'cu116'
- torch-version: 2.1.0
cuda-version: 'cu117'
- torch-version: 2.1.0
cuda-version: 'cu124'
- torch-version: 2.2.0
cuda-version: 'cu113'
- torch-version: 2.2.0
cuda-version: 'cu116'
- torch-version: 2.2.0
cuda-version: 'cu117'
- torch-version: 2.2.0
cuda-version: 'cu124'
- torch-version: 2.3.0
cuda-version: 'cu113'
- torch-version: 2.3.0
cuda-version: 'cu116'
- torch-version: 2.3.0
cuda-version: 'cu117'
- torch-version: 2.3.0
cuda-version: 'cu124'
- torch-version: 2.4.0
cuda-version: 'cu113'
- torch-version: 2.4.0
cuda-version: 'cu116'
- torch-version: 2.4.0
cuda-version: 'cu117'
- os: macos-14
cuda-version: 'cu113'
- os: macos-14
@@ -75,9 +91,7 @@ jobs:
- os: macos-14
cuda-version: 'cu121'
- os: macos-14
python-version: '3.8'
- os: macos-14
python-version: '3.9'
cuda-version: 'cu124'
- os: windows-2019
torch-version: 2.0.0
cuda-version: 'cu121'
5 changes: 5 additions & 0 deletions .github/workflows/cuda/Linux-env.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
#!/bin/bash

case ${1} in
cu124)
export FORCE_CUDA=1
export PATH=/usr/local/cuda-12.4/bin:${PATH}
export TORCH_CUDA_ARCH_LIST="5.0+PTX;6.0;7.0;7.5;8.0;8.6;9.0"
;;
cu121)
export FORCE_CUDA=1
export PATH=/usr/local/cuda-12.1/bin:${PATH}
8 changes: 7 additions & 1 deletion .github/workflows/cuda/Linux.sh
Original file line number Diff line number Diff line change
@@ -3,6 +3,12 @@
OS=ubuntu2004

case ${1} in
cu124)
CUDA=12.4
APT_KEY=${OS}-${CUDA/./-}-local
FILENAME=cuda-repo-${APT_KEY}_${CUDA}.1-550.54.15-1_amd64.deb
URL=https://developer.download.nvidia.com/compute/cuda/${CUDA}.1/local_installers
;;
cu121)
CUDA=12.1
APT_KEY=${OS}-${CUDA/./-}-local
@@ -56,7 +62,7 @@ sudo mv cuda-${OS}.pin /etc/apt/preferences.d/cuda-repository-pin-600
wget -nv ${URL}/${FILENAME}
sudo dpkg -i ${FILENAME}

if [ "${1}" = "cu117" ] || [ "${1}" = "cu118" ] || [ "${1}" = "cu121" ]; then
if [ "${1}" = "cu117" ] || [ "${1}" = "cu118" ] || [ "${1}" = "cu121" ] || [ "${1}" = "cu124" ]; then
sudo cp /var/cuda-repo-${APT_KEY}/cuda-*-keyring.gpg /usr/share/keyrings/
else
sudo apt-key add /var/cuda-repo-${APT_KEY}/7fa2af80.pub
4 changes: 4 additions & 0 deletions .github/workflows/cuda/Windows-env.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
#!/bin/bash

case ${1} in
cu124)
export FORCE_CUDA=1
export PATH=/c/Program\ Files/NVIDIA\ GPU\ Computing\ Toolkit/CUDA/v12.4/bin:${PATH}
;;
cu121)
export FORCE_CUDA=1
export PATH=/c/Program\ Files/NVIDIA\ GPU\ Computing\ Toolkit/CUDA/v12.1/bin:${PATH}
5 changes: 5 additions & 0 deletions .github/workflows/cuda/Windows.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
#!/bin/bash

case ${1} in
cu124)
CUDA_SHORT=12.4
CUDA_URL=https://developer.download.nvidia.com/compute/cuda/${CUDA_SHORT}.1/local_installers
CUDA_FILE=cuda_${CUDA_SHORT}.1_551.78_windows.exe
;;
cu121)
CUDA_SHORT=12.1
CUDA_URL=https://developer.download.nvidia.com/compute/cuda/${CUDA_SHORT}.1/local_installers
28 changes: 21 additions & 7 deletions .github/workflows/nightly.yml
Original file line number Diff line number Diff line change
@@ -16,8 +16,8 @@ jobs:
matrix:
os: [ubuntu-20.04, macos-14, windows-2019]
python-version: ['3.8', '3.9', '3.10', '3.11', '3.12']
torch-version: [1.12.0, 1.13.0, 2.0.0, 2.1.0, 2.2.0, 2.3.0]
cuda-version: ['cpu', 'cu113', 'cu116', 'cu117', 'cu118', 'cu121']
torch-version: [1.12.0, 1.13.0, 2.0.0, 2.1.0, 2.2.0, 2.3.0, 2.4.0]
cuda-version: ['cpu', 'cu113', 'cu116', 'cu117', 'cu118', 'cu121', 'cu124']
exclude:
- torch-version: 1.12.0
python-version: '3.12'
@@ -35,6 +35,8 @@ jobs:
cuda-version: 'cu118'
- torch-version: 1.12.0
cuda-version: 'cu121'
- torch-version: 1.12.0
cuda-version: 'cu124'
- torch-version: 1.13.0
python-version: '3.11'
- torch-version: 1.13.0
@@ -43,30 +45,44 @@ jobs:
cuda-version: 'cu118'
- torch-version: 1.13.0
cuda-version: 'cu121'
- torch-version: 1.13.0
cuda-version: 'cu124'
- torch-version: 2.0.0
cuda-version: 'cu113'
- torch-version: 2.0.0
cuda-version: 'cu116'
- torch-version: 1.13.0
cuda-version: 'cu121'
- torch-version: 2.0.0
cuda-version: 'cu124'
- torch-version: 2.1.0
cuda-version: 'cu113'
- torch-version: 2.1.0
cuda-version: 'cu116'
- torch-version: 2.1.0
cuda-version: 'cu117'
- torch-version: 2.1.0
cuda-version: 'cu124'
- torch-version: 2.2.0
cuda-version: 'cu113'
- torch-version: 2.2.0
cuda-version: 'cu116'
- torch-version: 2.2.0
cuda-version: 'cu117'
- torch-version: 2.2.0
cuda-version: 'cu124'
- torch-version: 2.3.0
cuda-version: 'cu113'
- torch-version: 2.3.0
cuda-version: 'cu116'
- torch-version: 2.3.0
cuda-version: 'cu117'
- torch-version: 2.3.0
cuda-version: 'cu124'
- torch-version: 2.4.0
cuda-version: 'cu113'
- torch-version: 2.4.0
cuda-version: 'cu116'
- torch-version: 2.4.0
cuda-version: 'cu117'
- os: macos-14
cuda-version: 'cu113'
- os: macos-14
@@ -78,9 +94,7 @@ jobs:
- os: macos-14
cuda-version: 'cu121'
- os: macos-14
python-version: '3.8'
- os: macos-14
python-version: '3.9'
cuda-version: 'cu124'
- os: windows-2019
torch-version: 2.0.0
cuda-version: 'cu121'
16 changes: 11 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
rev: v5.0.0
hooks:
- id: end-of-file-fixer
- id: trailing-whitespace
@@ -28,7 +28,7 @@ repos:
args: [--min=10, .]

- repo: https://github.com/asottile/pyupgrade
rev: v3.16.0
rev: v3.17.0
hooks:
- id: pyupgrade
name: Upgrade Python syntax
@@ -60,21 +60,27 @@ repos:
name: Sort imports

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.5.3
rev: v0.6.9
hooks:
- id: ruff
name: Ruff formatting
args: [--fix, --exit-non-zero-on-fix]

- repo: https://github.com/PyCQA/flake8
rev: 7.1.0
rev: 7.1.1
hooks:
- id: flake8
name: Check PEP8

- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v18.1.8
rev: v19.1.1
hooks:
- id: clang-format
name: Format C++ code
args: [--style=file]

- repo: https://github.com/sphinx-contrib/sphinx-lint
rev: v1.0.0
hooks:
- id: sphinx-lint
name: Check Sphinx
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
## [0.5.0] - 2023-MM-DD
### Added
- Added `torch.compile` support for `segment_matmul` ([#333](https://github.com/pyg-team/pyg-lib/pull/333))
- Added PyTorch 2.4 support ([#338](https://github.com/pyg-team/pyg-lib/pull/338))
- Added PyTorch 2.3 support ([#322](https://github.com/pyg-team/pyg-lib/pull/322))
- Added Windows support ([#315](https://github.com/pyg-team/pyg-lib/pull/315))
- Added macOS Apple Silicon support ([#310](https://github.com/pyg-team/pyg-lib/pull/310))
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -93,6 +93,7 @@ if (NOT MSVC)
endif()

find_package(Torch REQUIRED)
message("-- TORCH_LIBRARIES: ${TORCH_LIBRARIES}")
target_link_libraries(${PROJECT_NAME} PRIVATE ${TORCH_LIBRARIES})

find_package(OpenMP)
44 changes: 25 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
@@ -29,45 +29,51 @@ pip install pyg-lib -f https://data.pyg.org/whl/torch-${TORCH}+${CUDA}.html

where

* `${TORCH}` should be replaced by either `1.12.0`, `1.13.0`, `2.0.0`, `2.1.0`, `2.2.0`, or `2.3.0`
* `${TORCH}` should be replaced by either `1.12.0`, `1.13.0`, `2.0.0`, `2.1.0`, `2.2.0`, `2.3.0`, or `2.4.0`
* `${CUDA}` should be replaced by either `cpu`, `cu102`, `cu113`, `cu116`, `cu117`, `cu118`, or `cu121`

The following combinations are supported:

| PyTorch 2.3 | `cpu` | `cu102` | `cu113` | `cu116` | `cu117` | `cu118` | `cu121` |
| PyTorch 2.4 | `cpu` | `cu113` | `cu116` | `cu117` | `cu118` | `cu121` | `cu124` |
|--------------|-------|---------|---------|---------|---------|---------|---------|
| **Linux** || | | | |||
| **Windows** || | | | |||
| **Linux** || | | | |||
| **Windows** || | | | |||
| **macOS** || | | | | | |

| PyTorch 2.2 | `cpu` | `cu102` | `cu113` | `cu116` | `cu117` | `cu118` | `cu121` |
| PyTorch 2.3 | `cpu` | `cu113` | `cu116` | `cu117` | `cu118` | `cu121` | `cu124` |
|--------------|-------|---------|---------|---------|---------|---------|---------|
| **Linux** || | | | || |
| **Windows** || | | | || |
| **Linux** || | | | || |
| **Windows** || | | | || |
| **macOS** || | | | | | |

| PyTorch 2.1 | `cpu` | `cu102` | `cu113` | `cu116` | `cu117` | `cu118` | `cu121` |
| PyTorch 2.2 | `cpu` | `cu113` | `cu116` | `cu117` | `cu118` | `cu121` | `cu124` |
|--------------|-------|---------|---------|---------|---------|---------|---------|
| **Linux** || | | | || |
| **Windows** || | | | || |
| **Linux** || | | | || |
| **Windows** || | | | || |
| **macOS** || | | | | | |

| PyTorch 2.0 | `cpu` | `cu102` | `cu113` | `cu116` | `cu117` | `cu118` | `cu121` |
c
| PyTorch 2.1 | `cpu` | `cu113` | `cu116` | `cu117` | `cu118` | `cu121` | `cu124` |
|--------------|-------|---------|---------|---------|---------|---------|---------|
| **Linux** || | | ||| |
| **Linux** || | | ||| |
| **Windows** || | | ||| |
| **macOS** || | | | | | |

| PyTorch 1.13 | `cpu` | `cu102` | `cu113` | `cu116` | `cu117` | `cu118` | `cu121` |
| PyTorch 2.0 | `cpu` | `cu113` | `cu116` | `cu117` | `cu118` | `cu121` | `cu124` |
|--------------|-------|---------|---------|---------|---------|---------|---------|
| **Linux** || | ||| | |
| **Linux** || | ||| | |
| **Windows** || | ||| | |
| **macOS** || | | | | | |

| PyTorch 1.12 | `cpu` | `cu102` | `cu113` | `cu116` | `cu117` | `cu118` | `cu121` |
|--------------|-------|---------|---------|---------|---------|---------| --------|
| **Linux** ||||| | | |
| **Windows** ||||| | | |
| PyTorch 1.13 | `cpu` | `cu113` | `cu116` | `cu117` | `cu118` | `cu121` | `cu124` |
|--------------|-------|---------|---------|---------|---------|---------|---------|
| **Linux** || ||| | | |
| **Windows** || ||| | | |
| **macOS** || | | | | | |
c
| PyTorch 1.12 | `cpu` | `cu113` | `cu116` | `cu117` | `cu118` | `cu121` | `cu124` |
|--------------|-------|---------|---------|---------|---------| --------|---------|
| **Linux** |||| | | | |
| **Windows** |||| | | | |
| **macOS** || | | | | | |

### Form nightly
14 changes: 13 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@
import sys

import pyg_sphinx_theme
from sphinx.application import Sphinx

import pyg_lib

@@ -20,6 +21,8 @@
'sphinx.ext.mathjax',
'sphinx.ext.napoleon',
'sphinx.ext.viewcode',
'sphinx_copybutton',
'sphinx_autodoc_typehints',
'pyg',
]

@@ -34,5 +37,14 @@

intersphinx_mapping = {
'python': ('http://docs.python.org', None),
'torch': ('https://pytorch.org/docs/master', None),
'torch': ('https://pytorch.org/docs/stable', None),
}

typehints_use_rtype = False
typehints_defaults = 'comma'


def setup(app: Sphinx) -> None:
r"""Setup sphinx application."""
# Do not drop type hints in signatures:
del app.events.listeners['autodoc-process-signature']
4 changes: 2 additions & 2 deletions pyg_lib/__init__.py
Original file line number Diff line number Diff line change
@@ -34,15 +34,15 @@ def load_library(lib_name: str) -> None:
load_library('libpyg')

import pyg_lib.ops # noqa
import pyg_lib.sampler # noqa
import pyg_lib.partition # noqa
import pyg_lib.sampler # noqa


def cuda_version() -> int:
r"""Returns the CUDA version for which :obj:`pyg_lib` was compiled with.
Returns:
(int): The CUDA version.
The CUDA version.
"""
return torch.ops.pyg.cuda_version()

4 changes: 2 additions & 2 deletions pyg_lib/home.py
Original file line number Diff line number Diff line change
@@ -15,7 +15,7 @@ def get_home_dir() -> str:
variable :obj:`$PYG_LIB_HOME` which defaults to :obj:`"~/.cache/pyg_lib"`.
Returns:
(str): The cache directory.
The cache directory.
"""
if _home_dir is not None:
return _home_dir
@@ -29,7 +29,7 @@ def set_home_dir(path: str):
r"""Sets the cache directory used for storing all :obj:`pyg-lib` data.
Args:
path (str): The path to a local folder.
path: The path to a local folder.
"""
global _home_dir
_home_dir = path
104 changes: 43 additions & 61 deletions pyg_lib/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -106,7 +106,8 @@ def grouped_matmul(
r"""Performs dense-dense matrix multiplication according to groups,
utilizing dedicated kernels that effectively parallelize over groups.
Example:
.. code-block:: python
inputs = [torch.randn(5, 16), torch.randn(3, 32)]
others = [torch.randn(16, 32), torch.randn(32, 64)]
@@ -118,16 +119,12 @@ def grouped_matmul(
assert outs[1] == inputs[1] @ others[1]
Args:
inputs (List[torch.Tensor]): List of left operand 2D matrices of shapes
:obj:`[N_i, K_i]`.
others (List[torch.Tensor]): List of right operand 2D matrices of
shapes :obj:`[K_i, M_i]`.
biases (List[torch.Tensor], optional): Optional bias terms to apply for
each element. (default: :obj:`None`)
inputs: List of left operand 2D matrices of shapes :obj:`[N_i, K_i]`.
others: List of right operand 2D matrices of shapes :obj:`[K_i, M_i]`.
biases: Optional bias terms to apply for each element.
Returns:
List[torch.Tensor]: List of 2D output matrices of shapes
:obj:`[N_i, M_i]`.
List of 2D output matrices of shapes :obj:`[N_i, M_i]`.
"""
# Combine inputs into a single tuple for autograd:
outs = list(GroupedMatmul.apply(tuple(inputs + others)))
@@ -149,7 +146,8 @@ def segment_matmul(
the first dimension of :obj:`inputs` as given by :obj:`ptr`, utilizing
dedicated kernels that effectively parallelize over groups.
Example:
.. code-block:: python
inputs = torch.randn(8, 16)
ptr = torch.tensor([0, 5, 8])
other = torch.randn(2, 16, 32)
@@ -160,18 +158,14 @@ def segment_matmul(
assert torch.allclose(out[5:8], inputs[5:8] @ other[1])
Args:
inputs (torch.Tensor): The left operand 2D matrix of shape
:obj:`[N, K]`.
ptr (torch.Tensor): Compressed vector of shape :obj:`[B + 1]`, holding
the boundaries of segments. For best performance, given as a CPU
tensor.
other (torch.Tensor): The right operand 3D tensor of shape
:obj:`[B, K, M]`.
bias (torch.Tensor, optional): Optional bias term of shape
:obj:`[B, M]` (default: :obj:`None`)
inputs: The left operand 2D matrix of shape :obj:`[N, K]`.
ptr: Compressed vector of shape :obj:`[B + 1]`, holding the boundaries
of segments. For best performance, given as a CPU tensor.
other: The right operand 3D tensor of shape :obj:`[B, K, M]`.
bias: The bias term of shape :obj:`[B, M]`.
Returns:
torch.Tensor: The 2D output matrix of shape :obj:`[N, M]`.
The 2D output matrix of shape :obj:`[N, M]`.
"""
out = torch.ops.pyg.segment_matmul(inputs, ptr, other)
if bias is not None:
@@ -213,15 +207,13 @@ def sampled_add(
being more runtime and memory-efficient.
Args:
left (torch.Tensor): The left tensor.
right (torch.Tensor): The right tensor.
left_index (torch.LongTensor, optional): The values to sample from the
:obj:`left` tensor. (default: :obj:`None`)
right_index (torch.LongTensor, optional): The values to sample from the
:obj:`right` tensor. (default: :obj:`None`)
left: The left tensor.
right: The right tensor.
left_index: The values to sample from the :obj:`left` tensor.
right_index: The values to sample from the :obj:`right` tensor.
Returns:
torch.Tensor: The output tensor.
The output tensor.
"""
out = torch.ops.pyg.sampled_op(left, right, left_index, right_index, "add")
return out
@@ -245,15 +237,13 @@ def sampled_sub(
being more runtime and memory-efficient.
Args:
left (torch.Tensor): The left tensor.
right (torch.Tensor): The right tensor.
left_index (torch.LongTensor, optional): The values to sample from the
:obj:`left` tensor. (default: :obj:`None`)
right_index (torch.LongTensor, optional): The values to sample from the
:obj:`right` tensor. (default: :obj:`None`)
left: The left tensor.
right: The right tensor.
left_index: The values to sample from the :obj:`left` tensor.
right_index: The values to sample from the :obj:`right` tensor.
Returns:
torch.Tensor: The output tensor.
The output tensor.
"""
out = torch.ops.pyg.sampled_op(left, right, left_index, right_index, "sub")
return out
@@ -277,15 +267,13 @@ def sampled_mul(
thus being more runtime and memory-efficient.
Args:
left (torch.Tensor): The left tensor.
right (torch.Tensor): The right tensor.
left_index (torch.LongTensor, optional): The values to sample from the
:obj:`left` tensor. (default: :obj:`None`)
right_index (torch.LongTensor, optional): The values to sample from the
:obj:`right` tensor. (default: :obj:`None`)
left: The left tensor.
right: The right tensor.
left_index: The values to sample from the :obj:`left` tensor.
right_index: The values to sample from the :obj:`right` tensor.
Returns:
torch.Tensor: The output tensor.
The output tensor.
"""
out = torch.ops.pyg.sampled_op(left, right, left_index, right_index, "mul")
return out
@@ -309,15 +297,13 @@ def sampled_div(
being more runtime and memory-efficient.
Args:
left (torch.Tensor): The left tensor.
right (torch.Tensor): The right tensor.
left_index (torch.LongTensor, optional): The values to sample from the
:obj:`left` tensor. (default: :obj:`None`)
right_index (torch.LongTensor, optional): The values to sample from the
:obj:`right` tensor. (default: :obj:`None`)
left: The left tensor.
right: The right tensor.
left_index: The values to sample from the :obj:`left` tensor.
right_index: The values to sample from the :obj:`right` tensor.
Returns:
torch.Tensor: The output tensor.
The output tensor.
"""
out = torch.ops.pyg.sampled_op(left, right, left_index, right_index, "div")
return out
@@ -338,13 +324,12 @@ def index_sort(
device.
Args:
inputs (torch.Tensor): A vector with positive integer values.
max_value (int, optional): The maximum value stored inside
:obj:`inputs`. This value can be an estimation, but needs to be
greater than or equal to the real maximum. (default: :obj:`None`)
inputs: A vector with positive integer values.
max_value: The maximum value stored inside :obj:`inputs`. This value
can be an estimation, but needs to be greater than or equal to the
real maximum.
Returns:
Tuple[torch.LongTensor, torch.LongTensor]:
A tuple containing sorted values and indices of the elements in the
original :obj:`input` tensor.
"""
@@ -364,14 +349,6 @@ def softmax_csr(
:attr:`ptr`, and then proceeds to compute the softmax individually for
each group.
Args:
src (Tensor): The source tensor.
ptr (LongTensor): Groups defined by CSR representation.
dim (int, optional): The dimension in which to normalize.
(default: :obj:`0`)
:rtype: :class:`Tensor`
Examples:
>>> src = torch.randn(4, 4)
>>> ptr = torch.tensor([0, 4])
@@ -380,6 +357,11 @@ def softmax_csr(
[0.1453, 0.2591, 0.5907, 0.2410],
[0.0598, 0.2923, 0.1206, 0.0921],
[0.7792, 0.3502, 0.1638, 0.2145]])
Args:
src: The source tensor.
ptr: Groups defined by CSR representation.
dim: The dimension in which to normalize.
"""
dim = dim + src.dim() if dim < 0 else dim
return torch.ops.pyg.softmax_csr(src, ptr, dim)
19 changes: 8 additions & 11 deletions pyg_lib/partition/__init__.py
Original file line number Diff line number Diff line change
@@ -18,19 +18,16 @@ def metis(
<https://arxiv.org/abs/1905.07953>`_ paper.
Args:
rowptr (torch.Tensor): Compressed source node indices.
col (torch.Tensor): Target node indices.
num_partitions (int): The number of partitions.
node_weight (torch.Tensor, optional): Optional node weights.
(default: :obj:`None`)
edge_weight (torch.Tensor, optional): Optional edge weights.
(default: :obj:`None`)
recursive (bool, optional): If set to :obj:`True`, will use multilevel
recursive bisection instead of multilevel k-way partitioning.
(default: :obj:`False`)
rowptr: Compressed source node indices.
col: Target node indices.
num_partitions: The number of partitions.
node_weight: The node weights.
edge_weight: The edge weights.
recursive: If set to :obj:`True`, will use multilevel recursive
bisection instead of multilevel k-way partitioning.
Returns:
torch.Tensor: A vector that assings each node to a partition.
A vector that assings each node to a partition.
"""
return torch.ops.pyg.metis(rowptr, col, num_partitions, node_weight,
edge_weight, recursive)
107 changes: 50 additions & 57 deletions pyg_lib/sampler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Tuple, Optional, Dict
from typing import Dict, List, Optional, Tuple

import torch
from torch import Tensor
@@ -34,53 +34,48 @@ def neighbor_sample(
binary search to find neighbors that fulfill temporal constraints.
Args:
rowptr (torch.Tensor): Compressed source node indices.
col (torch.Tensor): Target node indices.
seed (torch.Tensor): The seed node indices.
num_neighbors (List[int]): The number of neighbors to sample for each
node in each iteration. If an entry is set to :obj:`-1`, all
neighbors will be included.
node_time (torch.Tensor, optional): Timestamps for the nodes in the
graph. If set, temporal sampling will be used such that neighbors
are guaranteed to fulfill temporal constraints, *i.e.* sampled
rowptr: Compressed source node indices.
col: Target node indices.
seed: The seed node indices.
num_neighbors: The number of neighbors to sample for each node in each
iteration.
If an entry is set to :obj:`-1`, all neighbors will be included.
node_time: Timestamps for the nodes in the graph.
If set, temporal sampling will be used such that neighbors are
guaranteed to fulfill temporal constraints, *i.e.* sampled
nodes have an earlier or equal timestamp than the seed node.
If used, the :obj:`col` vector needs to be sorted according to time
within individual neighborhoods. Requires :obj:`disjoint=True`.
within individual neighborhoods.
Requires :obj:`disjoint=True`.
Only either :obj:`node_time` or :obj:`edge_time` can be specified.
(default: :obj:`None`)
edge_time (torch.Tensor, optional): Timestamps for the edges in the
graph. If set, temporal sampling will be used such that neighbors
are guaranteed to fulfill temporal constraints, *i.e.* sampled
edge_time: Timestamps for the edges in the graph.
If set, temporal sampling will be used such that neighbors are
guaranteed to fulfill temporal constraints, *i.e.* sampled
edges have an earlier or equal timestamp than the seed node.
If used, the :obj:`col` vector needs to be sorted according to time
within individual neighborhoods. Requires :obj:`disjoint=True`.
within individual neighborhoods.
Requires :obj:`disjoint=True`.
Only either :obj:`node_time` or :obj:`edge_time` can be specified.
(default: :obj:`None`)
seed_time (torch.Tensor, optional): Optional values to override the
timestamp for seed nodes. If not set, will use timestamps in
:obj:`node_time` as default for seed nodes.
seed_time: Optional values to override the timestamp for seed nodes.
If not set, will use timestamps in :obj:`node_time` as default for
seed nodes.
Needs to be specified in case edge-level sampling is used via
:obj:`edge_time`. (default: :obj:`None`)
edge_weight (torch.Tensor, optional): If given, will perform biased
sampling based on the weight of each edge. (default: :obj:`None`)
csc (bool, optional): If set to :obj:`True`, assumes that the graph is
given in CSC format :obj:`(colptr, row)`. (default: :obj:`False`)
replace (bool, optional): If set to :obj:`True`, will sample with
replacement. (default: :obj:`False`)
directed (bool, optional): If set to :obj:`False`, will include all
edges between all sampled nodes. (default: :obj:`True`)
disjoint (bool, optional): If set to :obj:`True` , will create disjoint
subgraphs for every seed node. (default: :obj:`False`)
temporal_strategy (string, optional): The sampling strategy when using
temporal sampling (:obj:`"uniform"`, :obj:`"last"`).
(default: :obj:`"uniform"`)
return_edge_id (bool, optional): If set to :obj:`False`, will not
return the indices of edges of the original graph.
(default: :obj: `True`)
:obj:`edge_time`.
edge_weight: If given, will perform biased sampling based on the weight
of each edge.
csc: If set to :obj:`True`, assumes that the graph is given in CSC
format :obj:`(colptr, row)`.
replace: If set to :obj:`True`, will sample with replacement.
directed: If set to :obj:`False`, will include all edges between all
sampled nodes.
disjoint: If set to :obj:`True` , will create disjoint subgraphs for
every seed node.
temporal_strategy: The sampling strategy when using temporal sampling
(:obj:`"uniform"`, :obj:`"last"`).
return_edge_id: If set to :obj:`False`, will not return the indices of
edges of the original graph.
Returns:
(torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor],
List[int], List[int]):
Row indices, col indices of the returned subtree/subgraph, as well as
original node indices for all nodes sampled.
In addition, may return the indices of edges of the original graph.
@@ -176,16 +171,16 @@ def subgraph(
:obj:`(rowptr, col)`, containing only the nodes in :obj:`nodes`.
Args:
rowptr (torch.Tensor): Compressed source node indices.
col (torch.Tensor): Target node indices.
nodes (torch.Tensor): Node indices of the induced subgraph.
return_edge_id (bool, optional): If set to :obj:`False`, will not
rowptr: Compressed source node indices.
col: Target node indices.
nodes: Node indices of the induced subgraph.
return_edge_id: If set to :obj:`False`, will not
return the indices of edges of the original graph contained in the
induced subgraph. (default: :obj:`True`)
induced subgraph.
Returns:
(torch.Tensor, torch.Tensor, Optional[torch.Tensor]): Compressed source
node indices and target node indices of the induced subgraph.
Compressed source node indices and target node indices of the induced
subgraph.
In addition, may return the indices of edges of the original graph.
"""
return torch.ops.pyg.subgraph(rowptr, col, nodes, return_edge_id)
@@ -205,19 +200,17 @@ def random_walk(
<https://arxiv.org/abs/1607.00653>`_ paper.
Args:
rowptr (torch.Tensor): Compressed source node indices.
col (torch.Tensor): Target node indices.
seed (torch.Tensor): Seed node indices from where random walks start.
walk_length (int): The walk length of a random walk.
p (float, optional): Likelihood of immediately revisiting a node in the
walk. (default: :obj:`1.0`)
q (float, optional): Control parameter to interpolate between
breadth-first strategy and depth-first strategy.
(default: :obj:`1.0`)
rowptr: Compressed source node indices.
col: Target node indices.
seed: Seed node indices from where random walks start.
walk_length: The walk length of a random walk.
p: Likelihood of immediately revisiting a node in the walk.
q: Control parameter to interpolate between breadth-first strategy and
depth-first strategy.
Returns:
torch.Tensor: A tensor of shape :obj:`[seed.size(0), walk_length + 1]`,
holding the nodes indices of each walk for each seed node.
A tensor of shape :obj:`[seed.size(0), walk_length + 1]`, holding the
nodes indices of each walk for each seed node.
"""
return torch.ops.pyg.random_walk(rowptr, col, seed, walk_length, p, q)

0 comments on commit 08a5582

Please sign in to comment.