Skip to content

Commit

Permalink
Add PyTensor backend
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Feb 4, 2025
1 parent 253545a commit ea5d803
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 4 deletions.
6 changes: 5 additions & 1 deletion .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@ jobs:
# cupy is not tested because it demands gpu
# oneflow testing is dropped, see details at https://github.com/Oneflow-Inc/oneflow/issues/10340
# paddle was switched off because of divergence with numpy in py3.10, paddle==2.6.1
frameworks: ['numpy pytorch tensorflow jax']
# The last pytensor release that supports python 3.8 doesn't include einsum, so we skip that combination.
frameworks: ['numpy pytorch tensorflow jax', 'pytensor']
exclude:
- python-version: '3.8'
frameworks: 'pytensor'

steps:
- uses: actions/checkout@v4
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ Einops works with ...
- [paddle](https://github.com/PaddlePaddle/Paddle) (community)
- [oneflow](https://github.com/Oneflow-Inc/oneflow) (community)
- [tinygrad](https://github.com/tinygrad/tinygrad) (community)
- [pytensor](https://github.com/pymc-devs/pytensor) (community)

Additionally, einops can be used with any framework that supports
[Python array API standard](https://data-apis.org/array-api/latest/API_specification/index.html),
Expand Down
55 changes: 55 additions & 0 deletions einops/_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,3 +662,58 @@ def is_float_type(self, x):

def einsum(self, pattern, *x):
return self.tinygrad.Tensor.einsum(pattern, *x)


class PyTensorBackend(AbstractBackend):
framework_name = "pytensor"

def __init__(self):
from pytensor import tensor

self.pt = tensor

def is_appropriate_type(self, tensor):
return isinstance(tensor, self.pt.TensorVariable)

def is_float_type(self, x):
return x.dtype in self.pt.type.float_dtypes

def from_numpy(self, x):
return self.pt.as_tensor(x)

def to_numpy(self, x):
return x.eval() # Will only work if there are no symbolic inputs

def create_symbol(self, shape):
if not isinstance(shape, tuple | list):
shape = (shape,)
return self.pt.tensor(shape=shape)

def eval_symbol(self, symbol, input_dict):
# input_dict is actually a list of tuple?
return symbol.eval(dict(input_dict))

def arange(self, start, stop):
return self.pt.arange(start, stop)

def shape(self, x):
# use the static shape dimensions where known
return tuple(
static_dim if static_dim is not None else symbolic_dim
for static_dim, symbolic_dim in zip(x.type.shape, x.shape)
)

def stack_on_zeroth_dimension(self, tensors: list):
return self.pt.stack(tensors)

def tile(self, x, repeats):
return self.pt.tile(x, repeats)

def concat(self, tensors, axis: int):
return self.pt.concatenate(tensors, axis=axis)

def add_axis(self, x, new_position):
return self.pt.expand_dims(x, new_position)

def einsum(self, pattern, *x):
return self.pt.einsum(pattern, *x)
4 changes: 3 additions & 1 deletion einops/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ def collect_test_backends(symbolic=False, layers=False) -> List[_backends.Abstra
]
else:
if not layers:
backend_types = []
backend_types = [
_backends.PyTensorBackend,
]
else:
backend_types = [
_backends.TFKerasBackend,
Expand Down
1 change: 1 addition & 0 deletions einops/tests/run_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def main():
# "paddle": ["paddlepaddle==0.0.0 -f https://www.paddlepaddle.org.cn/whl/linux/cpu-mkl/develop.html"],
"paddle": ["paddlepaddle"],
"oneflow": ["oneflow==0.9.0"],
"pytensor": ["pytensor"],
}

usage = f"""
Expand Down
3 changes: 2 additions & 1 deletion einops/tests/test_einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ def test_layer():
"cupy",
"tensorflow.keras",
"paddle",
"pytensor",
]


Expand Down Expand Up @@ -254,7 +255,7 @@ def test_functional_symbolic():
)
if predicted_out_data.shape != out_shape:
raise ValueError(f"Expected output shape {out_shape} but got {predicted_out_data.shape}")
assert np.testing.assert_array_almost_equal(predicted_out_data, expected_out_data, decimal=5)
np.testing.assert_array_almost_equal(predicted_out_data, expected_out_data, decimal=5)


def test_functional_errors():
Expand Down
2 changes: 1 addition & 1 deletion einops/tests/test_other.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def test_parse_shape_symbolic_ellipsis(backend):
for static_shape, shape, pattern, expected in [
([10, 20], [None, None], "...", dict()),
([10], [None], "... a", dict(a=10)),
([10, 20], [None], "... a", dict(a=20)),
([10, 20], [None, None], "... a", dict(a=20)),
([10, 20, 30], [None, None, None], "... a", dict(a=30)),
([10, 20, 30, 40], [None, None, None, None], "... a", dict(a=40)),
([10], [None], "a ...", dict(a=10)),
Expand Down

0 comments on commit ea5d803

Please sign in to comment.