diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..5ffed53 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "tpu-tools"] + path = tpu-tools + url = git@github.com:AshishKumar4/tpu-tools.git diff --git a/gcsfuse.sh b/gcsfuse.sh deleted file mode 100755 index de6235f..0000000 --- a/gcsfuse.sh +++ /dev/null @@ -1,52 +0,0 @@ -#!/bin/bash - -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Description: -# bash setup_gcsfuse.sh DATASET_GCS_BUCKET=maxtext-dataset MOUNT_PATH=dataset - -set -e - -# Set environment variables -for ARGUMENT in "$@"; do - IFS='=' read -r KEY VALUE <<< "$ARGUMENT" - export "$KEY"="$VALUE" - echo "$KEY"="$VALUE" -done - -if [[ -z ${DATASET_GCS_BUCKET} || -z ${MOUNT_PATH} ]]; then - echo "Please set arguments: DATASET_GCS_BUCKET and MOUNT_PATH" - exit 1 -fi - -if [[ "$DATASET_GCS_BUCKET" =~ gs:\/\/ ]] ; then - DATASET_GCS_BUCKET="${DATASET_GCS_BUCKET/gs:\/\//}" - echo "Removed gs:// from GCS bucket name, GCS bucket is $DATASET_GCS_BUCKET" -fi - -if [[ -d ${MOUNT_PATH} ]]; then - echo "$MOUNT_PATH exists, removing..." - fusermount -u $MOUNT_PATH || rm -rf $MOUNT_PATH -fi - -mkdir -p $MOUNT_PATH - -# see https://cloud.google.com/storage/docs/gcsfuse-cli for all configurable options of gcsfuse CLI -# Grain uses _PROCESS_MANAGEMENT_MAX_THREADS = 64 (https://github.com/google/grain/blob/main/grain/_src/python/grain_pool.py) -# Please make sure max-conns-per-host > grain_worker_count * _PROCESS_MANAGEMENT_MAX_THREADS - -gcsfuse -o ro --implicit-dirs --http-client-timeout=5s --max-conns-per-host=0 --max-idle-conns-per-host=10000 \ - --experimental-enable-json-read --kernel-list-cache-ttl-secs=-1 -o ro --config-file=$HOME/gcsfuse.yml \ - --log-file=$HOME/gcsfuse.json "$DATASET_GCS_BUCKET" "$MOUNT_PATH" \ No newline at end of file diff --git a/tpu-tools b/tpu-tools new file mode 160000 index 0000000..e30a21d --- /dev/null +++ b/tpu-tools @@ -0,0 +1 @@ +Subproject commit e30a21d4598d1d47d5c7b3eb448afda4a1f85249 diff --git a/tpu_utils/README.md b/tpu_utils/README.md deleted file mode 100644 index e672acb..0000000 --- a/tpu_utils/README.md +++ /dev/null @@ -1,3 +0,0 @@ -# Code mostly borrowed from https://github.com/tensorflow/tpu/tree/master/tools/ray_tpu/legacy - -To make life easier when working with Cloud TPUs \ No newline at end of file diff --git a/tpu_utils/gcsfuse.sh b/tpu_utils/gcsfuse.sh deleted file mode 100755 index de6235f..0000000 --- a/tpu_utils/gcsfuse.sh +++ /dev/null @@ -1,52 +0,0 @@ -#!/bin/bash - -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Description: -# bash setup_gcsfuse.sh DATASET_GCS_BUCKET=maxtext-dataset MOUNT_PATH=dataset - -set -e - -# Set environment variables -for ARGUMENT in "$@"; do - IFS='=' read -r KEY VALUE <<< "$ARGUMENT" - export "$KEY"="$VALUE" - echo "$KEY"="$VALUE" -done - -if [[ -z ${DATASET_GCS_BUCKET} || -z ${MOUNT_PATH} ]]; then - echo "Please set arguments: DATASET_GCS_BUCKET and MOUNT_PATH" - exit 1 -fi - -if [[ "$DATASET_GCS_BUCKET" =~ gs:\/\/ ]] ; then - DATASET_GCS_BUCKET="${DATASET_GCS_BUCKET/gs:\/\//}" - echo "Removed gs:// from GCS bucket name, GCS bucket is $DATASET_GCS_BUCKET" -fi - -if [[ -d ${MOUNT_PATH} ]]; then - echo "$MOUNT_PATH exists, removing..." - fusermount -u $MOUNT_PATH || rm -rf $MOUNT_PATH -fi - -mkdir -p $MOUNT_PATH - -# see https://cloud.google.com/storage/docs/gcsfuse-cli for all configurable options of gcsfuse CLI -# Grain uses _PROCESS_MANAGEMENT_MAX_THREADS = 64 (https://github.com/google/grain/blob/main/grain/_src/python/grain_pool.py) -# Please make sure max-conns-per-host > grain_worker_count * _PROCESS_MANAGEMENT_MAX_THREADS - -gcsfuse -o ro --implicit-dirs --http-client-timeout=5s --max-conns-per-host=0 --max-idle-conns-per-host=10000 \ - --experimental-enable-json-read --kernel-list-cache-ttl-secs=-1 -o ro --config-file=$HOME/gcsfuse.yml \ - --log-file=$HOME/gcsfuse.json "$DATASET_GCS_BUCKET" "$MOUNT_PATH" \ No newline at end of file diff --git a/tpu_utils/ray_tpu_controller.py b/tpu_utils/ray_tpu_controller.py deleted file mode 100644 index 24b0f1a..0000000 --- a/tpu_utils/ray_tpu_controller.py +++ /dev/null @@ -1,304 +0,0 @@ -# Copyright 2023 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -"""Ray-based TPU controller from an admin CPU VM.""" -import collections -import dataclasses -import time -from typing import List, Optional, Mapping, Any -from absl import logging - -import ray -from ray.dashboard.modules.job.sdk import JobSubmissionClient -from ray.experimental.state import api as state_api -from ray.job_submission import JobStatus -import tpu_controller - -BASE_JAX_PIP_INSTALLS = [ - "jax[tpu]", - "-f https://storage.googleapis.com/jax-releases/libtpu_releases.html", -] -_DEFAULT_RAY_PORT = 6379 - - -# TODO(allencwang) - merge with TpuRayJob -@dataclasses.dataclass -class RayRuntimeEnv: - """Representation of a runtime environment.""" - - pip: str - working_dir: str - - -@dataclasses.dataclass -class TpuRayJob: - """Representation of a Tpu-based Ray Job.""" - - entrypoint: str - working_dir: str - pip_installs: List[str] = dataclasses.field(default_factory=list) - env_vars: Mapping[str, str] = None - entrypoint_resources: Mapping[str, int] = None - - def to_ray_job(self) -> Mapping[str, Any]: - return dict( - entrypoint=self.entrypoint, - runtime_env=dict( - working_dir=self.working_dir, - pip=self.pip_installs, - env_vars=self.env_vars, - ), - entrypoint_resources=self.entrypoint_resources, - ) - - -class RayTpuController(tpu_controller.TPUController): - """Ray-based TPU controller. - - By default, `RayTpuController` spins up a ray cluster by appending the Ray - startup commands to the TPU startup script, e.g.: - ``` - controller = RayTpuController(...) - controller.maybe_create_and_wait_for_ready() - # continues once all TPU workers have joined the Ray cluster. - ``` - - If the TPU was already created outside of `RayTpuController`, we have the - ability to start the Ray cluster via: - ``` - controller = RayTpuController(...) - controller.maybe_start_ray_on_workers() - # continues once all TPU workers have joined the Ray cluster. - ``` - - Attributes: - startup_script: an optional set of commands that will be concatenated to run - on TPU VM startup. - """ - - def __init__( - self, - tpu_name: str, - startup_script: Optional[List[str]] = None, - runtime_env: Optional[RayRuntimeEnv] = None, - head_addr: Optional[str] = None, - **kwargs, - ): - if not ray.is_initialized(): - if runtime_env: - result = ray.init(runtime_env=dataclasses.asdict(runtime_env)) - else: - result = ray.init() - self._head_addr = result.address_info["address"] - if head_addr: - self._head_addr = head_addr - self.resource_name = f"{tpu_name}_tpu_host" - ray_setup = self.get_ray_setup_commands() - self._job_client = None - if startup_script: - startup_script = startup_script + ray_setup - else: - startup_script = ray_setup - self._queued_jobs = [] - self._live_nodes = set() - super().__init__(tpu_name=tpu_name, startup_script=startup_script, **kwargs) - - @property - def queued_jobs(self): - return self._queued_jobs - - def maybe_start_ray_on_workers(self): - if self.tpu_hosts_joined_cluster(): - logging.info("Ray already started on each host.") - else: - logging.info("Manually starting Ray on each workers.") - self.run_commands_on_workers(self.get_ray_setup_commands()) - - @property - def job_client(self) -> JobSubmissionClient: - if not self._job_client: - self._job_client = JobSubmissionClient() - return self._job_client - - def get_ray_setup_commands(self) -> List[str]: - return [ - "mkdir -p /dev/shm", - "sudo mount -t tmpfs -o size=100g tmpfs /dev/shm", - "sudo pip3 install ray[default]", - "ray start --resources='{\"%s\": 1}' --address=%s" - % (self.resource_name, self._head_addr), - ] - - def tpu_hosts_joined_cluster(self) -> bool: - ray_nodes = state_api.list_nodes( - limit=10000, filters=[("state", "=", "ALIVE")] - ) - self._live_nodes.clear() - ips_addresses = self.get_ip_addresses() - for node in ray_nodes: - if ( - node.get("resources_total") - and node["resources_total"].get(self.resource_name) == 1 - and node["node_ip"] in ips_addresses - ): - self._live_nodes.add(node["node_id"]) - num_registered_tpu_hosts = len(self._live_nodes) - logging.info( - "Detected %d TPU hosts in cluster, expecting %d hosts in total", - num_registered_tpu_hosts, - self.get_num_nodes(), - ) - return num_registered_tpu_hosts == self.get_num_nodes() - - def maybe_create_and_wait_for_ready( - self, recreate_after_num_trials=5 - ) -> None: - """Creates TPU if not exists and waits for all nodes to join the cluster. - - Firstly, it checks TPU exists or not, if not, it will create one. - It will wait for all the nodes to join, if all nodes fail to join after - `recreate_after_num_trials` trials, it will try to recreate the TPU. The - threshold `recreate_after_num_trials` will be doubled each time TPU is - recreated. - - Args: - recreate_after_num_trials: the trail threshold for TPU recreation. - """ - if not self.tpu_exists(): - logging.warn("TPU is not found, create tpu...") - self.create_tpu() - num_trials = 0 - self.maybe_create_tpu() - while not self.tpu_hosts_joined_cluster(): - if num_trials >= recreate_after_num_trials: - logging.info("Tried %d times, recreating TPU VM ...", num_trials) - if self.tpu_exists(): - self.delete_tpu() - self.create_tpu() - recreate_after_num_trials *= 2 - logging.info( - "Will try to recreate TPU VM after %d trials.", - recreate_after_num_trials, - ) - num_trials = 0 - continue - logging.info("Waiting for 30s for TPU hosts to join cluster...") - num_trials += 1 - time.sleep(30) - - def queue_tpu_workload(self, job: TpuRayJob, reset_queue=False): - if reset_queue: - self._queued_jobs = [] - job.entrypoint_resources = {self.resource_name: 1} - for _ in range(self.get_num_nodes()): - self._queued_jobs.append(self.job_client.submit_job(**job.to_ray_job())) - logging.info("Queued %d jobs.", len(self._queued_jobs)) - - def job_queued_and_healthy(self) -> bool: - """Checks jobs are queued and healthy. - - Returns: - True if all the ondtions are met: - - job number matches node number - - all jobs are in healthy status - - all jobs are scheduled in live nodes. - False otherwise. - """ - if len(self._queued_jobs) != self.get_num_nodes(): - logging.warn( - "Detected %d jobs, expecting %d jobs.", - len(self._queued_jobs), - self.get_num_nodes(), - ) - return False - for job in self._queued_jobs: - job_info = self.job_client.get_job_info(job) - if job_info.status in {JobStatus.STOPPED, JobStatus.FAILED}: - logging.warn("Detected job %s %s.", job, job_info.status) - return False - if ( - job_info.status in {JobStatus.RUNNING, JobStatus.PENDING} - and job_info.driver_node_id - and job_info.driver_node_id not in self._live_nodes - ): - logging.warn( - "Detected job %s running on stale node %s.", - job, - job_info.driver_node_id, - ) - return False - return True - - def clean_stale_jobs(self, resource_name: str) -> None: - """Stops all the jobs with the same entrypoint but not in the job queue.""" - num_jobs_to_stop = 0 - for job in state_api.list_jobs(): - if ( - job["entrypoint_resources"] is None - or job["entrypoint_resources"].get(resource_name, 0) != 1 - ): - continue - if job["status"] not in {"RUNNING", "PENDING"}: - continue - job_id = job["job_id"] - if job_id in self._queued_jobs: - continue - # If node is dead, the job status may still be shown as running and - # occupying the resource. Getting job logs will force head node talk to - # dead node and mark the job as failed. TODO(yejingxin) raise the issue in - # ray github - try: - self.job_client.get_job_logs(job_id) - self.job_client.stop_job(job_id) - num_jobs_to_stop += 1 - except RuntimeError: - logging.warn("%s is not reachable due to stale node.", job_id) - except TimeoutError: - logging.warn("%s is not reachable due to stale node.", job_id) - if num_jobs_to_stop > 0: - logging.info( - "Requested to clean up %d stale jobs from previous failures.", - num_jobs_to_stop, - ) - - async def print_job_log(self) -> None: - if not self._queued_jobs: - return - async for line in self.job_client.tail_job_logs(self._queued_jobs[0]): - print(line, end="") - - def jobs_in_status(self, status) -> bool: - counter = collections.Counter( - (self.job_client.get_job_status(job) for job in self._queued_jobs) - ) - logging.info("TPU %s Job status: %s", self.tpu_name, counter) - return counter.get(status) == len(self._queued_jobs) - - def wait_until_tpu_job_completed(self, poll_timeout_in_s=10): - while self._queued_jobs: - for job in self._queued_jobs: - status = self.job_client.get_job_status(job) - logging.info("[ADMIN]: %s: Status is %s", job, status) - logs = self.job_client.get_job_logs(job) - logging.info("[%s]: %s", job, logs) - if status.is_terminal(): - self._queued_jobs.remove(job) - else: - logging.info("[ADMIN]: Sleeping for %ds.", poll_timeout_in_s) - time.sleep(poll_timeout_in_s) - - def run_tpu_workload(self, job: TpuRayJob): - self.queue_tpu_workload(job) - self.wait_until_tpu_job_completed() \ No newline at end of file diff --git a/tpu_utils/setup_tpu.sh b/tpu_utils/setup_tpu.sh deleted file mode 100755 index 6e36561..0000000 --- a/tpu_utils/setup_tpu.sh +++ /dev/null @@ -1,152 +0,0 @@ -#!/bin/bash - -# Install JAX and Flax -pip install jax[tpu] flax[all] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html - -# Install CPU version of tensorflow -pip install tensorflow[cpu] keras orbax optax clu grain augmax transformers opencv-python pandas tensorflow-datasets jupyterlab python-dotenv scikit-learn termcolor wrapt wandb - -pip install flaxdiff - -wget https://secure.nic.cz/files/knot-resolver/knot-resolver-release.deb -sudo dpkg -i knot-resolver-release.deb -sudo apt update -sudo apt install -y knot-resolver -sudo sh -c 'echo `hostname -I` `hostname` >> /etc/hosts' -sudo sh -c 'echo nameserver 127.0.0.1 > /etc/resolv.conf' - -# Backup the original resolv.conf -sudo cp /etc/resolv.conf /etc/resolv.conf.bak - -# Define the new nameservers -nameservers=( - "nameserver 127.0.0.1" - "nameserver 8.8.8.8" - "nameserver 8.8.4.4" - "nameserver 76.76.2.0" - "nameserver 76.76.10.0" - "nameserver 9.9.9.9" - "nameserver 1.1.1.1" - "nameserver 1.0.0.1" -) - -# Clear the existing resolv.conf file -sudo sh -c '> /etc/resolv.conf' - -# Add each nameserver to the resolv.conf file -for ns in "${nameservers[@]}"; do - sudo sh -c "echo \"$ns\" >> /etc/resolv.conf" -done -echo "Nameservers added to /etc/resolv.conf" - -sudo systemctl stop systemd-resolved -sudo systemctl start kresd@1.service -sudo systemctl start kresd@2.service -sudo systemctl start kresd@3.service -sudo systemctl start kresd@4.service -sudo systemctl start kresd@5.service -sudo systemctl start kresd@6.service -sudo systemctl start kresd@7.service -sudo systemctl start kresd@8.service -sudo systemctl start kresd@9.service -sudo systemctl start kresd@10.service -sudo systemctl start kresd@11.service -sudo systemctl start kresd@12.service -sudo systemctl start kresd@13.service -sudo systemctl start kresd@14.service -sudo systemctl start kresd@15.service -sudo systemctl start kresd@16.service - -# Installing and setting up gcsfuse -export GCSFUSE_REPO=gcsfuse-`lsb_release -c -s` -echo "deb [signed-by=/usr/share/keyrings/cloud.google.asc] https://packages.cloud.google.com/apt $GCSFUSE_REPO main" | sudo tee /etc/apt/sources.list.d/gcsfuse.list -curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo tee /usr/share/keyrings/cloud.google.asc -sudo apt update -sudo apt install -y gcsfuse libgl1 - -# Define the file name -gcsfuse_conf="$HOME/gcsfuse.yml" - -# Define the contents of the file -gcsfuse_conf_content=$(cat < $gcsfuse_conf - -ulimit -n 65535 - -# Increase the limits of number of open files to unlimited -# Add the limits to /etc/security/limits.conf -limits_conf="/etc/security/limits.conf" -sudo bash -c "cat <> $limits_conf -* soft nofile unlimited -* hard nofile unlimited -EOF" - -# Create a systemd override directory if it doesn't exist -systemd_override_dir="/etc/systemd/system.conf.d" -sudo mkdir -p $systemd_override_dir - -# Add the limits to the systemd service configuration -systemd_limits_conf="$systemd_override_dir/99-nofile.conf" -sudo bash -c "cat < $systemd_limits_conf -[Manager] -DefaultLimitNOFILE=infinity -EOF" - -# Reload the systemd configuration -sudo systemctl daemon-reload - -# Check for --mount-gcs argument -for arg in "$@" -do - case $arg in - --mount-gcs=*) - GCS_BUCKET="${arg#*=}" - shift - ;; - --dev) - DEV_MODE=true - shift - ;; - esac -done - -if [ -n "$GCS_BUCKET" ]; then - # URL of the file to download - FILE_URL="https://raw.githubusercontent.com/AshishKumar4/FlaxDiff/main/datasets/gcsfuse.sh" - # Local path to save the downloaded file - LOCAL_FILE="gcsfuse.sh" - - # Download the file - curl -o $LOCAL_FILE $FILE_URL - - # Make the script executable - chmod +x $LOCAL_FILE - echo "Mounting GCS bucket: $GCS_BUCKET to $HOME/gcs_mount" - # Run the script with the specified arguments - ./$LOCAL_FILE DATASET_GCS_BUCKET=$GCS_BUCKET MOUNT_PATH=$HOME/gcs_mount -fi - -if [ "$DEV_MODE" = true ]; then - # Create 'research' directory in the home folder - mkdir -p $HOME/research - - # Clone the repository into the 'research' directory - git clone git@github.com:AshishKumar4/FlaxDiff.git $HOME/research -else - # Download the training.py file into the home folder - wget -O $HOME/training.py https://github.com/AshishKumar4/FlaxDiff/raw/main/training.py -fi \ No newline at end of file diff --git a/tpu_utils/tpu_api.py b/tpu_utils/tpu_api.py deleted file mode 100644 index d94a6e4..0000000 --- a/tpu_utils/tpu_api.py +++ /dev/null @@ -1,212 +0,0 @@ -# Copyright 2023 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -"""Cloud TPU REST API basic functionality.""" -import os -import subprocess -import time -from typing import Any, Optional, List, Mapping - -import google.auth -import google.auth.transport.requests -import requests - -_TPU_BASE_URL = "https://tpu.googleapis.com/v2alpha1/" - - -def get_headers() -> Mapping[str, str]: - creds, _ = google.auth.default( - scopes=["https://www.googleapis.com/auth/cloud-platform"] - ) - creds.refresh(google.auth.transport.requests.Request()) - return {"Authorization": f"Bearer {creds.token}"} - - -def create_tpu( - tpu_name: str, - accelerator_type: str, - accelerator_topology: str, - zone: str, - project: str, - version: str, - startup_script: Optional[List[str]] = None, - block_until_completion: bool = True, - network: Optional[str] = "default", - subnetwork: Optional[str] = "default", - preemptible: bool = False, - reserved: bool = False, -): - """Creates a Cloud TPU. - - Note that this only supports TPU v4 creation right now due to - usage of acceleratorConfig(accelerator_type+accelerator_topology) rather than - solely accelerator_type. - - Args: - tpu_name: the TPU name. - accelerator_type: the TPU generation, e.g. V4. - accelerator_topology: the topology of the TPU. E.g. '4x4x4' - zone: the GCP zone. - project: the GCP project. - version: the TPU version, e.g. 'tpu_vm_v4_base'. - startup_script: an optional set of commands that will be concatenated to run - on TPU VM startup. - block_until_completion: Whether or not to wait until the operation has - finished running. - network: the network name the tpu_vm will use. - subnetwork: the subnetwork name the tpu_vm will use. - preemptible: whether to create preemptible TPUs. - reserved: whether to create reserved TPUs. - """ - if preemptible and reserved: - raise ValueError( - "Preemptible and Reserved cannot be set to True simultaneously" - ) - - tpu_node_url = os.path.join( - _TPU_BASE_URL, "projects", project, "locations", zone, "nodes" - ) - params = {"nodeId": tpu_name} - accelerator_config = dict(topology=accelerator_topology, type=accelerator_type) - if startup_script: - startup_script = "#! /bin/bash\n" + "\n".join(startup_script) - metadata = {"startup-script": startup_script} - else: - metadata = {} - - request = { - "accelerator_config": accelerator_config, - "runtimeVersion": version, - "networkConfig": { - "enableExternalIps": True, - "network": network, - "subnetwork": subnetwork, - }, - "metadata": metadata, - "schedulingConfig": { - "preemptible": preemptible, - "reserved": reserved, - }, - } - print("Creating TPU: ", tpu_name) - print("Request: ", request) - resp = requests.post( - tpu_node_url, params=params, json=request, headers=get_headers() - ) - resp.raise_for_status() - if block_until_completion: - create_op_url = os.path.join(_TPU_BASE_URL, resp.json()["name"]) - while not resp.json()["done"]: - print("Create TPU operation still running...") - time.sleep(30) - resp = requests.get(create_op_url, headers=get_headers()) - print("Create TPU operation complete.") - - -def list_tpus(project: str, zone: str) -> List[Mapping[str, Any]]: - """Lists all TPUs under a given project and zone. - - Args: - project: the GCP project. - zone: the GCP zone. - - Returns: - a string of JSON objects representing TPU VMs. - """ - tpu_node_url = os.path.join( - _TPU_BASE_URL, "projects", project, "locations", zone, "nodes" - ) - resp = requests.get(tpu_node_url, headers=get_headers()) - return resp.json()["nodes"] - - -def delete_tpu( - tpu_name: str, project: str, zone: str, block_until_completion: bool = True -): - """Deletes a Cloud TPU.""" - tpu_node_url = os.path.join( - _TPU_BASE_URL, "projects", project, "locations", zone, "nodes", tpu_name - ) - print("Deleting TPU: ", tpu_name) - resp = requests.delete(tpu_node_url, headers=get_headers()) - resp.raise_for_status() - if block_until_completion: - delete_op_url = os.path.join(_TPU_BASE_URL, resp.json()["name"]) - while not resp.json()["done"]: - print("Delete TPU operation still running...") - time.sleep(30) - resp = requests.get(delete_op_url, headers=get_headers()) - print("Delete TPU operation complete.") - - -def get_tpu(tpu_name: str, project: str, zone: str) -> Mapping[str, Any]: - """Gets the details of a Cloud TPU VM.""" - tpu_node_url = os.path.join( - _TPU_BASE_URL, "projects", project, "locations", zone, "nodes", tpu_name - ) - resp = requests.get(tpu_node_url, headers=get_headers()) - return resp.json() - - -def tpu_exists(tpu_name: str, project: str, zone: str) -> bool: - """Check whether a tpu exits or not.""" - resp = get_tpu(tpu_name, project, zone) - not_found = ( - "error" in resp - and "status" in resp["error"] - and "NOT_FOUND" == resp["error"]["status"] - ) - return not not_found - - -def update_tpu_startup_script( - tpu_name: str, - project: str, - zone: str, - startup_script: List[str], - block_until_completion: bool = True, -): - """Updates the TPU startup script.""" - tpu_node_url = os.path.join( - _TPU_BASE_URL, "projects", project, "locations", zone, "nodes", tpu_name - ) - params = { - "updateMask": "metadata", - } - startup_script = "#! /bin/bash\n" + "\n".join(startup_script) - metadata = {"startup-script": startup_script} - request = {"metadata": metadata} - print("Updating TPU: ", tpu_name) - print("Request: ", request) - resp = requests.patch( - tpu_node_url, headers=get_headers(), json=request, params=params - ) - resp.raise_for_status() - if block_until_completion: - create_op_url = os.path.join(_TPU_BASE_URL, resp.json()["name"]) - while not resp.json()["done"]: - print("Patch TPU operation still running...") - time.sleep(30) - resp = requests.get(create_op_url, headers=get_headers()) - print("Patch TPU operation complete.") - - -def get_default_gcp_project() -> str: - """Returns the default GCP project set in gcloud config.""" - return str( - subprocess.check_output("gcloud config get-value project", shell=True) - .strip() - .decode("utf-8") - ) diff --git a/tpu_utils/tpu_controller.py b/tpu_utils/tpu_controller.py deleted file mode 100644 index 8cbe9f4..0000000 --- a/tpu_utils/tpu_controller.py +++ /dev/null @@ -1,220 +0,0 @@ -# Copyright 2023 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -"""TPU controller class for common TPU manipulation.""" -import functools -import multiprocessing -import os -import subprocess -from typing import List, Optional, Iterable, Callable, Any, Mapping, Union - -from absl import logging -from fabric import Connection -import patchwork.transfers - -import tpu_api - - -_SSH_KEYS_PATH = os.path.expanduser("~/.ssh/google_compute_engine") - - -def connect(ip_address: str) -> Connection: - return Connection( - ip_address, - connect_kwargs={ - "key_filename": _SSH_KEYS_PATH, - }, - ) - - -class TPUController: - """Generic TPU controller interface. - - Attributes: - tpu_name: the TPU name. - accelerator_type: the TPU generation, e.g. V4. - accelerator_topology: the topology of the TPU. E.g. '4x4x4' - zone: the GCP zone. - project: the GCP project. - version: the TPU version, e.g. 'tpu_vm_v4_base'. - startup_script: an optional set of commands that will be concatenated to run - on TPU VM startup. - """ - - def __init__( - self, - tpu_name: str, - zone: str, - project: str, - accelerator_type: str, - accelerator_topology: str, - version: str, - startup_script: Optional[List[str]], - network: Optional[str] = "default", - subnetwork: Optional[str] = "default", - preemptible: bool = False, - reserved: bool = False, - ): - self._tpu_name = tpu_name - self._zone = zone - self._project = project - self._accelerator_type = accelerator_type - self._accelerator_topology = accelerator_topology - self._version = version - self._startup_script = startup_script - self._ip_addresses = [] - self._connections = {} - self._network = network - self._subnetwork = subnetwork - self._preemptible = preemptible - self._reserved = reserved - - @property - def tpu_name(self) -> str: - return self._tpu_name - - def tpu_exists(self) -> bool: - """Checks if the TPU exists.""" - return tpu_api.tpu_exists( - tpu_name=self._tpu_name, project=self._project, zone=self._zone - ) - - def get_ip_addresses(self) -> List[str]: - """Returns the IP addresses of the workers in the cluster.""" - if not self._ip_addresses: - for endpoint in self.get_tpu()["networkEndpoints"]: - if "ipAddress" in endpoint: - self._ip_addresses.append(endpoint["ipAddress"]) - return self._ip_addresses - - def _maybe_configure_ssh_on_admin(self) -> str: - """Runs the bash command to generate necessary SSH keys on the admin VM.""" - if not os.path.exists(_SSH_KEYS_PATH): - subprocess.check_output("gcloud compute config-ssh", shell=True) - - def get_connections(self) -> Mapping[str, Connection]: - """Returns the mapping between IP and fabric.Connection.""" - if not self._connections: - self._maybe_configure_ssh_on_admin() - for ip_address in self.get_ip_addresses(): - self._connections[ip_address] = connect(ip_address) - return self._connections - - def create_tpu(self): - """Creates the TPU.""" - tpu_api.create_tpu( - tpu_name=self._tpu_name, - zone=self._zone, - project=self._project, - accelerator_type=self._accelerator_type, - accelerator_topology=self._accelerator_topology, - version=self._version, - startup_script=self._startup_script, - network=self._network, - subnetwork=self._subnetwork, - preemptible=self._preemptible, - reserved=self._reserved, - ) - self._ip_addresses.clear() - - def maybe_create_tpu(self) -> bool: - """Creates the TPU if it doesn't exist. - - Returns: - True if the TPU needed to be created, False otherwise. - """ - if not self.tpu_exists(): - self.create_tpu() - return True - return False - - def delete_tpu(self): - """Deletes the TPU.""" - tpu_api.delete_tpu( - tpu_name=self._tpu_name, project=self._project, zone=self._zone - ) - - def get_tpu(self): - """Gets the TPU info.""" - return tpu_api.get_tpu( - tpu_name=self._tpu_name, project=self._project, zone=self._zone - ) - - def get_health(self): - return self.get_tpu()["health"] - - def get_state(self): - return self.get_tpu()["state"] - - def _run_on_worker( - self, ip_address: str, commands: Iterable[str], verbose: bool = True - ): - """Runs command(s) on a single worker.""" - for command in commands: - logging.info("Running %s on %s", command, ip_address) - if command.startswith("sudo"): - # Strip 'sudo' from command - command = command[5:] - output = self.get_connections()[ip_address].sudo(command) - if verbose: - logging.info(f"{ip_address}: " + output.stdout) - else: - output = self.get_connections()[ip_address].run(command) - if verbose: - logging.info(f"{ip_address}: " + output.stdout) - - def _run_per_worker(self, fn: Callable[..., Any]): - """Runs a callable function for all workers.""" - with multiprocessing.Pool(processes=len(self.get_ip_addresses())) as p: - p.map(fn, self.get_ip_addresses()) - - def run_commands_on_workers(self, commands: Iterable[str]): - """Runs a list of commands for all workers.""" - self._run_per_worker(functools.partial(self._run_on_worker, commands=commands)) - - def _copy_files_to_worker(self, ip_address: str, files: Union[str, Iterable[str]]): - """Copies files to a single worker.""" - connection = self.get_connections()[ip_address] - for file in files: - if os.path.isdir(file): - patchwork.transfers.rsync( - connection, file, "~/", exclude=".git", strict_host_keys=False - ) - else: - connection.put(file) - - def copy_files_to_workers(self, files: Union[str, Iterable[str]]): - """Copies files to all workers.""" - if isinstance(files, str): - files = [files] - self._run_per_worker(functools.partial(self._copy_files_to_worker, files=files)) - - def _get_files_from_worker(self, ip_address: str, files: Union[str, Iterable[str]]): - """Gets files from a single worker.""" - connection = self.get_connections()[ip_address] - for file in files: - connection.get(file) - - def get_files_from_workers(self, files: Union[str, Iterable[str]]): - """Gets files from all workers.""" - if isinstance(files, str): - files = [files] - self._run_per_worker( - functools.partial(self._get_files_from_worker, files=files) - ) - - def get_num_nodes(self): - """Returns the number of hosts in the TPU pod.""" - return len(self.get_ip_addresses())