You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by tw...@apache.org on 2020/07/09 08:27:38 UTC

[flink] 01/02: [FLINK-18524][table-common] Fix type inference for Scala varargs

This is an automated email from the ASF dual-hosted git repository.

twalthr pushed a commit to branch release-1.11
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 09b9162fd351fcb810cacbc2afa30ea7d4a108d4
Author: Timo Walther <tw...@apache.org>
AuthorDate: Wed Jul 8 11:13:57 2020 +0200

    [FLINK-18524][table-common] Fix type inference for Scala varargs
    
    This closes #12853.
---
 .../TypeInferenceExtractorScalaTest.scala          | 64 +++++++++++++++++++++-
 .../types/extraction/FunctionMappingExtractor.java | 43 ++++++++++++++-
 2 files changed, 102 insertions(+), 5 deletions(-)

diff --git a/flink-table/flink-table-api-scala/src/test/scala/org/apache/flink/table/types/extraction/TypeInferenceExtractorScalaTest.scala b/flink-table/flink-table-api-scala/src/test/scala/org/apache/flink/table/types/extraction/TypeInferenceExtractorScalaTest.scala
index df56b0c..abe1d84 100644
--- a/flink-table/flink-table-api-scala/src/test/scala/org/apache/flink/table/types/extraction/TypeInferenceExtractorScalaTest.scala
+++ b/flink-table/flink-table-api-scala/src/test/scala/org/apache/flink/table/types/extraction/TypeInferenceExtractorScalaTest.scala
@@ -34,6 +34,7 @@ import org.junit.runners.Parameterized.Parameters
 import org.junit.{Rule, Test}
 
 import scala.annotation.meta.getter
+import scala.annotation.varargs
 
 /**
  * Scala tests for [[TypeInferenceExtractor]].
@@ -78,7 +79,8 @@ object TypeInferenceExtractorScalaTest {
   def testData: Array[TestSpec] = Array(
 
     // Scala function with data type hint
-    TestSpec.forScalarFunction(classOf[ScalaScalarFunction])
+    TestSpec
+      .forScalarFunction(classOf[ScalaScalarFunction])
       .expectNamedArguments("i", "s", "d")
       .expectTypedArguments(
         DataTypes.INT.notNull().bridgedTo(classOf[Int]),
@@ -93,8 +95,42 @@ object TypeInferenceExtractorScalaTest {
             InputTypeStrategies.explicit(DataTypes.DECIMAL(10, 4)))),
         TypeStrategies.explicit(DataTypes.BOOLEAN.notNull().bridgedTo(classOf[Boolean]))),
 
+    TestSpec
+      .forScalarFunction(classOf[ScalaPrimitiveVarArgScalarFunction])
+      .expectOutputMapping(
+        InputTypeStrategies.varyingSequence(
+          Array[String]("i", "s", "d"),
+          Array[ArgumentTypeStrategy](
+            InputTypeStrategies.explicit(DataTypes.INT.notNull().bridgedTo(classOf[Int])),
+            InputTypeStrategies.explicit(DataTypes.STRING),
+            InputTypeStrategies.explicit(DataTypes.DOUBLE().notNull().bridgedTo(classOf[Double])))),
+        TypeStrategies.explicit(DataTypes.BOOLEAN.notNull().bridgedTo(classOf[Boolean]))),
+
+    TestSpec
+      .forScalarFunction(classOf[ScalaBoxedVarArgScalarFunction])
+      .expectOutputMapping(
+        InputTypeStrategies.varyingSequence(
+          Array[String]("i", "s", "d"),
+          Array[ArgumentTypeStrategy](
+            InputTypeStrategies.explicit(DataTypes.INT.notNull().bridgedTo(classOf[Int])),
+            InputTypeStrategies.explicit(DataTypes.STRING),
+            InputTypeStrategies.explicit(DataTypes.DOUBLE()))),
+        TypeStrategies.explicit(DataTypes.BOOLEAN.notNull().bridgedTo(classOf[Boolean]))),
+
+    TestSpec
+      .forScalarFunction(classOf[ScalaHintVarArgScalarFunction])
+      .expectOutputMapping(
+        InputTypeStrategies.varyingSequence(
+          Array[String]("i", "s", "d"),
+          Array[ArgumentTypeStrategy](
+            InputTypeStrategies.explicit(DataTypes.INT.notNull().bridgedTo(classOf[Int])),
+            InputTypeStrategies.explicit(DataTypes.STRING),
+            InputTypeStrategies.explicit(DataTypes.DECIMAL(10, 4)))),
+        TypeStrategies.explicit(DataTypes.BOOLEAN.notNull().bridgedTo(classOf[Boolean]))),
+
     // global output hint with local input overloading
-    TestSpec.forScalarFunction(classOf[ScalaGlobalOutputFunctionHint])
+    TestSpec
+      .forScalarFunction(classOf[ScalaGlobalOutputFunctionHint])
       .expectOutputMapping(
         InputTypeStrategies.sequence(InputTypeStrategies.explicit(DataTypes.INT)),
         TypeStrategies.explicit(DataTypes.INT))
@@ -122,4 +158,28 @@ object TypeInferenceExtractorScalaTest {
     @FunctionHint(input = Array(new DataTypeHint("STRING")))
     def eval(n: String): Integer = null
   }
+
+  private class ScalaPrimitiveVarArgScalarFunction extends ScalarFunction {
+    @varargs
+    def eval(
+      i: Int,
+      s: String,
+      d: Double*): Boolean = false
+  }
+
+  private class ScalaBoxedVarArgScalarFunction extends ScalarFunction {
+    @varargs
+    def eval(
+      i: Int,
+      s: String,
+      d: java.lang.Double*): Boolean = false
+  }
+
+  private class ScalaHintVarArgScalarFunction extends ScalarFunction {
+    @varargs
+    def eval(
+      i: Int,
+      s: String,
+      @DataTypeHint("ARRAY<DECIMAL(10, 4)>") d: java.math.BigDecimal*): Boolean = false
+  }
 }
diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionMappingExtractor.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionMappingExtractor.java
index becc597..1c6875c 100644
--- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionMappingExtractor.java
+++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionMappingExtractor.java
@@ -33,6 +33,8 @@ import javax.annotation.Nullable;
 
 import java.lang.reflect.Method;
 import java.lang.reflect.Parameter;
+import java.lang.reflect.ParameterizedType;
+import java.lang.reflect.Type;
 import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Map;
@@ -156,11 +158,13 @@ final class FunctionMappingExtractor {
 		}
 		for (Method method : methods) {
 			try {
+				final Method correctMethod = correctVarArgMethod(method);
+
 				final Map<FunctionSignatureTemplate, FunctionResultTemplate> collectedMappingsPerMethod =
-					collectMethodMappings(method, global, globalResultOnly, resultExtraction, accessor);
+					collectMethodMappings(correctMethod, global, globalResultOnly, resultExtraction, accessor);
 
 				// check if the method can be called
-				verifyMappingForMethod(method, collectedMappingsPerMethod, verification);
+				verifyMappingForMethod(correctMethod, collectedMappingsPerMethod, verification);
 
 				// check if method strategies conflict with function strategies
 				collectedMappingsPerMethod.forEach((signature, result) -> putMapping(collectedMappings, signature, result));
@@ -175,6 +179,39 @@ final class FunctionMappingExtractor {
 	}
 
 	/**
+	 * Special case for Scala which generates two methods when using var-args (a {@code Seq < String >}
+	 * and {@code String...}). This method searches for the Java-like variant.
+	 */
+	private static Method correctVarArgMethod(Method method) {
+		final int paramCount = method.getParameterCount();
+		final Class<?>[] paramClasses = method.getParameterTypes();
+		if (paramCount > 0 && paramClasses[paramCount - 1].getName().equals("scala.collection.Seq")) {
+			final Type[] paramTypes = method.getGenericParameterTypes();
+			final ParameterizedType seqType = (ParameterizedType) paramTypes[paramCount - 1];
+			final Type varArgType = seqType.getActualTypeArguments()[0];
+			return ExtractionUtils.collectMethods(method.getDeclaringClass(), method.getName())
+				.stream()
+				.filter(Method::isVarArgs)
+				.filter(candidate -> candidate.getParameterCount() == paramCount)
+				.filter(candidate -> {
+					final Type[] candidateParamTypes = candidate.getGenericParameterTypes();
+					for (int i = 0; i < paramCount - 1; i++) {
+						if (candidateParamTypes[i] != paramTypes[i]) {
+							return false;
+						}
+					}
+					final Class<?> candidateVarArgType = candidate.getParameterTypes()[paramCount - 1];
+					return candidateVarArgType.isArray() &&
+						// check for Object is needed in case of Scala primitives (e.g. Int)
+						(varArgType == Object.class || candidateVarArgType.getComponentType() == varArgType);
+				})
+				.findAny()
+				.orElse(method);
+		}
+		return method;
+	}
+
+	/**
 	 * Extracts mappings from signature to result (either accumulator or output) for the given method. It
 	 * considers both global hints for the entire function and local hints just for this method.
 	 *
@@ -368,7 +405,7 @@ final class FunctionMappingExtractor {
 				return FunctionArgumentTemplate.of(((CollectionDataType) type).getElementDataType());
 			}
 			// special case for varargs that have been misinterpreted as BYTES
-			else {
+			else if (type.equals(DataTypes.BYTES())) {
 				return FunctionArgumentTemplate.of(DataTypes.TINYINT().notNull().bridgedTo(byte.class));
 			}
 		}