This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 968b6f6 Add `is_floating_point()` test and better type support in
`verify_model_vm()` (#7134)
968b6f6 is described below
commit 968b6f60da37d85232af6f9a6070d8ff2ed4be8a
Author: Tyler Davis <[email protected]>
AuthorDate: Tue Dec 22 01:20:36 2020 -0800
Add `is_floating_point()` test and better type support in
`verify_model_vm()` (#7134)
* Add div_ and is_floating_point operators
* Add handling of exprs to op, update tests
* add test + supporting functions
* Revert whitespace changes
* Properly assign dtype to random integers
* Reformat with black
* Switched default dtype logic, removed extra line
---
tests/python/frontend/pytorch/test_forward.py | 85 +++++++++++++++++++++++++--
1 file changed, 80 insertions(+), 5 deletions(-)
diff --git a/tests/python/frontend/pytorch/test_forward.py
b/tests/python/frontend/pytorch/test_forward.py
index 2dda675..74d9c78 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -1889,9 +1889,10 @@ def _get_default_vm_targets():
return [tgt for (tgt, _) in tvm.testing.enabled_targets()]
-def verify_script_model(pt_model, ishapes, targets):
+def verify_script_model(pt_model, ishapes, targets, idtype=None):
script_module = torch.jit.script(pt_model)
- verify_model_vm(script_module, ishapes, targets=targets)
+
+ verify_model_vm(script_module, ishapes, idtype=idtype, targets=targets)
def verify_trace_model(pt_model, idata, targets):
@@ -1900,10 +1901,60 @@ def verify_trace_model(pt_model, idata, targets):
verify_model_vm(traced_model, ishapes, idata=idata, targets=targets)
-def verify_model_vm(input_model, ishapes, idtype=torch.float, idata=None,
targets=["llvm"]):
+def convert_pt_to_tvm_type(idtype):
+ """ Accepts a pytorch dtype and returns string TVM dtype."""
+ # TVM does not support PyTorch complex dtypes
+ if idtype == torch.float64:
+ curr_dtype = "float64"
+ elif idtype == torch.float32:
+ curr_dtype = "float32"
+ elif idtype == torch.float16:
+ curr_dtype = "float16"
+ elif idtype == torch.bfloat16:
+ curr_dtype = "bfloat16"
+ elif idtype == torch.int64:
+ curr_dtype = "int64"
+ elif idtype == torch.int32:
+ curr_dtype = "int32"
+ elif idtype == torch.int16:
+ curr_dtype = "int16"
+ elif idtype == torch.int8:
+ curr_dtype = "int8"
+ elif idtype == torch.uint8:
+ curr_dtype = "uint8"
+ elif idtype == torch.bool:
+ curr_dtype = "bool"
+ else:
+ raise NotImplementedError("Unsupported dtype: {}".format(idtype))
+ return curr_dtype
+
+
+def verify_model_vm(input_model, ishapes, idtype=None, idata=None,
targets=["llvm"]):
+ if not idtype:
+ idtype = torch.float
+
input_names = ["i{}".format(idx) for idx, ish in enumerate(ishapes)]
- input_shapes = list(zip(input_names, ishapes))
- input_data = idata if idata else [torch.randn(shape, dtype=idtype) for
shape in ishapes]
+ tvm_dtype = convert_pt_to_tvm_type(idtype)
+ input_dtypes = [tvm_dtype] * len(input_names)
+ input_shapes = list(zip(input_names, list(zip(ishapes, input_dtypes))))
+
+ if idata:
+ input_data = idata
+ # If no input_data provided, generate random data of specified dtype
+ else:
+ if idtype == torch.bool:
+ input_data = [
+ torch.Tensor.bool(torch.randint(low=0, high=2, size=shape))
for shape in ishapes
+ ]
+ # Torch dtype can be float, complex, int, or Bool. Complex not
supported, so if not float or Bool,
+ # dtype must be int!
+ elif not idtype.is_floating_point:
+ input_data = [
+ torch.randint(low=0, high=10, size=shape, dtype=idtype) for
shape in ishapes
+ ]
+ else:
+ input_data = [torch.randn(shape, dtype=idtype) for shape in
ishapes]
+
# Compile via VM
mod, params = relay.frontend.from_pytorch(input_model, input_shapes)
@@ -2951,6 +3002,29 @@ def test_forward_true_divide():
@tvm.testing.uses_gpu
+def test_forward_is_floating_point():
+ torch.set_grad_enabled(False)
+
+ class IsFloatingPoint(Module):
+ def forward(self, arg):
+ # `torch.jit.trace` cannot accept something that outputs
+ # a Bool, so `torch.jit.script` will be used instead
+ return torch.is_floating_point(arg)
+
+ targets = _get_default_vm_targets()
+ verify_script_model(IsFloatingPoint(), [(1, 1)], targets,
idtype=torch.float64)
+ verify_script_model(IsFloatingPoint(), [(1, 1)], targets,
idtype=torch.float32)
+ verify_script_model(IsFloatingPoint(), [(1, 1)], targets,
idtype=torch.float16)
+ # todo(dvisnty): Run the test for bfloat16 when full bfloat16 support is
implemented
+ # verify_script_model(IsFloatingPoint(), [(1,1)], targets,
idtype=torch.bfloat16)
+ verify_script_model(IsFloatingPoint(), [(1, 1)], targets,
idtype=torch.int64)
+ verify_script_model(IsFloatingPoint(), [(1, 1)], targets,
idtype=torch.int32)
+ verify_script_model(IsFloatingPoint(), [(1, 1)], targets,
idtype=torch.int16)
+ verify_script_model(IsFloatingPoint(), [(1, 1)], targets,
idtype=torch.int8)
+ verify_script_model(IsFloatingPoint(), [(1, 1)], targets,
idtype=torch.uint8)
+
+
[email protected]_gpu
def test_forward_traced_function():
def fn(t1, t2):
return t1 + t2
@@ -3425,6 +3499,7 @@ if __name__ == "__main__":
test_forward_addcdiv()
test_forward_addcmul()
test_forward_true_divide()
+ test_forward_is_floating_point()
test_forward_clone()
test_forward_softplus()
test_forward_softsign()