wrongtest-intellif commented on code in PR #12939:
URL: https://github.com/apache/tvm/pull/12939#discussion_r987854894
##########
include/tvm/tir/schedule/schedule.h:
##########
@@ -403,6 +403,15 @@ class ScheduleNode : public runtime::Object {
*/
virtual BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index,
const String& storage_scope) = 0;
+ /*!
+ * \brief Create 2 blocks that read&write a buffer region into a read/write
cache.
+ * \param block_rv The block operates on the target buffer.
+ * \param read_buffer_index The index of the buffer in block's read region.
+ * \param storage_scope The target storage scope
+ * \return The reindex stage block.
Review Comment:
return comment requires updation
##########
src/tir/schedule/primitive/cache_read_write.cc:
##########
@@ -1146,6 +1238,100 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef&
block_sref, int write_bu
return result_block_sref;
}
+Array<StmtSRef> CacheBuffer(ScheduleState self, const StmtSRef& block_sref,
int write_buffer_index,
Review Comment:
IIUC, the second argument should be `read_buffer_index`?
##########
src/tir/schedule/primitive/cache_read_write.cc:
##########
@@ -1146,6 +1238,100 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef&
block_sref, int write_bu
return result_block_sref;
}
+Array<StmtSRef> CacheBuffer(ScheduleState self, const StmtSRef& block_sref,
int write_buffer_index,
+ const String& storage_scope) {
+ /*!
+ * Do cache read then cache write
+ */
+
+ // Check 0. Check the input storage scope.
+ CheckStorageScope(self, storage_scope);
+
+ // Check 1. Check index, get the target buffer and the parent scope
+ const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
+ Buffer buffer =
+ GetNthAccessBuffer(self, GetRef<Block>(block), write_buffer_index,
BufferIndexType::kRead);
+ StmtSRef scope_sref = GetScopeRoot(self, block_sref,
/*require_stage_pipeline=*/false);
+
+ // Check 3. Check required region cover for cache_read
+ CheckRegionCover(self, scope_sref);
+
+ Array<StmtSRef> results_block_sref;
+ Buffer new_buffer = WithScope(buffer, storage_scope);
+
+ // Do cache read
+ // Cache read step 0. Create CacheStageInfo
+ CacheStageInfo info;
+ info.read_buffer = buffer;
+ // Create the corresponding buffer to be written for cache_read
+ info.write_buffer = new_buffer;
+ // Create the corresponding buffer allocation
+ info.alloc = info.write_buffer;
+ // Indicate which buffers should consume the cache.
+ info.consumer_blocks.push_back(block_sref);
+
+ // Cache read step 1. Update cache stage info for cache_read.
+ BufferRegion cache_region{nullptr};
+ Optional<StmtSRef> _write_block_sref = GetOnlyWriteBlock(self, scope_sref,
buffer);
+
+ StmtSRef write_block_sref = _write_block_sref.value();
+ const BlockNode* write_block = TVM_SREF_TO_BLOCK(write_block_sref);
+ // Find the producing region
+ BufferRegion region = GetBufferRegionFromBuffer(write_block->writes,
buffer).value();
+ StmtSRef parent_sref = GetRef<StmtSRef>(write_block_sref->parent);
+
+ // Detect insert position
+ CacheBufferLocDetector::Detect(self, write_block_sref, scope_sref, &info);
+ cache_region = RelaxBufferRegion(self, region, write_block_sref,
parent_sref, info.loc_sref);
+
+ // Cache read step 2. Making new cache stage block and rewrite readers.
+ Block cache_read_stage = MakeCacheStage(/*cache_region=*/cache_region,
/*info=*/&info,
+ /*storage_scope=*/storage_scope);
+ Stmt new_scope = CacheReadRewriter::Rewrite(/*scope_sref=*/scope_sref,
/*info=*/&info);
+
+ // Cache read step 3. Replacing and updating flags for cache read.
+ self->Replace(scope_sref, new_scope, info.block_reuse);
Review Comment:
Could we merge this with line 1324 to ensure an atomic state updation?
##########
src/tir/schedule/primitive.h:
##########
@@ -267,6 +267,17 @@ TVM_DLL StmtSRef CacheRead(ScheduleState self, const
StmtSRef& block_sref, int r
*/
TVM_DLL StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref,
int write_buffer_index,
const String& storage_scope);
+/*!
+ *!
+ * \brief Create 2 blocks that read&write a buffer region into a read/write
cache.
+ * \param self The state of the schedule
+ * \param block_sref The block operates on the target buffer.
+ * \param read_buffer_index The index of the buffer in block's read region.
+ * \param storage_scope The target storage scope
+ * \return The reindex stage block.
Review Comment:
return comment requires updation
##########
src/tir/schedule/primitive/cache_read_write.cc:
##########
@@ -1146,6 +1238,100 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef&
block_sref, int write_bu
return result_block_sref;
}
+Array<StmtSRef> CacheBuffer(ScheduleState self, const StmtSRef& block_sref,
int write_buffer_index,
+ const String& storage_scope) {
+ /*!
+ * Do cache read then cache write
+ */
+
+ // Check 0. Check the input storage scope.
+ CheckStorageScope(self, storage_scope);
+
+ // Check 1. Check index, get the target buffer and the parent scope
+ const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
+ Buffer buffer =
+ GetNthAccessBuffer(self, GetRef<Block>(block), write_buffer_index,
BufferIndexType::kRead);
+ StmtSRef scope_sref = GetScopeRoot(self, block_sref,
/*require_stage_pipeline=*/false);
+
+ // Check 3. Check required region cover for cache_read
+ CheckRegionCover(self, scope_sref);
+
+ Array<StmtSRef> results_block_sref;
+ Buffer new_buffer = WithScope(buffer, storage_scope);
+
+ // Do cache read
+ // Cache read step 0. Create CacheStageInfo
+ CacheStageInfo info;
+ info.read_buffer = buffer;
+ // Create the corresponding buffer to be written for cache_read
+ info.write_buffer = new_buffer;
+ // Create the corresponding buffer allocation
+ info.alloc = info.write_buffer;
+ // Indicate which buffers should consume the cache.
+ info.consumer_blocks.push_back(block_sref);
+
+ // Cache read step 1. Update cache stage info for cache_read.
+ BufferRegion cache_region{nullptr};
+ Optional<StmtSRef> _write_block_sref = GetOnlyWriteBlock(self, scope_sref,
buffer);
+
+ StmtSRef write_block_sref = _write_block_sref.value();
+ const BlockNode* write_block = TVM_SREF_TO_BLOCK(write_block_sref);
+ // Find the producing region
+ BufferRegion region = GetBufferRegionFromBuffer(write_block->writes,
buffer).value();
Review Comment:
Could we check the write region must exists as the API document described?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]