wkcn commented on a change in pull request #16280: [Gluon] Support None
argument in HybridBlock
URL: https://github.com/apache/incubator-mxnet/pull/16280#discussion_r328397532
##########
File path: tests/python/unittest/test_gluon.py
##########
@@ -441,6 +441,63 @@ def test_sparse_hybrid_block():
# an exception is expected when forwarding a HybridBlock w/ sparse param
y = net(x)
+@with_seed()
+def test_hybrid_block_none_args():
+ class Foo(HybridBlock):
+ def hybrid_forward(self, F, a, b):
+ if a is None and b is not None:
+ return b
+ elif b is None and a is not None:
+ return a
+ elif a is not None and b is not None:
+ return a + b
+ else:
+ raise NotImplementedError
+
+ class FooNested(HybridBlock):
+ def __init__(self, prefix=None, params=None):
+ super(FooNested, self).__init__(prefix=prefix, params=params)
+ self.f1 = Foo(prefix='foo1')
+ self.f2 = Foo(prefix='foo2')
+ self.f3 = Foo(prefix='foo3')
+
+ def hybrid_forward(self, F, a, b):
+ data = self.f1(a, b)
+ data = self.f2(a, data)
+ data = self.f3(data, b)
+ return data
+
+ for arg_inputs in [(None, mx.nd.ones((10,))),
+ (mx.nd.ones((10,)), mx.nd.ones((10,))),
+ (mx.nd.ones((10,)), None)]:
+ foo1 = FooNested(prefix='foo_nested_hybridized')
+ foo1.hybridize()
+ foo2 = FooNested(prefix='foo_nested_nohybrid')
+ for _ in range(2):
Review comment:
It will be better to annotate that the reason of the two loops is to test
hybridized cache : )
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services