Repository: flink
Updated Branches:
  refs/heads/master aec6ded5e -> c08bcf1e0


http://git-wip-us.apache.org/repos/asf/flink/blob/f5957ce3/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/PartitionITCase.scala
----------------------------------------------------------------------
diff --git 
a/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/PartitionITCase.scala
 
b/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/PartitionITCase.scala
index 7d26643..ca8bcd9 100644
--- 
a/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/PartitionITCase.scala
+++ 
b/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/PartitionITCase.scala
@@ -66,6 +66,22 @@ class PartitionITCase(mode: TestExecutionMode) extends 
MultipleProgramsTestBase(
   }
 
   @Test
+  def testRangePartitionByTupleField(): Unit = {
+    /*
+     * Test hash partition by tuple field
+     */
+    val env = ExecutionEnvironment.getExecutionEnvironment
+    val ds = CollectionDataSets.get3TupleDataSet(env)
+
+    val unique = ds.partitionByRange(1).mapPartition( _.map(_._2).toSet )
+
+    unique.writeAsText(resultPath, WriteMode.OVERWRITE)
+    env.execute()
+
+    expected = "1\n" + "2\n" + "3\n" + "4\n" + "5\n" + "6\n"
+  }
+
+  @Test
   def testHashPartitionByKeySelector(): Unit = {
     /*
      * Test hash partition by key selector
@@ -80,6 +96,20 @@ class PartitionITCase(mode: TestExecutionMode) extends 
MultipleProgramsTestBase(
   }
 
   @Test
+  def testRangePartitionByKeySelector(): Unit = {
+    /*
+     * Test hash partition by key selector
+     */
+    val env = ExecutionEnvironment.getExecutionEnvironment
+    val ds = CollectionDataSets.get3TupleDataSet(env)
+    val unique = ds.partitionByRange( _._2 ).mapPartition( _.map(_._2).toSet )
+
+    unique.writeAsText(resultPath, WriteMode.OVERWRITE)
+    env.execute()
+    expected = "1\n" + "2\n" + "3\n" + "4\n" + "5\n" + "6\n"
+  }
+
+  @Test
   def testForcedRebalancing(): Unit = {
     /*
      * Test forced rebalancing
@@ -129,6 +159,24 @@ class PartitionITCase(mode: TestExecutionMode) extends 
MultipleProgramsTestBase(
   }
 
   @Test
+  def testMapPartitionAfterRepartitionHasCorrectParallelism2(): Unit = {
+    // Verify that mapPartition operation after repartition picks up correct
+    // parallelism
+    val env = ExecutionEnvironment.getExecutionEnvironment
+    val ds = CollectionDataSets.get3TupleDataSet(env)
+    env.setParallelism(1)
+
+    val unique = ds.partitionByRange(1)
+      .setParallelism(4)
+      .mapPartition( _.map(_._2).toSet )
+
+    unique.writeAsText(resultPath, WriteMode.OVERWRITE)
+    env.execute()
+
+    expected = "1\n" + "2\n" + "3\n" + "4\n" + "5\n" + "6\n"
+  }
+
+  @Test
   def testMapAfterRepartitionHasCorrectParallelism(): Unit = {
     // Verify that map operation after repartition picks up correct
     // parallelism
@@ -157,6 +205,35 @@ class PartitionITCase(mode: TestExecutionMode) extends 
MultipleProgramsTestBase(
   }
 
   @Test
+  def testMapAfterRepartitionHasCorrectParallelism2(): Unit = {
+    // Verify that map operation after repartition picks up correct
+    // parallelism
+    val env = ExecutionEnvironment.getExecutionEnvironment
+    val ds = CollectionDataSets.get3TupleDataSet(env)
+    env.setParallelism(1)
+
+    val count = ds.partitionByRange(0).setParallelism(4).map(
+      new RichMapFunction[(Int, Long, String), Tuple1[Int]] {
+        var first = true
+        override def map(in: (Int, Long, String)): Tuple1[Int] = {
+          // only output one value with count 1
+          if (first) {
+            first = false
+            Tuple1(1)
+          } else {
+            Tuple1(0)
+          }
+        }
+      }).sum(0)
+
+    count.writeAsText(resultPath, WriteMode.OVERWRITE)
+    env.execute()
+
+    expected = if (mode == TestExecutionMode.COLLECTION) "(1)\n" else "(4)\n"
+  }
+
+
+  @Test
   def testFilterAfterRepartitionHasCorrectParallelism(): Unit = {
     // Verify that filter operation after repartition picks up correct
     // parallelism
@@ -186,7 +263,36 @@ class PartitionITCase(mode: TestExecutionMode) extends 
MultipleProgramsTestBase(
   }
 
   @Test
-  def testPartitionNestedPojo(): Unit = {
+  def testFilterAfterRepartitionHasCorrectParallelism2(): Unit = {
+    // Verify that filter operation after repartition picks up correct
+    // parallelism
+    val env = ExecutionEnvironment.getExecutionEnvironment
+    val ds = CollectionDataSets.get3TupleDataSet(env)
+    env.setParallelism(1)
+
+    val count = ds.partitionByRange(0).setParallelism(4).filter(
+      new RichFilterFunction[(Int, Long, String)] {
+        var first = true
+        override def filter(in: (Int, Long, String)): Boolean = {
+          // only output one value with count 1
+          if (first) {
+            first = false
+            true
+          } else {
+            false
+          }
+        }
+      })
+      .map( _ => Tuple1(1)).sum(0)
+
+    count.writeAsText(resultPath, WriteMode.OVERWRITE)
+    env.execute()
+
+    expected = if (mode == TestExecutionMode.COLLECTION) "(1)\n" else "(4)\n"
+  }
+
+  @Test
+  def testHashPartitionNestedPojo(): Unit = {
     val env = ExecutionEnvironment.getExecutionEnvironment
     env.setParallelism(3)
     val ds = CollectionDataSets.getDuplicatePojoDataSet(env)
@@ -199,4 +305,19 @@ class PartitionITCase(mode: TestExecutionMode) extends 
MultipleProgramsTestBase(
     env.execute()
     expected = "10000\n" + "20000\n" + "30000\n"
   }
+
+  @Test
+  def testRangePartitionNestedPojo(): Unit = {
+    val env = ExecutionEnvironment.getExecutionEnvironment
+    env.setParallelism(3)
+    val ds = CollectionDataSets.getDuplicatePojoDataSet(env)
+    val uniqLongs = ds
+      .partitionByRange("nestedPojo.longNumber")
+      .setParallelism(4)
+      .mapPartition( _.map(_.nestedPojo.longNumber).toSet )
+
+    uniqLongs.writeAsText(resultPath, WriteMode.OVERWRITE)
+    env.execute()
+    expected = "10000\n" + "20000\n" + "30000\n"
+  }
 }

Reply via email to