octoJon commented on code in PR #13802:
URL: https://github.com/apache/tvm/pull/13802#discussion_r1085997156
##########
tests/python/frontend/onnx/test_forward.py:
##########
@@ -6663,6 +6663,105 @@ def verify_qlinearsigmoid(a_shape):
verify_qlinearsigmoid([])
[email protected]_targets("llvm")
+def test_random_bernoulli(target, dev):
+ """test_random_bernoulli"""
+
+ def verify_bernoulli_with_ort(
+ shape,
+ in_dtype="float32",
+ out_dtype="int32",
+ seed=None,
+ out_shape=None,
+ target=target,
+ dev=dev,
+ use_vm=False,
+ opset=None,
+ freeze_params=False,
+ rtol=0.1,
+ atol=0.1,
+ opt_level=1,
+ convert_config=None,
+ ):
+ def get_bernoulli_model(shape, in_dtype="float32", out_dtype="int32",
seed=None):
+ onnx_itype = mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(in_dtype)]
+ onnx_otype = mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(out_dtype)]
+ node = helper.make_node(
+ "Bernoulli",
+ ["input"],
+ ["output"],
+ )
+ dtype_attr = helper.make_attribute("dtype", onnx_otype)
+ node.attribute.append(dtype_attr)
+ if seed is not None:
+ seed_attr = helper.make_attribute("seed", seed)
+ node.attribute.append(seed_attr)
+
+ graph = helper.make_graph(
+ [node],
+ "random_bernoulli_test",
+ inputs=[helper.make_tensor_value_info("input", onnx_itype,
list(shape))],
+ outputs=[helper.make_tensor_value_info("output", onnx_otype,
list(shape))],
+ )
+ return helper.make_model(graph,
producer_name="random_bernoulli_test")
+
+ inputs = np.random.uniform(size=shape).astype(in_dtype)
+ if seed is None:
+ ort_seed = None
+ else:
+ ort_seed = float(seed)
+ model = get_bernoulli_model(shape, in_dtype, out_dtype, ort_seed)
+ if opset is not None:
+ model.opset_import[0].version = opset
+
+ ort_out = get_onnxruntime_output(model, inputs)
+ if use_vm:
+ tvm_out = get_tvm_output_with_vm(
+ model,
+ inputs,
+ target,
+ dev,
+ opset=opset,
+ freeze_params=freeze_params,
+ convert_config=convert_config,
+ )
+ else:
+ tvm_out = get_tvm_output(
+ model,
+ inputs,
+ target,
+ dev,
+ out_shape,
+ opset=opset,
+ opt_level=opt_level,
+ convert_config=convert_config,
+ )
+
+ if not isinstance(tvm_out, list):
+ tvm_out = [tvm_out]
+ if not isinstance(ort_out, list):
+ ort_out = [ort_out]
+ for tvm_val, ort_val in zip(tvm_out, ort_out):
+ tvm.testing.assert_allclose(ort_val.mean(), tvm_val.mean(),
rtol=rtol, atol=atol)
Review Comment:
Statistician here! Here are the things I would test:
1. Verify all outputs are 0 or 1.
2. Verify that if you feed in an array of input probabilities which is a mix
of zeroes and ones, then the outputs you get out match your inputs. That is, a
Bernoulli distribution with p=1 should always return a 1, and a Bernoulli
distribution with p=0 should always return a 0. And if you have a mix of those
in the same input tensor, then you still get the expected mix of zeroes and
ones out.
3. Pick a couple fixed probabilities to test. (I would recommend p=0.5 and
p=0.1.) Verify that if you feed in a large-ish input array of all, say, p=0.5,
then the mean value of your outputs is "sufficiently" close to 0.5. The best
way to test for that is with a binomial proportion test, like
`scipy.stats.binomtest`. Normally we'd run a binomial proportion test and
confirm that the p-value from that test is larger than 0.05, but, as Andrew
says, that would give you quite a flaky test. For a non-flaky test, I would
recommend having the test fail only if the p-value from the binomial proportion
test is <1e-6, and just making sure that you have a large enough sample that
it's possible to detect small deviations from the intended binomial
distribution. With a sample of size 10,000, there's roughly a 1e-6 chance that
your sample mean will be more than 2.5 percentage points away from the true
mean, so an input tensor with 10,000 entries should get the job done.
So to recap on #3: If it's not too slow to generate a tensor of 10,000
Bernoulli values, my recommendation would be to generate 10,000 Bernoulli
values with p=0.5, run a binomial proportion test, and have the unit test fail
if the binomial proportion test has a p-value of <1e-6. That test should only
have a flaky failure once per million times that it's run. ... And then also
repeat that unit test again with p=0.1 rather than p=0.5.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]