On Thu, 6 Feb 2025 18:47:54 GMT, Ferenc Rakoczi <d...@openjdk.org> wrote:
>> By using the aarch64 vector registers the speed of the computation of the >> ML-DSA algorithms (key generation, document signing, signature verification) >> can be approximately doubled. > > Ferenc Rakoczi has updated the pull request incrementally with one additional > commit since the last revision: > > Adding comments + some code reorganization src/hotspot/cpu/aarch64/stubGenerator_aarch64.cpp line 4066: > 4064: } > 4065: > 4066: // Execute on round of keccak of two computations in parallel. Suggestion: It would be helpful to add comments that relate the register and instruction selection to the original Java source code. e.g. change the header as follows // Performs 2 keccak round transformations using vector parallelism // // Two sets of 25 * 64-bit input states a0[lo:hi]...a24[lo:hi] are passed in // the lower/upper halves of registers v0...v24 and the transformed states // are returned in the same registers. Intermediate 64-bit pairs // c0...c5 and d0...d5 are computed in registers v25...v30. v31 is // loaded with the required pair of 64 bit rounding constants. // During computation of the output states some intermediate results are // shuffled around registers v0...v30. Comments on each line indicate // how the values in registers correspond to variables ai, ci, di in // the Java source code, likewise how the generated machine instructions // correspond to Java source operations (n.b. rol means rotate left). The annotate the generation steps as follows: __ eor3(v29, __ T16B, v4, v9, v14); // c4 = a4 ^ a9 ^ a14 __ eor3(v26, __ T16B, v1, v6, v11); // c1 = a1 ^ a16 ^ a11 __ eor3(v28, __ T16B, v3, v8, v13); // c3 = a3 ^ a8 ^a13 __ eor3(v25, __ T16B, v0, v5, v10); // c0 = a0 ^ a5 ^ a10 __ eor3(v27, __ T16B, v2, v7, v12); // c2 = a2 ^ a7 ^ a12 __ eor3(v29, __ T16B, v29, v19, v24); // c4 ^= a19 ^ a24 __ eor3(v26, __ T16B, v26, v16, v21); // c1 ^= a16 ^ a21 __ eor3(v28, __ T16B, v28, v18, v23); // c3 ^= a18 ^ a23 __ eor3(v25, __ T16B, v25, v15, v20); // c0 ^= a15 ^ a20 __ eor3(v27, __ T16B, v27, v17, v22); // c2 ^= a17 ^ a22 __ rax1(v30, __ T2D, v29, v26); // d0 = c4 ^ rol(c1, 1) __ rax1(v26, __ T2D, v26, v28); // d2 = c1 ^ rol(c3, 1) __ rax1(v28, __ T2D, v28, v25); // d4 = c3 ^ rol(c0, 1) __ rax1(v25, __ T2D, v25, v27); // d1 = c0 ^ rol(c2, 1) __ rax1(v27, __ T2D, v27, v29); // d3 = c2 ^ rol(c4, 1) __ eor(v0, __ T16B, v0, v30); // a0 = a0 ^ d0 __ xar(v29, __ T2D, v1, v25, (64 - 1)); // a10' = rol((a1^d1), 1) __ xar(v1, __ T2D, v6, v25, (64 - 44)); // a1 = rol(a6^d1), 44) __ xar(v6, __ T2D, v9, v28, (64 - 20)); // a6 = rol((a9^d4), 20) __ xar(v9, __ T2D, v22, v26, (64 - 61)); // a9 = rol((a22^d2), 61) __ xar(v22, __ T2D, v14, v28, (64 - 39)); // a22 = rol((a14^d4), 39) __ xar(v14, __ T2D, v20, v30, (64 - 18)); // a14 = rol((a20^d0), 18) __ xar(v31, __ T2D, v2, v26, (64 - 62)); // a20' = rol((a2^d2), 62) __ xar(v2, __ T2D, v12, v26, (64 - 43)); // a2 = rol((a12^d2), 43) __ xar(v12, __ T2D, v13, v27, (64 - 25)); // a12 = rol((a13^d3), 25) __ xar(v13, __ T2D, v19, v28, (64 - 8)); // a13 = rol((a19^d4), 8) __ xar(v19, __ T2D, v23, v27, (64 - 56)); // a19 = rol((a23^d3), 56) __ xar(v23, __ T2D, v15, v30, (64 - 41)); // a23 = rol((a15^d0), 41) __ xar(v15, __ T2D, v4, v28, (64 - 27)); // a15 = rol((a4^d4), 27) __ xar(v28, __ T2D, v24, v28, (64 - 14)); // a4' = rol((a24^d4), 14) __ xar(v24, __ T2D, v21, v25, (64 - 2)); // a24 = rol((a21^d1), 2) __ xar(v8, __ T2D, v8, v27, (64 - 55)); // a21' = rol((a8^d3), 55) __ xar(v4, __ T2D, v16, v25, (64 - 45)); // a8' = rol((a16^d1), 45) __ xar(v16, __ T2D, v5, v30, (64 - 36)); // a16 = rol((a5^d0), 36) __ xar(v5, __ T2D, v3, v27, (64 - 28)); // a5 = rol((a3^d3), 28) __ xar(v27, __ T2D, v18, v27, (64 - 21)); // a3' = rol((a18^d3), 21) __ xar(v3, __ T2D, v17, v26, (64 - 15)); // a18' = rol((a17^d2), 15) __ xar(v25, __ T2D, v11, v25, (64 - 10)); // a17' = rol((a11^d1), 10) __ xar(v26, __ T2D, v7, v26, (64 - 6)); // a11' = rol((a7^d2), 6) __ xar(v30, __ T2D, v10, v30, (64 - 3)); // a7' = rol((a10^d0), 3) __ bcax(v20, __ T16B, v31, v22, v8); // a20 = a20' ^ (~a21 & a22') __ bcax(v21, __ T16B, v8, v23, v22); // a21 = a21' ^ (~a22 & a23) __ bcax(v22, __ T16B, v22, v24, v23); // a22 = a22 ^ (~a23 & a24) __ bcax(v23, __ T16B, v23, v31, v24); // a23 = a23 ^ (~a24 & a20') __ bcax(v24, __ T16B, v24, v8, v31); // a24 = a24 ^ (~a20' & a21') __ ld1r(v31, __ T2D, __ post(rscratch1, 8)); // rc = round_constants[i] __ bcax(v17, __ T16B, v25, v19, v3); // a17 = a17' ^ (~a18' & a19) __ bcax(v18, __ T16B, v3, v15, v19); // a18 = a18' ^ (~a19 & a15') __ bcax(v19, __ T16B, v19, v16, v15); // a19 = a19 ^ (~a15 & a16) __ bcax(v15, __ T16B, v15, v25, v16); // a15 = a15 ^ (~a16 & a17') __ bcax(v16, __ T16B, v16, v3, v25); // a16 = a16 ^ (~a17' & a18') __ bcax(v10, __ T16B, v29, v12, v26); // a10 = a10' ^ (~a11' & a12) __ bcax(v11, __ T16B, v26, v13, v12); // a11 = a11' ^ (~a12 & a13) __ bcax(v12, __ T16B, v12, v14, v13); // a12 = a12 ^ (~a13 & a14) __ bcax(v13, __ T16B, v13, v29, v14); // a13 = a13 ^ (~a14 & a10') __ bcax(v14, __ T16B, v14, v26, v29); // a14 = a14 ^ (~a10' & a11') __ bcax(v7, __ T16B, v30, v9, v4); // a7 = a7' ^ (~a8' & a9) __ bcax(v8, __ T16B, v4, v5, v9); // a8 = a8' ^ (~a9 & a5) __ bcax(v9, __ T16B, v9, v6, v5); // a9 = a9 ^ (~a5 & a6) __ bcax(v5, __ T16B, v5, v30, v6); // a5 = a5 ^ (~a6 & a7) __ bcax(v6, __ T16B, v6, v4, v30); // a6 = a6 ^ (~a7 & a8') __ bcax(v3, __ T16B, v27, v0, v28); // a3 = a3' ^ (~a4' & a0) __ bcax(v4, __ T16B, v28, v1, v0); // a4 = a4' ^ (~a0 & a1) __ bcax(v0, __ T16B, v0, v2, v1); // a0 = a0 ^ (~a1 & a2) __ bcax(v1, __ T16B, v1, v27, v2); // a1 = a1 ^ (~a2 & a3) __ bcax(v2, __ T16B, v2, v28, v27); // a2 = a2 ^ (~a3 & a4') __ eor(v0, __ T16B, v0, v31); // a0 = a0 ^ rc ------------- PR Review Comment: https://git.openjdk.org/jdk/pull/23300#discussion_r1959776475