Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix broken example: examples/text_generation/shakespeare_rnn.py #172

Open
david-berthelot opened this issue Dec 14, 2020 · 0 comments
Open
Assignees

Comments

@david-berthelot
Copy link
Contributor

CUDA_VISIBLE_DEVICES=0 python examples/text_generation/shakespeare_rnn.py 
[GpuDevice(id=0)]
X_one_hot.shape: (10, 10, 65)
Z.shape: (100, 65)
to be or not to beEQGnyo.3wZ
Traceback (most recent call last):
  File "/home/dberth/Code/objax/objax/module.py", line 226, in jit
    return f(*args, **kwargs), self.vc.tensors()
  File "/home/dberth/Code/objax/objax/module.py", line 180, in __call__
    return self.__wrapped__(*args, **kwargs)
  File "examples/text_generation/shakespeare_rnn.py", line 193, in train_op
    g, v = gv(x, xl)  # returns gradients, loss
  File "/home/dberth/Code/objax/objax/gradient.py", line 84, in __call__
    list(args), kwargs)
  File "/home/dberth/jax3/lib/python3.6/site-packages/jax/api.py", line 756, in grad_f_aux
    (_, aux), g = value_and_grad_f(*args, **kwargs)
  File "/home/dberth/jax3/lib/python3.6/site-packages/jax/api.py", line 815, in value_and_grad_f
    ans, vjp_py, aux = _vjp(f_partial, *dyn_args, has_aux=True)
  File "/home/dberth/jax3/lib/python3.6/site-packages/jax/api.py", line 1863, in _vjp
    out_primal, out_vjp, aux = ad.vjp(flat_fun, primals_flat, has_aux=True)
  File "/home/dberth/jax3/lib/python3.6/site-packages/jax/interpreters/ad.py", line 115, in vjp
    out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)
  File "/home/dberth/jax3/lib/python3.6/site-packages/jax/interpreters/ad.py", line 100, in linearize
    jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
  File "/home/dberth/jax3/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 404, in trace_to_jaxpr
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
  File "/home/dberth/jax3/lib/python3.6/site-packages/jax/linear_util.py", line 151, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/dberth/Code/objax/objax/gradient.py", line 63, in f_func
    self.vc.assign(original_vc)
  File "/home/dberth/Code/objax/objax/variable.py", line 249, in assign
    var.assign(tensor)
  File "/home/dberth/Code/objax/objax/variable.py", line 61, in assign
    assert_assigned_type_and_shape_match(self.value, tensor)
  File "/home/dberth/Code/objax/objax/util/check.py", line 52, in assert_assigned_type_and_shape_match
    shape_mismatch_error
AssertionError: Assign can not change shape of variable. The current variable shape is (10, 256), but the requested new shape is (1, 256).

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "examples/text_generation/shakespeare_rnn.py", line 210, in <module>
    v = train_op(X_one_hot, flat_labels)
  File "/home/dberth/Code/objax/objax/module.py", line 236, in __call__
    output, changes = self._call(self.vc.tensors(), kwargs, *args)
  File "/home/dberth/jax3/lib/python3.6/site-packages/jax/api.py", line 216, in f_jitted
    donated_invars=donated_invars)
  File "/home/dberth/jax3/lib/python3.6/site-packages/jax/core.py", line 1155, in bind
    return call_bind(self, fun, *args, **params)
  File "/home/dberth/jax3/lib/python3.6/site-packages/jax/core.py", line 1146, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/home/dberth/jax3/lib/python3.6/site-packages/jax/core.py", line 1158, in process
    return trace.process_call(self, fun, tracers, params)
  File "/home/dberth/jax3/lib/python3.6/site-packages/jax/core.py", line 577, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/home/dberth/jax3/lib/python3.6/site-packages/jax/interpreters/xla.py", line 557, in _xla_call_impl
    *unsafe_map(arg_spec, args))
  File "/home/dberth/jax3/lib/python3.6/site-packages/jax/linear_util.py", line 234, in memoized_fun
    ans = call(fun, *args)
  File "/home/dberth/jax3/lib/python3.6/site-packages/jax/interpreters/xla.py", line 622, in _xla_callable
    jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args)
  File "/home/dberth/jax3/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 1038, in trace_to_jaxpr_final
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
  File "/home/dberth/jax3/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 1019, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)
  File "/home/dberth/jax3/lib/python3.6/site-packages/jax/linear_util.py", line 151, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/dberth/Code/objax/objax/module.py", line 228, in jit
    self.vc.assign(original_values)
  File "/home/dberth/Code/objax/objax/variable.py", line 249, in assign
    var.assign(tensor)
  File "/home/dberth/Code/objax/objax/variable.py", line 61, in assign
    assert_assigned_type_and_shape_match(self.value, tensor)
  File "/home/dberth/Code/objax/objax/util/check.py", line 52, in assert_assigned_type_and_shape_match
    shape_mismatch_error
AssertionError: Assign can not change shape of variable. The current variable shape is (10, 256), but the requested new shape is (1, 256).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants