This is an automated email from the ASF dual-hosted git repository.
bgawrych pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new db39bb1126 Fix multivariate normal bug (#21105)
db39bb1126 is described below
commit db39bb1126dee62e44ca0bd1121695a77e2873d2
Author: hankaj <[email protected]>
AuthorDate: Fri Jul 29 13:06:00 2022 +0200
Fix multivariate normal bug (#21105)
---
python/mxnet/numpy_op_fallback.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/python/mxnet/numpy_op_fallback.py
b/python/mxnet/numpy_op_fallback.py
index 17b6327bb7..69b3f28662 100644
--- a/python/mxnet/numpy_op_fallback.py
+++ b/python/mxnet/numpy_op_fallback.py
@@ -182,7 +182,7 @@ class MultivariateNormal(operator.CustomOp):
scale = _mx_np.linalg.cholesky(cov)
#set context
noise = _mx_np.random.normal(size=out_data[0].shape, dtype=loc.dtype,
device=loc.device)
- out = loc + _mx_np.einsum('...jk,...j->...k', scale, noise)
+ out = loc + _mx_np.einsum('...jk,...k->...j', scale, noise)
self.assign(out_data[0], req[0], out)
def backward(self, req, out_grad, in_data, out_data, in_grad, aux):