You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kyuubi.apache.org by ch...@apache.org on 2023/10/09 02:51:38 UTC
[kyuubi] branch branch-1.8 updated: [KYUUBI #5336] Spark extension supports Spark 3.5
This is an automated email from the ASF dual-hosted git repository.
chengpan pushed a commit to branch branch-1.8
in repository https://gitbox.apache.org/repos/asf/kyuubi.git
The following commit(s) were added to refs/heads/branch-1.8 by this push:
new 9417c21c3 [KYUUBI #5336] Spark extension supports Spark 3.5
9417c21c3 is described below
commit 9417c21c38b2582eae670a6796a1c2146870c3a3
Author: wforget <64...@qq.com>
AuthorDate: Mon Oct 9 10:51:07 2023 +0800
[KYUUBI #5336] Spark extension supports Spark 3.5
### _Why are the changes needed?_
It is basically copied from `kyuubi-extension-spark-3-4`.
### _How was this patch tested?_
Compiled successfully:
```
build/mvn clean install -DskipTests -Pflink-provided,spark-provided,hive-provided,spark-3.5
```
- [ ] Add some test cases that check the changes thoroughly including negative and positive cases if possible
- [ ] Add screenshots for manual tests if appropriate
- [ ] [Run test](https://kyuubi.readthedocs.io/en/master/contributing/code/testing.html#running-tests) locally before make a pull request
### _Was this patch authored or co-authored using generative AI tooling?_
No
Closes #5336 from wForget/dev_spark_3_5.
Closes #5336
7ba99804a [wforget] remove iceberg.version in spark-3.5 profile
a18ce166a [wforget] Regenerate KyuubiEnsureRequirements based on EnsureRequirements in spark 3.5
4725c4701 [wforget] fix iceberg version
f5a8ea934 [wforget] Bump iceberg 1.4.0
06558dcfa [wforget] make kyuubi-spark-authz plugin compatible with Spark3.5
90d0e4c70 [wforget] make kyuubi-spark-authz plugin compatible with Spark3.5
4bc8d24d6 [wforget] add ci
1b3f2d916 [wforget] Make kyuubi spark extension compatible with Spark3.5
Authored-by: wforget <64...@qq.com>
Signed-off-by: Cheng Pan <ch...@apache.org>
(cherry picked from commit d2c072b7c2a4ea76b050db7fcec8916b70aa25f3)
Signed-off-by: Cheng Pan <ch...@apache.org>
---
.github/workflows/license.yml | 2 +-
.github/workflows/master.yml | 1 +
.github/workflows/style.yml | 1 +
build/dist | 2 +-
dev/kyuubi-codecov/pom.xml | 10 +
dev/reformat | 2 +-
.../spark/kyuubi-extension-spark-3-5/pom.xml | 206 ++++++
.../antlr4/org/apache/kyuubi/sql/KyuubiSparkSQL.g4 | 191 +++++
.../apache/kyuubi/sql/DropIgnoreNonexistent.scala | 49 ++
.../kyuubi/sql/InferRebalanceAndSortOrders.scala | 110 +++
.../kyuubi/sql/InsertShuffleNodeBeforeJoin.scala | 91 +++
.../kyuubi/sql/KyuubiEnsureRequirements.scala | 464 +++++++++++++
.../kyuubi/sql/KyuubiQueryStagePreparation.scala | 194 ++++++
.../org/apache/kyuubi/sql/KyuubiSQLConf.scala | 276 ++++++++
.../kyuubi/sql/KyuubiSQLExtensionException.scala | 28 +
.../kyuubi/sql/KyuubiSparkSQLAstBuilder.scala | 174 +++++
.../kyuubi/sql/KyuubiSparkSQLCommonExtension.scala | 49 ++
.../kyuubi/sql/KyuubiSparkSQLExtension.scala | 46 ++
.../apache/kyuubi/sql/KyuubiSparkSQLParser.scala | 140 ++++
.../apache/kyuubi/sql/RebalanceBeforeWriting.scala | 77 +++
.../kyuubi/sql/RepartitionBeforeWritingBase.scala | 125 ++++
.../scala/org/apache/kyuubi/sql/WriteUtils.scala | 34 +
.../sql/watchdog/ForcedMaxOutputRowsBase.scala | 90 +++
.../sql/watchdog/ForcedMaxOutputRowsRule.scala | 46 ++
.../sql/watchdog/KyuubiWatchDogException.scala | 30 +
.../kyuubi/sql/watchdog/MaxScanStrategy.scala | 305 ++++++++
.../sql/zorder/InsertZorderBeforeWriting.scala | 177 +++++
.../sql/zorder/InsertZorderBeforeWritingBase.scala | 155 +++++
.../sql/zorder/OptimizeZorderCommandBase.scala | 78 +++
.../sql/zorder/OptimizeZorderStatementBase.scala | 34 +
.../kyuubi/sql/zorder/ResolveZorderBase.scala | 79 +++
.../org/apache/kyuubi/sql/zorder/ZorderBase.scala | 95 +++
.../kyuubi/sql/zorder/ZorderBytesUtils.scala | 517 ++++++++++++++
.../spark/sql/FinalStageResourceManager.scala | 289 ++++++++
.../spark/sql/InjectCustomResourceProfile.scala | 60 ++
.../spark/sql/PruneFileSourcePartitionHelper.scala | 46 ++
.../sql/execution/CustomResourceProfileExec.scala | 112 +++
.../src/test/resources/log4j2-test.xml | 43 ++
.../spark/sql/DropIgnoreNonexistentSuite.scala | 45 ++
.../spark/sql/FinalStageConfigIsolationSuite.scala | 203 ++++++
.../spark/sql/FinalStageResourceManagerSuite.scala | 62 ++
.../spark/sql/InjectResourceProfileSuite.scala | 79 +++
.../sql/InsertShuffleNodeBeforeJoinSuite.scala | 19 +
.../sql/InsertShuffleNodeBeforeJoinSuiteBase.scala | 98 +++
.../spark/sql/KyuubiSparkSQLExtensionTest.scala | 124 ++++
.../spark/sql/RebalanceBeforeWritingSuite.scala | 271 ++++++++
.../scala/org/apache/spark/sql/WatchDogSuite.scala | 20 +
.../org/apache/spark/sql/WatchDogSuiteBase.scala | 601 ++++++++++++++++
.../org/apache/spark/sql/ZorderCoreBenchmark.scala | 117 ++++
.../scala/org/apache/spark/sql/ZorderSuite.scala | 123 ++++
.../org/apache/spark/sql/ZorderSuiteBase.scala | 768 +++++++++++++++++++++
.../spark/sql/benchmark/KyuubiBenchmarkBase.scala | 71 ++
.../src/main/resources/table_command_spec.json | 8 +-
.../plugin/spark/authz/gen/TableCommands.scala | 4 +-
pom.xml | 13 +
55 files changed, 7045 insertions(+), 9 deletions(-)
diff --git a/.github/workflows/license.yml b/.github/workflows/license.yml
index 91c53a7a1..55ef485f8 100644
--- a/.github/workflows/license.yml
+++ b/.github/workflows/license.yml
@@ -45,7 +45,7 @@ jobs:
- run: >-
build/mvn org.apache.rat:apache-rat-plugin:check
-Ptpcds -Pspark-block-cleaner -Pkubernetes-it
- -Pspark-3.1 -Pspark-3.2 -Pspark-3.3 -Pspark-3.4
+ -Pspark-3.1 -Pspark-3.2 -Pspark-3.3 -Pspark-3.4 -Pspark-3.5
- name: Upload rat report
if: failure()
uses: actions/upload-artifact@v3
diff --git a/.github/workflows/master.yml b/.github/workflows/master.yml
index 3b85530d4..586db9c8e 100644
--- a/.github/workflows/master.yml
+++ b/.github/workflows/master.yml
@@ -52,6 +52,7 @@ jobs:
- '3.2'
- '3.3'
- '3.4'
+ - '3.5'
spark-archive: [""]
exclude-tags: [""]
comment: ["normal"]
diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml
index 6f575302e..21cacbc1d 100644
--- a/.github/workflows/style.yml
+++ b/.github/workflows/style.yml
@@ -69,6 +69,7 @@ jobs:
build/mvn clean install ${MVN_OPT} -pl extensions/spark/kyuubi-extension-spark-3-1 -Pspark-3.1
build/mvn clean install ${MVN_OPT} -pl extensions/spark/kyuubi-extension-spark-3-3,extensions/spark/kyuubi-spark-connector-hive -Pspark-3.3
build/mvn clean install ${MVN_OPT} -pl extensions/spark/kyuubi-extension-spark-3-4 -Pspark-3.4
+ build/mvn clean install ${MVN_OPT} -pl extensions/spark/kyuubi-extension-spark-3-5 -Pspark-3.5
- name: Scalastyle with maven
id: scalastyle-check
diff --git a/build/dist b/build/dist
index b81a2661e..df9498008 100755
--- a/build/dist
+++ b/build/dist
@@ -335,7 +335,7 @@ if [[ -f "$KYUUBI_HOME/tools/spark-block-cleaner/target/spark-block-cleaner_${SC
fi
# Copy Kyuubi Spark extension
-SPARK_EXTENSION_VERSIONS=('3-1' '3-2' '3-3' '3-4')
+SPARK_EXTENSION_VERSIONS=('3-1' '3-2' '3-3' '3-4' '3-5')
# shellcheck disable=SC2068
for SPARK_EXTENSION_VERSION in ${SPARK_EXTENSION_VERSIONS[@]}; do
if [[ -f $"$KYUUBI_HOME/extensions/spark/kyuubi-extension-spark-$SPARK_EXTENSION_VERSION/target/kyuubi-extension-spark-${SPARK_EXTENSION_VERSION}_${SCALA_VERSION}-${VERSION}.jar" ]]; then
diff --git a/dev/kyuubi-codecov/pom.xml b/dev/kyuubi-codecov/pom.xml
index 0d265a006..c2ec1e729 100644
--- a/dev/kyuubi-codecov/pom.xml
+++ b/dev/kyuubi-codecov/pom.xml
@@ -219,5 +219,15 @@
</dependency>
</dependencies>
</profile>
+ <profile>
+ <id>spark-3.5</id>
+ <dependencies>
+ <dependency>
+ <groupId>org.apache.kyuubi</groupId>
+ <artifactId>kyuubi-extension-spark-3-5_${scala.binary.version}</artifactId>
+ <version>${project.version}</version>
+ </dependency>
+ </dependencies>
+ </profile>
</profiles>
</project>
diff --git a/dev/reformat b/dev/reformat
index 6346e68f6..31e8f49ad 100755
--- a/dev/reformat
+++ b/dev/reformat
@@ -20,7 +20,7 @@ set -x
KYUUBI_HOME="$(cd "`dirname "$0"`/.."; pwd)"
-PROFILES="-Pflink-provided,hive-provided,spark-provided,spark-block-cleaner,spark-3.4,spark-3.3,spark-3.2,spark-3.1,tpcds"
+PROFILES="-Pflink-provided,hive-provided,spark-provided,spark-block-cleaner,spark-3.5,spark-3.4,spark-3.3,spark-3.2,spark-3.1,tpcds"
# python style checks rely on `black` in path
if ! command -v black &> /dev/null
diff --git a/extensions/spark/kyuubi-extension-spark-3-5/pom.xml b/extensions/spark/kyuubi-extension-spark-3-5/pom.xml
new file mode 100644
index 000000000..e78a88a80
--- /dev/null
+++ b/extensions/spark/kyuubi-extension-spark-3-5/pom.xml
@@ -0,0 +1,206 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+ ~ Licensed to the Apache Software Foundation (ASF) under one or more
+ ~ contributor license agreements. See the NOTICE file distributed with
+ ~ this work for additional information regarding copyright ownership.
+ ~ The ASF licenses this file to You under the Apache License, Version 2.0
+ ~ (the "License"); you may not use this file except in compliance with
+ ~ the License. You may obtain a copy of the License at
+ ~
+ ~ http://www.apache.org/licenses/LICENSE-2.0
+ ~
+ ~ Unless required by applicable law or agreed to in writing, software
+ ~ distributed under the License is distributed on an "AS IS" BASIS,
+ ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ ~ See the License for the specific language governing permissions and
+ ~ limitations under the License.
+ -->
+<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
+ xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
+ <modelVersion>4.0.0</modelVersion>
+ <parent>
+ <groupId>org.apache.kyuubi</groupId>
+ <artifactId>kyuubi-parent</artifactId>
+ <version>1.9.0-SNAPSHOT</version>
+ <relativePath>../../../pom.xml</relativePath>
+ </parent>
+
+ <artifactId>kyuubi-extension-spark-3-5_${scala.binary.version}</artifactId>
+ <packaging>jar</packaging>
+ <name>Kyuubi Dev Spark Extensions (for Spark 3.5)</name>
+ <url>https://kyuubi.apache.org/</url>
+
+ <dependencies>
+ <dependency>
+ <groupId>org.scala-lang</groupId>
+ <artifactId>scala-library</artifactId>
+ <scope>provided</scope>
+ </dependency>
+
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-sql_${scala.binary.version}</artifactId>
+ <scope>provided</scope>
+ </dependency>
+
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-hive_${scala.binary.version}</artifactId>
+ <scope>provided</scope>
+ </dependency>
+
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-client-api</artifactId>
+ <scope>provided</scope>
+ </dependency>
+
+ <dependency>
+ <groupId>org.apache.kyuubi</groupId>
+ <artifactId>kyuubi-download</artifactId>
+ <version>${project.version}</version>
+ <type>pom</type>
+ <scope>test</scope>
+ </dependency>
+
+ <dependency>
+ <groupId>org.apache.kyuubi</groupId>
+ <artifactId>kyuubi-util-scala_${scala.binary.version}</artifactId>
+ <version>${project.version}</version>
+ <type>test-jar</type>
+ <scope>test</scope>
+ </dependency>
+
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-core_${scala.binary.version}</artifactId>
+ <type>test-jar</type>
+ <scope>test</scope>
+ </dependency>
+
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-catalyst_${scala.binary.version}</artifactId>
+ <type>test-jar</type>
+ <scope>test</scope>
+ </dependency>
+
+ <dependency>
+ <groupId>org.scalatestplus</groupId>
+ <artifactId>scalacheck-1-17_${scala.binary.version}</artifactId>
+ <scope>test</scope>
+ </dependency>
+
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-sql_${scala.binary.version}</artifactId>
+ <version>${spark.version}</version>
+ <type>test-jar</type>
+ <scope>test</scope>
+ </dependency>
+
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-client-runtime</artifactId>
+ <scope>test</scope>
+ </dependency>
+
+ <!--
+ Spark requires `commons-collections` and `commons-io` but got them from transitive
+ dependencies of `hadoop-client`. As we are using Hadoop Shaded Client, we need add
+ them explicitly. See more details at SPARK-33212.
+ -->
+ <dependency>
+ <groupId>commons-collections</groupId>
+ <artifactId>commons-collections</artifactId>
+ <scope>test</scope>
+ </dependency>
+
+ <dependency>
+ <groupId>commons-io</groupId>
+ <artifactId>commons-io</artifactId>
+ <scope>test</scope>
+ </dependency>
+
+ <dependency>
+ <groupId>jakarta.xml.bind</groupId>
+ <artifactId>jakarta.xml.bind-api</artifactId>
+ <scope>test</scope>
+ </dependency>
+
+ <dependency>
+ <groupId>org.apache.logging.log4j</groupId>
+ <artifactId>log4j-slf4j-impl</artifactId>
+ <scope>test</scope>
+ </dependency>
+ </dependencies>
+
+ <build>
+
+ <plugins>
+ <plugin>
+ <groupId>org.codehaus.mojo</groupId>
+ <artifactId>build-helper-maven-plugin</artifactId>
+ <executions>
+ <execution>
+ <id>regex-property</id>
+ <goals>
+ <goal>regex-property</goal>
+ </goals>
+ <configuration>
+ <name>spark.home</name>
+ <value>${project.basedir}/../../../externals/kyuubi-download/target/${spark.archive.name}</value>
+ <regex>(.+)\.tgz</regex>
+ <replacement>$1</replacement>
+ </configuration>
+ </execution>
+ </executions>
+ </plugin>
+ <plugin>
+ <groupId>org.scalatest</groupId>
+ <artifactId>scalatest-maven-plugin</artifactId>
+ <configuration>
+ <environmentVariables>
+ <!--
+ Some tests run Spark in local-cluster mode.
+ This mode uses SPARK_HOME and SPARK_SCALA_VERSION to build command to launch a Spark Standalone Cluster.
+ -->
+ <SPARK_HOME>${spark.home}</SPARK_HOME>
+ <SPARK_SCALA_VERSION>${scala.binary.version}</SPARK_SCALA_VERSION>
+ </environmentVariables>
+ </configuration>
+ </plugin>
+ <plugin>
+ <groupId>org.antlr</groupId>
+ <artifactId>antlr4-maven-plugin</artifactId>
+ <configuration>
+ <visitor>true</visitor>
+ <sourceDirectory>${project.basedir}/src/main/antlr4</sourceDirectory>
+ </configuration>
+ </plugin>
+
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-shade-plugin</artifactId>
+ <configuration>
+ <shadedArtifactAttached>false</shadedArtifactAttached>
+ <artifactSet>
+ <includes>
+ <include>org.apache.kyuubi:*</include>
+ </includes>
+ </artifactSet>
+ </configuration>
+ <executions>
+ <execution>
+ <goals>
+ <goal>shade</goal>
+ </goals>
+ <phase>package</phase>
+ </execution>
+ </executions>
+ </plugin>
+ </plugins>
+ <outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>
+ <testOutputDirectory>target/scala-${scala.binary.version}/test-classes</testOutputDirectory>
+ </build>
+</project>
diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/antlr4/org/apache/kyuubi/sql/KyuubiSparkSQL.g4 b/extensions/spark/kyuubi-extension-spark-3-5/src/main/antlr4/org/apache/kyuubi/sql/KyuubiSparkSQL.g4
new file mode 100644
index 000000000..e52b7f5cf
--- /dev/null
+++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/antlr4/org/apache/kyuubi/sql/KyuubiSparkSQL.g4
@@ -0,0 +1,191 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+grammar KyuubiSparkSQL;
+
+@members {
+ /**
+ * Verify whether current token is a valid decimal token (which contains dot).
+ * Returns true if the character that follows the token is not a digit or letter or underscore.
+ *
+ * For example:
+ * For char stream "2.3", "2." is not a valid decimal token, because it is followed by digit '3'.
+ * For char stream "2.3_", "2.3" is not a valid decimal token, because it is followed by '_'.
+ * For char stream "2.3W", "2.3" is not a valid decimal token, because it is followed by 'W'.
+ * For char stream "12.0D 34.E2+0.12 " 12.0D is a valid decimal token because it is followed
+ * by a space. 34.E2 is a valid decimal token because it is followed by symbol '+'
+ * which is not a digit or letter or underscore.
+ */
+ public boolean isValidDecimal() {
+ int nextChar = _input.LA(1);
+ if (nextChar >= 'A' && nextChar <= 'Z' || nextChar >= '0' && nextChar <= '9' ||
+ nextChar == '_') {
+ return false;
+ } else {
+ return true;
+ }
+ }
+ }
+
+tokens {
+ DELIMITER
+}
+
+singleStatement
+ : statement EOF
+ ;
+
+statement
+ : OPTIMIZE multipartIdentifier whereClause? zorderClause #optimizeZorder
+ | .*? #passThrough
+ ;
+
+whereClause
+ : WHERE partitionPredicate = predicateToken
+ ;
+
+zorderClause
+ : ZORDER BY order+=multipartIdentifier (',' order+=multipartIdentifier)*
+ ;
+
+// We don't have an expression rule in our grammar here, so we just grab the tokens and defer
+// parsing them to later.
+predicateToken
+ : .+?
+ ;
+
+multipartIdentifier
+ : parts+=identifier ('.' parts+=identifier)*
+ ;
+
+identifier
+ : strictIdentifier
+ ;
+
+strictIdentifier
+ : IDENTIFIER #unquotedIdentifier
+ | quotedIdentifier #quotedIdentifierAlternative
+ | nonReserved #unquotedIdentifier
+ ;
+
+quotedIdentifier
+ : BACKQUOTED_IDENTIFIER
+ ;
+
+nonReserved
+ : AND
+ | BY
+ | FALSE
+ | DATE
+ | INTERVAL
+ | OPTIMIZE
+ | OR
+ | TABLE
+ | TIMESTAMP
+ | TRUE
+ | WHERE
+ | ZORDER
+ ;
+
+AND: 'AND';
+BY: 'BY';
+FALSE: 'FALSE';
+DATE: 'DATE';
+INTERVAL: 'INTERVAL';
+OPTIMIZE: 'OPTIMIZE';
+OR: 'OR';
+TABLE: 'TABLE';
+TIMESTAMP: 'TIMESTAMP';
+TRUE: 'TRUE';
+WHERE: 'WHERE';
+ZORDER: 'ZORDER';
+
+MINUS: '-';
+
+BIGINT_LITERAL
+ : DIGIT+ 'L'
+ ;
+
+SMALLINT_LITERAL
+ : DIGIT+ 'S'
+ ;
+
+TINYINT_LITERAL
+ : DIGIT+ 'Y'
+ ;
+
+INTEGER_VALUE
+ : DIGIT+
+ ;
+
+DECIMAL_VALUE
+ : DIGIT+ EXPONENT
+ | DECIMAL_DIGITS EXPONENT? {isValidDecimal()}?
+ ;
+
+DOUBLE_LITERAL
+ : DIGIT+ EXPONENT? 'D'
+ | DECIMAL_DIGITS EXPONENT? 'D' {isValidDecimal()}?
+ ;
+
+BIGDECIMAL_LITERAL
+ : DIGIT+ EXPONENT? 'BD'
+ | DECIMAL_DIGITS EXPONENT? 'BD' {isValidDecimal()}?
+ ;
+
+BACKQUOTED_IDENTIFIER
+ : '`' ( ~'`' | '``' )* '`'
+ ;
+
+IDENTIFIER
+ : (LETTER | DIGIT | '_')+
+ ;
+
+fragment DECIMAL_DIGITS
+ : DIGIT+ '.' DIGIT*
+ | '.' DIGIT+
+ ;
+
+fragment EXPONENT
+ : 'E' [+-]? DIGIT+
+ ;
+
+fragment DIGIT
+ : [0-9]
+ ;
+
+fragment LETTER
+ : [A-Z]
+ ;
+
+SIMPLE_COMMENT
+ : '--' ~[\r\n]* '\r'? '\n'? -> channel(HIDDEN)
+ ;
+
+BRACKETED_COMMENT
+ : '/*' .*? '*/' -> channel(HIDDEN)
+ ;
+
+WS : [ \r\n\t]+ -> channel(HIDDEN)
+ ;
+
+// Catch-all for anything we can't recognize.
+// We use this to be able to ignore and recover all the text
+// when splitting statements with DelimiterLexer
+UNRECOGNIZED
+ : .
+ ;
diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/DropIgnoreNonexistent.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/DropIgnoreNonexistent.scala
new file mode 100644
index 000000000..e33632b8b
--- /dev/null
+++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/DropIgnoreNonexistent.scala
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kyuubi.sql
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunctionName, UnresolvedRelation}
+import org.apache.spark.sql.catalyst.plans.logical.{DropFunction, DropNamespace, LogicalPlan, NoopCommand, UncacheTable}
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution.command.{AlterTableDropPartitionCommand, DropTableCommand}
+
+import org.apache.kyuubi.sql.KyuubiSQLConf._
+
+case class DropIgnoreNonexistent(session: SparkSession) extends Rule[LogicalPlan] {
+
+ override def apply(plan: LogicalPlan): LogicalPlan = {
+ if (conf.getConf(DROP_IGNORE_NONEXISTENT)) {
+ plan match {
+ case i @ AlterTableDropPartitionCommand(_, _, false, _, _) =>
+ i.copy(ifExists = true)
+ case i @ DropTableCommand(_, false, _, _) =>
+ i.copy(ifExists = true)
+ case i @ DropNamespace(_, false, _) =>
+ i.copy(ifExists = true)
+ case UncacheTable(u: UnresolvedRelation, false, _) =>
+ NoopCommand("UNCACHE TABLE", u.multipartIdentifier)
+ case DropFunction(u: UnresolvedFunctionName, false) =>
+ NoopCommand("DROP FUNCTION", u.multipartIdentifier)
+ case _ => plan
+ }
+ } else {
+ plan
+ }
+ }
+
+}
diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/InferRebalanceAndSortOrders.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/InferRebalanceAndSortOrders.scala
new file mode 100644
index 000000000..fcbf5c0a1
--- /dev/null
+++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/InferRebalanceAndSortOrders.scala
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.sql
+
+import scala.annotation.tailrec
+
+import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeSet, Expression, NamedExpression, UnaryExpression}
+import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
+import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, RightOuter}
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan, Project, Sort, SubqueryAlias, View}
+
+/**
+ * Infer the columns for Rebalance and Sort to improve the compression ratio.
+ *
+ * For example
+ * {{{
+ * INSERT INTO TABLE t PARTITION(p='a')
+ * SELECT * FROM t1 JOIN t2 on t1.c1 = t2.c1
+ * }}}
+ * the inferred columns are: t1.c1
+ */
+object InferRebalanceAndSortOrders {
+
+ type PartitioningAndOrdering = (Seq[Expression], Seq[Expression])
+
+ private def getAliasMap(named: Seq[NamedExpression]): Map[Expression, Attribute] = {
+ @tailrec
+ def throughUnary(e: Expression): Expression = e match {
+ case u: UnaryExpression if u.deterministic =>
+ throughUnary(u.child)
+ case _ => e
+ }
+
+ named.flatMap {
+ case a @ Alias(child, _) =>
+ Some((throughUnary(child).canonicalized, a.toAttribute))
+ case _ => None
+ }.toMap
+ }
+
+ def infer(plan: LogicalPlan): Option[PartitioningAndOrdering] = {
+ def candidateKeys(
+ input: LogicalPlan,
+ output: AttributeSet = AttributeSet.empty): Option[PartitioningAndOrdering] = {
+ input match {
+ case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, _, _, _, _, _) =>
+ joinType match {
+ case LeftSemi | LeftAnti | LeftOuter => Some((leftKeys, leftKeys))
+ case RightOuter => Some((rightKeys, rightKeys))
+ case Inner | FullOuter =>
+ if (output.isEmpty) {
+ Some((leftKeys ++ rightKeys, leftKeys ++ rightKeys))
+ } else {
+ assert(leftKeys.length == rightKeys.length)
+ val keys = leftKeys.zip(rightKeys).flatMap { case (left, right) =>
+ if (left.references.subsetOf(output)) {
+ Some(left)
+ } else if (right.references.subsetOf(output)) {
+ Some(right)
+ } else {
+ None
+ }
+ }
+ Some((keys, keys))
+ }
+ case _ => None
+ }
+ case agg: Aggregate =>
+ val aliasMap = getAliasMap(agg.aggregateExpressions)
+ Some((
+ agg.groupingExpressions.map(p => aliasMap.getOrElse(p.canonicalized, p)),
+ agg.groupingExpressions.map(o => aliasMap.getOrElse(o.canonicalized, o))))
+ case s: Sort => Some((s.order.map(_.child), s.order.map(_.child)))
+ case p: Project =>
+ val aliasMap = getAliasMap(p.projectList)
+ candidateKeys(p.child, p.references).map { case (partitioning, ordering) =>
+ (
+ partitioning.map(p => aliasMap.getOrElse(p.canonicalized, p)),
+ ordering.map(o => aliasMap.getOrElse(o.canonicalized, o)))
+ }
+ case f: Filter => candidateKeys(f.child, output)
+ case s: SubqueryAlias => candidateKeys(s.child, output)
+ case v: View => candidateKeys(v.child, output)
+
+ case _ => None
+ }
+ }
+
+ candidateKeys(plan).map { case (partitioning, ordering) =>
+ (
+ partitioning.filter(_.references.subsetOf(plan.outputSet)),
+ ordering.filter(_.references.subsetOf(plan.outputSet)))
+ }
+ }
+}
diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/InsertShuffleNodeBeforeJoin.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/InsertShuffleNodeBeforeJoin.scala
new file mode 100644
index 000000000..1a02e8c1e
--- /dev/null
+++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/InsertShuffleNodeBeforeJoin.scala
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.sql
+
+import org.apache.spark.sql.catalyst.plans.physical.Distribution
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution.{SortExec, SparkPlan}
+import org.apache.spark.sql.execution.adaptive.QueryStageExec
+import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
+import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec}
+import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec}
+import org.apache.spark.sql.internal.SQLConf
+
+import org.apache.kyuubi.sql.KyuubiSQLConf._
+
+/**
+ * Insert shuffle node before join if it doesn't exist to make `OptimizeSkewedJoin` works.
+ */
+object InsertShuffleNodeBeforeJoin extends Rule[SparkPlan] {
+
+ override def apply(plan: SparkPlan): SparkPlan = {
+ // this rule has no meaning without AQE
+ if (!conf.getConf(FORCE_SHUFFLE_BEFORE_JOIN) ||
+ !conf.getConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED)) {
+ return plan
+ }
+
+ val newPlan = insertShuffleBeforeJoin(plan)
+ if (plan.fastEquals(newPlan)) {
+ plan
+ } else {
+ // make sure the output partitioning and ordering will not be broken.
+ KyuubiEnsureRequirements.apply(newPlan)
+ }
+ }
+
+ // Since spark 3.3, insertShuffleBeforeJoin shouldn't be applied if join is skewed.
+ private def insertShuffleBeforeJoin(plan: SparkPlan): SparkPlan = plan transformUp {
+ case smj @ SortMergeJoinExec(_, _, _, _, l, r, isSkewJoin) if !isSkewJoin =>
+ smj.withNewChildren(checkAndInsertShuffle(smj.requiredChildDistribution.head, l) ::
+ checkAndInsertShuffle(smj.requiredChildDistribution(1), r) :: Nil)
+
+ case shj: ShuffledHashJoinExec if !shj.isSkewJoin =>
+ if (!shj.left.isInstanceOf[Exchange] && !shj.right.isInstanceOf[Exchange]) {
+ shj.withNewChildren(withShuffleExec(shj.requiredChildDistribution.head, shj.left) ::
+ withShuffleExec(shj.requiredChildDistribution(1), shj.right) :: Nil)
+ } else if (!shj.left.isInstanceOf[Exchange]) {
+ shj.withNewChildren(
+ withShuffleExec(shj.requiredChildDistribution.head, shj.left) :: shj.right :: Nil)
+ } else if (!shj.right.isInstanceOf[Exchange]) {
+ shj.withNewChildren(
+ shj.left :: withShuffleExec(shj.requiredChildDistribution(1), shj.right) :: Nil)
+ } else {
+ shj
+ }
+ }
+
+ private def checkAndInsertShuffle(
+ distribution: Distribution,
+ child: SparkPlan): SparkPlan = child match {
+ case SortExec(_, _, _: Exchange, _) =>
+ child
+ case SortExec(_, _, _: QueryStageExec, _) =>
+ child
+ case sort @ SortExec(_, _, agg: BaseAggregateExec, _) =>
+ sort.withNewChildren(withShuffleExec(distribution, agg) :: Nil)
+ case _ =>
+ withShuffleExec(distribution, child)
+ }
+
+ private def withShuffleExec(distribution: Distribution, child: SparkPlan): SparkPlan = {
+ val numPartitions = distribution.requiredNumPartitions
+ .getOrElse(conf.numShufflePartitions)
+ ShuffleExchangeExec(distribution.createPartitioning(numPartitions), child)
+ }
+}
diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiEnsureRequirements.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiEnsureRequirements.scala
new file mode 100644
index 000000000..586cad838
--- /dev/null
+++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiEnsureRequirements.scala
@@ -0,0 +1,464 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.sql
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.SortOrder
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.JoinType
+import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.util.InternalRowComparableWrapper
+import org.apache.spark.sql.execution.{SortExec, SparkPlan}
+import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
+import org.apache.spark.sql.execution.exchange._
+import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec}
+import org.apache.spark.sql.internal.SQLConf
+
+/**
+ * Copy from Apache Spark `EnsureRequirements`
+ * 1. remove reorder join predicates
+ * 2. remove shuffle pruning
+ */
+object KyuubiEnsureRequirements extends Rule[SparkPlan] {
+
+ private def ensureDistributionAndOrdering(
+ parent: Option[SparkPlan],
+ originalChildren: Seq[SparkPlan],
+ requiredChildDistributions: Seq[Distribution],
+ requiredChildOrderings: Seq[Seq[SortOrder]],
+ shuffleOrigin: ShuffleOrigin): Seq[SparkPlan] = {
+ assert(requiredChildDistributions.length == originalChildren.length)
+ assert(requiredChildOrderings.length == originalChildren.length)
+ // Ensure that the operator's children satisfy their output distribution requirements.
+ var children = originalChildren.zip(requiredChildDistributions).map {
+ case (child, distribution) if child.outputPartitioning.satisfies(distribution) =>
+ child
+ case (child, BroadcastDistribution(mode)) =>
+ BroadcastExchangeExec(mode, child)
+ case (child, distribution) =>
+ val numPartitions = distribution.requiredNumPartitions
+ .getOrElse(conf.numShufflePartitions)
+ ShuffleExchangeExec(distribution.createPartitioning(numPartitions), child, shuffleOrigin)
+ }
+
+ // Get the indexes of children which have specified distribution requirements and need to be
+ // co-partitioned.
+ val childrenIndexes = requiredChildDistributions.zipWithIndex.filter {
+ case (_: ClusteredDistribution, _) => true
+ case _ => false
+ }.map(_._2)
+
+ // Special case: if all sides of the join are single partition and it's physical size less than
+ // or equal spark.sql.maxSinglePartitionBytes.
+ val preferSinglePartition = childrenIndexes.forall { i =>
+ children(i).outputPartitioning == SinglePartition &&
+ children(i).logicalLink
+ .forall(_.stats.sizeInBytes <= conf.getConf(SQLConf.MAX_SINGLE_PARTITION_BYTES))
+ }
+
+ // If there are more than one children, we'll need to check partitioning & distribution of them
+ // and see if extra shuffles are necessary.
+ if (childrenIndexes.length > 1 && !preferSinglePartition) {
+ val specs = childrenIndexes.map(i => {
+ val requiredDist = requiredChildDistributions(i)
+ assert(
+ requiredDist.isInstanceOf[ClusteredDistribution],
+ s"Expected ClusteredDistribution but found ${requiredDist.getClass.getSimpleName}")
+ i -> children(i).outputPartitioning.createShuffleSpec(
+ requiredDist.asInstanceOf[ClusteredDistribution])
+ }).toMap
+
+ // Find out the shuffle spec that gives better parallelism. Currently this is done by
+ // picking the spec with the largest number of partitions.
+ //
+ // NOTE: this is not optimal for the case when there are more than 2 children. Consider:
+ // (10, 10, 11)
+ // where the number represent the number of partitions for each child, it's better to pick 10
+ // here since we only need to shuffle one side - we'd need to shuffle two sides if we pick 11.
+ //
+ // However this should be sufficient for now since in Spark nodes with multiple children
+ // always have exactly 2 children.
+
+ // Whether we should consider `spark.sql.shuffle.partitions` and ensure enough parallelism
+ // during shuffle. To achieve a good trade-off between parallelism and shuffle cost, we only
+ // consider the minimum parallelism iff ALL children need to be re-shuffled.
+ //
+ // A child needs to be re-shuffled iff either one of below is true:
+ // 1. It can't create partitioning by itself, i.e., `canCreatePartitioning` returns false
+ // (as for the case of `RangePartitioning`), therefore it needs to be re-shuffled
+ // according to other shuffle spec.
+ // 2. It already has `ShuffleExchangeLike`, so we can re-use existing shuffle without
+ // introducing extra shuffle.
+ //
+ // On the other hand, in scenarios such as:
+ // HashPartitioning(5) <-> HashPartitioning(6)
+ // while `spark.sql.shuffle.partitions` is 10, we'll only re-shuffle the left side and make it
+ // HashPartitioning(6).
+ val shouldConsiderMinParallelism = specs.forall(p =>
+ !p._2.canCreatePartitioning || children(p._1).isInstanceOf[ShuffleExchangeLike])
+ // Choose all the specs that can be used to shuffle other children
+ val candidateSpecs = specs
+ .filter(_._2.canCreatePartitioning)
+ .filter(p =>
+ !shouldConsiderMinParallelism ||
+ children(p._1).outputPartitioning.numPartitions >= conf.defaultNumShufflePartitions)
+ val bestSpecOpt = if (candidateSpecs.isEmpty) {
+ None
+ } else {
+ // When choosing specs, we should consider those children with no `ShuffleExchangeLike` node
+ // first. For instance, if we have:
+ // A: (No_Exchange, 100) <---> B: (Exchange, 120)
+ // it's better to pick A and change B to (Exchange, 100) instead of picking B and insert a
+ // new shuffle for A.
+ val candidateSpecsWithoutShuffle = candidateSpecs.filter { case (k, _) =>
+ !children(k).isInstanceOf[ShuffleExchangeLike]
+ }
+ val finalCandidateSpecs = if (candidateSpecsWithoutShuffle.nonEmpty) {
+ candidateSpecsWithoutShuffle
+ } else {
+ candidateSpecs
+ }
+ // Pick the spec with the best parallelism
+ Some(finalCandidateSpecs.values.maxBy(_.numPartitions))
+ }
+
+ // Check if the following conditions are satisfied:
+ // 1. There are exactly two children (e.g., join). Note that Spark doesn't support
+ // multi-way join at the moment, so this check should be sufficient.
+ // 2. All children are of `KeyGroupedPartitioning`, and they are compatible with each other
+ // If both are true, skip shuffle.
+ val isKeyGroupCompatible = parent.isDefined &&
+ children.length == 2 && childrenIndexes.length == 2 && {
+ val left = children.head
+ val right = children(1)
+ val newChildren = checkKeyGroupCompatible(
+ parent.get,
+ left,
+ right,
+ requiredChildDistributions)
+ if (newChildren.isDefined) {
+ children = newChildren.get
+ }
+ newChildren.isDefined
+ }
+
+ children = children.zip(requiredChildDistributions).zipWithIndex.map {
+ case ((child, _), idx) if isKeyGroupCompatible || !childrenIndexes.contains(idx) =>
+ child
+ case ((child, dist), idx) =>
+ if (bestSpecOpt.isDefined && bestSpecOpt.get.isCompatibleWith(specs(idx))) {
+ child
+ } else {
+ val newPartitioning = bestSpecOpt.map { bestSpec =>
+ // Use the best spec to create a new partitioning to re-shuffle this child
+ val clustering = dist.asInstanceOf[ClusteredDistribution].clustering
+ bestSpec.createPartitioning(clustering)
+ }.getOrElse {
+ // No best spec available, so we create default partitioning from the required
+ // distribution
+ val numPartitions = dist.requiredNumPartitions
+ .getOrElse(conf.numShufflePartitions)
+ dist.createPartitioning(numPartitions)
+ }
+
+ child match {
+ case ShuffleExchangeExec(_, c, so, ps) =>
+ ShuffleExchangeExec(newPartitioning, c, so, ps)
+ case _ => ShuffleExchangeExec(newPartitioning, child)
+ }
+ }
+ }
+ }
+
+ // Now that we've performed any necessary shuffles, add sorts to guarantee output orderings:
+ children = children.zip(requiredChildOrderings).map { case (child, requiredOrdering) =>
+ // If child.outputOrdering already satisfies the requiredOrdering, we do not need to sort.
+ if (SortOrder.orderingSatisfies(child.outputOrdering, requiredOrdering)) {
+ child
+ } else {
+ SortExec(requiredOrdering, global = false, child = child)
+ }
+ }
+
+ children
+ }
+
+ /**
+ * Checks whether two children, `left` and `right`, of a join operator have compatible
+ * `KeyGroupedPartitioning`, and can benefit from storage-partitioned join.
+ *
+ * Returns the updated new children if the check is successful, otherwise `None`.
+ */
+ private def checkKeyGroupCompatible(
+ parent: SparkPlan,
+ left: SparkPlan,
+ right: SparkPlan,
+ requiredChildDistribution: Seq[Distribution]): Option[Seq[SparkPlan]] = {
+ parent match {
+ case smj: SortMergeJoinExec =>
+ checkKeyGroupCompatible(left, right, smj.joinType, requiredChildDistribution)
+ case sj: ShuffledHashJoinExec =>
+ checkKeyGroupCompatible(left, right, sj.joinType, requiredChildDistribution)
+ case _ =>
+ None
+ }
+ }
+
+ private def checkKeyGroupCompatible(
+ left: SparkPlan,
+ right: SparkPlan,
+ joinType: JoinType,
+ requiredChildDistribution: Seq[Distribution]): Option[Seq[SparkPlan]] = {
+ assert(requiredChildDistribution.length == 2)
+
+ var newLeft = left
+ var newRight = right
+
+ val specs = Seq(left, right).zip(requiredChildDistribution).map { case (p, d) =>
+ if (!d.isInstanceOf[ClusteredDistribution]) return None
+ val cd = d.asInstanceOf[ClusteredDistribution]
+ val specOpt = createKeyGroupedShuffleSpec(p.outputPartitioning, cd)
+ if (specOpt.isEmpty) return None
+ specOpt.get
+ }
+
+ val leftSpec = specs.head
+ val rightSpec = specs(1)
+
+ var isCompatible = false
+ if (!conf.v2BucketingPushPartValuesEnabled) {
+ isCompatible = leftSpec.isCompatibleWith(rightSpec)
+ } else {
+ logInfo("Pushing common partition values for storage-partitioned join")
+ isCompatible = leftSpec.areKeysCompatible(rightSpec)
+
+ // Partition expressions are compatible. Regardless of whether partition values
+ // match from both sides of children, we can calculate a superset of partition values and
+ // push-down to respective data sources so they can adjust their output partitioning by
+ // filling missing partition keys with empty partitions. As result, we can still avoid
+ // shuffle.
+ //
+ // For instance, if two sides of a join have partition expressions
+ // `day(a)` and `day(b)` respectively
+ // (the join query could be `SELECT ... FROM t1 JOIN t2 on t1.a = t2.b`), but
+ // with different partition values:
+ // `day(a)`: [0, 1]
+ // `day(b)`: [1, 2, 3]
+ // Following the case 2 above, we don't have to shuffle both sides, but instead can
+ // just push the common set of partition values: `[0, 1, 2, 3]` down to the two data
+ // sources.
+ if (isCompatible) {
+ val leftPartValues = leftSpec.partitioning.partitionValues
+ val rightPartValues = rightSpec.partitioning.partitionValues
+
+ logInfo(
+ s"""
+ |Left side # of partitions: ${leftPartValues.size}
+ |Right side # of partitions: ${rightPartValues.size}
+ |""".stripMargin)
+
+ // As partition keys are compatible, we can pick either left or right as partition
+ // expressions
+ val partitionExprs = leftSpec.partitioning.expressions
+
+ var mergedPartValues = InternalRowComparableWrapper
+ .mergePartitions(leftSpec.partitioning, rightSpec.partitioning, partitionExprs)
+ .map(v => (v, 1))
+
+ logInfo(s"After merging, there are ${mergedPartValues.size} partitions")
+
+ var replicateLeftSide = false
+ var replicateRightSide = false
+ var applyPartialClustering = false
+
+ // This means we allow partitions that are not clustered on their values,
+ // that is, multiple partitions with the same partition value. In the
+ // following, we calculate how many partitions that each distinct partition
+ // value has, and pushdown the information to scans, so they can adjust their
+ // final input partitions respectively.
+ if (conf.v2BucketingPartiallyClusteredDistributionEnabled) {
+ logInfo("Calculating partially clustered distribution for " +
+ "storage-partitioned join")
+
+ // Similar to `OptimizeSkewedJoin`, we need to check join type and decide
+ // whether partially clustered distribution can be applied. For instance, the
+ // optimization cannot be applied to a left outer join, where the left hand
+ // side is chosen as the side to replicate partitions according to stats.
+ // Otherwise, query result could be incorrect.
+ val canReplicateLeft = canReplicateLeftSide(joinType)
+ val canReplicateRight = canReplicateRightSide(joinType)
+
+ if (!canReplicateLeft && !canReplicateRight) {
+ logInfo("Skipping partially clustered distribution as it cannot be applied for " +
+ s"join type '$joinType'")
+ } else {
+ val leftLink = left.logicalLink
+ val rightLink = right.logicalLink
+
+ replicateLeftSide =
+ if (leftLink.isDefined && rightLink.isDefined &&
+ leftLink.get.stats.sizeInBytes > 1 &&
+ rightLink.get.stats.sizeInBytes > 1) {
+ logInfo(
+ s"""
+ |Using plan statistics to determine which side of join to fully
+ |cluster partition values:
+ |Left side size (in bytes): ${leftLink.get.stats.sizeInBytes}
+ |Right side size (in bytes): ${rightLink.get.stats.sizeInBytes}
+ |""".stripMargin)
+ leftLink.get.stats.sizeInBytes < rightLink.get.stats.sizeInBytes
+ } else {
+ // As a simple heuristic, we pick the side with fewer number of partitions
+ // to apply the grouping & replication of partitions
+ logInfo("Using number of partitions to determine which side of join " +
+ "to fully cluster partition values")
+ leftPartValues.size < rightPartValues.size
+ }
+
+ replicateRightSide = !replicateLeftSide
+
+ // Similar to skewed join, we need to check the join type to see whether replication
+ // of partitions can be applied. For instance, replication should not be allowed for
+ // the left-hand side of a right outer join.
+ if (replicateLeftSide && !canReplicateLeft) {
+ logInfo("Left-hand side is picked but cannot be applied to join type " +
+ s"'$joinType'. Skipping partially clustered distribution.")
+ replicateLeftSide = false
+ } else if (replicateRightSide && !canReplicateRight) {
+ logInfo("Right-hand side is picked but cannot be applied to join type " +
+ s"'$joinType'. Skipping partially clustered distribution.")
+ replicateRightSide = false
+ } else {
+ val partValues = if (replicateLeftSide) rightPartValues else leftPartValues
+ val numExpectedPartitions = partValues
+ .map(InternalRowComparableWrapper(_, partitionExprs))
+ .groupBy(identity)
+ .mapValues(_.size)
+
+ mergedPartValues = mergedPartValues.map { case (partVal, numParts) =>
+ (
+ partVal,
+ numExpectedPartitions.getOrElse(
+ InternalRowComparableWrapper(partVal, partitionExprs),
+ numParts))
+ }
+
+ logInfo("After applying partially clustered distribution, there are " +
+ s"${mergedPartValues.map(_._2).sum} partitions.")
+ applyPartialClustering = true
+ }
+ }
+ }
+
+ // Now we need to push-down the common partition key to the scan in each child
+ newLeft = populatePartitionValues(
+ left,
+ mergedPartValues,
+ applyPartialClustering,
+ replicateLeftSide)
+ newRight = populatePartitionValues(
+ right,
+ mergedPartValues,
+ applyPartialClustering,
+ replicateRightSide)
+ }
+ }
+
+ if (isCompatible) Some(Seq(newLeft, newRight)) else None
+ }
+
+ // Similar to `OptimizeSkewedJoin.canSplitRightSide`
+ private def canReplicateLeftSide(joinType: JoinType): Boolean = {
+ joinType == Inner || joinType == Cross || joinType == RightOuter
+ }
+
+ // Similar to `OptimizeSkewedJoin.canSplitLeftSide`
+ private def canReplicateRightSide(joinType: JoinType): Boolean = {
+ joinType == Inner || joinType == Cross || joinType == LeftSemi ||
+ joinType == LeftAnti || joinType == LeftOuter
+ }
+
+ // Populate the common partition values down to the scan nodes
+ private def populatePartitionValues(
+ plan: SparkPlan,
+ values: Seq[(InternalRow, Int)],
+ applyPartialClustering: Boolean,
+ replicatePartitions: Boolean): SparkPlan = plan match {
+ case scan: BatchScanExec =>
+ scan.copy(spjParams = scan.spjParams.copy(
+ commonPartitionValues = Some(values),
+ applyPartialClustering = applyPartialClustering,
+ replicatePartitions = replicatePartitions))
+ case node =>
+ node.mapChildren(child =>
+ populatePartitionValues(
+ child,
+ values,
+ applyPartialClustering,
+ replicatePartitions))
+ }
+
+ /**
+ * Tries to create a [[KeyGroupedShuffleSpec]] from the input partitioning and distribution, if
+ * the partitioning is a [[KeyGroupedPartitioning]] (either directly or indirectly), and
+ * satisfies the given distribution.
+ */
+ private def createKeyGroupedShuffleSpec(
+ partitioning: Partitioning,
+ distribution: ClusteredDistribution): Option[KeyGroupedShuffleSpec] = {
+ def check(partitioning: KeyGroupedPartitioning): Option[KeyGroupedShuffleSpec] = {
+ val attributes = partitioning.expressions.flatMap(_.collectLeaves())
+ val clustering = distribution.clustering
+
+ val satisfies = if (SQLConf.get.getConf(SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION)) {
+ attributes.length == clustering.length && attributes.zip(clustering).forall {
+ case (l, r) => l.semanticEquals(r)
+ }
+ } else {
+ partitioning.satisfies(distribution)
+ }
+
+ if (satisfies) {
+ Some(partitioning.createShuffleSpec(distribution).asInstanceOf[KeyGroupedShuffleSpec])
+ } else {
+ None
+ }
+ }
+
+ partitioning match {
+ case p: KeyGroupedPartitioning => check(p)
+ case PartitioningCollection(partitionings) =>
+ val specs = partitionings.map(p => createKeyGroupedShuffleSpec(p, distribution))
+ assert(specs.forall(_.isEmpty) || specs.forall(_.isDefined))
+ specs.head
+ case _ => None
+ }
+ }
+
+ def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
+ case operator: SparkPlan =>
+ val newChildren = ensureDistributionAndOrdering(
+ Some(operator),
+ operator.children,
+ operator.requiredChildDistribution,
+ operator.requiredChildOrdering,
+ ENSURE_REQUIREMENTS)
+ operator.withNewChildren(newChildren)
+ }
+}
diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiQueryStagePreparation.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiQueryStagePreparation.scala
new file mode 100644
index 000000000..a7fcbecd4
--- /dev/null
+++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiQueryStagePreparation.scala
@@ -0,0 +1,194 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.sql
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.adaptive.QueryStageExec
+import org.apache.spark.sql.execution.command.{ResetCommand, SetCommand}
+import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, ReusedExchangeExec, ShuffleExchangeLike}
+import org.apache.spark.sql.internal.SQLConf
+
+import org.apache.kyuubi.sql.KyuubiSQLConf._
+
+/**
+ * This rule split stage into two parts:
+ * 1. previous stage
+ * 2. final stage
+ * For final stage, we can inject extra config. It's useful if we use repartition to optimize
+ * small files that needs bigger shuffle partition size than previous.
+ *
+ * Let's say we have a query with 3 stages, then the logical machine like:
+ *
+ * Set/Reset Command -> cleanup previousStage config if user set the spark config.
+ * Query -> AQE -> stage1 -> preparation (use previousStage to overwrite spark config)
+ * -> AQE -> stage2 -> preparation (use spark config)
+ * -> AQE -> stage3 -> preparation (use finalStage config to overwrite spark config,
+ * store spark config to previousStage.)
+ *
+ * An example of the new finalStage config:
+ * `spark.sql.adaptive.advisoryPartitionSizeInBytes` ->
+ * `spark.sql.finalStage.adaptive.advisoryPartitionSizeInBytes`
+ */
+case class FinalStageConfigIsolation(session: SparkSession) extends Rule[SparkPlan] {
+ import FinalStageConfigIsolation._
+
+ override def apply(plan: SparkPlan): SparkPlan = {
+ // this rule has no meaning without AQE
+ if (!conf.getConf(FINAL_STAGE_CONFIG_ISOLATION) ||
+ !conf.getConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED)) {
+ return plan
+ }
+
+ if (isFinalStage(plan)) {
+ // We can not get the whole plan at query preparation phase to detect if current plan is
+ // for writing, so we depend on a tag which is been injected at post resolution phase.
+ // Note: we should still do clean up previous config for non-final stage to avoid such case:
+ // the first statement is write, but the second statement is query.
+ if (conf.getConf(FINAL_STAGE_CONFIG_ISOLATION_WRITE_ONLY) &&
+ !WriteUtils.isWrite(session, plan)) {
+ return plan
+ }
+
+ // set config for final stage
+ session.conf.getAll.filter(_._1.startsWith(FINAL_STAGE_CONFIG_PREFIX)).foreach {
+ case (k, v) =>
+ val sparkConfigKey = s"spark.sql.${k.substring(FINAL_STAGE_CONFIG_PREFIX.length)}"
+ val previousStageConfigKey =
+ s"$PREVIOUS_STAGE_CONFIG_PREFIX${k.substring(FINAL_STAGE_CONFIG_PREFIX.length)}"
+ // store the previous config only if we have not stored, to avoid some query only
+ // have one stage that will overwrite real config.
+ if (!session.sessionState.conf.contains(previousStageConfigKey)) {
+ val originalValue =
+ if (session.conf.getOption(sparkConfigKey).isDefined) {
+ session.sessionState.conf.getConfString(sparkConfigKey)
+ } else {
+ // the default value of config is None, so we need to use a internal tag
+ INTERNAL_UNSET_CONFIG_TAG
+ }
+ logInfo(s"Store config: $sparkConfigKey to previousStage, " +
+ s"original value: $originalValue ")
+ session.sessionState.conf.setConfString(previousStageConfigKey, originalValue)
+ }
+ logInfo(s"For final stage: set $sparkConfigKey = $v.")
+ session.conf.set(sparkConfigKey, v)
+ }
+ } else {
+ // reset config for previous stage
+ session.conf.getAll.filter(_._1.startsWith(PREVIOUS_STAGE_CONFIG_PREFIX)).foreach {
+ case (k, v) =>
+ val sparkConfigKey = s"spark.sql.${k.substring(PREVIOUS_STAGE_CONFIG_PREFIX.length)}"
+ logInfo(s"For previous stage: set $sparkConfigKey = $v.")
+ if (v == INTERNAL_UNSET_CONFIG_TAG) {
+ session.conf.unset(sparkConfigKey)
+ } else {
+ session.conf.set(sparkConfigKey, v)
+ }
+ // unset config so that we do not need to reset configs for every previous stage
+ session.conf.unset(k)
+ }
+ }
+
+ plan
+ }
+
+ /**
+ * Currently formula depend on AQE in Spark 3.1.1, not sure it can work in future.
+ */
+ private def isFinalStage(plan: SparkPlan): Boolean = {
+ var shuffleNum = 0
+ var broadcastNum = 0
+ var reusedNum = 0
+ var queryStageNum = 0
+
+ def collectNumber(p: SparkPlan): SparkPlan = {
+ p transform {
+ case shuffle: ShuffleExchangeLike =>
+ shuffleNum += 1
+ shuffle
+
+ case broadcast: BroadcastExchangeLike =>
+ broadcastNum += 1
+ broadcast
+
+ case reusedExchangeExec: ReusedExchangeExec =>
+ reusedNum += 1
+ reusedExchangeExec
+
+ // query stage is leaf node so we need to transform it manually
+ // compatible with Spark 3.5:
+ // SPARK-42101: table cache is a independent query stage, so do not need include it.
+ case queryStage: QueryStageExec if queryStage.nodeName != "TableCacheQueryStage" =>
+ queryStageNum += 1
+ collectNumber(queryStage.plan)
+ queryStage
+ }
+ }
+ collectNumber(plan)
+
+ if (shuffleNum == 0) {
+ // we don not care about broadcast stage here since it won't change partition number.
+ true
+ } else if (shuffleNum + broadcastNum + reusedNum == queryStageNum) {
+ true
+ } else {
+ false
+ }
+ }
+}
+object FinalStageConfigIsolation {
+ final val SQL_PREFIX = "spark.sql."
+ final val FINAL_STAGE_CONFIG_PREFIX = "spark.sql.finalStage."
+ final val PREVIOUS_STAGE_CONFIG_PREFIX = "spark.sql.previousStage."
+ final val INTERNAL_UNSET_CONFIG_TAG = "__INTERNAL_UNSET_CONFIG_TAG__"
+
+ def getPreviousStageConfigKey(configKey: String): Option[String] = {
+ if (configKey.startsWith(SQL_PREFIX)) {
+ Some(s"$PREVIOUS_STAGE_CONFIG_PREFIX${configKey.substring(SQL_PREFIX.length)}")
+ } else {
+ None
+ }
+ }
+}
+
+case class FinalStageConfigIsolationCleanRule(session: SparkSession) extends Rule[LogicalPlan] {
+ import FinalStageConfigIsolation._
+
+ override def apply(plan: LogicalPlan): LogicalPlan = plan match {
+ case set @ SetCommand(Some((k, Some(_)))) if k.startsWith(SQL_PREFIX) =>
+ checkAndUnsetPreviousStageConfig(k)
+ set
+
+ case reset @ ResetCommand(Some(k)) if k.startsWith(SQL_PREFIX) =>
+ checkAndUnsetPreviousStageConfig(k)
+ reset
+
+ case other => other
+ }
+
+ private def checkAndUnsetPreviousStageConfig(configKey: String): Unit = {
+ getPreviousStageConfigKey(configKey).foreach { previousStageConfigKey =>
+ if (session.sessionState.conf.contains(previousStageConfigKey)) {
+ logInfo(s"For previous stage: unset $previousStageConfigKey")
+ session.conf.unset(previousStageConfigKey)
+ }
+ }
+ }
+}
diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiSQLConf.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiSQLConf.scala
new file mode 100644
index 000000000..6f45dae12
--- /dev/null
+++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiSQLConf.scala
@@ -0,0 +1,276 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.sql
+
+import org.apache.spark.network.util.ByteUnit
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.SQLConf._
+
+object KyuubiSQLConf {
+
+ val INSERT_REPARTITION_BEFORE_WRITE =
+ buildConf("spark.sql.optimizer.insertRepartitionBeforeWrite.enabled")
+ .doc("Add repartition node at the top of query plan. An approach of merging small files.")
+ .version("1.2.0")
+ .booleanConf
+ .createWithDefault(true)
+
+ val INSERT_REPARTITION_NUM =
+ buildConf("spark.sql.optimizer.insertRepartitionNum")
+ .doc(s"The partition number if ${INSERT_REPARTITION_BEFORE_WRITE.key} is enabled. " +
+ s"If AQE is disabled, the default value is ${SQLConf.SHUFFLE_PARTITIONS.key}. " +
+ "If AQE is enabled, the default value is none that means depend on AQE. " +
+ "This config is used for Spark 3.1 only.")
+ .version("1.2.0")
+ .intConf
+ .createOptional
+
+ val DYNAMIC_PARTITION_INSERTION_REPARTITION_NUM =
+ buildConf("spark.sql.optimizer.dynamicPartitionInsertionRepartitionNum")
+ .doc(s"The partition number of each dynamic partition if " +
+ s"${INSERT_REPARTITION_BEFORE_WRITE.key} is enabled. " +
+ "We will repartition by dynamic partition columns to reduce the small file but that " +
+ "can cause data skew. This config is to extend the partition of dynamic " +
+ "partition column to avoid skew but may generate some small files.")
+ .version("1.2.0")
+ .intConf
+ .createWithDefault(100)
+
+ val FORCE_SHUFFLE_BEFORE_JOIN =
+ buildConf("spark.sql.optimizer.forceShuffleBeforeJoin.enabled")
+ .doc("Ensure shuffle node exists before shuffled join (shj and smj) to make AQE " +
+ "`OptimizeSkewedJoin` works (complex scenario join, multi table join).")
+ .version("1.2.0")
+ .booleanConf
+ .createWithDefault(false)
+
+ val FINAL_STAGE_CONFIG_ISOLATION =
+ buildConf("spark.sql.optimizer.finalStageConfigIsolation.enabled")
+ .doc("If true, the final stage support use different config with previous stage. " +
+ "The prefix of final stage config key should be `spark.sql.finalStage.`." +
+ "For example, the raw spark config: `spark.sql.adaptive.advisoryPartitionSizeInBytes`, " +
+ "then the final stage config should be: " +
+ "`spark.sql.finalStage.adaptive.advisoryPartitionSizeInBytes`.")
+ .version("1.2.0")
+ .booleanConf
+ .createWithDefault(false)
+
+ val SQL_CLASSIFICATION = "spark.sql.analyzer.classification"
+ val SQL_CLASSIFICATION_ENABLED =
+ buildConf("spark.sql.analyzer.classification.enabled")
+ .doc("When true, allows Kyuubi engine to judge this SQL's classification " +
+ s"and set `$SQL_CLASSIFICATION` back into sessionConf. " +
+ "Through this configuration item, Spark can optimizing configuration dynamic")
+ .version("1.4.0")
+ .booleanConf
+ .createWithDefault(false)
+
+ val INSERT_ZORDER_BEFORE_WRITING =
+ buildConf("spark.sql.optimizer.insertZorderBeforeWriting.enabled")
+ .doc("When true, we will follow target table properties to insert zorder or not. " +
+ "The key properties are: 1) kyuubi.zorder.enabled; if this property is true, we will " +
+ "insert zorder before writing data. 2) kyuubi.zorder.cols; string split by comma, we " +
+ "will zorder by these cols.")
+ .version("1.4.0")
+ .booleanConf
+ .createWithDefault(true)
+
+ val ZORDER_GLOBAL_SORT_ENABLED =
+ buildConf("spark.sql.optimizer.zorderGlobalSort.enabled")
+ .doc("When true, we do a global sort using zorder. Note that, it can cause data skew " +
+ "issue if the zorder columns have less cardinality. When false, we only do local sort " +
+ "using zorder.")
+ .version("1.4.0")
+ .booleanConf
+ .createWithDefault(true)
+
+ val REBALANCE_BEFORE_ZORDER =
+ buildConf("spark.sql.optimizer.rebalanceBeforeZorder.enabled")
+ .doc("When true, we do a rebalance before zorder in case data skew. " +
+ "Note that, if the insertion is dynamic partition we will use the partition " +
+ "columns to rebalance. Note that, this config only affects with Spark 3.3.x")
+ .version("1.6.0")
+ .booleanConf
+ .createWithDefault(false)
+
+ val REBALANCE_ZORDER_COLUMNS_ENABLED =
+ buildConf("spark.sql.optimizer.rebalanceZorderColumns.enabled")
+ .doc(s"When true and ${REBALANCE_BEFORE_ZORDER.key} is true, we do rebalance before " +
+ s"Z-Order. If it's dynamic partition insert, the rebalance expression will include " +
+ s"both partition columns and Z-Order columns. Note that, this config only " +
+ s"affects with Spark 3.3.x")
+ .version("1.6.0")
+ .booleanConf
+ .createWithDefault(false)
+
+ val TWO_PHASE_REBALANCE_BEFORE_ZORDER =
+ buildConf("spark.sql.optimizer.twoPhaseRebalanceBeforeZorder.enabled")
+ .doc(s"When true and ${REBALANCE_BEFORE_ZORDER.key} is true, we do two phase rebalance " +
+ s"before Z-Order for the dynamic partition write. The first phase rebalance using " +
+ s"dynamic partition column; The second phase rebalance using dynamic partition column + " +
+ s"Z-Order columns. Note that, this config only affects with Spark 3.3.x")
+ .version("1.6.0")
+ .booleanConf
+ .createWithDefault(false)
+
+ val ZORDER_USING_ORIGINAL_ORDERING_ENABLED =
+ buildConf("spark.sql.optimizer.zorderUsingOriginalOrdering.enabled")
+ .doc(s"When true and ${REBALANCE_BEFORE_ZORDER.key} is true, we do sort by " +
+ s"the original ordering i.e. lexicographical order. Note that, this config only " +
+ s"affects with Spark 3.3.x")
+ .version("1.6.0")
+ .booleanConf
+ .createWithDefault(false)
+
+ val WATCHDOG_MAX_PARTITIONS =
+ buildConf("spark.sql.watchdog.maxPartitions")
+ .doc("Set the max partition number when spark scans a data source. " +
+ "Enable maxPartitions Strategy by specifying this configuration. " +
+ "Add maxPartitions Strategy to avoid scan excessive partitions " +
+ "on partitioned table, it's optional that works with defined")
+ .version("1.4.0")
+ .intConf
+ .createOptional
+
+ val WATCHDOG_MAX_FILE_SIZE =
+ buildConf("spark.sql.watchdog.maxFileSize")
+ .doc("Set the maximum size in bytes of files when spark scans a data source. " +
+ "Enable maxFileSize Strategy by specifying this configuration. " +
+ "Add maxFileSize Strategy to avoid scan excessive size of files," +
+ " it's optional that works with defined")
+ .version("1.8.0")
+ .bytesConf(ByteUnit.BYTE)
+ .createOptional
+
+ val WATCHDOG_FORCED_MAXOUTPUTROWS =
+ buildConf("spark.sql.watchdog.forcedMaxOutputRows")
+ .doc("Add ForcedMaxOutputRows rule to avoid huge output rows of non-limit query " +
+ "unexpectedly, it's optional that works with defined")
+ .version("1.4.0")
+ .intConf
+ .createOptional
+
+ val DROP_IGNORE_NONEXISTENT =
+ buildConf("spark.sql.optimizer.dropIgnoreNonExistent")
+ .doc("Do not report an error if DROP DATABASE/TABLE/VIEW/FUNCTION/PARTITION specifies " +
+ "a non-existent database/table/view/function/partition")
+ .version("1.5.0")
+ .booleanConf
+ .createWithDefault(false)
+
+ val INFER_REBALANCE_AND_SORT_ORDERS =
+ buildConf("spark.sql.optimizer.inferRebalanceAndSortOrders.enabled")
+ .doc("When ture, infer columns for rebalance and sort orders from original query, " +
+ "e.g. the join keys from join. It can avoid compression ratio regression.")
+ .version("1.7.0")
+ .booleanConf
+ .createWithDefault(false)
+
+ val INFER_REBALANCE_AND_SORT_ORDERS_MAX_COLUMNS =
+ buildConf("spark.sql.optimizer.inferRebalanceAndSortOrdersMaxColumns")
+ .doc("The max columns of inferred columns.")
+ .version("1.7.0")
+ .intConf
+ .checkValue(_ > 0, "must be positive number")
+ .createWithDefault(3)
+
+ val INSERT_REPARTITION_BEFORE_WRITE_IF_NO_SHUFFLE =
+ buildConf("spark.sql.optimizer.insertRepartitionBeforeWriteIfNoShuffle.enabled")
+ .doc("When true, add repartition even if the original plan does not have shuffle.")
+ .version("1.7.0")
+ .booleanConf
+ .createWithDefault(false)
+
+ val FINAL_STAGE_CONFIG_ISOLATION_WRITE_ONLY =
+ buildConf("spark.sql.optimizer.finalStageConfigIsolationWriteOnly.enabled")
+ .doc("When true, only enable final stage isolation for writing.")
+ .version("1.7.0")
+ .booleanConf
+ .createWithDefault(true)
+
+ val FINAL_WRITE_STAGE_EAGERLY_KILL_EXECUTORS_ENABLED =
+ buildConf("spark.sql.finalWriteStage.eagerlyKillExecutors.enabled")
+ .doc("When true, eagerly kill redundant executors before running final write stage.")
+ .version("1.8.0")
+ .booleanConf
+ .createWithDefault(false)
+
+ val FINAL_WRITE_STAGE_EAGERLY_KILL_EXECUTORS_KILL_ALL =
+ buildConf("spark.sql.finalWriteStage.eagerlyKillExecutors.killAll")
+ .doc("When true, eagerly kill all executors before running final write stage. " +
+ "Mainly for test.")
+ .version("1.8.0")
+ .booleanConf
+ .createWithDefault(false)
+
+ val FINAL_WRITE_STAGE_SKIP_KILLING_EXECUTORS_FOR_TABLE_CACHE =
+ buildConf("spark.sql.finalWriteStage.skipKillingExecutorsForTableCache")
+ .doc("When true, skip killing executors if the plan has table caches.")
+ .version("1.8.0")
+ .booleanConf
+ .createWithDefault(true)
+
+ val FINAL_WRITE_STAGE_PARTITION_FACTOR =
+ buildConf("spark.sql.finalWriteStage.retainExecutorsFactor")
+ .doc("If the target executors * factor < active executors, and " +
+ "target executors * factor > min executors, then kill redundant executors.")
+ .version("1.8.0")
+ .doubleConf
+ .checkValue(_ >= 1, "must be bigger than or equal to 1")
+ .createWithDefault(1.2)
+
+ val FINAL_WRITE_STAGE_RESOURCE_ISOLATION_ENABLED =
+ buildConf("spark.sql.finalWriteStage.resourceIsolation.enabled")
+ .doc(
+ "When true, make final write stage resource isolation using custom RDD resource profile.")
+ .version("1.8.0")
+ .booleanConf
+ .createWithDefault(false)
+
+ val FINAL_WRITE_STAGE_EXECUTOR_CORES =
+ buildConf("spark.sql.finalWriteStage.executorCores")
+ .doc("Specify the executor core request for final write stage. " +
+ "It would be passed to the RDD resource profile.")
+ .version("1.8.0")
+ .intConf
+ .createOptional
+
+ val FINAL_WRITE_STAGE_EXECUTOR_MEMORY =
+ buildConf("spark.sql.finalWriteStage.executorMemory")
+ .doc("Specify the executor on heap memory request for final write stage. " +
+ "It would be passed to the RDD resource profile.")
+ .version("1.8.0")
+ .stringConf
+ .createOptional
+
+ val FINAL_WRITE_STAGE_EXECUTOR_MEMORY_OVERHEAD =
+ buildConf("spark.sql.finalWriteStage.executorMemoryOverhead")
+ .doc("Specify the executor memory overhead request for final write stage. " +
+ "It would be passed to the RDD resource profile.")
+ .version("1.8.0")
+ .stringConf
+ .createOptional
+
+ val FINAL_WRITE_STAGE_EXECUTOR_OFF_HEAP_MEMORY =
+ buildConf("spark.sql.finalWriteStage.executorOffHeapMemory")
+ .doc("Specify the executor off heap memory request for final write stage. " +
+ "It would be passed to the RDD resource profile.")
+ .version("1.8.0")
+ .stringConf
+ .createOptional
+}
diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiSQLExtensionException.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiSQLExtensionException.scala
new file mode 100644
index 000000000..88c5a988f
--- /dev/null
+++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiSQLExtensionException.scala
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.sql
+
+import java.sql.SQLException
+
+class KyuubiSQLExtensionException(reason: String, cause: Throwable)
+ extends SQLException(reason, cause) {
+
+ def this(reason: String) = {
+ this(reason, null)
+ }
+}
diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLAstBuilder.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLAstBuilder.scala
new file mode 100644
index 000000000..cc00bf88e
--- /dev/null
+++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLAstBuilder.scala
@@ -0,0 +1,174 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.sql
+
+import scala.collection.JavaConverters.asScalaBufferConverter
+import scala.collection.mutable.ListBuffer
+
+import org.antlr.v4.runtime.ParserRuleContext
+import org.antlr.v4.runtime.misc.Interval
+import org.antlr.v4.runtime.tree.ParseTree
+import org.apache.spark.sql.catalyst.SQLConfHelper
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.parser.ParserUtils.withOrigin
+import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project, Sort}
+
+import org.apache.kyuubi.sql.KyuubiSparkSQLParser._
+import org.apache.kyuubi.sql.zorder.{OptimizeZorderStatement, Zorder}
+
+class KyuubiSparkSQLAstBuilder extends KyuubiSparkSQLBaseVisitor[AnyRef] with SQLConfHelper {
+
+ def buildOptimizeStatement(
+ unparsedPredicateOptimize: UnparsedPredicateOptimize,
+ parseExpression: String => Expression): LogicalPlan = {
+
+ val UnparsedPredicateOptimize(tableIdent, tablePredicate, orderExpr) =
+ unparsedPredicateOptimize
+
+ val predicate = tablePredicate.map(parseExpression)
+ verifyPartitionPredicates(predicate)
+ val table = UnresolvedRelation(tableIdent)
+ val tableWithFilter = predicate match {
+ case Some(expr) => Filter(expr, table)
+ case None => table
+ }
+ val query =
+ Sort(
+ SortOrder(orderExpr, Ascending, NullsLast, Seq.empty) :: Nil,
+ conf.getConf(KyuubiSQLConf.ZORDER_GLOBAL_SORT_ENABLED),
+ Project(Seq(UnresolvedStar(None)), tableWithFilter))
+ OptimizeZorderStatement(tableIdent, query)
+ }
+
+ private def verifyPartitionPredicates(predicates: Option[Expression]): Unit = {
+ predicates.foreach {
+ case p if !isLikelySelective(p) =>
+ throw new KyuubiSQLExtensionException(s"unsupported partition predicates: ${p.sql}")
+ case _ =>
+ }
+ }
+
+ /**
+ * Forked from Apache Spark's org.apache.spark.sql.catalyst.expressions.PredicateHelper
+ * The `PredicateHelper.isLikelySelective()` is available since Spark-3.3, forked for Spark
+ * that is lower than 3.3.
+ *
+ * Returns whether an expression is likely to be selective
+ */
+ private def isLikelySelective(e: Expression): Boolean = e match {
+ case Not(expr) => isLikelySelective(expr)
+ case And(l, r) => isLikelySelective(l) || isLikelySelective(r)
+ case Or(l, r) => isLikelySelective(l) && isLikelySelective(r)
+ case _: StringRegexExpression => true
+ case _: BinaryComparison => true
+ case _: In | _: InSet => true
+ case _: StringPredicate => true
+ case BinaryPredicate(_) => true
+ case _: MultiLikeBase => true
+ case _ => false
+ }
+
+ private object BinaryPredicate {
+ def unapply(expr: Expression): Option[Expression] = expr match {
+ case _: Contains => Option(expr)
+ case _: StartsWith => Option(expr)
+ case _: EndsWith => Option(expr)
+ case _ => None
+ }
+ }
+
+ /**
+ * Create an expression from the given context. This method just passes the context on to the
+ * visitor and only takes care of typing (We assume that the visitor returns an Expression here).
+ */
+ protected def expression(ctx: ParserRuleContext): Expression = typedVisit(ctx)
+
+ protected def multiPart(ctx: ParserRuleContext): Seq[String] = typedVisit(ctx)
+
+ override def visitSingleStatement(ctx: SingleStatementContext): LogicalPlan = {
+ visit(ctx.statement()).asInstanceOf[LogicalPlan]
+ }
+
+ override def visitOptimizeZorder(
+ ctx: OptimizeZorderContext): UnparsedPredicateOptimize = withOrigin(ctx) {
+ val tableIdent = multiPart(ctx.multipartIdentifier())
+
+ val predicate = Option(ctx.whereClause())
+ .map(_.partitionPredicate)
+ .map(extractRawText(_))
+
+ val zorderCols = ctx.zorderClause().order.asScala
+ .map(visitMultipartIdentifier)
+ .map(UnresolvedAttribute(_))
+ .toSeq
+
+ val orderExpr =
+ if (zorderCols.length == 1) {
+ zorderCols.head
+ } else {
+ Zorder(zorderCols)
+ }
+ UnparsedPredicateOptimize(tableIdent, predicate, orderExpr)
+ }
+
+ override def visitPassThrough(ctx: PassThroughContext): LogicalPlan = null
+
+ override def visitMultipartIdentifier(ctx: MultipartIdentifierContext): Seq[String] =
+ withOrigin(ctx) {
+ ctx.parts.asScala.map(_.getText).toSeq
+ }
+
+ override def visitZorderClause(ctx: ZorderClauseContext): Seq[UnresolvedAttribute] =
+ withOrigin(ctx) {
+ val res = ListBuffer[UnresolvedAttribute]()
+ ctx.multipartIdentifier().forEach { identifier =>
+ res += UnresolvedAttribute(identifier.parts.asScala.map(_.getText).toSeq)
+ }
+ res.toSeq
+ }
+
+ private def typedVisit[T](ctx: ParseTree): T = {
+ ctx.accept(this).asInstanceOf[T]
+ }
+
+ private def extractRawText(exprContext: ParserRuleContext): String = {
+ // Extract the raw expression which will be parsed later
+ exprContext.getStart.getInputStream.getText(new Interval(
+ exprContext.getStart.getStartIndex,
+ exprContext.getStop.getStopIndex))
+ }
+}
+
+/**
+ * a logical plan contains an unparsed expression that will be parsed by spark.
+ */
+trait UnparsedExpressionLogicalPlan extends LogicalPlan {
+ override def output: Seq[Attribute] = throw new UnsupportedOperationException()
+
+ override def children: Seq[LogicalPlan] = throw new UnsupportedOperationException()
+
+ protected def withNewChildrenInternal(
+ newChildren: IndexedSeq[LogicalPlan]): LogicalPlan =
+ throw new UnsupportedOperationException()
+}
+
+case class UnparsedPredicateOptimize(
+ tableIdent: Seq[String],
+ tablePredicate: Option[String],
+ orderExpr: Expression) extends UnparsedExpressionLogicalPlan {}
diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLCommonExtension.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLCommonExtension.scala
new file mode 100644
index 000000000..f39ad3cc3
--- /dev/null
+++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLCommonExtension.scala
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.sql
+
+import org.apache.spark.sql.SparkSessionExtensions
+
+import org.apache.kyuubi.sql.zorder.{InsertZorderBeforeWritingDatasource33, InsertZorderBeforeWritingHive33, ResolveZorder}
+
+class KyuubiSparkSQLCommonExtension extends (SparkSessionExtensions => Unit) {
+ override def apply(extensions: SparkSessionExtensions): Unit = {
+ KyuubiSparkSQLCommonExtension.injectCommonExtensions(extensions)
+ }
+}
+
+object KyuubiSparkSQLCommonExtension {
+ def injectCommonExtensions(extensions: SparkSessionExtensions): Unit = {
+ // inject zorder parser and related rules
+ extensions.injectParser { case (_, parser) => new SparkKyuubiSparkSQLParser(parser) }
+ extensions.injectResolutionRule(ResolveZorder)
+
+ // Note that:
+ // InsertZorderBeforeWritingDatasource and InsertZorderBeforeWritingHive
+ // should be applied before
+ // RepartitionBeforeWriting and RebalanceBeforeWriting
+ // because we can only apply one of them (i.e. Global Sort or Repartition/Rebalance)
+ extensions.injectPostHocResolutionRule(InsertZorderBeforeWritingDatasource33)
+ extensions.injectPostHocResolutionRule(InsertZorderBeforeWritingHive33)
+ extensions.injectPostHocResolutionRule(FinalStageConfigIsolationCleanRule)
+
+ extensions.injectQueryStagePrepRule(_ => InsertShuffleNodeBeforeJoin)
+
+ extensions.injectQueryStagePrepRule(FinalStageConfigIsolation(_))
+ }
+}
diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala
new file mode 100644
index 000000000..792315d89
--- /dev/null
+++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.sql
+
+import org.apache.spark.sql.{FinalStageResourceManager, InjectCustomResourceProfile, SparkSessionExtensions}
+
+import org.apache.kyuubi.sql.watchdog.{ForcedMaxOutputRowsRule, MaxScanStrategy}
+
+// scalastyle:off line.size.limit
+/**
+ * Depend on Spark SQL Extension framework, we can use this extension follow steps
+ * 1. move this jar into $SPARK_HOME/jars
+ * 2. add config into `spark-defaults.conf`: `spark.sql.extensions=org.apache.kyuubi.sql.KyuubiSparkSQLExtension`
+ */
+// scalastyle:on line.size.limit
+class KyuubiSparkSQLExtension extends (SparkSessionExtensions => Unit) {
+ override def apply(extensions: SparkSessionExtensions): Unit = {
+ KyuubiSparkSQLCommonExtension.injectCommonExtensions(extensions)
+
+ extensions.injectPostHocResolutionRule(RebalanceBeforeWritingDatasource)
+ extensions.injectPostHocResolutionRule(RebalanceBeforeWritingHive)
+ extensions.injectPostHocResolutionRule(DropIgnoreNonexistent)
+
+ // watchdog extension
+ extensions.injectOptimizerRule(ForcedMaxOutputRowsRule)
+ extensions.injectPlannerStrategy(MaxScanStrategy)
+
+ extensions.injectQueryStagePrepRule(FinalStageResourceManager(_))
+ extensions.injectQueryStagePrepRule(InjectCustomResourceProfile)
+ }
+}
diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLParser.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLParser.scala
new file mode 100644
index 000000000..c4418c33c
--- /dev/null
+++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLParser.scala
@@ -0,0 +1,140 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.sql
+
+import org.antlr.v4.runtime._
+import org.antlr.v4.runtime.atn.PredictionMode
+import org.antlr.v4.runtime.misc.{Interval, ParseCancellationException}
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.{FunctionIdentifier, SQLConfHelper, TableIdentifier}
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.parser.{ParseErrorListener, ParseException, ParserInterface, PostProcessor}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.trees.Origin
+import org.apache.spark.sql.types.{DataType, StructType}
+
+abstract class KyuubiSparkSQLParserBase extends ParserInterface with SQLConfHelper {
+ def delegate: ParserInterface
+ def astBuilder: KyuubiSparkSQLAstBuilder
+
+ override def parsePlan(sqlText: String): LogicalPlan = parse(sqlText) { parser =>
+ astBuilder.visit(parser.singleStatement()) match {
+ case optimize: UnparsedPredicateOptimize =>
+ astBuilder.buildOptimizeStatement(optimize, delegate.parseExpression)
+ case plan: LogicalPlan => plan
+ case _ => delegate.parsePlan(sqlText)
+ }
+ }
+
+ protected def parse[T](command: String)(toResult: KyuubiSparkSQLParser => T): T = {
+ val lexer = new KyuubiSparkSQLLexer(
+ new UpperCaseCharStream(CharStreams.fromString(command)))
+ lexer.removeErrorListeners()
+ lexer.addErrorListener(ParseErrorListener)
+
+ val tokenStream = new CommonTokenStream(lexer)
+ val parser = new KyuubiSparkSQLParser(tokenStream)
+ parser.addParseListener(PostProcessor)
+ parser.removeErrorListeners()
+ parser.addErrorListener(ParseErrorListener)
+
+ try {
+ try {
+ // first, try parsing with potentially faster SLL mode
+ parser.getInterpreter.setPredictionMode(PredictionMode.SLL)
+ toResult(parser)
+ } catch {
+ case _: ParseCancellationException =>
+ // if we fail, parse with LL mode
+ tokenStream.seek(0) // rewind input stream
+ parser.reset()
+
+ // Try Again.
+ parser.getInterpreter.setPredictionMode(PredictionMode.LL)
+ toResult(parser)
+ }
+ } catch {
+ case e: ParseException if e.command.isDefined =>
+ throw e
+ case e: ParseException =>
+ throw e.withCommand(command)
+ case e: AnalysisException =>
+ val position = Origin(e.line, e.startPosition)
+ throw new ParseException(Option(command), e.message, position, position)
+ }
+ }
+
+ override def parseExpression(sqlText: String): Expression = {
+ delegate.parseExpression(sqlText)
+ }
+
+ override def parseTableIdentifier(sqlText: String): TableIdentifier = {
+ delegate.parseTableIdentifier(sqlText)
+ }
+
+ override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = {
+ delegate.parseFunctionIdentifier(sqlText)
+ }
+
+ override def parseMultipartIdentifier(sqlText: String): Seq[String] = {
+ delegate.parseMultipartIdentifier(sqlText)
+ }
+
+ override def parseTableSchema(sqlText: String): StructType = {
+ delegate.parseTableSchema(sqlText)
+ }
+
+ override def parseDataType(sqlText: String): DataType = {
+ delegate.parseDataType(sqlText)
+ }
+
+ /**
+ * This functions was introduced since spark-3.3, for more details, please see
+ * https://github.com/apache/spark/pull/34543
+ */
+ override def parseQuery(sqlText: String): LogicalPlan = {
+ delegate.parseQuery(sqlText)
+ }
+}
+
+class SparkKyuubiSparkSQLParser(
+ override val delegate: ParserInterface)
+ extends KyuubiSparkSQLParserBase {
+ def astBuilder: KyuubiSparkSQLAstBuilder = new KyuubiSparkSQLAstBuilder
+}
+
+/* Copied from Apache Spark's to avoid dependency on Spark Internals */
+class UpperCaseCharStream(wrapped: CodePointCharStream) extends CharStream {
+ override def consume(): Unit = wrapped.consume()
+ override def getSourceName(): String = wrapped.getSourceName
+ override def index(): Int = wrapped.index
+ override def mark(): Int = wrapped.mark
+ override def release(marker: Int): Unit = wrapped.release(marker)
+ override def seek(where: Int): Unit = wrapped.seek(where)
+ override def size(): Int = wrapped.size
+
+ override def getText(interval: Interval): String = wrapped.getText(interval)
+
+ // scalastyle:off
+ override def LA(i: Int): Int = {
+ val la = wrapped.LA(i)
+ if (la == 0 || la == IntStream.EOF) la
+ else Character.toUpperCase(la)
+ }
+ // scalastyle:on
+}
diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/RebalanceBeforeWriting.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/RebalanceBeforeWriting.scala
new file mode 100644
index 000000000..3cbacdd2f
--- /dev/null
+++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/RebalanceBeforeWriting.scala
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.sql
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, SortOrder}
+import org.apache.spark.sql.catalyst.plans.logical._
+
+trait RepartitionBuilderWithRebalance extends RepartitionBuilder {
+ override def buildRepartition(
+ dynamicPartitionColumns: Seq[Attribute],
+ query: LogicalPlan): LogicalPlan = {
+ if (!conf.getConf(KyuubiSQLConf.INFER_REBALANCE_AND_SORT_ORDERS) ||
+ dynamicPartitionColumns.nonEmpty) {
+ RebalancePartitions(dynamicPartitionColumns, query)
+ } else {
+ val maxColumns = conf.getConf(KyuubiSQLConf.INFER_REBALANCE_AND_SORT_ORDERS_MAX_COLUMNS)
+ val inferred = InferRebalanceAndSortOrders.infer(query)
+ if (inferred.isDefined) {
+ val (partitioning, ordering) = inferred.get
+ val rebalance = RebalancePartitions(partitioning.take(maxColumns), query)
+ if (ordering.nonEmpty) {
+ val sortOrders = ordering.take(maxColumns).map(o => SortOrder(o, Ascending))
+ Sort(sortOrders, false, rebalance)
+ } else {
+ rebalance
+ }
+ } else {
+ RebalancePartitions(dynamicPartitionColumns, query)
+ }
+ }
+ }
+
+ override def canInsertRepartitionByExpression(plan: LogicalPlan): Boolean = {
+ super.canInsertRepartitionByExpression(plan) && {
+ plan match {
+ case _: RebalancePartitions => false
+ case _ => true
+ }
+ }
+ }
+}
+
+/**
+ * For datasource table, there two commands can write data to table
+ * 1. InsertIntoHadoopFsRelationCommand
+ * 2. CreateDataSourceTableAsSelectCommand
+ * This rule add a RebalancePartitions node between write and query
+ */
+case class RebalanceBeforeWritingDatasource(session: SparkSession)
+ extends RepartitionBeforeWritingDatasourceBase
+ with RepartitionBuilderWithRebalance {}
+
+/**
+ * For Hive table, there two commands can write data to table
+ * 1. InsertIntoHiveTable
+ * 2. CreateHiveTableAsSelectCommand
+ * This rule add a RebalancePartitions node between write and query
+ */
+case class RebalanceBeforeWritingHive(session: SparkSession)
+ extends RepartitionBeforeWritingHiveBase
+ with RepartitionBuilderWithRebalance {}
diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/RepartitionBeforeWritingBase.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/RepartitionBeforeWritingBase.scala
new file mode 100644
index 000000000..3ebb9740f
--- /dev/null
+++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/RepartitionBeforeWritingBase.scala
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.sql
+
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand
+import org.apache.spark.sql.hive.execution.InsertIntoHiveTable
+import org.apache.spark.sql.internal.StaticSQLConf
+
+trait RepartitionBuilder extends Rule[LogicalPlan] with RepartitionBeforeWriteHelper {
+ def buildRepartition(
+ dynamicPartitionColumns: Seq[Attribute],
+ query: LogicalPlan): LogicalPlan
+}
+
+/**
+ * For datasource table, there two commands can write data to table
+ * 1. InsertIntoHadoopFsRelationCommand
+ * 2. CreateDataSourceTableAsSelectCommand
+ * This rule add a repartition node between write and query
+ */
+abstract class RepartitionBeforeWritingDatasourceBase extends RepartitionBuilder {
+
+ override def apply(plan: LogicalPlan): LogicalPlan = {
+ if (conf.getConf(KyuubiSQLConf.INSERT_REPARTITION_BEFORE_WRITE)) {
+ addRepartition(plan)
+ } else {
+ plan
+ }
+ }
+
+ private def addRepartition(plan: LogicalPlan): LogicalPlan = plan match {
+ case i @ InsertIntoHadoopFsRelationCommand(_, sp, _, pc, bucket, _, _, query, _, _, _, _)
+ if query.resolved && bucket.isEmpty && canInsertRepartitionByExpression(query) =>
+ val dynamicPartitionColumns = pc.filterNot(attr => sp.contains(attr.name))
+ i.copy(query = buildRepartition(dynamicPartitionColumns, query))
+
+ case u @ Union(children, _, _) =>
+ u.copy(children = children.map(addRepartition))
+
+ case _ => plan
+ }
+}
+
+/**
+ * For Hive table, there two commands can write data to table
+ * 1. InsertIntoHiveTable
+ * 2. CreateHiveTableAsSelectCommand
+ * This rule add a repartition node between write and query
+ */
+abstract class RepartitionBeforeWritingHiveBase extends RepartitionBuilder {
+ override def apply(plan: LogicalPlan): LogicalPlan = {
+ if (conf.getConf(StaticSQLConf.CATALOG_IMPLEMENTATION) == "hive" &&
+ conf.getConf(KyuubiSQLConf.INSERT_REPARTITION_BEFORE_WRITE)) {
+ addRepartition(plan)
+ } else {
+ plan
+ }
+ }
+
+ def addRepartition(plan: LogicalPlan): LogicalPlan = plan match {
+ case i @ InsertIntoHiveTable(table, partition, query, _, _, _, _, _, _, _, _)
+ if query.resolved && table.bucketSpec.isEmpty && canInsertRepartitionByExpression(query) =>
+ val dynamicPartitionColumns = partition.filter(_._2.isEmpty).keys
+ .flatMap(name => query.output.find(_.name == name)).toSeq
+ i.copy(query = buildRepartition(dynamicPartitionColumns, query))
+
+ case u @ Union(children, _, _) =>
+ u.copy(children = children.map(addRepartition))
+
+ case _ => plan
+ }
+}
+
+trait RepartitionBeforeWriteHelper extends Rule[LogicalPlan] {
+ private def hasBenefit(plan: LogicalPlan): Boolean = {
+ def probablyHasShuffle: Boolean = plan.find {
+ case _: Join => true
+ case _: Aggregate => true
+ case _: Distinct => true
+ case _: Deduplicate => true
+ case _: Window => true
+ case s: Sort if s.global => true
+ case _: RepartitionOperation => true
+ case _: GlobalLimit => true
+ case _ => false
+ }.isDefined
+
+ conf.getConf(KyuubiSQLConf.INSERT_REPARTITION_BEFORE_WRITE_IF_NO_SHUFFLE) || probablyHasShuffle
+ }
+
+ def canInsertRepartitionByExpression(plan: LogicalPlan): Boolean = {
+ def canInsert(p: LogicalPlan): Boolean = p match {
+ case Project(_, child) => canInsert(child)
+ case SubqueryAlias(_, child) => canInsert(child)
+ case Limit(_, _) => false
+ case _: Sort => false
+ case _: RepartitionByExpression => false
+ case _: Repartition => false
+ case _ => true
+ }
+
+ // 1. make sure AQE is enabled, otherwise it is no meaning to add a shuffle
+ // 2. make sure it does not break the semantics of original plan
+ // 3. try to avoid adding a shuffle if it has potential performance regression
+ conf.adaptiveExecutionEnabled && canInsert(plan) && hasBenefit(plan)
+ }
+}
diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/WriteUtils.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/WriteUtils.scala
new file mode 100644
index 000000000..89dd83194
--- /dev/null
+++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/WriteUtils.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.sql
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.execution.{SparkPlan, UnionExec}
+import org.apache.spark.sql.execution.command.DataWritingCommandExec
+import org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec
+
+object WriteUtils {
+ def isWrite(session: SparkSession, plan: SparkPlan): Boolean = {
+ plan match {
+ case _: DataWritingCommandExec => true
+ case _: V2TableWriteExec => true
+ case u: UnionExec if u.children.nonEmpty => u.children.forall(isWrite(session, _))
+ case _ => false
+ }
+ }
+}
diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsBase.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsBase.scala
new file mode 100644
index 000000000..4f897d1b6
--- /dev/null
+++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsBase.scala
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.sql.watchdog
+
+import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions.Alias
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution.command.DataWritingCommand
+
+import org.apache.kyuubi.sql.KyuubiSQLConf
+
+/*
+ * Add ForcedMaxOutputRows rule for output rows limitation
+ * to avoid huge output rows of non_limit query unexpectedly
+ * mainly applied to cases as below:
+ *
+ * case 1:
+ * {{{
+ * SELECT [c1, c2, ...]
+ * }}}
+ *
+ * case 2:
+ * {{{
+ * WITH CTE AS (
+ * ...)
+ * SELECT [c1, c2, ...] FROM CTE ...
+ * }}}
+ *
+ * The Logical Rule add a GlobalLimit node before root project
+ * */
+trait ForcedMaxOutputRowsBase extends Rule[LogicalPlan] {
+
+ protected def isChildAggregate(a: Aggregate): Boolean
+
+ protected def canInsertLimitInner(p: LogicalPlan): Boolean = p match {
+ case Aggregate(_, Alias(_, "havingCondition") :: Nil, _) => false
+ case agg: Aggregate => !isChildAggregate(agg)
+ case _: RepartitionByExpression => true
+ case _: Distinct => true
+ case _: Filter => true
+ case _: Project => true
+ case Limit(_, _) => true
+ case _: Sort => true
+ case Union(children, _, _) =>
+ if (children.exists(_.isInstanceOf[DataWritingCommand])) {
+ false
+ } else {
+ true
+ }
+ case _: MultiInstanceRelation => true
+ case _: Join => true
+ case _ => false
+ }
+
+ protected def canInsertLimit(p: LogicalPlan, maxOutputRowsOpt: Option[Int]): Boolean = {
+ maxOutputRowsOpt match {
+ case Some(forcedMaxOutputRows) => canInsertLimitInner(p) &&
+ !p.maxRows.exists(_ <= forcedMaxOutputRows)
+ case None => false
+ }
+ }
+
+ override def apply(plan: LogicalPlan): LogicalPlan = {
+ val maxOutputRowsOpt = conf.getConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS)
+ plan match {
+ case p if p.resolved && canInsertLimit(p, maxOutputRowsOpt) =>
+ Limit(
+ maxOutputRowsOpt.get,
+ plan)
+ case _ => plan
+ }
+ }
+}
diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsRule.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsRule.scala
new file mode 100644
index 000000000..a3d990b10
--- /dev/null
+++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsRule.scala
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.sql.watchdog
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, CommandResult, LogicalPlan, Union, WithCTE}
+import org.apache.spark.sql.execution.command.DataWritingCommand
+
+case class ForcedMaxOutputRowsRule(sparkSession: SparkSession) extends ForcedMaxOutputRowsBase {
+
+ override protected def isChildAggregate(a: Aggregate): Boolean = false
+
+ override protected def canInsertLimitInner(p: LogicalPlan): Boolean = p match {
+ case WithCTE(plan, _) => this.canInsertLimitInner(plan)
+ case plan: LogicalPlan => plan match {
+ case Union(children, _, _) => !children.exists {
+ case _: DataWritingCommand => true
+ case p: CommandResult if p.commandLogicalPlan.isInstanceOf[DataWritingCommand] => true
+ case _ => false
+ }
+ case _ => super.canInsertLimitInner(plan)
+ }
+ }
+
+ override protected def canInsertLimit(p: LogicalPlan, maxOutputRowsOpt: Option[Int]): Boolean = {
+ p match {
+ case WithCTE(plan, _) => this.canInsertLimit(plan, maxOutputRowsOpt)
+ case _ => super.canInsertLimit(p, maxOutputRowsOpt)
+ }
+ }
+}
diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/watchdog/KyuubiWatchDogException.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/watchdog/KyuubiWatchDogException.scala
new file mode 100644
index 000000000..e44309192
--- /dev/null
+++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/watchdog/KyuubiWatchDogException.scala
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.sql.watchdog
+
+import org.apache.kyuubi.sql.KyuubiSQLExtensionException
+
+final class MaxPartitionExceedException(
+ private val reason: String = "",
+ private val cause: Throwable = None.orNull)
+ extends KyuubiSQLExtensionException(reason, cause)
+
+final class MaxFileSizeExceedException(
+ private val reason: String = "",
+ private val cause: Throwable = None.orNull)
+ extends KyuubiSQLExtensionException(reason, cause)
diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/watchdog/MaxScanStrategy.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/watchdog/MaxScanStrategy.scala
new file mode 100644
index 000000000..1ed55ebc2
--- /dev/null
+++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/watchdog/MaxScanStrategy.scala
@@ -0,0 +1,305 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.sql.watchdog
+
+import org.apache.hadoop.fs.Path
+import org.apache.spark.sql.{PruneFileSourcePartitionHelper, SparkSession, Strategy}
+import org.apache.spark.sql.catalyst.SQLConfHelper
+import org.apache.spark.sql.catalyst.catalog.{CatalogTable, HiveTableRelation}
+import org.apache.spark.sql.catalyst.planning.ScanOperation
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.datasources.{CatalogFileIndex, HadoopFsRelation, InMemoryFileIndex, LogicalRelation}
+import org.apache.spark.sql.types.StructType
+
+import org.apache.kyuubi.sql.KyuubiSQLConf
+
+/**
+ * Add MaxScanStrategy to avoid scan excessive partitions or files
+ * 1. Check if scan exceed maxPartition of partitioned table
+ * 2. Check if scan exceed maxFileSize (calculated by hive table and partition statistics)
+ * This Strategy Add Planner Strategy after LogicalOptimizer
+ * @param session
+ */
+case class MaxScanStrategy(session: SparkSession)
+ extends Strategy
+ with SQLConfHelper
+ with PruneFileSourcePartitionHelper {
+ override def apply(plan: LogicalPlan): Seq[SparkPlan] = {
+ val maxScanPartitionsOpt = conf.getConf(KyuubiSQLConf.WATCHDOG_MAX_PARTITIONS)
+ val maxFileSizeOpt = conf.getConf(KyuubiSQLConf.WATCHDOG_MAX_FILE_SIZE)
+ if (maxScanPartitionsOpt.isDefined || maxFileSizeOpt.isDefined) {
+ checkScan(plan, maxScanPartitionsOpt, maxFileSizeOpt)
+ }
+ Nil
+ }
+
+ private def checkScan(
+ plan: LogicalPlan,
+ maxScanPartitionsOpt: Option[Int],
+ maxFileSizeOpt: Option[Long]): Unit = {
+ plan match {
+ case ScanOperation(_, _, _, relation: HiveTableRelation) =>
+ if (relation.isPartitioned) {
+ relation.prunedPartitions match {
+ case Some(prunedPartitions) =>
+ if (maxScanPartitionsOpt.exists(_ < prunedPartitions.size)) {
+ throw new MaxPartitionExceedException(
+ s"""
+ |SQL job scan hive partition: ${prunedPartitions.size}
+ |exceed restrict of hive scan maxPartition ${maxScanPartitionsOpt.get}
+ |You should optimize your SQL logical according partition structure
+ |or shorten query scope such as p_date, detail as below:
+ |Table: ${relation.tableMeta.qualifiedName}
+ |Owner: ${relation.tableMeta.owner}
+ |Partition Structure: ${relation.partitionCols.map(_.name).mkString(", ")}
+ |""".stripMargin)
+ }
+ lazy val scanFileSize = prunedPartitions.flatMap(_.stats).map(_.sizeInBytes).sum
+ if (maxFileSizeOpt.exists(_ < scanFileSize)) {
+ throw partTableMaxFileExceedError(
+ scanFileSize,
+ maxFileSizeOpt.get,
+ Some(relation.tableMeta),
+ prunedPartitions.flatMap(_.storage.locationUri).map(_.toString),
+ relation.partitionCols.map(_.name))
+ }
+ case _ =>
+ lazy val scanPartitions: Int = session
+ .sessionState.catalog.externalCatalog.listPartitionNames(
+ relation.tableMeta.database,
+ relation.tableMeta.identifier.table).size
+ if (maxScanPartitionsOpt.exists(_ < scanPartitions)) {
+ throw new MaxPartitionExceedException(
+ s"""
+ |Your SQL job scan a whole huge table without any partition filter,
+ |You should optimize your SQL logical according partition structure
+ |or shorten query scope such as p_date, detail as below:
+ |Table: ${relation.tableMeta.qualifiedName}
+ |Owner: ${relation.tableMeta.owner}
+ |Partition Structure: ${relation.partitionCols.map(_.name).mkString(", ")}
+ |""".stripMargin)
+ }
+
+ lazy val scanFileSize: BigInt =
+ relation.tableMeta.stats.map(_.sizeInBytes).getOrElse {
+ session
+ .sessionState.catalog.externalCatalog.listPartitions(
+ relation.tableMeta.database,
+ relation.tableMeta.identifier.table).flatMap(_.stats).map(_.sizeInBytes).sum
+ }
+ if (maxFileSizeOpt.exists(_ < scanFileSize)) {
+ throw new MaxFileSizeExceedException(
+ s"""
+ |Your SQL job scan a whole huge table without any partition filter,
+ |You should optimize your SQL logical according partition structure
+ |or shorten query scope such as p_date, detail as below:
+ |Table: ${relation.tableMeta.qualifiedName}
+ |Owner: ${relation.tableMeta.owner}
+ |Partition Structure: ${relation.partitionCols.map(_.name).mkString(", ")}
+ |""".stripMargin)
+ }
+ }
+ } else {
+ lazy val scanFileSize = relation.tableMeta.stats.map(_.sizeInBytes).sum
+ if (maxFileSizeOpt.exists(_ < scanFileSize)) {
+ throw nonPartTableMaxFileExceedError(
+ scanFileSize,
+ maxFileSizeOpt.get,
+ Some(relation.tableMeta))
+ }
+ }
+ case ScanOperation(
+ _,
+ _,
+ filters,
+ relation @ LogicalRelation(
+ fsRelation @ HadoopFsRelation(
+ fileIndex: InMemoryFileIndex,
+ partitionSchema,
+ _,
+ _,
+ _,
+ _),
+ _,
+ _,
+ _)) =>
+ if (fsRelation.partitionSchema.nonEmpty) {
+ val (partitionKeyFilters, dataFilter) =
+ getPartitionKeyFiltersAndDataFilters(
+ SparkSession.active,
+ relation,
+ partitionSchema,
+ filters,
+ relation.output)
+ val prunedPartitions = fileIndex.listFiles(
+ partitionKeyFilters.toSeq,
+ dataFilter)
+ if (maxScanPartitionsOpt.exists(_ < prunedPartitions.size)) {
+ throw maxPartitionExceedError(
+ prunedPartitions.size,
+ maxScanPartitionsOpt.get,
+ relation.catalogTable,
+ fileIndex.rootPaths,
+ fsRelation.partitionSchema)
+ }
+ lazy val scanFileSize = prunedPartitions.flatMap(_.files).map(_.getLen).sum
+ if (maxFileSizeOpt.exists(_ < scanFileSize)) {
+ throw partTableMaxFileExceedError(
+ scanFileSize,
+ maxFileSizeOpt.get,
+ relation.catalogTable,
+ fileIndex.rootPaths.map(_.toString),
+ fsRelation.partitionSchema.map(_.name))
+ }
+ } else {
+ lazy val scanFileSize = fileIndex.sizeInBytes
+ if (maxFileSizeOpt.exists(_ < scanFileSize)) {
+ throw nonPartTableMaxFileExceedError(
+ scanFileSize,
+ maxFileSizeOpt.get,
+ relation.catalogTable)
+ }
+ }
+ case ScanOperation(
+ _,
+ _,
+ filters,
+ logicalRelation @ LogicalRelation(
+ fsRelation @ HadoopFsRelation(
+ catalogFileIndex: CatalogFileIndex,
+ partitionSchema,
+ _,
+ _,
+ _,
+ _),
+ _,
+ _,
+ _)) =>
+ if (fsRelation.partitionSchema.nonEmpty) {
+ val (partitionKeyFilters, _) =
+ getPartitionKeyFiltersAndDataFilters(
+ SparkSession.active,
+ logicalRelation,
+ partitionSchema,
+ filters,
+ logicalRelation.output)
+
+ val fileIndex = catalogFileIndex.filterPartitions(
+ partitionKeyFilters.toSeq)
+
+ lazy val prunedPartitionSize = fileIndex.partitionSpec().partitions.size
+ if (maxScanPartitionsOpt.exists(_ < prunedPartitionSize)) {
+ throw maxPartitionExceedError(
+ prunedPartitionSize,
+ maxScanPartitionsOpt.get,
+ logicalRelation.catalogTable,
+ catalogFileIndex.rootPaths,
+ fsRelation.partitionSchema)
+ }
+
+ lazy val scanFileSize = fileIndex
+ .listFiles(Nil, Nil).flatMap(_.files).map(_.getLen).sum
+ if (maxFileSizeOpt.exists(_ < scanFileSize)) {
+ throw partTableMaxFileExceedError(
+ scanFileSize,
+ maxFileSizeOpt.get,
+ logicalRelation.catalogTable,
+ catalogFileIndex.rootPaths.map(_.toString),
+ fsRelation.partitionSchema.map(_.name))
+ }
+ } else {
+ lazy val scanFileSize = catalogFileIndex.sizeInBytes
+ if (maxFileSizeOpt.exists(_ < scanFileSize)) {
+ throw nonPartTableMaxFileExceedError(
+ scanFileSize,
+ maxFileSizeOpt.get,
+ logicalRelation.catalogTable)
+ }
+ }
+ case _ =>
+ }
+ }
+
+ def maxPartitionExceedError(
+ prunedPartitionSize: Int,
+ maxPartitionSize: Int,
+ tableMeta: Option[CatalogTable],
+ rootPaths: Seq[Path],
+ partitionSchema: StructType): Throwable = {
+ val truncatedPaths =
+ if (rootPaths.length > 5) {
+ rootPaths.slice(0, 5).mkString(",") + """... """ + (rootPaths.length - 5) + " more paths"
+ } else {
+ rootPaths.mkString(",")
+ }
+
+ new MaxPartitionExceedException(
+ s"""
+ |SQL job scan data source partition: $prunedPartitionSize
+ |exceed restrict of data source scan maxPartition $maxPartitionSize
+ |You should optimize your SQL logical according partition structure
+ |or shorten query scope such as p_date, detail as below:
+ |Table: ${tableMeta.map(_.qualifiedName).getOrElse("")}
+ |Owner: ${tableMeta.map(_.owner).getOrElse("")}
+ |RootPaths: $truncatedPaths
+ |Partition Structure: ${partitionSchema.map(_.name).mkString(", ")}
+ |""".stripMargin)
+ }
+
+ private def partTableMaxFileExceedError(
+ scanFileSize: Number,
+ maxFileSize: Long,
+ tableMeta: Option[CatalogTable],
+ rootPaths: Seq[String],
+ partitions: Seq[String]): Throwable = {
+ val truncatedPaths =
+ if (rootPaths.length > 5) {
+ rootPaths.slice(0, 5).mkString(",") + """... """ + (rootPaths.length - 5) + " more paths"
+ } else {
+ rootPaths.mkString(",")
+ }
+
+ new MaxFileSizeExceedException(
+ s"""
+ |SQL job scan file size in bytes: $scanFileSize
+ |exceed restrict of table scan maxFileSize $maxFileSize
+ |You should optimize your SQL logical according partition structure
+ |or shorten query scope such as p_date, detail as below:
+ |Table: ${tableMeta.map(_.qualifiedName).getOrElse("")}
+ |Owner: ${tableMeta.map(_.owner).getOrElse("")}
+ |RootPaths: $truncatedPaths
+ |Partition Structure: ${partitions.mkString(", ")}
+ |""".stripMargin)
+ }
+
+ private def nonPartTableMaxFileExceedError(
+ scanFileSize: Number,
+ maxFileSize: Long,
+ tableMeta: Option[CatalogTable]): Throwable = {
+ new MaxFileSizeExceedException(
+ s"""
+ |SQL job scan file size in bytes: $scanFileSize
+ |exceed restrict of table scan maxFileSize $maxFileSize
+ |detail as below:
+ |Table: ${tableMeta.map(_.qualifiedName).getOrElse("")}
+ |Owner: ${tableMeta.map(_.owner).getOrElse("")}
+ |Location: ${tableMeta.map(_.location).getOrElse("")}
+ |""".stripMargin)
+ }
+}
diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/zorder/InsertZorderBeforeWriting.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/zorder/InsertZorderBeforeWriting.scala
new file mode 100644
index 000000000..b3f98ec6d
--- /dev/null
+++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/zorder/InsertZorderBeforeWriting.scala
@@ -0,0 +1,177 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.sql.zorder
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.catalog.CatalogTable
+import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Expression, NullsLast, SortOrder}
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand
+import org.apache.spark.sql.hive.execution.InsertIntoHiveTable
+
+import org.apache.kyuubi.sql.{KyuubiSQLConf, KyuubiSQLExtensionException}
+
+trait InsertZorderHelper33 extends Rule[LogicalPlan] with ZorderBuilder {
+ private val KYUUBI_ZORDER_ENABLED = "kyuubi.zorder.enabled"
+ private val KYUUBI_ZORDER_COLS = "kyuubi.zorder.cols"
+
+ def isZorderEnabled(props: Map[String, String]): Boolean = {
+ props.contains(KYUUBI_ZORDER_ENABLED) &&
+ "true".equalsIgnoreCase(props(KYUUBI_ZORDER_ENABLED)) &&
+ props.contains(KYUUBI_ZORDER_COLS)
+ }
+
+ def getZorderColumns(props: Map[String, String]): Seq[String] = {
+ val cols = props.get(KYUUBI_ZORDER_COLS)
+ assert(cols.isDefined)
+ cols.get.split(",").map(_.trim)
+ }
+
+ def canInsertZorder(query: LogicalPlan): Boolean = query match {
+ case Project(_, child) => canInsertZorder(child)
+ // TODO: actually, we can force zorder even if existed some shuffle
+ case _: Sort => false
+ case _: RepartitionByExpression => false
+ case _: Repartition => false
+ case _ => true
+ }
+
+ def insertZorder(
+ catalogTable: CatalogTable,
+ plan: LogicalPlan,
+ dynamicPartitionColumns: Seq[Attribute]): LogicalPlan = {
+ if (!canInsertZorder(plan)) {
+ return plan
+ }
+ val cols = getZorderColumns(catalogTable.properties)
+ val resolver = session.sessionState.conf.resolver
+ val output = plan.output
+ val bound = cols.flatMap(col => output.find(attr => resolver(attr.name, col)))
+ if (bound.size < cols.size) {
+ logWarning(s"target table does not contain all zorder cols: ${cols.mkString(",")}, " +
+ s"please check your table properties ${KYUUBI_ZORDER_COLS}.")
+ plan
+ } else {
+ if (conf.getConf(KyuubiSQLConf.ZORDER_GLOBAL_SORT_ENABLED) &&
+ conf.getConf(KyuubiSQLConf.REBALANCE_BEFORE_ZORDER)) {
+ throw new KyuubiSQLExtensionException(s"${KyuubiSQLConf.ZORDER_GLOBAL_SORT_ENABLED.key} " +
+ s"and ${KyuubiSQLConf.REBALANCE_BEFORE_ZORDER.key} can not be enabled together.")
+ }
+ if (conf.getConf(KyuubiSQLConf.ZORDER_GLOBAL_SORT_ENABLED) &&
+ dynamicPartitionColumns.nonEmpty) {
+ logWarning(s"Dynamic partition insertion with global sort may produce small files.")
+ }
+
+ val zorderExpr =
+ if (bound.length == 1) {
+ bound
+ } else if (conf.getConf(KyuubiSQLConf.ZORDER_USING_ORIGINAL_ORDERING_ENABLED)) {
+ bound.asInstanceOf[Seq[Expression]]
+ } else {
+ buildZorder(bound) :: Nil
+ }
+ val (global, orderExprs, child) =
+ if (conf.getConf(KyuubiSQLConf.ZORDER_GLOBAL_SORT_ENABLED)) {
+ (true, zorderExpr, plan)
+ } else if (conf.getConf(KyuubiSQLConf.REBALANCE_BEFORE_ZORDER)) {
+ val rebalanceExpr =
+ if (dynamicPartitionColumns.isEmpty) {
+ // static partition insert
+ bound
+ } else if (conf.getConf(KyuubiSQLConf.REBALANCE_ZORDER_COLUMNS_ENABLED)) {
+ // improve data compression ratio
+ dynamicPartitionColumns.asInstanceOf[Seq[Expression]] ++ bound
+ } else {
+ dynamicPartitionColumns.asInstanceOf[Seq[Expression]]
+ }
+ // for dynamic partition insert, Spark always sort the partition columns,
+ // so here we sort partition columns + zorder.
+ val rebalance =
+ if (dynamicPartitionColumns.nonEmpty &&
+ conf.getConf(KyuubiSQLConf.TWO_PHASE_REBALANCE_BEFORE_ZORDER)) {
+ // improve compression ratio
+ RebalancePartitions(
+ rebalanceExpr,
+ RebalancePartitions(dynamicPartitionColumns, plan))
+ } else {
+ RebalancePartitions(rebalanceExpr, plan)
+ }
+ (false, dynamicPartitionColumns.asInstanceOf[Seq[Expression]] ++ zorderExpr, rebalance)
+ } else {
+ (false, zorderExpr, plan)
+ }
+ val order = orderExprs.map { expr =>
+ SortOrder(expr, Ascending, NullsLast, Seq.empty)
+ }
+ Sort(order, global, child)
+ }
+ }
+
+ override def buildZorder(children: Seq[Expression]): ZorderBase = Zorder(children)
+
+ def session: SparkSession
+ def applyInternal(plan: LogicalPlan): LogicalPlan
+
+ final override def apply(plan: LogicalPlan): LogicalPlan = {
+ if (conf.getConf(KyuubiSQLConf.INSERT_ZORDER_BEFORE_WRITING)) {
+ applyInternal(plan)
+ } else {
+ plan
+ }
+ }
+}
+
+case class InsertZorderBeforeWritingDatasource33(session: SparkSession)
+ extends InsertZorderHelper33 {
+ override def applyInternal(plan: LogicalPlan): LogicalPlan = plan match {
+ case insert: InsertIntoHadoopFsRelationCommand
+ if insert.query.resolved &&
+ insert.bucketSpec.isEmpty && insert.catalogTable.isDefined &&
+ isZorderEnabled(insert.catalogTable.get.properties) =>
+ val dynamicPartition =
+ insert.partitionColumns.filterNot(attr => insert.staticPartitions.contains(attr.name))
+ val newQuery = insertZorder(insert.catalogTable.get, insert.query, dynamicPartition)
+ if (newQuery.eq(insert.query)) {
+ insert
+ } else {
+ insert.copy(query = newQuery)
+ }
+
+ case _ => plan
+ }
+}
+
+case class InsertZorderBeforeWritingHive33(session: SparkSession)
+ extends InsertZorderHelper33 {
+ override def applyInternal(plan: LogicalPlan): LogicalPlan = plan match {
+ case insert: InsertIntoHiveTable
+ if insert.query.resolved &&
+ insert.table.bucketSpec.isEmpty && isZorderEnabled(insert.table.properties) =>
+ val dynamicPartition = insert.partition.filter(_._2.isEmpty).keys
+ .flatMap(name => insert.query.output.find(_.name == name)).toSeq
+ val newQuery = insertZorder(insert.table, insert.query, dynamicPartition)
+ if (newQuery.eq(insert.query)) {
+ insert
+ } else {
+ insert.copy(query = newQuery)
+ }
+
+ case _ => plan
+ }
+}
diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/zorder/InsertZorderBeforeWritingBase.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/zorder/InsertZorderBeforeWritingBase.scala
new file mode 100644
index 000000000..2c59d148e
--- /dev/null
+++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/zorder/InsertZorderBeforeWritingBase.scala
@@ -0,0 +1,155 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.sql.zorder
+
+import java.util.Locale
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.catalog.CatalogTable
+import org.apache.spark.sql.catalyst.expressions.{Ascending, Expression, NullsLast, SortOrder}
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand
+import org.apache.spark.sql.hive.execution.InsertIntoHiveTable
+
+import org.apache.kyuubi.sql.KyuubiSQLConf
+
+/**
+ * TODO: shall we forbid zorder if it's dynamic partition inserts ?
+ * Insert zorder before writing datasource if the target table properties has zorder properties
+ */
+abstract class InsertZorderBeforeWritingDatasourceBase
+ extends InsertZorderHelper {
+ override def applyInternal(plan: LogicalPlan): LogicalPlan = plan match {
+ case insert: InsertIntoHadoopFsRelationCommand
+ if insert.query.resolved && insert.bucketSpec.isEmpty && insert.catalogTable.isDefined &&
+ isZorderEnabled(insert.catalogTable.get.properties) =>
+ val newQuery = insertZorder(insert.catalogTable.get, insert.query)
+ if (newQuery.eq(insert.query)) {
+ insert
+ } else {
+ insert.copy(query = newQuery)
+ }
+ case _ => plan
+ }
+}
+
+/**
+ * TODO: shall we forbid zorder if it's dynamic partition inserts ?
+ * Insert zorder before writing hive if the target table properties has zorder properties
+ */
+abstract class InsertZorderBeforeWritingHiveBase
+ extends InsertZorderHelper {
+ override def applyInternal(plan: LogicalPlan): LogicalPlan = plan match {
+ case insert: InsertIntoHiveTable
+ if insert.query.resolved && insert.table.bucketSpec.isEmpty &&
+ isZorderEnabled(insert.table.properties) =>
+ val newQuery = insertZorder(insert.table, insert.query)
+ if (newQuery.eq(insert.query)) {
+ insert
+ } else {
+ insert.copy(query = newQuery)
+ }
+ case _ => plan
+ }
+}
+
+trait ZorderBuilder {
+ def buildZorder(children: Seq[Expression]): ZorderBase
+}
+
+trait InsertZorderHelper extends Rule[LogicalPlan] with ZorderBuilder {
+ private val KYUUBI_ZORDER_ENABLED = "kyuubi.zorder.enabled"
+ private val KYUUBI_ZORDER_COLS = "kyuubi.zorder.cols"
+
+ def isZorderEnabled(props: Map[String, String]): Boolean = {
+ props.contains(KYUUBI_ZORDER_ENABLED) &&
+ "true".equalsIgnoreCase(props(KYUUBI_ZORDER_ENABLED)) &&
+ props.contains(KYUUBI_ZORDER_COLS)
+ }
+
+ def getZorderColumns(props: Map[String, String]): Seq[String] = {
+ val cols = props.get(KYUUBI_ZORDER_COLS)
+ assert(cols.isDefined)
+ cols.get.split(",").map(_.trim.toLowerCase(Locale.ROOT))
+ }
+
+ def canInsertZorder(query: LogicalPlan): Boolean = query match {
+ case Project(_, child) => canInsertZorder(child)
+ // TODO: actually, we can force zorder even if existed some shuffle
+ case _: Sort => false
+ case _: RepartitionByExpression => false
+ case _: Repartition => false
+ case _ => true
+ }
+
+ def insertZorder(catalogTable: CatalogTable, plan: LogicalPlan): LogicalPlan = {
+ if (!canInsertZorder(plan)) {
+ return plan
+ }
+ val cols = getZorderColumns(catalogTable.properties)
+ val attrs = plan.output.map(attr => (attr.name, attr)).toMap
+ if (cols.exists(!attrs.contains(_))) {
+ logWarning(s"target table does not contain all zorder cols: ${cols.mkString(",")}, " +
+ s"please check your table properties ${KYUUBI_ZORDER_COLS}.")
+ plan
+ } else {
+ val bound = cols.map(attrs(_))
+ val orderExpr =
+ if (bound.length == 1) {
+ bound.head
+ } else {
+ buildZorder(bound)
+ }
+ // TODO: We can do rebalance partitions before local sort of zorder after SPARK 3.3
+ // see https://github.com/apache/spark/pull/34542
+ Sort(
+ SortOrder(orderExpr, Ascending, NullsLast, Seq.empty) :: Nil,
+ conf.getConf(KyuubiSQLConf.ZORDER_GLOBAL_SORT_ENABLED),
+ plan)
+ }
+ }
+
+ def applyInternal(plan: LogicalPlan): LogicalPlan
+
+ final override def apply(plan: LogicalPlan): LogicalPlan = {
+ if (conf.getConf(KyuubiSQLConf.INSERT_ZORDER_BEFORE_WRITING)) {
+ applyInternal(plan)
+ } else {
+ plan
+ }
+ }
+}
+
+/**
+ * TODO: shall we forbid zorder if it's dynamic partition inserts ?
+ * Insert zorder before writing datasource if the target table properties has zorder properties
+ */
+case class InsertZorderBeforeWritingDatasource(session: SparkSession)
+ extends InsertZorderBeforeWritingDatasourceBase {
+ override def buildZorder(children: Seq[Expression]): ZorderBase = Zorder(children)
+}
+
+/**
+ * TODO: shall we forbid zorder if it's dynamic partition inserts ?
+ * Insert zorder before writing hive if the target table properties has zorder properties
+ */
+case class InsertZorderBeforeWritingHive(session: SparkSession)
+ extends InsertZorderBeforeWritingHiveBase {
+ override def buildZorder(children: Seq[Expression]): ZorderBase = Zorder(children)
+}
diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/zorder/OptimizeZorderCommandBase.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/zorder/OptimizeZorderCommandBase.scala
new file mode 100644
index 000000000..21d1cf2a2
--- /dev/null
+++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/zorder/OptimizeZorderCommandBase.scala
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.sql.zorder
+
+import org.apache.spark.sql.{Row, SparkSession}
+import org.apache.spark.sql.catalyst.catalog.CatalogTable
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.command.DataWritingCommand
+import org.apache.spark.sql.hive.execution.InsertIntoHiveTable
+
+import org.apache.kyuubi.sql.KyuubiSQLExtensionException
+
+/**
+ * A runnable command for zorder, we delegate to real command to execute
+ */
+abstract class OptimizeZorderCommandBase extends DataWritingCommand {
+ def catalogTable: CatalogTable
+
+ override def outputColumnNames: Seq[String] = query.output.map(_.name)
+
+ private def isHiveTable: Boolean = {
+ catalogTable.provider.isEmpty ||
+ (catalogTable.provider.isDefined && "hive".equalsIgnoreCase(catalogTable.provider.get))
+ }
+
+ private def getWritingCommand(session: SparkSession): DataWritingCommand = {
+ // TODO: Support convert hive relation to datasource relation, can see
+ // [[org.apache.spark.sql.hive.RelationConversions]]
+ InsertIntoHiveTable(
+ catalogTable,
+ catalogTable.partitionColumnNames.map(p => (p, None)).toMap,
+ query,
+ overwrite = true,
+ ifPartitionNotExists = false,
+ outputColumnNames)
+ }
+
+ override def run(session: SparkSession, child: SparkPlan): Seq[Row] = {
+ // TODO: Support datasource relation
+ // TODO: Support read and insert overwrite the same table for some table format
+ if (!isHiveTable) {
+ throw new KyuubiSQLExtensionException("only support hive table")
+ }
+
+ val command = getWritingCommand(session)
+ command.run(session, child)
+ DataWritingCommand.propogateMetrics(session.sparkContext, command, metrics)
+ Seq.empty
+ }
+}
+
+/**
+ * A runnable command for zorder, we delegate to real command to execute
+ */
+case class OptimizeZorderCommand(
+ catalogTable: CatalogTable,
+ query: LogicalPlan)
+ extends OptimizeZorderCommandBase {
+ protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = {
+ copy(query = newChild)
+ }
+}
diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/zorder/OptimizeZorderStatementBase.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/zorder/OptimizeZorderStatementBase.scala
new file mode 100644
index 000000000..895f9e24b
--- /dev/null
+++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/zorder/OptimizeZorderStatementBase.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.sql.zorder
+
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode}
+
+/**
+ * A zorder statement that contains we parsed from SQL.
+ * We should convert this plan to certain command at Analyzer.
+ */
+case class OptimizeZorderStatement(
+ tableIdentifier: Seq[String],
+ query: LogicalPlan) extends UnaryNode {
+ override def child: LogicalPlan = query
+ override def output: Seq[Attribute] = child.output
+ protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan =
+ copy(query = newChild)
+}
diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/zorder/ResolveZorderBase.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/zorder/ResolveZorderBase.scala
new file mode 100644
index 000000000..9f735caa7
--- /dev/null
+++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/zorder/ResolveZorderBase.scala
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.sql.zorder
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.catalog.{CatalogTable, HiveTableRelation}
+import org.apache.spark.sql.catalyst.expressions.AttributeSet
+import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, SubqueryAlias}
+import org.apache.spark.sql.catalyst.rules.Rule
+
+import org.apache.kyuubi.sql.KyuubiSQLExtensionException
+
+/**
+ * Resolve `OptimizeZorderStatement` to `OptimizeZorderCommand`
+ */
+abstract class ResolveZorderBase extends Rule[LogicalPlan] {
+ def session: SparkSession
+ def buildOptimizeZorderCommand(
+ catalogTable: CatalogTable,
+ query: LogicalPlan): OptimizeZorderCommandBase
+
+ protected def checkQueryAllowed(query: LogicalPlan): Unit = query foreach {
+ case Filter(condition, SubqueryAlias(_, tableRelation: HiveTableRelation)) =>
+ if (tableRelation.partitionCols.isEmpty) {
+ throw new KyuubiSQLExtensionException("Filters are only supported for partitioned table")
+ }
+
+ val partitionKeyIds = AttributeSet(tableRelation.partitionCols)
+ if (condition.references.isEmpty || !condition.references.subsetOf(partitionKeyIds)) {
+ throw new KyuubiSQLExtensionException("Only partition column filters are allowed")
+ }
+
+ case _ =>
+ }
+
+ protected def getTableIdentifier(tableIdent: Seq[String]): TableIdentifier = tableIdent match {
+ case Seq(tbl) => TableIdentifier.apply(tbl)
+ case Seq(db, tbl) => TableIdentifier.apply(tbl, Some(db))
+ case _ => throw new KyuubiSQLExtensionException(
+ "only support session catalog table, please use db.table instead")
+ }
+
+ override def apply(plan: LogicalPlan): LogicalPlan = plan match {
+ case statement: OptimizeZorderStatement if statement.query.resolved =>
+ checkQueryAllowed(statement.query)
+ val tableIdentifier = getTableIdentifier(statement.tableIdentifier)
+ val catalogTable = session.sessionState.catalog.getTableMetadata(tableIdentifier)
+ buildOptimizeZorderCommand(catalogTable, statement.query)
+
+ case _ => plan
+ }
+}
+
+/**
+ * Resolve `OptimizeZorderStatement` to `OptimizeZorderCommand`
+ */
+case class ResolveZorder(session: SparkSession) extends ResolveZorderBase {
+ override def buildOptimizeZorderCommand(
+ catalogTable: CatalogTable,
+ query: LogicalPlan): OptimizeZorderCommandBase = {
+ OptimizeZorderCommand(catalogTable, query)
+ }
+}
diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/zorder/ZorderBase.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/zorder/ZorderBase.scala
new file mode 100644
index 000000000..e4d98ccbe
--- /dev/null
+++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/zorder/ZorderBase.scala
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.sql.zorder
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, FalseLiteral}
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
+import org.apache.spark.sql.types.{BinaryType, DataType}
+
+import org.apache.kyuubi.sql.KyuubiSQLExtensionException
+
+abstract class ZorderBase extends Expression {
+ override def foldable: Boolean = children.forall(_.foldable)
+ override def nullable: Boolean = false
+ override def dataType: DataType = BinaryType
+ override def prettyName: String = "zorder"
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ try {
+ defaultNullValues
+ TypeCheckResult.TypeCheckSuccess
+ } catch {
+ case e: KyuubiSQLExtensionException =>
+ TypeCheckResult.TypeCheckFailure(e.getMessage)
+ }
+ }
+
+ @transient
+ private[this] lazy val defaultNullValues: Array[Any] =
+ children.map(_.dataType)
+ .map(ZorderBytesUtils.defaultValue)
+ .toArray
+
+ override def eval(input: InternalRow): Any = {
+ val childrenValues = children.zipWithIndex.map {
+ case (child: Expression, index) =>
+ val v = child.eval(input)
+ if (v == null) {
+ defaultNullValues(index)
+ } else {
+ v
+ }
+ }
+ ZorderBytesUtils.interleaveBits(childrenValues.toArray)
+ }
+
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val evals = children.map(_.genCode(ctx))
+ val defaultValues = ctx.addReferenceObj("defaultValues", defaultNullValues)
+ val values = ctx.freshName("values")
+ val util = ZorderBytesUtils.getClass.getName.stripSuffix("$")
+ val inputs = evals.zipWithIndex.map {
+ case (eval, index) =>
+ s"""
+ |${eval.code}
+ |if (${eval.isNull}) {
+ | $values[$index] = $defaultValues[$index];
+ |} else {
+ | $values[$index] = ${eval.value};
+ |}
+ |""".stripMargin
+ }
+ ev.copy(
+ code =
+ code"""
+ |byte[] ${ev.value} = null;
+ |Object[] $values = new Object[${evals.length}];
+ |${inputs.mkString("\n")}
+ |${ev.value} = $util.interleaveBits($values);
+ |""".stripMargin,
+ isNull = FalseLiteral)
+ }
+}
+
+case class Zorder(children: Seq[Expression]) extends ZorderBase {
+ protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
+ copy(children = newChildren)
+}
diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/zorder/ZorderBytesUtils.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/zorder/ZorderBytesUtils.scala
new file mode 100644
index 000000000..d249f1dc3
--- /dev/null
+++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/zorder/ZorderBytesUtils.scala
@@ -0,0 +1,517 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.sql.zorder
+
+import java.lang.{Double => jDouble, Float => jFloat}
+
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+
+import org.apache.kyuubi.sql.KyuubiSQLExtensionException
+
+object ZorderBytesUtils {
+ final private val BIT_8_MASK = 1 << 7
+ final private val BIT_16_MASK = 1 << 15
+ final private val BIT_32_MASK = 1 << 31
+ final private val BIT_64_MASK = 1L << 63
+
+ def interleaveBits(inputs: Array[Any]): Array[Byte] = {
+ inputs.length match {
+ // it's a more fast approach, use O(8 * 8)
+ // can see http://graphics.stanford.edu/~seander/bithacks.html#InterleaveTableObvious
+ case 1 => longToByte(toLong(inputs(0)))
+ case 2 => interleave2Longs(toLong(inputs(0)), toLong(inputs(1)))
+ case 3 => interleave3Longs(toLong(inputs(0)), toLong(inputs(1)), toLong(inputs(2)))
+ case 4 =>
+ interleave4Longs(toLong(inputs(0)), toLong(inputs(1)), toLong(inputs(2)), toLong(inputs(3)))
+ case 5 => interleave5Longs(
+ toLong(inputs(0)),
+ toLong(inputs(1)),
+ toLong(inputs(2)),
+ toLong(inputs(3)),
+ toLong(inputs(4)))
+ case 6 => interleave6Longs(
+ toLong(inputs(0)),
+ toLong(inputs(1)),
+ toLong(inputs(2)),
+ toLong(inputs(3)),
+ toLong(inputs(4)),
+ toLong(inputs(5)))
+ case 7 => interleave7Longs(
+ toLong(inputs(0)),
+ toLong(inputs(1)),
+ toLong(inputs(2)),
+ toLong(inputs(3)),
+ toLong(inputs(4)),
+ toLong(inputs(5)),
+ toLong(inputs(6)))
+ case 8 => interleave8Longs(
+ toLong(inputs(0)),
+ toLong(inputs(1)),
+ toLong(inputs(2)),
+ toLong(inputs(3)),
+ toLong(inputs(4)),
+ toLong(inputs(5)),
+ toLong(inputs(6)),
+ toLong(inputs(7)))
+
+ case _ =>
+ // it's the default approach, use O(64 * n), n is the length of inputs
+ interleaveBitsDefault(inputs.map(toByteArray))
+ }
+ }
+
+ private def interleave2Longs(l1: Long, l2: Long): Array[Byte] = {
+ // output 8 * 16 bits
+ val result = new Array[Byte](16)
+ var i = 0
+ while (i < 8) {
+ val tmp1 = ((l1 >> (i * 8)) & 0xFF).toShort
+ val tmp2 = ((l2 >> (i * 8)) & 0xFF).toShort
+
+ var z = 0
+ var j = 0
+ while (j < 8) {
+ val x_masked = tmp1 & (1 << j)
+ val y_masked = tmp2 & (1 << j)
+ z |= (x_masked << j)
+ z |= (y_masked << (j + 1))
+ j = j + 1
+ }
+ result((7 - i) * 2 + 1) = (z & 0xFF).toByte
+ result((7 - i) * 2) = ((z >> 8) & 0xFF).toByte
+ i = i + 1
+ }
+ result
+ }
+
+ private def interleave3Longs(l1: Long, l2: Long, l3: Long): Array[Byte] = {
+ // output 8 * 24 bits
+ val result = new Array[Byte](24)
+ var i = 0
+ while (i < 8) {
+ val tmp1 = ((l1 >> (i * 8)) & 0xFF).toInt
+ val tmp2 = ((l2 >> (i * 8)) & 0xFF).toInt
+ val tmp3 = ((l3 >> (i * 8)) & 0xFF).toInt
+
+ var z = 0
+ var j = 0
+ while (j < 8) {
+ val r1_mask = tmp1 & (1 << j)
+ val r2_mask = tmp2 & (1 << j)
+ val r3_mask = tmp3 & (1 << j)
+ z |= (r1_mask << (2 * j)) | (r2_mask << (2 * j + 1)) | (r3_mask << (2 * j + 2))
+ j = j + 1
+ }
+ result((7 - i) * 3 + 2) = (z & 0xFF).toByte
+ result((7 - i) * 3 + 1) = ((z >> 8) & 0xFF).toByte
+ result((7 - i) * 3) = ((z >> 16) & 0xFF).toByte
+ i = i + 1
+ }
+ result
+ }
+
+ private def interleave4Longs(l1: Long, l2: Long, l3: Long, l4: Long): Array[Byte] = {
+ // output 8 * 32 bits
+ val result = new Array[Byte](32)
+ var i = 0
+ while (i < 8) {
+ val tmp1 = ((l1 >> (i * 8)) & 0xFF).toInt
+ val tmp2 = ((l2 >> (i * 8)) & 0xFF).toInt
+ val tmp3 = ((l3 >> (i * 8)) & 0xFF).toInt
+ val tmp4 = ((l4 >> (i * 8)) & 0xFF).toInt
+
+ var z = 0
+ var j = 0
+ while (j < 8) {
+ val r1_mask = tmp1 & (1 << j)
+ val r2_mask = tmp2 & (1 << j)
+ val r3_mask = tmp3 & (1 << j)
+ val r4_mask = tmp4 & (1 << j)
+ z |= (r1_mask << (3 * j)) | (r2_mask << (3 * j + 1)) | (r3_mask << (3 * j + 2)) |
+ (r4_mask << (3 * j + 3))
+ j = j + 1
+ }
+ result((7 - i) * 4 + 3) = (z & 0xFF).toByte
+ result((7 - i) * 4 + 2) = ((z >> 8) & 0xFF).toByte
+ result((7 - i) * 4 + 1) = ((z >> 16) & 0xFF).toByte
+ result((7 - i) * 4) = ((z >> 24) & 0xFF).toByte
+ i = i + 1
+ }
+ result
+ }
+
+ private def interleave5Longs(
+ l1: Long,
+ l2: Long,
+ l3: Long,
+ l4: Long,
+ l5: Long): Array[Byte] = {
+ // output 8 * 40 bits
+ val result = new Array[Byte](40)
+ var i = 0
+ while (i < 8) {
+ val tmp1 = ((l1 >> (i * 8)) & 0xFF).toLong
+ val tmp2 = ((l2 >> (i * 8)) & 0xFF).toLong
+ val tmp3 = ((l3 >> (i * 8)) & 0xFF).toLong
+ val tmp4 = ((l4 >> (i * 8)) & 0xFF).toLong
+ val tmp5 = ((l5 >> (i * 8)) & 0xFF).toLong
+
+ var z = 0L
+ var j = 0
+ while (j < 8) {
+ val r1_mask = tmp1 & (1 << j)
+ val r2_mask = tmp2 & (1 << j)
+ val r3_mask = tmp3 & (1 << j)
+ val r4_mask = tmp4 & (1 << j)
+ val r5_mask = tmp5 & (1 << j)
+ z |= (r1_mask << (4 * j)) | (r2_mask << (4 * j + 1)) | (r3_mask << (4 * j + 2)) |
+ (r4_mask << (4 * j + 3)) | (r5_mask << (4 * j + 4))
+ j = j + 1
+ }
+ result((7 - i) * 5 + 4) = (z & 0xFF).toByte
+ result((7 - i) * 5 + 3) = ((z >> 8) & 0xFF).toByte
+ result((7 - i) * 5 + 2) = ((z >> 16) & 0xFF).toByte
+ result((7 - i) * 5 + 1) = ((z >> 24) & 0xFF).toByte
+ result((7 - i) * 5) = ((z >> 32) & 0xFF).toByte
+ i = i + 1
+ }
+ result
+ }
+
+ private def interleave6Longs(
+ l1: Long,
+ l2: Long,
+ l3: Long,
+ l4: Long,
+ l5: Long,
+ l6: Long): Array[Byte] = {
+ // output 8 * 48 bits
+ val result = new Array[Byte](48)
+ var i = 0
+ while (i < 8) {
+ val tmp1 = ((l1 >> (i * 8)) & 0xFF).toLong
+ val tmp2 = ((l2 >> (i * 8)) & 0xFF).toLong
+ val tmp3 = ((l3 >> (i * 8)) & 0xFF).toLong
+ val tmp4 = ((l4 >> (i * 8)) & 0xFF).toLong
+ val tmp5 = ((l5 >> (i * 8)) & 0xFF).toLong
+ val tmp6 = ((l6 >> (i * 8)) & 0xFF).toLong
+
+ var z = 0L
+ var j = 0
+ while (j < 8) {
+ val r1_mask = tmp1 & (1 << j)
+ val r2_mask = tmp2 & (1 << j)
+ val r3_mask = tmp3 & (1 << j)
+ val r4_mask = tmp4 & (1 << j)
+ val r5_mask = tmp5 & (1 << j)
+ val r6_mask = tmp6 & (1 << j)
+ z |= (r1_mask << (5 * j)) | (r2_mask << (5 * j + 1)) | (r3_mask << (5 * j + 2)) |
+ (r4_mask << (5 * j + 3)) | (r5_mask << (5 * j + 4)) | (r6_mask << (5 * j + 5))
+ j = j + 1
+ }
+ result((7 - i) * 6 + 5) = (z & 0xFF).toByte
+ result((7 - i) * 6 + 4) = ((z >> 8) & 0xFF).toByte
+ result((7 - i) * 6 + 3) = ((z >> 16) & 0xFF).toByte
+ result((7 - i) * 6 + 2) = ((z >> 24) & 0xFF).toByte
+ result((7 - i) * 6 + 1) = ((z >> 32) & 0xFF).toByte
+ result((7 - i) * 6) = ((z >> 40) & 0xFF).toByte
+ i = i + 1
+ }
+ result
+ }
+
+ private def interleave7Longs(
+ l1: Long,
+ l2: Long,
+ l3: Long,
+ l4: Long,
+ l5: Long,
+ l6: Long,
+ l7: Long): Array[Byte] = {
+ // output 8 * 56 bits
+ val result = new Array[Byte](56)
+ var i = 0
+ while (i < 8) {
+ val tmp1 = ((l1 >> (i * 8)) & 0xFF).toLong
+ val tmp2 = ((l2 >> (i * 8)) & 0xFF).toLong
+ val tmp3 = ((l3 >> (i * 8)) & 0xFF).toLong
+ val tmp4 = ((l4 >> (i * 8)) & 0xFF).toLong
+ val tmp5 = ((l5 >> (i * 8)) & 0xFF).toLong
+ val tmp6 = ((l6 >> (i * 8)) & 0xFF).toLong
+ val tmp7 = ((l7 >> (i * 8)) & 0xFF).toLong
+
+ var z = 0L
+ var j = 0
+ while (j < 8) {
+ val r1_mask = tmp1 & (1 << j)
+ val r2_mask = tmp2 & (1 << j)
+ val r3_mask = tmp3 & (1 << j)
+ val r4_mask = tmp4 & (1 << j)
+ val r5_mask = tmp5 & (1 << j)
+ val r6_mask = tmp6 & (1 << j)
+ val r7_mask = tmp7 & (1 << j)
+ z |= (r1_mask << (6 * j)) | (r2_mask << (6 * j + 1)) | (r3_mask << (6 * j + 2)) |
+ (r4_mask << (6 * j + 3)) | (r5_mask << (6 * j + 4)) | (r6_mask << (6 * j + 5)) |
+ (r7_mask << (6 * j + 6))
+ j = j + 1
+ }
+ result((7 - i) * 7 + 6) = (z & 0xFF).toByte
+ result((7 - i) * 7 + 5) = ((z >> 8) & 0xFF).toByte
+ result((7 - i) * 7 + 4) = ((z >> 16) & 0xFF).toByte
+ result((7 - i) * 7 + 3) = ((z >> 24) & 0xFF).toByte
+ result((7 - i) * 7 + 2) = ((z >> 32) & 0xFF).toByte
+ result((7 - i) * 7 + 1) = ((z >> 40) & 0xFF).toByte
+ result((7 - i) * 7) = ((z >> 48) & 0xFF).toByte
+ i = i + 1
+ }
+ result
+ }
+
+ private def interleave8Longs(
+ l1: Long,
+ l2: Long,
+ l3: Long,
+ l4: Long,
+ l5: Long,
+ l6: Long,
+ l7: Long,
+ l8: Long): Array[Byte] = {
+ // output 8 * 64 bits
+ val result = new Array[Byte](64)
+ var i = 0
+ while (i < 8) {
+ val tmp1 = ((l1 >> (i * 8)) & 0xFF).toLong
+ val tmp2 = ((l2 >> (i * 8)) & 0xFF).toLong
+ val tmp3 = ((l3 >> (i * 8)) & 0xFF).toLong
+ val tmp4 = ((l4 >> (i * 8)) & 0xFF).toLong
+ val tmp5 = ((l5 >> (i * 8)) & 0xFF).toLong
+ val tmp6 = ((l6 >> (i * 8)) & 0xFF).toLong
+ val tmp7 = ((l7 >> (i * 8)) & 0xFF).toLong
+ val tmp8 = ((l8 >> (i * 8)) & 0xFF).toLong
+
+ var z = 0L
+ var j = 0
+ while (j < 8) {
+ val r1_mask = tmp1 & (1 << j)
+ val r2_mask = tmp2 & (1 << j)
+ val r3_mask = tmp3 & (1 << j)
+ val r4_mask = tmp4 & (1 << j)
+ val r5_mask = tmp5 & (1 << j)
+ val r6_mask = tmp6 & (1 << j)
+ val r7_mask = tmp7 & (1 << j)
+ val r8_mask = tmp8 & (1 << j)
+ z |= (r1_mask << (7 * j)) | (r2_mask << (7 * j + 1)) | (r3_mask << (7 * j + 2)) |
+ (r4_mask << (7 * j + 3)) | (r5_mask << (7 * j + 4)) | (r6_mask << (7 * j + 5)) |
+ (r7_mask << (7 * j + 6)) | (r8_mask << (7 * j + 7))
+ j = j + 1
+ }
+ result((7 - i) * 8 + 7) = (z & 0xFF).toByte
+ result((7 - i) * 8 + 6) = ((z >> 8) & 0xFF).toByte
+ result((7 - i) * 8 + 5) = ((z >> 16) & 0xFF).toByte
+ result((7 - i) * 8 + 4) = ((z >> 24) & 0xFF).toByte
+ result((7 - i) * 8 + 3) = ((z >> 32) & 0xFF).toByte
+ result((7 - i) * 8 + 2) = ((z >> 40) & 0xFF).toByte
+ result((7 - i) * 8 + 1) = ((z >> 48) & 0xFF).toByte
+ result((7 - i) * 8) = ((z >> 56) & 0xFF).toByte
+ i = i + 1
+ }
+ result
+ }
+
+ def interleaveBitsDefault(arrays: Array[Array[Byte]]): Array[Byte] = {
+ var totalLength = 0
+ var maxLength = 0
+ arrays.foreach { array =>
+ totalLength += array.length
+ maxLength = maxLength.max(array.length * 8)
+ }
+ val result = new Array[Byte](totalLength)
+ var resultBit = 0
+
+ var bit = 0
+ while (bit < maxLength) {
+ val bytePos = bit / 8
+ val bitPos = bit % 8
+
+ for (arr <- arrays) {
+ val len = arr.length
+ if (bytePos < len) {
+ val resultBytePos = totalLength - 1 - resultBit / 8
+ val resultBitPos = resultBit % 8
+ result(resultBytePos) =
+ updatePos(result(resultBytePos), resultBitPos, arr(len - 1 - bytePos), bitPos)
+ resultBit += 1
+ }
+ }
+ bit += 1
+ }
+ result
+ }
+
+ def updatePos(a: Byte, apos: Int, b: Byte, bpos: Int): Byte = {
+ var temp = (b & (1 << bpos)).toByte
+ if (apos > bpos) {
+ temp = (temp << (apos - bpos)).toByte
+ } else if (apos < bpos) {
+ temp = (temp >> (bpos - apos)).toByte
+ }
+ val atemp = (a & (1 << apos)).toByte
+ if (atemp == temp) {
+ return a
+ }
+ (a ^ (1 << apos)).toByte
+ }
+
+ def toLong(a: Any): Long = {
+ a match {
+ case b: Boolean => (if (b) 1 else 0).toLong ^ BIT_64_MASK
+ case b: Byte => b.toLong ^ BIT_64_MASK
+ case s: Short => s.toLong ^ BIT_64_MASK
+ case i: Int => i.toLong ^ BIT_64_MASK
+ case l: Long => l ^ BIT_64_MASK
+ case f: Float => java.lang.Float.floatToRawIntBits(f).toLong ^ BIT_64_MASK
+ case d: Double => java.lang.Double.doubleToRawLongBits(d) ^ BIT_64_MASK
+ case str: UTF8String => str.getPrefix
+ case dec: Decimal => dec.toLong ^ BIT_64_MASK
+ case other: Any =>
+ throw new KyuubiSQLExtensionException("Unsupported z-order type: " + other.getClass)
+ }
+ }
+
+ def toByteArray(a: Any): Array[Byte] = {
+ a match {
+ case bo: Boolean =>
+ booleanToByte(bo)
+ case b: Byte =>
+ byteToByte(b)
+ case s: Short =>
+ shortToByte(s)
+ case i: Int =>
+ intToByte(i)
+ case l: Long =>
+ longToByte(l)
+ case f: Float =>
+ floatToByte(f)
+ case d: Double =>
+ doubleToByte(d)
+ case str: UTF8String =>
+ // truncate or padding str to 8 byte
+ paddingTo8Byte(str.getBytes)
+ case dec: Decimal =>
+ longToByte(dec.toLong)
+ case other: Any =>
+ throw new KyuubiSQLExtensionException("Unsupported z-order type: " + other.getClass)
+ }
+ }
+
+ def booleanToByte(a: Boolean): Array[Byte] = {
+ if (a) {
+ byteToByte(1.toByte)
+ } else {
+ byteToByte(0.toByte)
+ }
+ }
+
+ def byteToByte(a: Byte): Array[Byte] = {
+ val tmp = (a ^ BIT_8_MASK).toByte
+ Array(tmp)
+ }
+
+ def shortToByte(a: Short): Array[Byte] = {
+ val tmp = a ^ BIT_16_MASK
+ Array(((tmp >> 8) & 0xFF).toByte, (tmp & 0xFF).toByte)
+ }
+
+ def intToByte(a: Int): Array[Byte] = {
+ val result = new Array[Byte](4)
+ var i = 0
+ val tmp = a ^ BIT_32_MASK
+ while (i <= 3) {
+ val offset = i * 8
+ result(3 - i) = ((tmp >> offset) & 0xFF).toByte
+ i += 1
+ }
+ result
+ }
+
+ def longToByte(a: Long): Array[Byte] = {
+ val result = new Array[Byte](8)
+ var i = 0
+ val tmp = a ^ BIT_64_MASK
+ while (i <= 7) {
+ val offset = i * 8
+ result(7 - i) = ((tmp >> offset) & 0xFF).toByte
+ i += 1
+ }
+ result
+ }
+
+ def floatToByte(a: Float): Array[Byte] = {
+ val fi = jFloat.floatToRawIntBits(a)
+ intToByte(fi)
+ }
+
+ def doubleToByte(a: Double): Array[Byte] = {
+ val dl = jDouble.doubleToRawLongBits(a)
+ longToByte(dl)
+ }
+
+ def paddingTo8Byte(a: Array[Byte]): Array[Byte] = {
+ val len = a.length
+ if (len == 8) {
+ a
+ } else if (len > 8) {
+ val result = new Array[Byte](8)
+ System.arraycopy(a, 0, result, 0, 8)
+ result
+ } else {
+ val result = new Array[Byte](8)
+ System.arraycopy(a, 0, result, 8 - len, len)
+ result
+ }
+ }
+
+ def defaultByteArrayValue(dataType: DataType): Array[Byte] = toByteArray {
+ defaultValue(dataType)
+ }
+
+ def defaultValue(dataType: DataType): Any = {
+ dataType match {
+ case BooleanType =>
+ true
+ case ByteType =>
+ Byte.MaxValue
+ case ShortType =>
+ Short.MaxValue
+ case IntegerType | DateType =>
+ Int.MaxValue
+ case LongType | TimestampType | _: DecimalType =>
+ Long.MaxValue
+ case FloatType =>
+ Float.MaxValue
+ case DoubleType =>
+ Double.MaxValue
+ case StringType =>
+ // we pad string to 8 bytes so it's equal to long
+ UTF8String.fromBytes(longToByte(Long.MaxValue))
+ case other: Any =>
+ throw new KyuubiSQLExtensionException(s"Unsupported z-order type: ${other.catalogString}")
+ }
+ }
+}
diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/spark/sql/FinalStageResourceManager.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/spark/sql/FinalStageResourceManager.scala
new file mode 100644
index 000000000..81873476c
--- /dev/null
+++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/spark/sql/FinalStageResourceManager.scala
@@ -0,0 +1,289 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import scala.annotation.tailrec
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.{ExecutorAllocationClient, MapOutputTrackerMaster, SparkContext, SparkEnv}
+import org.apache.spark.internal.Logging
+import org.apache.spark.resource.ResourceProfile
+import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SortExec, SparkPlan}
+import org.apache.spark.sql.execution.adaptive._
+import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
+import org.apache.spark.sql.execution.command.DataWritingCommandExec
+import org.apache.spark.sql.execution.datasources.WriteFilesExec
+import org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec
+import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, ShuffleExchangeExec}
+
+import org.apache.kyuubi.sql.{KyuubiSQLConf, WriteUtils}
+
+/**
+ * This rule assumes the final write stage has less cores requirement than previous, otherwise
+ * this rule would take no effect.
+ *
+ * It provide a feature:
+ * 1. Kill redundant executors before running final write stage
+ */
+case class FinalStageResourceManager(session: SparkSession)
+ extends Rule[SparkPlan] with FinalRebalanceStageHelper {
+ override def apply(plan: SparkPlan): SparkPlan = {
+ if (!conf.getConf(KyuubiSQLConf.FINAL_WRITE_STAGE_EAGERLY_KILL_EXECUTORS_ENABLED)) {
+ return plan
+ }
+
+ if (!WriteUtils.isWrite(session, plan)) {
+ return plan
+ }
+
+ val sc = session.sparkContext
+ val dra = sc.getConf.getBoolean("spark.dynamicAllocation.enabled", false)
+ val coresPerExecutor = sc.getConf.getInt("spark.executor.cores", 1)
+ val minExecutors = sc.getConf.getInt("spark.dynamicAllocation.minExecutors", 0)
+ val maxExecutors = sc.getConf.getInt("spark.dynamicAllocation.maxExecutors", Int.MaxValue)
+ val factor = conf.getConf(KyuubiSQLConf.FINAL_WRITE_STAGE_PARTITION_FACTOR)
+ val hasImprovementRoom = maxExecutors - 1 > minExecutors * factor
+ // Fast fail if:
+ // 1. DRA off
+ // 2. only work with yarn and k8s
+ // 3. maxExecutors is not bigger than minExecutors * factor
+ if (!dra || !sc.schedulerBackend.isInstanceOf[CoarseGrainedSchedulerBackend] ||
+ !hasImprovementRoom) {
+ return plan
+ }
+
+ val stageOpt = findFinalRebalanceStage(plan)
+ if (stageOpt.isEmpty) {
+ return plan
+ }
+
+ // It's not safe to kill executors if this plan contains table cache.
+ // If the executor loses then the rdd would re-compute those partition.
+ if (hasTableCache(plan) &&
+ conf.getConf(KyuubiSQLConf.FINAL_WRITE_STAGE_SKIP_KILLING_EXECUTORS_FOR_TABLE_CACHE)) {
+ return plan
+ }
+
+ // TODO: move this to query stage optimizer when updating Spark to 3.5.x
+ // Since we are in `prepareQueryStage`, the AQE shuffle read has not been applied.
+ // So we need to apply it by self.
+ val shuffleRead = queryStageOptimizerRules.foldLeft(stageOpt.get.asInstanceOf[SparkPlan]) {
+ case (latest, rule) => rule.apply(latest)
+ }
+ val (targetCores, stage) = shuffleRead match {
+ case AQEShuffleReadExec(stage: ShuffleQueryStageExec, partitionSpecs) =>
+ (partitionSpecs.length, stage)
+ case stage: ShuffleQueryStageExec =>
+ // we can still kill executors if no AQE shuffle read, e.g., `.repartition(2)`
+ (stage.shuffle.numPartitions, stage)
+ case _ =>
+ // it should never happen in current Spark, but to be safe do nothing if happens
+ logWarning("BUG, Please report to Apache Kyuubi community")
+ return plan
+ }
+ // The condition whether inject custom resource profile:
+ // - target executors < active executors
+ // - active executors - target executors > min executors
+ val numActiveExecutors = sc.getExecutorIds().length
+ val targetExecutors = (math.ceil(targetCores.toFloat / coresPerExecutor) * factor).toInt
+ .max(1)
+ val hasBenefits = targetExecutors < numActiveExecutors &&
+ (numActiveExecutors - targetExecutors) > minExecutors
+ logInfo(s"The snapshot of current executors view, " +
+ s"active executors: $numActiveExecutors, min executor: $minExecutors, " +
+ s"target executors: $targetExecutors, has benefits: $hasBenefits")
+ if (hasBenefits) {
+ val shuffleId = stage.plan.asInstanceOf[ShuffleExchangeExec].shuffleDependency.shuffleId
+ val numReduce = stage.plan.asInstanceOf[ShuffleExchangeExec].numPartitions
+ // Now, there is only a final rebalance stage waiting to execute and all tasks of previous
+ // stage are finished. Kill redundant existed executors eagerly so the tasks of final
+ // stage can be centralized scheduled.
+ killExecutors(sc, targetExecutors, shuffleId, numReduce)
+ }
+
+ plan
+ }
+
+ /**
+ * The priority of kill executors follow:
+ * 1. kill executor who is younger than other (The older the JIT works better)
+ * 2. kill executor who produces less shuffle data first
+ */
+ private def findExecutorToKill(
+ sc: SparkContext,
+ targetExecutors: Int,
+ shuffleId: Int,
+ numReduce: Int): Seq[String] = {
+ val tracker = SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
+ val shuffleStatusOpt = tracker.shuffleStatuses.get(shuffleId)
+ if (shuffleStatusOpt.isEmpty) {
+ return Seq.empty
+ }
+ val shuffleStatus = shuffleStatusOpt.get
+ val executorToBlockSize = new mutable.HashMap[String, Long]
+ shuffleStatus.withMapStatuses { mapStatus =>
+ mapStatus.foreach { status =>
+ var i = 0
+ var sum = 0L
+ while (i < numReduce) {
+ sum += status.getSizeForBlock(i)
+ i += 1
+ }
+ executorToBlockSize.getOrElseUpdate(status.location.executorId, sum)
+ }
+ }
+
+ val backend = sc.schedulerBackend.asInstanceOf[CoarseGrainedSchedulerBackend]
+ val executorsWithRegistrationTs = backend.getExecutorsWithRegistrationTs()
+ val existedExecutors = executorsWithRegistrationTs.keys.toSet
+ val expectedNumExecutorToKill = existedExecutors.size - targetExecutors
+ if (expectedNumExecutorToKill < 1) {
+ return Seq.empty
+ }
+
+ val executorIdsToKill = new ArrayBuffer[String]()
+ // We first kill executor who does not hold shuffle block. It would happen because
+ // the last stage is running fast and finished in a short time. The existed executors are
+ // from previous stages that have not been killed by DRA, so we can not find it by tracking
+ // shuffle status.
+ // We should evict executors by their alive time first and retain all of executors which
+ // have better locality for shuffle block.
+ executorsWithRegistrationTs.toSeq.sortBy(_._2).foreach { case (id, _) =>
+ if (executorIdsToKill.length < expectedNumExecutorToKill &&
+ !executorToBlockSize.contains(id)) {
+ executorIdsToKill.append(id)
+ }
+ }
+
+ // Evict the rest executors according to the shuffle block size
+ executorToBlockSize.toSeq.sortBy(_._2).foreach { case (id, _) =>
+ if (executorIdsToKill.length < expectedNumExecutorToKill && existedExecutors.contains(id)) {
+ executorIdsToKill.append(id)
+ }
+ }
+
+ executorIdsToKill.toSeq
+ }
+
+ private def killExecutors(
+ sc: SparkContext,
+ targetExecutors: Int,
+ shuffleId: Int,
+ numReduce: Int): Unit = {
+ val executorAllocationClient = sc.schedulerBackend.asInstanceOf[ExecutorAllocationClient]
+
+ val executorsToKill =
+ if (conf.getConf(KyuubiSQLConf.FINAL_WRITE_STAGE_EAGERLY_KILL_EXECUTORS_KILL_ALL)) {
+ executorAllocationClient.getExecutorIds()
+ } else {
+ findExecutorToKill(sc, targetExecutors, shuffleId, numReduce)
+ }
+ logInfo(s"Request to kill executors, total count ${executorsToKill.size}, " +
+ s"[${executorsToKill.mkString(", ")}].")
+ if (executorsToKill.isEmpty) {
+ return
+ }
+
+ // Note, `SparkContext#killExecutors` does not allow with DRA enabled,
+ // see `https://github.com/apache/spark/pull/20604`.
+ // It may cause the status in `ExecutorAllocationManager` inconsistent with
+ // `CoarseGrainedSchedulerBackend` for a while. But it should be synchronous finally.
+ //
+ // We should adjust target num executors, otherwise `YarnAllocator` might re-request original
+ // target executors if DRA has not updated target executors yet.
+ // Note, DRA would re-adjust executors if there are more tasks to be executed, so we are safe.
+ //
+ // * We kill executor
+ // * YarnAllocator re-request target executors
+ // * DRA can not release executors since they are new added
+ // ----------------------------------------------------------------> timeline
+ executorAllocationClient.killExecutors(
+ executorIds = executorsToKill,
+ adjustTargetNumExecutors = true,
+ countFailures = false,
+ force = false)
+
+ FinalStageResourceManager.getAdjustedTargetExecutors(sc)
+ .filter(_ < targetExecutors).foreach { adjustedExecutors =>
+ val delta = targetExecutors - adjustedExecutors
+ logInfo(s"Target executors after kill ($adjustedExecutors) is lower than required " +
+ s"($targetExecutors). Requesting $delta additional executor(s).")
+ executorAllocationClient.requestExecutors(delta)
+ }
+ }
+
+ @transient private val queryStageOptimizerRules: Seq[Rule[SparkPlan]] = Seq(
+ OptimizeSkewInRebalancePartitions,
+ CoalesceShufflePartitions(session),
+ OptimizeShuffleWithLocalRead)
+}
+
+object FinalStageResourceManager extends Logging {
+
+ private[sql] def getAdjustedTargetExecutors(sc: SparkContext): Option[Int] = {
+ sc.schedulerBackend match {
+ case schedulerBackend: CoarseGrainedSchedulerBackend =>
+ try {
+ val field = classOf[CoarseGrainedSchedulerBackend]
+ .getDeclaredField("requestedTotalExecutorsPerResourceProfile")
+ field.setAccessible(true)
+ schedulerBackend.synchronized {
+ val requestedTotalExecutorsPerResourceProfile =
+ field.get(schedulerBackend).asInstanceOf[mutable.HashMap[ResourceProfile, Int]]
+ val defaultRp = sc.resourceProfileManager.defaultResourceProfile
+ requestedTotalExecutorsPerResourceProfile.get(defaultRp)
+ }
+ } catch {
+ case e: Exception =>
+ logWarning("Failed to get requestedTotalExecutors of Default ResourceProfile", e)
+ None
+ }
+ case _ => None
+ }
+ }
+}
+
+trait FinalRebalanceStageHelper extends AdaptiveSparkPlanHelper {
+ @tailrec
+ final protected def findFinalRebalanceStage(plan: SparkPlan): Option[ShuffleQueryStageExec] = {
+ plan match {
+ case write: DataWritingCommandExec => findFinalRebalanceStage(write.child)
+ case write: V2TableWriteExec => findFinalRebalanceStage(write.child)
+ case write: WriteFilesExec => findFinalRebalanceStage(write.child)
+ case p: ProjectExec => findFinalRebalanceStage(p.child)
+ case f: FilterExec => findFinalRebalanceStage(f.child)
+ case s: SortExec if !s.global => findFinalRebalanceStage(s.child)
+ case stage: ShuffleQueryStageExec
+ if stage.isMaterialized && stage.mapStats.isDefined &&
+ stage.plan.isInstanceOf[ShuffleExchangeExec] &&
+ stage.plan.asInstanceOf[ShuffleExchangeExec].shuffleOrigin != ENSURE_REQUIREMENTS =>
+ Some(stage)
+ case _ => None
+ }
+ }
+
+ final protected def hasTableCache(plan: SparkPlan): Boolean = {
+ find(plan) {
+ case _: InMemoryTableScanExec => true
+ case _ => false
+ }.isDefined
+ }
+}
diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/spark/sql/InjectCustomResourceProfile.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/spark/sql/InjectCustomResourceProfile.scala
new file mode 100644
index 000000000..64421d6bf
--- /dev/null
+++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/spark/sql/InjectCustomResourceProfile.scala
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution.{CustomResourceProfileExec, SparkPlan}
+import org.apache.spark.sql.execution.adaptive._
+
+import org.apache.kyuubi.sql.{KyuubiSQLConf, WriteUtils}
+
+/**
+ * Inject custom resource profile for final write stage, so we can specify custom
+ * executor resource configs.
+ */
+case class InjectCustomResourceProfile(session: SparkSession)
+ extends Rule[SparkPlan] with FinalRebalanceStageHelper {
+ override def apply(plan: SparkPlan): SparkPlan = {
+ if (!conf.getConf(KyuubiSQLConf.FINAL_WRITE_STAGE_RESOURCE_ISOLATION_ENABLED)) {
+ return plan
+ }
+
+ if (!WriteUtils.isWrite(session, plan)) {
+ return plan
+ }
+
+ val stage = findFinalRebalanceStage(plan)
+ if (stage.isEmpty) {
+ return plan
+ }
+
+ // TODO: Ideally, We can call `CoarseGrainedSchedulerBackend.requestTotalExecutors` eagerly
+ // to reduce the task submit pending time, but it may lose task locality.
+ //
+ // By default, it would request executors when catch stage submit event.
+ injectCustomResourceProfile(plan, stage.get.id)
+ }
+
+ private def injectCustomResourceProfile(plan: SparkPlan, id: Int): SparkPlan = {
+ plan match {
+ case stage: ShuffleQueryStageExec if stage.id == id =>
+ CustomResourceProfileExec(stage)
+ case _ => plan.mapChildren(child => injectCustomResourceProfile(child, id))
+ }
+ }
+}
diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/spark/sql/PruneFileSourcePartitionHelper.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/spark/sql/PruneFileSourcePartitionHelper.scala
new file mode 100644
index 000000000..ce496eb47
--- /dev/null
+++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/spark/sql/PruneFileSourcePartitionHelper.scala
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, ExpressionSet, PredicateHelper, SubqueryExpression}
+import org.apache.spark.sql.catalyst.plans.logical.LeafNode
+import org.apache.spark.sql.execution.datasources.DataSourceStrategy
+import org.apache.spark.sql.types.StructType
+
+trait PruneFileSourcePartitionHelper extends PredicateHelper {
+
+ def getPartitionKeyFiltersAndDataFilters(
+ sparkSession: SparkSession,
+ relation: LeafNode,
+ partitionSchema: StructType,
+ filters: Seq[Expression],
+ output: Seq[AttributeReference]): (ExpressionSet, Seq[Expression]) = {
+ val normalizedFilters = DataSourceStrategy.normalizeExprs(
+ filters.filter(f => f.deterministic && !SubqueryExpression.hasSubquery(f)),
+ output)
+ val partitionColumns =
+ relation.resolve(partitionSchema, sparkSession.sessionState.analyzer.resolver)
+ val partitionSet = AttributeSet(partitionColumns)
+ val (partitionFilters, dataFilters) = normalizedFilters.partition(f =>
+ f.references.subsetOf(partitionSet))
+ val extraPartitionFilter =
+ dataFilters.flatMap(extractPredicatesWithinOutputSet(_, partitionSet))
+
+ (ExpressionSet(partitionFilters ++ extraPartitionFilter), dataFilters)
+ }
+}
diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/spark/sql/execution/CustomResourceProfileExec.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/spark/sql/execution/CustomResourceProfileExec.scala
new file mode 100644
index 000000000..3698140fb
--- /dev/null
+++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/spark/sql/execution/CustomResourceProfileExec.scala
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import org.apache.spark.network.util.{ByteUnit, JavaUtils}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.resource.{ExecutorResourceRequests, ResourceProfileBuilder}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder}
+import org.apache.spark.sql.catalyst.plans.physical.Partitioning
+import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
+import org.apache.spark.sql.vectorized.ColumnarBatch
+import org.apache.spark.util.Utils
+
+import org.apache.kyuubi.sql.KyuubiSQLConf._
+
+/**
+ * This node wraps the final executed plan and inject custom resource profile to the RDD.
+ * It assumes that, the produced RDD would create the `ResultStage` in `DAGScheduler`,
+ * so it makes resource isolation between previous and final stage.
+ *
+ * Note that, Spark does not support config `minExecutors` for each resource profile.
+ * Which means, it would retain `minExecutors` for each resource profile.
+ * So, suggest set `spark.dynamicAllocation.minExecutors` to 0 if enable this feature.
+ */
+case class CustomResourceProfileExec(child: SparkPlan) extends UnaryExecNode {
+ override def output: Seq[Attribute] = child.output
+ override def outputPartitioning: Partitioning = child.outputPartitioning
+ override def outputOrdering: Seq[SortOrder] = child.outputOrdering
+ override def supportsColumnar: Boolean = child.supportsColumnar
+ override def supportsRowBased: Boolean = child.supportsRowBased
+ override protected def doCanonicalize(): SparkPlan = child.canonicalized
+
+ private val executorCores = conf.getConf(FINAL_WRITE_STAGE_EXECUTOR_CORES).getOrElse(
+ sparkContext.getConf.getInt("spark.executor.cores", 1))
+ private val executorMemory = conf.getConf(FINAL_WRITE_STAGE_EXECUTOR_MEMORY).getOrElse(
+ sparkContext.getConf.get("spark.executor.memory", "2G"))
+ private val executorMemoryOverhead =
+ conf.getConf(FINAL_WRITE_STAGE_EXECUTOR_MEMORY_OVERHEAD)
+ .getOrElse(sparkContext.getConf.get("spark.executor.memoryOverhead", "1G"))
+ private val executorOffHeapMemory = conf.getConf(FINAL_WRITE_STAGE_EXECUTOR_OFF_HEAP_MEMORY)
+
+ override lazy val metrics: Map[String, SQLMetric] = {
+ val base = Map(
+ "executorCores" -> SQLMetrics.createMetric(sparkContext, "executor cores"),
+ "executorMemory" -> SQLMetrics.createMetric(sparkContext, "executor memory (MiB)"),
+ "executorMemoryOverhead" -> SQLMetrics.createMetric(
+ sparkContext,
+ "executor memory overhead (MiB)"))
+ val addition = executorOffHeapMemory.map(_ =>
+ "executorOffHeapMemory" ->
+ SQLMetrics.createMetric(sparkContext, "executor off heap memory (MiB)")).toMap
+ base ++ addition
+ }
+
+ private def wrapResourceProfile[T](rdd: RDD[T]): RDD[T] = {
+ if (Utils.isTesting) {
+ // do nothing for local testing
+ return rdd
+ }
+
+ metrics("executorCores") += executorCores
+ metrics("executorMemory") += JavaUtils.byteStringAs(executorMemory, ByteUnit.MiB)
+ metrics("executorMemoryOverhead") += JavaUtils.byteStringAs(
+ executorMemoryOverhead,
+ ByteUnit.MiB)
+ executorOffHeapMemory.foreach(m =>
+ metrics("executorOffHeapMemory") += JavaUtils.byteStringAs(m, ByteUnit.MiB))
+
+ val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
+ SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq)
+
+ val resourceProfileBuilder = new ResourceProfileBuilder()
+ val executorResourceRequests = new ExecutorResourceRequests()
+ executorResourceRequests.cores(executorCores)
+ executorResourceRequests.memory(executorMemory)
+ executorResourceRequests.memoryOverhead(executorMemoryOverhead)
+ executorOffHeapMemory.foreach(executorResourceRequests.offHeapMemory)
+ resourceProfileBuilder.require(executorResourceRequests)
+ rdd.withResources(resourceProfileBuilder.build())
+ rdd
+ }
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ val rdd = child.execute()
+ wrapResourceProfile(rdd)
+ }
+
+ override protected def doExecuteColumnar(): RDD[ColumnarBatch] = {
+ val rdd = child.executeColumnar()
+ wrapResourceProfile(rdd)
+ }
+
+ override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = {
+ this.copy(child = newChild)
+ }
+}
diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/test/resources/log4j2-test.xml b/extensions/spark/kyuubi-extension-spark-3-5/src/test/resources/log4j2-test.xml
new file mode 100644
index 000000000..bfc40dd6d
--- /dev/null
+++ b/extensions/spark/kyuubi-extension-spark-3-5/src/test/resources/log4j2-test.xml
@@ -0,0 +1,43 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+ ~ Licensed to the Apache Software Foundation (ASF) under one or more
+ ~ contributor license agreements. See the NOTICE file distributed with
+ ~ this work for additional information regarding copyright ownership.
+ ~ The ASF licenses this file to You under the Apache License, Version 2.0
+ ~ (the "License"); you may not use this file except in compliance with
+ ~ the License. You may obtain a copy of the License at
+ ~
+ ~ http://www.apache.org/licenses/LICENSE-2.0
+ ~
+ ~ Unless required by applicable law or agreed to in writing, software
+ ~ distributed under the License is distributed on an "AS IS" BASIS,
+ ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ ~ See the License for the specific language governing permissions and
+ ~ limitations under the License.
+ -->
+
+<!-- Extra logging related to initialization of Log4j.
+ Set to debug or trace if log4j initialization is failing. -->
+<Configuration status="WARN">
+ <Appenders>
+ <Console name="stdout" target="SYSTEM_OUT">
+ <PatternLayout pattern="%d{HH:mm:ss.SSS} %p %c: %m%n"/>
+ <Filters>
+ <ThresholdFilter level="FATAL"/>
+ <RegexFilter regex=".*Thrift error occurred during processing of message.*" onMatch="DENY" onMismatch="NEUTRAL"/>
+ </Filters>
+ </Console>
+ <File name="file" fileName="target/unit-tests.log">
+ <PatternLayout pattern="%d{HH:mm:ss.SSS} %t %p %c{1}: %m%n"/>
+ <Filters>
+ <RegexFilter regex=".*Thrift error occurred during processing of message.*" onMatch="DENY" onMismatch="NEUTRAL"/>
+ </Filters>
+ </File>
+ </Appenders>
+ <Loggers>
+ <Root level="INFO">
+ <AppenderRef ref="stdout"/>
+ <AppenderRef ref="file"/>
+ </Root>
+ </Loggers>
+</Configuration>
diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/DropIgnoreNonexistentSuite.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/DropIgnoreNonexistentSuite.scala
new file mode 100644
index 000000000..bbc61fb44
--- /dev/null
+++ b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/DropIgnoreNonexistentSuite.scala
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql
+
+import org.apache.spark.sql.catalyst.plans.logical.{DropNamespace, NoopCommand}
+import org.apache.spark.sql.execution.command._
+
+import org.apache.kyuubi.sql.KyuubiSQLConf
+
+class DropIgnoreNonexistentSuite extends KyuubiSparkSQLExtensionTest {
+
+ test("drop ignore nonexistent") {
+ withSQLConf(KyuubiSQLConf.DROP_IGNORE_NONEXISTENT.key -> "true") {
+ // drop nonexistent database
+ val df1 = sql("DROP DATABASE nonexistent_database")
+ assert(df1.queryExecution.analyzed.asInstanceOf[DropNamespace].ifExists == true)
+
+ // drop nonexistent function
+ val df4 = sql("DROP FUNCTION nonexistent_function")
+ assert(df4.queryExecution.analyzed.isInstanceOf[NoopCommand])
+
+ // drop nonexistent PARTITION
+ withTable("test") {
+ sql("CREATE TABLE IF NOT EXISTS test(i int) PARTITIONED BY (p int)")
+ val df5 = sql("ALTER TABLE test DROP PARTITION (p = 1)")
+ assert(df5.queryExecution.analyzed
+ .asInstanceOf[AlterTableDropPartitionCommand].ifExists == true)
+ }
+ }
+ }
+}
diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/FinalStageConfigIsolationSuite.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/FinalStageConfigIsolationSuite.scala
new file mode 100644
index 000000000..96c8ae6e8
--- /dev/null
+++ b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/FinalStageConfigIsolationSuite.scala
@@ -0,0 +1,203 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, QueryStageExec}
+import org.apache.spark.sql.internal.SQLConf
+
+import org.apache.kyuubi.sql.{FinalStageConfigIsolation, KyuubiSQLConf}
+
+class FinalStageConfigIsolationSuite extends KyuubiSparkSQLExtensionTest {
+ override protected def beforeAll(): Unit = {
+ super.beforeAll()
+ setupData()
+ }
+
+ test("final stage config set reset check") {
+ withSQLConf(
+ KyuubiSQLConf.FINAL_STAGE_CONFIG_ISOLATION.key -> "true",
+ KyuubiSQLConf.FINAL_STAGE_CONFIG_ISOLATION_WRITE_ONLY.key -> "false",
+ "spark.sql.finalStage.adaptive.coalescePartitions.minPartitionNum" -> "1",
+ "spark.sql.finalStage.adaptive.advisoryPartitionSizeInBytes" -> "100") {
+ // use loop to double check final stage config doesn't affect the sql query each other
+ (1 to 3).foreach { _ =>
+ sql("SELECT COUNT(*) FROM VALUES(1) as t(c)").collect()
+ assert(spark.sessionState.conf.getConfString(
+ "spark.sql.previousStage.adaptive.coalescePartitions.minPartitionNum") ===
+ FinalStageConfigIsolation.INTERNAL_UNSET_CONFIG_TAG)
+ assert(spark.sessionState.conf.getConfString(
+ "spark.sql.adaptive.coalescePartitions.minPartitionNum") ===
+ "1")
+ assert(spark.sessionState.conf.getConfString(
+ "spark.sql.finalStage.adaptive.coalescePartitions.minPartitionNum") ===
+ "1")
+
+ // 64MB
+ assert(spark.sessionState.conf.getConfString(
+ "spark.sql.previousStage.adaptive.advisoryPartitionSizeInBytes") ===
+ "67108864b")
+ assert(spark.sessionState.conf.getConfString(
+ "spark.sql.adaptive.advisoryPartitionSizeInBytes") ===
+ "100")
+ assert(spark.sessionState.conf.getConfString(
+ "spark.sql.finalStage.adaptive.advisoryPartitionSizeInBytes") ===
+ "100")
+ }
+
+ sql("SET spark.sql.adaptive.advisoryPartitionSizeInBytes=1")
+ assert(spark.sessionState.conf.getConfString(
+ "spark.sql.adaptive.advisoryPartitionSizeInBytes") ===
+ "1")
+ assert(!spark.sessionState.conf.contains(
+ "spark.sql.previousStage.adaptive.advisoryPartitionSizeInBytes"))
+
+ sql("SET a=1")
+ assert(spark.sessionState.conf.getConfString("a") === "1")
+
+ sql("RESET spark.sql.adaptive.coalescePartitions.minPartitionNum")
+ assert(!spark.sessionState.conf.contains(
+ "spark.sql.adaptive.coalescePartitions.minPartitionNum"))
+ assert(!spark.sessionState.conf.contains(
+ "spark.sql.previousStage.adaptive.coalescePartitions.minPartitionNum"))
+
+ sql("RESET a")
+ assert(!spark.sessionState.conf.contains("a"))
+ }
+ }
+
+ test("final stage config isolation") {
+ def checkPartitionNum(
+ sqlString: String,
+ previousPartitionNum: Int,
+ finalPartitionNum: Int): Unit = {
+ val df = sql(sqlString)
+ df.collect()
+ val shuffleReaders = collect(df.queryExecution.executedPlan) {
+ case customShuffleReader: AQEShuffleReadExec => customShuffleReader
+ }
+ assert(shuffleReaders.nonEmpty)
+ // reorder stage by stage id to ensure we get the right stage
+ val sortedShuffleReaders = shuffleReaders.sortWith {
+ case (s1, s2) =>
+ s1.child.asInstanceOf[QueryStageExec].id < s2.child.asInstanceOf[QueryStageExec].id
+ }
+ if (sortedShuffleReaders.length > 1) {
+ assert(sortedShuffleReaders.head.partitionSpecs.length === previousPartitionNum)
+ }
+ assert(sortedShuffleReaders.last.partitionSpecs.length === finalPartitionNum)
+ assert(df.rdd.partitions.length === finalPartitionNum)
+ }
+
+ withSQLConf(
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+ SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1",
+ SQLConf.SHUFFLE_PARTITIONS.key -> "3",
+ KyuubiSQLConf.FINAL_STAGE_CONFIG_ISOLATION.key -> "true",
+ KyuubiSQLConf.FINAL_STAGE_CONFIG_ISOLATION_WRITE_ONLY.key -> "false",
+ "spark.sql.adaptive.advisoryPartitionSizeInBytes" -> "1",
+ "spark.sql.adaptive.coalescePartitions.minPartitionSize" -> "1",
+ "spark.sql.finalStage.adaptive.advisoryPartitionSizeInBytes" -> "10000000") {
+
+ // use loop to double check final stage config doesn't affect the sql query each other
+ (1 to 3).foreach { _ =>
+ checkPartitionNum(
+ "SELECT c1, count(*) FROM t1 GROUP BY c1",
+ 1,
+ 1)
+
+ checkPartitionNum(
+ "SELECT c2, count(*) FROM (SELECT c1, count(*) as c2 FROM t1 GROUP BY c1) GROUP BY c2",
+ 3,
+ 1)
+
+ checkPartitionNum(
+ "SELECT t1.c1, count(*) FROM t1 JOIN t2 ON t1.c2 = t2.c2 GROUP BY t1.c1",
+ 3,
+ 1)
+
+ checkPartitionNum(
+ """
+ | SELECT /*+ REPARTITION */
+ | t1.c1, count(*) FROM t1
+ | JOIN t2 ON t1.c2 = t2.c2
+ | JOIN t3 ON t1.c1 = t3.c1
+ | GROUP BY t1.c1
+ |""".stripMargin,
+ 3,
+ 1)
+
+ // one shuffle reader
+ checkPartitionNum(
+ """
+ | SELECT /*+ BROADCAST(t1) */
+ | t1.c1, t2.c2 FROM t1
+ | JOIN t2 ON t1.c2 = t2.c2
+ | DISTRIBUTE BY c1
+ |""".stripMargin,
+ 1,
+ 1)
+
+ // test ReusedExchange
+ checkPartitionNum(
+ """
+ |SELECT /*+ REPARTITION */ t0.c2 FROM (
+ |SELECT t1.c1, (count(*) + c1) as c2 FROM t1 GROUP BY t1.c1
+ |) t0 JOIN (
+ |SELECT t1.c1, (count(*) + c1) as c2 FROM t1 GROUP BY t1.c1
+ |) t1 ON t0.c2 = t1.c2
+ |""".stripMargin,
+ 3,
+ 1)
+
+ // one shuffle reader
+ checkPartitionNum(
+ """
+ |SELECT t0.c1 FROM (
+ |SELECT t1.c1 FROM t1 GROUP BY t1.c1
+ |) t0 JOIN (
+ |SELECT t1.c1 FROM t1 GROUP BY t1.c1
+ |) t1 ON t0.c1 = t1.c1
+ |""".stripMargin,
+ 1,
+ 1)
+ }
+ }
+ }
+
+ test("final stage config isolation write only") {
+ withSQLConf(
+ KyuubiSQLConf.FINAL_STAGE_CONFIG_ISOLATION.key -> "true",
+ KyuubiSQLConf.FINAL_STAGE_CONFIG_ISOLATION_WRITE_ONLY.key -> "true",
+ "spark.sql.finalStage.adaptive.advisoryPartitionSizeInBytes" -> "7") {
+ sql("set spark.sql.adaptive.advisoryPartitionSizeInBytes=5")
+ sql("SELECT * FROM t1").count()
+ assert(spark.conf.getOption("spark.sql.adaptive.advisoryPartitionSizeInBytes")
+ .contains("5"))
+
+ withTable("tmp") {
+ sql("CREATE TABLE t1 USING PARQUET SELECT /*+ repartition */ 1 AS c1, 'a' AS c2")
+ assert(spark.conf.getOption("spark.sql.adaptive.advisoryPartitionSizeInBytes")
+ .contains("7"))
+ }
+
+ sql("SELECT * FROM t1").count()
+ assert(spark.conf.getOption("spark.sql.adaptive.advisoryPartitionSizeInBytes")
+ .contains("5"))
+ }
+ }
+}
diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/FinalStageResourceManagerSuite.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/FinalStageResourceManagerSuite.scala
new file mode 100644
index 000000000..4b9991ef6
--- /dev/null
+++ b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/FinalStageResourceManagerSuite.scala
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.SparkConf
+import org.scalatest.time.{Minutes, Span}
+
+import org.apache.kyuubi.sql.KyuubiSQLConf
+import org.apache.kyuubi.tags.SparkLocalClusterTest
+
+@SparkLocalClusterTest
+class FinalStageResourceManagerSuite extends KyuubiSparkSQLExtensionTest {
+
+ override def sparkConf(): SparkConf = {
+ // It is difficult to run spark in local-cluster mode when spark.testing is set.
+ sys.props.remove("spark.testing")
+
+ super.sparkConf().set("spark.master", "local-cluster[3, 1, 1024]")
+ .set("spark.dynamicAllocation.enabled", "true")
+ .set("spark.dynamicAllocation.initialExecutors", "3")
+ .set("spark.dynamicAllocation.minExecutors", "1")
+ .set("spark.dynamicAllocation.shuffleTracking.enabled", "true")
+ .set(KyuubiSQLConf.FINAL_STAGE_CONFIG_ISOLATION.key, "true")
+ .set(KyuubiSQLConf.FINAL_WRITE_STAGE_EAGERLY_KILL_EXECUTORS_ENABLED.key, "true")
+ }
+
+ test("[KYUUBI #5136][Bug] Final Stage hangs forever") {
+ // Prerequisite to reproduce the bug:
+ // 1. Dynamic allocation is enabled.
+ // 2. Dynamic allocation min executors is 1.
+ // 3. target executors < active executors.
+ // 4. No active executor is left after FinalStageResourceManager killed executors.
+ // This is possible because FinalStageResourceManager retained executors may already be
+ // requested to be killed but not died yet.
+ // 5. Final Stage required executors is 1.
+ withSQLConf(
+ (KyuubiSQLConf.FINAL_WRITE_STAGE_EAGERLY_KILL_EXECUTORS_KILL_ALL.key, "true")) {
+ withTable("final_stage") {
+ eventually(timeout(Span(10, Minutes))) {
+ sql(
+ "CREATE TABLE final_stage AS SELECT id, count(*) as num FROM (SELECT 0 id) GROUP BY id")
+ }
+ assert(FinalStageResourceManager.getAdjustedTargetExecutors(spark.sparkContext).get == 1)
+ }
+ }
+ }
+}
diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/InjectResourceProfileSuite.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/InjectResourceProfileSuite.scala
new file mode 100644
index 000000000..b0767b187
--- /dev/null
+++ b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/InjectResourceProfileSuite.scala
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent}
+import org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate
+
+import org.apache.kyuubi.sql.KyuubiSQLConf
+
+class InjectResourceProfileSuite extends KyuubiSparkSQLExtensionTest {
+ private def checkCustomResourceProfile(sqlString: String, exists: Boolean): Unit = {
+ @volatile var lastEvent: SparkListenerSQLAdaptiveExecutionUpdate = null
+ val listener = new SparkListener {
+ override def onOtherEvent(event: SparkListenerEvent): Unit = {
+ event match {
+ case e: SparkListenerSQLAdaptiveExecutionUpdate => lastEvent = e
+ case _ =>
+ }
+ }
+ }
+
+ spark.sparkContext.addSparkListener(listener)
+ try {
+ sql(sqlString).collect()
+ spark.sparkContext.listenerBus.waitUntilEmpty()
+ assert(lastEvent != null)
+ var current = lastEvent.sparkPlanInfo
+ var shouldStop = false
+ while (!shouldStop) {
+ if (current.nodeName != "CustomResourceProfile") {
+ if (current.children.isEmpty) {
+ assert(!exists)
+ shouldStop = true
+ } else {
+ current = current.children.head
+ }
+ } else {
+ assert(exists)
+ shouldStop = true
+ }
+ }
+ } finally {
+ spark.sparkContext.removeSparkListener(listener)
+ }
+ }
+
+ test("Inject resource profile") {
+ withTable("t") {
+ withSQLConf(
+ "spark.sql.adaptive.forceApply" -> "true",
+ KyuubiSQLConf.FINAL_STAGE_CONFIG_ISOLATION.key -> "true",
+ KyuubiSQLConf.FINAL_WRITE_STAGE_RESOURCE_ISOLATION_ENABLED.key -> "true") {
+
+ sql("CREATE TABLE t (c1 int, c2 string) USING PARQUET")
+
+ checkCustomResourceProfile("INSERT INTO TABLE t VALUES(1, 'a')", false)
+ checkCustomResourceProfile("SELECT 1", false)
+ checkCustomResourceProfile(
+ "INSERT INTO TABLE t SELECT /*+ rebalance */ * FROM VALUES(1, 'a')",
+ true)
+ }
+ }
+ }
+}
diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/InsertShuffleNodeBeforeJoinSuite.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/InsertShuffleNodeBeforeJoinSuite.scala
new file mode 100644
index 000000000..f0d384657
--- /dev/null
+++ b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/InsertShuffleNodeBeforeJoinSuite.scala
@@ -0,0 +1,19 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql
+
+class InsertShuffleNodeBeforeJoinSuite extends InsertShuffleNodeBeforeJoinSuiteBase
diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/InsertShuffleNodeBeforeJoinSuiteBase.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/InsertShuffleNodeBeforeJoinSuiteBase.scala
new file mode 100644
index 000000000..c657dee49
--- /dev/null
+++ b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/InsertShuffleNodeBeforeJoinSuiteBase.scala
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.SparkConf
+import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, ShuffleExchangeLike}
+import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
+
+import org.apache.kyuubi.sql.KyuubiSQLConf
+
+trait InsertShuffleNodeBeforeJoinSuiteBase extends KyuubiSparkSQLExtensionTest {
+ override protected def beforeAll(): Unit = {
+ super.beforeAll()
+ setupData()
+ }
+
+ override def sparkConf(): SparkConf = {
+ super.sparkConf()
+ .set(
+ StaticSQLConf.SPARK_SESSION_EXTENSIONS.key,
+ "org.apache.kyuubi.sql.KyuubiSparkSQLCommonExtension")
+ }
+
+ test("force shuffle before join") {
+ def checkShuffleNodeNum(sqlString: String, num: Int): Unit = {
+ var expectedResult: Seq[Row] = Seq.empty
+ withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") {
+ expectedResult = sql(sqlString).collect()
+ }
+ val df = sql(sqlString)
+ checkAnswer(df, expectedResult)
+ assert(
+ collect(df.queryExecution.executedPlan) {
+ case shuffle: ShuffleExchangeLike if shuffle.shuffleOrigin == ENSURE_REQUIREMENTS =>
+ shuffle
+ }.size == num)
+ }
+
+ withSQLConf(
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+ KyuubiSQLConf.FORCE_SHUFFLE_BEFORE_JOIN.key -> "true") {
+ Seq("SHUFFLE_HASH", "MERGE").foreach { joinHint =>
+ // positive case
+ checkShuffleNodeNum(
+ s"""
+ |SELECT /*+ $joinHint(t2, t3) */ t1.c1, t1.c2, t2.c1, t3.c1 from t1
+ | JOIN t2 ON t1.c1 = t2.c1
+ | JOIN t3 ON t1.c1 = t3.c1
+ | """.stripMargin,
+ 4)
+
+ // negative case
+ checkShuffleNodeNum(
+ s"""
+ |SELECT /*+ $joinHint(t2, t3) */ t1.c1, t1.c2, t2.c1, t3.c1 from t1
+ | JOIN t2 ON t1.c1 = t2.c1
+ | JOIN t3 ON t1.c2 = t3.c2
+ | """.stripMargin,
+ 4)
+ }
+
+ checkShuffleNodeNum(
+ """
+ |SELECT t1.c1, t2.c1, t3.c2 from t1
+ | JOIN t2 ON t1.c1 = t2.c1
+ | JOIN (
+ | SELECT c2, count(*) FROM t1 GROUP BY c2
+ | ) t3 ON t1.c1 = t3.c2
+ | """.stripMargin,
+ 5)
+
+ checkShuffleNodeNum(
+ """
+ |SELECT t1.c1, t2.c1, t3.c1 from t1
+ | JOIN t2 ON t1.c1 = t2.c1
+ | JOIN (
+ | SELECT c1, count(*) FROM t1 GROUP BY c1
+ | ) t3 ON t1.c1 = t3.c1
+ | """.stripMargin,
+ 5)
+ }
+ }
+}
diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/KyuubiSparkSQLExtensionTest.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/KyuubiSparkSQLExtensionTest.scala
new file mode 100644
index 000000000..dd9ffbf16
--- /dev/null
+++ b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/KyuubiSparkSQLExtensionTest.scala
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql
+
+import org.apache.hadoop.hive.conf.HiveConf.ConfVars
+import org.apache.spark.SparkConf
+import org.apache.spark.sql.execution.QueryExecution
+import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
+import org.apache.spark.sql.execution.command.{DataWritingCommand, DataWritingCommandExec}
+import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
+import org.apache.spark.sql.test.SQLTestData.TestData
+import org.apache.spark.sql.test.SQLTestUtils
+import org.apache.spark.sql.util.QueryExecutionListener
+import org.apache.spark.util.Utils
+
+import org.apache.kyuubi.sql.KyuubiSQLConf
+
+trait KyuubiSparkSQLExtensionTest extends QueryTest
+ with SQLTestUtils
+ with AdaptiveSparkPlanHelper {
+ sys.props.put("spark.testing", "1")
+
+ private var _spark: Option[SparkSession] = None
+ protected def spark: SparkSession = _spark.getOrElse {
+ throw new RuntimeException("test spark session don't initial before using it.")
+ }
+
+ override protected def beforeAll(): Unit = {
+ if (_spark.isEmpty) {
+ _spark = Option(SparkSession.builder()
+ .master("local[1]")
+ .config(sparkConf)
+ .enableHiveSupport()
+ .getOrCreate())
+ }
+ super.beforeAll()
+ }
+
+ override protected def afterAll(): Unit = {
+ super.afterAll()
+ cleanupData()
+ _spark.foreach(_.stop)
+ }
+
+ protected def setupData(): Unit = {
+ val self = spark
+ import self.implicits._
+ spark.sparkContext.parallelize(
+ (1 to 100).map(i => TestData(i, i.toString)),
+ 10)
+ .toDF("c1", "c2").createOrReplaceTempView("t1")
+ spark.sparkContext.parallelize(
+ (1 to 10).map(i => TestData(i, i.toString)),
+ 5)
+ .toDF("c1", "c2").createOrReplaceTempView("t2")
+ spark.sparkContext.parallelize(
+ (1 to 50).map(i => TestData(i, i.toString)),
+ 2)
+ .toDF("c1", "c2").createOrReplaceTempView("t3")
+ }
+
+ private def cleanupData(): Unit = {
+ spark.sql("DROP VIEW IF EXISTS t1")
+ spark.sql("DROP VIEW IF EXISTS t2")
+ spark.sql("DROP VIEW IF EXISTS t3")
+ }
+
+ def sparkConf(): SparkConf = {
+ val basePath = Utils.createTempDir() + "/" + getClass.getCanonicalName
+ val metastorePath = basePath + "/metastore_db"
+ val warehousePath = basePath + "/warehouse"
+ new SparkConf()
+ .set(
+ StaticSQLConf.SPARK_SESSION_EXTENSIONS.key,
+ "org.apache.kyuubi.sql.KyuubiSparkSQLExtension")
+ .set(KyuubiSQLConf.SQL_CLASSIFICATION_ENABLED.key, "true")
+ .set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true")
+ .set("spark.hadoop.hive.exec.dynamic.partition.mode", "nonstrict")
+ .set("spark.hadoop.hive.metastore.client.capability.check", "false")
+ .set(
+ ConfVars.METASTORECONNECTURLKEY.varname,
+ s"jdbc:derby:;databaseName=$metastorePath;create=true")
+ .set(StaticSQLConf.WAREHOUSE_PATH, warehousePath)
+ .set("spark.ui.enabled", "false")
+ }
+
+ def withListener(sqlString: String)(callback: DataWritingCommand => Unit): Unit = {
+ withListener(sql(sqlString))(callback)
+ }
+
+ def withListener(df: => DataFrame)(callback: DataWritingCommand => Unit): Unit = {
+ val listener = new QueryExecutionListener {
+ override def onFailure(f: String, qe: QueryExecution, e: Exception): Unit = {}
+
+ override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {
+ qe.executedPlan match {
+ case write: DataWritingCommandExec => callback(write.cmd)
+ case _ =>
+ }
+ }
+ }
+ spark.listenerManager.register(listener)
+ try {
+ df.collect()
+ sparkContext.listenerBus.waitUntilEmpty()
+ } finally {
+ spark.listenerManager.unregister(listener)
+ }
+ }
+}
diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/RebalanceBeforeWritingSuite.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/RebalanceBeforeWritingSuite.scala
new file mode 100644
index 000000000..1d9630f49
--- /dev/null
+++ b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/RebalanceBeforeWritingSuite.scala
@@ -0,0 +1,271 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, RebalancePartitions, Sort}
+import org.apache.spark.sql.execution.command.DataWritingCommand
+import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand
+import org.apache.spark.sql.hive.HiveUtils
+import org.apache.spark.sql.hive.execution.InsertIntoHiveTable
+
+import org.apache.kyuubi.sql.KyuubiSQLConf
+
+class RebalanceBeforeWritingSuite extends KyuubiSparkSQLExtensionTest {
+
+ test("check rebalance exists") {
+ def check(df: => DataFrame, expectedRebalanceNum: Int = 1): Unit = {
+ withSQLConf(KyuubiSQLConf.INSERT_REPARTITION_BEFORE_WRITE_IF_NO_SHUFFLE.key -> "true") {
+ withListener(df) { write =>
+ assert(write.collect {
+ case r: RebalancePartitions => r
+ }.size == expectedRebalanceNum)
+ }
+ }
+ withSQLConf(KyuubiSQLConf.INSERT_REPARTITION_BEFORE_WRITE_IF_NO_SHUFFLE.key -> "false") {
+ withListener(df) { write =>
+ assert(write.collect {
+ case r: RebalancePartitions => r
+ }.isEmpty)
+ }
+ }
+ }
+
+ // It's better to set config explicitly in case of we change the default value.
+ withSQLConf(KyuubiSQLConf.INSERT_REPARTITION_BEFORE_WRITE.key -> "true") {
+ Seq("USING PARQUET", "").foreach { storage =>
+ withTable("tmp1") {
+ sql(s"CREATE TABLE tmp1 (c1 int) $storage PARTITIONED BY (c2 string)")
+ check(sql("INSERT INTO TABLE tmp1 PARTITION(c2='a') " +
+ "SELECT * FROM VALUES(1),(2) AS t(c1)"))
+ }
+
+ withTable("tmp1", "tmp2") {
+ sql(s"CREATE TABLE tmp1 (c1 int) $storage PARTITIONED BY (c2 string)")
+ sql(s"CREATE TABLE tmp2 (c1 int) $storage PARTITIONED BY (c2 string)")
+ check(
+ sql(
+ """FROM VALUES(1),(2)
+ |INSERT INTO TABLE tmp1 PARTITION(c2='a') SELECT *
+ |INSERT INTO TABLE tmp2 PARTITION(c2='a') SELECT *
+ |""".stripMargin),
+ 2)
+ }
+
+ withTable("tmp1") {
+ sql(s"CREATE TABLE tmp1 (c1 int) $storage")
+ check(sql("INSERT INTO TABLE tmp1 SELECT * FROM VALUES(1),(2),(3) AS t(c1)"))
+ }
+
+ withTable("tmp1", "tmp2") {
+ sql(s"CREATE TABLE tmp1 (c1 int) $storage")
+ sql(s"CREATE TABLE tmp2 (c1 int) $storage")
+ check(
+ sql(
+ """FROM VALUES(1),(2),(3)
+ |INSERT INTO TABLE tmp1 SELECT *
+ |INSERT INTO TABLE tmp2 SELECT *
+ |""".stripMargin),
+ 2)
+ }
+
+ withTable("tmp1") {
+ sql(s"CREATE TABLE tmp1 $storage AS SELECT * FROM VALUES(1),(2),(3) AS t(c1)")
+ }
+
+ withTable("tmp1") {
+ sql(s"CREATE TABLE tmp1 $storage PARTITIONED BY(c2) AS " +
+ s"SELECT * FROM VALUES(1, 'a'),(2, 'b') AS t(c1, c2)")
+ }
+ }
+ }
+ }
+
+ test("check rebalance does not exists") {
+ def check(df: DataFrame): Unit = {
+ withListener(df) { write =>
+ assert(write.collect {
+ case r: RebalancePartitions => r
+ }.isEmpty)
+ }
+ }
+
+ withSQLConf(
+ KyuubiSQLConf.INSERT_REPARTITION_BEFORE_WRITE.key -> "true",
+ KyuubiSQLConf.INSERT_REPARTITION_BEFORE_WRITE_IF_NO_SHUFFLE.key -> "true") {
+ // test no write command
+ check(sql("SELECT * FROM VALUES(1, 'a'),(2, 'b') AS t(c1, c2)"))
+ check(sql("SELECT count(*) FROM VALUES(1, 'a'),(2, 'b') AS t(c1, c2)"))
+
+ // test not supported plan
+ withTable("tmp1") {
+ sql(s"CREATE TABLE tmp1 (c1 int) PARTITIONED BY (c2 string)")
+ check(sql("INSERT INTO TABLE tmp1 PARTITION(c2) " +
+ "SELECT /*+ repartition(10) */ * FROM VALUES(1, 'a'),(2, 'b') AS t(c1, c2)"))
+ check(sql("INSERT INTO TABLE tmp1 PARTITION(c2) " +
+ "SELECT * FROM VALUES(1, 'a'),(2, 'b') AS t(c1, c2) ORDER BY c1"))
+ check(sql("INSERT INTO TABLE tmp1 PARTITION(c2) " +
+ "SELECT * FROM VALUES(1, 'a'),(2, 'b') AS t(c1, c2) LIMIT 10"))
+ }
+ }
+
+ withSQLConf(KyuubiSQLConf.INSERT_REPARTITION_BEFORE_WRITE.key -> "false") {
+ Seq("USING PARQUET", "").foreach { storage =>
+ withTable("tmp1") {
+ sql(s"CREATE TABLE tmp1 (c1 int) $storage PARTITIONED BY (c2 string)")
+ check(sql("INSERT INTO TABLE tmp1 PARTITION(c2) " +
+ "SELECT * FROM VALUES(1, 'a'),(2, 'b') AS t(c1, c2)"))
+ }
+
+ withTable("tmp1") {
+ sql(s"CREATE TABLE tmp1 (c1 int) $storage")
+ check(sql("INSERT INTO TABLE tmp1 SELECT * FROM VALUES(1),(2),(3) AS t(c1)"))
+ }
+ }
+ }
+ }
+
+ test("test dynamic partition write") {
+ def checkRepartitionExpression(sqlString: String): Unit = {
+ withListener(sqlString) { write =>
+ assert(write.isInstanceOf[InsertIntoHiveTable])
+ assert(write.collect {
+ case r: RebalancePartitions if r.partitionExpressions.size == 1 =>
+ assert(r.partitionExpressions.head.asInstanceOf[Attribute].name === "c2")
+ r
+ }.size == 1)
+ }
+ }
+
+ withSQLConf(
+ KyuubiSQLConf.INSERT_REPARTITION_BEFORE_WRITE.key -> "true",
+ KyuubiSQLConf.DYNAMIC_PARTITION_INSERTION_REPARTITION_NUM.key -> "2",
+ KyuubiSQLConf.INSERT_REPARTITION_BEFORE_WRITE_IF_NO_SHUFFLE.key -> "true") {
+ Seq("USING PARQUET", "").foreach { storage =>
+ withTable("tmp1") {
+ sql(s"CREATE TABLE tmp1 (c1 int) $storage PARTITIONED BY (c2 string)")
+ checkRepartitionExpression("INSERT INTO TABLE tmp1 SELECT 1 as c1, 'a' as c2 ")
+ }
+
+ withTable("tmp1") {
+ checkRepartitionExpression(
+ "CREATE TABLE tmp1 PARTITIONED BY(C2) SELECT 1 as c1, 'a' as c2")
+ }
+ }
+ }
+ }
+
+ test("OptimizedCreateHiveTableAsSelectCommand") {
+ withSQLConf(
+ HiveUtils.CONVERT_METASTORE_PARQUET.key -> "true",
+ HiveUtils.CONVERT_METASTORE_CTAS.key -> "true",
+ KyuubiSQLConf.INSERT_REPARTITION_BEFORE_WRITE_IF_NO_SHUFFLE.key -> "true") {
+ withTable("t") {
+ withListener("CREATE TABLE t STORED AS parquet AS SELECT 1 as a") { write =>
+ assert(write.isInstanceOf[InsertIntoHadoopFsRelationCommand])
+ assert(write.collect {
+ case _: RebalancePartitions => true
+ }.size == 1)
+ }
+ }
+ }
+ }
+
+ test("Infer rebalance and sorder orders") {
+ def checkShuffleAndSort(dataWritingCommand: LogicalPlan, sSize: Int, rSize: Int): Unit = {
+ assert(dataWritingCommand.isInstanceOf[DataWritingCommand])
+ val plan = dataWritingCommand.asInstanceOf[DataWritingCommand].query
+ assert(plan.collect {
+ case s: Sort => s
+ }.size == sSize)
+ assert(plan.collect {
+ case r: RebalancePartitions if r.partitionExpressions.size == rSize => r
+ }.nonEmpty || rSize == 0)
+ }
+
+ withView("v") {
+ withTable("t", "input1", "input2") {
+ withSQLConf(KyuubiSQLConf.INFER_REBALANCE_AND_SORT_ORDERS.key -> "true") {
+ sql(s"CREATE TABLE t (c1 int, c2 long) USING PARQUET PARTITIONED BY (p string)")
+ sql(s"CREATE TABLE input1 USING PARQUET AS SELECT * FROM VALUES(1,2),(1,3)")
+ sql(s"CREATE TABLE input2 USING PARQUET AS SELECT * FROM VALUES(1,3),(1,3)")
+ sql(s"CREATE VIEW v as SELECT col1, count(*) as col2 FROM input1 GROUP BY col1")
+
+ val df0 = sql(
+ s"""
+ |INSERT INTO TABLE t PARTITION(p='a')
+ |SELECT /*+ broadcast(input2) */ input1.col1, input2.col1
+ |FROM input1
+ |JOIN input2
+ |ON input1.col1 = input2.col1
+ |""".stripMargin)
+ checkShuffleAndSort(df0.queryExecution.analyzed, 1, 1)
+
+ val df1 = sql(
+ s"""
+ |INSERT INTO TABLE t PARTITION(p='a')
+ |SELECT /*+ broadcast(input2) */ input1.col1, input1.col2
+ |FROM input1
+ |LEFT JOIN input2
+ |ON input1.col1 = input2.col1 and input1.col2 = input2.col2
+ |""".stripMargin)
+ checkShuffleAndSort(df1.queryExecution.analyzed, 1, 2)
+
+ val df2 = sql(
+ s"""
+ |INSERT INTO TABLE t PARTITION(p='a')
+ |SELECT col1 as c1, count(*) as c2
+ |FROM input1
+ |GROUP BY col1
+ |HAVING count(*) > 0
+ |""".stripMargin)
+ checkShuffleAndSort(df2.queryExecution.analyzed, 1, 1)
+
+ // dynamic partition
+ val df3 = sql(
+ s"""
+ |INSERT INTO TABLE t PARTITION(p)
+ |SELECT /*+ broadcast(input2) */ input1.col1, input1.col2, input1.col2
+ |FROM input1
+ |JOIN input2
+ |ON input1.col1 = input2.col1
+ |""".stripMargin)
+ checkShuffleAndSort(df3.queryExecution.analyzed, 0, 1)
+
+ // non-deterministic
+ val df4 = sql(
+ s"""
+ |INSERT INTO TABLE t PARTITION(p='a')
+ |SELECT col1 + rand(), count(*) as c2
+ |FROM input1
+ |GROUP BY col1
+ |""".stripMargin)
+ checkShuffleAndSort(df4.queryExecution.analyzed, 0, 0)
+
+ // view
+ val df5 = sql(
+ s"""
+ |INSERT INTO TABLE t PARTITION(p='a')
+ |SELECT * FROM v
+ |""".stripMargin)
+ checkShuffleAndSort(df5.queryExecution.analyzed, 1, 1)
+ }
+ }
+ }
+ }
+}
diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/WatchDogSuite.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/WatchDogSuite.scala
new file mode 100644
index 000000000..957089340
--- /dev/null
+++ b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/WatchDogSuite.scala
@@ -0,0 +1,20 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+class WatchDogSuite extends WatchDogSuiteBase {}
diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/WatchDogSuiteBase.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/WatchDogSuiteBase.scala
new file mode 100644
index 000000000..a202e813c
--- /dev/null
+++ b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/WatchDogSuiteBase.scala
@@ -0,0 +1,601 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.io.File
+
+import scala.collection.JavaConverters._
+
+import org.apache.commons.io.FileUtils
+import org.apache.spark.sql.catalyst.plans.logical.{GlobalLimit, LogicalPlan}
+
+import org.apache.kyuubi.sql.KyuubiSQLConf
+import org.apache.kyuubi.sql.watchdog.{MaxFileSizeExceedException, MaxPartitionExceedException}
+
+trait WatchDogSuiteBase extends KyuubiSparkSQLExtensionTest {
+ override protected def beforeAll(): Unit = {
+ super.beforeAll()
+ setupData()
+ }
+
+ case class LimitAndExpected(limit: Int, expected: Int)
+
+ val limitAndExpecteds = List(LimitAndExpected(1, 1), LimitAndExpected(11, 10))
+
+ private def checkMaxPartition: Unit = {
+ withSQLConf(KyuubiSQLConf.WATCHDOG_MAX_PARTITIONS.key -> "100") {
+ checkAnswer(sql("SELECT count(distinct(p)) FROM test"), Row(10) :: Nil)
+ }
+ withSQLConf(KyuubiSQLConf.WATCHDOG_MAX_PARTITIONS.key -> "5") {
+ sql("SELECT * FROM test where p=1").queryExecution.sparkPlan
+
+ sql(s"SELECT * FROM test WHERE p in (${Range(0, 5).toList.mkString(",")})")
+ .queryExecution.sparkPlan
+
+ intercept[MaxPartitionExceedException](
+ sql("SELECT * FROM test where p != 1").queryExecution.sparkPlan)
+
+ intercept[MaxPartitionExceedException](
+ sql("SELECT * FROM test").queryExecution.sparkPlan)
+
+ intercept[MaxPartitionExceedException](sql(
+ s"SELECT * FROM test WHERE p in (${Range(0, 6).toList.mkString(",")})")
+ .queryExecution.sparkPlan)
+ }
+ }
+
+ test("watchdog with scan maxPartitions -- hive") {
+ Seq("textfile", "parquet").foreach { format =>
+ withTable("test", "temp") {
+ sql(
+ s"""
+ |CREATE TABLE test(i int)
+ |PARTITIONED BY (p int)
+ |STORED AS $format""".stripMargin)
+ spark.range(0, 10, 1).selectExpr("id as col")
+ .createOrReplaceTempView("temp")
+
+ for (part <- Range(0, 10)) {
+ sql(
+ s"""
+ |INSERT OVERWRITE TABLE test PARTITION (p='$part')
+ |select col from temp""".stripMargin)
+ }
+ checkMaxPartition
+ }
+ }
+ }
+
+ test("watchdog with scan maxPartitions -- data source") {
+ withTempDir { dir =>
+ withTempView("test") {
+ spark.range(10).selectExpr("id", "id as p")
+ .write
+ .partitionBy("p")
+ .mode("overwrite")
+ .save(dir.getCanonicalPath)
+ spark.read.load(dir.getCanonicalPath).createOrReplaceTempView("test")
+ checkMaxPartition
+ }
+ }
+ }
+
+ test("test watchdog: simple SELECT STATEMENT") {
+
+ withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") {
+
+ List("", "ORDER BY c1", "ORDER BY c2").foreach { sort =>
+ List("", " DISTINCT").foreach { distinct =>
+ assert(sql(
+ s"""
+ |SELECT $distinct *
+ |FROM t1
+ |$sort
+ |""".stripMargin).queryExecution.optimizedPlan.isInstanceOf[GlobalLimit])
+ }
+ }
+
+ limitAndExpecteds.foreach { case LimitAndExpected(limit, expected) =>
+ List("", "ORDER BY c1", "ORDER BY c2").foreach { sort =>
+ List("", "DISTINCT").foreach { distinct =>
+ assert(sql(
+ s"""
+ |SELECT $distinct *
+ |FROM t1
+ |$sort
+ |LIMIT $limit
+ |""".stripMargin).queryExecution.optimizedPlan.maxRows.contains(expected))
+ }
+ }
+ }
+ }
+ }
+
+ test("test watchdog: SELECT ... WITH AGGREGATE STATEMENT ") {
+
+ withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") {
+
+ assert(!sql("SELECT count(*) FROM t1")
+ .queryExecution.optimizedPlan.isInstanceOf[GlobalLimit])
+
+ val sorts = List("", "ORDER BY cnt", "ORDER BY c1", "ORDER BY cnt, c1", "ORDER BY c1, cnt")
+ val havingConditions = List("", "HAVING cnt > 1")
+
+ havingConditions.foreach { having =>
+ sorts.foreach { sort =>
+ assert(sql(
+ s"""
+ |SELECT c1, COUNT(*) as cnt
+ |FROM t1
+ |GROUP BY c1
+ |$having
+ |$sort
+ |""".stripMargin).queryExecution.optimizedPlan.isInstanceOf[GlobalLimit])
+ }
+ }
+
+ limitAndExpecteds.foreach { case LimitAndExpected(limit, expected) =>
+ havingConditions.foreach { having =>
+ sorts.foreach { sort =>
+ assert(sql(
+ s"""
+ |SELECT c1, COUNT(*) as cnt
+ |FROM t1
+ |GROUP BY c1
+ |$having
+ |$sort
+ |LIMIT $limit
+ |""".stripMargin).queryExecution.optimizedPlan.maxRows.contains(expected))
+ }
+ }
+ }
+ }
+ }
+
+ test("test watchdog: SELECT with CTE forceMaxOutputRows") {
+ // simple CTE
+ val q1 =
+ """
+ |WITH t2 AS (
+ | SELECT * FROM t1
+ |)
+ |""".stripMargin
+
+ // nested CTE
+ val q2 =
+ """
+ |WITH
+ | t AS (SELECT * FROM t1),
+ | t2 AS (
+ | WITH t3 AS (SELECT * FROM t1)
+ | SELECT * FROM t3
+ | )
+ |""".stripMargin
+ withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") {
+
+ val sorts = List("", "ORDER BY c1", "ORDER BY c2")
+
+ sorts.foreach { sort =>
+ Seq(q1, q2).foreach { withQuery =>
+ assert(sql(
+ s"""
+ |$withQuery
+ |SELECT * FROM t2
+ |$sort
+ |""".stripMargin).queryExecution.optimizedPlan.isInstanceOf[GlobalLimit])
+ }
+ }
+
+ limitAndExpecteds.foreach { case LimitAndExpected(limit, expected) =>
+ sorts.foreach { sort =>
+ Seq(q1, q2).foreach { withQuery =>
+ assert(sql(
+ s"""
+ |$withQuery
+ |SELECT * FROM t2
+ |$sort
+ |LIMIT $limit
+ |""".stripMargin).queryExecution.optimizedPlan.maxRows.contains(expected))
+ }
+ }
+ }
+ }
+ }
+
+ test("test watchdog: SELECT AGGREGATE WITH CTE forceMaxOutputRows") {
+
+ withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") {
+
+ assert(!sql(
+ """
+ |WITH custom_cte AS (
+ |SELECT * FROM t1
+ |)
+ |
+ |SELECT COUNT(*)
+ |FROM custom_cte
+ |""".stripMargin).queryExecution
+ .analyzed.isInstanceOf[GlobalLimit])
+
+ val sorts = List("", "ORDER BY cnt", "ORDER BY c1", "ORDER BY cnt, c1", "ORDER BY c1, cnt")
+ val havingConditions = List("", "HAVING cnt > 1")
+
+ havingConditions.foreach { having =>
+ sorts.foreach { sort =>
+ assert(sql(
+ s"""
+ |WITH custom_cte AS (
+ |SELECT * FROM t1
+ |)
+ |
+ |SELECT c1, COUNT(*) as cnt
+ |FROM custom_cte
+ |GROUP BY c1
+ |$having
+ |$sort
+ |""".stripMargin).queryExecution.optimizedPlan.isInstanceOf[GlobalLimit])
+ }
+ }
+
+ limitAndExpecteds.foreach { case LimitAndExpected(limit, expected) =>
+ havingConditions.foreach { having =>
+ sorts.foreach { sort =>
+ assert(sql(
+ s"""
+ |WITH custom_cte AS (
+ |SELECT * FROM t1
+ |)
+ |
+ |SELECT c1, COUNT(*) as cnt
+ |FROM custom_cte
+ |GROUP BY c1
+ |$having
+ |$sort
+ |LIMIT $limit
+ |""".stripMargin).queryExecution.optimizedPlan.maxRows.contains(expected))
+ }
+ }
+ }
+ }
+ }
+
+ test("test watchdog: UNION Statement for forceMaxOutputRows") {
+
+ withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") {
+
+ List("", "ALL").foreach { x =>
+ assert(sql(
+ s"""
+ |SELECT c1, c2 FROM t1
+ |UNION $x
+ |SELECT c1, c2 FROM t2
+ |UNION $x
+ |SELECT c1, c2 FROM t3
+ |""".stripMargin)
+ .queryExecution.optimizedPlan.isInstanceOf[GlobalLimit])
+ }
+
+ val sorts = List("", "ORDER BY cnt", "ORDER BY c1", "ORDER BY cnt, c1", "ORDER BY c1, cnt")
+ val havingConditions = List("", "HAVING cnt > 1")
+
+ List("", "ALL").foreach { x =>
+ havingConditions.foreach { having =>
+ sorts.foreach { sort =>
+ assert(sql(
+ s"""
+ |SELECT c1, count(c2) as cnt
+ |FROM t1
+ |GROUP BY c1
+ |$having
+ |UNION $x
+ |SELECT c1, COUNT(c2) as cnt
+ |FROM t2
+ |GROUP BY c1
+ |$having
+ |UNION $x
+ |SELECT c1, COUNT(c2) as cnt
+ |FROM t3
+ |GROUP BY c1
+ |$having
+ |$sort
+ |""".stripMargin)
+ .queryExecution.optimizedPlan.isInstanceOf[GlobalLimit])
+ }
+ }
+ }
+
+ limitAndExpecteds.foreach { case LimitAndExpected(limit, expected) =>
+ assert(sql(
+ s"""
+ |SELECT c1, c2 FROM t1
+ |UNION
+ |SELECT c1, c2 FROM t2
+ |UNION
+ |SELECT c1, c2 FROM t3
+ |LIMIT $limit
+ |""".stripMargin)
+ .queryExecution.optimizedPlan.maxRows.contains(expected))
+ }
+ }
+ }
+
+ test("test watchdog: Select View Statement for forceMaxOutputRows") {
+ withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "3") {
+ withTable("tmp_table", "tmp_union") {
+ withView("tmp_view", "tmp_view2") {
+ sql(s"create table tmp_table (a int, b int)")
+ sql(s"insert into tmp_table values (1,10),(2,20),(3,30),(4,40),(5,50)")
+ sql(s"create table tmp_union (a int, b int)")
+ sql(s"insert into tmp_union values (6,60),(7,70),(8,80),(9,90),(10,100)")
+ sql(s"create view tmp_view2 as select * from tmp_union")
+ assert(!sql(
+ s"""
+ |CREATE VIEW tmp_view
+ |as
+ |SELECT * FROM
+ |tmp_table
+ |""".stripMargin)
+ .queryExecution.optimizedPlan.isInstanceOf[GlobalLimit])
+
+ assert(sql(
+ s"""
+ |SELECT * FROM
+ |tmp_view
+ |""".stripMargin)
+ .queryExecution.optimizedPlan.maxRows.contains(3))
+
+ assert(sql(
+ s"""
+ |SELECT * FROM
+ |tmp_view
+ |limit 11
+ |""".stripMargin)
+ .queryExecution.optimizedPlan.maxRows.contains(3))
+
+ assert(sql(
+ s"""
+ |SELECT * FROM
+ |(select * from tmp_view
+ |UNION
+ |select * from tmp_view2)
+ |ORDER BY a
+ |DESC
+ |""".stripMargin)
+ .collect().head.get(0) === 10)
+ }
+ }
+ }
+ }
+
+ test("test watchdog: Insert Statement for forceMaxOutputRows") {
+
+ withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") {
+ withTable("tmp_table", "tmp_insert") {
+ spark.sql(s"create table tmp_table (a int, b int)")
+ spark.sql(s"insert into tmp_table values (1,10),(2,20),(3,30),(4,40),(5,50)")
+ val multiInsertTableName1: String = "tmp_tbl1"
+ val multiInsertTableName2: String = "tmp_tbl2"
+ sql(s"drop table if exists $multiInsertTableName1")
+ sql(s"drop table if exists $multiInsertTableName2")
+ sql(s"create table $multiInsertTableName1 like tmp_table")
+ sql(s"create table $multiInsertTableName2 like tmp_table")
+ assert(!sql(
+ s"""
+ |FROM tmp_table
+ |insert into $multiInsertTableName1 select * limit 2
+ |insert into $multiInsertTableName2 select *
+ |""".stripMargin)
+ .queryExecution.optimizedPlan.isInstanceOf[GlobalLimit])
+ }
+ }
+ }
+
+ test("test watchdog: Distribute by for forceMaxOutputRows") {
+
+ withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") {
+ withTable("tmp_table") {
+ spark.sql(s"create table tmp_table (a int, b int)")
+ spark.sql(s"insert into tmp_table values (1,10),(2,20),(3,30),(4,40),(5,50)")
+ assert(sql(
+ s"""
+ |SELECT *
+ |FROM tmp_table
+ |DISTRIBUTE BY a
+ |""".stripMargin)
+ .queryExecution.optimizedPlan.isInstanceOf[GlobalLimit])
+ }
+ }
+ }
+
+ test("test watchdog: Subquery for forceMaxOutputRows") {
+ withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "1") {
+ withTable("tmp_table1") {
+ sql("CREATE TABLE spark_catalog.`default`.tmp_table1(KEY INT, VALUE STRING) USING PARQUET")
+ sql("INSERT INTO TABLE spark_catalog.`default`.tmp_table1 " +
+ "VALUES (1, 'aa'),(2,'bb'),(3, 'cc'),(4,'aa'),(5,'cc'),(6, 'aa')")
+ assert(
+ sql("select * from tmp_table1").queryExecution.optimizedPlan.isInstanceOf[GlobalLimit])
+ val testSqlText =
+ """
+ |select count(*)
+ |from tmp_table1
+ |where tmp_table1.key in (
+ |select distinct tmp_table1.key
+ |from tmp_table1
+ |where tmp_table1.value = "aa"
+ |)
+ |""".stripMargin
+ val plan = sql(testSqlText).queryExecution.optimizedPlan
+ assert(!findGlobalLimit(plan))
+ checkAnswer(sql(testSqlText), Row(3) :: Nil)
+ }
+
+ def findGlobalLimit(plan: LogicalPlan): Boolean = plan match {
+ case _: GlobalLimit => true
+ case p if p.children.isEmpty => false
+ case p => p.children.exists(findGlobalLimit)
+ }
+
+ }
+ }
+
+ test("test watchdog: Join for forceMaxOutputRows") {
+ withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "1") {
+ withTable("tmp_table1", "tmp_table2") {
+ sql("CREATE TABLE spark_catalog.`default`.tmp_table1(KEY INT, VALUE STRING) USING PARQUET")
+ sql("INSERT INTO TABLE spark_catalog.`default`.tmp_table1 " +
+ "VALUES (1, 'aa'),(2,'bb'),(3, 'cc'),(4,'aa'),(5,'cc'),(6, 'aa')")
+ sql("CREATE TABLE spark_catalog.`default`.tmp_table2(KEY INT, VALUE STRING) USING PARQUET")
+ sql("INSERT INTO TABLE spark_catalog.`default`.tmp_table2 " +
+ "VALUES (1, 'aa'),(2,'bb'),(3, 'cc'),(4,'aa'),(5,'cc'),(6, 'aa')")
+ val testSqlText =
+ """
+ |select a.*,b.*
+ |from tmp_table1 a
+ |join
+ |tmp_table2 b
+ |on a.KEY = b.KEY
+ |""".stripMargin
+ val plan = sql(testSqlText).queryExecution.optimizedPlan
+ assert(findGlobalLimit(plan))
+ }
+
+ def findGlobalLimit(plan: LogicalPlan): Boolean = plan match {
+ case _: GlobalLimit => true
+ case p if p.children.isEmpty => false
+ case p => p.children.exists(findGlobalLimit)
+ }
+ }
+ }
+
+ private def checkMaxFileSize(tableSize: Long, nonPartTableSize: Long): Unit = {
+ withSQLConf(KyuubiSQLConf.WATCHDOG_MAX_FILE_SIZE.key -> tableSize.toString) {
+ checkAnswer(sql("SELECT count(distinct(p)) FROM test"), Row(10) :: Nil)
+ }
+
+ withSQLConf(KyuubiSQLConf.WATCHDOG_MAX_FILE_SIZE.key -> (tableSize / 2).toString) {
+ sql("SELECT * FROM test where p=1").queryExecution.sparkPlan
+
+ sql(s"SELECT * FROM test WHERE p in (${Range(0, 3).toList.mkString(",")})")
+ .queryExecution.sparkPlan
+
+ intercept[MaxFileSizeExceedException](
+ sql("SELECT * FROM test where p != 1").queryExecution.sparkPlan)
+
+ intercept[MaxFileSizeExceedException](
+ sql("SELECT * FROM test").queryExecution.sparkPlan)
+
+ intercept[MaxFileSizeExceedException](sql(
+ s"SELECT * FROM test WHERE p in (${Range(0, 6).toList.mkString(",")})")
+ .queryExecution.sparkPlan)
+ }
+
+ withSQLConf(KyuubiSQLConf.WATCHDOG_MAX_FILE_SIZE.key -> nonPartTableSize.toString) {
+ checkAnswer(sql("SELECT count(*) FROM test_non_part"), Row(10000) :: Nil)
+ }
+
+ withSQLConf(KyuubiSQLConf.WATCHDOG_MAX_FILE_SIZE.key -> (nonPartTableSize - 1).toString) {
+ intercept[MaxFileSizeExceedException](
+ sql("SELECT * FROM test_non_part").queryExecution.sparkPlan)
+ }
+ }
+
+ test("watchdog with scan maxFileSize -- hive") {
+ Seq(false).foreach { convertMetastoreParquet =>
+ withTable("test", "test_non_part", "temp") {
+ spark.range(10000).selectExpr("id as col")
+ .createOrReplaceTempView("temp")
+
+ // partitioned table
+ sql(
+ s"""
+ |CREATE TABLE test(i int)
+ |PARTITIONED BY (p int)
+ |STORED AS parquet""".stripMargin)
+ for (part <- Range(0, 10)) {
+ sql(
+ s"""
+ |INSERT OVERWRITE TABLE test PARTITION (p='$part')
+ |select col from temp""".stripMargin)
+ }
+
+ val tablePath = new File(spark.sessionState.catalog.externalCatalog
+ .getTable("default", "test").location)
+ val tableSize = FileUtils.listFiles(tablePath, Array("parquet"), true).asScala
+ .map(_.length()).sum
+ assert(tableSize > 0)
+
+ // non-partitioned table
+ sql(
+ s"""
+ |CREATE TABLE test_non_part(i int)
+ |STORED AS parquet""".stripMargin)
+ sql(
+ s"""
+ |INSERT OVERWRITE TABLE test_non_part
+ |select col from temp""".stripMargin)
+ sql("ANALYZE TABLE test_non_part COMPUTE STATISTICS")
+
+ val nonPartTablePath = new File(spark.sessionState.catalog.externalCatalog
+ .getTable("default", "test_non_part").location)
+ val nonPartTableSize = FileUtils.listFiles(nonPartTablePath, Array("parquet"), true).asScala
+ .map(_.length()).sum
+ assert(nonPartTableSize > 0)
+
+ // check
+ withSQLConf("spark.sql.hive.convertMetastoreParquet" -> convertMetastoreParquet.toString) {
+ checkMaxFileSize(tableSize, nonPartTableSize)
+ }
+ }
+ }
+ }
+
+ test("watchdog with scan maxFileSize -- data source") {
+ withTempDir { dir =>
+ withTempView("test", "test_non_part") {
+ // partitioned table
+ val tablePath = new File(dir, "test")
+ spark.range(10).selectExpr("id", "id as p")
+ .write
+ .partitionBy("p")
+ .mode("overwrite")
+ .parquet(tablePath.getCanonicalPath)
+ spark.read.load(tablePath.getCanonicalPath).createOrReplaceTempView("test")
+
+ val tableSize = FileUtils.listFiles(tablePath, Array("parquet"), true).asScala
+ .map(_.length()).sum
+ assert(tableSize > 0)
+
+ // non-partitioned table
+ val nonPartTablePath = new File(dir, "test_non_part")
+ spark.range(10000).selectExpr("id", "id as p")
+ .write
+ .mode("overwrite")
+ .parquet(nonPartTablePath.getCanonicalPath)
+ spark.read.load(nonPartTablePath.getCanonicalPath).createOrReplaceTempView("test_non_part")
+
+ val nonPartTableSize = FileUtils.listFiles(nonPartTablePath, Array("parquet"), true).asScala
+ .map(_.length()).sum
+ assert(tableSize > 0)
+
+ // check
+ checkMaxFileSize(tableSize, nonPartTableSize)
+ }
+ }
+ }
+}
diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/ZorderCoreBenchmark.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/ZorderCoreBenchmark.scala
new file mode 100644
index 000000000..9b1614fce
--- /dev/null
+++ b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/ZorderCoreBenchmark.scala
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.SparkConf
+import org.apache.spark.benchmark.Benchmark
+import org.apache.spark.sql.benchmark.KyuubiBenchmarkBase
+import org.apache.spark.sql.internal.StaticSQLConf
+
+import org.apache.kyuubi.sql.zorder.ZorderBytesUtils
+
+/**
+ * Benchmark to measure performance with zorder core.
+ *
+ * {{{
+ * RUN_BENCHMARK=1 ./build/mvn clean test \
+ * -pl extensions/spark/kyuubi-extension-spark-3-1 -am \
+ * -Pspark-3.1,kyuubi-extension-spark-3-1 \
+ * -Dtest=none -DwildcardSuites=org.apache.spark.sql.ZorderCoreBenchmark
+ * }}}
+ */
+class ZorderCoreBenchmark extends KyuubiSparkSQLExtensionTest with KyuubiBenchmarkBase {
+ private val runBenchmark = sys.env.contains("RUN_BENCHMARK")
+ private val numRows = 1 * 1000 * 1000
+
+ private def randomInt(numColumns: Int): Seq[Array[Any]] = {
+ (1 to numRows).map { l =>
+ val arr = new Array[Any](numColumns)
+ (0 until numColumns).foreach(col => arr(col) = l)
+ arr
+ }
+ }
+
+ private def randomLong(numColumns: Int): Seq[Array[Any]] = {
+ (1 to numRows).map { l =>
+ val arr = new Array[Any](numColumns)
+ (0 until numColumns).foreach(col => arr(col) = l.toLong)
+ arr
+ }
+ }
+
+ private def interleaveMultiByteArrayBenchmark(): Unit = {
+ val benchmark =
+ new Benchmark(s"$numRows rows zorder core benchmark", numRows, output = output)
+ benchmark.addCase("2 int columns benchmark", 3) { _ =>
+ randomInt(2).foreach(ZorderBytesUtils.interleaveBits)
+ }
+
+ benchmark.addCase("3 int columns benchmark", 3) { _ =>
+ randomInt(3).foreach(ZorderBytesUtils.interleaveBits)
+ }
+
+ benchmark.addCase("4 int columns benchmark", 3) { _ =>
+ randomInt(4).foreach(ZorderBytesUtils.interleaveBits)
+ }
+
+ benchmark.addCase("2 long columns benchmark", 3) { _ =>
+ randomLong(2).foreach(ZorderBytesUtils.interleaveBits)
+ }
+
+ benchmark.addCase("3 long columns benchmark", 3) { _ =>
+ randomLong(3).foreach(ZorderBytesUtils.interleaveBits)
+ }
+
+ benchmark.addCase("4 long columns benchmark", 3) { _ =>
+ randomLong(4).foreach(ZorderBytesUtils.interleaveBits)
+ }
+
+ benchmark.run()
+ }
+
+ private def paddingTo8ByteBenchmark() {
+ val iterations = 10 * 1000 * 1000
+
+ val b2 = Array('a'.toByte, 'b'.toByte)
+ val benchmark =
+ new Benchmark(s"$iterations iterations paddingTo8Byte benchmark", iterations, output = output)
+ benchmark.addCase("2 length benchmark", 3) { _ =>
+ (1 to iterations).foreach(_ => ZorderBytesUtils.paddingTo8Byte(b2))
+ }
+
+ val b16 = Array.tabulate(16) { i => i.toByte }
+ benchmark.addCase("16 length benchmark", 3) { _ =>
+ (1 to iterations).foreach(_ => ZorderBytesUtils.paddingTo8Byte(b16))
+ }
+
+ benchmark.run()
+ }
+
+ test("zorder core benchmark") {
+ assume(runBenchmark)
+
+ withHeader {
+ interleaveMultiByteArrayBenchmark()
+ paddingTo8ByteBenchmark()
+ }
+ }
+
+ override def sparkConf(): SparkConf = {
+ super.sparkConf().remove(StaticSQLConf.SPARK_SESSION_EXTENSIONS.key)
+ }
+}
diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/ZorderSuite.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/ZorderSuite.scala
new file mode 100644
index 000000000..c2fa16197
--- /dev/null
+++ b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/ZorderSuite.scala
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.catalyst.parser.ParserInterface
+import org.apache.spark.sql.catalyst.plans.logical.{RebalancePartitions, Sort}
+import org.apache.spark.sql.internal.SQLConf
+
+import org.apache.kyuubi.sql.{KyuubiSQLConf, SparkKyuubiSparkSQLParser}
+import org.apache.kyuubi.sql.zorder.Zorder
+
+trait ZorderSuiteSpark extends ZorderSuiteBase {
+
+ test("Add rebalance before zorder") {
+ Seq("true" -> false, "false" -> true).foreach { case (useOriginalOrdering, zorder) =>
+ withSQLConf(
+ KyuubiSQLConf.ZORDER_GLOBAL_SORT_ENABLED.key -> "false",
+ KyuubiSQLConf.REBALANCE_BEFORE_ZORDER.key -> "true",
+ KyuubiSQLConf.REBALANCE_ZORDER_COLUMNS_ENABLED.key -> "true",
+ KyuubiSQLConf.ZORDER_USING_ORIGINAL_ORDERING_ENABLED.key -> useOriginalOrdering) {
+ withTable("t") {
+ sql(
+ """
+ |CREATE TABLE t (c1 int, c2 string) PARTITIONED BY (d string)
+ | TBLPROPERTIES (
+ |'kyuubi.zorder.enabled'= 'true',
+ |'kyuubi.zorder.cols'= 'c1,C2')
+ |""".stripMargin)
+ val p = sql("INSERT INTO TABLE t PARTITION(d='a') SELECT * FROM VALUES(1,'a')")
+ .queryExecution.analyzed
+ assert(p.collect {
+ case sort: Sort
+ if !sort.global &&
+ ((sort.order.exists(_.child.isInstanceOf[Zorder]) && zorder) ||
+ (!sort.order.exists(_.child.isInstanceOf[Zorder]) && !zorder)) => sort
+ }.size == 1)
+ assert(p.collect {
+ case rebalance: RebalancePartitions
+ if rebalance.references.map(_.name).exists(_.equals("c1")) => rebalance
+ }.size == 1)
+
+ val p2 = sql("INSERT INTO TABLE t PARTITION(d) SELECT * FROM VALUES(1,'a','b')")
+ .queryExecution.analyzed
+ assert(p2.collect {
+ case sort: Sort
+ if (!sort.global && Seq("c1", "c2", "d").forall(x =>
+ sort.references.map(_.name).exists(_.equals(x)))) &&
+ ((sort.order.exists(_.child.isInstanceOf[Zorder]) && zorder) ||
+ (!sort.order.exists(_.child.isInstanceOf[Zorder]) && !zorder)) => sort
+ }.size == 1)
+ assert(p2.collect {
+ case rebalance: RebalancePartitions
+ if Seq("c1", "c2", "d").forall(x =>
+ rebalance.references.map(_.name).exists(_.equals(x))) => rebalance
+ }.size == 1)
+ }
+ }
+ }
+ }
+
+ test("Two phase rebalance before Z-Order") {
+ withSQLConf(
+ SQLConf.OPTIMIZER_EXCLUDED_RULES.key ->
+ "org.apache.spark.sql.catalyst.optimizer.CollapseRepartition",
+ KyuubiSQLConf.ZORDER_GLOBAL_SORT_ENABLED.key -> "false",
+ KyuubiSQLConf.REBALANCE_BEFORE_ZORDER.key -> "true",
+ KyuubiSQLConf.TWO_PHASE_REBALANCE_BEFORE_ZORDER.key -> "true",
+ KyuubiSQLConf.REBALANCE_ZORDER_COLUMNS_ENABLED.key -> "true") {
+ withTable("t") {
+ sql(
+ """
+ |CREATE TABLE t (c1 int) PARTITIONED BY (d string)
+ | TBLPROPERTIES (
+ |'kyuubi.zorder.enabled'= 'true',
+ |'kyuubi.zorder.cols'= 'c1')
+ |""".stripMargin)
+ val p = sql("INSERT INTO TABLE t PARTITION(d) SELECT * FROM VALUES(1,'a')")
+ val rebalance = p.queryExecution.optimizedPlan.innerChildren
+ .flatMap(_.collect { case r: RebalancePartitions => r })
+ assert(rebalance.size == 2)
+ assert(rebalance.head.partitionExpressions.flatMap(_.references.map(_.name))
+ .contains("d"))
+ assert(rebalance.head.partitionExpressions.flatMap(_.references.map(_.name))
+ .contains("c1"))
+
+ assert(rebalance(1).partitionExpressions.flatMap(_.references.map(_.name))
+ .contains("d"))
+ assert(!rebalance(1).partitionExpressions.flatMap(_.references.map(_.name))
+ .contains("c1"))
+ }
+ }
+ }
+}
+
+trait ParserSuite { self: ZorderSuiteBase =>
+ override def createParser: ParserInterface = {
+ new SparkKyuubiSparkSQLParser(spark.sessionState.sqlParser)
+ }
+}
+
+class ZorderWithCodegenEnabledSuite
+ extends ZorderWithCodegenEnabledSuiteBase
+ with ZorderSuiteSpark
+ with ParserSuite {}
+class ZorderWithCodegenDisabledSuite
+ extends ZorderWithCodegenDisabledSuiteBase
+ with ZorderSuiteSpark
+ with ParserSuite {}
diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/ZorderSuiteBase.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/ZorderSuiteBase.scala
new file mode 100644
index 000000000..2d3eec957
--- /dev/null
+++ b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/ZorderSuiteBase.scala
@@ -0,0 +1,768 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.SparkConf
+import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier}
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
+import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, AttributeReference, EqualTo, Expression, ExpressionEvalHelper, Literal, NullsLast, SortOrder}
+import org.apache.spark.sql.catalyst.parser.{ParseException, ParserInterface}
+import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, OneRowRelation, Project, Sort}
+import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.hive.execution.InsertIntoHiveTable
+import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
+import org.apache.spark.sql.types._
+
+import org.apache.kyuubi.sql.{KyuubiSQLConf, KyuubiSQLExtensionException}
+import org.apache.kyuubi.sql.zorder.{OptimizeZorderCommandBase, OptimizeZorderStatement, Zorder, ZorderBytesUtils}
+
+trait ZorderSuiteBase extends KyuubiSparkSQLExtensionTest with ExpressionEvalHelper {
+ override def sparkConf(): SparkConf = {
+ super.sparkConf()
+ .set(
+ StaticSQLConf.SPARK_SESSION_EXTENSIONS.key,
+ "org.apache.kyuubi.sql.KyuubiSparkSQLCommonExtension")
+ }
+
+ test("optimize unpartitioned table") {
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+ withTable("up") {
+ sql(s"DROP TABLE IF EXISTS up")
+
+ val target = Seq(
+ Seq(0, 0),
+ Seq(1, 0),
+ Seq(0, 1),
+ Seq(1, 1),
+ Seq(2, 0),
+ Seq(3, 0),
+ Seq(2, 1),
+ Seq(3, 1),
+ Seq(0, 2),
+ Seq(1, 2),
+ Seq(0, 3),
+ Seq(1, 3),
+ Seq(2, 2),
+ Seq(3, 2),
+ Seq(2, 3),
+ Seq(3, 3))
+ sql(s"CREATE TABLE up (c1 INT, c2 INT, c3 INT)")
+ sql(s"INSERT INTO TABLE up VALUES" +
+ "(0,0,2),(0,1,2),(0,2,1),(0,3,3)," +
+ "(1,0,4),(1,1,2),(1,2,1),(1,3,3)," +
+ "(2,0,2),(2,1,1),(2,2,5),(2,3,5)," +
+ "(3,0,3),(3,1,4),(3,2,9),(3,3,0)")
+
+ val e = intercept[KyuubiSQLExtensionException] {
+ sql("OPTIMIZE up WHERE c1 > 1 ZORDER BY c1, c2")
+ }
+ assert(e.getMessage == "Filters are only supported for partitioned table")
+
+ sql("OPTIMIZE up ZORDER BY c1, c2")
+ val res = sql("SELECT c1, c2 FROM up").collect()
+
+ assert(res.length == 16)
+
+ for (i <- target.indices) {
+ val t = target(i)
+ val r = res(i)
+ assert(t(0) == r.getInt(0))
+ assert(t(1) == r.getInt(1))
+ }
+ }
+ }
+ }
+
+ test("optimize partitioned table") {
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+ withTable("p") {
+ sql("DROP TABLE IF EXISTS p")
+
+ val target = Seq(
+ Seq(0, 0),
+ Seq(1, 0),
+ Seq(0, 1),
+ Seq(1, 1),
+ Seq(2, 0),
+ Seq(3, 0),
+ Seq(2, 1),
+ Seq(3, 1),
+ Seq(0, 2),
+ Seq(1, 2),
+ Seq(0, 3),
+ Seq(1, 3),
+ Seq(2, 2),
+ Seq(3, 2),
+ Seq(2, 3),
+ Seq(3, 3))
+
+ sql(s"CREATE TABLE p (c1 INT, c2 INT, c3 INT) PARTITIONED BY (id INT)")
+ sql(s"ALTER TABLE p ADD PARTITION (id = 1)")
+ sql(s"ALTER TABLE p ADD PARTITION (id = 2)")
+ sql(s"INSERT INTO TABLE p PARTITION (id = 1) VALUES" +
+ "(0,0,2),(0,1,2),(0,2,1),(0,3,3)," +
+ "(1,0,4),(1,1,2),(1,2,1),(1,3,3)," +
+ "(2,0,2),(2,1,1),(2,2,5),(2,3,5)," +
+ "(3,0,3),(3,1,4),(3,2,9),(3,3,0)")
+ sql(s"INSERT INTO TABLE p PARTITION (id = 2) VALUES" +
+ "(0,0,2),(0,1,2),(0,2,1),(0,3,3)," +
+ "(1,0,4),(1,1,2),(1,2,1),(1,3,3)," +
+ "(2,0,2),(2,1,1),(2,2,5),(2,3,5)," +
+ "(3,0,3),(3,1,4),(3,2,9),(3,3,0)")
+
+ sql(s"OPTIMIZE p ZORDER BY c1, c2")
+
+ val res1 = sql(s"SELECT c1, c2 FROM p WHERE id = 1").collect()
+ val res2 = sql(s"SELECT c1, c2 FROM p WHERE id = 2").collect()
+
+ assert(res1.length == 16)
+ assert(res2.length == 16)
+
+ for (i <- target.indices) {
+ val t = target(i)
+ val r1 = res1(i)
+ assert(t(0) == r1.getInt(0))
+ assert(t(1) == r1.getInt(1))
+
+ val r2 = res2(i)
+ assert(t(0) == r2.getInt(0))
+ assert(t(1) == r2.getInt(1))
+ }
+ }
+ }
+ }
+
+ test("optimize partitioned table with filters") {
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+ withTable("p") {
+ sql("DROP TABLE IF EXISTS p")
+
+ val target1 = Seq(
+ Seq(0, 0),
+ Seq(1, 0),
+ Seq(0, 1),
+ Seq(1, 1),
+ Seq(2, 0),
+ Seq(3, 0),
+ Seq(2, 1),
+ Seq(3, 1),
+ Seq(0, 2),
+ Seq(1, 2),
+ Seq(0, 3),
+ Seq(1, 3),
+ Seq(2, 2),
+ Seq(3, 2),
+ Seq(2, 3),
+ Seq(3, 3))
+ val target2 = Seq(
+ Seq(0, 0),
+ Seq(0, 1),
+ Seq(0, 2),
+ Seq(0, 3),
+ Seq(1, 0),
+ Seq(1, 1),
+ Seq(1, 2),
+ Seq(1, 3),
+ Seq(2, 0),
+ Seq(2, 1),
+ Seq(2, 2),
+ Seq(2, 3),
+ Seq(3, 0),
+ Seq(3, 1),
+ Seq(3, 2),
+ Seq(3, 3))
+ sql(s"CREATE TABLE p (c1 INT, c2 INT, c3 INT) PARTITIONED BY (id INT)")
+ sql(s"ALTER TABLE p ADD PARTITION (id = 1)")
+ sql(s"ALTER TABLE p ADD PARTITION (id = 2)")
+ sql(s"INSERT INTO TABLE p PARTITION (id = 1) VALUES" +
+ "(0,0,2),(0,1,2),(0,2,1),(0,3,3)," +
+ "(1,0,4),(1,1,2),(1,2,1),(1,3,3)," +
+ "(2,0,2),(2,1,1),(2,2,5),(2,3,5)," +
+ "(3,0,3),(3,1,4),(3,2,9),(3,3,0)")
+ sql(s"INSERT INTO TABLE p PARTITION (id = 2) VALUES" +
+ "(0,0,2),(0,1,2),(0,2,1),(0,3,3)," +
+ "(1,0,4),(1,1,2),(1,2,1),(1,3,3)," +
+ "(2,0,2),(2,1,1),(2,2,5),(2,3,5)," +
+ "(3,0,3),(3,1,4),(3,2,9),(3,3,0)")
+
+ val e = intercept[KyuubiSQLExtensionException](
+ sql(s"OPTIMIZE p WHERE id = 1 AND c1 > 1 ZORDER BY c1, c2"))
+ assert(e.getMessage == "Only partition column filters are allowed")
+
+ sql(s"OPTIMIZE p WHERE id = 1 ZORDER BY c1, c2")
+
+ val res1 = sql(s"SELECT c1, c2 FROM p WHERE id = 1").collect()
+ val res2 = sql(s"SELECT c1, c2 FROM p WHERE id = 2").collect()
+
+ assert(res1.length == 16)
+ assert(res2.length == 16)
+
+ for (i <- target1.indices) {
+ val t1 = target1(i)
+ val r1 = res1(i)
+ assert(t1(0) == r1.getInt(0))
+ assert(t1(1) == r1.getInt(1))
+
+ val t2 = target2(i)
+ val r2 = res2(i)
+ assert(t2(0) == r2.getInt(0))
+ assert(t2(1) == r2.getInt(1))
+ }
+ }
+ }
+ }
+
+ test("optimize zorder with datasource table") {
+ // TODO remove this if we support datasource table
+ withTable("t") {
+ sql("CREATE TABLE t (c1 int, c2 int) USING PARQUET")
+ val msg = intercept[KyuubiSQLExtensionException] {
+ sql("OPTIMIZE t ZORDER BY c1, c2")
+ }.getMessage
+ assert(msg.contains("only support hive table"))
+ }
+ }
+
+ private def checkZorderTable(
+ enabled: Boolean,
+ cols: String,
+ planHasRepartition: Boolean,
+ resHasSort: Boolean): Unit = {
+ def checkSort(plan: LogicalPlan): Unit = {
+ assert(plan.isInstanceOf[Sort] === resHasSort)
+ plan match {
+ case sort: Sort =>
+ val colArr = cols.split(",")
+ val refs =
+ if (colArr.length == 1) {
+ sort.order.head
+ .child.asInstanceOf[AttributeReference] :: Nil
+ } else {
+ sort.order.head
+ .child.asInstanceOf[Zorder].children.map(_.references.head)
+ }
+ assert(refs.size === colArr.size)
+ refs.zip(colArr).foreach { case (ref, col) =>
+ assert(ref.name === col.trim)
+ }
+ case _ =>
+ }
+ }
+
+ val repartition =
+ if (planHasRepartition) {
+ "/*+ repartition */"
+ } else {
+ ""
+ }
+ withSQLConf("spark.sql.shuffle.partitions" -> "1") {
+ // hive
+ withSQLConf("spark.sql.hive.convertMetastoreParquet" -> "false") {
+ withTable("zorder_t1", "zorder_t2_true", "zorder_t2_false") {
+ sql(
+ s"""
+ |CREATE TABLE zorder_t1 (c1 int, c2 string, c3 long, c4 double) STORED AS PARQUET
+ |TBLPROPERTIES (
+ | 'kyuubi.zorder.enabled' = '$enabled',
+ | 'kyuubi.zorder.cols' = '$cols')
+ |""".stripMargin)
+ val df1 = sql(s"""
+ |INSERT INTO TABLE zorder_t1
+ |SELECT $repartition * FROM VALUES(1,'a',2,4D),(2,'b',3,6D)
+ |""".stripMargin)
+ assert(df1.queryExecution.analyzed.isInstanceOf[InsertIntoHiveTable])
+ checkSort(df1.queryExecution.analyzed.children.head)
+
+ Seq("true", "false").foreach { optimized =>
+ withSQLConf(
+ "spark.sql.hive.convertMetastoreCtas" -> optimized,
+ "spark.sql.hive.convertMetastoreParquet" -> optimized) {
+
+ withListener(
+ s"""
+ |CREATE TABLE zorder_t2_$optimized STORED AS PARQUET
+ |TBLPROPERTIES (
+ | 'kyuubi.zorder.enabled' = '$enabled',
+ | 'kyuubi.zorder.cols' = '$cols')
+ |
+ |SELECT $repartition * FROM
+ |VALUES(1,'a',2,4D),(2,'b',3,6D) AS t(c1 ,c2 , c3, c4)
+ |""".stripMargin) { write =>
+ if (optimized.toBoolean) {
+ assert(write.isInstanceOf[InsertIntoHadoopFsRelationCommand])
+ } else {
+ assert(write.isInstanceOf[InsertIntoHiveTable])
+ }
+ checkSort(write.query)
+ }
+ }
+ }
+ }
+ }
+
+ // datasource
+ withTable("zorder_t3", "zorder_t4") {
+ sql(
+ s"""
+ |CREATE TABLE zorder_t3 (c1 int, c2 string, c3 long, c4 double) USING PARQUET
+ |TBLPROPERTIES (
+ | 'kyuubi.zorder.enabled' = '$enabled',
+ | 'kyuubi.zorder.cols' = '$cols')
+ |""".stripMargin)
+ val df1 = sql(s"""
+ |INSERT INTO TABLE zorder_t3
+ |SELECT $repartition * FROM VALUES(1,'a',2,4D),(2,'b',3,6D)
+ |""".stripMargin)
+ assert(df1.queryExecution.analyzed.isInstanceOf[InsertIntoHadoopFsRelationCommand])
+ checkSort(df1.queryExecution.analyzed.children.head)
+
+ withListener(
+ s"""
+ |CREATE TABLE zorder_t4 USING PARQUET
+ |TBLPROPERTIES (
+ | 'kyuubi.zorder.enabled' = '$enabled',
+ | 'kyuubi.zorder.cols' = '$cols')
+ |
+ |SELECT $repartition * FROM
+ |VALUES(1,'a',2,4D),(2,'b',3,6D) AS t(c1 ,c2 , c3, c4)
+ |""".stripMargin) { write =>
+ assert(write.isInstanceOf[InsertIntoHadoopFsRelationCommand])
+ checkSort(write.query)
+ }
+ }
+ }
+ }
+
+ test("Support insert zorder by table properties") {
+ withSQLConf(KyuubiSQLConf.INSERT_ZORDER_BEFORE_WRITING.key -> "false") {
+ checkZorderTable(true, "c1", false, false)
+ checkZorderTable(false, "c1", false, false)
+ }
+ withSQLConf(KyuubiSQLConf.INSERT_ZORDER_BEFORE_WRITING.key -> "true") {
+ checkZorderTable(true, "", false, false)
+ checkZorderTable(true, "c5", false, false)
+ checkZorderTable(true, "c1,c5", false, false)
+ checkZorderTable(false, "c3", false, false)
+ checkZorderTable(true, "c3", true, false)
+ checkZorderTable(true, "c3", false, true)
+ checkZorderTable(true, "c2,c4", false, true)
+ checkZorderTable(true, "c4, c2, c1, c3", false, true)
+ }
+ }
+
+ test("zorder: check unsupported data type") {
+ def checkZorderPlan(zorder: Expression): Unit = {
+ val msg = intercept[AnalysisException] {
+ val plan = Project(Seq(Alias(zorder, "c")()), OneRowRelation())
+ spark.sessionState.analyzer.checkAnalysis(plan)
+ }.getMessage
+ // before Spark 3.2.0 the null type catalog string is null, after Spark 3.2.0 it's void
+ // see https://github.com/apache/spark/pull/33437
+ assert(msg.contains("Unsupported z-order type:") &&
+ (msg.contains("null") || msg.contains("void")))
+ }
+
+ checkZorderPlan(Zorder(Seq(Literal(null, NullType))))
+ checkZorderPlan(Zorder(Seq(Literal(1, IntegerType), Literal(null, NullType))))
+ }
+
+ test("zorder: check supported data type") {
+ val children = Seq(
+ Literal.create(false, BooleanType),
+ Literal.create(null, BooleanType),
+ Literal.create(1.toByte, ByteType),
+ Literal.create(null, ByteType),
+ Literal.create(1.toShort, ShortType),
+ Literal.create(null, ShortType),
+ Literal.create(1, IntegerType),
+ Literal.create(null, IntegerType),
+ Literal.create(1L, LongType),
+ Literal.create(null, LongType),
+ Literal.create(1f, FloatType),
+ Literal.create(null, FloatType),
+ Literal.create(1d, DoubleType),
+ Literal.create(null, DoubleType),
+ Literal.create("1", StringType),
+ Literal.create(null, StringType),
+ Literal.create(1L, TimestampType),
+ Literal.create(null, TimestampType),
+ Literal.create(1, DateType),
+ Literal.create(null, DateType),
+ Literal.create(BigDecimal(1, 1), DecimalType(1, 1)),
+ Literal.create(null, DecimalType(1, 1)))
+ val zorder = Zorder(children)
+ val plan = Project(Seq(Alias(zorder, "c")()), OneRowRelation())
+ spark.sessionState.analyzer.checkAnalysis(plan)
+ assert(zorder.foldable)
+
+// // scalastyle:off
+// val resultGen = org.apache.commons.codec.binary.Hex.encodeHex(
+// zorder.eval(InternalRow.fromSeq(children)).asInstanceOf[Array[Byte]], false)
+// resultGen.grouped(2).zipWithIndex.foreach { case (char, i) =>
+// print("0x" + char(0) + char(1) + ", ")
+// if ((i + 1) % 10 == 0) {
+// println()
+// }
+// }
+// // scalastyle:on
+
+ val expected = Array(
+ 0xFB, 0xEA, 0xAA, 0xBA, 0xAE, 0xAB, 0xAA, 0xEA, 0xBA, 0xAE, 0xAB, 0xAA, 0xEA, 0xBA, 0xA6,
+ 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA,
+ 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xBA, 0xBB, 0xAA, 0xAA, 0xAA,
+ 0xBA, 0xAA, 0xBA, 0xAA, 0xBA, 0xAA, 0xBA, 0xAA, 0xBA, 0xAA, 0xBA, 0xAA, 0x9A, 0xAA, 0xAA,
+ 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xEA,
+ 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA,
+ 0xAA, 0xAA, 0xBE, 0xAA, 0xAA, 0x8A, 0xBA, 0xAA, 0x2A, 0xEA, 0xA8, 0xAA, 0xAA, 0xA2, 0xAA,
+ 0xAA, 0x8A, 0xAA, 0xAA, 0x2F, 0xEB, 0xFE)
+ .map(_.toByte)
+ checkEvaluation(zorder, expected, InternalRow.fromSeq(children))
+ }
+
+ private def checkSort(input: DataFrame, expected: Seq[Row], dataType: Array[DataType]): Unit = {
+ withTempDir { dir =>
+ input.repartition(3).write.mode("overwrite").format("parquet").save(dir.getCanonicalPath)
+ val df = spark.read.format("parquet")
+ .load(dir.getCanonicalPath)
+ .repartition(1)
+ assert(df.schema.fields.map(_.dataType).sameElements(dataType))
+ val exprs = Seq("c1", "c2").map(col).map(_.expr)
+ val sortOrder = SortOrder(Zorder(exprs), Ascending, NullsLast, Seq.empty)
+ val zorderSort = Sort(Seq(sortOrder), true, df.logicalPlan)
+ val result = Dataset.ofRows(spark, zorderSort)
+ checkAnswer(result, expected)
+ }
+ }
+
+ test("sort with zorder -- boolean column") {
+ val schema = StructType(StructField("c1", BooleanType) :: StructField("c2", BooleanType) :: Nil)
+ val nonNullDF = spark.createDataFrame(
+ spark.sparkContext.parallelize(
+ Seq(Row(false, false), Row(false, true), Row(true, false), Row(true, true))),
+ schema)
+ val expected =
+ Row(false, false) :: Row(true, false) :: Row(false, true) :: Row(true, true) :: Nil
+ checkSort(nonNullDF, expected, Array(BooleanType, BooleanType))
+ val df = spark.createDataFrame(
+ spark.sparkContext.parallelize(
+ Seq(Row(false, false), Row(false, null), Row(null, false), Row(null, null))),
+ schema)
+ val expected2 =
+ Row(false, false) :: Row(null, false) :: Row(false, null) :: Row(null, null) :: Nil
+ checkSort(df, expected2, Array(BooleanType, BooleanType))
+ }
+
+ test("sort with zorder -- int column") {
+ // TODO: add more datatype unit test
+ val session = spark
+ import session.implicits._
+ // generate 4 * 4 matrix
+ val len = 3
+ val input = spark.range(len + 1).selectExpr("cast(id as int) as c1")
+ .select($"c1", explode(sequence(lit(0), lit(len))) as "c2")
+ val expected =
+ Row(0, 0) :: Row(1, 0) :: Row(0, 1) :: Row(1, 1) ::
+ Row(2, 0) :: Row(3, 0) :: Row(2, 1) :: Row(3, 1) ::
+ Row(0, 2) :: Row(1, 2) :: Row(0, 3) :: Row(1, 3) ::
+ Row(2, 2) :: Row(3, 2) :: Row(2, 3) :: Row(3, 3) :: Nil
+ checkSort(input, expected, Array(IntegerType, IntegerType))
+
+ // contains null value case.
+ val nullDF = spark.range(1).selectExpr("cast(null as int) as c1")
+ val input2 = spark.range(len).selectExpr("cast(id as int) as c1")
+ .union(nullDF)
+ .select(
+ $"c1",
+ explode(concat(sequence(lit(0), lit(len - 1)), array(lit(null)))) as "c2")
+ val expected2 = Row(0, 0) :: Row(1, 0) :: Row(0, 1) :: Row(1, 1) ::
+ Row(2, 0) :: Row(2, 1) :: Row(0, 2) :: Row(1, 2) ::
+ Row(2, 2) :: Row(null, 0) :: Row(null, 1) :: Row(null, 2) ::
+ Row(0, null) :: Row(1, null) :: Row(2, null) :: Row(null, null) :: Nil
+ checkSort(input2, expected2, Array(IntegerType, IntegerType))
+ }
+
+ test("sort with zorder -- string column") {
+ val schema = StructType(StructField("c1", StringType) :: StructField("c2", StringType) :: Nil)
+ val rdd = spark.sparkContext.parallelize(Seq(
+ Row("a", "a"),
+ Row("a", "b"),
+ Row("a", "c"),
+ Row("a", "d"),
+ Row("b", "a"),
+ Row("b", "b"),
+ Row("b", "c"),
+ Row("b", "d"),
+ Row("c", "a"),
+ Row("c", "b"),
+ Row("c", "c"),
+ Row("c", "d"),
+ Row("d", "a"),
+ Row("d", "b"),
+ Row("d", "c"),
+ Row("d", "d")))
+ val input = spark.createDataFrame(rdd, schema)
+ val expected = Row("a", "a") :: Row("b", "a") :: Row("c", "a") :: Row("a", "b") ::
+ Row("a", "c") :: Row("b", "b") :: Row("c", "b") :: Row("b", "c") ::
+ Row("c", "c") :: Row("d", "a") :: Row("d", "b") :: Row("d", "c") ::
+ Row("a", "d") :: Row("b", "d") :: Row("c", "d") :: Row("d", "d") :: Nil
+ checkSort(input, expected, Array(StringType, StringType))
+
+ val rdd2 = spark.sparkContext.parallelize(Seq(
+ Row(null, "a"),
+ Row("a", "b"),
+ Row("a", "c"),
+ Row("a", null),
+ Row("b", "a"),
+ Row(null, "b"),
+ Row("b", null),
+ Row("b", "d"),
+ Row("c", "a"),
+ Row("c", null),
+ Row(null, "c"),
+ Row("c", "d"),
+ Row("d", null),
+ Row("d", "b"),
+ Row("d", "c"),
+ Row(null, "d"),
+ Row(null, null)))
+ val input2 = spark.createDataFrame(rdd2, schema)
+ val expected2 = Row("b", "a") :: Row("c", "a") :: Row("a", "b") :: Row("a", "c") ::
+ Row("d", "b") :: Row("d", "c") :: Row("b", "d") :: Row("c", "d") ::
+ Row(null, "a") :: Row(null, "b") :: Row(null, "c") :: Row(null, "d") ::
+ Row("a", null) :: Row("b", null) :: Row("c", null) :: Row("d", null) ::
+ Row(null, null) :: Nil
+ checkSort(input2, expected2, Array(StringType, StringType))
+ }
+
+ test("test special value of short int long type") {
+ val df1 = spark.createDataFrame(Seq(
+ (-1, -1L),
+ (Int.MinValue, Int.MinValue.toLong),
+ (1, 1L),
+ (Int.MaxValue - 1, Int.MaxValue.toLong),
+ (Int.MaxValue - 1, Int.MaxValue.toLong - 1),
+ (Int.MaxValue, Int.MaxValue.toLong + 1),
+ (Int.MaxValue, Int.MaxValue.toLong))).toDF("c1", "c2")
+ val expected1 =
+ Row(Int.MinValue, Int.MinValue.toLong) ::
+ Row(-1, -1L) ::
+ Row(1, 1L) ::
+ Row(Int.MaxValue - 1, Int.MaxValue.toLong - 1) ::
+ Row(Int.MaxValue - 1, Int.MaxValue.toLong) ::
+ Row(Int.MaxValue, Int.MaxValue.toLong) ::
+ Row(Int.MaxValue, Int.MaxValue.toLong + 1) :: Nil
+ checkSort(df1, expected1, Array(IntegerType, LongType))
+
+ val df2 = spark.createDataFrame(Seq(
+ (-1, -1.toShort),
+ (Short.MinValue.toInt, Short.MinValue),
+ (1, 1.toShort),
+ (Short.MaxValue.toInt, (Short.MaxValue - 1).toShort),
+ (Short.MaxValue.toInt + 1, (Short.MaxValue - 1).toShort),
+ (Short.MaxValue.toInt, Short.MaxValue),
+ (Short.MaxValue.toInt + 1, Short.MaxValue))).toDF("c1", "c2")
+ val expected2 =
+ Row(Short.MinValue.toInt, Short.MinValue) ::
+ Row(-1, -1.toShort) ::
+ Row(1, 1.toShort) ::
+ Row(Short.MaxValue.toInt, Short.MaxValue - 1) ::
+ Row(Short.MaxValue.toInt, Short.MaxValue) ::
+ Row(Short.MaxValue.toInt + 1, Short.MaxValue - 1) ::
+ Row(Short.MaxValue.toInt + 1, Short.MaxValue) :: Nil
+ checkSort(df2, expected2, Array(IntegerType, ShortType))
+
+ val df3 = spark.createDataFrame(Seq(
+ (-1L, -1.toShort),
+ (Short.MinValue.toLong, Short.MinValue),
+ (1L, 1.toShort),
+ (Short.MaxValue.toLong, (Short.MaxValue - 1).toShort),
+ (Short.MaxValue.toLong + 1, (Short.MaxValue - 1).toShort),
+ (Short.MaxValue.toLong, Short.MaxValue),
+ (Short.MaxValue.toLong + 1, Short.MaxValue))).toDF("c1", "c2")
+ val expected3 =
+ Row(Short.MinValue.toLong, Short.MinValue) ::
+ Row(-1L, -1.toShort) ::
+ Row(1L, 1.toShort) ::
+ Row(Short.MaxValue.toLong, Short.MaxValue - 1) ::
+ Row(Short.MaxValue.toLong, Short.MaxValue) ::
+ Row(Short.MaxValue.toLong + 1, Short.MaxValue - 1) ::
+ Row(Short.MaxValue.toLong + 1, Short.MaxValue) :: Nil
+ checkSort(df3, expected3, Array(LongType, ShortType))
+ }
+
+ test("skip zorder if only requires one column") {
+ withTable("t") {
+ withSQLConf("spark.sql.hive.convertMetastoreParquet" -> "false") {
+ sql("CREATE TABLE t (c1 int, c2 string) stored as parquet")
+ val order1 = sql("OPTIMIZE t ZORDER BY c1").queryExecution.analyzed
+ .asInstanceOf[OptimizeZorderCommandBase].query.asInstanceOf[Sort].order.head.child
+ assert(!order1.isInstanceOf[Zorder])
+ assert(order1.isInstanceOf[AttributeReference])
+ }
+ }
+ }
+
+ test("Add config to control if zorder using global sort") {
+ withTable("t") {
+ withSQLConf(KyuubiSQLConf.ZORDER_GLOBAL_SORT_ENABLED.key -> "false") {
+ sql(
+ """
+ |CREATE TABLE t (c1 int, c2 string) TBLPROPERTIES (
+ |'kyuubi.zorder.enabled'= 'true',
+ |'kyuubi.zorder.cols'= 'c1,c2')
+ |""".stripMargin)
+ val p1 = sql("OPTIMIZE t ZORDER BY c1, c2").queryExecution.analyzed
+ assert(p1.collect {
+ case shuffle: Sort if !shuffle.global => shuffle
+ }.size == 1)
+
+ val p2 = sql("INSERT INTO TABLE t SELECT * FROM VALUES(1,'a')").queryExecution.analyzed
+ assert(p2.collect {
+ case shuffle: Sort if !shuffle.global => shuffle
+ }.size == 1)
+ }
+ }
+ }
+
+ test("fast approach test") {
+ Seq[Seq[Any]](
+ Seq(1L, 2L),
+ Seq(1L, 2L, 3L),
+ Seq(1L, 2L, 3L, 4L),
+ Seq(1L, 2L, 3L, 4L, 5L),
+ Seq(1L, 2L, 3L, 4L, 5L, 6L),
+ Seq(1L, 2L, 3L, 4L, 5L, 6L, 7L),
+ Seq(1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L))
+ .foreach { inputs =>
+ assert(java.util.Arrays.equals(
+ ZorderBytesUtils.interleaveBits(inputs.toArray),
+ ZorderBytesUtils.interleaveBitsDefault(inputs.map(ZorderBytesUtils.toByteArray).toArray)))
+ }
+ }
+
+ test("OPTIMIZE command is parsed as expected") {
+ val parser = createParser
+ val globalSort = spark.conf.get(KyuubiSQLConf.ZORDER_GLOBAL_SORT_ENABLED)
+
+ assert(parser.parsePlan("OPTIMIZE p zorder by c1") ===
+ OptimizeZorderStatement(
+ Seq("p"),
+ Sort(
+ SortOrder(UnresolvedAttribute("c1"), Ascending, NullsLast, Seq.empty) :: Nil,
+ globalSort,
+ Project(Seq(UnresolvedStar(None)), UnresolvedRelation(TableIdentifier("p"))))))
+
+ assert(parser.parsePlan("OPTIMIZE p zorder by c1, c2") ===
+ OptimizeZorderStatement(
+ Seq("p"),
+ Sort(
+ SortOrder(
+ Zorder(Seq(UnresolvedAttribute("c1"), UnresolvedAttribute("c2"))),
+ Ascending,
+ NullsLast,
+ Seq.empty) :: Nil,
+ globalSort,
+ Project(Seq(UnresolvedStar(None)), UnresolvedRelation(TableIdentifier("p"))))))
+
+ assert(parser.parsePlan("OPTIMIZE p where id = 1 zorder by c1") ===
+ OptimizeZorderStatement(
+ Seq("p"),
+ Sort(
+ SortOrder(UnresolvedAttribute("c1"), Ascending, NullsLast, Seq.empty) :: Nil,
+ globalSort,
+ Project(
+ Seq(UnresolvedStar(None)),
+ Filter(
+ EqualTo(UnresolvedAttribute("id"), Literal(1)),
+ UnresolvedRelation(TableIdentifier("p")))))))
+
+ assert(parser.parsePlan("OPTIMIZE p where id = 1 zorder by c1, c2") ===
+ OptimizeZorderStatement(
+ Seq("p"),
+ Sort(
+ SortOrder(
+ Zorder(Seq(UnresolvedAttribute("c1"), UnresolvedAttribute("c2"))),
+ Ascending,
+ NullsLast,
+ Seq.empty) :: Nil,
+ globalSort,
+ Project(
+ Seq(UnresolvedStar(None)),
+ Filter(
+ EqualTo(UnresolvedAttribute("id"), Literal(1)),
+ UnresolvedRelation(TableIdentifier("p")))))))
+
+ assert(parser.parsePlan("OPTIMIZE p where id = current_date() zorder by c1") ===
+ OptimizeZorderStatement(
+ Seq("p"),
+ Sort(
+ SortOrder(UnresolvedAttribute("c1"), Ascending, NullsLast, Seq.empty) :: Nil,
+ globalSort,
+ Project(
+ Seq(UnresolvedStar(None)),
+ Filter(
+ EqualTo(
+ UnresolvedAttribute("id"),
+ UnresolvedFunction("current_date", Seq.empty, false)),
+ UnresolvedRelation(TableIdentifier("p")))))))
+
+ // TODO: add following case support
+ intercept[ParseException] {
+ parser.parsePlan("OPTIMIZE p zorder by (c1)")
+ }
+
+ intercept[ParseException] {
+ parser.parsePlan("OPTIMIZE p zorder by (c1, c2)")
+ }
+ }
+
+ test("OPTIMIZE partition predicates constraint") {
+ withTable("p") {
+ sql("CREATE TABLE p (c1 INT, c2 INT) PARTITIONED BY (event_date DATE)")
+ val e1 = intercept[KyuubiSQLExtensionException] {
+ sql("OPTIMIZE p WHERE event_date = current_date as c ZORDER BY c1, c2")
+ }
+ assert(e1.getMessage.contains("unsupported partition predicates"))
+
+ val e2 = intercept[KyuubiSQLExtensionException] {
+ sql("OPTIMIZE p WHERE c1 = 1 ZORDER BY c1, c2")
+ }
+ assert(e2.getMessage == "Only partition column filters are allowed")
+ }
+ }
+
+ def createParser: ParserInterface
+}
+
+trait ZorderWithCodegenEnabledSuiteBase extends ZorderSuiteBase {
+ override def sparkConf(): SparkConf = {
+ val conf = super.sparkConf
+ conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true")
+ conf
+ }
+}
+
+trait ZorderWithCodegenDisabledSuiteBase extends ZorderSuiteBase {
+ override def sparkConf(): SparkConf = {
+ val conf = super.sparkConf
+ conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "false")
+ conf.set(SQLConf.CODEGEN_FACTORY_MODE.key, "NO_CODEGEN")
+ conf
+ }
+}
diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/benchmark/KyuubiBenchmarkBase.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/benchmark/KyuubiBenchmarkBase.scala
new file mode 100644
index 000000000..b891a7224
--- /dev/null
+++ b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/benchmark/KyuubiBenchmarkBase.scala
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.benchmark
+
+import java.io.{File, FileOutputStream, OutputStream}
+
+import scala.collection.JavaConverters._
+
+import com.google.common.reflect.ClassPath
+import org.scalatest.Assertions._
+
+trait KyuubiBenchmarkBase {
+ var output: Option[OutputStream] = None
+
+ private val prefix = {
+ val benchmarkClasses = ClassPath.from(Thread.currentThread.getContextClassLoader)
+ .getTopLevelClassesRecursive("org.apache.spark.sql").asScala.toArray
+ assert(benchmarkClasses.nonEmpty)
+ val benchmark = benchmarkClasses.find(_.load().getName.endsWith("Benchmark"))
+ val targetDirOrProjDir =
+ new File(benchmark.get.load().getProtectionDomain.getCodeSource.getLocation.toURI)
+ .getParentFile.getParentFile
+ if (targetDirOrProjDir.getName == "target") {
+ targetDirOrProjDir.getParentFile.getCanonicalPath + "/"
+ } else {
+ targetDirOrProjDir.getCanonicalPath + "/"
+ }
+ }
+
+ def withHeader(func: => Unit): Unit = {
+ val version = System.getProperty("java.version").split("\\D+")(0).toInt
+ val jdkString = if (version > 8) s"-jdk$version" else ""
+ val resultFileName =
+ s"${this.getClass.getSimpleName.replace("$", "")}$jdkString-results.txt"
+ val dir = new File(s"${prefix}benchmarks/")
+ if (!dir.exists()) {
+ // scalastyle:off println
+ println(s"Creating ${dir.getAbsolutePath} for benchmark results.")
+ // scalastyle:on println
+ dir.mkdirs()
+ }
+ val file = new File(dir, resultFileName)
+ if (!file.exists()) {
+ file.createNewFile()
+ }
+ output = Some(new FileOutputStream(file))
+
+ func
+
+ output.foreach { o =>
+ if (o != null) {
+ o.close()
+ }
+ }
+ }
+}
diff --git a/extensions/spark/kyuubi-spark-authz/src/main/resources/table_command_spec.json b/extensions/spark/kyuubi-spark-authz/src/main/resources/table_command_spec.json
index de0e03cac..dad13baa1 100644
--- a/extensions/spark/kyuubi-spark-authz/src/main/resources/table_command_spec.json
+++ b/extensions/spark/kyuubi-spark-authz/src/main/resources/table_command_spec.json
@@ -143,7 +143,7 @@
}, {
"classname" : "org.apache.spark.sql.catalyst.plans.logical.CreateTableAsSelect",
"tableDescs" : [ {
- "fieldName" : "left",
+ "fieldName" : "name",
"fieldExtractor" : "ResolvedIdentifierTableExtractor",
"columnDesc" : null,
"actionTypeDesc" : null,
@@ -164,7 +164,7 @@
"isInput" : false,
"setCurrentDatabaseIfMissing" : false
}, {
- "fieldName" : "left",
+ "fieldName" : "name",
"fieldExtractor" : "ResolvedDbObjectNameTableExtractor",
"columnDesc" : null,
"actionTypeDesc" : null,
@@ -494,7 +494,7 @@
}, {
"classname" : "org.apache.spark.sql.catalyst.plans.logical.ReplaceTableAsSelect",
"tableDescs" : [ {
- "fieldName" : "left",
+ "fieldName" : "name",
"fieldExtractor" : "ResolvedIdentifierTableExtractor",
"columnDesc" : null,
"actionTypeDesc" : null,
@@ -515,7 +515,7 @@
"isInput" : false,
"setCurrentDatabaseIfMissing" : false
}, {
- "fieldName" : "left",
+ "fieldName" : "name",
"fieldExtractor" : "ResolvedDbObjectNameTableExtractor",
"columnDesc" : null,
"actionTypeDesc" : null,
diff --git a/extensions/spark/kyuubi-spark-authz/src/test/scala/org/apache/kyuubi/plugin/spark/authz/gen/TableCommands.scala b/extensions/spark/kyuubi-spark-authz/src/test/scala/org/apache/kyuubi/plugin/spark/authz/gen/TableCommands.scala
index ca2ee9294..6a6800210 100644
--- a/extensions/spark/kyuubi-spark-authz/src/test/scala/org/apache/kyuubi/plugin/spark/authz/gen/TableCommands.scala
+++ b/extensions/spark/kyuubi-spark-authz/src/test/scala/org/apache/kyuubi/plugin/spark/authz/gen/TableCommands.scala
@@ -234,9 +234,9 @@ object TableCommands {
TableCommandSpec(
cmd,
Seq(
- resolvedIdentifierTableDesc.copy(fieldName = "left"),
+ resolvedIdentifierTableDesc.copy(fieldName = "name"),
tableDesc,
- resolvedDbObjectNameDesc.copy(fieldName = "left")),
+ resolvedDbObjectNameDesc.copy(fieldName = "name")),
CREATETABLE_AS_SELECT,
Seq(queryQueryDesc))
}
diff --git a/pom.xml b/pom.xml
index 94349d21b..12b41bbad 100644
--- a/pom.xml
+++ b/pom.xml
@@ -2241,6 +2241,19 @@
</properties>
</profile>
+ <profile>
+ <id>spark-3.5</id>
+ <modules>
+ <module>extensions/spark/kyuubi-extension-spark-3-5</module>
+ </modules>
+ <properties>
+ <delta.version>2.4.0</delta.version>
+ <spark.version>3.5.0</spark.version>
+ <spark.binary.version>3.5</spark.binary.version>
+ <maven.plugin.scalatest.exclude.tags>org.scalatest.tags.Slow,org.apache.kyuubi.tags.DeltaTest,org.apache.kyuubi.tags.IcebergTest,org.apache.kyuubi.tags.PySparkTest</maven.plugin.scalatest.exclude.tags>
+ </properties>
+ </profile>
+
<profile>
<id>spark-master</id>
<properties>