Lunderberg commented on code in PR #13085:
URL: https://github.com/apache/tvm/pull/13085#discussion_r997068685
##########
src/tir/ir/index_map.cc:
##########
@@ -147,6 +147,37 @@ IndexMap IndexMap::Inverse(Array<Range> initial_ranges)
const {
return inverse;
}
+/*!
+ * \brief Evaluator to compute the mapped indices of a given index map.
+ */
+class IndexMapEvaluator : public DataTypeLegalizer {
+ public:
+ explicit IndexMapEvaluator(const IndexMap& index_map) :
index_map_(index_map) {}
+
+ Array<PrimExpr> Eval(const Array<PrimExpr>& arguments, arith::Analyzer*
analyzer) {
+ var_map_.clear();
+ ICHECK_EQ(arguments.size(), index_map_->initial_indices.size());
+ for (int i = 0; i < static_cast<int>(arguments.size()); ++i) {
+ var_map_.Set(index_map_->initial_indices[i], arguments[i]);
+ }
+ Array<PrimExpr> result = index_map_->final_indices;
+ result.MutateByApply([&](PrimExpr expr) { return
analyzer->Simplify(this->VisitExpr(expr)); });
Review Comment:
Since the simplification is only done as a post-processing step, and isn't
needed during the datatype rewrite itself, there wouldn't be a performance
difference between simplification inside or outside the `Eval` function call.
Could we remove the `Analyzer*` argument, and instead perform simplification
outside of this method?
##########
src/tir/ir/index_map.cc:
##########
@@ -147,6 +147,37 @@ IndexMap IndexMap::Inverse(Array<Range> initial_ranges)
const {
return inverse;
}
+/*!
+ * \brief Evaluator to compute the mapped indices of a given index map.
+ */
+class IndexMapEvaluator : public DataTypeLegalizer {
Review Comment:
Should this be part of a more general utility overall? It looks like this
is effectively a variation of
[`tvm::tir::Substitute`](https://github.com/apache/tvm/blob/main/include/tvm/tir/stmt_functor.h#L368)
that allows the substituted values to have a different datatype than the
replaced var's type. Rather than having a utility that is specific to
`IndexMap`, should this functionality be an option when calling `Substitute`?
##########
src/tir/ir/index_map.cc:
##########
@@ -162,18 +193,18 @@ Array<PrimExpr> IndexMapNode::MapIndices(const
Array<PrimExpr>& indices,
analyzer = &local_analyzer;
}
- Array<PrimExpr> output = final_indices.Map(
- [&](PrimExpr index) { return
analyzer->Simplify(Substitute(std::move(index), vmap)); });
-
+ Array<PrimExpr> output =
IndexMapEvaluator(GetRef<IndexMap>(this)).Eval(indices, analyzer);
return output;
}
Array<Range> IndexMapNode::MapRanges(const Array<Range>& ranges,
arith::Analyzer* analyzer) const {
ICHECK_EQ(ranges.size(), initial_indices.size());
Map<Var, Range> input_iters;
+ int max_bits = 0;
Review Comment:
Also, can we move the `auto output_dtype = DataType::Int(max_bits);` line
closer to the initialization of `max_bits`? That way, it's clear to a reader
that the purpose of `max_bits` is solely to generate the `output_dtype`.
Perhaps even moving it into a narrower scope, to make that explicit for the
compiler as well.
```c++
auto output_dtype = [&](){
int max_bits = 0;
for(const auto& range : ranges) {
max_bits = std::max(max_bits, range->extent.dtype().bits());
}
return DataType::Int(max_bits);
}();
```
##########
src/tir/ir/index_map.cc:
##########
@@ -162,18 +193,18 @@ Array<PrimExpr> IndexMapNode::MapIndices(const
Array<PrimExpr>& indices,
analyzer = &local_analyzer;
}
- Array<PrimExpr> output = final_indices.Map(
- [&](PrimExpr index) { return
analyzer->Simplify(Substitute(std::move(index), vmap)); });
-
+ Array<PrimExpr> output =
IndexMapEvaluator(GetRef<IndexMap>(this)).Eval(indices, analyzer);
return output;
}
Array<Range> IndexMapNode::MapRanges(const Array<Range>& ranges,
arith::Analyzer* analyzer) const {
ICHECK_EQ(ranges.size(), initial_indices.size());
Map<Var, Range> input_iters;
+ int max_bits = 0;
Review Comment:
Can we separate the loop that computes `max_bits` from the loop that
computes `input_iters`? The two don't depend on each other, so it feels a bit
odd to have their initialization be mixed together. The initialization of
`max_bits` also doesn't depend on `initial_indices`, so it could use a
range-based for-loop for readability instead of the C-style loop over `i`.
```c++
int max_bits = 0;
for(const auto& range : ranges) {
max_bits = std::max(max_bits, range->extent.dtype().bits());
}
```
##########
src/tir/ir/index_map.cc:
##########
@@ -147,6 +147,37 @@ IndexMap IndexMap::Inverse(Array<Range> initial_ranges)
const {
return inverse;
}
+/*!
+ * \brief Evaluator to compute the mapped indices of a given index map.
+ */
+class IndexMapEvaluator : public DataTypeLegalizer {
+ public:
+ explicit IndexMapEvaluator(const IndexMap& index_map) :
index_map_(index_map) {}
+
+ Array<PrimExpr> Eval(const Array<PrimExpr>& arguments, arith::Analyzer*
analyzer) {
+ var_map_.clear();
+ ICHECK_EQ(arguments.size(), index_map_->initial_indices.size());
+ for (int i = 0; i < static_cast<int>(arguments.size()); ++i) {
+ var_map_.Set(index_map_->initial_indices[i], arguments[i]);
+ }
+ Array<PrimExpr> result = index_map_->final_indices;
Review Comment:
Can use `Array::Map` instead of making a copy that is then mutated.
```c++
return index_map_->final_indices.Map(...);
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]