piiswrong closed pull request #10491: [MXNET-306] Add slice_like operator
URL: https://github.com/apache/incubator-mxnet/pull/10491
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/docs/api/python/ndarray/ndarray.md 
b/docs/api/python/ndarray/ndarray.md
index 08acc1a73c5..923c00820d5 100644
--- a/docs/api/python/ndarray/ndarray.md
+++ b/docs/api/python/ndarray/ndarray.md
@@ -323,6 +323,7 @@ The `ndarray` package provides several classes:
     NDArray.__setitem__
     NDArray.slice
     NDArray.slice_axis
+    NDArray.slice_like
     NDArray.take
     NDArray.one_hot
     NDArray.pick
@@ -423,6 +424,7 @@ The `ndarray` package provides several classes:
 
     slice
     slice_axis
+    slice_like
     take
     batch_take
     one_hot
diff --git a/docs/api/python/symbol/symbol.md b/docs/api/python/symbol/symbol.md
index 2eceead3c0f..c750435c521 100644
--- a/docs/api/python/symbol/symbol.md
+++ b/docs/api/python/symbol/symbol.md
@@ -287,6 +287,7 @@ Composite multiple symbols into a new one by an operator.
 
     Symbol.slice
     Symbol.slice_axis
+    Symbol.slice_like
     Symbol.take
     Symbol.one_hot
     Symbol.pick
@@ -419,6 +420,7 @@ Composite multiple symbols into a new one by an operator.
 
     slice
     slice_axis
+    slice_like
     take
     batch_take
     one_hot
diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py
index 34aff0f7eff..361aa240daf 100644
--- a/python/mxnet/ndarray/ndarray.py
+++ b/python/mxnet/ndarray/ndarray.py
@@ -1124,6 +1124,14 @@ def slice_axis(self, *args, **kwargs):
         """
         return op.slice_axis(self, *args, **kwargs)
 
+    def slice_like(self, *args, **kwargs):
+        """Convenience fluent method for :py:func:`slice_like`.
+
+        The arguments are the same as for :py:func:`slice_like`, with
+        this array as data.
+        """
+        return op.slice_like(self, *args, **kwargs)
+
     def take(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`take`.
 
diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py
index 6d5c8c3a146..94896eed662 100644
--- a/python/mxnet/symbol/symbol.py
+++ b/python/mxnet/symbol/symbol.py
@@ -1861,6 +1861,14 @@ def slice_axis(self, *args, **kwargs):
         """
         return op.slice_axis(self, *args, **kwargs)
 
+    def slice_like(self, *args, **kwargs):
+        """Convenience fluent method for :py:func:`slice_like`.
+
+        The arguments are the same as for :py:func:`slice_like`, with
+        this array as data.
+        """
+        return op.slice_like(self, *args, **kwargs)
+
     def take(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`take`.
 
diff --git a/src/operator/tensor/matrix_op-inl.h 
b/src/operator/tensor/matrix_op-inl.h
index 7756119ba3f..c46233c367f 100644
--- a/src/operator/tensor/matrix_op-inl.h
+++ b/src/operator/tensor/matrix_op-inl.h
@@ -1129,6 +1129,166 @@ void SliceAxisGrad_(const nnvm::NodeAttrs& attrs,
   }
 }
 
+struct SliceLikeParam : public dmlc::Parameter<SliceLikeParam> {
+  TShape axes;
+  DMLC_DECLARE_PARAMETER(SliceLikeParam) {
+    DMLC_DECLARE_FIELD(axes).set_default(TShape())
+    .describe("List of axes on which input data will be sliced according to 
the "
+              "corresponding size of the second input. By default will slice 
on "
+              "all axes. Negative axes are supported.");
+  }
+};
+
+inline bool SliceLikeShape(const nnvm::NodeAttrs& attrs,
+                           std::vector<TShape> *in_attrs,
+                           std::vector<TShape> *out_attrs) {
+  const SliceLikeParam& param = nnvm::get<SliceLikeParam>(attrs.parsed);
+  CHECK_EQ(in_attrs->size(), 2U);
+  CHECK_EQ(out_attrs->size(), 1U);
+  TShape& ishape = (*in_attrs)[0];
+  TShape& from_shape = (*in_attrs)[1];
+  if (param.axes.ndim() == 0) {
+    CHECK_EQ(ishape.ndim(), from_shape.ndim())
+      << "By default slice_axis performs slice on all axes, but ndim mismatch "
+         "for inputs: " << ishape.ndim() << " vs. " << from_shape.ndim();
+    for (index_t i = 0; i < ishape.ndim(); ++i) {
+      CHECK_GE(ishape[i], from_shape[i])
+        << "Slice axis " << i << " with size " << from_shape[i]
+        << "exceeds limit of input with size " << ishape[i];
+    }
+    SHAPE_ASSIGN_CHECK(*out_attrs, 0, from_shape);
+  } else {
+    TShape shape(ishape);
+    for (index_t i = 0; i < param.axes.ndim(); ++i) {
+      int axis = static_cast<int>(param.axes[i]);
+      if (axis < 0) {
+        axis += static_cast<int>(ishape.ndim());
+      }
+      CHECK_GE(axis, 0)
+        << "Slice axis: " << static_cast<int>(param.axes[i]) << " too small";
+      CHECK_GT(ishape.ndim(), axis)
+        << "Slice axis: " << axis << " exceeds first input: " << ishape.ndim();
+      CHECK_GT(from_shape.ndim(), axis)
+        << "Slice axis: " << axis << " exceeds second input: " << 
from_shape.ndim();
+      shape[axis] = from_shape[axis];
+      CHECK_GE(ishape[axis], from_shape[axis])
+        << "Slice axis " << axis << " with size " << from_shape[axis]
+        << "exceeds limit of input with size " << ishape[axis];
+    }
+    SHAPE_ASSIGN_CHECK(*out_attrs, 0, shape);
+  }
+  return true;
+}
+
+inline void SliceLikeInferRanges(const TShape& dshape,
+                                 const TShape& fshape,
+                                 const TShape& axes,
+                                 nnvm::Tuple<dmlc::optional<int>>* param_begin,
+                                 nnvm::Tuple<dmlc::optional<int>>* param_end,
+                                 nnvm::Tuple<dmlc::optional<int>>* param_step) 
{
+  std::vector<dmlc::optional<int>> pb(dshape.ndim());
+  std::vector<dmlc::optional<int>> pe(dshape.ndim());
+  std::vector<dmlc::optional<int>> ps(dshape.ndim());
+  if (axes.ndim() == 0) {
+    for (index_t i = 0; i < dshape.ndim(); ++i) {
+      pb[i] = 0;
+      pe[i] = fshape[i];
+      ps[i] = 1;
+    }
+  } else {
+    for (index_t i = 0; i < axes.ndim(); ++i) {
+      int axis = static_cast<int>(axes[i]);
+      if (axis < 0) {
+        axis += static_cast<int>(dshape.ndim());
+      }
+      CHECK_GE(axis, 0)
+        << "Slice axis: " << static_cast<int>(axes[i]) << " too small";
+      CHECK_LT(axis, static_cast<int>(dshape.ndim()))
+        << "Slice axis: " << axis << " exceeds first input: " << dshape.ndim();
+      CHECK_LT(axis, static_cast<int>(fshape.ndim()))
+        << "Slice axis: " << axis << " exceeds first input: " << fshape.ndim();
+      pb[axis] = 0;
+      pe[axis] = fshape[axis];
+      ps[axis] = 1;
+    }
+  }
+  *param_begin = nnvm::Tuple<dmlc::optional<int>>(pb.begin(), pb.end());
+  *param_end = nnvm::Tuple<dmlc::optional<int>>(pe.begin(), pe.end());
+  *param_step = nnvm::Tuple<dmlc::optional<int>>(ps.begin(), ps.end());
+}
+
+template<typename xpu>
+void SliceLikeForward(const nnvm::NodeAttrs& attrs,
+                      const OpContext& ctx,
+                      const std::vector<TBlob>& inputs,
+                      const std::vector<OpReqType>& req,
+                      const std::vector<TBlob>& outputs) {
+  CHECK_EQ(inputs.size(), 2U);
+  CHECK_EQ(outputs.size(), 1U);
+  CHECK_EQ(req.size(), 1U);
+  using namespace mshadow::expr;
+  const SliceLikeParam& param = nnvm::get<SliceLikeParam>(attrs.parsed);
+  mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
+  const TBlob& data = inputs[0];
+  const TBlob& out = outputs[0];
+  const TShape& ishape = data.shape_;
+  const TShape& from_shape = inputs[1].shape_;
+  nnvm::Tuple<dmlc::optional<int>> param_begin;
+  nnvm::Tuple<dmlc::optional<int>> param_end;
+  nnvm::Tuple<dmlc::optional<int>> param_step;
+  SliceLikeInferRanges(ishape, from_shape, param.axes, &param_begin, 
&param_end, &param_step);
+
+  MXNET_NDIM_SWITCH(data.ndim(), ndim, {
+    common::StaticArray<int, ndim> begin, end, step;
+    GetIndexRange(data.shape_, param_begin, param_end, param_step, &begin, 
&end, &step);
+    MSHADOW_TYPE_SWITCH(out.type_flag_, DType, {
+      mxnet_op::Kernel<slice_forward<ndim>, xpu>::Launch(s, 
out.shape_.FlatTo2D()[0],
+          out.dptr<DType>(), data.dptr<DType>(), req[0],
+          data.shape_.get<ndim>(), out.shape_.get<ndim>(), begin, step);
+    })
+  })
+}
+
+template<typename xpu>
+void SliceLikeBackward(const nnvm::NodeAttrs& attrs,
+                       const OpContext& ctx,
+                       const std::vector<TBlob>& inputs,
+                       const std::vector<OpReqType>& req,
+                       const std::vector<TBlob>& outputs) {
+  CHECK_EQ(inputs.size(), 1U);
+  CHECK_EQ(outputs.size(), 2U);
+  CHECK_EQ(req.size(), 2U);
+  if (req[0] == kNullOp) return;
+  using namespace mshadow;
+  Stream<xpu>* s = ctx.get_stream<xpu>();
+  const TBlob& ograd = inputs[0];
+  const TBlob& igrad = outputs[0];
+  const SliceLikeParam& param = nnvm::get<SliceLikeParam>(attrs.parsed);
+  Fill(s, outputs[1], req[1], 0);  // Second input not relavant to gradients.
+  if (req[0] == kWriteTo) {
+    Fill(s, igrad, req[0], 0);
+  } else if (req[0] == kWriteInplace) {
+    LOG(FATAL) << "_slice_like_backward does not support kWriteInplace";
+  }
+
+  const TShape& ishape = ograd.shape_;
+  const TShape& from_shape = outputs[1].shape_;
+  nnvm::Tuple<dmlc::optional<int>> param_begin;
+  nnvm::Tuple<dmlc::optional<int>> param_end;
+  nnvm::Tuple<dmlc::optional<int>> param_step;
+  SliceLikeInferRanges(ishape, from_shape, param.axes, &param_begin, 
&param_end, &param_step);
+
+  MXNET_NDIM_SWITCH(ograd.ndim(), ndim, {
+    common::StaticArray<int, ndim> begin, end, step;
+    GetIndexRange(ograd.shape_, param_begin, param_end, param_step, &begin, 
&end, &step);
+    MSHADOW_TYPE_SWITCH(ograd.type_flag_, DType, {
+      mxnet_op::Kernel<slice_assign<ndim>, xpu>::Launch(s, 
ograd.shape_.FlatTo2D()[0],
+          igrad.dptr<DType>(), ograd.dptr<DType>(), req[0],
+          igrad.shape_.get<ndim>(), ograd.shape_.get<ndim>(), begin, step);
+    })
+  })
+}
+
 struct ClipParam : public dmlc::Parameter<ClipParam> {
   real_t a_min, a_max;
   DMLC_DECLARE_PARAMETER(ClipParam) {
diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc
index 9037827a5aa..29d493ae5a5 100644
--- a/src/operator/tensor/matrix_op.cc
+++ b/src/operator/tensor/matrix_op.cc
@@ -95,6 +95,7 @@ DMLC_REGISTER_PARAMETER(ClipParam);
 DMLC_REGISTER_PARAMETER(SliceAssignScalarParam);
 DMLC_REGISTER_PARAMETER(SliceParam);
 DMLC_REGISTER_PARAMETER(SliceAxisParam);
+DMLC_REGISTER_PARAMETER(SliceLikeParam);
 DMLC_REGISTER_PARAMETER(RepeatParam);
 DMLC_REGISTER_PARAMETER(TileParam);
 DMLC_REGISTER_PARAMETER(ReverseParam);
@@ -513,6 +514,80 @@ NNVM_REGISTER_OP(_backward_slice_axis)
 .set_attr<nnvm::TIsBackward>("TIsBackward", true)
 .set_attr<FCompute>("FCompute<cpu>", SliceAxisGrad_<cpu>);
 
+NNVM_REGISTER_OP(slice_like)
+.describe(R"code(Slices a region of the array like the shape of another array.
+
+This function is similar to ``slice``, however, the `begin` are always `0`s
+and `end` of specific axes are inferred from the second input `shape_like`.
+
+Given the second `shape_like` input of ``shape=(d_0, d_1, ..., d_n-1)``,
+a ``slice_like`` operator with default empty `axes`, it performs the
+following operation:
+
+`` out = slice(input, begin=(0, 0, ..., 0), end=(d_0, d_1, ..., d_n-1))``.
+
+When `axes` is not empty, it is used to speficy which axes are being sliced.
+
+Given a 4-d input data, ``slice_like`` operator with ``axes=(0, 2, -1)``
+will perform the following operation:
+
+`` out = slice(input, begin=(0, 0, 0, 0), end=(d_0, None, d_2, d_3))``.
+
+Note that it is allowed to have first and second input with different 
dimensions,
+however, you have to make sure the `axes` are specified and not exceeding the
+dimension limits.
+
+For example, given `input_1` with ``shape=(2,3,4,5)`` and `input_2` with
+``shape=(1,2,3)``, it is not allowed to use:
+
+`` out = slice_like(a, b)`` because ndim of `input_1` is 4, and ndim of 
`input_2`
+is 3.
+
+The following is allowed in this situation:
+
+`` out = slice_like(a, b, axes=(0, 2))``
+
+Example::
+
+  x = [[  1.,   2.,   3.,   4.],
+       [  5.,   6.,   7.,   8.],
+       [  9.,  10.,  11.,  12.]]
+
+  y = [[  0.,   0.,   0.],
+       [  0.,   0.,   0.]]
+
+  slice_like(x, y) = [[ 1.,  2.,  3.]
+                      [ 5.,  6.,  7.]]
+  slice_like(x, y, axes=(0, 1)) = [[ 1.,  2.,  3.]
+                                   [ 5.,  6.,  7.]]
+  slice_like(x, y, axes=(0)) = [[ 1.,  2.,  3.,  4.]
+                                [ 5.,  6.,  7.,  8.]]
+  slice_like(x, y, axes=(-1)) = [[  1.,   2.,   3.]
+                                 [  5.,   6.,   7.]
+                                 [  9.,  10.,  11.]]
+)code" ADD_FILELINE)
+.set_num_inputs(2)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<SliceLikeParam>)
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+  [](const NodeAttrs& attrs) {
+    return std::vector<std::string>{"data", "shape_like"};
+  })
+.set_attr<nnvm::FInferShape>("FInferShape", SliceLikeShape)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>)
+.set_attr<nnvm::FGradient>("FGradient", 
ElemwiseGradUseNone{"_backward_slice_like"})
+.set_attr<FCompute>("FCompute<cpu>", SliceLikeForward<cpu>)
+.add_argument("data", "NDArray-or-Symbol", "Source input")
+.add_argument("shape_like", "NDArray-or-Symbol", "Shape like input")
+.add_arguments(SliceLikeParam::__FIELDS__());
+
+NNVM_REGISTER_OP(_backward_slice_like)
+.set_num_inputs(1)
+.set_num_outputs(2)
+.set_attr_parser(ParamParser<SliceLikeParam>)
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_attr<FCompute>("FCompute<cpu>", SliceLikeBackward<cpu>);
+
 NNVM_REGISTER_OP(clip)
 MXNET_ADD_SPARSE_OP_ALIAS(clip)
 .describe(R"code(Clips (limits) the values in an array.
diff --git a/src/operator/tensor/matrix_op.cu b/src/operator/tensor/matrix_op.cu
index 1dbb3b8f67e..bd1b9f20825 100644
--- a/src/operator/tensor/matrix_op.cu
+++ b/src/operator/tensor/matrix_op.cu
@@ -34,7 +34,7 @@ namespace op {
  * \brief Compute the number of elements of every row.
  */
 struct SliceMarkCsrIndPtr {
-  /*! 
+  /*!
    * \brief
    * \param i           the i-th row of the output csr ndarray
    * \param prefix_sum  indptr array of the output csr ndarray
@@ -168,6 +168,12 @@ NNVM_REGISTER_OP(slice_axis)
 NNVM_REGISTER_OP(_backward_slice_axis)
 .set_attr<FCompute>("FCompute<gpu>", SliceAxisGrad_<gpu>);
 
+NNVM_REGISTER_OP(slice_like)
+.set_attr<FCompute>("FCompute<gpu>", SliceLikeForward<gpu>);
+
+NNVM_REGISTER_OP(_backward_slice_like)
+.set_attr<FCompute>("FCompute<gpu>", SliceLikeBackward<gpu>);
+
 NNVM_REGISTER_OP(clip)
 .set_attr<FCompute>("FCompute<gpu>", Clip<gpu>)
 .set_attr<FComputeEx>("FComputeEx<gpu>", ClipEx<gpu>);
diff --git a/tests/python/unittest/test_ndarray.py 
b/tests/python/unittest/test_ndarray.py
index a7cb037dcf4..030816ecbbc 100644
--- a/tests/python/unittest/test_ndarray.py
+++ b/tests/python/unittest/test_ndarray.py
@@ -920,8 +920,8 @@ def test_output():
 def test_ndarray_fluent():
     has_grad = set(['flatten', 'expand_dims', 'flip', 'tile', 'transpose', 
'sum', 'nansum', 'prod',
                     'nanprod', 'mean', 'max', 'min', 'reshape', 
'broadcast_to', 'split',
-                    'broadcast_axes', 'pad', 'swapaxes', 'slice', 
'slice_axis', 'take',
-                    'one_hot', 'pick', 'sort', 'topk', 'argsort', 'argmax', 
'argmin',
+                    'broadcast_axes', 'pad', 'swapaxes', 'slice', 
'slice_axis', 'slice_like',
+                    'take', 'one_hot', 'pick', 'sort', 'topk', 'argsort', 
'argmax', 'argmin',
                     'clip', 'abs', 'sign', 'sin', 'cos', 'tan', 'arcsin', 
'arccos', 'arctan',
                     'degrees', 'radians', 'sinh', 'cosh', 'tanh', 'arcsinh', 
'arccosh', 'arctanh',
                     'exp', 'expm1', 'log', 'log10', 'log2', 'log1p', 'sqrt', 
'rsqrt', 'square',
@@ -958,6 +958,7 @@ def check_fluent_regular(func, kwargs, shape=(5, 17, 1), 
equal_nan=False):
     check_fluent_regular('split', {'axis': 2, 'num_outputs': 3}, shape=(5, 17, 
6))
     check_fluent_regular('slice', {'begin': (2, 5, 1), 'end': (4, 7, 6)}, 
shape=(5, 17, 6))
     check_fluent_regular('slice_axis', {'axis': 1, 'begin': 5, 'end': 7})
+    check_fluent_regular('slice_like', {'axes': (0, -2), 'shape_like': 
mx.nd.zeros((3, 3))})
     check_fluent_regular('take', {'indices': mx.nd.array([2, 3])})
     check_fluent_regular('pick', {'axis': 1, 'index': mx.nd.array([[2], [3], 
[5], [6], [11]])})
     check_fluent_regular('clip', {'a_min': 0.25, 'a_max': 0.75})
@@ -1259,7 +1260,7 @@ def test_ndarray_astype():
     x = mx.nd.zeros((2, 3), dtype='int32')
     y = x.astype('int32', copy=False)
     assert (id(x) == id(y))
-    
+
     # Test the string version 'int32'
     # has the same behaviour as the np.int32
     x = mx.nd.zeros((2, 3), dtype='int32')
diff --git a/tests/python/unittest/test_operator.py 
b/tests/python/unittest/test_operator.py
index ca1c42c2386..8d28c6d5470 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -1971,6 +1971,46 @@ def test_slice_axis():
             xx[idx] = x.asnumpy()[idx]
             assert_allclose(xx + x_grad_npy, xgrad.asnumpy(), atol=1E-5)
 
+@with_seed()
+def test_slice_like():
+    for ndim in range(1, 6):
+        from_shape = np.random.randint(1, 11, size=(ndim,))
+        shape = [s + np.random.randint(0, 3) for s in from_shape]
+        for t in range(ndim):
+            if t > 0:
+                axes = np.random.randint(0, ndim, size=t).tolist()
+            else:
+                axes = []
+            idx = []
+            for i in range(ndim):
+                idx.append(slice(0, shape[i]))
+                if i in axes or not axes:
+                    idx[i] = slice(0, from_shape[i])
+
+            if axes:
+                pos = np.random.randint(0, t)
+                if axes[pos] > 0:
+                    axes[pos] -= ndim  # negative index
+
+            X = mx.symbol.Variable('X')
+            X_1 = mx.symbol.Variable('X1')
+            x = mx.nd.array(np.random.normal(size=shape))
+            x1 = mx.nd.array(np.random.normal(size=from_shape))
+            Y = mx.symbol.slice_like(data=X, shape_like=X_1, axes=axes)
+
+            xgrad = mx.nd.empty(x.shape)
+            xgrad1 = mx.nd.empty(x1.shape)
+            exec1 = Y.bind(default_context(), args = [x, x1],
+                           args_grad = {'X': xgrad, 'X1': xgrad1})
+            exec1.forward(is_train=True)
+            y = exec1.outputs[0]
+            assert_allclose(x.asnumpy()[idx], y.asnumpy())
+            exec1.backward([y])
+            xx = x.asnumpy()
+            xx[:] = 0.0
+            xx[idx] = x.asnumpy()[idx]
+            assert_allclose(xx, xgrad.asnumpy())
+            assert_allclose(xgrad1.asnumpy(), 
mx.nd.zeros_like(xgrad1).asnumpy())
 
 @with_seed()
 def test_flip():
diff --git a/tests/python/unittest/test_symbol.py 
b/tests/python/unittest/test_symbol.py
index d81b8cfb4c0..11c2eba7609 100644
--- a/tests/python/unittest/test_symbol.py
+++ b/tests/python/unittest/test_symbol.py
@@ -166,8 +166,8 @@ def test_symbol_infer_shape_var():
 def test_symbol_fluent():
     has_grad = set(['flatten', 'expand_dims', 'flip', 'tile', 'transpose', 
'sum', 'nansum', 'prod',
                     'nanprod', 'mean', 'max', 'min', 'reshape', 
'broadcast_to', 'split',
-                    'broadcast_axes', 'pad', 'swapaxes', 'slice', 
'slice_axis', 'take',
-                    'one_hot', 'pick', 'sort', 'topk', 'argsort', 'argmax', 
'argmin',
+                    'broadcast_axes', 'pad', 'swapaxes', 'slice', 
'slice_axis', 'slice_like',
+                    'take', 'one_hot', 'pick', 'sort', 'topk', 'argsort', 
'argmax', 'argmin',
                     'clip', 'abs', 'sign', 'sin', 'cos', 'tan', 'arcsin', 
'arccos', 'arctan',
                     'degrees', 'radians', 'sinh', 'cosh', 'tanh', 'arcsinh', 
'arccosh', 'arctanh',
                     'exp', 'expm1', 'log', 'log10', 'log2', 'log1p', 'sqrt', 
'rsqrt',
@@ -204,6 +204,7 @@ def check_fluent_regular(func, kwargs, shape=(5, 17, 1), 
equal_nan=False):
     check_fluent_regular('split', {'axis': 2, 'num_outputs': 3}, shape=(5, 17, 
6))
     check_fluent_regular('slice', {'begin': (2, 5, 1), 'end': (4, 7, 6)}, 
shape=(5, 17, 6))
     check_fluent_regular('slice_axis', {'axis': 1, 'begin': 5, 'end': 7})
+    check_fluent_regular('slice_like', {'axes': (0, -2), 'shape_like': 
mx.sym.zeros((3, 3))})
     check_fluent_regular('clip', {'a_min': 0.25, 'a_max': 0.75})
     check_fluent_regular('broadcast_axes', {'axis': (2,), 'size': (5,)})
     check_fluent_regular('pad', {'mode': 'constant', 'pad_width': 
(0,0,0,0,3,0,0,4)}, shape=(5, 17, 2, 3))


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to