junrushao1994 commented on a change in pull request #9860: URL: https://github.com/apache/tvm/pull/9860#discussion_r780522736
########## File path: src/meta_schedule/feature_extractor/per_store_feature.cc ########## @@ -0,0 +1,1307 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include <tvm/tir/transform.h> + +#include <cmath> +#include <memory> +#include <numeric> +#include <unordered_map> +#include <unordered_set> +#include <vector> + +#include "../utils.h" + +namespace tvm { +namespace tir { + +using support::NDIntSet; + +/*! \brief Type for multi-dimensional index */ +using MultiIndex = std::vector<PrimExpr>; +/*! \brief Vector of int64_t */ +using IntVec = std::vector<int64_t>; +/*! \brief Vector of for loops */ +using ForVec = std::vector<const ForNode*>; + +/*! + * \brief An unordered_map for (for, buffer) => V + * \tparam V The value type + */ +template <class V> +using ForBufferMap = std::unordered_map<const ForNode*, std::unordered_map<const BufferNode*, V>>; + +/*! \brief Given x, compute log2(|x| + 1) */ +inline double slog(double x) { return x >= 0 ? std::log2(x + 1) : std::log2(-x + 1); } + +namespace utils { + +/*! + * \brief Get the shape of the buffer + * \param buffer The buffer + * \param analyzer The analyzer + * \return The shape of the buffer + */ +std::vector<int64_t> GetBufferShape(const Buffer& buffer, arith::Analyzer* analyzer) { + int ndim = buffer->shape.size(); + std::vector<int64_t> result; + result.reserve(ndim); + for (const PrimExpr& i : buffer->shape) { + if (const IntImmNode* int_imm = i.as<IntImmNode>()) { + result.push_back(int_imm->value); + continue; + } + arith::ConstIntBound bound = analyzer->const_int_bound(i); + if (0 <= bound->max_value && bound->max_value < arith::ConstIntBound::kPosInf) { + result.push_back(bound->max_value); + } else { + result.push_back(1); + } + } + return result; +} + +/*! + * \brief Given a loop, return its `pragma_auto_unroll_max_step` annotation if it exists + * \param loop The loop to be checked + * \return The value of `pragma_auto_unroll_max_step` if it exists, or -1 if it does not exist + */ +int64_t GetPragmaAutoUnroll(const ForNode* loop) { + if (Optional<IntImm> auto_unroll = GetAnn<IntImm>(loop, tir::attr::pragma_auto_unroll_max_step)) { + return auto_unroll.value()->value; + } + return -1; +} + +/*! + * \brief Given a list of loops, return the extent of the first loop if the list is not empty, + * and the first loop has constant extent. Otherwise returns the default value given + * \param loops The list of loops to be checked + * \param default_value The default value to be returned if the list is empty or the first loop + * does not have constant extent + * \return The extent of the first loop if the list is not empty, or the first loop has constant + * extent. Otherwise returns the default value + */ +int64_t FirstLoopExtent(const ForVec& loops, int64_t default_value) { + if (!loops.empty()) { + if (const int64_t* extent = GetLoopIntExtent(loops[0])) { + return *extent; + } + } + return default_value; +} + +/*! + * \brief Relax each of the multi-indexing pattern according to the domains bound in the analyzer, + * and then union them into a single region + * \param multi_index_pattern A list of multi-index pattern to be relaxed + * \param numel The size of the single region after union + * \param analyzer The analyzer that contains the domain information + * \return The relaxed and unioned region + */ +IntVec RelaxAndUnion(const std::vector<MultiIndex>& multi_indices, int64_t* numel, + arith::Analyzer* analyzer) { + if (multi_indices.empty()) { + return {}; + } + int n_indices = multi_indices.size(); + int ndim = multi_indices[0].size(); + IntVec access_shape(ndim, 0); + for (int i = 0; i < ndim; ++i) { + int64_t minimum = arith::ConstIntBound::kPosInf; + int64_t maximum = arith::ConstIntBound::kNegInf; + for (int j = 0; j < n_indices; ++j) { + arith::ConstIntBound bound = analyzer->const_int_bound(multi_indices[j][i]); + minimum = std::min(minimum, bound->min_value); + maximum = std::max(maximum, bound->max_value); + } + *numel *= maximum - minimum + 1; Review comment: good point! -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
