Lunderberg commented on code in PR #14254:
URL: https://github.com/apache/tvm/pull/14254#discussion_r1139148547
##########
include/tvm/tir/builtin.h:
##########
@@ -726,15 +726,26 @@ TVM_DLL const Op& texture2d_store();
TVM_DLL const Op& texture2d_load();
/*!
- * \brief Initiate a non-blocking DMA copy from source to destination
+ * \brief Initiate a non-blocking DMA copy from source to destination; a DMA
copy outside of a group
+ * has a defacto group size of one
*/
TVM_DLL const Op& dma_copy();
/*!
- * \brief Wait until the number of DMAs in flight is less than or equal to
some maximum
+ * \brief Wait until the number of DMA groups in flight is less than or equal
to some maximum
*/
TVM_DLL const Op& dma_wait();
+/*!
Review Comment:
Nit: Should document behavior of nested `dma_start_group()`.
```c++
/* \brief Start a group of DMA copies
*
* Any call to `dma_copy()` that occur after `dma_start_group()` will
* be deferred until the next call to `dma_end_group()`, rather than
* being launched immediately.
*
* Only one DMA group may be active at a given time. Calling
* `dma_start_group()` while a group is already active is unsupported.
*/
```
##########
include/tvm/tir/builtin.h:
##########
@@ -726,15 +726,26 @@ TVM_DLL const Op& texture2d_store();
TVM_DLL const Op& texture2d_load();
/*!
- * \brief Initiate a non-blocking DMA copy from source to destination
+ * \brief Initiate a non-blocking DMA copy from source to destination; a DMA
copy outside of a group
Review Comment:
Nit: Can we explicitly specify the behavior, both for when a group is active
and when a group is inactive?
```c++
/* \brief Initiate a non-blocking DMA copy from source to destination
*
* If a `dma_start_group()` call is active, the copy will be collected,
* and will be launched when the next `dma_end_group()` call occurs.
*
* If no `dma_start_group()` call is active, the copy will be launched
* immediately.
*/
```
##########
src/runtime/hexagon/hexagon_device_api.cc:
##########
@@ -233,6 +233,19 @@
TVM_REGISTER_GLOBAL("device_api.hexagon.dma_wait").set_body([](TVMArgs args, TVM
*rv = static_cast<int32_t>(0);
});
+TVM_REGISTER_GLOBAL("device_api.hexagon.dma_start_group")
+ .set_body([](TVMArgs args, TVMRetValue* rv) {
Review Comment:
Nit: When the input/output types are fixed, the `.set_body_typed()` method
can be used to avoid needing manual argument wrangling.
```c++
.set_body_typed([](int queue_id) -> int32_t {
return HexagonDeviceAPI::Global()->UserDMA()->StartGroup(queue_id);
});
```
##########
src/tir/transforms/lower_async_dma.cc:
##########
@@ -22,26 +22,61 @@
*/
#include <tvm/arith/analyzer.h>
+#include <tvm/arith/bound.h>
#include <tvm/arith/iter_affine_map.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/buffer.h>
+#include <tvm/tir/stmt.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
+#include <optional>
+
+#include "../../arith/ir_mutator_with_analyzer.h"
#include "ir_utils.h"
namespace tvm {
namespace tir {
-class AsyncDMALowerer : public StmtExprMutator {
+class AsyncDMALowerer : public arith::IRMutatorWithAnalyzer {
public:
- explicit AsyncDMALowerer(bool dma_bypass_cache) :
dma_bypass_cache_(dma_bypass_cache) {}
+ explicit AsyncDMALowerer(bool dma_bypass_cache, arith::Analyzer* analyzer)
+ : IRMutatorWithAnalyzer(analyzer), dma_bypass_cache_(dma_bypass_cache) {}
+
+ Stmt VisitStmt_(const ForNode* loop) final {
+ // if for loop is not within async_commit_queue_scope
+ if (!async_queue_id_.has_value()) {
+ return arith::IRMutatorWithAnalyzer::VisitStmt_(loop);
+ }
- // Create member statement to track a mapping from iter var to iter range
- Stmt VisitStmt_(const ForNode* op) final {
- input_iters.Set(op->loop_var, Range(op->min, op->extent));
- return StmtExprMutator::VisitStmt_(op);
+ // if for loop is not a memcpy of a contiguous region
+ std::optional<tvm::tir::MemCpyDetails> mem_copy =
IdentifyMemCpy(GetRef<For>(loop), analyzer_);
+ if (!mem_copy.has_value() || mem_copy->dest->region.size() != 1 ||
+ mem_copy->source->region.size() != 1) {
+ LOG(FATAL) << "Unable to lower async dma due to non contiguous memory
access";
+ }
+
+ // now that we are about to perform the `copy` transform
+ // save queue ID for inspection in `wait` transform
+ // and, increment the number of DMA copies in the group
+ queue_ids_.insert(async_queue_id_.value());
+ dmas_in_group_++;
+
+ tvm::PrimExpr src_min = mem_copy->source->region[0]->min;
+ tvm::PrimExpr dst_min = mem_copy->dest->region[0]->min;
+ tvm::PrimExpr dst_extent = mem_copy->dest->region[0]->extent;
+
+ auto src = BufferLoad(mem_copy->source->buffer, {src_min});
+ auto dst = BufferLoad(mem_copy->dest->buffer, {dst_min});
+ return Evaluate(
+ Call(DataType::Int(32), builtin::dma_copy(),
+ {async_queue_id_.value(), Call(DataType::Handle(),
builtin::address_of(), {dst}),
+ Call(DataType::Handle(), builtin::address_of(), {src}),
+ dst_extent * src->dtype.bytes(), dma_bypass_cache_}));
}
Stmt VisitStmt_(const AttrStmtNode* op) final {
+ arith::IRMutatorWithAnalyzer::VisitStmt_(op);
Review Comment:
Can we add a comment for why we need to pre-visit the `AttrStmt`?
##########
include/tvm/tir/builtin.h:
##########
@@ -726,15 +726,26 @@ TVM_DLL const Op& texture2d_store();
TVM_DLL const Op& texture2d_load();
/*!
- * \brief Initiate a non-blocking DMA copy from source to destination
+ * \brief Initiate a non-blocking DMA copy from source to destination; a DMA
copy outside of a group
+ * has a defacto group size of one
*/
TVM_DLL const Op& dma_copy();
/*!
- * \brief Wait until the number of DMAs in flight is less than or equal to
some maximum
+ * \brief Wait until the number of DMA groups in flight is less than or equal
to some maximum
*/
TVM_DLL const Op& dma_wait();
+/*!
+ * \brief Start a group of DMA copies
+ */
+TVM_DLL const Op& dma_start_group();
+
+/*!
Review Comment:
Similar here, we should document what usage is supported. This would also
be a good place to specify that empty groups still count as "in-flight", since
the `dma_copy()` in-between a start/end pair may have been optimized out.
```c++
/* \brief End a group of DMA copies
*
* Launch all calls to `dma_copy()` that occurred since the preceding
* `dma_start_group()`. Calling `dma_end_group()` without an active
* group is unsupported.
*
* Note: A group of DMA calls may be empty, and will still contribute
* to the count of in-flight groups used by `dma_wait()`.
*/
```
##########
src/tir/transforms/lower_async_dma.cc:
##########
@@ -63,22 +98,21 @@ class AsyncDMALowerer : public StmtExprMutator {
DLOG(INFO) << "AsyncDMALowerer exiting because the queue ID observed
in the "
"`async_wait_queue_scope` transform has not been
previously observed in the "
"`async_commit_queue_scope` transform";
- return StmtExprMutator::VisitStmt_(op);
+ return arith::IRMutatorWithAnalyzer::VisitStmt_(op);
Review Comment:
Do we need to re-visit with `IRMutatorWithAnalyzer` in this case? If the
initial visiting is still valid, we should save the result and return it here.
##########
src/runtime/hexagon/hexagon_user_dma.cc:
##########
@@ -103,15 +104,15 @@ int HexagonUserDMA::Copy(int queue_id, void* dst, void*
src, uint32_t length, bo
return DMA_SUCCESS;
}
-void HexagonUserDMA::Wait(int queue_id, uint32_t max_dmas_in_flight) {
+void HexagonUserDMA::Wait(uint32_t queue_id, uint32_t max_dmas_in_flight) {
// wait (forever) until max DMAs in flight <= actual DMAs in flight
while (DMAsInFlight(queue_id) > max_dmas_in_flight) {
}
}
-uint32_t HexagonUserDMA::Poll(int queue_id) { return DMAsInFlight(queue_id); }
+uint32_t HexagonUserDMA::Poll(uint32_t queue_id) { return
DMAsInFlight(queue_id); }
-uint32_t HexagonUserDMA::DMAsInFlight(int queue_id) {
+uint32_t HexagonUserDMA::DMAsInFlight(uint32_t queue_id) {
Review Comment:
Nit: Can we update the name from `DMAsInFlight` to `DMAGroupsInFlight`?
##########
src/runtime/hexagon/hexagon_device_api.cc:
##########
@@ -233,6 +233,19 @@
TVM_REGISTER_GLOBAL("device_api.hexagon.dma_wait").set_body([](TVMArgs args, TVM
*rv = static_cast<int32_t>(0);
});
+TVM_REGISTER_GLOBAL("device_api.hexagon.dma_start_group")
+ .set_body([](TVMArgs args, TVMRetValue* rv) {
+ int queue_id = args[0];
+ HexagonDeviceAPI::Global()->UserDMA()->StartGroup(queue_id);
+ *rv = static_cast<int32_t>(0);
Review Comment:
Should this use `uint32_t` for the `queue_id` and return value, to match the
internal methods.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]