This is an automated email from the ASF dual-hosted git repository. andrewzhaoluo pushed a commit to branch aluo/rebase-09162022-autotensorization in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 71d0343a43ac282816ac1e9cbcf5ceb744d7fa2c Author: Andrew Zhao Luo <[email protected]> AuthorDate: Fri Sep 2 15:12:20 2022 -0700 ad simplify optional --- src/te/autodiff/ad_simplify.cc | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/te/autodiff/ad_simplify.cc b/src/te/autodiff/ad_simplify.cc index 26047e879e..240adf14b3 100644 --- a/src/te/autodiff/ad_simplify.cc +++ b/src/te/autodiff/ad_simplify.cc @@ -44,6 +44,7 @@ * Due to TVM's restriction, we also lift the reduction to the top of the compute stage. * */ +#include <dmlc/optional.h> #include <tvm/arith/analyzer.h> #include <tvm/arith/int_solver.h> #include <tvm/runtime/registry.h> @@ -53,7 +54,6 @@ #include <iterator> #include <memory> -#include <optional> #include <utility> #include "ad_utils.h" @@ -629,9 +629,9 @@ class EliminateDivModMutator : public ExprMutator { } private: - std::optional<std::pair<Var, Var>> AddNewVarPair(const PrimExpr& e, const PrimExpr& mut, - int64_t val, DivMode mode) { - using tresult = std::optional<std::pair<Var, Var>>; + dmlc::optional<std::pair<Var, Var>> AddNewVarPair(const PrimExpr& e, const PrimExpr& mut, + int64_t val, DivMode mode) { + using tresult = dmlc::optional<std::pair<Var, Var>>; // Try to find the variables using the mutated expressions if (!e.same_as(mut)) { @@ -1183,19 +1183,21 @@ PrimExpr RemoveJacobianAndLiftNonzeroCondImpl(const PrimExpr& expr_orig, const A return RemoveJacobianAndLiftNonzeroCondImpl(new_red, axis, vranges); } + PrimExpr new_outer_cond, new_reduce_cond; Array<PrimExpr> new_source = red->source; // Partially lift conditions from the reduce condition - auto [new_outer_cond, new_reduce_cond] = + std::tie(new_outer_cond, new_reduce_cond) = LiftConditionsThroughReduction(red->condition, red->axis, axis); // If it's not sum then we haven't yet lifted nonzeroness cond from the source if (!is_sum) { + PrimExpr outer_nz_cond, nz_cond, nz_source; auto nz = NonzeronessCondition(red->source[red->value_index]); // Append conditions from the reduction - PrimExpr nz_source = nz.value; - auto [outer_nz_cond, nz_cond] = - LiftConditionsThroughReduction(new_reduce_cond && nz.cond, red->axis, axis); + nz_cond = new_reduce_cond && nz.cond; + nz_source = nz.value; + std::tie(outer_nz_cond, nz_cond) = LiftConditionsThroughReduction(nz_cond, red->axis, axis); new_outer_cond = new_outer_cond && outer_nz_cond; new_source.Set(red->value_index, Select(nz_cond, nz_source, make_zero(nz_source.dtype()))); }
