-
Notifications
You must be signed in to change notification settings - Fork 4
/
parity_reader.py
48 lines (38 loc) · 1.49 KB
/
parity_reader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import tensorflow as tf
import numpy as np
from .registry import register
VECTOR_SIZE = 64
@register("parity")
def input_fn(data_sources, params, training):
def _input_fn():
"""
Returns training input (x) and output (y).
x: 64 element vector generated by randomly picking a number of -1 and 1
to be randomly set; rest of it is padded with 0s.
y: one-hot encoding representing 1 if there are odd number of 1s or 0 otherwise.
"""
def get_x():
x_shape = (params.batch_size, VECTOR_SIZE)
raw_x = np.random.choice(a=[-1, 1], size=x_shape)
num_non_zero = np.random.choice(
a=range(VECTOR_SIZE), size=params.batch_size).astype(np.int32)
zero_mask = np.ones(shape=x_shape)
for i, v in enumerate(num_non_zero):
zero_mask[i, v:] = 0
return (raw_x * zero_mask).astype(np.int32), num_non_zero
x, difficulty = tf.py_func(get_x, [], [tf.int32, tf.int32])
x.set_shape((params.batch_size, VECTOR_SIZE))
difficulty.set_shape((params.batch_size))
sequence_length = tf.constant([1] * params.batch_size)
target_mask = tf.sequence_mask(sequence_length, maxlen=1, dtype=tf.float32)
num_ones_per_sample = tf.reduce_sum(
tf.cast(tf.equal(x, 1), tf.int32), axis=1)
y = tf.one_hot(
indices=tf.mod(num_ones_per_sample, 2), depth=2, dtype=tf.int64)
return {
"inputs": x,
"seq_length": None,
"difficulty": difficulty,
"target_mask": target_mask
}, y
return _input_fn