This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 64ac555 [Object] Restore the StrMap behavior in JSON/SHash/SEqual
(#5719)
64ac555 is described below
commit 64ac55506392f723ff22fb7147ccd4547fbd64fd
Author: Junru Shao <[email protected]>
AuthorDate: Wed Jun 3 13:34:44 2020 -0700
[Object] Restore the StrMap behavior in JSON/SHash/SEqual (#5719)
---
include/tvm/node/container.h | 5 +--
python/tvm/ir/json_compact.py | 1 +
src/node/container.cc | 59 +++++++++++++++++++--------------
src/node/serialization.cc | 29 +++++++++++-----
tests/python/relay/test_json_compact.py | 29 ++++++++++++++++
5 files changed, 89 insertions(+), 34 deletions(-)
diff --git a/include/tvm/node/container.h b/include/tvm/node/container.h
index 1a7a8df..a3cfdaf 100644
--- a/include/tvm/node/container.h
+++ b/include/tvm/node/container.h
@@ -50,6 +50,7 @@ using runtime::ObjectRef;
using runtime::String;
using runtime::StringObj;
+/*! \brief String-aware ObjectRef hash functor */
struct ObjectHash {
size_t operator()(const ObjectRef& a) const {
if (const auto* str = a.as<StringObj>()) {
@@ -59,6 +60,7 @@ struct ObjectHash {
}
};
+/*! \brief String-aware ObjectRef equal functor */
struct ObjectEqual {
bool operator()(const ObjectRef& a, const ObjectRef& b) const {
if (a.same_as(b)) {
@@ -96,8 +98,7 @@ class MapNode : public Object {
* \tparam V The value NodeRef type.
*/
template <typename K, typename V,
- typename = typename std::enable_if<std::is_base_of<ObjectRef,
K>::value ||
- std::is_base_of<std::string,
K>::value>::type,
+ typename = typename std::enable_if<std::is_base_of<ObjectRef,
K>::value>::type,
typename = typename std::enable_if<std::is_base_of<ObjectRef,
V>::value>::type>
class Map : public ObjectRef {
public:
diff --git a/python/tvm/ir/json_compact.py b/python/tvm/ir/json_compact.py
index 6fc24c0..2facc79 100644
--- a/python/tvm/ir/json_compact.py
+++ b/python/tvm/ir/json_compact.py
@@ -129,6 +129,7 @@ def create_updater_06_to_07():
"relay.PassContext": _rename("transform.PassContext"),
"relay.ModulePass": _rename("transform.ModulePass"),
"relay.Sequential": _rename("transform.Sequential"),
+ "StrMap": _rename("Map"),
# TIR
"Variable": [_update_tir_var("tir.Var"), _update_from_std_str("name")],
"SizeVar": [_update_tir_var("tir.SizeVar"),
_update_from_std_str("name")],
diff --git a/src/node/container.cc b/src/node/container.cc
index f7b9dd3..bdebb7f 100644
--- a/src/node/container.cc
+++ b/src/node/container.cc
@@ -247,40 +247,51 @@ struct MapNodeTrait {
}
static void SHashReduce(const MapNode* key, SHashReducer hash_reduce) {
- if (key->data.empty()) {
- hash_reduce(uint64_t(0));
- return;
- }
- if (key->data.begin()->first->IsInstance<StringObj>()) {
+ bool is_str_map = std::all_of(key->data.begin(), key->data.end(), [](const
auto& v) {
+ return v.first->template IsInstance<StringObj>();
+ });
+ if (is_str_map) {
SHashReduceForSMap(key, hash_reduce);
} else {
SHashReduceForOMap(key, hash_reduce);
}
}
+ static bool SEqualReduceForOMap(const MapNode* lhs, const MapNode* rhs,
SEqualReducer equal) {
+ for (const auto& kv : lhs->data) {
+ // Only allow equal checking if the keys are already mapped
+ // This resolves common use cases where we want to store
+ // Map<Var, Value> where Var is defined in the function
+ // parameters.
+ ObjectRef rhs_key = equal->MapLhsToRhs(kv.first);
+ if (!rhs_key.defined()) return false;
+ auto it = rhs->data.find(rhs_key);
+ if (it == rhs->data.end()) return false;
+ if (!equal(kv.second, it->second)) return false;
+ }
+ return true;
+ }
+
+ static bool SEqualReduceForSMap(const MapNode* lhs, const MapNode* rhs,
SEqualReducer equal) {
+ for (const auto& kv : lhs->data) {
+ auto it = rhs->data.find(kv.first);
+ if (it == rhs->data.end()) return false;
+ if (!equal(kv.second, it->second)) return false;
+ }
+ return true;
+ }
+
static bool SEqualReduce(const MapNode* lhs, const MapNode* rhs,
SEqualReducer equal) {
if (rhs->data.size() != lhs->data.size()) return false;
if (rhs->data.size() == 0) return true;
- if (lhs->data.begin()->first->IsInstance<StringObj>()) {
- for (const auto& kv : lhs->data) {
- auto it = rhs->data.find(kv.first);
- if (it == rhs->data.end()) return false;
- if (!equal(kv.second, it->second)) return false;
- }
- } else {
- for (const auto& kv : lhs->data) {
- // Only allow equal checking if the keys are already mapped
- // This resolves common use cases where we want to store
- // Map<Var, Value> where Var is defined in the function
- // parameters.
- ObjectRef rhs_key = equal->MapLhsToRhs(kv.first);
- if (!rhs_key.defined()) return false;
- auto it = rhs->data.find(rhs_key);
- if (it == rhs->data.end()) return false;
- if (!equal(kv.second, it->second)) return false;
- }
+ bool ls = std::all_of(lhs->data.begin(), lhs->data.end(),
+ [](const auto& v) { return v.first->template
IsInstance<StringObj>(); });
+ bool rs = std::all_of(rhs->data.begin(), rhs->data.end(),
+ [](const auto& v) { return v.first->template
IsInstance<StringObj>(); });
+ if (ls != rs) {
+ return false;
}
- return true;
+ return (ls && rs) ? SEqualReduceForSMap(lhs, rhs, equal) :
SEqualReduceForOMap(lhs, rhs, equal);
}
};
diff --git a/src/node/serialization.cc b/src/node/serialization.cc
index 9845a6f..3866533 100644
--- a/src/node/serialization.cc
+++ b/src/node/serialization.cc
@@ -110,11 +110,18 @@ class NodeIndexer : public AttrVisitor {
}
} else if (node->IsInstance<MapNode>()) {
MapNode* n = static_cast<MapNode*>(node);
- for (const auto& kv : n->data) {
- if (!kv.first->IsInstance<StringObj>()) {
+ bool is_str_map = std::all_of(n->data.begin(), n->data.end(), [](const
auto& v) {
+ return v.first->template IsInstance<StringObj>();
+ });
+ if (is_str_map) {
+ for (const auto& kv : n->data) {
+ MakeIndex(const_cast<Object*>(kv.second.get()));
+ }
+ } else {
+ for (const auto& kv : n->data) {
MakeIndex(const_cast<Object*>(kv.first.get()));
+ MakeIndex(const_cast<Object*>(kv.second.get()));
}
- MakeIndex(const_cast<Object*>(kv.second.get()));
}
} else {
// if the node already have repr bytes, no need to visit Attrs.
@@ -246,13 +253,19 @@ class JSONAttrGetter : public AttrVisitor {
}
} else if (node->IsInstance<MapNode>()) {
MapNode* n = static_cast<MapNode*>(node);
- for (const auto& kv : n->data) {
- if (const auto* str = kv.first.as<StringObj>()) {
- node_->keys.push_back(std::string(str->data, str->size));
- } else {
+ bool is_str_map = std::all_of(n->data.begin(), n->data.end(), [](const
auto& v) {
+ return v.first->template IsInstance<StringObj>();
+ });
+ if (is_str_map) {
+ for (const auto& kv : n->data) {
+ node_->keys.push_back(Downcast<String>(kv.first));
+
node_->data.push_back(node_index_->at(const_cast<Object*>(kv.second.get())));
+ }
+ } else {
+ for (const auto& kv : n->data) {
node_->data.push_back(node_index_->at(const_cast<Object*>(kv.first.get())));
+
node_->data.push_back(node_index_->at(const_cast<Object*>(kv.second.get())));
}
-
node_->data.push_back(node_index_->at(const_cast<Object*>(kv.second.get())));
}
} else {
// recursively index normal object.
diff --git a/tests/python/relay/test_json_compact.py
b/tests/python/relay/test_json_compact.py
index c961f99..00d41f0 100644
--- a/tests/python/relay/test_json_compact.py
+++ b/tests/python/relay/test_json_compact.py
@@ -186,6 +186,34 @@ def test_tir_var():
assert y.name == "y"
+def test_str_map():
+ nodes = [
+ {'type_key': ''},
+ {'type_key': 'StrMap', 'keys': ['z', 'x'], 'data': [2, 3]},
+ {'type_key': 'IntImm', 'attrs': {'dtype': 'int32', 'value': '2'}},
+ {'type_key': 'Max', 'attrs': {'a': '4', 'b': '10', 'dtype': 'int32'}},
+ {'type_key': 'Add', 'attrs': {'a': '5', 'b': '9', 'dtype': 'int32'}},
+ {'type_key': 'Add', 'attrs': {'a': '6', 'b': '8', 'dtype': 'int32'}},
+ {'type_key': 'tir.Var', 'attrs': {'dtype': 'int32', 'name': '7',
'type_annotation': '0'}},
+ {'type_key': 'runtime.String', 'repr_str': 'x'},
+ {'type_key': 'IntImm', 'attrs': {'dtype': 'int32', 'value': '1'}},
+ {'type_key': 'IntImm', 'attrs': {'dtype': 'int32', 'value': '2'}},
+ {'type_key': 'IntImm', 'attrs': {'dtype': 'int32', 'value': '100'}}
+ ]
+ data = {
+ "root" : 1,
+ "nodes": nodes,
+ "attrs": {"tvm_version": "0.6.0"},
+ "b64ndarrays": [],
+ }
+ x = tvm.ir.load_json(json.dumps(data))
+ assert(isinstance(x, tvm.ir.container.Map))
+ assert(len(x) == 2)
+ assert('x' in x)
+ assert('z' in x)
+ assert(bool(x['z'] == 2))
+
+
if __name__ == "__main__":
test_op()
test_type_var()
@@ -194,3 +222,4 @@ if __name__ == "__main__":
test_func_tuple_type()
test_global_var()
test_tir_var()
+ test_str_map()