Skip to content

Commit

Permalink
remove torch_xla2 jax dependency in native install (#7126)
Browse files Browse the repository at this point in the history
  • Loading branch information
zpcore authored May 29, 2024
1 parent 7770a49 commit ed90be1
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
7 changes: 6 additions & 1 deletion experimental/torch_xla2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,13 @@ Note: `dev-requirements.txt` will install the CPU-only version of PyTorch.

#### 1.1 Install this package

Install `torch_xla2` from source for your platform:
If you want to install torch_xla2 without the jax dependency and use the jax dependency from torch_xla:
```bash
pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
pip install -e .
```

Otherwise, install `torch_xla2` from source for your platform:
```bash
pip install -e .[cpu]
pip install -e .[cuda]
Expand Down
7 changes: 3 additions & 4 deletions experimental/torch_xla2/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ name = "torch_xla2"
dependencies = [
"absl-py",
"immutabledict",
"jax[cpu]>=0.4.24",
"pytest",
"tensorflow-cpu",
# Developers should install `dev-requirements.txt` first
Expand All @@ -18,10 +17,10 @@ requires-python = ">=3.10"
license = {file = "LICENSE"}

[project.optional-dependencies]
cpu = ["jax[cpu]"]
cpu = ["jax[cpu]>=0.4.24", "jax[cpu]"]
# Add libtpu index `-f https://storage.googleapis.com/libtpu-releases/index.html`
tpu = ["jax[tpu]"]
cuda = ["jax[cuda12]"]
tpu = ["jax[cpu]>=0.4.24", "jax[tpu]"]
cuda = ["jax[cpu]>=0.4.24", "jax[cuda12]"]

[tool.pytest.ini_options]
addopts="-n auto"
Expand Down

0 comments on commit ed90be1

Please sign in to comment.