octoJon commented on code in PR #13802:
URL: https://github.com/apache/tvm/pull/13802#discussion_r1088224316


##########
tests/python/frontend/onnx/test_forward.py:
##########
@@ -6707,6 +6707,117 @@ def verify_qlinearsigmoid(a_shape):
     verify_qlinearsigmoid([])
 
 
[email protected]_targets("llvm")
+def test_random_bernoulli(target, dev):
+    """test_random_bernoulli"""
+
+    def verify_bernoulli(
+        inputs=None,
+        shape=[],
+        in_dtype="float32",
+        out_dtype="int32",
+        seed=None,
+        target=target,
+        dev=dev,
+        use_vm=False,
+        freeze_params=False,
+        rtol=0.1,
+        atol=0.1,
+        in_out_equal=False,
+    ):
+        def get_bernoulli_model(shape, in_dtype="float32", out_dtype="int32", 
seed=None):
+            onnx_itype = mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(in_dtype)]
+            onnx_otype = mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(out_dtype)]
+            node = helper.make_node(
+                "Bernoulli",
+                ["input"],
+                ["output"],
+            )
+            dtype_attr = helper.make_attribute("dtype", onnx_otype)
+            node.attribute.append(dtype_attr)
+            if seed is not None:
+                seed_attr = helper.make_attribute("seed", float(seed))
+                node.attribute.append(seed_attr)
+
+            graph = helper.make_graph(
+                [node],
+                "random_bernoulli_test",
+                inputs=[helper.make_tensor_value_info("input", onnx_itype, 
list(shape))],
+                outputs=[helper.make_tensor_value_info("output", onnx_otype, 
list(shape))],
+            )
+            return helper.make_model(graph, 
producer_name="random_bernoulli_test")
+
+        if inputs is None:
+            assert len(shape) != 0
+            inputs = np.random.uniform(size=shape).astype(in_dtype)
+        else:
+            shape = inputs.shape
+            in_dtype = inputs.dtype
+        model = get_bernoulli_model(shape, in_dtype, out_dtype, seed)
+
+        if use_vm:
+            tvm_out = get_tvm_output_with_vm(
+                model,
+                inputs,
+                target,
+                dev,
+                freeze_params=freeze_params,
+            )
+        else:
+            tvm_out = get_tvm_output(
+                model,
+                inputs,
+                target,
+                dev,
+            )
+
+        if isinstance(tvm_out, list):
+            tvm_out = tvm_out[0]
+        ideal_mean = np.mean(inputs)
+        # check that values are 0 or 1
+        tvm_flat = tvm_out.flatten()
+        for i in range(len(tvm_flat)):
+            assert tvm_flat[i] == 0 or tvm_flat[i] == 1
+        if in_out_equal:
+            tvm.testing.assert_allclose(inputs, tvm_out)
+        else:
+            # check that mean value is close to the theoretical one by 
binomial test
+            bnm_test_res = scipy.stats.binomtest(

Review Comment:
   Oh, god. I didn't read that reference link closely enough. It contains a 
really blatantly misogynistic comment. Apologies to anyone I exposed to that. 
Please don't include that reference link in the code.
   
   (I think the math is right, though.)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to