Skip to content

Commit

Permalink
Update constraints to torch 2.4 dynamic_shapes API (#806)
Browse files Browse the repository at this point in the history
This commit updates to use dynamic_shapes for dynamic dimensions as the
usage of constraints is deprecated in torch 2.4. (The SHARK test checks
out the main branch of this repo and will pass once this is merged)
  • Loading branch information
saienduri authored Aug 7, 2024
1 parent 26ce08e commit 5e8cbb7
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 21 deletions.
4 changes: 2 additions & 2 deletions models/turbine_models/custom_models/resnet_18.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ class CompiledResnet18Model(CompiledModule):
params = export_parameters(resnet_model.model)

def main(self, x=AbstractTensor(None, 3, 224, 224, dtype=torch.float32)):
const = [x.dynamic_dim(0) < 16]
return jittable(resnet_model.forward)(x, constraints=const)
dynamic_shapes = {"arg0_1": {0: torch.export.Dim("dim", max=15)}}
return jittable(resnet_model.forward)(x, dynamic_shapes=dynamic_shapes)

import_to = "INPUT" if compile_to == "linalg" else "IMPORT"
inst = CompiledResnet18Model(context=Context(), import_to=import_to)
Expand Down
41 changes: 22 additions & 19 deletions models/turbine_models/custom_models/stateless_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,10 @@ class StateUpdateModule(CompiledModule):
def run_initialize(
self, x=AbstractTensor(BATCH_SIZE, None, dtype=torch.int64)
):
init_const = [x.dynamic_dim(1) < MAX_STEP_SEQ]
token, *state = self.initialize(x, constraints=init_const)
dynamic_shapes_init = {
"arg0_1": {1: torch.export.Dim("dim", max=MAX_STEP_SEQ - 1)}
}
token, *state = self.initialize(x, dynamic_shapes=dynamic_shapes_init)
self.global_seq_step = IREE.tensor_dim(
state[0], 1
) # ? dimension of arbitrarily 0th kv tensor
Expand Down Expand Up @@ -267,16 +269,15 @@ def run_forward(self, x=AbstractTensor(1, 1, dtype=torch.int64)):
HIDDEN_DIM,
NUM_LAYERS,
)
forw_const = (
[state_arg[0].dynamic_dim(1) < MAX_STEP_SEQ]
+ [
x.dynamic_dim(1) == (state_arg[0].dynamic_dim(1))
for x in state_arg[1:]
]
+ [x.dynamic_dim(1) < MAX_STEP_SEQ for x in state_arg[1:]]
state_arg0_dim = torch.export.Dim(
"state_arg0_dim", max=MAX_STEP_SEQ - 1
)
dynamic_shapes_forw = {"arg0_1": None, "arg1_1": {1: state_arg0_dim}}
for state_arg_idx in range(2, len(state_arg) + 1):
current_dim_dict = {f"arg{state_arg_idx}_1": {1: state_arg0_dim}}
dynamic_shapes_forw = {**dynamic_shapes_forw, **current_dim_dict}
token, *state_update = self.forward(
x, *state_arg, constraints=forw_const
x, *state_arg, dynamic_shapes=dynamic_shapes_forw
)
for i in range(NUM_LAYERS):
update = IREE.tensor_reshape(
Expand Down Expand Up @@ -343,17 +344,19 @@ def run_cached_initialize(
HIDDEN_DIM,
NUM_LAYERS,
)
forw_const = (
[x.dynamic_dim(1) < MAX_STEP_SEQ]
+ [state_arg[0].dynamic_dim(1) < MAX_STEP_SEQ]
+ [
x.dynamic_dim(1) == (state_arg[0].dynamic_dim(1))
for x in state_arg[1:]
]
+ [x.dynamic_dim(1) < MAX_STEP_SEQ for x in state_arg[1:]]
state_arg0_dim1 = torch.export.Dim(
"state_arg0_dim1", max=MAX_STEP_SEQ - 1
)
x_dim = torch.export.Dim("x_dim", max=MAX_STEP_SEQ - 1)
dynamic_shapes_forw = {
"arg0_1": {1: x_dim},
"arg1_1": {1: state_arg0_dim1},
}
for state_arg_idx in range(2, len(state_arg) + 1):
current_dim_dict = {f"arg{state_arg_idx}_1": {1: state_arg0_dim1}}
dynamic_shapes_forw = {**dynamic_shapes_forw, **current_dim_dict}
token, *state = self.cached_initialize(
x, *state_arg, constraints=forw_const
x, *state_arg, dynamic_shapes=dynamic_shapes_forw
)
len_of_new_tokens = IREE.tensor_dim(
state[0], 1
Expand Down

0 comments on commit 5e8cbb7

Please sign in to comment.