This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 1c52e633c7 [TIR][Schedule] Method returning the function being worked
on (#14593)
1c52e633c7 is described below
commit 1c52e633c79afa4a6e5cf90fc97445448e486bf5
Author: Ruihang Lai <[email protected]>
AuthorDate: Tue Apr 11 21:43:42 2023 -0400
[TIR][Schedule] Method returning the function being worked on (#14593)
PR #11999 introduces the sugar method `work_on` to TIR Schedule, with
a field `func_working_on_` newly added to the ScheduleNode. In some
cases we may want to know which function a ScheduleNode is working on,
which is not supported previously.
Therefore, this PR introduces a method to ScheduleNode that returns
the function (more accurately, GlobalVar) currently being worked on.
With this we are able to know the function being worked on.
---
include/tvm/tir/schedule/schedule.h | 2 ++
python/tvm/tir/schedule/schedule.py | 7 ++++++-
src/tir/schedule/concrete_schedule.h | 1 +
src/tir/schedule/schedule.cc | 2 ++
tests/python/unittest/test_tir_schedule_utilities.py | 1 +
5 files changed, 12 insertions(+), 1 deletion(-)
diff --git a/include/tvm/tir/schedule/schedule.h
b/include/tvm/tir/schedule/schedule.h
index c294d0ae87..69f0520117 100644
--- a/include/tvm/tir/schedule/schedule.h
+++ b/include/tvm/tir/schedule/schedule.h
@@ -115,6 +115,8 @@ class ScheduleNode : public runtime::Object {
virtual ScheduleState state() const = 0;
/*! \return The internally maintained trace of scheduling program execution
*/
virtual Optional<Trace> trace() const = 0;
+ /*! \return The GlobalVar of the func that the schedule is currently working
on */
+ virtual Optional<GlobalVar> func_working_on() const = 0;
/*!
* \brief Instruct the schedule to work on a function in the IRModule.
*
diff --git a/python/tvm/tir/schedule/schedule.py
b/python/tvm/tir/schedule/schedule.py
index b19e30848f..34fd649a5d 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -19,7 +19,7 @@ from typing import Callable, Dict, List, Optional, Tuple,
Union
from tvm._ffi import register_object as _register_object
from tvm.error import TVMError, register_error
-from tvm.ir import IRModule, PrimExpr
+from tvm.ir import GlobalVar, IRModule, PrimExpr
from tvm.runtime import Object, String
from tvm.tir import Block, Buffer, FloatImm, For, IntImm, PrimFunc
@@ -207,6 +207,11 @@ class Schedule(Object):
"""Returns the internally maintained trace of scheduling program
execution"""
return _ffi_api.ScheduleGetTrace(self) # type: ignore # pylint:
disable=no-member
+ @property
+ def func_working_on(self) -> Optional[GlobalVar]:
+ """Returns the GlobalVar of the func that the schedule is currently
working on"""
+ return _ffi_api.ScheduleGetFuncWorkingOn(self) # type: ignore #
pylint: disable=no-member
+
def work_on(self, func_name: str) -> None:
"""Instruct the schedule to work on a function in the IRModule.
diff --git a/src/tir/schedule/concrete_schedule.h
b/src/tir/schedule/concrete_schedule.h
index eb7c38753c..16065df3cd 100644
--- a/src/tir/schedule/concrete_schedule.h
+++ b/src/tir/schedule/concrete_schedule.h
@@ -64,6 +64,7 @@ class ConcreteScheduleNode : public ScheduleNode {
public:
ScheduleState state() const final { return state_; }
Optional<Trace> trace() const override { return NullOpt; }
+ Optional<GlobalVar> func_working_on() const final { return func_working_on_;
}
void WorkOn(const String& func_name) final;
Schedule Copy() override;
void Seed(support::LinearCongruentialEngine::TRandState seed) final;
diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc
index 20a044439b..8663ac2b97 100644
--- a/src/tir/schedule/schedule.cc
+++ b/src/tir/schedule/schedule.cc
@@ -50,6 +50,8 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetState") //
.set_body_method<Schedule>(&ScheduleNode::state);
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetTrace") //
.set_body_method<Schedule>(&ScheduleNode::trace);
+TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetFuncWorkingOn") //
+ .set_body_method<Schedule>(&ScheduleNode::func_working_on);
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCopy") //
.set_body_method<Schedule>(&ScheduleNode::Copy);
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSeed") //
diff --git a/tests/python/unittest/test_tir_schedule_utilities.py
b/tests/python/unittest/test_tir_schedule_utilities.py
index ba2c134def..a8be97488b 100644
--- a/tests/python/unittest/test_tir_schedule_utilities.py
+++ b/tests/python/unittest/test_tir_schedule_utilities.py
@@ -193,6 +193,7 @@ def test_tir_schedule_work_on():
sch.get_block(name="init")
sch.work_on(func_name="vector_add")
sch.get_block(name="init")
+ assert sch.func_working_on == sch.mod.get_global_var("vector_add")
def test_tir_schedule_get_loops(use_block_name):