ashutosh-arm commented on code in PR #15506:
URL: https://github.com/apache/tvm/pull/15506#discussion_r1286837751
##########
tests/python/relay/strategy/test_select_implementation.py:
##########
@@ -187,5 +189,38 @@ def test_dense(target, expected_valid_impl, expected_impl):
assert selected_impl.name == expected_impl
[email protected](
+ "target,schedule_func",
+ [
+ ("llvm -device=arm_cpu", topi.x86),
+ ("c -device=arm_cpu -mcpu=cortex-m55", topi.arm_cpu),
+ ],
+)
+def test_pool2d(target, schedule_func, monkeypatch):
+ target = tvm.target.Target(target)
+
+ data_shape = (1, 2, 2, 4)
+ dtype = "float32"
+
+ out = relay.nn.avg_pool2d(relay.var("data", shape=data_shape, dtype=dtype))
+ out = relay.Function(relay.analysis.free_vars(out), out)
+ out = tvm.IRModule.from_expr(out)
+
+ # Since pool does not use OpStrategy to determine the relevant schedule,
+ # we cannot simply check the schedule name that was selected with
+ # `select_implementation`. With this implementation of schedule selection,
+ # "pool.arm_cpu" will always be the schedule name, regardless of what
schedule
+ # was selected. Instead, this test checks that the relevant schedule
function
+ # is called when building the module.
+ mock_schedule = MagicMock()
+ mock_schedule.side_effect = lambda outs, layout:
topi.generic.schedule_pool(outs, layout)
Review Comment:
question: `side_effect` is new to me. Does this work in the same way as
described in the following places?
https://docs.pytest.org/en/7.1.x/how-to/monkeypatch.html#monkeypatching-functions
https://docs.python.org/3/library/unittest.mock-examples.html#side-effect-functions-and-iterables
Would it better to have a mock `schedule_pool()` within the test and check
the value returned by it / prints done by it? This will make sure that the
replacement actually happened only when x86 schedule was invoked.
##########
python/tvm/relay/op/strategy/arm_cpu.py:
##########
@@ -74,8 +74,7 @@ def schedule_pool_arm_cpu(attrs, outs, target):
and layout in ("NWC", "NHWC")
):
return topi.arm_cpu.schedule_pool(outs, layout)
- logger.warning("pool is not optimized for arm cpu.")
- return topi.generic.schedule_pool(outs, layout)
+ return topi.x86.schedule_pool(outs, layout)
Review Comment:
question: should we still wrap this under an `arm_cpu` schedule inside topi?
In the future, if needed, the API can then further be optimized for the
specific device.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]