You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ge...@apache.org on 2022/06/16 04:08:58 UTC

[spark] branch master updated: [SPARK-39383][SQL] Refactor DEFAULT column support to skip passing the primary Analyzer around

This is an automated email from the ASF dual-hosted git repository.

gengliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 0b4739eb2c6 [SPARK-39383][SQL] Refactor DEFAULT column support to skip passing the primary Analyzer around
0b4739eb2c6 is described below

commit 0b4739eb2c66ce69ffc16ad05ee0f12fe51d150b
Author: Daniel Tenedorio <da...@databricks.com>
AuthorDate: Wed Jun 15 21:08:39 2022 -0700

    [SPARK-39383][SQL] Refactor DEFAULT column support to skip passing the primary Analyzer around
    
    ### What changes were proposed in this pull request?
    
    Refactor DEFAULT column support to skip passing the main `Analyzer` around. Instead, the `ResolvedDefaultColumnsUtil.scala` file gains the method `getDefaultColumnAnalyzer` which constructs a separate `Analyzer` containing only built-in functions.
    
    ### Why are the changes needed?
    
    This cleans up the code by reducing state passed to different methods and classes. The idea came from code review in https://github.com/apache/spark/pull/36771.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No, it is a refactoring-only change.
    
    ### How was this patch tested?
    
    DEFAULT column support is covered by existing tests.
    
    Closes #36880 from dtenedor/refactor-default-col-analyzers.
    
    Authored-by: Daniel Tenedorio <da...@databricks.com>
    Signed-off-by: Gengliang Wang <ge...@apache.org>
---
 .../spark/sql/catalyst/analysis/Analyzer.scala     |  2 +-
 .../catalyst/analysis/ResolveDefaultColumns.scala  | 19 +++------
 .../catalyst/util/ResolveDefaultColumnsUtil.scala  | 48 ++++++++++++++++++----
 .../sql/catalyst/catalog/SessionCatalogSuite.scala | 13 +++---
 .../apache/spark/sql/execution/command/ddl.scala   |  3 +-
 .../spark/sql/execution/command/tables.scala       |  3 +-
 .../execution/datasources/DataSourceStrategy.scala |  2 +-
 .../datasources/v2/DataSourceV2Strategy.scala      |  4 +-
 8 files changed, 58 insertions(+), 36 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index b8fa6e421ca..446bc46d9b1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -313,7 +313,7 @@ class Analyzer(override val catalogManager: CatalogManager)
       ResolveAggregateFunctions ::
       TimeWindowing ::
       SessionWindowing ::
-      ResolveDefaultColumns(this, v1SessionCatalog) ::
+      ResolveDefaultColumns(v1SessionCatalog) ::
       ResolveInlineTables ::
       ResolveLambdaVariables ::
       ResolveTimeZone ::
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultColumns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultColumns.scala
index e47a1230a7e..30f6fc9ea1d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultColumns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultColumns.scala
@@ -47,12 +47,9 @@ import org.apache.spark.sql.types._
  * (1, 5)
  * (4, 6)
  *
- * @param analyzer analyzer to use for processing DEFAULT values stored as text.
  * @param catalog  the catalog to use for looking up the schema of INSERT INTO table objects.
  */
-case class ResolveDefaultColumns(
-    analyzer: Analyzer,
-    catalog: SessionCatalog) extends Rule[LogicalPlan] {
+case class ResolveDefaultColumns(catalog: SessionCatalog) extends Rule[LogicalPlan] {
   override def apply(plan: LogicalPlan): LogicalPlan = {
     plan.resolveOperatorsWithPruning(
       (_ => SQLConf.get.enableDefaultColumns), ruleId) {
@@ -111,7 +108,7 @@ case class ResolveDefaultColumns(
       val expanded: UnresolvedInlineTable =
         addMissingDefaultValuesForInsertFromInlineTable(table, schema)
       val replaced: Option[LogicalPlan] =
-        replaceExplicitDefaultValuesForInputOfInsertInto(analyzer, schema, expanded)
+        replaceExplicitDefaultValuesForInputOfInsertInto(schema, expanded)
       replaced.map { r: LogicalPlan =>
         node = r
         for (child <- children.reverse) {
@@ -135,7 +132,7 @@ case class ResolveDefaultColumns(
       val expanded: Project =
         addMissingDefaultValuesForInsertFromProject(project, schema)
       val replaced: Option[LogicalPlan] =
-        replaceExplicitDefaultValuesForInputOfInsertInto(analyzer, schema, expanded)
+        replaceExplicitDefaultValuesForInputOfInsertInto(schema, expanded)
       replaced.map { r =>
         regenerated.copy(query = r)
       }.getOrElse(i)
@@ -156,8 +153,7 @@ case class ResolveDefaultColumns(
     val schemaForTargetTable: Option[StructType] = getSchemaForTargetTable(u.table)
     schemaForTargetTable.map { schema =>
       val defaultExpressions: Seq[Expression] = schema.fields.map {
-        case f if f.metadata.contains(CURRENT_DEFAULT_COLUMN_METADATA_KEY) =>
-          analyze(analyzer, f, "UPDATE")
+        case f if f.metadata.contains(CURRENT_DEFAULT_COLUMN_METADATA_KEY) => analyze(f, "UPDATE")
         case _ => Literal(null)
       }
       // Create a map from each column name in the target table to its DEFAULT expression.
@@ -187,8 +183,7 @@ case class ResolveDefaultColumns(
       }
     }
     val defaultExpressions: Seq[Expression] = schema.fields.map {
-      case f if f.metadata.contains(CURRENT_DEFAULT_COLUMN_METADATA_KEY) =>
-        analyze(analyzer, f, "MERGE")
+      case f if f.metadata.contains(CURRENT_DEFAULT_COLUMN_METADATA_KEY) => analyze(f, "MERGE")
       case _ => Literal(null)
     }
     val columnNamesToExpressions: Map[String, Expression] =
@@ -323,13 +318,11 @@ case class ResolveDefaultColumns(
    * command from a logical plan.
    */
   private def replaceExplicitDefaultValuesForInputOfInsertInto(
-      analyzer: Analyzer,
       insertTableSchemaWithoutPartitionColumns: StructType,
       input: LogicalPlan): Option[LogicalPlan] = {
     val schema = insertTableSchemaWithoutPartitionColumns
     val defaultExpressions: Seq[Expression] = schema.fields.map {
-      case f if f.metadata.contains(CURRENT_DEFAULT_COLUMN_METADATA_KEY) =>
-        analyze(analyzer, f, "INSERT")
+      case f if f.metadata.contains(CURRENT_DEFAULT_COLUMN_METADATA_KEY) => analyze(f, "INSERT")
       case _ => Literal(null)
     }
     // Check the type of `input` and replace its expressions accordingly.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala
index 319095e541c..2885f986236 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala
@@ -19,15 +19,20 @@ package org.apache.spark.sql.catalyst.util
 
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.Analyzer
+import org.apache.spark.sql.catalyst.analysis._
+import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog, SessionCatalog}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.{Literal => ExprLiteral}
 import org.apache.spark.sql.catalyst.optimizer.ConstantFolding
 import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException}
 import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.connector.catalog.{CatalogManager, FunctionCatalog, Identifier}
+import org.apache.spark.sql.connector.catalog.functions.UnboundFunction
 import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.connector.V1Function
 import org.apache.spark.sql.types._
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
 
 /**
  * This object contains fields to help process DEFAULT columns.
@@ -81,14 +86,12 @@ object ResolveDefaultColumns {
    * data source then takes responsibility to provide the constant-folded value in the
    * EXISTS_DEFAULT metadata for such columns where the value is not present in storage.
    *
-   * @param analyzer      used for analyzing the result of parsing the expression stored as text.
    * @param tableSchema   represents the names and types of the columns of the statement to process.
    * @param tableProvider provider of the target table to store default values for, if any.
    * @param statementType name of the statement being processed, such as INSERT; useful for errors.
    * @return a copy of `tableSchema` with field metadata updated with the constant-folded values.
    */
   def constantFoldCurrentDefaultsToExistDefaults(
-      analyzer: Analyzer,
       tableSchema: StructType,
       tableProvider: Option[String],
       statementType: String): StructType = {
@@ -103,7 +106,7 @@ object ResolveDefaultColumns {
           if (!allowedTableProviders.contains(givenTableProvider)) {
             throw QueryCompilationErrors.defaultReferencesNotAllowedInDataSource(givenTableProvider)
           }
-          val analyzed: Expression = analyze(analyzer, field, statementType)
+          val analyzed: Expression = analyze(field, statementType)
           val newMetadata: Metadata = new MetadataBuilder().withMetadata(field.metadata)
             .putString(EXISTS_DEFAULT_COLUMN_METADATA_KEY, analyzed.sql).build()
           field.copy(metadata = newMetadata)
@@ -125,10 +128,7 @@ object ResolveDefaultColumns {
    * @param statementType which type of statement we are running, such as INSERT; useful for errors.
    * @return Result of the analysis and constant-folding operation.
    */
-  def analyze(
-      analyzer: Analyzer,
-      field: StructField,
-      statementType: String): Expression = {
+  def analyze(field: StructField, statementType: String): Expression = {
     // Parse the expression.
     val colText: String = field.metadata.getString(CURRENT_DEFAULT_COLUMN_METADATA_KEY)
     lazy val parser = new CatalystSqlParser()
@@ -143,6 +143,7 @@ object ResolveDefaultColumns {
     }
     // Analyze the parse result.
     val plan = try {
+      val analyzer: Analyzer = DefaultColumnAnalyzer
       val analyzed = analyzer.execute(Project(Seq(Alias(parsed, field.name)()), OneRowRelation()))
       analyzer.checkAnalysis(analyzed)
       ConstantFolding(analyzed)
@@ -241,4 +242,35 @@ object ResolveDefaultColumns {
       }
     }
   }
+
+  /**
+   * This is an Analyzer for processing default column values using built-in functions only.
+   */
+  object DefaultColumnAnalyzer extends Analyzer(
+    new CatalogManager(BuiltInFunctionCatalog, BuiltInFunctionCatalog.v1Catalog)) {
+  }
+
+  /**
+   * This is a FunctionCatalog for performing analysis using built-in functions only. It is a helper
+   * for the DefaultColumnAnalyzer above.
+   */
+  object BuiltInFunctionCatalog extends FunctionCatalog {
+    val v1Catalog = new SessionCatalog(
+      new InMemoryCatalog, FunctionRegistry.builtin, TableFunctionRegistry.builtin) {
+      override def createDatabase(
+          dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit = {}
+    }
+    import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
+    override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = {}
+    override def name(): String = CatalogManager.SESSION_CATALOG_NAME
+    override def listFunctions(namespace: Array[String]): Array[Identifier] = {
+      throw new UnsupportedOperationException()
+    }
+    override def loadFunction(ident: Identifier): UnboundFunction = {
+      V1Function(v1Catalog.lookupPersistentFunction(ident.asFunctionIdentifier))
+    }
+    override def functionExists(ident: Identifier): Boolean = {
+      v1Catalog.isPersistentFunction(ident.asFunctionIdentifier)
+    }
+  }
 }
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala
index bf9ce8791b6..da5e07d33c6 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala
@@ -156,25 +156,24 @@ abstract class SessionCatalogSuite extends AnalysisTest with Eventually {
       assert(defaultValueColumnE == "41 + 1")
 
       // Analyze the default column values.
-      val analyzer = new Analyzer(new SessionCatalog(new InMemoryCatalog, FunctionRegistry.builtin))
       val statementType = "CREATE TABLE"
-      assert(ResolveDefaultColumns.analyze(analyzer, columnA, statementType).sql == "42")
-      assert(ResolveDefaultColumns.analyze(analyzer, columnB, statementType).sql == "'abc'")
+      assert(ResolveDefaultColumns.analyze(columnA, statementType).sql == "42")
+      assert(ResolveDefaultColumns.analyze(columnB, statementType).sql == "'abc'")
       assert(intercept[AnalysisException] {
-        ResolveDefaultColumns.analyze(analyzer, columnC, statementType)
+        ResolveDefaultColumns.analyze(columnC, statementType)
       }.getMessage.contains("fails to parse as a valid expression"))
       assert(intercept[AnalysisException] {
-        ResolveDefaultColumns.analyze(analyzer, columnD, statementType)
+        ResolveDefaultColumns.analyze(columnD, statementType)
       }.getMessage.contains("fails to resolve as a valid expression"))
       assert(intercept[AnalysisException] {
-        ResolveDefaultColumns.analyze(analyzer, columnE, statementType)
+        ResolveDefaultColumns.analyze(columnE, statementType)
       }.getMessage.contains("statement provided a value of incompatible type"))
 
       // Make sure that constant-folding default values does not take place when the feature is
       // disabled.
       withSQLConf(SQLConf.ENABLE_DEFAULT_COLUMNS.key -> "false") {
         val result: StructType = ResolveDefaultColumns.constantFoldCurrentDefaultsToExistDefaults(
-          analyzer, db1tbl3.schema, db1tbl3.provider, "CREATE TABLE")
+          db1tbl3.schema, db1tbl3.provider, "CREATE TABLE")
         val columnEWithFeatureDisabled: StructField = findField("e", result)
         // No constant-folding has taken place to the EXISTS_DEFAULT metadata.
         assert(!columnEWithFeatureDisabled.metadata.contains("EXISTS_DEFAULT"))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
index 5cdcf33d6cd..3432258f4ef 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
@@ -364,8 +364,7 @@ case class AlterTableChangeColumnCommand(
             // Check that the proposed default value parses and analyzes correctly, and that the
             // type of the resulting expression is equivalent or coercible to the destination column
             // type.
-            ResolveDefaultColumns.analyze(
-              sparkSession.sessionState.analyzer, result, "ALTER TABLE ALTER COLUMN")
+            ResolveDefaultColumns.analyze(result, "ALTER TABLE ALTER COLUMN")
             result
           } else {
             withNewComment.clearCurrentDefaultValue()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
index 07246033265..d8c5c04082b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
@@ -290,8 +290,7 @@ case class AlterTableAddColumnsCommand(
     colsToAdd.map { col: StructField =>
       if (col.metadata.contains(ResolveDefaultColumns.CURRENT_DEFAULT_COLUMN_METADATA_KEY)) {
         val foldedStructType = ResolveDefaultColumns.constantFoldCurrentDefaultsToExistDefaults(
-          sparkSession.sessionState.analyzer, StructType(Seq(col)), tableProvider,
-          "ALTER TABLE ADD COLUMNS")
+          StructType(Seq(col)), tableProvider, "ALTER TABLE ADD COLUMNS")
         foldedStructType.fields(0)
       } else {
         col
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index a82c222ea1c..294889ec449 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -149,7 +149,7 @@ case class DataSourceAnalysis(analyzer: Analyzer) extends Rule[LogicalPlan] {
     case CreateTable(tableDesc, mode, None) if DDLUtils.isDatasourceTable(tableDesc) =>
       val newSchema: StructType =
         ResolveDefaultColumns.constantFoldCurrentDefaultsToExistDefaults(
-          analyzer, tableDesc.schema, tableDesc.provider, "CREATE TABLE")
+          tableDesc.schema, tableDesc.provider, "CREATE TABLE")
       val newTableDesc = tableDesc.copy(schema = newSchema)
       CreateDataSourceTableCommand(newTableDesc, ignoreIfExists = mode == SaveMode.Ignore)
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
index 401427ac41e..2add527a359 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
@@ -173,7 +173,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
         tableSpec, ifNotExists) =>
       val newSchema: StructType =
         ResolveDefaultColumns.constantFoldCurrentDefaultsToExistDefaults(
-          session.sessionState.analyzer, schema, tableSpec.provider, "CREATE TABLE")
+          schema, tableSpec.provider, "CREATE TABLE")
       CreateTableExec(catalog.asTableCatalog, ident.asIdentifier, newSchema,
         partitioning, qualifyLocInTableSpec(tableSpec), ifNotExists) :: Nil
 
@@ -195,7 +195,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
     case ReplaceTable(ResolvedDBObjectName(catalog, ident), schema, parts, tableSpec, orCreate) =>
       val newSchema: StructType =
         ResolveDefaultColumns.constantFoldCurrentDefaultsToExistDefaults(
-          session.sessionState.analyzer, schema, tableSpec.provider, "CREATE TABLE")
+          schema, tableSpec.provider, "CREATE TABLE")
       catalog match {
         case staging: StagingTableCatalog =>
           AtomicReplaceTableExec(staging, ident.asIdentifier, newSchema, parts,


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org