This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 6027412 [IR] Update the type_keys to reflect the code-org (#5074)
6027412 is described below
commit 6027412bcb443572088d71f1060370317eb6e671
Author: Tianqi Chen <[email protected]>
AuthorDate: Sun Mar 15 15:39:47 2020 -0700
[IR] Update the type_keys to reflect the code-org (#5074)
---
include/tvm/ir/expr.h | 2 +-
include/tvm/ir/module.h | 2 +-
include/tvm/ir/span.h | 4 +-
include/tvm/ir/transform.h | 6 +-
include/tvm/ir/type.h | 18 +++---
include/tvm/ir/type_relation.h | 6 +-
python/tvm/ir/__init__.py | 2 +-
python/tvm/ir/base.py | 4 +-
python/tvm/ir/expr.py | 2 +-
python/tvm/ir/json_compact.py | 24 +++++++
python/tvm/ir/module.py | 2 +-
python/tvm/ir/transform.py | 10 +--
python/tvm/ir/type.py | 25 ++++++--
python/tvm/ir/type_relation.py | 3 +-
src/ir/transform.cc | 4 +-
tests/python/relay/test_ir_nodes.py | 78 -----------------------
tests/python/relay/test_json_compact.py | 73 ++++++++++++++++++++-
tests/python/unittest/test_ir_type.py | 108 ++++++++++++++++++++++++++++++++
18 files changed, 255 insertions(+), 118 deletions(-)
diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h
index e37374a..c8b1a3f 100644
--- a/include/tvm/ir/expr.h
+++ b/include/tvm/ir/expr.h
@@ -196,7 +196,7 @@ class GlobalVarNode : public RelayExprNode {
v->Visit("_checked_type_", &checked_type_);
}
- static constexpr const char* _type_key = "relay.GlobalVar";
+ static constexpr const char* _type_key = "GlobalVar";
TVM_DECLARE_FINAL_OBJECT_INFO(GlobalVarNode, RelayExprNode);
};
diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h
index 23d1f6e..1ee7c32 100644
--- a/include/tvm/ir/module.h
+++ b/include/tvm/ir/module.h
@@ -226,7 +226,7 @@ class IRModuleNode : public Object {
*/
TVM_DLL std::unordered_set<std::string> Imports() const;
- static constexpr const char* _type_key = "relay.Module";
+ static constexpr const char* _type_key = "IRModule";
TVM_DECLARE_FINAL_OBJECT_INFO(IRModuleNode, Object);
private:
diff --git a/include/tvm/ir/span.h b/include/tvm/ir/span.h
index 8cbfff7..4720dfe 100644
--- a/include/tvm/ir/span.h
+++ b/include/tvm/ir/span.h
@@ -44,7 +44,7 @@ class SourceNameNode : public Object {
// override attr visitor
void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); }
- static constexpr const char* _type_key = "relay.SourceName";
+ static constexpr const char* _type_key = "SourceName";
TVM_DECLARE_FINAL_OBJECT_INFO(SourceNameNode, Object);
};
@@ -89,7 +89,7 @@ class SpanNode : public Object {
TVM_DLL static Span make(SourceName source, int lineno, int col_offset);
- static constexpr const char* _type_key = "relay.Span";
+ static constexpr const char* _type_key = "Span";
TVM_DECLARE_FINAL_OBJECT_INFO(SpanNode, Object);
};
diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h
index 2afcb17..1b6ea25 100644
--- a/include/tvm/ir/transform.h
+++ b/include/tvm/ir/transform.h
@@ -110,7 +110,7 @@ class PassContextNode : public Object {
v->Visit("disabled_pass", &disabled_pass);
}
- static constexpr const char* _type_key = "relay.PassContext";
+ static constexpr const char* _type_key = "transform.PassContext";
TVM_DECLARE_FINAL_OBJECT_INFO(PassContextNode, Object);
};
@@ -206,7 +206,7 @@ class PassInfoNode : public Object {
v->Visit("required", &required);
}
- static constexpr const char* _type_key = "relay.PassInfo";
+ static constexpr const char* _type_key = "transform.PassInfo";
TVM_DECLARE_FINAL_OBJECT_INFO(PassInfoNode, Object);
};
@@ -265,7 +265,7 @@ class PassNode : public Object {
void VisitAttrs(AttrVisitor* v) {}
- static constexpr const char* _type_key = "relay.Pass";
+ static constexpr const char* _type_key = "transform.Pass";
TVM_DECLARE_BASE_OBJECT_INFO(PassNode, Object);
};
diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h
index 7fd224b..a9475a1 100644
--- a/include/tvm/ir/type.h
+++ b/include/tvm/ir/type.h
@@ -78,7 +78,7 @@ class TypeNode : public Object {
*/
mutable Span span;
- static constexpr const char* _type_key = "relay.Type";
+ static constexpr const char* _type_key = "Type";
TVM_DECLARE_BASE_OBJECT_INFO(TypeNode, Object);
};
@@ -110,7 +110,7 @@ class PrimTypeNode : public TypeNode {
v->Visit("dtype", &dtype);
}
- static constexpr const char* _type_key = "relay.PrimType";
+ static constexpr const char* _type_key = "PrimType";
TVM_DECLARE_FINAL_OBJECT_INFO(PrimTypeNode, TypeNode);
};
@@ -175,7 +175,7 @@ class TypeVarNode : public TypeNode {
v->Visit("span", &span);
}
- static constexpr const char* _type_key = "relay.TypeVar";
+ static constexpr const char* _type_key = "TypeVar";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeVarNode, TypeNode);
};
@@ -215,7 +215,7 @@ class GlobalTypeVarNode : public TypeNode {
v->Visit("kind", &kind);
}
- static constexpr const char* _type_key = "relay.GlobalTypeVar";
+ static constexpr const char* _type_key = "GlobalTypeVar";
TVM_DECLARE_FINAL_OBJECT_INFO(GlobalTypeVarNode, TypeNode);
};
@@ -251,7 +251,7 @@ class TupleTypeNode : public TypeNode {
v->Visit("span", &span);
}
- static constexpr const char* _type_key = "relay.TupleType";
+ static constexpr const char* _type_key = "TupleType";
TVM_DECLARE_FINAL_OBJECT_INFO(TupleTypeNode, TypeNode);
};
@@ -289,7 +289,7 @@ inline Type VoidType() {
*/
class TypeConstraintNode : public TypeNode {
public:
- static constexpr const char* _type_key = "relay.TypeConstraint";
+ static constexpr const char* _type_key = "TypeConstraint";
TVM_DECLARE_BASE_OBJECT_INFO(TypeConstraintNode, TypeNode);
};
@@ -334,7 +334,7 @@ class FuncTypeNode : public TypeNode {
v->Visit("span", &span);
}
- static constexpr const char* _type_key = "relay.FuncType";
+ static constexpr const char* _type_key = "FuncType";
TVM_DECLARE_FINAL_OBJECT_INFO(FuncTypeNode, TypeNode);
};
@@ -380,7 +380,7 @@ class IncompleteTypeNode : public TypeNode {
v->Visit("span", &span);
}
- static constexpr const char* _type_key = "relay.IncompleteType";
+ static constexpr const char* _type_key = "IncompleteType";
TVM_DECLARE_FINAL_OBJECT_INFO(IncompleteTypeNode, TypeNode);
};
@@ -417,6 +417,8 @@ class RelayRefTypeNode : public TypeNode {
v->Visit("span", &span);
}
+ // Keep the relay prefix in the type as this type is specific
+ // to the relay itself.
static constexpr const char* _type_key = "relay.RefType";
TVM_DECLARE_FINAL_OBJECT_INFO(RelayRefTypeNode, TypeNode);
};
diff --git a/include/tvm/ir/type_relation.h b/include/tvm/ir/type_relation.h
index ff36b96..f7bfb68 100644
--- a/include/tvm/ir/type_relation.h
+++ b/include/tvm/ir/type_relation.h
@@ -50,7 +50,7 @@ class TypeCallNode : public TypeNode {
v->Visit("span", &span);
}
- static constexpr const char* _type_key = "relay.TypeCall";
+ static constexpr const char* _type_key = "TypeCall";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeCallNode, TypeNode);
};
@@ -119,7 +119,7 @@ class TypeReporterNode : public Object {
// solver is not serializable.
void VisitAttrs(AttrVisitor* v) {}
- static constexpr const char* _type_key = "relay.TypeReporter";
+ static constexpr const char* _type_key = "TypeReporter";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeReporterNode, Object);
};
@@ -195,7 +195,7 @@ class TypeRelationNode : public TypeConstraintNode {
v->Visit("span", &span);
}
- static constexpr const char* _type_key = "relay.TypeRelation";
+ static constexpr const char* _type_key = "TypeRelation";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeRelationNode, TypeConstraintNode);
};
diff --git a/python/tvm/ir/__init__.py b/python/tvm/ir/__init__.py
index 4160326..8418d63 100644
--- a/python/tvm/ir/__init__.py
+++ b/python/tvm/ir/__init__.py
@@ -17,7 +17,7 @@
# pylint: disable=unused-import
"""Common data structures across all IR variants."""
from .base import SourceName, Span, Node, EnvFunc, load_json, save_json
-from .type import Type, TypeKind, TypeVar, GlobalTypeVar, TupleType
+from .type import Type, TypeKind, PrimType, TypeVar, GlobalTypeVar, TupleType
from .type import TypeConstraint, FuncType, IncompleteType, RelayRefType
from .tensor_type import TensorType
from .type_relation import TypeCall, TypeRelation
diff --git a/python/tvm/ir/base.py b/python/tvm/ir/base.py
index 944daa1..810d78f 100644
--- a/python/tvm/ir/base.py
+++ b/python/tvm/ir/base.py
@@ -56,7 +56,7 @@ class Node(Object):
return _ffi_api.PrettyPrint(self)
-@tvm._ffi.register_object("relay.SourceName")
+@tvm._ffi.register_object("SourceName")
class SourceName(Object):
"""A identifier for a source location.
@@ -69,7 +69,7 @@ class SourceName(Object):
self.__init_handle_by_constructor__(_ffi_api.SourceName, name)
-@tvm._ffi.register_object("relay.Span")
+@tvm._ffi.register_object("Span")
class Span(Object):
"""Specifies a location in a source program.
diff --git a/python/tvm/ir/expr.py b/python/tvm/ir/expr.py
index 4e6bf16..eedfff8 100644
--- a/python/tvm/ir/expr.py
+++ b/python/tvm/ir/expr.py
@@ -51,7 +51,7 @@ class RelayExpr(BaseExpr):
return ret
-@tvm._ffi.register_object("relay.GlobalVar")
+@tvm._ffi.register_object("GlobalVar")
class GlobalVar(RelayExpr):
"""A global variable in the IR.
diff --git a/python/tvm/ir/json_compact.py b/python/tvm/ir/json_compact.py
index d1cac95..10ecbaa 100644
--- a/python/tvm/ir/json_compact.py
+++ b/python/tvm/ir/json_compact.py
@@ -62,11 +62,35 @@ def create_updater_06_to_07():
# set vindex to null
nodes[vindex]["type_key"] = ""
del item["attrs"]["var"]
+ assert item["type_key"].startswith("relay.")
+ item["type_key"] = item["type_key"][len("relay."):]
return item
+ def _rename(new_name):
+ def _convert(item, _):
+ item["type_key"] = new_name
+ return item
+ return _convert
+
node_map = {
"relay.TypeVar": _ftype_var,
"relay.GlobalTypeVar": _ftype_var,
+ "relay.Type": _rename("Type"),
+ "relay.TupleType": _rename("TupleType"),
+ "relay.TypeConstraint": _rename("TypeConstraint"),
+ "relay.FuncType": _rename("FuncType"),
+ "relay.IncompleteType": _rename("IncompleteType"),
+ "relay.TypeRelation": _rename("TypeRelation"),
+ "relay.TypeCall": _rename("TypeCall"),
+ "relay.Module": _rename("IRModule"),
+ "relay.SourceName": _rename("SourceName"),
+ "relay.Span": _rename("Span"),
+ "relay.GlobalVar": _rename("GlobalVar"),
+ "relay.Pass": _rename("transform.Pass"),
+ "relay.PassInfo": _rename("transform.PassInfo"),
+ "relay.PassContext": _rename("transform.PassContext"),
+ "relay.ModulePass": _rename("transform.ModulePass"),
+ "relay.Sequantial": _rename("transform.Sequantial"),
}
return create_updater(node_map, "0.6", "0.7")
diff --git a/python/tvm/ir/module.py b/python/tvm/ir/module.py
index 2d7481f..24f5211 100644
--- a/python/tvm/ir/module.py
+++ b/python/tvm/ir/module.py
@@ -24,7 +24,7 @@ from . import type as _ty
from . import _ffi_api
-@tvm._ffi.register_object("relay.Module")
+@tvm._ffi.register_object("IRModule")
class IRModule(Node):
"""IRModule that holds functions and type definitions.
diff --git a/python/tvm/ir/transform.py b/python/tvm/ir/transform.py
index a35feb3..cdb9257 100644
--- a/python/tvm/ir/transform.py
+++ b/python/tvm/ir/transform.py
@@ -27,7 +27,7 @@ from tvm.runtime import Object, ndarray as _nd
from . import _ffi_transform_api
-@tvm._ffi.register_object("relay.PassInfo")
+@tvm._ffi.register_object("transform.PassInfo")
class PassInfo(Object):
"""The class contains the meta data required by a pass. It is the
container of information needed by running an optimization or analysis.
@@ -51,7 +51,7 @@ class PassInfo(Object):
_ffi_transform_api.PassInfo, opt_level, name, required)
-@tvm._ffi.register_object("relay.PassContext")
+@tvm._ffi.register_object("transform.PassContext")
class PassContext(Object):
"""The basis where a Relay optimization/analysis runs on.
Each pass context contains a number of auxiliary information that is used
@@ -112,7 +112,7 @@ class PassContext(Object):
return _ffi_transform_api.GetCurrentPassContext()
-@tvm._ffi.register_object("relay.Pass")
+@tvm._ffi.register_object("transform.Pass")
class Pass(Object):
"""The base class of all passes. All methods here are just simple wrappers
that are implemented in the backend. They are defined for users to
@@ -141,7 +141,7 @@ class Pass(Object):
return _ffi_transform_api.RunPass(self, mod)
-@tvm._ffi.register_object("relay.ModulePass")
+@tvm._ffi.register_object("transform.ModulePass")
class ModulePass(Pass):
"""A pass that works on tvm.IRModule. Users don't need to interact with
this class directly. Instead, a module pass should be created through
@@ -152,7 +152,7 @@ class ModulePass(Pass):
"""
-@tvm._ffi.register_object("relay.Sequential")
+@tvm._ffi.register_object("transform.Sequential")
class Sequential(Pass):
"""A pass that works on a sequence of pass objects. Multiple passes can be
executed sequentially using this class.
diff --git a/python/tvm/ir/type.py b/python/tvm/ir/type.py
index ebe2aae..ebbb629 100644
--- a/python/tvm/ir/type.py
+++ b/python/tvm/ir/type.py
@@ -46,7 +46,20 @@ class TypeKind(IntEnum):
TypeData = 6
-@tvm._ffi.register_object("relay.TypeVar")
+class PrimType(Type):
+ """Primitive data type in the low level IR
+
+ Parameters
+ ----------
+ dtype : str
+ The runtime data type relates to the primtype.
+ """
+ def __init__(self, dtype):
+ self.__init_handle_by_constructor__(
+ _ffi_api.PrimType, dtype)
+
+
+@tvm._ffi.register_object("TypeVar")
class TypeVar(Type):
"""Type parameter in functions.
@@ -85,7 +98,7 @@ class TypeVar(Type):
return TypeCall(self, args)
-@tvm._ffi.register_object("relay.GlobalTypeVar")
+@tvm._ffi.register_object("GlobalTypeVar")
class GlobalTypeVar(Type):
"""A global type variable that is used for defining new types or type
aliases.
@@ -120,7 +133,7 @@ class GlobalTypeVar(Type):
return TypeCall(self, args)
-@tvm._ffi.register_object("relay.TupleType")
+@tvm._ffi.register_object("TupleType")
class TupleType(Type):
"""The type of tuple values.
@@ -135,12 +148,12 @@ class TupleType(Type):
_ffi_api.TupleType, fields)
-@tvm._ffi.register_object("relay.TypeConstraint")
+@tvm._ffi.register_object("TypeConstraint")
class TypeConstraint(Type):
"""Abstract class representing a type constraint."""
-@tvm._ffi.register_object("relay.FuncType")
+@tvm._ffi.register_object("FuncType")
class FuncType(Type):
"""Function type.
@@ -179,7 +192,7 @@ class FuncType(Type):
_ffi_api.FuncType, arg_types, ret_type, type_params,
type_constraints)
-@tvm._ffi.register_object("relay.IncompleteType")
+@tvm._ffi.register_object("IncompleteType")
class IncompleteType(Type):
"""Incomplete type during type inference.
diff --git a/python/tvm/ir/type_relation.py b/python/tvm/ir/type_relation.py
index 63c83d9..bacb2c2 100644
--- a/python/tvm/ir/type_relation.py
+++ b/python/tvm/ir/type_relation.py
@@ -21,6 +21,7 @@ from .type import Type, TypeConstraint
from . import _ffi_api
+@tvm._ffi.register_object("TypeCall")
class TypeCall(Type):
"""Type function application.
@@ -41,7 +42,7 @@ class TypeCall(Type):
self.__init_handle_by_constructor__(_ffi_api.TypeCall, func, args)
-@tvm._ffi.register_object("relay.TypeRelation")
+@tvm._ffi.register_object("TypeRelation")
class TypeRelation(TypeConstraint):
"""User defined type relation, it is an input-output relation on types.
diff --git a/src/ir/transform.cc b/src/ir/transform.cc
index 2b5010b..6878abc 100644
--- a/src/ir/transform.cc
+++ b/src/ir/transform.cc
@@ -132,7 +132,7 @@ class ModulePassNode : public PassNode {
*/
PassInfo Info() const override { return pass_info; }
- static constexpr const char* _type_key = "relay.ModulePass";
+ static constexpr const char* _type_key = "transform.ModulePass";
TVM_DECLARE_FINAL_OBJECT_INFO(ModulePassNode, PassNode);
};
@@ -206,7 +206,7 @@ class SequentialNode : public PassNode {
*/
IRModule operator()(const IRModule& mod, const PassContext& pass_ctx) const
final;
- static constexpr const char* _type_key = "relay.Sequential";
+ static constexpr const char* _type_key = "transform.Sequential";
TVM_DECLARE_FINAL_OBJECT_INFO(SequentialNode, PassNode);
};
diff --git a/tests/python/relay/test_ir_nodes.py
b/tests/python/relay/test_ir_nodes.py
index d3d0808..968a3bb 100644
--- a/tests/python/relay/test_ir_nodes.py
+++ b/tests/python/relay/test_ir_nodes.py
@@ -30,13 +30,6 @@ def check_json_roundtrip(node):
assert graph_equal(back, node)
-def test_bad_constructor():
- try:
- x = relay.ty.TensorType("xx", "xx")
- except tvm.error.TVMError:
- pass
-
-
# Span
def test_span():
span = relay.Span(None, 1, 1)
@@ -55,71 +48,6 @@ def test_span():
assert back.lineno == span.lineno
assert back.col_offset == span.col_offset
-# Types
-
-def test_tensor_type():
- shape = tvm.runtime.convert([1, 2, 3])
- dtype = 'float32'
- tt = relay.TensorType(shape, dtype)
- assert tt.dtype == dtype
- assert tt.shape == shape
- assert tt.span == None
- str(tt)
- check_json_roundtrip(tt)
-
-
-def test_type_param():
- tp = relay.TypeVar('name', relay.TypeKind.Type)
- assert tp.kind == relay.TypeKind.Type
- # assert tp.span # TODO allow us to set span
- str(tp)
- check_json_roundtrip(tp)
-
-
-def test_func_type():
- type_params = tvm.runtime.convert([])
- type_constraints = tvm.runtime.convert([]) # TODO: fill me in
- arg_types = tvm.runtime.convert([])
- ret_type = relay.TensorType((1, 2, 3), 'float32')
- tf = relay.FuncType(arg_types, ret_type, type_params, type_constraints)
- assert tf.type_params == type_params
- assert tf.type_constraints == type_constraints
- assert tf.arg_types == arg_types
- assert tf.ret_type == ret_type
- assert tf.span == None
- # TODO make sure we can set span
- str(tf)
- check_json_roundtrip(tf)
-
-
-def test_tuple_type():
- tp = relay.TypeVar('tp', relay.TypeKind.Type)
- tf = relay.FuncType(tvm.runtime.convert([]), None,
tvm.runtime.convert([]), tvm.runtime.convert([]))
- tt = relay.TensorType(tvm.runtime.convert([1, 2, 3]), 'float32')
- fields = tvm.runtime.convert([tp, tf, tt])
-
- tup_ty = relay.TupleType(fields)
- assert tup_ty.fields == fields
- str(tup_ty)
- check_json_roundtrip(tup_ty)
-
-
-def test_type_relation():
- tp = relay.TypeVar('tp', relay.TypeKind.Type)
- tf = relay.FuncType(tvm.runtime.convert([]), None,
tvm.runtime.convert([]), tvm.runtime.convert([]))
- tt = relay.TensorType(tvm.runtime.convert([1, 2, 3]), 'float32')
- args = tvm.runtime.convert([tp, tf, tt])
-
- num_inputs = 2
- func = tvm.ir.EnvFunc.get("tvm.relay.type_relation.Broadcast")
- attrs = tvm.ir.make_node("attrs.TestAttrs", name="attr", padding=(3,4))
-
- tr = relay.TypeRelation(func, args, num_inputs, attrs)
- assert tr.args == args
- assert tr.num_inputs == num_inputs
- str(tr)
- check_json_roundtrip(tr)
-
def test_constant():
arr = tvm.nd.array(10)
@@ -280,13 +208,7 @@ def test_conv2d_attrs():
if __name__ == "__main__":
- test_bad_constructor()
test_span()
- test_tensor_type()
- test_type_param()
- test_func_type()
- test_tuple_type()
- test_type_relation()
test_constant()
test_tuple()
test_local_var()
diff --git a/tests/python/relay/test_json_compact.py
b/tests/python/relay/test_json_compact.py
index 6316791..d58ddd5 100644
--- a/tests/python/relay/test_json_compact.py
+++ b/tests/python/relay/test_json_compact.py
@@ -17,7 +17,6 @@
import tvm
from tvm import te
-from tvm import relay
import json
def test_type_var():
@@ -36,13 +35,81 @@ def test_type_var():
"b64ndarrays": [],
}
tvar = tvm.ir.load_json(json.dumps(data))
- assert isinstance(tvar, relay.TypeVar)
+ assert isinstance(tvar, tvm.ir.TypeVar)
assert tvar.name_hint == "in0"
nodes[1]["type_key"] = "relay.GlobalTypeVar"
tvar = tvm.ir.load_json(json.dumps(data))
- assert isinstance(tvar, relay.GlobalTypeVar)
+ assert isinstance(tvar, tvm.ir.GlobalTypeVar)
assert tvar.name_hint == "in0"
+def test_incomplete_type():
+ nodes = [
+ {"type_key": ""},
+ {"type_key": "relay.IncompleteType",
+ "attrs": {"kind": "0", "span": "0"}}]
+ data = {
+ "root" : 1,
+ "nodes": nodes,
+ "attrs": {"tvm_version": "0.6.0"},
+ "b64ndarrays": [],
+ }
+ tvar = tvm.ir.load_json(json.dumps(data))
+ assert isinstance(tvar, tvm.ir.IncompleteType)
+
+
+def test_func_tuple_type():
+ nodes = [
+ {"type_key": ""},
+ {"type_key": "relay.FuncType",
+ "attrs": {
+ "arg_types": "2",
+ "ret_type": "3",
+ "span": "0",
+ "type_constraints": "6",
+ "type_params": "5"
+ }
+ },
+ {"type_key": "Array"},
+ {"type_key": "relay.TupleType",
+ "attrs": { "fields": "4", "span": "0" }},
+ {"type_key": "Array"},
+ {"type_key": "Array"},
+ {"type_key": "Array"}
+ ]
+ data = {
+ "root" : 1,
+ "nodes": nodes,
+ "attrs": {"tvm_version": "0.6.0"},
+ "b64ndarrays": [],
+ }
+ tvar = tvm.ir.load_json(json.dumps(data))
+ assert isinstance(tvar, tvm.ir.FuncType)
+
+
+def test_global_var():
+ nodes = [
+ {"type_key": ""},
+ {"type_key": "relay.GlobalVar",
+ "attrs": {
+ "_checked_type_": "0",
+ "name_hint": "x",
+ "span": "0"
+ }
+ }
+ ]
+ data = {
+ "root" : 1,
+ "nodes": nodes,
+ "attrs": {"tvm_version": "0.6.0"},
+ "b64ndarrays": [],
+ }
+ tvar = tvm.ir.load_json(json.dumps(data))
+ assert isinstance(tvar, tvm.ir.GlobalVar)
+
+
if __name__ == "__main__":
test_type_var()
+ test_incomplete_type()
+ test_func_tuple_type()
+ test_global_var()
diff --git a/tests/python/unittest/test_ir_type.py
b/tests/python/unittest/test_ir_type.py
new file mode 100644
index 0000000..f919f92
--- /dev/null
+++ b/tests/python/unittest/test_ir_type.py
@@ -0,0 +1,108 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test type nodes in the IR"""
+import tvm
+
+def check_json_roundtrip(node):
+ from tvm.relay.analysis import graph_equal
+ json_str = tvm.ir.save_json(node)
+ back = tvm.ir.load_json(json_str)
+ assert graph_equal(back, node)
+
+
+def test_prim_type():
+ x = tvm.ir.PrimType("int32")
+ assert isinstance(x, tvm.ir.PrimType)
+ assert x.dtype == "int32"
+
+
+def test_tensor_type_bad_constructor():
+ try:
+ x = tvm.ir.TensorType("xx", "xx")
+ except tvm.error.TVMError:
+ pass
+
+def test_tensor_type():
+ shape = tvm.runtime.convert([1, 2, 3])
+ dtype = 'float32'
+ tt = tvm.ir.TensorType(shape, dtype)
+ assert tt.dtype == dtype
+ assert tt.shape == shape
+ assert tt.span == None
+ str(tt)
+ check_json_roundtrip(tt)
+
+
+def test_type_param():
+ tp = tvm.ir.TypeVar('name', tvm.ir.TypeKind.Type)
+ assert tp.kind == tvm.ir.TypeKind.Type
+ # assert tp.span # TODO allow us to set span
+ str(tp)
+ check_json_roundtrip(tp)
+
+
+def test_func_type():
+ type_params = tvm.runtime.convert([])
+ type_constraints = tvm.runtime.convert([]) # TODO: fill me in
+ arg_types = tvm.runtime.convert([])
+ ret_type = tvm.ir.TensorType((1, 2, 3), 'float32')
+ tf = tvm.ir.FuncType(arg_types, ret_type, type_params, type_constraints)
+ assert tf.type_params == type_params
+ assert tf.type_constraints == type_constraints
+ assert tf.arg_types == arg_types
+ assert tf.ret_type == ret_type
+ assert tf.span == None
+ # TODO make sure we can set span
+ str(tf)
+ check_json_roundtrip(tf)
+
+
+def test_tuple_type():
+ tp = tvm.ir.TypeVar('tp', tvm.ir.TypeKind.Type)
+ tf = tvm.ir.FuncType([], None, [], [])
+ tt = tvm.ir.TensorType(tvm.runtime.convert([1, 2, 3]), 'float32')
+ fields = tvm.runtime.convert([tp, tf, tt])
+
+ tup_ty = tvm.ir.TupleType(fields)
+ assert tup_ty.fields == fields
+ str(tup_ty)
+ check_json_roundtrip(tup_ty)
+
+def test_type_relation():
+ tp = tvm.ir.TypeVar('tp', tvm.ir.TypeKind.Type)
+ tf = tvm.ir.FuncType([], None, [], [])
+ tt = tvm.ir.TensorType(
+ tvm.runtime.convert([1, 2, 3]), 'float32')
+ args = tvm.runtime.convert([tp, tf, tt])
+
+ num_inputs = 2
+ func = tvm.ir.EnvFunc.get("tvm.relay.type_relation.Broadcast")
+ attrs = tvm.ir.make_node("attrs.TestAttrs", name="attr", padding=(3,4))
+
+ tr = tvm.ir.TypeRelation(func, args, num_inputs, attrs)
+ assert tr.args == args
+ assert tr.num_inputs == num_inputs
+ str(tr)
+ check_json_roundtrip(tr)
+
+if __name__ == "__main__":
+ test_tensor_type_bad_constructor()
+ test_tensor_type()
+ test_type_param()
+ test_func_type()
+ test_tuple_type()
+ test_type_relation()