On Sat, 22 Mar 2025 20:02:31 GMT, Ferenc Rakoczi <d...@openjdk.org> wrote:
>> By using the AVX-512 vector registers the speed of the computation of the >> ML-DSA algorithms (key generation, document signing, signature verification) >> can be approximately doubled. > > Ferenc Rakoczi has updated the pull request incrementally with two additional > commits since the last revision: > > - Further readability improvements. > - Added asserts for array sizes I still need to have a look at the sha3 changes, but I think I am done with the most complex part of the review. This was a really interesting bit of code to review! src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 270: > 268: } > 269: > 270: static void loadPerm(int destinationRegs[], Register perms, `replXmm`? i.e. this function is replicating (any) Xmm register, not just perm?.. src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 327: > 325: // > 326: // > 327: static address generate_dilithiumAlmostNtt_avx512(StubGenerator *stubgen, Similar comments as to `generate_dilithiumAlmostInverseNtt_avx512` - similar comment about the 'pair-wise' operation, updating `[j]` and `[j+l]` at a time.. - somehow had less trouble following the flow through registers here, perhaps I am getting used to it. FYI, ended renaming some as: // xmm16_27 = Temp1 // xmm0_3 = Coeffs1 // xmm4_7 = Coeffs2 // xmm8_11 = Coeffs3 // xmm12_15 = Coeffs4 = Temp2 // xmm16_27 = Scratch src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 421: > 419: for (int i = 0; i < 8; i += 2) { > 420: __ evpermi2d(xmm(i / 2 + 12), xmm(i), xmm(i + 1), > Assembler::AVX_512bit); > 421: } Wish there was a more 'abstract' way to arrange this, so its obvious from the shape of the code what registers are input/outputs (i.e. and use the register arrays). Even though its just 'elementary index operations' `i/2 + 16` is still 'clever'. Couldnt think of anything myself though (same elsewhere in this function for the table permutes). src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 509: > 507: // coeffs (int[256]) = c_rarg0 > 508: // zetas (int[256]) = c_rarg1 > 509: static address generate_dilithiumAlmostInverseNtt_avx512(StubGenerator > *stubgen, Done with this function; Perhaps the 'permute table' is a common vector-algorithm pattern, but this is really clever! Some general comments first, rest inline. - The array names for registers helped a lot. And so did the new helper functions! - The java version of this code is quite intimidating to vectorize.. 3D loop, with geometric iteration variables.. and the literature is even more intimidating (discrete convolutions which I havent touched in two decades, ffts, ntts, etc.) Here is my attempt at a comment to 'un-scare' the next reader, though feel free to reword however you like. The core of the (Java) loop is this 'pair-wise' operation: int a = coeffs[j]; int b = coeffs[j + offset]; coeffs[j] = (a + b); coeffs[j + offset] = montMul(a - b, -MONT_ZETAS_FOR_NTT[m]); There are 8 'levels' (0-7); ('levels' are equivalent to (unrolling) the outer (Java) loop) At each level, the 'pair-wise-offset' doubles (2^l: 1, 2, 4, 8, 16, 32, 64, 128). To vectorize this Java code, observe that at each level, REGARDLESS the offset, half the operations are the SUM, and the other half is the montgomery MULTIPLICATION (of the pair-difference with a constant). At each level, one 'just' has to shuffle the coefficients, so that SUMs and MULTIPLICATIONs line up accordingly. Otherwise, this pattern is 'lightly similar' to a discrete convolution (compute integral/summation of two functions at every offset) - I still would prefer (more) symbolic register names.. I wouldn't hold my approval over it so won't object if nobody else does, but register numbers are harder to 'see' through the flow. I ended up search/replacing/'annotating' to make it easier on myself to follow the flow of data: // xmm8_11 = Perms1 // xmm12_15 = Perms2 // xmm16_27 = Scratch // xmm0_3 = CoeffsPlus // xmm4_7 = CoeffsMul // xmm24_27 = CoeffsMinus (overlaps with Scratch) (I made a similar comment, but I think it is now hidden after the last refactor) - would prefer to see the helper functions to get ALL the registers passed explicitly (i.e. currently `montMulPerm`, `montQInvModR`, `dilithium_q`, `xmm29`, are implicit.). As a general rule, I've tried to set up all the registers up at the 'entry' function (`generate_dilithium*` in this case) and from there on, use symbolic names. Not always reasonable, but what I've grown used to see? Done with this function; Perhaps the 'permute table' is a common vector-algorithm pattern, but this is really clever! Some general comments first, rest inline. - The array names for registers helped a lot. And so did the new helper functions! - The java version of this code is quite intimidating to vectorize.. 3D loop, with geometric iteration variables.. and the literature is even more intimidating (discrete convolutions which I havent touched in two decades, ffts, ntts, etc.) Here is my attempt at a comment to 'un-scare' the next reader, though feel free to reword however you like. The core of the (Java) loop is this 'pair-wise' operation: int a = coeffs[j]; int b = coeffs[j + offset]; coeffs[j] = (a + b); coeffs[j + offset] = montMul(a - b, -MONT_ZETAS_FOR_NTT[m]); There are 8 'levels' (0-7); ('levels' are equivalent to (unrolling) the outer (Java) loop) At each level, the 'pair-wise-offset' doubles (2^l: 1, 2, 4, 8, 16, 32, 64, 128). To vectorize this Java code, observe that at each level, REGARDLESS the offset, half the operations are the SUM, and the other half is the montgomery MULTIPLICATION (of the pair-difference with a constant). At each level, one 'just' has to shuffle the coefficients, so that SUMs and MULTIPLICATIONs line up accordingly. Otherwise, this pattern is 'lightly similar' to a discrete convolution (compute integral/summation of two functions at every offset) - I still would prefer (more) symbolic register names.. I wouldn't hold my approval over it so won't object if nobody else does, but register numbers are harder to 'see' through the flow. I ended up search/replacing/'annotating' to make it easier on myself to follow the flow of data: // xmm8_11 = Perms1 // xmm12_15 = Perms2 // xmm16_27 = Scratch // xmm0_3 = CoeffsPlus // xmm4_7 = CoeffsMul // xmm24_27 = CoeffsMinus (overlaps with Scratch) (I made a similar comment, but I think it is now hidden after the last refactor) - would prefer to see the helper functions to get ALL the registers passed explicitly (i.e. currently `montMulPerm`, `montQInvModR`, `dilithium_q`, `xmm29`, are implicit.). As a general rule, I've tried to set up all the registers up at the 'entry' function (`generate_dilithium*` in this case) and from there on, use symbolic names. Not always reasonable, but what I've grown used to see? src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 554: > 552: for (int i = 0; i < 8; i += 2) { > 553: __ evpermi2d(xmm(i / 2 + 8), xmm(i), xmm(i + 1), > Assembler::AVX_512bit); > 554: __ evpermi2d(xmm(i / 2 + 12), xmm(i), xmm(i + 1), > Assembler::AVX_512bit); Took a bit to unscramble the flow, so a comment needed? Purpose 'fairly obvious' once I got the general shape of the level/algorithm (as per my top-level comment) but something like "shuffle xmm0-7 into xmm8-15"? src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 572: > 570: load4Xmms(xmm4_7, zetas, 512, _masm); > 571: sub_add(xmm24_27, xmm0_3, xmm8_11, xmm12_15, _masm); > 572: montMul64(xmm4_7, xmm24_27, xmm4_7, xmm16_27, _masm); >From my annotated version, levels 1-4, fairly 'straightforward': // level 1 replXmm(Perms1, perms, nttInvL1PermsIdx, _masm); replXmm(Perms2, perms, nttInvL1PermsIdx + 64, _masm); for (int i = 0; i < 4; i++) { __ evpermi2d(xmm(Perms1[i]), xmm(CoeffsPlus[i]), xmm(CoeffsMul[i]), Assembler::AVX_512bit); __ evpermi2d(xmm(Perms2[i]), xmm(CoeffsPlus[i]), xmm(CoeffsMul[i]), Assembler::AVX_512bit); } load4Xmms(CoeffsMul, zetas, 512, _masm); sub_add(CoeffsMinus, CoeffsPlus, Perms1, Perms2, _masm); montMul64(CoeffsMul, CoeffsMinus, CoeffsMul, Scratch, _masm); src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 613: > 611: montMul64(xmm4_7, xmm24_27, xmm4_7, xmm16_27, _masm); > 612: > 613: // level 5 "// No shuffling for level 5 and 6; can just rearrange full registers" src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 656: > 654: for (int i = 0; i < 8; i++) { > 655: __ evpsubd(xmm(i), k0, xmm(i + 8), xmm(i), false, > Assembler::AVX_512bit); > 656: } Fairly clean as is, but could also be two sub_add calls, I think (you have to swap order of add/sub in the helper, to be able to clobber `xmm(i)`.. or swap register usage downstream, so perhaps not.. but would be cleaner) sub_add(CoeffsPlus, Scratch, Perms1, CoeffsPlus, _masm); sub_add(CoeffsMul, &Scratch[4], Perms2, CoeffsMul, _masm); If nothing else, would had prefered to see the use of the register array variables src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 660: > 658: store4Xmms(coeffs, 0, xmm16_19, _masm); > 659: store4Xmms(coeffs, 4 * XMMBYTES, xmm20_23, _masm); > 660: montMulByConst128(_masm); Would prefer explicit parameters here. But I think this could also be two `montMul64` calls? montMul64(xmm0_3, xmm0_3, xmm29_29, Scratch, _masm); montMul64(xmm4_7, xmm4_7, xmm29_29, Scratch, _masm); (I think there is one other use of `montMulByConst128` where same applies; then you could delete both `montMulByConst128` and `montmulEven` src/hotspot/cpu/x86/stubGenerator_x86_64_dilithium.cpp line 871: > 869: __ evpaddd(xmm5, k0, xmm1, barrettAddend, false, > Assembler::AVX_512bit); > 870: __ evpaddd(xmm6, k0, xmm2, barrettAddend, false, > Assembler::AVX_512bit); > 871: __ evpaddd(xmm7, k0, xmm3, barrettAddend, false, > Assembler::AVX_512bit); Fairly 'straightforward' transcription of the java code.. no comments from me. At first glance using `xmm0_3`, `xmm4_7`, etc. might had been a good idea, but you only save one line per 4x group. (Unless you have one big loop, but I suspect that give you worse performance? Is that something you tried already? Might be worth it otherwise..) src/java.base/share/classes/sun/security/provider/ML_DSA.java line 1418: > 1416: int twoGamma2, int multiplier) > { > 1417: assert (input.length == ML_DSA_N) && (lowPart.length == > ML_DSA_N) > 1418: && (highPart.length == ML_DSA_N); I wrote this test to test java-to-intrinsic correspondence. Might be good to include it (and add the other 4 intrinsics). This is very similar to all my other *Fuzz* tests I've been adding for my own intrinsics (and you made this test FAR easier to write by breaking out the java implementation; need to 'copy' that pattern myself) import java.util.Arrays; import java.util.Random; import java.lang.invoke.MethodHandle; import java.lang.invoke.MethodHandles; import java.lang.reflect.Field; import java.lang.reflect.Method; import java.lang.reflect.Constructor; public class ML_DSA_Intrinsic_Test { public static void main(String[] args) throws Exception { MethodHandles.Lookup lookup = MethodHandles.lookup(); Class<?> kClazz = Class.forName("sun.security.provider.ML_DSA"); Constructor<?> constructor = kClazz.getDeclaredConstructor( int.class); constructor.setAccessible(true); Method m = kClazz.getDeclaredMethod("mlDsaNttMultiply", int[].class, int[].class, int[].class); m.setAccessible(true); MethodHandle mult = lookup.unreflect(m); m = kClazz.getDeclaredMethod("implDilithiumNttMultJava", int[].class, int[].class, int[].class); m.setAccessible(true); MethodHandle multJava = lookup.unreflect(m); Random rnd = new Random(); long seed = rnd.nextLong(); rnd.setSeed(seed); //Note: it might be useful to increase this number during development of new intrinsics final int repeat = 1000000; int[] coeffs1 = new int[ML_DSA_N]; int[] coeffs2 = new int[ML_DSA_N]; int[] prod1 = new int[ML_DSA_N]; int[] prod2 = new int[ML_DSA_N]; try { for (int i = 0; i < repeat; i++) { run(prod1, prod2, coeffs1, coeffs2, mult, multJava, rnd, seed, i); } System.out.println("Fuzz Success"); } catch (Throwable e) { System.out.println("Fuzz Failed: " + e); } } private static final int ML_DSA_N = 256; public static void run(int[] prod1, int[] prod2, int[] coeffs1, int[] coeffs2, MethodHandle mult, MethodHandle multJava, Random rnd, long seed, int i) throws Exception, Throwable { for (int j = 0; j<ML_DSA_N; j++) { coeffs1[j] = rnd.nextInt(); coeffs2[j] = rnd.nextInt(); } mult.invoke(prod1, coeffs1, coeffs2); multJava.invoke(prod2, coeffs1, coeffs2); if (!Arrays.equals(prod1, prod2)) { throw new RuntimeException("[Seed "+seed+"@"+i+"] Result mismatch: " + Arrays.toString(prod1) + " != " + Arrays.toString(prod2)); } } } // java --add-opens java.base/sun.security.provider=ALL-UNNAMED -XX:+UseDilithiumIntrinsics test/jdk/sun/security/provider/acvp/ML_DSA_Intrinsic_Test.java ------------- PR Review: https://git.openjdk.org/jdk/pull/23860#pullrequestreview-2708301954 PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2008921783 PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2009415317 PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2009477186 PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2009428310 PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2009433467 PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2009435329 PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2009435791 PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2009437669 PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2009438921 PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2009486160 PR Review Comment: https://git.openjdk.org/jdk/pull/23860#discussion_r2010355575