This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 7926a5d  [Relay][Op] Add unbiased variance op and corresponding 
support in pytorch frontend (#6232)
7926a5d is described below

commit 7926a5da09a7098507d365b69524f7675e54970c
Author: shiwenloong <shiwenlo...@outlook.com>
AuthorDate: Mon Aug 10 14:07:43 2020 +0800

    [Relay][Op] Add unbiased variance op and corresponding support in pytorch 
frontend (#6232)
---
 include/tvm/relay/attrs/reduce.h              | 31 ++++++++++++++++++++++++
 python/tvm/relay/frontend/pytorch.py          | 25 ++++++++-----------
 python/tvm/relay/op/_tensor_grad.py           | 15 +++++++++---
 python/tvm/relay/op/op_attrs.py               |  5 ++++
 python/tvm/relay/op/reduce.py                 | 18 +++++++++-----
 src/relay/op/make_op.h                        |  3 ++-
 src/relay/op/tensor/reduce.cc                 | 31 +++++++++++++++++-------
 src/relay/transforms/pattern_util.h           |  5 ++--
 tests/python/frontend/pytorch/test_forward.py | 35 +++++++++++++++++++++++++++
 tests/python/relay/test_op_grad_level4.py     |  5 +++-
 tests/python/relay/test_op_level4.py          | 12 +++++++++
 11 files changed, 147 insertions(+), 38 deletions(-)

diff --git a/include/tvm/relay/attrs/reduce.h b/include/tvm/relay/attrs/reduce.h
index f57c1f4..14b75ff 100644
--- a/include/tvm/relay/attrs/reduce.h
+++ b/include/tvm/relay/attrs/reduce.h
@@ -60,6 +60,37 @@ struct ReduceAttrs : public tvm::AttrsNode<ReduceAttrs> {
         "Whether to perform reduction on axis that are NOT in axis instead.");
   }
 };
+
+struct VarianceAttrs : public tvm::AttrsNode<VarianceAttrs> {
+  Array<Integer> axis;
+  bool keepdims;
+  bool exclude;
+  bool unbiased;
+
+  TVM_DECLARE_ATTRS(VarianceAttrs, "relay.attrs.VarianceAttrs") {
+    TVM_ATTR_FIELD(axis)
+        .set_default(NullValue<Array<Integer>>())
+        .describe(R"code(The axis or axes along which to perform the reduction.
+
+      The default, `axis=()`, will compute over all elements into a
+      scalar array with shape `(1,)`.
+
+      If `axis` is int, a reduction is performed on a particular axis.
+
+      If `axis` is a tuple of ints, a reduction is performed on all the axes
+      specified in the tuple.
+
+      If `exclude` is true, reduction will be performed on the axes that are
+      NOT in axis instead.)code");
+
+    TVM_ATTR_FIELD(keepdims).set_default(false).describe(
+        "If this is set to `True`, the reduced axes are left "
+        "in the result as dimension with size one.");
+    TVM_ATTR_FIELD(exclude).set_default(false).describe(
+        "Whether to perform reduction on axis that are NOT in axis instead.");
+    TVM_ATTR_FIELD(unbiased).set_default(false).describe("Whether to use the 
unbiased estimation.");
+  }
+};
 }  // namespace relay
 }  // namespace tvm
 #endif  // TVM_RELAY_ATTRS_REDUCE_H_
diff --git a/python/tvm/relay/frontend/pytorch.py 
b/python/tvm/relay/frontend/pytorch.py
index bbc684e..a1cabcd 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -1262,28 +1262,23 @@ def _std():
             keepdims = bool(inputs[3])
             unbiased = bool(inputs[2])
 
-        if unbiased:
-            msg = "Currently only supports standard-deviation calculated via 
the biased "\
-                  "estimator. PyTorch's Bessel's correction is not supported."
-            raise NotImplementedError(msg)
-
-        return _op.reduce.std(data, axis=axis, keepdims=keepdims)
+        return _op.reduce.std(data, axis=axis, keepdims=keepdims, 
unbiased=unbiased)
 
     return _impl
 
 def _variance():
     def _impl(inputs, input_types):
         data = inputs[0]
-        axis = list(_infer_shape(inputs[1]))
-        keepdims = bool(inputs[3])
-        unbiased = bool(inputs[2])
-
-        if unbiased:
-            msg = "Currently only supports standard-deviation calculated via 
the biased "\
-                  "estimator. PyTorch's Bessel's correction is not supported."
-            raise NotImplementedError(msg)
+        if len(inputs) == 2:
+            axis = None
+            keepdims = False
+            unbiased = bool(inputs[1])
+        else:
+            axis = list(_infer_shape(inputs[1]))
+            keepdims = bool(inputs[3])
+            unbiased = bool(inputs[2])
 
-        return _op.reduce.variance(data, axis=axis, keepdims=keepdims)
+        return _op.reduce.variance(data, axis=axis, keepdims=keepdims, 
unbiased=unbiased)
 
     return _impl
 
diff --git a/python/tvm/relay/op/_tensor_grad.py 
b/python/tvm/relay/op/_tensor_grad.py
index aee8603..46a4535 100644
--- a/python/tvm/relay/op/_tensor_grad.py
+++ b/python/tvm/relay/op/_tensor_grad.py
@@ -589,16 +589,23 @@ def mean_grad(orig, grad):
 def variance_grad(orig, grad):
     """Note that we take mean as an argument in the variance node"""
     data, data_mean, axis = orig.args[0], orig.args[1], _get_reduce_axis(orig)
+    unbiased = orig.attrs.unbiased
     shape = data.checked_type.concrete_shape
     if axis is None:
         axis = list(range(len(data.checked_type.concrete_shape)))
     if not orig.attrs.keepdims:
         grad = _unreduce_expand(grad, axis)
-    mult = 2.0
+    mult1 = 2.0
+    mult2 = -2.0
+    count = 1
     for a in axis:
-        mult /= shape[a]
-    return [(grad * const(mult, dtype=data.checked_type.dtype)) * data,
-            const(-2, dtype=data.checked_type.dtype) * grad * data_mean]
+        count *= shape[a]
+    if unbiased:
+        mult2 = mult2 * count / (count - 1)
+        count -= 1
+    mult1 /= count
+    return [(grad * const(mult1, dtype=data.checked_type.dtype)) * data,
+            const(mult2, dtype=data.checked_type.dtype) * grad * data_mean]
 
 
 @register_gradient("copy")
diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py
index 32540a5..7f91989 100644
--- a/python/tvm/relay/op/op_attrs.py
+++ b/python/tvm/relay/op/op_attrs.py
@@ -460,6 +460,11 @@ class ReduceAttrs(Attrs):
     """Attributes used in reduction operators (e.g. sum)"""
 
 
+@tvm._ffi.register_object("relay.attrs.VarianceAttrs")
+class VarianceAttrs(Attrs):
+    """Attributes used in reduction operators (e.g. sum)"""
+
+
 @tvm._ffi.register_object("relay.attrs.RequantizeAttrs")
 class RequantizeAttrs(Attrs):
     """Attributes used in requantize operators"""
diff --git a/python/tvm/relay/op/reduce.py b/python/tvm/relay/op/reduce.py
index 988c949..99189f8 100644
--- a/python/tvm/relay/op/reduce.py
+++ b/python/tvm/relay/op/reduce.py
@@ -312,7 +312,7 @@ def mean(data, axis=None, keepdims=False, exclude=False):
     return _make.mean(data, axis, keepdims, exclude)
 
 
-def variance(data, axis=None, keepdims=False, exclude=False):
+def variance(data, axis=None, keepdims=False, exclude=False, unbiased=False):
     """Computes the variance of data over given axes.
 
     Parameters
@@ -334,6 +334,9 @@ def variance(data, axis=None, keepdims=False, 
exclude=False):
         If `exclude` is true, reduction will be performed on the axes that are
         NOT in axis instead.
 
+    unbiased : bool
+        If this is set to True, the unbiased estimation will be used.
+
     Returns
     -------
     result : relay.Expr
@@ -341,10 +344,10 @@ def variance(data, axis=None, keepdims=False, 
exclude=False):
     """
     axis = [axis] if isinstance(axis, int) else axis
     m = mean(data, axis, True, exclude)
-    return _make._variance(data, m, axis, keepdims, exclude)
+    return _make._variance(data, m, axis, keepdims, exclude, unbiased)
 
 
-def std(data, axis=None, keepdims=False, exclude=False):
+def std(data, axis=None, keepdims=False, exclude=False, unbiased=False):
     """Computes the standard deviation of data over given axes.
 
     Parameters
@@ -366,6 +369,9 @@ def std(data, axis=None, keepdims=False, exclude=False):
         If `exclude` is true, reduction will be performed on the axes that are
         NOT in axis instead.
 
+    unbiased : bool
+        If this is set to True, the unbiased estimation will be used.
+
     Returns
     -------
     result : relay.Expr
@@ -373,7 +379,7 @@ def std(data, axis=None, keepdims=False, exclude=False):
     """
     axis = [axis] if isinstance(axis, int) else axis
     m = mean(data, axis, True, exclude)
-    return sqrt(_make._variance(data, m, axis, keepdims, exclude))
+    return sqrt(_make._variance(data, m, axis, keepdims, exclude, unbiased))
 
 
 def mean_variance(data, axis=None, keepdims=False, exclude=False):
@@ -405,7 +411,7 @@ def mean_variance(data, axis=None, keepdims=False, 
exclude=False):
     """
     axis = [axis] if isinstance(axis, int) else axis
     m = mean(data, axis, True, exclude)
-    var = _make._variance(data, m, axis, keepdims, exclude)
+    var = _make._variance(data, m, axis, keepdims, exclude, False)
     if not keepdims:
         m = squeeze(m)
     return TupleWrapper(Tuple((m, var)), 2)
@@ -440,7 +446,7 @@ def mean_std(data, axis=None, keepdims=False, 
exclude=False):
     """
     axis = [axis] if isinstance(axis, int) else axis
     m = mean(data, axis, True, exclude)
-    s = sqrt(_make._variance(data, m, axis, keepdims, exclude))
+    s = sqrt(_make._variance(data, m, axis, keepdims, exclude, False))
     if not keepdims:
         m = squeeze(m)
     return TupleWrapper(Tuple((m, s)), 2)
diff --git a/src/relay/op/make_op.h b/src/relay/op/make_op.h
index c03a7bf..8ca2203 100644
--- a/src/relay/op/make_op.h
+++ b/src/relay/op/make_op.h
@@ -74,7 +74,8 @@ Expr MakeTile(Expr data, Array<Integer> reps);
 
 Expr MakeTopK(Expr data, int k, int axis, String ret_type, bool is_ascend, 
DataType dtype);
 
-Expr MakeVariance(Expr data, Expr mean, Array<Integer> axis, bool keepdims, 
bool exclude);
+Expr MakeVariance(Expr data, Expr mean, Array<Integer> axis, bool keepdims, 
bool exclude,
+                  bool unbiased);
 
 Expr MakeZeros(Array<Integer> shape, DataType dtype);
 
diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc
index 9fd1400..e16ecb6 100644
--- a/src/relay/op/tensor/reduce.cc
+++ b/src/relay/op/tensor/reduce.cc
@@ -38,6 +38,7 @@ namespace tvm {
 namespace relay {
 
 TVM_REGISTER_NODE_TYPE(ReduceAttrs);
+TVM_REGISTER_NODE_TYPE(VarianceAttrs);
 
 /*!
  * \brief GetReduceAxes, get the new axis from indim and other arguments
@@ -193,12 +194,14 @@ Array<te::Tensor> ReduceCompute(const Attrs& attrs, const 
Array<te::Tensor>& inp
 /*!
  * \brief ReduceShapeImpl get the outshape for the reduction operator
  * \param in_shape Shape of input data.
- * \param param ReduceAttrs details.
+ * \param param Attrs details.
  * \param reporter The reporter to report solution to.
  * \return oshape Output shape inferred.
+ * \tparam AttrsType The attribute type.
  */
+template <typename AttrsType>
 inline std::vector<IndexExpr> ReduceShapeImpl(const std::vector<IndexExpr>& 
in_shape,
-                                              const ReduceAttrs* param,
+                                              const AttrsType* param,
                                               const TypeReporter& reporter) {
   uint32_t indim = in_shape.size();
   auto r_axes = GetReduceAxes(indim, param->axis, param->exclude);
@@ -542,7 +545,7 @@ bool VarianceRel(const Array<Type>& types, int num_inputs, 
const Attrs& attrs,
   std::vector<IndexExpr> mean_shape(mean->shape.begin(), mean->shape.end());
   CHECK_EQ(in_shape.size(), mean_shape.size());
 
-  const ReduceAttrs* param = attrs.as<ReduceAttrs>();
+  const VarianceAttrs* param = attrs.as<VarianceAttrs>();
   CHECK(param != nullptr);
 
   // assign output type and shape
@@ -554,39 +557,49 @@ bool VarianceRel(const Array<Type>& types, int 
num_inputs, const Attrs& attrs,
 Array<te::Tensor> VarianceCompute(const Attrs& attrs, const Array<te::Tensor>& 
inputs,
                                   const Type& out_type) {
   IndexExpr count = tir::make_const(inputs[0]->dtype, 1);
-  const ReduceAttrs* param = attrs.as<ReduceAttrs>();
+  const VarianceAttrs* param = attrs.as<VarianceAttrs>();
   CHECK(param != nullptr);
   auto axes = param->axis;
+  bool unbiased = param->unbiased;
   auto data = inputs[0];
   auto mean = inputs[1];
   for (int64_t i : GetReduceAxes(data->shape.size(), param->axis, 
param->exclude)) {
     count *= data->shape[i];
   }
+  if (unbiased) {
+    count -= 1;
+  }
   std::vector<Integer> expand_shape;
   auto sq_diff = topi::power(topi::subtract(data, mean), 2);
-  auto var = topi::divide(ReduceCompute(attrs, {sq_diff}, out_type, 
topi::sum)[0], count);
+  if (param->exclude) {
+    axes = GetExcludeAxes(sq_diff->shape.size(), param->axis);
+    CHECK_NE(axes.size(), 0);
+  }
+  auto var = topi::divide(topi::sum(sq_diff, axes, param->keepdims, false), 
count);
 
   return {var};
 }
 
-Expr MakeVariance(Expr data, Expr mean, Array<Integer> axis, bool keepdims, 
bool exclude) {
-  auto attrs = make_object<ReduceAttrs>();
+Expr MakeVariance(Expr data, Expr mean, Array<Integer> axis, bool keepdims, 
bool exclude,
+                  bool unbiased = false) {
+  auto attrs = make_object<VarianceAttrs>();
   attrs->axis = std::move(axis);
   attrs->keepdims = keepdims;
   attrs->exclude = exclude;
+  attrs->unbiased = unbiased;
   static const Op& op = Op::Get("variance");
   return Call(op, {data, mean}, Attrs(attrs), {});
 }
 
 TVM_REGISTER_GLOBAL("relay.op._make._variance").set_body([](const TVMArgs& 
args, TVMRetValue* rv) {
-  runtime::detail::unpack_call<Expr, 5>(MakeVariance, args, rv);
+  runtime::detail::unpack_call<Expr, 6>(MakeVariance, args, rv);
 });
 
 RELAY_REGISTER_OP("variance")
     .describe(R"code(Computes the variance of array elements over given axes.
 
 )code" TVM_ADD_FILELINE)
-    .set_attrs_type<ReduceAttrs>()
+    .set_attrs_type<VarianceAttrs>()
     .set_support_level(4)
     .set_num_inputs(2)
     .add_argument("data", "Tensor", "The input tensor.")
diff --git a/src/relay/transforms/pattern_util.h 
b/src/relay/transforms/pattern_util.h
index b3e3681..ee65503 100644
--- a/src/relay/transforms/pattern_util.h
+++ b/src/relay/transforms/pattern_util.h
@@ -580,8 +580,9 @@ inline Expr Mean(Expr data, Array<Integer> axis, bool 
keepdims, bool exclude) {
   return MakeReduce(data, axis, keepdims, exclude, "mean");
 }
 
-inline Expr Variance(Expr data, Expr mean, Array<Integer> axis, bool keepdims, 
bool exclude) {
-  return MakeVariance(data, mean, axis, keepdims, exclude);
+inline Expr Variance(Expr data, Expr mean, Array<Integer> axis, bool keepdims, 
bool exclude,
+                     bool unbiased = false) {
+  return MakeVariance(data, mean, axis, keepdims, exclude, unbiased);
 }
 
 static inline Expr Where(const Expr& condition, const Expr& x, const Expr& y) {
diff --git a/tests/python/frontend/pytorch/test_forward.py 
b/tests/python/frontend/pytorch/test_forward.py
index 3c9dfb1..ae03a70 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -1873,6 +1873,18 @@ def test_forward_std():
         def forward(self, *args):
             return args[0].std(unbiased=False)
 
+    class Std7(Module):
+        def forward(self, *args):
+            return args[0].std(dim=1, keepdim=False, unbiased=True)
+
+    class Std8(Module):
+        def forward(self, *args):
+            return args[0].std(dim=(2,3), keepdim=True, unbiased=True)
+
+    class Std9(Module):
+        def forward(self, *args):
+            return args[0].std(unbiased=True)
+
     input_data = torch.rand(input_shape).float()
     verify_model(Std1().float().eval(), input_data=input_data)
     verify_model(Std2().float().eval(), input_data=input_data)
@@ -1880,6 +1892,9 @@ def test_forward_std():
     verify_model(Std4().float().eval(), input_data=input_data)
     verify_model(Std5().float().eval(), input_data=input_data)
     verify_model(Std6().float().eval(), input_data=input_data)
+    verify_model(Std7().float().eval(), input_data=input_data)
+    verify_model(Std8().float().eval(), input_data=input_data)
+    verify_model(Std9().float().eval(), input_data=input_data)
 
 
 def test_forward_variance():
@@ -1906,12 +1921,32 @@ def test_forward_variance():
         def forward(self, *args):
             return args[0].var(dim=(2,3), keepdim=False, unbiased=False)
 
+    class Variance6(Module):
+        def forward(self, *args):
+            return args[0].var(unbiased=False)
+
+    class Variance7(Module):
+        def forward(self, *args):
+            return args[0].var(dim=1, keepdim=False, unbiased=True)
+
+    class Variance8(Module):
+        def forward(self, *args):
+            return args[0].var(dim=(2,3), keepdim=True, unbiased=True)
+
+    class Variance9(Module):
+        def forward(self, *args):
+            return args[0].var(unbiased=True)
+
     input_data = torch.rand(input_shape).float()
     verify_model(Variance1().float().eval(), input_data=input_data)
     verify_model(Variance2().float().eval(), input_data=input_data)
     verify_model(Variance3().float().eval(), input_data=input_data)
     verify_model(Variance4().float().eval(), input_data=input_data)
     verify_model(Variance5().float().eval(), input_data=input_data)
+    verify_model(Variance6().float().eval(), input_data=input_data)
+    verify_model(Variance7().float().eval(), input_data=input_data)
+    verify_model(Variance8().float().eval(), input_data=input_data)
+    verify_model(Variance9().float().eval(), input_data=input_data)
 
 
 def test_forward_rsub():
diff --git a/tests/python/relay/test_op_grad_level4.py 
b/tests/python/relay/test_op_grad_level4.py
index 956c6af..b35ffe9 100644
--- a/tests/python/relay/test_op_grad_level4.py
+++ b/tests/python/relay/test_op_grad_level4.py
@@ -26,7 +26,10 @@ def verify_reduction_grad(red_fn, d_shape, axis=None, 
keepdims=False, exclude=Fa
 
 
 def test_reduction_grad():
-    for op in (relay.sum, relay.variance, relay.mean):
+    def _unbiased_variance(x, axis=None, keepdims=False, exclude=False):
+        return relay.variance(x, axis=axis, keepdims=keepdims, 
exclude=exclude, unbiased=True)
+
+    for op in (relay.sum, relay.variance, _unbiased_variance, relay.mean):
         verify_reduction_grad(op, (4, 2))
         verify_reduction_grad(op, (4, 2), axis=-1, keepdims=True)
         verify_reduction_grad(op, (4, 2, 1), axis=(1, 2), exclude=True)
diff --git a/tests/python/relay/test_op_level4.py 
b/tests/python/relay/test_op_level4.py
index c800b1c..8e01fa2 100644
--- a/tests/python/relay/test_op_level4.py
+++ b/tests/python/relay/test_op_level4.py
@@ -225,6 +225,16 @@ def test_reduce_functions():
         if not keepdims:
             x = np.squeeze(x, axis=axis)
         return x
+    
+    def _unbiased_relay_wrapper(f):
+        def _unbiased_func(x, axis=None, keepdims=False, exclude=False):
+            return f(x, axis=axis, keepdims=keepdims, exclude=exclude, 
unbiased=True)
+        return _unbiased_func
+    
+    def _unbiased_np_wrapper(f):
+        def _unbiased_func(a, axis=None, dtype=None, keepdims=None):
+            return f(a, axis=axis, dtype=dtype, ddof=1, keepdims=keepdims)
+        return _unbiased_func
 
     d1, d2, d3, d4 = te.var("d1"), te.var("d2"), te.var("d3"), te.var("d4")
     for func in [[relay.sum, np.sum],
@@ -232,7 +242,9 @@ def test_reduce_functions():
                  [relay.min, np.min],
                  [relay.mean, np.mean],
                  [relay.variance, np.var],
+                 [_unbiased_relay_wrapper(relay.variance), 
_unbiased_np_wrapper(np.var)],
                  [relay.std, np.std],
+                 [_unbiased_relay_wrapper(relay.std), 
_unbiased_np_wrapper(np.std)],
                  [relay.prod, np.prod],
                  [relay.all, np.all],
                  [relay.any, np.any],

Reply via email to