Skip to content

Commit

Permalink
Merge branch 'main' into misc-fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Nov 27, 2024
2 parents 13bc8c0 + 9104793 commit 1b68c85
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 10 deletions.
6 changes: 3 additions & 3 deletions docker/cpu/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ ARG TORCH_VERSION=""
ARG TORCH_RELEASE_TYPE=stable

RUN if [ -n "${TORCH_VERSION}" ]; then \
pip install --no-cache-dir torch==${TORCH_VERSION} torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu ; \
pip install --no-cache-dir torch==${TORCH_VERSION} torchvision torchaudio torchao --index-url https://download.pytorch.org/whl/cpu ; \
elif [ "${TORCH_RELEASE_TYPE}" = "stable" ]; then \
pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu ; \
pip install --no-cache-dir torch torchvision torchaudio torchao --index-url https://download.pytorch.org/whl/cpu ; \
elif [ "${TORCH_RELEASE_TYPE}" = "nightly" ]; then \
pip install --no-cache-dir --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu ; \
pip install --no-cache-dir --pre torch torchvision torchaudio torchao --index-url https://download.pytorch.org/whl/nightly/cpu ; \
else \
echo "Error: Invalid TORCH_RELEASE_TYPE. Must be 'stable', 'nightly', or specify a TORCH_VERSION." && exit 1 ; \
fi
6 changes: 3 additions & 3 deletions docker/cuda-ort/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ ARG TORCH_CUDA=cu118
ARG TORCH_VERSION=stable

RUN if [ "${TORCH_VERSION}" = "stable" ]; then \
pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/${TORCH_CUDA} ; \
pip install --no-cache-dir torch torchvision torchaudio torchao --index-url https://download.pytorch.org/whl/${TORCH_CUDA} ; \
elif [ "${TORCH_VERSION}" = "nightly" ]; then \
pip install --no-cache-dir --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/${TORCH_CUDA} ; \
pip install --no-cache-dir --pre torch torchvision torchaudio torchao --index-url https://download.pytorch.org/whl/nightly/${TORCH_CUDA} ; \
else \
pip install --no-cache-dir torch==${TORCH_VERSION} torchvision torchaudio --index-url https://download.pytorch.org/whl/${TORCH_CUDA} ; \
pip install --no-cache-dir torch==${TORCH_VERSION} torchvision torchaudio torchao --index-url https://download.pytorch.org/whl/${TORCH_CUDA} ; \
fi

# Install torch-ort and onnxruntime-training
Expand Down
6 changes: 3 additions & 3 deletions docker/cuda/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ ARG TORCH_CUDA=cu124
ARG TORCH_RELEASE_TYPE=stable

RUN if [ -n "${TORCH_VERSION}" ]; then \
pip install --no-cache-dir torch==${TORCH_VERSION} torchvision torchaudio --index-url https://download.pytorch.org/whl/${TORCH_CUDA} ; \
pip install --no-cache-dir torch==${TORCH_VERSION} torchvision torchaudio torchao --index-url https://download.pytorch.org/whl/${TORCH_CUDA} ; \
elif [ "${TORCH_RELEASE_TYPE}" = "stable" ]; then \
pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/${TORCH_CUDA} ; \
pip install --no-cache-dir torch torchvision torchaudio torchao --index-url https://download.pytorch.org/whl/${TORCH_CUDA} ; \
elif [ "${TORCH_RELEASE_TYPE}" = "nightly" ]; then \
pip install --no-cache-dir --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/${TORCH_CUDA} ; \
pip install --no-cache-dir --pre torch torchvision torchaudio torchao --index-url https://download.pytorch.org/whl/nightly/${TORCH_CUDA} ; \
else \
echo "Error: Invalid TORCH_RELEASE_TYPE. Must be 'stable', 'nightly', or specify a TORCH_VERSION." && exit 1 ; \
fi
Expand Down
5 changes: 5 additions & 0 deletions examples/pytorch_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@
"quantization_scheme": "gptq",
"quantization_config": {"bits": 4, "use_exllama ": True, "version": 2, "model_seqlen": 256},
},
"torchao-int4wo-128": {
"torch_dtype": "bfloat16",
"quantization_scheme": "torchao",
"quantization_config": {"quant_type": "int4_weight_only", "group_size": 128},
}
}


Expand Down
13 changes: 13 additions & 0 deletions optimum_benchmark/backends/pytorch/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
TrainerState,
TrainingArguments,
)
from transformers import TorchAoConfig

from ...import_utils import is_deepspeed_available, is_torch_distributed_available, is_zentorch_available
from ..base import Backend
Expand Down Expand Up @@ -323,6 +324,11 @@ def process_quantization_config(self) -> None:
self.quantization_config = BitsAndBytesConfig(
**dict(getattr(self.pretrained_config, "quantization_config", {}), **self.config.quantization_config)
)
elif self.is_torchao_quantized:
self.logger.info("\t+ Processing TorchAO config")
self.quantization_config = TorchAoConfig(
**dict(getattr(self.pretrained_config, "quantization_config", {}), **self.config.quantization_config)
)
else:
raise ValueError(f"Quantization scheme {self.config.quantization_scheme} not recognized")

Expand Down Expand Up @@ -354,6 +360,13 @@ def is_awq_quantized(self) -> bool:
and self.pretrained_config.quantization_config.get("quant_method", None) == "awq"
)

@property
def is_torchao_quantized(self) -> bool:
return self.config.quantization_scheme == "torchao" or (
hasattr(self.pretrained_config, "quantization_config")
and self.pretrained_config.quantization_config.get("quant_method", None) == "torchao"
)

@property
def is_exllamav2(self) -> bool:
return (
Expand Down
2 changes: 1 addition & 1 deletion optimum_benchmark/backends/pytorch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
AMP_DTYPES = ["bfloat16", "float16"]
TORCH_DTYPES = ["bfloat16", "float16", "float32", "auto"]

QUANTIZATION_CONFIGS = {"bnb": {"llm_int8_threshold": 0.0}, "gptq": {}, "awq": {}}
QUANTIZATION_CONFIGS = {"bnb": {"llm_int8_threshold": 0.0}, "gptq": {}, "awq": {}, "torchao": {}}


@dataclass
Expand Down

0 comments on commit 1b68c85

Please sign in to comment.