This is an automated email from the ASF dual-hosted git repository.
csullivan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 24847c5515 [IR] Implemented Variant<...> container (#15672)
24847c5515 is described below
commit 24847c55151825ebf4c655cb2e3c5c09c61b48c8
Author: Eric Lunderberg <[email protected]>
AuthorDate: Wed Sep 13 16:56:21 2023 -0500
[IR] Implemented Variant<...> container (#15672)
* [IR] Implemented Variant<...> container
This commit introduces a new container, `Variant`, which is analogous to
the `std::variant` introduced in C++17, the `enum` in Rust, or a
tagged union in C. The `Variant` class is templated over the types that
it may contain (e.g. `Variant<String, Expr>`), where each type is a
distinct option that can be stored within the container.
`Variant` is implemented as a subclass of `ObjectRef` with no additional
data members, similar to the implementation of `Optional<T>`. It can be
constructed from any of its contained types, and the contents can be
inspected using the usual `my_object.as<T>()` and
`Downcast<T>(my_object)` methods. This is intended to allow for drop-in
replacement of `ObjectRef` with `Variant<Type1, Type2, ...>` in places
that previously used a common base class.
To ensure that each variant can be uniquely retrieved, no type
stored within the variant may inherit from any other type within the
variant. This condition is checked at compile-time, with a
`static_assert` explaining the limitation. This condition is necessary
to mimic the semantics of `std::variant`, whose active member depends on
the compile-time type of an object. Without this condition, the
expression `Variant<PrimExpr, tir::Var> variant = PrimExpr(...)` could
populate either of the variants depending on the run-time type of an
object. Because the `Variant` class is primarily intended for use when
two types do not already inherit from each other, this limitation is not
expected to limit its utility.
There are several locations within the TVM codebase where this pattern
may be useful, and which are currently worked around various
strategies. (This PR does not alter any existing implementations,
instead introducing the `Variant` container that can be used in
subsequent PRs, if desired.)
* Workaround: Store a common base class. For example, the type of
`relax::TensorStructInfoNode::shape` is `Optional<Expr>`, with a
comment stating that it should be only `NullOpt`, `ShapeExpr`, or
`Var`. However, these restrictions are not checked by the compiler,
and a developer could erroneously provide a different type. By
expressing the type as as `Optional<Variant<Var,ShapeExpr>>`, these
errors could be automatically caught.
* Workaround: Use additional data structures. For example, a
`PrimFunc` parameter may be either a TIR primitive, which is lowered
to a primitive type, or a TIR Buffer, which is lowered to a
`DLTensor*` argument and appropriate unpacking code. However, these
two types are represented as an `Array<tir::Var>` and a
`Map<tir::Var, tir::Buffer>`, which together represent a
`Array<Variant<tir::Var, tir::Buffer>>`. The separate data
structures must be kept in sync whenever modified, such as when
removing a parameter.
* Workaround: Use `std::variant`. For example, the
`tvm::tir::IdentifyMemCpyImpl` utility function returns a
`std::variant` with the result or an error message. However, this
is only suitable for use within a C++ implementation, and requires a
wrapper in order to expose it to the FFI.
* Avoid ODR, lint errors for conversion to Variant
* Added more C++ functionality tests.
---
include/tvm/runtime/container/variant.h | 123 +++++++++++++++++++++++++++++
include/tvm/runtime/packed_func.h | 54 ++++++++++++-
src/support/ffi_testing.cc | 12 +++
tests/cpp/container_test.cc | 47 +++++++++++
tests/python/unittest/test_ir_container.py | 26 ++++++
5 files changed, 259 insertions(+), 3 deletions(-)
diff --git a/include/tvm/runtime/container/variant.h
b/include/tvm/runtime/container/variant.h
new file mode 100644
index 0000000000..7953ac47c1
--- /dev/null
+++ b/include/tvm/runtime/container/variant.h
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/runtime/container/variant.h
+ * \brief Runtime Variant container types.
+ */
+#ifndef TVM_RUNTIME_CONTAINER_VARIANT_H_
+#define TVM_RUNTIME_CONTAINER_VARIANT_H_
+
+#include <tvm/runtime/object.h>
+
+#include <tuple>
+#include <type_traits>
+#include <utility>
+
+namespace tvm {
+namespace runtime {
+
+namespace detail {
+template <typename Parent, typename ChildTuple>
+constexpr bool parent_is_base_of_any = false;
+
+template <typename Parent, typename... Child>
+constexpr bool parent_is_base_of_any<Parent, std::tuple<Child...>> =
+ ((std::is_base_of_v<Parent, Child> && !std::is_same_v<Parent, Child>) ||
...);
+
+/* \brief Utility to check if any parent is a base class of any child
+ *
+ * The type-checking in Variant relies on all types being from
+ * independent types, such that `Object::IsInstance` is sufficient to
+ * determine which variant is populated.
+ *
+ * For example, suppose the illegal `Variant<tir::Var, tir::PrimExpr>`
+ * were allowed (e.g. to represent either the defintion of a variable
+ * or the usage of a variable). If a function returned
+ * `tir::PrimExpr`, it could result in either variant being filled, as
+ * the underlying type at runtime could be a `tir::Var`. This
+ * behavior is different from `std::variant`, which determines the
+ * active variant based solely on the compile-time type, and could
+ * produce very unexpected results if the variants have different
+ * semantic interpretations.
+ */
+template <typename ParentTuple, typename ChildTuple>
+static constexpr bool any_parent_is_base_of_any_child = false;
+
+template <typename ChildTuple, typename... Parent>
+static constexpr bool any_parent_is_base_of_any_child<std::tuple<Parent...>,
ChildTuple> =
+ (parent_is_base_of_any<Parent, ChildTuple> || ...);
+} // namespace detail
+
+template <typename... V>
+class Variant : public ObjectRef {
+ static constexpr bool all_inherit_from_objectref =
(std::is_base_of_v<ObjectRef, V> && ...);
+ static_assert(all_inherit_from_objectref,
+ "All types used in Variant<...> must inherit from ObjectRef");
+
+ static constexpr bool a_variant_inherits_from_another_variant =
+ detail::any_parent_is_base_of_any_child<std::tuple<V...>,
std::tuple<V...>>;
+ static_assert(!a_variant_inherits_from_another_variant,
+ "Due to implementation limitations, "
+ "no type stored in a tvm::runtime::Variant "
+ "may be a subclass of any other type "
+ "stored in the same variant.");
+
+ public:
+ /* \brief Helper utility to check if the type is part of the variant */
+ template <typename T>
+ static constexpr bool is_variant = (std::is_same_v<T, V> || ...);
+
+ /* \brief Helper utility for SFINAE if the type is part of the variant */
+ template <typename T>
+ using enable_if_variant = std::enable_if_t<is_variant<T>>;
+
+ template <typename T, typename = enable_if_variant<T>>
+ Variant(T value) : ObjectRef(std::move(value)) {} // NOLINT(*)
+
+ template <typename T, typename = enable_if_variant<T>>
+ Variant& operator=(T value) {
+ ObjectRef::operator=(std::move(value));
+ return *this;
+ }
+
+ // These functions would normally be declared with the
+ // TVM_DEFINE_OBJECT_REF_METHODS macro. However, we need additional
+ // type-checking inside the ObjectPtr<Object> constructor.
+ using ContainerType = Object;
+ Variant() : ObjectRef() {}
+ explicit Variant(ObjectPtr<Object> node) : ObjectRef(node) {
+ CHECK(node == nullptr || (node->IsInstance<typename V::ContainerType>() ||
...))
+ << "Variant<"
+ << static_cast<const std::stringstream&>(
+ (std::stringstream() << ... << V::ContainerType::_type_key))
+ .str()
+ << "> cannot hold an object of type " << node->GetTypeKey();
+ }
+ TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(Variant);
+};
+
+} // namespace runtime
+
+// expose the functions to the root namespace.
+using runtime::Variant;
+
+} // namespace tvm
+
+#endif // TVM_RUNTIME_CONTAINER_VARIANT_H_
diff --git a/include/tvm/runtime/packed_func.h
b/include/tvm/runtime/packed_func.h
index 7aa8ef1ba7..caaaec3640 100644
--- a/include/tvm/runtime/packed_func.h
+++ b/include/tvm/runtime/packed_func.h
@@ -27,6 +27,7 @@
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/container/array.h>
#include <tvm/runtime/container/map.h>
+#include <tvm/runtime/container/variant.h>
#include <tvm/runtime/data_type.h>
#include <tvm/runtime/logging.h>
#include <tvm/runtime/module.h>
@@ -680,9 +681,6 @@ class TVMArgValue : public TVMPODValue_ {
} else if (type_code_ == kTVMStr) {
return std::string(value_.v_str);
} else {
- ICHECK(IsObjectRef<tvm::runtime::String>())
- << "Could not convert TVM object of type " <<
runtime::Object::TypeIndex2Key(type_code_)
- << " to a string.";
return AsObjectRef<tvm::runtime::String>().operator std::string();
}
}
@@ -2063,6 +2061,56 @@ struct PackedFuncValueConverter<Optional<T>> {
}
};
+template <typename... VariantTypes>
+struct PackedFuncValueConverter<Variant<VariantTypes...>> {
+ using VType = Variant<VariantTypes...>;
+
+ // Can't just take `const TVMPODValue&` as an argument, because
+ // `TVMArgValue` and `TVMRetValue` have different implementations
+ // for `operator std::string()`.
+ template <typename PODSubclass>
+ static VType From(const PODSubclass& val) {
+ if (auto opt = TryAsObjectRef<VariantTypes...>(val)) {
+ return opt.value();
+ }
+
+ if (auto opt = TryValueConverter<PODSubclass, VariantTypes...>(val)) {
+ return opt.value();
+ }
+
+ LOG(FATAL) << "Expected one of "
+ << static_cast<const std::stringstream&>(
+ (std::stringstream() << ... <<
VariantTypes::ContainerType::_type_key))
+ .str()
+ << " but got " << ArgTypeCode2Str(val.type_code());
+ }
+
+ template <typename VarFirst, typename... VarRest>
+ static Optional<VType> TryAsObjectRef(const TVMPODValue_& val) {
+ if (val.IsObjectRef<VarFirst>()) {
+ return VType(val.AsObjectRef<VarFirst>());
+ } else if constexpr (sizeof...(VarRest)) {
+ return TryAsObjectRef<VarRest...>(val);
+ } else {
+ return NullOpt;
+ }
+ }
+
+ template <typename PODSubclass, typename VarFirst, typename... VarRest>
+ static Optional<VType> TryValueConverter(const PODSubclass& val) {
+ try {
+ return VType(PackedFuncValueConverter<VarFirst>::From(val));
+ } catch (const InternalError&) {
+ }
+
+ if constexpr (sizeof...(VarRest)) {
+ return TryValueConverter<PODSubclass, VarRest...>(val);
+ } else {
+ return NullOpt;
+ }
+ }
+};
+
inline bool String::CanConvertFrom(const TVMArgValue& val) {
return val.type_code() == kTVMStr || val.IsObjectRef<tvm::runtime::String>();
}
diff --git a/src/support/ffi_testing.cc b/src/support/ffi_testing.cc
index e00b9b8d05..75b5a2527f 100644
--- a/src/support/ffi_testing.cc
+++ b/src/support/ffi_testing.cc
@@ -23,6 +23,7 @@
*/
#include <tvm/ir/attrs.h>
#include <tvm/ir/env_func.h>
+#include <tvm/runtime/container/variant.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h>
#include <tvm/te/tensor.h>
@@ -177,4 +178,15 @@
TVM_REGISTER_GLOBAL("testing.sleep_in_ffi").set_body_typed([](double timeout) {
std::this_thread::sleep_for(duration);
});
+TVM_REGISTER_GLOBAL("testing.ReturnsVariant").set_body_typed([](int x) ->
Variant<String, IntImm> {
+ if (x % 2 == 0) {
+ return IntImm(DataType::Int(64), x / 2);
+ } else {
+ return String("argument was odd");
+ }
+});
+
+TVM_REGISTER_GLOBAL("testing.AcceptsVariant")
+ .set_body_typed([](Variant<String, Integer> arg) -> String { return
arg->GetTypeKey(); });
+
} // namespace tvm
diff --git a/tests/cpp/container_test.cc b/tests/cpp/container_test.cc
index d75a510d0c..5c9af19f9b 100644
--- a/tests/cpp/container_test.cc
+++ b/tests/cpp/container_test.cc
@@ -23,6 +23,7 @@
#include <tvm/runtime/container/array.h>
#include <tvm/runtime/container/map.h>
#include <tvm/runtime/container/string.h>
+#include <tvm/runtime/container/variant.h>
#include <tvm/tir/function.h>
#include <tvm/tir/op.h>
@@ -853,3 +854,49 @@ TEST(Optional, PackedCall) {
test_ffi(s, static_cast<int>(kTVMObjectHandle));
test_ffi(String(s), static_cast<int>(kTVMObjectRValueRefArg));
}
+
+TEST(Variant, Construct) {
+ Variant<PrimExpr, String> variant;
+ variant = PrimExpr(1);
+ ICHECK(variant.as<PrimExpr>());
+ ICHECK(!variant.as<String>());
+
+ variant = String("hello");
+ ICHECK(variant.as<String>());
+ ICHECK(!variant.as<PrimExpr>());
+}
+
+TEST(Variant, InvalidTypeThrowsError) {
+ auto expected_to_throw = []() {
+ ObjectPtr<Object> node = make_object<Object>();
+ Variant<PrimExpr, String> variant(node);
+ };
+
+ EXPECT_THROW(expected_to_throw(), InternalError);
+}
+
+TEST(Variant, ReferenceIdentifyPreservedThroughAssignment) {
+ Variant<PrimExpr, String> variant;
+ ICHECK(!variant.defined());
+
+ String string_obj = "dummy_test";
+ variant = string_obj;
+ ICHECK(variant.defined());
+ ICHECK(variant.same_as(string_obj));
+ ICHECK(string_obj.same_as(variant));
+
+ String out_string_obj = Downcast<String>(variant);
+ ICHECK(string_obj.same_as(out_string_obj));
+}
+
+TEST(Variant, ExtractValueFromAssignment) {
+ Variant<PrimExpr, String> variant = String("hello");
+ ICHECK_EQ(variant.as<String>().value(), "hello");
+}
+
+TEST(Variant, AssignmentFromVariant) {
+ Variant<PrimExpr, String> variant = String("hello");
+ auto variant2 = variant;
+ ICHECK(variant2.as<String>());
+ ICHECK_EQ(variant2.as<String>().value(), "hello");
+}
diff --git a/tests/python/unittest/test_ir_container.py
b/tests/python/unittest/test_ir_container.py
index 1915849e10..aa482dd65c 100644
--- a/tests/python/unittest/test_ir_container.py
+++ b/tests/python/unittest/test_ir_container.py
@@ -112,5 +112,31 @@ def test_ndarray_container():
assert isinstance(arr[0], tvm.nd.NDArray)
+def test_return_variant_type():
+ func = tvm.get_global_func("testing.ReturnsVariant")
+ res_even = func(42)
+ assert isinstance(res_even, tvm.tir.IntImm)
+ assert res_even == 21
+
+ res_odd = func(17)
+ assert isinstance(res_odd, tvm.runtime.String)
+ assert res_odd == "argument was odd"
+
+
+def test_pass_variant_type():
+ func = tvm.get_global_func("testing.AcceptsVariant")
+
+ assert func("string arg") == "runtime.String"
+ assert func(17) == "IntImm"
+
+
+def test_pass_incorrect_variant_type():
+ func = tvm.get_global_func("testing.AcceptsVariant")
+ float_arg = tvm.tir.FloatImm("float32", 0.5)
+
+ with pytest.raises(Exception):
+ func(float_arg)
+
+
if __name__ == "__main__":
tvm.testing.main()