"forward the precomputed output" means that Op1 already computed the final 
output, therefore Op2 just has to behaves as identity in the forward pass 

The intermediate value is already an output of Op1 as shown in the example 
code, sorry if that wasn't clear.

Nicolas

Le mardi 8 août 2017 20:56:12 UTC+2, nouiz a écrit :
>
> I don't understand what you mean by "forward the precomputed output"
>
> What I would recommand is to make 1 op for the forward. The intermediate 
> value that can be reused for the gradient, make them output. Don't use them 
> in the forward, but you can reuse them your grad override.
>
> Frédéric
>
> On Mon, Jul 31, 2017 at 9:43 AM <nicolas....@gmail.com <javascript:>> 
> wrote:
>
>> I am trying to build an Op with a custom/optimized gradient formula. To 
>> override the automatic differenciation, I'm trying to use OpFromGraph. 
>> The gradient formula can reuse intermediate results from the feed forward 
>> pass, so I have tried to split the Op in two: Op1 computes the intermediate 
>> and final result and gives all of it to Op2, Op2 forwards the final result 
>> and takes care of the gradient computation given all the necessary values.
>>
>> Note that the gradient of the loss wrt the intermediate results is never 
>> needed.
>>
>> Below is a what I believe to be a minimal working example of my problem, 
>> it exhibits a strange conversion error related to the gradient computation 
>> with the intermediate values. Please take note of the presence of an 
>> integral variable.
>>
>> import numpy as np
>> import theano.tensor as T
>> import theano
>>
>>
>> def make_ops():
>>     x = T.vector()
>>     m = T.bvector()
>>
>>     r = m.sum().astype('floatX')  # intermediate value
>>     z = x * m / r  # final result
>>
>>
>>     def grad_op1(inputs, output_gradients):
>>         return [
>>             output_gradients[0],  # gradient computation delegated to op2
>>             T.DisconnectedType()()  # variable has integral type
>>             # T.zeros_like(inputs[1])
>>         ]
>>
>>
>>     op1 = theano.OpFromGraph(
>>         inputs=[x, m],
>>         outputs=[z, m, r],
>>         grad_overrides=grad_op1,
>>         inline=True,
>>         name="op1")
>>
>>
>>     z = T.vector()
>>     r_forwarded = T.scalar()
>>
>>     def grad_op2(inputs, output_gradients):
>>         _, m_, r_ = inputs
>>         dm_ = theano.gradient.DisconnectedType()(name="dm_")
>>         # I think the error could be around here 
>> <<<<<<<<<<------------------------------
>>         # dr_ = theano.gradient.DisconnectedType()(name="dr_")
>>         dr_ = T.zeros_like(r_)
>>         return [m_ / r_, dm_, dr_]
>>
>>     op2 = theano.OpFromGraph(
>>         inputs=[z, m, r_forwarded],
>>         outputs=[z],  # Op 2 forwards the precomputed output
>>         grad_overrides=grad_op2,
>>         inline=True,
>>         name="op2")
>>
>>     return op1, op2
>>
>>
>> def main():
>>     op1, op2 = make_ops()
>>     x = T.vector(name="x")
>>     m = T.bvector(name="m")
>>     z_intermediate, m_forwarded, r = op1(x, m)
>>     z = op2(z_intermediate, m, r)
>>
>>     g = theano.grad(T.sum(z), wrt=x)
>>     print(g.eval({x: np.array([1., .3, .0, .2], dtype=np.float32),
>>                   m: np.array([1, 0, 1, 1], dtype=np.int8)}))
>>
>>
>> if __name__ == "__main__":
>>     main()
>>
>> (Note: I had tried to hijack my previous question thread with this 
>> problem but it went unnoticed, sorry for double posting)
>>
>> Thank you
>>
>> -- 
>>
>> --- 
>> You received this message because you are subscribed to the Google Groups 
>> "theano-users" group.
>> To unsubscribe from this group and stop receiving emails from it, send an 
>> email to theano-users...@googlegroups.com <javascript:>.
>> For more options, visit https://groups.google.com/d/optout.
>>
>

-- 

--- 
You received this message because you are subscribed to the Google Groups 
"theano-users" group.
To unsubscribe from this group and stop receiving emails from it, send an email 
to theano-users+unsubscr...@googlegroups.com.
For more options, visit https://groups.google.com/d/optout.

Reply via email to