leandron commented on a change in pull request #6232:
URL: https://github.com/apache/incubator-tvm/pull/6232#discussion_r466997354
##########
File path: python/tvm/relay/op/reduce.py
##########
@@ -376,6 +408,39 @@ def std(data, axis=None, keepdims=False, exclude=False):
return sqrt(_make._variance(data, m, axis, keepdims, exclude))
+def unbiased_std(data, axis=None, keepdims=False, exclude=False):
+ """Computes the unbiased standard deviation of data over given axes.
+
+ Parameters
+ ----------
+ data : relay.Expr
+ The input data
+
+ axis : None or int or tuple of int
+ Axis or axes along which a standard deviation operation is performed.
+ The default, axis=None, will compute the standard deviation of all
elements in the
+ input array. If axis is negative it counts from the last to the first
axis.
+
+ keepdims : bool
+ If this is set to True, the axes which are reduced are left in the
result as dimensions
+ with size one.
+ With this option, the result will broadcast correctly against the
input array.
+
+ exclude : bool
+ If `exclude` is true, reduction will be performed on the axes that are
+ NOT in axis instead.
+
+ Returns
+ -------
+ result : relay.Expr
+ The computed result.
+ """
+ axis = [axis] if isinstance(axis, int) else axis
+ m = mean(data, axis, True, exclude)
+
+ return sqrt(_make._unbiased_variance(data, m, axis, keepdims, exclude))
Review comment:
I'm not very familiar with the specifics here, but, could this be
replaced by the example below, to avoid repetition?
```
return sqrt(unbiased_variance(data, axis, keepdims, exclude))
```
##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -1263,27 +1263,32 @@ def _impl(inputs, input_types):
unbiased = bool(inputs[2])
if unbiased:
- msg = "Currently only supports standard-deviation calculated via
the biased "\
- "estimator. PyTorch's Bessel's correction is not supported."
- raise NotImplementedError(msg)
+ std_op = _op.reduce.unbiased_std
+ else:
+ std_op = _op.reduce.std
Review comment:
minor suggestion: you could use the same as you did below with `axis`,
to make this statement shorter.
```
std_op = _op.reduce.unbiased_std if unbiased else _op.reduce.std
```
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]