diff --git a/experimental/torch_xla2/README.md b/experimental/torch_xla2/README.md index 5cb283ab288..dc5a1fdffce 100644 --- a/experimental/torch_xla2/README.md +++ b/experimental/torch_xla2/README.md @@ -54,8 +54,13 @@ Note: `dev-requirements.txt` will install the CPU-only version of PyTorch. #### 1.1 Install this package -Install `torch_xla2` from source for your platform: +If you want to install torch_xla2 without the jax dependency and use the jax dependency from torch_xla: +```bash +pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html +pip install -e . +``` +Otherwise, install `torch_xla2` from source for your platform: ```bash pip install -e .[cpu] pip install -e .[cuda] diff --git a/experimental/torch_xla2/pyproject.toml b/experimental/torch_xla2/pyproject.toml index 25e5d02db9e..14b77ad0216 100644 --- a/experimental/torch_xla2/pyproject.toml +++ b/experimental/torch_xla2/pyproject.toml @@ -8,7 +8,6 @@ name = "torch_xla2" dependencies = [ "absl-py", "immutabledict", - "jax[cpu]>=0.4.24", "pytest", "tensorflow-cpu", # Developers should install `dev-requirements.txt` first @@ -18,10 +17,10 @@ requires-python = ">=3.10" license = {file = "LICENSE"} [project.optional-dependencies] -cpu = ["jax[cpu]"] +cpu = ["jax[cpu]>=0.4.24", "jax[cpu]"] # Add libtpu index `-f https://storage.googleapis.com/libtpu-releases/index.html` -tpu = ["jax[tpu]"] -cuda = ["jax[cuda12]"] +tpu = ["jax[cpu]>=0.4.24", "jax[tpu]"] +cuda = ["jax[cpu]>=0.4.24", "jax[cuda12]"] [tool.pytest.ini_options] addopts="-n auto"