This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 7d53953adc [Unity] Check the PassContext within RewriteCUDAGraph 
transform (#15861)
7d53953adc is described below

commit 7d53953adc472ec2b1adf9cd91d77d900646f9b8
Author: Eric Lunderberg <[email protected]>
AuthorDate: Tue Oct 3 18:27:44 2023 -0500

    [Unity] Check the PassContext within RewriteCUDAGraph transform (#15861)
    
    Prior to this commit, the `RewriteCUDAGraph` pass would
    unconditionally rewrite an `IRModule`, and was conditionally
    included as a lowering pass for used in `relax.build`, based on the
    current `PassContext`.  This commit moves the check on the
    `PassContext` from the `relax.build` method to the `RewriteCUDAGraph`
    pass itself.  This allows the pass to be part of a lowering
    flow that is constructed once, and is later used when the
    `PassContext.current()` may have changed.
---
 python/tvm/relax/vm_build.py                       | 36 +++++++++++-----------
 src/relax/transform/rewrite_cuda_graph.cc          | 10 +++++-
 .../relax/test_transform_rewrite_cuda_graph.py     | 27 ++++++++++++++++
 3 files changed, 54 insertions(+), 19 deletions(-)

diff --git a/python/tvm/relax/vm_build.py b/python/tvm/relax/vm_build.py
index e6331e436f..8b33379957 100644
--- a/python/tvm/relax/vm_build.py
+++ b/python/tvm/relax/vm_build.py
@@ -305,24 +305,24 @@ def build(
     if isinstance(target, str):
         target = tvm.target.Target(target)
 
-    passes = []
-    passes.append(relax.transform.RewriteDataflowReshape())
-    passes.append(relax.transform.ToNonDataflow())
-    passes.append(relax.transform.RemovePurityChecking())
-    passes.append(relax.transform.CallTIRRewrite())
-    passes.append(relax.transform.StaticPlanBlockMemory())
-
-    if 
tvm.transform.PassContext.current().config.get("relax.backend.use_cuda_graph", 
False):
-        passes.append(relax.transform.RewriteCUDAGraph())
-
-    passes.append(relax.transform.LowerAllocTensor())
-    passes.append(relax.transform.KillAfterLastUse())
-
-    passes.append(relax.transform.VMBuiltinLower())
-    passes.append(relax.transform.VMShapeLower())
-    passes.append(relax.transform.AttachGlobalSymbol())
-    seq = tvm.transform.Sequential(passes)
-    new_mod = seq(mod)
+    lowering_passes = tvm.transform.Sequential(
+        [
+            relax.transform.RewriteDataflowReshape(),
+            relax.transform.ToNonDataflow(),
+            relax.transform.RemovePurityChecking(),
+            relax.transform.CallTIRRewrite(),
+            relax.transform.StaticPlanBlockMemory(),
+            relax.transform.RewriteCUDAGraph(),
+            relax.transform.LowerAllocTensor(),
+            relax.transform.KillAfterLastUse(),
+            relax.transform.VMBuiltinLower(),
+            relax.transform.VMShapeLower(),
+            relax.transform.AttachGlobalSymbol(),
+        ],
+        name="relax.lower",
+    )
+
+    new_mod = lowering_passes(mod)
 
     # Extract external runtime modules if exist.
     attrs = dict(mod.attrs) if mod.attrs else {}
diff --git a/src/relax/transform/rewrite_cuda_graph.cc 
b/src/relax/transform/rewrite_cuda_graph.cc
index c2a0754462..402dd55545 100644
--- a/src/relax/transform/rewrite_cuda_graph.cc
+++ b/src/relax/transform/rewrite_cuda_graph.cc
@@ -554,7 +554,15 @@ namespace transform {
 
 Pass RewriteCUDAGraph() {
   runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =  //
-      [=](IRModule m, PassContext pc) { return 
::tvm::relax::RewriteCUDAGraph(std::move(m)); };
+      [=](IRModule mod, PassContext pc) {
+        bool use_cuda_graph =
+            
pc->GetConfig<Bool>("relax.backend.use_cuda_graph").value_or(Bool(false))->value;
+        if (use_cuda_graph) {
+          mod = ::tvm::relax::RewriteCUDAGraph(std::move(mod));
+        }
+
+        return mod;
+      };
   return CreateModulePass(pass_func, 0, "RewriteCUDAGraph", {});
 }
 
diff --git a/tests/python/relax/test_transform_rewrite_cuda_graph.py 
b/tests/python/relax/test_transform_rewrite_cuda_graph.py
index b2c991a2ee..73aaf4dac5 100644
--- a/tests/python/relax/test_transform_rewrite_cuda_graph.py
+++ b/tests/python/relax/test_transform_rewrite_cuda_graph.py
@@ -15,6 +15,8 @@
 # specific language governing permissions and limitations
 # under the License.
 
+import pytest
+
 import tvm
 from tvm import relax
 from tvm.script import tir as T, relax as R, ir as I
@@ -25,6 +27,13 @@ class BaseCompare(tvm.testing.CompareBeforeAfter):
     transform = relax.transform.RewriteCUDAGraph()
 
 
[email protected](autouse=True)
+def enable_cuda_graph():
+    """Enable cuda graph transform for all tests in this file"""
+    with tvm.transform.PassContext(config={"relax.backend.use_cuda_graph": 
True}):
+        yield
+
+
 def test_rewrite_cuda_graph():
     # fmt: off
     @I.ir_module
@@ -677,5 +686,23 @@ class TestNullValue(BaseCompare):
     expected = before
 
 
+def test_transform_is_no_op_when_disabled():
+    @I.ir_module
+    class Before:
+        @R.function
+        def main():
+            storage = R.memory.alloc_storage(R.shape([8]), 0, "global", 
"float32")
+            alloc3 = R.memory.alloc_tensor(storage, 0, R.shape([8]), "float32")
+            return R.tuple()
+
+    with tvm.transform.PassContext(config={"relax.backend.use_cuda_graph": 
True}):
+        AfterWhenEnabled = relax.transform.RewriteCUDAGraph()(Before)
+    with tvm.transform.PassContext(config={"relax.backend.use_cuda_graph": 
False}):
+        AfterWhenDisabled = relax.transform.RewriteCUDAGraph()(Before)
+
+    assert not tvm.ir.structural_equal(Before, AfterWhenEnabled)
+    tvm.ir.assert_structural_equal(Before, AfterWhenDisabled)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to