szha closed pull request #10618: [MXNET-338] Fix symbol boolean evaluation URL: https://github.com/apache/incubator-mxnet/pull/10618
This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py index 915dfc9245e..1ab7cf87bf5 100644 --- a/python/mxnet/symbol/symbol.py +++ b/python/mxnet/symbol/symbol.py @@ -103,6 +103,11 @@ def __add__(self, other): else: raise TypeError('type %s not supported' % str(type(other))) + def __bool__(self): + raise NotImplementedForSymbol(self.__bool__, 'bool') + + __nonzero__ = __bool__ + def __iadd__(self, other): raise NotImplementedForSymbol(self.__iadd__, '+=', other, 1) diff --git a/tests/python/unittest/test_symbol.py b/tests/python/unittest/test_symbol.py index 11c2eba7609..387428ab296 100644 --- a/tests/python/unittest/test_symbol.py +++ b/tests/python/unittest/test_symbol.py @@ -20,7 +20,8 @@ import re import mxnet as mx import numpy as np -from common import models +from common import assertRaises, models +from mxnet.base import NotImplementedForSymbol from mxnet.test_utils import discard_stderr import pickle as pkl @@ -31,6 +32,10 @@ def test_symbol_basic(): m.list_arguments() m.list_outputs() +def test_symbol_bool(): + x = mx.symbol.Variable('x') + assertRaises(NotImplementedForSymbol, bool, x) + def test_symbol_compose(): data = mx.symbol.Variable('data') net1 = mx.symbol.FullyConnected(data=data, name='fc1', num_hidden=10) ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services