Skip to content

Commit

Permalink
Bring back CI to a normal state (#82)
Browse files Browse the repository at this point in the history
* Fix invalid dtype checks

* Fix invalid gemma quantization import

* Fix dtype again

* Quality

* Update transformers requirements to support gemma

* Add a few more to .dockerignore

* Remove unused test workflow
  • Loading branch information
mfuntowicz authored Feb 27, 2024
1 parent bb13d65 commit 9601738
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 30 deletions.
3 changes: 3 additions & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
.env*/
third-party/*

**/*.engine
**/*.pyc
**/*.egg-info
Expand Down
15 changes: 0 additions & 15 deletions .github/workflows/test.yml

This file was deleted.

4 changes: 2 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ celerybeat.pid
*.sage.py

# Environments
.env
.venv
.env*/
.venv/
env/
venv/
ENV/
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
"numpy >= 1.22.0",
"onnx >= 1.12.0",
"optimum >= 1.13.0",
"transformers >= 4.32.1",
"transformers >= 4.38.1",
"pynvml"
]

Expand Down
2 changes: 1 addition & 1 deletion src/optimum/nvidia/models/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@
import numpy as np
import tensorrt_llm
import torch
from quantization import QuantMode
from tensorrt_llm import Mapping, str_dtype_to_torch
from tensorrt_llm._utils import numpy_to_torch, pad_vocab_size, torch_to_numpy
from tensorrt_llm.layers import MoeConfig
from tensorrt_llm.models import PretrainedConfig, PretrainedModel
from tensorrt_llm.models.gemma.model import GemmaForCausalLM as TrtGemmaForCausalLM
from tensorrt_llm.models.gemma.weight import dup_kv_weight, extract_layer_idx, split
from tensorrt_llm.plugin import PluginConfig
from tensorrt_llm.quantization import QuantMode
from tensorrt_llm.runtime.lora_manager import LoraConfig
from transformers import GemmaForCausalLM as TransformersGemmaForCausalLM
from transformers import PretrainedConfig as TransformersPretrainedConfig
Expand Down
23 changes: 12 additions & 11 deletions tests/test_dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,16 @@
@pytest.mark.parametrize(
"literal_dtype,dtype",
[
("int64", torch.int16),
("float32", torch.float32),
("bool", torch.bool),
("uint8", torch.uint8),
("int8", torch.int8),
# ("int16", torch.int16),
("int32", torch.int32),
("int64", torch.int64),
("float8", torch.float8_e4m3fn), # Change this when supported
("float16", torch.float16),
("bfloat16", torch.bfloat16),
("float8", torch.float8_e4m3fn), # Change this when supported
("int8", torch.int8),
("uint8", torch.uint8),
("bool", torch.bool),
("float32", torch.float32),
],
)
def test_convert_str_to_torch(literal_dtype: str, dtype):
Expand All @@ -42,14 +43,14 @@ def test_convert_str_to_torch(literal_dtype: str, dtype):
@pytest.mark.parametrize(
"literal_dtype,dtype",
[
("uint8", trt.uint8),
("int8", trt.int8),
("int32", trt.int32),
("int64", trt.int64),
("float32", trt.float32),
("float8", trt.fp8),
("float16", trt.float16),
("bfloat16", trt.bfloat16),
("int32", trt.int32),
("float8", trt.fp8),
("int8", trt.int8),
("uint8", trt.uint8),
("float32", trt.float32),
],
)
def test_convert_str_to_tensorrt(literal_dtype: str, dtype):
Expand Down

0 comments on commit 9601738

Please sign in to comment.