Lunderberg commented on code in PR #13130:
URL: https://github.com/apache/tvm/pull/13130#discussion_r1023118214


##########
src/tir/analysis/control_flow_graph.cc:
##########
@@ -0,0 +1,1642 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file control_flow_graph.cc
+ * \brief Utility to deduce bound of expression
+ */
+
+#include "control_flow_graph.h"
+
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/builtin.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <numeric>
+#include <optional>
+#include <queue>
+#include <set>
+#include <sstream>
+#include <unordered_set>
+
+#include "../../arith/conjunctive_normal_form.h"
+#include "../../arith/constraint_extract.h"
+#include "../../arith/ir_mutator_with_analyzer.h"
+#include "../../arith/ir_visitor_with_analyzer.h"
+#include "../../arith/narrow_predicate_expression.h"
+#include "../../arith/unwrap_vector_expr.h"
+
+namespace tvm {
+namespace tir {
+
+using namespace arith;
+
+namespace {
+bool HasBufferLoad(PrimExpr expr) {
+  struct Visitor : public ExprVisitor {
+    void VisitExpr_(const BufferLoadNode* node) override { found_buffer_load = 
true; }
+    bool found_buffer_load{false};
+  };
+
+  Visitor visitor;
+  visitor(expr);
+  return visitor.found_buffer_load;
+}
+
+Optional<PrimExpr> SubstituteParamValues(const Array<Var>& param_vars,
+                                         const Array<PrimExpr>& param_values,
+                                         const PrimExpr& expr) {
+  ICHECK_EQ(param_vars.size(), param_values.size())
+      << "Expression was defined as having " << param_vars.size() << " 
parameters, but received "
+      << param_values.size() << " arguments.";
+
+  Map<tir::Var, PrimExpr> var_map;
+  for (size_t i = 0; i < param_values.size(); i++) {
+    var_map.Set(param_vars[i], param_values[i]);
+  }
+
+  return Substitute(expr, var_map);
+}
+}  // namespace
+
+PrimExpr BufferTouch::BeforeLoopIteration() const {
+  PrimExpr loop_predicate = Bool(true);
+  for (auto it = loop_var_expressions.rbegin(); it != 
loop_var_expressions.rend(); it++) {
+    const Var& loop_var = it->first;
+    const PrimExpr& loop_expr = it->second;
+    loop_predicate = (loop_var <= loop_expr) || ((loop_var == loop_expr) && 
loop_predicate);
+  }
+  return loop_predicate;
+}
+
+PrimExpr BufferTouch::AtLoopIteration() const {
+  PrimExpr loop_predicate = Bool(true);
+  for (auto it = loop_var_expressions.rbegin(); it != 
loop_var_expressions.rend(); it++) {
+    const Var& loop_var = it->first;
+    const PrimExpr& loop_expr = it->second;
+    loop_predicate = (loop_var == loop_expr) && loop_predicate;
+  }
+  return loop_predicate;
+}
+
+PrimExpr BufferTouch::AfterLoopIteration() const {
+  PrimExpr loop_predicate = Bool(true);
+  for (auto it = loop_var_expressions.rbegin(); it != 
loop_var_expressions.rend(); it++) {
+    const Var& loop_var = it->first;
+    const PrimExpr& loop_expr = it->second;
+    loop_predicate = (loop_var >= loop_expr) || ((loop_var == loop_expr) && 
loop_predicate);
+  }
+  return loop_predicate;
+}
+
+bool BufferTouch::IsSubsetOf(const BufferTouch& other, Analyzer* analyzer) 
const {
+  if (this->buffer.same_as(other.buffer)) {
+    With<ConstraintContext> constraint(analyzer, predicate);
+
+    return analyzer->CanProve(other.predicate);
+  } else {
+    return false;
+  }
+}
+
+bool BufferTouch::IsDistinctFrom(const BufferTouch& other, Analyzer* analyzer) 
const {
+  if (this->buffer.same_as(other.buffer)) {
+    With<ConstraintContext> constraint(analyzer, predicate);
+
+    return analyzer->CanProve(!other.predicate);
+  } else {
+    return true;
+  }
+}
+
+std::ostream& operator<<(std::ostream& os, const BufferTouch& tp) {
+  auto touch_type = [&]() {
+    if (tp.touch_type == BufferTouch::AccessType::Read) {
+      return "read";
+    } else if (tp.touch_type == BufferTouch::AccessType::Write) {
+      return "write";
+    } else if (tp.touch_type == BufferTouch::AccessType::Assume) {
+      return "assume";
+    } else {
+      return "???";
+    }
+  }();
+
+  os << "BufferTouch(" << tp.buffer->name << ", " << touch_type << ", " << 
tp.predicate
+     << ", value = " << tp.value << ")";
+  return os;
+}
+
+class BufferConstraintApply : public IRMutatorWithAnalyzer {
+ public:
+  using Parent = IRMutatorWithAnalyzer;
+
+  BufferConstraintApply(const Map<Buffer, Array<Var>>& axis_var_lookup,
+                        const std::vector<BufferTouch>& knowns, Analyzer* 
analyzer)
+      : Parent(analyzer), axis_var_lookup_(axis_var_lookup), knowns_(knowns) {}
+
+  using Parent::VisitExpr_;
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) override {
+    for (const auto& known : knowns_) {
+      if (!op->buffer.same_as(known.buffer)) {
+        continue;
+      }
+
+      Optional<Var> lane_var = NullOpt;
+      IntImm num_lanes;
+
+      Array<PrimExpr> indices = op->indices.Map([&](const auto& index) {
+        if (index.dtype().lanes() == 1) {
+          return index;
+        } else {
+          ICHECK(!lane_var) << "Multiple indices found with non-scalar values";
+          lane_var = Var("lane", index.dtype().element_of());
+          num_lanes = IntImm(index.dtype().element_of(), 
index.dtype().lanes());
+          return UnwrapVectorExpr(index, lane_var.value());
+        }
+      });
+
+      auto axis_vars = axis_var_lookup_.at(op->buffer);
+      PrimExpr predicate = SubstituteParamValues(axis_vars, indices, 
known.predicate).value();
+
+      std::optional<With<ConstraintContext>> context;
+      if (lane_var.defined()) {
+        Var lanes = lane_var.value();
+        PrimExpr known = (IntImm(lanes.dtype(), 0) <= lanes) && (lanes < 
num_lanes);
+        context.emplace(analyzer_, known);
+      }
+
+      if (analyzer_->CanProve(predicate)) {
+        return SubstituteParamValues(axis_vars, op->indices, 
known.value).value();
+      }
+    }
+
+    return GetRef<PrimExpr>(op);
+  }
+
+ private:
+  const Map<Buffer, Array<Var>>& axis_var_lookup_;
+  const std::vector<BufferTouch>& knowns_;
+};
+
+/*! \brief Extract the control-flow graph
+ *
+ * Walk through a statement, populating the control-flow graph.
+ */
+class ControlFlowGraphBuilder final : public IRVisitorWithAnalyzer {
+ public:
+  static void Build(ControlFlowGraph* out, const Stmt& stmt) {
+    ControlFlowGraphBuilder extractor(out);
+    extractor.AppendControlBlock();
+    extractor(stmt);
+  }
+
+ private:
+  ControlFlowGraphBuilder(ControlFlowGraph* out) : out_(out) {}
+
+  using Parent = IRVisitorWithAnalyzer;
+  using Parent::VisitExpr_;
+  using Parent::VisitStmt_;
+
+  void VisitStmt(const Stmt& stmt) override {
+    // Update the lookup table to determine which control-flow block
+    // contains the start of the specified statement.  This is used
+    // later to determine which set of known values should be used to
+    // simplify a statement.
+    out_->control_flow_lookup_[stmt.get()] = CurrentControlBlock();
+    Stmt prev_stmt = current_stmt_;
+    current_stmt_ = stmt;
+    Parent::VisitStmt(stmt);
+    current_stmt_ = prev_stmt;
+  }
+
+  void VisitStmt_(const EvaluateNode* op) override {
+    if (auto* call = op->value.as<CallNode>()) {
+      if (call->op.same_as(builtin::assume())) {
+        Assume(call->args[0], true);
+        return;
+      }
+    }
+
+    Parent::VisitStmt_(op);
+  }
+
+  void Assume(PrimExpr assumption, bool from_assume_statement) {
+    for (const auto& expr : ExtractConstraints(assumption, false)) {
+      AssumeConstraintComponent(expr, from_assume_statement);
+    }
+  }
+
+  void AssumeConstraintComponent(PrimExpr assumption, bool 
from_assume_statement) {
+    PrimExpr additional_predicate = Bool(true);
+
+    std::vector<PrimExpr> buffer_exprs;
+    for (const auto& expr : ExtractComponents(assumption)) {
+      auto side_effect = tir::SideEffect(expr);
+      if (side_effect <= tir::CallEffectKind::kPure) {
+        // Pulling out portions of the assumption that do not depend
+        // on a buffer value allows the following two forms to be
+        // treated identically.
+        //
+        // if i < 3: T.assume(buf[i] == value)
+        // T.assume(i>=3 or buf[i] == value)
+        additional_predicate = additional_predicate && logical_not(expr);
+      } else if (side_effect == tir::CallEffectKind::kReadState) {
+        buffer_exprs.push_back(expr);
+      } else {
+        LOG(FATAL) << "Assumption must be pure or read-only, but contained 
expression " << expr
+                   << " with side-effect \'" << side_effect << "\'";
+      }
+    }
+
+    if (buffer_exprs.empty()) {
+      out_->non_buffer_assumptions_.push_back(!CurrentScopePredicate() || 
assumption);
+      return;
+    }
+
+    CHECK_EQ(buffer_exprs.size(), 1) << "T.assume must contain only a single 
buffer expression";
+
+    auto* as_equal_node = buffer_exprs[0].as<tir::EQNode>();
+    CHECK(as_equal_node || !from_assume_statement)
+        << "T.assume buffer constraint must be of the form 'buffer[indices] == 
"
+           "value', but received "
+        << assumption;
+    if (!as_equal_node) {
+      // This assumption is an inequality on a data-dependent
+      // conditional.  Not an error for this to occur, but also not
+      // something that is currently supported.
+      return;
+    }

Review Comment:
   If the data-dependent inequality occurs in `T.assume()`, it raises an error 
because the primary purpose of `T.assume()` is to provide information for this 
analysis.  The comment applies specifically to scoped contexts that have a 
data-dependent conditional (e.g. `if buf[i] > 10: ...`), which are valid TIR 
but not something that are used.
   
   I've added a description of this limitation to the docstring of 
`ControlFlowGraph`, that only fully-constrained values are tracked, so that the 
limitation is more visible.  However, since it is entirely legal TIR to write 
these data-dependent inequalities, I don't think it should raise an error when 
encountering them.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to