Dear all,
I have the following problem: When I want to differentiate a row of an
input matrix w.r.t. a row of an output matrix, I get a
DisconnectedInputError. The minimum scenario is in the attached file. I
suspect theano does not consider theta and theta[i] to be the same
variable, rather take theta[i] to be a different variable depending on
theta.
Long story:
I have a neural network, using M-dimensional input and N-dimensional output
and having batch size B. Input matrix for my net is then (B,M) dimensional
and the output matrix is (B,N) dimensional. For every batch element I want
to compute the jacobbian of the network, that is, M*N matrix with gradients
of element outputs depending on the element inputs. In the scan function I
can compute jacobbians of individual rows wrt. whole input (thus having B
matrices of size (B*M, N) but when I try computing the jacobbian only wrt a
row of the input, I get the DisconnectedInputError.
The question is:
Is there a way to differentiate w.r.t. a subtensor?
Best regards,
Jan
--
---
You received this message because you are subscribed to the Google Groups
"theano-users" group.
To unsubscribe from this group and stop receiving emails from it, send an email
to [email protected].
For more options, visit https://groups.google.com/d/optout.
import theano
from theano import tensor as T
import numpy as np
n_samples = 10
theta_single = T.as_tensor_variable([0.2, 0.15, 0.4, 0.05, 0.2])
theta = T.tile(theta_single,(n_samples, 1))
print theta_single.eval().shape
print theta.eval().shape
# differentiating with the original theta_single
J1, _ = theano.scan(lambda i, theta, theta_single : theano.gradient.jacobian(theta[i,:], theta_single),
sequences=T.arange(theta.shape[0]),
non_sequences=[theta, theta_single])
J1.eval()
# same as before, just differentiating wrt to single row of theta taken from that same array\n",
J2, _ = theano.scan(lambda i, theta, theta_single : theano.gradient.jacobian(theta[i,:], theta[i,:]),
sequences=T.arange(theta.shape[0]),
non_sequences=[theta, theta_single])
J2.eval()