Alex Nitz <[email protected]> writes: > I have attached a patch that addresses using pow with a complex vector. The > issue I found was that it was using the wrong > function name in the kernel. There is a if statement that sets the function > name to "pow" for float64, and "powf" for everything else. > This problems is that complex types also use "pow" for the function name. > > I've also attached several patches that address a few issues related to > using a real GPUArray with a complex scalar. The main issue is that > the get_axbz_kernel set the output (z) vector to the same dtype as the > input one (x), and assumes the constant factors are the same dtype as well. > > So for real types the following operation makes sense. > z[i] = a * x[i] + b > > If "a" or "b" is complex, however, the code will complain that it has been > given the wrong type. My patch changes the behavior so that "a","b", and > "z" have > the same dtype, but can be set separately from "x". For the various > operations in GPUArray that call this function, I use the _get_common_dtype > function > even when the "other" is a scalar. This applies to subtraction, addition, > and multiplication. Division calls a different kernel so I made a similar > modification there > as well. > > Finally, I modified the "dot" function to work when one argument is complex > and the other is real. Using get_common_dtype worked to fix this issue as > well.
Thanks for your patches, I've applied them. In the future, please stick to PEP 8. (wrt commas and spaces, especially) I.e. BAD: f(x,y) GOOD: f(x, y) Also, please next time run the tests to see if they pass: File "/mnt/nfs-main/home/andreas/src/pycuda/pycuda/gpuarray.py", line 450, in __rmul__ result = self._new_like_me(_get_common_dtype(self, other)) NameError: global name 'other' is not defined (FTFY) > Also, I wonder why in compyte/array.py the get_common_dtype function does > not simply call numpy.find_common_dtype(vectors,scalars)? I didn't know about numpy.find_common_dtype. Thanks for pointing it out. But in any case, obj2 is allowed to be a plain Python scalar, for which I'd rather let numpy do the special case handling... Andreas _______________________________________________ PyCUDA mailing list [email protected] http://lists.tiker.net/listinfo/pycuda
