This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new da5242b  MXNET-1295 Adding integer index support to Sequence* family 
of operators. (#13880)
da5242b is described below

commit da5242b732de39ad47d8ecee582f261ba5935fa9
Author: stephenrawls <[email protected]>
AuthorDate: Sat Jan 26 22:04:09 2019 -0800

    MXNET-1295 Adding integer index support to Sequence* family of operators. 
(#13880)
    
    * Adding integer index support to Sequence* family of operators.
    
    Adding ability to use int32 arrays, or any castable-to-int type, as
    the sequence_length array to SequenceMask, SequenceLast, and
    SequenceReverse. Previously these operaters all requred sequence_length
    to be the same data type as the input array.
    
    See MxNet Jira ticket here:
      https://issues.apache.org/jira/browse/MXNET-1295
    
    See also GitHub issues here:
       https://github.com/apache/incubator-mxnet/issues/12649
       https://github.com/dmlc/gluon-nlp/issues/346
    
    * Adding explicit braces to an if statement to fix g++ warning
    
    * fixing sequence_mask.cu by adding IType to template
    
    * Fixing whitespace errors reported by linter
    
    * Adding unit tests
    
    * Fixing length of lines to pass linter
---
 python/mxnet/test_utils.py             | 37 +++++++++++++-------
 src/operator/sequence_last-inl.h       | 30 ++++++++--------
 src/operator/sequence_last.cc          | 16 ++++++---
 src/operator/sequence_last.cu          |  9 +++--
 src/operator/sequence_mask-inl.h       | 24 ++++++-------
 src/operator/sequence_mask.cc          | 16 ++++++---
 src/operator/sequence_mask.cu          |  9 +++--
 src/operator/sequence_reverse-inl.h    | 20 +++++------
 src/operator/sequence_reverse.cc       | 17 +++++++---
 src/operator/sequence_reverse.cu       |  8 +++--
 tests/python/unittest/test_operator.py | 62 ++++++++++++++++++----------------
 11 files changed, 144 insertions(+), 104 deletions(-)

diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py
index 0a4d17d..4138e4d 100644
--- a/python/mxnet/test_utils.py
+++ b/python/mxnet/test_utils.py
@@ -620,8 +620,11 @@ def _parse_location(sym, location, ctx, 
dtype=default_dtype()):
         *In either case, value of all the arguments must be provided.*
     ctx : Context
         Device context.
-    dtype: np.float16 or np.float32 or np.float64
-        Datatype for mx.nd.array.
+    dtype: "asnumpy" or np.float16 or np.float32 or np.float64
+        If dtype is "asnumpy" then the mx.nd.array created will have the same
+        type as th numpy array from which it is copied.
+        Otherwise, dtype is the explicit datatype for all mx.nd.array objects
+        created in this function.
 
     Returns
     -------
@@ -643,7 +646,7 @@ def _parse_location(sym, location, ctx, 
dtype=default_dtype()):
     ValueError: Symbol arguments and keys of the given location do not match.
     """
     assert isinstance(location, (dict, list, tuple))
-    assert dtype in (np.float16, np.float32, np.float64)
+    assert dtype == "asnumpy" or dtype in (np.float16, np.float32, np.float64)
     if isinstance(location, dict):
         if set(location.keys()) != set(sym.list_arguments()):
             raise ValueError("Symbol arguments and keys of the given location 
do not match."
@@ -651,8 +654,8 @@ def _parse_location(sym, location, ctx, 
dtype=default_dtype()):
                              % (str(set(sym.list_arguments())), 
str(set(location.keys()))))
     else:
         location = {k: v for k, v in zip(sym.list_arguments(), location)}
-    location = {k: mx.nd.array(v, ctx=ctx, dtype=dtype) if isinstance(v, 
np.ndarray) \
-               else v for k, v in location.items()}
+    location = {k: mx.nd.array(v, ctx=ctx, dtype=v.dtype if dtype == "asnumpy" 
else dtype) \
+               if isinstance(v, np.ndarray) else v for k, v in 
location.items()}
     return location
 
 
@@ -677,8 +680,11 @@ def _parse_aux_states(sym, aux_states, ctx, 
dtype=default_dtype()):
         *In either case, all aux states of `sym` must be provided.*
     ctx : Context
         Device context.
-    dtype: np.float16 or np.float32 or np.float64
-        Datatype for mx.nd.array.
+    dtype: "asnumpy" or np.float16 or np.float32 or np.float64
+        If dtype is "asnumpy" then the mx.nd.array created will have the same
+        type as th numpy array from which it is copied.
+        Otherwise, dtype is the explicit datatype for all mx.nd.array objects
+        created in this function.
 
     Returns
     -------
@@ -702,7 +708,7 @@ def _parse_aux_states(sym, aux_states, ctx, 
dtype=default_dtype()):
     >>> _parse_aux_states(fc2, {'batchnorm0_moving_var': mean_states}, None)
     ValueError: Symbol aux_states names and given aux_states do not match.
     """
-    assert dtype in (np.float16, np.float32, np.float64)
+    assert dtype == "asnumpy" or dtype in (np.float16, np.float32, np.float64)
     if aux_states is not None:
         if isinstance(aux_states, dict):
             if set(aux_states.keys()) != set(sym.list_auxiliary_states()):
@@ -713,7 +719,8 @@ def _parse_aux_states(sym, aux_states, ctx, 
dtype=default_dtype()):
         elif isinstance(aux_states, (list, tuple)):
             aux_names = sym.list_auxiliary_states()
             aux_states = {k:v for k, v in zip(aux_names, aux_states)}
-        aux_states = {k: mx.nd.array(v, ctx=ctx, dtype=dtype) for k, v in 
aux_states.items()}
+        aux_states = {k: mx.nd.array(v, ctx=ctx, dtype=v.dtype if dtype == 
"asnumpy" else dtype) \
+                      for k, v in aux_states.items()}
     return aux_states
 
 
@@ -962,8 +969,11 @@ def check_symbolic_forward(sym, location, expected, 
rtol=1E-4, atol=None,
             Contains the mapping between names of auxiliary states and their 
values.
     ctx : Context, optional
         running context
-    dtype: np.float16 or np.float32 or np.float64
-        Datatype for mx.nd.array.
+    dtype: "asnumpy" or np.float16 or np.float32 or np.float64
+        If dtype is "asnumpy" then the mx.nd.array created will have the same
+        type as th numpy array from which it is copied.
+        Otherwise, dtype is the explicit datatype for all mx.nd.array objects
+        created in this function.
 
     equal_nan: Boolean
         if True, `nan` is a valid value for checking equivalency (ie `nan` == 
`nan`)
@@ -979,7 +989,7 @@ def check_symbolic_forward(sym, location, expected, 
rtol=1E-4, atol=None,
     >>> ret_expected = np.array([[19, 22], [43, 50]])
     >>> check_symbolic_forward(sym_dot, [mat1, mat2], [ret_expected])
     """
-    assert dtype in (np.float16, np.float32, np.float64)
+    assert dtype == "asnumpy" or dtype in (np.float16, np.float32, np.float64)
     if ctx is None:
         ctx = default_context()
 
@@ -988,7 +998,8 @@ def check_symbolic_forward(sym, location, expected, 
rtol=1E-4, atol=None,
                                    dtype=dtype)
     if isinstance(expected, dict):
         expected = [expected[k] for k in sym.list_outputs()]
-    args_grad_data = {k:mx.nd.empty(v.shape, ctx=ctx, dtype=dtype) for k, v in 
location.items()}
+    args_grad_data = {k:mx.nd.empty(v.shape, ctx=ctx, dtype=v.dtype if dtype 
== "asnumpy" else dtype) \
+                      for k, v in location.items()}
 
     executor = sym.bind(ctx=ctx, args=location, args_grad=args_grad_data, 
aux_states=aux_states)
     for g in executor.grad_arrays:
diff --git a/src/operator/sequence_last-inl.h b/src/operator/sequence_last-inl.h
index 1a59473..61506c2 100644
--- a/src/operator/sequence_last-inl.h
+++ b/src/operator/sequence_last-inl.h
@@ -65,9 +65,9 @@ struct SequenceLastParam : public 
dmlc::Parameter<SequenceLastParam> {
 
 template <int req>
 struct SequenceLastKernel {
-  template <typename DType>
+  template <typename DType, typename IType>
   MSHADOW_XINLINE static void Map(int i, DType *out, const DType *in,
-                                  const DType *idx, int offset1, int offset2,
+                                  const IType *idx, int offset1, int offset2,
                                   mshadow::Shape<2> oshape) {
     const auto opos = mxnet_op::unravel(i, oshape);
     const int seqpos = static_cast<int>(idx[opos[0]]) - 1;
@@ -77,9 +77,9 @@ struct SequenceLastKernel {
 };
 
 struct SequenceLastGradKernel {
-  template <typename DType>
+  template <typename DType, typename IType>
   MSHADOW_XINLINE static void Map(int i, DType *in_grad, const DType *out_grad,
-                                  const DType *idx, int offset1, int offset2,
+                                  const IType *idx, int offset1, int offset2,
                                   mshadow::Shape<2> oshape) {
     const auto opos = mxnet_op::unravel(i, oshape);
     const int seqpos = static_cast<int>(idx[opos[0]]) - 1;
@@ -88,14 +88,14 @@ struct SequenceLastGradKernel {
   }
 };
 
-template <typename xpu, typename DType>
+template <typename xpu, typename DType, typename IType>
 class SequenceLastOp : public Operator {
  public:
   explicit SequenceLastOp(SequenceLastParam p) { this->param_ = p; }
 
   void sequence_last(const mshadow::Tensor<xpu, 3, DType> &data,
                      const mshadow::Tensor<xpu, 2, DType> &out,
-                     const mshadow::Tensor<xpu, 1, DType> &indices,
+                     const mshadow::Tensor<xpu, 1, IType> &indices,
                      const OpReqType req, mshadow::Stream<xpu> *const s) {
     using namespace mshadow;
     using namespace mshadow::expr;
@@ -115,7 +115,7 @@ class SequenceLastOp : public Operator {
 
   void sequence_last_grad(const mshadow::Tensor<xpu, 3, DType> &in_grad,
                           const mshadow::Tensor<xpu, 2, DType> &out_grad,
-                          const mshadow::Tensor<xpu, 1, DType> &indices,
+                          const mshadow::Tensor<xpu, 1, IType> &indices,
                           mshadow::Stream<xpu> *const s) {
     using namespace mshadow;
     using namespace mshadow::expr;
@@ -163,11 +163,11 @@ class SequenceLastOp : public Operator {
     Tensor<xpu, 2, DType> out =
         out_data[seq_last::kOut].get_with_shape<xpu, 2, DType>(
             Shape2(batch, rest_size), s);
-    Tensor<xpu, 1, DType> indices =
+    Tensor<xpu, 1, IType> indices =
         param_.use_sequence_length
-            ? in_data[seq_last::kSequenceLength].get<xpu, 1, DType>(s)
+            ? in_data[seq_last::kSequenceLength].get<xpu, 1, IType>(s)
             : ctx.requested[seq_last::kTempSpace]
-                  .get_space_typed<xpu, 1, DType>(Shape1(batch), s);
+                  .get_space_typed<xpu, 1, IType>(Shape1(batch), s);
     if (!param_.use_sequence_length) indices = max_seq_len;
 
     sequence_last(data, out, indices, req[seq_last::kOut], s);
@@ -206,11 +206,11 @@ class SequenceLastOp : public Operator {
     Tensor<xpu, 2, DType> output_grad =
         out_grad[seq_last::kOut].get_with_shape<xpu, 2, DType>(
             Shape2(batch, rest_size), s);
-    Tensor<xpu, 1, DType> indices =
+    Tensor<xpu, 1, IType> indices =
         param_.use_sequence_length
-            ? in_data[seq_last::kSequenceLength].get<xpu, 1, DType>(s)
+            ? in_data[seq_last::kSequenceLength].get<xpu, 1, IType>(s)
             : ctx.requested[seq_last::kTempSpace]
-                  .get_space_typed<xpu, 1, DType>(Shape1(batch), s);
+                  .get_space_typed<xpu, 1, IType>(Shape1(batch), s);
 
     if (req[seq_last::kData] == kWriteTo) data_grad = 0.0f;
     sequence_last_grad(data_grad, output_grad, indices, s);
@@ -221,7 +221,7 @@ class SequenceLastOp : public Operator {
 };  // class SequenceLastOp
 
 template <typename xpu>
-Operator *CreateOp(SequenceLastParam param, int dtype);
+Operator *CreateOp(SequenceLastParam param, int dtype, int itype);
 
 #if DMLC_USE_CXX11
 class SequenceLastProp : public OperatorProperty {
@@ -281,8 +281,6 @@ class SequenceLastProp : public OperatorProperty {
     for (size_t i = 0; i < in_type->size(); ++i) {
       if ((*in_type)[i] == -1) {
         (*in_type)[i] = dtype;
-      } else {
-        UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments()[i]);
       }
     }
     out_type->clear();
diff --git a/src/operator/sequence_last.cc b/src/operator/sequence_last.cc
index 345524b..f2388a8 100644
--- a/src/operator/sequence_last.cc
+++ b/src/operator/sequence_last.cc
@@ -28,10 +28,13 @@
 namespace mxnet {
 namespace op {
 template <>
-Operator *CreateOp<cpu>(SequenceLastParam param, int dtype) {
+Operator *CreateOp<cpu>(SequenceLastParam param, int dtype, int itype) {
   Operator *op = nullptr;
-  MSHADOW_TYPE_SWITCH(dtype, DType,
-                           { op = new SequenceLastOp<cpu, DType>(param); })
+  MSHADOW_TYPE_SWITCH(dtype, DType, {
+      MSHADOW_TYPE_SWITCH(itype, IType, {
+          op = new SequenceLastOp<cpu, DType, IType>(param);
+        });
+    });
   return op;
 }
 
@@ -39,7 +42,12 @@ Operator *CreateOp<cpu>(SequenceLastParam param, int dtype) {
 Operator *SequenceLastProp::CreateOperatorEx(Context ctx,
                                              std::vector<TShape> *in_shape,
                                              std::vector<int> *in_type) const {
-  DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0]);
+  if (in_type->size() >= 2 && (*in_type)[1] != -1) {
+    DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0], (*in_type)[1]);
+  }
+
+  // sequence_length not passed in, so fall back to using input array dtype 
for second argument
+  DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0], (*in_type)[0]);
 }
 
 DMLC_REGISTER_PARAMETER(SequenceLastParam);
diff --git a/src/operator/sequence_last.cu b/src/operator/sequence_last.cu
index dfc4e59..fb5ae84 100644
--- a/src/operator/sequence_last.cu
+++ b/src/operator/sequence_last.cu
@@ -28,10 +28,13 @@
 
 namespace mxnet {
 namespace op {
-template <> Operator *CreateOp<gpu>(SequenceLastParam param, int dtype) {
+template <> Operator *CreateOp<gpu>(SequenceLastParam param, int dtype, int 
itype) {
   Operator *op = NULL;
-  MSHADOW_TYPE_SWITCH(dtype, DType,
-                           { op = new SequenceLastOp<gpu, DType>(param); })
+  MSHADOW_TYPE_SWITCH(dtype, DType, {
+      MSHADOW_TYPE_SWITCH(itype, IType, {
+          op = new SequenceLastOp<gpu, DType, IType>(param);
+        });
+    });
   return op;
 }
 
diff --git a/src/operator/sequence_mask-inl.h b/src/operator/sequence_mask-inl.h
index c93ffb5..c2584ab 100644
--- a/src/operator/sequence_mask-inl.h
+++ b/src/operator/sequence_mask-inl.h
@@ -68,8 +68,8 @@ struct SequenceMaskParam : public 
dmlc::Parameter<SequenceMaskParam> {
 // (seqlen, batch, rest) case
 template <int req>
 struct SequenceMask0Kernel {
-  template <typename DType>
-  MSHADOW_XINLINE static void Map(int b, DType *in, const DType *idx,
+  template <typename DType, typename IType>
+  MSHADOW_XINLINE static void Map(int b, DType *in, const IType *idx,
                                   index_t max_s_len, index_t batch_size,
                                   index_t restsize, DType value) {
     const index_t seqpos = static_cast<int>(idx[b]);
@@ -86,8 +86,8 @@ struct SequenceMask0Kernel {
 // (batch, seqlen, rest) case
 template <int req>
 struct SequenceMask1Kernel {
-  template <typename DType>
-  MSHADOW_XINLINE static void Map(int b, DType *in, const DType *idx,
+  template <typename DType, typename IType>
+  MSHADOW_XINLINE static void Map(int b, DType *in, const IType *idx,
                                   index_t max_s_len, index_t batch_size,
                                   index_t restsize, DType value) {
     const index_t seqpos = static_cast<int>(idx[b]);
@@ -101,13 +101,13 @@ struct SequenceMask1Kernel {
   }
 };
 
-template <typename xpu, typename DType>
+template <typename xpu, typename DType, typename IType>
 class SequenceMaskOp : public Operator {
  public:
   explicit SequenceMaskOp(SequenceMaskParam p) { this->param_ = p; }
 
   void sequence_mask(const mshadow::Tensor<xpu, 3, DType> &data,
-                     const mshadow::Tensor<xpu, 1, DType> &indices,
+                     const mshadow::Tensor<xpu, 1, IType> &indices,
                      const OpReqType req, mshadow::Stream<xpu> *const s,
                      DType val) {
     using namespace mshadow;
@@ -153,8 +153,8 @@ class SequenceMaskOp : public Operator {
     // Actual implementation of masking
     Assign(out, req[seq_mask::kOut], F<mshadow_op::identity>(data));
     if (param_.use_sequence_length) {
-      Tensor<xpu, 1, DType> indices =
-          in_data[seq_mask::kSequenceLength].get<xpu, 1, DType>(s);
+      Tensor<xpu, 1, IType> indices =
+          in_data[seq_mask::kSequenceLength].get<xpu, 1, IType>(s);
       sequence_mask(out, indices, req[seq_mask::kOut], s,
                     static_cast<DType>(param_.value));
     }
@@ -190,8 +190,8 @@ class SequenceMaskOp : public Operator {
     if (!param_.use_sequence_length) {
       Assign(data_g, req[seq_mask::kData], F<mshadow_op::identity>(out_g));
     } else {
-      Tensor<xpu, 1, DType> indices =
-          in_data[seq_mask::kSequenceLength].get<xpu, 1, DType>(s);
+      Tensor<xpu, 1, IType> indices =
+          in_data[seq_mask::kSequenceLength].get<xpu, 1, IType>(s);
       if (req[seq_mask::kData] == kAddTo) {
         Tensor<xpu, 3, DType> out_g_temp =
             ctx.requested[seq_mask::kTempSpace].get_space_typed<xpu, 3, DType>(
@@ -212,7 +212,7 @@ class SequenceMaskOp : public Operator {
 };  // class SequenceMaskOp
 
 template <typename xpu>
-Operator *CreateOp(SequenceMaskParam param, int dtype);
+Operator *CreateOp(SequenceMaskParam param, int dtype, int itype);
 
 #if DMLC_USE_CXX11
 class SequenceMaskProp : public OperatorProperty {
@@ -270,8 +270,6 @@ class SequenceMaskProp : public OperatorProperty {
     for (size_t i = 0; i < in_type->size(); ++i) {
       if ((*in_type)[i] == -1) {
         (*in_type)[i] = dtype;
-      } else {
-        UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments()[i]);
       }
     }
     out_type->clear();
diff --git a/src/operator/sequence_mask.cc b/src/operator/sequence_mask.cc
index e02c57b..76e5838 100644
--- a/src/operator/sequence_mask.cc
+++ b/src/operator/sequence_mask.cc
@@ -28,10 +28,13 @@
 namespace mxnet {
 namespace op {
 template <>
-Operator *CreateOp<cpu>(SequenceMaskParam param, int dtype) {
+Operator *CreateOp<cpu>(SequenceMaskParam param, int dtype, int itype) {
   Operator *op = nullptr;
-  MSHADOW_TYPE_SWITCH(dtype, DType,
-                           { op = new SequenceMaskOp<cpu, DType>(param); })
+  MSHADOW_TYPE_SWITCH(dtype, DType, {
+      MSHADOW_TYPE_SWITCH(itype, IType, {
+          op = new SequenceMaskOp<cpu, DType, IType>(param);
+        });
+    });
   return op;
 }
 
@@ -39,7 +42,12 @@ Operator *CreateOp<cpu>(SequenceMaskParam param, int dtype) {
 Operator *SequenceMaskProp::CreateOperatorEx(Context ctx,
                                              std::vector<TShape> *in_shape,
                                              std::vector<int> *in_type) const {
-  DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0]);
+  if (in_type->size() >= 2 && (*in_type)[1] != -1) {
+    DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0], (*in_type)[1]);
+  }
+
+  // sequence_length not passed in, so fall back to using input array dtype 
for second argument
+  DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0], (*in_type)[0]);
 }
 
 DMLC_REGISTER_PARAMETER(SequenceMaskParam);
diff --git a/src/operator/sequence_mask.cu b/src/operator/sequence_mask.cu
index 2ca8832..cec627c 100644
--- a/src/operator/sequence_mask.cu
+++ b/src/operator/sequence_mask.cu
@@ -29,10 +29,13 @@
 namespace mxnet {
 namespace op {
 
-template <> Operator *CreateOp<gpu>(SequenceMaskParam param, int dtype) {
+template <> Operator *CreateOp<gpu>(SequenceMaskParam param, int dtype, int 
itype) {
   Operator *op = NULL;
-  MSHADOW_TYPE_SWITCH(dtype, DType,
-                           { op = new SequenceMaskOp<gpu, DType>(param); })
+  MSHADOW_TYPE_SWITCH(dtype, DType, {
+      MSHADOW_TYPE_SWITCH(itype, IType, {
+          op = new SequenceMaskOp<gpu, DType, IType>(param);
+        });
+    });
   return op;
 }
 
diff --git a/src/operator/sequence_reverse-inl.h 
b/src/operator/sequence_reverse-inl.h
index 5c48729..eb9f71c 100644
--- a/src/operator/sequence_reverse-inl.h
+++ b/src/operator/sequence_reverse-inl.h
@@ -65,14 +65,14 @@ struct SequenceReverseParam : public 
dmlc::Parameter<SequenceReverseParam> {
 };
 
 struct ReverseKernel {
-  template <typename DType>
+  template <typename DType, typename IType>
   MSHADOW_XINLINE static void Map(const int i, DType *const out_data,
                                   const DType *const in_data,
                                   const OpReqType req,
                                   const index_t max_seq_len,
                                   const index_t batch_size,
                                   const index_t other_dim, const index_t numel,
-                                  const DType *const indices) {
+                                  const IType *const indices) {
     for (index_t batch = 0; batch < batch_size; ++batch) {
       const index_t num_seq =
           indices ? static_cast<index_t>(indices[batch]) : max_seq_len;
@@ -102,13 +102,13 @@ struct ReverseKernel {
   }
 };
 
-template <typename xpu, typename DType>
+template <typename xpu, typename DType, typename IType>
 class SequenceReverseOp : public Operator {
  public:
   explicit SequenceReverseOp(SequenceReverseParam p) { this->param_ = p; }
   void sequence_reverse(const mshadow::Tensor<xpu, 3, DType> &data,
                         const mshadow::Tensor<xpu, 3, DType> &out,
-                        const OpReqType req, const DType *const indices,
+                        const OpReqType req, const IType *const indices,
                         mshadow::Stream<xpu> *const s) {
     using namespace mshadow;
     using namespace mshadow::expr;
@@ -145,9 +145,9 @@ class SequenceReverseOp : public Operator {
     Tensor<xpu, 3, DType> out =
         out_data[seq_reverse::kOut].get_with_shape<xpu, 3, DType>(s3, s);
 
-    const DType *const indices =
+    const IType *const indices =
         param_.use_sequence_length
-            ? in_data[seq_reverse::kSequenceLength].dptr<DType>()
+            ? in_data[seq_reverse::kSequenceLength].dptr<IType>()
             : nullptr;
 
     sequence_reverse(data, out, req[seq_reverse::kOut], indices, s);
@@ -179,9 +179,9 @@ class SequenceReverseOp : public Operator {
     Tensor<xpu, 3, DType> output_grad =
         out_grad[seq_reverse::kOut].get_with_shape<xpu, 3, DType>(s3, s);
 
-    const DType *const indices =
+    const IType *const indices =
         param_.use_sequence_length
-            ? in_data[seq_reverse::kSequenceLength].dptr<DType>()
+            ? in_data[seq_reverse::kSequenceLength].dptr<IType>()
             : nullptr;
 
     sequence_reverse(output_grad, data_grad, req[seq_reverse::kData], indices,
@@ -193,7 +193,7 @@ class SequenceReverseOp : public Operator {
 };  // class SequenceReverseOp
 
 template <typename xpu>
-Operator *CreateOp(SequenceReverseParam param, int dtype);
+Operator *CreateOp(SequenceReverseParam param, int dtype, int itype);
 
 #if DMLC_USE_CXX11
 class SequenceReverseProp : public OperatorProperty {
@@ -249,8 +249,6 @@ class SequenceReverseProp : public OperatorProperty {
     for (size_t i = 0; i < in_type->size(); ++i) {
       if ((*in_type)[i] == -1) {
         (*in_type)[i] = dtype;
-      } else {
-        UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments()[i]);
       }
     }
     out_type->clear();
diff --git a/src/operator/sequence_reverse.cc b/src/operator/sequence_reverse.cc
index 21cab78..9225b6b 100644
--- a/src/operator/sequence_reverse.cc
+++ b/src/operator/sequence_reverse.cc
@@ -28,10 +28,13 @@
 namespace mxnet {
 namespace op {
 template <>
-Operator *CreateOp<cpu>(SequenceReverseParam param, int dtype) {
+Operator *CreateOp<cpu>(SequenceReverseParam param, int dtype, int itype) {
   Operator *op = nullptr;
-  MSHADOW_TYPE_SWITCH(dtype, DType,
-                           { op = new SequenceReverseOp<cpu, DType>(param); })
+  MSHADOW_TYPE_SWITCH(dtype, DType, {
+      MSHADOW_TYPE_SWITCH(itype, IType, {
+          op = new SequenceReverseOp<cpu, DType, IType>(param);
+        });
+    });
   return op;
 }
 
@@ -39,7 +42,13 @@ Operator *CreateOp<cpu>(SequenceReverseParam param, int 
dtype) {
 Operator *SequenceReverseProp::CreateOperatorEx(
     Context ctx, std::vector<TShape> *in_shape,
     std::vector<int> *in_type) const {
-  DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0]);
+
+  if (in_type->size() >= 2 && (*in_type)[1] != -1) {
+    DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0], (*in_type)[1]);
+  }
+
+  // sequence_length not passed in, so fall back to using input array dtype 
for second argument
+  DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0], (*in_type)[0]);
 }
 
 DMLC_REGISTER_PARAMETER(SequenceReverseParam);
diff --git a/src/operator/sequence_reverse.cu b/src/operator/sequence_reverse.cu
index 1edc9c1..db5b416 100644
--- a/src/operator/sequence_reverse.cu
+++ b/src/operator/sequence_reverse.cu
@@ -28,11 +28,13 @@
 
 namespace mxnet {
 namespace op {
-template <> Operator *CreateOp<gpu>(SequenceReverseParam param, int dtype) {
+template <> Operator *CreateOp<gpu>(SequenceReverseParam param, int dtype, int 
itype) {
   Operator *op = nullptr;
   MSHADOW_TYPE_SWITCH(dtype, DType, {
-    op = new SequenceReverseOp<gpu, DType>(param);
-  })
+      MSHADOW_TYPE_SWITCH(itype, IType, {
+          op = new SequenceReverseOp<gpu, DType, IType>(param);
+        });
+    });
   return op;
 }
 
diff --git a/tests/python/unittest/test_operator.py 
b/tests/python/unittest/test_operator.py
index 3f34ade..670cc7e 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -3292,36 +3292,38 @@ def check_sequence_func(ftype, mask_value=0, axis=0):
     L = mx.symbol.Variable('L') # lengths
     shapes = [(3, 4), (1, 1), (3, 4, 3, 1, 1)]
     for seqlenQ in [True, False]:
-        for s in shapes:
-            x = mx.random.uniform(-1, 1, s, ctx=mx.cpu()).copyto(xpu)
-            batch = s[1] if (axis == 0) else s[0]
-            seqlen = s[axis]
-            l_np = np.random.randint(1, seqlen + 1, batch)
-            l = mx.nd.array(l_np, ctx=mx.cpu()).copyto(xpu)
-            if not seqlenQ:
-                l_np = None
-            args = {'data':X, 'use_sequence_length':seqlenQ, "axis":axis}
-            if seqlenQ:
-                args['sequence_length'] = L
-            if ftype == "last":
-                Y = mx.symbol.SequenceLast(**args)
-                np_out = sequence_last_numpy(x.asnumpy(), l_np, axis)
-            elif ftype == "mask":
-                args['value'] = mask_value
-                Y = mx.symbol.SequenceMask(**args)
-                np_out = sequence_mask_numpy(x.asnumpy(), l_np, axis, 
mask_value)
-            elif ftype == "reverse":
-                Y = mx.symbol.SequenceReverse(**args)
-                np_out = sequence_reverse_numpy(x.asnumpy(), l_np, axis)
-            fargs = [x, l] if seqlenQ else [x]
-            gargs = [x.asnumpy(), l_np] if seqlenQ else [x.asnumpy()]
-            check_symbolic_forward(Y, fargs, [np_out])
-            check_numeric_gradient(Y, gargs, grad_nodes={'X':'write'},
-                numeric_eps=1e-2, rtol=1e-2)
-            check_numeric_gradient(Y, gargs, grad_nodes={'X':'add'},
-                numeric_eps=1e-3, rtol=1e-2, atol=1E-4)
-            check_numeric_gradient(Y, gargs, grad_nodes={'X':'null'},
-                numeric_eps=1e-3, rtol=1e-2, atol=1E-4)
+        for ary_dtype in [np.float32]:
+            for idx_dtype in [np.int32, np.float32]:
+                for s in shapes:
+                    x = mx.random.uniform(-1, 1, s, 
ctx=mx.cpu()).astype(ary_dtype).copyto(xpu)
+                    batch = s[1] if (axis == 0) else s[0]
+                    seqlen = s[axis]
+                    l_np = np.random.randint(1, seqlen + 1, batch)
+                    l = mx.nd.array(l_np, ctx=mx.cpu(), 
dtype=idx_dtype).copyto(xpu)
+                    if not seqlenQ:
+                        l_np = None
+                    args = {'data':X, 'use_sequence_length':seqlenQ, 
"axis":axis}
+                    if seqlenQ:
+                        args['sequence_length'] = L
+                    if ftype == "last":
+                        Y = mx.symbol.SequenceLast(**args)
+                        np_out = sequence_last_numpy(x.asnumpy(), l_np, axis)
+                    elif ftype == "mask":
+                        args['value'] = mask_value
+                        Y = mx.symbol.SequenceMask(**args)
+                        np_out = sequence_mask_numpy(x.asnumpy(), l_np, axis, 
mask_value)
+                    elif ftype == "reverse":
+                        Y = mx.symbol.SequenceReverse(**args)
+                        np_out = sequence_reverse_numpy(x.asnumpy(), l_np, 
axis)
+                    fargs = [x, l] if seqlenQ else [x]
+                    gargs = [x.asnumpy(), l_np] if seqlenQ else [x.asnumpy()]
+                    check_symbolic_forward(Y, fargs, [np_out], dtype="asnumpy")
+                    check_numeric_gradient(Y, gargs, grad_nodes={'X':'write'},
+                        numeric_eps=1e-2, rtol=1e-2)
+                    check_numeric_gradient(Y, gargs, grad_nodes={'X':'add'},
+                        numeric_eps=1e-3, rtol=1e-2, atol=1E-4)
+                    check_numeric_gradient(Y, gargs, grad_nodes={'X':'null'},
+                        numeric_eps=1e-3, rtol=1e-2, atol=1E-4)
 
 
 @with_seed()

Reply via email to