[FLINK-3234] [dataSet] Add KeySelector support to sortPartition operation.

This closes #1585


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/0a63797a
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/0a63797a
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/0a63797a

Branch: refs/heads/tableOnCalcite
Commit: 0a63797a6a5418b2363bca25bd77c33c217ff257
Parents: 572855d
Author: Chiwan Park <chiwanp...@apache.org>
Authored: Thu Feb 4 20:46:10 2016 +0900
Committer: Fabian Hueske <fhue...@apache.org>
Committed: Wed Feb 10 11:51:26 2016 +0100

----------------------------------------------------------------------
 .../java/org/apache/flink/api/java/DataSet.java |  18 ++
 .../java/operators/SortPartitionOperator.java   | 174 +++++++++++++------
 .../api/java/operator/SortPartitionTest.java    |  82 +++++++++
 .../org/apache/flink/api/scala/DataSet.scala    |  25 +++
 .../api/scala/PartitionSortedDataSet.scala      |  22 ++-
 .../javaApiOperators/SortPartitionITCase.java   |  61 +++++++
 .../scala/operators/SortPartitionITCase.scala   |  59 +++++++
 7 files changed, 385 insertions(+), 56 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/0a63797a/flink-java/src/main/java/org/apache/flink/api/java/DataSet.java
----------------------------------------------------------------------
diff --git a/flink-java/src/main/java/org/apache/flink/api/java/DataSet.java 
b/flink-java/src/main/java/org/apache/flink/api/java/DataSet.java
index bfb97f4..c315920 100644
--- a/flink-java/src/main/java/org/apache/flink/api/java/DataSet.java
+++ b/flink-java/src/main/java/org/apache/flink/api/java/DataSet.java
@@ -1381,6 +1381,24 @@ public abstract class DataSet<T> {
                return new SortPartitionOperator<>(this, field, order, 
Utils.getCallLocationName());
        }
 
+       /**
+        * Locally sorts the partitions of the DataSet on the extracted key in 
the specified order.
+        * The DataSet can be sorted on multiple values by returning a tuple 
from the KeySelector.
+        *
+        * Note that no additional sort keys can be appended to a KeySelector 
sort keys. To sort
+        * the partitions by multiple values using KeySelector, the KeySelector 
must return a tuple
+        * consisting of the values.
+        *
+        * @param keyExtractor The KeySelector function which extracts the key 
values from the DataSet
+        *                     on which the DataSet is sorted.
+        * @param order The order in which the DataSet is sorted.
+        * @return The DataSet with sorted local partitions.
+        */
+       public <K> SortPartitionOperator<T> sortPartition(KeySelector<T, K> 
keyExtractor, Order order) {
+               final TypeInformation<K> keyType = 
TypeExtractor.getKeySelectorTypes(keyExtractor, getType());
+               return new SortPartitionOperator<>(this, new 
Keys.SelectorFunctionKeys<>(clean(keyExtractor), getType(), keyType), order, 
Utils.getCallLocationName());
+       }
+
        // 
--------------------------------------------------------------------------------------------
        //  Top-K
        // 
--------------------------------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/flink/blob/0a63797a/flink-java/src/main/java/org/apache/flink/api/java/operators/SortPartitionOperator.java
----------------------------------------------------------------------
diff --git 
a/flink-java/src/main/java/org/apache/flink/api/java/operators/SortPartitionOperator.java
 
b/flink-java/src/main/java/org/apache/flink/api/java/operators/SortPartitionOperator.java
index 354a0cd..7f30a30 100644
--- 
a/flink-java/src/main/java/org/apache/flink/api/java/operators/SortPartitionOperator.java
+++ 
b/flink-java/src/main/java/org/apache/flink/api/java/operators/SortPartitionOperator.java
@@ -26,9 +26,13 @@ import org.apache.flink.api.common.operators.Order;
 import org.apache.flink.api.common.operators.Ordering;
 import org.apache.flink.api.common.operators.UnaryOperatorInformation;
 import org.apache.flink.api.common.operators.base.SortPartitionOperatorBase;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.api.java.DataSet;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
 
-import java.util.Arrays;
+import java.util.ArrayList;
+import java.util.List;
 
 /**
  * This operator represents a DataSet with locally sorted partitions.
@@ -38,27 +42,58 @@ import java.util.Arrays;
 @Public
 public class SortPartitionOperator<T> extends SingleInputOperator<T, T, 
SortPartitionOperator<T>> {
 
-       private int[] sortKeyPositions;
+       private List<Keys<T>> keys;
 
-       private Order[] sortOrders;
+       private List<Order> orders;
 
        private final String sortLocationName;
 
+       private boolean useKeySelector;
 
-       public SortPartitionOperator(DataSet<T> dataSet, int sortField, Order 
sortOrder, String sortLocationName) {
+       private SortPartitionOperator(DataSet<T> dataSet, String 
sortLocationName) {
                super(dataSet, dataSet.getType());
+
+               keys = new ArrayList<>();
+               orders = new ArrayList<>();
                this.sortLocationName = sortLocationName;
+       }
+
+
+       public SortPartitionOperator(DataSet<T> dataSet, int sortField, Order 
sortOrder, String sortLocationName) {
+               this(dataSet, sortLocationName);
+               this.useKeySelector = false;
+
+               ensureSortableKey(sortField);
 
-               int[] flatOrderKeys = getFlatFields(sortField);
-               this.appendSorting(flatOrderKeys, sortOrder);
+               keys.add(new Keys.ExpressionKeys<>(sortField, getType()));
+               orders.add(sortOrder);
        }
 
        public SortPartitionOperator(DataSet<T> dataSet, String sortField, 
Order sortOrder, String sortLocationName) {
-               super(dataSet, dataSet.getType());
-               this.sortLocationName = sortLocationName;
+               this(dataSet, sortLocationName);
+               this.useKeySelector = false;
+
+               ensureSortableKey(sortField);
+
+               keys.add(new Keys.ExpressionKeys<>(sortField, getType()));
+               orders.add(sortOrder);
+       }
+
+       public <K> SortPartitionOperator(DataSet<T> dataSet, 
Keys.SelectorFunctionKeys<T, K> sortKey, Order sortOrder, String 
sortLocationName) {
+               this(dataSet, sortLocationName);
+               this.useKeySelector = true;
+
+               ensureSortableKey(sortKey);
 
-               int[] flatOrderKeys = getFlatFields(sortField);
-               this.appendSorting(flatOrderKeys, sortOrder);
+               keys.add(sortKey);
+               orders.add(sortOrder);
+       }
+
+       /**
+        * Returns whether using key selector or not.
+     */
+       public boolean useKeySelector() {
+               return useKeySelector;
        }
 
        /**
@@ -70,9 +105,14 @@ public class SortPartitionOperator<T> extends 
SingleInputOperator<T, T, SortPart
         * @return The DataSet with sorted local partitions.
         */
        public SortPartitionOperator<T> sortPartition(int field, Order order) {
+               if (useKeySelector) {
+                       throw new InvalidProgramException("Expression keys 
cannot be appended after a KeySelector");
+               }
+
+               ensureSortableKey(field);
+               keys.add(new Keys.ExpressionKeys<>(field, getType()));
+               orders.add(order);
 
-               int[] flatOrderKeys = getFlatFields(field);
-               this.appendSorting(flatOrderKeys, order);
                return this;
        }
 
@@ -81,58 +121,41 @@ public class SortPartitionOperator<T> extends 
SingleInputOperator<T, T, SortPart
         * local partition sorting of the DataSet.
         *
         * @param field The field expression referring to the field of the 
additional sort order of
-        *                 the local partition sorting.
-        * @param order The order  of the additional sort order of the local 
partition sorting.
+        *              the local partition sorting.
+        * @param order The order of the additional sort order of the local 
partition sorting.
         * @return The DataSet with sorted local partitions.
         */
        public SortPartitionOperator<T> sortPartition(String field, Order 
order) {
-               int[] flatOrderKeys = getFlatFields(field);
-               this.appendSorting(flatOrderKeys, order);
+               if (useKeySelector) {
+                       throw new InvalidProgramException("Expression keys 
cannot be appended after a KeySelector");
+               }
+
+               ensureSortableKey(field);
+               keys.add(new Keys.ExpressionKeys<>(field, getType()));
+               orders.add(order);
+
                return this;
        }
 
-       // 
--------------------------------------------------------------------------------------------
-       //  Key Extraction
-       // 
--------------------------------------------------------------------------------------------
-
-       private int[] getFlatFields(int field) {
+       public <K> SortPartitionOperator<T> sortPartition(KeySelector<T, K> 
keyExtractor, Order order) {
+               throw new InvalidProgramException("KeySelector cannot be 
chained.");
+       }
 
-               if (!Keys.ExpressionKeys.isSortKey(field, super.getType())) {
+       private void ensureSortableKey(int field) throws 
InvalidProgramException {
+               if (!Keys.ExpressionKeys.isSortKey(field, getType())) {
                        throw new InvalidProgramException("Selected sort key is 
not a sortable type");
                }
-
-               Keys.ExpressionKeys<T> ek = new Keys.ExpressionKeys<>(field, 
super.getType());
-               return ek.computeLogicalKeyPositions();
        }
 
-       private int[] getFlatFields(String fields) {
-
-               if (!Keys.ExpressionKeys.isSortKey(fields, super.getType())) {
+       private void ensureSortableKey(String field) throws 
InvalidProgramException {
+               if (!Keys.ExpressionKeys.isSortKey(field, getType())) {
                        throw new InvalidProgramException("Selected sort key is 
not a sortable type");
                }
-
-               Keys.ExpressionKeys<T> ek = new Keys.ExpressionKeys<>(fields, 
super.getType());
-               return ek.computeLogicalKeyPositions();
        }
 
-       private void appendSorting(int[] flatOrderFields, Order order) {
-
-               if(this.sortKeyPositions == null) {
-                       // set sorting info
-                       this.sortKeyPositions = flatOrderFields;
-                       this.sortOrders = new Order[flatOrderFields.length];
-                       Arrays.fill(this.sortOrders, order);
-               } else {
-                       // append sorting info to exising info
-                       int oldLength = this.sortKeyPositions.length;
-                       int newLength = oldLength + flatOrderFields.length;
-                       this.sortKeyPositions = 
Arrays.copyOf(this.sortKeyPositions, newLength);
-                       this.sortOrders = Arrays.copyOf(this.sortOrders, 
newLength);
-
-                       for(int i=0; i<flatOrderFields.length; i++) {
-                               this.sortKeyPositions[oldLength+i] = 
flatOrderFields[i];
-                               this.sortOrders[oldLength+i] = order;
-                       }
+       private <K> void ensureSortableKey(Keys.SelectorFunctionKeys<T, K> 
sortKey) {
+               if (!sortKey.getKeyType().isSortKeyType()) {
+                       throw new InvalidProgramException("Selected sort key is 
not a sortable type");
                }
        }
 
@@ -144,16 +167,33 @@ public class SortPartitionOperator<T> extends 
SingleInputOperator<T, T, SortPart
 
                String name = "Sort at " + sortLocationName;
 
+               if (useKeySelector) {
+                       return translateToDataFlowWithKeyExtractor(input, 
(Keys.SelectorFunctionKeys<T, ?>) keys.get(0), orders.get(0), name);
+               }
+
+               // flatten sort key positions
+               List<Integer> allKeyPositions = new ArrayList<>();
+               List<Order> allOrders = new ArrayList<>();
+               for (int i = 0, length = keys.size(); i < length; i++) {
+                       int[] sortKeyPositions = 
keys.get(i).computeLogicalKeyPositions();
+                       Order order = orders.get(i);
+
+                       for (int sortKeyPosition : sortKeyPositions) {
+                               allKeyPositions.add(sortKeyPosition);
+                               allOrders.add(order);
+                       }
+               }
+
                Ordering partitionOrdering = new Ordering();
-               for (int i = 0; i < this.sortKeyPositions.length; i++) {
-                       
partitionOrdering.appendOrdering(this.sortKeyPositions[i], null, 
this.sortOrders[i]);
+               for (int i = 0, length = allKeyPositions.size(); i < length; 
i++) {
+                       
partitionOrdering.appendOrdering(allKeyPositions.get(i), null, 
allOrders.get(i));
                }
 
                // distinguish between partition types
                UnaryOperatorInformation<T, T> operatorInfo = new 
UnaryOperatorInformation<>(getType(), getType());
-               SortPartitionOperatorBase<T> noop = new  
SortPartitionOperatorBase<>(operatorInfo, partitionOrdering, name);
+               SortPartitionOperatorBase<T> noop = new 
SortPartitionOperatorBase<>(operatorInfo, partitionOrdering, name);
                noop.setInput(input);
-               if(this.getParallelism() < 0) {
+               if (this.getParallelism() < 0) {
                        // use parallelism of input if not explicitly specified
                        noop.setParallelism(input.getParallelism());
                } else {
@@ -165,4 +205,32 @@ public class SortPartitionOperator<T> extends 
SingleInputOperator<T, T, SortPart
 
        }
 
+       private <K> 
org.apache.flink.api.common.operators.SingleInputOperator<?, T, ?> 
translateToDataFlowWithKeyExtractor(
+               Operator<T> input, Keys.SelectorFunctionKeys<T, K> keys, Order 
order, String name) {
+               TypeInformation<Tuple2<K, T>> typeInfoWithKey = 
KeyFunctions.createTypeWithKey(keys);
+               Keys.ExpressionKeys<Tuple2<K, T>> newKey = new 
Keys.ExpressionKeys<>(0, typeInfoWithKey);
+
+               Operator<Tuple2<K, T>> keyedInput = 
KeyFunctions.appendKeyExtractor(input, keys);
+
+               int[] sortKeyPositions = newKey.computeLogicalKeyPositions();
+               Ordering partitionOrdering = new Ordering();
+               for (int keyPosition : sortKeyPositions) {
+                       partitionOrdering.appendOrdering(keyPosition, null, 
order);
+               }
+
+               // distinguish between partition types
+               UnaryOperatorInformation<Tuple2<K, T>, Tuple2<K, T>> 
operatorInfo = new UnaryOperatorInformation<>(typeInfoWithKey, typeInfoWithKey);
+               SortPartitionOperatorBase<Tuple2<K, T>> noop = new 
SortPartitionOperatorBase<>(operatorInfo, partitionOrdering, name);
+               noop.setInput(keyedInput);
+               if (this.getParallelism() < 0) {
+                       // use parallelism of input if not explicitly specified
+                       noop.setParallelism(input.getParallelism());
+               } else {
+                       // use explicitly specified parallelism
+                       noop.setParallelism(this.getParallelism());
+               }
+
+               return KeyFunctions.appendKeyRemover(noop, keys);
+       }
+
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/0a63797a/flink-java/src/test/java/org/apache/flink/api/java/operator/SortPartitionTest.java
----------------------------------------------------------------------
diff --git 
a/flink-java/src/test/java/org/apache/flink/api/java/operator/SortPartitionTest.java
 
b/flink-java/src/test/java/org/apache/flink/api/java/operator/SortPartitionTest.java
index a4e2bbc..3540e6a 100644
--- 
a/flink-java/src/test/java/org/apache/flink/api/java/operator/SortPartitionTest.java
+++ 
b/flink-java/src/test/java/org/apache/flink/api/java/operator/SortPartitionTest.java
@@ -169,6 +169,88 @@ public class SortPartitionTest {
                tupleDs.sortPartition("f3", Order.ASCENDING);
        }
 
+       @Test
+       public void testSortPartitionWithKeySelector1() {
+               final ExecutionEnvironment env = 
ExecutionEnvironment.getExecutionEnvironment();
+               DataSet<Tuple4<Integer, Long, CustomType, Long[]>> tupleDs = 
env.fromCollection(tupleWithCustomData, tupleWithCustomInfo);
+
+               // should work
+               try {
+                       tupleDs.sortPartition(new KeySelector<Tuple4<Integer, 
Long, CustomType, Long[]>, Integer>() {
+                               @Override
+                               public Integer getKey(Tuple4<Integer, Long, 
CustomType, Long[]> value) throws Exception {
+                                       return value.f0;
+                               }
+                       }, Order.ASCENDING);
+               } catch (Exception e) {
+                       Assert.fail();
+               }
+       }
+
+       @Test(expected = InvalidProgramException.class)
+       public void testSortPartitionWithKeySelector2() {
+               final ExecutionEnvironment env = 
ExecutionEnvironment.getExecutionEnvironment();
+               DataSet<Tuple4<Integer, Long, CustomType, Long[]>> tupleDs = 
env.fromCollection(tupleWithCustomData, tupleWithCustomInfo);
+
+               // must not work
+               tupleDs.sortPartition(new KeySelector<Tuple4<Integer, Long, 
CustomType, Long[]>, Long[]>() {
+                       @Override
+                       public Long[] getKey(Tuple4<Integer, Long, CustomType, 
Long[]> value) throws Exception {
+                               return value.f3;
+                       }
+               }, Order.ASCENDING);
+       }
+
+       @Test(expected = InvalidProgramException.class)
+       public void testSortPartitionWithKeySelector3() {
+               final ExecutionEnvironment env = 
ExecutionEnvironment.getExecutionEnvironment();
+               DataSet<Tuple4<Integer, Long, CustomType, Long[]>> tupleDs = 
env.fromCollection(tupleWithCustomData, tupleWithCustomInfo);
+
+               // must not work
+               tupleDs
+                       .sortPartition("f1", Order.ASCENDING)
+                       .sortPartition(new KeySelector<Tuple4<Integer, Long, 
CustomType, Long[]>, CustomType>() {
+                               @Override
+                               public CustomType getKey(Tuple4<Integer, Long, 
CustomType, Long[]> value) throws Exception {
+                                       return value.f2;
+                               }
+                       }, Order.ASCENDING);
+       }
+
+       @Test
+       public void testSortPartitionWithKeySelector4() {
+               final ExecutionEnvironment env = 
ExecutionEnvironment.getExecutionEnvironment();
+               DataSet<Tuple4<Integer, Long, CustomType, Long[]>> tupleDs = 
env.fromCollection(tupleWithCustomData, tupleWithCustomInfo);
+
+               // should work
+               try {
+                       tupleDs.sortPartition(new 
KeySelector<Tuple4<Integer,Long,CustomType,Long[]>, Tuple2<Integer, Long>>() {
+                               @Override
+                               public Tuple2<Integer, Long> 
getKey(Tuple4<Integer, Long, CustomType, Long[]> value) throws Exception {
+                                       return new Tuple2<>(value.f0, value.f1);
+                               }
+                       }, Order.ASCENDING);
+               } catch (Exception e) {
+                       Assert.fail();
+               }
+       }
+
+       @Test(expected = InvalidProgramException.class)
+       public void testSortPartitionWithKeySelector5() {
+               final ExecutionEnvironment env = 
ExecutionEnvironment.getExecutionEnvironment();
+               DataSet<Tuple4<Integer, Long, CustomType, Long[]>> tupleDs = 
env.fromCollection(tupleWithCustomData, tupleWithCustomInfo);
+
+               // must not work
+               tupleDs
+                       .sortPartition(new KeySelector<Tuple4<Integer, Long, 
CustomType, Long[]>, CustomType>() {
+                               @Override
+                               public CustomType getKey(Tuple4<Integer, Long, 
CustomType, Long[]> value) throws Exception {
+                                       return value.f2;
+                               }
+                       }, Order.ASCENDING)
+                       .sortPartition("f1", Order.ASCENDING);
+       }
+
        public static class CustomType implements Serializable {
                
                public static class Nest {

http://git-wip-us.apache.org/repos/asf/flink/blob/0a63797a/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala
----------------------------------------------------------------------
diff --git 
a/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala 
b/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala
index e47bc42..5735b32 100644
--- a/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala
+++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala
@@ -1511,6 +1511,31 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) {
       new SortPartitionOperator[T](javaSet, field, order, 
getCallLocationName()))
   }
 
+  /**
+    * Locally sorts the partitions of the DataSet on the extracted key in the 
specified order.
+    * The DataSet can be sorted on multiple values by returning a tuple from 
the KeySelector.
+    *
+    * Note that no additional sort keys can be appended to a KeySelector sort 
keys. To sort
+    * the partitions by multiple values using KeySelector, the KeySelector 
must return a tuple
+    * consisting of the values.
+    */
+  def sortPartition[K: TypeInformation](fun: T => K, order: Order): DataSet[T] 
={
+    val keyExtractor = new KeySelector[T, K] {
+      val cleanFun = clean(fun)
+      def getKey(in: T) = cleanFun(in)
+    }
+
+    val keyType = implicitly[TypeInformation[K]]
+    new PartitionSortedDataSet[T](
+      new SortPartitionOperator[T](javaSet,
+        new Keys.SelectorFunctionKeys[T, K](
+          keyExtractor,
+          javaSet.getType,
+          keyType),
+        order,
+        getCallLocationName()))
+  }
+
   // 
--------------------------------------------------------------------------------------------
   //  Result writing
   // 
--------------------------------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/flink/blob/0a63797a/flink-scala/src/main/scala/org/apache/flink/api/scala/PartitionSortedDataSet.scala
----------------------------------------------------------------------
diff --git 
a/flink-scala/src/main/scala/org/apache/flink/api/scala/PartitionSortedDataSet.scala
 
b/flink-scala/src/main/scala/org/apache/flink/api/scala/PartitionSortedDataSet.scala
index c924a76..a402dd9 100644
--- 
a/flink-scala/src/main/scala/org/apache/flink/api/scala/PartitionSortedDataSet.scala
+++ 
b/flink-scala/src/main/scala/org/apache/flink/api/scala/PartitionSortedDataSet.scala
@@ -18,7 +18,9 @@
 package org.apache.flink.api.scala
 
 import org.apache.flink.annotation.Public
+import org.apache.flink.api.common.InvalidProgramException
 import org.apache.flink.api.common.operators.Order
+import org.apache.flink.api.common.typeinfo.TypeInformation
 import org.apache.flink.api.java.operators.SortPartitionOperator
 
 import scala.reflect.ClassTag
@@ -37,16 +39,30 @@ class PartitionSortedDataSet[T: ClassTag](set: 
SortPartitionOperator[T])
    * Appends the given field and order to the sort-partition operator.
    */
   override def sortPartition(field: Int, order: Order): DataSet[T] = {
+    if (set.useKeySelector()) {
+      throw new InvalidProgramException("Expression keys cannot be appended 
after selector " +
+        "function keys")
+    }
+
     this.set.sortPartition(field, order)
     this
   }
 
-/**
- * Appends the given field and order to the sort-partition operator.
- */
+  /**
+   * Appends the given field and order to the sort-partition operator.
+   */
   override def sortPartition(field: String, order: Order): DataSet[T] = {
+    if (set.useKeySelector()) {
+      throw new InvalidProgramException("Expression keys cannot be appended 
after selector " +
+        "function keys")
+    }
+
     this.set.sortPartition(field, order)
     this
   }
 
+  override def sortPartition[K: TypeInformation](fun: T => K, order: Order): 
DataSet[T] = {
+    throw new InvalidProgramException("KeySelector cannot be chained.")
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/0a63797a/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/SortPartitionITCase.java
----------------------------------------------------------------------
diff --git 
a/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/SortPartitionITCase.java
 
b/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/SortPartitionITCase.java
index 2423420..c7f07f6 100644
--- 
a/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/SortPartitionITCase.java
+++ 
b/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/SortPartitionITCase.java
@@ -23,6 +23,7 @@ import 
org.apache.flink.api.common.functions.MapPartitionFunction;
 import org.apache.flink.api.common.operators.Order;
 import org.apache.flink.api.java.DataSet;
 import org.apache.flink.api.java.ExecutionEnvironment;
+import org.apache.flink.api.java.functions.KeySelector;
 import org.apache.flink.api.java.tuple.Tuple1;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.api.java.tuple.Tuple3;
@@ -197,6 +198,58 @@ public class SortPartitionITCase extends 
MultipleProgramsTestBase {
                compareResultAsText(result, expected);
        }
 
+       @Test
+       public void testSortPartitionWithKeySelector1() throws Exception {
+               /*
+                * Test sort partition on an extracted key
+                */
+
+               final ExecutionEnvironment env = 
ExecutionEnvironment.getExecutionEnvironment();
+               env.setParallelism(4);
+
+               DataSet<Tuple3<Integer, Long, String>> ds = 
CollectionDataSets.get3TupleDataSet(env);
+               List<Tuple1<Boolean>> result = ds
+                       .map(new IdMapper<Tuple3<Integer, Long, 
String>>()).setParallelism(4) // parallelize input
+                       .sortPartition(new KeySelector<Tuple3<Integer, Long, 
String>, Long>() {
+                               @Override
+                               public Long getKey(Tuple3<Integer, Long, 
String> value) throws Exception {
+                                       return value.f1;
+                               }
+                       }, Order.ASCENDING)
+                       .mapPartition(new OrderCheckMapper<>(new 
Tuple3AscendingChecker()))
+                       .distinct().collect();
+
+               String expected = "(true)\n";
+
+               compareResultAsText(result, expected);
+       }
+
+       @Test
+       public void testSortPartitionWithKeySelector2() throws Exception {
+               /*
+                * Test sort partition on an extracted key
+                */
+
+               final ExecutionEnvironment env = 
ExecutionEnvironment.getExecutionEnvironment();
+               env.setParallelism(4);
+
+               DataSet<Tuple3<Integer, Long, String>> ds = 
CollectionDataSets.get3TupleDataSet(env);
+               List<Tuple1<Boolean>> result = ds
+                       .map(new IdMapper<Tuple3<Integer, Long, 
String>>()).setParallelism(4) // parallelize input
+                       .sortPartition(new KeySelector<Tuple3<Integer, Long, 
String>, Tuple2<Integer, Long>>() {
+                               @Override
+                               public Tuple2<Integer, Long> 
getKey(Tuple3<Integer, Long, String> value) throws Exception {
+                                       return new Tuple2<>(value.f0, value.f1);
+                               }
+                       }, Order.DESCENDING)
+                       .mapPartition(new OrderCheckMapper<>(new 
Tuple3Checker()))
+                       .distinct().collect();
+
+               String expected = "(true)\n";
+
+               compareResultAsText(result, expected);
+       }
+
        public interface OrderChecker<T> extends Serializable {
                boolean inOrder(T t1, T t2);
        }
@@ -210,6 +263,14 @@ public class SortPartitionITCase extends 
MultipleProgramsTestBase {
        }
 
        @SuppressWarnings("serial")
+       public static class Tuple3AscendingChecker implements 
OrderChecker<Tuple3<Integer, Long, String>> {
+               @Override
+               public boolean inOrder(Tuple3<Integer, Long, String> t1, 
Tuple3<Integer, Long, String> t2) {
+                       return t1.f1 <= t2.f1;
+               }
+       }
+
+       @SuppressWarnings("serial")
        public static class Tuple5Checker implements 
OrderChecker<Tuple5<Integer, Long, Integer, String, Long>> {
                @Override
                public boolean inOrder(Tuple5<Integer, Long, Integer, String, 
Long> t1,

http://git-wip-us.apache.org/repos/asf/flink/blob/0a63797a/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/SortPartitionITCase.scala
----------------------------------------------------------------------
diff --git 
a/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/SortPartitionITCase.scala
 
b/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/SortPartitionITCase.scala
index 3f67063..cda8f4f 100644
--- 
a/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/SortPartitionITCase.scala
+++ 
b/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/SortPartitionITCase.scala
@@ -24,6 +24,7 @@ import scala.collection.JavaConverters._
 
 import org.apache.flink.api.common.functions.MapPartitionFunction
 import org.apache.flink.api.common.operators.Order
+import org.apache.flink.api.common.InvalidProgramException
 import org.apache.flink.api.scala._
 import org.apache.flink.api.scala.util.CollectionDataSets
 import org.apache.flink.test.util.MultipleProgramsTestBase.TestExecutionMode
@@ -166,6 +167,58 @@ class SortPartitionITCase(mode: TestExecutionMode) extends 
MultipleProgramsTestB
     TestBaseUtils.compareResultAsText(result.asJava, expected)
   }
 
+  @Test
+  def testSortPartitionWithKeySelector1(): Unit = {
+    val env = ExecutionEnvironment.getExecutionEnvironment
+    env.setParallelism(4)
+    val ds = CollectionDataSets.get3TupleDataSet(env)
+
+    val result = ds
+      .map { x => x }.setParallelism(4)
+      .sortPartition(_._2, Order.ASCENDING)
+      .mapPartition(new OrderCheckMapper(new Tuple3AscendingChecker))
+      .distinct()
+      .collect()
+
+    val expected: String = "(true)\n"
+    TestBaseUtils.compareResultAsText(result.asJava, expected)
+  }
+
+  @Test
+  def testSortPartitionWithKeySelector2(): Unit = {
+    val env = ExecutionEnvironment.getExecutionEnvironment
+    env.setParallelism(4)
+    val ds = CollectionDataSets.get3TupleDataSet(env)
+
+    val result = ds
+      .map { x => x }.setParallelism(4)
+      .sortPartition(x => (x._2, x._1), Order.DESCENDING)
+      .mapPartition(new OrderCheckMapper(new Tuple3Checker))
+      .distinct()
+      .collect()
+
+    val expected: String = "(true)\n"
+    TestBaseUtils.compareResultAsText(result.asJava, expected)
+  }
+
+  @Test(expected = classOf[InvalidProgramException])
+  def testSortPartitionWithKeySelector3(): Unit = {
+    val env = ExecutionEnvironment.getExecutionEnvironment
+    env.setParallelism(4)
+    val ds = CollectionDataSets.get3TupleDataSet(env)
+
+    val result = ds
+      .map { x => x }.setParallelism(4)
+      .sortPartition(x => (x._2, x._1), Order.DESCENDING)
+      .sortPartition(0, Order.DESCENDING)
+      .mapPartition(new OrderCheckMapper(new Tuple3Checker))
+      .distinct()
+      .collect()
+
+    val expected: String = "(true)\n"
+    TestBaseUtils.compareResultAsText(result.asJava, expected)
+  }
+
 }
 
 trait OrderChecker[T] extends Serializable {
@@ -178,6 +231,12 @@ class Tuple3Checker extends OrderChecker[(Int, Long, 
String)] {
   }
 }
 
+class Tuple3AscendingChecker extends OrderChecker[(Int, Long, String)] {
+  def inOrder(t1: (Int, Long, String), t2: (Int, Long, String)): Boolean = {
+    t1._2 <= t2._2
+  }
+}
+
 class Tuple5Checker extends OrderChecker[(Int, Long, Int, String, Long)] {
   def inOrder(t1: (Int, Long, Int, String, Long), t2: (Int, Long, Int, String, 
Long)): Boolean = {
     t1._5 < t2._5 || t1._5 == t2._5 && t1._3 >= t2._3

Reply via email to