I am trying to build an Op with a custom/optimized gradient formula. To
override the automatic differenciation, I'm trying to use OpFromGraph.
The gradient formula can reuse intermediate results from the feed forward
pass, so I have tried to split the Op in two: Op1 computes the intermediate
and final result and gives all of it to Op2, Op2 forwards the final result
and takes care of the gradient computation given all the necessary values.
Note that the gradient of the loss wrt the intermediate results is never
needed.
Below is a what I believe to be a minimal working example of my problem, it
exhibits a strange conversion error related to the gradient computation
with the intermediate values. Please take note of the presence of an
integral variable.
import numpy as np
import theano.tensor as T
import theano
def make_ops():
x = T.vector()
m = T.bvector()
r = m.sum().astype('floatX') # intermediate value
z = x * m / r # final result
def grad_op1(inputs, output_gradients):
return [
output_gradients[0], # gradient computation delegated to op2
T.DisconnectedType()() # variable has integral type
# T.zeros_like(inputs[1])
]
op1 = theano.OpFromGraph(
inputs=[x, m],
outputs=[z, m, r],
grad_overrides=grad_op1,
inline=True,
name="op1")
z = T.vector()
r_forwarded = T.scalar()
def grad_op2(inputs, output_gradients):
_, m_, r_ = inputs
dm_ = theano.gradient.DisconnectedType()(name="dm_")
# I think the error could be around here
<<<<<<<<<<------------------------------
# dr_ = theano.gradient.DisconnectedType()(name="dr_")
dr_ = T.zeros_like(r_)
return [m_ / r_, dm_, dr_]
op2 = theano.OpFromGraph(
inputs=[z, m, r_forwarded],
outputs=[z], # Op 2 forwards the precomputed output
grad_overrides=grad_op2,
inline=True,
name="op2")
return op1, op2
def main():
op1, op2 = make_ops()
x = T.vector(name="x")
m = T.bvector(name="m")
z_intermediate, m_forwarded, r = op1(x, m)
z = op2(z_intermediate, m, r)
g = theano.grad(T.sum(z), wrt=x)
print(g.eval({x: np.array([1., .3, .0, .2], dtype=np.float32),
m: np.array([1, 0, 1, 1], dtype=np.int8)}))
if __name__ == "__main__":
main()
(Note: I had tried to hijack my previous question thread with this problem
but it went unnoticed, sorry for double posting)
Thank you
--
---
You received this message because you are subscribed to the Google Groups
"theano-users" group.
To unsubscribe from this group and stop receiving emails from it, send an email
to [email protected].
For more options, visit https://groups.google.com/d/optout.