Skip to content

Commit

Permalink
Update README.md
Browse files Browse the repository at this point in the history
  • Loading branch information
ayaka14732 committed Mar 29, 2022
1 parent 1188ce1 commit de2848b
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 33 deletions.
53 changes: 20 additions & 33 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,17 @@ You can learn more about the TRC program on its [homepage](https://sites.researc

## 2. Environment Setup

### 2.1. Create a TPU instance
### 2.1. Modify VPC firewall

Open the [Firewall management page](https://console.cloud.google.com/networking/firewalls/list) in VPC network.

Click the button to create a new firewall rule.

![](assets/2.png)

Set name to allow-all, targets to 'All instances in the network', source filter to 0.0.0.0/0, protocols and ports to 'Allow all', and then click 'Create'.

### 2.2. Create a TPU instance

Open [Google Cloud Platform](https://cloud.google.com/tpu), navigate to the [TPU management page](https://console.cloud.google.com/compute/tpus).

Expand All @@ -62,19 +72,17 @@ gcloud alpha compute tpus tpu-vm create node-1 --project tpu-develop --zone=euro

If the command fails because there are no more TPUs to allocate, you can re-run the command again.

### 2.2. Modify VPC firewall

TODO: Add a screenshot.
### 2.3. Add public key to the server

### 2.3. Basic configurations

Before you can SSH into the Cloud VM, you need to login by the `gcloud` command:
In Cloud Shell, login to the Cloud VM by the `gcloud` command:

```sh
gcloud alpha compute tpus tpu-vm ssh node-1 --zone europe-west4-a
```

After logging in, you can add your public key to `~/.ssh/authorized_keys`.
After logging in, add your public key to `~/.ssh/authorized_keys`.

### 2.4. Basic configurations

Install packages:

Expand Down Expand Up @@ -113,23 +121,12 @@ pip install -r requirements.txt

### 2.4. How can I verify that the TPU is working?

Run:

```python
import jax
import jax.numpy as np
import jax.random as rand

print(jax.devices()) # should print TPU

key = rand.PRNGKey(42)
a = np.array([1, 2, 3])

key, *subkey = rand.split(key, num=3)
a = rand.uniform(subkey[0], shape=(10000, 100000))
b = rand.uniform(subkey[1], shape=(100000, 10000))

c = np.dot(a, b)
print(c.shape)
a.device() # should print TpuDevice
```

### 2.5. Development environment
Expand Down Expand Up @@ -245,16 +242,6 @@ torch.dot(a_, b_) # error: 1D tensors expected, but got 3D and 3D tensors

[google/jax#9973](https://github.com/google/jax/issues/9973).

## 7. More Resources about TPU

Libraries:

- [Hugging Face Accelerate](https://github.com/huggingface/accelerate) - accelerate PyTorch code on TPU (but PyTorch's performance on TPU is not ideal)

Tutorials:

- https://github.com/shawwn/jaxnotes/blob/master/notebooks/001_jax.ipynb

Community:
## 7. Community

As of 23 Feb, 2022, there is no official chat group for Cloud TPUs. You can join our unofficial chat group [@cloudtpu](https://t.me/cloudtpu) on Telegram.
As of 23 Feb, 2022, there is no official chat group for Cloud TPUs. You can join my chat group [@cloudtpu](https://t.me/cloudtpu) on Telegram.
Binary file added assets/2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit de2848b

Please sign in to comment.