reminisce commented on a change in pull request #15581: Numpy-compatible Infra
URL: https://github.com/apache/incubator-mxnet/pull/15581#discussion_r309459033
 
 

 ##########
 File path: python/mxnet/numpy/multiarray.py
 ##########
 @@ -0,0 +1,2147 @@
+#!/usr/bin/env python
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# pylint: disable=too-many-lines
+"""numpy ndarray and util functions."""
+
+from __future__ import absolute_import
+from __future__ import division
+
+try:
+    from __builtin__ import slice as py_slice
+except ImportError:
+    from builtins import slice as py_slice
+
+from array import array as native_array
+import sys
+import ctypes
+import warnings
+import numpy as _np
+from ..ndarray import NDArray, _DTYPE_NP_TO_MX, _GRAD_REQ_MAP
+from ..ndarray._internal import _set_np_ndarray_class
+from . import _op as _mx_np_op
+from ..base import check_call, _LIB, NDArrayHandle
+from ..base import mx_real_t, c_array_buf, mx_uint, numeric_types, 
integer_types
+from ..util import _sanity_check_params, set_module
+from ..context import current_context
+from ..ndarray import numpy as _mx_nd_np
+from ..ndarray.numpy import _internal as _npi
+
+__all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'maximum', 'minimum', 
'stack', 'arange',
+           'argmax', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 
'concatenate',
+           'clip', 'split', 'swapaxes', 'expand_dims', 'tile', 'sin', 'cos',
+           'sinh', 'cosh', 'log10', 'sqrt', 'abs', 'exp', 'arctan']
+
+
+# This function is copied from ndarray.py since pylint
+# keeps giving false alarm error of undefined-all-variable
+def _new_alloc_handle(shape, ctx, delay_alloc, dtype=mx_real_t):
+    """Return a new handle with specified shape and context.
+
+    Empty handle is only used to hold results.
+
+    Returns
+    -------
+    handle
+        A new empty `ndarray` handle.
+    """
+    hdl = NDArrayHandle()
+    check_call(_LIB.MXNDArrayCreateEx(
+        c_array_buf(mx_uint, native_array('I', shape)),
+        mx_uint(len(shape)),
+        ctypes.c_int(ctx.device_typeid),
+        ctypes.c_int(ctx.device_id),
+        ctypes.c_int(int(delay_alloc)),
+        ctypes.c_int(int(_DTYPE_NP_TO_MX[_np.dtype(dtype).type])),
+        ctypes.byref(hdl)))
+    return hdl
+
+
+# Have to use 0 as default value for stype since plylint does not allow
+# importing _STORAGE_TYPE_DEFAULT from ndarray.py.
+def _np_ndarray_cls(handle, writable=True, stype=0):
+    if stype != 0:
+        raise ValueError('_np_ndarray_cls currently only supports default 
storage '
+                         'type, while received stype = {}'.format(stype))
+    return ndarray(handle, writable=writable)
+
+
+_set_np_ndarray_class(_np_ndarray_cls)
+
+
+def _get_index(idx):
+    if isinstance(idx, NDArray) and not isinstance(idx, ndarray):
+        raise TypeError('Cannot have mx.nd.NDArray as index')
+    if isinstance(idx, ndarray):
+        return idx._as_nd_ndarray()
+    elif sys.version_info[0] > 2 and isinstance(idx, range):
+        return arange(idx.start, idx.stop, idx.step, 
dtype='int32')._as_nd_ndarray()
+    else:
+        return idx
+
+
+@set_module('mxnet.numpy')  # pylint: disable=invalid-name
+class ndarray(NDArray):
+    """An array object represents a multidimensional, homogeneous array of 
fixed-size items.
+    An associated data-type object describes the format of each element in the 
array
+    (its byte-order, how many bytes it occupies in memory, whether it is an 
integer, a
+    floating point number, or something else, etc.). Arrays should be 
constructed using
+    `array`, `zeros` or `empty`. Currently, only c-contiguous arrays are 
supported."""
+
+    # pylint: disable=too-many-return-statements
+    def __getitem__(self, key):
+        # TODO(junwu): calling base class __getitem__ is a temp solution
+        ndim = self.ndim
+        shape = self.shape
+        if ndim == 0:
+            if key != ():
+                raise IndexError('scalar tensor can only accept `()` as index')
+        if isinstance(key, tuple) and len(key) == 0:
+            return self
+        elif isinstance(key, tuple) and len(key) == ndim\
+                and all(isinstance(idx, integer_types) for idx in key):
+            out = self
+            for idx in key:
+                out = out[idx]
 
 Review comment:
   `arr[i]` gives a view of the original subarray, while we don't have a C-API 
to do the equivalent thing for `out[merge_idx(idx)]`. We can add it when the 
whole basic/advanced indexing implementation are PRed.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to