You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iceberg.apache.org by bl...@apache.org on 2020/11/06 22:43:26 UTC
[iceberg] branch master updated: Spark: Add Spark3 extensions
module (#1728)
This is an automated email from the ASF dual-hosted git repository.
blue pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iceberg.git
The following commit(s) were added to refs/heads/master by this push:
new da8b85c Spark: Add Spark3 extensions module (#1728)
da8b85c is described below
commit da8b85c5d3f00815ebc9c570c04c2581a8f51eb7
Author: Anton Okolnychyi <ao...@apple.com>
AuthorDate: Fri Nov 6 14:43:16 2020 -0800
Spark: Add Spark3 extensions module (#1728)
---
.github/labeler.yml | 1 +
LICENSE | 2 +
baseline.gradle | 1 +
build.gradle | 43 ++++
project/scalastyle_config.xml | 147 ++++++++++++++
settings.gradle | 2 +
.../IcebergSqlExtensions.g4 | 219 +++++++++++++++++++++
.../extensions/IcebergSparkSessionExtensions.scala | 32 +++
.../IcebergSparkSqlExtensionsParser.scala | 180 +++++++++++++++++
.../IcebergSqlExtensionsAstBuilder.scala | 76 +++++++
.../sql/catalyst/plans/logical/statements.scala | 32 +++
.../v2/ExtendedDataSourceV2Strategy.scala | 33 ++++
.../spark/extensions/TestCallStatementParser.java | 160 +++++++++++++++
versions.props | 1 +
14 files changed, 929 insertions(+)
diff --git a/.github/labeler.yml b/.github/labeler.yml
index 01ecf31..8b5778a 100644
--- a/.github/labeler.yml
+++ b/.github/labeler.yml
@@ -63,6 +63,7 @@ SPARK:
- spark/**/*
- spark2/**/*
- spark3/**/*
+ - spark3-extensions/**/*
FLINK:
- flink-runtime/**/*
- flink/**/*
diff --git a/LICENSE b/LICENSE
index a21e6b9..997618c 100644
--- a/LICENSE
+++ b/LICENSE
@@ -248,6 +248,7 @@ License: http://www.apache.org/licenses/LICENSE-2.0
This product includes code from Presto.
* Retry wait and jitter logic in Tasks.java
+* SQL grammar rules for parsing CALL statements in IcebergSqlExtensions.g4
Copyright: 2016 Facebook and contributors
Home page: https://prestodb.io/
@@ -279,6 +280,7 @@ This product includes code from Apache Spark.
* dev/check-license script
* vectorized reading of definition levels in BaseVectorizedParquetValuesReader.java
+* portions of the extensions parser
Copyright: 2011-2018 The Apache Software Foundation
Home page: http://hive.apache.org/
diff --git a/baseline.gradle b/baseline.gradle
index a47e5e9..84cc627 100644
--- a/baseline.gradle
+++ b/baseline.gradle
@@ -36,6 +36,7 @@ subprojects {
apply plugin: 'com.palantir.baseline-checkstyle'
apply plugin: 'com.palantir.baseline-error-prone'
}
+ apply plugin: 'com.palantir.baseline-scalastyle'
apply plugin: 'com.palantir.baseline-class-uniqueness'
apply plugin: 'com.palantir.baseline-reproducibility'
apply plugin: 'com.palantir.baseline-exact-dependencies'
diff --git a/build.gradle b/build.gradle
index eb4a7cb..984523d 100644
--- a/build.gradle
+++ b/build.gradle
@@ -884,6 +884,49 @@ project(':iceberg-spark3') {
}
}
+project(":iceberg-spark3-extensions") {
+ apply plugin: 'java'
+ apply plugin: 'scala'
+ apply plugin: 'antlr'
+
+ configurations {
+ /*
+ The Gradle Antlr plugin erroneously adds both antlr-build and runtime dependencies to the runtime path. This
+ bug https://github.com/gradle/gradle/issues/820 exists because older versions of Antlr do not have separate
+ runtime and compile dependencies and they do not want to break backwards compatibility. So to only end up with
+ the runtime dependency on the runtime classpath we remove the dependencies added by the plugin here. Then add
+ the runtime dependency back to only the runtime configuration manually.
+ */
+ compile {
+ extendsFrom = extendsFrom.findAll { it != configurations.antlr }
+ }
+ }
+
+ dependencies {
+ compileOnly project(':iceberg-spark3')
+
+ compileOnly "org.scala-lang:scala-library"
+ compileOnly("org.apache.spark:spark-hive_2.12") {
+ exclude group: 'org.apache.avro', module: 'avro'
+ exclude group: 'org.apache.arrow'
+ }
+
+ testCompile project(path: ':iceberg-api', configuration: 'testArtifacts')
+ testCompile project(path: ':iceberg-hive-metastore', configuration: 'testArtifacts')
+ testCompile project(path: ':iceberg-spark', configuration: 'testArtifacts')
+ testCompile project(path: ':iceberg-spark3', configuration: 'testArtifacts')
+
+ // Required because we remove antlr plugin dependencies from the compile configuration, see note above
+ runtime "org.antlr:antlr4-runtime:4.7.1"
+ antlr "org.antlr:antlr4:4.7.1"
+ }
+
+ generateGrammarSource {
+ maxHeapSize = "64m"
+ arguments += ['-visitor', '-package', 'org.apache.spark.sql.catalyst.parser.extensions']
+ }
+}
+
project(':iceberg-spark3-runtime') {
apply plugin: 'com.github.johnrengelman.shadow'
diff --git a/project/scalastyle_config.xml b/project/scalastyle_config.xml
new file mode 100644
index 0000000..abb919f
--- /dev/null
+++ b/project/scalastyle_config.xml
@@ -0,0 +1,147 @@
+<!--
+ ~ Licensed under the Apache License, Version 2.0 (the "License");
+ ~ you may not use this file except in compliance with the License.
+ ~ You may obtain a copy of the License at
+ ~
+ ~ http://www.apache.org/licenses/LICENSE-2.0
+ ~
+ ~ Unless required by applicable law or agreed to in writing, software
+ ~ distributed under the License is distributed on an "AS IS" BASIS,
+ ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ ~ See the License for the specific language governing permissions and
+ ~ limitations under the License.
+ -->
+
+<scalastyle commentFilter="enabled">
+ <name>Iceberg Scalastyle configuration</name>
+ <check level="error" class="org.scalastyle.file.FileTabChecker" enabled="true"/>
+ <check level="error" class="org.scalastyle.file.FileLengthChecker" enabled="true">
+ <parameters>
+ <parameter name="maxFileLength"><![CDATA[800]]></parameter>
+ </parameters>
+ </check>
+ <check level="error" class="org.scalastyle.file.HeaderMatchesChecker" enabled="true">
+ <parameters>
+ <parameter name="regex">true</parameter>
+ <parameter name="header">(?m)^/\*$\n^ \* Licensed to the Apache Software Foundation \(ASF\) under one$</parameter>
+ </parameters>
+ </check>
+ <check level="error" class="org.scalastyle.scalariform.SpacesAfterPlusChecker" enabled="true"/>
+ <check level="error" class="org.scalastyle.file.WhitespaceEndOfLineChecker" enabled="true"/>
+ <check level="error" class="org.scalastyle.scalariform.SpacesBeforePlusChecker" enabled="true"/>
+ <check level="error" class="org.scalastyle.file.FileLineLengthChecker" enabled="true">
+ <parameters>
+ <parameter name="maxLineLength"><![CDATA[120]]></parameter>
+ <parameter name="tabSize"><![CDATA[4]]></parameter>
+ <parameter name="ignoreImports">true</parameter>
+ </parameters>
+ </check>
+ <check level="error" class="org.scalastyle.scalariform.ClassNamesChecker" enabled="true">
+ <parameters>
+ <parameter name="regex"><![CDATA[[A-Z][A-Za-z]*]]></parameter>
+ </parameters>
+ </check>
+ <check level="error" class="org.scalastyle.scalariform.ObjectNamesChecker" enabled="true">
+ <parameters>
+ <parameter name="regex"><![CDATA[[A-Z][A-Za-z]*]]></parameter>
+ </parameters>
+ </check>
+ <check level="error" class="org.scalastyle.scalariform.PackageObjectNamesChecker" enabled="true">
+ <parameters>
+ <parameter name="regex"><![CDATA[^[a-z][A-Za-z]*$]]></parameter>
+ </parameters>
+ </check>
+ <check level="error" class="org.scalastyle.scalariform.EqualsHashCodeChecker" enabled="true"/>
+ <check level="error" class="org.scalastyle.scalariform.IllegalImportsChecker" enabled="true">
+ <parameters>
+ <parameter name="illegalImports"><![CDATA[sun._,java.awt._]]></parameter>
+ </parameters>
+ </check>
+ <check level="error" class="org.scalastyle.scalariform.ParameterNumberChecker" enabled="true">
+ <parameters>
+ <parameter name="maxParameters"><![CDATA[8]]></parameter>
+ </parameters>
+ </check>
+ <check level="error" class="org.scalastyle.scalariform.MagicNumberChecker" enabled="true">
+ <parameters>
+ <parameter name="ignore"><![CDATA[-1,0,1,2,3]]></parameter>
+ </parameters>
+ </check>
+ <check level="error" class="org.scalastyle.scalariform.NoWhitespaceBeforeLeftBracketChecker" enabled="true"/>
+ <check level="error" class="org.scalastyle.scalariform.NoWhitespaceAfterLeftBracketChecker" enabled="true"/>
+ <check level="error" class="org.scalastyle.scalariform.ReturnChecker" enabled="false"/>
+ <check level="error" class="org.scalastyle.scalariform.NullChecker" enabled="false"/>
+ <check level="error" class="org.scalastyle.scalariform.NoCloneChecker" enabled="true"/>
+ <check level="error" class="org.scalastyle.scalariform.NoFinalizeChecker" enabled="true"/>
+ <check level="error" class="org.scalastyle.scalariform.CovariantEqualsChecker" enabled="true"/>
+ <check level="error" class="org.scalastyle.scalariform.StructuralTypeChecker" enabled="true"/>
+ <check level="error" class="org.scalastyle.file.RegexChecker" enabled="true">
+ <parameters>
+ <parameter name="regex"><![CDATA[println]]></parameter>
+ </parameters>
+ </check>
+ <check level="error" class="org.scalastyle.scalariform.NumberOfTypesChecker" enabled="true">
+ <parameters>
+ <parameter name="maxTypes"><![CDATA[30]]></parameter>
+ </parameters>
+ </check>
+ <check level="error" class="org.scalastyle.scalariform.CyclomaticComplexityChecker" enabled="true">
+ <parameters>
+ <parameter name="maximum"><![CDATA[10]]></parameter>
+ </parameters>
+ </check>
+ <check level="error" class="org.scalastyle.scalariform.UppercaseLChecker" enabled="true"/>
+ <check level="error" class="org.scalastyle.scalariform.SimplifyBooleanExpressionChecker" enabled="true"/>
+ <check level="error" class="org.scalastyle.scalariform.IfBraceChecker" enabled="true">
+ <parameters>
+ <parameter name="singleLineAllowed"><![CDATA[true]]></parameter>
+ <parameter name="doubleLineAllowed"><![CDATA[false]]></parameter>
+ </parameters>
+ </check>
+ <check level="error" class="org.scalastyle.scalariform.MethodLengthChecker" enabled="true">
+ <parameters>
+ <parameter name="maxLength"><![CDATA[50]]></parameter>
+ </parameters>
+ </check>
+ <check level="error" class="org.scalastyle.scalariform.MethodNamesChecker" enabled="true">
+ <parameters>
+ <parameter name="regex"><![CDATA[^[a-z][A-Za-z0-9]*$]]></parameter>
+ </parameters>
+ </check>
+ <check level="error" class="org.scalastyle.scalariform.NumberOfMethodsInTypeChecker" enabled="false">
+ <parameters>
+ <parameter name="maxMethods"><![CDATA[30]]></parameter>
+ </parameters>
+ </check>
+ <check level="error" class="org.scalastyle.scalariform.PublicMethodsHaveTypeChecker" enabled="true"/>
+ <check level="error" class="org.scalastyle.file.NewLineAtEofChecker" enabled="true"/>
+ <check level="error" class="org.scalastyle.file.NoNewLineAtEofChecker" enabled="false"/>
+ <check level="error" class="org.scalastyle.scalariform.WhileChecker" enabled="false"/>
+ <check level="error" class="org.scalastyle.scalariform.VarFieldChecker" enabled="false"/>
+ <check level="error" class="org.scalastyle.scalariform.VarLocalChecker" enabled="false"/>
+ <check level="error" class="org.scalastyle.scalariform.RedundantIfChecker" enabled="false"/>
+ <check level="error" class="org.scalastyle.scalariform.TokenChecker" enabled="false">
+ <parameters>
+ <parameter name="regex"><![CDATA[println]]></parameter>
+ </parameters>
+ </check>
+ <check level="error" class="org.scalastyle.scalariform.DeprecatedJavaChecker" enabled="true"/>
+ <check level="error" class="org.scalastyle.scalariform.EmptyClassChecker" enabled="true"/>
+ <check level="error" class="org.scalastyle.scalariform.ClassTypeParameterChecker" enabled="true">
+ <parameters>
+ <parameter name="regex"><![CDATA[^[A-Z_]$]]></parameter>
+ </parameters>
+ </check>
+ <check level="error" class="org.scalastyle.scalariform.UnderscoreImportChecker" enabled="false"/>
+ <check level="error" class="org.scalastyle.scalariform.ImportOrderChecker" enabled="true">
+ <parameters>
+ <parameter name="groups">all</parameter>
+ <parameter name="group.all">.+</parameter>
+ </parameters>
+ </check>
+ <check level="error" class="org.scalastyle.scalariform.DisallowSpaceBeforeTokenChecker" enabled="true">
+ <parameters>
+ <parameter name="tokens">COMMA</parameter>
+ </parameters>
+ </check>
+</scalastyle>
diff --git a/settings.gradle b/settings.gradle
index 0377d31..9e4b58b 100644
--- a/settings.gradle
+++ b/settings.gradle
@@ -32,6 +32,7 @@ include 'parquet'
include 'bundled-guava'
include 'spark'
include 'spark3'
+include 'spark3-extensions'
include 'spark3-runtime'
include 'pig'
include 'hive-metastore'
@@ -50,6 +51,7 @@ project(':parquet').name = 'iceberg-parquet'
project(':bundled-guava').name = 'iceberg-bundled-guava'
project(':spark').name = 'iceberg-spark'
project(':spark3').name = 'iceberg-spark3'
+project(':spark3-extensions').name = 'iceberg-spark3-extensions'
project(':spark3-runtime').name = 'iceberg-spark3-runtime'
project(':pig').name = 'iceberg-pig'
project(':hive-metastore').name = 'iceberg-hive-metastore'
diff --git a/spark3-extensions/src/main/antlr/org.apache.spark.sql.catalyst.parser.extensions/IcebergSqlExtensions.g4 b/spark3-extensions/src/main/antlr/org.apache.spark.sql.catalyst.parser.extensions/IcebergSqlExtensions.g4
new file mode 100644
index 0000000..1b11000
--- /dev/null
+++ b/spark3-extensions/src/main/antlr/org.apache.spark.sql.catalyst.parser.extensions/IcebergSqlExtensions.g4
@@ -0,0 +1,219 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ * This file is an adaptation of Presto's and Spark's grammar files.
+ */
+
+grammar IcebergSqlExtensions;
+
+@lexer::members {
+ /**
+ * Verify whether current token is a valid decimal token (which contains dot).
+ * Returns true if the character that follows the token is not a digit or letter or underscore.
+ *
+ * For example:
+ * For char stream "2.3", "2." is not a valid decimal token, because it is followed by digit '3'.
+ * For char stream "2.3_", "2.3" is not a valid decimal token, because it is followed by '_'.
+ * For char stream "2.3W", "2.3" is not a valid decimal token, because it is followed by 'W'.
+ * For char stream "12.0D 34.E2+0.12 " 12.0D is a valid decimal token because it is followed
+ * by a space. 34.E2 is a valid decimal token because it is followed by symbol '+'
+ * which is not a digit or letter or underscore.
+ */
+ public boolean isValidDecimal() {
+ int nextChar = _input.LA(1);
+ if (nextChar >= 'A' && nextChar <= 'Z' || nextChar >= '0' && nextChar <= '9' ||
+ nextChar == '_') {
+ return false;
+ } else {
+ return true;
+ }
+ }
+
+ /**
+ * This method will be called when we see '/*' and try to match it as a bracketed comment.
+ * If the next character is '+', it should be parsed as hint later, and we cannot match
+ * it as a bracketed comment.
+ *
+ * Returns true if the next character is '+'.
+ */
+ public boolean isHint() {
+ int nextChar = _input.LA(1);
+ if (nextChar == '+') {
+ return true;
+ } else {
+ return false;
+ }
+ }
+}
+
+singleStatement
+ : statement EOF
+ ;
+
+statement
+ : CALL multipartIdentifier '(' (callArgument (',' callArgument)*)? ')' #call
+ | .*? #nonIcebergCommand
+ ;
+
+callArgument
+ : expression #positionalArgument
+ | identifier '=>' expression #namedArgument
+ ;
+
+expression
+ : constant
+ ;
+
+constant
+ : number #numericLiteral
+ | booleanValue #booleanLiteral
+ | STRING+ #stringLiteral
+ | identifier STRING #typeConstructor
+ ;
+
+booleanValue
+ : TRUE | FALSE
+ ;
+
+number
+ : MINUS? EXPONENT_VALUE #exponentLiteral
+ | MINUS? DECIMAL_VALUE #decimalLiteral
+ | MINUS? INTEGER_VALUE #integerLiteral
+ | MINUS? BIGINT_LITERAL #bigIntLiteral
+ | MINUS? SMALLINT_LITERAL #smallIntLiteral
+ | MINUS? TINYINT_LITERAL #tinyIntLiteral
+ | MINUS? DOUBLE_LITERAL #doubleLiteral
+ | MINUS? FLOAT_LITERAL #floatLiteral
+ | MINUS? BIGDECIMAL_LITERAL #bigDecimalLiteral
+ ;
+
+multipartIdentifier
+ : parts+=identifier ('.' parts+=identifier)*
+ ;
+
+identifier
+ : IDENTIFIER #unquotedIdentifier
+ | quotedIdentifier #quotedIdentifierAlternative
+ | nonReserved #unquotedIdentifier
+ ;
+
+quotedIdentifier
+ : BACKQUOTED_IDENTIFIER
+ ;
+
+nonReserved
+ : CALL
+ | TRUE | FALSE
+ ;
+
+CALL: 'CALL';
+
+TRUE: 'TRUE';
+FALSE: 'FALSE';
+
+PLUS: '+';
+MINUS: '-';
+
+STRING
+ : '\'' ( ~('\''|'\\') | ('\\' .) )* '\''
+ | '"' ( ~('"'|'\\') | ('\\' .) )* '"'
+ ;
+
+BIGINT_LITERAL
+ : DIGIT+ 'L'
+ ;
+
+SMALLINT_LITERAL
+ : DIGIT+ 'S'
+ ;
+
+TINYINT_LITERAL
+ : DIGIT+ 'Y'
+ ;
+
+INTEGER_VALUE
+ : DIGIT+
+ ;
+
+EXPONENT_VALUE
+ : DIGIT+ EXPONENT
+ | DECIMAL_DIGITS EXPONENT {isValidDecimal()}?
+ ;
+
+DECIMAL_VALUE
+ : DECIMAL_DIGITS {isValidDecimal()}?
+ ;
+
+FLOAT_LITERAL
+ : DIGIT+ EXPONENT? 'F'
+ | DECIMAL_DIGITS EXPONENT? 'F' {isValidDecimal()}?
+ ;
+
+DOUBLE_LITERAL
+ : DIGIT+ EXPONENT? 'D'
+ | DECIMAL_DIGITS EXPONENT? 'D' {isValidDecimal()}?
+ ;
+
+BIGDECIMAL_LITERAL
+ : DIGIT+ EXPONENT? 'BD'
+ | DECIMAL_DIGITS EXPONENT? 'BD' {isValidDecimal()}?
+ ;
+
+IDENTIFIER
+ : (LETTER | DIGIT | '_')+
+ ;
+
+BACKQUOTED_IDENTIFIER
+ : '`' ( ~'`' | '``' )* '`'
+ ;
+
+fragment DECIMAL_DIGITS
+ : DIGIT+ '.' DIGIT*
+ | '.' DIGIT+
+ ;
+
+fragment EXPONENT
+ : 'E' [+-]? DIGIT+
+ ;
+
+fragment DIGIT
+ : [0-9]
+ ;
+
+fragment LETTER
+ : [A-Z]
+ ;
+
+SIMPLE_COMMENT
+ : '--' ('\\\n' | ~[\r\n])* '\r'? '\n'? -> channel(HIDDEN)
+ ;
+
+BRACKETED_COMMENT
+ : '/*' {!isHint()}? (BRACKETED_COMMENT|.)*? '*/' -> channel(HIDDEN)
+ ;
+
+WS
+ : [ \r\n\t]+ -> channel(HIDDEN)
+ ;
+
+// Catch-all for anything we can't recognize.
+// We use this to be able to ignore and recover all the text
+// when splitting statements with DelimiterLexer
+UNRECOGNIZED
+ : .
+ ;
diff --git a/spark3-extensions/src/main/scala/org/apache/iceberg/spark/extensions/IcebergSparkSessionExtensions.scala b/spark3-extensions/src/main/scala/org/apache/iceberg/spark/extensions/IcebergSparkSessionExtensions.scala
new file mode 100644
index 0000000..add3899
--- /dev/null
+++ b/spark3-extensions/src/main/scala/org/apache/iceberg/spark/extensions/IcebergSparkSessionExtensions.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iceberg.spark.extensions
+
+import org.apache.spark.sql.SparkSessionExtensions
+import org.apache.spark.sql.catalyst.parser.extensions.IcebergSparkSqlExtensionsParser
+import org.apache.spark.sql.execution.datasources.v2.ExtendedDataSourceV2Strategy
+
+class IcebergSparkSessionExtensions extends (SparkSessionExtensions => Unit) {
+
+ override def apply(extensions: SparkSessionExtensions): Unit = {
+ extensions.injectParser { case (_, parser) => new IcebergSparkSqlExtensionsParser(parser) }
+ extensions.injectPlannerStrategy { _ => ExtendedDataSourceV2Strategy }
+ }
+}
diff --git a/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/IcebergSparkSqlExtensionsParser.scala b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/IcebergSparkSqlExtensionsParser.scala
new file mode 100644
index 0000000..a5b4c23
--- /dev/null
+++ b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/IcebergSparkSqlExtensionsParser.scala
@@ -0,0 +1,180 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.catalyst.parser.extensions
+
+import org.antlr.v4.runtime._
+import org.antlr.v4.runtime.atn.PredictionMode
+import org.antlr.v4.runtime.misc.ParseCancellationException
+import org.antlr.v4.runtime.tree.TerminalNodeImpl
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.parser.{ParseErrorListener, ParseException, ParserInterface, UpperCaseCharStream}
+import org.apache.spark.sql.catalyst.parser.extensions.IcebergSqlExtensionsParser._
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.trees.Origin
+import org.apache.spark.sql.internal.{SQLConf, VariableSubstitution}
+import org.apache.spark.sql.types.{DataType, StructType}
+
+class IcebergSparkSqlExtensionsParser(delegate: ParserInterface) extends ParserInterface {
+
+ private lazy val substitutor = new VariableSubstitution(SQLConf.get)
+ private lazy val astBuilder = new IcebergSqlExtensionsAstBuilder(delegate)
+
+ /**
+ * Parse a string to a DataType.
+ */
+ override def parseDataType(sqlText: String): DataType = {
+ delegate.parseDataType(sqlText)
+ }
+
+ /**
+ * Parse a string to a raw DataType without CHAR/VARCHAR replacement.
+ */
+ override def parseRawDataType(sqlText: String): DataType = {
+ delegate.parseRawDataType(sqlText)
+ }
+
+ /**
+ * Parse a string to an Expression.
+ */
+ override def parseExpression(sqlText: String): Expression = {
+ delegate.parseExpression(sqlText)
+ }
+
+ /**
+ * Parse a string to a TableIdentifier.
+ */
+ override def parseTableIdentifier(sqlText: String): TableIdentifier = {
+ delegate.parseTableIdentifier(sqlText)
+ }
+
+ /**
+ * Parse a string to a FunctionIdentifier.
+ */
+ override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = {
+ delegate.parseFunctionIdentifier(sqlText)
+ }
+
+ /**
+ * Parse a string to a multi-part identifier.
+ */
+ override def parseMultipartIdentifier(sqlText: String): Seq[String] = {
+ delegate.parseMultipartIdentifier(sqlText)
+ }
+
+ /**
+ * Creates StructType for a given SQL string, which is a comma separated list of field
+ * definitions which will preserve the correct Hive metadata.
+ */
+ override def parseTableSchema(sqlText: String): StructType = {
+ delegate.parseTableSchema(sqlText)
+ }
+
+ /**
+ * Parse a string to a LogicalPlan.
+ */
+ override def parsePlan(sqlText: String): LogicalPlan = {
+ val sqlTextAfterSubstitution = substitutor.substitute(sqlText)
+ parse(sqlTextAfterSubstitution) { parser =>
+ astBuilder.visit(parser.singleStatement()) match {
+ case plan: LogicalPlan => plan
+ case _ => delegate.parsePlan(sqlText)
+ }
+ }
+ }
+
+ protected def parse[T](command: String)(toResult: IcebergSqlExtensionsParser => T): T = {
+ val lexer = new IcebergSqlExtensionsLexer(new UpperCaseCharStream(CharStreams.fromString(command)))
+ lexer.removeErrorListeners()
+ lexer.addErrorListener(ParseErrorListener)
+
+ val tokenStream = new CommonTokenStream(lexer)
+ val parser = new IcebergSqlExtensionsParser(tokenStream)
+ parser.addParseListener(IcebergSqlExtensionsPostProcessor)
+ parser.removeErrorListeners()
+ parser.addErrorListener(ParseErrorListener)
+
+ try {
+ try {
+ // first, try parsing with potentially faster SLL mode
+ parser.getInterpreter.setPredictionMode(PredictionMode.SLL)
+ toResult(parser)
+ }
+ catch {
+ case _: ParseCancellationException =>
+ // if we fail, parse with LL mode
+ tokenStream.seek(0) // rewind input stream
+ parser.reset()
+
+ // Try Again.
+ parser.getInterpreter.setPredictionMode(PredictionMode.LL)
+ toResult(parser)
+ }
+ }
+ catch {
+ case e: ParseException if e.command.isDefined =>
+ throw e
+ case e: ParseException =>
+ throw e.withCommand(command)
+ case e: AnalysisException =>
+ val position = Origin(e.line, e.startPosition)
+ throw new ParseException(Option(command), e.message, position, position)
+ }
+ }
+}
+
+/**
+ * The post-processor validates & cleans-up the parse tree during the parse process.
+ */
+// while we reuse ParseErrorListener and ParseException, we have to copy and modify PostProcessor
+// as it directly depends on classes generated from the extensions grammar file
+case object IcebergSqlExtensionsPostProcessor extends IcebergSqlExtensionsBaseListener {
+
+ /** Remove the back ticks from an Identifier. */
+ override def exitQuotedIdentifier(ctx: QuotedIdentifierContext): Unit = {
+ replaceTokenByIdentifier(ctx, 1) { token =>
+ // Remove the double back ticks in the string.
+ token.setText(token.getText.replace("``", "`"))
+ token
+ }
+ }
+
+ /** Treat non-reserved keywords as Identifiers. */
+ override def exitNonReserved(ctx: NonReservedContext): Unit = {
+ replaceTokenByIdentifier(ctx, 0)(identity)
+ }
+
+ private def replaceTokenByIdentifier(
+ ctx: ParserRuleContext,
+ stripMargins: Int)(
+ f: CommonToken => CommonToken = identity): Unit = {
+ val parent = ctx.getParent
+ parent.removeLastChild()
+ val token = ctx.getChild(0).getPayload.asInstanceOf[Token]
+ val newToken = new CommonToken(
+ new org.antlr.v4.runtime.misc.Pair(token.getTokenSource, token.getInputStream),
+ IcebergSqlExtensionsParser.IDENTIFIER,
+ token.getChannel,
+ token.getStartIndex + stripMargins,
+ token.getStopIndex - stripMargins)
+ parent.addChild(new TerminalNodeImpl(f(newToken)))
+ }
+}
diff --git a/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/IcebergSqlExtensionsAstBuilder.scala b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/IcebergSqlExtensionsAstBuilder.scala
new file mode 100644
index 0000000..0094bc6
--- /dev/null
+++ b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/IcebergSqlExtensionsAstBuilder.scala
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.catalyst.parser.extensions
+
+import org.antlr.v4.runtime._
+import org.antlr.v4.runtime.tree.{ParseTree, TerminalNode}
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.parser.ParserInterface
+import org.apache.spark.sql.catalyst.parser.ParserUtils._
+import org.apache.spark.sql.catalyst.parser.extensions.IcebergSqlExtensionsParser._
+import org.apache.spark.sql.catalyst.plans.logical.{CallArgument, CallStatement, LogicalPlan, NamedArgument, PositionalArgument}
+import scala.collection.JavaConverters._
+
+class IcebergSqlExtensionsAstBuilder(delegate: ParserInterface) extends IcebergSqlExtensionsBaseVisitor[AnyRef] {
+
+ override def visitCall(ctx: CallContext): LogicalPlan = {
+ val name = ctx.multipartIdentifier.parts.asScala.map(_.getText)
+ val args = ctx.callArgument.asScala.map(typedVisit[CallArgument])
+ CallStatement(name, args)
+ }
+
+ override def visitPositionalArgument(ctx: PositionalArgumentContext): CallArgument = withOrigin(ctx) {
+ val expr = typedVisit[Expression](ctx.expression)
+ PositionalArgument(expr)
+ }
+
+ override def visitNamedArgument(ctx: NamedArgumentContext): CallArgument = withOrigin(ctx) {
+ val name = ctx.identifier.getText
+ val expr = typedVisit[Expression](ctx.expression)
+ NamedArgument(name, expr)
+ }
+
+ // return null for any statement we cannot handle so it can be parsed with the main Spark parser
+ override def visitNonIcebergCommand(ctx: NonIcebergCommandContext): LogicalPlan = null
+
+ override def visitSingleStatement(ctx: SingleStatementContext): LogicalPlan = withOrigin(ctx) {
+ visit(ctx.statement).asInstanceOf[LogicalPlan]
+ }
+
+ override def visitExpression(ctx: ExpressionContext): Expression = {
+ // reconstruct the SQL string and parse it using the main Spark parser
+ // while we can avoid the logic to build Spark expressions, we still have to parse them
+ // we cannot call ctx.getText directly since it will not render spaces correctly
+ // that's why we need to recurse down the tree in reconstructSqlString
+ val sqlString = reconstructSqlString(ctx)
+ delegate.parseExpression(sqlString)
+ }
+
+ private def reconstructSqlString(ctx: ParserRuleContext): String = {
+ ctx.children.asScala.map {
+ case c: ParserRuleContext => reconstructSqlString(c)
+ case t: TerminalNode => t.getText
+ }.mkString(" ")
+ }
+
+ private def typedVisit[T](ctx: ParseTree): T = {
+ ctx.accept(this).asInstanceOf[T]
+ }
+}
diff --git a/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala
new file mode 100644
index 0000000..85db3dc
--- /dev/null
+++ b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.catalyst.plans.logical
+
+import org.apache.spark.sql.catalyst.expressions.Expression
+
+case class CallStatement(name: Seq[String], args: Seq[CallArgument]) extends ParsedStatement
+
+sealed trait CallArgument {
+ def expr: Expression
+}
+
+case class NamedArgument(name: String, expr: Expression) extends CallArgument
+
+case class PositionalArgument(expr: Expression) extends CallArgument
diff --git a/spark3-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDataSourceV2Strategy.scala b/spark3-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDataSourceV2Strategy.scala
new file mode 100644
index 0000000..7dc0b18
--- /dev/null
+++ b/spark3-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDataSourceV2Strategy.scala
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.v2
+
+import org.apache.spark.sql.{AnalysisException, Strategy}
+import org.apache.spark.sql.catalyst.plans.logical.{CallStatement, LogicalPlan}
+import org.apache.spark.sql.execution.SparkPlan
+
+object ExtendedDataSourceV2Strategy extends Strategy {
+
+ override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+ case _: CallStatement =>
+ throw new AnalysisException("CALL statements are not currently supported")
+ case _ => Nil
+ }
+}
diff --git a/spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCallStatementParser.java b/spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCallStatementParser.java
new file mode 100644
index 0000000..cf83ecd
--- /dev/null
+++ b/spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCallStatementParser.java
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iceberg.spark.extensions;
+
+import java.math.BigDecimal;
+import java.sql.Timestamp;
+import java.time.Instant;
+import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.catalyst.expressions.Expression;
+import org.apache.spark.sql.catalyst.expressions.Literal;
+import org.apache.spark.sql.catalyst.expressions.Literal$;
+import org.apache.spark.sql.catalyst.parser.ParseException;
+import org.apache.spark.sql.catalyst.parser.ParserInterface;
+import org.apache.spark.sql.catalyst.plans.logical.CallArgument;
+import org.apache.spark.sql.catalyst.plans.logical.CallStatement;
+import org.apache.spark.sql.catalyst.plans.logical.NamedArgument;
+import org.apache.spark.sql.catalyst.plans.logical.PositionalArgument;
+import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.DataTypes;
+import org.junit.AfterClass;
+import org.junit.Assert;
+import org.junit.BeforeClass;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+import scala.collection.JavaConverters;
+
+public class TestCallStatementParser {
+
+ @Rule
+ public TemporaryFolder temp = new TemporaryFolder();
+
+ private static SparkSession spark = null;
+ private static ParserInterface parser = null;
+
+ @BeforeClass
+ public static void startSpark() {
+ TestCallStatementParser.spark = SparkSession.builder()
+ .master("local[2]")
+ .config("spark.sql.extensions", IcebergSparkSessionExtensions.class.getName())
+ .config("spark.extra.prop", "value")
+ .getOrCreate();
+ TestCallStatementParser.parser = spark.sessionState().sqlParser();
+ }
+
+ @AfterClass
+ public static void stopSpark() {
+ SparkSession currentSpark = TestCallStatementParser.spark;
+ TestCallStatementParser.spark = null;
+ TestCallStatementParser.parser = null;
+ currentSpark.stop();
+ }
+
+ @Test
+ public void testCallWithPositionalArgs() throws ParseException {
+ CallStatement call = (CallStatement) parser.parsePlan("CALL c.n.func(1, '2', 3L, true, 1.0D, 9.0e1, 900e-1BD)");
+ Assert.assertEquals(ImmutableList.of("c", "n", "func"), JavaConverters.seqAsJavaList(call.name()));
+
+ Assert.assertEquals(7, call.args().size());
+
+ checkArg(call, 0, 1, DataTypes.IntegerType);
+ checkArg(call, 1, "2", DataTypes.StringType);
+ checkArg(call, 2, 3L, DataTypes.LongType);
+ checkArg(call, 3, true, DataTypes.BooleanType);
+ checkArg(call, 4, 1.0D, DataTypes.DoubleType);
+ checkArg(call, 5, 9.0e1, DataTypes.DoubleType);
+ checkArg(call, 6, new BigDecimal("900e-1"), DataTypes.createDecimalType(3, 1));
+ }
+
+ @Test
+ public void testCallWithNamedArgs() throws ParseException {
+ CallStatement call = (CallStatement) parser.parsePlan("CALL cat.system.func(c1 => 1, c2 => '2', c3 => true)");
+ Assert.assertEquals(ImmutableList.of("cat", "system", "func"), JavaConverters.seqAsJavaList(call.name()));
+
+ Assert.assertEquals(3, call.args().size());
+
+ checkArg(call, 0, "c1", 1, DataTypes.IntegerType);
+ checkArg(call, 1, "c2", "2", DataTypes.StringType);
+ checkArg(call, 2, "c3", true, DataTypes.BooleanType);
+ }
+
+ @Test
+ public void testCallWithMixedArgs() throws ParseException {
+ CallStatement call = (CallStatement) parser.parsePlan("CALL cat.system.func(c1 => 1, '2')");
+ Assert.assertEquals(ImmutableList.of("cat", "system", "func"), JavaConverters.seqAsJavaList(call.name()));
+
+ Assert.assertEquals(2, call.args().size());
+
+ checkArg(call, 0, "c1", 1, DataTypes.IntegerType);
+ checkArg(call, 1, "2", DataTypes.StringType);
+ }
+
+ @Test
+ public void testCallWithTimestampArg() throws ParseException {
+ CallStatement call = (CallStatement) parser.parsePlan("CALL cat.system.func(TIMESTAMP '2017-02-03T10:37:30.00Z')");
+ Assert.assertEquals(ImmutableList.of("cat", "system", "func"), JavaConverters.seqAsJavaList(call.name()));
+
+ Assert.assertEquals(1, call.args().size());
+
+ checkArg(call, 0, Timestamp.from(Instant.parse("2017-02-03T10:37:30.00Z")), DataTypes.TimestampType);
+ }
+
+ @Test
+ public void testCallWithVarSubstitution() throws ParseException {
+ CallStatement call = (CallStatement) parser.parsePlan("CALL cat.system.func('${spark.extra.prop}')");
+ Assert.assertEquals(ImmutableList.of("cat", "system", "func"), JavaConverters.seqAsJavaList(call.name()));
+
+ Assert.assertEquals(1, call.args().size());
+
+ checkArg(call, 0, "value", DataTypes.StringType);
+ }
+
+ private void checkArg(CallStatement call, int index, Object expectedValue, DataType expectedType) {
+ checkArg(call, index, null, expectedValue, expectedType);
+ }
+
+ private void checkArg(CallStatement call, int index, String expectedName,
+ Object expectedValue, DataType expectedType) {
+
+ if (expectedName != null) {
+ NamedArgument arg = checkCast(call.args().apply(index), NamedArgument.class);
+ Assert.assertEquals(expectedName, arg.name());
+ } else {
+ CallArgument arg = call.args().apply(index);
+ checkCast(arg, PositionalArgument.class);
+ }
+
+ Expression expectedExpr = toSparkLiteral(expectedValue, expectedType);
+ Expression actualExpr = call.args().apply(index).expr();
+ Assert.assertEquals("Arg types must match", expectedExpr.dataType(), actualExpr.dataType());
+ Assert.assertEquals("Arg must match", expectedExpr, actualExpr);
+ }
+
+ private Literal toSparkLiteral(Object value, DataType dataType) {
+ return Literal$.MODULE$.create(value, dataType);
+ }
+
+ private <T> T checkCast(Object value, Class<T> expectedClass) {
+ Assert.assertTrue("Expected instance of " + expectedClass.getName(), expectedClass.isInstance(value));
+ return expectedClass.cast(value);
+ }
+}
diff --git a/versions.props b/versions.props
index 2d587c0..7e03df5 100644
--- a/versions.props
+++ b/versions.props
@@ -17,6 +17,7 @@ com.github.ben-manes.caffeine:caffeine = 2.7.0
org.apache.arrow:arrow-vector = 1.0.0
org.apache.arrow:arrow-memory-netty = 1.0.0
com.github.stephenc.findbugs:findbugs-annotations = 1.3.9-1
+org.scala-lang:scala-library = 2.12.10
# test deps
junit:junit = 4.12