Lunderberg commented on code in PR #13085:
URL: https://github.com/apache/tvm/pull/13085#discussion_r997068685


##########
src/tir/ir/index_map.cc:
##########
@@ -147,6 +147,37 @@ IndexMap IndexMap::Inverse(Array<Range> initial_ranges) 
const {
   return inverse;
 }
 
+/*!
+ * \brief Evaluator to compute the mapped indices of a given index map.
+ */
+class IndexMapEvaluator : public DataTypeLegalizer {
+ public:
+  explicit IndexMapEvaluator(const IndexMap& index_map) : 
index_map_(index_map) {}
+
+  Array<PrimExpr> Eval(const Array<PrimExpr>& arguments, arith::Analyzer* 
analyzer) {
+    var_map_.clear();
+    ICHECK_EQ(arguments.size(), index_map_->initial_indices.size());
+    for (int i = 0; i < static_cast<int>(arguments.size()); ++i) {
+      var_map_.Set(index_map_->initial_indices[i], arguments[i]);
+    }
+    Array<PrimExpr> result = index_map_->final_indices;
+    result.MutateByApply([&](PrimExpr expr) { return 
analyzer->Simplify(this->VisitExpr(expr)); });

Review Comment:
   Since the simplification is only done as a post-processing step, and isn't 
needed during the datatype rewrite itself, there wouldn't be a performance 
difference between simplification inside or outside the `Eval` function call.  
Could we remove the `Analyzer*` argument, and instead perform simplification 
outside of this method?



##########
src/tir/ir/index_map.cc:
##########
@@ -147,6 +147,37 @@ IndexMap IndexMap::Inverse(Array<Range> initial_ranges) 
const {
   return inverse;
 }
 
+/*!
+ * \brief Evaluator to compute the mapped indices of a given index map.
+ */
+class IndexMapEvaluator : public DataTypeLegalizer {

Review Comment:
   Should this be part of a more general utility overall?  It looks like this 
is effectively a variation of 
[`tvm::tir::Substitute`](https://github.com/apache/tvm/blob/main/include/tvm/tir/stmt_functor.h#L368)
 that allows the substituted values to have a different datatype than the 
replaced var's type.  Rather than having a utility that is specific to 
`IndexMap`, should this functionality be an option when calling `Substitute`?



##########
src/tir/ir/index_map.cc:
##########
@@ -162,18 +193,18 @@ Array<PrimExpr> IndexMapNode::MapIndices(const 
Array<PrimExpr>& indices,
     analyzer = &local_analyzer;
   }
 
-  Array<PrimExpr> output = final_indices.Map(
-      [&](PrimExpr index) { return 
analyzer->Simplify(Substitute(std::move(index), vmap)); });
-
+  Array<PrimExpr> output = 
IndexMapEvaluator(GetRef<IndexMap>(this)).Eval(indices, analyzer);
   return output;
 }
 
 Array<Range> IndexMapNode::MapRanges(const Array<Range>& ranges, 
arith::Analyzer* analyzer) const {
   ICHECK_EQ(ranges.size(), initial_indices.size());
 
   Map<Var, Range> input_iters;
+  int max_bits = 0;

Review Comment:
   Also, can we move the `auto output_dtype = DataType::Int(max_bits);` line 
closer to the initialization of `max_bits`?  That way, it's clear to a reader 
that the purpose of `max_bits` is solely to generate the `output_dtype`.  
Perhaps even moving it into a narrower scope, to make that explicit for the 
compiler as well.
   
   ```c++
   auto output_dtype = [&](){
     int max_bits = 0;
     for(const auto& range : ranges) {
       max_bits = std::max(max_bits, range->extent.dtype().bits());
     }
     return DataType::Int(max_bits);
   }();
   ```



##########
src/tir/ir/index_map.cc:
##########
@@ -162,18 +193,18 @@ Array<PrimExpr> IndexMapNode::MapIndices(const 
Array<PrimExpr>& indices,
     analyzer = &local_analyzer;
   }
 
-  Array<PrimExpr> output = final_indices.Map(
-      [&](PrimExpr index) { return 
analyzer->Simplify(Substitute(std::move(index), vmap)); });
-
+  Array<PrimExpr> output = 
IndexMapEvaluator(GetRef<IndexMap>(this)).Eval(indices, analyzer);
   return output;
 }
 
 Array<Range> IndexMapNode::MapRanges(const Array<Range>& ranges, 
arith::Analyzer* analyzer) const {
   ICHECK_EQ(ranges.size(), initial_indices.size());
 
   Map<Var, Range> input_iters;
+  int max_bits = 0;

Review Comment:
   Can we separate the loop that computes `max_bits` from the loop that 
computes `input_iters`?  The two don't depend on each other, so it feels a bit 
odd to have their initialization be mixed together.  The initialization of 
`max_bits` also doesn't depend on `initial_indices`, so it could use a 
range-based for-loop for readability instead of the C-style loop over `i`.
   
   ```c++
   int max_bits = 0;
   for(const auto& range : ranges) {
     max_bits = std::max(max_bits, range->extent.dtype().bits());
   }
   ```



##########
src/tir/ir/index_map.cc:
##########
@@ -147,6 +147,37 @@ IndexMap IndexMap::Inverse(Array<Range> initial_ranges) 
const {
   return inverse;
 }
 
+/*!
+ * \brief Evaluator to compute the mapped indices of a given index map.
+ */
+class IndexMapEvaluator : public DataTypeLegalizer {
+ public:
+  explicit IndexMapEvaluator(const IndexMap& index_map) : 
index_map_(index_map) {}
+
+  Array<PrimExpr> Eval(const Array<PrimExpr>& arguments, arith::Analyzer* 
analyzer) {
+    var_map_.clear();
+    ICHECK_EQ(arguments.size(), index_map_->initial_indices.size());
+    for (int i = 0; i < static_cast<int>(arguments.size()); ++i) {
+      var_map_.Set(index_map_->initial_indices[i], arguments[i]);
+    }
+    Array<PrimExpr> result = index_map_->final_indices;

Review Comment:
   Can use `Array::Map` instead of making a copy that is then mutated.
   
   ```c++
   return index_map_->final_indices.Map(...);
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to