This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 7d53953adc [Unity] Check the PassContext within RewriteCUDAGraph
transform (#15861)
7d53953adc is described below
commit 7d53953adc472ec2b1adf9cd91d77d900646f9b8
Author: Eric Lunderberg <[email protected]>
AuthorDate: Tue Oct 3 18:27:44 2023 -0500
[Unity] Check the PassContext within RewriteCUDAGraph transform (#15861)
Prior to this commit, the `RewriteCUDAGraph` pass would
unconditionally rewrite an `IRModule`, and was conditionally
included as a lowering pass for used in `relax.build`, based on the
current `PassContext`. This commit moves the check on the
`PassContext` from the `relax.build` method to the `RewriteCUDAGraph`
pass itself. This allows the pass to be part of a lowering
flow that is constructed once, and is later used when the
`PassContext.current()` may have changed.
---
python/tvm/relax/vm_build.py | 36 +++++++++++-----------
src/relax/transform/rewrite_cuda_graph.cc | 10 +++++-
.../relax/test_transform_rewrite_cuda_graph.py | 27 ++++++++++++++++
3 files changed, 54 insertions(+), 19 deletions(-)
diff --git a/python/tvm/relax/vm_build.py b/python/tvm/relax/vm_build.py
index e6331e436f..8b33379957 100644
--- a/python/tvm/relax/vm_build.py
+++ b/python/tvm/relax/vm_build.py
@@ -305,24 +305,24 @@ def build(
if isinstance(target, str):
target = tvm.target.Target(target)
- passes = []
- passes.append(relax.transform.RewriteDataflowReshape())
- passes.append(relax.transform.ToNonDataflow())
- passes.append(relax.transform.RemovePurityChecking())
- passes.append(relax.transform.CallTIRRewrite())
- passes.append(relax.transform.StaticPlanBlockMemory())
-
- if
tvm.transform.PassContext.current().config.get("relax.backend.use_cuda_graph",
False):
- passes.append(relax.transform.RewriteCUDAGraph())
-
- passes.append(relax.transform.LowerAllocTensor())
- passes.append(relax.transform.KillAfterLastUse())
-
- passes.append(relax.transform.VMBuiltinLower())
- passes.append(relax.transform.VMShapeLower())
- passes.append(relax.transform.AttachGlobalSymbol())
- seq = tvm.transform.Sequential(passes)
- new_mod = seq(mod)
+ lowering_passes = tvm.transform.Sequential(
+ [
+ relax.transform.RewriteDataflowReshape(),
+ relax.transform.ToNonDataflow(),
+ relax.transform.RemovePurityChecking(),
+ relax.transform.CallTIRRewrite(),
+ relax.transform.StaticPlanBlockMemory(),
+ relax.transform.RewriteCUDAGraph(),
+ relax.transform.LowerAllocTensor(),
+ relax.transform.KillAfterLastUse(),
+ relax.transform.VMBuiltinLower(),
+ relax.transform.VMShapeLower(),
+ relax.transform.AttachGlobalSymbol(),
+ ],
+ name="relax.lower",
+ )
+
+ new_mod = lowering_passes(mod)
# Extract external runtime modules if exist.
attrs = dict(mod.attrs) if mod.attrs else {}
diff --git a/src/relax/transform/rewrite_cuda_graph.cc
b/src/relax/transform/rewrite_cuda_graph.cc
index c2a0754462..402dd55545 100644
--- a/src/relax/transform/rewrite_cuda_graph.cc
+++ b/src/relax/transform/rewrite_cuda_graph.cc
@@ -554,7 +554,15 @@ namespace transform {
Pass RewriteCUDAGraph() {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = //
- [=](IRModule m, PassContext pc) { return
::tvm::relax::RewriteCUDAGraph(std::move(m)); };
+ [=](IRModule mod, PassContext pc) {
+ bool use_cuda_graph =
+
pc->GetConfig<Bool>("relax.backend.use_cuda_graph").value_or(Bool(false))->value;
+ if (use_cuda_graph) {
+ mod = ::tvm::relax::RewriteCUDAGraph(std::move(mod));
+ }
+
+ return mod;
+ };
return CreateModulePass(pass_func, 0, "RewriteCUDAGraph", {});
}
diff --git a/tests/python/relax/test_transform_rewrite_cuda_graph.py
b/tests/python/relax/test_transform_rewrite_cuda_graph.py
index b2c991a2ee..73aaf4dac5 100644
--- a/tests/python/relax/test_transform_rewrite_cuda_graph.py
+++ b/tests/python/relax/test_transform_rewrite_cuda_graph.py
@@ -15,6 +15,8 @@
# specific language governing permissions and limitations
# under the License.
+import pytest
+
import tvm
from tvm import relax
from tvm.script import tir as T, relax as R, ir as I
@@ -25,6 +27,13 @@ class BaseCompare(tvm.testing.CompareBeforeAfter):
transform = relax.transform.RewriteCUDAGraph()
[email protected](autouse=True)
+def enable_cuda_graph():
+ """Enable cuda graph transform for all tests in this file"""
+ with tvm.transform.PassContext(config={"relax.backend.use_cuda_graph":
True}):
+ yield
+
+
def test_rewrite_cuda_graph():
# fmt: off
@I.ir_module
@@ -677,5 +686,23 @@ class TestNullValue(BaseCompare):
expected = before
+def test_transform_is_no_op_when_disabled():
+ @I.ir_module
+ class Before:
+ @R.function
+ def main():
+ storage = R.memory.alloc_storage(R.shape([8]), 0, "global",
"float32")
+ alloc3 = R.memory.alloc_tensor(storage, 0, R.shape([8]), "float32")
+ return R.tuple()
+
+ with tvm.transform.PassContext(config={"relax.backend.use_cuda_graph":
True}):
+ AfterWhenEnabled = relax.transform.RewriteCUDAGraph()(Before)
+ with tvm.transform.PassContext(config={"relax.backend.use_cuda_graph":
False}):
+ AfterWhenDisabled = relax.transform.RewriteCUDAGraph()(Before)
+
+ assert not tvm.ir.structural_equal(Before, AfterWhenEnabled)
+ tvm.ir.assert_structural_equal(Before, AfterWhenDisabled)
+
+
if __name__ == "__main__":
tvm.testing.main()