This is an automated email from the ASF dual-hosted git repository.

ptrendx pushed a commit to branch v1.6.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.6.x by this push:
     new 4c41afd  Backport of #16827, #16791 and #16888 to 1.6 branch (#16901)
4c41afd is described below

commit 4c41afd1e6014c5cde00f4d253474ffa1e141cac
Author: Przemyslaw Tredak <[email protected]>
AuthorDate: Tue Nov 26 09:03:15 2019 -0800

    Backport of #16827, #16791 and #16888 to 1.6 branch (#16901)
    
    * refactor and reduce float types for some functions, also add bitwise_xor 
(#16827)
    
    * Mixed precison binary op backward (use in) for numpy (#16791)
    
    * mixed precison binary op backward
    
    * reduce unix cpu runtime
    
    * Add evaluation_loss to the estimator base class. (#16888)
    
    * Add evaluation_loss to the estimator base class.
    
    * Update the base estimator class to support the separate evaluation loss.
    
    * Add evaluation loss to the base estimator class.
    
    * Add unittest for evaluation loss in the test_evaluation function
    
    * Update estimator.py
    
    * Update estimator.py
---
 python/mxnet/gluon/contrib/estimator/estimator.py  |  11 +-
 python/mxnet/ndarray/numpy/_op.py                  |  40 +++-
 python/mxnet/numpy/multiarray.py                   |  42 +++-
 python/mxnet/numpy_dispatch_protocol.py            |   1 +
 python/mxnet/symbol/numpy/_symbol.py               |  35 ++-
 src/operator/elemwise_op_common.h                  |   3 +-
 src/operator/numpy/np_elemwise_broadcast_op.cc     | 243 +--------------------
 src/operator/numpy/np_elemwise_broadcast_op.cu     |  75 +------
 src/operator/numpy/np_elemwise_broadcast_op.h      | 114 +++++++++-
 ..._op.cc => np_elemwise_broadcast_op_extended.cc} | 193 ++++------------
 ..._op.cu => np_elemwise_broadcast_op_extended.cu} |  81 +------
 src/operator/operator_tune.cc                      |   4 +-
 src/operator/tensor/elemwise_binary_broadcast_op.h | 136 ++++++++----
 src/operator/tensor/elemwise_binary_op.h           | 148 +++++++------
 src/operator/tensor/elemwise_binary_scalar_op.h    |  20 ++
 src/operator/tensor/elemwise_unary_op.h            |   4 +-
 tests/python/unittest/test_gluon_estimator.py      |   4 +-
 .../python/unittest/test_numpy_interoperability.py |  13 ++
 tests/python/unittest/test_numpy_op.py             |  23 +-
 19 files changed, 528 insertions(+), 662 deletions(-)

diff --git a/python/mxnet/gluon/contrib/estimator/estimator.py 
b/python/mxnet/gluon/contrib/estimator/estimator.py
index 83b954d..54a0b16 100644
--- a/python/mxnet/gluon/contrib/estimator/estimator.py
+++ b/python/mxnet/gluon/contrib/estimator/estimator.py
@@ -59,6 +59,9 @@ class Estimator(object):
         Trainer to apply optimizer on network parameters.
     context : Context or list of Context
         Device(s) to run the training on.
+    evaluation_loss: gluon.loss.loss
+        Loss (objective) function to calculate during evaluation. If set 
evaluation_loss
+        None, it will use the same loss function as self.loss
 
     """
 
@@ -85,12 +88,16 @@ class Estimator(object):
                  metrics=None,
                  initializer=None,
                  trainer=None,
-                 context=None):
+                 context=None,
+                 evaluation_loss=None):
         self.net = net
         self.loss = self._check_loss(loss)
         self._train_metrics = _check_metrics(metrics)
         self._add_default_training_metrics()
         self._add_validation_metrics()
+        self.evaluation_loss = self.loss
+        if evaluation_loss is not None:
+            self.evaluation_loss = self._check_loss(evaluation_loss)
 
         self.logger = logging.Logger(name='Estimator', level=logging.INFO)
         self.logger.addHandler(logging.StreamHandler(sys.stdout))
@@ -228,7 +235,7 @@ class Estimator(object):
         """
         data, label = self._get_data_and_label(val_batch, self.context, 
batch_axis)
         pred = [self.net(x) for x in data]
-        loss = [self.loss(y_hat, y) for y_hat, y in zip(pred, label)]
+        loss = [self.evaluation_loss(y_hat, y) for y_hat, y in zip(pred, 
label)]
         # update metrics
         for metric in val_metrics:
             if isinstance(metric, metric_loss):
diff --git a/python/mxnet/ndarray/numpy/_op.py 
b/python/mxnet/ndarray/numpy/_op.py
index ff404a7..ed3d9d8 100644
--- a/python/mxnet/ndarray/numpy/_op.py
+++ b/python/mxnet/ndarray/numpy/_op.py
@@ -36,7 +36,7 @@ __all__ = ['zeros', 'ones', 'full', 'add', 'subtract', 
'multiply', 'divide', 'mo
            'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'split', 
'vsplit', 'concatenate', 'append',
            'stack', 'vstack', 'column_stack', 'dstack', 'mean', 'maximum', 
'minimum', 'swapaxes', 'clip', 'argmax',
            'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 
'hamming', 'blackman', 'flip',
-           'around', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 
'identity', 'take',
+           'around', 'hypot', 'bitwise_xor', 'rad2deg', 'deg2rad', 'unique', 
'lcm', 'tril', 'identity', 'take',
            'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 
'less', 'greater_equal', 'less_equal',
            'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 
'shares_memory', 'may_share_memory', 'diff']
 
@@ -4291,6 +4291,44 @@ def hypot(x1, x2, out=None, **kwargs):
 
 @set_module('mxnet.ndarray.numpy')
 @wrap_np_binary_func
+def bitwise_xor(x1, x2, out=None, **kwargs):
+    r"""
+    Compute the bit-wise XOR of two arrays element-wise.
+
+    Parameters
+    ----------
+    x1, x2 : ndarray or scalar
+        Only integer and boolean types are handled. If x1.shape != x2.shape,
+        they must be broadcastable to a common shape (which becomes the shape 
of the output).
+    out : ndarray, optional
+        A location into which the result is stored. If provided, it must have 
a shape that the
+        inputs broadcast to. If not provided or None, a freshly-allocated 
array is returned.
+
+    Returns
+    -------
+    out : ndarray
+        Result.
+
+    Examples
+    --------
+    >>> np.bitwise_xor(13, 17)
+    28
+
+    >>> np.bitwise_xor(31, 5)
+    26
+    >>> np.bitwise_xor(np.array([31,3], dtype='int32'), 5)
+    array([26,  6])
+
+    >>> np.bitwise_xor(np.array([31,3], dtype='int32'), np.array([5,6], 
dtype='int32'))
+    array([26,  5])
+    >>> np.bitwise_xor(np.array([True, True], dtype='bool'), np.array([False, 
True], dtype='bool'))
+    array([ True, False])
+    """
+    return _ufunc_helper(x1, x2, _npi.bitwise_xor, _np.bitwise_xor, 
_npi.bitwise_xor_scalar, None, out)
+
+
+@set_module('mxnet.ndarray.numpy')
+@wrap_np_binary_func
 def ldexp(x1, x2, out=None, **kwargs):
     """
     Returns x1 * 2**x2, element-wise.
diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py
index c623f67..ad5fb54 100644
--- a/python/mxnet/numpy/multiarray.py
+++ b/python/mxnet/numpy/multiarray.py
@@ -53,8 +53,8 @@ __all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 
'full', 'add', 'subtrac
            'tensordot', 'histogram', 'eye', 'linspace', 'logspace', 
'expand_dims', 'tile', 'arange',
            'split', 'vsplit', 'concatenate', 'stack', 'vstack', 
'column_stack', 'dstack', 'mean', 'maximum', 'minimum',
            'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 
'copysign', 'ravel', 'hanning', 'hamming',
-           'blackman', 'flip', 'around', 'arctan2', 'hypot', 'rad2deg', 
'deg2rad', 'unique', 'lcm', 'tril',
-           'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 
'not_equal', 'greater', 'less',
+           'blackman', 'flip', 'around', 'arctan2', 'hypot', 'bitwise_xor', 
'rad2deg', 'deg2rad', 'unique', 'lcm',
+           'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 
'equal', 'not_equal', 'greater', 'less',
            'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 
'true_divide', 'nonzero', 'shares_memory',
            'may_share_memory', 'diff']
 
@@ -6196,6 +6196,44 @@ def hypot(x1, x2, out=None, **kwargs):
 
 @set_module('mxnet.numpy')
 @wrap_np_binary_func
+def bitwise_xor(x1, x2, out=None, **kwargs):
+    r"""
+    Compute the bit-wise XOR of two arrays element-wise.
+
+    Parameters
+    ----------
+    x1, x2 : ndarray or scalar
+        Only integer and boolean types are handled. If x1.shape != x2.shape,
+        they must be broadcastable to a common shape (which becomes the shape 
of the output).
+    out : ndarray, optional
+        A location into which the result is stored. If provided, it must have 
a shape that the
+        inputs broadcast to. If not provided or None, a freshly-allocated 
array is returned.
+
+    Returns
+    -------
+    out : ndarray
+        Result.
+
+    Examples
+    --------
+    >>> np.bitwise_xor(13, 17)
+    28
+
+    >>> np.bitwise_xor(31, 5)
+    26
+    >>> np.bitwise_xor(np.array([31,3], dtype=np.int32), 5)
+    array([26,  6])
+
+    >>> np.bitwise_xor(np.array([31,3], dtype='int32'), np.array([5,6], 
dtype='int32'))
+    array([26,  5])
+    >>> np.bitwise_xor(np.array([True, True], dtype='bool'), np.array([False, 
True], dtype='bool'))
+    array([ True, False])
+    """
+    return _mx_nd_np.bitwise_xor(x1, x2, out=out)
+
+
+@set_module('mxnet.numpy')
+@wrap_np_binary_func
 def ldexp(x1, x2, out=None, **kwargs):
     """
     Returns x1 * 2**x2, element-wise.
diff --git a/python/mxnet/numpy_dispatch_protocol.py 
b/python/mxnet/numpy_dispatch_protocol.py
index cdd21af..8a4a90c 100644
--- a/python/mxnet/numpy_dispatch_protocol.py
+++ b/python/mxnet/numpy_dispatch_protocol.py
@@ -222,6 +222,7 @@ _NUMPY_ARRAY_UFUNC_LIST = [
     'ceil',
     'trunc',
     'floor',
+    'bitwise_xor',
     'logical_not',
     'equal',
     'not_equal',
diff --git a/python/mxnet/symbol/numpy/_symbol.py 
b/python/mxnet/symbol/numpy/_symbol.py
index d3837d2..e4ac462 100644
--- a/python/mxnet/symbol/numpy/_symbol.py
+++ b/python/mxnet/symbol/numpy/_symbol.py
@@ -38,7 +38,7 @@ __all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 
'divide', 'mod', 'rem
            'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'split', 
'vsplit', 'concatenate', 'append',
            'stack', 'vstack', 'column_stack', 'dstack', 'mean', 'maximum', 
'minimum', 'swapaxes', 'clip', 'argmax',
            'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 
'hamming', 'blackman', 'flip',
-           'around', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 
'identity', 'take',
+           'around', 'hypot', 'bitwise_xor', 'rad2deg', 'deg2rad', 'unique', 
'lcm', 'tril', 'identity', 'take',
            'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 
'less', 'greater_equal',
            'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 
'shares_memory', 'may_share_memory', 'diff']
 
@@ -4057,17 +4057,16 @@ def hypot(x1, x2, out=None, **kwargs):
 
     Parameters
     ----------
-    x1, x2 : array_like
+    x1, x2 : _Symbol or scalar
         Leg of the triangle(s).
-    out : ndarray, None, or tuple of ndarray and None, optional
+    out : _Symbol or None, optional
         A location into which the result is stored. If provided, it must have
         a shape that the inputs broadcast to. If not provided or `None`,
-        a freshly-allocated array is returned. A tuple (possible only as a
-        keyword argument) must have length equal to the number of outputs.
+        a freshly-allocated array is returned.
 
     Returns
     -------
-    z : ndarray
+    z : _Symbol or scalar
         The hypotenuse of the triangle(s).
         This is a scalar if both `x1` and `x2` are scalars.
 
@@ -4080,6 +4079,30 @@ def hypot(x1, x2, out=None, **kwargs):
 
 
 @set_module('mxnet.symbol.numpy')
+@wrap_np_binary_func
+def bitwise_xor(x1, x2, out=None, **kwargs):
+    r"""
+    Compute the bit-wise XOR of two arrays element-wise.
+
+    Parameters
+    ----------
+    x1, x2 : _Symbol or scalar
+        Only integer and boolean types are handled. If x1.shape != x2.shape,
+        they must be broadcastable to a common shape (which becomes the shape 
of the output).
+    out : _Symbol or None, optional
+        A location into which the result is stored. If provided, it must have
+        a shape that the inputs broadcast to. If not provided or `None`,
+        a freshly-allocated array is returned.
+
+    Returns
+    -------
+    out : _Symbol or scalar
+        Result.
+    """
+    return _ufunc_helper(x1, x2, _npi.bitwise_xor, _np.bitwise_xor, 
_npi.bitwise_xor_scalar, None, out)
+
+
+@set_module('mxnet.symbol.numpy')
 def unique(ar, return_index=False, return_inverse=False, return_counts=False, 
axis=None):
     """
     Find the unique elements of an array.
diff --git a/src/operator/elemwise_op_common.h 
b/src/operator/elemwise_op_common.h
index 6711297..2cdd73a 100644
--- a/src/operator/elemwise_op_common.h
+++ b/src/operator/elemwise_op_common.h
@@ -209,7 +209,8 @@ inline bool ElemwiseIntType(const nnvm::NodeAttrs& attrs,
   CHECK(in_attrs->at(0) == mshadow::kInt64 ||
         in_attrs->at(0) == mshadow::kInt32 ||
         in_attrs->at(0) == mshadow::kInt8 ||
-        in_attrs->at(0) == mshadow::kUint8) << "Only supports integer types.";
+        in_attrs->at(0) == mshadow::kUint8 ||
+        in_attrs->at(0) == mshadow::kBool) << "Only supports integer types.";
   if (n_in != -1) {
     CHECK_EQ(in_attrs->size(), static_cast<size_t>(n_in)) << " in operator " 
<< attrs.name;
   }
diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cc 
b/src/operator/numpy/np_elemwise_broadcast_op.cc
index a76e59d..f2adfc1 100644
--- a/src/operator/numpy/np_elemwise_broadcast_op.cc
+++ b/src/operator/numpy/np_elemwise_broadcast_op.cc
@@ -28,16 +28,6 @@
 namespace mxnet {
 namespace op {
 
-bool NumpyBinaryScalarType(const nnvm::NodeAttrs& attrs,
-                           std::vector<int>* in_attrs,
-                           std::vector<int>* out_attrs) {
-  CHECK_EQ(in_attrs->size(), 1U);
-  CHECK_EQ(out_attrs->size(), 1U);
-  TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
-  TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
-  return in_attrs->at(0) != -1;
-}
-
 #define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(name)              \
   NNVM_REGISTER_OP(name)                                            \
   .set_num_inputs(1)                                                \
@@ -147,22 +137,9 @@ 
MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_multiply)
   "FCompute<cpu>",
   NumpyBinaryBroadcastComputeWithBool<cpu, op::mshadow_op::mul>)
 #endif
-.set_attr<nnvm::FGradient>("FGradient", 
ElemwiseGradUseIn{"_backward_broadcast_mul"});
-
-MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_mod)
-.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, 
mshadow_op::mod>)
-.set_attr<nnvm::FGradient>("FGradient", 
ElemwiseGradUseIn{"_backward_broadcast_mod"});
-
-MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_power)
-.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, 
mshadow_op::power>)
-.set_attr<nnvm::FGradient>("FGradient", 
ElemwiseGradUseIn{"_backward_broadcast_power"});
-
-MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_copysign)
-.describe(R"code()code" ADD_FILELINE)
-.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, 
mshadow_op::copysign>)
-.set_attr<nnvm::FGradient>("FGradient", 
ElemwiseGradUseIn{"_backward_npi_copysign"});
+.set_attr<nnvm::FGradient>("FGradient", 
ElemwiseGradUseIn{"_backward_npi_broadcast_mul"});
 
-NNVM_REGISTER_OP(_backward_npi_copysign)
+NNVM_REGISTER_OP(_backward_npi_broadcast_mul)
 .set_num_inputs(3)
 .set_num_outputs(2)
 .set_attr<nnvm::TIsBackward>("TIsBackward", true)
@@ -174,44 +151,16 @@ NNVM_REGISTER_OP(_backward_npi_copysign)
   [](const NodeAttrs& attrs) {
     return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
   })
-.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastBackwardUseIn<cpu, 
mshadow_op::copysign_grad,
-                                                                  
mshadow_op::copysign_rgrad>);
+.set_attr<FCompute>("FCompute<cpu>", NumpyBinaryBackwardUseIn<cpu, 
mshadow_op::right,
+                                                              
mshadow_op::left>);
 
-NNVM_REGISTER_OP(_npi_lcm)
-.set_num_inputs(2)
-.set_num_outputs(1)
-.set_attr<nnvm::FListInputNames>("FListInputNames",
-[](const NodeAttrs& attrs) {
-     return std::vector<std::string>{"lhs", "rhs"};
-})
-.set_attr<mxnet::FInferShape>("FInferShape", BinaryBroadcastShape)
-.set_attr<nnvm::FInferType>("FInferType", ElemwiseIntType<2, 1>)
-.set_attr<nnvm::FInplaceOption>("FInplaceOption",
-[](const NodeAttrs& attrs){
-     return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}};
-})
-.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
-.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, 
mshadow_op::lcm>)
-.add_argument("lhs", "NDArray-or-Symbol", "First input to the function")
-.add_argument("rhs", "NDArray-or-Symbol", "Second input to the function");
-
-NNVM_REGISTER_OP(_npi_lcm_scalar)
-.set_num_inputs(1)
-.set_num_outputs(1)
-.set_attr_parser([](NodeAttrs* attrs) {
-    attrs->parsed = std::stod(attrs->dict["scalar"]);
-  })
-.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
-.set_attr<nnvm::FInferType>("FInferType", ElemwiseIntType<1, 1>)
-.set_attr<nnvm::FInplaceOption>("FInplaceOption",
-  [](const NodeAttrs& attrs){
-    return std::vector<std::pair<int, int> >{{0, 0}};
-  })
-.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
-.add_argument("data", "NDArray-or-Symbol", "source input")
-.add_argument("scalar", "int", "scalar input")
-.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, 
mshadow_op::lcm>);
+MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_mod)
+.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, 
mshadow_op::mod>)
+.set_attr<nnvm::FGradient>("FGradient", 
ElemwiseGradUseIn{"_backward_broadcast_mod"});
 
+MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_power)
+.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, 
mshadow_op::power>)
+.set_attr<nnvm::FGradient>("FGradient", 
ElemwiseGradUseIn{"_backward_broadcast_power"});
 MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_add_scalar)
 .set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, 
op::mshadow_op::plus>)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_copy"});
@@ -244,177 +193,5 @@ 
MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rpower_scalar)
 .set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, 
mshadow_op::rpower>)
 .set_attr<nnvm::FGradient>("FGradient", 
ElemwiseGradUseOut{"_backward_rpower_scalar"});
 
-MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_copysign_scalar)
-.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, 
mshadow_op::copysign>)
-.set_attr<nnvm::FGradient>("FGradient", 
ElemwiseGradUseIn{"_backward_npi_copysign_scalar"});
-
-MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rcopysign_scalar)
-.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, 
mshadow_op::rcopysign>)
-.set_attr<nnvm::FGradient>("FGradient", 
ElemwiseGradUseIn{"_backward_npi_rcopysign_scalar"});
-
-MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_backward_npi_copysign_scalar)
-.set_attr<FCompute>("FCompute<cpu>",
-                    BinaryScalarOp::Backward<cpu, mshadow_op::copysign_grad>);
-
-MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_backward_npi_rcopysign_scalar)
-.set_attr<FCompute>("FCompute<cpu>",
-                    BinaryScalarOp::Backward<cpu, mshadow_op::rcopysign_grad>);
-
-inline bool IsFloatType(const int dtype) {
-  return (dtype == mshadow::kFloat16 ||
-          dtype == mshadow::kFloat32 ||
-          dtype == mshadow::kFloat64);
-}
-
-inline bool Arctan2OpType(const nnvm::NodeAttrs& attrs,
-                          std::vector<int>* in_attrs,
-                          std::vector<int>* out_attrs) {
-  CHECK_EQ(in_attrs->size(), 2U);
-  CHECK_EQ(out_attrs->size(), 1U);
-
-  TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
-  TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(1));
-  TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
-  TYPE_ASSIGN_CHECK(*in_attrs, 1, out_attrs->at(0));
-  // check if it is float16, float32 or float64. If not, raise error.
-  CHECK(IsFloatType(in_attrs->at(0))) << "Do not support `int` as input.\n";
-  return out_attrs->at(0) != -1;
-}
-
-NNVM_REGISTER_OP(_npi_arctan2)
-.set_num_inputs(2)
-.set_num_outputs(1)
-.set_attr<nnvm::FListInputNames>("FListInputNames",
-  [](const NodeAttrs& attrs) {
-    return std::vector<std::string>{"x1", "x2"};
-  })
-.set_attr<mxnet::FInferShape>("FInferShape", BinaryBroadcastShape)
-.set_attr<nnvm::FInferType>("FInferType", Arctan2OpType)
-.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, 
mshadow_op::arctan2>)
-.set_attr<nnvm::FGradient>("FGradient", 
ElemwiseGradUseIn{"_backward_npi_arctan2"})
-.set_attr<nnvm::FInplaceOption>("FInplaceOption",
-  [](const NodeAttrs& attrs) {
-    return std::vector<std::pair<int, int> >{{0, 0}};
-  })
-.add_argument("x1", "NDArray-or-Symbol", "The input array")
-.add_argument("x2", "NDArray-or-Symbol", "The input array");
-
-NNVM_REGISTER_OP(_backward_npi_arctan2)
-.set_num_inputs(3)
-.set_num_outputs(2)
-.set_attr<nnvm::TIsBackward>("TIsBackward", true)
-.set_attr<FResourceRequest>("FResourceRequest",
-  [](const NodeAttrs& attrs) {
-    return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
-  })
-.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastBackwardUseIn<cpu, 
mshadow_op::arctan2_grad,
-                                                                  
mshadow_op::arctan2_rgrad>);
-
-MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_arctan2_scalar)
-.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, 
mshadow_op::arctan2>)
-.set_attr<nnvm::FGradient>("FGradient", 
ElemwiseGradUseIn{"_backward_npi_arctan2_scalar"});
-
-MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rarctan2_scalar)
-.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, 
mshadow_op::rarctan2>)
-.set_attr<nnvm::FGradient>("FGradient", 
ElemwiseGradUseIn{"_backward_npi_rarctan2_scalar"});
-
-MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_arctan2_scalar)
-.add_argument("scalar", "float", "scalar value")
-.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = 
std::stod(attrs->dict["scalar"]); })
-.set_attr<FCompute>("FCompute<cpu>",
-                    BinaryScalarOp::Backward<cpu, mshadow_op::arctan2_grad>);
-
-MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_rarctan2_scalar)
-.add_argument("scalar", "float", "scalar value")
-.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = 
std::stod(attrs->dict["scalar"]); })
-.set_attr<FCompute>("FCompute<cpu>",
-                    BinaryScalarOp::Backward<cpu, mshadow_op::arctan2_rgrad>);
-
-bool HypotOpType(const nnvm::NodeAttrs& attrs,
-                 std::vector<int>* in_attrs,
-                 std::vector<int>* out_attrs) {
-  CHECK_EQ(in_attrs->size(), 2U);
-  CHECK_EQ(out_attrs->size(), 1U);
-
-  TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
-  TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(1));
-  TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
-  TYPE_ASSIGN_CHECK(*in_attrs, 1, out_attrs->at(0));
-
-  CHECK(IsFloatType(in_attrs->at(0))) << "Do not support `int` as input.\n";
-  return out_attrs->at(0) != -1;
-}
-
-// rigister hypot that do not support int here
-NNVM_REGISTER_OP(_npi_hypot)
-.set_num_inputs(2)
-.set_num_outputs(1)
-.set_attr<nnvm::FListInputNames>("FListInputNames",
-  [](const NodeAttrs& attrs) {
-    return std::vector<std::string>{"x1", "x2"};
-  })
-.set_attr<mxnet::FInferShape>("FInferShape", BinaryBroadcastShape)
-.set_attr<nnvm::FInferType>("FInferType", HypotOpType)
-.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, 
mshadow_op::hypot>)
-.set_attr<nnvm::FGradient>("FGradient", 
ElemwiseGradUseIn{"_backward_npi_hypot"})
-.set_attr<nnvm::FInplaceOption>("FInplaceOption",
-  [](const NodeAttrs& attrs) {
-    return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}};
-  })
-.add_argument("x1", "NDArray-or-Symbol", "The input array")
-.add_argument("x2", "NDArray-or-Symbol", "The input array");
-
-NNVM_REGISTER_OP(_backward_npi_hypot)
-.set_num_inputs(3)
-.set_num_outputs(2)
-.set_attr<nnvm::TIsBackward>("TIsBackward", true)
-.set_attr<nnvm::FInplaceOption>("FInplaceOption",
-  [](const NodeAttrs& attrs) {
-    return std::vector<std::pair<int, int> > {{0, 1}};
-  })
-.set_attr<FResourceRequest>("FResourceRequest",
-  [](const NodeAttrs& attrs) {
-    return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
-  })
-.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastBackwardUseIn<cpu, 
mshadow_op::hypot_grad_left,
-                                                                  
mshadow_op::hypot_grad_right>);
-
-MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_ldexp)
-.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, 
mshadow_op::ldexp>)
-.set_attr<nnvm::FGradient>("FGradient", 
ElemwiseGradUseIn{"_backward_npi_ldexp"});
-
-MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_ldexp_scalar)
-.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, 
mshadow_op::ldexp>)
-.set_attr<nnvm::FGradient>("FGradient", 
ElemwiseGradUseIn{"_backward_npi_ldexp_scalar"});
-
-MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rldexp_scalar)
-.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, 
mshadow_op::rldexp>)
-.set_attr<nnvm::FGradient>("FGradient", 
ElemwiseGradUseIn{"_backward_npi_rldexp_scalar"});
-
-NNVM_REGISTER_OP(_backward_npi_ldexp)
-.set_num_inputs(3)
-.set_num_outputs(2)
-.set_attr<nnvm::TIsBackward>("TIsBackward", true)
-.set_attr<nnvm::FInplaceOption>("FInplaceOption",
-  [](const NodeAttrs& attrs){
-    return std::vector<std::pair<int, int> >{{0, 1}};
-  })
-.set_attr<FResourceRequest>("FResourceRequest",
-  [](const NodeAttrs& attrs) {
-    return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
-  })
-.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastBackwardUseIn<cpu, 
mshadow_op::ldexp_grad,
-                                                                  
mshadow_op::ldexp_rgrad>);
-
-MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_ldexp_scalar)
-.add_argument("scalar", "float", "scalar value")
-.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = 
std::stod(attrs->dict["scalar"]); })
-.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Backward<cpu, 
mshadow_op::ldexp_grad>);
-
-MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_rldexp_scalar)
-.add_argument("scalar", "float", "scalar value")
-.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = 
std::stod(attrs->dict["scalar"]); })
-.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Backward<cpu, 
mshadow_op::rldexp_grad>);
-
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cu 
b/src/operator/numpy/np_elemwise_broadcast_op.cu
index a0a277d..59dfc25 100644
--- a/src/operator/numpy/np_elemwise_broadcast_op.cu
+++ b/src/operator/numpy/np_elemwise_broadcast_op.cu
@@ -64,35 +64,16 @@ NNVM_REGISTER_OP(_npi_multiply)
   NumpyBinaryBroadcastComputeWithBool<gpu, op::mshadow_op::mul>);
 #endif
 
+NNVM_REGISTER_OP(_backward_npi_broadcast_mul)
+.set_attr<FCompute>("FCompute<gpu>", NumpyBinaryBackwardUseIn<gpu, 
mshadow_op::right,
+                                                              
mshadow_op::left>);
+
 NNVM_REGISTER_OP(_npi_mod)
 .set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, 
mshadow_op::mod>);
 
 NNVM_REGISTER_OP(_npi_power)
 .set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, 
mshadow_op::power>);
 
-NNVM_REGISTER_OP(_npi_copysign)
-.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, 
mshadow_op::copysign>);
-
-NNVM_REGISTER_OP(_npi_lcm)
-.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, 
mshadow_op::lcm>);
-
-NNVM_REGISTER_OP(_backward_npi_copysign)
-.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastBackwardUseIn<gpu, 
mshadow_op::copysign_grad,
-                                                                  
mshadow_op::copysign_rgrad>);
-
-NNVM_REGISTER_OP(_npi_arctan2)
-.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, 
mshadow_op::arctan2>);
-
-NNVM_REGISTER_OP(_backward_npi_arctan2)
-.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastBackwardUseIn<gpu, 
mshadow_op::arctan2_grad,
-                                                                  
mshadow_op::arctan2_rgrad>);
-NNVM_REGISTER_OP(_npi_hypot)
-.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, 
mshadow_op::hypot>);
-
-NNVM_REGISTER_OP(_backward_npi_hypot)
-.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastBackwardUseIn<gpu, 
mshadow_op::hypot_grad_left,
-                                                                  
mshadow_op::hypot_grad_right>);
-
 NNVM_REGISTER_OP(_npi_add_scalar)
 .set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, 
op::mshadow_op::plus>);
 
@@ -117,53 +98,5 @@ NNVM_REGISTER_OP(_npi_power_scalar)
 NNVM_REGISTER_OP(_npi_rpower_scalar)
 .set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, 
mshadow_op::rpower>);
 
-NNVM_REGISTER_OP(_npi_copysign_scalar)
-.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, 
mshadow_op::copysign>);
-
-NNVM_REGISTER_OP(_npi_rcopysign_scalar)
-.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, 
mshadow_op::rcopysign>);
-
-NNVM_REGISTER_OP(_backward_npi_copysign_scalar)
-.set_attr<FCompute>("FCompute<gpu>",
-                    BinaryScalarOp::Backward<gpu, mshadow_op::copysign_grad>);
-
-NNVM_REGISTER_OP(_backward_npi_rcopysign_scalar)
-.set_attr<FCompute>("FCompute<gpu>",
-                    BinaryScalarOp::Backward<gpu, mshadow_op::rcopysign_grad>);
-
-NNVM_REGISTER_OP(_npi_arctan2_scalar)
-.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, 
mshadow_op::arctan2>);
-
-NNVM_REGISTER_OP(_backward_npi_arctan2_scalar)
-.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, 
mshadow_op::arctan2_grad>);
-
-NNVM_REGISTER_OP(_npi_rarctan2_scalar)
-.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, 
mshadow_op::rarctan2>);
-
-NNVM_REGISTER_OP(_backward_npi_rarctan2_scalar)
-.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, 
mshadow_op::rarctan2_grad>);
-
-NNVM_REGISTER_OP(_npi_lcm_scalar)
-.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, 
mshadow_op::lcm>);
-
-NNVM_REGISTER_OP(_npi_ldexp)
-.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, 
mshadow_op::ldexp>);
-
-NNVM_REGISTER_OP(_npi_ldexp_scalar)
-.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, 
mshadow_op::ldexp>);
-
-NNVM_REGISTER_OP(_npi_rldexp_scalar)
-.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, 
mshadow_op::rldexp>);
-
-NNVM_REGISTER_OP(_backward_npi_ldexp)
-.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastBackwardUseIn<gpu, 
mshadow_op::ldexp_grad,
-                                                                  
mshadow_op::ldexp_rgrad>);
-
-NNVM_REGISTER_OP(_backward_npi_ldexp_scalar)
-.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Backward<gpu, 
mshadow_op::ldexp_grad>);
-
-NNVM_REGISTER_OP(_backward_npi_rldexp_scalar)
-.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Backward<gpu, 
mshadow_op::rldexp_grad>);
-
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/numpy/np_elemwise_broadcast_op.h 
b/src/operator/numpy/np_elemwise_broadcast_op.h
index 1a4596f..a2b7877 100644
--- a/src/operator/numpy/np_elemwise_broadcast_op.h
+++ b/src/operator/numpy/np_elemwise_broadcast_op.h
@@ -25,6 +25,7 @@
 #ifndef MXNET_OPERATOR_NUMPY_NP_ELEMWISE_BROADCAST_OP_H_
 #define MXNET_OPERATOR_NUMPY_NP_ELEMWISE_BROADCAST_OP_H_
 
+#include <algorithm>
 #include <vector>
 #include <string>
 
@@ -34,6 +35,16 @@
 namespace mxnet {
 namespace op {
 
+inline bool NumpyBinaryScalarType(const nnvm::NodeAttrs& attrs,
+                           std::vector<int>* in_attrs,
+                           std::vector<int>* out_attrs) {
+  CHECK_EQ(in_attrs->size(), 1U);
+  CHECK_EQ(out_attrs->size(), 1U);
+  TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
+  TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
+  return in_attrs->at(0) != -1;
+}
+
 inline void PrintErrorMessage(const std::string& op_name, const int dtype1, 
const int dtype2) {
   LOG(FATAL) << "Operator " << op_name << " does not support combination of "
              << common::dtype_string(dtype1) << " with " << 
common::dtype_string(dtype2)
@@ -381,11 +392,13 @@ void NumpyBinaryBroadcastComputeWithBool(const 
nnvm::NodeAttrs& attrs,
 }
 
 template<typename xpu, typename LOP, typename ROP>
-void MixedBinaryBackwardUseIn(const nnvm::NodeAttrs& attrs,
+void NumpyBinaryBackwardUseIn(const nnvm::NodeAttrs& attrs,
                               const OpContext& ctx,
                               const std::vector<TBlob>& inputs,
                               const std::vector<OpReqType>& req,
                               const std::vector<TBlob>& outputs) {
+  using namespace mshadow;
+  using namespace mxnet_op;
   CHECK_EQ(inputs.size(), 3U);
   CHECK_EQ(outputs.size(), 2U);
 
@@ -396,7 +409,104 @@ void MixedBinaryBackwardUseIn(const nnvm::NodeAttrs& 
attrs,
     return;
   }
 
-  PrintErrorMessage(attrs.op->name, lhs.type_flag_, rhs.type_flag_);
+  const TBlob& ograd = inputs[0];
+  const TBlob& lgrad = outputs[0];
+  const TBlob& rgrad = outputs[1];
+
+  if (common::is_float(lhs.type_flag_) || common::is_float(rhs.type_flag_)) {
+    // If any of the inputs is a float, it's the same type as the output
+    // So 2 of the 3 tensors have the same data type
+    Stream<xpu> *s = ctx.get_stream<xpu>();
+    mxnet::TShape new_lshape, new_rshape, new_oshape;
+    using namespace broadcast;
+    const bool need_bc = BinaryBroadcastShapeCompact(lgrad.shape_, 
rgrad.shape_, ograd.shape_,
+                                                     &new_lshape, &new_rshape, 
&new_oshape) != 0;
+
+    // Prepare all the temporary memory
+    size_t workspace_size_l = 0, workspace_size_r = 0;
+    TBlob temp_tblob;  // The TBlob for casted input data
+    TBlob temp_igrad;  // The TBlob for casted grad results
+    size_t tensor_size = (lgrad.type_flag_ != ograd.type_flag_) ? lgrad.Size() 
: rgrad.Size();
+    Tensor<xpu, 1, char> workspace;
+
+    MSHADOW_TYPE_SWITCH(ograd.type_flag_, OType, {
+      BROADCAST_NDIM_SWITCH(new_oshape.ndim(), ndim, {
+        workspace_size_l = ReduceWorkspaceSize<ndim, OType>(
+          s, new_lshape, req[0], new_oshape, new_lshape, new_rshape);
+        workspace_size_r = ReduceWorkspaceSize<ndim, OType>(
+          s, new_rshape, req[1], new_oshape, new_lshape, new_rshape);
+      });
+      size_t workspace_size = std::max(workspace_size_l, workspace_size_r);
+      size_t cast_tensor_size = tensor_size * sizeof(OType);
+      // Allocate the temporary memories now
+      Tensor<xpu, 1, char> temp_space =
+        ctx.requested[0].get_space_typed<xpu, 1, char>(
+          Shape1(workspace_size + cast_tensor_size * 2), s);
+      // Tensor for temp_tblob
+      Tensor<xpu, 1, OType> temp_tblob_tensor(
+                              reinterpret_cast<OType*>(temp_space.dptr_),
+                              Shape1(tensor_size), s);
+      // Tensor for temp_igrad
+      Tensor<xpu, 1, OType> temp_igrad_tensor(
+                              reinterpret_cast<OType*>(temp_space.dptr_) + 
tensor_size,
+                              Shape1(tensor_size), s);
+      temp_tblob =
+        TBlob(temp_tblob_tensor)
+          .reshape(((lgrad.type_flag_ != ograd.type_flag_) ? lhs.shape_ : 
rhs.shape_));
+      temp_igrad =
+        TBlob(temp_igrad_tensor)
+          .reshape(((lgrad.type_flag_ != ograd.type_flag_) ? lhs.shape_ : 
rhs.shape_));
+      if (temp_igrad.Size() != 0) {
+        Kernel<set_zero, xpu>::Launch(s, temp_igrad.Size(), 
temp_igrad.dptr<OType>());
+      }
+      workspace =
+        Tensor<xpu, 1, char>(temp_space.dptr_ + 2 * cast_tensor_size, 
Shape1(workspace_size), s);
+    });
+    // Cast the input that does not have consistent type to temp_tblob
+    CastCompute<xpu>(
+      attrs, ctx, {((lgrad.type_flag_ != ograd.type_flag_) ? lhs : rhs)}, 
{kWriteTo}, {temp_tblob});
+    if (!need_bc) {
+      if (lhs.type_flag_ != ograd.type_flag_) {
+        ElemwiseBinaryOp::BackwardUseIn<xpu, LOP, ROP>(
+          attrs, ctx, {ograd, temp_tblob, rhs}, {kWriteTo, req[1]}, 
{temp_igrad, rgrad});
+      } else {
+        ElemwiseBinaryOp::BackwardUseIn<xpu, LOP, ROP>(
+          attrs, ctx, {ograd, lhs, temp_tblob}, {req[0], kWriteTo}, {lgrad, 
temp_igrad});
+      }
+    } else {
+      if (lhs.type_flag_ != ograd.type_flag_) {
+        MSHADOW_TYPE_SWITCH(ograd.type_flag_, DType, {
+          BROADCAST_NDIM_SWITCH(new_oshape.ndim(), NDim, {
+            BinaryBroadcastBackwardUseInImplWithWorkspace<xpu, NDim, DType, 
LOP, ROP>(
+              ctx, {ograd, temp_tblob, rhs}, {kWriteTo, req[1]}, {temp_igrad, 
rgrad},
+              workspace, new_lshape, new_rshape, new_oshape);
+          });
+        });
+      } else {
+        MSHADOW_TYPE_SWITCH(ograd.type_flag_, DType, {
+          BROADCAST_NDIM_SWITCH(new_oshape.ndim(), NDim, {
+            BinaryBroadcastBackwardUseInImplWithWorkspace<xpu, NDim, DType, 
LOP, ROP>(
+              ctx, {ograd, lhs, temp_tblob}, {req[0], kWriteTo}, {lgrad, 
temp_igrad},
+              workspace, new_lshape, new_rshape, new_oshape);
+          });
+        });
+      }
+    }
+
+    // If both inputs are floating numbers, cast the igrad to the input that 
has
+    // the different data type
+    if (common::is_float(lhs.type_flag_) && common::is_float(rhs.type_flag_)) {
+      if (lhs.type_flag_ != ograd.type_flag_) {
+        CastCompute<xpu>(attrs, ctx, {temp_igrad}, {req[0]}, {lgrad});
+      } else {
+        CastCompute<xpu>(attrs, ctx, {temp_igrad}, {req[1]}, {rgrad});
+      }
+    }
+  } else {
+    // Case where both inputs are integer types, should not even do
+    // backward computation for this case.
+    PrintErrorMessage(attrs.op->name, lhs.type_flag_, rhs.type_flag_);
+  }
 }
 
 }  // namespace op
diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cc 
b/src/operator/numpy/np_elemwise_broadcast_op_extended.cc
similarity index 59%
copy from src/operator/numpy/np_elemwise_broadcast_op.cc
copy to src/operator/numpy/np_elemwise_broadcast_op_extended.cc
index a76e59d..84c47e5 100644
--- a/src/operator/numpy/np_elemwise_broadcast_op.cc
+++ b/src/operator/numpy/np_elemwise_broadcast_op_extended.cc
@@ -19,25 +19,16 @@
 
 /*!
  *  Copyright (c) 2019 by Contributors
- * \file np_elemwise_binary_op.cc
- * \brief CPU Implementation of basic functions for elementwise numpy binary 
broadcast operator.
+ * \file np_elemwise_binary_op_extended.cc
+ * \brief CPU Implementation of extended functions for elementwise numpy 
binary broadcast operator.
  */
 
+#include "../../common/utils.h"
 #include "./np_elemwise_broadcast_op.h"
 
 namespace mxnet {
 namespace op {
 
-bool NumpyBinaryScalarType(const nnvm::NodeAttrs& attrs,
-                           std::vector<int>* in_attrs,
-                           std::vector<int>* out_attrs) {
-  CHECK_EQ(in_attrs->size(), 1U);
-  CHECK_EQ(out_attrs->size(), 1U);
-  TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
-  TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
-  return in_attrs->at(0) != -1;
-}
-
 #define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(name)              \
   NNVM_REGISTER_OP(name)                                            \
   .set_num_inputs(1)                                                \
@@ -54,109 +45,6 @@ bool NumpyBinaryScalarType(const nnvm::NodeAttrs& attrs,
   .add_argument("data", "NDArray-or-Symbol", "source input")        \
   .add_argument("scalar", "float", "scalar input")
 
-bool NumpyBinaryMixedPrecisionType(const nnvm::NodeAttrs& attrs,
-                                   std::vector<int>* in_attrs,
-                                   std::vector<int>* out_attrs) {
-  CHECK_EQ(in_attrs->size(), 2U);
-  CHECK_EQ(out_attrs->size(), 1U);
-  const int ltype = in_attrs->at(0);
-  const int rtype = in_attrs->at(1);
-  if (ltype != -1 && rtype != -1 && (ltype != rtype)) {
-    // Only when both input types are known and not the same, we enter the 
mixed-precision mode
-    TYPE_ASSIGN_CHECK(*out_attrs, 0, common::np_binary_out_infer_type(ltype, 
rtype));
-  } else {
-    return ElemwiseType<2, 1>(attrs, in_attrs, out_attrs);
-  }
-  return true;
-}
-
-#ifndef _WIN32
-#define MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(name)                
\
-  NNVM_REGISTER_OP(name)                                                       
\
-  .set_num_inputs(2)                                                           
\
-  .set_num_outputs(1)                                                          
\
-  .set_attr<nnvm::FListInputNames>("FListInputNames",                          
\
-    [](const NodeAttrs& attrs) {                                               
\
-      return std::vector<std::string>{"lhs", "rhs"};                           
\
-    })                                                                         
\
-  .set_attr<mxnet::FInferShape>("FInferShape", BinaryBroadcastShape)           
\
-  .set_attr<nnvm::FInferType>("FInferType", NumpyBinaryMixedPrecisionType)     
\
-  .set_attr<nnvm::FInplaceOption>("FInplaceOption",                            
\
-    [](const NodeAttrs& attrs){                                                
\
-      return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}};                
\
-    })                                                                         
\
-  .add_argument("lhs", "NDArray-or-Symbol", "First input to the function")     
\
-  .add_argument("rhs", "NDArray-or-Symbol", "Second input to the function")
-#else
-#define MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(name)                
\
-  NNVM_REGISTER_OP(name)                                                       
\
-  .set_num_inputs(2)                                                           
\
-  .set_num_outputs(1)                                                          
\
-  .set_attr<nnvm::FListInputNames>("FListInputNames",                          
\
-    [](const NodeAttrs& attrs) {                                               
\
-      return std::vector<std::string>{"lhs", "rhs"};                           
\
-    })                                                                         
\
-  .set_attr<mxnet::FInferShape>("FInferShape", BinaryBroadcastShape)           
\
-  .set_attr<nnvm::FInferType>("FInferType", NumpyBinaryMixedPrecisionType)     
\
-  .set_attr<nnvm::FInplaceOption>("FInplaceOption",                            
\
-    [](const NodeAttrs& attrs){                                                
\
-      return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}};                
\
-    })                                                                         
\
-  .set_attr<FResourceRequest>("FResourceRequest",                              
\
-  [](const NodeAttrs& attrs) {                                                 
\
-    return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};          
\
-  })                                                                           
\
-  .add_argument("lhs", "NDArray-or-Symbol", "First input to the function")     
\
-  .add_argument("rhs", "NDArray-or-Symbol", "Second input to the function")
-#endif
-
-MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_add)
-#ifndef _WIN32
-.set_attr<FCompute>(
-  "FCompute<cpu>",
-  NumpyBinaryBroadcastComputeWithBool<cpu, op::mshadow_op::plus, 
op::mshadow_op::mixed_plus,
-                                      op::mshadow_op::mixed_plus>)
-#else
-.set_attr<FCompute>(
-  "FCompute<cpu>",
-  NumpyBinaryBroadcastComputeWithBool<cpu, op::mshadow_op::plus>)
-#endif
-.set_attr<nnvm::FGradient>("FGradient", 
ElemwiseGradUseNone{"_backward_broadcast_add"});
-
-MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_subtract)
-#ifndef _WIN32
-.set_attr<FCompute>(
-  "FCompute<cpu>",
-  NumpyBinaryBroadcastCompute<cpu, op::mshadow_op::minus, 
op::mshadow_op::mixed_minus,
-                              op::mshadow_op::mixed_rminus>)
-#else
-.set_attr<FCompute>(
-  "FCompute<cpu>",
-  NumpyBinaryBroadcastCompute<cpu, op::mshadow_op::minus>)
-#endif
-.set_attr<nnvm::FGradient>("FGradient", 
ElemwiseGradUseNone{"_backward_broadcast_sub"});
-
-MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_multiply)
-#ifndef _WIN32
-.set_attr<FCompute>(
-  "FCompute<cpu>",
-  NumpyBinaryBroadcastComputeWithBool<cpu, op::mshadow_op::mul, 
op::mshadow_op::mixed_mul,
-                                      op::mshadow_op::mixed_mul>)
-#else
-.set_attr<FCompute>(
-  "FCompute<cpu>",
-  NumpyBinaryBroadcastComputeWithBool<cpu, op::mshadow_op::mul>)
-#endif
-.set_attr<nnvm::FGradient>("FGradient", 
ElemwiseGradUseIn{"_backward_broadcast_mul"});
-
-MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_mod)
-.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, 
mshadow_op::mod>)
-.set_attr<nnvm::FGradient>("FGradient", 
ElemwiseGradUseIn{"_backward_broadcast_mod"});
-
-MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_power)
-.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, 
mshadow_op::power>)
-.set_attr<nnvm::FGradient>("FGradient", 
ElemwiseGradUseIn{"_backward_broadcast_power"});
-
 MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_copysign)
 .describe(R"code()code" ADD_FILELINE)
 .set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, 
mshadow_op::copysign>)
@@ -191,7 +79,7 @@ NNVM_REGISTER_OP(_npi_lcm)
      return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}};
 })
 .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
-.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, 
mshadow_op::lcm>)
+.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastIntCompute<cpu, 
mshadow_op::lcm>)
 .add_argument("lhs", "NDArray-or-Symbol", "First input to the function")
 .add_argument("rhs", "NDArray-or-Symbol", "Second input to the function");
 
@@ -212,37 +100,40 @@ NNVM_REGISTER_OP(_npi_lcm_scalar)
 .add_argument("scalar", "int", "scalar input")
 .set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, 
mshadow_op::lcm>);
 
-MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_add_scalar)
-.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, 
op::mshadow_op::plus>)
-.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_copy"});
-
-MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_subtract_scalar)
-.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, 
op::mshadow_op::minus>)
-.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_copy"});
-
-MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rsubtract_scalar)
-.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, 
mshadow_op::rminus>)
-.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"negative"});
-
-MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_multiply_scalar)
-.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, 
op::mshadow_op::mul>)
-.set_attr<nnvm::FGradient>("FGradient", 
ElemwiseGradUseNone{"_backward_mul_scalar"});
-
-MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_mod_scalar)
-.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, 
mshadow_op::mod>)
-.set_attr<nnvm::FGradient>("FGradient", 
ElemwiseGradUseIn{"_backward_mod_scalar"});
-
-MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rmod_scalar)
-.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, 
mshadow_op::rmod>)
-.set_attr<nnvm::FGradient>("FGradient", 
ElemwiseGradUseIn{"_backward_rmod_scalar"});
-
-MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_power_scalar)
-.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, 
mshadow_op::power>)
-.set_attr<nnvm::FGradient>("FGradient", 
ElemwiseGradUseIn{"_backward_power_scalar"});
+NNVM_REGISTER_OP(_npi_bitwise_xor)
+.set_num_inputs(2)
+.set_num_outputs(1)
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+[](const NodeAttrs& attrs) {
+     return std::vector<std::string>{"lhs", "rhs"};
+})
+.set_attr<mxnet::FInferShape>("FInferShape", BinaryBroadcastShape)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseIntType<2, 1>)
+.set_attr<nnvm::FInplaceOption>("FInplaceOption",
+[](const NodeAttrs& attrs){
+     return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}};
+})
+.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
+.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastIntCompute<cpu, 
mshadow_op::bitwise_xor>)
+.add_argument("lhs", "NDArray-or-Symbol", "First input to the function")
+.add_argument("rhs", "NDArray-or-Symbol", "Second input to the function");
 
-MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rpower_scalar)
-.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, 
mshadow_op::rpower>)
-.set_attr<nnvm::FGradient>("FGradient", 
ElemwiseGradUseOut{"_backward_rpower_scalar"});
+NNVM_REGISTER_OP(_npi_bitwise_xor_scalar)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr_parser([](NodeAttrs* attrs) {
+    attrs->parsed = std::stod(attrs->dict["scalar"]);
+  })
+.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseIntType<1, 1>)
+.set_attr<nnvm::FInplaceOption>("FInplaceOption",
+  [](const NodeAttrs& attrs){
+    return std::vector<std::pair<int, int> >{{0, 0}};
+  })
+.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
+.add_argument("data", "NDArray-or-Symbol", "source input")
+.add_argument("scalar", "int", "scalar input")
+.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::ComputeInt<cpu, 
mshadow_op::bitwise_xor>);
 
 MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_copysign_scalar)
 .set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, 
mshadow_op::copysign>)
@@ -260,12 +151,6 @@ 
MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_backward_npi_rcopysign_scalar)
 .set_attr<FCompute>("FCompute<cpu>",
                     BinaryScalarOp::Backward<cpu, mshadow_op::rcopysign_grad>);
 
-inline bool IsFloatType(const int dtype) {
-  return (dtype == mshadow::kFloat16 ||
-          dtype == mshadow::kFloat32 ||
-          dtype == mshadow::kFloat64);
-}
-
 inline bool Arctan2OpType(const nnvm::NodeAttrs& attrs,
                           std::vector<int>* in_attrs,
                           std::vector<int>* out_attrs) {
@@ -277,7 +162,7 @@ inline bool Arctan2OpType(const nnvm::NodeAttrs& attrs,
   TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
   TYPE_ASSIGN_CHECK(*in_attrs, 1, out_attrs->at(0));
   // check if it is float16, float32 or float64. If not, raise error.
-  CHECK(IsFloatType(in_attrs->at(0))) << "Do not support `int` as input.\n";
+  CHECK(common::is_float(in_attrs->at(0))) << "Do not support `int` as 
input.\n";
   return out_attrs->at(0) != -1;
 }
 
@@ -341,7 +226,7 @@ bool HypotOpType(const nnvm::NodeAttrs& attrs,
   TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
   TYPE_ASSIGN_CHECK(*in_attrs, 1, out_attrs->at(0));
 
-  CHECK(IsFloatType(in_attrs->at(0))) << "Do not support `int` as input.\n";
+  CHECK(common::is_float(in_attrs->at(0))) << "Do not support `int` as 
input.\n";
   return out_attrs->at(0) != -1;
 }
 
diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cu 
b/src/operator/numpy/np_elemwise_broadcast_op_extended.cu
similarity index 61%
copy from src/operator/numpy/np_elemwise_broadcast_op.cu
copy to src/operator/numpy/np_elemwise_broadcast_op_extended.cu
index a0a277d..f858fb4 100644
--- a/src/operator/numpy/np_elemwise_broadcast_op.cu
+++ b/src/operator/numpy/np_elemwise_broadcast_op_extended.cu
@@ -19,8 +19,8 @@
 
 /*!
  *  Copyright (c) 2019 by Contributors
- * \file np_elemwise_broadcast_op.cu
- * \brief GPU Implementation of basic functions for elementwise binary 
broadcast operator.
+ * \file np_elemwise_broadcast_op_extended.cu
+ * \brief GPU Implementation of extended functions for elementwise binary 
broadcast operator.
  */
 
 #include "./np_elemwise_broadcast_op.h"
@@ -28,53 +28,14 @@
 namespace mxnet {
 namespace op {
 
-NNVM_REGISTER_OP(_npi_add)
-#ifndef _WIN32
-.set_attr<FCompute>(
-  "FCompute<gpu>",
-  NumpyBinaryBroadcastComputeWithBool<gpu, op::mshadow_op::plus, 
op::mshadow_op::mixed_plus,
-                                      op::mshadow_op::mixed_plus>);
-#else
-.set_attr<FCompute>(
-  "FCompute<gpu>",
-  NumpyBinaryBroadcastComputeWithBool<gpu, op::mshadow_op::plus>);
-#endif
-
-NNVM_REGISTER_OP(_npi_subtract)
-#ifndef _WIN32
-.set_attr<FCompute>(
-  "FCompute<gpu>",
-  NumpyBinaryBroadcastCompute<gpu, op::mshadow_op::minus, 
op::mshadow_op::mixed_minus,
-                              op::mshadow_op::mixed_rminus>);
-#else
-.set_attr<FCompute>(
-  "FCompute<gpu>",
-  NumpyBinaryBroadcastCompute<gpu, op::mshadow_op::minus>);
-#endif
-
-NNVM_REGISTER_OP(_npi_multiply)
-#ifndef _WIN32
-.set_attr<FCompute>(
-  "FCompute<gpu>",
-  NumpyBinaryBroadcastComputeWithBool<gpu, op::mshadow_op::mul, 
op::mshadow_op::mixed_mul,
-                                      op::mshadow_op::mixed_mul>);
-#else
-.set_attr<FCompute>(
-  "FCompute<gpu>",
-  NumpyBinaryBroadcastComputeWithBool<gpu, op::mshadow_op::mul>);
-#endif
-
-NNVM_REGISTER_OP(_npi_mod)
-.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, 
mshadow_op::mod>);
-
-NNVM_REGISTER_OP(_npi_power)
-.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, 
mshadow_op::power>);
-
 NNVM_REGISTER_OP(_npi_copysign)
 .set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, 
mshadow_op::copysign>);
 
 NNVM_REGISTER_OP(_npi_lcm)
-.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, 
mshadow_op::lcm>);
+.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastIntCompute<gpu, 
mshadow_op::lcm>);
+
+NNVM_REGISTER_OP(_npi_bitwise_xor)
+.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastIntCompute<gpu, 
mshadow_op::bitwise_xor>);
 
 NNVM_REGISTER_OP(_backward_npi_copysign)
 .set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastBackwardUseIn<gpu, 
mshadow_op::copysign_grad,
@@ -92,31 +53,6 @@ NNVM_REGISTER_OP(_npi_hypot)
 NNVM_REGISTER_OP(_backward_npi_hypot)
 .set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastBackwardUseIn<gpu, 
mshadow_op::hypot_grad_left,
                                                                   
mshadow_op::hypot_grad_right>);
-
-NNVM_REGISTER_OP(_npi_add_scalar)
-.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, 
op::mshadow_op::plus>);
-
-NNVM_REGISTER_OP(_npi_subtract_scalar)
-.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, 
op::mshadow_op::minus>);
-
-NNVM_REGISTER_OP(_npi_rsubtract_scalar)
-.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, 
mshadow_op::rminus>);
-
-NNVM_REGISTER_OP(_npi_multiply_scalar)
-.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, 
op::mshadow_op::mul>);
-
-NNVM_REGISTER_OP(_npi_mod_scalar)
-.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, 
mshadow_op::mod>);
-
-NNVM_REGISTER_OP(_npi_rmod_scalar)
-.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, 
mshadow_op::rmod>);
-
-NNVM_REGISTER_OP(_npi_power_scalar)
-.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, 
mshadow_op::power>);
-
-NNVM_REGISTER_OP(_npi_rpower_scalar)
-.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, 
mshadow_op::rpower>);
-
 NNVM_REGISTER_OP(_npi_copysign_scalar)
 .set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, 
mshadow_op::copysign>);
 
@@ -144,7 +80,10 @@ NNVM_REGISTER_OP(_backward_npi_rarctan2_scalar)
 .set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, 
mshadow_op::rarctan2_grad>);
 
 NNVM_REGISTER_OP(_npi_lcm_scalar)
-.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, 
mshadow_op::lcm>);
+.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::ComputeInt<gpu, 
mshadow_op::lcm>);
+
+NNVM_REGISTER_OP(_npi_bitwise_xor_scalar)
+.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::ComputeInt<gpu, 
mshadow_op::bitwise_xor>);
 
 NNVM_REGISTER_OP(_npi_ldexp)
 .set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, 
mshadow_op::ldexp>);
diff --git a/src/operator/operator_tune.cc b/src/operator/operator_tune.cc
index 633f630..e2a4c8a 100644
--- a/src/operator/operator_tune.cc
+++ b/src/operator/operator_tune.cc
@@ -396,10 +396,10 @@ 
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::logical_or);  // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::logical_or);  // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::logical_xor);  // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::logical_xor);  // NOLINT()
-IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::bitwise_xor);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::bitwise_xor);  
// NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::smooth_l1_loss);  // 
NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::smooth_l1_gradient);  // 
NOLINT()
-IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::lcm);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::lcm);  // 
NOLINT()
 IMPLEMENT_BLANK_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mxnet_op::set_to_int<0>);  
// NOLINT()
 IMPLEMENT_BLANK_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mxnet_op::set_to_int<1>);  
// NOLINT()
 IMPLEMENT_BLANK_WORKLOAD_FWD(mxnet::op::PopulateFullIdxRspKernel);  // NOLINT()
diff --git a/src/operator/tensor/elemwise_binary_broadcast_op.h 
b/src/operator/tensor/elemwise_binary_broadcast_op.h
index b48ed38..ffd0f12 100644
--- a/src/operator/tensor/elemwise_binary_broadcast_op.h
+++ b/src/operator/tensor/elemwise_binary_broadcast_op.h
@@ -333,6 +333,37 @@ struct csr_dns_map_kernel {
 }  // namespace mxnet_op
 
 template<typename xpu, typename OP>
+void BinaryBroadcastIntCompute(const nnvm::NodeAttrs& attrs,
+                               const OpContext& ctx,
+                               const std::vector<TBlob>& inputs,
+                               const std::vector<OpReqType>& req,
+                               const std::vector<TBlob>& outputs) {
+  if (outputs[0].shape_.Size() == 0U) return;
+  mxnet::TShape new_lshape, new_rshape, new_oshape;
+  int ndim = BinaryBroadcastShapeCompact(inputs[0].shape_, inputs[1].shape_, 
outputs[0].shape_,
+                                         &new_lshape, &new_rshape, 
&new_oshape);
+  if (!ndim) {
+    ElemwiseBinaryOp::ComputeInt<xpu, OP>(attrs, ctx, inputs, req, outputs);
+  } else {
+    if (req[0] == kNullOp) return;
+    mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
+    if (outputs[0].type_flag_ == mshadow::kBool) {
+      LOG(FATAL) << "Operator " << attrs.op->name << " does not support 
boolean type";
+    }
+    MXNET_INT_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+      BROADCAST_NDIM_SWITCH(ndim, NDim, {
+        mshadow::Shape<NDim> oshape = new_oshape.get<NDim>();
+        mshadow::Shape<NDim> lstride = 
mxnet_op::calc_stride(new_lshape.get<NDim>());
+        mshadow::Shape<NDim> rstride = 
mxnet_op::calc_stride(new_rshape.get<NDim>());
+        mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, OP>, xpu>::
+        template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, 
oshape,
+        inputs[0].dptr<DType>(), inputs[1].dptr<DType>(), 
outputs[0].dptr<DType>());
+      });
+    });
+  }
+}
+
+template<typename xpu, typename OP>
 void BinaryBroadcastCompute(const nnvm::NodeAttrs& attrs,
                             const OpContext& ctx,
                             const std::vector<TBlob>& inputs,
@@ -345,22 +376,21 @@ void BinaryBroadcastCompute(const nnvm::NodeAttrs& attrs,
   if (!ndim) {
     ElemwiseBinaryOp::Compute<xpu, OP>(attrs, ctx, inputs, req, outputs);
   } else {
-    if (req[0] != kNullOp) {
-      mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
-      if (outputs[0].type_flag_ == mshadow::kBool) {
-        LOG(FATAL) << "Operator " << attrs.op->name << " does not support 
boolean type";
-      }
-      MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
-        BROADCAST_NDIM_SWITCH(ndim, NDim, {
-          mshadow::Shape<NDim> oshape = new_oshape.get<NDim>();
-          mshadow::Shape<NDim> lstride = 
mxnet_op::calc_stride(new_lshape.get<NDim>());
-          mshadow::Shape<NDim> rstride = 
mxnet_op::calc_stride(new_rshape.get<NDim>());
-          mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, OP>, xpu>::
-          template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, 
oshape,
-          inputs[0].dptr<DType>(), inputs[1].dptr<DType>(), 
outputs[0].dptr<DType>());
-        });
-      });
+    if (req[0] == kNullOp) return;
+    mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
+    if (outputs[0].type_flag_ == mshadow::kBool) {
+      LOG(FATAL) << "Operator " << attrs.op->name << " does not support 
boolean type";
     }
+    MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+      BROADCAST_NDIM_SWITCH(ndim, NDim, {
+        mshadow::Shape<NDim> oshape = new_oshape.get<NDim>();
+        mshadow::Shape<NDim> lstride = 
mxnet_op::calc_stride(new_lshape.get<NDim>());
+        mshadow::Shape<NDim> rstride = 
mxnet_op::calc_stride(new_rshape.get<NDim>());
+        mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, OP>, xpu>::
+        template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, 
oshape,
+        inputs[0].dptr<DType>(), inputs[1].dptr<DType>(), 
outputs[0].dptr<DType>());
+      });
+    });
   }
 }
 
@@ -377,19 +407,18 @@ void BinaryBroadcastComputeWithBool(const 
nnvm::NodeAttrs& attrs,
   if (!ndim) {
     ElemwiseBinaryOp::ComputeWithBool<xpu, OP>(attrs, ctx, inputs, req, 
outputs);
   } else {
-    if (req[0] != kNullOp) {
-      mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
-      MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, DType, {
-        BROADCAST_NDIM_SWITCH(ndim, NDim, {
-          mshadow::Shape<NDim> oshape = new_oshape.get<NDim>();
-          mshadow::Shape<NDim> lstride = 
mxnet_op::calc_stride(new_lshape.get<NDim>());
-          mshadow::Shape<NDim> rstride = 
mxnet_op::calc_stride(new_rshape.get<NDim>());
-          mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, OP>, xpu>::
-          template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, 
oshape,
-          inputs[0].dptr<DType>(), inputs[1].dptr<DType>(), 
outputs[0].dptr<DType>());
-        });
+    if (req[0] == kNullOp) return;
+    mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
+    MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, DType, {
+      BROADCAST_NDIM_SWITCH(ndim, NDim, {
+        mshadow::Shape<NDim> oshape = new_oshape.get<NDim>();
+        mshadow::Shape<NDim> lstride = 
mxnet_op::calc_stride(new_lshape.get<NDim>());
+        mshadow::Shape<NDim> rstride = 
mxnet_op::calc_stride(new_rshape.get<NDim>());
+        mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, OP>, xpu>::
+        template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, 
oshape,
+        inputs[0].dptr<DType>(), inputs[1].dptr<DType>(), 
outputs[0].dptr<DType>());
       });
-    }
+    });
   }
 }
 
@@ -406,20 +435,19 @@ void BinaryBroadcastComputeLogic(const nnvm::NodeAttrs& 
attrs,
   if (!ndim) {
     ElemwiseBinaryOp::ComputeLogic<xpu, OP>(attrs, ctx, inputs, req, outputs);
   } else {
-    if (req[0] != kNullOp) {
-      mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
-      MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, DType, {
-          BROADCAST_NDIM_SWITCH(ndim, NDim, {
-            mshadow::Shape<NDim> oshape = new_oshape.get<NDim>();
-            mshadow::Shape<NDim> lstride = 
mxnet_op::calc_stride(new_lshape.get<NDim>());
-            mshadow::Shape<NDim> rstride = 
mxnet_op::calc_stride(new_rshape.get<NDim>());
-            mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, OP>, 
xpu>::
-            template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, 
oshape,
-                              inputs[0].dptr<DType>(), inputs[1].dptr<DType>(),
-                              outputs[0].dptr<bool>());
-          });
-      });
-    }
+    if (req[0] == kNullOp) return;
+    mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
+    MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, DType, {
+        BROADCAST_NDIM_SWITCH(ndim, NDim, {
+          mshadow::Shape<NDim> oshape = new_oshape.get<NDim>();
+          mshadow::Shape<NDim> lstride = 
mxnet_op::calc_stride(new_lshape.get<NDim>());
+          mshadow::Shape<NDim> rstride = 
mxnet_op::calc_stride(new_rshape.get<NDim>());
+          mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, OP>, xpu>::
+          template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, 
oshape,
+                            inputs[0].dptr<DType>(), inputs[1].dptr<DType>(),
+                            outputs[0].dptr<bool>());
+        });
+    });
   }
 }
 
@@ -672,6 +700,32 @@ BinaryBroadcastBackwardUseNone(const nnvm::NodeAttrs& 
attrs,
                                const std::vector<TBlob>& outputs);
 
 template<typename xpu, int ndim, typename DType, typename LOP, typename ROP>
+void BinaryBroadcastBackwardUseInImplWithWorkspace(const OpContext& ctx,
+                                                   const std::vector<TBlob>& 
inputs,
+                                                   const 
std::vector<OpReqType>& req,
+                                                   const std::vector<TBlob>& 
outputs,
+                                                   const mshadow::Tensor<xpu, 
1, char>& workspace,
+                                                   const mxnet::TShape& 
new_lshape,
+                                                   const mxnet::TShape& 
new_rshape,
+                                                   const mxnet::TShape& 
new_oshape) {
+  using namespace mshadow;
+  using namespace mshadow::expr;
+  using namespace broadcast;
+  Stream<xpu> *s = ctx.get_stream<xpu>();
+  const TBlob lgrad = outputs[0].reshape(new_lshape);
+  const TBlob rgrad = outputs[1].reshape(new_rshape);
+  const TBlob ograd = inputs[0].reshape(new_oshape);
+  const TBlob lhs = inputs[1].reshape(new_lshape);
+  const TBlob rhs = inputs[2].reshape(new_rshape);
+  if (ograd.Size() != 0) {
+    Reduce<red::sum, ndim, DType, op::mshadow_op::mul, LOP>(s, lgrad, req[0], 
workspace,
+      ograd, lhs, rhs);
+    Reduce<red::sum, ndim, DType, op::mshadow_op::mul, ROP>(s, rgrad, req[1], 
workspace,
+      ograd, lhs, rhs);
+  }
+}
+
+template<typename xpu, int ndim, typename DType, typename LOP, typename ROP>
 inline void BinaryBroadcastBackwardUseInImpl(const OpContext& ctx,
                                              const std::vector<TBlob>& inputs,
                                              const std::vector<OpReqType>& req,
diff --git a/src/operator/tensor/elemwise_binary_op.h 
b/src/operator/tensor/elemwise_binary_op.h
index c046a28..bc5140a 100644
--- a/src/operator/tensor/elemwise_binary_op.h
+++ b/src/operator/tensor/elemwise_binary_op.h
@@ -475,31 +475,54 @@ class ElemwiseBinaryOp : public OpBase {
                                        std::vector<int> *out_attrs);
 
   template<typename xpu, typename OP>
+  static void ComputeInt(const nnvm::NodeAttrs &attrs,
+                         const OpContext &ctx,
+                         const std::vector<TBlob> &inputs,
+                         const std::vector<OpReqType> &req,
+                         const std::vector<TBlob> &outputs) {
+    using namespace mxnet_op;
+    if (req[0] == kNullOp) return;
+    Stream<xpu> *s = ctx.get_stream<xpu>();
+    CHECK_EQ(inputs.size(), 2U);
+    CHECK_EQ(outputs.size(), 1U);
+    MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
+      MXNET_INT_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+        const size_t size = (minthree(outputs[0].Size(), inputs[0].Size(), 
inputs[1].Size())
+        + DataType<DType>::kLanes - 1) / DataType<DType>::kLanes;
+        if (size != 0) {
+          Kernel<mxnet_op::op_with_req<OP, Req>, xpu>::Launch(s, size,
+          outputs[0].dptr<DType>(),
+          inputs[0].dptr<DType>(), inputs[1].dptr<DType>());
+        }
+      });
+    });
+  }
+
+  template<typename xpu, typename OP>
   static void Compute(const nnvm::NodeAttrs &attrs,
                       const OpContext &ctx,
                       const std::vector<TBlob> &inputs,
                       const std::vector<OpReqType> &req,
                       const std::vector<TBlob> &outputs) {
     using namespace mxnet_op;
-    if (req[0] != kNullOp) {
-      Stream<xpu> *s = ctx.get_stream<xpu>();
-      CHECK_EQ(inputs.size(), 2U);
-      CHECK_EQ(outputs.size(), 1U);
-      if (outputs[0].type_flag_ == mshadow::kBool) {
-        LOG(FATAL) << "Operator " << attrs.op->name << " does not support 
boolean type";
-      }
-      MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
-        MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
-          const size_t size = (minthree(outputs[0].Size(), inputs[0].Size(), 
inputs[1].Size())
-          + DataType<DType>::kLanes - 1) / DataType<DType>::kLanes;
-          if (size != 0) {
-            Kernel<mxnet_op::op_with_req<OP, Req>, xpu>::Launch(s, size,
-            outputs[0].dptr<DType>(),
-            inputs[0].dptr<DType>(), inputs[1].dptr<DType>());
-          }
-        });
-      });
+    if (req[0] == kNullOp) return;
+    Stream<xpu> *s = ctx.get_stream<xpu>();
+    CHECK_EQ(inputs.size(), 2U);
+    CHECK_EQ(outputs.size(), 1U);
+    if (outputs[0].type_flag_ == mshadow::kBool) {
+      LOG(FATAL) << "Operator " << attrs.op->name << " does not support 
boolean type";
     }
+    MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
+      MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+        const size_t size = (minthree(outputs[0].Size(), inputs[0].Size(), 
inputs[1].Size())
+        + DataType<DType>::kLanes - 1) / DataType<DType>::kLanes;
+        if (size != 0) {
+          Kernel<mxnet_op::op_with_req<OP, Req>, xpu>::Launch(s, size,
+          outputs[0].dptr<DType>(),
+          inputs[0].dptr<DType>(), inputs[1].dptr<DType>());
+        }
+      });
+    });
   }
 
   template<typename xpu, typename OP>
@@ -509,22 +532,21 @@ class ElemwiseBinaryOp : public OpBase {
                               const std::vector<OpReqType> &req,
                               const std::vector<TBlob> &outputs) {
     using namespace mxnet_op;
-    if (req[0] != kNullOp) {
-      Stream<xpu> *s = ctx.get_stream<xpu>();
-      CHECK_EQ(inputs.size(), 2U);
-      CHECK_EQ(outputs.size(), 1U);
-      MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
-        MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, DType, {
-          const size_t size = (minthree(outputs[0].Size(), inputs[0].Size(), 
inputs[1].Size())
-          + DataType<DType>::kLanes - 1) / DataType<DType>::kLanes;
-          if (size != 0) {
-            Kernel<mxnet_op::op_with_req<OP, Req>, xpu>::Launch(s, size,
-            outputs[0].dptr<DType>(),
-            inputs[0].dptr<DType>(), inputs[1].dptr<DType>());
-          }
-        });
+    if (req[0] == kNullOp) return;
+    Stream<xpu> *s = ctx.get_stream<xpu>();
+    CHECK_EQ(inputs.size(), 2U);
+    CHECK_EQ(outputs.size(), 1U);
+    MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
+      MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, DType, {
+        const size_t size = (minthree(outputs[0].Size(), inputs[0].Size(), 
inputs[1].Size())
+        + DataType<DType>::kLanes - 1) / DataType<DType>::kLanes;
+        if (size != 0) {
+          Kernel<mxnet_op::op_with_req<OP, Req>, xpu>::Launch(s, size,
+          outputs[0].dptr<DType>(),
+          inputs[0].dptr<DType>(), inputs[1].dptr<DType>());
+        }
       });
-    }
+    });
   }
 
   template<typename xpu, typename OP>
@@ -534,23 +556,22 @@ class ElemwiseBinaryOp : public OpBase {
                            const std::vector<OpReqType> &req,
                            const std::vector<TBlob> &outputs) {
     using namespace mxnet_op;
-    if (req[0] != kNullOp) {
-      Stream<xpu> *s = ctx.get_stream<xpu>();
-      CHECK_EQ(inputs.size(), 2U);
-      CHECK_EQ(outputs.size(), 1U);
-      MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
-        MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, DType, {
-            const size_t size = (minthree(outputs[0].Size(), inputs[0].Size(), 
inputs[1].Size())
-            + DataType<DType>::kLanes - 1) / DataType<DType>::kLanes;
-            if (size != 0) {
-              Kernel<mxnet_op::op_with_req<OP, Req>, xpu>::Launch(s, size,
-                                                                  
outputs[0].dptr<bool>(),
-                                                                  
inputs[0].dptr<DType>(),
-                                                                  
inputs[1].dptr<DType>());
-            }
-        });
+    if (req[0] == kNullOp) return;
+    Stream<xpu> *s = ctx.get_stream<xpu>();
+    CHECK_EQ(inputs.size(), 2U);
+    CHECK_EQ(outputs.size(), 1U);
+    MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
+      MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, DType, {
+          const size_t size = (minthree(outputs[0].Size(), inputs[0].Size(), 
inputs[1].Size())
+          + DataType<DType>::kLanes - 1) / DataType<DType>::kLanes;
+          if (size != 0) {
+            Kernel<mxnet_op::op_with_req<OP, Req>, xpu>::Launch(s, size,
+                                                                
outputs[0].dptr<bool>(),
+                                                                
inputs[0].dptr<DType>(),
+                                                                
inputs[1].dptr<DType>());
+          }
       });
-    }
+    });
   }
 
   template<typename xpu, typename OP>
@@ -560,22 +581,21 @@ class ElemwiseBinaryOp : public OpBase {
                                const std::vector<OpReqType> &req,
                                const std::vector<TBlob> &outputs) {
     using namespace mxnet_op;
-    if (req[0] != kNullOp) {
-      Stream<xpu> *s = ctx.get_stream<xpu>();
-      CHECK_EQ(inputs.size(), 2U);
-      CHECK_EQ(outputs.size(), 1U);
-      MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
-        MSHADOW_TYPE_SWITCH_WITH_HALF2(outputs[0].type_flag_, DType, {
-          const size_t size = (minthree(outputs[0].Size(), inputs[0].Size(), 
inputs[1].Size())
-          + DataType<DType>::kLanes - 1) / DataType<DType>::kLanes;
-          if (size != 0) {
-            Kernel<mxnet_op::op_with_req<OP, Req>, xpu>::Launch(s, size,
-            outputs[0].dptr<DType>(),
-            inputs[0].dptr<DType>(), inputs[1].dptr<DType>());
-          }
-        });
+    if (req[0] == kNullOp) return;
+    Stream<xpu> *s = ctx.get_stream<xpu>();
+    CHECK_EQ(inputs.size(), 2U);
+    CHECK_EQ(outputs.size(), 1U);
+    MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
+      MSHADOW_TYPE_SWITCH_WITH_HALF2(outputs[0].type_flag_, DType, {
+        const size_t size = (minthree(outputs[0].Size(), inputs[0].Size(), 
inputs[1].Size())
+        + DataType<DType>::kLanes - 1) / DataType<DType>::kLanes;
+        if (size != 0) {
+          Kernel<mxnet_op::op_with_req<OP, Req>, xpu>::Launch(s, size,
+          outputs[0].dptr<DType>(),
+          inputs[0].dptr<DType>(), inputs[1].dptr<DType>());
+        }
       });
-    }
+    });
   }
 
   template<typename xpu, typename OP>
diff --git a/src/operator/tensor/elemwise_binary_scalar_op.h 
b/src/operator/tensor/elemwise_binary_scalar_op.h
index 834bbdb..3e87028 100644
--- a/src/operator/tensor/elemwise_binary_scalar_op.h
+++ b/src/operator/tensor/elemwise_binary_scalar_op.h
@@ -245,6 +245,26 @@ class BinaryScalarOp : public UnaryOp {
   }
 
   template<typename xpu, typename OP>
+  static void ComputeInt(const nnvm::NodeAttrs &attrs,
+                         const OpContext &ctx,
+                         const std::vector<TBlob> &inputs,
+                         const std::vector<OpReqType> &req,
+                         const std::vector<TBlob> &outputs) {
+    DCHECK_EQ(inputs.size(), 1);
+    DCHECK_EQ(outputs.size(), 1);
+    using namespace mshadow;
+    using namespace mshadow::expr;
+    Stream<xpu> *s = ctx.get_stream<xpu>();
+    const double alpha = nnvm::get<double>(attrs.parsed);
+    MXNET_INT_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+      MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
+        mxnet_op::Kernel<mxnet_op::op_with_req<OP, Req>, xpu>::Launch(
+          s, inputs[0].Size(), outputs[0].dptr<DType>(), 
inputs[0].dptr<DType>(), DType(alpha));
+      });
+    });
+  }
+
+  template<typename xpu, typename OP>
   static void ComputeLogic(const nnvm::NodeAttrs &attrs,
                       const OpContext &ctx,
                       const std::vector<TBlob> &inputs,
diff --git a/src/operator/tensor/elemwise_unary_op.h 
b/src/operator/tensor/elemwise_unary_op.h
index 188ccd6..8886e15 100644
--- a/src/operator/tensor/elemwise_unary_op.h
+++ b/src/operator/tensor/elemwise_unary_op.h
@@ -453,8 +453,8 @@ void CastCompute(const nnvm::NodeAttrs& attrs,
     Tensor<xpu, 1, DstDType> out = outputs[0].FlatTo1D<xpu, DstDType>(s);
     MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, SrcDType, {
       Tensor<xpu, 1, SrcDType> data = inputs[0].FlatTo1D<xpu, SrcDType>(s);
-      if (outputs[0].type_flag_ != inputs[0].type_flag_ ||
-          req[0] != kWriteInplace) {
+      if ((outputs[0].type_flag_ != inputs[0].type_flag_ ||
+          req[0] != kWriteInplace) && outputs[0].Size() != 0) {
         Assign(out, req[0], tcast<DstDType>(data));
       }
     });
diff --git a/tests/python/unittest/test_gluon_estimator.py 
b/tests/python/unittest/test_gluon_estimator.py
index aaf9839..cf913a6 100644
--- a/tests/python/unittest/test_gluon_estimator.py
+++ b/tests/python/unittest/test_gluon_estimator.py
@@ -83,13 +83,15 @@ def test_validation():
     ctx = mx.cpu()
     loss = gluon.loss.L2Loss()
     acc = mx.metric.Accuracy()
+    evaluation_loss = gluon.loss.L1Loss()
     net.initialize(ctx=ctx)
     trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 
0.001})
     est = Estimator(net=net,
                     loss=loss,
                     metrics=acc,
                     trainer=trainer,
-                    context=ctx)
+                    context=ctx,
+                    evaluation_loss=evaluation_loss)
     # Input dataloader
     est.fit(train_data=dataloader,
             val_data=dataloader,
diff --git a/tests/python/unittest/test_numpy_interoperability.py 
b/tests/python/unittest/test_numpy_interoperability.py
index 8416b1a..5b6cea7 100644
--- a/tests/python/unittest/test_numpy_interoperability.py
+++ b/tests/python/unittest/test_numpy_interoperability.py
@@ -788,6 +788,18 @@ def _add_workload_lcm():
     OpArgMngr.add_workload('lcm', np.array(195225786*2, dtype=np.int32), 
np.array(195225786*5, dtype=np.int32))
 
 
+def _add_workload_bitwise_xor():
+    OpArgMngr.add_workload('bitwise_xor', np.array([False, False, True, True], 
dtype=np.bool),
+                           np.array([False, True, False, True], dtype=np.bool))
+    for dtype in [np.int8, np.int32, np.int64]:
+        zeros = np.array([0], dtype=dtype)
+        ones = np.array([-1], dtype=dtype)
+        OpArgMngr.add_workload('bitwise_xor', zeros, zeros)
+        OpArgMngr.add_workload('bitwise_xor', ones, zeros)
+        OpArgMngr.add_workload('bitwise_xor', zeros, ones)
+        OpArgMngr.add_workload('bitwise_xor', ones, ones)
+
+
 def _add_workload_ldexp():
     OpArgMngr.add_workload('ldexp', np.array(2., np.float32), np.array(3, 
np.int8))
     OpArgMngr.add_workload('ldexp', np.array(2., np.float64), np.array(3, 
np.int8))
@@ -1194,6 +1206,7 @@ def _prepare_workloads():
     _add_workload_inner()
     _add_workload_hypot()
     _add_workload_lcm()
+    _add_workload_bitwise_xor()
     _add_workload_ldexp()
     _add_workload_subtract(array_pool)
     _add_workload_multiply(array_pool)
diff --git a/tests/python/unittest/test_numpy_op.py 
b/tests/python/unittest/test_numpy_op.py
index 643b9c1..3aef5ca 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -1649,6 +1649,7 @@ def test_np_binary_funcs():
         'power': (1.0, 2.0, [lambda y, x1, x2: _np.power(x1, x2 - 1.0) * x2],
                              [lambda y, x1, x2: _np.power(x1, x2) * 
_np.log(x1)]),
         'lcm': (-100, 100, [None], None, [[_np.int32]]),
+        'bitwise_xor': (-100, 100, [None], None, [[_np.int32]]),
         'maximum': (-1, 1, [lambda y, x1, x2: _np.ones(y.shape) * (x1 >= x2)],
                            [lambda y, x1, x2: _np.ones(y.shape) * (x1 < x2)]),
         'minimum': (-1, 1, [lambda y, x1, x2: _np.ones(y.shape) * (x1 <= x2)],
@@ -1683,7 +1684,9 @@ def test_np_binary_funcs():
 @with_seed()
 @use_np
 def test_np_mixed_precision_binary_funcs():
-    def check_mixed_precision_binary_func(func, low, high, lshape, rshape, 
ltype, rtype):
+    itypes = [np.bool, np.int8, np.int32, np.int64]
+    ftypes = [np.float16, np.float32, np.float64]
+    def check_mixed_precision_binary_func(func, low, high, lshape, rshape, 
lgrad, rgrad, ltype, rtype):
         class TestMixedBinary(HybridBlock):
             def __init__(self, func):
                 super(TestMixedBinary, self).__init__()
@@ -1717,13 +1720,15 @@ def test_np_mixed_precision_binary_funcs():
                             use_broadcast=False, equal_nan=True)
 
     funcs = {
-        'add': (-1.0, 1.0),
-        'subtract': (-1.0, 1.0),
-        'multiply': (-1.0, 1.0),
+        'add': (-1.0, 1.0, None, None),
+        'subtract': (-1.0, 1.0, None, None),
+        'multiply': (-1.0, 1.0, lambda y, x1, x2: _np.broadcast_to(x2, 
y.shape),
+                                lambda y, x1, x2: _np.broadcast_to(x1, 
y.shape))
     }
 
     shape_pairs = [((3, 2), (3, 2)),
                    ((3, 2), (3, 1)),
+                   ((3, 0), (3, 0)),
                    ((3, 1), (3, 0)),
                    ((0, 2), (1, 2)),
                    ((2, 3, 4), (3, 1)),
@@ -1733,16 +1738,16 @@ def test_np_mixed_precision_binary_funcs():
     itypes = [np.bool, np.int8, np.int32, np.int64]
     ftypes = [np.float16, np.float32, np.float64]
     for func, func_data in funcs.items():
-        low, high = func_data
+        low, high, lgrad, rgrad = func_data
         for lshape, rshape in shape_pairs:
             for type1, type2 in itertools.product(itypes, ftypes):
-                check_mixed_precision_binary_func(func, low, high, lshape, 
rshape, type1, type2)
-                check_mixed_precision_binary_func(func, low, high, lshape, 
rshape, type2, type1)
+                check_mixed_precision_binary_func(func, low, high, lshape, 
rshape, lgrad, rgrad, type1, type2)
+                check_mixed_precision_binary_func(func, low, high, lshape, 
rshape, lgrad, rgrad, type2, type1)
 
             for type1, type2 in itertools.product(ftypes, ftypes):
                 if type1 == type2:
                     continue
-                check_mixed_precision_binary_func(func, low, high, lshape, 
rshape, type1, type2)
+                check_mixed_precision_binary_func(func, low, high, lshape, 
rshape, lgrad, rgrad, type1, type2)
 
 
 @with_seed()
@@ -4102,7 +4107,7 @@ def test_np_diff():
                         mx_out.backward()
                         if (np_out.size == 0):
                             np_backward = _np.zeros(shape)
-                        else:                    
+                        else:
                             np_backward = 
np_diff_backward(_np.ones(np_out.shape, dtype=itype), n=n, axis=axis)
                         assert x.grad.shape == np_backward.shape
                         assert_almost_equal(x.grad.asnumpy(), np_backward, 
rtol=rtol, atol=atol)

Reply via email to