Skip to content

Commit

Permalink
Adding isort (#2375)
Browse files Browse the repository at this point in the history
* adding isort and executing

* add worflows

* move out import kubernetes from utils to skylet in order to avoid circular import

* adding llm folder to yapf

* avoid circular import and isort fix

* add docs in isort

* Apply suggestions from code review

Co-authored-by: Zongheng Yang <[email protected]>

* fix gcp isort

* remove operator import

* Update sky/clouds/__init__.py

Co-authored-by: Zongheng Yang <[email protected]>

* remove isort for providers

* checkout form master

---------

Co-authored-by: Zongheng Yang <[email protected]>
  • Loading branch information
gbmarc1 and concretevitamin authored Aug 11, 2023
1 parent 5dd9aa1 commit 47ecf99
Show file tree
Hide file tree
Showing 120 changed files with 443 additions and 439 deletions.
12 changes: 12 additions & 0 deletions .github/workflows/format.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ jobs:
pip install yapf==0.32.0
pip install toml==0.10.2
pip install black==22.10.0
pip install isort==5.12.0
- name: Running yapf
run: |
yapf --diff --recursive ./ --exclude 'sky/skylet/ray_patches/**' \
Expand All @@ -42,3 +43,14 @@ jobs:
sky/skylet/providers/gcp/ \
sky/skylet/providers/azure/ \
sky/skylet/providers/ibm/
- name: Running isort for black formatted files
run: |
isort --diff --check --profile black -l 88 -m 3 \
sky/skylet/providers/ibm/
- name: Running isort for yapf formatted files
run: |
isort --diff --check ./ --sg 'sky/skylet/ray_patches/**' \
--sg 'sky/skylet/providers/aws/**' \
--sg 'sky/skylet/providers/gcp/**' \
--sg 'sky/skylet/providers/azure/**' \
--sg 'sky/skylet/providers/ibm/**'
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Configuration file for the Sphinx documentation builder.

import sys
import os
import sys

sys.path.insert(0, os.path.abspath('.'))
sys.path.insert(0, os.path.abspath('../'))
Expand Down
3 changes: 2 additions & 1 deletion examples/docker/echo_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
# python echo_app.py

import random
import sky
import string

import sky

with sky.Dag() as dag:
# The setup command to build the container image
setup = 'docker build -t echo:v0 /echo_app'
Expand Down
4 changes: 2 additions & 2 deletions examples/example_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
Incorporate the notion of region/zone (affects pricing).
Incorporate the notion of per-account egress quota (affects pricing).
"""
import sky

import time_estimators

import sky


def make_application():
"""A simple application: train_op -> infer_op."""
Expand Down
3 changes: 2 additions & 1 deletion examples/horovod_distributed_tf_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
import json
from typing import Dict, List

import sky
import time_estimators

import sky

IPAddr = str

with sky.Dag() as dag:
Expand Down
2 changes: 1 addition & 1 deletion examples/local/launch_cloud_onprem.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
import tempfile
import textwrap
import uuid
import yaml

from click import testing as cli_testing
import yaml

from sky import cli
from sky import global_user_state
Expand Down
3 changes: 2 additions & 1 deletion examples/playground/storage_playground.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
# These are not exhaustive tests. Actual Tests are in tests/test_storage.py and
# tests/test_smoke.py.

from sky.data import storage, StoreType
from sky.data import storage
from sky.data import StoreType


def get_args():
Expand Down
11 changes: 5 additions & 6 deletions examples/ray_tune_examples/tune_ptl_example.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
### Source: https://docs.ray.io/en/latest/tune/examples/mnist_ptl_mini.html
import math
import os

import torch
from filelock import FileLock
from torch.nn import functional as F
import pytorch_lightning as pl
from pl_bolts.datamodules.mnist_datamodule import MNISTDataModule
import os
from ray.tune.integration.pytorch_lightning import TuneReportCallback

import pytorch_lightning as pl
from ray import tune
from ray.tune.integration.pytorch_lightning import TuneReportCallback
import torch
from torch.nn import functional as F


class LightningMNISTClassifier(pl.LightningModule):
Expand Down
22 changes: 13 additions & 9 deletions examples/spot/lightning_cifar10/train.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
# Code modified from https://pytorch-lightning.readthedocs.io/en/stable/notebooks/lightning_examples/cifar10-baseline.html

import argparse
import glob
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from pl_bolts.datamodules import CIFAR10DataModule
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
from pytorch_lightning import LightningModule, Trainer, seed_everything
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning import LightningModule
from pytorch_lightning import seed_everything
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import OneCycleLR
from torch.optim.swa_utils import AveragedModel, update_bn
from torch.optim.swa_utils import AveragedModel
from torch.optim.swa_utils import update_bn
from torchmetrics.functional import accuracy

import argparse, glob
import torchvision

seed_everything(7)

Expand Down
16 changes: 7 additions & 9 deletions examples/spot/resnet_ddp/resnet_ddp.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
import argparse
import os
import random

import numpy as np
import torch
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
import torchvision
import torchvision.transforms as transforms

import argparse
import os
import random
import numpy as np

import wandb


Expand Down
4 changes: 2 additions & 2 deletions examples/tpu/tpu_app_code/run_tpu.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import tensorflow_datasets as tfds
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_text as tf_text
from transformers import TFDistilBertForSequenceClassification
from transformers import TFBertForSequenceClassification
from transformers import TFDistilBertForSequenceClassification

tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.config.experimental_connect_to_cluster(tpu)
Expand Down
21 changes: 18 additions & 3 deletions format.sh
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,14 @@ YAPF_EXCLUDES=(
'--exclude' 'sky/skylet/providers/ibm/**'
)

ISORT_YAPF_EXCLUDES=(
'--sg' 'build/**'
'--sg' 'sky/skylet/providers/aws/**'
'--sg' 'sky/skylet/providers/gcp/**'
'--sg' 'sky/skylet/providers/azure/**'
'--sg' 'sky/skylet/providers/ibm/**'
)

BLACK_INCLUDES=(
'sky/skylet/providers/aws'
'sky/skylet/providers/gcp'
Expand Down Expand Up @@ -86,9 +94,12 @@ format_changed() {

# Format all files
format_all() {
yapf --in-place "${YAPF_FLAGS[@]}" "${YAPF_EXCLUDES[@]}" sky tests examples
yapf --in-place "${YAPF_FLAGS[@]}" "${YAPF_EXCLUDES[@]}" sky tests examples llm
}

echo 'SkyPilot Black:'
black "${BLACK_INCLUDES[@]}"

## This flag formats individual files. --files *must* be the first command line
## arg to use this option.
if [[ "$1" == '--files' ]]; then
Expand All @@ -102,8 +113,12 @@ else
format_changed
fi
echo 'SkyPilot yapf: Done'
echo 'SkyPilot Black:'
black "${BLACK_INCLUDES[@]}"

echo 'SkyPilot isort:'
isort sky tests examples llm docs "${ISORT_YAPF_EXCLUDES[@]}"

isort --profile black -l 88 -m 3 "sky/skylet/providers/ibm"


# Run mypy
# TODO(zhwu): When more of the codebase is typed properly, the mypy flags
Expand Down
13 changes: 6 additions & 7 deletions llm/vicuna-llama-2/scripts/flash_attn_patch.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
from typing import List, Optional, Tuple
import logging
from typing import List, Optional, Tuple

from einops import rearrange
from flash_attn.bert_padding import pad_input
from flash_attn.bert_padding import unpad_input
# pip3 install "flash-attn>=2.0"
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
import torch
from torch import nn

import transformers
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb

from einops import rearrange
from flash_attn.flash_attn_interface import ( # pip3 install "flash-attn>=2.0"
flash_attn_varlen_qkvpacked_func,)
from flash_attn.bert_padding import unpad_input, pad_input


def forward(
self,
Expand Down
10 changes: 5 additions & 5 deletions llm/vicuna-llama-2/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass, field
from dataclasses import dataclass
from dataclasses import field
import json
import pathlib
import os
import pathlib
import shutil
import subprocess
from typing import Dict, Optional

from fastchat.conversation import SeparatorStyle
from fastchat.model.model_adapter import get_conversation_template
import torch
from torch.utils.data import Dataset
import transformers
from transformers import Trainer
from transformers.trainer_pt_utils import LabelSmoother

from fastchat.conversation import SeparatorStyle
from fastchat.model.model_adapter import get_conversation_template

IGNORE_TOKEN_ID = LabelSmoother.ignore_index


Expand Down
3 changes: 1 addition & 2 deletions llm/vicuna-llama-2/scripts/train_flash_attn.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
# Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.

# Need to call this before importing transformers.
from flash_attn_patch import (
replace_llama_attn_with_flash_attn,)
from flash_attn_patch import replace_llama_attn_with_flash_attn

replace_llama_attn_with_flash_attn()

Expand Down
3 changes: 1 addition & 2 deletions llm/vicuna-llama-2/scripts/train_xformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
# Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.

# Need to call this before importing transformers.
from xformers_patch import (
replace_llama_attn_with_xformers_attn,)
from xformers_patch import replace_llama_attn_with_xformers_attn

replace_llama_attn_with_xformers_attn()

Expand Down
2 changes: 1 addition & 1 deletion llm/vicuna-llama-2/scripts/xformers_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
from typing import Optional, Tuple

import torch
import transformers.models.llama.modeling_llama
from torch import nn
import transformers.models.llama.modeling_llama

try:
import xformers.ops
Expand Down
7 changes: 7 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,10 @@ python_version = "3.8"
follow_imports = "skip"
ignore_missing_imports = true
allow_redefinition = true

[tool.isort]
profile = "google"
line_length = 80
multi_line_output = 0
combine_as_imports = true
use_parentheses = true
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ black==22.10.0
# https://github.com/edaniszewski/pylint-quotes
pylint-quotes==0.2.3
toml==0.10.2
isort==5.12.0

# type checking
mypy==0.991
Expand Down
35 changes: 26 additions & 9 deletions sky/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,35 @@
from sky import benchmark
from sky import clouds
from sky.clouds.service_catalog import list_accelerators
from sky.core import autostop
from sky.core import cancel
from sky.core import cost_report
from sky.core import down
from sky.core import download_logs
from sky.core import job_status
from sky.core import queue
from sky.core import spot_cancel
from sky.core import spot_queue
from sky.core import spot_status
from sky.core import start
from sky.core import status
from sky.core import stop
from sky.core import storage_delete
from sky.core import storage_ls
from sky.core import tail_logs
from sky.dag import Dag
from sky.execution import launch, exec, spot_launch # pylint: disable=redefined-builtin
from sky.data import Storage
from sky.data import StorageMode
from sky.data import StoreType
from sky.execution import exec # pylint: disable=redefined-builtin
from sky.execution import launch
from sky.execution import spot_launch
from sky.optimizer import Optimizer
from sky.optimizer import OptimizeTarget
from sky.resources import Resources
from sky.task import Task
from sky.optimizer import Optimizer, OptimizeTarget
from sky.data import Storage, StorageMode, StoreType
from sky.status_lib import ClusterStatus
from sky.skylet.job_lib import JobStatus
from sky.core import (status, start, stop, down, autostop, queue, cancel,
tail_logs, download_logs, job_status, spot_queue,
spot_status, spot_cancel, storage_ls, storage_delete,
cost_report)
from sky.status_lib import ClusterStatus
from sky.task import Task

# Aliases.
IBM = clouds.IBM
Expand Down
2 changes: 1 addition & 1 deletion sky/adaptors/cloudflare.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

import contextlib
import functools
import threading
import os
import threading
from typing import Dict, Optional, Tuple

from sky.utils import ux_utils
Expand Down
2 changes: 1 addition & 1 deletion sky/adaptors/gcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ def wrapper(*args, **kwargs):
global googleapiclient, google
if googleapiclient is None or google is None:
try:
import googleapiclient as _googleapiclient
import google as _google
import googleapiclient as _googleapiclient
googleapiclient = _googleapiclient
google = _google
except ImportError:
Expand Down
Loading

0 comments on commit 47ecf99

Please sign in to comment.