You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ib...@apache.org on 2021/01/30 01:45:06 UTC

[beam] branch master updated: [BEAM-10925] Add rule to replace Calc with BeamCalcRel for ZetaSQL UDFs.

This is an automated email from the ASF dual-hosted git repository.

ibzib pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 5127734  [BEAM-10925] Add rule to replace Calc with BeamCalcRel for ZetaSQL UDFs.
     new a497ff2  Merge pull request #13841 from ibzib/calc-rule
5127734 is described below

commit 5127734d44e83d82776f49219bb662656e388b5b
Author: Kyle Weaver <kc...@google.com>
AuthorDate: Fri Jan 29 13:12:23 2021 -0800

    [BEAM-10925] Add rule to replace Calc with BeamCalcRel for ZetaSQL UDFs.
---
 ...taSqlCalcRule.java => BeamJavaUdfCalcRule.java} | 15 ++++----
 .../sql/zetasql/BeamZetaSqlCalcRule.java           |  2 +-
 .../sdk/extensions/sql/zetasql/SqlAnalyzer.java    |  3 ++
 .../sql/zetasql/ZetaSQLQueryPlanner.java           | 41 +++++++++++++++++++++-
 4 files changed, 52 insertions(+), 9 deletions(-)

diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamZetaSqlCalcRule.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamJavaUdfCalcRule.java
similarity index 82%
copy from sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamZetaSqlCalcRule.java
copy to sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamJavaUdfCalcRule.java
index 2e7ea0f..23d0f76 100644
--- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamZetaSqlCalcRule.java
+++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamJavaUdfCalcRule.java
@@ -17,6 +17,7 @@
  */
 package org.apache.beam.sdk.extensions.sql.zetasql;
 
+import org.apache.beam.sdk.extensions.sql.impl.rel.BeamCalcRel;
 import org.apache.beam.sdk.extensions.sql.impl.rel.BeamLogicalConvention;
 import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.plan.Convention;
 import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.plan.RelOptRule;
@@ -26,18 +27,18 @@ import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.convert.Con
 import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.core.Calc;
 import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.logical.LogicalCalc;
 
-/** A {@code ConverterRule} to replace {@link Calc} with {@link BeamZetaSqlCalcRel}. */
-public class BeamZetaSqlCalcRule extends ConverterRule {
-  public static final BeamZetaSqlCalcRule INSTANCE = new BeamZetaSqlCalcRule();
+/** {@link ConverterRule} to replace {@link Calc} with {@link BeamCalcRel}. */
+public class BeamJavaUdfCalcRule extends ConverterRule {
+  public static final BeamJavaUdfCalcRule INSTANCE = new BeamJavaUdfCalcRule();
 
-  private BeamZetaSqlCalcRule() {
+  private BeamJavaUdfCalcRule() {
     super(
-        LogicalCalc.class, Convention.NONE, BeamLogicalConvention.INSTANCE, "BeamZetaSqlCalcRule");
+        LogicalCalc.class, Convention.NONE, BeamLogicalConvention.INSTANCE, "BeamJavaUdfCalcRule");
   }
 
   @Override
   public boolean matches(RelOptRuleCall x) {
-    return true;
+    return ZetaSQLQueryPlanner.hasUdfInProjects(x);
   }
 
   @Override
@@ -45,7 +46,7 @@ public class BeamZetaSqlCalcRule extends ConverterRule {
     final Calc calc = (Calc) rel;
     final RelNode input = calc.getInput();
 
-    return new BeamZetaSqlCalcRel(
+    return new BeamCalcRel(
         calc.getCluster(),
         calc.getTraitSet().replace(BeamLogicalConvention.INSTANCE),
         RelOptRule.convert(input, input.getTraitSet().replace(BeamLogicalConvention.INSTANCE)),
diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamZetaSqlCalcRule.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamZetaSqlCalcRule.java
index 2e7ea0f..2f6c60d 100644
--- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamZetaSqlCalcRule.java
+++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamZetaSqlCalcRule.java
@@ -37,7 +37,7 @@ public class BeamZetaSqlCalcRule extends ConverterRule {
 
   @Override
   public boolean matches(RelOptRuleCall x) {
-    return true;
+    return !ZetaSQLQueryPlanner.hasUdfInProjects(x);
   }
 
   @Override
diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/SqlAnalyzer.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/SqlAnalyzer.java
index f4db1f1..4889183 100644
--- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/SqlAnalyzer.java
+++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/SqlAnalyzer.java
@@ -78,6 +78,9 @@ public class SqlAnalyzer {
    */
   public static final String ZETASQL_FUNCTION_GROUP_NAME = "ZetaSQL";
 
+  public static final String USER_DEFINED_JAVA_SCALAR_FUNCTIONS =
+      "user_defined_java_scalar_functions";
+
   private static final ImmutableSet<ResolvedNodeKind> SUPPORTED_STATEMENT_KINDS =
       ImmutableSet.of(
           RESOLVED_QUERY_STMT, RESOLVED_CREATE_FUNCTION_STMT, RESOLVED_CREATE_TABLE_FUNCTION_STMT);
diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSQLQueryPlanner.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSQLQueryPlanner.java
index b943ab3..9ca5e83 100644
--- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSQLQueryPlanner.java
+++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSQLQueryPlanner.java
@@ -36,22 +36,29 @@ import org.apache.beam.sdk.extensions.sql.impl.rel.BeamRelNode;
 import org.apache.beam.sdk.extensions.sql.impl.rule.BeamCalcRule;
 import org.apache.beam.sdk.extensions.sql.impl.rule.BeamUncollectRule;
 import org.apache.beam.sdk.extensions.sql.impl.rule.BeamUnnestRule;
+import org.apache.beam.sdk.extensions.sql.zetasql.translation.ZetaSqlScalarFunctionImpl;
 import org.apache.beam.sdk.extensions.sql.zetasql.unnest.BeamZetaSqlUncollectRule;
 import org.apache.beam.sdk.extensions.sql.zetasql.unnest.BeamZetaSqlUnnestRule;
 import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.config.CalciteConnectionConfig;
 import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.jdbc.CalciteSchema;
 import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.plan.ConventionTraitDef;
 import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.plan.RelOptRule;
+import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.plan.RelOptRuleCall;
+import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.plan.RelOptUtil;
 import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.plan.RelTraitDef;
 import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.plan.RelTraitSet;
 import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.prepare.CalciteCatalogReader;
+import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.RelNode;
 import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.RelRoot;
+import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.logical.LogicalCalc;
 import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.metadata.ChainedRelMetadataProvider;
 import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.metadata.JaninoRelMetadataProvider;
 import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.metadata.RelMetadataQuery;
 import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.rules.FilterCalcMergeRule;
 import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.rules.JoinCommuteRule;
 import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.rules.ProjectCalcMergeRule;
+import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rex.RexCall;
+import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rex.RexNode;
 import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.schema.SchemaPlus;
 import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.SqlNode;
 import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.SqlOperatorTable;
@@ -59,11 +66,14 @@ import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.fun.SqlStdO
 import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.parser.SqlParser;
 import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.parser.SqlParserImplFactory;
 import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.util.ChainedSqlOperatorTable;
+import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.validate.SqlUserDefinedFunction;
 import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.tools.FrameworkConfig;
 import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.tools.Frameworks;
 import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.tools.RuleSet;
 import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.tools.RuleSets;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 /** ZetaSQLQueryPlanner. */
 @SuppressWarnings({
@@ -71,6 +81,8 @@ import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Immutabl
   "nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402)
 })
 public class ZetaSQLQueryPlanner implements QueryPlanner {
+  private static final Logger LOG = LoggerFactory.getLogger(ZetaSQLQueryPlanner.class);
+
   private final ZetaSQLPlannerImpl plannerImpl;
 
   public ZetaSQLQueryPlanner(FrameworkConfig config) {
@@ -104,6 +116,30 @@ public class ZetaSQLQueryPlanner implements QueryPlanner {
     return modifyRuleSetsForZetaSql(BeamRuleSets.getRuleSets());
   }
 
+  /** Returns true if the argument contains any user-defined Java functions. */
+  static boolean hasUdfInProjects(RelOptRuleCall x) {
+    List<RelNode> resList = x.getRelList();
+    for (RelNode relNode : resList) {
+      if (relNode instanceof LogicalCalc) {
+        LogicalCalc logicalCalc = (LogicalCalc) relNode;
+        for (RexNode rexNode : logicalCalc.getProgram().getExprList()) {
+          if (rexNode instanceof RexCall) {
+            RexCall call = (RexCall) rexNode;
+            if (call.getOperator() instanceof SqlUserDefinedFunction) {
+              SqlUserDefinedFunction udf = (SqlUserDefinedFunction) call.op;
+              if (udf.function instanceof ZetaSqlScalarFunctionImpl) {
+                ZetaSqlScalarFunctionImpl scalarFunction = (ZetaSqlScalarFunctionImpl) udf.function;
+                return scalarFunction.functionGroup.equals(
+                    SqlAnalyzer.USER_DEFINED_JAVA_SCALAR_FUNCTIONS);
+              }
+            }
+          }
+        }
+      }
+    }
+    return false;
+  }
+
   private static Collection<RuleSet> modifyRuleSetsForZetaSql(Collection<RuleSet> ruleSets) {
     ImmutableList.Builder<RuleSet> ret = ImmutableList.builder();
     for (RuleSet ruleSet : ruleSets) {
@@ -123,6 +159,7 @@ public class ZetaSQLQueryPlanner implements QueryPlanner {
           continue;
         } else if (rule instanceof BeamCalcRule) {
           bd.add(BeamZetaSqlCalcRule.INSTANCE);
+          bd.add(BeamJavaUdfCalcRule.INSTANCE);
         } else if (rule instanceof BeamUnnestRule) {
           bd.add(BeamZetaSqlUnnestRule.INSTANCE);
         } else if (rule instanceof BeamUncollectRule) {
@@ -196,7 +233,9 @@ public class ZetaSQLQueryPlanner implements QueryPlanner {
     RelMetadataQuery.THREAD_PROVIDERS.set(
         JaninoRelMetadataProvider.of(root.rel.getCluster().getMetadataProvider()));
     root.rel.getCluster().invalidateMetadataQuery();
-    return (BeamRelNode) plannerImpl.transform(0, desiredTraits, root.rel);
+    BeamRelNode beamRelNode = (BeamRelNode) plannerImpl.transform(0, desiredTraits, root.rel);
+    LOG.info("BEAMPlan>\n" + RelOptUtil.toString(beamRelNode));
+    return beamRelNode;
   }
 
   private static FrameworkConfig defaultConfig(