Lunderberg commented on code in PR #13130: URL: https://github.com/apache/tvm/pull/13130#discussion_r1019605163
########## src/tir/analysis/control_flow_graph.cc: ########## @@ -0,0 +1,1641 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file control_flow_graph.cc + * \brief Utility to deduce bound of expression + */ + +#include "control_flow_graph.h" + +#include <tvm/runtime/registry.h> +#include <tvm/tir/analysis.h> +#include <tvm/tir/builtin.h> +#include <tvm/tir/expr.h> +#include <tvm/tir/op.h> +#include <tvm/tir/stmt_functor.h> + +#include <numeric> +#include <optional> +#include <queue> +#include <set> +#include <sstream> +#include <unordered_set> + +#include "../../arith/conjunctive_normal_form.h" +#include "../../arith/constraint_extract.h" +#include "../../arith/ir_mutator_with_analyzer.h" +#include "../../arith/ir_visitor_with_analyzer.h" +#include "../../arith/narrow_predicate_expression.h" +#include "../../arith/unwrap_vector_expr.h" + +namespace tvm { +namespace tir { + +using namespace arith; + +namespace { +bool HasBufferLoad(PrimExpr expr) { + struct Visitor : public ExprVisitor { + void VisitExpr_(const BufferLoadNode* node) override { found_buffer_load = true; } + bool found_buffer_load{false}; + }; + + Visitor visitor; + visitor(expr); + return visitor.found_buffer_load; +} + +Optional<PrimExpr> SubstituteParamValues(const Array<Var>& param_vars, + const Array<PrimExpr>& param_values, + const PrimExpr& expr) { + ICHECK_EQ(param_vars.size(), param_values.size()) + << "Expression was defined as having " << param_vars.size() << " parameters, but received " + << param_values.size() << " arguments."; + + Map<tir::Var, PrimExpr> var_map; + for (size_t i = 0; i < param_values.size(); i++) { + var_map.Set(param_vars[i], param_values[i]); + } + + return Substitute(expr, var_map); +} +} // namespace + +PrimExpr BufferTouch::BeforeLoopIteration() const { + PrimExpr loop_predicate = Bool(true); + for (auto it = loop_var_expressions.rbegin(); it != loop_var_expressions.rend(); it++) { + const Var& loop_var = it->first; + const PrimExpr& loop_expr = it->second; + loop_predicate = (loop_var <= loop_expr) || ((loop_var == loop_expr) && loop_predicate); + } + return loop_predicate; +} + +PrimExpr BufferTouch::AtLoopIteration() const { + PrimExpr loop_predicate = Bool(true); + for (auto it = loop_var_expressions.rbegin(); it != loop_var_expressions.rend(); it++) { + const Var& loop_var = it->first; + const PrimExpr& loop_expr = it->second; + loop_predicate = (loop_var == loop_expr) && loop_predicate; + } + return loop_predicate; +} + +PrimExpr BufferTouch::AfterLoopIteration() const { + PrimExpr loop_predicate = Bool(true); + for (auto it = loop_var_expressions.rbegin(); it != loop_var_expressions.rend(); it++) { + const Var& loop_var = it->first; + const PrimExpr& loop_expr = it->second; + loop_predicate = (loop_var >= loop_expr) || ((loop_var == loop_expr) && loop_predicate); + } + return loop_predicate; +} + +bool BufferTouch::IsSubsetOf(const BufferTouch& other, Analyzer* analyzer) const { + if (this->buffer.same_as(other.buffer)) { + With<ConstraintContext> constraint(analyzer, predicate); + + return analyzer->CanProve(other.predicate); + } else { + return false; + } +} + +bool BufferTouch::IsDistinctFrom(const BufferTouch& other, Analyzer* analyzer) const { + if (this->buffer.same_as(other.buffer)) { + With<ConstraintContext> constraint(analyzer, predicate); + + return analyzer->CanProve(!other.predicate); + } else { + return true; + } +} + +std::ostream& operator<<(std::ostream& os, const BufferTouch& tp) { + auto touch_type = [&]() { + if (tp.touch_type == BufferTouch::AccessType::Read) { + return "read"; + } else if (tp.touch_type == BufferTouch::AccessType::Write) { + return "write"; + } else if (tp.touch_type == BufferTouch::AccessType::Assume) { + return "assume"; + } else { + return "???"; + } + }(); + + os << "BufferTouch(" << tp.buffer->name << ", " << touch_type << ", " << tp.predicate + << ", value = " << tp.value << ")"; + return os; +} + +class BufferConstraintApply : public IRMutatorWithAnalyzer { + public: + using Parent = IRMutatorWithAnalyzer; + + BufferConstraintApply(const Map<Buffer, Array<Var>>& axis_var_lookup, + const std::vector<BufferTouch>& knowns, Analyzer* analyzer) + : Parent(analyzer), axis_var_lookup_(axis_var_lookup), knowns_(knowns) {} + + using Parent::VisitExpr_; + + PrimExpr VisitExpr_(const BufferLoadNode* op) override { + for (const auto& known : knowns_) { + if (!op->buffer.same_as(known.buffer)) { + continue; + } + + Optional<Var> lane_var = NullOpt; + IntImm num_lanes; + + Array<PrimExpr> indices = op->indices.Map([&](const auto& index) { + if (index.dtype().lanes() == 1) { + return index; + } else { + ICHECK(!lane_var) << "Multiple indices found with non-scalar values"; + lane_var = Var("lane", index.dtype().element_of()); + num_lanes = IntImm(index.dtype().element_of(), index.dtype().lanes()); + return UnwrapVectorExpr(index, lane_var.value()); + } + }); + + auto axis_vars = axis_var_lookup_.at(op->buffer); + PrimExpr predicate = SubstituteParamValues(axis_vars, indices, known.predicate).value(); + + std::optional<With<ConstraintContext>> context; + if (lane_var.defined()) { + Var lanes = lane_var.value(); + PrimExpr known = (IntImm(lanes.dtype(), 0) <= lanes) && (lanes < num_lanes); + context.emplace(analyzer_, known); + } + + if (analyzer_->CanProve(predicate)) { + return SubstituteParamValues(axis_vars, op->indices, known.value).value(); + } + } + + return GetRef<PrimExpr>(op); + } + + private: + const Map<Buffer, Array<Var>>& axis_var_lookup_; + const std::vector<BufferTouch>& knowns_; +}; + +/*! \brief Extract the control-flow graph + * + * Walk through a statement, populating the control-flow graph. + */ +class ControlFlowGraphBuilder final : public IRVisitorWithAnalyzer { + public: + static void Build(ControlFlowGraph* out, const Stmt& stmt) { + ControlFlowGraphBuilder extractor(out); + extractor.AppendControlBlock(); + extractor(stmt); + } + + private: + ControlFlowGraphBuilder(ControlFlowGraph* out) : out_(out) {} + + using Parent = IRVisitorWithAnalyzer; + using Parent::VisitExpr_; + using Parent::VisitStmt_; + + void VisitStmt(const Stmt& stmt) override { + // Update the lookup table to determine which control-flow block + // contains the start of the specified statement. This is used + // later to determine which set of known values should be used to + // simplify a statement. + out_->control_flow_lookup_[stmt.get()] = CurrentControlBlock(); + Stmt prev_stmt = current_stmt_; + current_stmt_ = stmt; + Parent::VisitStmt(stmt); + current_stmt_ = prev_stmt; + } + + void VisitStmt_(const EvaluateNode* op) override { + if (auto* call = op->value.as<CallNode>()) { + if (call->op.same_as(builtin::assume())) { + Assume(call->args[0], true); + return; + } + } + + Parent::VisitStmt_(op); + } + + void Assume(PrimExpr assumption, bool from_assume_statement) { + for (const auto& expr : ExtractConstraints(assumption, false)) { + AssumeConstraintComponent(expr, from_assume_statement); + } + } + + void AssumeConstraintComponent(PrimExpr assumption, bool from_assume_statement) { + PrimExpr additional_predicate = Bool(true); + + std::vector<PrimExpr> buffer_exprs; + for (const auto& expr : ExtractComponents(assumption)) { + auto side_effect = tir::SideEffect(expr); + if (side_effect <= tir::CallEffectKind::kPure) { + // Pulling out portions of the assumption that do not depend + // on a buffer value allows the following two forms to be + // treated identically. + // + // if i < 3: T.assume(buf[i] == value) + // T.assume(i>=3 or buf[i] == value) + additional_predicate = additional_predicate && logical_not(expr); + } else if (side_effect == tir::CallEffectKind::kReadState) { + buffer_exprs.push_back(expr); + } else { + LOG(FATAL) << "Assumption must be pure or read-only"; + } + } + + if (buffer_exprs.empty()) { + out_->non_buffer_assumptions_.push_back(!CurrentScopePredicate() || assumption); + return; + } + + CHECK_EQ(buffer_exprs.size(), 1) << "T.assume must contain only a single buffer expression"; + + auto* as_equal_node = buffer_exprs[0].as<tir::EQNode>(); + CHECK(as_equal_node || !from_assume_statement) + << "T.assume buffer constraint must be of the form 'buffer[indices] == " + "value', but received " + << assumption; + if (!as_equal_node) { + // This assumption is an inequality a data-dependent + // conditional. Not an error for this to occur, but also not Review Comment: Thank you, and updated. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
