[FLINK-5184] [table] Fix compareSerialized() of RowComparator. This closes #2894
Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/0bb68479 Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/0bb68479 Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/0bb68479 Branch: refs/heads/master Commit: 0bb684797dfb3e03dd4f9761a6bf1eb8ce9d1c0d Parents: db441de Author: godfreyhe <[email protected]> Authored: Tue Nov 29 19:27:58 2016 +0800 Committer: Fabian Hueske <[email protected]> Committed: Tue Nov 29 13:19:26 2016 +0100 ---------------------------------------------------------------------- .../api/table/typeutils/RowComparator.scala | 16 +++- .../flink/api/table/typeutils/RowTypeInfo.scala | 1 + .../RowComparatorWithManyFieldsTest.scala | 82 ++++++++++++++++++++ 3 files changed, 95 insertions(+), 4 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/0bb68479/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/RowComparator.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/RowComparator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/RowComparator.scala index cc97656..8bbe4d8 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/RowComparator.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/RowComparator.scala @@ -32,6 +32,8 @@ import org.apache.flink.types.KeyFieldOutOfBoundsException * Comparator for [[Row]]. */ class RowComparator private ( + /** the number of fields of the Row */ + val numberOfFields: Int, /** key positions describe which fields are keys in what order */ val keyPositions: Array[Int], /** null-aware comparators for the key fields, in the same order as the key fields */ @@ -43,8 +45,8 @@ class RowComparator private ( extends CompositeTypeComparator[Row] with Serializable { // null masks for serialized comparison - private val nullMask1 = new Array[Boolean](serializers.length) - private val nullMask2 = new Array[Boolean](serializers.length) + private val nullMask1 = new Array[Boolean](numberOfFields) + private val nullMask2 = new Array[Boolean](numberOfFields) // cache for the deserialized key field objects @transient @@ -63,10 +65,12 @@ class RowComparator private ( * Intermediate constructor for creating auxiliary fields. */ def this( + numberOfFields: Int, keyPositions: Array[Int], comparators: Array[NullAwareComparator[Any]], serializers: Array[TypeSerializer[Any]]) = { this( + numberOfFields, keyPositions, comparators, serializers, @@ -76,6 +80,7 @@ class RowComparator private ( /** * General constructor for RowComparator. * + * @param numberOfFields the number of fields of the Row * @param keyPositions key positions describe which fields are keys in what order * @param comparators non-null-aware comparators for the key fields, in the same order as * the key fields @@ -83,11 +88,13 @@ class RowComparator private ( * @param orders sorting orders for the fields */ def this( + numberOfFields: Int, keyPositions: Array[Int], comparators: Array[TypeComparator[Any]], serializers: Array[TypeSerializer[Any]], orders: Array[Boolean]) = { this( + numberOfFields, keyPositions, makeNullAware(comparators, orders), serializers) @@ -133,8 +140,8 @@ class RowComparator private ( val len = serializers.length val keyLen = keyPositions.length - readIntoNullMask(len, firstSource, nullMask1) - readIntoNullMask(len, secondSource, nullMask2) + readIntoNullMask(numberOfFields, firstSource, nullMask1) + readIntoNullMask(numberOfFields, secondSource, nullMask2) // deserialize var i = 0 @@ -217,6 +224,7 @@ class RowComparator private ( val serializersCopy = serializers.map(_.duplicate()) new RowComparator( + numberOfFields, keyPositions, comparatorsCopy, serializersCopy, http://git-wip-us.apache.org/repos/asf/flink/blob/0bb68479/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/RowTypeInfo.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/RowTypeInfo.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/RowTypeInfo.scala index 489edca..711bb49 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/RowTypeInfo.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/RowTypeInfo.scala @@ -96,6 +96,7 @@ class RowTypeInfo(fieldTypes: Seq[TypeInformation[_]]) val maxIndex = logicalKeyFields.max new RowComparator( + getArity, logicalKeyFields.toArray, fieldComparators.toArray.asInstanceOf[Array[TypeComparator[Any]]], types.take(maxIndex + 1).map(_.createSerializer(config).asInstanceOf[TypeSerializer[Any]]), http://git-wip-us.apache.org/repos/asf/flink/blob/0bb68479/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/typeutils/RowComparatorWithManyFieldsTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/typeutils/RowComparatorWithManyFieldsTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/typeutils/RowComparatorWithManyFieldsTest.scala new file mode 100644 index 0000000..33715c1 --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/typeutils/RowComparatorWithManyFieldsTest.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.table.typeutils + +import org.apache.flink.api.common.ExecutionConfig +import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} +import org.apache.flink.api.common.typeutils.{ComparatorTestBase, TypeComparator, TypeSerializer} +import org.apache.flink.api.table.Row +import org.apache.flink.util.Preconditions +import org.junit.Assert._ + +/** + * Tests [[RowComparator]] for wide rows. + */ +class RowComparatorWithManyFieldsTest extends ComparatorTestBase[Row] { + val numberOfFields = 10 + val fieldTypes = new Array[TypeInformation[_]](numberOfFields) + for (i <- 0 until numberOfFields) { + fieldTypes(i) = BasicTypeInfo.STRING_TYPE_INFO + } + val typeInfo = new RowTypeInfo(fieldTypes) + + val data: Array[Row] = Array( + createRow(Array(null, "b0", "c0", "d0", "e0", "f0", "g0", "h0", "i0", "j0")), + createRow(Array("a1", "b1", "c1", "d1", "e1", "f1", "g1", "h1", "i1", "j1")), + createRow(Array("a2", "b2", "c2", "d2", "e2", "f2", "g2", "h2", "i2", "j2")), + createRow(Array("a3", "b3", "c3", "d3", "e3", "f3", "g3", "h3", "i3", "j3")) + ) + + override protected def deepEquals(message: String, should: Row, is: Row): Unit = { + val arity = should.productArity + assertEquals(message, arity, is.productArity) + var index = 0 + while (index < arity) { + val copiedValue: Any = should.productElement(index) + val element: Any = is.productElement(index) + assertEquals(message, element, copiedValue) + index += 1 + } + } + + override protected def createComparator(ascending: Boolean): TypeComparator[Row] = { + typeInfo.createComparator( + Array(0), + Array(ascending), + 0, + new ExecutionConfig()) + } + + override protected def createSerializer(): TypeSerializer[Row] = { + typeInfo.createSerializer(new ExecutionConfig()) + } + + override protected def getSortedTestData: Array[Row] = { + data + } + + override protected def supportsNullKeys: Boolean = true + + private def createRow(values: Array[_]): Row = { + Preconditions.checkArgument(values.length == numberOfFields) + val r: Row = new Row(numberOfFields) + values.zipWithIndex.foreach { case (e, i) => r.setField(i, e) } + r + } +}
