Skip to content

Commit

Permalink
Cleanup comments and redundant code.
Browse files Browse the repository at this point in the history
  • Loading branch information
monorimet committed Apr 11, 2024
1 parent 12b91f4 commit a3939e8
Show file tree
Hide file tree
Showing 8 changed files with 18 additions and 356 deletions.
7 changes: 1 addition & 6 deletions core/shark_turbine/aot/builtins/jittable.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,11 +214,6 @@ def flat_wrapped_f(*args):
if "functorch_functionalize" in self._passes:
transformed_f = functorch_functionalize(transformed_f, *flat_pytorch_args)

for node in transformed_f.graph.nodes: # type: ignore
if node.op == "call_function":
if node.target == torch._ops.ops.aten.lift_fresh_copy.default:
print(f"replaced lift_fresh_copy")
node.target = torch._ops.ops.aten.clone.default
transformed_f.recompile() # type: ignore

# Ask dynamo to give us an aten graph.
Expand All @@ -233,7 +228,7 @@ def flat_wrapped_f(*args):
)
logger.debug("Invoking dynamo trace")
gm, guards = exported_f(*flat_pytorch_args)
logger.debug("Dyanmo trace complete")
logger.debug("Dynamo trace complete")

# TODO: Add debug logging for the exported graph module.
# gm.print_readable()
Expand Down

This file was deleted.

14 changes: 8 additions & 6 deletions models/turbine_models/custom_models/sd_inference/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,17 @@ def export_unet_model(
target_triple=None,
max_alloc=None,
upload_ir=False,
decomp_attn=True,
):
mapper = {}
decomp_list = DEFAULT_DECOMPOSITIONS
decomp_list.extend(
[
torch.ops.aten._scaled_dot_product_flash_attention_for_cpu,
torch.ops.aten._scaled_dot_product_flash_attention.default,
]
)
if decomp_attn:
decomp_list.extend(
[
torch.ops.aten._scaled_dot_product_flash_attention_for_cpu,
torch.ops.aten._scaled_dot_product_flash_attention.default,
]
)
dtype = torch.float16 if precision == "fp16" else torch.float32
unet_model = unet_model.to(dtype)
utils.save_external_weights(
Expand Down
3 changes: 1 addition & 2 deletions models/turbine_models/custom_models/sd_inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,10 @@
EulerDiscreteScheduler,
)

winograd_params = "keys=unet.down_blocks.2.resnets.0.conv2.weight keys=unet.down_blocks.2.resnets.1.conv1.weight keys=unet.down_blocks.2.resnets.1.conv2.weight keys=unet.mid_block.resnets.0.conv1.weight keys=unet.mid_block.resnets.0.conv2.weight keys=unet.mid_block.resnets.1.conv1.weight keys=unet.mid_block.resnets.1.conv2.weight keys=unet.up_blocks.0.resnets.0.conv2.weight keys=unet.up_blocks.0.resnets.1.conv2.weight keys=unet.up_blocks.0.resnets.2.conv2.weight keys=unet.up_blocks.0.resnets.0.conv1.weight keys=unet.up_blocks.0.resnets.1.conv1.weight keys=unet.up_blocks.0.resnets.2.conv1.weight keys=unet.up_blocks.0.upsamplers.0.conv.weight"
# If flags are verified to work on a specific model and improve performance without regressing numerics, add them to this dictionary. If you are working with bleeding edge flags, please add them manually with the --ireec_flags argument.
gfx94X_flags = {
"all": [
"--iree-global-opt-propagate-transposes=true",
"--iree-opt-const-eval=false",
"--iree-opt-outer-dim-concat=true",
"--iree-rocm-bc-dir=/opt/rocm/amdgcn/bitcode",
"--iree-vm-target-truncate-unsupported-floats",
Expand Down Expand Up @@ -95,6 +93,7 @@ def compile_to_vmfb(
"--iree-hal-target-backends=rocm",
"--iree-rocm-target-chip=" + target_triple,
"--verify=false",
"--iree-opt-const-eval=false",
]
)
elif device == "cuda":
Expand Down
14 changes: 8 additions & 6 deletions models/turbine_models/custom_models/sd_inference/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,15 +116,17 @@ def export_vae_model(
max_alloc=None,
variant="decode",
upload_ir=False,
decomp_attn=True,
):
mapper = {}
decomp_list = DEFAULT_DECOMPOSITIONS
decomp_list.extend(
[
torch.ops.aten._scaled_dot_product_flash_attention_for_cpu,
torch.ops.aten._scaled_dot_product_flash_attention.default,
]
)
if decomp_attn:
decomp_list.extend(
[
torch.ops.aten._scaled_dot_product_flash_attention_for_cpu,
torch.ops.aten._scaled_dot_product_flash_attention.default,
]
)
dtype = torch.float16 if precision == "fp16" else torch.float32
vae_model = vae_model.to(dtype)
utils.save_external_weights(
Expand Down
Loading

0 comments on commit a3939e8

Please sign in to comment.