You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@beam.apache.org by GitBox <gi...@apache.org> on 2021/12/28 17:18:12 UTC

[GitHub] [beam] ibzib commented on a change in pull request #16200: [BEAM-11808][BEAM-9879] Support aggregate functions with two arguments

ibzib commented on a change in pull request #16200:
URL: https://github.com/apache/beam/pull/16200#discussion_r771014170



##########
File path: sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/AggregateScanConverter.java
##########
@@ -148,24 +149,27 @@ private LogicalProject convertAggregateScanInputScanToLogicalProject(
       // aggregation?
       ResolvedAggregateFunctionCall aggregateFunctionCall =
           ((ResolvedAggregateFunctionCall) resolvedComputedColumn.getExpr());
-      if (aggregateFunctionCall.getArgumentList() != null
-          && aggregateFunctionCall.getArgumentList().size() == 1) {
-        ResolvedExpr resolvedExpr = aggregateFunctionCall.getArgumentList().get(0);
-
-        // TODO: assume aggregate function's input is either a ColumnRef or a cast(ColumnRef).
-        // TODO: user might use multiple CAST so we need to handle this rare case.
-        projects.add(
-            getExpressionConverter()
-                .convertRexNodeFromResolvedExpr(
-                    resolvedExpr,
-                    node.getInputScan().getColumnList(),
-                    input.getRowType().getFieldList(),
-                    ImmutableMap.of()));
-        fieldNames.add(getTrait().resolveAlias(resolvedComputedColumn.getColumn()));
-      } else if (aggregateFunctionCall.getArgumentList() != null
-          && aggregateFunctionCall.getArgumentList().size() > 1) {
-        throw new IllegalArgumentException(
-            aggregateFunctionCall.getFunction().getName() + " has more than one argument.");
+      ImmutableList<ResolvedExpr> argumentList =

Review comment:
       Why do we need to copy to an ImmutableList?

##########
File path: sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperatorMappingTable.java
##########
@@ -105,4 +113,16 @@
           .put("nullif", new SqlNullIfOperatorRewriter())
           .put("$in", new SqlInOperatorRewriter())
           .build();
+
+  public static @Nullable SqlOperator create(

Review comment:
       Nit: this can/should be made package-private
   ```suggestion
     static @Nullable SqlOperator create(
   ```

##########
File path: sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/AggregateScanConverter.java
##########
@@ -248,6 +249,8 @@ private AggregateCall convertAggCall(
           || expr.nodeKind() == RESOLVED_COLUMN_REF
           || expr.nodeKind() == RESOLVED_GET_STRUCT_FIELD) {
         argList.add(columnRefOff);
+      } else if (expr.nodeKind() == RESOLVED_LITERAL) {

Review comment:
       We should have separate cases here:
   
   if i == 0, must be one of (RESOLVED_CAST, RESOLVED_COLUMN_REF, RESOLVED_GET_STRUCT_FIELD)
   else must be RESOLVED_LITERAL

##########
File path: sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperatorMappingTable.java
##########
@@ -17,85 +17,93 @@
  */
 package org.apache.beam.sdk.extensions.sql.zetasql.translation;
 
+import com.google.zetasql.resolvedast.ResolvedNodes;
 import java.util.Map;
+import java.util.function.Function;
 import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.sql.SqlOperator;
 import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.sql.fun.SqlStdOperatorTable;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
+import org.checkerframework.checker.nullness.qual.Nullable;
 
 /** SqlOperatorMappingTable. */
 class SqlOperatorMappingTable {
 
   // todo: Some of operators defined here are later overridden in ZetaSQLPlannerImpl.
   // We should remove them from this table and add generic way to provide custom
   // implementation. (Ex.: timestamp_add)
-  static final Map<String, SqlOperator> ZETASQL_FUNCTION_TO_CALCITE_SQL_OPERATOR =
-      ImmutableMap.<String, SqlOperator>builder()
-          // grouped window function
-          .put("TUMBLE", SqlStdOperatorTable.TUMBLE_OLD)
-          .put("HOP", SqlStdOperatorTable.HOP_OLD)
-          .put("SESSION", SqlStdOperatorTable.SESSION_OLD)
+  static final Map<String, Function<ResolvedNodes.ResolvedFunctionCallBase, SqlOperator>>
+      ZETASQL_FUNCTION_TO_CALCITE_SQL_OPERATOR =
+          ImmutableMap
+              .<String, Function<ResolvedNodes.ResolvedFunctionCallBase, SqlOperator>>builder()
+              // grouped window function
+              .put("TUMBLE", resolvedFunction -> SqlStdOperatorTable.TUMBLE_OLD)
+              .put("HOP", resolvedFunction -> SqlStdOperatorTable.HOP_OLD)
+              .put("SESSION", resolvedFunction -> SqlStdOperatorTable.SESSION_OLD)
 
-          // ZetaSQL functions
-          .put("$and", SqlStdOperatorTable.AND)
-          .put("$or", SqlStdOperatorTable.OR)
-          .put("$not", SqlStdOperatorTable.NOT)
-          .put("$equal", SqlStdOperatorTable.EQUALS)
-          .put("$not_equal", SqlStdOperatorTable.NOT_EQUALS)
-          .put("$greater", SqlStdOperatorTable.GREATER_THAN)
-          .put("$greater_or_equal", SqlStdOperatorTable.GREATER_THAN_OR_EQUAL)
-          .put("$less", SqlStdOperatorTable.LESS_THAN)
-          .put("$less_or_equal", SqlStdOperatorTable.LESS_THAN_OR_EQUAL)
-          .put("$like", SqlOperators.LIKE)
-          .put("$is_null", SqlStdOperatorTable.IS_NULL)
-          .put("$is_true", SqlStdOperatorTable.IS_TRUE)
-          .put("$is_false", SqlStdOperatorTable.IS_FALSE)
-          .put("$add", SqlStdOperatorTable.PLUS)
-          .put("$subtract", SqlStdOperatorTable.MINUS)
-          .put("$multiply", SqlStdOperatorTable.MULTIPLY)
-          .put("$unary_minus", SqlStdOperatorTable.UNARY_MINUS)
-          .put("$divide", SqlStdOperatorTable.DIVIDE)
-          .put("concat", SqlOperators.CONCAT)
-          .put("substr", SqlOperators.SUBSTR)
-          .put("substring", SqlOperators.SUBSTR)
-          .put("trim", SqlOperators.TRIM)
-          .put("replace", SqlOperators.REPLACE)
-          .put("char_length", SqlOperators.CHAR_LENGTH)
-          .put("starts_with", SqlOperators.START_WITHS)
-          .put("ends_with", SqlOperators.ENDS_WITH)
-          .put("ltrim", SqlOperators.LTRIM)
-          .put("rtrim", SqlOperators.RTRIM)
-          .put("reverse", SqlOperators.REVERSE)
-          .put("$count_star", SqlStdOperatorTable.COUNT)
-          .put("max", SqlStdOperatorTable.MAX)
-          .put("min", SqlStdOperatorTable.MIN)
-          .put("avg", SqlStdOperatorTable.AVG)
-          .put("sum", SqlStdOperatorTable.SUM)
-          .put("any_value", SqlStdOperatorTable.ANY_VALUE)
-          .put("count", SqlStdOperatorTable.COUNT)
-          .put("bit_and", SqlStdOperatorTable.BIT_AND)
-          .put("string_agg", SqlOperators.STRING_AGG_STRING_FN) // NULL values not supported
-          .put("array_agg", SqlOperators.ARRAY_AGG_FN)
-          .put("bit_or", SqlStdOperatorTable.BIT_OR)
-          .put("bit_xor", SqlOperators.BIT_XOR)
-          .put("ceil", SqlStdOperatorTable.CEIL)
-          .put("floor", SqlStdOperatorTable.FLOOR)
-          .put("mod", SqlStdOperatorTable.MOD)
-          .put("timestamp", SqlOperators.TIMESTAMP_OP)
-          .put("$case_no_value", SqlStdOperatorTable.CASE)
+              // ZetaSQL functions
+              .put("$and", resolvedFunction -> SqlStdOperatorTable.AND)
+              .put("$or", resolvedFunction -> SqlStdOperatorTable.OR)
+              .put("$not", resolvedFunction -> SqlStdOperatorTable.NOT)
+              .put("$equal", resolvedFunction -> SqlStdOperatorTable.EQUALS)
+              .put("$not_equal", resolvedFunction -> SqlStdOperatorTable.NOT_EQUALS)
+              .put("$greater", resolvedFunction -> SqlStdOperatorTable.GREATER_THAN)
+              .put(
+                  "$greater_or_equal",
+                  resolvedFunction -> SqlStdOperatorTable.GREATER_THAN_OR_EQUAL)
+              .put("$less", resolvedFunction -> SqlStdOperatorTable.LESS_THAN)
+              .put("$less_or_equal", resolvedFunction -> SqlStdOperatorTable.LESS_THAN_OR_EQUAL)
+              .put("$like", resolvedFunction -> SqlOperators.LIKE)
+              .put("$is_null", resolvedFunction -> SqlStdOperatorTable.IS_NULL)
+              .put("$is_true", resolvedFunction -> SqlStdOperatorTable.IS_TRUE)
+              .put("$is_false", resolvedFunction -> SqlStdOperatorTable.IS_FALSE)
+              .put("$add", resolvedFunction -> SqlStdOperatorTable.PLUS)
+              .put("$subtract", resolvedFunction -> SqlStdOperatorTable.MINUS)
+              .put("$multiply", resolvedFunction -> SqlStdOperatorTable.MULTIPLY)
+              .put("$unary_minus", resolvedFunction -> SqlStdOperatorTable.UNARY_MINUS)
+              .put("$divide", resolvedFunction -> SqlStdOperatorTable.DIVIDE)
+              .put("concat", resolvedFunction -> SqlOperators.CONCAT)
+              .put("substr", resolvedFunction -> SqlOperators.SUBSTR)
+              .put("substring", resolvedFunction -> SqlOperators.SUBSTR)
+              .put("trim", resolvedFunction -> SqlOperators.TRIM)
+              .put("replace", resolvedFunction -> SqlOperators.REPLACE)
+              .put("char_length", resolvedFunction -> SqlOperators.CHAR_LENGTH)
+              .put("starts_with", resolvedFunction -> SqlOperators.START_WITHS)
+              .put("ends_with", resolvedFunction -> SqlOperators.ENDS_WITH)
+              .put("ltrim", resolvedFunction -> SqlOperators.LTRIM)
+              .put("rtrim", resolvedFunction -> SqlOperators.RTRIM)
+              .put("reverse", resolvedFunction -> SqlOperators.REVERSE)
+              .put("$count_star", resolvedFunction -> SqlStdOperatorTable.COUNT)
+              .put("max", resolvedFunction -> SqlStdOperatorTable.MAX)
+              .put("min", resolvedFunction -> SqlStdOperatorTable.MIN)
+              .put("avg", resolvedFunction -> SqlStdOperatorTable.AVG)
+              .put("sum", resolvedFunction -> SqlStdOperatorTable.SUM)
+              .put("any_value", resolvedFunction -> SqlStdOperatorTable.ANY_VALUE)
+              .put("count", resolvedFunction -> SqlStdOperatorTable.COUNT)
+              .put("bit_and", resolvedFunction -> SqlStdOperatorTable.BIT_AND)
+              .put("string_agg", SqlOperators::createStringAggOperator) // NULL values not supported

Review comment:
       If I understand correctly, you parameterized the string_agg operator on resolvedFunction, since unlike other functions that support multiple arguments, there's no way to make string_agg generic (because String and byte[] don't share a base class like Number). I think this is fine, and probably an inevitable change.

##########
File path: sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/AggregateScanConverter.java
##########
@@ -148,24 +149,27 @@ private LogicalProject convertAggregateScanInputScanToLogicalProject(
       // aggregation?
       ResolvedAggregateFunctionCall aggregateFunctionCall =
           ((ResolvedAggregateFunctionCall) resolvedComputedColumn.getExpr());
-      if (aggregateFunctionCall.getArgumentList() != null
-          && aggregateFunctionCall.getArgumentList().size() == 1) {
-        ResolvedExpr resolvedExpr = aggregateFunctionCall.getArgumentList().get(0);
-
-        // TODO: assume aggregate function's input is either a ColumnRef or a cast(ColumnRef).
-        // TODO: user might use multiple CAST so we need to handle this rare case.
-        projects.add(
-            getExpressionConverter()
-                .convertRexNodeFromResolvedExpr(
-                    resolvedExpr,
-                    node.getInputScan().getColumnList(),
-                    input.getRowType().getFieldList(),
-                    ImmutableMap.of()));
-        fieldNames.add(getTrait().resolveAlias(resolvedComputedColumn.getColumn()));
-      } else if (aggregateFunctionCall.getArgumentList() != null
-          && aggregateFunctionCall.getArgumentList().size() > 1) {
-        throw new IllegalArgumentException(
-            aggregateFunctionCall.getFunction().getName() + " has more than one argument.");
+      ImmutableList<ResolvedExpr> argumentList =
+          ImmutableList.copyOf(aggregateFunctionCall.getArgumentList());
+      if (argumentList != null && argumentList.size() >= 1) {
+        ResolvedExpr resolvedExpr = argumentList.get(0);
+        for (int i = 0; i < argumentList.size(); i++) {
+          if (i == 0) {
+            // TODO: assume aggregate function's input is either a ColumnRef or a cast(ColumnRef).
+            // TODO: user might use multiple CAST so we need to handle this rare case.
+            projects.add(
+                getExpressionConverter()
+                    .convertRexNodeFromResolvedExpr(
+                        resolvedExpr,
+                        node.getInputScan().getColumnList(),
+                        input.getRowType().getFieldList(),
+                        ImmutableMap.of()));
+          } else {
+            projects.add(
+                getExpressionConverter().convertRexNodeFromResolvedExpr(argumentList.get(i)));
+          }
+          fieldNames.add(getTrait().resolveAlias(resolvedComputedColumn.getColumn()));

Review comment:
       This doesn't look like it belongs inside the for loop.

##########
File path: sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperators.java
##########
@@ -180,6 +176,43 @@
           null,
           new CastFunctionImpl());
 
+  public static SqlOperator createStringAggOperator(
+      ResolvedNodes.ResolvedFunctionCallBase aggregateFunctionCall) {
+    List<ResolvedNodes.ResolvedExpr> args = aggregateFunctionCall.getArgumentList();
+    String inputType = args.get(0).getType().typeName();
+    Value delimiter = null;
+    if (args.size() == 2) {
+      delimiter = ((ResolvedNodes.ResolvedLiteral) args.get(1)).getValue();
+    }
+    switch (inputType) {
+      case "BYTES":
+        if (delimiter != null) {
+          return SqlOperators.createUdafOperator(
+              "string_agg",
+              x -> SqlOperators.createTypeFactory().createSqlType(SqlTypeName.VARBINARY),
+              new UdafImpl<>(new StringAgg.StringAggByte(delimiter.getBytesValue().toByteArray())));
+        }
+        return SqlOperators.createUdafOperator(
+            "string_agg",
+            x -> SqlOperators.createTypeFactory().createSqlType(SqlTypeName.VARBINARY),
+            new UdafImpl<>(new StringAgg.StringAggByte()));
+      case "STRING":
+        if (delimiter != null) {
+          return SqlOperators.createUdafOperator(
+              "string_agg",
+              x -> SqlOperators.createTypeFactory().createSqlType(SqlTypeName.VARCHAR),
+              new UdafImpl<>(new StringAgg.StringAggString(delimiter.getStringValue())));
+        }
+        return SqlOperators.createUdafOperator(
+            "string_agg",
+            x -> SqlOperators.createTypeFactory().createSqlType(SqlTypeName.VARCHAR),
+            new UdafImpl<>(new StringAgg.StringAggString()));

Review comment:
       Nit: I prefer to keep all the logic here. That way we can make the default delimiter explicit rather than implicit, and drop the extra constructor StringAggString(). And same with above for bytes.
   
   ```java
   return SqlOperators.createUdafOperator(
       "string_agg",
       x -> SqlOperators.createTypeFactory().createSqlType(SqlTypeName.VARCHAR),
       new UdafImpl<>(new StringAgg.StringAggString(delimiter == null ? "," : delimiter)));
   ```

##########
File path: sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/AggregateScanConverter.java
##########
@@ -148,24 +149,27 @@ private LogicalProject convertAggregateScanInputScanToLogicalProject(
       // aggregation?
       ResolvedAggregateFunctionCall aggregateFunctionCall =
           ((ResolvedAggregateFunctionCall) resolvedComputedColumn.getExpr());
-      if (aggregateFunctionCall.getArgumentList() != null
-          && aggregateFunctionCall.getArgumentList().size() == 1) {
-        ResolvedExpr resolvedExpr = aggregateFunctionCall.getArgumentList().get(0);
-
-        // TODO: assume aggregate function's input is either a ColumnRef or a cast(ColumnRef).
-        // TODO: user might use multiple CAST so we need to handle this rare case.
-        projects.add(
-            getExpressionConverter()
-                .convertRexNodeFromResolvedExpr(
-                    resolvedExpr,
-                    node.getInputScan().getColumnList(),
-                    input.getRowType().getFieldList(),
-                    ImmutableMap.of()));
-        fieldNames.add(getTrait().resolveAlias(resolvedComputedColumn.getColumn()));
-      } else if (aggregateFunctionCall.getArgumentList() != null
-          && aggregateFunctionCall.getArgumentList().size() > 1) {
-        throw new IllegalArgumentException(
-            aggregateFunctionCall.getFunction().getName() + " has more than one argument.");
+      ImmutableList<ResolvedExpr> argumentList =
+          ImmutableList.copyOf(aggregateFunctionCall.getArgumentList());
+      if (argumentList != null && argumentList.size() >= 1) {

Review comment:
       It seems like there is an assumption here that argumentList will never be null and will never have size 0. If that's true, can we check and throw an error for those cases?

##########
File path: sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/udaf/StringAgg.java
##########
@@ -73,4 +79,53 @@ public String extractOutput(String output) {
       return output;
     }
   }
+
+  /** A {@link CombineFn} that aggregates bytes with a byte as delimiter. */

Review comment:
       ```suggestion
     /** A {@link CombineFn} that aggregates bytes with a byte array as delimiter. */
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org