Hi, I have encountered a very weird behaviour with einsum. I try to compute something like R*A*R', where * denotes a kind of "matrix multiplication". However, for particular shapes of R and A, the results are extremely bad.
I compare two einsum results: First, I compute in two einsum calls as (R*A)*R'. Second, I compute the whole result in one einsum call. However, the results are significantly different for some shapes. My test: import numpy as np for D in range(30): A = np.random.randn(100,D,D) R = np.random.randn(D,D) Y1 = np.einsum('...ik,...kj->...ij', R, A) Y1 = np.einsum('...ik,...kj->...ij', Y1, R.T) Y2 = np.einsum('...ik,...kl,...lj->...ij', R, A, R.T) print("D=%d" % D, np.allclose(Y1,Y2), np.linalg.norm(Y1-Y2)) Output: D=0 True 0.0 D=1 True 0.0 D=2 True 8.40339658678e-15 D=3 True 8.09995399928e-15 D=4 True 3.59428803435e-14 D=5 False 34.755610184 D=6 False 28.3576558351 D=7 False 41.5402690906 D=8 True 2.31709582841e-13 D=9 False 36.0161112799 D=10 True 4.76237746912e-13 D=11 True 4.57944440782e-13 D=12 True 4.90302218301e-13 D=13 True 6.96175851271e-13 D=14 True 1.10067181384e-12 D=15 True 1.29095933163e-12 D=16 True 1.3466837332e-12 D=17 True 1.52265065763e-12 D=18 True 2.05407923852e-12 D=19 True 2.33327630748e-12 D=20 True 2.96849358082e-12 D=21 True 3.31063706175e-12 D=22 True 4.28163620455e-12 D=23 True 3.58951880681e-12 D=24 True 4.69973694769e-12 D=25 True 5.47385264567e-12 D=26 True 5.49643316347e-12 D=27 True 6.75132988402e-12 D=28 True 7.86435437892e-12 D=29 True 7.85453681029e-12 So, for D={5,6,7,9}, allclose returns False and the error norm is HUGE. It doesn't seem like just some small numerical inaccuracy because the error norm is so large. I don't know which one is correct (Y1 or Y2) but at least either one is wrong in my opinion. I ran the same test several times, and each time same values of D fail. If I change the shapes somehow, the failing values of D might change too, but I usually have several failing values. I'm running the latest version from github (commit bd7104cef4) under Python 3.2.3. With NumPy 1.6.1 under Python 2.7.3 the test crashes and Python exits printing "Floating point exception". This seems so weird to me that I wonder if I'm just doing something stupid.. Thanks a lot for any help! Jaakko _______________________________________________ NumPy-Discussion mailing list NumPy-Discussion@scipy.org http://mail.scipy.org/mailman/listinfo/numpy-discussion