diff --git a/torch2trt/flattener.py b/torch2trt/flattener.py index 7fbf4070..fca20595 100644 --- a/torch2trt/flattener.py +++ b/torch2trt/flattener.py @@ -3,7 +3,7 @@ def _default_condition(x): - return isinstance(x, torch.Tensor) and (x.dtype is torch.half or x.dtype is torch.float or x.dtype == torch.bool) + return isinstance(x, torch.Tensor) and (x.dtype is torch.half or x.dtype is torch.float or x.dtype == torch.bool or x.dtype == torch.int32 or x.dtype == torch.int64 or x.dtype == torch.long) def _make_schema_from_value(value, condition=_default_condition, size=0): @@ -90,4 +90,4 @@ def unflatten(self, flattened): result[child_key] = Flattener(child_schema, self.size).unflatten(flattened) return result else: - return None \ No newline at end of file + return None