This is an automated email from the ASF dual-hosted git repository. andrewzhaoluo pushed a commit to branch aluo/rebase-08312022-autotensorization-fq2i-changes in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 18b8089564e3eda6caf42ede61a24a6d47efb841 Author: Andrew Zhao Luo <[email protected]> AuthorDate: Fri Sep 2 15:22:55 2022 -0700 final optional --- src/relay/transforms/fold_explicit_padding.cc | 13 ++--- src/relay/transforms/pattern_utils.h | 69 +++++++------------------- src/tir/transforms/common_subexpr_elim_tools.h | 5 +- src/tir/transforms/loop_partition.cc | 51 +++++++------------ 4 files changed, 44 insertions(+), 94 deletions(-) diff --git a/src/relay/transforms/fold_explicit_padding.cc b/src/relay/transforms/fold_explicit_padding.cc index 794bcfd3d0..37385f80c1 100644 --- a/src/relay/transforms/fold_explicit_padding.cc +++ b/src/relay/transforms/fold_explicit_padding.cc @@ -22,6 +22,7 @@ * \brief A pass for folding explicit pads into other ops. */ +#include <dmlc/optional.h> #include <tvm/relay/dataflow_matcher.h> #include <tvm/relay/expr.h> #include <tvm/relay/expr_functor.h> @@ -31,10 +32,6 @@ #include <tvm/tir/op.h> #include <tvm/topi/nn/pooling.h> -#include <optional> -#include <set> -#include <string> - #include "../op/tensor/transform.h" #include "pattern_utils.h" @@ -183,10 +180,10 @@ class SimplifyExplicitPad { return attrs; } - static const std::optional<Array<PrimExpr>> get_padding(const PadAttrs* param, - std::string data_layout) { + static const Optional<Array<PrimExpr>> get_padding(const PadAttrs* param, + std::string data_layout) { // Gets spatial axes padding from the given PadAttrs `param`. If padding - // is non-zero on non-spatial axes, return std::nullopt. + // is non-zero on non-spatial axes, return NullOpt. ICHECK(param); ICHECK(data_layout.size() == param->pad_width.size()) << "Data Layout and padding attributes should have the same extent"; @@ -199,7 +196,7 @@ class SimplifyExplicitPad { if (!image_dims.count(data_layout[i])) { for (size_t j = 0; j < param->pad_width[i].size(); ++j) { if (param->pad_width[i][j] != 0) { - return std::nullopt; + return NullOpt; } } } diff --git a/src/relay/transforms/pattern_utils.h b/src/relay/transforms/pattern_utils.h index ffe1cc2ca2..f71d84434d 100644 --- a/src/relay/transforms/pattern_utils.h +++ b/src/relay/transforms/pattern_utils.h @@ -27,6 +27,7 @@ #define TVM_RELAY_TRANSFORMS_PATTERN_UTILS_H_ #include <builtin_fp16.h> +#include <dmlc/optional.h> #include <tvm/node/structural_equal.h> #include <tvm/relay/analysis.h> #include <tvm/relay/attrs/nn.h> @@ -39,7 +40,6 @@ #include <tvm/tir/data_layout.h> #include <limits> -#include <optional> #include <string> #include <utility> #include <vector> @@ -344,40 +344,6 @@ static inline Constant MakeConstantTensor(DataType dtype, std::vector<int64_t> s return Constant(arr); } -/*! - * \brief Create a Constant tensor of zeros. - * - * \param dtype The data type. - * \param shape The shape of the output constant tensor. - * \return A Constant. - */ -static inline Constant MakeConstantZeros(DataType dtype, std::vector<int64_t> shape) { - runtime::NDArray arr = runtime::NDArray::Empty(shape, dtype, {kDLCPU, 0}); - int64_t data_size = 1; - for (int64_t dim : shape) { - data_size *= dim; - } - TVM_DTYPE_DISPATCH(dtype, DType, { - for (int64_t i = 0; i < data_size; i++) { - if (dtype == DataType::Float(16)) { - // convert to float16 - // storage is uint16_t - // Similar handling as that in MakeConstantScalar - *(static_cast<DType*>(arr->data) + i) = - __truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 10>(static_cast<float>(0)); - } else if (dtype == DataType::BFloat(16)) { - // convert to bfloat16 - // storage is uint16_t - *(static_cast<DType*>(arr->data) + i) = - __truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 7>(static_cast<float>(0)); - } else { - *(static_cast<DType*>(arr->data) + i) = 0; - } - } - }) - return Constant(arr); -} - /*! * \brief Check whether a shape is static and create corresponding Constant. Eventually this will be removed and replaced with CheckConstantShapeArrayInteger @@ -439,47 +405,48 @@ inline bool IsEqualScalar(const Expr& a, const Expr& b) { * \param i element index * \return Converted scalar value, or None if conversion failed */ -static inline std::optional<long double> TryToScalar(const runtime::NDArray& array, size_t i = 0) { +static inline dmlc::optional<long double> TryToScalar(const runtime::NDArray& array, size_t i = 0) { if (array->dtype.code == kDLInt) { if (array->dtype.bits == 8) { - return std::optional<long double>(reinterpret_cast<int8_t*>(array->data)[i]); + return dmlc::optional<long double>(reinterpret_cast<int8_t*>(array->data)[i]); } else if (array->dtype.bits == 16) { - return std::optional<long double>(reinterpret_cast<int16_t*>(array->data)[i]); + return dmlc::optional<long double>(reinterpret_cast<int16_t*>(array->data)[i]); } else if (array->dtype.bits == 32) { - return std::optional<long double>(reinterpret_cast<int32_t*>(array->data)[i]); + return dmlc::optional<long double>(reinterpret_cast<int32_t*>(array->data)[i]); } else if (array->dtype.bits == 64) { - return std::optional<long double>(reinterpret_cast<int64_t*>(array->data)[i]); + return dmlc::optional<long double>(reinterpret_cast<int64_t*>(array->data)[i]); } } else if (array->dtype.code == kDLUInt) { if (array->dtype.bits == 1) { // bool - return std::optional<long double>(reinterpret_cast<uint8_t*>(array->data)[i]); + return dmlc::optional<long double>(reinterpret_cast<uint8_t*>(array->data)[i]); } else if (array->dtype.bits == 8) { - return std::optional<long double>(reinterpret_cast<uint8_t*>(array->data)[i]); + return dmlc::optional<long double>(reinterpret_cast<uint8_t*>(array->data)[i]); } else if (array->dtype.bits == 16) { - return std::optional<long double>(reinterpret_cast<uint16_t*>(array->data)[i]); + return dmlc::optional<long double>(reinterpret_cast<uint16_t*>(array->data)[i]); } else if (array->dtype.bits == 32) { - return std::optional<long double>(reinterpret_cast<uint32_t*>(array->data)[i]); + return dmlc::optional<long double>(reinterpret_cast<uint32_t*>(array->data)[i]); } else if (array->dtype.bits == 64) { - return std::optional<long double>(reinterpret_cast<uint64_t*>(array->data)[i]); + return dmlc::optional<long double>(reinterpret_cast<uint64_t*>(array->data)[i]); } } else if (array->dtype.code == kDLFloat) { if (array->dtype.bits == 16) { - return std::optional<long double>( + return dmlc::optional<long double>( __extendXfYf2__<uint16_t, uint16_t, 10, float, uint32_t, 23>( reinterpret_cast<uint16_t*>(array->data)[i])); } if (array->dtype.bits == 32) { - return std::optional<long double>(reinterpret_cast<float*>(array->data)[i]); + return dmlc::optional<long double>(reinterpret_cast<float*>(array->data)[i]); } else if (array->dtype.bits == 64) { - return std::optional<long double>(reinterpret_cast<double*>(array->data)[i]); + return dmlc::optional<long double>(reinterpret_cast<double*>(array->data)[i]); } } else if (array->dtype.code == kDLBfloat) { if (array->dtype.bits == 16) { - return std::optional<long double>(__extendXfYf2__<uint16_t, uint16_t, 7, float, uint32_t, 23>( - reinterpret_cast<uint16_t*>(array->data)[i])); + return dmlc::optional<long double>( + __extendXfYf2__<uint16_t, uint16_t, 7, float, uint32_t, 23>( + reinterpret_cast<uint16_t*>(array->data)[i])); } } - return std::nullopt; + return dmlc::optional<long double>(); } /*! diff --git a/src/tir/transforms/common_subexpr_elim_tools.h b/src/tir/transforms/common_subexpr_elim_tools.h index 0871fd0091..fcd29fddc0 100644 --- a/src/tir/transforms/common_subexpr_elim_tools.h +++ b/src/tir/transforms/common_subexpr_elim_tools.h @@ -33,11 +33,12 @@ #include <tvm/tir/stmt.h> #include <tvm/tir/stmt_functor.h> // For the class StmtExprVisitor -#include <optional> #include <unordered_map> // For the hashtable datatype #include <utility> // For pairs datatype #include <vector> +#include "../../../3rdparty/dmlc-core/include/dmlc/optional.h" + namespace tvm { namespace tir { @@ -176,7 +177,7 @@ class UsesVarName : public StmtExprVisitor { */ void PrintComputationTable(const ComputationTable& table); -using MaybeValue = std::optional<PrimExpr>; +using MaybeValue = dmlc::optional<PrimExpr>; bool EqualTerms(const PrimExpr& a, const PrimExpr& b); // Used for deciding the (decidable) equivalence relation diff --git a/src/tir/transforms/loop_partition.cc b/src/tir/transforms/loop_partition.cc index 6ecc6459b9..677506889e 100644 --- a/src/tir/transforms/loop_partition.cc +++ b/src/tir/transforms/loop_partition.cc @@ -29,7 +29,6 @@ #include <tvm/tir/stmt_functor.h> #include <tvm/tir/transform.h> -#include <optional> #include <unordered_map> #include <unordered_set> @@ -554,39 +553,25 @@ Stmt LoopPartitioner::TryPartition(const Stmt& stmt, Var var, PrimExpr min, Prim if (finder.partitions.empty()) return Stmt(); arith::IntervalSet for_interval(min, max); - - auto [middle_interval, cond_set, - opt_cond_value] = [&]() -> std::tuple<IntSet, ExpressionSet, std::optional<bool>> { - { - // find an interval in which all conditions on var are true - auto [middle_interval, cond_set] = - GetIntervalAndCondset(finder.partitions, for_interval, true, has_partition_hint_); - if (!middle_interval.IsNothing()) { - return {middle_interval, cond_set, true}; - } - } - - { - // if such interval doesn't exist, find an interval in which all - // conditions on var are false - auto [middle_interval, cond_set] = - GetIntervalAndCondset(finder.partitions, for_interval, false, has_partition_hint_); - - if (!middle_interval.IsNothing()) { - return {middle_interval, cond_set, false}; - } - } - - // we couldn't find an interval in which the conditions are - // provably true or false. Therefore, we can't partition the loop - // based on those conds - return {{}, {}, std::nullopt}; - }(); - - if (!opt_cond_value.has_value()) { - return Stmt(); + bool cond_value; + IntSet middle_interval; + ExpressionSet cond_set; + // find an interval in which all conditions on var are true + std::tie(middle_interval, cond_set) = + GetIntervalAndCondset(finder.partitions, for_interval, true, has_partition_hint_); + if (middle_interval.IsNothing()) { + // if such interval doesn't exist, find an interval in which all + // conditions on var are false + std::tie(middle_interval, cond_set) = + GetIntervalAndCondset(finder.partitions, for_interval, false, has_partition_hint_); + if (middle_interval.IsNothing()) + // we couldn't find an interval in which the conditions are provably true or false + // Therefore, we can't partition the loop based on those conds + return Stmt(); + cond_value = false; + } else { + cond_value = true; } - bool cond_value = opt_cond_value.value(); IntervalSet middle_interval_i = Downcast<IntervalSet>(middle_interval); // middle_interval is the subrange of the loop variable range for which a
