chrishkchris edited a comment on issue #468: Distributted module
URL: https://github.com/apache/incubator-singa/pull/468#issuecomment-527824025
 
 
   > I have combined all the commits into two commits. Meanwhile, I found that 
the resnet.py is not compatible with the master branch modified "Add" function 
with broadcasting. Get the error (    assert(len(self.shape0) <= 2 and 
len(self.shape1) <= 2),"up till now, the dimensions of tensor a and b should 
less than 3"
   > AssertionError: up till now, the dimensions of tensor a and b should less 
than 3)
   > Since in resnet we used "out = autograd.add(out, residual)", the input to 
the add function should have a dimension more than 2, the assert function 
should return always false and hence assertion error
   
   When I disable the assertion `assert(len(self.shape0) <= 2 and 
len(self.shape1) <= 2)`, the resnet.py can run successfully
   
   See the code of Add function
   
   ```python
   class Add(Operation):
       def __init__(self):
           super(Add, self).__init__()
   
       def forward(self, a, b):
           #up till now, the dimensions of tensor a and b should less than 3
           self.shape0=list(a.shape())
           self.shape1=list(b.shape())
           assert(len(self.shape0) <= 2 and len(self.shape1) <= 2),"up till 
now, the dimensions of tensor a and b should less than 3"
           return singa.__add__(a, b)
   
       def backward(self, dy):
           if(type(dy)==float):
               assert self.shape0==self.shape1,('should have same shape')
               return dy,dy
           db=CTensor(list(dy.shape()), dy.device())
           db.CopyData(dy)
           for i in range(len(self.shape0)-len(self.shape1)):
               db=singa.Sum(db, 0)
           return dy, db
   ```
   Can we allow input dimension to ADD more than two? (e.g. change the limit 2 
to 4 or disable the assertion). Typically there are four dimensions for conv. 
feature maps: batch, depth/channel, width, height
   Or if the input shapes of the two operands are two same, it also passes:
   ```python
   assert( (len(self.shape0) <= 2 and len(self.shape1) <= 2) or 
(len(self.shape0)==len(self.shape1)) )

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to