On Mon, 21 Feb 2022, at 11:50 PM, Stefan van der Walt wrote:
> Just to play a bit of devil's advocate here, I'd have to say that most people
> will not expect
>
> x[0] + 200
>
> To often yield a number less than 200!
It's tricky though, because I would expect
np.uint8(255) + 1
to be equal to 0. (As does JAX, see below.)
ie, someone, somewhere, is going to be surprised. I don't think we can help
that at all. So my argument is that we should prefer the surprising behaviour
that is at least consistent in some overarching framework, and the framework
itself should be as parsimonious as possible. I'd prefer not to have to write
"except for scalars" in a bunch of places in the docs.
> I think uint8's are especially problematic because they overflow so quickly
> (you won't easily run into the same behavior with uint16 and higher). Of
> course, there is no way to pretend that NumPy integers are Python integers,
> but by changing the casting table for uint8 a bit we may be able to avoid
> many common errors.
See, I kinda hate the idea of special-casing one dtype. Common errors might be
a good thing — people can very quickly learn to be careful with uint8s. If we
try really hard to hide this reality, people will be surprised *later*, or
indeed errors may go unnoticed.
> Besides, coming from value based casting, users already have this expectation:
>
> In [1]: np.uint8(255) + 1
> Out[1]: 256
>
> Currently, NumPy scalars and arrays are treated differently. Arrays have
> stronger types than scalars, in that users expect:
>
> In [3]: np.array([253, 254, 255], dtype=np.uint8) + 3
> Out[3]: array([0, 1, 2], dtype=uint8)
I think the users that expect *both* of those behaviours are a small set.
> So perhaps the real question is: how important is it to us that arrays and
> scalars behave the same in the new casting scheme? (JAX, from the docs you
> linked, also makes the scalar vs array distinction.)
No, as far as I can tell, they distinguish between *Python* scalars and arrays,
not between JAX scalars and arrays. They do have a concept of weakly typed
arrays, but I don't think that's what you get when you do jnp.uint8(x). Indeed
I just checked that
jnp.uint8(255) + 1
returns a uint8 scalar with value 0. (or 0-dimensional array? Not sure how JAX
handles scalars, the exact repr returned is DeviceArray(0, dtype=uint8))
>> We have also increasingly encountered users surprised/annoyed that
>> scikit-image blew up their uint8 to a float64, using 8x the RAM.
>
> I know this used to be true, but my sense is that it is less and less so,
> especially now that almost all skimage functions use floatx internally.
Greg spent a long time last year making sure that we didn't promote float32 to
float64 for this reason. This has reduced some of the burden but not all, and
my point is broader: users will not be happy to have uint8 + Python int return
an int64 array implicitly. And to quote from the JAX document, which to me
seems to be the nail in the coffin for alternatives:
> The benefit of these semantics are that you can readily express sequences of
> operations with clean Python code, without having to explicitly cast scalars
> to the appropriate type. Imagine if rather than writing this:
>
> 3 * (x + 1) ** 2
> you had to write this:
>
> np.int32(3) * (x + np.int32(1)) ** np.int32(2)
Juan.
_______________________________________________
NumPy-Discussion mailing list -- numpy-discussion@python.org
To unsubscribe send an email to numpy-discussion-le...@python.org
https://mail.python.org/mailman3/lists/numpy-discussion.python.org/
Member address: arch...@mail-archive.com