You are viewing a plain text version of this content. The canonical link for it is here.
Posted to notifications@shardingsphere.apache.org by du...@apache.org on 2022/11/18 10:14:56 UTC
[shardingsphere] branch master updated: Add case when segment, support sql node convert. (#22259)
This is an automated email from the ASF dual-hosted git repository.
duanzhengqiang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git
The following commit(s) were added to refs/heads/master by this push:
new 5822c5b6e43 Add case when segment, support sql node convert. (#22259)
5822c5b6e43 is described below
commit 5822c5b6e432144834c80c27dbab0c90552136be
Author: Chuxin Chen <ch...@qq.com>
AuthorDate: Fri Nov 18 18:14:47 2022 +0800
Add case when segment, support sql node convert. (#22259)
---
.../segment/expression/ExpressionConverter.java | 5 ++
.../segment/expression/impl/CaseWhenConverter.java | 72 ++++++++++++++++++++++
.../impl/OpenGaussStatementSQLVisitor.java | 22 +++++++
.../impl/PostgreSQLStatementSQLVisitor.java | 22 +++++++
.../common/segment/dml/expr/CaseWhenSegment.java | 45 ++++++++++++++
.../SQLNodeConverterEngineParameterizedTest.java | 1 +
.../segment/expression/ExpressionAssert.java | 29 ++++++++-
.../impl/expr/ExpectedCaseWhenExpression.java | 46 ++++++++++++++
.../segment/impl/expr/ExpectedExpression.java | 3 +
.../main/resources/case/dml/select-expression.xml | 58 +++++++++++++++++
.../sql/supported/dml/select-expression.xml | 2 +
11 files changed, 304 insertions(+), 1 deletion(-)
diff --git a/kernel/sql-federation/optimizer/src/main/java/org/apache/shardingsphere/sqlfederation/optimizer/converter/segment/expression/ExpressionConverter.java b/kernel/sql-federation/optimizer/src/main/java/org/apache/shardingsphere/sqlfederation/optimizer/converter/segment/expression/ExpressionConverter.java
index 40091e17ead..b943744a04c 100644
--- a/kernel/sql-federation/optimizer/src/main/java/org/apache/shardingsphere/sqlfederation/optimizer/converter/segment/expression/ExpressionConverter.java
+++ b/kernel/sql-federation/optimizer/src/main/java/org/apache/shardingsphere/sqlfederation/optimizer/converter/segment/expression/ExpressionConverter.java
@@ -22,6 +22,7 @@ import org.apache.shardingsphere.infra.util.exception.external.sql.type.generic.
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.BetweenExpression;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.BinaryOperationExpression;
+import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.CaseWhenSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExistsSubqueryExpression;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.FunctionSegment;
@@ -36,6 +37,7 @@ import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.DataTypeS
import org.apache.shardingsphere.sqlfederation.optimizer.converter.segment.SQLSegmentConverter;
import org.apache.shardingsphere.sqlfederation.optimizer.converter.segment.expression.impl.BetweenExpressionConverter;
import org.apache.shardingsphere.sqlfederation.optimizer.converter.segment.expression.impl.BinaryOperationExpressionConverter;
+import org.apache.shardingsphere.sqlfederation.optimizer.converter.segment.expression.impl.CaseWhenConverter;
import org.apache.shardingsphere.sqlfederation.optimizer.converter.segment.expression.impl.ColumnConverter;
import org.apache.shardingsphere.sqlfederation.optimizer.converter.segment.expression.impl.ExistsSubqueryExpressionConverter;
import org.apache.shardingsphere.sqlfederation.optimizer.converter.segment.expression.impl.FunctionConverter;
@@ -99,6 +101,9 @@ public final class ExpressionConverter implements SQLSegmentConverter<Expression
if (segment instanceof DataTypeSegment) {
return new DataTypeConverter().convert((DataTypeSegment) segment);
}
+ if (segment instanceof CaseWhenSegment) {
+ return new CaseWhenConverter().convert((CaseWhenSegment) segment);
+ }
throw new UnsupportedSQLOperationException("unsupported TableSegment type: " + segment.getClass());
}
}
diff --git a/kernel/sql-federation/optimizer/src/main/java/org/apache/shardingsphere/sqlfederation/optimizer/converter/segment/expression/impl/CaseWhenConverter.java b/kernel/sql-federation/optimizer/src/main/java/org/apache/shardingsphere/sqlfederation/optimizer/converter/segment/expression/impl/CaseWhenConverter.java
new file mode 100644
index 00000000000..9db5040b00a
--- /dev/null
+++ b/kernel/sql-federation/optimizer/src/main/java/org/apache/shardingsphere/sqlfederation/optimizer/converter/segment/expression/impl/CaseWhenConverter.java
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.shardingsphere.sqlfederation.optimizer.converter.segment.expression.impl;
+
+import org.apache.calcite.sql.SqlBasicCall;
+import org.apache.calcite.sql.SqlLiteral;
+import org.apache.calcite.sql.SqlNode;
+import org.apache.calcite.sql.SqlNodeList;
+import org.apache.calcite.sql.fun.SqlCase;
+import org.apache.calcite.sql.fun.SqlStdOperatorTable;
+import org.apache.calcite.sql.parser.SqlParserPos;
+import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.CaseWhenSegment;
+import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
+import org.apache.shardingsphere.sqlfederation.optimizer.converter.segment.SQLSegmentConverter;
+import org.apache.shardingsphere.sqlfederation.optimizer.converter.segment.expression.ExpressionConverter;
+
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.LinkedList;
+import java.util.Optional;
+
+/**
+ * Case when converter.
+ */
+public final class CaseWhenConverter implements SQLSegmentConverter<CaseWhenSegment, SqlNode> {
+
+ @Override
+ public Optional<SqlNode> convert(final CaseWhenSegment segment) {
+ Collection<SqlNode> whenList = convertWhenList(segment.getCaseArg(), segment.getWhenList());
+ Collection<SqlNode> thenList = new LinkedList<>();
+ segment.getThenList().forEach(each -> new ExpressionConverter().convert(each).ifPresent(thenList::add));
+ Optional<SqlNode> elseExpr = new ExpressionConverter().convert(segment.getElseExpression());
+ return Optional.of(new SqlCase(SqlParserPos.ZERO, null, new SqlNodeList(whenList, SqlParserPos.ZERO), new SqlNodeList(thenList, SqlParserPos.ZERO),
+ elseExpr.orElseGet(() -> SqlLiteral.createCharString("NULL", SqlParserPos.ZERO))));
+ }
+
+ private Collection<SqlNode> convertWhenList(final ExpressionSegment caseArg, final Collection<ExpressionSegment> whenList) {
+ Collection<SqlNode> result = new LinkedList<>();
+ for (ExpressionSegment each : whenList) {
+ if (null != caseArg) {
+ convertWithCaseArg(caseArg, each, result);
+ } else {
+ new ExpressionConverter().convert(each).ifPresent(result::add);
+ }
+ }
+ return result;
+ }
+
+ private void convertWithCaseArg(final ExpressionSegment caseArg, final ExpressionSegment expressionSegment, final Collection<SqlNode> result) {
+ Optional<SqlNode> leftExpr = new ExpressionConverter().convert(caseArg);
+ Optional<SqlNode> rightExpr = new ExpressionConverter().convert(expressionSegment);
+ if (leftExpr.isPresent() && rightExpr.isPresent()) {
+ new ExpressionConverter().convert(expressionSegment).ifPresent(optional -> result.add(
+ new SqlBasicCall(SqlStdOperatorTable.EQUALS, Arrays.asList(leftExpr.get(), rightExpr.get()), SqlParserPos.ZERO)));
+ }
+ }
+}
diff --git a/sql-parser/dialect/opengauss/src/main/java/org/apache/shardingsphere/sql/parser/opengauss/visitor/statement/impl/OpenGaussStatementSQLVisitor.java b/sql-parser/dialect/opengauss/src/main/java/org/apache/shardingsphere/sql/parser/opengauss/visitor/statement/impl/OpenGaussStatementSQLVisitor.java
index 479c3267e82..2252289e2e1 100644
--- a/sql-parser/dialect/opengauss/src/main/java/org/apache/shardingsphere/sql/parser/opengauss/visitor/statement/impl/OpenGaussStatementSQLVisitor.java
+++ b/sql-parser/dialect/opengauss/src/main/java/org/apache/shardingsphere/sql/parser/opengauss/visitor/statement/impl/OpenGaussStatementSQLVisitor.java
@@ -36,6 +36,7 @@ import org.apache.shardingsphere.sql.parser.autogen.OpenGaussStatementParser.Att
import org.apache.shardingsphere.sql.parser.autogen.OpenGaussStatementParser.AttrsContext;
import org.apache.shardingsphere.sql.parser.autogen.OpenGaussStatementParser.BExprContext;
import org.apache.shardingsphere.sql.parser.autogen.OpenGaussStatementParser.CExprContext;
+import org.apache.shardingsphere.sql.parser.autogen.OpenGaussStatementParser.CaseExprContext;
import org.apache.shardingsphere.sql.parser.autogen.OpenGaussStatementParser.ColIdContext;
import org.apache.shardingsphere.sql.parser.autogen.OpenGaussStatementParser.ColumnNameContext;
import org.apache.shardingsphere.sql.parser.autogen.OpenGaussStatementParser.ColumnNamesContext;
@@ -102,6 +103,7 @@ import org.apache.shardingsphere.sql.parser.autogen.OpenGaussStatementParser.Tar
import org.apache.shardingsphere.sql.parser.autogen.OpenGaussStatementParser.UnreservedWordContext;
import org.apache.shardingsphere.sql.parser.autogen.OpenGaussStatementParser.UpdateContext;
import org.apache.shardingsphere.sql.parser.autogen.OpenGaussStatementParser.ValuesClauseContext;
+import org.apache.shardingsphere.sql.parser.autogen.OpenGaussStatementParser.WhenClauseContext;
import org.apache.shardingsphere.sql.parser.autogen.OpenGaussStatementParser.WhereClauseContext;
import org.apache.shardingsphere.sql.parser.autogen.OpenGaussStatementParser.WhereOrCurrentClauseContext;
import org.apache.shardingsphere.sql.parser.autogen.OpenGaussStatementParser.WindowClauseContext;
@@ -124,6 +126,7 @@ import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.OnDupl
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.combine.CombineSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.BetweenExpression;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.BinaryOperationExpression;
+import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.CaseWhenSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExistsSubqueryExpression;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.FunctionSegment;
@@ -382,6 +385,9 @@ public abstract class OpenGaussStatementSQLVisitor extends OpenGaussStatementBas
if (null != ctx.selectWithParens()) {
return createSubqueryExpressionSegment(ctx);
}
+ if (null != ctx.caseExpr()) {
+ return visit(ctx.caseExpr());
+ }
super.visitCExpr(ctx);
String text = ctx.start.getInputStream().getText(new Interval(ctx.start.getStartIndex(), ctx.stop.getStopIndex()));
return new CommonExpressionSegment(ctx.start.getStartIndex(), ctx.stop.getStopIndex(), text);
@@ -396,6 +402,19 @@ public abstract class OpenGaussStatementSQLVisitor extends OpenGaussStatementBas
return new SubqueryExpressionSegment(subquerySegment);
}
+ @Override
+ public ASTNode visitCaseExpr(final CaseExprContext ctx) {
+ Collection<ExpressionSegment> whenList = new LinkedList<>();
+ Collection<ExpressionSegment> thenList = new LinkedList<>();
+ for (WhenClauseContext each : ctx.whenClauseList().whenClause()) {
+ whenList.add((ExpressionSegment) visit(each.aExpr(0)));
+ thenList.add((ExpressionSegment) visit(each.aExpr(1)));
+ }
+ ExpressionSegment argExpression = null == ctx.caseArg() ? null : (ExpressionSegment) visit(ctx.caseArg().aExpr());
+ ExpressionSegment elseExpression = null == ctx.caseDefault() ? null : (ExpressionSegment) visit(ctx.caseDefault().aExpr());
+ return new CaseWhenSegment(ctx.getStart().getStartIndex(), ctx.getStop().getStopIndex(), argExpression, whenList, thenList, elseExpression);
+ }
+
@Override
public ASTNode visitFuncExpr(final FuncExprContext ctx) {
if (null != ctx.functionExprCommonSubexpr()) {
@@ -1066,6 +1085,9 @@ public abstract class OpenGaussStatementSQLVisitor extends OpenGaussStatementBas
if (projection instanceof LiteralExpressionSegment) {
return Optional.of(new ExpressionProjectionSegment(ctx.start.getStartIndex(), ctx.stop.getStopIndex(), getOriginalText(expr), (LiteralExpressionSegment) projection));
}
+ if (projection instanceof CaseWhenSegment) {
+ return Optional.of(new ExpressionProjectionSegment(ctx.start.getStartIndex(), ctx.stop.getStopIndex(), getOriginalText(expr), (CaseWhenSegment) projection));
+ }
return Optional.empty();
}
diff --git a/sql-parser/dialect/postgresql/src/main/java/org/apache/shardingsphere/sql/parser/postgresql/visitor/statement/impl/PostgreSQLStatementSQLVisitor.java b/sql-parser/dialect/postgresql/src/main/java/org/apache/shardingsphere/sql/parser/postgresql/visitor/statement/impl/PostgreSQLStatementSQLVisitor.java
index 901d9caf9ec..27e2443a7ab 100644
--- a/sql-parser/dialect/postgresql/src/main/java/org/apache/shardingsphere/sql/parser/postgresql/visitor/statement/impl/PostgreSQLStatementSQLVisitor.java
+++ b/sql-parser/dialect/postgresql/src/main/java/org/apache/shardingsphere/sql/parser/postgresql/visitor/statement/impl/PostgreSQLStatementSQLVisitor.java
@@ -34,6 +34,7 @@ import org.apache.shardingsphere.sql.parser.autogen.PostgreSQLStatementParser.At
import org.apache.shardingsphere.sql.parser.autogen.PostgreSQLStatementParser.AttrsContext;
import org.apache.shardingsphere.sql.parser.autogen.PostgreSQLStatementParser.BExprContext;
import org.apache.shardingsphere.sql.parser.autogen.PostgreSQLStatementParser.CExprContext;
+import org.apache.shardingsphere.sql.parser.autogen.PostgreSQLStatementParser.CaseExprContext;
import org.apache.shardingsphere.sql.parser.autogen.PostgreSQLStatementParser.ColIdContext;
import org.apache.shardingsphere.sql.parser.autogen.PostgreSQLStatementParser.ColumnNameContext;
import org.apache.shardingsphere.sql.parser.autogen.PostgreSQLStatementParser.ColumnNamesContext;
@@ -101,6 +102,7 @@ import org.apache.shardingsphere.sql.parser.autogen.PostgreSQLStatementParser.Ta
import org.apache.shardingsphere.sql.parser.autogen.PostgreSQLStatementParser.UnreservedWordContext;
import org.apache.shardingsphere.sql.parser.autogen.PostgreSQLStatementParser.UpdateContext;
import org.apache.shardingsphere.sql.parser.autogen.PostgreSQLStatementParser.ValuesClauseContext;
+import org.apache.shardingsphere.sql.parser.autogen.PostgreSQLStatementParser.WhenClauseContext;
import org.apache.shardingsphere.sql.parser.autogen.PostgreSQLStatementParser.WhereClauseContext;
import org.apache.shardingsphere.sql.parser.autogen.PostgreSQLStatementParser.WhereOrCurrentClauseContext;
import org.apache.shardingsphere.sql.parser.autogen.PostgreSQLStatementParser.WindowClauseContext;
@@ -124,6 +126,7 @@ import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.OnDupl
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.combine.CombineSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.BetweenExpression;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.BinaryOperationExpression;
+import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.CaseWhenSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExistsSubqueryExpression;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.FunctionSegment;
@@ -382,6 +385,9 @@ public abstract class PostgreSQLStatementSQLVisitor extends PostgreSQLStatementP
if (null != ctx.selectWithParens()) {
return createSubqueryExpressionSegment(ctx);
}
+ if (null != ctx.caseExpr()) {
+ return visit(ctx.caseExpr());
+ }
super.visitCExpr(ctx);
String text = ctx.start.getInputStream().getText(new Interval(ctx.start.getStartIndex(), ctx.stop.getStopIndex()));
return new CommonExpressionSegment(ctx.start.getStartIndex(), ctx.stop.getStopIndex(), text);
@@ -393,6 +399,19 @@ public abstract class PostgreSQLStatementSQLVisitor extends PostgreSQLStatementP
return null == ctx.EXISTS() ? new SubqueryExpressionSegment(subquerySegment) : new ExistsSubqueryExpression(ctx.start.getStartIndex(), ctx.stop.getStopIndex(), subquerySegment);
}
+ @Override
+ public ASTNode visitCaseExpr(final CaseExprContext ctx) {
+ Collection<ExpressionSegment> whenList = new LinkedList<>();
+ Collection<ExpressionSegment> thenList = new LinkedList<>();
+ for (WhenClauseContext each : ctx.whenClauseList().whenClause()) {
+ whenList.add((ExpressionSegment) visit(each.aExpr(0)));
+ thenList.add((ExpressionSegment) visit(each.aExpr(1)));
+ }
+ ExpressionSegment argExpression = null == ctx.caseArg() ? null : (ExpressionSegment) visit(ctx.caseArg().aExpr());
+ ExpressionSegment elseExpression = null == ctx.caseDefault() ? null : (ExpressionSegment) visit(ctx.caseDefault().aExpr());
+ return new CaseWhenSegment(ctx.getStart().getStartIndex(), ctx.getStop().getStopIndex(), argExpression, whenList, thenList, elseExpression);
+ }
+
@Override
public ASTNode visitFuncExpr(final FuncExprContext ctx) {
if (null != ctx.functionExprCommonSubexpr()) {
@@ -1033,6 +1052,9 @@ public abstract class PostgreSQLStatementSQLVisitor extends PostgreSQLStatementP
if (projection instanceof LiteralExpressionSegment) {
return Optional.of(new ExpressionProjectionSegment(ctx.start.getStartIndex(), ctx.stop.getStopIndex(), getOriginalText(expr), (LiteralExpressionSegment) projection));
}
+ if (projection instanceof CaseWhenSegment) {
+ return Optional.of(new ExpressionProjectionSegment(ctx.start.getStartIndex(), ctx.stop.getStopIndex(), getOriginalText(expr), (CaseWhenSegment) projection));
+ }
return Optional.empty();
}
diff --git a/sql-parser/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/segment/dml/expr/CaseWhenSegment.java b/sql-parser/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/segment/dml/expr/CaseWhenSegment.java
new file mode 100644
index 00000000000..b6a8ddb4456
--- /dev/null
+++ b/sql-parser/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/segment/dml/expr/CaseWhenSegment.java
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr;
+
+import lombok.Getter;
+import lombok.RequiredArgsConstructor;
+import lombok.ToString;
+
+import java.util.Collection;
+
+/**
+ * Case when segment.
+ */
+@RequiredArgsConstructor
+@Getter
+@ToString
+public final class CaseWhenSegment implements ExpressionSegment {
+
+ private final int startIndex;
+
+ private final int stopIndex;
+
+ private final ExpressionSegment caseArg;
+
+ private final Collection<ExpressionSegment> whenList;
+
+ private final Collection<ExpressionSegment> thenList;
+
+ private final ExpressionSegment elseExpression;
+}
diff --git a/test/optimize/src/test/java/org/apache/shardingsphere/infra/federation/converter/parameterized/engine/SQLNodeConverterEngineParameterizedTest.java b/test/optimize/src/test/java/org/apache/shardingsphere/infra/federation/converter/parameterized/engine/SQLNodeConverterEngineParameterizedTest.java
index 2968d9f2159..b7b507ab70c 100644
--- a/test/optimize/src/test/java/org/apache/shardingsphere/infra/federation/converter/parameterized/engine/SQLNodeConverterEngineParameterizedTest.java
+++ b/test/optimize/src/test/java/org/apache/shardingsphere/infra/federation/converter/parameterized/engine/SQLNodeConverterEngineParameterizedTest.java
@@ -139,6 +139,7 @@ public final class SQLNodeConverterEngineParameterizedTest {
SUPPORTED_SQL_CASE_IDS.add("select_minus");
SUPPORTED_SQL_CASE_IDS.add("select_minus_order_by");
SUPPORTED_SQL_CASE_IDS.add("select_minus_order_by_limit");
+ SUPPORTED_SQL_CASE_IDS.add("select_projections_with_only_expr_for_postgres");
}
// CHECKSTYLE:ON
diff --git a/test/parser/src/main/java/org/apache/shardingsphere/test/sql/parser/internal/asserts/segment/expression/ExpressionAssert.java b/test/parser/src/main/java/org/apache/shardingsphere/test/sql/parser/internal/asserts/segment/expression/ExpressionAssert.java
index fba7055956e..ecf5fc3b89e 100644
--- a/test/parser/src/main/java/org/apache/shardingsphere/test/sql/parser/internal/asserts/segment/expression/ExpressionAssert.java
+++ b/test/parser/src/main/java/org/apache/shardingsphere/test/sql/parser/internal/asserts/segment/expression/ExpressionAssert.java
@@ -22,6 +22,7 @@ import lombok.NoArgsConstructor;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.BetweenExpression;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.BinaryOperationExpression;
+import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.CaseWhenSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.CollateExpression;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExistsSubqueryExpression;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
@@ -47,6 +48,7 @@ import org.apache.shardingsphere.test.sql.parser.internal.asserts.segment.projec
import org.apache.shardingsphere.test.sql.parser.internal.asserts.statement.dml.impl.SelectStatementAssert;
import org.apache.shardingsphere.test.sql.parser.internal.jaxb.cases.domain.segment.impl.expr.ExpectedBetweenExpression;
import org.apache.shardingsphere.test.sql.parser.internal.jaxb.cases.domain.segment.impl.expr.ExpectedBinaryOperationExpression;
+import org.apache.shardingsphere.test.sql.parser.internal.jaxb.cases.domain.segment.impl.expr.ExpectedCaseWhenExpression;
import org.apache.shardingsphere.test.sql.parser.internal.jaxb.cases.domain.segment.impl.expr.ExpectedCollateExpression;
import org.apache.shardingsphere.test.sql.parser.internal.jaxb.cases.domain.segment.impl.expr.ExpectedExistsSubquery;
import org.apache.shardingsphere.test.sql.parser.internal.jaxb.cases.domain.segment.impl.expr.ExpectedExpression;
@@ -63,9 +65,9 @@ import org.apache.shardingsphere.test.sql.parser.internal.jaxb.sql.SQLCaseType;
import java.util.Iterator;
import static org.hamcrest.CoreMatchers.is;
+import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
-import static org.hamcrest.MatcherAssert.assertThat;
/**
* Expression assert.
@@ -325,8 +327,31 @@ public final class ExpressionAssert {
}
}
+ /**
+ * Assert case when expression.
+ *
+ * @param assertContext assert context
+ * @param actual actual case when expression
+ * @param expected expected case when expression
+ */
+ public static void assertCaseWhenExpression(final SQLCaseAssertContext assertContext, final CaseWhenSegment actual, final ExpectedCaseWhenExpression expected) {
+ assertThat(assertContext.getText("When list size is not same!"), actual.getWhenList().size(), is(expected.getWhenList().size()));
+ assertThat(assertContext.getText("Then list size is not same!"), actual.getThenList().size(), is(expected.getThenList().size()));
+ Iterator<ExpectedExpression> whenListIterator = expected.getWhenList().iterator();
+ for (ExpressionSegment each : actual.getWhenList()) {
+ assertExpression(assertContext, each, whenListIterator.next());
+ }
+ Iterator<ExpectedExpression> thenListIterator = expected.getThenList().iterator();
+ for (ExpressionSegment each : actual.getThenList()) {
+ assertExpression(assertContext, each, thenListIterator.next());
+ }
+ assertExpression(assertContext, actual.getCaseArg(), expected.getCaseArg());
+ assertExpression(assertContext, actual.getElseExpression(), expected.getElseExpr());
+ }
+
/**
* Assert expression by actual expression segment class type.
+ *
* @param assertContext assert context
* @param actual actual expression segment
* @param expected expected expression
@@ -372,6 +397,8 @@ public final class ExpressionAssert {
assertFunction(assertContext, (FunctionSegment) actual, expected.getFunction());
} else if (actual instanceof CollateExpression) {
assertCollateExpression(assertContext, (CollateExpression) actual, expected.getCollateExpression());
+ } else if (actual instanceof CaseWhenSegment) {
+ assertCaseWhenExpression(assertContext, (CaseWhenSegment) actual, expected.getCaseWhenExpression());
} else {
throw new UnsupportedOperationException(String.format("Unsupported expression: %s", actual.getClass().getName()));
}
diff --git a/test/parser/src/main/java/org/apache/shardingsphere/test/sql/parser/internal/jaxb/cases/domain/segment/impl/expr/ExpectedCaseWhenExpression.java b/test/parser/src/main/java/org/apache/shardingsphere/test/sql/parser/internal/jaxb/cases/domain/segment/impl/expr/ExpectedCaseWhenExpression.java
new file mode 100644
index 00000000000..69cc592965f
--- /dev/null
+++ b/test/parser/src/main/java/org/apache/shardingsphere/test/sql/parser/internal/jaxb/cases/domain/segment/impl/expr/ExpectedCaseWhenExpression.java
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.shardingsphere.test.sql.parser.internal.jaxb.cases.domain.segment.impl.expr;
+
+import lombok.Getter;
+import lombok.Setter;
+import org.apache.shardingsphere.test.sql.parser.internal.jaxb.cases.domain.segment.AbstractExpectedSQLSegment;
+
+import javax.xml.bind.annotation.XmlElement;
+import java.util.LinkedList;
+import java.util.List;
+
+/**
+ * Expected case when in expression.
+ */
+@Getter
+@Setter
+public final class ExpectedCaseWhenExpression extends AbstractExpectedSQLSegment implements ExpectedExpressionSegment {
+
+ @XmlElement(name = "case-arg")
+ private ExpectedExpression caseArg;
+
+ @XmlElement(name = "when-list")
+ private List<ExpectedExpression> whenList = new LinkedList<>();
+
+ @XmlElement(name = "then-list")
+ private List<ExpectedExpression> thenList = new LinkedList<>();
+
+ @XmlElement(name = "else-expression")
+ private ExpectedExpression elseExpr;
+}
diff --git a/test/parser/src/main/java/org/apache/shardingsphere/test/sql/parser/internal/jaxb/cases/domain/segment/impl/expr/ExpectedExpression.java b/test/parser/src/main/java/org/apache/shardingsphere/test/sql/parser/internal/jaxb/cases/domain/segment/impl/expr/ExpectedExpression.java
index f91de4fa036..1e10de1c3ee 100644
--- a/test/parser/src/main/java/org/apache/shardingsphere/test/sql/parser/internal/jaxb/cases/domain/segment/impl/expr/ExpectedExpression.java
+++ b/test/parser/src/main/java/org/apache/shardingsphere/test/sql/parser/internal/jaxb/cases/domain/segment/impl/expr/ExpectedExpression.java
@@ -86,4 +86,7 @@ public final class ExpectedExpression extends AbstractExpectedSQLSegment {
@XmlElement(name = "collate-expression")
private ExpectedCollateExpression collateExpression;
+
+ @XmlElement(name = "case-when-expression")
+ private ExpectedCaseWhenExpression caseWhenExpression;
}
diff --git a/test/parser/src/main/resources/case/dml/select-expression.xml b/test/parser/src/main/resources/case/dml/select-expression.xml
index 14065b95bf7..59d9bf7a538 100644
--- a/test/parser/src/main/resources/case/dml/select-expression.xml
+++ b/test/parser/src/main/resources/case/dml/select-expression.xml
@@ -1891,6 +1891,64 @@
</from>
</select>
+ <select sql-case-id="select_projections_with_expr_for_postgres">
+ <projections start-index="7" stop-index="58">
+ <expression-projection start-index="7" stop-index="11" text="10+20" />
+ <expression-projection start-index="13" stop-index="56" text="CASE order_id WHEN 1 THEN '11' ELSE '00' END">
+ <expr>
+ <case-when-expression>
+ <case-arg>
+ <column name="order_id" start-index="18" stop-index="25" />
+ </case-arg>
+ <when-list>
+ <literal-expression value="1" start-index="32" stop-index="32"/>
+ </when-list>
+ <then-list>
+ <literal-expression value="11" start-index="39" stop-index="42"/>
+ </then-list>
+ <else-expression>
+ <literal-expression value="00" start-index="49" stop-index="52" />
+ </else-expression>
+ </case-when-expression>
+ </expr>
+ </expression-projection>
+ <expression-projection start-index="58" stop-index="58" text="1">
+ <expr>
+ <literal-expression value="1" start-index="58" stop-index="58"/>
+ </expr>
+ </expression-projection>
+ </projections>
+ <from>
+ <simple-table name="t_order" start-index="65" stop-index="71" />
+ </from>
+ </select>
+
+ <select sql-case-id="select_projections_with_only_expr_for_postgres">
+ <projections start-index="7" stop-index="50">
+ <expression-projection start-index="7" stop-index="50" text="CASE order_id WHEN 1 THEN '11' ELSE '00' END">
+ <expr>
+ <case-when-expression>
+ <case-arg>
+ <column name="order_id" start-index="12" stop-index="19" />
+ </case-arg>
+ <when-list>
+ <literal-expression value="1" start-index="26" stop-index="26"/>
+ </when-list>
+ <then-list>
+ <literal-expression value="11" start-index="33" stop-index="36"/>
+ </then-list>
+ <else-expression>
+ <literal-expression value="00" start-index="43" stop-index="46" />
+ </else-expression>
+ </case-when-expression>
+ </expr>
+ </expression-projection>
+ </projections>
+ <from>
+ <simple-table name="t_order" start-index="57" stop-index="63" />
+ </from>
+ </select>
+
<select sql-case-id="select_with_amp">
<projections start-index="7" stop-index="11">
<expression-projection text="1 & 1" start-index="7" stop-index="11"/>
diff --git a/test/parser/src/main/resources/sql/supported/dml/select-expression.xml b/test/parser/src/main/resources/sql/supported/dml/select-expression.xml
index e77883e3d5b..a36e095c052 100644
--- a/test/parser/src/main/resources/sql/supported/dml/select-expression.xml
+++ b/test/parser/src/main/resources/sql/supported/dml/select-expression.xml
@@ -83,6 +83,8 @@
<sql-case id="select_where_with_subquery" value="SELECT last_name, department_id FROM employees WHERE department_id = (SELECT department_id FROM employees WHERE last_name = 'Lorentz') ORDER BY last_name, department_id" db-types="Oracle" />
<sql-case id="select_where_with_expr_with_not_in" value="SELECT * FROM employees WHERE department_id NOT IN (SELECT department_id FROM departments WHERE location_id = 1700) ORDER BY last_name" db-types="Oracle" />
<sql-case id="select_projections_with_expr" value="SELECT 10+20,CASE order_id WHEN 1 THEN '11' ELSE '00' END,1 FROM t_order" db-types="MySQL" />
+ <sql-case id="select_projections_with_expr_for_postgres" value="SELECT 10+20,CASE order_id WHEN 1 THEN '11' ELSE '00' END,1 FROM t_order" db-types="PostgreSQL,openGauss" />
+ <sql-case id="select_projections_with_only_expr_for_postgres" value="SELECT CASE order_id WHEN 1 THEN '11' ELSE '00' END FROM t_order" db-types="PostgreSQL,openGauss" />
<sql-case id="select_with_amp" value="select 1 & 1" db-types="PostgreSQL,openGauss" />
<sql-case id="select_with_vertical_bar" value="select 1 | 1" db-types="PostgreSQL,openGauss" />
<sql-case id="select_with_abs_function" value="SELECT ABS(1) FROM t_order WHERE ABS(1) > 1 GROUP BY ABS(1) ORDER BY ABS(1)" db-types="MySQL,Oracle" />