This is an automated email from the ASF dual-hosted git repository. lahirujayathilake pushed a commit to branch custos-signer in repository https://gitbox.apache.org/repos/asf/airavata-custos.git
commit f219899f1c3d096ad3d7a8eb1fc51f2036e1307c Author: lahiruj <[email protected]> AuthorDate: Thu Dec 11 16:39:32 2025 -0500 Fix SSH cert public key handling and OpenSSH private key encoding for RSA/ECDSA --- .../apache/custos/signer/sdk/util/SshKeyUtils.java | 146 +++++++++++++++++++-- .../signer/service/ca/SshCertificateSigner.java | 19 ++- 2 files changed, 150 insertions(+), 15 deletions(-) diff --git a/signer/signer-sdk-core/src/main/java/org/apache/custos/signer/sdk/util/SshKeyUtils.java b/signer/signer-sdk-core/src/main/java/org/apache/custos/signer/sdk/util/SshKeyUtils.java index 00f4d09f8..678ef0e38 100644 --- a/signer/signer-sdk-core/src/main/java/org/apache/custos/signer/sdk/util/SshKeyUtils.java +++ b/signer/signer-sdk-core/src/main/java/org/apache/custos/signer/sdk/util/SshKeyUtils.java @@ -30,7 +30,9 @@ import java.math.BigInteger; import java.security.KeyPair; import java.security.KeyPairGenerator; import java.security.PublicKey; +import java.security.SecureRandom; import java.security.interfaces.ECPublicKey; +import java.security.interfaces.RSAPrivateCrtKey; import java.security.interfaces.RSAPublicKey; import java.security.spec.ECGenParameterSpec; import java.security.spec.ECParameterSpec; @@ -87,13 +89,17 @@ public final class SshKeyUtils { throw new IllegalArgumentException("KeyPair or private key is null"); } - // For Ed25519, generate OpenSSH format + // Generate OpenSSH format for all key types String algorithm = keyPair.getPrivate().getAlgorithm(); if ("EdDSA".equals(algorithm) || "Ed25519".equals(algorithm)) { return keyPairToOpenSshPrivateKey(keyPair); + } else if ("RSA".equalsIgnoreCase(algorithm)) { + return keyPairToOpenSshPrivateKeyRsa(keyPair); + } else if ("EC".equalsIgnoreCase(algorithm)) { + return keyPairToOpenSshPrivateKeyEcdsa(keyPair); } - // For other key types, use standard PKCS#8 format + // Fallback to PKCS#8 for unsupported types StringWriter writer = new StringWriter(); try (JcaPEMWriter pemWriter = new JcaPEMWriter(writer)) { pemWriter.writeObject(keyPair.getPrivate()); @@ -178,27 +184,150 @@ public final class SshKeyUtils { privOut.flush(); byte[] privateKeyBlobBytes = privateBuf.toByteArray(); - // Build OpenSSH key format + return formatOpenSshKey(publicKeyBlobBytes, privateKeyBlobBytes); + } + + /** + * Convert RSA key pair to OpenSSH private key format. + */ + private static String keyPairToOpenSshPrivateKeyRsa(KeyPair keyPair) throws Exception { + if (!(keyPair.getPrivate() instanceof RSAPrivateCrtKey)) { + throw new IllegalArgumentException("RSA private key must be RSAPrivateCrtKey to extract prime factors"); + } + RSAPrivateCrtKey rsaPrivate = (RSAPrivateCrtKey) keyPair.getPrivate(); + RSAPublicKey rsaPublic = (RSAPublicKey) keyPair.getPublic(); + + ByteArrayOutputStream publicKeyBlob = new ByteArrayOutputStream(); + DataOutputStream publicKeyOut = new DataOutputStream(publicKeyBlob); + writeString(publicKeyOut, "ssh-rsa".getBytes(java.nio.charset.StandardCharsets.UTF_8)); + writeMpInt(publicKeyOut, rsaPublic.getPublicExponent()); + writeMpInt(publicKeyOut, rsaPublic.getModulus()); + publicKeyOut.flush(); + byte[] publicKeyBlobBytes = publicKeyBlob.toByteArray(); + + // Build private section + SecureRandom random = new java.security.SecureRandom(); + int checkInt = random.nextInt(); + + ByteArrayOutputStream privateBuf = new ByteArrayOutputStream(); + DataOutputStream privOut = new DataOutputStream(privateBuf); + + privOut.writeInt(checkInt); + privOut.writeInt(checkInt); + + writeString(privOut, "ssh-rsa".getBytes(java.nio.charset.StandardCharsets.UTF_8)); + + // Public key fields (per PROTOCOL.key) + writeMpInt(privOut, rsaPrivate.getModulus()); // n + writeMpInt(privOut, rsaPublic.getPublicExponent()); // e + + // Private key fields + writeMpInt(privOut, rsaPrivate.getPrivateExponent()); // d + writeMpInt(privOut, rsaPrivate.getPrimeQ().modInverse(rsaPrivate.getPrimeP())); // iqmp + writeMpInt(privOut, rsaPrivate.getPrimeP()); // p + writeMpInt(privOut, rsaPrivate.getPrimeQ()); // q + + // comment + writeString(privOut, "custos-generated".getBytes(java.nio.charset.StandardCharsets.UTF_8)); + + // padding + int paddingNeeded = (8 - (privateBuf.size() % 8)) % 8; + for (int i = 1; i <= paddingNeeded; i++) { + privOut.write(i); + } + privOut.flush(); + byte[] privateKeyBlobBytes = privateBuf.toByteArray(); + + return formatOpenSshKey(publicKeyBlobBytes, privateKeyBlobBytes); + } + + /** + * Convert ECDSA key pair to OpenSSH private key format. + */ + private static String keyPairToOpenSshPrivateKeyEcdsa(KeyPair keyPair) throws Exception { + java.security.interfaces.ECPrivateKey ecPrivate = (java.security.interfaces.ECPrivateKey) keyPair.getPrivate(); + ECPublicKey ecPublic = (ECPublicKey) keyPair.getPublic(); + + // Determine curve name + String curveName = "nistp256"; + ECParameterSpec params = ecPublic.getParams(); + int fieldSize = (params.getCurve().getField().getFieldSize() + 7) / 8; + if (fieldSize == 32) { + curveName = "nistp256"; + } else if (fieldSize == 48) { + curveName = "nistp384"; + } else if (fieldSize == 66) { + curveName = "nistp521"; + } + + String keyType = "ecdsa-sha2-" + curveName; + + ByteArrayOutputStream publicKeyBlob = new ByteArrayOutputStream(); + DataOutputStream publicKeyOut = new DataOutputStream(publicKeyBlob); + writeString(publicKeyOut, keyType.getBytes(java.nio.charset.StandardCharsets.UTF_8)); + writeString(publicKeyOut, curveName.getBytes(java.nio.charset.StandardCharsets.UTF_8)); + byte[] q = encodeEcPointUncompressed(ecPublic.getW(), params); + writeString(publicKeyOut, q); + publicKeyOut.flush(); + byte[] publicKeyBlobBytes = publicKeyBlob.toByteArray(); + + // Build private section + java.security.SecureRandom random = new java.security.SecureRandom(); + int checkInt = random.nextInt(); + + ByteArrayOutputStream privateBuf = new ByteArrayOutputStream(); + DataOutputStream privOut = new DataOutputStream(privateBuf); + + privOut.writeInt(checkInt); + privOut.writeInt(checkInt); + + writeString(privOut, keyType.getBytes(java.nio.charset.StandardCharsets.UTF_8)); + writeString(privOut, curveName.getBytes(java.nio.charset.StandardCharsets.UTF_8)); + + // Public key fields + writeString(privOut, q); + + // Private key + writeMpInt(privOut, ecPrivate.getS()); + + // comment + writeString(privOut, "custos-generated".getBytes(java.nio.charset.StandardCharsets.UTF_8)); + + // padding + int paddingNeeded = (8 - (privateBuf.size() % 8)) % 8; + for (int i = 1; i <= paddingNeeded; i++) { + privOut.write(i); + } + privOut.flush(); + byte[] privateKeyBlobBytes = privateBuf.toByteArray(); + + return formatOpenSshKey(publicKeyBlobBytes, privateKeyBlobBytes); + } + + /** + * Format OpenSSH key from public and private key blobs. + */ + private static String formatOpenSshKey(byte[] publicKeyBlobBytes, byte[] privateKeyBlobBytes) throws Exception { ByteArrayOutputStream opensshKey = new ByteArrayOutputStream(); DataOutputStream out = new DataOutputStream(opensshKey); - // Magic string: "openssh-key-v1\0" + // Magic string - "openssh-key-v1\0" out.write("openssh-key-v1\0".getBytes(java.nio.charset.StandardCharsets.UTF_8)); - // Cipher name: "none" (unencrypted) + // Cipher name - "none" (unencrypted) byte[] cipherName = "none".getBytes(java.nio.charset.StandardCharsets.UTF_8); out.writeInt(cipherName.length); out.write(cipherName); - // KDF name: "none" + // KDF name - "none" byte[] kdfName = "none".getBytes(java.nio.charset.StandardCharsets.UTF_8); out.writeInt(kdfName.length); out.write(kdfName); - // KDF options: empty (4 bytes for length = 0) + // KDF options - empty (4 bytes for length = 0) out.writeInt(0); - // Number of keys: 1 + // Number of keys - 1 out.writeInt(1); // Public key (with length prefix) @@ -211,7 +340,6 @@ public final class SshKeyUtils { String base64Key = Base64.getEncoder().encodeToString(opensshKey.toByteArray()); - // Format as PEM StringBuilder pem = new StringBuilder(); pem.append("-----BEGIN OPENSSH PRIVATE KEY-----\n"); // Split base64 into 70-character lines diff --git a/signer/signer-service/src/main/java/org/apache/custos/signer/service/ca/SshCertificateSigner.java b/signer/signer-service/src/main/java/org/apache/custos/signer/service/ca/SshCertificateSigner.java index c054dfd7d..bf5db240e 100644 --- a/signer/signer-service/src/main/java/org/apache/custos/signer/service/ca/SshCertificateSigner.java +++ b/signer/signer-service/src/main/java/org/apache/custos/signer/service/ca/SshCertificateSigner.java @@ -69,9 +69,7 @@ public class SshCertificateSigner { /** * Sign an SSH public key and return the certificate */ - public SshCertificateResult signCertificate(String tenantId, String clientId, - String principal, int ttlSeconds, - byte[] publicKeyBytes, String caFingerprint) { + public SshCertificateResult signCertificate(String tenantId, String clientId, String principal, int ttlSeconds, byte[] publicKeyBytes, String caFingerprint) { try { // Parse the public key SshPublicKey publicKey = parseSshPublicKey(publicKeyBytes); @@ -139,6 +137,7 @@ public class SshCertificateSigner { byte[] decoded = Base64.getDecoder().decode(parts[1]); byte[] rawKeyBytes; + byte[] keyDataForCert; if (SSH_KEY_TYPE_ED25519.equals(keyType)) { ByteBuffer buf = ByteBuffer.wrap(decoded); int typeLen = buf.getInt(); @@ -154,6 +153,12 @@ public class SshCertificateSigner { } rawKeyBytes = new byte[pkLen]; buf.get(rawKeyBytes); + // For certificates, include the length prefix for the raw 32-byte key + ByteBuffer keyBuf = ByteBuffer.allocate(4 + pkLen); + keyBuf.putInt(pkLen); + keyBuf.put(rawKeyBytes); + keyDataForCert = keyBuf.array(); + } else { ByteBuffer buf = ByteBuffer.wrap(decoded); int typeLen = buf.getInt(); @@ -166,6 +171,8 @@ public class SshCertificateSigner { byte[] remaining = new byte[buf.remaining()]; buf.get(remaining); rawKeyBytes = remaining; + // For RSA/ECDSA the certificate needs the full SSH wire-format public key blob + keyDataForCert = rawKeyBytes; } // Calculate fingerprint (SHA256 hash) @@ -173,7 +180,7 @@ public class SshCertificateSigner { byte[] hash = digest.digest(rawKeyBytes); String fingerprint = Base64.getEncoder().encodeToString(hash); - return new SshPublicKey(keyType, rawKeyBytes, fingerprint); + return new SshPublicKey(keyType, keyDataForCert, fingerprint); } /** @@ -266,7 +273,7 @@ public class SshCertificateSigner { writeBytes(out, certificate.getNonce()); // Subject public key (SSH wire public key blob) - writeBytes(out, certificate.getPublicKey()); + out.write(certificate.getPublicKey()); // Serial writeUint64(out, certificate.getSerial()); @@ -319,7 +326,7 @@ public class SshCertificateSigner { writeBytes(out, certificate.getNonce()); // Subject public key (SSH wire public key blob) - writeBytes(out, certificate.getPublicKey()); + out.write(certificate.getPublicKey()); // Serial writeUint64(out, certificate.getSerial());
