You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by fh...@apache.org on 2016/11/29 14:42:56 UTC

[2/3] flink git commit: [FLINK-5184] [table] Fix compareSerialized() of RowComparator.

[FLINK-5184] [table] Fix compareSerialized() of RowComparator.

This closes #2894


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/0bb68479
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/0bb68479
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/0bb68479

Branch: refs/heads/master
Commit: 0bb684797dfb3e03dd4f9761a6bf1eb8ce9d1c0d
Parents: db441de
Author: godfreyhe <go...@163.com>
Authored: Tue Nov 29 19:27:58 2016 +0800
Committer: Fabian Hueske <fh...@apache.org>
Committed: Tue Nov 29 13:19:26 2016 +0100

----------------------------------------------------------------------
 .../api/table/typeutils/RowComparator.scala     | 16 +++-
 .../flink/api/table/typeutils/RowTypeInfo.scala |  1 +
 .../RowComparatorWithManyFieldsTest.scala       | 82 ++++++++++++++++++++
 3 files changed, 95 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/0bb68479/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/RowComparator.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/RowComparator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/RowComparator.scala
index cc97656..8bbe4d8 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/RowComparator.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/RowComparator.scala
@@ -32,6 +32,8 @@ import org.apache.flink.types.KeyFieldOutOfBoundsException
  * Comparator for [[Row]].
  */
 class RowComparator private (
+    /** the number of fields of the Row */
+    val numberOfFields: Int,
     /** key positions describe which fields are keys in what order */
     val keyPositions: Array[Int],
     /** null-aware comparators for the key fields, in the same order as the key fields */
@@ -43,8 +45,8 @@ class RowComparator private (
   extends CompositeTypeComparator[Row] with Serializable {
 
   // null masks for serialized comparison
-  private val nullMask1 = new Array[Boolean](serializers.length)
-  private val nullMask2 = new Array[Boolean](serializers.length)
+  private val nullMask1 = new Array[Boolean](numberOfFields)
+  private val nullMask2 = new Array[Boolean](numberOfFields)
 
   // cache for the deserialized key field objects
   @transient
@@ -63,10 +65,12 @@ class RowComparator private (
    * Intermediate constructor for creating auxiliary fields.
    */
   def this(
+      numberOfFields: Int,
       keyPositions: Array[Int],
       comparators: Array[NullAwareComparator[Any]],
       serializers: Array[TypeSerializer[Any]]) = {
     this(
+      numberOfFields,
       keyPositions,
       comparators,
       serializers,
@@ -76,6 +80,7 @@ class RowComparator private (
   /**
    * General constructor for RowComparator.
    *
+   * @param numberOfFields the number of fields of the Row
    * @param keyPositions key positions describe which fields are keys in what order
    * @param comparators non-null-aware comparators for the key fields, in the same order as
    *   the key fields
@@ -83,11 +88,13 @@ class RowComparator private (
    * @param orders sorting orders for the fields
    */
   def this(
+      numberOfFields: Int,
       keyPositions: Array[Int],
       comparators: Array[TypeComparator[Any]],
       serializers: Array[TypeSerializer[Any]],
       orders: Array[Boolean]) = {
     this(
+      numberOfFields,
       keyPositions,
       makeNullAware(comparators, orders),
       serializers)
@@ -133,8 +140,8 @@ class RowComparator private (
     val len = serializers.length
     val keyLen = keyPositions.length
 
-    readIntoNullMask(len, firstSource, nullMask1)
-    readIntoNullMask(len, secondSource, nullMask2)
+    readIntoNullMask(numberOfFields, firstSource, nullMask1)
+    readIntoNullMask(numberOfFields, secondSource, nullMask2)
 
     // deserialize
     var i = 0
@@ -217,6 +224,7 @@ class RowComparator private (
     val serializersCopy = serializers.map(_.duplicate())
 
     new RowComparator(
+      numberOfFields,
       keyPositions,
       comparatorsCopy,
       serializersCopy,

http://git-wip-us.apache.org/repos/asf/flink/blob/0bb68479/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/RowTypeInfo.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/RowTypeInfo.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/RowTypeInfo.scala
index 489edca..711bb49 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/RowTypeInfo.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/typeutils/RowTypeInfo.scala
@@ -96,6 +96,7 @@ class RowTypeInfo(fieldTypes: Seq[TypeInformation[_]])
       val maxIndex = logicalKeyFields.max
 
       new RowComparator(
+        getArity,
         logicalKeyFields.toArray,
         fieldComparators.toArray.asInstanceOf[Array[TypeComparator[Any]]],
         types.take(maxIndex + 1).map(_.createSerializer(config).asInstanceOf[TypeSerializer[Any]]),

http://git-wip-us.apache.org/repos/asf/flink/blob/0bb68479/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/typeutils/RowComparatorWithManyFieldsTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/typeutils/RowComparatorWithManyFieldsTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/typeutils/RowComparatorWithManyFieldsTest.scala
new file mode 100644
index 0000000..33715c1
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/typeutils/RowComparatorWithManyFieldsTest.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.api.table.typeutils
+
+import org.apache.flink.api.common.ExecutionConfig
+import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
+import org.apache.flink.api.common.typeutils.{ComparatorTestBase, TypeComparator, TypeSerializer}
+import org.apache.flink.api.table.Row
+import org.apache.flink.util.Preconditions
+import org.junit.Assert._
+
+/**
+  * Tests [[RowComparator]] for wide rows.
+  */
+class RowComparatorWithManyFieldsTest extends ComparatorTestBase[Row] {
+  val numberOfFields = 10
+  val fieldTypes = new Array[TypeInformation[_]](numberOfFields)
+  for (i <- 0 until numberOfFields) {
+    fieldTypes(i) = BasicTypeInfo.STRING_TYPE_INFO
+  }
+  val typeInfo = new RowTypeInfo(fieldTypes)
+
+  val data: Array[Row] = Array(
+    createRow(Array(null, "b0", "c0", "d0", "e0", "f0", "g0", "h0", "i0", "j0")),
+    createRow(Array("a1", "b1", "c1", "d1", "e1", "f1", "g1", "h1", "i1", "j1")),
+    createRow(Array("a2", "b2", "c2", "d2", "e2", "f2", "g2", "h2", "i2", "j2")),
+    createRow(Array("a3", "b3", "c3", "d3", "e3", "f3", "g3", "h3", "i3", "j3"))
+  )
+
+  override protected def deepEquals(message: String, should: Row, is: Row): Unit = {
+    val arity = should.productArity
+    assertEquals(message, arity, is.productArity)
+    var index = 0
+    while (index < arity) {
+      val copiedValue: Any = should.productElement(index)
+      val element: Any = is.productElement(index)
+      assertEquals(message, element, copiedValue)
+      index += 1
+    }
+  }
+
+  override protected def createComparator(ascending: Boolean): TypeComparator[Row] = {
+    typeInfo.createComparator(
+      Array(0),
+      Array(ascending),
+      0,
+      new ExecutionConfig())
+  }
+
+  override protected def createSerializer(): TypeSerializer[Row] = {
+    typeInfo.createSerializer(new ExecutionConfig())
+  }
+
+  override protected def getSortedTestData: Array[Row] = {
+    data
+  }
+
+  override protected def supportsNullKeys: Boolean = true
+
+  private def createRow(values: Array[_]): Row = {
+    Preconditions.checkArgument(values.length == numberOfFields)
+    val r: Row = new Row(numberOfFields)
+    values.zipWithIndex.foreach { case (e, i) => r.setField(i, e) }
+    r
+  }
+}