You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@samza.apache.org by sr...@apache.org on 2019/02/11 22:08:52 UTC
[samza] branch master updated: Support for types in Samza SQL UDF
(#911)
This is an automated email from the ASF dual-hosted git repository.
srinivasulu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/samza.git
The following commit(s) were added to refs/heads/master by this push:
new 01ec54d Support for types in Samza SQL UDF (#911)
01ec54d is described below
commit 01ec54dd6ed3c2e7a6558a42732ffd8a0e4e0dc8
Author: Srinivasulu Punuru <sr...@users.noreply.github.com>
AuthorDate: Mon Feb 11 14:08:47 2019 -0800
Support for types in Samza SQL UDF (#911)
* Support for types in udf
* Option to disable argument check for dynamic functions
* Added comments
* Update based on review comments
* Adding license
* Added some comments
* fix for test
---
.../org/apache/samza/sql/udfs/SamzaSqlUdf.java | 5 +++
.../apache/samza/sql/udfs/SamzaSqlUdfMethod.java | 6 +++
.../java/org/apache/samza/sql/udfs/ScalarUdf.java | 5 +--
.../samza/sql/data/SamzaSqlExecutionContext.java | 18 ++++++--
.../apache/samza/sql/fn/BuildOutputRecordUdf.java | 4 +-
.../apache/samza/sql/fn/ConvertToStringUdf.java | 9 ++--
.../java/org/apache/samza/sql/fn/FlattenUdf.java | 8 ++--
.../org/apache/samza/sql/fn/GetSqlFieldUdf.java | 18 ++++----
.../org/apache/samza/sql/fn/RegexMatchUdf.java | 9 ++--
.../samza/sql/impl/ConfigBasedUdfResolver.java | 16 ++++---
.../apache/samza/sql/interfaces/UdfMetadata.java | 26 +++++++++--
.../org/apache/samza/sql/planner/QueryPlanner.java | 2 +-
.../sql/planner/SamzaSqlScalarFunctionImpl.java | 35 +++++++++------
.../sql/planner/SamzaSqlUdfOperatorTable.java | 50 +++++++++++++---------
.../sql/runner/TestSamzaSqlApplicationConfig.java | 3 +-
.../org/apache/samza/sql/util/MyTestArrayUdf.java | 8 ++--
.../util/{MyTestUdf.java => MyTestPolyUdf.java} | 23 +++++-----
.../java/org/apache/samza/sql/util/MyTestUdf.java | 15 +++++--
.../apache/samza/sql/util/SamzaSqlTestConfig.java | 2 +-
.../samza/test/samzasql/TestSamzaSqlEndToEnd.java | 24 +++++++++++
20 files changed, 195 insertions(+), 91 deletions(-)
diff --git a/samza-api/src/main/java/org/apache/samza/sql/udfs/SamzaSqlUdf.java b/samza-api/src/main/java/org/apache/samza/sql/udfs/SamzaSqlUdf.java
index de1821e..6a55c07 100644
--- a/samza-api/src/main/java/org/apache/samza/sql/udfs/SamzaSqlUdf.java
+++ b/samza-api/src/main/java/org/apache/samza/sql/udfs/SamzaSqlUdf.java
@@ -37,6 +37,11 @@ public @interface SamzaSqlUdf {
String name();
/**
+ * Description of the UDF
+ */
+ String description();
+
+ /**
* Whether the UDF is enabled or not.
*/
boolean enabled() default true;
diff --git a/samza-api/src/main/java/org/apache/samza/sql/udfs/SamzaSqlUdfMethod.java b/samza-api/src/main/java/org/apache/samza/sql/udfs/SamzaSqlUdfMethod.java
index 9b1c7ef..cba2749 100644
--- a/samza-api/src/main/java/org/apache/samza/sql/udfs/SamzaSqlUdfMethod.java
+++ b/samza-api/src/main/java/org/apache/samza/sql/udfs/SamzaSqlUdfMethod.java
@@ -34,6 +34,12 @@ import org.apache.samza.sql.schema.SamzaSqlFieldType;
public @interface SamzaSqlUdfMethod {
/**
+ * Whether the argument check needs to be disabled. This is useful if the udf takes in
+ * dynamic number of arguments
+ */
+ boolean disableArgumentCheck() default false;
+
+ /**
* Type of the arguments for the Samza SQL udf method
*/
SamzaSqlFieldType[] params() default {};
diff --git a/samza-api/src/main/java/org/apache/samza/sql/udfs/ScalarUdf.java b/samza-api/src/main/java/org/apache/samza/sql/udfs/ScalarUdf.java
index 3307dc0..ff67487 100644
--- a/samza-api/src/main/java/org/apache/samza/sql/udfs/ScalarUdf.java
+++ b/samza-api/src/main/java/org/apache/samza/sql/udfs/ScalarUdf.java
@@ -23,8 +23,8 @@ import org.apache.samza.config.Config;
/**
- * The base class for the Scalar UDFs. All the scalar UDF classes needs to extend this and implement a method named
- * "execute". The number and type of arguments for the execute method in the UDF class should match the number and type of fields
+ * The base class for the Scalar UDFs. All the scalar UDF classes needs to extend this.
+ * The number and type of arguments for the method annotated with {@link SamzaSqlUdfMethod} in the UDF class should match the number and type of fields
* used while invoking this UDF in SQL statement.
* Say for e.g. User creates a UDF class with signature int execute(int var1, String var2). It can be used in a SQL query
* select myudf(id, name) from profile
@@ -36,5 +36,4 @@ public interface ScalarUdf {
* @param udfConfig Config specific to the udf.
*/
void init(Config udfConfig);
-
}
diff --git a/samza-sql/src/main/java/org/apache/samza/sql/data/SamzaSqlExecutionContext.java b/samza-sql/src/main/java/org/apache/samza/sql/data/SamzaSqlExecutionContext.java
index 091ca62..2a8f92e 100644
--- a/samza-sql/src/main/java/org/apache/samza/sql/data/SamzaSqlExecutionContext.java
+++ b/samza-sql/src/main/java/org/apache/samza/sql/data/SamzaSqlExecutionContext.java
@@ -19,7 +19,9 @@
package org.apache.samza.sql.data;
+import java.util.ArrayList;
import java.util.HashMap;
+import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;
@@ -38,7 +40,10 @@ public class SamzaSqlExecutionContext implements Cloneable {
* The variables that are shared among all cloned instance of {@link SamzaSqlExecutionContext}
*/
private final SamzaSqlApplicationConfig sqlConfig;
- private final Map<String, UdfMetadata> udfMetadata;
+
+ // Maps the UDF name to list of all UDF methods associated with the name.
+ // Since we support polymorphism there can be multiple udfMetadata associated with the single name.
+ private final Map<String, List<UdfMetadata>> udfMetadata;
/**
* The variable that are not shared among all cloned instance of {@link SamzaSqlExecutionContext}
@@ -52,8 +57,11 @@ public class SamzaSqlExecutionContext implements Cloneable {
public SamzaSqlExecutionContext(SamzaSqlApplicationConfig config) {
this.sqlConfig = config;
- udfMetadata =
- this.sqlConfig.getUdfMetadata().stream().collect(Collectors.toMap(UdfMetadata::getName, Function.identity()));
+ udfMetadata = new HashMap<>();
+ for(UdfMetadata udf : this.sqlConfig.getUdfMetadata()) {
+ udfMetadata.putIfAbsent(udf.getName(), new ArrayList<>());
+ udfMetadata.get(udf.getName()).add(udf);
+ }
}
public ScalarUdf getOrCreateUdf(String clazz, String udfName) {
@@ -61,7 +69,9 @@ public class SamzaSqlExecutionContext implements Cloneable {
}
public ScalarUdf createInstance(String clazz, String udfName) {
- Config udfConfig = udfMetadata.get(udfName).getUdfConfig();
+
+ // Configs should be same for all the UDF methods within a UDF. Hence taking the first one.
+ Config udfConfig = udfMetadata.get(udfName).get(0).getUdfConfig();
ScalarUdf scalarUdf = ReflectionUtils.createInstance(clazz);
if (scalarUdf == null) {
String msg = String.format("Couldn't create udf %s of class %s", udfName, clazz);
diff --git a/samza-sql/src/main/java/org/apache/samza/sql/fn/BuildOutputRecordUdf.java b/samza-sql/src/main/java/org/apache/samza/sql/fn/BuildOutputRecordUdf.java
index dc928ab..e0c34f1 100644
--- a/samza-sql/src/main/java/org/apache/samza/sql/fn/BuildOutputRecordUdf.java
+++ b/samza-sql/src/main/java/org/apache/samza/sql/fn/BuildOutputRecordUdf.java
@@ -61,13 +61,13 @@ import org.apache.samza.sql.udfs.ScalarUdf;
* If no args is provided, it returns an empty SamzaSqlRelRecord (with empty field names and values list).
*/
-@SamzaSqlUdf(name="BuildOutputRecord")
+@SamzaSqlUdf(name = "BuildOutputRecord", description = "Creates an Output record.")
public class BuildOutputRecordUdf implements ScalarUdf {
@Override
public void init(Config udfConfig) {
}
- @SamzaSqlUdfMethod
+ @SamzaSqlUdfMethod(disableArgumentCheck = true)
public SamzaSqlRelRecord execute(Object... args) {
int numOfArgs = args.length;
Validate.isTrue(numOfArgs % 2 == 0, "numOfArgs should be an even number");
diff --git a/samza-sql/src/main/java/org/apache/samza/sql/fn/ConvertToStringUdf.java b/samza-sql/src/main/java/org/apache/samza/sql/fn/ConvertToStringUdf.java
index dc482d8..659f7e3 100644
--- a/samza-sql/src/main/java/org/apache/samza/sql/fn/ConvertToStringUdf.java
+++ b/samza-sql/src/main/java/org/apache/samza/sql/fn/ConvertToStringUdf.java
@@ -20,6 +20,7 @@
package org.apache.samza.sql.fn;
import org.apache.samza.config.Config;
+import org.apache.samza.sql.schema.SamzaSqlFieldType;
import org.apache.samza.sql.udfs.SamzaSqlUdf;
import org.apache.samza.sql.udfs.SamzaSqlUdfMethod;
import org.apache.samza.sql.udfs.ScalarUdf;
@@ -28,15 +29,15 @@ import org.apache.samza.sql.udfs.ScalarUdf;
/**
* UDF that converts an object to it's string representation.
*/
-@SamzaSqlUdf(name = "convertToString")
+@SamzaSqlUdf(name = "convertToString", description = "Converts the object to string.")
public class ConvertToStringUdf implements ScalarUdf {
@Override
public void init(Config udfConfig) {
}
- @SamzaSqlUdfMethod
- public String execute(Object... args) {
- return args[0].toString();
+ @SamzaSqlUdfMethod(params = SamzaSqlFieldType.ANY)
+ public String execute(Object args) {
+ return args.toString();
}
}
diff --git a/samza-sql/src/main/java/org/apache/samza/sql/fn/FlattenUdf.java b/samza-sql/src/main/java/org/apache/samza/sql/fn/FlattenUdf.java
index fa3d15e..0734a3a 100644
--- a/samza-sql/src/main/java/org/apache/samza/sql/fn/FlattenUdf.java
+++ b/samza-sql/src/main/java/org/apache/samza/sql/fn/FlattenUdf.java
@@ -21,20 +21,20 @@ package org.apache.samza.sql.fn;
import java.util.List;
import org.apache.samza.config.Config;
+import org.apache.samza.sql.schema.SamzaSqlFieldType;
import org.apache.samza.sql.udfs.SamzaSqlUdf;
import org.apache.samza.sql.udfs.SamzaSqlUdfMethod;
import org.apache.samza.sql.udfs.ScalarUdf;
-@SamzaSqlUdf(name = "Flatten")
+@SamzaSqlUdf(name = "Flatten", description = "Flattens the array.")
public class FlattenUdf implements ScalarUdf {
@Override
public void init(Config udfConfig) {
}
- @SamzaSqlUdfMethod
- public Object execute(Object... arg) {
- List value = (List) arg[0];
+ @SamzaSqlUdfMethod(params = SamzaSqlFieldType.ARRAY)
+ public Object execute(List value) {
return value != null && !value.isEmpty() ? value.get(0) : value;
}
}
\ No newline at end of file
diff --git a/samza-sql/src/main/java/org/apache/samza/sql/fn/GetSqlFieldUdf.java b/samza-sql/src/main/java/org/apache/samza/sql/fn/GetSqlFieldUdf.java
index 8f5704c..de56fa0 100644
--- a/samza-sql/src/main/java/org/apache/samza/sql/fn/GetSqlFieldUdf.java
+++ b/samza-sql/src/main/java/org/apache/samza/sql/fn/GetSqlFieldUdf.java
@@ -24,6 +24,7 @@ import java.util.Map;
import org.apache.commons.lang.Validate;
import org.apache.samza.config.Config;
import org.apache.samza.sql.SamzaSqlRelRecord;
+import org.apache.samza.sql.schema.SamzaSqlFieldType;
import org.apache.samza.sql.udfs.SamzaSqlUdf;
import org.apache.samza.sql.udfs.SamzaSqlUdfMethod;
import org.apache.samza.sql.udfs.ScalarUdf;
@@ -51,22 +52,21 @@ import org.apache.samza.sql.udfs.ScalarUdf;
* - sessionKey (Scalar)
*
*/
-@SamzaSqlUdf(name = "GetSqlField")
+@SamzaSqlUdf(name = "GetSqlField", description = "Get an element from complex Sql field as a String.")
public class GetSqlFieldUdf implements ScalarUdf {
@Override
public void init(Config udfConfig) {
}
- @SamzaSqlUdfMethod
- public String execute(Object... args) {
- Object currentFieldOrValue = args[0];
+ @SamzaSqlUdfMethod(params = {SamzaSqlFieldType.ANY, SamzaSqlFieldType.STRING})
+ public String execute(Object field, String fieldName) {
+ Object currentFieldOrValue = field;
Validate.isTrue(currentFieldOrValue == null
|| currentFieldOrValue instanceof SamzaSqlRelRecord);
- if (currentFieldOrValue != null && args.length > 1) {
- String[] fieldNameChain = ((String) args[1]).split("\\.");
- for (int i = 0; i < fieldNameChain.length && currentFieldOrValue != null; i++) {
- currentFieldOrValue = extractField(fieldNameChain[i], currentFieldOrValue);
- }
+
+ String[] fieldNameChain = fieldName.split("\\.");
+ for (int i = 0; i < fieldNameChain.length && currentFieldOrValue != null; i++) {
+ currentFieldOrValue = extractField(fieldNameChain[i], currentFieldOrValue);
}
if (currentFieldOrValue != null) {
diff --git a/samza-sql/src/main/java/org/apache/samza/sql/fn/RegexMatchUdf.java b/samza-sql/src/main/java/org/apache/samza/sql/fn/RegexMatchUdf.java
index 00b5775..c157112 100644
--- a/samza-sql/src/main/java/org/apache/samza/sql/fn/RegexMatchUdf.java
+++ b/samza-sql/src/main/java/org/apache/samza/sql/fn/RegexMatchUdf.java
@@ -21,6 +21,7 @@ package org.apache.samza.sql.fn;
import java.util.regex.Pattern;
import org.apache.samza.config.Config;
+import org.apache.samza.sql.schema.SamzaSqlFieldType;
import org.apache.samza.sql.udfs.SamzaSqlUdf;
import org.apache.samza.sql.udfs.SamzaSqlUdfMethod;
import org.apache.samza.sql.udfs.ScalarUdf;
@@ -29,15 +30,15 @@ import org.apache.samza.sql.udfs.ScalarUdf;
/**
* Simple RegexMatch Udf.
*/
-@SamzaSqlUdf(name="RegexMatch")
+@SamzaSqlUdf(name="RegexMatch", description = "Function to perform the regex match.")
public class RegexMatchUdf implements ScalarUdf {
@Override
public void init(Config config) {
}
- @SamzaSqlUdfMethod
- public Boolean execute(Object... args) {
- return Pattern.matches((String) args[0], (String) args[1]);
+ @SamzaSqlUdfMethod(params = {SamzaSqlFieldType.STRING, SamzaSqlFieldType.STRING})
+ public Boolean match(String regexPattern, String input) {
+ return Pattern.matches(regexPattern, input);
}
}
diff --git a/samza-sql/src/main/java/org/apache/samza/sql/impl/ConfigBasedUdfResolver.java b/samza-sql/src/main/java/org/apache/samza/sql/impl/ConfigBasedUdfResolver.java
index d21c1a6..1319a85 100644
--- a/samza-sql/src/main/java/org/apache/samza/sql/impl/ConfigBasedUdfResolver.java
+++ b/samza-sql/src/main/java/org/apache/samza/sql/impl/ConfigBasedUdfResolver.java
@@ -23,7 +23,9 @@ import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
+import java.util.HashMap;
import java.util.List;
+import java.util.Map;
import java.util.Properties;
import java.util.stream.Collectors;
import org.apache.commons.lang.StringUtils;
@@ -31,6 +33,7 @@ import org.apache.samza.SamzaException;
import org.apache.samza.config.Config;
import org.apache.samza.sql.interfaces.UdfMetadata;
import org.apache.samza.sql.interfaces.UdfResolver;
+import org.apache.samza.sql.schema.SamzaSqlFieldType;
import org.apache.samza.sql.udfs.SamzaSqlUdf;
import org.apache.samza.sql.udfs.SamzaSqlUdfMethod;
import org.apache.samza.sql.udfs.ScalarUdf;
@@ -74,16 +77,15 @@ public class ConfigBasedUdfResolver implements UdfResolver {
}
SamzaSqlUdf sqlUdf;
+ Map<SamzaSqlUdfMethod, Method> udfMethods = new HashMap<>();
SamzaSqlUdfMethod sqlUdfMethod = null;
- Method udfMethod = null;
sqlUdf = udfClass.getAnnotation(SamzaSqlUdf.class);
Method[] methods = udfClass.getMethods();
for (Method method : methods) {
sqlUdfMethod = method.getAnnotation(SamzaSqlUdfMethod.class);
if (sqlUdfMethod != null) {
- udfMethod = method;
- break;
+ udfMethods.put(sqlUdfMethod, method);
}
}
@@ -93,7 +95,7 @@ public class ConfigBasedUdfResolver implements UdfResolver {
throw new SamzaException(msg);
}
- if (sqlUdfMethod == null) {
+ if (udfMethods.isEmpty()) {
String msg = String.format("UdfClass %s doesn't have any methods annotated with SamzaSqlUdfMethod", udfClass);
LOG.error(msg);
throw new SamzaException(msg);
@@ -101,7 +103,11 @@ public class ConfigBasedUdfResolver implements UdfResolver {
if (sqlUdf.enabled()) {
String udfName = sqlUdf.name();
- udfs.add(new UdfMetadata(udfName, udfMethod, udfConfig.subset(udfName + ".")));
+ for (Map.Entry<SamzaSqlUdfMethod, Method> udfMethod : udfMethods.entrySet()) {
+ List<SamzaSqlFieldType> params = Arrays.asList(udfMethod.getKey().params());
+ udfs.add(new UdfMetadata(udfName, udfMethod.getValue(), udfConfig.subset(udfName + "."), params,
+ udfMethod.getKey().disableArgumentCheck()));
+ }
}
}
}
diff --git a/samza-sql/src/main/java/org/apache/samza/sql/interfaces/UdfMetadata.java b/samza-sql/src/main/java/org/apache/samza/sql/interfaces/UdfMetadata.java
index b1a2d6d..4adb5ea 100644
--- a/samza-sql/src/main/java/org/apache/samza/sql/interfaces/UdfMetadata.java
+++ b/samza-sql/src/main/java/org/apache/samza/sql/interfaces/UdfMetadata.java
@@ -21,7 +21,9 @@ package org.apache.samza.sql.interfaces;
import java.lang.reflect.Method;
+import java.util.List;
import org.apache.samza.config.Config;
+import org.apache.samza.sql.schema.SamzaSqlFieldType;
/**
@@ -30,15 +32,18 @@ import org.apache.samza.config.Config;
public class UdfMetadata {
private final String name;
-
private final Method udfMethod;
-
private final Config udfConfig;
+ private final boolean disableArgCheck;
+ private final List<SamzaSqlFieldType> arguments;
- public UdfMetadata(String name, Method udfMethod, Config udfConfig) {
+ public UdfMetadata(String name, Method udfMethod, Config udfConfig, List<SamzaSqlFieldType> arguments,
+ boolean disableArgCheck) {
this.name = name;
this.udfMethod = udfMethod;
this.udfConfig = udfConfig;
+ this.arguments = arguments;
+ this.disableArgCheck = disableArgCheck;
}
public Config getUdfConfig() {
@@ -58,4 +63,19 @@ public class UdfMetadata {
public String getName() {
return name;
}
+
+ /**
+ * @return Returns the list of arguments that the udf should take.
+ */
+ public List<SamzaSqlFieldType> getArguments() {
+ return arguments;
+ }
+
+ /**
+ * @return Returns whether the argument check needs to be disabled.
+ */
+ public boolean isDisableArgCheck() {
+ return disableArgCheck;
+ }
+
}
diff --git a/samza-sql/src/main/java/org/apache/samza/sql/planner/QueryPlanner.java b/samza-sql/src/main/java/org/apache/samza/sql/planner/QueryPlanner.java
index b860b20..8bccc2e 100644
--- a/samza-sql/src/main/java/org/apache/samza/sql/planner/QueryPlanner.java
+++ b/samza-sql/src/main/java/org/apache/samza/sql/planner/QueryPlanner.java
@@ -113,7 +113,7 @@ public class QueryPlanner {
}
List<SamzaSqlScalarFunctionImpl> samzaSqlFunctions = udfMetadata.stream()
- .map(x -> new SamzaSqlScalarFunctionImpl(x.getName(), x.getUdfMethod()))
+ .map(x -> new SamzaSqlScalarFunctionImpl(x))
.collect(Collectors.toList());
final List<RelTraitDef> traitDefs = new ArrayList<>();
diff --git a/samza-sql/src/main/java/org/apache/samza/sql/planner/SamzaSqlScalarFunctionImpl.java b/samza-sql/src/main/java/org/apache/samza/sql/planner/SamzaSqlScalarFunctionImpl.java
index 6894c86..c5d0121 100644
--- a/samza-sql/src/main/java/org/apache/samza/sql/planner/SamzaSqlScalarFunctionImpl.java
+++ b/samza-sql/src/main/java/org/apache/samza/sql/planner/SamzaSqlScalarFunctionImpl.java
@@ -35,22 +35,27 @@ import org.apache.calcite.schema.ImplementableFunction;
import org.apache.calcite.schema.ScalarFunction;
import org.apache.calcite.schema.impl.ScalarFunctionImpl;
import org.apache.samza.sql.data.SamzaSqlExecutionContext;
+import org.apache.samza.sql.interfaces.UdfMetadata;
import org.apache.samza.sql.udfs.ScalarUdf;
-
+/**
+ * Calcite implementation for Samza SQL UDF.
+ * This class contains logic to generate the java code to execute {@link org.apache.samza.sql.udfs.SamzaSqlUdf}.
+ */
public class SamzaSqlScalarFunctionImpl implements ScalarFunction, ImplementableFunction {
private final ScalarFunction myIncFunction;
private final Method udfMethod;
private final Method getUdfMethod;
-
-
private final String udfName;
+ private final UdfMetadata udfMetadata;
+
+ public SamzaSqlScalarFunctionImpl(UdfMetadata udfMetadata) {
- public SamzaSqlScalarFunctionImpl(String udfName, Method udfMethod) {
- myIncFunction = ScalarFunctionImpl.create(udfMethod);
- this.udfName = udfName;
- this.udfMethod = udfMethod;
+ myIncFunction = ScalarFunctionImpl.create(udfMetadata.getUdfMethod());
+ this.udfMetadata = udfMetadata;
+ this.udfName = udfMetadata.getName();
+ this.udfMethod = udfMetadata.getUdfMethod();
this.getUdfMethod = Arrays.stream(SamzaSqlExecutionContext.class.getMethods())
.filter(x -> x.getName().equals("getOrCreateUdf"))
.findFirst()
@@ -61,17 +66,23 @@ public class SamzaSqlScalarFunctionImpl implements ScalarFunction, Implementable
return udfName;
}
+ public int numberOfArguments() {
+ return udfMetadata.getArguments().size();
+ }
+
+ public UdfMetadata getUdfMetadata() {
+ return udfMetadata;
+ }
+
@Override
public CallImplementor getImplementor() {
return RexImpTable.createImplementor((translator, call, translatedOperands) -> {
final Expression context = Expressions.parameter(SamzaSqlExecutionContext.class, "context");
final Expression getUdfInstance = Expressions.call(ScalarUdf.class, context, getUdfMethod,
Expressions.constant(udfMethod.getDeclaringClass().getName()), Expressions.constant(udfName));
- final Expression callExpression = Expressions.convert_(Expressions.call(Expressions.convert_(getUdfInstance, udfMethod.getDeclaringClass()), udfMethod,
- translatedOperands), Object.class);
- // The Janino compiler which is used to compile the expressions doesn't seem to understand the Type of the ScalarUdf.execute
- // because it is a generic. To work around that we are explicitly casting it to the return type.
- return Expressions.convert_(callExpression, udfMethod.getReturnType());
+ final Expression callExpression = Expressions.call(Expressions.convert_(getUdfInstance, udfMethod.getDeclaringClass()), udfMethod,
+ translatedOperands);
+ return callExpression;
}, NullPolicy.NONE, false);
}
diff --git a/samza-sql/src/main/java/org/apache/samza/sql/planner/SamzaSqlUdfOperatorTable.java b/samza-sql/src/main/java/org/apache/samza/sql/planner/SamzaSqlUdfOperatorTable.java
index 476e9b0..6ee10f8 100644
--- a/samza-sql/src/main/java/org/apache/samza/sql/planner/SamzaSqlUdfOperatorTable.java
+++ b/samza-sql/src/main/java/org/apache/samza/sql/planner/SamzaSqlUdfOperatorTable.java
@@ -1,27 +1,26 @@
/*
-* Licensed to the Apache Software Foundation (ASF) under one
-* or more contributor license agreements. See the NOTICE file
-* distributed with this work for additional information
-* regarding copyright ownership. The ASF licenses this file
-* to you under the Apache License, Version 2.0 (the
-* "License"); you may not use this file except in compliance
-* with the License. You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing,
-* software distributed under the License is distributed on an
-* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-* KIND, either express or implied. See the License for the
-* specific language governing permissions and limitations
-* under the License.
-*/
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
package org.apache.samza.sql.planner;
import java.util.List;
import java.util.stream.Collectors;
-
import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlIdentifier;
import org.apache.calcite.sql.SqlOperator;
@@ -30,6 +29,7 @@ import org.apache.calcite.sql.SqlSyntax;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.util.ListSqlOperatorTable;
import org.apache.calcite.sql.validate.SqlUserDefinedFunction;
+import org.apache.samza.sql.interfaces.UdfMetadata;
public class SamzaSqlUdfOperatorTable implements SqlOperatorTable {
@@ -45,8 +45,18 @@ public class SamzaSqlUdfOperatorTable implements SqlOperatorTable {
}
private SqlOperator getSqlOperator(SamzaSqlScalarFunctionImpl scalarFunction) {
- return new SqlUserDefinedFunction(new SqlIdentifier(scalarFunction.getUdfName(), SqlParserPos.ZERO),
- o -> scalarFunction.getReturnType(o.getTypeFactory()), null, Checker.ANY_CHECKER, null, scalarFunction);
+ int numArguments = scalarFunction.numberOfArguments();
+ UdfMetadata udfMetadata = scalarFunction.getUdfMetadata();
+
+ if(udfMetadata.isDisableArgCheck()) {
+ return new SqlUserDefinedFunction(new SqlIdentifier(scalarFunction.getUdfName(), SqlParserPos.ZERO),
+ o -> scalarFunction.getReturnType(o.getTypeFactory()), null, Checker.ANY_CHECKER,
+ null, scalarFunction);
+ } else {
+ return new SqlUserDefinedFunction(new SqlIdentifier(scalarFunction.getUdfName(), SqlParserPos.ZERO),
+ o -> scalarFunction.getReturnType(o.getTypeFactory()), null, Checker.getChecker(numArguments, numArguments),
+ null, scalarFunction);
+ }
}
@Override
diff --git a/samza-sql/src/test/java/org/apache/samza/sql/runner/TestSamzaSqlApplicationConfig.java b/samza-sql/src/test/java/org/apache/samza/sql/runner/TestSamzaSqlApplicationConfig.java
index 8d2c588..c6fb357 100644
--- a/samza-sql/src/test/java/org/apache/samza/sql/runner/TestSamzaSqlApplicationConfig.java
+++ b/samza-sql/src/test/java/org/apache/samza/sql/runner/TestSamzaSqlApplicationConfig.java
@@ -55,7 +55,8 @@ public class TestSamzaSqlApplicationConfig {
.collect(Collectors.toList()),
queryInfo.stream().map(SamzaSqlQueryParser.QueryInfo::getSink).collect(Collectors.toList()));
- Assert.assertEquals(numUdfs, samzaSqlApplicationConfig.getUdfMetadata().size());
+ // Two of the UDFs has an overload, hence + 1.
+ Assert.assertEquals(numUdfs + 2, samzaSqlApplicationConfig.getUdfMetadata().size());
Assert.assertEquals(1, samzaSqlApplicationConfig.getInputSystemStreamConfigBySource().size());
Assert.assertEquals(1, samzaSqlApplicationConfig.getOutputSystemStreamConfigsBySource().size());
}
diff --git a/samza-sql/src/test/java/org/apache/samza/sql/util/MyTestArrayUdf.java b/samza-sql/src/test/java/org/apache/samza/sql/util/MyTestArrayUdf.java
index c71813b..7f6ee50 100644
--- a/samza-sql/src/test/java/org/apache/samza/sql/util/MyTestArrayUdf.java
+++ b/samza-sql/src/test/java/org/apache/samza/sql/util/MyTestArrayUdf.java
@@ -23,20 +23,20 @@ import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.samza.config.Config;
+import org.apache.samza.sql.schema.SamzaSqlFieldType;
import org.apache.samza.sql.udfs.SamzaSqlUdf;
import org.apache.samza.sql.udfs.SamzaSqlUdfMethod;
import org.apache.samza.sql.udfs.ScalarUdf;
-@SamzaSqlUdf(name = "MyTestArray")
+@SamzaSqlUdf(name = "MyTestArray", description = "Test udf that returns an array")
public class MyTestArrayUdf implements ScalarUdf {
@Override
public void init(Config udfConfig) {
}
- @SamzaSqlUdfMethod
- public List<String> execute(Object... args) {
- Integer value = (Integer) args[0];
+ @SamzaSqlUdfMethod(params = SamzaSqlFieldType.INT32)
+ public List<String> execute(Integer value) {
return IntStream.range(0, value).mapToObj(String::valueOf).collect(Collectors.toList());
}
}
diff --git a/samza-sql/src/test/java/org/apache/samza/sql/util/MyTestUdf.java b/samza-sql/src/test/java/org/apache/samza/sql/util/MyTestPolyUdf.java
similarity index 72%
copy from samza-sql/src/test/java/org/apache/samza/sql/util/MyTestUdf.java
copy to samza-sql/src/test/java/org/apache/samza/sql/util/MyTestPolyUdf.java
index 6b714b4..f4afbd6 100644
--- a/samza-sql/src/test/java/org/apache/samza/sql/util/MyTestUdf.java
+++ b/samza-sql/src/test/java/org/apache/samza/sql/util/MyTestPolyUdf.java
@@ -16,10 +16,10 @@
* specific language governing permissions and limitations
* under the License.
*/
-
package org.apache.samza.sql.util;
import org.apache.samza.config.Config;
+import org.apache.samza.sql.schema.SamzaSqlFieldType;
import org.apache.samza.sql.udfs.SamzaSqlUdf;
import org.apache.samza.sql.udfs.SamzaSqlUdfMethod;
import org.apache.samza.sql.udfs.ScalarUdf;
@@ -28,22 +28,25 @@ import org.slf4j.LoggerFactory;
/**
- * Test UDF used by unit and integration tests.
+ * UDF to test polymorphism.
*/
-@SamzaSqlUdf(name = "MyTest")
-public class MyTestUdf implements ScalarUdf {
+@SamzaSqlUdf(name = "MyTestPoly", description = "Test Polymorphism UDF.")
+public class MyTestPolyUdf implements ScalarUdf {
+ private static final Logger LOG = LoggerFactory.getLogger(MyTestPolyUdf.class);
- private static final Logger LOG = LoggerFactory.getLogger(MyTestUdf.class);
+ @SamzaSqlUdfMethod(params = SamzaSqlFieldType.INT32)
+ public Integer execute(Integer value) {
+ return value * 2;
+ }
- @SamzaSqlUdfMethod
- public Integer execute(Object... value) {
- return ((Integer) value[0]) * 2;
+ @SamzaSqlUdfMethod(params = SamzaSqlFieldType.ANY)
+ public Integer execute(String value) {
+ return value.length() * 2;
}
+
@Override
public void init(Config udfConfig) {
LOG.info("Init called with {}", udfConfig);
}
}
-
-
diff --git a/samza-sql/src/test/java/org/apache/samza/sql/util/MyTestUdf.java b/samza-sql/src/test/java/org/apache/samza/sql/util/MyTestUdf.java
index 6b714b4..35a44e3 100644
--- a/samza-sql/src/test/java/org/apache/samza/sql/util/MyTestUdf.java
+++ b/samza-sql/src/test/java/org/apache/samza/sql/util/MyTestUdf.java
@@ -20,6 +20,7 @@
package org.apache.samza.sql.util;
import org.apache.samza.config.Config;
+import org.apache.samza.sql.schema.SamzaSqlFieldType;
import org.apache.samza.sql.udfs.SamzaSqlUdf;
import org.apache.samza.sql.udfs.SamzaSqlUdfMethod;
import org.apache.samza.sql.udfs.ScalarUdf;
@@ -30,16 +31,22 @@ import org.slf4j.LoggerFactory;
/**
* Test UDF used by unit and integration tests.
*/
-@SamzaSqlUdf(name = "MyTest")
+@SamzaSqlUdf(name = "MyTest", description = "Test UDF.")
public class MyTestUdf implements ScalarUdf {
private static final Logger LOG = LoggerFactory.getLogger(MyTestUdf.class);
- @SamzaSqlUdfMethod
- public Integer execute(Object... value) {
- return ((Integer) value[0]) * 2;
+ @SamzaSqlUdfMethod(params = SamzaSqlFieldType.INT32)
+ public Integer execute(Integer value) {
+ return value * 2;
}
+ @SamzaSqlUdfMethod(params = SamzaSqlFieldType.ANY)
+ public Integer execute(Object value) {
+ return ((Integer) value) * 2;
+ }
+
+
@Override
public void init(Config udfConfig) {
LOG.info("Init called with {}", udfConfig);
diff --git a/samza-sql/src/test/java/org/apache/samza/sql/util/SamzaSqlTestConfig.java b/samza-sql/src/test/java/org/apache/samza/sql/util/SamzaSqlTestConfig.java
index 19a8638..627dc65 100644
--- a/samza-sql/src/test/java/org/apache/samza/sql/util/SamzaSqlTestConfig.java
+++ b/samza-sql/src/test/java/org/apache/samza/sql/util/SamzaSqlTestConfig.java
@@ -96,7 +96,7 @@ public class SamzaSqlTestConfig {
ConfigBasedUdfResolver.class.getName());
staticConfigs.put(configUdfResolverDomain + ConfigBasedUdfResolver.CFG_UDF_CLASSES, Joiner.on(",")
.join(MyTestUdf.class.getName(), RegexMatchUdf.class.getName(), FlattenUdf.class.getName(),
- MyTestArrayUdf.class.getName(), BuildOutputRecordUdf.class.getName()));
+ MyTestArrayUdf.class.getName(), BuildOutputRecordUdf.class.getName(), MyTestPolyUdf.class.getName()));
String avroSystemConfigPrefix =
String.format(ConfigBasedIOResolverFactory.CFG_FMT_SAMZA_PREFIX, SAMZA_SYSTEM_TEST_AVRO);
diff --git a/samza-test/src/test/java/org/apache/samza/test/samzasql/TestSamzaSqlEndToEnd.java b/samza-test/src/test/java/org/apache/samza/test/samzasql/TestSamzaSqlEndToEnd.java
index 76119e4..9deb561 100644
--- a/samza-test/src/test/java/org/apache/samza/test/samzasql/TestSamzaSqlEndToEnd.java
+++ b/samza-test/src/test/java/org/apache/samza/test/samzasql/TestSamzaSqlEndToEnd.java
@@ -369,6 +369,30 @@ public class TestSamzaSqlEndToEnd extends SamzaSqlIntegrationTestHarness {
}
@Test
+ public void testEndToEndUdfPolymorphism() throws Exception {
+ int numMessages = 20;
+ TestAvroSystemFactory.messages.clear();
+ Map<String, String> staticConfigs = SamzaSqlTestConfig.fetchStaticConfigsWithFactories(configs, numMessages);
+ String sql1 = "Insert into testavro.outputTopic(id, long_value) "
+ + "select MyTestPoly(id) as long_value, MyTestPoly(name) as id from testavro.SIMPLE1";
+ List<String> sqlStmts = Collections.singletonList(sql1);
+ staticConfigs.put(SamzaSqlApplicationConfig.CFG_SQL_STMTS_JSON, JsonUtil.toJson(sqlStmts));
+ runApplication(new MapConfig(staticConfigs));
+
+ LOG.info("output Messages " + TestAvroSystemFactory.messages);
+
+ List<Integer> outMessages = TestAvroSystemFactory.messages.stream()
+ .map(x -> Integer.valueOf(((GenericRecord) x.getMessage()).get("long_value").toString()))
+ .sorted()
+ .collect(Collectors.toList());
+ Assert.assertEquals(outMessages.size(), numMessages);
+ MyTestUdf udf = new MyTestUdf();
+
+ Assert.assertTrue(
+ IntStream.range(0, numMessages).map(udf::execute).boxed().collect(Collectors.toList()).equals(outMessages));
+ }
+
+ @Test
public void testRegexMatchUdfInWhereClause() throws Exception {
int numMessages = 20;
TestAvroSystemFactory.messages.clear();