Repository: spark
Updated Branches:
  refs/heads/master 1d4f35520 -> f55218aeb


http://git-wip-us.apache.org/repos/asf/spark/blob/f55218ae/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
----------------------------------------------------------------------
diff --git 
a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
 
b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
new file mode 100644
index 0000000..b3bcf5f
--- /dev/null
+++ 
b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
@@ -0,0 +1,291 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Random;
+import java.util.Set;
+import java.util.concurrent.Semaphore;
+import java.util.concurrent.TimeUnit;
+
+import com.google.common.collect.Lists;
+import com.google.common.collect.Sets;
+import org.junit.After;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import static org.junit.Assert.*;
+
+import org.apache.spark.network.TestUtils;
+import org.apache.spark.network.TransportContext;
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NioManagedBuffer;
+import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.util.SystemPropertyConfigProvider;
+import org.apache.spark.network.util.TransportConf;
+
+public class ExternalShuffleIntegrationSuite {
+
+  static String APP_ID = "app-id";
+  static String SORT_MANAGER = 
"org.apache.spark.shuffle.sort.SortShuffleManager";
+  static String HASH_MANAGER = 
"org.apache.spark.shuffle.hash.HashShuffleManager";
+
+  // Executor 0 is sort-based
+  static TestShuffleDataContext dataContext0;
+  // Executor 1 is hash-based
+  static TestShuffleDataContext dataContext1;
+
+  static ExternalShuffleBlockHandler handler;
+  static TransportServer server;
+  static TransportConf conf;
+
+  static byte[][] exec0Blocks = new byte[][] {
+    new byte[123],
+    new byte[12345],
+    new byte[1234567],
+  };
+
+  static byte[][] exec1Blocks = new byte[][] {
+    new byte[321],
+    new byte[54321],
+  };
+
+  @BeforeClass
+  public static void beforeAll() throws IOException {
+    Random rand = new Random();
+
+    for (byte[] block : exec0Blocks) {
+      rand.nextBytes(block);
+    }
+    for (byte[] block: exec1Blocks) {
+      rand.nextBytes(block);
+    }
+
+    dataContext0 = new TestShuffleDataContext(2, 5);
+    dataContext0.create();
+    dataContext0.insertSortShuffleData(0, 0, exec0Blocks);
+
+    dataContext1 = new TestShuffleDataContext(6, 2);
+    dataContext1.create();
+    dataContext1.insertHashShuffleData(1, 0, exec1Blocks);
+
+    conf = new TransportConf(new SystemPropertyConfigProvider());
+    handler = new ExternalShuffleBlockHandler();
+    TransportContext transportContext = new TransportContext(conf, handler);
+    server = transportContext.createServer();
+  }
+
+  @AfterClass
+  public static void afterAll() {
+    dataContext0.cleanup();
+    dataContext1.cleanup();
+    server.close();
+  }
+
+  @After
+  public void afterEach() {
+    handler.clearRegisteredExecutors();
+  }
+
+  class FetchResult {
+    public Set<String> successBlocks;
+    public Set<String> failedBlocks;
+    public List<ManagedBuffer> buffers;
+
+    public void releaseBuffers() {
+      for (ManagedBuffer buffer : buffers) {
+        buffer.release();
+      }
+    }
+  }
+
+  // Fetch a set of blocks from a pre-registered executor.
+  private FetchResult fetchBlocks(String execId, String[] blockIds) throws 
Exception {
+    return fetchBlocks(execId, blockIds, server.getPort());
+  }
+
+  // Fetch a set of blocks from a pre-registered executor. Connects to the 
server on the given port,
+  // to allow connecting to invalid servers.
+  private FetchResult fetchBlocks(String execId, String[] blockIds, int port) 
throws Exception {
+    final FetchResult res = new FetchResult();
+    res.successBlocks = Collections.synchronizedSet(new HashSet<String>());
+    res.failedBlocks = Collections.synchronizedSet(new HashSet<String>());
+    res.buffers = Collections.synchronizedList(new 
LinkedList<ManagedBuffer>());
+
+    final Semaphore requestsRemaining = new Semaphore(0);
+
+    ExternalShuffleClient client = new ExternalShuffleClient(conf, APP_ID);
+    client.fetchBlocks(TestUtils.getLocalHost(), port, execId, blockIds,
+      new BlockFetchingListener() {
+        @Override
+        public void onBlockFetchSuccess(String blockId, ManagedBuffer data) {
+          synchronized (this) {
+            if (!res.successBlocks.contains(blockId) && 
!res.failedBlocks.contains(blockId)) {
+              data.retain();
+              res.successBlocks.add(blockId);
+              res.buffers.add(data);
+              requestsRemaining.release();
+            }
+          }
+        }
+
+        @Override
+        public void onBlockFetchFailure(String blockId, Throwable exception) {
+          synchronized (this) {
+            if (!res.successBlocks.contains(blockId) && 
!res.failedBlocks.contains(blockId)) {
+              res.failedBlocks.add(blockId);
+              requestsRemaining.release();
+            }
+          }
+        }
+      });
+
+    if (!requestsRemaining.tryAcquire(blockIds.length, 5, TimeUnit.SECONDS)) {
+      fail("Timeout getting response from the server");
+    }
+    return res;
+  }
+
+  @Test
+  public void testFetchOneSort() throws Exception {
+    registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER));
+    FetchResult exec0Fetch = fetchBlocks("exec-0", new String[] { 
"shuffle_0_0_0" });
+    assertEquals(Sets.newHashSet("shuffle_0_0_0"), exec0Fetch.successBlocks);
+    assertTrue(exec0Fetch.failedBlocks.isEmpty());
+    assertBufferListsEqual(exec0Fetch.buffers, 
Lists.newArrayList(exec0Blocks[0]));
+    exec0Fetch.releaseBuffers();
+  }
+
+  @Test
+  public void testFetchThreeSort() throws Exception {
+    registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER));
+    FetchResult exec0Fetch = fetchBlocks("exec-0",
+      new String[] { "shuffle_0_0_0", "shuffle_0_0_1", "shuffle_0_0_2" });
+    assertEquals(Sets.newHashSet("shuffle_0_0_0", "shuffle_0_0_1", 
"shuffle_0_0_2"),
+      exec0Fetch.successBlocks);
+    assertTrue(exec0Fetch.failedBlocks.isEmpty());
+    assertBufferListsEqual(exec0Fetch.buffers, 
Lists.newArrayList(exec0Blocks));
+    exec0Fetch.releaseBuffers();
+  }
+
+  @Test
+  public void testFetchHash() throws Exception {
+    registerExecutor("exec-1", dataContext1.createExecutorInfo(HASH_MANAGER));
+    FetchResult execFetch = fetchBlocks("exec-1",
+      new String[] { "shuffle_1_0_0", "shuffle_1_0_1" });
+    assertEquals(Sets.newHashSet("shuffle_1_0_0", "shuffle_1_0_1"), 
execFetch.successBlocks);
+    assertTrue(execFetch.failedBlocks.isEmpty());
+    assertBufferListsEqual(execFetch.buffers, Lists.newArrayList(exec1Blocks));
+    execFetch.releaseBuffers();
+  }
+
+  @Test
+  public void testFetchWrongShuffle() throws Exception {
+    registerExecutor("exec-1", dataContext1.createExecutorInfo(SORT_MANAGER /* 
wrong manager */));
+    FetchResult execFetch = fetchBlocks("exec-1",
+      new String[] { "shuffle_1_0_0", "shuffle_1_0_1" });
+    assertTrue(execFetch.successBlocks.isEmpty());
+    assertEquals(Sets.newHashSet("shuffle_1_0_0", "shuffle_1_0_1"), 
execFetch.failedBlocks);
+  }
+
+  @Test
+  public void testFetchInvalidShuffle() throws Exception {
+    registerExecutor("exec-1", dataContext1.createExecutorInfo("unknown sort 
manager"));
+    FetchResult execFetch = fetchBlocks("exec-1",
+      new String[] { "shuffle_1_0_0" });
+    assertTrue(execFetch.successBlocks.isEmpty());
+    assertEquals(Sets.newHashSet("shuffle_1_0_0"), execFetch.failedBlocks);
+  }
+
+  @Test
+  public void testFetchWrongBlockId() throws Exception {
+    registerExecutor("exec-1", dataContext1.createExecutorInfo(SORT_MANAGER /* 
wrong manager */));
+    FetchResult execFetch = fetchBlocks("exec-1",
+      new String[] { "rdd_1_0_0" });
+    assertTrue(execFetch.successBlocks.isEmpty());
+    assertEquals(Sets.newHashSet("rdd_1_0_0"), execFetch.failedBlocks);
+  }
+
+  @Test
+  public void testFetchNonexistent() throws Exception {
+    registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER));
+    FetchResult execFetch = fetchBlocks("exec-0",
+      new String[] { "shuffle_2_0_0" });
+    assertTrue(execFetch.successBlocks.isEmpty());
+    assertEquals(Sets.newHashSet("shuffle_2_0_0"), execFetch.failedBlocks);
+  }
+
+  @Test
+  public void testFetchWrongExecutor() throws Exception {
+    registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER));
+    FetchResult execFetch = fetchBlocks("exec-0",
+      new String[] { "shuffle_0_0_0" /* right */, "shuffle_1_0_0" /* wrong */ 
});
+    // Both still fail, as we start by checking for all block.
+    assertTrue(execFetch.successBlocks.isEmpty());
+    assertEquals(Sets.newHashSet("shuffle_0_0_0", "shuffle_1_0_0"), 
execFetch.failedBlocks);
+  }
+
+  @Test
+  public void testFetchUnregisteredExecutor() throws Exception {
+    registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER));
+    FetchResult execFetch = fetchBlocks("exec-2",
+      new String[] { "shuffle_0_0_0", "shuffle_1_0_0" });
+    assertTrue(execFetch.successBlocks.isEmpty());
+    assertEquals(Sets.newHashSet("shuffle_0_0_0", "shuffle_1_0_0"), 
execFetch.failedBlocks);
+  }
+
+  @Test
+  public void testFetchNoServer() throws Exception {
+    registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER));
+    FetchResult execFetch = fetchBlocks("exec-0",
+      new String[] { "shuffle_1_0_0", "shuffle_1_0_1" }, 1 /* port */);
+    assertTrue(execFetch.successBlocks.isEmpty());
+    assertEquals(Sets.newHashSet("shuffle_1_0_0", "shuffle_1_0_1"), 
execFetch.failedBlocks);
+  }
+
+  private void registerExecutor(String executorId, ExecutorShuffleInfo 
executorInfo) {
+    ExternalShuffleClient client = new ExternalShuffleClient(conf, APP_ID);
+    client.registerWithShuffleServer(TestUtils.getLocalHost(), 
server.getPort(),
+      executorId, executorInfo);
+  }
+
+  private void assertBufferListsEqual(List<ManagedBuffer> list0, List<byte[]> 
list1)
+    throws Exception {
+    assertEquals(list0.size(), list1.size());
+    for (int i = 0; i < list0.size(); i ++) {
+      assertBuffersEqual(list0.get(i), new 
NioManagedBuffer(ByteBuffer.wrap(list1.get(i))));
+    }
+  }
+
+  private void assertBuffersEqual(ManagedBuffer buffer0, ManagedBuffer 
buffer1) throws Exception {
+    ByteBuffer nio0 = buffer0.nioByteBuffer();
+    ByteBuffer nio1 = buffer1.nioByteBuffer();
+
+    int len = nio0.remaining();
+    assertEquals(nio0.remaining(), nio1.remaining());
+    for (int i = 0; i < len; i ++) {
+      assertEquals(nio0.get(), nio1.get());
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/f55218ae/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java
----------------------------------------------------------------------
diff --git 
a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java
 
b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java
new file mode 100644
index 0000000..c18346f
--- /dev/null
+++ 
b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java
@@ -0,0 +1,167 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import java.nio.ByteBuffer;
+import java.util.Iterator;
+import java.util.LinkedHashMap;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import com.google.common.collect.Maps;
+import io.netty.buffer.Unpooled;
+import org.junit.Test;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+
+import static org.junit.Assert.*;
+import static org.junit.Assert.assertEquals;
+import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.eq;
+import static org.mockito.Mockito.*;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NettyManagedBuffer;
+import org.apache.spark.network.buffer.NioManagedBuffer;
+import org.apache.spark.network.client.ChunkReceivedCallback;
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.util.JavaUtils;
+
+public class OneForOneBlockFetcherSuite {
+  @Test
+  public void testFetchOne() {
+    LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
+    blocks.put("shuffle_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new 
byte[0])));
+
+    BlockFetchingListener listener = fetchBlocks(blocks);
+
+    verify(listener).onBlockFetchSuccess("shuffle_0_0_0", 
blocks.get("shuffle_0_0_0"));
+  }
+
+  @Test
+  public void testFetchThree() {
+    LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
+    blocks.put("b0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12])));
+    blocks.put("b1", new NioManagedBuffer(ByteBuffer.wrap(new byte[23])));
+    blocks.put("b2", new NettyManagedBuffer(Unpooled.wrappedBuffer(new 
byte[23])));
+
+    BlockFetchingListener listener = fetchBlocks(blocks);
+
+    for (int i = 0; i < 3; i ++) {
+      verify(listener, times(1)).onBlockFetchSuccess("b" + i, blocks.get("b" + 
i));
+    }
+  }
+
+  @Test
+  public void testFailure() {
+    LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
+    blocks.put("b0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12])));
+    blocks.put("b1", null);
+    blocks.put("b2", null);
+
+    BlockFetchingListener listener = fetchBlocks(blocks);
+
+    // Each failure will cause a failure to be invoked in all remaining block 
fetches.
+    verify(listener, times(1)).onBlockFetchSuccess("b0", blocks.get("b0"));
+    verify(listener, times(1)).onBlockFetchFailure(eq("b1"), (Throwable) 
any());
+    verify(listener, times(2)).onBlockFetchFailure(eq("b2"), (Throwable) 
any());
+  }
+
+  @Test
+  public void testFailureAndSuccess() {
+    LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
+    blocks.put("b0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12])));
+    blocks.put("b1", null);
+    blocks.put("b2", new NioManagedBuffer(ByteBuffer.wrap(new byte[21])));
+
+    BlockFetchingListener listener = fetchBlocks(blocks);
+
+    // We may call both success and failure for the same block.
+    verify(listener, times(1)).onBlockFetchSuccess("b0", blocks.get("b0"));
+    verify(listener, times(1)).onBlockFetchFailure(eq("b1"), (Throwable) 
any());
+    verify(listener, times(1)).onBlockFetchSuccess("b2", blocks.get("b2"));
+    verify(listener, times(1)).onBlockFetchFailure(eq("b2"), (Throwable) 
any());
+  }
+
+  @Test
+  public void testEmptyBlockFetch() {
+    try {
+      fetchBlocks(Maps.<String, ManagedBuffer>newLinkedHashMap());
+      fail();
+    } catch (IllegalArgumentException e) {
+      assertEquals("Zero-sized blockIds array", e.getMessage());
+    }
+  }
+
+  /**
+   * Begins a fetch on the given set of blocks by mocking out the server side 
of the RPC which
+   * simply returns the given (BlockId, Block) pairs.
+   * As "blocks" is a LinkedHashMap, the blocks are guaranteed to be returned 
in the same order
+   * that they were inserted in.
+   *
+   * If a block's buffer is "null", an exception will be thrown instead.
+   */
+  private BlockFetchingListener fetchBlocks(final LinkedHashMap<String, 
ManagedBuffer> blocks) {
+    TransportClient client = mock(TransportClient.class);
+    BlockFetchingListener listener = mock(BlockFetchingListener.class);
+    String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]);
+    OneForOneBlockFetcher fetcher = new OneForOneBlockFetcher(client, 
blockIds, listener);
+
+    // Respond to the "OpenBlocks" message with an appropirate 
ShuffleStreamHandle with streamId 123
+    doAnswer(new Answer<Void>() {
+      @Override
+      public Void answer(InvocationOnMock invocationOnMock) throws Throwable {
+        String message = JavaUtils.deserialize((byte[]) 
invocationOnMock.getArguments()[0]);
+        RpcResponseCallback callback = (RpcResponseCallback) 
invocationOnMock.getArguments()[1];
+        callback.onSuccess(JavaUtils.serialize(new ShuffleStreamHandle(123, 
blocks.size())));
+        assertEquals("OpenZeBlocks", message);
+        return null;
+      }
+    }).when(client).sendRpc((byte[]) any(), (RpcResponseCallback) any());
+
+    // Respond to each chunk request with a single buffer from our blocks 
array.
+    final AtomicInteger expectedChunkIndex = new AtomicInteger(0);
+    final Iterator<ManagedBuffer> blockIterator = blocks.values().iterator();
+    doAnswer(new Answer<Void>() {
+      @Override
+      public Void answer(InvocationOnMock invocation) throws Throwable {
+        try {
+          long streamId = (Long) invocation.getArguments()[0];
+          int myChunkIndex = (Integer) invocation.getArguments()[1];
+          assertEquals(123, streamId);
+          assertEquals(expectedChunkIndex.getAndIncrement(), myChunkIndex);
+
+          ChunkReceivedCallback callback = (ChunkReceivedCallback) 
invocation.getArguments()[2];
+          ManagedBuffer result = blockIterator.next();
+          if (result != null) {
+            callback.onSuccess(myChunkIndex, result);
+          } else {
+            callback.onFailure(myChunkIndex, new RuntimeException("Failed " + 
myChunkIndex));
+          }
+        } catch (Exception e) {
+          e.printStackTrace();
+          fail("Unexpected failure");
+        }
+        return null;
+      }
+    }).when(client).fetchChunk(anyLong(), anyInt(), (ChunkReceivedCallback) 
any());
+
+    fetcher.start("OpenZeBlocks");
+    return listener;
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/f55218ae/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ShuffleMessagesSuite.java
----------------------------------------------------------------------
diff --git 
a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ShuffleMessagesSuite.java
 
b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ShuffleMessagesSuite.java
new file mode 100644
index 0000000..ee9482b
--- /dev/null
+++ 
b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ShuffleMessagesSuite.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import org.junit.Test;
+
+import static org.junit.Assert.*;
+
+import org.apache.spark.network.util.JavaUtils;
+
+import static org.apache.spark.network.shuffle.ExternalShuffleMessages.*;
+
+public class ShuffleMessagesSuite {
+  @Test
+  public void serializeOpenShuffleBlocks() {
+    OpenShuffleBlocks msg = new OpenShuffleBlocks("app-1", "exec-2",
+      new String[] { "block0", "block1" });
+    OpenShuffleBlocks msg2 = JavaUtils.deserialize(JavaUtils.serialize(msg));
+    assertEquals(msg, msg2);
+  }
+
+  @Test
+  public void serializeRegisterExecutor() {
+    RegisterExecutor msg = new RegisterExecutor("app-1", "exec-2", new 
ExecutorShuffleInfo(
+      new String[] { "/local1", "/local2" }, 32, "MyShuffleManager"));
+    RegisterExecutor msg2 = JavaUtils.deserialize(JavaUtils.serialize(msg));
+    assertEquals(msg, msg2);
+  }
+
+  @Test
+  public void serializeShuffleStreamHandle() {
+    ShuffleStreamHandle msg = new ShuffleStreamHandle(12345, 16);
+    ShuffleStreamHandle msg2 = JavaUtils.deserialize(JavaUtils.serialize(msg));
+    assertEquals(msg, msg2);
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/f55218ae/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java
----------------------------------------------------------------------
diff --git 
a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java
 
b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java
new file mode 100644
index 0000000..442b756
--- /dev/null
+++ 
b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import java.io.DataOutputStream;
+import java.io.File;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.OutputStream;
+
+import com.google.common.io.Files;
+
+/**
+ * Manages some sort- and hash-based shuffle data, including the creation
+ * and cleanup of directories that can be read by the {@link 
ExternalShuffleBlockManager}.
+ */
+public class TestShuffleDataContext {
+  private final String[] localDirs;
+  private final int subDirsPerLocalDir;
+
+  public TestShuffleDataContext(int numLocalDirs, int subDirsPerLocalDir) {
+    this.localDirs = new String[numLocalDirs];
+    this.subDirsPerLocalDir = subDirsPerLocalDir;
+  }
+
+  public void create() {
+    for (int i = 0; i < localDirs.length; i ++) {
+      localDirs[i] = Files.createTempDir().getAbsolutePath();
+
+      for (int p = 0; p < subDirsPerLocalDir; p ++) {
+        new File(localDirs[i], String.format("%02x", p)).mkdirs();
+      }
+    }
+  }
+
+  public void cleanup() {
+    for (String localDir : localDirs) {
+      deleteRecursively(new File(localDir));
+    }
+  }
+
+  /** Creates reducer blocks in a sort-based data format within our local 
dirs. */
+  public void insertSortShuffleData(int shuffleId, int mapId, byte[][] blocks) 
throws IOException {
+    String blockId = "shuffle_" + shuffleId + "_" + mapId + "_0";
+
+    OutputStream dataStream = new FileOutputStream(
+      ExternalShuffleBlockManager.getFile(localDirs, subDirsPerLocalDir, 
blockId + ".data"));
+    DataOutputStream indexStream = new DataOutputStream(new FileOutputStream(
+      ExternalShuffleBlockManager.getFile(localDirs, subDirsPerLocalDir, 
blockId + ".index")));
+
+    long offset = 0;
+    indexStream.writeLong(offset);
+    for (byte[] block : blocks) {
+      offset += block.length;
+      dataStream.write(block);
+      indexStream.writeLong(offset);
+    }
+
+    dataStream.close();
+    indexStream.close();
+  }
+
+  /** Creates reducer blocks in a hash-based data format within our local 
dirs. */
+  public void insertHashShuffleData(int shuffleId, int mapId, byte[][] blocks) 
throws IOException {
+    for (int i = 0; i < blocks.length; i ++) {
+      String blockId = "shuffle_" + shuffleId + "_" + mapId + "_" + i;
+      Files.write(blocks[i],
+        ExternalShuffleBlockManager.getFile(localDirs, subDirsPerLocalDir, 
blockId));
+    }
+  }
+
+  /**
+   * Creates an ExecutorShuffleInfo object based on the given shuffle manager 
which targets this
+   * context's directories.
+   */
+  public ExecutorShuffleInfo createExecutorInfo(String shuffleManager) {
+    return new ExecutorShuffleInfo(localDirs, subDirsPerLocalDir, 
shuffleManager);
+  }
+
+  private static void deleteRecursively(File f) {
+    assert f != null;
+    if (f.isDirectory()) {
+      File[] children = f.listFiles();
+      if (children != null) {
+        for (File child : children) {
+          deleteRecursively(child);
+        }
+      }
+    }
+    f.delete();
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/f55218ae/pom.xml
----------------------------------------------------------------------
diff --git a/pom.xml b/pom.xml
index 4c7806c..61a508a 100644
--- a/pom.xml
+++ b/pom.xml
@@ -92,6 +92,7 @@
     <module>mllib</module>
     <module>tools</module>
     <module>network/common</module>
+    <module>network/shuffle</module>
     <module>streaming</module>
     <module>sql/catalyst</module>
     <module>sql/core</module>

http://git-wip-us.apache.org/repos/asf/spark/blob/f55218ae/project/SparkBuild.scala
----------------------------------------------------------------------
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 7708351..33618f5 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -31,11 +31,12 @@ object BuildCommons {
   private val buildLocation = file(".").getAbsoluteFile.getParentFile
 
   val allProjects@Seq(bagel, catalyst, core, graphx, hive, hiveThriftServer, 
mllib, repl,
-  sql, networkCommon, streaming, streamingFlumeSink, streamingFlume, 
streamingKafka, streamingMqtt,
-  streamingTwitter, streamingZeromq) =
+  sql, networkCommon, networkShuffle, streaming, streamingFlumeSink, 
streamingFlume, streamingKafka,
+  streamingMqtt, streamingTwitter, streamingZeromq) =
     Seq("bagel", "catalyst", "core", "graphx", "hive", "hive-thriftserver", 
"mllib", "repl",
-      "sql", "network-common", "streaming", "streaming-flume-sink", 
"streaming-flume", "streaming-kafka",
-      "streaming-mqtt", "streaming-twitter", 
"streaming-zeromq").map(ProjectRef(buildLocation, _))
+      "sql", "network-common", "network-shuffle", "streaming", 
"streaming-flume-sink",
+      "streaming-flume", "streaming-kafka", "streaming-mqtt", 
"streaming-twitter",
+      "streaming-zeromq").map(ProjectRef(buildLocation, _))
 
   val optionallyEnabledProjects@Seq(yarn, yarnStable, yarnAlpha, java8Tests, 
sparkGangliaLgpl, sparkKinesisAsl) =
     Seq("yarn", "yarn-stable", "yarn-alpha", "java8-tests", "ganglia-lgpl", 
"kinesis-asl")
@@ -142,7 +143,7 @@ object SparkBuild extends PomBuild {
 
   // TODO: Add Sql to mima checks
   allProjects.filterNot(x => Seq(spark, sql, hive, hiveThriftServer, catalyst, 
repl,
-    streamingFlumeSink, networkCommon).contains(x)).foreach {
+    streamingFlumeSink, networkCommon, networkShuffle).contains(x)).foreach {
       x => enable(MimaBuild.mimaSettings(sparkHome, x))(x)
     }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to