This is an automated email from the ASF dual-hosted git repository.

lgoldstein pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/mina-sshd.git

commit 562d2fe7f8313a14299263dacb29bde626b52a77
Author: Lyor Goldstein <[email protected]>
AuthorDate: Mon Feb 18 10:07:28 2019 +0200

    [SSHD-895] Implemented a sample DefaultClientKexExtensionHandler that 
updates the client's signature factories
---
 CHANGES.md                                         |   9 +
 .../sshd/cli/client/SshClientCliSupport.java       |  34 +++
 .../sshd/common/kex/extension/KexExtensions.java   |   5 +
 .../sshd/common/signature/SignatureFactory.java    | 105 ++++++++
 .../DefaultClientKexExtensionHandler.java          | 293 +++++++++++++++++++++
 .../common/kex/extension/KexExtensionHandler.java  |  53 +++-
 .../org/apache/sshd/common/session/Session.java    |   8 +
 .../common/session/helpers/AbstractSession.java    |  45 +++-
 .../sshd/common/session/helpers/SessionHelper.java |   1 +
 9 files changed, 544 insertions(+), 9 deletions(-)

diff --git a/CHANGES.md b/CHANGES.md
index 9815aa6..5a2777c 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -18,6 +18,15 @@ current sesssion - client/server proposals and what has been 
negotiated.
 
 * The `Session` object provides a `KexExtensionHandler` for usage with [KEX 
extension negotiation](https://tools.wordtothewise.com/rfc/rfc8308)
 
+## Minor code helpers
+
+* The `Session` object provides a `isServerSession` method that can be used to 
distinguish between
+client/server instances without having to resort to `instanceof`.
+
+* When creating a CLI SSH client one can specify `-o KexExtensionHandler=XXX` 
option to initialize
+a client-side `KexExtensionHandler` using an FQCN. If `default` is specified 
as the option value,
+then the internal `DefaultClientKexExtensionHandler` is used.
+
 ## Behavioral changes and enhancements
 
 * [SSHD-882](https://issues.apache.org/jira/browse/SSHD-882) - Provide hooks 
to allow users to register a consumer
diff --git 
a/sshd-cli/src/main/java/org/apache/sshd/cli/client/SshClientCliSupport.java 
b/sshd-cli/src/main/java/org/apache/sshd/cli/client/SshClientCliSupport.java
index ebc8f78..8e68cc6 100644
--- a/sshd-cli/src/main/java/org/apache/sshd/cli/client/SshClientCliSupport.java
+++ b/sshd-cli/src/main/java/org/apache/sshd/cli/client/SshClientCliSupport.java
@@ -70,12 +70,16 @@ import 
org.apache.sshd.common.config.ConfigFileReaderSupport;
 import org.apache.sshd.common.config.keys.BuiltinIdentities;
 import org.apache.sshd.common.config.keys.KeyUtils;
 import org.apache.sshd.common.config.keys.PublicKeyEntry;
+import org.apache.sshd.common.kex.KexFactoryManager;
+import org.apache.sshd.common.kex.extension.DefaultClientKexExtensionHandler;
+import org.apache.sshd.common.kex.extension.KexExtensionHandler;
 import org.apache.sshd.common.keyprovider.FileKeyPairProvider;
 import org.apache.sshd.common.mac.BuiltinMacs;
 import org.apache.sshd.common.mac.Mac;
 import org.apache.sshd.common.util.GenericUtils;
 import org.apache.sshd.common.util.OsUtils;
 import org.apache.sshd.common.util.io.NoCloseOutputStream;
+import org.apache.sshd.common.util.threads.ThreadUtils;
 
 /**
  * TODO Add javadoc
@@ -335,6 +339,7 @@ public abstract class SshClientCliSupport extends 
CliSupport {
 
             setupServerKeyVerifier(client, options, stdin, stdout, stderr);
             setupSessionUserInteraction(client, stdin, stdout, stderr);
+            setupSessionExtensions(client, options, stdin, stdout, stderr);
 
             Map<String, Object> props = client.getProperties();
             props.putAll(options);
@@ -422,6 +427,35 @@ public abstract class SshClientCliSupport extends 
CliSupport {
         return ui;
     }
 
+    public static void setupSessionExtensions(
+            KexFactoryManager manager, Map<String, ?> options, BufferedReader 
stdin, PrintStream stdout, PrintStream stderr)
+                throws Exception {
+        String kexExtension = 
Objects.toString(options.remove(KexExtensionHandler.class.getSimpleName()), 
null);
+        if (GenericUtils.isEmpty(kexExtension)) {
+            return;
+        }
+
+        if ("default".equalsIgnoreCase(kexExtension)) {
+            
manager.setKexExtensionHandler(DefaultClientKexExtensionHandler.INSTANCE);
+            stdout.println("Using " + 
DefaultClientKexExtensionHandler.class.getSimpleName());
+        } else {
+            ClassLoader cl = 
ThreadUtils.resolveDefaultClassLoader(KexExtensionHandler.class);
+            try {
+                Class<?> clazz = cl.loadClass(kexExtension);
+                KexExtensionHandler handler = 
KexExtensionHandler.class.cast(clazz.newInstance());
+                manager.setKexExtensionHandler(handler);
+            } catch (Exception e) {
+                stderr.append("ERROR: Failed 
(").append(e.getClass().getSimpleName()).append(')')
+                    .append(" to instantiate KEX extension 
handler=").append(kexExtension)
+                    .append(": ").println(e.getMessage());
+                stderr.flush();
+                throw e;
+            }
+
+            stdout.println("Using " + 
KexExtensionHandler.class.getSimpleName() + "=" + kexExtension);
+        }
+    }
+
     public static ServerKeyVerifier setupServerKeyVerifier(
             ClientAuthenticationManager manager, Map<String, ?> options, 
BufferedReader stdin, PrintStream stdout, PrintStream stderr) {
         ServerKeyVerifier current = manager.getServerKeyVerifier();
diff --git 
a/sshd-common/src/main/java/org/apache/sshd/common/kex/extension/KexExtensions.java
 
b/sshd-common/src/main/java/org/apache/sshd/common/kex/extension/KexExtensions.java
index 41fbce0..cdf484a 100644
--- 
a/sshd-common/src/main/java/org/apache/sshd/common/kex/extension/KexExtensions.java
+++ 
b/sshd-common/src/main/java/org/apache/sshd/common/kex/extension/KexExtensions.java
@@ -32,6 +32,7 @@ import java.util.NavigableSet;
 import java.util.Objects;
 import java.util.TreeMap;
 import java.util.function.Function;
+import java.util.function.Predicate;
 import java.util.stream.Collectors;
 import java.util.stream.Stream;
 
@@ -57,6 +58,10 @@ public final class KexExtensions {
     public static final String CLIENT_KEX_EXTENSION = "ext-info-c";
     public static final String SERVER_KEX_EXTENSION = "ext-info-s";
 
+    @SuppressWarnings("checkstyle:Indentation")
+    public static final Predicate<String> IS_KEX_EXTENSION_SIGNAL =
+        n -> CLIENT_KEX_EXTENSION.equalsIgnoreCase(n) || 
SERVER_KEX_EXTENSION.equalsIgnoreCase(n);
+
     /**
      * A case <U>insensitive</U> map of all the default known {@link 
KexExtensionParser}
      * where key=the extension name
diff --git 
a/sshd-common/src/main/java/org/apache/sshd/common/signature/SignatureFactory.java
 
b/sshd-common/src/main/java/org/apache/sshd/common/signature/SignatureFactory.java
index b6273c0..25296d0 100644
--- 
a/sshd-common/src/main/java/org/apache/sshd/common/signature/SignatureFactory.java
+++ 
b/sshd-common/src/main/java/org/apache/sshd/common/signature/SignatureFactory.java
@@ -20,16 +20,20 @@
 package org.apache.sshd.common.signature;
 
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.HashSet;
 import java.util.List;
+import java.util.Map;
 import java.util.Set;
+import java.util.TreeMap;
 
 import org.apache.sshd.common.BuiltinFactory;
 import org.apache.sshd.common.NamedFactory;
 import org.apache.sshd.common.NamedResource;
 import org.apache.sshd.common.config.keys.KeyUtils;
+import org.apache.sshd.common.keyprovider.KeyPairProvider;
 import org.apache.sshd.common.util.GenericUtils;
 
 /**
@@ -37,6 +41,26 @@ import org.apache.sshd.common.util.GenericUtils;
  */
 public interface SignatureFactory extends BuiltinFactory<Signature> {
     /**
+     * ECC signature types in ascending order of preference (i.e., most 
preferred 1st)
+     */
+    List<String> ECC_SIGNATURE_TYPE_PREFERENCES =
+        Collections.unmodifiableList(
+            Arrays.asList(
+                    KeyPairProvider.ECDSA_SHA2_NISTP521,
+                    KeyPairProvider.ECDSA_SHA2_NISTP384,
+                    KeyPairProvider.ECDSA_SHA2_NISTP256));
+
+    /**
+     * RSA signature types in ascending order of preference (i.e., most 
preferred 1st)
+     */
+    List<String> RSA_SIGNATURE_TYPE_PREFERENCES =
+        Collections.unmodifiableList(
+            Arrays.asList(
+                KeyUtils.RSA_SHA512_KEY_TYPE_ALIAS,
+                KeyUtils.RSA_SHA256_KEY_TYPE_ALIAS,
+                KeyPairProvider.SSH_RSA));
+
+    /**
      * @param provided The provided signature key types
      * @param factories The available signature factories
      * @return A {@link List} of the matching available factories names
@@ -89,5 +113,86 @@ public interface SignatureFactory extends 
BuiltinFactory<Signature> {
 
         return supported;
     }
+
+    // returns -1 or > size() if append to end
+    static int resolvePreferredSignaturePosition(
+            List<? extends NamedFactory<Signature>> factories, 
NamedFactory<Signature> factory) {
+        if (GenericUtils.isEmpty(factories)) {
+            return -1;  // just add it to the end
+        }
+
+        String name = factory.getName();
+        if (KeyPairProvider.SSH_RSA.equalsIgnoreCase(name)) {
+            return -1;
+        }
+
+        int pos = RSA_SIGNATURE_TYPE_PREFERENCES.indexOf(name);
+        if (pos >= 0) {
+            Map<String, Integer> posMap = new 
TreeMap<>(String.CASE_INSENSITIVE_ORDER);
+            for (int index = 0, count = factories.size(); index < count; 
index++) {
+                NamedFactory<Signature> f = factories.get(index);
+                String keyType = f.getName();
+                String canonicalName = KeyUtils.getCanonicalKeyType(keyType);
+                if (!KeyPairProvider.SSH_RSA.equalsIgnoreCase(canonicalName)) {
+                    continue;   // debug breakpoint
+                }
+
+                posMap.put(keyType, index);
+            }
+
+            return 
resolvePreferredSignaturePosition(RSA_SIGNATURE_TYPE_PREFERENCES, pos, posMap);
+        }
+
+        pos = ECC_SIGNATURE_TYPE_PREFERENCES.indexOf(name);
+        if (pos >= 0) {
+            Map<String, Integer> posMap = new 
TreeMap<>(String.CASE_INSENSITIVE_ORDER);
+            for (int index = 0, count = factories.size(); index < count; 
index++) {
+                NamedFactory<Signature> f = factories.get(index);
+                String keyType = f.getName();
+                if (!ECC_SIGNATURE_TYPE_PREFERENCES.contains(keyType)) {
+                    continue;   // debug breakpoint
+                }
+
+                posMap.put(keyType, index);
+            }
+
+            return 
resolvePreferredSignaturePosition(ECC_SIGNATURE_TYPE_PREFERENCES, pos, posMap);
+        }
+
+        return -1;  // no special preference - stick it as last
+    }
+
+    static int resolvePreferredSignaturePosition(
+            List<String> preferredOrder, int prefValue, Map<String, Integer> 
posMap) {
+        if (GenericUtils.isEmpty(preferredOrder) || (prefValue < 0) || 
GenericUtils.isEmpty(posMap)) {
+            return -1;
+        }
+
+        int posValue = -1;
+        for (Map.Entry<String, Integer> pe : posMap.entrySet()) {
+            String name = pe.getKey();
+            int order = preferredOrder.indexOf(name);
+            if (order < 0) {
+                continue;   // should not happen, but tolerate
+            }
+
+            Integer curIndex = pe.getValue();
+            int resIndex;
+            if (order < prefValue) {
+                resIndex = curIndex.intValue() + 1;
+            } else if (order > prefValue) {
+                resIndex = curIndex.intValue(); // by using same index we 
insert in front of it in effect
+            } else {
+                continue;   // should not happen, but tolerate
+            }
+
+            // Preferred factories should be as close as possible to the 
beginning of the list
+            if ((posValue < 0) || (resIndex < posValue)) {
+                posValue = resIndex;
+            }
+        }
+
+        return posValue;
+    }
 }
 
diff --git 
a/sshd-core/src/main/java/org/apache/sshd/common/kex/extension/DefaultClientKexExtensionHandler.java
 
b/sshd-core/src/main/java/org/apache/sshd/common/kex/extension/DefaultClientKexExtensionHandler.java
new file mode 100644
index 0000000..8654a84
--- /dev/null
+++ 
b/sshd-core/src/main/java/org/apache/sshd/common/kex/extension/DefaultClientKexExtensionHandler.java
@@ -0,0 +1,293 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sshd.common.kex.extension;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.EnumMap;
+import java.util.List;
+import java.util.Map;
+import java.util.NavigableSet;
+import java.util.Objects;
+import java.util.stream.Stream;
+
+import org.apache.sshd.common.AttributeRepository.AttributeKey;
+import org.apache.sshd.common.NamedFactory;
+import org.apache.sshd.common.NamedResource;
+import org.apache.sshd.common.OptionalFeature;
+import org.apache.sshd.common.config.keys.KeyUtils;
+import org.apache.sshd.common.kex.KexProposalOption;
+import org.apache.sshd.common.kex.extension.parser.ServerSignatureAlgorithms;
+import org.apache.sshd.common.session.Session;
+import org.apache.sshd.common.signature.BuiltinSignatures;
+import org.apache.sshd.common.signature.Signature;
+import org.apache.sshd.common.signature.SignatureFactory;
+import org.apache.sshd.common.util.GenericUtils;
+import org.apache.sshd.common.util.logging.AbstractLoggingBean;
+
+/**
+ * Detects if the server sends a <A 
HREF="https://tools.ietf.org/html/rfc8308#section-3.1";>&quot;server-sig-algs&quot;</A>
+ * and updates the client session by adding the <A 
HREF="https://tools.ietf.org/html/rfc8332";>&quot;rsa-sha2-256/512&quot;</A>
+ * signature factories (if not already added).
+ *
+ * <B>Note:</B> experimental - used for development purposes and as an example
+ * @author <a href="mailto:[email protected]";>Apache MINA SSHD Project</a>
+ */
+public class DefaultClientKexExtensionHandler extends AbstractLoggingBean 
implements KexExtensionHandler {
+    /**
+     * Session {@link AttributeKey} used to store the client's proposal
+     */
+    public static final AttributeKey<Map<KexProposalOption, String>> 
PROPOSAL_KEY = new AttributeKey<>();
+
+    public static final NavigableSet<String> DEFAULT_EXTRA_SIGNATURES =
+        Collections.unmodifiableNavigableSet(
+            GenericUtils.asSortedSet(String.CASE_INSENSITIVE_ORDER,
+                KeyUtils.RSA_SHA256_KEY_TYPE_ALIAS,
+                KeyUtils.RSA_SHA512_KEY_TYPE_ALIAS));
+
+    public static final DefaultClientKexExtensionHandler INSTANCE = new 
DefaultClientKexExtensionHandler();
+
+    public DefaultClientKexExtensionHandler() {
+        super();
+    }
+
+    @Override
+    public boolean isKexExtensionsAvailable(Session session) throws 
IOException {
+        return (session != null) && (!session.isServerSession());
+    }
+
+    @Override
+    public void handleKexInitProposal(
+            Session session, boolean initiator, Map<KexProposalOption, String> 
proposal)
+                throws IOException {
+        if (session.isServerSession()) {
+            return; // just in case
+        }
+
+        boolean debugEnabled = log.isDebugEnabled();
+        Collection<String> extraAlgos = getExtraSignatureAlgorithms(session);
+        if (GenericUtils.isEmpty(extraAlgos)) {
+            if (debugEnabled) {
+                log.debug("handleKexInitProposal({}) no extra signatures to 
add to {}", session, proposal);
+            }
+            return;
+        }
+
+        Collection<? extends NamedResource> sigList = 
session.getSignatureFactories();
+        long existCount = sigList.stream().filter(f -> 
extraAlgos.contains(f.getName())).count();
+        if (existCount == extraAlgos.size()) {
+            if (debugEnabled) {
+                log.debug("handleKexInitProposal({}) required extra signatures 
({}) already supported for {}",
+                    session, extraAlgos, proposal);
+            }
+            return;
+        }
+
+        if (initiator) {
+            session.setAttribute(PROPOSAL_KEY, new EnumMap<>(proposal));
+            if (debugEnabled) {
+                log.debug("handleKexInitProposal({}) initial proposal={}", 
session, proposal);
+            }
+            return;
+        }
+
+        // Check if client already sent its proposal - if not, we can still 
influence it
+        Map<KexProposalOption, String> sentProposal = 
session.getAttribute(PROPOSAL_KEY);
+        if (GenericUtils.isNotEmpty(sentProposal)) {
+            if (debugEnabled) {
+                log.debug("handleKexInitProposal({}) already sent proposal={} 
(server={})",
+                    session, sentProposal, proposal);
+            }
+            return;
+        }
+
+        String algos = proposal.get(KexProposalOption.ALGORITHMS);
+        String extDeclared = Stream.of(GenericUtils.split(algos, ','))
+            .filter(s -> 
KexExtensions.SERVER_KEX_EXTENSION.equalsIgnoreCase(s))
+            .findFirst()
+            .orElse(null);
+        if (GenericUtils.isEmpty(extDeclared)) {
+            if (debugEnabled) {
+                log.debug("handleKexInitProposal({}) server proposal={} does 
not include extension indicator",
+                    session, proposal);
+            }
+            return;
+        }
+
+        updateAvailableSignatureFactories(session, extraAlgos);
+    }
+
+    protected Collection<String> getExtraSignatureAlgorithms(Session session) 
throws IOException {
+        return DEFAULT_EXTRA_SIGNATURES;
+    }
+
+    @Override
+    public boolean handleKexExtensionRequest(
+            Session session, int index, int count, String name, byte[] data)
+                throws IOException {
+        if (!ServerSignatureAlgorithms.NAME.equalsIgnoreCase(name)) {
+            return true;    // process next extension (if available)
+        }
+
+        Collection<String> sigAlgos = 
ServerSignatureAlgorithms.INSTANCE.parseExtension(data);
+        updateAvailableSignatureFactories(session, sigAlgos);
+        return false;   // don't care about any more extensions (for now)
+    }
+
+    public List<NamedFactory<Signature>> updateAvailableSignatureFactories(
+            Session session, Collection<String> extraAlgos)
+                throws IOException {
+        List<NamedFactory<Signature>> available = 
session.getSignatureFactories();
+        List<NamedFactory<Signature>> updated =
+            resolveUpdatedSignatureFactories(session, available, extraAlgos);
+        if (!GenericUtils.isSameReference(available, updated)) {
+            if (log.isDebugEnabled()) {
+                log.debug("updateAvailableSignatureFactories({}) available={}, 
updated={}",
+                    session, available, updated);
+            }
+            session.setSignatureFactories(updated);
+        }
+
+        return updated;
+    }
+
+    /**
+     * Checks if the extra signature algorithms are already included in the 
available ones,
+     * and adds the extra ones (if supported).
+     *
+     * @param session The {@link Session} for which the resolution occurs
+     * @param available The available signature factories
+     * @param extraAlgos The extra requested signatures - ignored if {@code 
null}/empty
+     * @return The resolved signature factories - same as input if nothing 
added
+     * @throws IOException If failed to resolve the factories
+     */
+    public List<NamedFactory<Signature>> resolveUpdatedSignatureFactories(
+            Session session, List<NamedFactory<Signature>> available, 
Collection<String> extraAlgos)
+                throws IOException {
+        boolean debugEnabled = log.isDebugEnabled();
+        List<NamedFactory<Signature>> toAdd =
+            resolveRequestedSignatureFactories(session, extraAlgos);
+        if (GenericUtils.isEmpty(toAdd)) {
+            if (debugEnabled) {
+                log.debug("resolveUpdatedSignatureFactories({}) Nothing to add 
to {} out of {}",
+                    session, NamedResource.getNames(available), extraAlgos);
+            }
+            return available;
+        }
+
+        for (int index = 0; index < toAdd.size(); index++) {
+            NamedFactory<Signature> f = toAdd.get(index);
+            String name = f.getName();
+            NamedFactory<Signature> a = available.stream()
+                .filter(s -> Objects.equals(name, s.getName()))
+                .findFirst()
+                .orElse(null);
+            if (a == null) {
+                continue;
+            }
+
+            if (debugEnabled) {
+                log.debug("resolveUpdatedSignatureFactories({}) skip {} - 
already available", session, name);
+            }
+
+            toAdd.remove(index);
+            index--;    // compensate for loop auto-increment
+        }
+
+        return updateAvailableSignatureFactories(session, available, toAdd);
+    }
+
+    public List<NamedFactory<Signature>> updateAvailableSignatureFactories(
+            Session session, List<NamedFactory<Signature>> available, 
Collection<? extends NamedFactory<Signature>> toAdd)
+                throws IOException {
+        boolean debugEnabled = log.isDebugEnabled();
+        if (GenericUtils.isEmpty(toAdd)) {
+            if (debugEnabled) {
+                log.debug("updateAvailableSignatureFactories({}) nothing to 
add to {}",
+                    session, NamedResource.getNames(available));
+            }
+            return available;
+        }
+
+        List<NamedFactory<Signature>> updated =
+            new ArrayList<>(available.size() + toAdd.size());
+        updated.addAll(available);
+
+        for (NamedFactory<Signature> f : toAdd) {
+            int index = resolvePreferredSignaturePosition(session, updated, f);
+            if (debugEnabled) {
+                log.debug("updateAvailableSignatureFactories({}) add {} at 
position={}", session, f, index);
+            }
+            if ((index < 0) || (index >= updated.size())) {
+                updated.add(f);
+            } else {
+                updated.add(index, f);
+            }
+        }
+
+        return updated;
+    }
+
+    public int resolvePreferredSignaturePosition(
+            Session session, List<? extends NamedFactory<Signature>> 
factories, NamedFactory<Signature> factory)
+                throws IOException {
+        return SignatureFactory.resolvePreferredSignaturePosition(factories, 
factory);
+    }
+
+    public List<NamedFactory<Signature>> resolveRequestedSignatureFactories(
+            Session session, Collection<String> extraAlgos)
+                throws IOException {
+        if (GenericUtils.isEmpty(extraAlgos)) {
+            return Collections.emptyList();
+        }
+
+        List<NamedFactory<Signature>> toAdd = Collections.emptyList();
+        boolean debugEnabled = log.isDebugEnabled();
+        for (String algo : extraAlgos) {
+            NamedFactory<Signature> factory = 
resolveRequestedSignatureFactory(session, algo);
+            if (factory == null) {
+                if (debugEnabled) {
+                    log.debug("resolveRequestedSignatureFactories({}) skip {} 
- no factory found", session, algo);
+                }
+                continue;
+            }
+
+            if ((factory instanceof OptionalFeature) && (!((OptionalFeature) 
factory).isSupported())) {
+                if (debugEnabled) {
+                    log.debug("resolveRequestedSignatureFactories({}) skip {} 
- not supported", session, algo);
+                }
+                continue;
+            }
+
+            if (toAdd.isEmpty()) {
+                toAdd = new ArrayList<>(extraAlgos.size());
+            }
+            toAdd.add(factory);
+        }
+
+        return toAdd;
+    }
+
+    public NamedFactory<Signature> resolveRequestedSignatureFactory(Session 
session, String name) throws IOException {
+        return BuiltinSignatures.fromFactoryName(name);
+    }
+}
diff --git 
a/sshd-core/src/main/java/org/apache/sshd/common/kex/extension/KexExtensionHandler.java
 
b/sshd-core/src/main/java/org/apache/sshd/common/kex/extension/KexExtensionHandler.java
index 0464bc3..45c6384 100644
--- 
a/sshd-core/src/main/java/org/apache/sshd/common/kex/extension/KexExtensionHandler.java
+++ 
b/sshd-core/src/main/java/org/apache/sshd/common/kex/extension/KexExtensionHandler.java
@@ -22,8 +22,10 @@ package org.apache.sshd.common.kex.extension;
 import java.io.IOException;
 import java.util.Collections;
 import java.util.EnumSet;
+import java.util.Map;
 import java.util.Set;
 
+import org.apache.sshd.common.kex.KexProposalOption;
 import org.apache.sshd.common.session.Session;
 import org.apache.sshd.common.util.buffer.Buffer;
 
@@ -43,7 +45,7 @@ public interface KexExtensionHandler {
 
     /**
      * @param session The {@link Session} about to execute KEX
-     * @return {@code true} whether to declare KEX extensions availability for 
the session
+     * @return {@code true} whether to KEX extensions are supported/allowed 
for the session
      * @throws IOException If failed to process the request
      */
     default boolean isKexExtensionsAvailable(Session session) throws 
IOException {
@@ -51,8 +53,51 @@ public interface KexExtensionHandler {
     }
 
     /**
-     * Invoked in order to allow the handler to send an {@code 
SSH_MSG_EXT_INFO} message.
+     * Invoked when a peer is ready to send the KEX options proposal or has 
received
+     * such a proposal. <B>Note:</B> this method is called during the 
negotiation phase
+     * even if {@link #isKexExtensionsAvailable(Session)} returns {@code 
false} for the session.
      *
+     * @param session The {@link Session} initiating or receiving the proposal
+     * @param initiator {@code true} if the proposal is about to be sent, 
{@code false}
+     * if this is a proposal received from the peer.
+     * @param proposal The proposal contents -  <B>Caveat emptor:</B> the 
proposal is
+     * <U>modifiable</U> i.e., the handler can modify before being sent or 
before
+     * being processed (if incoming)
+     * @throws IOException If failed to handle the request
+     */
+    default void handleKexInitProposal(
+            Session session, boolean initiator, Map<KexProposalOption, String> 
proposal)
+                throws IOException {
+        // ignored
+    }
+
+    /**
+     * Invoked during the KEX negotiation phase to inform about option
+     * being negotiated. <B>Note:</B> this method is called during the
+     * negotiation phase even if {@link #isKexExtensionsAvailable(Session)}
+     * returns {@code false} for the session.
+     *
+     * @param session The {@link Session} executing the negotiation
+     * @param option The negotiated {@link KexProposalOption}
+     * @param nValue The negotiated option value (may be {@code null}/empty).
+     * @param c2sOptions The client proposals
+     * @param cValue The client-side value for the option (may be {@code 
null}/empty).
+     * @param s2cOptions The server proposals
+     * @param sValue The server-side value for the option (may be {@code 
null}/empty).
+     * @throws IOException If failed to handle the invocation
+     */
+    default void handleKexExtensionNegotiation(
+            Session session, KexProposalOption option, String nValue,
+            Map<KexProposalOption, String> c2sOptions, String cValue,
+            Map<KexProposalOption, String> s2cOptions, String sValue)
+                throws IOException {
+        // do nothing
+    }
+
+    /**
+     * Invoked in order to allow the handler to send an {@code 
SSH_MSG_EXT_INFO} message.
+     * <B>Note:</B> this method is called only if {@link 
#isKexExtensionsAvailable(Session)}
+     * returns {@code true} for the session.
      * @param session The {@link Session}
      * @param phase The phase at which the handler is invoked
      * @throws IOException If failed to handle the invocation
@@ -63,7 +108,9 @@ public interface KexExtensionHandler {
     }
 
     /**
-     * Parses the {@code SSH_MSG_EXT_INFO} message
+     * Parses the {@code SSH_MSG_EXT_INFO} message. <B>Note:</B> this method
+     * is called only if {@link #isKexExtensionsAvailable(Session)} returns
+     * {@code true} for the session.
      *
      * @param session The {@link Session} through which the message was 
received
      * @param buffer The message buffer
diff --git 
a/sshd-core/src/main/java/org/apache/sshd/common/session/Session.java 
b/sshd-core/src/main/java/org/apache/sshd/common/session/Session.java
index 7943e6a..ee55a26 100644
--- a/sshd-core/src/main/java/org/apache/sshd/common/session/Session.java
+++ b/sshd-core/src/main/java/org/apache/sshd/common/session/Session.java
@@ -67,6 +67,14 @@ public interface Session
                 Closeable {
 
     /**
+     * Quick indication if this is a server or client session (instead of
+     * having to ask {@code instanceof}).
+     *
+     * @return {@code true} if this is a server session
+     */
+    boolean isServerSession();
+
+    /**
      * Timeout status.
      */
     enum TimeoutStatus {
diff --git 
a/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/AbstractSession.java
 
b/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/AbstractSession.java
index 1dfc381..11c6d87 100644
--- 
a/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/AbstractSession.java
+++ 
b/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/AbstractSession.java
@@ -1257,8 +1257,9 @@ public abstract class AbstractSession extends 
SessionHelper {
      * @param buffer   the {@link Buffer} containing the key exchange init 
packet
      * @param proposal the remote proposal to fill
      * @return the packet data
+     * @throws IOException If failed to handle the message
      */
-    protected byte[] receiveKexInit(Buffer buffer, Map<KexProposalOption, 
String> proposal) {
+    protected byte[] receiveKexInit(Buffer buffer, Map<KexProposalOption, 
String> proposal) throws IOException {
         // Recreate the packet payload which will be needed at a later time
         byte[] d = buffer.array();
         byte[] data = new byte[buffer.available() + 1 /* the opcode */];
@@ -1290,13 +1291,26 @@ public abstract class AbstractSession extends 
SessionHelper {
             size += readLen;
         }
 
+        KexExtensionHandler extHandler = getKexExtensionHandler();
+        if (extHandler != null) {
+            if (traceEnabled) {
+                log.trace("receiveKexInit({}) options before handler: {}", 
this, proposal);
+            }
+
+            extHandler.handleKexInitProposal(this, false, proposal);
+
+            if (traceEnabled) {
+                log.trace("receiveKexInit({}) options after handler: {}", 
this, proposal);
+            }
+        }
+
         firstKexPacketFollows = buffer.getBoolean();
         if (traceEnabled) {
             log.trace("receiveKexInit({}) first kex packet follows: {}", this, 
firstKexPacketFollows);
         }
 
         long reserved = buffer.getUInt();
-        if (reserved != 0) {
+        if (reserved != 0L) {
             if (traceEnabled) {
                 log.trace("receiveKexInit({}) non-zero reserved value: {}", 
this, reserved);
             }
@@ -1489,6 +1503,7 @@ public abstract class AbstractSession extends 
SessionHelper {
             boolean debugEnabled = log.isDebugEnabled();
             boolean traceEnabled = log.isTraceEnabled();
             SessionDisconnectHandler discHandler = 
getSessionDisconnectHandler();
+            KexExtensionHandler extHandler = getKexExtensionHandler();
             for (KexProposalOption paramType : KexProposalOption.VALUES) {
                 String clientParamValue = c2sOptions.get(paramType);
                 String serverParamValue = s2cOptions.get(paramType);
@@ -1517,6 +1532,11 @@ public abstract class AbstractSession extends 
SessionHelper {
 
                 // check if reached an agreement
                 String value = guess.get(paramType);
+                if (extHandler != null) {
+                    extHandler.handleKexExtensionNegotiation(
+                        this, paramType, value, c2sOptions, clientParamValue, 
s2cOptions, serverParamValue);
+                }
+
                 if (value != null) {
                     if (traceEnabled) {
                         log.trace("negotiate({})[{}] guess={} (client={} / 
server={})",
@@ -1553,8 +1573,7 @@ public abstract class AbstractSession extends 
SessionHelper {
              *      key exchange method, the parties MUST disconnect.
              */
             String kexOption = guess.get(KexProposalOption.ALGORITHMS);
-            if (KexExtensions.CLIENT_KEX_EXTENSION.equalsIgnoreCase(kexOption)
-                    || 
KexExtensions.SERVER_KEX_EXTENSION.equalsIgnoreCase(kexOption)) {
+            if (KexExtensions.IS_KEX_EXTENSION_SIGNAL.test(kexOption)) {
                 if ((discHandler != null)
                         && discHandler.handleKexDisconnectReason(
                                 this, c2sOptions, s2cOptions, negotiatedGuess, 
KexProposalOption.ALGORITHMS)) {
@@ -1923,17 +1942,31 @@ public abstract class AbstractSession extends 
SessionHelper {
         String resolvedAlgorithms = resolveAvailableSignaturesProposal();
         if (GenericUtils.isEmpty(resolvedAlgorithms)) {
             throw new 
SshException(SshConstants.SSH2_DISCONNECT_HOST_KEY_NOT_VERIFIABLE,
-                    "sendKexInit() no resolved signatures available");
+                "sendKexInit() no resolved signatures available");
         }
 
         Map<KexProposalOption, String> proposal = 
createProposal(resolvedAlgorithms);
+        KexExtensionHandler extHandler = getKexExtensionHandler();
+        boolean traceEnabled = log.isTraceEnabled();
+        if (extHandler != null) {
+            if (traceEnabled) {
+                log.trace("sendKexInit({}) options before handler: {}", this, 
proposal);
+            }
+
+            extHandler.handleKexInitProposal(this, true, proposal);
+
+            if (traceEnabled) {
+                log.trace("sendKexInit({}) options after handler: {}", this, 
proposal);
+            }
+        }
+
         byte[] seed;
         synchronized (kexState) {
             seed = sendKexInit(proposal);
             setKexSeed(seed);
         }
 
-        if (log.isTraceEnabled()) {
+        if (traceEnabled) {
             log.trace("sendKexInit({}) proposal={} seed: {}", this, proposal, 
BufferUtils.toHex(':', seed));
         }
         return seed;
diff --git 
a/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/SessionHelper.java
 
b/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/SessionHelper.java
index 1bc8aea..86f479b 100644
--- 
a/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/SessionHelper.java
+++ 
b/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/SessionHelper.java
@@ -131,6 +131,7 @@ public abstract class SessionHelper extends 
AbstractKexFactoryManager implements
         return ioSession;
     }
 
+    @Override
     public boolean isServerSession() {
         return serverSession;
     }

Reply via email to