Skip to content

Commit

Permalink
Force features to be floats instead of doubles
Browse files Browse the repository at this point in the history
As per @caitlynlee in google-research#7 , Make predictions give errors for various tensors as `expected dtype double does not equal original dtype float`. Forcing tensors  to be floats instead of doubles in `astronet/predict.py ` solves the issue.

```
115: global_view =  preprocess.global_view(time, flux, FLAGS.period).astype(np.float32)
...
120: local_view = preprocess.local_view(time, flux, FLAGS.period, FLAGS.duration).astype(np.float32) 
```
  • Loading branch information
ritwik12 authored Feb 24, 2020
1 parent 3dfe65f commit 8d6a8c0
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions exoplanet-ml/astronet/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,12 @@ def _process_tce(feature_config):
features = {}

if "global_view" in feature_config:
global_view = preprocess.global_view(time, flux, FLAGS.period)
global_view = preprocess.global_view(time, flux, FLAGS.period).astype(np.float32)
# Add a batch dimension.
features["global_view"] = np.expand_dims(global_view, 0)

if "local_view" in feature_config:
local_view = preprocess.local_view(time, flux, FLAGS.period, FLAGS.duration)
local_view = preprocess.local_view(time, flux, FLAGS.period, FLAGS.duration).astype(np.float32)
# Add a batch dimension.
features["local_view"] = np.expand_dims(local_view, 0)

Expand Down

0 comments on commit 8d6a8c0

Please sign in to comment.