xidulu commented on a change in pull request #16638: [Numpy] Add sampling 
method for bernoulli
URL: https://github.com/apache/incubator-mxnet/pull/16638#discussion_r344432954
 
 

 ##########
 File path: python/mxnet/symbol/numpy_extension/random.py
 ##########
 @@ -0,0 +1,101 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Namespace for operators used in Gluon dispatched by F=symbol."""
+
+from __future__ import absolute_import
+from ...context import current_context
+from ..numpy import _internal as _npi
+
+__all__ = ['bernoulli']
+
+
+def bernoulli(prob=None, logit=None, size=None, dtype=None, ctx=None, 
out=None):
+    """Creates a Bernoulli distribution parameterized by :attr:`prob`
+    or :attr:`logit` (but not both).
+
+    Samples are binary (0 or 1). They take the value `1` with probability `p`
+    and `0` with probability `1 - p`.
+
+    Parameters
+    ----------
+    prob : float, _Symbol
+        The probability of sampling '1'.
+    logit : float, _Symbol
+        The log-odds of sampling '1'.
+    size : int or tuple of ints, optional
+        Output shape.  If the given shape is, e.g., ``(m, n, k)``, then
+        ``m * n * k`` samples are drawn.  Default is None, in which case a
+        single value is returned.
+    dtype : dtype, optional
+        Desired dtype of the result. All dtypes are determined by their
+        name, i.e., 'int64', 'int', etc, so byteorder is not available
+        and a specific precision may have different C types depending
+        on the platform. The default value is 'np.float32'.
+    ctx : Context, optional
+        Device context of output. Default is current context.
+    out : symbol, optional
+        The output symbol (default is `None`).
+
+    Returns
+    -------
+    out : _Symbol
+        Drawn samples from the parameterized bernoulli distribution.
+
+    Examples
+    --------
+    >>> prob = np.random.uniform(size=(4,4))
+    >>> logit = np.log(prob) - np.log(1 - prob)
+    >>> npx.random.bernoulli(logit=logit)
+    array([[0., 1., 1., 1.],
+        [0., 1., 1., 1.],
+        [0., 1., 0., 0.],
+        [1., 0., 1., 0.]])
+
+    >>> npx.random.bernoulli(prob=prob)
+    array([[0., 1., 0., 1.],
+        [1., 1., 1., 1.],
+        [1., 1., 1., 0.],
+        [1., 0., 1., 0.]])
+    """
+    from ..numpy import _Symbol as np_symbol
+    tensor_type_name = np_symbol
+    if (prob is None) == (logit is None):
+        raise ValueError(
+            "Either `prob` or `logit` must be specified, but not both.")
 
 Review comment:
   Good suggestion! 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to