MasterJH5574 commented on code in PR #11088:
URL: https://github.com/apache/tvm/pull/11088#discussion_r855853168
##########
python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py:
##########
@@ -82,3 +82,37 @@ def __init__(
reuse_read.as_dict() if reuse_read is not None else None,
reuse_write.as_dict() if reuse_write is not None else None,
)
+
+
+@register_object("meta_schedule.MultiLevelTilingWithIntrin")
+class MultiLevelTilingWithIntrin(ScheduleRule):
+ """Multi-level tiling with reuse.
+
+ Parameters
+ ----------
+ intrin_name : str
+ The name of a tensor intrinsic, must be registerd via
TensorIntrin.register(...) beforehand
Review Comment:
Should we have the documents for other parameters here?
##########
src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc:
##########
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "../../tir/schedule/transform.h"
+#include "../utils.h"
+#include "multi_level_tiling.h"
+
+namespace tvm {
+namespace meta_schedule {
+
+/*!
+ * \brief Tile a subset of loops in the block according to the given tensor
intrinsic, and annotate
+ * the tiled block for tensorization by postproc rewrite.
+ */
+tir::BlockRV TileForIntrin(tir::Schedule sch, tir::BlockRV block, const
std::string& intrin_name) {
+ Optional<tir::LoopRV> tiled_loop_rv = TileWithTensorIntrin(sch, block,
intrin_name);
+ ICHECK(tiled_loop_rv.defined());
+ tir::BlockRV outer_block = sch->Blockize(tiled_loop_rv.value());
+ sch->Annotate(outer_block, tir::attr::meta_schedule_auto_tensorize,
String(intrin_name));
+ return outer_block;
+}
+
+/*!
+ * \brief Extension of MultiLevelTiling for auto-tensorizing with a single
intrinsic.
+ */
+class MultiLevelTilingWithIntrinNode : public MultiLevelTilingNode {
+ protected:
+ // Override ApplySubRules to tile the inner loops according to the given
tensor intrinsic, then
+ // tile the outerloops.
+ virtual std::vector<State> ApplySubRules(std::vector<State> states) {
+ states = SubRule(std::move(states), [&](State state) {
+ state.block_rv = TileForIntrin(state.sch, state.block_rv, intrin_name);
+ return std::vector<State>(1, state);
+ });
+ return MultiLevelTilingNode::ApplySubRules(states);
+ }
+
+ public:
+ /*! \brief The name of a tensor intrinsic. */
+ String intrin_name;
+
+ static constexpr const char* _type_key =
"meta_schedule.MultiLevelTilingWithIntrin";
+ TVM_DECLARE_FINAL_OBJECT_INFO(MultiLevelTilingWithIntrinNode,
MultiLevelTilingNode);
+};
+
+ScheduleRule ScheduleRule::MultiLevelTilingWithIntrin(
+ String intrin_name, String structure, Optional<Array<String>> tile_binds,
+ Optional<Integer> max_innermost_factor, Optional<Array<Integer>>
vector_load_lens,
+ Optional<Map<String, ObjectRef>> reuse_read, Optional<Map<String,
ObjectRef>> reuse_write) {
+ ICHECK(tir::TensorIntrin::Get(intrin_name).defined())
+ << "Provided tensor intrinsic " << intrin_name << " is not registered.";
+ auto node = MultiLevelTilingInitCommon<MultiLevelTilingWithIntrinNode>(
+ structure, tile_binds, max_innermost_factor, vector_load_lens,
reuse_read, reuse_write);
Review Comment:
I’m a little bit confused: where are the loops tiled according to the tiling
structure in this rule?
##########
src/meta_schedule/schedule_rule/multi_level_tiling.cc:
##########
@@ -25,6 +25,7 @@
#include <vector>
#include "../utils.h"
+#include "tvm/meta_schedule/schedule_rule.h"
Review Comment:
Hmmm what’s this line intended for?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]