wrongtest commented on a change in pull request #10732:
URL: https://github.com/apache/tvm/pull/10732#discussion_r833829656



##########
File path: src/tir/schedule/analysis/layout.cc
##########
@@ -0,0 +1,212 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "../utils.h"
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief Calculate the strides of the buffer
+ * \param buffer The buffer
+ * \return The strides
+ */
+Array<PrimExpr> GetStrides(const Buffer& buffer) {
+  if (!buffer->strides.empty()) {
+    ICHECK_EQ(buffer->strides.size(), buffer->shape.size());
+    return buffer->strides;
+  }
+  int ndim = buffer->shape.size();
+  if (ndim == 0) {
+    return {};
+  }
+  Array<PrimExpr> strides(ndim, PrimExpr{nullptr});
+  PrimExpr stride = make_const(buffer->DefaultIndexType(), 1);
+  for (int i = ndim - 1; i >= 0; --i) {
+    strides.Set(i, stride);
+    stride = stride * buffer->shape[i];
+  }
+  return strides;
+}
+
+/*!
+ * \brief Auxiliary class that collects the IterSplitExpr in the indexing 
pattern
+ * to help decision making in layout transformation
+ */
+class SplitExprCollector {
+ public:
+  /*!
+   * \brief The corresponding IterSplitExpr, simplified for our case
+   * The pattern is `source // lower_factor % extent * scale`
+   */
+  struct SplitExpr {
+    /*! \brief The source variable */
+    Var source;
+    /*! \brief The lower factor of the split expression */
+    int64_t lower_factor;
+    /*! \brief The extent of the split expression */
+    int64_t extent;
+  };
+
+  /*!
+   * \brief Collect the split expressions in the indexing pattern
+   * \param index The indexing pattern
+   * \param input_iters The input iterators' domain
+   * \param predicate The predicate of the affine map
+   * \param require_bijective Whether the affine map is required to be 
bijective
+   * \param analyzer The analyzer
+   * \return The collected split expressions
+   */
+  static std::vector<SplitExpr> Collect(const PrimExpr& index,
+                                        const Map<Var, Range>& input_iters,  //
+                                        const PrimExpr& predicate,           //
+                                        bool require_bijective,              //
+                                        arith::Analyzer* analyzer) {
+    DiagnosticContext diag_ctx(DiagnosticContext::Default(IRModule()));
+    Array<arith::IterSumExpr> iter_sum_exprs = arith::DetectIterMap(
+        {analyzer->Simplify(index)}, input_iters, predicate, 
require_bijective, analyzer, diag_ctx);

Review comment:
       @junrushao1994 Hi, any suggestion for the dbg info in `DetectIterMap`? 
We may remove them or change to logging if the diag ctx is not preferred.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to