You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
CUDA_VISIBLE_DEVICES=0 python examples/text_generation/shakespeare_rnn.py
[GpuDevice(id=0)]
X_one_hot.shape: (10, 10, 65)
Z.shape: (100, 65)
to be or not to beEQGnyo.3wZ
Traceback (most recent call last):
File "/home/dberth/Code/objax/objax/module.py", line 226, in jit
return f(*args, **kwargs), self.vc.tensors()
File "/home/dberth/Code/objax/objax/module.py", line 180, in __call__
return self.__wrapped__(*args, **kwargs)
File "examples/text_generation/shakespeare_rnn.py", line 193, in train_op
g, v = gv(x, xl) # returns gradients, loss
File "/home/dberth/Code/objax/objax/gradient.py", line 84, in __call__
list(args), kwargs)
File "/home/dberth/jax3/lib/python3.6/site-packages/jax/api.py", line 756, in grad_f_aux
(_, aux), g = value_and_grad_f(*args, **kwargs)
File "/home/dberth/jax3/lib/python3.6/site-packages/jax/api.py", line 815, in value_and_grad_f
ans, vjp_py, aux = _vjp(f_partial, *dyn_args, has_aux=True)
File "/home/dberth/jax3/lib/python3.6/site-packages/jax/api.py", line 1863, in _vjp
out_primal, out_vjp, aux = ad.vjp(flat_fun, primals_flat, has_aux=True)
File "/home/dberth/jax3/lib/python3.6/site-packages/jax/interpreters/ad.py", line 115, in vjp
out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)
File "/home/dberth/jax3/lib/python3.6/site-packages/jax/interpreters/ad.py", line 100, in linearize
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
File "/home/dberth/jax3/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 404, in trace_to_jaxpr
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File "/home/dberth/jax3/lib/python3.6/site-packages/jax/linear_util.py", line 151, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/dberth/Code/objax/objax/gradient.py", line 63, in f_func
self.vc.assign(original_vc)
File "/home/dberth/Code/objax/objax/variable.py", line 249, in assign
var.assign(tensor)
File "/home/dberth/Code/objax/objax/variable.py", line 61, in assign
assert_assigned_type_and_shape_match(self.value, tensor)
File "/home/dberth/Code/objax/objax/util/check.py", line 52, in assert_assigned_type_and_shape_match
shape_mismatch_error
AssertionError: Assign can not change shape of variable. The current variable shape is (10, 256), but the requested new shape is (1, 256).
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "examples/text_generation/shakespeare_rnn.py", line 210, in<module>
v = train_op(X_one_hot, flat_labels)
File "/home/dberth/Code/objax/objax/module.py", line 236, in __call__
output, changes = self._call(self.vc.tensors(), kwargs, *args)
File "/home/dberth/jax3/lib/python3.6/site-packages/jax/api.py", line 216, in f_jitted
donated_invars=donated_invars)
File "/home/dberth/jax3/lib/python3.6/site-packages/jax/core.py", line 1155, inbindreturn call_bind(self, fun, *args, **params)
File "/home/dberth/jax3/lib/python3.6/site-packages/jax/core.py", line 1146, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/home/dberth/jax3/lib/python3.6/site-packages/jax/core.py", line 1158, in process
return trace.process_call(self, fun, tracers, params)
File "/home/dberth/jax3/lib/python3.6/site-packages/jax/core.py", line 577, in process_call
return primitive.impl(f, *tracers, **params)
File "/home/dberth/jax3/lib/python3.6/site-packages/jax/interpreters/xla.py", line 557, in _xla_call_impl
*unsafe_map(arg_spec, args))
File "/home/dberth/jax3/lib/python3.6/site-packages/jax/linear_util.py", line 234, in memoized_fun
ans = call(fun, *args)
File "/home/dberth/jax3/lib/python3.6/site-packages/jax/interpreters/xla.py", line 622, in _xla_callable
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args)
File "/home/dberth/jax3/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 1038, in trace_to_jaxpr_final
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
File "/home/dberth/jax3/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 1019, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)
File "/home/dberth/jax3/lib/python3.6/site-packages/jax/linear_util.py", line 151, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/dberth/Code/objax/objax/module.py", line 228, in jit
self.vc.assign(original_values)
File "/home/dberth/Code/objax/objax/variable.py", line 249, in assign
var.assign(tensor)
File "/home/dberth/Code/objax/objax/variable.py", line 61, in assign
assert_assigned_type_and_shape_match(self.value, tensor)
File "/home/dberth/Code/objax/objax/util/check.py", line 52, in assert_assigned_type_and_shape_match
shape_mismatch_error
AssertionError: Assign can not change shape of variable. The current variable shape is (10, 256), but the requested new shape is (1, 256).
The text was updated successfully, but these errors were encountered:
The text was updated successfully, but these errors were encountered: