Skip to content

Commit

Permalink
try reverting pytorch change
Browse files Browse the repository at this point in the history
Signed-off-by: Yee Hing Tong <[email protected]>
  • Loading branch information
wild-endeavor committed Oct 22, 2024
1 parent 7641697 commit 540a8d2
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions flytekit/extras/pytorch/native.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
import torch

from flytekit.core.context_manager import FlyteContext
from flytekit.core.type_engine import AsyncTypeTransformer, TypeEngine, TypeTransformerFailedError
from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError
from flytekit.models.core import types as _core_types
from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar
from flytekit.models.types import LiteralType

T = TypeVar("T")


class PyTorchTypeTransformer(AsyncTypeTransformer[T]):
class PyTorchTypeTransformer(TypeTransformer[T]):
def get_literal_type(self, t: Type[T]) -> LiteralType:
return LiteralType(
blob=_core_types.BlobType(
Expand All @@ -21,7 +21,7 @@ def get_literal_type(self, t: Type[T]) -> LiteralType:
)
)

async def async_to_literal(
def to_literal(
self,
ctx: FlyteContext,
python_val: T,
Expand All @@ -44,17 +44,17 @@ async def async_to_literal(
# save pytorch tensor/module to a file
torch.save(python_val, local_path)

remote_path = await ctx.file_access.async_put_raw_data(local_path)
remote_path = ctx.file_access.put_raw_data(local_path)
return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path)))

async def async_to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T:
def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T:
try:
uri = lv.scalar.blob.uri
except AttributeError:
TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}")

local_path = ctx.file_access.get_random_local_path()
await ctx.file_access.async_get_data(uri, local_path, is_multipart=False)
ctx.file_access.get_data(uri, local_path, is_multipart=False)

# cpu <-> gpu conversion
if torch.cuda.is_available():
Expand Down

0 comments on commit 540a8d2

Please sign in to comment.