yzhliu commented on a change in pull request #6078: URL: https://github.com/apache/incubator-tvm/pull/6078#discussion_r463901829
########## File path: src/te/autodiff/ad_simplify.cc ########## @@ -0,0 +1,1305 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ad_simplify.cc + * \brief Simplify tensor compute generated by tensor-level autodiff. + * + * The major simplification we do in this file is to eliminate + * the Jacobian tensor created by autodiff. + * + * Jacobian tensor is sparse because one output element usually relates + * to a small portion of the inputs. For example, element-wise function has a one-to-one mapping + * between input tensor and output tensor, thus the Jacobian is diagonal. + * + * Generally, we have Out_{\beta} = f( In_{A \alpha} ) in which A is a matrix, + * \alpha and \beta are vectors represent the indices of In and Out respectively. + * i.e., the non-zero Jacobian indices is a linear combination of the input indices. + * Thereby we solve linear equations of \beta = A \alpha, + * as well as linear inequalities of their domain ranges. + * + * Refer to Urban S, van der Smagt P. Automatic differentiation for tensor algebras[J]. + * arXiv preprint arXiv:1711.01348, 2017. for more details. + * + * Implement-wise, we extract the equations in the compute definition via NonzeronessCondition, + * replace the compute expression with solved new axes, and create a selection node + * (non-zero-condition ? new_compute_expression : 0). + * + * Due to TVM's restriction, we also lift the reduction to the top of the compute stage. + * + */ +#include <dmlc/optional.h> +#include <tvm/arith/analyzer.h> +#include <tvm/arith/int_solver.h> +#include <tvm/runtime/registry.h> +#include <tvm/te/autodiff.h> +#include <tvm/tir/analysis.h> +#include <tvm/tir/stmt_functor.h> + +#include <memory> +#include <utility> + +#include "ad_util.h" + +namespace tvm { +namespace te { + +using arith::DivMode; +using arith::kFloorDiv; +using arith::kTruncDiv; + +template <class K, class V> +Map<K, V> Merge(Map<K, V> original, const Map<K, V>& update) { + for (const auto& p : update) { + original.Set(p.first, p.second); + } + return std::move(original); +} + +// Concatenate two arrays +template <class T> +Array<T> Concat(Array<T> a, const Array<T>& b) { + for (const auto& x : b) { + a.push_back(x); + } + return std::move(a); +} + +// Combine all expressions from the container using &&. +template <class container> +PrimExpr All(const container& c) { + PrimExpr res; + for (const auto& e : c) { + if (res.get()) { + res = res && e; + } else { + res = e; + } + } + if (res.get()) { + return res; + } else { + return const_true(); + } +} + +Map<Var, Range> IterVarsToMap(const Array<IterVar>& itervars) { + Map<Var, Range> res; + for (const IterVar& v : itervars) { + res.Set(v->var, v->dom); + } + return res; +} + +// Given a map from vars to ranges create an array of itervars +Array<IterVar> IterVarsFromMap(const Array<Var>& vars, const Map<Var, Range>& vranges, + IterVarType iter_type = kDataPar, std::string thread_tag = "") { + Array<IterVar> res; + for (const Var& v : vars) { + CHECK(vranges.count(v)) << "A range for the variable " << v << " was not provided in map " + << vranges; + res.push_back(IterVar(vranges[v], v, iter_type, thread_tag)); + } + return res; +} + +Array<Var> IterVarsToVars(const Array<IterVar>& itervars) { + Array<Var> res; + for (const IterVar& v : itervars) { + res.push_back(v->var); + } + return res; +} + +template <typename ValueType> +inline bool is_const_value(const PrimExpr& e, ValueType value) { + static_assert(std::is_integral<ValueType>::value, + "Comparison to non-integer values is forbidden."); + if (const tir::IntImmNode* i = e.as<tir::IntImmNode>()) { + return i->value == value; + } else if (const tir::FloatImmNode* i = e.as<tir::FloatImmNode>()) { + return i->value == value; + } else if (const tir::CastNode* c = e.as<tir::CastNode>()) { + return is_const_value(c->value, value); + } else if (const tir::BroadcastNode* b = e.as<tir::BroadcastNode>()) { + return is_const_value(b->value, value); + } else { + return false; + } +} + +// Return true if this combiner is just a sum. +bool IsSumCombiner(const CommReducer& combiner, const Map<Var, Range>& vranges) { + arith::Analyzer analyzer; + analyzer.Bind(vranges); + if (combiner->result.size() != 1) { + return false; + } + + if (!is_const_value(analyzer.Simplify(combiner->identity_element[0], + ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE), + 0)) { + return false; + } + + PrimExpr combiner_result = + analyzer.Simplify(combiner->result[0], ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE); + + return tir::ExprDeepEqual()(combiner_result, combiner->lhs[0] + combiner->rhs[0]) || + tir::ExprDeepEqual()(combiner_result, combiner->rhs[0] + combiner->lhs[0]); +} + +bool CanFactorZeroFromCombiner(const CommReducer& combiner, int value_index, + const Map<Var, Range>& vranges) { + arith::Analyzer analyzer; + analyzer.Bind(vranges); + if (!is_const_value(analyzer.Simplify(combiner->identity_element[value_index], + ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE), + 0)) { + return false; + } + + PrimExpr zero = make_zero(combiner->result[value_index].dtype()); + PrimExpr in = Substitute(combiner->result[value_index], {{combiner->lhs[value_index], zero}, + {combiner->rhs[value_index], zero}}); + in = analyzer.Simplify(in, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE); + + return is_const_value(in, 0); +} + +struct NonzeroConditionResult { + PrimExpr cond; + PrimExpr value; + + PrimExpr to_expr() const { return Select(cond, value, make_zero(value.dtype())); } + + friend std::ostream& operator<<(std::ostream& os, const NonzeroConditionResult& r) { + return os << r.to_expr(); + } +}; + +// The implementation of NonzeroCondition +// transform expression to cond ? value : 0 +class NonzeroConditionFunctor : public ExprFunctor<NonzeroConditionResult(const PrimExpr&)> { + public: + NonzeroConditionResult NonzeroCondition(const PrimExpr& e) { + if (e.dtype().is_bool()) { + // Boolean expressions are non-zero whenever they are true themselves + return {e, const_true()}; + } else { + return VisitExpr(e); + } + } + + // Most of the cases are implemented using helpers below + result_type VisitExpr_(const VarNode* op) final { return Default_(GetRef<PrimExpr>(op)); } + result_type VisitExpr_(const IntImmNode* op) final { return Const_(GetRef<IntImm>(op)); } + result_type VisitExpr_(const FloatImmNode* op) final { return Const_(GetRef<FloatImm>(op)); } + result_type VisitExpr_(const StringImmNode* op) final { return Default_(GetRef<PrimExpr>(op)); } + result_type VisitExpr_(const AddNode* op) final { return BinOpAddLike_(GetRef<Add>(op)); } + result_type VisitExpr_(const SubNode* op) final { return BinOpAddLike_(GetRef<Sub>(op)); } + result_type VisitExpr_(const MulNode* op) final { return BinOpMulLike_(GetRef<Mul>(op)); } + result_type VisitExpr_(const DivNode* op) final { return BinOpDivLike_(GetRef<Div>(op)); } + result_type VisitExpr_(const ModNode* op) final { return BinOpDivLike_(GetRef<Mod>(op)); } + result_type VisitExpr_(const FloorDivNode* op) final { + return BinOpDivLike_(GetRef<FloorDiv>(op)); + } + result_type VisitExpr_(const FloorModNode* op) final { + return BinOpDivLike_(GetRef<FloorMod>(op)); + } + result_type VisitExpr_(const MinNode* op) final { return BinOpAddLike_(GetRef<Min>(op)); } + result_type VisitExpr_(const MaxNode* op) final { return BinOpAddLike_(GetRef<Max>(op)); } + + result_type VisitExpr_(const CastNode* op) final { + auto nz_a = NonzeroCondition(op->value); + + if (nz_a.value.same_as(op->value)) { + return {nz_a.cond, GetRef<PrimExpr>(op)}; + } else { + return {nz_a.cond, Cast(op->dtype, nz_a.value)}; + } + } + + result_type VisitExpr_(const SelectNode* op) final { + PrimExpr cond = op->condition, true_val = op->true_value, false_val = op->false_value; + auto nz_a = NonzeroCondition(true_val); + auto nz_b = NonzeroCondition(false_val); + + // If the false part is zero, we can get rid of the select + if (is_const_value(nz_b.value, 0)) { + PrimExpr new_cond = + analyzer_.Simplify(nz_a.cond && cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE); + return {new_cond, nz_a.value}; + } + + // If the true part is zero, we can also get rid of the select + if (is_const_value(nz_a.value, 0)) { + PrimExpr new_cond = + analyzer_.Simplify(nz_b.cond && !cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE); + return {new_cond, nz_b.value}; + } + + // Otherwise we retain the select and combine the conditions into this + PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond), + ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE); + if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) { + return {new_cond, GetRef<PrimExpr>(op)}; + } else { + return {new_cond, Select(cond, nz_a.value, nz_b.value)}; + } + } + + result_type VisitExpr_(const CallNode* op) final { + if (op->op.same_as(Op::Get("tir.if_then_else"))) { + PrimExpr cond = op->args[0], true_val = op->args[1], false_val = op->args[2]; + auto nz_a = NonzeroCondition(true_val); + auto nz_b = NonzeroCondition(false_val); + + // We don't have as much freedom here as in the select case + // since the `if` must be preserved in any case + PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond), + ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE); + if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) { + return {new_cond, GetRef<PrimExpr>(op)}; + } else { + return {new_cond, if_then_else(cond, nz_a.value, nz_b.value)}; + } + } else { + return Default_(GetRef<PrimExpr>(op)); + } + } + + result_type VisitExpr_(const ProducerLoadNode* op) final { + return Default_(GetRef<PrimExpr>(op)); + } + + NonzeroConditionResult Default_(const PrimExpr& e) { + // This is always correct, so it's the default + return {const_true(), e}; + } + + template <class T> + NonzeroConditionResult Const_(const T& op) { + if (op->value == 0) { + return {const_false(), op}; + } else { + return {const_true(), op}; + } + } + + template <class T> + NonzeroConditionResult BinOpAddLike_(const T& op) { + auto nz_a = NonzeroCondition(op->a); + auto nz_b = NonzeroCondition(op->b); + + // For addition and similar ops the result may be nonzero if either of the arguments is + // nonzero, so we combine the conditions with Or. + if (tir::ExprDeepEqual()(nz_a.cond, nz_b.cond)) { + // If the conditions are the same, we don't need Or + if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) { + return {nz_a.cond, op}; + } else { + return {nz_a.cond, T(nz_a.value, nz_b.value)}; + } + } else { + // Otherwise use Or + PrimExpr new_cond = + analyzer_.Simplify(nz_a.cond || nz_b.cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE); + // A little optimization: if the combined condition is the same as one of the inner + // conditions, we don't need to guard the inner value with a select, otherwise + // we create a select in the `to_expr` call. + PrimExpr new_a = tir::ExprDeepEqual()(nz_a.cond, new_cond) ? nz_a.value : nz_a.to_expr(); + PrimExpr new_b = tir::ExprDeepEqual()(nz_b.cond, new_cond) ? nz_b.value : nz_b.to_expr(); + PrimExpr new_expr = T(new_a, new_b); + return {new_cond, new_expr}; + } + } + + template <class T> + NonzeroConditionResult BinOpMulLike_(const T& op) { + auto nz_a = NonzeroCondition(op->a); + auto nz_b = NonzeroCondition(op->b); + + // For multiplication and similar ops the result may be nonzero if + // both the arguments are nonzero, so we combine with And. + PrimExpr new_cond = + analyzer_.Simplify(nz_a.cond && nz_b.cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE); + + if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) { + return {new_cond, op}; + } else { + return {new_cond, T(nz_a.value, nz_b.value)}; + } + } + + template <class T> + NonzeroConditionResult BinOpDivLike_(const T& op) { + auto nz_a = NonzeroCondition(op->a); + + // For Div we simply use the condition of the numerator. + + if (nz_a.value.same_as(op->a)) { + return {nz_a.cond, op}; + } else { + return {nz_a.cond, T(nz_a.value, op->b)}; + } + } + + private: + arith::Analyzer analyzer_; +}; + +inline NonzeroConditionResult NonzeronessCondition(const PrimExpr& expr) { + return NonzeroConditionFunctor().NonzeroCondition(expr); +} + +struct FactorOutAtomicFormulasResult { Review comment: do you mean to change the struct & member name to mention CNF? ########## File path: src/te/autodiff/ad_simplify.cc ########## @@ -0,0 +1,1305 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ad_simplify.cc + * \brief Simplify tensor compute generated by tensor-level autodiff. + * + * The major simplification we do in this file is to eliminate + * the Jacobian tensor created by autodiff. + * + * Jacobian tensor is sparse because one output element usually relates + * to a small portion of the inputs. For example, element-wise function has a one-to-one mapping + * between input tensor and output tensor, thus the Jacobian is diagonal. + * + * Generally, we have Out_{\beta} = f( In_{A \alpha} ) in which A is a matrix, + * \alpha and \beta are vectors represent the indices of In and Out respectively. + * i.e., the non-zero Jacobian indices is a linear combination of the input indices. + * Thereby we solve linear equations of \beta = A \alpha, + * as well as linear inequalities of their domain ranges. + * + * Refer to Urban S, van der Smagt P. Automatic differentiation for tensor algebras[J]. + * arXiv preprint arXiv:1711.01348, 2017. for more details. + * + * Implement-wise, we extract the equations in the compute definition via NonzeronessCondition, + * replace the compute expression with solved new axes, and create a selection node + * (non-zero-condition ? new_compute_expression : 0). + * + * Due to TVM's restriction, we also lift the reduction to the top of the compute stage. + * + */ +#include <dmlc/optional.h> +#include <tvm/arith/analyzer.h> +#include <tvm/arith/int_solver.h> +#include <tvm/runtime/registry.h> +#include <tvm/te/autodiff.h> +#include <tvm/tir/analysis.h> +#include <tvm/tir/stmt_functor.h> + +#include <memory> +#include <utility> + +#include "ad_util.h" + +namespace tvm { +namespace te { + +using arith::DivMode; +using arith::kFloorDiv; +using arith::kTruncDiv; + +template <class K, class V> +Map<K, V> Merge(Map<K, V> original, const Map<K, V>& update) { + for (const auto& p : update) { + original.Set(p.first, p.second); + } + return std::move(original); +} + +// Concatenate two arrays +template <class T> +Array<T> Concat(Array<T> a, const Array<T>& b) { + for (const auto& x : b) { + a.push_back(x); + } + return std::move(a); +} + +// Combine all expressions from the container using &&. +template <class container> +PrimExpr All(const container& c) { + PrimExpr res; + for (const auto& e : c) { + if (res.get()) { + res = res && e; + } else { + res = e; + } + } + if (res.get()) { + return res; + } else { + return const_true(); + } +} + +Map<Var, Range> IterVarsToMap(const Array<IterVar>& itervars) { + Map<Var, Range> res; + for (const IterVar& v : itervars) { + res.Set(v->var, v->dom); + } + return res; +} + +// Given a map from vars to ranges create an array of itervars +Array<IterVar> IterVarsFromMap(const Array<Var>& vars, const Map<Var, Range>& vranges, + IterVarType iter_type = kDataPar, std::string thread_tag = "") { + Array<IterVar> res; + for (const Var& v : vars) { + CHECK(vranges.count(v)) << "A range for the variable " << v << " was not provided in map " + << vranges; + res.push_back(IterVar(vranges[v], v, iter_type, thread_tag)); + } + return res; +} + +Array<Var> IterVarsToVars(const Array<IterVar>& itervars) { + Array<Var> res; + for (const IterVar& v : itervars) { + res.push_back(v->var); + } + return res; +} + +template <typename ValueType> +inline bool is_const_value(const PrimExpr& e, ValueType value) { + static_assert(std::is_integral<ValueType>::value, + "Comparison to non-integer values is forbidden."); + if (const tir::IntImmNode* i = e.as<tir::IntImmNode>()) { + return i->value == value; + } else if (const tir::FloatImmNode* i = e.as<tir::FloatImmNode>()) { + return i->value == value; + } else if (const tir::CastNode* c = e.as<tir::CastNode>()) { + return is_const_value(c->value, value); + } else if (const tir::BroadcastNode* b = e.as<tir::BroadcastNode>()) { + return is_const_value(b->value, value); + } else { + return false; + } +} + +// Return true if this combiner is just a sum. +bool IsSumCombiner(const CommReducer& combiner, const Map<Var, Range>& vranges) { + arith::Analyzer analyzer; + analyzer.Bind(vranges); + if (combiner->result.size() != 1) { + return false; + } + + if (!is_const_value(analyzer.Simplify(combiner->identity_element[0], + ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE), + 0)) { + return false; + } + + PrimExpr combiner_result = + analyzer.Simplify(combiner->result[0], ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE); + + return tir::ExprDeepEqual()(combiner_result, combiner->lhs[0] + combiner->rhs[0]) || + tir::ExprDeepEqual()(combiner_result, combiner->rhs[0] + combiner->lhs[0]); +} + +bool CanFactorZeroFromCombiner(const CommReducer& combiner, int value_index, + const Map<Var, Range>& vranges) { + arith::Analyzer analyzer; + analyzer.Bind(vranges); + if (!is_const_value(analyzer.Simplify(combiner->identity_element[value_index], + ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE), + 0)) { + return false; + } + + PrimExpr zero = make_zero(combiner->result[value_index].dtype()); + PrimExpr in = Substitute(combiner->result[value_index], {{combiner->lhs[value_index], zero}, + {combiner->rhs[value_index], zero}}); + in = analyzer.Simplify(in, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE); + + return is_const_value(in, 0); +} + +struct NonzeroConditionResult { + PrimExpr cond; + PrimExpr value; + + PrimExpr to_expr() const { return Select(cond, value, make_zero(value.dtype())); } + + friend std::ostream& operator<<(std::ostream& os, const NonzeroConditionResult& r) { + return os << r.to_expr(); + } +}; + +// The implementation of NonzeroCondition +// transform expression to cond ? value : 0 +class NonzeroConditionFunctor : public ExprFunctor<NonzeroConditionResult(const PrimExpr&)> { + public: + NonzeroConditionResult NonzeroCondition(const PrimExpr& e) { + if (e.dtype().is_bool()) { + // Boolean expressions are non-zero whenever they are true themselves + return {e, const_true()}; + } else { + return VisitExpr(e); + } + } + + // Most of the cases are implemented using helpers below + result_type VisitExpr_(const VarNode* op) final { return Default_(GetRef<PrimExpr>(op)); } + result_type VisitExpr_(const IntImmNode* op) final { return Const_(GetRef<IntImm>(op)); } + result_type VisitExpr_(const FloatImmNode* op) final { return Const_(GetRef<FloatImm>(op)); } + result_type VisitExpr_(const StringImmNode* op) final { return Default_(GetRef<PrimExpr>(op)); } + result_type VisitExpr_(const AddNode* op) final { return BinOpAddLike_(GetRef<Add>(op)); } + result_type VisitExpr_(const SubNode* op) final { return BinOpAddLike_(GetRef<Sub>(op)); } + result_type VisitExpr_(const MulNode* op) final { return BinOpMulLike_(GetRef<Mul>(op)); } + result_type VisitExpr_(const DivNode* op) final { return BinOpDivLike_(GetRef<Div>(op)); } + result_type VisitExpr_(const ModNode* op) final { return BinOpDivLike_(GetRef<Mod>(op)); } + result_type VisitExpr_(const FloorDivNode* op) final { + return BinOpDivLike_(GetRef<FloorDiv>(op)); + } + result_type VisitExpr_(const FloorModNode* op) final { + return BinOpDivLike_(GetRef<FloorMod>(op)); + } + result_type VisitExpr_(const MinNode* op) final { return BinOpAddLike_(GetRef<Min>(op)); } + result_type VisitExpr_(const MaxNode* op) final { return BinOpAddLike_(GetRef<Max>(op)); } + + result_type VisitExpr_(const CastNode* op) final { + auto nz_a = NonzeroCondition(op->value); + + if (nz_a.value.same_as(op->value)) { + return {nz_a.cond, GetRef<PrimExpr>(op)}; + } else { + return {nz_a.cond, Cast(op->dtype, nz_a.value)}; + } + } + + result_type VisitExpr_(const SelectNode* op) final { + PrimExpr cond = op->condition, true_val = op->true_value, false_val = op->false_value; + auto nz_a = NonzeroCondition(true_val); + auto nz_b = NonzeroCondition(false_val); + + // If the false part is zero, we can get rid of the select + if (is_const_value(nz_b.value, 0)) { + PrimExpr new_cond = + analyzer_.Simplify(nz_a.cond && cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE); + return {new_cond, nz_a.value}; + } + + // If the true part is zero, we can also get rid of the select + if (is_const_value(nz_a.value, 0)) { + PrimExpr new_cond = + analyzer_.Simplify(nz_b.cond && !cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE); + return {new_cond, nz_b.value}; + } + + // Otherwise we retain the select and combine the conditions into this + PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond), + ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE); + if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) { + return {new_cond, GetRef<PrimExpr>(op)}; + } else { + return {new_cond, Select(cond, nz_a.value, nz_b.value)}; + } + } + + result_type VisitExpr_(const CallNode* op) final { + if (op->op.same_as(Op::Get("tir.if_then_else"))) { + PrimExpr cond = op->args[0], true_val = op->args[1], false_val = op->args[2]; + auto nz_a = NonzeroCondition(true_val); + auto nz_b = NonzeroCondition(false_val); + + // We don't have as much freedom here as in the select case + // since the `if` must be preserved in any case + PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond), + ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE); + if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) { + return {new_cond, GetRef<PrimExpr>(op)}; + } else { + return {new_cond, if_then_else(cond, nz_a.value, nz_b.value)}; + } + } else { + return Default_(GetRef<PrimExpr>(op)); + } + } + + result_type VisitExpr_(const ProducerLoadNode* op) final { + return Default_(GetRef<PrimExpr>(op)); + } + + NonzeroConditionResult Default_(const PrimExpr& e) { + // This is always correct, so it's the default + return {const_true(), e}; + } + + template <class T> + NonzeroConditionResult Const_(const T& op) { + if (op->value == 0) { + return {const_false(), op}; + } else { + return {const_true(), op}; + } + } + + template <class T> + NonzeroConditionResult BinOpAddLike_(const T& op) { + auto nz_a = NonzeroCondition(op->a); + auto nz_b = NonzeroCondition(op->b); + + // For addition and similar ops the result may be nonzero if either of the arguments is + // nonzero, so we combine the conditions with Or. + if (tir::ExprDeepEqual()(nz_a.cond, nz_b.cond)) { + // If the conditions are the same, we don't need Or + if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) { + return {nz_a.cond, op}; + } else { + return {nz_a.cond, T(nz_a.value, nz_b.value)}; + } + } else { + // Otherwise use Or + PrimExpr new_cond = + analyzer_.Simplify(nz_a.cond || nz_b.cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE); + // A little optimization: if the combined condition is the same as one of the inner + // conditions, we don't need to guard the inner value with a select, otherwise + // we create a select in the `to_expr` call. + PrimExpr new_a = tir::ExprDeepEqual()(nz_a.cond, new_cond) ? nz_a.value : nz_a.to_expr(); + PrimExpr new_b = tir::ExprDeepEqual()(nz_b.cond, new_cond) ? nz_b.value : nz_b.to_expr(); + PrimExpr new_expr = T(new_a, new_b); + return {new_cond, new_expr}; + } + } + + template <class T> + NonzeroConditionResult BinOpMulLike_(const T& op) { + auto nz_a = NonzeroCondition(op->a); + auto nz_b = NonzeroCondition(op->b); + + // For multiplication and similar ops the result may be nonzero if + // both the arguments are nonzero, so we combine with And. + PrimExpr new_cond = + analyzer_.Simplify(nz_a.cond && nz_b.cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE); + + if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) { + return {new_cond, op}; + } else { + return {new_cond, T(nz_a.value, nz_b.value)}; + } + } + + template <class T> + NonzeroConditionResult BinOpDivLike_(const T& op) { + auto nz_a = NonzeroCondition(op->a); + + // For Div we simply use the condition of the numerator. + + if (nz_a.value.same_as(op->a)) { Review comment: do you mean provide a function `bool value_equals(const PrimExpr&)` in NonzeroConditionResult? ########## File path: src/te/autodiff/ad_simplify.cc ########## @@ -0,0 +1,1305 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ad_simplify.cc + * \brief Simplify tensor compute generated by tensor-level autodiff. + * + * The major simplification we do in this file is to eliminate + * the Jacobian tensor created by autodiff. + * + * Jacobian tensor is sparse because one output element usually relates + * to a small portion of the inputs. For example, element-wise function has a one-to-one mapping + * between input tensor and output tensor, thus the Jacobian is diagonal. + * + * Generally, we have Out_{\beta} = f( In_{A \alpha} ) in which A is a matrix, + * \alpha and \beta are vectors represent the indices of In and Out respectively. + * i.e., the non-zero Jacobian indices is a linear combination of the input indices. + * Thereby we solve linear equations of \beta = A \alpha, + * as well as linear inequalities of their domain ranges. + * + * Refer to Urban S, van der Smagt P. Automatic differentiation for tensor algebras[J]. + * arXiv preprint arXiv:1711.01348, 2017. for more details. + * + * Implement-wise, we extract the equations in the compute definition via NonzeronessCondition, + * replace the compute expression with solved new axes, and create a selection node + * (non-zero-condition ? new_compute_expression : 0). + * + * Due to TVM's restriction, we also lift the reduction to the top of the compute stage. + * + */ +#include <dmlc/optional.h> +#include <tvm/arith/analyzer.h> +#include <tvm/arith/int_solver.h> +#include <tvm/runtime/registry.h> +#include <tvm/te/autodiff.h> +#include <tvm/tir/analysis.h> +#include <tvm/tir/stmt_functor.h> + +#include <memory> +#include <utility> + +#include "ad_util.h" + +namespace tvm { +namespace te { + +using arith::DivMode; +using arith::kFloorDiv; +using arith::kTruncDiv; + +template <class K, class V> +Map<K, V> Merge(Map<K, V> original, const Map<K, V>& update) { + for (const auto& p : update) { + original.Set(p.first, p.second); + } + return std::move(original); +} + +// Concatenate two arrays +template <class T> +Array<T> Concat(Array<T> a, const Array<T>& b) { + for (const auto& x : b) { + a.push_back(x); + } + return std::move(a); +} + +// Combine all expressions from the container using &&. +template <class container> +PrimExpr All(const container& c) { + PrimExpr res; + for (const auto& e : c) { + if (res.get()) { + res = res && e; + } else { + res = e; + } + } + if (res.get()) { + return res; + } else { + return const_true(); + } +} + +Map<Var, Range> IterVarsToMap(const Array<IterVar>& itervars) { + Map<Var, Range> res; + for (const IterVar& v : itervars) { + res.Set(v->var, v->dom); + } + return res; +} + +// Given a map from vars to ranges create an array of itervars +Array<IterVar> IterVarsFromMap(const Array<Var>& vars, const Map<Var, Range>& vranges, + IterVarType iter_type = kDataPar, std::string thread_tag = "") { + Array<IterVar> res; + for (const Var& v : vars) { + CHECK(vranges.count(v)) << "A range for the variable " << v << " was not provided in map " + << vranges; + res.push_back(IterVar(vranges[v], v, iter_type, thread_tag)); + } + return res; +} + +Array<Var> IterVarsToVars(const Array<IterVar>& itervars) { + Array<Var> res; + for (const IterVar& v : itervars) { + res.push_back(v->var); + } + return res; +} + +template <typename ValueType> +inline bool is_const_value(const PrimExpr& e, ValueType value) { + static_assert(std::is_integral<ValueType>::value, + "Comparison to non-integer values is forbidden."); + if (const tir::IntImmNode* i = e.as<tir::IntImmNode>()) { + return i->value == value; + } else if (const tir::FloatImmNode* i = e.as<tir::FloatImmNode>()) { + return i->value == value; + } else if (const tir::CastNode* c = e.as<tir::CastNode>()) { + return is_const_value(c->value, value); + } else if (const tir::BroadcastNode* b = e.as<tir::BroadcastNode>()) { + return is_const_value(b->value, value); + } else { + return false; + } +} + +// Return true if this combiner is just a sum. +bool IsSumCombiner(const CommReducer& combiner, const Map<Var, Range>& vranges) { + arith::Analyzer analyzer; + analyzer.Bind(vranges); + if (combiner->result.size() != 1) { + return false; + } + + if (!is_const_value(analyzer.Simplify(combiner->identity_element[0], + ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE), + 0)) { + return false; + } + + PrimExpr combiner_result = + analyzer.Simplify(combiner->result[0], ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE); + + return tir::ExprDeepEqual()(combiner_result, combiner->lhs[0] + combiner->rhs[0]) || + tir::ExprDeepEqual()(combiner_result, combiner->rhs[0] + combiner->lhs[0]); +} + +bool CanFactorZeroFromCombiner(const CommReducer& combiner, int value_index, + const Map<Var, Range>& vranges) { + arith::Analyzer analyzer; + analyzer.Bind(vranges); + if (!is_const_value(analyzer.Simplify(combiner->identity_element[value_index], + ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE), + 0)) { + return false; + } + + PrimExpr zero = make_zero(combiner->result[value_index].dtype()); + PrimExpr in = Substitute(combiner->result[value_index], {{combiner->lhs[value_index], zero}, + {combiner->rhs[value_index], zero}}); + in = analyzer.Simplify(in, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE); + + return is_const_value(in, 0); +} + +struct NonzeroConditionResult { + PrimExpr cond; + PrimExpr value; + + PrimExpr to_expr() const { return Select(cond, value, make_zero(value.dtype())); } + + friend std::ostream& operator<<(std::ostream& os, const NonzeroConditionResult& r) { + return os << r.to_expr(); + } +}; + +// The implementation of NonzeroCondition +// transform expression to cond ? value : 0 +class NonzeroConditionFunctor : public ExprFunctor<NonzeroConditionResult(const PrimExpr&)> { + public: + NonzeroConditionResult NonzeroCondition(const PrimExpr& e) { + if (e.dtype().is_bool()) { + // Boolean expressions are non-zero whenever they are true themselves + return {e, const_true()}; + } else { + return VisitExpr(e); + } + } + + // Most of the cases are implemented using helpers below + result_type VisitExpr_(const VarNode* op) final { return Default_(GetRef<PrimExpr>(op)); } + result_type VisitExpr_(const IntImmNode* op) final { return Const_(GetRef<IntImm>(op)); } + result_type VisitExpr_(const FloatImmNode* op) final { return Const_(GetRef<FloatImm>(op)); } + result_type VisitExpr_(const StringImmNode* op) final { return Default_(GetRef<PrimExpr>(op)); } + result_type VisitExpr_(const AddNode* op) final { return BinOpAddLike_(GetRef<Add>(op)); } + result_type VisitExpr_(const SubNode* op) final { return BinOpAddLike_(GetRef<Sub>(op)); } + result_type VisitExpr_(const MulNode* op) final { return BinOpMulLike_(GetRef<Mul>(op)); } + result_type VisitExpr_(const DivNode* op) final { return BinOpDivLike_(GetRef<Div>(op)); } + result_type VisitExpr_(const ModNode* op) final { return BinOpDivLike_(GetRef<Mod>(op)); } + result_type VisitExpr_(const FloorDivNode* op) final { + return BinOpDivLike_(GetRef<FloorDiv>(op)); + } + result_type VisitExpr_(const FloorModNode* op) final { + return BinOpDivLike_(GetRef<FloorMod>(op)); + } + result_type VisitExpr_(const MinNode* op) final { return BinOpAddLike_(GetRef<Min>(op)); } + result_type VisitExpr_(const MaxNode* op) final { return BinOpAddLike_(GetRef<Max>(op)); } + + result_type VisitExpr_(const CastNode* op) final { + auto nz_a = NonzeroCondition(op->value); + + if (nz_a.value.same_as(op->value)) { + return {nz_a.cond, GetRef<PrimExpr>(op)}; + } else { + return {nz_a.cond, Cast(op->dtype, nz_a.value)}; + } + } + + result_type VisitExpr_(const SelectNode* op) final { + PrimExpr cond = op->condition, true_val = op->true_value, false_val = op->false_value; + auto nz_a = NonzeroCondition(true_val); + auto nz_b = NonzeroCondition(false_val); + + // If the false part is zero, we can get rid of the select + if (is_const_value(nz_b.value, 0)) { + PrimExpr new_cond = + analyzer_.Simplify(nz_a.cond && cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE); + return {new_cond, nz_a.value}; + } + + // If the true part is zero, we can also get rid of the select + if (is_const_value(nz_a.value, 0)) { + PrimExpr new_cond = + analyzer_.Simplify(nz_b.cond && !cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE); + return {new_cond, nz_b.value}; + } + + // Otherwise we retain the select and combine the conditions into this + PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond), + ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE); + if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) { + return {new_cond, GetRef<PrimExpr>(op)}; + } else { + return {new_cond, Select(cond, nz_a.value, nz_b.value)}; + } + } + + result_type VisitExpr_(const CallNode* op) final { + if (op->op.same_as(Op::Get("tir.if_then_else"))) { + PrimExpr cond = op->args[0], true_val = op->args[1], false_val = op->args[2]; + auto nz_a = NonzeroCondition(true_val); + auto nz_b = NonzeroCondition(false_val); + + // We don't have as much freedom here as in the select case + // since the `if` must be preserved in any case + PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond), + ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE); + if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) { + return {new_cond, GetRef<PrimExpr>(op)}; + } else { + return {new_cond, if_then_else(cond, nz_a.value, nz_b.value)}; + } + } else { + return Default_(GetRef<PrimExpr>(op)); + } + } + + result_type VisitExpr_(const ProducerLoadNode* op) final { + return Default_(GetRef<PrimExpr>(op)); + } + + NonzeroConditionResult Default_(const PrimExpr& e) { + // This is always correct, so it's the default + return {const_true(), e}; + } + + template <class T> + NonzeroConditionResult Const_(const T& op) { + if (op->value == 0) { + return {const_false(), op}; + } else { + return {const_true(), op}; + } + } + + template <class T> + NonzeroConditionResult BinOpAddLike_(const T& op) { + auto nz_a = NonzeroCondition(op->a); + auto nz_b = NonzeroCondition(op->b); + + // For addition and similar ops the result may be nonzero if either of the arguments is + // nonzero, so we combine the conditions with Or. + if (tir::ExprDeepEqual()(nz_a.cond, nz_b.cond)) { + // If the conditions are the same, we don't need Or + if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) { + return {nz_a.cond, op}; + } else { + return {nz_a.cond, T(nz_a.value, nz_b.value)}; + } + } else { + // Otherwise use Or + PrimExpr new_cond = + analyzer_.Simplify(nz_a.cond || nz_b.cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE); + // A little optimization: if the combined condition is the same as one of the inner + // conditions, we don't need to guard the inner value with a select, otherwise + // we create a select in the `to_expr` call. + PrimExpr new_a = tir::ExprDeepEqual()(nz_a.cond, new_cond) ? nz_a.value : nz_a.to_expr(); + PrimExpr new_b = tir::ExprDeepEqual()(nz_b.cond, new_cond) ? nz_b.value : nz_b.to_expr(); + PrimExpr new_expr = T(new_a, new_b); + return {new_cond, new_expr}; + } + } + + template <class T> + NonzeroConditionResult BinOpMulLike_(const T& op) { + auto nz_a = NonzeroCondition(op->a); + auto nz_b = NonzeroCondition(op->b); + + // For multiplication and similar ops the result may be nonzero if + // both the arguments are nonzero, so we combine with And. + PrimExpr new_cond = + analyzer_.Simplify(nz_a.cond && nz_b.cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE); + + if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) { + return {new_cond, op}; + } else { + return {new_cond, T(nz_a.value, nz_b.value)}; + } + } + + template <class T> + NonzeroConditionResult BinOpDivLike_(const T& op) { + auto nz_a = NonzeroCondition(op->a); + + // For Div we simply use the condition of the numerator. + + if (nz_a.value.same_as(op->a)) { + return {nz_a.cond, op}; + } else { + return {nz_a.cond, T(nz_a.value, op->b)}; + } + } + + private: + arith::Analyzer analyzer_; +}; + +inline NonzeroConditionResult NonzeronessCondition(const PrimExpr& expr) { + return NonzeroConditionFunctor().NonzeroCondition(expr); +} + +struct FactorOutAtomicFormulasResult { + std::vector<PrimExpr> atomic_formulas; + PrimExpr rest; + + PrimExpr to_expr() const { + PrimExpr res = rest; + for (const PrimExpr& e : atomic_formulas) { + res = And(e, res); + } + return res; + } + + Array<PrimExpr> to_array() const { + Array<PrimExpr> res = atomic_formulas; + res.push_back(rest); + return res; + } +}; + +// The implementation of FactorOutAtomicFormulas +class FactorOutAtomicFormulasFunctor + : public ExprFunctor<FactorOutAtomicFormulasResult(const PrimExpr&)> { + public: + result_type Atomic_(const PrimExpr& e) { + // For atomic expressions the result is the expr itself with True as the residual + return {{e}, make_const(e.dtype(), 1)}; + } + + // This is basically the list of expression kinds that are considered atomic + result_type VisitExpr_(const VarNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); } + result_type VisitExpr_(const CallNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); } + result_type VisitExpr_(const IntImmNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); } + result_type VisitExpr_(const EQNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); } + result_type VisitExpr_(const NENode* op) final { return Atomic_(GetRef<PrimExpr>(op)); } + result_type VisitExpr_(const LENode* op) final { return Atomic_(GetRef<PrimExpr>(op)); } + result_type VisitExpr_(const LTNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); } + result_type VisitExpr_(const GENode* op) final { return Atomic_(GetRef<PrimExpr>(op)); } + result_type VisitExpr_(const GTNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); } + + result_type VisitExpr_(const SelectNode* op) final { + // Select can be rewritten through other logical ops + PrimExpr expr = (op->condition && op->true_value) || (!op->condition && op->false_value); + return VisitExpr(expr); + } + + result_type VisitExpr_(const NotNode* op) final { + // Not should be moved down + if (const OrNode* or_expr = op->a.as<OrNode>()) { + PrimExpr expr = !or_expr->a && !or_expr->b; + return VisitExpr(expr); + } else if (const AndNode* and_expr = op->a.as<AndNode>()) { + PrimExpr expr = !and_expr->a || !and_expr->b; + return VisitExpr(expr); + } else if (const SelectNode* sel_expr = op->a.as<SelectNode>()) { + PrimExpr expr = ((!sel_expr->condition || !sel_expr->true_value) && + (sel_expr->condition || !sel_expr->false_value)); + return VisitExpr(expr); + } + return Atomic_(GetRef<PrimExpr>(op)); + } + + result_type VisitExpr_(const AndNode* op) final { + auto res_a = VisitExpr(op->a); + auto res_b = VisitExpr(op->b); + + // For the And case we return the union of the sets of atomic formulas + std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_set; + res_set.reserve(res_a.atomic_formulas.size() + res_b.atomic_formulas.size()); + std::copy(res_a.atomic_formulas.begin(), res_a.atomic_formulas.end(), + std::inserter(res_set, res_set.end())); + std::copy(res_b.atomic_formulas.begin(), res_b.atomic_formulas.end(), + std::inserter(res_set, res_set.end())); + + std::vector<PrimExpr> res{res_set.begin(), res_set.end()}; + + // And the residuals are combined with && + return {res, res_a.rest && res_b.rest}; + } + + result_type VisitExpr_(const MulNode* op) final { + // Since we work with bools, for multiplication we do the same thing as for And + PrimExpr e_and = op->a && op->b; + return VisitExpr(e_and); + } + + result_type VisitExpr_(const OrNode* op) final { + auto res_a = VisitExpr(op->a); + auto res_b = VisitExpr(op->b); + + std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_a_set{ + res_a.atomic_formulas.begin(), res_a.atomic_formulas.end()}; + std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_b_set{ + res_b.atomic_formulas.begin(), res_b.atomic_formulas.end()}; + + // For the Or case we intersect the sets of atomic formulas + std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_set; + res_set.reserve(std::min(res_a.atomic_formulas.size(), res_b.atomic_formulas.size())); + for (const auto& res_b_formula : res_b_set) { + if (res_a_set.count(res_b_formula)) { + res_set.insert(res_b_formula); + } + } + + // Computing the residual is more complex: we have to compute the sets of atomic formulas + // which are left behind, and then combine them with the residuals into the new residual. + std::vector<PrimExpr> new_cond_a; + new_cond_a.reserve(res_a.atomic_formulas.size() - res_set.size()); + for (const auto& formula : res_a_set) { + if (!res_set.count(formula)) new_cond_a.emplace_back(formula); + } + + std::vector<PrimExpr> new_cond_b; + new_cond_b.reserve(res_b.atomic_formulas.size() - res_set.size()); + for (const auto& formula : res_b_set) { + if (!res_set.count(formula)) new_cond_b.emplace_back(formula); + } + + res_a.atomic_formulas = std::move(new_cond_a); + res_b.atomic_formulas = std::move(new_cond_b); + + PrimExpr new_rest = res_a.to_expr() || res_b.to_expr(); + std::vector<PrimExpr> res{res_set.begin(), res_set.end()}; + + return {res, new_rest}; + } +}; + +// Transform the given formula into a conjunction of atomic formulas (represented as an array) +// and a non-atomic residual. Atomic formulas are consts, calls, variables and comparisons (a <= b, +// etc), i.e. formulas which are not logical operators (||, &&, !) on the top level. +FactorOutAtomicFormulasResult FactorOutAtomicFormulas(const PrimExpr& e) { + CHECK(e.dtype().is_bool()); + return FactorOutAtomicFormulasFunctor().VisitExpr(e); +} + +struct EliminateDivModResult { + PrimExpr expr; + Map<Var, PrimExpr> substitution; + Array<Var> new_variables; + Array<PrimExpr> conditions; + Map<Var, Range> ranges; +}; + +inline PrimExpr ModImpl(PrimExpr a, PrimExpr b, DivMode mode) { + if (mode == kTruncDiv) { + return truncmod(a, b); + } else { + CHECK_EQ(mode, kFloorDiv); + return floormod(a, b); + } +} + +inline PrimExpr DivImpl(PrimExpr a, PrimExpr b, DivMode mode) { + if (mode == kTruncDiv) { + return truncdiv(a, b); + } else { + CHECK_EQ(mode, kFloorDiv); + return floordiv(a, b); + } +} + +class EliminateDivModMutator : public ExprMutator { + public: + Map<Var, PrimExpr> substitution; + Array<Var> new_variables; + Array<PrimExpr> conditions; + Map<Var, Range> ranges; + + explicit EliminateDivModMutator(Map<Var, Range> ranges) : ranges(std::move(ranges)) {} + + virtual PrimExpr VisitExpr_(const DivNode* op) { + const IntImmNode* imm = op->b.as<IntImmNode>(); + if (imm && imm->value != 0) { + if (imm->value < 0) { + // x / -c == -(x/c) for truncated division + return make_zero(op->dtype) - + VisitExpr(truncdiv(op->a, make_const(op->dtype, -imm->value))); + } + + // Try to find the already existing variables for this expression + auto it = expr_to_vars_.find(std::make_tuple(kTruncDiv, op->a, imm->value)); + if (it != expr_to_vars_.end()) { + return it->second.first; + } + + // Otherwise recursively mutate the left hand side, and create new variables + PrimExpr mutated_a = VisitExpr(op->a); + if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kTruncDiv)) { + return var_pair_opt.value().first; + } else { + return truncdiv(mutated_a, op->b); + } + } + + return truncdiv(VisitExpr(op->a), VisitExpr(op->b)); + } + + virtual PrimExpr VisitExpr_(const ModNode* op) { + const IntImmNode* imm = op->b.as<IntImmNode>(); + if (imm && imm->value != 0) { + if (imm->value < 0) { + // x % -c == x % c for truncated division + return VisitExpr(truncmod(op->a, make_const(op->dtype, -imm->value))); + } + + // Try to find the already existing variables for this expression + auto it = expr_to_vars_.find(std::make_tuple(kTruncDiv, op->a, imm->value)); + if (it != expr_to_vars_.end()) { + return it->second.second; + } + + // Otherwise recursively mutate the left hand side, and create new variables + PrimExpr mutated_a = VisitExpr(op->a); + if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kTruncDiv)) { + return var_pair_opt.value().second; + } else { + return truncmod(mutated_a, op->b); + } + } + + return truncmod(VisitExpr(op->a), VisitExpr(op->b)); + } + + virtual PrimExpr VisitExpr_(const FloorDivNode* op) { + const IntImmNode* imm = op->b.as<IntImmNode>(); + if (imm && imm->value != 0) { + if (imm->value < 0) { + // x / -c == (-x) / c for flooring division + return VisitExpr( + floordiv(make_zero(op->dtype) - op->a, make_const(op->dtype, -imm->value))); + } + + // Try to find the already existing variables for this expression + auto it = expr_to_vars_.find(std::make_tuple(kFloorDiv, op->a, imm->value)); + if (it != expr_to_vars_.end()) { + return it->second.first; + } + + // Otherwise recursively mutate the left hand side, and create new variables + PrimExpr mutated_a = VisitExpr(op->a); + if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kFloorDiv)) { + return var_pair_opt.value().first; + } else { + return floordiv(mutated_a, op->b); + } + } + + return floordiv(VisitExpr(op->a), VisitExpr(op->b)); + } + + virtual PrimExpr VisitExpr_(const FloorModNode* op) { + const IntImmNode* imm = op->b.as<IntImmNode>(); + if (imm && imm->value != 0) { + if (imm->value < 0) { + // x % -c == -(-x % c) for flooring division + return VisitExpr(make_zero(op->dtype) - floormod(make_zero(op->dtype) - op->a, + make_const(op->dtype, -imm->value))); + } + + // Try to find the already existing variables for this expression + auto it = expr_to_vars_.find(std::make_tuple(kFloorDiv, op->a, imm->value)); + if (it != expr_to_vars_.end()) { + return it->second.second; + } + + // Otherwise recursively mutate the left hand side, and create new variables + PrimExpr mutated_a = VisitExpr(op->a); + if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kFloorDiv)) { + return var_pair_opt.value().second; + } else { + return floormod(mutated_a, op->b); + } + } + + return floormod(VisitExpr(op->a), VisitExpr(op->b)); + } + + private: + dmlc::optional<std::pair<Var, Var>> AddNewVarPair(const PrimExpr& e, const PrimExpr& mut, + int64_t val, DivMode mode) { + using tresult = dmlc::optional<std::pair<Var, Var>>; + + // Try to find the variables using the mutated expressions + if (!e.same_as(mut)) { + auto it = expr_to_vars_.find(std::make_tuple(mode, mut, val)); + if (it != expr_to_vars_.end()) { + return tresult(it->second); + } + } + + PrimExpr val_e = make_const(e.dtype(), val); + idx_ += 1; + + // Convert `ranges` to IntSets + std::unordered_map<const VarNode*, IntSet> var_intsets; + for (const auto& p : ranges) { + var_intsets[p.first.get()] = IntSet::FromRange(p.second); + } + + // Infer ranges for the expressions we want to replace with variables + Range div_range = EvalSet(DivImpl(mut, val_e, mode), var_intsets).CoverRange(Range()); + Range mod_range = EvalSet(ModImpl(mut, val_e, mode), var_intsets).CoverRange(Range()); + + // We don't want to add unbounded variables + if (!div_range.get() || !mod_range.get()) { + LOG(WARNING) << "EliminateDivMod: won't eliminate " << DivImpl(e, val_e, mode) + << " because its bounds cannot be inferred"; + return tresult(); + } + if (!mod_range.get()) { + LOG(WARNING) << "EliminateDivMod: won't eliminate " << ModImpl(e, val_e, mode) + << " because its bounds cannot be inferred"; + return tresult(); + } + + // Create new variables for the expressions + auto div = Var((mode == kTruncDiv ? "tdiv" : "fdiv") + std::to_string(idx_), e.dtype()); + auto mod = Var((mode == kTruncDiv ? "tmod" : "fmod") + std::to_string(idx_), e.dtype()); + + new_variables.push_back(div); + new_variables.push_back(mod); + + // Note that we have to perform substitution to mut because mut may contain new variables + substitution.Set(div, DivImpl(Substitute(mut, substitution), val_e, mode)); + substitution.Set(mod, ModImpl(Substitute(mut, substitution), val_e, mode)); + + ranges.Set(div, div_range); + ranges.Set(mod, mod_range); + + // This additional condition works as a definition for the new variables + conditions.push_back(mut == div * val_e + mod); + + if (!analyzer_.CanProve(mod_range->extent <= val_e)) { + // Since we use the C/C++ definition of mod, there may be multiple values of `mod` + // satisfying the added condition if the expr `e` may change its sign, so we + // have to add another condition. + LOG(WARNING) << "EliminateDivMod: cannot fully eliminate div or mod because " + << ModImpl(e, val_e, mode) << " probably may change its sign"; + conditions.push_back(Select(e >= 0, mod >= 0, mod <= 0)); + } + + auto p = std::make_pair(div, mod); + expr_to_vars_[std::make_tuple(mode, e, val)] = p; + if (!e.same_as(mut)) { + expr_to_vars_[std::make_tuple(mode, mut, val)] = p; + } + return tresult(p); + } + + class TupleEqual_ { + public: + bool operator()(const std::tuple<DivMode, PrimExpr, int64_t>& lhs, + const std::tuple<DivMode, PrimExpr, int64_t>& rhs) const { + return std::get<0>(lhs) == std::get<0>(rhs) && + tir::ExprDeepEqual()(std::get<1>(lhs), std::get<1>(rhs)) && + std::get<2>(lhs) == std::get<2>(rhs); + } + }; + + class TupleHasher_ { + public: + size_t operator()(const std::tuple<DivMode, PrimExpr, int64_t>& key) const { + return ((std::hash<int>()(std::get<0>(key)) ^ (StructuralHash()(std::get<1>(key)) << 1)) >> + 1) ^ + (std::hash<int64_t>()(std::get<2>(key)) << 1); + } + }; + + // A counter for naming new variables + int idx_{0}; + // A map from pairs of exprs and numbers (e, n) to pairs of new vars (div, mod) + // such that `div = e / n` and `mod = e % n` + std::unordered_map<std::tuple<DivMode, PrimExpr, int64_t>, std::pair<Var, Var>, TupleHasher_, + TupleEqual_> + expr_to_vars_; + arith::Analyzer analyzer_; +}; + +// Replace every subexpr of the form e/const and e % const with a new variable. +// Syntactically equal expressions will be mapped to the same variable. +EliminateDivModResult EliminateDivMod(const PrimExpr& expr, Map<Var, Range> ranges) { + EliminateDivModResult res; + EliminateDivModMutator mutator(ranges); + res.expr = mutator(expr); + res.conditions = std::move(mutator.conditions); + res.new_variables = std::move(mutator.new_variables); + res.substitution = std::move(mutator.substitution); + res.ranges = std::move(mutator.ranges); + return res; +} + +arith::IntConstraintsTransform EliminateDivModFromDomainConditions( + const arith::IntConstraints& domain) { + auto elim_res = EliminateDivMod(All(domain->relations), domain->ranges); + + Map<Var, Range> new_vranges = elim_res.ranges; + Array<Var> new_axis = Concat(domain->variables, elim_res.new_variables); + PrimExpr new_cond = elim_res.expr && All(elim_res.conditions); + + arith::IntConstraints new_domain(new_axis, new_vranges, + FactorOutAtomicFormulas(new_cond).to_array()); + + Map<Var, PrimExpr> src_to_dst; + Map<Var, PrimExpr> dst_to_src = elim_res.substitution; + for (const Var& v : domain->variables) { + src_to_dst.Set(v, v); + dst_to_src.Set(v, v); + } + + return arith::IntConstraintsTransform(domain, new_domain, src_to_dst, dst_to_src); +} + +// Simplify an iteration domain. +inline arith::IntConstraintsTransform IdentityTransformation(const arith::IntConstraints& domain) { + Map<Var, PrimExpr> identity_map; + for (const Var& v : domain->variables) { + identity_map.Set(v, v); + } + return arith::IntConstraintsTransform(domain, domain, identity_map, identity_map); +} + +arith::IntConstraintsTransform SimplifyDomain(const arith::IntConstraints& iter_domains, + bool eliminate_div_mod) { + arith::IntConstraintsTransform transf = IdentityTransformation(iter_domains); + + if (eliminate_div_mod) { + transf = transf + EliminateDivModFromDomainConditions(transf->dst); + } + + // TODO(sgrechanik-h): Repeating the following steps has a positive effect, however we probably + // should find a better terminating criterion (like stop when the domain volume stops decreasing) + // Also 2 steps seems to be slightly better than 3 + for (size_t i = 0; i < 2; ++i) { + arith::IntConstraintsTransform tr = arith::SolveLinearEquations(transf->dst); + transf = transf + tr; + // TODO(sgrechanik-h): This helps for some artificial examples, however I'm not sure about + // enabling it in general. The problem it solves is propagating equalities of outer vars. + // tr = AddOuterVariablesIntoDomain(transf->dst); + tr = arith::SolveInequalitiesDeskewRange(transf->dst); + transf = transf + tr; + } + + return transf; +} + +// Use the condition of a reduction op to simplify its domain (axis) +PrimExpr SimplifyReductionDomain(const PrimExpr& expr, const Map<Var, Range>& outer_vranges) { + if (const ReduceNode* red = expr.as<ReduceNode>()) { + Array<Var> vars = IterVarsToVars(red->axis); + Map<Var, Range> vranges = Merge(outer_vranges, IterVarsToMap(red->axis)); + Array<PrimExpr> relations = FactorOutAtomicFormulas(red->condition).to_array(); + + arith::IntConstraints domain(vars, vranges, relations); + auto res = SimplifyDomain(domain); + + Array<PrimExpr> new_source; + for (const PrimExpr& src : red->source) { + new_source.push_back(Substitute(src, res->src_to_dst)); + } + + Array<IterVar> new_axis = IterVarsFromMap(res->dst->variables, res->dst->ranges, kCommReduce); + + // Perform simplification mainly to remove a possibly empty reduction. + arith::Analyzer analyzer; + return analyzer.Simplify( + Reduce(red->combiner, new_source, new_axis, All(res->dst->relations), red->value_index), + ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE); + } else { + return expr; + } +} + +// Extract from cond an implication of cond not containing vars +std::pair<PrimExpr, PrimExpr> ImplicationNotContainingVars( + const PrimExpr& cond, const std::unordered_set<const VarNode*>& vars) { + CHECK(cond.dtype().is_bool()) << "The type of cond must be bool"; + // TODO(sgrechanik-h): NOT Review comment: Actually in my understanding it's not straightforward to separate NOT node here, as the false branch of (!pair.a) will also contain the reduction (instead the zero). I'm not sure whether it provides benefit, @sergei-grechanik would you help to comment? ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected]
