ChaiBapchya closed pull request #12775: randn operator for symbol and NDarray
API
URL: https://github.com/apache/incubator-mxnet/pull/12775
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/python/mxnet/ndarray/random.py b/python/mxnet/ndarray/random.py
index 1e941f79aa1..2d550f52664 100644
--- a/python/mxnet/ndarray/random.py
+++ b/python/mxnet/ndarray/random.py
@@ -152,7 +152,7 @@ def normal(loc=0, scale=1, shape=_Null, dtype=_Null,
ctx=None, out=None, **kwarg
[loc, scale], shape, dtype, ctx, out, kwargs)
-def randn(*shape, **kwargs):
+def randn(loc=0, scale=1, shape=_Null, dtype=_Null, ctx=None, out=None,
**kwargs):
"""Draw random samples from a normal (Gaussian) distribution.
Samples are distributed according to a normal distribution parametrized
@@ -193,11 +193,6 @@ def randn(*shape, **kwargs):
[5.357444 5.7793283 3.9896927]]
<NDArray 2x3 @cpu(0)>
"""
- loc = kwargs.pop('loc', 0)
- scale = kwargs.pop('scale', 1)
- dtype = kwargs.pop('dtype', _Null)
- ctx = kwargs.pop('ctx', None)
- out = kwargs.pop('out', None)
assert isinstance(loc, (int, float))
assert isinstance(scale, (int, float))
return _random_helper(_internal._random_normal, _internal._sample_normal,
diff --git a/python/mxnet/symbol/random.py b/python/mxnet/symbol/random.py
index e9abe9c4a18..32e2f6a7cef 100644
--- a/python/mxnet/symbol/random.py
+++ b/python/mxnet/symbol/random.py
@@ -96,6 +96,33 @@ def normal(loc=0, scale=1, shape=_Null, dtype=_Null,
**kwargs):
[loc, scale], shape, dtype, kwargs)
+def randn(loc=0, scale=1, shape=_Null, dtype=_Null, **kwargs):
+ """Draw random samples from a normal (Gaussian) distribution.
+
+ Samples are distributed according to a normal distribution parametrized
+ by *loc* (mean) and *scale* (standard deviation).
+
+
+ Parameters
+ ----------
+ loc : float or NDArray
+ Mean (centre) of the distribution.
+ scale : float or NDArray
+ Standard deviation (spread or width) of the distribution.
+ shape : int or tuple of ints
+ The number of samples to draw. If shape is, e.g., `(m, n)` and `loc`
and
+ `scale` are scalars, output shape will be `(m, n)`. If `loc` and
`scale`
+ are NDArrays with shape, e.g., `(x, y)`, then output will have shape
+ `(x, y, m, n)`, where `m*n` samples are drawn for each `[loc, scale)`
pair.
+ dtype : {'float16','float32', 'float64'}
+ Data type of output samples. Default is 'float32'
+ """
+ assert isinstance(loc, (int, float))
+ assert isinstance(scale, (int, float))
+ return _random_helper(_internal._random_normal, _internal._sample_normal,
+ [loc, scale], shape, dtype, kwargs)
+
+
def poisson(lam=1, shape=_Null, dtype=_Null, **kwargs):
"""Draw random samples from a Poisson distribution.
diff --git a/tests/python/unittest/test_loss.py
b/tests/python/unittest/test_loss.py
index 18d1ebf8fb1..b34132fa195 100644
--- a/tests/python/unittest/test_loss.py
+++ b/tests/python/unittest/test_loss.py
@@ -351,9 +351,9 @@ def test_triplet_loss():
@with_seed()
def test_cosine_loss():
#Generating samples
- input1 = mx.nd.random.randn(3, 2)
- input2 = mx.nd.random.randn(3, 2)
- label = mx.nd.sign(mx.nd.random.randn(input1.shape[0]))
+ input1 = mx.nd.random.randn(shape=(3, 2))
+ input2 = mx.nd.random.randn(shape=(3, 2))
+ label = mx.nd.sign(mx.nd.random.randn(shape=(input1.shape[0])))
#Calculating loss from cosine embedding loss function in Gluon
Loss = gluon.loss.CosineEmbeddingLoss()
loss = Loss(input1, input2, label)
diff --git a/tests/python/unittest/test_random.py
b/tests/python/unittest/test_random.py
index 6a59d8627ba..809df12f1d2 100644
--- a/tests/python/unittest/test_random.py
+++ b/tests/python/unittest/test_random.py
@@ -57,6 +57,7 @@ def check_with_device(device, dtype):
},
{
'name': 'randn',
+ 'symbol':mx.sym.random.randn,
'ndop': mx.nd.random.randn,
'params': { 'loc': 10.0, 'scale': 0.5 },
'checks': [
@@ -203,9 +204,6 @@ def check_with_device(device, dtype):
params = symbdic['params'].copy()
params.update(shape=shape, dtype=dtype, ctx=device)
args = ()
- if name == 'randn':
- params.pop('shape') # randn does not accept shape param
- args = shape
if name.endswith('_like'):
params['data'] = mx.nd.ones(params.pop('shape'),
dtype=params.pop('dtype'),
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services