larroy commented on a change in pull request #13857: float32 -> float16 cast
consistency across implementations
URL: https://github.com/apache/incubator-mxnet/pull/13857#discussion_r248658930
##########
File path: tests/python/unittest/test_operator.py
##########
@@ -3994,6 +3994,44 @@ def test_cast():
assert_almost_equal(exe.grad_arrays[0].asnumpy(),
X.astype(dsttype).astype(srctype), rtol=1e-3, atol=1e-5)
+# Test requires all platforms to round float32->float16 with same
round-to-nearest-even policy.
+@with_seed()
+def test_cast_float32_to_float16():
+ fp16_fraction_bits = 10
+ fp32_fraction_bits = 23
+ fp32_exp_min = -126
+ fp32_exp_max = 127
+ # generate test cases in the vicinity of representable float16 mantissas
+ # and mid-way between them, but over the full range of float32 exponents.
+ def get_data():
+ for sign_bit in [0, 1]:
+ for exponent in range(fp32_exp_min - fp32_fraction_bits - 1,
fp32_exp_max + 2):
+ denominator = 2**(fp16_fraction_bits + 1)
+ for numerator in range(0, denominator):
+ for y in [-1.0, 0.0, 1.0]:
+ small_delta = y / 2**fp32_fraction_bits
+ val = (-1.0)**sign_bit * 2.0**exponent * (1.0 +
+ numerator / float(denominator) + small_delta)
+ yield val
+
+ input_np = np.array(list(get_data())).astype(np.float32)
+ # temp cast to np.float64 gets around numpy bug: see
https://github.com/numpy/numpy/issues/12721
+ expected_output = input_np.astype(np.float64).astype(np.float16)
+
+ x = mx.sym.Variable('x', dtype=np.float32)
+ sym = mx.sym.Cast(x, dtype=np.float16)
+ ctx = default_context()
+ exe = sym.bind(ctx, {'x' : mx.nd.array(input_np, dtype=np.float32,
ctx=ctx)})
+ assert exe.arg_arrays[0].dtype == np.float32
+ assert exe.outputs[0].dtype == np.float16
+ exe.forward(is_train=False)
+ sym_output = exe.outputs[0].asnumpy()
+ for fp32_val, model_fp16_val, np_fp16_val in zip(input_np, sym_output,
expected_output):
+ if model_fp16_val != np_fp16_val:
+ raise RuntimeError('fp32->fp16 cast mismatches seen, e.g. with val
{}, model_fp16 = {},'
Review comment:
Better raise assertionerror or use
https://nose.readthedocs.io/en/latest/testing_tools.html as RuntimeError has a
different semantics
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services