This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch branch-2.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-2.4 by this push:
     new 3b64e2f  [SPARK-30129][CORE][2.4] Set client's id in TransportClient 
after successful auth
3b64e2f is described below

commit 3b64e2f35657ef0a4001a4f0926ead8cd9226a28
Author: Marcelo Vanzin <van...@cloudera.com>
AuthorDate: Thu Dec 5 09:02:10 2019 -0800

    [SPARK-30129][CORE][2.4] Set client's id in TransportClient after 
successful auth
    
    The new auth code was missing this bit, so it was not possible to know which
    app a client belonged to when auth was on.
    
    I also refactored the SASL test that checks for this so it also checks the
    new protocol (test failed before the fix, passes now).
    
    Closes #26764 from vanzin/SPARK-30129-2.4.
    
    Authored-by: Marcelo Vanzin <van...@cloudera.com>
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
---
 .../spark/network/crypto/AuthClientBootstrap.java  |   1 +
 .../spark/network/crypto/AuthRpcHandler.java       |   1 +
 .../spark/network/sasl/SaslIntegrationSuite.java   | 117 -------------
 .../spark/network/shuffle/AppIsolationSuite.java   | 184 +++++++++++++++++++++
 4 files changed, 186 insertions(+), 117 deletions(-)

diff --git 
a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthClientBootstrap.java
 
b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthClientBootstrap.java
index 3c26378..737e187 100644
--- 
a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthClientBootstrap.java
+++ 
b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthClientBootstrap.java
@@ -77,6 +77,7 @@ public class AuthClientBootstrap implements 
TransportClientBootstrap {
 
     try {
       doSparkAuth(client, channel);
+      client.setClientId(appId);
     } catch (GeneralSecurityException | IOException e) {
       throw Throwables.propagate(e);
     } catch (RuntimeException e) {
diff --git 
a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java
 
b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java
index fb44dbb..821cc7a 100644
--- 
a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java
+++ 
b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java
@@ -125,6 +125,7 @@ class AuthRpcHandler extends RpcHandler {
       response.encode(responseData);
       callback.onSuccess(responseData.nioBuffer());
       engine.sessionCipher().addToChannel(channel);
+      client.setClientId(challenge.appId);
     } catch (Exception e) {
       // This is a fatal error: authentication has failed. Close the channel 
explicitly.
       LOG.debug("Authentication failed for client {}, closing channel.", 
channel.remoteAddress());
diff --git 
a/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
 
b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
index 02e6eb3..0ef01ea 100644
--- 
a/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
+++ 
b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
@@ -21,8 +21,6 @@ import java.io.IOException;
 import java.nio.ByteBuffer;
 import java.util.ArrayList;
 import java.util.Arrays;
-import java.util.concurrent.CountDownLatch;
-import java.util.concurrent.atomic.AtomicReference;
 
 import org.junit.After;
 import org.junit.AfterClass;
@@ -34,8 +32,6 @@ import static org.mockito.Mockito.*;
 
 import org.apache.spark.network.TestUtils;
 import org.apache.spark.network.TransportContext;
-import org.apache.spark.network.buffer.ManagedBuffer;
-import org.apache.spark.network.client.ChunkReceivedCallback;
 import org.apache.spark.network.client.RpcResponseCallback;
 import org.apache.spark.network.client.TransportClient;
 import org.apache.spark.network.client.TransportClientFactory;
@@ -44,15 +40,6 @@ import org.apache.spark.network.server.RpcHandler;
 import org.apache.spark.network.server.StreamManager;
 import org.apache.spark.network.server.TransportServer;
 import org.apache.spark.network.server.TransportServerBootstrap;
-import org.apache.spark.network.shuffle.BlockFetchingListener;
-import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler;
-import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver;
-import org.apache.spark.network.shuffle.OneForOneBlockFetcher;
-import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
-import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
-import org.apache.spark.network.shuffle.protocol.OpenBlocks;
-import org.apache.spark.network.shuffle.protocol.RegisterExecutor;
-import org.apache.spark.network.shuffle.protocol.StreamHandle;
 import org.apache.spark.network.util.JavaUtils;
 import org.apache.spark.network.util.MapConfigProvider;
 import org.apache.spark.network.util.TransportConf;
@@ -163,104 +150,6 @@ public class SaslIntegrationSuite {
     }
   }
 
-  /**
-   * This test is not actually testing SASL behavior, but testing that the 
shuffle service
-   * performs correct authorization checks based on the SASL authentication 
data.
-   */
-  @Test
-  public void testAppIsolation() throws Exception {
-    // Start a new server with the correct RPC handler to serve block data.
-    ExternalShuffleBlockResolver blockResolver = 
mock(ExternalShuffleBlockResolver.class);
-    ExternalShuffleBlockHandler blockHandler = new ExternalShuffleBlockHandler(
-      new OneForOneStreamManager(), blockResolver);
-    TransportServerBootstrap bootstrap = new SaslServerBootstrap(conf, 
secretKeyHolder);
-    TransportContext blockServerContext = new TransportContext(conf, 
blockHandler);
-    TransportServer blockServer = 
blockServerContext.createServer(Arrays.asList(bootstrap));
-
-    TransportClient client1 = null;
-    TransportClient client2 = null;
-    TransportClientFactory clientFactory2 = null;
-    try {
-      // Create a client, and make a request to fetch blocks from a different 
app.
-      clientFactory = blockServerContext.createClientFactory(
-          Arrays.asList(new SaslClientBootstrap(conf, "app-1", 
secretKeyHolder)));
-      client1 = clientFactory.createClient(TestUtils.getLocalHost(),
-        blockServer.getPort());
-
-      AtomicReference<Throwable> exception = new AtomicReference<>();
-
-      CountDownLatch blockFetchLatch = new CountDownLatch(1);
-      BlockFetchingListener listener = new BlockFetchingListener() {
-        @Override
-        public void onBlockFetchSuccess(String blockId, ManagedBuffer data) {
-          blockFetchLatch.countDown();
-        }
-        @Override
-        public void onBlockFetchFailure(String blockId, Throwable t) {
-          exception.set(t);
-          blockFetchLatch.countDown();
-        }
-      };
-
-      String[] blockIds = { "shuffle_0_1_2", "shuffle_0_3_4" };
-      OneForOneBlockFetcher fetcher =
-          new OneForOneBlockFetcher(client1, "app-2", "0", blockIds, listener, 
conf);
-      fetcher.start();
-      blockFetchLatch.await();
-      checkSecurityException(exception.get());
-
-      // Register an executor so that the next steps work.
-      ExecutorShuffleInfo executorInfo = new ExecutorShuffleInfo(
-        new String[] { System.getProperty("java.io.tmpdir") }, 1,
-          "org.apache.spark.shuffle.sort.SortShuffleManager");
-      RegisterExecutor regmsg = new RegisterExecutor("app-1", "0", 
executorInfo);
-      client1.sendRpcSync(regmsg.toByteBuffer(), TIMEOUT_MS);
-
-      // Make a successful request to fetch blocks, which creates a new 
stream. But do not actually
-      // fetch any blocks, to keep the stream open.
-      OpenBlocks openMessage = new OpenBlocks("app-1", "0", blockIds);
-      ByteBuffer response = client1.sendRpcSync(openMessage.toByteBuffer(), 
TIMEOUT_MS);
-      StreamHandle stream = (StreamHandle) 
BlockTransferMessage.Decoder.fromByteBuffer(response);
-      long streamId = stream.streamId;
-
-      // Create a second client, authenticated with a different app ID, and 
try to read from
-      // the stream created for the previous app.
-      clientFactory2 = blockServerContext.createClientFactory(
-          Arrays.asList(new SaslClientBootstrap(conf, "app-2", 
secretKeyHolder)));
-      client2 = clientFactory2.createClient(TestUtils.getLocalHost(),
-        blockServer.getPort());
-
-      CountDownLatch chunkReceivedLatch = new CountDownLatch(1);
-      ChunkReceivedCallback callback = new ChunkReceivedCallback() {
-        @Override
-        public void onSuccess(int chunkIndex, ManagedBuffer buffer) {
-          chunkReceivedLatch.countDown();
-        }
-        @Override
-        public void onFailure(int chunkIndex, Throwable t) {
-          exception.set(t);
-          chunkReceivedLatch.countDown();
-        }
-      };
-
-      exception.set(null);
-      client2.fetchChunk(streamId, 0, callback);
-      chunkReceivedLatch.await();
-      checkSecurityException(exception.get());
-    } finally {
-      if (client1 != null) {
-        client1.close();
-      }
-      if (client2 != null) {
-        client2.close();
-      }
-      if (clientFactory2 != null) {
-        clientFactory2.close();
-      }
-      blockServer.close();
-    }
-  }
-
   /** RPC handler which simply responds with the message it received. */
   public static class TestRpcHandler extends RpcHandler {
     @Override
@@ -273,10 +162,4 @@ public class SaslIntegrationSuite {
       return new OneForOneStreamManager();
     }
   }
-
-  private static void checkSecurityException(Throwable t) {
-    assertNotNull("No exception was caught.", t);
-    assertTrue("Expected SecurityException.",
-      t.getMessage().contains(SecurityException.class.getName()));
-  }
 }
diff --git 
a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/AppIsolationSuite.java
 
b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/AppIsolationSuite.java
new file mode 100644
index 0000000..4a3bebb
--- /dev/null
+++ 
b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/AppIsolationSuite.java
@@ -0,0 +1,184 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import java.nio.ByteBuffer;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.Function;
+import java.util.function.Supplier;
+
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import static org.junit.Assert.*;
+import static org.mockito.Mockito.*;
+
+import org.apache.spark.network.TestUtils;
+import org.apache.spark.network.TransportContext;
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.client.ChunkReceivedCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportClientBootstrap;
+import org.apache.spark.network.client.TransportClientFactory;
+import org.apache.spark.network.crypto.AuthClientBootstrap;
+import org.apache.spark.network.crypto.AuthServerBootstrap;
+import org.apache.spark.network.sasl.SaslClientBootstrap;
+import org.apache.spark.network.sasl.SaslServerBootstrap;
+import org.apache.spark.network.sasl.SecretKeyHolder;
+import org.apache.spark.network.server.OneForOneStreamManager;
+import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.server.TransportServerBootstrap;
+import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
+import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
+import org.apache.spark.network.shuffle.protocol.OpenBlocks;
+import org.apache.spark.network.shuffle.protocol.RegisterExecutor;
+import org.apache.spark.network.shuffle.protocol.StreamHandle;
+import org.apache.spark.network.util.MapConfigProvider;
+import org.apache.spark.network.util.TransportConf;
+
+public class AppIsolationSuite {
+
+  // Use a long timeout to account for slow / overloaded build machines. In 
the normal case,
+  // tests should finish way before the timeout expires.
+  private static final long TIMEOUT_MS = 10_000;
+
+  private static SecretKeyHolder secretKeyHolder;
+  private static TransportConf conf;
+
+  @BeforeClass
+  public static void beforeAll() {
+    Map<String, String> confMap = new HashMap<>();
+    confMap.put("spark.network.crypto.enabled", "true");
+    confMap.put("spark.network.crypto.saslFallback", "false");
+    conf = new TransportConf("shuffle", new MapConfigProvider(confMap));
+
+    secretKeyHolder = mock(SecretKeyHolder.class);
+    when(secretKeyHolder.getSaslUser(eq("app-1"))).thenReturn("app-1");
+    when(secretKeyHolder.getSecretKey(eq("app-1"))).thenReturn("app-1");
+    when(secretKeyHolder.getSaslUser(eq("app-2"))).thenReturn("app-2");
+    when(secretKeyHolder.getSecretKey(eq("app-2"))).thenReturn("app-2");
+  }
+
+  @Test
+  public void testSaslAppIsolation() throws Exception {
+    testAppIsolation(
+      () -> new SaslServerBootstrap(conf, secretKeyHolder),
+      appId -> new SaslClientBootstrap(conf, appId, secretKeyHolder));
+  }
+
+  @Test
+  public void testAuthEngineAppIsolation() throws Exception {
+    testAppIsolation(
+      () -> new AuthServerBootstrap(conf, secretKeyHolder),
+      appId -> new AuthClientBootstrap(conf, appId, secretKeyHolder));
+  }
+
+  private void testAppIsolation(
+      Supplier<TransportServerBootstrap> serverBootstrap,
+      Function<String, TransportClientBootstrap> clientBootstrapFactory) 
throws Exception {
+    // Start a new server with the correct RPC handler to serve block data.
+    ExternalShuffleBlockResolver blockResolver = 
mock(ExternalShuffleBlockResolver.class);
+    ExternalShuffleBlockHandler blockHandler = new ExternalShuffleBlockHandler(
+      new OneForOneStreamManager(), blockResolver);
+    TransportServerBootstrap bootstrap = serverBootstrap.get();
+    TransportContext blockServerContext = new TransportContext(conf, 
blockHandler);
+
+    try (
+      TransportServer blockServer = 
blockServerContext.createServer(Arrays.asList(bootstrap));
+      // Create a client, and make a request to fetch blocks from a different 
app.
+      TransportClientFactory clientFactory1 = 
blockServerContext.createClientFactory(
+          Arrays.asList(clientBootstrapFactory.apply("app-1")));
+      TransportClient client1 = clientFactory1.createClient(
+          TestUtils.getLocalHost(), blockServer.getPort())) {
+
+      AtomicReference<Throwable> exception = new AtomicReference<>();
+
+      CountDownLatch blockFetchLatch = new CountDownLatch(1);
+      BlockFetchingListener listener = new BlockFetchingListener() {
+        @Override
+        public void onBlockFetchSuccess(String blockId, ManagedBuffer data) {
+          blockFetchLatch.countDown();
+        }
+        @Override
+        public void onBlockFetchFailure(String blockId, Throwable t) {
+          exception.set(t);
+          blockFetchLatch.countDown();
+        }
+      };
+
+      String[] blockIds = { "shuffle_0_1_2", "shuffle_0_3_4" };
+      OneForOneBlockFetcher fetcher =
+          new OneForOneBlockFetcher(client1, "app-2", "0", blockIds, listener, 
conf);
+      fetcher.start();
+      blockFetchLatch.await();
+      checkSecurityException(exception.get());
+
+      // Register an executor so that the next steps work.
+      ExecutorShuffleInfo executorInfo = new ExecutorShuffleInfo(
+        new String[] { System.getProperty("java.io.tmpdir") }, 1,
+          "org.apache.spark.shuffle.sort.SortShuffleManager");
+      RegisterExecutor regmsg = new RegisterExecutor("app-1", "0", 
executorInfo);
+      client1.sendRpcSync(regmsg.toByteBuffer(), TIMEOUT_MS);
+
+      // Make a successful request to fetch blocks, which creates a new 
stream. But do not actually
+      // fetch any blocks, to keep the stream open.
+      OpenBlocks openMessage = new OpenBlocks("app-1", "0", blockIds);
+      ByteBuffer response = client1.sendRpcSync(openMessage.toByteBuffer(), 
TIMEOUT_MS);
+      StreamHandle stream = (StreamHandle) 
BlockTransferMessage.Decoder.fromByteBuffer(response);
+      long streamId = stream.streamId;
+
+      try (
+        // Create a second client, authenticated with a different app ID, and 
try to read from
+        // the stream created for the previous app.
+        TransportClientFactory clientFactory2 = 
blockServerContext.createClientFactory(
+            Arrays.asList(clientBootstrapFactory.apply("app-2")));
+        TransportClient client2 = clientFactory2.createClient(
+            TestUtils.getLocalHost(), blockServer.getPort())
+      ) {
+        CountDownLatch chunkReceivedLatch = new CountDownLatch(1);
+        ChunkReceivedCallback callback = new ChunkReceivedCallback() {
+          @Override
+          public void onSuccess(int chunkIndex, ManagedBuffer buffer) {
+            chunkReceivedLatch.countDown();
+          }
+
+          @Override
+          public void onFailure(int chunkIndex, Throwable t) {
+            exception.set(t);
+            chunkReceivedLatch.countDown();
+          }
+        };
+
+        exception.set(null);
+        client2.fetchChunk(streamId, 0, callback);
+        chunkReceivedLatch.await();
+        checkSecurityException(exception.get());
+      }
+    }
+  }
+
+  private static void checkSecurityException(Throwable t) {
+    assertNotNull("No exception was caught.", t);
+    assertTrue("Expected SecurityException.",
+      t.getMessage().contains(SecurityException.class.getName()));
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to