diff --git a/core/shark_turbine/dynamo/passes.py b/core/shark_turbine/dynamo/passes.py index 80be06ae8..cebbacd2f 100644 --- a/core/shark_turbine/dynamo/passes.py +++ b/core/shark_turbine/dynamo/passes.py @@ -54,7 +54,9 @@ # These decompositions either didnt exist or weren't required for 2.1.0 if torch.__version__ > "2.1.0": - DEFAULT_DECOMPOSITIONS.append(torch.ops.aten._scaled_dot_product_flash_attention_for_cpu) + DEFAULT_DECOMPOSITIONS.append( + torch.ops.aten._scaled_dot_product_flash_attention_for_cpu + ) DEFAULT_DECOMPOSITIONS.append(torch.ops.aten.unbind_int)