Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add type conversions [feature request] #30

Open
PeterMitrano opened this issue Apr 4, 2021 · 4 comments
Open

add type conversions [feature request] #30

PeterMitrano opened this issue Apr 4, 2021 · 4 comments

Comments

@PeterMitrano
Copy link

Converting between things like float32/64/int32/64 would be a really nice addition!

@jonasrauber
Copy link
Owner

We do have support for astype(dtype) which allows you to do type conversions. But you are right, it's not clear if/how this works with specifc dtypes. So far, we only use astype together with a tensor's dtype attribute; they are always compatible. Maybe that already solves your use case.

@PeterMitrano
Copy link
Author

Sure astype exists but it's not useful as is IMHO. Currently I have to write my_int_tensor.astype(torch.float32) which defeats the whole point of being agnostic.

Here's how I think it should behave, not sure if this makes sense

def foo(my_int_tensor: eagerpy.Tensor):
  my_int_tensor.astype(eagerpy.float32)

Here's a simple wrapper that does what I want, but is not a good implementation

import eagerpy
import torch
import tensorflow as tf

def my_astype(x: eagerpy.TensorType, dtype: str):
    if dtype == 'float32':
        if isinstance(x.raw, torch.Tensor):
            specific_dtype = torch.float32
        elif isinstance(x.raw, tf.Tensor):
            specific_dtype = tf.float32
        else:
            pass # more stuff here...
        return x.astype(specific_dtype)
    else:
        pass # more stuff here...

x=torch.tensor([1,2,3])
ex = eagerpy.astensor(x)

print(ex.dtype)
print(my_astype(ex, 'float32').dtype)

x=tf.constant([1,2,3])
ex = eagerpy.astensor(x)

print(ex.dtype)
print(my_astype(ex, 'float32').dtype)

@jonasrauber
Copy link
Owner

Yes, I fully agree. For float32 specifically, we actually also have a tensor.float32() method (but not yet for the other dtypes).

@PeterMitrano
Copy link
Author

PeterMitrano commented Apr 7, 2021 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants