Skip to content

Commit

Permalink
add bert-base-uncased_fp16 to shark_tank
Browse files Browse the repository at this point in the history
  • Loading branch information
monorimet committed Jan 10, 2023
1 parent ec7b19d commit 2ea5fa6
Show file tree
Hide file tree
Showing 7 changed files with 25 additions and 21 deletions.
2 changes: 0 additions & 2 deletions .github/workflows/nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,6 @@ jobs:
# Install the built wheel
pip install ./wheelhouse/nodai*
# Validate the Models
pip uninstall torch torchvision
pip install --pre torch torchvision --extra-index-url https://download.pytorch.org/whl/nightly/cu117
/bin/bash "$GITHUB_WORKSPACE/build_tools/populate_sharktank_ci.sh"
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="./gen_shark_tank/" -k "not metal" |
tail -n 1 |
Expand Down
2 changes: 1 addition & 1 deletion build_tools/populate_sharktank_ci.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/bin/bash

IMPORTER=1 ./setup_venv.sh
IMPORTER=1 BENCHMARK=1 ./setup_venv.sh
source $GITHUB_WORKSPACE/shark.venv/bin/activate
python generate_sharktank.py --upload=False --ci_tank_dir=True
6 changes: 5 additions & 1 deletion setup_venv.sh
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,12 @@ fi
$PYTHON -m pip install --no-warn-conflicts -e . -f https://llvm.github.io/torch-mlir/package-index/ -f ${RUNTIME} -f https://download.pytorch.org/whl/nightly/torch/

if [[ $(uname -s) = 'Linux' && ! -z "${BENCHMARK}" ]]; then
T_VER=$($PYTHON -m pip show torch | grep Version)
TORCH_VERSION=${T_VER:9:17}
TV_VER=$($PYTHON -m pip show torchvision | grep Version)
TV_VERSION=${TV_VER:9:18}
$PYTHON -m pip uninstall -y torch torchvision
$PYTHON -m pip install --pre torch torchvision --extra-index-url https://download.pytorch.org/whl/nightly/cu117
$PYTHON -m pip install --no-deps https://download.pytorch.org/whl/nightly/cu117/torch-${TORCH_VERSION}%2Bcu117-cp310-cp310-linux_x86_64.whl https://download.pytorch.org/whl/nightly/cu117/torchvision-${TV_VERSION}%2Bcu117-cp310-cp310-linux_x86_64.whl
if [ $? -eq 0 ];then
echo "Successfully Installed torch + cu117."
else
Expand Down
5 changes: 4 additions & 1 deletion shark/shark_benchmark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,10 @@ def benchmark_all_csv(
else:
bench_result["shape_type"] = "static"
bench_result["device"] = device_str
bench_result["data_type"] = inputs[0].dtype
if "fp16" in modelname:
bench_result["data_type"] = "float16"
else:
bench_result["data_type"] = inputs[0].dtype
for e in engines:
(
bench_result["param_count"],
Expand Down
2 changes: 1 addition & 1 deletion tank/all_models.csv
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ albert-base-v2,linalg,torch,1e-2,1e-3,default,None,True,True,True,"issue with at
alexnet,linalg,torch,1e-2,1e-3,default,None,False,False,True,"Assertion Error: Zeros Output"
bert-base-cased,linalg,torch,1e-2,1e-3,default,None,False,False,False,""
bert-base-uncased,linalg,torch,1e-2,1e-3,default,None,False,False,False,""
bert-base-uncased_fp16,linalg,torch,1e-2,1e-2,default,None,True,True,True,""
bert-base-uncased_fp16,linalg,torch,1e-1,1e-1,default,None,True,False,True,""
facebook/deit-small-distilled-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"Fails during iree-compile."
google/vit-base-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/311"
microsoft/beit-base-patch16-224-pt22k-ft22k,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/390"
Expand Down
28 changes: 13 additions & 15 deletions tank/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,13 +156,14 @@ def get_vision_model(torch_model):
"mnasnet1_0": models.mnasnet1_0(weights="DEFAULT"),
}
if isinstance(torch_model, str):
fp16_model = None
if "fp16" in torch_model:
fp16_model = True
torch_model = vision_models_dict[torch_model]
model = VisionModule(torch_model)
test_input = torch.randn(1, 3, 224, 224)
actual_out = model(test_input)
if fp16_model == True:
if fp16_model is not None:
test_input_fp16 = test_input.to(
device=torch.device("cuda"), dtype=torch.half
)
Expand All @@ -187,17 +188,15 @@ class BertHalfPrecisionModel(torch.nn.Module):
def __init__(self, hf_model_name):
super().__init__()
from transformers import AutoModelForMaskedLM
import transformers as trf

transformers_path = trf.__path__[0]
hf_model_path = f"{transformers_path}/models/{hf_model_name}"
self.model = AutoModelForMaskedLM.from_pretrained(
hf_model_name, # The pretrained model.
num_labels=2, # The number of output labels--2 for binary classification.
output_attentions=False, # Whether the model returns attentions weights.
output_hidden_states=False, # Whether the model returns all hidden-states.
torchscript=True,
)
torch_dtype=torch.float16,
).to("cuda")

def forward(self, tokens):
return self.model.forward(tokens)[0]
Expand All @@ -210,22 +209,21 @@ def get_fp16_model(torch_model):
model = BertHalfPrecisionModel(modelname)
tokenizer = AutoTokenizer.from_pretrained(modelname)
text = "Replace me by any text you like."
encoded_input = tokenizer(
test_input_fp16 = tokenizer(
text,
truncation=True,
max_length=128,
return_tensors="pt",
)
for key in encoded_input:
encoded_input[key] = (
encoded_input[key].detach().numpy().astype(np.half)
)

).input_ids.to("cuda")
# test_input = torch.randint(2, (1, 128))
# test_input_fp16 = test_input.to(
# device=torch.device("cuda")
# )
model_fp16 = model.half()
model_fp16.eval()
model_fp16.to("cuda")
actual_out_fp16 = model_fp16(encoded_input)
return model_fp16, encoded_input, actual_out_fp16
with torch.no_grad():
actual_out_fp16 = model_fp16(test_input_fp16)
return model_fp16, test_input_fp16, actual_out_fp16


# Utility function for comparing two tensors (torch).
Expand Down
1 change: 1 addition & 0 deletions tank/torch_model_list.csv
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ microsoft/beit-base-patch16-224-pt22k-ft22k,True,hf_img_cls,False,86M,"image-cla
nvidia/mit-b0,True,hf_img_cls,False,3.7M,"image-classification,transformer-encoder",SegFormer
mnasnet1_0,False,vision,True,-,"cnn, torchvision, mobile, architecture-search","Outperforms other mobile CNNs on Accuracy vs. Latency"
resnet50_fp16,False,vision,True,23M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
bert-base-uncased_fp16,True,fp16,False,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"

0 comments on commit 2ea5fa6

Please sign in to comment.