Hi guys,
Could someone explain why I have this exception:
> Traceback (most recent call last):
> File "test.py", line 51, in <module>
> grad = T.grad(T.mean(stop_grad(f) * log_p), [E, W])
> File "python_venv/local/lib/python2.7/site-packages/theano/gradient.py",
> line 561, in grad
> rval[i].type.why_null)
> theano.gradient.NullTypeGradError: tensor.grad encountered a NaN. This
> variable is Null because the grad method for input 1 (Reshape{1}.0) of the
> CrossentropyCategorical1Hot op is not implemented.
in the following snippet of code. The code works after
replacing T.nnet.categorical_crossentropy with commented line. It looks
like theano are trying to propagate derivatives through the argmax
operation when using T.nnet.categorical_crossentropy. My theano version
is 0.9.0beta1
I would appreciate any help on this matter.
import theano
import numpy as np
from theano import tensor as T
from theano.gradient import disconnected_grad as stop_grad
bs = 42
vs = 3
ms = 4
hid_dim = 3
E = theano.shared(np.zeros((vs, hid_dim), dtype=np.float32))
W = theano.shared(np.zeros((hid_dim, vs), dtype=np.float32))
def step_fun(x_t, h_tm1):
h_t = T.dot(x_t, E)
mu = T.dot(h_t + h_tm1, W)
x_tp1 = T.nnet.softmax(mu)
return [x_tp1, h_t]
start_x = T.zeros((bs, vs), dtype=np.float32)
h0 = T.zeros((bs, hid_dim))
[x, _], _ = theano.scan(fn=step_fun,
outputs_info=[start_x, h0],
n_steps=ms)
x = x.dimshuffle(1, 0, 2)
f = T.sum(T.prod(x, axis=1), axis=1)
pred = x[:, 1:]
pred = T.reshape(pred, (bs * (ms - 1), vs))
x = T.argmax(x, axis=2)
targets = T.reshape(x[:, 1:], (bs * (ms - 1), ))
log_p = - T.nnet.categorical_crossentropy(pred, targets)
# log_p = T.log(pred[T.arange(targets.shape[0]), targets])
log_p = T.reshape(log_p, (bs, (ms - 1)))
log_p = T.sum(log_p, axis=1)
grad = T.grad(T.mean(stop_grad(f) * log_p), [E, W])
fun = theano.function([], grad)
fun()
--
---
You received this message because you are subscribed to the Google Groups
"theano-users" group.
To unsubscribe from this group and stop receiving emails from it, send an email
to [email protected].
For more options, visit https://groups.google.com/d/optout.