This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 968b6f6  Add `is_floating_point()` test and better type support in 
`verify_model_vm()` (#7134)
968b6f6 is described below

commit 968b6f60da37d85232af6f9a6070d8ff2ed4be8a
Author: Tyler Davis <[email protected]>
AuthorDate: Tue Dec 22 01:20:36 2020 -0800

    Add `is_floating_point()` test and better type support in 
`verify_model_vm()` (#7134)
    
    * Add div_ and is_floating_point operators
    
    * Add handling of exprs to op, update tests
    
    * add test + supporting functions
    
    * Revert whitespace changes
    
    * Properly assign dtype to random integers
    
    * Reformat with black
    
    * Switched default dtype logic, removed extra line
---
 tests/python/frontend/pytorch/test_forward.py | 85 +++++++++++++++++++++++++--
 1 file changed, 80 insertions(+), 5 deletions(-)

diff --git a/tests/python/frontend/pytorch/test_forward.py 
b/tests/python/frontend/pytorch/test_forward.py
index 2dda675..74d9c78 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -1889,9 +1889,10 @@ def _get_default_vm_targets():
     return [tgt for (tgt, _) in tvm.testing.enabled_targets()]
 
 
-def verify_script_model(pt_model, ishapes, targets):
+def verify_script_model(pt_model, ishapes, targets, idtype=None):
     script_module = torch.jit.script(pt_model)
-    verify_model_vm(script_module, ishapes, targets=targets)
+
+    verify_model_vm(script_module, ishapes, idtype=idtype, targets=targets)
 
 
 def verify_trace_model(pt_model, idata, targets):
@@ -1900,10 +1901,60 @@ def verify_trace_model(pt_model, idata, targets):
     verify_model_vm(traced_model, ishapes, idata=idata, targets=targets)
 
 
-def verify_model_vm(input_model, ishapes, idtype=torch.float, idata=None, 
targets=["llvm"]):
+def convert_pt_to_tvm_type(idtype):
+    """ Accepts a pytorch dtype and returns string TVM dtype."""
+    # TVM does not support PyTorch complex dtypes
+    if idtype == torch.float64:
+        curr_dtype = "float64"
+    elif idtype == torch.float32:
+        curr_dtype = "float32"
+    elif idtype == torch.float16:
+        curr_dtype = "float16"
+    elif idtype == torch.bfloat16:
+        curr_dtype = "bfloat16"
+    elif idtype == torch.int64:
+        curr_dtype = "int64"
+    elif idtype == torch.int32:
+        curr_dtype = "int32"
+    elif idtype == torch.int16:
+        curr_dtype = "int16"
+    elif idtype == torch.int8:
+        curr_dtype = "int8"
+    elif idtype == torch.uint8:
+        curr_dtype = "uint8"
+    elif idtype == torch.bool:
+        curr_dtype = "bool"
+    else:
+        raise NotImplementedError("Unsupported dtype: {}".format(idtype))
+    return curr_dtype
+
+
+def verify_model_vm(input_model, ishapes, idtype=None, idata=None, 
targets=["llvm"]):
+    if not idtype:
+        idtype = torch.float
+
     input_names = ["i{}".format(idx) for idx, ish in enumerate(ishapes)]
-    input_shapes = list(zip(input_names, ishapes))
-    input_data = idata if idata else [torch.randn(shape, dtype=idtype) for 
shape in ishapes]
+    tvm_dtype = convert_pt_to_tvm_type(idtype)
+    input_dtypes = [tvm_dtype] * len(input_names)
+    input_shapes = list(zip(input_names, list(zip(ishapes, input_dtypes))))
+
+    if idata:
+        input_data = idata
+    # If no input_data provided, generate random data of specified dtype
+    else:
+        if idtype == torch.bool:
+            input_data = [
+                torch.Tensor.bool(torch.randint(low=0, high=2, size=shape)) 
for shape in ishapes
+            ]
+        # Torch dtype can be float, complex, int, or Bool. Complex not 
supported, so if not float or Bool,
+        # dtype must be int!
+        elif not idtype.is_floating_point:
+            input_data = [
+                torch.randint(low=0, high=10, size=shape, dtype=idtype) for 
shape in ishapes
+            ]
+        else:
+            input_data = [torch.randn(shape, dtype=idtype) for shape in 
ishapes]
+
     # Compile via VM
     mod, params = relay.frontend.from_pytorch(input_model, input_shapes)
 
@@ -2951,6 +3002,29 @@ def test_forward_true_divide():
 
 
 @tvm.testing.uses_gpu
+def test_forward_is_floating_point():
+    torch.set_grad_enabled(False)
+
+    class IsFloatingPoint(Module):
+        def forward(self, arg):
+            # `torch.jit.trace` cannot accept something that outputs
+            # a Bool, so `torch.jit.script` will be used instead
+            return torch.is_floating_point(arg)
+
+    targets = _get_default_vm_targets()
+    verify_script_model(IsFloatingPoint(), [(1, 1)], targets, 
idtype=torch.float64)
+    verify_script_model(IsFloatingPoint(), [(1, 1)], targets, 
idtype=torch.float32)
+    verify_script_model(IsFloatingPoint(), [(1, 1)], targets, 
idtype=torch.float16)
+    # todo(dvisnty): Run the test for bfloat16 when full bfloat16 support is 
implemented
+    # verify_script_model(IsFloatingPoint(), [(1,1)], targets, 
idtype=torch.bfloat16)
+    verify_script_model(IsFloatingPoint(), [(1, 1)], targets, 
idtype=torch.int64)
+    verify_script_model(IsFloatingPoint(), [(1, 1)], targets, 
idtype=torch.int32)
+    verify_script_model(IsFloatingPoint(), [(1, 1)], targets, 
idtype=torch.int16)
+    verify_script_model(IsFloatingPoint(), [(1, 1)], targets, 
idtype=torch.int8)
+    verify_script_model(IsFloatingPoint(), [(1, 1)], targets, 
idtype=torch.uint8)
+
+
[email protected]_gpu
 def test_forward_traced_function():
     def fn(t1, t2):
         return t1 + t2
@@ -3425,6 +3499,7 @@ if __name__ == "__main__":
     test_forward_addcdiv()
     test_forward_addcmul()
     test_forward_true_divide()
+    test_forward_is_floating_point()
     test_forward_clone()
     test_forward_softplus()
     test_forward_softsign()

Reply via email to