haojin2 commented on a change in pull request #15905: [Numpy] Basic indexing in 
symbolic interface
URL: https://github.com/apache/incubator-mxnet/pull/15905#discussion_r314582546
 
 

 ##########
 File path: python/mxnet/symbol/numpy/_symbol.py
 ##########
 @@ -39,25 +46,90 @@ def _num_outputs(sym):
 
 @set_module('mxnet.symbol.numpy')
 class _Symbol(Symbol):
-    def __getitem__(self, key):
-        num_outputs = _num_outputs(self)
-        if num_outputs == 1:
-            raise NotImplementedError
-        if not isinstance(key, int):
+    def __init__(self, handle):
+        super(_Symbol, self).__init__(handle)
+        self._output_is_list = False
+
+    def __getitem__(self, key): # pylint: disable = 
too-many-return-statements, inconsistent-return-statements
+        num_outputs = len(self)
+        # print("Num of outputs is ", num_outputs)
+        if num_outputs == 1: # pylint: disable = too-many-nested-blocks
+            # If number of output is one and is not a list, perform ndarray 
basic slicing
+            if not self._output_is_list:
+                if isinstance(key, integer_types):
+                    sliced = _npi.slice(self, key, key+1)
+                    return _npi.reshape(sliced, (-3, -4))
+                elif isinstance(key, py_slice):
+                    if key.step is None or key.step != 0:
+                        start = [None] if key.start is None else key.start
+                        stop = [None] if key.stop is None else key.stop
+                        return _npi.slice(self, start, stop, key.step)
+                    else:
+                        raise ValueError("slice step cannot be zero")
+                elif isinstance(key, list):
+                    raise NotImplementedError
+                elif isinstance(key, tuple):
+                    begin = []
+                    end = []
+                    step = []
+                    new_shape = ()
+                    for index in key:
+                        if isinstance(index, py_slice):
+                            if index.step is not None and index.step == 0:
+                                raise ValueError("slice step cannot be zero")
+                            begin.append(index.start)
+                            end.append(index.stop)
+                            step.append(index.step)
+                            new_shape += (-2,)
+                        elif isinstance(index, integer_types):
+                            begin.append(index)
+                            end.append(index+1)
+                            step.append(1)
+                            new_shape += (-3,)
+                    new_shape += (-4,)
+                    sliced = _npi.slice(self, begin, end, step)
+                    return _npi.reshape(sliced, new_shape)
+            # perform trivial list slicing on length one list represented by 
flag
+            else:
+                if isinstance(key, integer_types):
+                    if key in [-1, 0]:
+                        self._output_is_list = False
+                        return self
+                    else:
+                        raise IndexError
+                elif isinstance(key, py_slice):
+                    if (key.start is None or key.start <= 0) and (key.stop is 
None or key.stop > 0):
+                        return self
+                    else:
+                        raise ValueError
+                else:
+                    raise IndexError
+        # list slicing on several nodes of outputs
+        elif num_outputs > 1:
+            if isinstance(key, py_slice):
+                start = 0 if key.start is None else key.start
+                stop = num_outputs if key.stop is None else key.stop
+                step = 1 if key.step is None else key.step
+                return Group([self[i] for i in range(start, stop, step)], 
_Symbol)
+            elif isinstance(key, integer_types):
+                if key >= num_outputs:
+                # Important, python determines the end by this exception
+                    raise IndexError
+                handle = SymbolHandle()
+                check_call(_LIB.MXSymbolGetOutput(
+                    self.handle, mx_uint(key), ctypes.byref(handle)))
+                return _Symbol(handle=handle)
+            else:
+                raise NotImplementedError
+        else:
             raise NotImplementedError
-        if key >= num_outputs:
-            # Important, python determines the end by this exception
-            raise IndexError
-        handle = SymbolHandle()
-        check_call(_LIB.MXSymbolGetOutput(
-            self.handle, mx_uint(key), ctypes.byref(handle)))
-        return _Symbol(handle=handle)
+
 
 Review comment:
   get rid of the extra blank line.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to