You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2023/03/20 12:16:48 UTC

[spark] branch master updated: [SPARK-42720][PS][SQL] Uses expression for distributed-sequence default index instead of plan

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 551cda9f3af [SPARK-42720][PS][SQL] Uses expression for distributed-sequence default index instead of plan
551cda9f3af is described below

commit 551cda9f3af228a11c9c3a1aea25184baa362d9c
Author: Hyukjin Kwon <gu...@apache.org>
AuthorDate: Mon Mar 20 21:16:03 2023 +0900

    [SPARK-42720][PS][SQL] Uses expression for distributed-sequence default index instead of plan
    
    ### What changes were proposed in this pull request?
    
    This PR replaces `DataFrame.withSequenceColumn` to `DataFrame.select(distributed_sequence_column, col("*")` internally because this essentially attaches a column and it should be treated as a scalar expression at the logical level.
    
    This is used to generate the unique index only for pandas API on Spark.
    
    ### Why are the changes needed?
    
    For better readability of codes, and for cleaner definition of Spark Connect protobuf message, see also https://github.com/apache/spark/pull/40270.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No, it's internal change only.
    
    ### How was this patch tested?
    
    Existing test cases in pandas API on Spark verify this change.
    
    Closes #40456 from HyukjinKwon/SPARK-42720.
    
    Lead-authored-by: Hyukjin Kwon <gu...@apache.org>
    Co-authored-by: itholic <ha...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
 .../spark/sql/catalyst/analysis/Analyzer.scala     |  4 +-
 .../analysis/ExtractDistributedSequenceID.scala    | 43 ++++++++++++++++++++++
 .../expressions/DistributedSequenceID.scala        | 42 +++++++++++++++++++++
 .../spark/sql/catalyst/trees/TreePatterns.scala    |  1 +
 .../main/scala/org/apache/spark/sql/Dataset.scala  |  6 +--
 .../org/apache/spark/sql/DataFrameSuite.scala      |  6 +--
 6 files changed, 93 insertions(+), 9 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index d069d639a4a..3a2dff78cba 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -322,7 +322,9 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
       ResolveUnion ::
       RewriteDeleteFromTable ::
       typeCoercionRules ++
-      Seq(ResolveWithCTE) ++
+      Seq(
+        ResolveWithCTE,
+        ExtractDistributedSequenceID) ++
       extendedResolutionRules : _*),
     Batch("Remove TempResolvedColumn", Once, RemoveTempResolvedColumn),
     Batch("Post-Hoc Resolution", Once,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ExtractDistributedSequenceID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ExtractDistributedSequenceID.scala
new file mode 100644
index 00000000000..bf6ab8e5061
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ExtractDistributedSequenceID.scala
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, DistributedSequenceID}
+import org.apache.spark.sql.catalyst.plans.logical.{AttachDistributedSequence, LogicalPlan, Project}
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.trees.TreePattern.DISTRIBUTED_SEQUENCE_ID
+import org.apache.spark.sql.types.LongType
+
+/**
+ * Extracts [[DistributedSequenceID]] in logical plans, and replace it to
+ * [[AttachDistributedSequence]] because this expressions requires a shuffle
+ * to generate a sequence that needs the context of the whole data, e.g.,
+ * [[org.apache.spark.rdd.RDD.zipWithIndex]].
+ */
+object ExtractDistributedSequenceID extends Rule[LogicalPlan] {
+  def apply(plan: LogicalPlan): LogicalPlan = {
+    plan.resolveOperatorsUpWithPruning(_.containsPattern(DISTRIBUTED_SEQUENCE_ID)) {
+      case plan: LogicalPlan if plan.resolved &&
+          plan.expressions.exists(_.exists(_.isInstanceOf[DistributedSequenceID])) =>
+        val attr = AttributeReference("distributed_sequence_id", LongType, nullable = false)()
+        val newPlan = plan.withNewChildren(plan.children.map(AttachDistributedSequence(attr, _)))
+          .transformExpressions { case _: DistributedSequenceID => attr }
+        Project(plan.output, newPlan)
+    }
+  }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DistributedSequenceID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DistributedSequenceID.scala
new file mode 100644
index 00000000000..5a0bff990e6
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DistributedSequenceID.scala
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.sql.catalyst.trees.TreePattern.{DISTRIBUTED_SEQUENCE_ID, TreePattern}
+import org.apache.spark.sql.types.{DataType, LongType}
+
+/**
+ * Returns increasing 64-bit integers consecutive from 0.
+ * The generated ID is guaranteed to be increasing consecutive started from 0.
+ *
+ * @note this expression is dedicated for Pandas API on Spark to use.
+ */
+case class DistributedSequenceID() extends LeafExpression with Unevaluable with NonSQLExpression {
+
+  override def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = {
+    DistributedSequenceID()
+  }
+
+  override def nullable: Boolean = false
+
+  override def dataType: DataType = LongType
+
+  final override val nodePatterns: Seq[TreePattern] = Seq(DISTRIBUTED_SEQUENCE_ID)
+
+  override def nodeName: String = "distributed_sequence_id"
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
index 37d3ada5349..8e904cf3c16 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
@@ -97,6 +97,7 @@ object TreePattern extends Enumeration  {
   val UPDATE_FIELDS: Value = Value
   val UPPER_OR_LOWER: Value = Value
   val UP_CAST: Value = Value
+  val DISTRIBUTED_SEQUENCE_ID: Value = Value
 
   // Logical plan patterns (alphabetically ordered)
   val AGGREGATE: Value = Value
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index edcfad0c798..57da3b5af60 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -3973,11 +3973,7 @@ class Dataset[T] private[sql](
    * This is for 'distributed-sequence' default index in pandas API on Spark.
    */
   private[sql] def withSequenceColumn(name: String) = {
-    Dataset.ofRows(
-      sparkSession,
-      AttachDistributedSequence(
-        AttributeReference(name, LongType, nullable = false)(),
-        logicalPlan))
+    select(Column(DistributedSequenceID()).alias(name), col("*"))
   }
 
   /**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index bf8d7816e47..a15c049715b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -3316,9 +3316,9 @@ class DataFrameSuite extends QueryTest
   }
 
   test("SPARK-36338: DataFrame.withSequenceColumn should append unique sequence IDs") {
-    val ids = spark.range(10).repartition(5)
-      .withSequenceColumn("default_index").collect().map(_.getLong(0))
-    assert(ids.toSet === Range(0, 10).toSet)
+    val ids = spark.range(10).repartition(5).withSequenceColumn("default_index")
+    assert(ids.collect().map(_.getLong(0)).toSet === Range(0, 10).toSet)
+    assert(ids.take(5).map(_.getLong(0)).toSet === Range(0, 5).toSet)
   }
 
   test("SPARK-35320: Reading JSON with key type different to String in a map should fail") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org