You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by tw...@apache.org on 2020/07/09 08:27:38 UTC
[flink] 01/02: [FLINK-18524][table-common] Fix type inference for
Scala varargs
This is an automated email from the ASF dual-hosted git repository.
twalthr pushed a commit to branch release-1.11
in repository https://gitbox.apache.org/repos/asf/flink.git
commit 09b9162fd351fcb810cacbc2afa30ea7d4a108d4
Author: Timo Walther <tw...@apache.org>
AuthorDate: Wed Jul 8 11:13:57 2020 +0200
[FLINK-18524][table-common] Fix type inference for Scala varargs
This closes #12853.
---
.../TypeInferenceExtractorScalaTest.scala | 64 +++++++++++++++++++++-
.../types/extraction/FunctionMappingExtractor.java | 43 ++++++++++++++-
2 files changed, 102 insertions(+), 5 deletions(-)
diff --git a/flink-table/flink-table-api-scala/src/test/scala/org/apache/flink/table/types/extraction/TypeInferenceExtractorScalaTest.scala b/flink-table/flink-table-api-scala/src/test/scala/org/apache/flink/table/types/extraction/TypeInferenceExtractorScalaTest.scala
index df56b0c..abe1d84 100644
--- a/flink-table/flink-table-api-scala/src/test/scala/org/apache/flink/table/types/extraction/TypeInferenceExtractorScalaTest.scala
+++ b/flink-table/flink-table-api-scala/src/test/scala/org/apache/flink/table/types/extraction/TypeInferenceExtractorScalaTest.scala
@@ -34,6 +34,7 @@ import org.junit.runners.Parameterized.Parameters
import org.junit.{Rule, Test}
import scala.annotation.meta.getter
+import scala.annotation.varargs
/**
* Scala tests for [[TypeInferenceExtractor]].
@@ -78,7 +79,8 @@ object TypeInferenceExtractorScalaTest {
def testData: Array[TestSpec] = Array(
// Scala function with data type hint
- TestSpec.forScalarFunction(classOf[ScalaScalarFunction])
+ TestSpec
+ .forScalarFunction(classOf[ScalaScalarFunction])
.expectNamedArguments("i", "s", "d")
.expectTypedArguments(
DataTypes.INT.notNull().bridgedTo(classOf[Int]),
@@ -93,8 +95,42 @@ object TypeInferenceExtractorScalaTest {
InputTypeStrategies.explicit(DataTypes.DECIMAL(10, 4)))),
TypeStrategies.explicit(DataTypes.BOOLEAN.notNull().bridgedTo(classOf[Boolean]))),
+ TestSpec
+ .forScalarFunction(classOf[ScalaPrimitiveVarArgScalarFunction])
+ .expectOutputMapping(
+ InputTypeStrategies.varyingSequence(
+ Array[String]("i", "s", "d"),
+ Array[ArgumentTypeStrategy](
+ InputTypeStrategies.explicit(DataTypes.INT.notNull().bridgedTo(classOf[Int])),
+ InputTypeStrategies.explicit(DataTypes.STRING),
+ InputTypeStrategies.explicit(DataTypes.DOUBLE().notNull().bridgedTo(classOf[Double])))),
+ TypeStrategies.explicit(DataTypes.BOOLEAN.notNull().bridgedTo(classOf[Boolean]))),
+
+ TestSpec
+ .forScalarFunction(classOf[ScalaBoxedVarArgScalarFunction])
+ .expectOutputMapping(
+ InputTypeStrategies.varyingSequence(
+ Array[String]("i", "s", "d"),
+ Array[ArgumentTypeStrategy](
+ InputTypeStrategies.explicit(DataTypes.INT.notNull().bridgedTo(classOf[Int])),
+ InputTypeStrategies.explicit(DataTypes.STRING),
+ InputTypeStrategies.explicit(DataTypes.DOUBLE()))),
+ TypeStrategies.explicit(DataTypes.BOOLEAN.notNull().bridgedTo(classOf[Boolean]))),
+
+ TestSpec
+ .forScalarFunction(classOf[ScalaHintVarArgScalarFunction])
+ .expectOutputMapping(
+ InputTypeStrategies.varyingSequence(
+ Array[String]("i", "s", "d"),
+ Array[ArgumentTypeStrategy](
+ InputTypeStrategies.explicit(DataTypes.INT.notNull().bridgedTo(classOf[Int])),
+ InputTypeStrategies.explicit(DataTypes.STRING),
+ InputTypeStrategies.explicit(DataTypes.DECIMAL(10, 4)))),
+ TypeStrategies.explicit(DataTypes.BOOLEAN.notNull().bridgedTo(classOf[Boolean]))),
+
// global output hint with local input overloading
- TestSpec.forScalarFunction(classOf[ScalaGlobalOutputFunctionHint])
+ TestSpec
+ .forScalarFunction(classOf[ScalaGlobalOutputFunctionHint])
.expectOutputMapping(
InputTypeStrategies.sequence(InputTypeStrategies.explicit(DataTypes.INT)),
TypeStrategies.explicit(DataTypes.INT))
@@ -122,4 +158,28 @@ object TypeInferenceExtractorScalaTest {
@FunctionHint(input = Array(new DataTypeHint("STRING")))
def eval(n: String): Integer = null
}
+
+ private class ScalaPrimitiveVarArgScalarFunction extends ScalarFunction {
+ @varargs
+ def eval(
+ i: Int,
+ s: String,
+ d: Double*): Boolean = false
+ }
+
+ private class ScalaBoxedVarArgScalarFunction extends ScalarFunction {
+ @varargs
+ def eval(
+ i: Int,
+ s: String,
+ d: java.lang.Double*): Boolean = false
+ }
+
+ private class ScalaHintVarArgScalarFunction extends ScalarFunction {
+ @varargs
+ def eval(
+ i: Int,
+ s: String,
+ @DataTypeHint("ARRAY<DECIMAL(10, 4)>") d: java.math.BigDecimal*): Boolean = false
+ }
}
diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionMappingExtractor.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionMappingExtractor.java
index becc597..1c6875c 100644
--- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionMappingExtractor.java
+++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionMappingExtractor.java
@@ -33,6 +33,8 @@ import javax.annotation.Nullable;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
+import java.lang.reflect.ParameterizedType;
+import java.lang.reflect.Type;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
@@ -156,11 +158,13 @@ final class FunctionMappingExtractor {
}
for (Method method : methods) {
try {
+ final Method correctMethod = correctVarArgMethod(method);
+
final Map<FunctionSignatureTemplate, FunctionResultTemplate> collectedMappingsPerMethod =
- collectMethodMappings(method, global, globalResultOnly, resultExtraction, accessor);
+ collectMethodMappings(correctMethod, global, globalResultOnly, resultExtraction, accessor);
// check if the method can be called
- verifyMappingForMethod(method, collectedMappingsPerMethod, verification);
+ verifyMappingForMethod(correctMethod, collectedMappingsPerMethod, verification);
// check if method strategies conflict with function strategies
collectedMappingsPerMethod.forEach((signature, result) -> putMapping(collectedMappings, signature, result));
@@ -175,6 +179,39 @@ final class FunctionMappingExtractor {
}
/**
+ * Special case for Scala which generates two methods when using var-args (a {@code Seq < String >}
+ * and {@code String...}). This method searches for the Java-like variant.
+ */
+ private static Method correctVarArgMethod(Method method) {
+ final int paramCount = method.getParameterCount();
+ final Class<?>[] paramClasses = method.getParameterTypes();
+ if (paramCount > 0 && paramClasses[paramCount - 1].getName().equals("scala.collection.Seq")) {
+ final Type[] paramTypes = method.getGenericParameterTypes();
+ final ParameterizedType seqType = (ParameterizedType) paramTypes[paramCount - 1];
+ final Type varArgType = seqType.getActualTypeArguments()[0];
+ return ExtractionUtils.collectMethods(method.getDeclaringClass(), method.getName())
+ .stream()
+ .filter(Method::isVarArgs)
+ .filter(candidate -> candidate.getParameterCount() == paramCount)
+ .filter(candidate -> {
+ final Type[] candidateParamTypes = candidate.getGenericParameterTypes();
+ for (int i = 0; i < paramCount - 1; i++) {
+ if (candidateParamTypes[i] != paramTypes[i]) {
+ return false;
+ }
+ }
+ final Class<?> candidateVarArgType = candidate.getParameterTypes()[paramCount - 1];
+ return candidateVarArgType.isArray() &&
+ // check for Object is needed in case of Scala primitives (e.g. Int)
+ (varArgType == Object.class || candidateVarArgType.getComponentType() == varArgType);
+ })
+ .findAny()
+ .orElse(method);
+ }
+ return method;
+ }
+
+ /**
* Extracts mappings from signature to result (either accumulator or output) for the given method. It
* considers both global hints for the entire function and local hints just for this method.
*
@@ -368,7 +405,7 @@ final class FunctionMappingExtractor {
return FunctionArgumentTemplate.of(((CollectionDataType) type).getElementDataType());
}
// special case for varargs that have been misinterpreted as BYTES
- else {
+ else if (type.equals(DataTypes.BYTES())) {
return FunctionArgumentTemplate.of(DataTypes.TINYINT().notNull().bridgedTo(byte.class));
}
}