MoritzMaxeiner commented on issue #8836: Backward shape inconsistent with 
custom HybridBlock and gluon.loss
URL: 
https://github.com/apache/incubator-mxnet/issues/8836#issuecomment-347384351
 
 
   @reminisce I've removed the time unrolling, but if I remove either of the 
two cells, or the reshape operation, the issue won't arise, so I don't think I 
can reduce it any further.
   
   ```python
   import mxnet as mx
   
   class Test(mx.gluon.HybridBlock):
       def __init__(self, input_size, output_size, **kwargs):
           super(Test, self).__init__(**kwargs)
           self.input_size = input_size
           self.output_size = output_size
           self.hidden_unit_size = output_size*input_size
   
           self.num_cells = 2
           with self.name_scope():
               self.cell_a = mx.gluon.rnn.GRUCell(self.hidden_unit_size, 
input_size=input_size)
               self.cell_b = mx.gluon.rnn.GRUCell(self.hidden_unit_size, 
input_size=input_size)
   
       def hybrid_forward(self, F, inputs, states):
           prev_h = states[0]
           if F is mx.symbol:
               prev_h = F.split(prev_h, axis=0, num_outputs=self.num_cells, 
squeeze_axis=1)
   
           cell_a_next_h, _ = self.cell_a(inputs, [prev_h[0]])
   
           cell_b_next_h, _ = self.cell_b(prev_h[1], [prev_h[1]])
   
           b_output = cell_b_next_h.reshape(shape=(0, self.input_size, 
self.output_size))
   
           return cell_a_next_h, b_output, []
   
       def state_info(self, batch_size=0):
           return [{'shape': (self.num_cells, batch_size, 
self.hidden_unit_size), '__layout__': 'LNC'}]
   
       def begin_state(self, batch_size=0, func=mx.ndarray.zeros, **kwargs):
           states = []
           for i, info in enumerate(self.state_info(batch_size)):
               if info is not None:
                   info.update(kwargs)
               else:
                   info = kwargs
               states.append(func(name='%sh0_%d'%(self.prefix, i), **info))
           return states
   
   args_nof_examples = 1
   args_nof_batches = 1
   args_batch_size = 1
   
   args_input_size = 1
   args_output_size = 1
   
   data = mx.ndarray.zeros(shape=(args_nof_examples, args_input_size))
   labels = mx.ndarray.ones((args_nof_examples, args_input_size))
   gen = mx.io.NDArrayIter(data, labels, args_batch_size, 
last_batch_handle='discard')
   
   with mx.cpu(0) as context:
       model = Test(args_input_size, args_output_size)
       model.initialize(mx.init.Xavier(), ctx = context)
       model.hybridize()
   
       loss = mx.gluon.loss.SoftmaxCrossEntropyLoss()
   
       states = model.begin_state(args_batch_size)
       for batch in gen:
           with mx.autograd.record():
               a, b, _ = model(batch.data[0], states)
               L = loss(b, batch.label[0])
           L.backward()
   ```

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to