This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 6027412  [IR] Update the type_keys to reflect the code-org (#5074)
6027412 is described below

commit 6027412bcb443572088d71f1060370317eb6e671
Author: Tianqi Chen <[email protected]>
AuthorDate: Sun Mar 15 15:39:47 2020 -0700

    [IR] Update the type_keys to reflect the code-org (#5074)
---
 include/tvm/ir/expr.h                   |   2 +-
 include/tvm/ir/module.h                 |   2 +-
 include/tvm/ir/span.h                   |   4 +-
 include/tvm/ir/transform.h              |   6 +-
 include/tvm/ir/type.h                   |  18 +++---
 include/tvm/ir/type_relation.h          |   6 +-
 python/tvm/ir/__init__.py               |   2 +-
 python/tvm/ir/base.py                   |   4 +-
 python/tvm/ir/expr.py                   |   2 +-
 python/tvm/ir/json_compact.py           |  24 +++++++
 python/tvm/ir/module.py                 |   2 +-
 python/tvm/ir/transform.py              |  10 +--
 python/tvm/ir/type.py                   |  25 ++++++--
 python/tvm/ir/type_relation.py          |   3 +-
 src/ir/transform.cc                     |   4 +-
 tests/python/relay/test_ir_nodes.py     |  78 -----------------------
 tests/python/relay/test_json_compact.py |  73 ++++++++++++++++++++-
 tests/python/unittest/test_ir_type.py   | 108 ++++++++++++++++++++++++++++++++
 18 files changed, 255 insertions(+), 118 deletions(-)

diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h
index e37374a..c8b1a3f 100644
--- a/include/tvm/ir/expr.h
+++ b/include/tvm/ir/expr.h
@@ -196,7 +196,7 @@ class GlobalVarNode : public RelayExprNode {
     v->Visit("_checked_type_", &checked_type_);
   }
 
-  static constexpr const char* _type_key = "relay.GlobalVar";
+  static constexpr const char* _type_key = "GlobalVar";
   TVM_DECLARE_FINAL_OBJECT_INFO(GlobalVarNode, RelayExprNode);
 };
 
diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h
index 23d1f6e..1ee7c32 100644
--- a/include/tvm/ir/module.h
+++ b/include/tvm/ir/module.h
@@ -226,7 +226,7 @@ class IRModuleNode : public Object {
    */
   TVM_DLL std::unordered_set<std::string> Imports() const;
 
-  static constexpr const char* _type_key = "relay.Module";
+  static constexpr const char* _type_key = "IRModule";
   TVM_DECLARE_FINAL_OBJECT_INFO(IRModuleNode, Object);
 
  private:
diff --git a/include/tvm/ir/span.h b/include/tvm/ir/span.h
index 8cbfff7..4720dfe 100644
--- a/include/tvm/ir/span.h
+++ b/include/tvm/ir/span.h
@@ -44,7 +44,7 @@ class SourceNameNode : public Object {
   // override attr visitor
   void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); }
 
-  static constexpr const char* _type_key = "relay.SourceName";
+  static constexpr const char* _type_key = "SourceName";
   TVM_DECLARE_FINAL_OBJECT_INFO(SourceNameNode, Object);
 };
 
@@ -89,7 +89,7 @@ class SpanNode : public Object {
 
   TVM_DLL static Span make(SourceName source, int lineno, int col_offset);
 
-  static constexpr const char* _type_key = "relay.Span";
+  static constexpr const char* _type_key = "Span";
   TVM_DECLARE_FINAL_OBJECT_INFO(SpanNode, Object);
 };
 
diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h
index 2afcb17..1b6ea25 100644
--- a/include/tvm/ir/transform.h
+++ b/include/tvm/ir/transform.h
@@ -110,7 +110,7 @@ class PassContextNode : public Object {
     v->Visit("disabled_pass", &disabled_pass);
   }
 
-  static constexpr const char* _type_key = "relay.PassContext";
+  static constexpr const char* _type_key = "transform.PassContext";
   TVM_DECLARE_FINAL_OBJECT_INFO(PassContextNode, Object);
 };
 
@@ -206,7 +206,7 @@ class PassInfoNode : public Object {
     v->Visit("required", &required);
   }
 
-  static constexpr const char* _type_key = "relay.PassInfo";
+  static constexpr const char* _type_key = "transform.PassInfo";
   TVM_DECLARE_FINAL_OBJECT_INFO(PassInfoNode, Object);
 };
 
@@ -265,7 +265,7 @@ class PassNode : public Object {
 
   void VisitAttrs(AttrVisitor* v) {}
 
-  static constexpr const char* _type_key = "relay.Pass";
+  static constexpr const char* _type_key = "transform.Pass";
   TVM_DECLARE_BASE_OBJECT_INFO(PassNode, Object);
 };
 
diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h
index 7fd224b..a9475a1 100644
--- a/include/tvm/ir/type.h
+++ b/include/tvm/ir/type.h
@@ -78,7 +78,7 @@ class TypeNode : public Object {
    */
   mutable Span span;
 
-  static constexpr const char* _type_key = "relay.Type";
+  static constexpr const char* _type_key = "Type";
   TVM_DECLARE_BASE_OBJECT_INFO(TypeNode, Object);
 };
 
@@ -110,7 +110,7 @@ class PrimTypeNode : public TypeNode {
     v->Visit("dtype", &dtype);
   }
 
-  static constexpr const char* _type_key = "relay.PrimType";
+  static constexpr const char* _type_key = "PrimType";
   TVM_DECLARE_FINAL_OBJECT_INFO(PrimTypeNode, TypeNode);
 };
 
@@ -175,7 +175,7 @@ class TypeVarNode : public TypeNode {
     v->Visit("span", &span);
   }
 
-  static constexpr const char* _type_key = "relay.TypeVar";
+  static constexpr const char* _type_key = "TypeVar";
   TVM_DECLARE_FINAL_OBJECT_INFO(TypeVarNode, TypeNode);
 };
 
@@ -215,7 +215,7 @@ class GlobalTypeVarNode : public TypeNode {
     v->Visit("kind", &kind);
   }
 
-  static constexpr const char* _type_key = "relay.GlobalTypeVar";
+  static constexpr const char* _type_key = "GlobalTypeVar";
   TVM_DECLARE_FINAL_OBJECT_INFO(GlobalTypeVarNode, TypeNode);
 };
 
@@ -251,7 +251,7 @@ class TupleTypeNode : public TypeNode {
     v->Visit("span", &span);
   }
 
-  static constexpr const char* _type_key = "relay.TupleType";
+  static constexpr const char* _type_key = "TupleType";
   TVM_DECLARE_FINAL_OBJECT_INFO(TupleTypeNode, TypeNode);
 };
 
@@ -289,7 +289,7 @@ inline Type VoidType() {
  */
 class TypeConstraintNode : public TypeNode {
  public:
-  static constexpr const char* _type_key = "relay.TypeConstraint";
+  static constexpr const char* _type_key = "TypeConstraint";
   TVM_DECLARE_BASE_OBJECT_INFO(TypeConstraintNode, TypeNode);
 };
 
@@ -334,7 +334,7 @@ class FuncTypeNode : public TypeNode {
     v->Visit("span", &span);
   }
 
-  static constexpr const char* _type_key = "relay.FuncType";
+  static constexpr const char* _type_key = "FuncType";
   TVM_DECLARE_FINAL_OBJECT_INFO(FuncTypeNode, TypeNode);
 };
 
@@ -380,7 +380,7 @@ class IncompleteTypeNode : public TypeNode {
     v->Visit("span", &span);
   }
 
-  static constexpr const char* _type_key = "relay.IncompleteType";
+  static constexpr const char* _type_key = "IncompleteType";
   TVM_DECLARE_FINAL_OBJECT_INFO(IncompleteTypeNode, TypeNode);
 };
 
@@ -417,6 +417,8 @@ class RelayRefTypeNode : public TypeNode {
     v->Visit("span", &span);
   }
 
+  // Keep the relay prefix in the type as this type is specific
+  // to the relay itself.
   static constexpr const char* _type_key = "relay.RefType";
   TVM_DECLARE_FINAL_OBJECT_INFO(RelayRefTypeNode, TypeNode);
 };
diff --git a/include/tvm/ir/type_relation.h b/include/tvm/ir/type_relation.h
index ff36b96..f7bfb68 100644
--- a/include/tvm/ir/type_relation.h
+++ b/include/tvm/ir/type_relation.h
@@ -50,7 +50,7 @@ class TypeCallNode : public TypeNode {
     v->Visit("span", &span);
   }
 
-  static constexpr const char* _type_key = "relay.TypeCall";
+  static constexpr const char* _type_key = "TypeCall";
   TVM_DECLARE_FINAL_OBJECT_INFO(TypeCallNode, TypeNode);
 };
 
@@ -119,7 +119,7 @@ class TypeReporterNode : public Object {
   // solver is not serializable.
   void VisitAttrs(AttrVisitor* v) {}
 
-  static constexpr const char* _type_key = "relay.TypeReporter";
+  static constexpr const char* _type_key = "TypeReporter";
   TVM_DECLARE_FINAL_OBJECT_INFO(TypeReporterNode, Object);
 };
 
@@ -195,7 +195,7 @@ class TypeRelationNode : public TypeConstraintNode {
     v->Visit("span", &span);
   }
 
-  static constexpr const char* _type_key = "relay.TypeRelation";
+  static constexpr const char* _type_key = "TypeRelation";
   TVM_DECLARE_FINAL_OBJECT_INFO(TypeRelationNode, TypeConstraintNode);
 };
 
diff --git a/python/tvm/ir/__init__.py b/python/tvm/ir/__init__.py
index 4160326..8418d63 100644
--- a/python/tvm/ir/__init__.py
+++ b/python/tvm/ir/__init__.py
@@ -17,7 +17,7 @@
 # pylint: disable=unused-import
 """Common data structures across all IR variants."""
 from .base import SourceName, Span, Node, EnvFunc, load_json, save_json
-from .type import Type, TypeKind, TypeVar, GlobalTypeVar, TupleType
+from .type import Type, TypeKind, PrimType, TypeVar, GlobalTypeVar, TupleType
 from .type import TypeConstraint, FuncType, IncompleteType, RelayRefType
 from .tensor_type import TensorType
 from .type_relation import TypeCall, TypeRelation
diff --git a/python/tvm/ir/base.py b/python/tvm/ir/base.py
index 944daa1..810d78f 100644
--- a/python/tvm/ir/base.py
+++ b/python/tvm/ir/base.py
@@ -56,7 +56,7 @@ class Node(Object):
         return _ffi_api.PrettyPrint(self)
 
 
-@tvm._ffi.register_object("relay.SourceName")
+@tvm._ffi.register_object("SourceName")
 class SourceName(Object):
     """A identifier for a source location.
 
@@ -69,7 +69,7 @@ class SourceName(Object):
         self.__init_handle_by_constructor__(_ffi_api.SourceName, name)
 
 
-@tvm._ffi.register_object("relay.Span")
+@tvm._ffi.register_object("Span")
 class Span(Object):
     """Specifies a location in a source program.
 
diff --git a/python/tvm/ir/expr.py b/python/tvm/ir/expr.py
index 4e6bf16..eedfff8 100644
--- a/python/tvm/ir/expr.py
+++ b/python/tvm/ir/expr.py
@@ -51,7 +51,7 @@ class RelayExpr(BaseExpr):
         return ret
 
 
-@tvm._ffi.register_object("relay.GlobalVar")
+@tvm._ffi.register_object("GlobalVar")
 class GlobalVar(RelayExpr):
     """A global variable in the IR.
 
diff --git a/python/tvm/ir/json_compact.py b/python/tvm/ir/json_compact.py
index d1cac95..10ecbaa 100644
--- a/python/tvm/ir/json_compact.py
+++ b/python/tvm/ir/json_compact.py
@@ -62,11 +62,35 @@ def create_updater_06_to_07():
         # set vindex to null
         nodes[vindex]["type_key"] = ""
         del item["attrs"]["var"]
+        assert item["type_key"].startswith("relay.")
+        item["type_key"] = item["type_key"][len("relay."):]
         return item
 
+    def _rename(new_name):
+        def _convert(item, _):
+            item["type_key"] = new_name
+            return item
+        return _convert
+
     node_map = {
         "relay.TypeVar": _ftype_var,
         "relay.GlobalTypeVar": _ftype_var,
+        "relay.Type": _rename("Type"),
+        "relay.TupleType": _rename("TupleType"),
+        "relay.TypeConstraint": _rename("TypeConstraint"),
+        "relay.FuncType": _rename("FuncType"),
+        "relay.IncompleteType": _rename("IncompleteType"),
+        "relay.TypeRelation": _rename("TypeRelation"),
+        "relay.TypeCall": _rename("TypeCall"),
+        "relay.Module": _rename("IRModule"),
+        "relay.SourceName": _rename("SourceName"),
+        "relay.Span": _rename("Span"),
+        "relay.GlobalVar": _rename("GlobalVar"),
+        "relay.Pass": _rename("transform.Pass"),
+        "relay.PassInfo": _rename("transform.PassInfo"),
+        "relay.PassContext": _rename("transform.PassContext"),
+        "relay.ModulePass": _rename("transform.ModulePass"),
+        "relay.Sequantial": _rename("transform.Sequantial"),
     }
     return create_updater(node_map, "0.6", "0.7")
 
diff --git a/python/tvm/ir/module.py b/python/tvm/ir/module.py
index 2d7481f..24f5211 100644
--- a/python/tvm/ir/module.py
+++ b/python/tvm/ir/module.py
@@ -24,7 +24,7 @@ from . import type as _ty
 from . import _ffi_api
 
 
-@tvm._ffi.register_object("relay.Module")
+@tvm._ffi.register_object("IRModule")
 class IRModule(Node):
     """IRModule that holds functions and type definitions.
 
diff --git a/python/tvm/ir/transform.py b/python/tvm/ir/transform.py
index a35feb3..cdb9257 100644
--- a/python/tvm/ir/transform.py
+++ b/python/tvm/ir/transform.py
@@ -27,7 +27,7 @@ from tvm.runtime import Object, ndarray as _nd
 
 from . import _ffi_transform_api
 
-@tvm._ffi.register_object("relay.PassInfo")
+@tvm._ffi.register_object("transform.PassInfo")
 class PassInfo(Object):
     """The class contains the meta data required by a pass. It is the
     container of information needed by running an optimization or analysis.
@@ -51,7 +51,7 @@ class PassInfo(Object):
             _ffi_transform_api.PassInfo, opt_level, name, required)
 
 
-@tvm._ffi.register_object("relay.PassContext")
+@tvm._ffi.register_object("transform.PassContext")
 class PassContext(Object):
     """The basis where a Relay optimization/analysis runs on.
     Each pass context contains a number of auxiliary information that is used
@@ -112,7 +112,7 @@ class PassContext(Object):
         return _ffi_transform_api.GetCurrentPassContext()
 
 
-@tvm._ffi.register_object("relay.Pass")
+@tvm._ffi.register_object("transform.Pass")
 class Pass(Object):
     """The base class of all passes. All methods here are just simple wrappers
     that are implemented in the backend. They are defined for users to
@@ -141,7 +141,7 @@ class Pass(Object):
         return _ffi_transform_api.RunPass(self, mod)
 
 
-@tvm._ffi.register_object("relay.ModulePass")
+@tvm._ffi.register_object("transform.ModulePass")
 class ModulePass(Pass):
     """A pass that works on tvm.IRModule. Users don't need to interact with
     this class directly. Instead, a module pass should be created through
@@ -152,7 +152,7 @@ class ModulePass(Pass):
     """
 
 
-@tvm._ffi.register_object("relay.Sequential")
+@tvm._ffi.register_object("transform.Sequential")
 class Sequential(Pass):
     """A pass that works on a sequence of pass objects. Multiple passes can be
     executed sequentially using this class.
diff --git a/python/tvm/ir/type.py b/python/tvm/ir/type.py
index ebe2aae..ebbb629 100644
--- a/python/tvm/ir/type.py
+++ b/python/tvm/ir/type.py
@@ -46,7 +46,20 @@ class TypeKind(IntEnum):
     TypeData = 6
 
 
-@tvm._ffi.register_object("relay.TypeVar")
+class PrimType(Type):
+    """Primitive data type in the low level IR
+
+    Parameters
+    ----------
+    dtype : str
+        The runtime data type relates to the primtype.
+    """
+    def __init__(self, dtype):
+        self.__init_handle_by_constructor__(
+            _ffi_api.PrimType, dtype)
+
+
+@tvm._ffi.register_object("TypeVar")
 class TypeVar(Type):
     """Type parameter in functions.
 
@@ -85,7 +98,7 @@ class TypeVar(Type):
         return TypeCall(self, args)
 
 
-@tvm._ffi.register_object("relay.GlobalTypeVar")
+@tvm._ffi.register_object("GlobalTypeVar")
 class GlobalTypeVar(Type):
     """A global type variable that is used for defining new types or type 
aliases.
 
@@ -120,7 +133,7 @@ class GlobalTypeVar(Type):
         return TypeCall(self, args)
 
 
-@tvm._ffi.register_object("relay.TupleType")
+@tvm._ffi.register_object("TupleType")
 class TupleType(Type):
     """The type of tuple values.
 
@@ -135,12 +148,12 @@ class TupleType(Type):
             _ffi_api.TupleType, fields)
 
 
-@tvm._ffi.register_object("relay.TypeConstraint")
+@tvm._ffi.register_object("TypeConstraint")
 class TypeConstraint(Type):
     """Abstract class representing a type constraint."""
 
 
-@tvm._ffi.register_object("relay.FuncType")
+@tvm._ffi.register_object("FuncType")
 class FuncType(Type):
     """Function type.
 
@@ -179,7 +192,7 @@ class FuncType(Type):
             _ffi_api.FuncType, arg_types, ret_type, type_params, 
type_constraints)
 
 
-@tvm._ffi.register_object("relay.IncompleteType")
+@tvm._ffi.register_object("IncompleteType")
 class IncompleteType(Type):
     """Incomplete type during type inference.
 
diff --git a/python/tvm/ir/type_relation.py b/python/tvm/ir/type_relation.py
index 63c83d9..bacb2c2 100644
--- a/python/tvm/ir/type_relation.py
+++ b/python/tvm/ir/type_relation.py
@@ -21,6 +21,7 @@ from .type import Type, TypeConstraint
 from . import _ffi_api
 
 
+@tvm._ffi.register_object("TypeCall")
 class TypeCall(Type):
     """Type function application.
 
@@ -41,7 +42,7 @@ class TypeCall(Type):
         self.__init_handle_by_constructor__(_ffi_api.TypeCall, func, args)
 
 
-@tvm._ffi.register_object("relay.TypeRelation")
+@tvm._ffi.register_object("TypeRelation")
 class TypeRelation(TypeConstraint):
     """User defined type relation, it is an input-output relation on types.
 
diff --git a/src/ir/transform.cc b/src/ir/transform.cc
index 2b5010b..6878abc 100644
--- a/src/ir/transform.cc
+++ b/src/ir/transform.cc
@@ -132,7 +132,7 @@ class ModulePassNode : public PassNode {
    */
   PassInfo Info() const override { return pass_info; }
 
-  static constexpr const char* _type_key = "relay.ModulePass";
+  static constexpr const char* _type_key = "transform.ModulePass";
   TVM_DECLARE_FINAL_OBJECT_INFO(ModulePassNode, PassNode);
 };
 
@@ -206,7 +206,7 @@ class SequentialNode : public PassNode {
    */
   IRModule operator()(const IRModule& mod, const PassContext& pass_ctx) const 
final;
 
-  static constexpr const char* _type_key = "relay.Sequential";
+  static constexpr const char* _type_key = "transform.Sequential";
   TVM_DECLARE_FINAL_OBJECT_INFO(SequentialNode, PassNode);
 };
 
diff --git a/tests/python/relay/test_ir_nodes.py 
b/tests/python/relay/test_ir_nodes.py
index d3d0808..968a3bb 100644
--- a/tests/python/relay/test_ir_nodes.py
+++ b/tests/python/relay/test_ir_nodes.py
@@ -30,13 +30,6 @@ def check_json_roundtrip(node):
     assert graph_equal(back, node)
 
 
-def test_bad_constructor():
-    try:
-        x = relay.ty.TensorType("xx", "xx")
-    except tvm.error.TVMError:
-        pass
-
-
 # Span
 def test_span():
     span = relay.Span(None, 1, 1)
@@ -55,71 +48,6 @@ def test_span():
     assert back.lineno == span.lineno
     assert back.col_offset == span.col_offset
 
-# Types
-
-def test_tensor_type():
-    shape = tvm.runtime.convert([1, 2, 3])
-    dtype = 'float32'
-    tt = relay.TensorType(shape, dtype)
-    assert tt.dtype == dtype
-    assert tt.shape == shape
-    assert tt.span == None
-    str(tt)
-    check_json_roundtrip(tt)
-
-
-def test_type_param():
-    tp = relay.TypeVar('name', relay.TypeKind.Type)
-    assert tp.kind == relay.TypeKind.Type
-    # assert tp.span  # TODO allow us to set span
-    str(tp)
-    check_json_roundtrip(tp)
-
-
-def test_func_type():
-    type_params = tvm.runtime.convert([])
-    type_constraints = tvm.runtime.convert([])  # TODO: fill me in
-    arg_types = tvm.runtime.convert([])
-    ret_type = relay.TensorType((1, 2, 3), 'float32')
-    tf = relay.FuncType(arg_types, ret_type, type_params, type_constraints)
-    assert tf.type_params == type_params
-    assert tf.type_constraints == type_constraints
-    assert tf.arg_types == arg_types
-    assert tf.ret_type == ret_type
-    assert tf.span == None
-    # TODO make sure we can set span
-    str(tf)
-    check_json_roundtrip(tf)
-
-
-def test_tuple_type():
-    tp = relay.TypeVar('tp', relay.TypeKind.Type)
-    tf = relay.FuncType(tvm.runtime.convert([]), None, 
tvm.runtime.convert([]), tvm.runtime.convert([]))
-    tt = relay.TensorType(tvm.runtime.convert([1, 2, 3]), 'float32')
-    fields = tvm.runtime.convert([tp, tf, tt])
-
-    tup_ty = relay.TupleType(fields)
-    assert tup_ty.fields == fields
-    str(tup_ty)
-    check_json_roundtrip(tup_ty)
-
-
-def test_type_relation():
-    tp = relay.TypeVar('tp', relay.TypeKind.Type)
-    tf = relay.FuncType(tvm.runtime.convert([]), None, 
tvm.runtime.convert([]), tvm.runtime.convert([]))
-    tt = relay.TensorType(tvm.runtime.convert([1, 2, 3]), 'float32')
-    args = tvm.runtime.convert([tp, tf, tt])
-
-    num_inputs = 2
-    func = tvm.ir.EnvFunc.get("tvm.relay.type_relation.Broadcast")
-    attrs = tvm.ir.make_node("attrs.TestAttrs", name="attr", padding=(3,4))
-
-    tr = relay.TypeRelation(func, args, num_inputs, attrs)
-    assert tr.args == args
-    assert tr.num_inputs == num_inputs
-    str(tr)
-    check_json_roundtrip(tr)
-
 
 def test_constant():
     arr = tvm.nd.array(10)
@@ -280,13 +208,7 @@ def test_conv2d_attrs():
 
 
 if __name__ == "__main__":
-    test_bad_constructor()
     test_span()
-    test_tensor_type()
-    test_type_param()
-    test_func_type()
-    test_tuple_type()
-    test_type_relation()
     test_constant()
     test_tuple()
     test_local_var()
diff --git a/tests/python/relay/test_json_compact.py 
b/tests/python/relay/test_json_compact.py
index 6316791..d58ddd5 100644
--- a/tests/python/relay/test_json_compact.py
+++ b/tests/python/relay/test_json_compact.py
@@ -17,7 +17,6 @@
 
 import tvm
 from tvm import te
-from tvm import relay
 import json
 
 def test_type_var():
@@ -36,13 +35,81 @@ def test_type_var():
         "b64ndarrays": [],
     }
     tvar = tvm.ir.load_json(json.dumps(data))
-    assert isinstance(tvar, relay.TypeVar)
+    assert isinstance(tvar, tvm.ir.TypeVar)
     assert tvar.name_hint == "in0"
     nodes[1]["type_key"] = "relay.GlobalTypeVar"
     tvar = tvm.ir.load_json(json.dumps(data))
-    assert isinstance(tvar, relay.GlobalTypeVar)
+    assert isinstance(tvar, tvm.ir.GlobalTypeVar)
     assert tvar.name_hint == "in0"
 
 
+def test_incomplete_type():
+    nodes = [
+        {"type_key": ""},
+        {"type_key": "relay.IncompleteType",
+         "attrs": {"kind": "0", "span": "0"}}]
+    data = {
+        "root" : 1,
+        "nodes": nodes,
+        "attrs": {"tvm_version": "0.6.0"},
+        "b64ndarrays": [],
+    }
+    tvar = tvm.ir.load_json(json.dumps(data))
+    assert isinstance(tvar, tvm.ir.IncompleteType)
+
+
+def test_func_tuple_type():
+    nodes = [
+        {"type_key": ""},
+        {"type_key": "relay.FuncType",
+         "attrs": {
+             "arg_types": "2",
+             "ret_type": "3",
+             "span": "0",
+             "type_constraints": "6",
+             "type_params": "5"
+         }
+        },
+        {"type_key": "Array"},
+        {"type_key": "relay.TupleType",
+         "attrs": { "fields": "4", "span": "0" }},
+        {"type_key": "Array"},
+        {"type_key": "Array"},
+        {"type_key": "Array"}
+    ]
+    data = {
+        "root" : 1,
+        "nodes": nodes,
+        "attrs": {"tvm_version": "0.6.0"},
+        "b64ndarrays": [],
+    }
+    tvar = tvm.ir.load_json(json.dumps(data))
+    assert isinstance(tvar, tvm.ir.FuncType)
+
+
+def test_global_var():
+    nodes = [
+        {"type_key": ""},
+        {"type_key": "relay.GlobalVar",
+         "attrs": {
+             "_checked_type_": "0",
+             "name_hint": "x",
+             "span": "0"
+         }
+        }
+    ]
+    data = {
+        "root" : 1,
+        "nodes": nodes,
+        "attrs": {"tvm_version": "0.6.0"},
+        "b64ndarrays": [],
+    }
+    tvar = tvm.ir.load_json(json.dumps(data))
+    assert isinstance(tvar, tvm.ir.GlobalVar)
+
+
 if __name__ == "__main__":
     test_type_var()
+    test_incomplete_type()
+    test_func_tuple_type()
+    test_global_var()
diff --git a/tests/python/unittest/test_ir_type.py 
b/tests/python/unittest/test_ir_type.py
new file mode 100644
index 0000000..f919f92
--- /dev/null
+++ b/tests/python/unittest/test_ir_type.py
@@ -0,0 +1,108 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test type nodes in the IR"""
+import tvm
+
+def check_json_roundtrip(node):
+    from tvm.relay.analysis import graph_equal
+    json_str = tvm.ir.save_json(node)
+    back = tvm.ir.load_json(json_str)
+    assert graph_equal(back, node)
+
+
+def test_prim_type():
+    x = tvm.ir.PrimType("int32")
+    assert isinstance(x, tvm.ir.PrimType)
+    assert x.dtype == "int32"
+
+
+def test_tensor_type_bad_constructor():
+    try:
+        x = tvm.ir.TensorType("xx", "xx")
+    except tvm.error.TVMError:
+        pass
+
+def test_tensor_type():
+    shape = tvm.runtime.convert([1, 2, 3])
+    dtype = 'float32'
+    tt = tvm.ir.TensorType(shape, dtype)
+    assert tt.dtype == dtype
+    assert tt.shape == shape
+    assert tt.span == None
+    str(tt)
+    check_json_roundtrip(tt)
+
+
+def test_type_param():
+    tp = tvm.ir.TypeVar('name', tvm.ir.TypeKind.Type)
+    assert tp.kind == tvm.ir.TypeKind.Type
+    # assert tp.span  # TODO allow us to set span
+    str(tp)
+    check_json_roundtrip(tp)
+
+
+def test_func_type():
+    type_params = tvm.runtime.convert([])
+    type_constraints = tvm.runtime.convert([])  # TODO: fill me in
+    arg_types = tvm.runtime.convert([])
+    ret_type = tvm.ir.TensorType((1, 2, 3), 'float32')
+    tf = tvm.ir.FuncType(arg_types, ret_type, type_params, type_constraints)
+    assert tf.type_params == type_params
+    assert tf.type_constraints == type_constraints
+    assert tf.arg_types == arg_types
+    assert tf.ret_type == ret_type
+    assert tf.span == None
+    # TODO make sure we can set span
+    str(tf)
+    check_json_roundtrip(tf)
+
+
+def test_tuple_type():
+    tp = tvm.ir.TypeVar('tp', tvm.ir.TypeKind.Type)
+    tf = tvm.ir.FuncType([], None, [], [])
+    tt = tvm.ir.TensorType(tvm.runtime.convert([1, 2, 3]), 'float32')
+    fields = tvm.runtime.convert([tp, tf, tt])
+
+    tup_ty = tvm.ir.TupleType(fields)
+    assert tup_ty.fields == fields
+    str(tup_ty)
+    check_json_roundtrip(tup_ty)
+
+def test_type_relation():
+    tp = tvm.ir.TypeVar('tp', tvm.ir.TypeKind.Type)
+    tf = tvm.ir.FuncType([], None, [], [])
+    tt = tvm.ir.TensorType(
+        tvm.runtime.convert([1, 2, 3]), 'float32')
+    args = tvm.runtime.convert([tp, tf, tt])
+
+    num_inputs = 2
+    func = tvm.ir.EnvFunc.get("tvm.relay.type_relation.Broadcast")
+    attrs = tvm.ir.make_node("attrs.TestAttrs", name="attr", padding=(3,4))
+
+    tr = tvm.ir.TypeRelation(func, args, num_inputs, attrs)
+    assert tr.args == args
+    assert tr.num_inputs == num_inputs
+    str(tr)
+    check_json_roundtrip(tr)
+
+if __name__ == "__main__":
+    test_tensor_type_bad_constructor()
+    test_tensor_type()
+    test_type_param()
+    test_func_type()
+    test_tuple_type()
+    test_type_relation()

Reply via email to