You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by sh...@apache.org on 2022/07/01 02:29:36 UTC

[flink] branch master updated: [FLINK-26361][hive] Create LogicalFilter with CorrelationId to fix failed to rewrite subquery in hive dialect (#18920)

This is an automated email from the ASF dual-hosted git repository.

shengkai pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new 202eacb2ee9 [FLINK-26361][hive] Create LogicalFilter with CorrelationId to fix failed to rewrite subquery in hive dialect (#18920)
202eacb2ee9 is described below

commit 202eacb2ee96607be8c7c8c569db62a296539e3a
Author: yuxia Luo <lu...@alumni.sjtu.edu.cn>
AuthorDate: Fri Jul 1 10:29:26 2022 +0800

    [FLINK-26361][hive] Create LogicalFilter with CorrelationId to fix failed to rewrite subquery in hive dialect (#18920)
---
 .../delegation/hive/HiveParserCalcitePlanner.java  | 15 ++++-
 .../planner/delegation/hive/HiveParserUtils.java   | 37 ++++++++++++
 .../hive/copy/HiveParserBaseSemanticAnalyzer.java  | 70 ++++++++++++++++++++++
 .../src/test/resources/query-test/sub_query.q      | 17 ++++++
 4 files changed, 136 insertions(+), 3 deletions(-)

diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/planner/delegation/hive/HiveParserCalcitePlanner.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/planner/delegation/hive/HiveParserCalcitePlanner.java
index 60f5cda3283..69f4a8d7a23 100644
--- a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/planner/delegation/hive/HiveParserCalcitePlanner.java
+++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/planner/delegation/hive/HiveParserCalcitePlanner.java
@@ -75,7 +75,6 @@ import org.apache.calcite.rel.core.SetOp;
 import org.apache.calcite.rel.core.Sort;
 import org.apache.calcite.rel.logical.LogicalAggregate;
 import org.apache.calcite.rel.logical.LogicalCorrelate;
-import org.apache.calcite.rel.logical.LogicalFilter;
 import org.apache.calcite.rel.logical.LogicalIntersect;
 import org.apache.calcite.rel.logical.LogicalJoin;
 import org.apache.calcite.rel.logical.LogicalMinus;
@@ -910,7 +909,12 @@ public class HiveParserCalcitePlanner {
         RexNode factoredFilterExpr =
                 RexUtil.pullFactors(cluster.getRexBuilder(), convertedFilterExpr)
                         .accept(funcConverter);
-        RelNode filterRel = LogicalFilter.create(srcRel, factoredFilterExpr);
+        RelNode filterRel =
+                HiveParserUtils.genFilterRelNode(
+                        srcRel,
+                        factoredFilterExpr,
+                        HiveParserBaseSemanticAnalyzer.getVariablesSetForFilter(
+                                factoredFilterExpr));
         relToRowResolver.put(filterRel, relToRowResolver.get(srcRel));
         relToHiveColNameCalcitePosMap.put(filterRel, hiveColNameToCalcitePos);
 
@@ -1070,7 +1074,12 @@ public class HiveParserCalcitePlanner {
                             .convert(subQueryExpr)
                             .accept(funcConverter);
 
-            RelNode filterRel = LogicalFilter.create(srcRel, convertedFilterLHS);
+            RelNode filterRel =
+                    HiveParserUtils.genFilterRelNode(
+                            srcRel,
+                            convertedFilterLHS,
+                            HiveParserBaseSemanticAnalyzer.getVariablesSetForFilter(
+                                    convertedFilterLHS));
 
             relToHiveColNameCalcitePosMap.put(filterRel, relToHiveColNameCalcitePosMap.get(srcRel));
             relToRowResolver.put(filterRel, relToRowResolver.get(srcRel));
diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/planner/delegation/hive/HiveParserUtils.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/planner/delegation/hive/HiveParserUtils.java
index ae682206abe..4504572afbd 100644
--- a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/planner/delegation/hive/HiveParserUtils.java
+++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/planner/delegation/hive/HiveParserUtils.java
@@ -63,6 +63,7 @@ import org.apache.calcite.rel.RelNode;
 import org.apache.calcite.rel.core.AggregateCall;
 import org.apache.calcite.rel.core.CorrelationId;
 import org.apache.calcite.rel.core.RelFactories;
+import org.apache.calcite.rel.logical.LogicalFilter;
 import org.apache.calcite.rel.logical.LogicalValues;
 import org.apache.calcite.rel.type.RelDataType;
 import org.apache.calcite.rel.type.RelDataTypeFactory;
@@ -162,6 +163,13 @@ public class HiveParserUtils {
                     "org.apache.flink.calcite.shaded.com.google.common.collect.ImmutableList");
     private static final boolean useShadedImmutableList = shadedImmutableListClz != null;
 
+    private static final Class immutableSetClz =
+            HiveReflectionUtils.tryGetClass("com.google.common.collect.ImmutableSet");
+    private static final Class shadedImmutableSetClz =
+            HiveReflectionUtils.tryGetClass(
+                    "org.apache.flink.calcite.shaded.com.google.common.collect.ImmutableSet");
+    private static final boolean useShadedImmutableSet = shadedImmutableSetClz != null;
+
     private HiveParserUtils() {}
 
     public static void removeASTChild(HiveParserASTNode node) {
@@ -317,6 +325,17 @@ public class HiveParserUtils {
         }
     }
 
+    // converts a collection to guava ImmutableSet
+    private static Object toImmutableSet(Collection collection) {
+        try {
+            Class clz = useShadedImmutableSet ? shadedImmutableSetClz : immutableSetClz;
+            return HiveReflectionUtils.invokeMethod(
+                    clz, null, "copyOf", new Class[] {Collection.class}, new Object[] {collection});
+        } catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) {
+            throw new FlinkHiveException("Failed to create immutable set", e);
+        }
+    }
+
     // creates LogicalValues node
     public static RelNode genValuesRelNode(
             RelOptCluster cluster, RelDataType rowType, List<List<RexLiteral>> rows) {
@@ -339,6 +358,24 @@ public class HiveParserUtils {
         }
     }
 
+    // creates LogicFilter node
+    public static RelNode genFilterRelNode(
+            RelNode relNode, RexNode rexNode, Collection<CorrelationId> variables) {
+        Class[] argTypes =
+                new Class[] {
+                    RelNode.class,
+                    RexNode.class,
+                    useShadedImmutableSet ? shadedImmutableSetClz : immutableSetClz
+                };
+        Method method = HiveReflectionUtils.tryGetMethod(LogicalFilter.class, "create", argTypes);
+        Preconditions.checkState(method != null, "Cannot get the method to create a LogicalFilter");
+        try {
+            return (LogicalFilter) method.invoke(null, relNode, rexNode, toImmutableSet(variables));
+        } catch (IllegalAccessException | InvocationTargetException e) {
+            throw new FlinkHiveException("Failed to create LogicalFilter", e);
+        }
+    }
+
     /** Proxy to {@link RexSubQuery#in(RelNode, com.google.common.collect.ImmutableList)}. */
     public static RexSubQuery rexSubQueryIn(RelNode relNode, Collection<RexNode> rexNodes) {
         Class[] argTypes = new Class[] {RelNode.class, null};
diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/planner/delegation/hive/copy/HiveParserBaseSemanticAnalyzer.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/planner/delegation/hive/copy/HiveParserBaseSemanticAnalyzer.java
index b4e6d4d4672..4a9fd748312 100644
--- a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/planner/delegation/hive/copy/HiveParserBaseSemanticAnalyzer.java
+++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/planner/delegation/hive/copy/HiveParserBaseSemanticAnalyzer.java
@@ -40,12 +40,17 @@ import org.antlr.runtime.tree.TreeVisitorAction;
 import org.apache.calcite.plan.RelOptCluster;
 import org.apache.calcite.rel.RelNode;
 import org.apache.calcite.rel.core.CorrelationId;
+import org.apache.calcite.rel.logical.LogicalFilter;
+import org.apache.calcite.rel.logical.LogicalJoin;
 import org.apache.calcite.rel.type.RelDataType;
 import org.apache.calcite.rex.RexBuilder;
 import org.apache.calcite.rex.RexCall;
+import org.apache.calcite.rex.RexCorrelVariable;
+import org.apache.calcite.rex.RexFieldAccess;
 import org.apache.calcite.rex.RexFieldCollation;
 import org.apache.calcite.rex.RexLiteral;
 import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.rex.RexSubQuery;
 import org.apache.calcite.rex.RexWindowBound;
 import org.apache.calcite.sql.SqlCall;
 import org.apache.calcite.sql.SqlKind;
@@ -74,6 +79,7 @@ import org.apache.hadoop.hive.ql.metadata.InvalidTableException;
 import org.apache.hadoop.hive.ql.metadata.Partition;
 import org.apache.hadoop.hive.ql.metadata.Table;
 import org.apache.hadoop.hive.ql.metadata.VirtualColumn;
+import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveFilter;
 import org.apache.hadoop.hive.ql.parse.PTFInvocationSpec.Order;
 import org.apache.hadoop.hive.ql.parse.SemanticException;
 import org.apache.hadoop.hive.ql.parse.WindowingSpec;
@@ -1862,6 +1868,70 @@ public class HiveParserBaseSemanticAnalyzer {
                 rows);
     }
 
+    /**
+     * traverse the given node to find all correlated variables, the main logic is from {@link
+     * HiveFilter#getVariablesSet()}.
+     */
+    public static Set<CorrelationId> getVariablesSetForFilter(RexNode rexNode) {
+        Set<CorrelationId> correlationVariables = new HashSet<>();
+        if (rexNode instanceof RexSubQuery) {
+            RexSubQuery rexSubQuery = (RexSubQuery) rexNode;
+            // we expect correlated variables in Filter only for now.
+            // also check case where operator has 0 inputs .e.g TableScan
+            if (rexSubQuery.rel.getInputs().isEmpty()) {
+                return correlationVariables;
+            }
+            RelNode input = rexSubQuery.rel.getInput(0);
+            while (input != null
+                    && !(input instanceof LogicalFilter)
+                    && input.getInputs().size() >= 1) {
+                // we don't expect corr vars within UNION for now
+                if (input.getInputs().size() > 1) {
+                    if (input instanceof LogicalJoin) {
+                        correlationVariables.addAll(
+                                findCorrelatedVar(((LogicalJoin) input).getCondition()));
+                    }
+                    // todo: throw Unsupported exception when the input isn't LogicalJoin and
+                    // contains correlate variables in FLINK-28317
+                    return correlationVariables;
+                }
+                input = input.getInput(0);
+            }
+            if (input instanceof LogicalFilter) {
+                correlationVariables.addAll(
+                        findCorrelatedVar(((LogicalFilter) input).getCondition()));
+            }
+            return correlationVariables;
+        }
+        // AND, NOT etc
+        if (rexNode instanceof RexCall) {
+            int numOperands = ((RexCall) rexNode).getOperands().size();
+            for (int i = 0; i < numOperands; i++) {
+                RexNode op = ((RexCall) rexNode).getOperands().get(i);
+                correlationVariables.addAll(getVariablesSetForFilter(op));
+            }
+        }
+        return correlationVariables;
+    }
+
+    private static Set<CorrelationId> findCorrelatedVar(RexNode node) {
+        Set<CorrelationId> allVars = new HashSet<>();
+        if (node instanceof RexCall) {
+            RexCall nd = (RexCall) node;
+            for (RexNode rn : nd.getOperands()) {
+                if (rn instanceof RexFieldAccess) {
+                    final RexNode ref = ((RexFieldAccess) rn).getReferenceExpr();
+                    if (ref instanceof RexCorrelVariable) {
+                        allVars.add(((RexCorrelVariable) ref).id);
+                    }
+                } else {
+                    allVars.addAll(findCorrelatedVar(rn));
+                }
+            }
+        }
+        return allVars;
+    }
+
     private static void validatePartColumnType(
             Table tbl,
             Map<String, String> partSpec,
diff --git a/flink-connectors/flink-connector-hive/src/test/resources/query-test/sub_query.q b/flink-connectors/flink-connector-hive/src/test/resources/query-test/sub_query.q
new file mode 100644
index 00000000000..26a55cb0446
--- /dev/null
+++ b/flink-connectors/flink-connector-hive/src/test/resources/query-test/sub_query.q
@@ -0,0 +1,17 @@
+-- SORT_QUERY_RESULTS
+
+select * from src where src.key in (select c.key from (select * from src b where exists (select a.key from src a where b.value = a.value)) c);
+
+[+I[1, val1], +I[2, val2], +I[3, val3]]
+
+select * from src x where x.key in (select y.key from src y where exists (select z.key from src z where y.key = z.key));
+
+[+I[1, val1], +I[2, val2], +I[3, val3]]
+
+select * from src x join src y on x.key = y.key where exists (select * from src z where z.value = x.value and z.value = y.value);
+
+[+I[1, val1, 1, val1], +I[2, val2, 2, val2], +I[3, val3, 3, val3]]
+
+select * from (select x.key from src x);
+
+[+I[1], +I[2], +I[3]]