You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@phoenix.apache.org by ch...@apache.org on 2020/11/24 09:36:58 UTC
[phoenix] branch 4.x updated: PHOENIX-6224 Support Correlated IN
Subquery
This is an automated email from the ASF dual-hosted git repository.
chenglei pushed a commit to branch 4.x
in repository https://gitbox.apache.org/repos/asf/phoenix.git
The following commit(s) were added to refs/heads/4.x by this push:
new a0471a6 PHOENIX-6224 Support Correlated IN Subquery
a0471a6 is described below
commit a0471a61626bab6fc39d13d6b74038a7e4af4371
Author: chenglei <ch...@apache.org>
AuthorDate: Tue Nov 24 17:36:07 2020 +0800
PHOENIX-6224 Support Correlated IN Subquery
---
.../apache/phoenix/end2end/join/SubqueryIT.java | 128 ++++++++-
.../end2end/join/SubqueryUsingSortMergeJoinIT.java | 106 +++++++
.../apache/phoenix/compile/SubqueryRewriter.java | 309 +++++++++++++++++----
.../org/apache/phoenix/parse/ParseNodeFactory.java | 11 +
.../org/apache/phoenix/parse/SelectStatement.java | 2 +-
.../apache/phoenix/compile/QueryCompilerTest.java | 129 +++++++++
6 files changed, 633 insertions(+), 52 deletions(-)
diff --git a/phoenix-core/src/it/java/org/apache/phoenix/end2end/join/SubqueryIT.java b/phoenix-core/src/it/java/org/apache/phoenix/end2end/join/SubqueryIT.java
index 0d2ade5..69e8008 100644
--- a/phoenix-core/src/it/java/org/apache/phoenix/end2end/join/SubqueryIT.java
+++ b/phoenix-core/src/it/java/org/apache/phoenix/end2end/join/SubqueryIT.java
@@ -663,7 +663,133 @@ public class SubqueryIT extends BaseJoinIT {
conn.close();
}
}
-
+
+ @Test
+ public void testCorrelatedInSubqueryBug6224() throws Exception {
+ Properties props = PropertiesUtil.deepCopy(TEST_PROPERTIES);
+ final Connection conn = DriverManager.getConnection(getUrl(), props);
+ String tableName1 = getTableName(conn, JOIN_ITEM_TABLE_FULL_NAME);
+ String tableName3 = getTableName(conn, JOIN_CUSTOMER_TABLE_FULL_NAME);
+ String tableName4 = getTableName(conn, JOIN_ORDER_TABLE_FULL_NAME);
+ try {
+ String query = "SELECT \"order_id\", name FROM " + tableName4 +
+ " o JOIN " + tableName1 +
+ " i ON o.\"item_id\" = i.\"item_id\" WHERE quantity in (SELECT max(quantity) FROM " +
+ tableName4 + " q WHERE o.\"item_id\" = q.\"item_id\")";
+ PreparedStatement statement = conn.prepareStatement(query);
+ ResultSet rs = statement.executeQuery();
+ assertTrue (rs.next());
+ assertEquals(rs.getString(1), "000000000000001");
+ assertEquals(rs.getString(2), "T1");
+ assertTrue (rs.next());
+ assertEquals(rs.getString(1), "000000000000003");
+ assertEquals(rs.getString(2), "T2");
+ assertTrue (rs.next());
+ assertEquals(rs.getString(1), "000000000000004");
+ assertEquals(rs.getString(2), "T6");
+ assertTrue (rs.next());
+ assertEquals(rs.getString(1), "000000000000005");
+ assertEquals(rs.getString(2), "T3");
+ assertFalse(rs.next());
+
+ query = "SELECT \"order_id\", name FROM " + tableName4 +
+ " o JOIN " + tableName1 +
+ " i ON o.\"item_id\" = i.\"item_id\" WHERE quantity in (SELECT max(quantity) FROM " +
+ tableName1 + " i2 JOIN " + tableName4 +
+ " q ON i2.\"item_id\" = q.\"item_id\" WHERE o.\"item_id\" = i2.\"item_id\")";
+ statement = conn.prepareStatement(query);
+ rs = statement.executeQuery();
+ assertTrue (rs.next());
+ assertEquals(rs.getString(1), "000000000000001");
+ assertEquals(rs.getString(2), "T1");
+ assertTrue (rs.next());
+ assertEquals(rs.getString(1), "000000000000003");
+ assertEquals(rs.getString(2), "T2");
+ assertTrue (rs.next());
+ assertEquals(rs.getString(1), "000000000000004");
+ assertEquals(rs.getString(2), "T6");
+ assertTrue (rs.next());
+ assertEquals(rs.getString(1), "000000000000005");
+ assertEquals(rs.getString(2), "T3");
+ assertFalse(rs.next());
+
+ query = "SELECT name from " + tableName3 +
+ " WHERE \"customer_id\" IN (SELECT \"customer_id\" FROM " +
+ tableName1 + " i JOIN " + tableName4 +
+ " o ON o.\"item_id\" = i.\"item_id\" WHERE i.name = 'T2' OR quantity in (SELECT max(quantity) FROM " +
+ tableName4 + " q WHERE o.\"item_id\" = q.\"item_id\" and q.\"item_id\" = '0000000006'))";
+ statement = conn.prepareStatement(query);
+ rs = statement.executeQuery();
+ assertTrue (rs.next());
+ assertEquals(rs.getString(1), "C2");
+ assertTrue (rs.next());
+ assertEquals(rs.getString(1), "C4");
+ assertFalse(rs.next());
+
+ query = "SELECT \"order_id\" FROM " + tableName4 +
+ " o WHERE quantity in (SELECT quantity FROM " + tableName4 +
+ " WHERE o.\"item_id\" = \"item_id\" AND \"order_id\" != '000000000000004')";
+ statement = conn.prepareStatement(query);
+ rs = statement.executeQuery();
+ assertTrue (rs.next());
+ assertEquals(rs.getString(1), "000000000000001");
+ assertTrue (rs.next());
+ assertEquals(rs.getString(1), "000000000000002");
+ assertTrue (rs.next());
+ assertEquals(rs.getString(1), "000000000000003");
+ assertTrue (rs.next());
+ assertEquals(rs.getString(1), "000000000000005");
+ assertFalse(rs.next());
+
+ query = "SELECT \"order_id\" FROM " + tableName4 +
+ " o WHERE quantity in (SELECT quantity FROM " + tableName4 +
+ " WHERE o.\"item_id\" = \"item_id\" AND \"order_id\" != '000000000000003')";
+ statement = conn.prepareStatement(query);
+ rs = statement.executeQuery();
+ assertTrue (rs.next());
+ assertEquals(rs.getString(1), "000000000000001");
+ assertTrue (rs.next());
+ assertEquals(rs.getString(1), "000000000000002");
+ assertTrue (rs.next());
+ assertEquals(rs.getString(1), "000000000000004");
+ assertTrue (rs.next());
+ assertEquals(rs.getString(1), "000000000000005");
+ assertFalse(rs.next());
+
+ query = "SELECT \"order_id\" FROM " + tableName4 +
+ " o WHERE quantity in (SELECT max(quantity) FROM " + tableName4 +
+ " WHERE o.\"item_id\" = \"item_id\" AND \"order_id\" != '000000000000004' GROUP BY \"order_id\")";
+ statement = conn.prepareStatement(query);
+ rs = statement.executeQuery();
+ assertTrue (rs.next());
+ assertEquals(rs.getString(1), "000000000000001");
+ assertTrue (rs.next());
+ assertEquals(rs.getString(1), "000000000000002");
+ assertTrue (rs.next());
+ assertEquals(rs.getString(1), "000000000000003");
+ assertTrue (rs.next());
+ assertEquals(rs.getString(1), "000000000000005");
+ assertFalse(rs.next());
+
+ query = "SELECT \"order_id\" FROM " + tableName4 +
+ " o WHERE quantity in (SELECT max(quantity) FROM " + tableName4 +
+ " WHERE o.\"item_id\" = \"item_id\" AND \"order_id\" != '000000000000003' GROUP BY \"order_id\")";
+ statement = conn.prepareStatement(query);
+ rs = statement.executeQuery();
+ assertTrue (rs.next());
+ assertEquals(rs.getString(1), "000000000000001");
+ assertTrue (rs.next());
+ assertEquals(rs.getString(1), "000000000000002");
+ assertTrue (rs.next());
+ assertEquals(rs.getString(1), "000000000000004");
+ assertTrue (rs.next());
+ assertEquals(rs.getString(1), "000000000000005");
+ assertFalse(rs.next());
+ } finally {
+ conn.close();
+ }
+ }
+
@Test
public void testAnyAllComparisonSubquery() throws Exception {
Properties props = PropertiesUtil.deepCopy(TEST_PROPERTIES);
diff --git a/phoenix-core/src/it/java/org/apache/phoenix/end2end/join/SubqueryUsingSortMergeJoinIT.java b/phoenix-core/src/it/java/org/apache/phoenix/end2end/join/SubqueryUsingSortMergeJoinIT.java
index 1b54422..9d98a04 100644
--- a/phoenix-core/src/it/java/org/apache/phoenix/end2end/join/SubqueryUsingSortMergeJoinIT.java
+++ b/phoenix-core/src/it/java/org/apache/phoenix/end2end/join/SubqueryUsingSortMergeJoinIT.java
@@ -493,6 +493,112 @@ public class SubqueryUsingSortMergeJoinIT extends BaseJoinIT {
}
@Test
+ public void testCorrelatedInSubqueryBug6224() throws Exception {
+ Properties props = PropertiesUtil.deepCopy(TEST_PROPERTIES);
+ Connection conn = DriverManager.getConnection(getUrl(), props);
+ String tableName1 = getTableName(conn, JOIN_ITEM_TABLE_FULL_NAME);
+ String tableName3 = getTableName(conn, JOIN_CUSTOMER_TABLE_FULL_NAME);
+ String tableName4 = getTableName(conn, JOIN_ORDER_TABLE_FULL_NAME);
+ try {
+ String query = "SELECT /*+ USE_SORT_MERGE_JOIN*/ \"order_id\", name FROM " +
+ tableName4 + " o JOIN " + tableName1 +
+ " i ON o.\"item_id\" = i.\"item_id\" WHERE quantity in (SELECT max(quantity) FROM " + tableName4 +
+ " q WHERE o.\"item_id\" = q.\"item_id\") order by \"order_id\"";
+ PreparedStatement statement = conn.prepareStatement(query);
+ ResultSet rs = statement.executeQuery();
+ assertTrue (rs.next());
+ assertEquals(rs.getString(1), "000000000000001");
+ assertEquals(rs.getString(2), "T1");
+ assertTrue (rs.next());
+ assertEquals(rs.getString(1), "000000000000003");
+ assertEquals(rs.getString(2), "T2");
+ assertTrue (rs.next());
+ assertEquals(rs.getString(1), "000000000000004");
+ assertEquals(rs.getString(2), "T6");
+ assertTrue (rs.next());
+ assertEquals(rs.getString(1), "000000000000005");
+ assertEquals(rs.getString(2), "T3");
+ assertFalse(rs.next());
+
+ query = "SELECT /*+ USE_SORT_MERGE_JOIN*/ name from " + tableName3 +
+ " WHERE \"customer_id\" IN "+
+ "(SELECT \"customer_id\" FROM " + tableName1 + " i JOIN " + tableName4 +
+ " o ON o.\"item_id\" = i.\"item_id\" WHERE i.name = 'T2' OR quantity in (SELECT max(quantity) FROM " + tableName4 +
+ " q WHERE o.\"item_id\" = q.\"item_id\" and q.\"item_id\" = '0000000006')) order by name";
+ statement = conn.prepareStatement(query);
+ rs = statement.executeQuery();
+ assertTrue (rs.next());
+ assertEquals(rs.getString(1), "C2");
+ assertTrue (rs.next());
+ assertEquals(rs.getString(1), "C4");
+ assertFalse(rs.next());
+
+ query = "SELECT /*+ USE_SORT_MERGE_JOIN*/ \"order_id\" FROM " + tableName4 +
+ " o WHERE quantity in (SELECT quantity FROM " + tableName4 +
+ " WHERE o.\"item_id\" = \"item_id\" AND \"order_id\" != '000000000000004') order by \"order_id\"";
+ statement = conn.prepareStatement(query);
+ rs = statement.executeQuery();
+ assertTrue (rs.next());
+ assertEquals(rs.getString(1), "000000000000001");
+ assertTrue (rs.next());
+ assertEquals(rs.getString(1), "000000000000002");
+ assertTrue (rs.next());
+ assertEquals(rs.getString(1), "000000000000003");
+ assertTrue (rs.next());
+ assertEquals(rs.getString(1), "000000000000005");
+ assertFalse(rs.next());
+
+ query = "SELECT /*+ USE_SORT_MERGE_JOIN*/ \"order_id\" FROM " + tableName4 +
+ " o WHERE quantity in (SELECT quantity FROM " + tableName4 +
+ " WHERE o.\"item_id\" = \"item_id\" AND \"order_id\" != '000000000000003') order by \"order_id\"";
+ statement = conn.prepareStatement(query);
+ rs = statement.executeQuery();
+ assertTrue (rs.next());
+ assertEquals(rs.getString(1), "000000000000001");
+ assertTrue (rs.next());
+ assertEquals(rs.getString(1), "000000000000002");
+ assertTrue (rs.next());
+ assertEquals(rs.getString(1), "000000000000004");
+ assertTrue (rs.next());
+ assertEquals(rs.getString(1), "000000000000005");
+ assertFalse(rs.next());
+
+ query = "SELECT /*+ USE_SORT_MERGE_JOIN*/ \"order_id\" FROM " + tableName4 +
+ " o WHERE quantity in (SELECT max(quantity) FROM " + tableName4 +
+ " WHERE o.\"item_id\" = \"item_id\" AND \"order_id\" != '000000000000004' GROUP BY \"order_id\") order by \"order_id\"";
+ statement = conn.prepareStatement(query);
+ rs = statement.executeQuery();
+ assertTrue (rs.next());
+ assertEquals(rs.getString(1), "000000000000001");
+ assertTrue (rs.next());
+ assertEquals(rs.getString(1), "000000000000002");
+ assertTrue (rs.next());
+ assertEquals(rs.getString(1), "000000000000003");
+ assertTrue (rs.next());
+ assertEquals(rs.getString(1), "000000000000005");
+
+ assertFalse(rs.next());
+
+ query = "SELECT /*+ USE_SORT_MERGE_JOIN*/ \"order_id\" FROM " + tableName4 +
+ " o WHERE quantity in (SELECT max(quantity) FROM " + tableName4 +
+ " WHERE o.\"item_id\" = \"item_id\" AND \"order_id\" != '000000000000003' GROUP BY \"order_id\") order by \"order_id\"";
+ statement = conn.prepareStatement(query);
+ rs = statement.executeQuery();
+ assertTrue (rs.next());
+ assertEquals(rs.getString(1), "000000000000001");
+ assertTrue (rs.next());
+ assertEquals(rs.getString(1), "000000000000002");
+ assertTrue (rs.next());
+ assertEquals(rs.getString(1), "000000000000004");
+ assertTrue (rs.next());
+ assertEquals(rs.getString(1), "000000000000005");
+ assertFalse(rs.next());
+ } finally {
+ conn.close();
+ }
+ }
+
+ @Test
public void testAnyAllComparisonSubquery() throws Exception {
Properties props = PropertiesUtil.deepCopy(TEST_PROPERTIES);
Connection conn = DriverManager.getConnection(getUrl(), props);
diff --git a/phoenix-core/src/main/java/org/apache/phoenix/compile/SubqueryRewriter.java b/phoenix-core/src/main/java/org/apache/phoenix/compile/SubqueryRewriter.java
index dd9e62b..e2f6f58 100644
--- a/phoenix-core/src/main/java/org/apache/phoenix/compile/SubqueryRewriter.java
+++ b/phoenix-core/src/main/java/org/apache/phoenix/compile/SubqueryRewriter.java
@@ -32,7 +32,6 @@ import org.apache.phoenix.parse.AndParseNode;
import org.apache.phoenix.parse.AndRewriterBooleanParseNodeVisitor;
import org.apache.phoenix.parse.ArrayAllComparisonNode;
import org.apache.phoenix.parse.ArrayAnyComparisonNode;
-import org.apache.phoenix.parse.BooleanParseNodeVisitor;
import org.apache.phoenix.parse.ColumnParseNode;
import org.apache.phoenix.parse.ComparisonParseNode;
import org.apache.phoenix.parse.CompoundParseNode;
@@ -134,25 +133,167 @@ public class SubqueryRewriter extends ParseNodeRewriter {
});
}
+ /**
+ * <pre>
+ * Rewrite the In Subquery to semi/anti/left join for both NonCorrelated and Correlated subquery.
+ *
+ * 1.If the {@link InParseNode} is the only node in where clause or is the ANDed part of the where clause,
+ * then we would rewrite the In Subquery to semi/anti join:
+ * For NonCorrelated subquery, an example is:
+ * SELECT item_id, name FROM item i WHERE i.item_id IN
+ * (SELECT item_id FROM order o where o.price > 8)
+ *
+ * The above sql would be rewritten as:
+ * SELECT ITEM_ID,NAME FROM item I Semi JOIN
+ * (SELECT DISTINCT 1 $35,ITEM_ID $36 FROM order O WHERE O.PRICE > 8) $34
+ * ON (I.ITEM_ID = $34.$36)
+ *
+ * For Correlated subquery, an example is:
+ * SELECT item_id, name FROM item i WHERE i.item_id IN
+ * (SELECT item_id FROM order o where o.price = i.price)
+ *
+ * The above sql would be rewritten as:
+ * SELECT ITEM_ID,NAME FROM item I Semi JOIN
+ * (SELECT DISTINCT 1 $3,ITEM_ID $4,O.PRICE $2 FROM order O ) $1
+ * ON ((I.ITEM_ID = $1.$4 AND $1.$2 = I.PRICE))
+ *
+ * 2.If the {@link InParseNode} is the ORed part of the where clause,then we would rewrite the In Subquery to
+ * Left Join.
+ *
+ * For NonCorrelated subquery, an example is:
+ * SELECT item_id, name FROM item i WHERE i.item_id IN
+ * (SELECT max(item_id) FROM order o where o.price > 8 group by o.customer_id,o.item_id) or i.discount1 > 10
+ *
+ * The above sql would be rewritten as:
+ * SELECT ITEM_ID,NAME FROM item I Left JOIN
+ * (SELECT DISTINCT 1 $56, MAX(ITEM_ID) $57 FROM order O WHERE O.PRICE > 8 GROUP BY O.CUSTOMER_ID,O.ITEM_ID) $55
+ * ON (I.ITEM_ID = $55.$57) WHERE ($55.$56 IS NOT NULL OR I.DISCOUNT1 > 10)
+ *
+ * For Correlated subquery, an example is:
+ * SELECT item_id, name FROM item i WHERE i.item_id IN
+ * (SELECT max(item_id) FROM order o where o.price = i.price group by o.customer_id) or i.discount1 > 10;
+ *
+ * The above sql would be rewritten as:
+ * SELECT ITEM_ID,NAME FROM item I Left JOIN
+ * (SELECT DISTINCT 1 $28, MAX(ITEM_ID) $29,O.PRICE $27 FROM order O GROUP BY O.PRICE,O.CUSTOMER_ID) $26
+ * ON ((I.ITEM_ID = $26.$29 AND $26.$27 = I.PRICE)) WHERE ($26.$28 IS NOT NULL OR I.DISCOUNT1 > 10)
+ * </pre>
+ */
@Override
- public ParseNode visitLeave(InParseNode node, List<ParseNode> l) throws SQLException {
- boolean isTopNode = topNode == node;
+ public ParseNode visitLeave(InParseNode inParseNode, List<ParseNode> childParseNodes) throws SQLException {
+ boolean isTopNode = topNode == inParseNode;
if (isTopNode) {
topNode = null;
}
-
- SubqueryParseNode subqueryNode = (SubqueryParseNode) l.get(1);
- SelectStatement subquery = fixSubqueryStatement(subqueryNode.getSelectNode());
- String rhsTableAlias = ParseNodeFactory.createTempAlias();
- List<AliasedNode> selectNodes = fixAliasedNodes(subquery.getSelect(), true);
- subquery = NODE_FACTORY.select(subquery, !node.isSubqueryDistinct(), selectNodes);
- ParseNode onNode = getJoinConditionNode(l.get(0), selectNodes, rhsTableAlias);
- TableNode rhsTable = NODE_FACTORY.derivedTable(rhsTableAlias, subquery);
- JoinType joinType = isTopNode ? (node.isNegate() ? JoinType.Anti : JoinType.Semi) : JoinType.Left;
- ParseNode ret = isTopNode ? null : NODE_FACTORY.isNull(NODE_FACTORY.column(NODE_FACTORY.table(null, rhsTableAlias), selectNodes.get(0).getAlias(), null), !node.isNegate());
- tableNode = NODE_FACTORY.join(joinType, tableNode, rhsTable, onNode, false);
-
- return ret;
+
+ SubqueryParseNode subqueryParseNode = (SubqueryParseNode) childParseNodes.get(1);
+ SelectStatement subquerySelectStatementToUse = fixSubqueryStatement(subqueryParseNode.getSelectNode());
+ String subqueryTableTempAlias = ParseNodeFactory.createTempAlias();
+
+ JoinConditionExtractor joinConditionExtractor = new JoinConditionExtractor(
+ subquerySelectStatementToUse,
+ resolver,
+ connection,
+ subqueryTableTempAlias);
+
+ List<AliasedNode> newSubquerySelectAliasedNodes = null;
+ ParseNode extractedJoinConditionParseNode = null;
+ int extractedSelectAliasNodeCount = 0;
+ List<AliasedNode> oldSubqueryAliasedNodes = subquerySelectStatementToUse.getSelect();
+ ParseNode whereParseNodeAfterExtract =
+ subquerySelectStatementToUse.getWhere() == null ?
+ null :
+ subquerySelectStatementToUse.getWhere().accept(joinConditionExtractor);
+ if (whereParseNodeAfterExtract == subquerySelectStatementToUse.getWhere()) {
+ /**
+ * It is an NonCorrelated subquery.
+ */
+ newSubquerySelectAliasedNodes = Lists.<AliasedNode> newArrayListWithExpectedSize(
+ oldSubqueryAliasedNodes.size() + 1);
+
+ newSubquerySelectAliasedNodes.add(
+ NODE_FACTORY.aliasedNode(
+ ParseNodeFactory.createTempAlias(),
+ LiteralParseNode.ONE));
+ this.addNewAliasedNodes(newSubquerySelectAliasedNodes, oldSubqueryAliasedNodes);
+ subquerySelectStatementToUse = NODE_FACTORY.select(
+ subquerySelectStatementToUse,
+ !inParseNode.isSubqueryDistinct(),
+ newSubquerySelectAliasedNodes,
+ whereParseNodeAfterExtract);
+ } else {
+ /**
+ * It is an Correlated subquery.
+ */
+ List<AliasedNode> extractedAdditionalSelectAliasNodes =
+ joinConditionExtractor.getAdditionalSelectNodes();
+ extractedSelectAliasNodeCount = extractedAdditionalSelectAliasNodes.size();
+ newSubquerySelectAliasedNodes = Lists.<AliasedNode> newArrayListWithExpectedSize(
+ oldSubqueryAliasedNodes.size() + 1 +
+ extractedAdditionalSelectAliasNodes.size());
+
+ newSubquerySelectAliasedNodes.add(NODE_FACTORY.aliasedNode(
+ ParseNodeFactory.createTempAlias(),
+ LiteralParseNode.ONE));
+ this.addNewAliasedNodes(newSubquerySelectAliasedNodes, oldSubqueryAliasedNodes);
+ newSubquerySelectAliasedNodes.addAll(extractedAdditionalSelectAliasNodes);
+ extractedJoinConditionParseNode = joinConditionExtractor.getJoinCondition();
+
+ boolean isAggregate = subquerySelectStatementToUse.isAggregate();
+ if(!isAggregate) {
+ subquerySelectStatementToUse =
+ NODE_FACTORY.select(
+ subquerySelectStatementToUse,
+ !inParseNode.isSubqueryDistinct(),
+ newSubquerySelectAliasedNodes,
+ whereParseNodeAfterExtract);
+ } else {
+ /**
+ * If exists AggregateFunction,we must add the correlated join condition to both the
+ * groupBy clause and select lists of the subquery.
+ */
+ List<ParseNode> newGroupByParseNodes = this.createNewGroupByParseNodes(
+ extractedAdditionalSelectAliasNodes,
+ subquerySelectStatementToUse);
+
+ subquerySelectStatementToUse = NODE_FACTORY.select(
+ subquerySelectStatementToUse,
+ !inParseNode.isSubqueryDistinct(),
+ newSubquerySelectAliasedNodes,
+ whereParseNodeAfterExtract,
+ newGroupByParseNodes,
+ true);
+ }
+ }
+
+ ParseNode joinOnConditionParseNode = getJoinConditionNodeForInSubquery(
+ childParseNodes.get(0),
+ newSubquerySelectAliasedNodes,
+ subqueryTableTempAlias,
+ extractedJoinConditionParseNode,
+ extractedSelectAliasNodeCount);
+ TableNode rhsTableNode = NODE_FACTORY.derivedTable(
+ subqueryTableTempAlias,
+ subquerySelectStatementToUse);
+ JoinType joinType = isTopNode ?
+ (inParseNode.isNegate() ? JoinType.Anti : JoinType.Semi) :
+ JoinType.Left;
+ ParseNode resultWhereParseNode = isTopNode ?
+ null :
+ NODE_FACTORY.isNull(
+ NODE_FACTORY.column(
+ NODE_FACTORY.table(null, subqueryTableTempAlias),
+ newSubquerySelectAliasedNodes.get(0).getAlias(),
+ null),
+ !inParseNode.isNegate());
+ tableNode = NODE_FACTORY.join(
+ joinType,
+ tableNode,
+ rhsTableNode,
+ joinOnConditionParseNode,
+ false);
+
+ return resultWhereParseNode;
}
@Override
@@ -236,11 +377,9 @@ public class SubqueryRewriter extends ParseNodeRewriter {
if (!isAggregate) {
subquery = NODE_FACTORY.select(subquery, subquery.isDistinct(), selectNodes, where);
} else {
- List<ParseNode> groupbyNodes = Lists.newArrayListWithExpectedSize(additionalSelectNodes.size() + subquery.getGroupBy().size());
- for (AliasedNode aliasedNode : additionalSelectNodes) {
- groupbyNodes.add(aliasedNode.getNode());
- }
- groupbyNodes.addAll(subquery.getGroupBy());
+ List<ParseNode> groupbyNodes = this.createNewGroupByParseNodes(
+ additionalSelectNodes,
+ subquery);
subquery = NODE_FACTORY.select(subquery, subquery.isDistinct(), selectNodes, where, groupbyNodes, true);
}
@@ -299,7 +438,7 @@ public class SubqueryRewriter extends ParseNodeRewriter {
String derivedTableAlias = null;
if (!subquery.getGroupBy().isEmpty()) {
derivedTableAlias = ParseNodeFactory.createTempAlias();
- aliasedNodes = fixAliasedNodes(aliasedNodes, false);
+ aliasedNodes = createNewAliasedNodes(aliasedNodes);
}
if (aliasedNodes.size() == 1) {
@@ -373,42 +512,112 @@ public class SubqueryRewriter extends ParseNodeRewriter {
select.getBindCount(), false, false, Collections.<SelectStatement> emptyList(),
select.getUdfParseNodes());
}
-
- private List<AliasedNode> fixAliasedNodes(List<AliasedNode> nodes, boolean addSelectOne) {
- List<AliasedNode> normNodes = Lists.<AliasedNode> newArrayListWithExpectedSize(nodes.size() + (addSelectOne ? 1 : 0));
- if (addSelectOne) {
- normNodes.add(NODE_FACTORY.aliasedNode(ParseNodeFactory.createTempAlias(), LiteralParseNode.ONE));
- }
- for (int i = 0; i < nodes.size(); i++) {
- AliasedNode aliasedNode = nodes.get(i);
- normNodes.add(NODE_FACTORY.aliasedNode(
- ParseNodeFactory.createTempAlias(), aliasedNode.getNode()));
+
+ /**
+ * Create new {@link AliasedNode}s by every {@link ParseNode} in subquerySelectAliasedNodes and generate new aliases
+ * by {@link ParseNodeFactory#createTempAlias}.
+ * and generate new Aliases for subquerySelectAliasedNodes,
+ * @param subquerySelectAliasedNodes
+ * @param addSelectOne
+ * @return
+ */
+ private List<AliasedNode> createNewAliasedNodes(List<AliasedNode> subquerySelectAliasedNodes) {
+ List<AliasedNode> newAliasedNodes = Lists.<AliasedNode> newArrayListWithExpectedSize(
+ subquerySelectAliasedNodes.size());
+
+ this.addNewAliasedNodes(newAliasedNodes, subquerySelectAliasedNodes);
+ return newAliasedNodes;
+ }
+
+ /**
+ * Add every {@link ParseNode} in oldSelectAliasedNodes to newSelectAliasedNodes and generate new aliases by
+ * {@link ParseNodeFactory#createTempAlias}.
+ * @param oldSelectAliasedNodes
+ * @param addSelectOne
+ * @return
+ */
+ private void addNewAliasedNodes(List<AliasedNode> newSelectAliasedNodes, List<AliasedNode> oldSelectAliasedNodes) {
+ for (int index = 0; index < oldSelectAliasedNodes.size(); index++) {
+ AliasedNode oldSelectAliasedNode = oldSelectAliasedNodes.get(index);
+ newSelectAliasedNodes.add(NODE_FACTORY.aliasedNode(
+ ParseNodeFactory.createTempAlias(),
+ oldSelectAliasedNode.getNode()));
}
- return normNodes;
}
-
- private ParseNode getJoinConditionNode(ParseNode lhs, List<AliasedNode> rhs, String rhsTableAlias) throws SQLException {
- List<ParseNode> lhsNodes;
- if (lhs instanceof RowValueConstructorParseNode) {
- lhsNodes = ((RowValueConstructorParseNode) lhs).getChildren();
+
+ /**
+ * Get the join conditions in order to rewrite InSubquery to Join.
+ * @param lhsParseNode
+ * @param rhsSubquerySelectAliasedNodes the first element is {@link LiteralParseNode#ONE}.
+ * @param rhsSubqueryTableAlias
+ * @param extractedJoinConditionParseNode For NonCorrelated subquery, it is null.
+ * @param extractedSelectAliasNodeCount For NonCorrelated subquery, it is 0.
+ * @throws SQLException
+ */
+ private ParseNode getJoinConditionNodeForInSubquery(
+ ParseNode lhsParseNode,
+ List<AliasedNode> rhsSubquerySelectAliasedNodes,
+ String rhsSubqueryTableAlias,
+ ParseNode extractedJoinConditionParseNode,
+ int extractedSelectAliasNodeCount) throws SQLException {
+ List<ParseNode> lhsParseNodes;
+ if (lhsParseNode instanceof RowValueConstructorParseNode) {
+ lhsParseNodes = ((RowValueConstructorParseNode) lhsParseNode).getChildren();
} else {
- lhsNodes = Collections.singletonList(lhs);
+ lhsParseNodes = Collections.singletonList(lhsParseNode);
}
- if (lhsNodes.size() != (rhs.size() - 1))
- throw new SQLExceptionInfo.Builder(SQLExceptionCode.SUBQUERY_RETURNS_DIFFERENT_NUMBER_OF_FIELDS).build().buildException();
-
- int count = lhsNodes.size();
- TableName rhsTableName = NODE_FACTORY.table(null, rhsTableAlias);
- List<ParseNode> equalNodes = Lists.newArrayListWithExpectedSize(count);
- for (int i = 0; i < count; i++) {
- ParseNode rhsNode = NODE_FACTORY.column(rhsTableName, rhs.get(i + 1).getAlias(), null);
- equalNodes.add(NODE_FACTORY.equal(lhsNodes.get(i), rhsNode));
+
+ if (lhsParseNodes.size() !=
+ (rhsSubquerySelectAliasedNodes.size() - 1 - extractedSelectAliasNodeCount)) {
+ throw new SQLExceptionInfo.Builder(
+ SQLExceptionCode.SUBQUERY_RETURNS_DIFFERENT_NUMBER_OF_FIELDS)
+ .build().buildException();
}
-
- return count == 1 ? equalNodes.get(0) : NODE_FACTORY.and(equalNodes);
+
+ int count = lhsParseNodes.size();
+ TableName rhsSubqueryTableName = NODE_FACTORY.table(null, rhsSubqueryTableAlias);
+ List<ParseNode> joinEqualParseNodes = Lists.newArrayListWithExpectedSize(
+ count + (extractedJoinConditionParseNode == null ? 0: 1));
+ for (int index = 0; index < count; index++) {
+ /**
+ * The +1 is to skip the first {@link LiteralParseNode#ONE}
+ */
+ ParseNode rhsNode = NODE_FACTORY.column(
+ rhsSubqueryTableName,
+ rhsSubquerySelectAliasedNodes.get(index + 1).getAlias(),
+ null);
+ joinEqualParseNodes.add(NODE_FACTORY.equal(lhsParseNodes.get(index), rhsNode));
+ }
+
+ if(extractedJoinConditionParseNode != null) {
+ joinEqualParseNodes.add(extractedJoinConditionParseNode);
+ }
+
+ return joinEqualParseNodes.size() == 1 ? joinEqualParseNodes.get(0) : NODE_FACTORY.and(joinEqualParseNodes);
}
+ /**
+ * Combine every {@link ParseNode} in extractedAdditionalSelectAliasNodes and GroupBy clause of the
+ * subquerySelectStatementToUse to get new GroupBy ParseNodes.
+ * @param extractedAdditionalSelectAliasNodes
+ * @param subquerySelectStatementToUse
+ * @return
+ */
+ private List<ParseNode> createNewGroupByParseNodes(
+ List<AliasedNode> extractedAdditionalSelectAliasNodes,
+ SelectStatement subquerySelectStatementToUse) {
+ List<ParseNode> newGroupByParseNodes = Lists.newArrayListWithExpectedSize(
+ extractedAdditionalSelectAliasNodes.size() +
+ subquerySelectStatementToUse.getGroupBy().size());
+
+ for (AliasedNode aliasedNode : extractedAdditionalSelectAliasNodes) {
+ newGroupByParseNodes.add(aliasedNode.getNode());
+ }
+ newGroupByParseNodes.addAll(subquerySelectStatementToUse.getGroupBy());
+ return newGroupByParseNodes;
+ }
+
private static class JoinConditionExtractor extends AndRewriterBooleanParseNodeVisitor {
private final TableName tableName;
private ColumnResolveVisitor columnResolveVisitor;
diff --git a/phoenix-core/src/main/java/org/apache/phoenix/parse/ParseNodeFactory.java b/phoenix-core/src/main/java/org/apache/phoenix/parse/ParseNodeFactory.java
index 349c0b0..6577aac 100644
--- a/phoenix-core/src/main/java/org/apache/phoenix/parse/ParseNodeFactory.java
+++ b/phoenix-core/src/main/java/org/apache/phoenix/parse/ParseNodeFactory.java
@@ -27,6 +27,7 @@ import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
+import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ArrayListMultimap;
import org.apache.hadoop.hbase.filter.CompareFilter.CompareOp;
import org.apache.hadoop.hbase.util.Pair;
@@ -198,6 +199,16 @@ public class ParseNodeFactory {
}
private static AtomicInteger tempAliasCounter = new AtomicInteger(0);
+
+ @VisibleForTesting
+ public static int getTempAliasCounterValue() {
+ return tempAliasCounter.get();
+ }
+
+ @VisibleForTesting
+ public static void setTempAliasCounterValue(int newValue) {
+ tempAliasCounter.set(newValue);
+ }
public static String createTempAlias() {
return "$" + tempAliasCounter.incrementAndGet();
diff --git a/phoenix-core/src/main/java/org/apache/phoenix/parse/SelectStatement.java b/phoenix-core/src/main/java/org/apache/phoenix/parse/SelectStatement.java
index d4f079b..8f937a9 100644
--- a/phoenix-core/src/main/java/org/apache/phoenix/parse/SelectStatement.java
+++ b/phoenix-core/src/main/java/org/apache/phoenix/parse/SelectStatement.java
@@ -307,7 +307,7 @@ public class SelectStatement implements FilterableStatement {
}
/**
- * Gets the group-by, containing at least 1 element, or null, if none.
+ * Gets the group-by, containing at least 1 element, or empty list, if none.
*/
public List<ParseNode> getGroupBy() {
return groupBy;
diff --git a/phoenix-core/src/test/java/org/apache/phoenix/compile/QueryCompilerTest.java b/phoenix-core/src/test/java/org/apache/phoenix/compile/QueryCompilerTest.java
index b49aaf8..e31849c 100644
--- a/phoenix-core/src/test/java/org/apache/phoenix/compile/QueryCompilerTest.java
+++ b/phoenix-core/src/test/java/org/apache/phoenix/compile/QueryCompilerTest.java
@@ -36,6 +36,7 @@ import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
+import java.sql.SQLFeatureNotSupportedException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.Collections;
@@ -80,6 +81,7 @@ import org.apache.phoenix.filter.EncodedQualifiersColumnProjectionFilter;
import org.apache.phoenix.jdbc.PhoenixConnection;
import org.apache.phoenix.jdbc.PhoenixPreparedStatement;
import org.apache.phoenix.jdbc.PhoenixStatement;
+import org.apache.phoenix.parse.ParseNodeFactory;
import org.apache.phoenix.query.BaseConnectionlessQueryTest;
import org.apache.phoenix.query.QueryConstants;
import org.apache.phoenix.query.QueryServices;
@@ -104,6 +106,7 @@ import org.apache.phoenix.util.QueryUtil;
import org.apache.phoenix.util.ScanUtil;
import org.apache.phoenix.util.SchemaUtil;
import org.apache.phoenix.util.TestUtil;
+import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test;
@@ -126,6 +129,11 @@ import com.google.common.collect.Lists;
justification="Test code.")
public class QueryCompilerTest extends BaseConnectionlessQueryTest {
+ @Before
+ public void setUp() {
+ ParseNodeFactory.setTempAliasCounterValue(0);
+ }
+
@Test
public void testParameterUnbound() throws Exception {
try {
@@ -6468,4 +6476,125 @@ public class QueryCompilerTest extends BaseConnectionlessQueryTest {
conn.close();
}
}
+
+ @Test
+ public void testInSubqueryBug6224() throws Exception {
+ Connection conn = null;
+ try {
+ conn = DriverManager.getConnection(getUrl());
+ String itemTableName = "item_table";
+ String sql ="create table " + itemTableName +
+ " (item_id varchar not null primary key, " +
+ " name varchar, " +
+ " price integer, " +
+ " discount1 integer, " +
+ " discount2 integer, " +
+ " supplier_id varchar, " +
+ " description varchar)";
+ conn.createStatement().execute(sql);
+
+ String orderTableName = "order_table";
+ sql = "create table " + orderTableName +
+ " (order_id varchar not null primary key, " +
+ " customer_id varchar, " +
+ " item_id varchar, " +
+ " price integer, " +
+ " quantity integer, " +
+ " date timestamp)";
+ conn.createStatement().execute(sql);
+ //test simple Correlated subquery
+ sql= "SELECT item_id, name FROM " + itemTableName + " i WHERE i.item_id IN "+
+ "(SELECT item_id FROM " + orderTableName + " o where o.price = i.price) ORDER BY name";
+ QueryPlan queryPlan= TestUtil.getOptimizeQueryPlanNoIterator(conn, sql);
+ assertTrue(queryPlan instanceof HashJoinPlan);
+ TestUtil.assertSelectStatement(
+ queryPlan.getStatement(),
+ "SELECT ITEM_ID,NAME FROM ITEM_TABLE I Semi JOIN (SELECT DISTINCT 1 $3,ITEM_ID $4,O.PRICE $2 FROM ORDER_TABLE O ) $1 "+
+ "ON ((I.ITEM_ID = $1.$4 AND $1.$2 = I.PRICE)) ORDER BY NAME");
+
+ //test Correlated subquery with AggregateFunction but no groupBy
+ sql= "SELECT item_id, name FROM " + itemTableName + " i WHERE i.item_id IN "+
+ "(SELECT max(item_id) FROM " + orderTableName + " o where o.price = i.price) ORDER BY name";
+ queryPlan= TestUtil.getOptimizeQueryPlanNoIterator(conn, sql);
+ assertTrue(queryPlan instanceof HashJoinPlan);
+ TestUtil.assertSelectStatement(
+ queryPlan.getStatement(),
+ "SELECT ITEM_ID,NAME FROM ITEM_TABLE I Semi JOIN "+
+ "(SELECT DISTINCT 1 $11, MAX(ITEM_ID) $12,O.PRICE $10 FROM ORDER_TABLE O GROUP BY O.PRICE) $9 "+
+ "ON ((I.ITEM_ID = $9.$12 AND $9.$10 = I.PRICE)) ORDER BY NAME");
+
+ //test Correlated subquery with AggregateFunction with groupBy
+ sql= "SELECT item_id, name FROM " + itemTableName + " i WHERE i.item_id IN "+
+ "(SELECT max(item_id) FROM " + orderTableName + " o where o.price = i.price group by o.customer_id) ORDER BY name";
+ queryPlan= TestUtil.getOptimizeQueryPlanNoIterator(conn, sql);
+ assertTrue(queryPlan instanceof HashJoinPlan);
+ TestUtil.assertSelectStatement(
+ queryPlan.getStatement(),
+ "SELECT ITEM_ID,NAME FROM ITEM_TABLE I Semi JOIN "+
+ "(SELECT DISTINCT 1 $19, MAX(ITEM_ID) $20,O.PRICE $18 FROM ORDER_TABLE O GROUP BY O.PRICE,O.CUSTOMER_ID) $17 "+
+ "ON ((I.ITEM_ID = $17.$20 AND $17.$18 = I.PRICE)) ORDER BY NAME");
+
+ //for Correlated subquery, the extracted join condition must be equal expression.
+ sql= "SELECT item_id, name FROM " + itemTableName + " i WHERE i.item_id IN "+
+ "(SELECT max(item_id) FROM " + orderTableName + " o where o.price = i.price or o.quantity > 1 group by o.customer_id) ORDER BY name";
+ try {
+ queryPlan= TestUtil.getOptimizeQueryPlanNoIterator(conn, sql);
+ fail();
+ } catch(SQLFeatureNotSupportedException exception) {
+
+ }
+
+ //test Correlated subquery with AggregateFunction with groupBy and is ORed part of the where clause.
+ sql= "SELECT item_id, name FROM " + itemTableName + " i WHERE i.item_id IN "+
+ "(SELECT max(item_id) FROM " + orderTableName + " o where o.price = i.price group by o.customer_id) or i.discount1 > 10 ORDER BY name";
+ queryPlan= TestUtil.getOptimizeQueryPlanNoIterator(conn, sql);
+ assertTrue(queryPlan instanceof HashJoinPlan);
+ TestUtil.assertSelectStatement(
+ queryPlan.getStatement(),
+ "SELECT ITEM_ID,NAME FROM ITEM_TABLE I Left JOIN "+
+ "(SELECT DISTINCT 1 $28, MAX(ITEM_ID) $29,O.PRICE $27 FROM ORDER_TABLE O GROUP BY O.PRICE,O.CUSTOMER_ID) $26 "+
+ "ON ((I.ITEM_ID = $26.$29 AND $26.$27 = I.PRICE)) WHERE ($26.$28 IS NOT NULL OR I.DISCOUNT1 > 10) ORDER BY NAME");
+
+ // test NonCorrelated subquery
+ sql= "SELECT item_id, name FROM " + itemTableName + " i WHERE i.item_id IN "+
+ "(SELECT item_id FROM " + orderTableName + " o where o.price > 8) ORDER BY name";
+ queryPlan= TestUtil.getOptimizeQueryPlanNoIterator(conn, sql);
+ assertTrue(queryPlan instanceof HashJoinPlan);
+ TestUtil.assertSelectStatement(
+ queryPlan.getStatement(),
+ "SELECT ITEM_ID,NAME FROM ITEM_TABLE I Semi JOIN "+
+ "(SELECT DISTINCT 1 $35,ITEM_ID $36 FROM ORDER_TABLE O WHERE O.PRICE > 8) $34 ON (I.ITEM_ID = $34.$36) ORDER BY NAME");
+
+ sql= "SELECT item_id, name FROM " + itemTableName + " i WHERE i.item_id IN "+
+ "(SELECT max(item_id) FROM " + orderTableName + " o where o.price > 8) ORDER BY name";
+ queryPlan= TestUtil.getOptimizeQueryPlanNoIterator(conn, sql);
+ assertTrue(queryPlan instanceof HashJoinPlan);
+ TestUtil.assertSelectStatement(
+ queryPlan.getStatement(),
+ "SELECT ITEM_ID,NAME FROM ITEM_TABLE I Semi JOIN "+
+ "(SELECT DISTINCT 1 $42, MAX(ITEM_ID) $43 FROM ORDER_TABLE O WHERE O.PRICE > 8) $41 ON (I.ITEM_ID = $41.$43) ORDER BY NAME");
+
+ sql= "SELECT item_id, name FROM " + itemTableName + " i WHERE i.item_id IN "+
+ "(SELECT max(item_id) FROM " + orderTableName + " o where o.price > 8 group by o.customer_id,o.item_id) ORDER BY name";
+ queryPlan= TestUtil.getOptimizeQueryPlanNoIterator(conn, sql);
+ assertTrue(queryPlan instanceof HashJoinPlan);
+ TestUtil.assertSelectStatement(
+ queryPlan.getStatement(),
+ "SELECT ITEM_ID,NAME FROM ITEM_TABLE I Semi JOIN "+
+ "(SELECT DISTINCT 1 $49, MAX(ITEM_ID) $50 FROM ORDER_TABLE O WHERE O.PRICE > 8 GROUP BY O.CUSTOMER_ID,O.ITEM_ID) $48 "+
+ "ON (I.ITEM_ID = $48.$50) ORDER BY NAME");
+
+ sql= "SELECT item_id, name FROM " + itemTableName + " i WHERE i.item_id IN "+
+ "(SELECT max(item_id) FROM " + orderTableName + " o where o.price > 8 group by o.customer_id,o.item_id) or i.discount1 > 10 ORDER BY name";
+ queryPlan= TestUtil.getOptimizeQueryPlanNoIterator(conn, sql);
+ assertTrue(queryPlan instanceof HashJoinPlan);
+ TestUtil.assertSelectStatement(
+ queryPlan.getStatement(),
+ "SELECT ITEM_ID,NAME FROM ITEM_TABLE I Left JOIN "+
+ "(SELECT DISTINCT 1 $56, MAX(ITEM_ID) $57 FROM ORDER_TABLE O WHERE O.PRICE > 8 GROUP BY O.CUSTOMER_ID,O.ITEM_ID) $55 "+
+ "ON (I.ITEM_ID = $55.$57) WHERE ($55.$56 IS NOT NULL OR I.DISCOUNT1 > 10) ORDER BY NAME");
+ } finally {
+ conn.close();
+ }
+ }
}