This is an automated email from the ASF dual-hosted git repository.
zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new da5242b MXNET-1295 Adding integer index support to Sequence* family
of operators. (#13880)
da5242b is described below
commit da5242b732de39ad47d8ecee582f261ba5935fa9
Author: stephenrawls <[email protected]>
AuthorDate: Sat Jan 26 22:04:09 2019 -0800
MXNET-1295 Adding integer index support to Sequence* family of operators.
(#13880)
* Adding integer index support to Sequence* family of operators.
Adding ability to use int32 arrays, or any castable-to-int type, as
the sequence_length array to SequenceMask, SequenceLast, and
SequenceReverse. Previously these operaters all requred sequence_length
to be the same data type as the input array.
See MxNet Jira ticket here:
https://issues.apache.org/jira/browse/MXNET-1295
See also GitHub issues here:
https://github.com/apache/incubator-mxnet/issues/12649
https://github.com/dmlc/gluon-nlp/issues/346
* Adding explicit braces to an if statement to fix g++ warning
* fixing sequence_mask.cu by adding IType to template
* Fixing whitespace errors reported by linter
* Adding unit tests
* Fixing length of lines to pass linter
---
python/mxnet/test_utils.py | 37 +++++++++++++-------
src/operator/sequence_last-inl.h | 30 ++++++++--------
src/operator/sequence_last.cc | 16 ++++++---
src/operator/sequence_last.cu | 9 +++--
src/operator/sequence_mask-inl.h | 24 ++++++-------
src/operator/sequence_mask.cc | 16 ++++++---
src/operator/sequence_mask.cu | 9 +++--
src/operator/sequence_reverse-inl.h | 20 +++++------
src/operator/sequence_reverse.cc | 17 +++++++---
src/operator/sequence_reverse.cu | 8 +++--
tests/python/unittest/test_operator.py | 62 ++++++++++++++++++----------------
11 files changed, 144 insertions(+), 104 deletions(-)
diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py
index 0a4d17d..4138e4d 100644
--- a/python/mxnet/test_utils.py
+++ b/python/mxnet/test_utils.py
@@ -620,8 +620,11 @@ def _parse_location(sym, location, ctx,
dtype=default_dtype()):
*In either case, value of all the arguments must be provided.*
ctx : Context
Device context.
- dtype: np.float16 or np.float32 or np.float64
- Datatype for mx.nd.array.
+ dtype: "asnumpy" or np.float16 or np.float32 or np.float64
+ If dtype is "asnumpy" then the mx.nd.array created will have the same
+ type as th numpy array from which it is copied.
+ Otherwise, dtype is the explicit datatype for all mx.nd.array objects
+ created in this function.
Returns
-------
@@ -643,7 +646,7 @@ def _parse_location(sym, location, ctx,
dtype=default_dtype()):
ValueError: Symbol arguments and keys of the given location do not match.
"""
assert isinstance(location, (dict, list, tuple))
- assert dtype in (np.float16, np.float32, np.float64)
+ assert dtype == "asnumpy" or dtype in (np.float16, np.float32, np.float64)
if isinstance(location, dict):
if set(location.keys()) != set(sym.list_arguments()):
raise ValueError("Symbol arguments and keys of the given location
do not match."
@@ -651,8 +654,8 @@ def _parse_location(sym, location, ctx,
dtype=default_dtype()):
% (str(set(sym.list_arguments())),
str(set(location.keys()))))
else:
location = {k: v for k, v in zip(sym.list_arguments(), location)}
- location = {k: mx.nd.array(v, ctx=ctx, dtype=dtype) if isinstance(v,
np.ndarray) \
- else v for k, v in location.items()}
+ location = {k: mx.nd.array(v, ctx=ctx, dtype=v.dtype if dtype == "asnumpy"
else dtype) \
+ if isinstance(v, np.ndarray) else v for k, v in
location.items()}
return location
@@ -677,8 +680,11 @@ def _parse_aux_states(sym, aux_states, ctx,
dtype=default_dtype()):
*In either case, all aux states of `sym` must be provided.*
ctx : Context
Device context.
- dtype: np.float16 or np.float32 or np.float64
- Datatype for mx.nd.array.
+ dtype: "asnumpy" or np.float16 or np.float32 or np.float64
+ If dtype is "asnumpy" then the mx.nd.array created will have the same
+ type as th numpy array from which it is copied.
+ Otherwise, dtype is the explicit datatype for all mx.nd.array objects
+ created in this function.
Returns
-------
@@ -702,7 +708,7 @@ def _parse_aux_states(sym, aux_states, ctx,
dtype=default_dtype()):
>>> _parse_aux_states(fc2, {'batchnorm0_moving_var': mean_states}, None)
ValueError: Symbol aux_states names and given aux_states do not match.
"""
- assert dtype in (np.float16, np.float32, np.float64)
+ assert dtype == "asnumpy" or dtype in (np.float16, np.float32, np.float64)
if aux_states is not None:
if isinstance(aux_states, dict):
if set(aux_states.keys()) != set(sym.list_auxiliary_states()):
@@ -713,7 +719,8 @@ def _parse_aux_states(sym, aux_states, ctx,
dtype=default_dtype()):
elif isinstance(aux_states, (list, tuple)):
aux_names = sym.list_auxiliary_states()
aux_states = {k:v for k, v in zip(aux_names, aux_states)}
- aux_states = {k: mx.nd.array(v, ctx=ctx, dtype=dtype) for k, v in
aux_states.items()}
+ aux_states = {k: mx.nd.array(v, ctx=ctx, dtype=v.dtype if dtype ==
"asnumpy" else dtype) \
+ for k, v in aux_states.items()}
return aux_states
@@ -962,8 +969,11 @@ def check_symbolic_forward(sym, location, expected,
rtol=1E-4, atol=None,
Contains the mapping between names of auxiliary states and their
values.
ctx : Context, optional
running context
- dtype: np.float16 or np.float32 or np.float64
- Datatype for mx.nd.array.
+ dtype: "asnumpy" or np.float16 or np.float32 or np.float64
+ If dtype is "asnumpy" then the mx.nd.array created will have the same
+ type as th numpy array from which it is copied.
+ Otherwise, dtype is the explicit datatype for all mx.nd.array objects
+ created in this function.
equal_nan: Boolean
if True, `nan` is a valid value for checking equivalency (ie `nan` ==
`nan`)
@@ -979,7 +989,7 @@ def check_symbolic_forward(sym, location, expected,
rtol=1E-4, atol=None,
>>> ret_expected = np.array([[19, 22], [43, 50]])
>>> check_symbolic_forward(sym_dot, [mat1, mat2], [ret_expected])
"""
- assert dtype in (np.float16, np.float32, np.float64)
+ assert dtype == "asnumpy" or dtype in (np.float16, np.float32, np.float64)
if ctx is None:
ctx = default_context()
@@ -988,7 +998,8 @@ def check_symbolic_forward(sym, location, expected,
rtol=1E-4, atol=None,
dtype=dtype)
if isinstance(expected, dict):
expected = [expected[k] for k in sym.list_outputs()]
- args_grad_data = {k:mx.nd.empty(v.shape, ctx=ctx, dtype=dtype) for k, v in
location.items()}
+ args_grad_data = {k:mx.nd.empty(v.shape, ctx=ctx, dtype=v.dtype if dtype
== "asnumpy" else dtype) \
+ for k, v in location.items()}
executor = sym.bind(ctx=ctx, args=location, args_grad=args_grad_data,
aux_states=aux_states)
for g in executor.grad_arrays:
diff --git a/src/operator/sequence_last-inl.h b/src/operator/sequence_last-inl.h
index 1a59473..61506c2 100644
--- a/src/operator/sequence_last-inl.h
+++ b/src/operator/sequence_last-inl.h
@@ -65,9 +65,9 @@ struct SequenceLastParam : public
dmlc::Parameter<SequenceLastParam> {
template <int req>
struct SequenceLastKernel {
- template <typename DType>
+ template <typename DType, typename IType>
MSHADOW_XINLINE static void Map(int i, DType *out, const DType *in,
- const DType *idx, int offset1, int offset2,
+ const IType *idx, int offset1, int offset2,
mshadow::Shape<2> oshape) {
const auto opos = mxnet_op::unravel(i, oshape);
const int seqpos = static_cast<int>(idx[opos[0]]) - 1;
@@ -77,9 +77,9 @@ struct SequenceLastKernel {
};
struct SequenceLastGradKernel {
- template <typename DType>
+ template <typename DType, typename IType>
MSHADOW_XINLINE static void Map(int i, DType *in_grad, const DType *out_grad,
- const DType *idx, int offset1, int offset2,
+ const IType *idx, int offset1, int offset2,
mshadow::Shape<2> oshape) {
const auto opos = mxnet_op::unravel(i, oshape);
const int seqpos = static_cast<int>(idx[opos[0]]) - 1;
@@ -88,14 +88,14 @@ struct SequenceLastGradKernel {
}
};
-template <typename xpu, typename DType>
+template <typename xpu, typename DType, typename IType>
class SequenceLastOp : public Operator {
public:
explicit SequenceLastOp(SequenceLastParam p) { this->param_ = p; }
void sequence_last(const mshadow::Tensor<xpu, 3, DType> &data,
const mshadow::Tensor<xpu, 2, DType> &out,
- const mshadow::Tensor<xpu, 1, DType> &indices,
+ const mshadow::Tensor<xpu, 1, IType> &indices,
const OpReqType req, mshadow::Stream<xpu> *const s) {
using namespace mshadow;
using namespace mshadow::expr;
@@ -115,7 +115,7 @@ class SequenceLastOp : public Operator {
void sequence_last_grad(const mshadow::Tensor<xpu, 3, DType> &in_grad,
const mshadow::Tensor<xpu, 2, DType> &out_grad,
- const mshadow::Tensor<xpu, 1, DType> &indices,
+ const mshadow::Tensor<xpu, 1, IType> &indices,
mshadow::Stream<xpu> *const s) {
using namespace mshadow;
using namespace mshadow::expr;
@@ -163,11 +163,11 @@ class SequenceLastOp : public Operator {
Tensor<xpu, 2, DType> out =
out_data[seq_last::kOut].get_with_shape<xpu, 2, DType>(
Shape2(batch, rest_size), s);
- Tensor<xpu, 1, DType> indices =
+ Tensor<xpu, 1, IType> indices =
param_.use_sequence_length
- ? in_data[seq_last::kSequenceLength].get<xpu, 1, DType>(s)
+ ? in_data[seq_last::kSequenceLength].get<xpu, 1, IType>(s)
: ctx.requested[seq_last::kTempSpace]
- .get_space_typed<xpu, 1, DType>(Shape1(batch), s);
+ .get_space_typed<xpu, 1, IType>(Shape1(batch), s);
if (!param_.use_sequence_length) indices = max_seq_len;
sequence_last(data, out, indices, req[seq_last::kOut], s);
@@ -206,11 +206,11 @@ class SequenceLastOp : public Operator {
Tensor<xpu, 2, DType> output_grad =
out_grad[seq_last::kOut].get_with_shape<xpu, 2, DType>(
Shape2(batch, rest_size), s);
- Tensor<xpu, 1, DType> indices =
+ Tensor<xpu, 1, IType> indices =
param_.use_sequence_length
- ? in_data[seq_last::kSequenceLength].get<xpu, 1, DType>(s)
+ ? in_data[seq_last::kSequenceLength].get<xpu, 1, IType>(s)
: ctx.requested[seq_last::kTempSpace]
- .get_space_typed<xpu, 1, DType>(Shape1(batch), s);
+ .get_space_typed<xpu, 1, IType>(Shape1(batch), s);
if (req[seq_last::kData] == kWriteTo) data_grad = 0.0f;
sequence_last_grad(data_grad, output_grad, indices, s);
@@ -221,7 +221,7 @@ class SequenceLastOp : public Operator {
}; // class SequenceLastOp
template <typename xpu>
-Operator *CreateOp(SequenceLastParam param, int dtype);
+Operator *CreateOp(SequenceLastParam param, int dtype, int itype);
#if DMLC_USE_CXX11
class SequenceLastProp : public OperatorProperty {
@@ -281,8 +281,6 @@ class SequenceLastProp : public OperatorProperty {
for (size_t i = 0; i < in_type->size(); ++i) {
if ((*in_type)[i] == -1) {
(*in_type)[i] = dtype;
- } else {
- UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments()[i]);
}
}
out_type->clear();
diff --git a/src/operator/sequence_last.cc b/src/operator/sequence_last.cc
index 345524b..f2388a8 100644
--- a/src/operator/sequence_last.cc
+++ b/src/operator/sequence_last.cc
@@ -28,10 +28,13 @@
namespace mxnet {
namespace op {
template <>
-Operator *CreateOp<cpu>(SequenceLastParam param, int dtype) {
+Operator *CreateOp<cpu>(SequenceLastParam param, int dtype, int itype) {
Operator *op = nullptr;
- MSHADOW_TYPE_SWITCH(dtype, DType,
- { op = new SequenceLastOp<cpu, DType>(param); })
+ MSHADOW_TYPE_SWITCH(dtype, DType, {
+ MSHADOW_TYPE_SWITCH(itype, IType, {
+ op = new SequenceLastOp<cpu, DType, IType>(param);
+ });
+ });
return op;
}
@@ -39,7 +42,12 @@ Operator *CreateOp<cpu>(SequenceLastParam param, int dtype) {
Operator *SequenceLastProp::CreateOperatorEx(Context ctx,
std::vector<TShape> *in_shape,
std::vector<int> *in_type) const {
- DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0]);
+ if (in_type->size() >= 2 && (*in_type)[1] != -1) {
+ DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0], (*in_type)[1]);
+ }
+
+ // sequence_length not passed in, so fall back to using input array dtype
for second argument
+ DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0], (*in_type)[0]);
}
DMLC_REGISTER_PARAMETER(SequenceLastParam);
diff --git a/src/operator/sequence_last.cu b/src/operator/sequence_last.cu
index dfc4e59..fb5ae84 100644
--- a/src/operator/sequence_last.cu
+++ b/src/operator/sequence_last.cu
@@ -28,10 +28,13 @@
namespace mxnet {
namespace op {
-template <> Operator *CreateOp<gpu>(SequenceLastParam param, int dtype) {
+template <> Operator *CreateOp<gpu>(SequenceLastParam param, int dtype, int
itype) {
Operator *op = NULL;
- MSHADOW_TYPE_SWITCH(dtype, DType,
- { op = new SequenceLastOp<gpu, DType>(param); })
+ MSHADOW_TYPE_SWITCH(dtype, DType, {
+ MSHADOW_TYPE_SWITCH(itype, IType, {
+ op = new SequenceLastOp<gpu, DType, IType>(param);
+ });
+ });
return op;
}
diff --git a/src/operator/sequence_mask-inl.h b/src/operator/sequence_mask-inl.h
index c93ffb5..c2584ab 100644
--- a/src/operator/sequence_mask-inl.h
+++ b/src/operator/sequence_mask-inl.h
@@ -68,8 +68,8 @@ struct SequenceMaskParam : public
dmlc::Parameter<SequenceMaskParam> {
// (seqlen, batch, rest) case
template <int req>
struct SequenceMask0Kernel {
- template <typename DType>
- MSHADOW_XINLINE static void Map(int b, DType *in, const DType *idx,
+ template <typename DType, typename IType>
+ MSHADOW_XINLINE static void Map(int b, DType *in, const IType *idx,
index_t max_s_len, index_t batch_size,
index_t restsize, DType value) {
const index_t seqpos = static_cast<int>(idx[b]);
@@ -86,8 +86,8 @@ struct SequenceMask0Kernel {
// (batch, seqlen, rest) case
template <int req>
struct SequenceMask1Kernel {
- template <typename DType>
- MSHADOW_XINLINE static void Map(int b, DType *in, const DType *idx,
+ template <typename DType, typename IType>
+ MSHADOW_XINLINE static void Map(int b, DType *in, const IType *idx,
index_t max_s_len, index_t batch_size,
index_t restsize, DType value) {
const index_t seqpos = static_cast<int>(idx[b]);
@@ -101,13 +101,13 @@ struct SequenceMask1Kernel {
}
};
-template <typename xpu, typename DType>
+template <typename xpu, typename DType, typename IType>
class SequenceMaskOp : public Operator {
public:
explicit SequenceMaskOp(SequenceMaskParam p) { this->param_ = p; }
void sequence_mask(const mshadow::Tensor<xpu, 3, DType> &data,
- const mshadow::Tensor<xpu, 1, DType> &indices,
+ const mshadow::Tensor<xpu, 1, IType> &indices,
const OpReqType req, mshadow::Stream<xpu> *const s,
DType val) {
using namespace mshadow;
@@ -153,8 +153,8 @@ class SequenceMaskOp : public Operator {
// Actual implementation of masking
Assign(out, req[seq_mask::kOut], F<mshadow_op::identity>(data));
if (param_.use_sequence_length) {
- Tensor<xpu, 1, DType> indices =
- in_data[seq_mask::kSequenceLength].get<xpu, 1, DType>(s);
+ Tensor<xpu, 1, IType> indices =
+ in_data[seq_mask::kSequenceLength].get<xpu, 1, IType>(s);
sequence_mask(out, indices, req[seq_mask::kOut], s,
static_cast<DType>(param_.value));
}
@@ -190,8 +190,8 @@ class SequenceMaskOp : public Operator {
if (!param_.use_sequence_length) {
Assign(data_g, req[seq_mask::kData], F<mshadow_op::identity>(out_g));
} else {
- Tensor<xpu, 1, DType> indices =
- in_data[seq_mask::kSequenceLength].get<xpu, 1, DType>(s);
+ Tensor<xpu, 1, IType> indices =
+ in_data[seq_mask::kSequenceLength].get<xpu, 1, IType>(s);
if (req[seq_mask::kData] == kAddTo) {
Tensor<xpu, 3, DType> out_g_temp =
ctx.requested[seq_mask::kTempSpace].get_space_typed<xpu, 3, DType>(
@@ -212,7 +212,7 @@ class SequenceMaskOp : public Operator {
}; // class SequenceMaskOp
template <typename xpu>
-Operator *CreateOp(SequenceMaskParam param, int dtype);
+Operator *CreateOp(SequenceMaskParam param, int dtype, int itype);
#if DMLC_USE_CXX11
class SequenceMaskProp : public OperatorProperty {
@@ -270,8 +270,6 @@ class SequenceMaskProp : public OperatorProperty {
for (size_t i = 0; i < in_type->size(); ++i) {
if ((*in_type)[i] == -1) {
(*in_type)[i] = dtype;
- } else {
- UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments()[i]);
}
}
out_type->clear();
diff --git a/src/operator/sequence_mask.cc b/src/operator/sequence_mask.cc
index e02c57b..76e5838 100644
--- a/src/operator/sequence_mask.cc
+++ b/src/operator/sequence_mask.cc
@@ -28,10 +28,13 @@
namespace mxnet {
namespace op {
template <>
-Operator *CreateOp<cpu>(SequenceMaskParam param, int dtype) {
+Operator *CreateOp<cpu>(SequenceMaskParam param, int dtype, int itype) {
Operator *op = nullptr;
- MSHADOW_TYPE_SWITCH(dtype, DType,
- { op = new SequenceMaskOp<cpu, DType>(param); })
+ MSHADOW_TYPE_SWITCH(dtype, DType, {
+ MSHADOW_TYPE_SWITCH(itype, IType, {
+ op = new SequenceMaskOp<cpu, DType, IType>(param);
+ });
+ });
return op;
}
@@ -39,7 +42,12 @@ Operator *CreateOp<cpu>(SequenceMaskParam param, int dtype) {
Operator *SequenceMaskProp::CreateOperatorEx(Context ctx,
std::vector<TShape> *in_shape,
std::vector<int> *in_type) const {
- DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0]);
+ if (in_type->size() >= 2 && (*in_type)[1] != -1) {
+ DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0], (*in_type)[1]);
+ }
+
+ // sequence_length not passed in, so fall back to using input array dtype
for second argument
+ DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0], (*in_type)[0]);
}
DMLC_REGISTER_PARAMETER(SequenceMaskParam);
diff --git a/src/operator/sequence_mask.cu b/src/operator/sequence_mask.cu
index 2ca8832..cec627c 100644
--- a/src/operator/sequence_mask.cu
+++ b/src/operator/sequence_mask.cu
@@ -29,10 +29,13 @@
namespace mxnet {
namespace op {
-template <> Operator *CreateOp<gpu>(SequenceMaskParam param, int dtype) {
+template <> Operator *CreateOp<gpu>(SequenceMaskParam param, int dtype, int
itype) {
Operator *op = NULL;
- MSHADOW_TYPE_SWITCH(dtype, DType,
- { op = new SequenceMaskOp<gpu, DType>(param); })
+ MSHADOW_TYPE_SWITCH(dtype, DType, {
+ MSHADOW_TYPE_SWITCH(itype, IType, {
+ op = new SequenceMaskOp<gpu, DType, IType>(param);
+ });
+ });
return op;
}
diff --git a/src/operator/sequence_reverse-inl.h
b/src/operator/sequence_reverse-inl.h
index 5c48729..eb9f71c 100644
--- a/src/operator/sequence_reverse-inl.h
+++ b/src/operator/sequence_reverse-inl.h
@@ -65,14 +65,14 @@ struct SequenceReverseParam : public
dmlc::Parameter<SequenceReverseParam> {
};
struct ReverseKernel {
- template <typename DType>
+ template <typename DType, typename IType>
MSHADOW_XINLINE static void Map(const int i, DType *const out_data,
const DType *const in_data,
const OpReqType req,
const index_t max_seq_len,
const index_t batch_size,
const index_t other_dim, const index_t numel,
- const DType *const indices) {
+ const IType *const indices) {
for (index_t batch = 0; batch < batch_size; ++batch) {
const index_t num_seq =
indices ? static_cast<index_t>(indices[batch]) : max_seq_len;
@@ -102,13 +102,13 @@ struct ReverseKernel {
}
};
-template <typename xpu, typename DType>
+template <typename xpu, typename DType, typename IType>
class SequenceReverseOp : public Operator {
public:
explicit SequenceReverseOp(SequenceReverseParam p) { this->param_ = p; }
void sequence_reverse(const mshadow::Tensor<xpu, 3, DType> &data,
const mshadow::Tensor<xpu, 3, DType> &out,
- const OpReqType req, const DType *const indices,
+ const OpReqType req, const IType *const indices,
mshadow::Stream<xpu> *const s) {
using namespace mshadow;
using namespace mshadow::expr;
@@ -145,9 +145,9 @@ class SequenceReverseOp : public Operator {
Tensor<xpu, 3, DType> out =
out_data[seq_reverse::kOut].get_with_shape<xpu, 3, DType>(s3, s);
- const DType *const indices =
+ const IType *const indices =
param_.use_sequence_length
- ? in_data[seq_reverse::kSequenceLength].dptr<DType>()
+ ? in_data[seq_reverse::kSequenceLength].dptr<IType>()
: nullptr;
sequence_reverse(data, out, req[seq_reverse::kOut], indices, s);
@@ -179,9 +179,9 @@ class SequenceReverseOp : public Operator {
Tensor<xpu, 3, DType> output_grad =
out_grad[seq_reverse::kOut].get_with_shape<xpu, 3, DType>(s3, s);
- const DType *const indices =
+ const IType *const indices =
param_.use_sequence_length
- ? in_data[seq_reverse::kSequenceLength].dptr<DType>()
+ ? in_data[seq_reverse::kSequenceLength].dptr<IType>()
: nullptr;
sequence_reverse(output_grad, data_grad, req[seq_reverse::kData], indices,
@@ -193,7 +193,7 @@ class SequenceReverseOp : public Operator {
}; // class SequenceReverseOp
template <typename xpu>
-Operator *CreateOp(SequenceReverseParam param, int dtype);
+Operator *CreateOp(SequenceReverseParam param, int dtype, int itype);
#if DMLC_USE_CXX11
class SequenceReverseProp : public OperatorProperty {
@@ -249,8 +249,6 @@ class SequenceReverseProp : public OperatorProperty {
for (size_t i = 0; i < in_type->size(); ++i) {
if ((*in_type)[i] == -1) {
(*in_type)[i] = dtype;
- } else {
- UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments()[i]);
}
}
out_type->clear();
diff --git a/src/operator/sequence_reverse.cc b/src/operator/sequence_reverse.cc
index 21cab78..9225b6b 100644
--- a/src/operator/sequence_reverse.cc
+++ b/src/operator/sequence_reverse.cc
@@ -28,10 +28,13 @@
namespace mxnet {
namespace op {
template <>
-Operator *CreateOp<cpu>(SequenceReverseParam param, int dtype) {
+Operator *CreateOp<cpu>(SequenceReverseParam param, int dtype, int itype) {
Operator *op = nullptr;
- MSHADOW_TYPE_SWITCH(dtype, DType,
- { op = new SequenceReverseOp<cpu, DType>(param); })
+ MSHADOW_TYPE_SWITCH(dtype, DType, {
+ MSHADOW_TYPE_SWITCH(itype, IType, {
+ op = new SequenceReverseOp<cpu, DType, IType>(param);
+ });
+ });
return op;
}
@@ -39,7 +42,13 @@ Operator *CreateOp<cpu>(SequenceReverseParam param, int
dtype) {
Operator *SequenceReverseProp::CreateOperatorEx(
Context ctx, std::vector<TShape> *in_shape,
std::vector<int> *in_type) const {
- DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0]);
+
+ if (in_type->size() >= 2 && (*in_type)[1] != -1) {
+ DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0], (*in_type)[1]);
+ }
+
+ // sequence_length not passed in, so fall back to using input array dtype
for second argument
+ DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0], (*in_type)[0]);
}
DMLC_REGISTER_PARAMETER(SequenceReverseParam);
diff --git a/src/operator/sequence_reverse.cu b/src/operator/sequence_reverse.cu
index 1edc9c1..db5b416 100644
--- a/src/operator/sequence_reverse.cu
+++ b/src/operator/sequence_reverse.cu
@@ -28,11 +28,13 @@
namespace mxnet {
namespace op {
-template <> Operator *CreateOp<gpu>(SequenceReverseParam param, int dtype) {
+template <> Operator *CreateOp<gpu>(SequenceReverseParam param, int dtype, int
itype) {
Operator *op = nullptr;
MSHADOW_TYPE_SWITCH(dtype, DType, {
- op = new SequenceReverseOp<gpu, DType>(param);
- })
+ MSHADOW_TYPE_SWITCH(itype, IType, {
+ op = new SequenceReverseOp<gpu, DType, IType>(param);
+ });
+ });
return op;
}
diff --git a/tests/python/unittest/test_operator.py
b/tests/python/unittest/test_operator.py
index 3f34ade..670cc7e 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -3292,36 +3292,38 @@ def check_sequence_func(ftype, mask_value=0, axis=0):
L = mx.symbol.Variable('L') # lengths
shapes = [(3, 4), (1, 1), (3, 4, 3, 1, 1)]
for seqlenQ in [True, False]:
- for s in shapes:
- x = mx.random.uniform(-1, 1, s, ctx=mx.cpu()).copyto(xpu)
- batch = s[1] if (axis == 0) else s[0]
- seqlen = s[axis]
- l_np = np.random.randint(1, seqlen + 1, batch)
- l = mx.nd.array(l_np, ctx=mx.cpu()).copyto(xpu)
- if not seqlenQ:
- l_np = None
- args = {'data':X, 'use_sequence_length':seqlenQ, "axis":axis}
- if seqlenQ:
- args['sequence_length'] = L
- if ftype == "last":
- Y = mx.symbol.SequenceLast(**args)
- np_out = sequence_last_numpy(x.asnumpy(), l_np, axis)
- elif ftype == "mask":
- args['value'] = mask_value
- Y = mx.symbol.SequenceMask(**args)
- np_out = sequence_mask_numpy(x.asnumpy(), l_np, axis,
mask_value)
- elif ftype == "reverse":
- Y = mx.symbol.SequenceReverse(**args)
- np_out = sequence_reverse_numpy(x.asnumpy(), l_np, axis)
- fargs = [x, l] if seqlenQ else [x]
- gargs = [x.asnumpy(), l_np] if seqlenQ else [x.asnumpy()]
- check_symbolic_forward(Y, fargs, [np_out])
- check_numeric_gradient(Y, gargs, grad_nodes={'X':'write'},
- numeric_eps=1e-2, rtol=1e-2)
- check_numeric_gradient(Y, gargs, grad_nodes={'X':'add'},
- numeric_eps=1e-3, rtol=1e-2, atol=1E-4)
- check_numeric_gradient(Y, gargs, grad_nodes={'X':'null'},
- numeric_eps=1e-3, rtol=1e-2, atol=1E-4)
+ for ary_dtype in [np.float32]:
+ for idx_dtype in [np.int32, np.float32]:
+ for s in shapes:
+ x = mx.random.uniform(-1, 1, s,
ctx=mx.cpu()).astype(ary_dtype).copyto(xpu)
+ batch = s[1] if (axis == 0) else s[0]
+ seqlen = s[axis]
+ l_np = np.random.randint(1, seqlen + 1, batch)
+ l = mx.nd.array(l_np, ctx=mx.cpu(),
dtype=idx_dtype).copyto(xpu)
+ if not seqlenQ:
+ l_np = None
+ args = {'data':X, 'use_sequence_length':seqlenQ,
"axis":axis}
+ if seqlenQ:
+ args['sequence_length'] = L
+ if ftype == "last":
+ Y = mx.symbol.SequenceLast(**args)
+ np_out = sequence_last_numpy(x.asnumpy(), l_np, axis)
+ elif ftype == "mask":
+ args['value'] = mask_value
+ Y = mx.symbol.SequenceMask(**args)
+ np_out = sequence_mask_numpy(x.asnumpy(), l_np, axis,
mask_value)
+ elif ftype == "reverse":
+ Y = mx.symbol.SequenceReverse(**args)
+ np_out = sequence_reverse_numpy(x.asnumpy(), l_np,
axis)
+ fargs = [x, l] if seqlenQ else [x]
+ gargs = [x.asnumpy(), l_np] if seqlenQ else [x.asnumpy()]
+ check_symbolic_forward(Y, fargs, [np_out], dtype="asnumpy")
+ check_numeric_gradient(Y, gargs, grad_nodes={'X':'write'},
+ numeric_eps=1e-2, rtol=1e-2)
+ check_numeric_gradient(Y, gargs, grad_nodes={'X':'add'},
+ numeric_eps=1e-3, rtol=1e-2, atol=1E-4)
+ check_numeric_gradient(Y, gargs, grad_nodes={'X':'null'},
+ numeric_eps=1e-3, rtol=1e-2, atol=1E-4)
@with_seed()