You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by hv...@apache.org on 2021/04/09 13:07:15 UTC

[spark] branch master updated: [SPARK-34989] Improve the performance of mapChildren and withNewChildren methods

This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 0945baf  [SPARK-34989] Improve the performance of mapChildren and withNewChildren methods
0945baf is described below

commit 0945baf90660a101ae0f86a39d4c91ca74ae5ee3
Author: Ali Afroozeh <al...@databricks.com>
AuthorDate: Fri Apr 9 15:06:26 2021 +0200

    [SPARK-34989] Improve the performance of mapChildren and withNewChildren methods
    
    ### What changes were proposed in this pull request?
    One of the main performance bottlenecks in query compilation is overly-generic tree transformation methods, namely `mapChildren` and `withNewChildren` (defined in `TreeNode`). These methods have an overly-generic implementation to iterate over the children and rely on reflection to create new instances. We have observed that, especially for queries with large query plans, a significant amount of CPU cycles are wasted in these methods. In this PR we make these methods more efficient, b [...]
    
    #### Problem detail
    The `mapChildren` method in `TreeNode` is overly generic and costly. To be more specific, this method:
    - iterates over all the fields of a node using Scala’s product iterator. While the iteration is not reflection-based, thanks to the Scala compiler generating code for `Product`, we create many anonymous functions and visit many nested structures (recursive calls).
    The anonymous functions (presumably compiled to Java anonymous inner classes) also show up quite high on the list in the object allocation profiles, so we are putting unnecessary pressure on GC here.
    - does a lot of comparisons. Basically for each element returned from the product iterator, we check if it is a child (contained in the list of children) and then transform it. We can avoid that by just iterating over children, but in the current implementation, we need to gather all the fields (only transform the children) so that we can instantiate the object using the reflection.
    - creates objects using reflection, by delegating to the `makeCopy` method, which is several orders of magnitude slower than using the constructor.
    
    #### Solution
    The proposed solution in this PR is rather straightforward: we rewrite the `mapChildren` method using the `children` and `withNewChildren` methods. The default `withNewChildren` method suffers from the same problems as `mapChildren` and we need to make it more efficient by specializing it in concrete classes.  Similar to how each concrete query plan node already defines its children, it should also define how they can be constructed given a new list of children. Actually, the implemen [...]
    ```
    override def withNewChildren(newChildren: Seq[LogicalPlan]): LogicalPlan = copy(children = newChildren)
    ```
    The current `withNewChildren` method has two properties that we should preserve:
    
    - It returns the same instance if the provided children are the same as its children, i.e., it preserves referential equality.
    - It copies tags and maintains the origin links when a new copy is created.
    
    These properties are hard to enforce in the concrete node type implementation. Therefore, we propose a template method `withNewChildrenInternal` that should be rewritten by the concrete classes and let the `withNewChildren` method take care of referential equality and copying:
    ```
    override def withNewChildren(newChildren: Seq[LogicalPlan]): LogicalPlan = {
     if (childrenFastEquals(children, newChildren)) {
       this
     } else {
       CurrentOrigin.withOrigin(origin) {
         val res = withNewChildrenInternal(newChildren)
         res.copyTagsFrom(this)
         res
       }
     }
    }
    ```
    
    With the refactoring done in a previous PR (https://github.com/apache/spark/pull/31932) most tree node types fall in one of the categories of `Leaf`, `Unary`, `Binary` or `Ternary`. These traits have a more efficient implementation for `mapChildren` and define a more specialized version of `withNewChildrenInternal` that avoids creating unnecessary lists. For example, the `mapChildren` method in `UnaryLike` is defined as follows:
    ```
      override final def mapChildren(f: T => T): T = {
        val newChild = f(child)
        if (newChild fastEquals child) {
          this.asInstanceOf[T]
        } else {
          CurrentOrigin.withOrigin(origin) {
            val res = withNewChildInternal(newChild)
            res.copyTagsFrom(this.asInstanceOf[T])
            res
          }
        }
      }
    ```
    
    #### Results
    With this PR, we have observed significant performance improvements in query compilation time, more specifically in the analysis and optimization phases. The table below shows the TPC-DS queries that had more than 25% speedup in compilation times. Biggest speedups are observed in queries with large query plans.
    | Query  | Speedup |
    | ------------- | ------------- |
    |q4    |29%|
    |q9    |81%|
    |q14a  |31%|
    |q14b  |28%|
    |q22   |33%|
    |q33   |29%|
    |q34   |25%|
    |q39   |27%|
    |q41   |27%|
    |q44   |26%|
    |q47   |28%|
    |q48   |76%|
    |q49   |46%|
    |q56   |26%|
    |q58   |43%|
    |q59   |46%|
    |q60   |50%|
    |q65   |59%|
    |q66   |46%|
    |q67   |52%|
    |q69   |31%|
    |q70   |30%|
    |q96   |26%|
    |q98   |32%|
    
    #### Binary incompatibility
    Changing the `withNewChildren` in `TreeNode` breaks the binary compatibility of the code compiled against older versions of Spark because now it is expected that concrete `TreeNode` subclasses all implement the `withNewChildrenInternal` method. This is a problem, for example, when users write custom expressions. This change is the right choice, since it forces all newly added expressions to Catalyst implement it in an efficient manner and will prevent future regressions.
    Please note that we have not completely removed the old implementation and renamed it to `legacyWithNewChildren`. This method will be removed in the future and for now helps the transition. There are expressions such as `UpdateFields` that have a complex way of defining children. Writing `withNewChildren` for them requires refactoring the expression. For now, these expressions use the old, slow method. In a future PR we address these expressions.
    
    ### Does this PR introduce _any_ user-facing change?
    
    This PR does not introduce user facing changes but my break binary compatibility of the code compiled against older versions. See the binary compatibility section.
    
    ### How was this patch tested?
    
    This PR is mainly a refactoring and passes existing tests.
    
    Closes #32030 from dbaliafroozeh/ImprovedMapChildren.
    
    Authored-by: Ali Afroozeh <al...@databricks.com>
    Signed-off-by: herman <he...@databricks.com>
---
 .../apache/spark/sql/avro/AvroDataToCatalyst.scala |   3 +
 .../apache/spark/sql/avro/CatalystDataToAvro.scala |   3 +
 .../org/apache/spark/ml/stat/Summarizer.scala      |   4 +
 .../spark/sql/catalyst/analysis/unresolved.scala   |  30 ++++
 .../expressions/CallMethodViaReflection.scala      |   3 +
 .../spark/sql/catalyst/expressions/Cast.scala      |   6 +
 .../sql/catalyst/expressions/DynamicPruning.scala  |   6 +
 .../catalyst/expressions/PartitionTransforms.scala |   5 +
 .../spark/sql/catalyst/expressions/PythonUDF.scala |   3 +
 .../spark/sql/catalyst/expressions/ScalaUDF.scala  |   3 +
 .../spark/sql/catalyst/expressions/SortOrder.scala |   6 +
 .../expressions/SubExprEvaluationRuntime.scala     |   3 +
 .../sql/catalyst/expressions/TimeWindow.scala      |   6 +
 .../spark/sql/catalyst/expressions/TryCast.scala   |   3 +
 .../ApproxCountDistinctForIntervals.scala          |   4 +
 .../aggregate/ApproximatePercentile.scala          |   4 +
 .../catalyst/expressions/aggregate/Average.scala   |   3 +
 .../expressions/aggregate/CentralMomentAgg.scala   |  18 ++
 .../sql/catalyst/expressions/aggregate/Corr.scala  |   3 +
 .../sql/catalyst/expressions/aggregate/Count.scala |   3 +
 .../catalyst/expressions/aggregate/CountIf.scala   |   3 +
 .../expressions/aggregate/CountMinSketchAgg.scala  |   8 +
 .../expressions/aggregate/Covariance.scala         |   7 +
 .../sql/catalyst/expressions/aggregate/First.scala |   2 +
 .../aggregate/HyperLogLogPlusPlus.scala            |   3 +
 .../sql/catalyst/expressions/aggregate/Last.scala  |   2 +
 .../sql/catalyst/expressions/aggregate/Max.scala   |   2 +
 .../expressions/aggregate/MaxByAndMinBy.scala      |   6 +
 .../sql/catalyst/expressions/aggregate/Min.scala   |   2 +
 .../expressions/aggregate/Percentile.scala         |   7 +
 .../expressions/aggregate/PivotFirst.scala         |   4 +
 .../catalyst/expressions/aggregate/Product.scala   |   3 +
 .../sql/catalyst/expressions/aggregate/Sum.scala   |   2 +
 .../expressions/aggregate/UnevaluableAggs.scala    |   4 +
 .../expressions/aggregate/bitwiseAggregates.scala  |   9 +
 .../catalyst/expressions/aggregate/collect.scala   |   6 +
 .../expressions/aggregate/interfaces.scala         |  10 ++
 .../sql/catalyst/expressions/arithmetic.scala      |  36 ++++
 .../catalyst/expressions/bitwiseExpressions.scala  |  18 ++
 .../catalyst/expressions/codegen/javaCode.scala    |   8 +-
 .../expressions/collectionOperations.scala         |  96 ++++++++++
 .../catalyst/expressions/complexTypeCreator.scala  |  26 +++
 .../expressions/complexTypeExtractors.scala        |  14 ++
 .../expressions/conditionalExpressions.scala       |  10 ++
 .../expressions/constraintExpressions.scala        |   8 +-
 .../sql/catalyst/expressions/csvExpressions.scala  |   9 +
 .../catalyst/expressions/datetimeExpressions.scala | 162 +++++++++++++++++
 .../catalyst/expressions/decimalExpressions.scala  |  15 ++
 .../sql/catalyst/expressions/generators.scala      |  18 ++
 .../spark/sql/catalyst/expressions/grouping.scala  |  11 ++
 .../spark/sql/catalyst/expressions/hash.scala      |  18 ++
 .../expressions/higherOrderFunctions.scala         |  58 +++++-
 .../catalyst/expressions/intervalExpressions.scala |  67 ++++++-
 .../sql/catalyst/expressions/jsonExpressions.scala |  22 +++
 .../sql/catalyst/expressions/mathExpressions.scala | 114 ++++++++++--
 .../spark/sql/catalyst/expressions/misc.scala      |  11 ++
 .../catalyst/expressions/namedExpressions.scala    |   3 +
 .../sql/catalyst/expressions/nullExpressions.scala |  24 +++
 .../sql/catalyst/expressions/objects/objects.scala |  70 +++++++-
 .../sql/catalyst/expressions/predicates.scala      |  36 ++++
 .../catalyst/expressions/randomExpressions.scala   |   4 +
 .../catalyst/expressions/regexpExpressions.scala   |  30 ++++
 .../catalyst/expressions/stringExpressions.scala   | 134 ++++++++++++++
 .../spark/sql/catalyst/expressions/subquery.scala  |   9 +
 .../catalyst/expressions/windowExpressions.scala   |  39 ++++
 .../spark/sql/catalyst/expressions/xml/xpath.scala |  24 +++
 .../catalyst/optimizer/CostBasedJoinReorder.scala  |   3 +
 .../optimizer/NormalizeFloatingNumbers.scala       |   3 +
 .../plans/logical/EventTimeWatermark.scala         |   3 +
 .../plans/logical/ScriptTransformation.scala       |   3 +
 .../plans/logical/basicLogicalOperators.scala      |  73 ++++++++
 .../spark/sql/catalyst/plans/logical/hints.scala   |   6 +
 .../spark/sql/catalyst/plans/logical/object.scala  |  52 +++++-
 .../plans/logical/pythonLogicalOperators.scala     |  20 ++-
 .../sql/catalyst/plans/logical/statements.scala    |  11 +-
 .../sql/catalyst/plans/logical/v2Commands.scala    | 196 ++++++++++++++++++---
 .../sql/catalyst/plans/physical/partitioning.scala |  11 ++
 .../sql/catalyst/streaming/WriteToStream.scala     |   2 +
 .../streaming/WriteToStreamStatement.scala         |   3 +
 .../apache/spark/sql/catalyst/trees/TreeNode.scala | 151 +++++++++++++++-
 .../sql/catalyst/analysis/AnalysisErrorSuite.scala |   2 +
 .../sql/catalyst/analysis/TypeCoercionSuite.scala  |  10 ++
 .../analysis/UnsupportedOperationsSuite.scala      |   2 +
 .../SubexpressionEliminationSuite.scala            |   2 +
 .../optimizer/ConvertToLocalRelationSuite.scala    |   3 +
 .../sql/catalyst/plans/LogicalPlanSuite.scala      |   3 +
 .../plans/logical/LogicalPlanIntegritySuite.scala  |   2 +
 .../spark/sql/catalyst/trees/TreeNodeSuite.scala   |  12 +-
 .../spark/sql/execution/CollectMetricsExec.scala   |   3 +
 .../org/apache/spark/sql/execution/Columnar.scala  |   6 +
 .../apache/spark/sql/execution/ExpandExec.scala    |   3 +
 .../apache/spark/sql/execution/GenerateExec.scala  |   3 +
 .../org/apache/spark/sql/execution/SortExec.scala  |   3 +
 .../execution/SparkScriptTransformationExec.scala  |   3 +
 .../execution/SubqueryAdaptiveBroadcastExec.scala  |   3 +
 .../sql/execution/SubqueryBroadcastExec.scala      |   3 +
 .../sql/execution/WholeStageCodegenExec.scala      |   6 +
 .../adaptive/CustomShuffleReaderExec.scala         |   3 +
 .../execution/aggregate/HashAggregateExec.scala    |   3 +
 .../aggregate/ObjectHashAggregateExec.scala        |   3 +
 .../execution/aggregate/SortAggregateExec.scala    |   3 +
 .../aggregate/TypedAggregateExpression.scala       |   8 +
 .../spark/sql/execution/aggregate/udaf.scala       |   7 +
 .../sql/execution/basicPhysicalOperators.scala     |  18 ++
 .../execution/command/AnalyzeColumnCommand.scala   |   2 +-
 .../command/AnalyzePartitionCommand.scala          |   2 +-
 .../execution/command/AnalyzeTableCommand.scala    |   2 +-
 .../execution/command/AnalyzeTablesCommand.scala   |   2 +-
 .../command/InsertIntoDataSourceDirCommand.scala   |   2 +-
 .../spark/sql/execution/command/SetCommand.scala   |   5 +-
 .../apache/spark/sql/execution/command/cache.scala |   2 +-
 .../spark/sql/execution/command/commands.scala     |  12 +-
 .../execution/command/createDataSourceTables.scala |   5 +-
 .../apache/spark/sql/execution/command/ddl.scala   |  30 ++--
 .../spark/sql/execution/command/functions.scala    |  10 +-
 .../spark/sql/execution/command/resources.scala    |  13 +-
 .../spark/sql/execution/command/tables.scala       |  30 ++--
 .../apache/spark/sql/execution/command/views.scala |   6 +-
 .../execution/datasources/FileFormatWriter.scala   |   3 +
 .../datasources/InsertIntoDataSourceCommand.scala  |   4 +-
 .../InsertIntoHadoopFsRelationCommand.scala        |   3 +
 .../datasources/SaveIntoDataSourceCommand.scala    |   4 +-
 .../spark/sql/execution/datasources/ddl.scala      |  10 +-
 .../datasources/v2/WriteToDataSourceV2Exec.scala   |  32 +++-
 .../apache/spark/sql/execution/debug/package.scala |   3 +
 .../execution/exchange/BroadcastExchangeExec.scala |   3 +
 .../execution/exchange/ShuffleExchangeExec.scala   |   3 +
 .../execution/joins/BroadcastHashJoinExec.scala    |   4 +
 .../joins/BroadcastNestedLoopJoinExec.scala        |   4 +
 .../sql/execution/joins/CartesianProductExec.scala |   4 +
 .../sql/execution/joins/ShuffledHashJoinExec.scala |   4 +
 .../sql/execution/joins/SortMergeJoinExec.scala    |   4 +
 .../org/apache/spark/sql/execution/limit.scala     |  17 +-
 .../org/apache/spark/sql/execution/objects.scala   |  33 ++++
 .../execution/python/AggregateInPandasExec.scala   |   3 +
 .../sql/execution/python/ArrowEvalPythonExec.scala |   3 +
 .../sql/execution/python/BatchEvalPythonExec.scala |   3 +
 .../python/FlatMapCoGroupsInPandasExec.scala       |   4 +
 .../python/FlatMapGroupsInPandasExec.scala         |   3 +
 .../sql/execution/python/MapInPandasExec.scala     |   3 +
 .../sql/execution/python/WindowInPandasExec.scala  |   3 +
 .../streaming/EventTimeWatermarkExec.scala         |   3 +
 .../streaming/FlatMapGroupsWithStateExec.scala     |   3 +
 .../streaming/StreamingSymmetricHashJoinExec.scala |   4 +
 .../continuous/WriteToContinuousDataSource.scala   |   2 +
 .../WriteToContinuousDataSourceExec.scala          |   3 +
 .../sources/WriteToMicroBatchDataSource.scala      |   3 +
 .../execution/streaming/statefulOperators.scala    |   9 +
 .../sql/execution/streaming/streamingLimits.scala  |   6 +
 .../org/apache/spark/sql/execution/subquery.scala  |   3 +
 .../spark/sql/execution/window/WindowExec.scala    |   3 +
 .../approved-plans-v1_4/q14a.sf100/explain.txt     |   2 +-
 .../approved-plans-v1_4/q14a/explain.txt           |   2 +-
 .../approved-plans-v1_4/q5.sf100/explain.txt       |   2 +-
 .../approved-plans-v1_4/q5/explain.txt             |   2 +-
 .../approved-plans-v1_4/q77.sf100/explain.txt      |   2 +-
 .../approved-plans-v1_4/q77/explain.txt            |   2 +-
 .../approved-plans-v1_4/q80.sf100/explain.txt      |   2 +-
 .../approved-plans-v1_4/q80/explain.txt            |   2 +-
 .../apache/spark/sql/DataFrameFunctionsSuite.scala |   2 +
 .../apache/spark/sql/ExtraStrategiesSuite.scala    |   5 +-
 .../spark/sql/SparkSessionExtensionSuite.scala     |  21 +++
 .../spark/sql/TypedImperativeAggregateSuite.scala  |   6 +-
 .../execution/BaseScriptTransformationSuite.scala  |   3 +
 .../spark/sql/execution/ColumnarRulesSuite.scala   |   1 +
 .../apache/spark/sql/execution/ExchangeSuite.scala |   3 +
 .../apache/spark/sql/execution/PlannerSuite.scala  |   2 +
 .../apache/spark/sql/execution/ReferenceSort.scala |   3 +
 .../spark/sql/util/DataFrameCallbackSuite.scala    |   4 +-
 .../execution/CreateHiveTableAsSelectCommand.scala |   6 +
 .../execution/HiveScriptTransformationExec.scala   |   3 +
 .../hive/execution/InsertIntoHiveDirCommand.scala  |   3 +
 .../sql/hive/execution/InsertIntoHiveTable.scala   |   3 +
 .../scala/org/apache/spark/sql/hive/hiveUDFs.scala |  12 ++
 .../sql/hive/execution/TestingTypedCount.scala     |   3 +
 175 files changed, 2213 insertions(+), 146 deletions(-)

diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala
index 64fb588..b496500 100644
--- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala
+++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala
@@ -134,4 +134,7 @@ private[avro] case class AvroDataToCatalyst(
       """
     })
   }
+
+  override protected def withNewChildInternal(newChild: Expression): AvroDataToCatalyst =
+    copy(child = newChild)
 }
diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/CatalystDataToAvro.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/CatalystDataToAvro.scala
index 53910b7..5d79c44 100644
--- a/external/avro/src/main/scala/org/apache/spark/sql/avro/CatalystDataToAvro.scala
+++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/CatalystDataToAvro.scala
@@ -64,4 +64,7 @@ private[avro] case class CatalystDataToAvro(
     defineCodeGen(ctx, ev, input =>
       s"(byte[]) $expr.nullSafeEval($input)")
   }
+
+  override protected def withNewChildInternal(newChild: Expression): CatalystDataToAvro =
+    copy(child = newChild)
 }
diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala
index 109ccbd..a3dd133 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala
@@ -374,6 +374,10 @@ private[spark] object SummaryBuilderImpl extends Logging {
     override def left: Expression = featuresExpr
     override def right: Expression = weightExpr
 
+    override protected def withNewChildrenInternal(
+        newLeft: Expression, newRight: Expression): MetricsAggregate =
+      copy(featuresExpr = newLeft, weightExpr = newRight)
+
     override def update(state: SummarizerBuffer, row: InternalRow): SummarizerBuffer = {
       val features = vectorUDT.deserialize(featuresExpr.eval(row))
       val weight = weightExpr.eval(row).asInstanceOf[Double]
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index 3fc3db3..3b2f4ca 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -263,6 +263,9 @@ case class UnresolvedGenerator(name: FunctionIdentifier, children: Seq[Expressio
 
   override def terminate(): TraversableOnce[InternalRow] =
     throw QueryExecutionErrors.cannotTerminateGeneratorError(this)
+
+  override protected def withNewChildrenInternal(
+    newChildren: IndexedSeq[Expression]): UnresolvedGenerator = copy(children = newChildren)
 }
 
 case class UnresolvedFunction(
@@ -284,6 +287,15 @@ case class UnresolvedFunction(
     val distinct = if (isDistinct) "distinct " else ""
     s"'$name($distinct${children.mkString(", ")})"
   }
+
+  override protected def withNewChildrenInternal(
+      newChildren: IndexedSeq[Expression]): UnresolvedFunction = {
+    if (filter.isDefined) {
+      copy(arguments = newChildren.dropRight(1), filter = Some(newChildren.last))
+    } else {
+      copy(arguments = newChildren)
+    }
+  }
 }
 
 object UnresolvedFunction {
@@ -441,6 +453,8 @@ case class MultiAlias(child: Expression, names: Seq[String])
 
   override def toString: String = s"$child AS $names"
 
+  override protected def withNewChildInternal(newChild: Expression): MultiAlias =
+    copy(child = newChild)
 }
 
 /**
@@ -475,6 +489,11 @@ case class UnresolvedExtractValue(child: Expression, extraction: Expression)
 
   override def toString: String = s"$child[$extraction]"
   override def sql: String = s"${child.sql}[${extraction.sql}]"
+
+  override protected def withNewChildrenInternal(
+    newLeft: Expression, newRight: Expression): UnresolvedExtractValue = {
+      copy(child = newLeft, extraction = newRight)
+  }
 }
 
 /**
@@ -499,6 +518,9 @@ case class UnresolvedAlias(
   override def newInstance(): NamedExpression = throw new UnresolvedException("newInstance")
 
   override lazy val resolved = false
+
+  override protected def withNewChildInternal(newChild: Expression): UnresolvedAlias =
+    copy(child = newChild)
 }
 
 /**
@@ -520,6 +542,9 @@ case class UnresolvedSubqueryColumnAliases(
   override def output: Seq[Attribute] = Nil
 
   override lazy val resolved = false
+
+  override protected def withNewChildInternal(
+    newChild: LogicalPlan): UnresolvedSubqueryColumnAliases = copy(child = newChild)
 }
 
 /**
@@ -541,6 +566,9 @@ case class UnresolvedDeserializer(deserializer: Expression, inputAttributes: Seq
   override def dataType: DataType = throw new UnresolvedException("dataType")
   override def nullable: Boolean = throw new UnresolvedException("nullable")
   override lazy val resolved = false
+
+  override protected def withNewChildInternal(newChild: Expression): UnresolvedDeserializer =
+    copy(deserializer = newChild)
 }
 
 case class GetColumnByOrdinal(ordinal: Int, dataType: DataType) extends LeafExpression
@@ -587,6 +615,8 @@ case class UnresolvedHaving(
   extends UnaryNode {
   override lazy val resolved: Boolean = false
   override def output: Seq[Attribute] = child.output
+  override protected def withNewChildInternal(newChild: LogicalPlan): UnresolvedHaving =
+    copy(child = newChild)
 }
 
 /**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala
index 0de17d4..7cb830d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala
@@ -114,6 +114,9 @@ case class CallMethodViaReflection(children: Seq[Expression])
 
   /** A temporary buffer used to hold intermediate results returned by children. */
   @transient private lazy val buffer = new Array[Object](argExprs.length)
+
+  override protected def withNewChildrenInternal(
+    newChildren: IndexedSeq[Expression]): CallMethodViaReflection = copy(children = newChildren)
 }
 
 object CallMethodViaReflection {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 879b154..1e1b7ee 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -1812,6 +1812,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
   } else {
     s"cannot cast ${child.dataType.catalogString} to ${dataType.catalogString}"
   }
+
+  override protected def withNewChildInternal(newChild: Expression): Cast = copy(child = newChild)
 }
 
 /**
@@ -1841,6 +1843,8 @@ case class AnsiCast(child: Expression, dataType: DataType, timeZoneId: Option[St
       Some(SQLConf.STORE_ASSIGNMENT_POLICY.key),
       Some(SQLConf.StoreAssignmentPolicy.LEGACY.toString))
 
+  override protected def withNewChildInternal(newChild: Expression): AnsiCast =
+    copy(child = newChild)
 }
 
 object AnsiCast {
@@ -1998,4 +2002,6 @@ case class UpCast(child: Expression, target: AbstractDataType, walkedTypePath: S
     case DecimalType => DecimalType.SYSTEM_DEFAULT
     case _ => target.asInstanceOf[DataType]
   }
+
+  override protected def withNewChildInternal(newChild: Expression): UpCast = copy(child = newChild)
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala
index 550fa4c..de4b874 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala
@@ -78,6 +78,9 @@ case class DynamicPruningSubquery(
       buildKeys = buildKeys.map(_.canonicalized),
       exprId = ExprId(0))
   }
+
+  override protected def withNewChildInternal(newChild: Expression): DynamicPruningSubquery =
+    copy(pruningKey = newChild)
 }
 
 /**
@@ -94,4 +97,7 @@ case class DynamicPruningExpression(child: Expression)
   override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     child.genCode(ctx)
   }
+
+  override protected def withNewChildInternal(newChild: Expression): DynamicPruningExpression =
+    copy(child = newChild)
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PartitionTransforms.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PartitionTransforms.scala
index 05d5537..ab39061 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PartitionTransforms.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PartitionTransforms.scala
@@ -43,6 +43,7 @@ abstract class PartitionTransformExpression extends Expression with Unevaluable
  */
 case class Years(child: Expression) extends PartitionTransformExpression {
   override def dataType: DataType = IntegerType
+  override protected def withNewChildInternal(newChild: Expression): Years = copy(child = newChild)
 }
 
 /**
@@ -50,6 +51,7 @@ case class Years(child: Expression) extends PartitionTransformExpression {
  */
 case class Months(child: Expression) extends PartitionTransformExpression {
   override def dataType: DataType = IntegerType
+  override protected def withNewChildInternal(newChild: Expression): Months = copy(child = newChild)
 }
 
 /**
@@ -57,6 +59,7 @@ case class Months(child: Expression) extends PartitionTransformExpression {
  */
 case class Days(child: Expression) extends PartitionTransformExpression {
   override def dataType: DataType = IntegerType
+  override protected def withNewChildInternal(newChild: Expression): Days = copy(child = newChild)
 }
 
 /**
@@ -64,6 +67,7 @@ case class Days(child: Expression) extends PartitionTransformExpression {
  */
 case class Hours(child: Expression) extends PartitionTransformExpression {
   override def dataType: DataType = IntegerType
+  override protected def withNewChildInternal(newChild: Expression): Hours = copy(child = newChild)
 }
 
 /**
@@ -71,4 +75,5 @@ case class Hours(child: Expression) extends PartitionTransformExpression {
  */
 case class Bucket(numBuckets: Literal, child: Expression) extends PartitionTransformExpression {
   override def dataType: DataType = IntegerType
+  override protected def withNewChildInternal(newChild: Expression): Bucket = copy(child = newChild)
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala
index da2e182..73f8c30 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala
@@ -73,4 +73,7 @@ case class PythonUDF(
     // `resultId` can be seen as cosmetic variation in PythonUDF, as it doesn't affect the result.
     this.copy(resultId = ExprId(-1)).withNewChildren(canonicalizedChildren)
   }
+
+  override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): PythonUDF =
+    copy(children = newChildren)
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
index 4086e76..375ae95 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
@@ -1195,4 +1195,7 @@ case class ScalaUDF(
 
     resultConverter(result)
   }
+
+  override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): ScalaUDF =
+    copy(children = newChildren)
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
index d9923b5..9aef25c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
@@ -88,6 +88,9 @@ case class SortOrder(
     children.exists(required.child.semanticEquals) &&
       direction == required.direction && nullOrdering == required.nullOrdering
   }
+
+  override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): SortOrder =
+    copy(child = newChildren.head, sameOrderExpressions = newChildren.tail)
 }
 
 object SortOrder {
@@ -226,4 +229,7 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression {
   }
 
   override def dataType: DataType = LongType
+
+  override protected def withNewChildInternal(newChild: Expression): SortPrefix =
+    copy(child = newChild.asInstanceOf[SortOrder])
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntime.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntime.scala
index a1f7ba30..0f224fe 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntime.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntime.scala
@@ -140,6 +140,9 @@ case class ExpressionProxy(
   }
 
   override def hashCode(): Int = this.id.hashCode()
+
+  override protected def withNewChildInternal(newChild: Expression): ExpressionProxy =
+    copy(child = newChild)
 }
 
 /**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala
index f7fe467..ed1d770 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala
@@ -92,6 +92,9 @@ case class TimeWindow(
     }
     dataTypeCheck
   }
+
+  override protected def withNewChildInternal(newChild: Expression): TimeWindow =
+    copy(timeColumn = newChild)
 }
 
 object TimeWindow {
@@ -155,4 +158,7 @@ case class PreciseTimestampConversion(
        """.stripMargin)
   }
   override def nullSafeEval(input: Any): Any = input
+
+  override protected def withNewChildInternal(newChild: Expression): PreciseTimestampConversion =
+    copy(child = newChild)
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryCast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryCast.scala
index 8856389..0f63de1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryCast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryCast.scala
@@ -85,6 +85,9 @@ case class TryCast(child: Expression, dataType: DataType, timeZoneId: Option[Str
   override def typeCheckFailureMessage: String =
     AnsiCast.typeCheckFailureMessage(child.dataType, dataType, None, None)
 
+  override protected def withNewChildInternal(newChild: Expression): TryCast =
+    copy(child = newChild)
+
   override def toString: String = {
     s"try_cast($child as ${dataType.simpleString})"
   }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala
index 42dc6f6..19e212d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala
@@ -249,4 +249,8 @@ case class ApproxCountDistinctForIntervals(
     override def getLong(offset: Int): Long = array(offset)
     override def setLong(offset: Int, value: Long): Unit = { array(offset) = value }
   }
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression, newRight: Expression): ApproxCountDistinctForIntervals =
+    copy(child = newLeft, endpointsExpression = newRight)
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala
index 4e4a06a..38d8d7d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala
@@ -208,6 +208,10 @@ case class ApproximatePercentile(
   override def deserialize(bytes: Array[Byte]): PercentileDigest = {
     ApproximatePercentile.serializer.deserialize(bytes)
   }
+
+  override protected def withNewChildrenInternal(
+      newFirst: Expression, newSecond: Expression, newThird: Expression): ApproximatePercentile =
+    copy(child = newFirst, percentageExpression = newSecond, accuracyExpression = newThird)
 }
 
 object ApproximatePercentile {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
index 36004b0..90e91ae 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
@@ -93,4 +93,7 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit
       coalesce(child.cast(sumDataType), Literal.default(sumDataType))),
     /* count = */ If(child.isNull, count, count + 1L)
   )
+
+  override protected def withNewChildInternal(newChild: Expression): Average =
+    copy(child = newChild)
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
index 4ca933f..c5c78e5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
@@ -167,6 +167,9 @@ case class StddevPop(
   }
 
   override def prettyName: String = "stddev_pop"
+
+  override protected def withNewChildInternal(newChild: Expression): StddevPop =
+    copy(child = newChild)
 }
 
 // Compute the sample standard deviation of a column
@@ -197,6 +200,9 @@ case class StddevSamp(
 
   override def prettyName: String =
     getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("stddev_samp")
+
+  override protected def withNewChildInternal(newChild: Expression): StddevSamp =
+    copy(child = newChild)
 }
 
 // Compute the population variance of a column
@@ -223,6 +229,9 @@ case class VariancePop(
   }
 
   override def prettyName: String = "var_pop"
+
+  override protected def withNewChildInternal(newChild: Expression): VariancePop =
+    copy(child = newChild)
 }
 
 // Compute the sample variance of a column
@@ -250,6 +259,9 @@ case class VarianceSamp(
   }
 
   override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("var_samp")
+
+  override protected def withNewChildInternal(newChild: Expression): VarianceSamp =
+    copy(child = newChild)
 }
 
 @ExpressionDescription(
@@ -278,6 +290,9 @@ case class Skewness(
     If(n === 0.0, Literal.create(null, DoubleType),
       If(m2 === 0.0, divideByZeroEvalResult, sqrt(n) * m3 / sqrt(m2 * m2 * m2)))
   }
+
+  override protected def withNewChildInternal(newChild: Expression): Skewness =
+    copy(child = newChild)
 }
 
 @ExpressionDescription(
@@ -306,4 +321,7 @@ case class Kurtosis(
   }
 
   override def prettyName: String = "kurtosis"
+
+  override protected def withNewChildInternal(newChild: Expression): Kurtosis =
+    copy(child = newChild)
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala
index d819971..c798004 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala
@@ -127,4 +127,7 @@ case class Corr(
   }
 
   override def prettyName: String = "corr"
+
+  override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Corr =
+    copy(x = newLeft, y = newRight)
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
index 189d216..1d13155 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
@@ -89,6 +89,9 @@ case class Count(children: Seq[Expression]) extends DeclarativeAggregate {
       )
     }
   }
+
+  override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Count =
+    copy(children = newChildren)
 }
 
 object Count {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountIf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountIf.scala
index c1c4c84..d4fdd51 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountIf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountIf.scala
@@ -56,4 +56,7 @@ case class CountIf(predicate: Expression) extends UnevaluableAggregate with Impl
         s"function $prettyName requires boolean type, not ${predicate.dataType.catalogString}"
       )
   }
+
+  override protected def withNewChildInternal(newChild: Expression): CountIf =
+    copy(predicate = newChild)
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala
index a838a0a..38d0db1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala
@@ -154,4 +154,12 @@ case class CountMinSketchAgg(
   override def second: Expression = epsExpression
   override def third: Expression = confidenceExpression
   override def fourth: Expression = seedExpression
+
+  override protected def withNewChildrenInternal(first: Expression, second: Expression,
+      third: Expression, fourth: Expression): CountMinSketchAgg =
+    copy(
+      child = first,
+      epsExpression = second,
+      confidenceExpression = third,
+      seedExpression = fourth)
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala
index 8fcee10..9ea9b378 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala
@@ -109,6 +109,10 @@ case class CovPopulation(
     If(n === 0.0, Literal.create(null, DoubleType), ck / n)
   }
   override def prettyName: String = "covar_pop"
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression, newRight: Expression): CovPopulation =
+    copy(left = newLeft, right = newRight)
 }
 
 
@@ -135,4 +139,7 @@ case class CovSample(
       If(n === 1.0, divideByZeroEvalResult, ck / (n - 1.0)))
   }
   override def prettyName: String = "covar_samp"
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression, newRight: Expression): CovSample = copy(left = newLeft, right = newRight)
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala
index accd15a..ea994af 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala
@@ -118,6 +118,8 @@ case class First(child: Expression, ignoreNulls: Boolean)
   override lazy val evaluateExpression: AttributeReference = first
 
   override def toString: String = s"$prettyName($child)${if (ignoreNulls) " ignore nulls"}"
+
+  override protected def withNewChildInternal(newChild: Expression): First = copy(child = newChild)
 }
 
 object FirstLast {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala
index 430c25c..9b0493f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala
@@ -138,6 +138,9 @@ case class HyperLogLogPlusPlus(
   override def eval(buffer: InternalRow): Any = {
     hllppHelper.query(buffer, mutableAggBufferOffset)
   }
+
+  override protected def withNewChildInternal(newChild: Expression): HyperLogLogPlusPlus =
+    copy(child = newChild)
 }
 
 object HyperLogLogPlusPlus {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala
index e3c427d..0fe6199 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala
@@ -115,4 +115,6 @@ case class Last(child: Expression, ignoreNulls: Boolean)
   override lazy val evaluateExpression: AttributeReference = last
 
   override def toString: String = s"$prettyName($child)${if (ignoreNulls) " ignore nulls"}"
+
+  override protected def withNewChildInternal(newChild: Expression): Last = copy(child = newChild)
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala
index 42721ea..b802678 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala
@@ -62,4 +62,6 @@ case class Max(child: Expression) extends DeclarativeAggregate with UnaryLike[Ex
   }
 
   override lazy val evaluateExpression: AttributeReference = max
+
+  override protected def withNewChildInternal(newChild: Expression): Max = copy(child = newChild)
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxByAndMinBy.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxByAndMinBy.scala
index e402bca..664bc32 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxByAndMinBy.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxByAndMinBy.scala
@@ -110,6 +110,9 @@ case class MaxBy(valueExpr: Expression, orderingExpr: Expression) extends MaxMin
 
   override protected def orderingUpdater(oldExpr: Expression, newExpr: Expression): Expression =
     greatest(oldExpr, newExpr)
+
+  override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): MaxBy =
+    copy(valueExpr = newLeft, orderingExpr = newRight)
 }
 
 @ExpressionDescription(
@@ -130,4 +133,7 @@ case class MinBy(valueExpr: Expression, orderingExpr: Expression) extends MaxMin
 
   override protected def orderingUpdater(oldExpr: Expression, newExpr: Expression): Expression =
     least(oldExpr, newExpr)
+
+  override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): MinBy =
+    copy(valueExpr = newLeft, orderingExpr = newRight)
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala
index 84410c7..9c5c7bb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala
@@ -62,4 +62,6 @@ case class Min(child: Expression) extends DeclarativeAggregate with UnaryLike[Ex
   }
 
   override lazy val evaluateExpression: AttributeReference = min
+
+  override protected def withNewChildInternal(newChild: Expression): Min = copy(child = newChild)
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
index b81c523..5bce4d3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
@@ -304,4 +304,11 @@ case class Percentile(
       bis.close()
     }
   }
+
+  override protected def withNewChildrenInternal(
+      newFirst: Expression, newSecond: Expression, newThird: Expression): Percentile = copy(
+    child = newFirst,
+    percentageExpression = newSecond,
+    frequencyExpression = newThird
+  )
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala
index 422fcab..b90e46e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala
@@ -153,5 +153,9 @@ case class PivotFirst(
 
   override val inputAggBufferAttributes: Seq[AttributeReference] =
     aggBufferAttributes.map(_.newInstance())
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression, newRight: Expression): PivotFirst =
+    copy(pivotColumn = newLeft, valueColumn = newRight)
 }
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Product.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Product.scala
index 50c74f1..3af3944 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Product.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Product.scala
@@ -59,4 +59,7 @@ case class Product(child: Expression)
     Seq(coalesce(coalesce(product.left, one) * product.right, product.left))
 
   override lazy val evaluateExpression: Expression = product
+
+  override protected def withNewChildInternal(newChild: Expression): Product =
+    copy(child = newChild)
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
index f412a3e..56eebed 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
@@ -148,4 +148,6 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast
         CheckOverflowInSum(sum, d, !SQLConf.get.ansiEnabled))
     case _ => sum
   }
+
+  override protected def withNewChildInternal(newChild: Expression): Sum = copy(child = newChild)
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/UnevaluableAggs.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/UnevaluableAggs.scala
index 5b914c4..878d853 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/UnevaluableAggs.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/UnevaluableAggs.scala
@@ -56,6 +56,8 @@ abstract class UnevaluableBooleanAggBase(arg: Expression)
   since = "3.0.0")
 case class BoolAnd(arg: Expression) extends UnevaluableBooleanAggBase(arg) {
   override def nodeName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("bool_and")
+  override protected def withNewChildInternal(newChild: Expression): Expression =
+    copy(arg = newChild)
 }
 
 @ExpressionDescription(
@@ -73,4 +75,6 @@ case class BoolAnd(arg: Expression) extends UnevaluableBooleanAggBase(arg) {
   since = "3.0.0")
 case class BoolOr(arg: Expression) extends UnevaluableBooleanAggBase(arg) {
   override def nodeName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("bool_or")
+  override protected def withNewChildInternal(newChild: Expression): Expression =
+    copy(arg = newChild)
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/bitwiseAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/bitwiseAggregates.scala
index 5ffc0f6..86a16ad 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/bitwiseAggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/bitwiseAggregates.scala
@@ -69,6 +69,9 @@ case class BitAndAgg(child: Expression) extends BitAggregate {
   override def bitOperator(left: Expression, right: Expression): BinaryArithmetic = {
     BitwiseAnd(left, right)
   }
+
+  override protected def withNewChildInternal(newChild: Expression): BitAndAgg =
+    copy(child = newChild)
 }
 
 @ExpressionDescription(
@@ -87,6 +90,9 @@ case class BitOrAgg(child: Expression) extends BitAggregate {
   override def bitOperator(left: Expression, right: Expression): BinaryArithmetic = {
     BitwiseOr(left, right)
   }
+
+  override protected def withNewChildInternal(newChild: Expression): BitOrAgg =
+    copy(child = newChild)
 }
 
 @ExpressionDescription(
@@ -105,4 +111,7 @@ case class BitXorAgg(child: Expression) extends BitAggregate {
   override def bitOperator(left: Expression, right: Expression): BinaryArithmetic = {
     BitwiseXor(left, right)
   }
+
+  override protected def withNewChildInternal(newChild: Expression): Expression =
+    copy(child = newChild)
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
index d8a76d7..a8db8211a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
@@ -125,6 +125,9 @@ case class CollectList(
   override def eval(buffer: mutable.ArrayBuffer[Any]): Any = {
     new GenericArrayData(buffer.toArray)
   }
+
+  override protected def withNewChildInternal(newChild: Expression): CollectList =
+    copy(child = newChild)
 }
 
 /**
@@ -191,4 +194,7 @@ case class CollectSet(
   override def prettyName: String = "collect_set"
 
   override def createAggregationBuffer(): mutable.HashSet[Any] = mutable.HashSet.empty
+
+  override protected def withNewChildInternal(newChild: Expression): CollectSet =
+    copy(child = newChild)
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
index e0c6ce7..281734c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
@@ -164,6 +164,16 @@ case class AggregateExpression(
       case _ => aggFuncStr
     }
   }
+
+  override protected def withNewChildrenInternal(
+      newChildren: IndexedSeq[Expression]): AggregateExpression =
+    if (filter.isDefined) {
+      copy(
+        aggregateFunction = newChildren(0).asInstanceOf[AggregateFunction],
+        filter = Some(newChildren(1)))
+    } else {
+      copy(aggregateFunction = newChildren(0).asInstanceOf[AggregateFunction])
+    }
 }
 
 /**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 64ea579..2885191 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -105,6 +105,9 @@ case class UnaryMinus(
       case funcName => s"$funcName(${child.sql})"
     }
   }
+
+  override protected def withNewChildInternal(newChild: Expression): UnaryMinus =
+    copy(child = newChild)
 }
 
 @ExpressionDescription(
@@ -131,6 +134,9 @@ case class UnaryPositive(child: Expression)
   protected override def nullSafeEval(input: Any): Any = input
 
   override def sql: String = s"(+ ${child.sql})"
+
+  override protected def withNewChildInternal(newChild: Expression): UnaryPositive =
+    copy(child = newChild)
 }
 
 /**
@@ -183,6 +189,8 @@ case class Abs(child: Expression, failOnError: Boolean = SQLConf.get.ansiEnabled
   }
 
   protected override def nullSafeEval(input: Any): Any = numeric.abs(input)
+
+  override protected def withNewChildInternal(newChild: Expression): Abs = copy(child = newChild)
 }
 
 abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant {
@@ -309,6 +317,9 @@ case class Add(
   }
 
   override def exactMathMethod: Option[String] = Some("addExact")
+
+  override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Add =
+    copy(left = newLeft, right = newRight)
 }
 
 @ExpressionDescription(
@@ -352,6 +363,9 @@ case class Subtract(
   }
 
   override def exactMathMethod: Option[String] = Some("subtractExact")
+
+  override protected def withNewChildrenInternal(
+    newLeft: Expression, newRight: Expression): Subtract = copy(left = newLeft, right = newRight)
 }
 
 @ExpressionDescription(
@@ -380,6 +394,9 @@ case class Multiply(
   protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.times(input1, input2)
 
   override def exactMathMethod: Option[String] = Some("multiplyExact")
+
+  override protected def withNewChildrenInternal(
+    newLeft: Expression, newRight: Expression): Multiply = copy(left = newLeft, right = newRight)
 }
 
 // Common base trait for Divide and Remainder, since these two classes are almost identical
@@ -506,6 +523,9 @@ case class Divide(
   }
 
   override def evalOperation(left: Any, right: Any): Any = div(left, right)
+
+  override protected def withNewChildrenInternal(
+    newLeft: Expression, newRight: Expression): Divide = copy(left = newLeft, right = newRight)
 }
 
 // scalastyle:off line.size.limit
@@ -553,6 +573,10 @@ case class IntegralDivide(
   }
 
   override def evalOperation(left: Any, right: Any): Any = div(left, right)
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression, newRight: Expression): IntegralDivide =
+    copy(left = newLeft, right = newRight)
 }
 
 @ExpressionDescription(
@@ -607,6 +631,9 @@ case class Remainder(
   }
 
   override def evalOperation(left: Any, right: Any): Any = mod(left, right)
+
+  override protected def withNewChildrenInternal(
+    newLeft: Expression, newRight: Expression): Remainder = copy(left = newLeft, right = newRight)
 }
 
 @ExpressionDescription(
@@ -791,6 +818,9 @@ case class Pmod(
   }
 
   override def sql: String = s"$prettyName(${left.sql}, ${right.sql})"
+
+  override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Pmod =
+    copy(left = newLeft, right = newRight)
 }
 
 /**
@@ -866,6 +896,9 @@ case class Least(children: Seq[Expression]) extends ComplexTypeMergingExpression
          |$codes
       """.stripMargin)
   }
+
+  override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Least =
+    copy(children = newChildren)
 }
 
 /**
@@ -941,4 +974,7 @@ case class Greatest(children: Seq[Expression]) extends ComplexTypeMergingExpress
          |$codes
       """.stripMargin)
   }
+
+  override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Greatest =
+    copy(children = newChildren)
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala
index a1fb68e..3940c65 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala
@@ -56,6 +56,9 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme
   }
 
   protected override def nullSafeEval(input1: Any, input2: Any): Any = and(input1, input2)
+
+  override protected def withNewChildrenInternal(
+    newLeft: Expression, newRight: Expression): BitwiseAnd = copy(left = newLeft, right = newRight)
 }
 
 /**
@@ -92,6 +95,9 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet
   }
 
   protected override def nullSafeEval(input1: Any, input2: Any): Any = or(input1, input2)
+
+  override protected def withNewChildrenInternal(
+    newLeft: Expression, newRight: Expression): BitwiseOr = copy(left = newLeft, right = newRight)
 }
 
 /**
@@ -128,6 +134,9 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme
   }
 
   protected override def nullSafeEval(input1: Any, input2: Any): Any = xor(input1, input2)
+
+  override protected def withNewChildrenInternal(
+    newLeft: Expression, newRight: Expression): BitwiseXor = copy(left = newLeft, right = newRight)
 }
 
 /**
@@ -169,6 +178,9 @@ case class BitwiseNot(child: Expression)
   protected override def nullSafeEval(input: Any): Any = not(input)
 
   override def sql: String = s"~${child.sql}"
+
+  override protected def withNewChildInternal(newChild: Expression): BitwiseNot =
+    copy(child = newChild)
 }
 
 @ExpressionDescription(
@@ -204,6 +216,9 @@ case class BitwiseCount(child: Expression)
     case IntegerType => java.lang.Long.bitCount(input.asInstanceOf[Int])
     case LongType => java.lang.Long.bitCount(input.asInstanceOf[Long])
   }
+
+  override protected def withNewChildInternal(newChild: Expression): BitwiseCount =
+    copy(child = newChild)
 }
 
 object BitwiseGetUtil {
@@ -262,4 +277,7 @@ case class BitwiseGet(left: Expression, right: Expression)
 
   override def prettyName: String =
     getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("bit_get")
+
+  override protected def withNewChildrenInternal(
+    newLeft: Expression, newRight: Expression): BitwiseGet = copy(left = newLeft, right = newRight)
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala
index 689858d..c840cdf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala
@@ -22,7 +22,7 @@ import java.lang.{Boolean => JBool}
 import scala.collection.mutable.ArrayBuffer
 import scala.language.implicitConversions
 
-import org.apache.spark.sql.catalyst.trees.TreeNode
+import org.apache.spark.sql.catalyst.trees.{LeafLike, TreeNode}
 import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.types.{BooleanType, DataType}
 
@@ -298,11 +298,13 @@ case class CodeBlock(codeParts: Seq[String], blockInputs: Seq[JavaCode]) extends
     }
     buf.toString
   }
+
+  override protected def withNewChildrenInternal(newChildren: IndexedSeq[Block]): Block =
+    super.legacyWithNewChildren(newChildren)
 }
 
-case object EmptyBlock extends Block with Serializable {
+case object EmptyBlock extends Block with Serializable with LeafLike[Block] {
   override val code: String = ""
-  override def children: Seq[Block] = Seq.empty
 }
 
 /**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index d3fad8c..125e796 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -125,6 +125,8 @@ case class Size(child: Expression, legacySizeOfNull: Boolean)
       defineCodeGen(ctx, ev, c => s"($c).numElements()")
     }
   }
+
+  override protected def withNewChildInternal(newChild: Expression): Size = copy(child = newChild)
 }
 
 object Size {
@@ -159,6 +161,9 @@ case class MapKeys(child: Expression)
   }
 
   override def prettyName: String = "map_keys"
+
+  override protected def withNewChildInternal(newChild: Expression): MapKeys =
+    copy(child = newChild)
 }
 
 @ExpressionDescription(
@@ -321,6 +326,9 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI
   }
 
   override def prettyName: String = "arrays_zip"
+
+  override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): ArraysZip =
+    copy(children = newChildren)
 }
 
 /**
@@ -351,6 +359,9 @@ case class MapValues(child: Expression)
   }
 
   override def prettyName: String = "map_values"
+
+  override protected def withNewChildInternal(newChild: Expression): MapValues =
+    copy(child = newChild)
 }
 
 /**
@@ -523,6 +534,8 @@ case class MapEntries(child: Expression)
   }
 
   override def prettyName: String = "map_entries"
+
+  override def withNewChildInternal(newChild: Expression): MapEntries = copy(child = newChild)
 }
 
 /**
@@ -642,6 +655,9 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres
   }
 
   override def prettyName: String = "map_concat"
+
+  override def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): MapConcat =
+    copy(children = newChildren)
 }
 
 /**
@@ -720,6 +736,9 @@ case class MapFromEntries(child: Expression) extends UnaryExpression with NullIn
   }
 
   override def prettyName: String = "map_from_entries"
+
+  override protected def withNewChildInternal(newChild: Expression): MapFromEntries =
+    copy(child = newChild)
 }
 
 
@@ -919,6 +938,10 @@ case class SortArray(base: Expression, ascendingOrder: Expression)
   }
 
   override def prettyName: String = "sort_array"
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression, newRight: Expression): SortArray =
+    copy(base = newLeft, ascendingOrder = newRight)
 }
 
 /**
@@ -1007,6 +1030,8 @@ case class Shuffle(child: Expression, randomSeed: Option[Long] = None)
   }
 
   override def freshCopy(): Shuffle = Shuffle(child, randomSeed)
+
+  override def withNewChildInternal(newChild: Expression): Shuffle = copy(child = newChild)
 }
 
 /**
@@ -1083,6 +1108,9 @@ case class Reverse(child: Expression)
   }
 
   override def prettyName: String = "reverse"
+
+  override protected def withNewChildInternal(newChild: Expression): Reverse =
+    copy(child = newChild)
 }
 
 /**
@@ -1180,6 +1208,10 @@ case class ArrayContains(left: Expression, right: Expression)
   }
 
   override def prettyName: String = "array_contains"
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression, newRight: Expression): ArrayContains =
+    copy(left = newLeft, right = newRight)
 }
 
 /**
@@ -1403,6 +1435,10 @@ case class ArraysOverlap(left: Expression, right: Expression)
   }
 
   override def prettyName: String = "arrays_overlap"
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression, newRight: Expression): ArraysOverlap =
+    copy(left = newLeft, right = newRight)
 }
 
 /**
@@ -1516,6 +1552,10 @@ case class Slice(x: Expression, start: Expression, length: Expression)
        |}
      """.stripMargin
   }
+
+  override protected def withNewChildrenInternal(
+      newFirst: Expression, newSecond: Expression, newThird: Expression): Slice =
+    copy(x = newFirst, start = newSecond, length = newThird)
 }
 
 /**
@@ -1559,6 +1599,16 @@ case class ArrayJoin(
     Seq(array, delimiter)
   }
 
+  override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
+    if (nullReplacement.isDefined) {
+      copy(
+        array = newChildren(0),
+        delimiter = newChildren(1),
+        nullReplacement = Some(newChildren(2)))
+    } else {
+      copy(array = newChildren(0), delimiter = newChildren(1))
+    }
+
   override def nullable: Boolean = children.exists(_.nullable)
 
   override def foldable: Boolean = children.forall(_.foldable)
@@ -1756,6 +1806,9 @@ case class ArrayMin(child: Expression)
   }
 
   override def prettyName: String = "array_min"
+
+  override protected def withNewChildInternal(newChild: Expression): ArrayMin =
+    copy(child = newChild)
 }
 
 /**
@@ -1824,6 +1877,9 @@ case class ArrayMax(child: Expression)
   }
 
   override def prettyName: String = "array_max"
+
+  override protected def withNewChildInternal(newChild: Expression): ArrayMax =
+    copy(child = newChild)
 }
 
 
@@ -1903,6 +1959,10 @@ case class ArrayPosition(left: Expression, right: Expression)
        """.stripMargin
     })
   }
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression, newRight: Expression): ArrayPosition =
+    copy(left = newLeft, right = newRight)
 }
 
 /**
@@ -2085,6 +2145,9 @@ case class ElementAt(
   }
 
   override def prettyName: String = "element_at"
+
+  override protected def withNewChildrenInternal(
+    newLeft: Expression, newRight: Expression): ElementAt = copy(left = newLeft, right = newRight)
 }
 
 /**
@@ -2291,6 +2354,9 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio
   override def toString: String = s"concat(${children.mkString(", ")})"
 
   override def sql: String = s"concat(${children.map(_.sql).mkString(", ")})"
+
+  override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Concat =
+    copy(children = newChildren)
 }
 
 /**
@@ -2403,6 +2469,9 @@ case class Flatten(child: Expression) extends UnaryExpression with NullIntoleran
   }
 
   override def prettyName: String = "flatten"
+
+  override protected def withNewChildInternal(newChild: Expression): Flatten =
+    copy(child = newChild)
 }
 
 @ExpressionDescription(
@@ -2460,6 +2529,15 @@ case class Sequence(
 
   override def children: Seq[Expression] = Seq(start, stop) ++ stepOpt
 
+  override def withNewChildrenInternal(
+      newChildren: IndexedSeq[Expression]): TimeZoneAwareExpression = {
+    if (stepOpt.isDefined) {
+      copy(start = newChildren(0), stop = newChildren(1), stepOpt = Some(newChildren(2)))
+    } else {
+      copy(start = newChildren(0), stop = newChildren(1))
+    }
+  }
+
   override def foldable: Boolean = children.forall(_.foldable)
 
   override def nullable: Boolean = children.exists(_.nullable)
@@ -2949,6 +3027,8 @@ case class ArrayRepeat(left: Expression, right: Expression)
      """.stripMargin
   }
 
+  override protected def withNewChildrenInternal(
+    newLeft: Expression, newRight: Expression): ArrayRepeat = copy(left = newLeft, right = newRight)
 }
 
 /**
@@ -3063,6 +3143,9 @@ case class ArrayRemove(left: Expression, right: Expression)
   }
 
   override def prettyName: String = "array_remove"
+
+  override protected def withNewChildrenInternal(
+    newLeft: Expression, newRight: Expression): ArrayRemove = copy(left = newLeft, right = newRight)
 }
 
 /**
@@ -3295,6 +3378,9 @@ case class ArrayDistinct(child: Expression)
   }
 
   override def prettyName: String = "array_distinct"
+
+  override protected def withNewChildInternal(newChild: Expression): ArrayDistinct =
+    copy(child = newChild)
 }
 
 /**
@@ -3497,6 +3583,9 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi
   }
 
   override def prettyName: String = "array_union"
+
+  override protected def withNewChildrenInternal(
+    newLeft: Expression, newRight: Expression): ArrayUnion = copy(left = newLeft, right = newRight)
 }
 
 object ArrayUnion {
@@ -3780,6 +3869,10 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArrayBina
   }
 
   override def prettyName: String = "array_intersect"
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression, newRight: Expression): ArrayIntersect =
+    copy(left = newLeft, right = newRight)
 }
 
 /**
@@ -4004,4 +4097,7 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryL
   }
 
   override def prettyName: String = "array_except"
+
+  override protected def withNewChildrenInternal(
+    newLeft: Expression, newRight: Expression): ArrayExcept = copy(left = newLeft, right = newRight)
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index 3c016a7..f1456c4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -102,6 +102,9 @@ case class CreateArray(children: Seq[Expression], useStringTypeWhenEmpty: Boolea
   }
 
   override def prettyName: String = "array"
+
+  override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): CreateArray =
+    copy(children = newChildren)
 }
 
 object CreateArray {
@@ -254,6 +257,9 @@ case class CreateMap(children: Seq[Expression], useStringTypeWhenEmpty: Boolean)
   }
 
   override def prettyName: String = "map"
+
+  override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): CreateMap =
+    copy(children = newChildren)
 }
 
 object CreateMap {
@@ -314,6 +320,10 @@ case class MapFromArrays(left: Expression, right: Expression)
   }
 
   override def prettyName: String = "map_from_arrays"
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression, newRight: Expression): MapFromArrays =
+    copy(left = newLeft, right = newRight)
 }
 
 /**
@@ -493,6 +503,9 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression with
     val childrenSQL = children.indices.filter(_ % 2 == 1).map(children(_).sql).mkString(", ")
     s"$alias($childrenSQL)"
   }.getOrElse(super.sql)
+
+  override protected def withNewChildrenInternal(
+    newChildren: IndexedSeq[Expression]): CreateNamedStruct = copy(children = newChildren)
 }
 
 /**
@@ -576,6 +589,13 @@ case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: E
   }
 
   override def prettyName: String = "str_to_map"
+
+  override protected def withNewChildrenInternal(
+      newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = copy(
+    text = newFirst,
+    pairDelim = newSecond,
+    keyValueDelim = newThird
+  )
 }
 
 /**
@@ -627,6 +647,9 @@ case class WithField(name: String, valExpr: Expression)
     "WithField.nullable should not be called.")
 
   override def prettyName: String = "WithField"
+
+  override protected def withNewChildInternal(newChild: Expression): WithField =
+    copy(valExpr = newChild)
 }
 
 /**
@@ -659,6 +682,9 @@ case class UpdateFields(structExpr: Expression, fieldOps: Seq[StructFieldsOperat
     case e: Expression => e
   }
 
+  override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
+    super.legacyWithNewChildren(newChildren)
+
   override def dataType: StructType = StructType(newFields)
 
   override def nullable: Boolean = structExpr.nullable
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
index 139d9a5..f64cc8a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
@@ -138,6 +138,9 @@ case class GetStructField(child: Expression, ordinal: Int, name: Option[String]
       }
     })
   }
+
+  override protected def withNewChildInternal(newChild: Expression): GetStructField =
+    copy(child = newChild)
 }
 
 /**
@@ -212,6 +215,9 @@ case class GetArrayStructFields(
       """
     })
   }
+
+  override protected def withNewChildInternal(newChild: Expression): GetArrayStructFields =
+    copy(child = newChild)
 }
 
 /**
@@ -292,6 +298,10 @@ case class GetArrayItem(
       """
     })
   }
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression, newRight: Expression): GetArrayItem =
+    copy(child = newLeft, ordinal = newRight)
 }
 
 /**
@@ -470,4 +480,8 @@ case class GetMapValue(
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     doGetValueGenCode(ctx, ev, child.dataType.asInstanceOf[MapType], failOnError)
   }
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression, newRight: Expression): GetMapValue =
+    copy(child = newLeft, key = newRight)
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
index a062dd4..e708d56 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
@@ -95,6 +95,13 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
   override def toString: String = s"if ($predicate) $trueValue else $falseValue"
 
   override def sql: String = s"(IF(${predicate.sql}, ${trueValue.sql}, ${falseValue.sql}))"
+
+  override protected def withNewChildrenInternal(
+      newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = copy(
+    predicate = newFirst,
+    trueValue = newSecond,
+    falseValue = newThird
+  )
 }
 
 /**
@@ -132,6 +139,9 @@ case class CaseWhen(
 
   override def children: Seq[Expression] = branches.flatMap(b => b._1 :: b._2 :: Nil) ++ elseValue
 
+  override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
+    super.legacyWithNewChildren(newChildren)
+
   // both then and else expressions should be considered.
   @transient
   override lazy val inputTypesForMerging: Seq[DataType] = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala
index 5bfae7b..8feaf52 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala
@@ -36,6 +36,12 @@ case class KnownNotNull(child: Expression) extends TaggingExpression {
   override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     child.genCode(ctx).copy(isNull = FalseLiteral)
   }
+
+  override protected def withNewChildInternal(newChild: Expression): KnownNotNull =
+    copy(child = newChild)
 }
 
-case class KnownFloatingPointNormalized(child: Expression) extends TaggingExpression
+case class KnownFloatingPointNormalized(child: Expression) extends TaggingExpression {
+  override protected def withNewChildInternal(newChild: Expression): KnownFloatingPointNormalized =
+    copy(child = newChild)
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala
index ac47020..79bbc10 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala
@@ -140,6 +140,9 @@ case class CsvToStructs(
   override def inputTypes: Seq[AbstractDataType] = StringType :: Nil
 
   override def prettyName: String = "from_csv"
+
+  override protected def withNewChildInternal(newChild: Expression): CsvToStructs =
+    copy(child = newChild)
 }
 
 /**
@@ -197,6 +200,9 @@ case class SchemaOfCsv(
   }
 
   override def prettyName: String = "schema_of_csv"
+
+  override protected def withNewChildInternal(newChild: Expression): SchemaOfCsv =
+    copy(child = newChild)
 }
 
 /**
@@ -264,4 +270,7 @@ case class StructsToCsv(
   override def inputTypes: Seq[AbstractDataType] = StructType :: Nil
 
   override def prettyName: String = "to_csv"
+
+  override protected def withNewChildInternal(newChild: Expression): StructsToCsv =
+    copy(child = newChild)
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
index 355064e..ba9d458 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
@@ -251,6 +251,9 @@ case class DateAdd(startDate: Expression, days: Expression)
   }
 
   override def prettyName: String = "date_add"
+
+  override protected def withNewChildrenInternal(
+    newLeft: Expression, newRight: Expression): DateAdd = copy(startDate = newLeft, days = newRight)
 }
 
 /**
@@ -286,6 +289,9 @@ case class DateSub(startDate: Expression, days: Expression)
   }
 
   override def prettyName: String = "date_sub"
+
+  override protected def withNewChildrenInternal(
+    newLeft: Expression, newRight: Expression): DateSub = copy(startDate = newLeft, days = newRight)
 }
 
 trait GetTimeField extends UnaryExpression
@@ -323,6 +329,7 @@ case class Hour(child: Expression, timeZoneId: Option[String] = None) extends Ge
   override def withTimeZone(timeZoneId: String): Hour = copy(timeZoneId = Option(timeZoneId))
   override val func = DateTimeUtils.getHours
   override val funcName = "getHours"
+  override protected def withNewChildInternal(newChild: Expression): Hour = copy(child = newChild)
 }
 
 @ExpressionDescription(
@@ -339,6 +346,7 @@ case class Minute(child: Expression, timeZoneId: Option[String] = None) extends
   override def withTimeZone(timeZoneId: String): Minute = copy(timeZoneId = Option(timeZoneId))
   override val func = DateTimeUtils.getMinutes
   override val funcName = "getMinutes"
+  override protected def withNewChildInternal(newChild: Expression): Minute = copy(child = newChild)
 }
 
 @ExpressionDescription(
@@ -355,6 +363,8 @@ case class Second(child: Expression, timeZoneId: Option[String] = None) extends
   override def withTimeZone(timeZoneId: String): Second = copy(timeZoneId = Option(timeZoneId))
   override val func = DateTimeUtils.getSeconds
   override val funcName = "getSeconds"
+  override protected def withNewChildInternal(newChild: Expression): Second =
+    copy(child = newChild)
 }
 
 case class SecondWithFraction(child: Expression, timeZoneId: Option[String] = None)
@@ -366,6 +376,8 @@ case class SecondWithFraction(child: Expression, timeZoneId: Option[String] = No
     copy(timeZoneId = Option(timeZoneId))
   override val func = DateTimeUtils.getSecondsWithFraction
   override val funcName = "getSecondsWithFraction"
+  override protected def withNewChildInternal(newChild: Expression): SecondWithFraction =
+    copy(child = newChild)
 }
 
 trait GetDateField extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
@@ -398,6 +410,8 @@ trait GetDateField extends UnaryExpression with ImplicitCastInputTypes with Null
 case class DayOfYear(child: Expression) extends GetDateField {
   override val func = DateTimeUtils.getDayInYear
   override val funcName = "getDayInYear"
+  override protected def withNewChildInternal(newChild: Expression): DayOfYear =
+    copy(child = newChild)
 }
 
 @ExpressionDescription(
@@ -421,6 +435,9 @@ case class DateFromUnixDate(child: Expression) extends UnaryExpression
     defineCodeGen(ctx, ev, c => c)
 
   override def prettyName: String = "date_from_unix_date"
+
+  override protected def withNewChildInternal(newChild: Expression): DateFromUnixDate =
+    copy(child = newChild)
 }
 
 @ExpressionDescription(
@@ -444,6 +461,9 @@ case class UnixDate(child: Expression) extends UnaryExpression
     defineCodeGen(ctx, ev, c => c)
 
   override def prettyName: String = "unix_date"
+
+  override protected def withNewChildInternal(newChild: Expression): UnixDate =
+    copy(child = newChild)
 }
 
 abstract class IntegralToTimestampBase extends UnaryExpression
@@ -531,6 +551,9 @@ case class SecondsToTimestamp(child: Expression) extends UnaryExpression
   }
 
   override def prettyName: String = "timestamp_seconds"
+
+  override protected def withNewChildInternal(newChild: Expression): SecondsToTimestamp =
+    copy(child = newChild)
 }
 
 // scalastyle:off line.size.limit
@@ -550,6 +573,9 @@ case class MillisToTimestamp(child: Expression)
   override def upScaleFactor: Long = MICROS_PER_MILLIS
 
   override def prettyName: String = "timestamp_millis"
+
+  override protected def withNewChildInternal(newChild: Expression): MillisToTimestamp =
+    copy(child = newChild)
 }
 
 // scalastyle:off line.size.limit
@@ -569,6 +595,9 @@ case class MicrosToTimestamp(child: Expression)
   override def upScaleFactor: Long = 1L
 
   override def prettyName: String = "timestamp_micros"
+
+  override protected def withNewChildInternal(newChild: Expression): MicrosToTimestamp =
+    copy(child = newChild)
 }
 
 abstract class TimestampToLongBase extends UnaryExpression
@@ -608,6 +637,9 @@ case class UnixSeconds(child: Expression) extends TimestampToLongBase {
   override def scaleFactor: Long = MICROS_PER_SECOND
 
   override def prettyName: String = "unix_seconds"
+
+  override protected def withNewChildInternal(newChild: Expression): UnixSeconds =
+    copy(child = newChild)
 }
 
 // scalastyle:off line.size.limit
@@ -625,6 +657,9 @@ case class UnixMillis(child: Expression) extends TimestampToLongBase {
   override def scaleFactor: Long = MICROS_PER_MILLIS
 
   override def prettyName: String = "unix_millis"
+
+  override protected def withNewChildInternal(newChild: Expression): UnixMillis =
+    copy(child = newChild)
 }
 
 // scalastyle:off line.size.limit
@@ -642,6 +677,9 @@ case class UnixMicros(child: Expression) extends TimestampToLongBase {
   override def scaleFactor: Long = 1L
 
   override def prettyName: String = "unix_micros"
+
+  override protected def withNewChildInternal(newChild: Expression): UnixMicros =
+    copy(child = newChild)
 }
 
 @ExpressionDescription(
@@ -656,11 +694,15 @@ case class UnixMicros(child: Expression) extends TimestampToLongBase {
 case class Year(child: Expression) extends GetDateField {
   override val func = DateTimeUtils.getYear
   override val funcName = "getYear"
+  override protected def withNewChildInternal(newChild: Expression): Year =
+    copy(child = newChild)
 }
 
 case class YearOfWeek(child: Expression) extends GetDateField {
   override val func = DateTimeUtils.getWeekBasedYear
   override val funcName = "getWeekBasedYear"
+  override protected def withNewChildInternal(newChild: Expression): YearOfWeek =
+    copy(child = newChild)
 }
 
 @ExpressionDescription(
@@ -675,6 +717,8 @@ case class YearOfWeek(child: Expression) extends GetDateField {
 case class Quarter(child: Expression) extends GetDateField {
   override val func = DateTimeUtils.getQuarter
   override val funcName = "getQuarter"
+  override protected def withNewChildInternal(newChild: Expression): Quarter =
+    copy(child = newChild)
 }
 
 @ExpressionDescription(
@@ -689,6 +733,7 @@ case class Quarter(child: Expression) extends GetDateField {
 case class Month(child: Expression) extends GetDateField {
   override val func = DateTimeUtils.getMonth
   override val funcName = "getMonth"
+  override protected def withNewChildInternal(newChild: Expression): Month = copy(child = newChild)
 }
 
 @ExpressionDescription(
@@ -703,6 +748,8 @@ case class Month(child: Expression) extends GetDateField {
 case class DayOfMonth(child: Expression) extends GetDateField {
   override val func = DateTimeUtils.getDayOfMonth
   override val funcName = "getDayOfMonth"
+  override protected def withNewChildInternal(newChild: Expression): DayOfMonth =
+    copy(child = newChild)
 }
 
 // scalastyle:off line.size.limit
@@ -719,6 +766,8 @@ case class DayOfMonth(child: Expression) extends GetDateField {
 case class DayOfWeek(child: Expression) extends GetDateField {
   override val func = DateTimeUtils.getDayOfWeek
   override val funcName = "getDayOfWeek"
+  override protected def withNewChildInternal(newChild: Expression): DayOfWeek =
+    copy(child = newChild)
 }
 
 // scalastyle:off line.size.limit
@@ -735,6 +784,8 @@ case class DayOfWeek(child: Expression) extends GetDateField {
 case class WeekDay(child: Expression) extends GetDateField {
   override val func = DateTimeUtils.getWeekDay
   override val funcName = "getWeekDay"
+  override protected def withNewChildInternal(newChild: Expression): WeekDay =
+    copy(child = newChild)
 }
 
 // scalastyle:off line.size.limit
@@ -751,6 +802,8 @@ case class WeekDay(child: Expression) extends GetDateField {
 case class WeekOfYear(child: Expression) extends GetDateField {
   override val func = DateTimeUtils.getWeekOfYear
   override val funcName = "getWeekOfYear"
+  override protected def withNewChildInternal(newChild: Expression): WeekOfYear =
+    copy(child = newChild)
 }
 
 // scalastyle:off line.size.limit
@@ -814,6 +867,10 @@ case class DateFormatClass(left: Expression, right: Expression, timeZoneId: Opti
   override protected def formatString: Expression = right
 
   override protected def isParsing: Boolean = false
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression, newRight: Expression): DateFormatClass =
+    copy(left = newLeft, right = newRight)
 }
 
 /**
@@ -859,6 +916,10 @@ case class ToUnixTimestamp(
   }
 
   override def prettyName: String = "to_unix_timestamp"
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression, newRight: Expression): ToUnixTimestamp =
+    copy(timeExp = newLeft, format = newRight)
 }
 
 // scalastyle:off line.size.limit
@@ -915,6 +976,10 @@ case class UnixTimestamp(
   }
 
   override def prettyName: String = "unix_timestamp"
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression, newRight: Expression): UnixTimestamp =
+    copy(timeExp = newLeft, format = newRight)
 }
 
 abstract class ToTimestamp
@@ -1120,6 +1185,10 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[
   override protected def formatString: Expression = format
 
   override protected def isParsing: Boolean = false
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression, newRight: Expression): FromUnixTime =
+    copy(sec = newLeft, format = newRight)
 }
 
 /**
@@ -1152,6 +1221,9 @@ case class LastDay(startDate: Expression)
   }
 
   override def prettyName: String = "last_day"
+
+  override protected def withNewChildInternal(newChild: Expression): LastDay =
+    copy(startDate = newChild)
 }
 
 /**
@@ -1249,6 +1321,10 @@ case class NextDay(
   }
 
   override def prettyName: String = "next_day"
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression, newRight: Expression): NextDay =
+    copy(startDate = newLeft, dayOfWeek = newRight)
 }
 
 /**
@@ -1292,6 +1368,10 @@ case class TimeAdd(start: Expression, interval: Expression, timeZoneId: Option[S
         })
     }
   }
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression, newRight: Expression): TimeAdd =
+    copy(start = newLeft, interval = newRight)
 }
 
 /**
@@ -1305,6 +1385,8 @@ case class DatetimeSub(
   override def exprsReplaced: Seq[Expression] = Seq(start, interval)
   override def toString: String = s"$start - $interval"
   override def mkString(childrenString: Seq[String]): String = childrenString.mkString(" - ")
+  override protected def withNewChildInternal(newChild: Expression): DatetimeSub =
+    copy(child = newChild)
 }
 
 /**
@@ -1367,6 +1449,10 @@ case class DateAddInterval(
 
   override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
     copy(timeZoneId = Option(timeZoneId))
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression, newRight: Expression): DateAddInterval =
+    copy(start = newLeft, interval = newRight)
 }
 
 sealed trait UTCTimestamp extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
@@ -1447,6 +1533,9 @@ case class FromUTCTimestamp(left: Expression, right: Expression) extends UTCTime
   override val func = DateTimeUtils.fromUTCTime
   override val funcName: String = "fromUTCTime"
   override val prettyName: String = "from_utc_timestamp"
+  override protected def withNewChildrenInternal(
+      newLeft: Expression, newRight: Expression): FromUTCTimestamp =
+    copy(left = newLeft, right = newRight)
 }
 
 /**
@@ -1478,6 +1567,9 @@ case class ToUTCTimestamp(left: Expression, right: Expression) extends UTCTimest
   override val func = DateTimeUtils.toUTCTime
   override val funcName: String = "toUTCTime"
   override val prettyName: String = "to_utc_timestamp"
+  override protected def withNewChildrenInternal(
+      newLeft: Expression, newRight: Expression): ToUTCTimestamp =
+    copy(left = newLeft, right = newRight)
 }
 
 abstract class AddMonthsBase extends BinaryExpression with ImplicitCastInputTypes
@@ -1517,6 +1609,10 @@ case class AddMonths(startDate: Expression, numMonths: Expression) extends AddMo
   override def inputTypes: Seq[AbstractDataType] = Seq(DateType, IntegerType)
 
   override def prettyName: String = "add_months"
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression, newRight: Expression): AddMonths =
+    copy(startDate = newLeft, numMonths = newRight)
 }
 
 // Adds the year-month interval to the date
@@ -1528,6 +1624,10 @@ case class DateAddYMInterval(date: Expression, interval: Expression) extends Add
 
   override def toString: String = s"$left + $right"
   override def sql: String = s"${left.sql} + ${right.sql}"
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression, newRight: Expression): DateAddYMInterval =
+    copy(date = newLeft, interval = newRight)
 }
 
 // Adds the year-month interval to the timestamp
@@ -1562,6 +1662,10 @@ case class TimestampAddYMInterval(
       s"""$dtu.timestampAddMonths($micros, $months, $zid)"""
     })
   }
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression, newRight: Expression): TimestampAddYMInterval =
+    copy(timestamp = newLeft, interval = newRight)
 }
 
 /**
@@ -1628,6 +1732,10 @@ case class MonthsBetween(
   }
 
   override def prettyName: String = "months_between"
+
+  override protected def withNewChildrenInternal(
+      newFirst: Expression, newSecond: Expression, newThird: Expression): MonthsBetween =
+    copy(date1 = newFirst, date2 = newSecond, roundOff = newThird)
 }
 
 /**
@@ -1672,6 +1780,9 @@ case class ParseToDate(left: Expression, format: Option[Expression], child: Expr
   override def flatArguments: Iterator[Any] = Iterator(left, format)
 
   override def prettyName: String = "to_date"
+
+  override protected def withNewChildInternal(newChild: Expression): ParseToDate =
+    copy(child = newChild)
 }
 
 /**
@@ -1714,6 +1825,9 @@ case class ParseToTimestamp(left: Expression, format: Option[Expression], child:
 
   override def prettyName: String = "to_timestamp"
   override def dataType: DataType = TimestampType
+
+  override protected def withNewChildInternal(newChild: Expression): ParseToTimestamp =
+    copy(child = newChild)
 }
 
 trait TruncInstant extends BinaryExpression with ImplicitCastInputTypes {
@@ -1849,6 +1963,10 @@ case class TruncDate(date: Expression, format: Expression)
       (date: String, fmt: String) => s"truncDate($date, $fmt);"
     }
   }
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression, newRight: Expression): TruncDate =
+    copy(date = newLeft, format = newRight)
 }
 
 /**
@@ -1920,6 +2038,10 @@ case class TruncTimestamp(
         s"truncTimestamp($date, $fmt, $zid);"
     }
   }
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression, newRight: Expression): TruncTimestamp =
+    copy(format = newLeft, timestamp = newRight)
 }
 
 /**
@@ -1952,6 +2074,10 @@ case class DateDiff(endDate: Expression, startDate: Expression)
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     defineCodeGen(ctx, ev, (end, start) => s"$end - $start")
   }
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression, newRight: Expression): DateDiff =
+    copy(endDate = newLeft, startDate = newRight)
 }
 
 /**
@@ -1969,6 +2095,10 @@ private case class GetTimestamp(
 
   override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
     copy(timeZoneId = Option(timeZoneId))
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression, newRight: Expression): GetTimestamp =
+    copy(left = newLeft, right = newRight)
 }
 
 @ExpressionDescription(
@@ -2032,6 +2162,10 @@ case class MakeDate(
   }
 
   override def prettyName: String = "make_date"
+
+  override protected def withNewChildrenInternal(
+      newFirst: Expression, newSecond: Expression, newThird: Expression): MakeDate =
+    copy(year = newFirst, month = newSecond, day = newThird)
 }
 
 // scalastyle:off line.size.limit
@@ -2198,6 +2332,20 @@ case class MakeTimestamp(
   }
 
   override def prettyName: String = "make_timestamp"
+
+//  override def children: Seq[Expression] = Seq(year, month, day, hour, min, sec) ++ timezone
+  override protected def withNewChildrenInternal(
+      newChildren: IndexedSeq[Expression]): MakeTimestamp = {
+    val timezoneOpt = if (timezone.isDefined) Some(newChildren(6)) else None
+    copy(
+      year = newChildren(0),
+      month = newChildren(1),
+      day = newChildren(2),
+      hour = newChildren(3),
+      min = newChildren(4),
+      sec = newChildren(5),
+      timezone = timezoneOpt)
+  }
 }
 
 object DatePart {
@@ -2284,6 +2432,9 @@ case class DatePart(field: Expression, source: Expression, child: Expression)
   override def exprsReplaced: Seq[Expression] = Seq(field, source)
 
   override def prettyName: String = "date_part"
+
+  override protected def withNewChildInternal(newChild: Expression): DatePart =
+    copy(child = newChild)
 }
 
 // scalastyle:off line.size.limit
@@ -2349,6 +2500,9 @@ case class Extract(field: Expression, source: Expression, child: Expression)
   override def mkString(childrenString: Seq[String]): String = {
     prettyName + childrenString.mkString("(", " FROM ", ")")
   }
+
+  override protected def withNewChildInternal(newChild: Expression): Extract =
+    copy(child = newChild)
 }
 
 /**
@@ -2401,6 +2555,10 @@ case class SubtractTimestamps(
       defineCodeGen(ctx, ev, (end, start) =>
         s"new org.apache.spark.unsafe.types.CalendarInterval(0, 0, $end - $start)")
   }
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression, newRight: Expression): SubtractTimestamps =
+    copy(left = newLeft, right = newRight)
 }
 
 object SubtractTimestamps {
@@ -2452,6 +2610,10 @@ case class SubtractDates(
         s"$dtu.subtractDates($leftDays, $rightDays)"
       })
   }
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression, newRight: Expression): SubtractDates =
+    copy(left = newLeft, right = newRight)
 }
 
 object SubtractDates {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
index b987bed..7165bca 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
@@ -40,6 +40,9 @@ case class UnscaledValue(child: Expression) extends UnaryExpression with NullInt
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     defineCodeGen(ctx, ev, c => s"$c.toUnscaledLong()")
   }
+
+  override protected def withNewChildInternal(newChild: Expression): UnscaledValue =
+    copy(child = newChild)
 }
 
 /**
@@ -89,6 +92,9 @@ case class MakeDecimal(
          |""".stripMargin
     })
   }
+
+  override protected def withNewChildInternal(newChild: Expression): MakeDecimal =
+    copy(child = newChild)
 }
 
 object MakeDecimal {
@@ -111,6 +117,9 @@ case class PromotePrecision(child: Expression) extends UnaryExpression {
   override def prettyName: String = "promote_precision"
   override def sql: String = child.sql
   override lazy val canonicalized: Expression = child.canonicalized
+
+  override protected def withNewChildInternal(newChild: Expression): Expression =
+    copy(child = newChild)
 }
 
 /**
@@ -145,6 +154,9 @@ case class CheckOverflow(
   override def toString: String = s"CheckOverflow($child, $dataType, $nullOnOverflow)"
 
   override def sql: String = child.sql
+
+  override protected def withNewChildInternal(newChild: Expression): CheckOverflow =
+    copy(child = newChild)
 }
 
 // A variant `CheckOverflow`, which treats null as overflow. This is necessary in `Sum`.
@@ -194,4 +206,7 @@ case class CheckOverflowInSum(
   override def toString: String = s"CheckOverflowInSum($child, $dataType, $nullOnOverflow)"
 
   override def sql: String = child.sql
+
+  override protected def withNewChildInternal(newChild: Expression): CheckOverflowInSum =
+    copy(child = newChild)
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
index f10ceea..fef9bb3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
@@ -118,6 +118,9 @@ case class UserDefinedGenerator(
   }
 
   override def toString: String = s"UserDefinedGenerator(${children.mkString(",")})"
+
+  override protected def withNewChildrenInternal(
+    newChildren: IndexedSeq[Expression]): UserDefinedGenerator = copy(children = newChildren)
 }
 
 /**
@@ -227,6 +230,9 @@ case class Stack(children: Seq[Expression]) extends Generator {
          |$wrapperClass<InternalRow> ${ev.value} = $wrapperClass$$.MODULE$$.make($rowData);
        """.stripMargin, isNull = FalseLiteral)
   }
+
+  override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Stack =
+    copy(children = newChildren)
 }
 
 /**
@@ -253,6 +259,9 @@ case class ReplicateRows(children: Seq[Expression]) extends Generator with Codeg
       InternalRow(fields: _*)
     }
   }
+
+  override protected def withNewChildrenInternal(
+    newChildren: IndexedSeq[Expression]): ReplicateRows = copy(children = newChildren)
 }
 
 /**
@@ -269,6 +278,9 @@ case class GeneratorOuter(child: Generator) extends UnaryExpression with Generat
   override def elementSchema: StructType = child.elementSchema
 
   override lazy val resolved: Boolean = false
+
+  override protected def withNewChildInternal(newChild: Expression): GeneratorOuter =
+    copy(child = newChild.asInstanceOf[Generator])
 }
 
 /**
@@ -369,6 +381,8 @@ abstract class ExplodeBase extends UnaryExpression with CollectionGenerator with
 // scalastyle:on line.size.limit
 case class Explode(child: Expression) extends ExplodeBase {
   override val position: Boolean = false
+  override protected def withNewChildInternal(newChild: Expression): Explode =
+    copy(child = newChild)
 }
 
 /**
@@ -394,6 +408,8 @@ case class Explode(child: Expression) extends ExplodeBase {
 // scalastyle:on line.size.limit line.contains.tab
 case class PosExplode(child: Expression) extends ExplodeBase {
   override val position = true
+  override protected def withNewChildInternal(newChild: Expression): PosExplode =
+    copy(child = newChild)
 }
 
 /**
@@ -445,4 +461,6 @@ case class Inline(child: Expression) extends UnaryExpression with CollectionGene
   override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     child.genCode(ctx)
   }
+
+  override protected def withNewChildInternal(newChild: Expression): Inline = copy(child = newChild)
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala
index bf28efa..0dd82be 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala
@@ -111,6 +111,8 @@ case class Cube(
     children: Seq[Expression]) extends BaseGroupingSets {
   override def groupingSets: Seq[Seq[Expression]] = groupingSetIndexes.map(_.map(children))
   override def selectedGroupByExprs: Seq[Seq[Expression]] = BaseGroupingSets.cubeExprs(groupingSets)
+  override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Cube =
+    copy(children = newChildren)
 }
 
 object Cube {
@@ -125,6 +127,8 @@ case class Rollup(
   override def groupingSets: Seq[Seq[Expression]] = groupingSetIndexes.map(_.map(children))
   override def selectedGroupByExprs: Seq[Seq[Expression]] =
     BaseGroupingSets.rollupExprs(groupingSets)
+  override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Rollup =
+    copy(children = newChildren)
 }
 
 object Rollup {
@@ -142,6 +146,9 @@ case class GroupingSets(
   // Includes the `userGivenGroupByExprs` in the children, which will be included in the final
   // GROUP BY expressions, so that `SELECT c ... GROUP BY (a, b, c) GROUPING SETS (a, b)` works.
   override def children: Seq[Expression] = flatGroupingSets ++ userGivenGroupByExprs
+  override protected def withNewChildrenInternal(
+      newChildren: IndexedSeq[Expression]): GroupingSets =
+    super.legacyWithNewChildren(newChildren).asInstanceOf[GroupingSets]
 }
 
 object GroupingSets {
@@ -184,6 +191,8 @@ case class Grouping(child: Expression) extends Expression with Unevaluable
     AttributeSet(VirtualColumn.groupingIdAttribute :: Nil)
   override def dataType: DataType = ByteType
   override def nullable: Boolean = false
+  override protected def withNewChildInternal(newChild: Expression): Grouping =
+    copy(child = newChild)
 }
 
 /**
@@ -223,6 +232,8 @@ case class GroupingID(groupByExprs: Seq[Expression]) extends Expression with Une
   override def dataType: DataType = GroupingID.dataType
   override def nullable: Boolean = false
   override def prettyName: String = "grouping_id"
+  override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): GroupingID =
+    copy(groupByExprs = newChildren)
 }
 
 object GroupingID {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
index 9738559..f23c1e5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
@@ -69,6 +69,8 @@ case class Md5(child: Expression)
     defineCodeGen(ctx, ev, c =>
       s"UTF8String.fromString(${classOf[DigestUtils].getName}.md5Hex($c))")
   }
+
+  override protected def withNewChildInternal(newChild: Expression): Md5 = copy(child = newChild)
 }
 
 /**
@@ -152,6 +154,9 @@ case class Sha2(left: Expression, right: Expression)
       """
     })
   }
+
+  override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Sha2 =
+    copy(left = newLeft, right = newRight)
 }
 
 /**
@@ -182,6 +187,8 @@ case class Sha1(child: Expression)
       s"UTF8String.fromString(${classOf[DigestUtils].getName}.sha1Hex($c))"
     )
   }
+
+  override protected def withNewChildInternal(newChild: Expression): Sha1 = copy(child = newChild)
 }
 
 /**
@@ -221,6 +228,8 @@ case class Crc32(child: Expression)
       """
     })
   }
+
+  override protected def withNewChildInternal(newChild: Expression): Crc32 = copy(child = newChild)
 }
 
 
@@ -598,6 +607,9 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends HashExpress
   override protected def computeHash(value: Any, dataType: DataType, seed: Int): Int = {
     Murmur3HashFunction.hash(value, dataType, seed).toInt
   }
+
+  override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Murmur3Hash =
+    copy(children = newChildren)
 }
 
 object Murmur3HashFunction extends InterpretedHashFunction {
@@ -638,6 +650,9 @@ case class XxHash64(children: Seq[Expression], seed: Long) extends HashExpressio
   override protected def computeHash(value: Any, dataType: DataType, seed: Long): Long = {
     XxHash64Function.hash(value, dataType, seed)
   }
+
+  override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): XxHash64 =
+    copy(children = newChildren)
 }
 
 object XxHash64Function extends InterpretedHashFunction {
@@ -842,6 +857,9 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] {
        |$code
      """.stripMargin
   }
+
+  override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): HiveHash =
+    copy(children = newChildren)
 }
 
 object HiveHashFunction extends InterpretedHashFunction {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
index bbfdf71..a0f9dc2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
@@ -103,6 +103,12 @@ case class LambdaFunction(
   lazy val bound: Boolean = arguments.forall(_.resolved)
 
   override def eval(input: InternalRow): Any = function.eval(input)
+
+  override protected def withNewChildrenInternal(
+      newChildren: IndexedSeq[Expression]): LambdaFunction =
+    copy(
+      function = newChildren.head,
+      arguments = newChildren.tail.asInstanceOf[Seq[NamedExpression]])
 }
 
 object LambdaFunction {
@@ -219,6 +225,7 @@ trait SimpleHigherOrderFunction extends HigherOrderFunction with BinaryLike[Expr
       nullSafeEval(inputRow, value)
     }
   }
+
 }
 
 trait ArrayBasedSimpleHigherOrderFunction extends SimpleHigherOrderFunction {
@@ -289,6 +296,10 @@ case class ArrayTransform(
   }
 
   override def prettyName: String = "transform"
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression, newRight: Expression): ArrayTransform =
+    copy(argument = newLeft, function = newRight)
 }
 
 /**
@@ -378,6 +389,10 @@ case class ArraySort(
   }
 
   override def prettyName: String = "array_sort"
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression, newRight: Expression): ArraySort =
+    copy(argument = newLeft, function = newRight)
 }
 
 object ArraySort {
@@ -448,6 +463,10 @@ case class MapFilter(
   override def functionType: AbstractDataType = BooleanType
 
   override def prettyName: String = "map_filter"
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression, newRight: Expression): MapFilter =
+    copy(argument = newLeft, function = newRight)
 }
 
 /**
@@ -513,6 +532,10 @@ case class ArrayFilter(
   }
 
   override def prettyName: String = "filter"
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression, newRight: Expression): ArrayFilter =
+    copy(argument = newLeft, function = newRight)
 }
 
 /**
@@ -594,6 +617,10 @@ case class ArrayExists(
   }
 
   override def prettyName: String = "exists"
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression, newRight: Expression): ArrayExists =
+    copy(argument = newLeft, function = newRight)
 }
 
 object ArrayExists {
@@ -670,6 +697,10 @@ case class ArrayForAll(
   }
 
   override def prettyName: String = "forall"
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression, newRight: Expression): ArrayForAll =
+    copy(argument = newLeft, function = newRight)
 }
 
 /**
@@ -767,6 +798,10 @@ case class ArrayAggregate(
   override def second: Expression = zero
   override def third: Expression = merge
   override def fourth: Expression = finish
+
+  override protected def withNewChildrenInternal(first: Expression, second: Expression,
+      third: Expression, fourth: Expression): ArrayAggregate =
+    copy(argument = first, zero = second, merge = third, finish = fourth)
 }
 
 /**
@@ -802,7 +837,7 @@ case class TransformKeys(
   }
 
   @transient lazy val LambdaFunction(
-    _, (keyVar: NamedLambdaVariable) :: (valueVar: NamedLambdaVariable) :: Nil, _) = function
+    _, Seq(keyVar: NamedLambdaVariable, valueVar: NamedLambdaVariable), _) = function
 
   private lazy val mapBuilder = new ArrayBasedMapBuilder(dataType.keyType, dataType.valueType)
 
@@ -821,6 +856,10 @@ case class TransformKeys(
   }
 
   override def prettyName: String = "transform_keys"
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression, newRight: Expression): TransformKeys =
+    copy(argument = newLeft, function = newRight)
 }
 
 /**
@@ -852,7 +891,7 @@ case class TransformValues(
   }
 
   @transient lazy val LambdaFunction(
-    _, (keyVar: NamedLambdaVariable) :: (valueVar: NamedLambdaVariable) :: Nil, _) = function
+    _, Seq(keyVar: NamedLambdaVariable, valueVar: NamedLambdaVariable), _) = function
 
   override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = {
     val map = argumentValue.asInstanceOf[MapData]
@@ -869,6 +908,10 @@ case class TransformValues(
   }
 
   override def prettyName: String = "transform_values"
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression, newRight: Expression): TransformValues =
+    copy(argument = newLeft, function = newRight)
 }
 
 /**
@@ -1056,6 +1099,13 @@ case class MapZipWith(left: Expression, right: Expression, function: Expression)
   override def first: Expression = left
   override def second: Expression = right
   override def third: Expression = function
+
+  override protected def withNewChildrenInternal(
+      newFirst: Expression, newSecond: Expression, newThird: Expression): MapZipWith =
+    copy(
+      left = newFirst,
+      right = newSecond,
+      function = newThird)
 }
 
 // scalastyle:off line.size.limit
@@ -1136,4 +1186,8 @@ case class ZipWith(left: Expression, right: Expression, function: Expression)
   override def first: Expression = left
   override def second: Expression = right
   override def third: Expression = function
+
+  override protected def withNewChildrenInternal(
+      newFirst: Expression, newSecond: Expression, newThird: Expression): ZipWith =
+    copy(left = newFirst, right = newSecond, function = newThird)
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala
index 23cf0bc..4311b38 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala
@@ -49,22 +49,40 @@ abstract class ExtractIntervalPart(
 }
 
 case class ExtractIntervalYears(child: Expression)
-  extends ExtractIntervalPart(child, IntegerType, getYears, "getYears")
+  extends ExtractIntervalPart(child, IntegerType, getYears, "getYears") {
+  override protected def withNewChildInternal(newChild: Expression): ExtractIntervalYears =
+    copy(child = newChild)
+}
 
 case class ExtractIntervalMonths(child: Expression)
-  extends ExtractIntervalPart(child, ByteType, getMonths, "getMonths")
+  extends ExtractIntervalPart(child, ByteType, getMonths, "getMonths") {
+  override protected def withNewChildInternal(newChild: Expression): ExtractIntervalMonths =
+    copy(child = newChild)
+}
 
 case class ExtractIntervalDays(child: Expression)
-  extends ExtractIntervalPart(child, IntegerType, getDays, "getDays")
+  extends ExtractIntervalPart(child, IntegerType, getDays, "getDays") {
+  override protected def withNewChildInternal(newChild: Expression): ExtractIntervalDays =
+    copy(child = newChild)
+}
 
 case class ExtractIntervalHours(child: Expression)
-  extends ExtractIntervalPart(child, LongType, getHours, "getHours")
+  extends ExtractIntervalPart(child, LongType, getHours, "getHours") {
+  override protected def withNewChildInternal(newChild: Expression): ExtractIntervalHours =
+    copy(child = newChild)
+}
 
 case class ExtractIntervalMinutes(child: Expression)
-  extends ExtractIntervalPart(child, ByteType, getMinutes, "getMinutes")
+  extends ExtractIntervalPart(child, ByteType, getMinutes, "getMinutes") {
+  override protected def withNewChildInternal(newChild: Expression): ExtractIntervalMinutes =
+    copy(child = newChild)
+}
 
 case class ExtractIntervalSeconds(child: Expression)
-  extends ExtractIntervalPart(child, DecimalType(8, 6), getSeconds, "getSeconds")
+  extends ExtractIntervalPart(child, DecimalType(8, 6), getSeconds, "getSeconds") {
+  override protected def withNewChildInternal(newChild: Expression): ExtractIntervalSeconds =
+    copy(child = newChild)
+}
 
 object ExtractIntervalPart {
 
@@ -119,6 +137,10 @@ case class MultiplyInterval(
     if (failOnError) multiplyExact else multiply
 
   override protected def operationName: String = if (failOnError) "multiplyExact" else "multiply"
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression, newRight: Expression): MultiplyInterval =
+    copy(interval = newLeft, num = newRight)
 }
 
 case class DivideInterval(
@@ -131,6 +153,10 @@ case class DivideInterval(
     if (failOnError) divideExact else divide
 
   override protected def operationName: String = if (failOnError) "divideExact" else "divide"
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression, newRight: Expression): DivideInterval =
+    copy(interval = newLeft, num = newRight)
 }
 
 // scalastyle:off line.size.limit
@@ -251,6 +277,19 @@ case class MakeInterval(
   }
 
   override def prettyName: String = "make_interval"
+
+  // Seq(years, months, weeks, days, hours, mins, secs)
+  override protected def withNewChildrenInternal(
+      newChildren: IndexedSeq[Expression]): MakeInterval =
+    copy(
+      years = newChildren(0),
+      months = newChildren(1),
+      weeks = newChildren(2),
+      days = newChildren(3),
+      hours = newChildren(4),
+      mins = newChildren(5),
+      secs = newChildren(6)
+    )
 }
 
 // Multiply an year-month interval by a numeric
@@ -298,6 +337,10 @@ case class MultiplyYMInterval(
   }
 
   override def toString: String = s"($left * $right)"
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression, newRight: Expression): MultiplyYMInterval =
+    copy(interval = newLeft, num = newRight)
 }
 
 // Multiply a day-time interval by a numeric
@@ -340,6 +383,10 @@ case class MultiplyDTInterval(
   }
 
   override def toString: String = s"($left * $right)"
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression, newRight: Expression): MultiplyDTInterval =
+    copy(interval = newLeft, num = newRight)
 }
 
 // Divide an year-month interval by a numeric
@@ -394,6 +441,10 @@ case class DivideYMInterval(
   }
 
   override def toString: String = s"($left / $right)"
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression, newRight: Expression): DivideYMInterval =
+    copy(interval = newLeft, num = newRight)
 }
 
 // Divide a day-time interval by a numeric
@@ -437,4 +488,8 @@ case class DivideDTInterval(
   }
 
   override def toString: String = s"($left / $right)"
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression, newRight: Expression): DivideDTInterval =
+    copy(interval = newLeft, num = newRight)
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
index b217110..6a56bbf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
@@ -335,6 +335,10 @@ case class GetJsonObject(json: Expression, path: Expression)
         false
     }
   }
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression, newRight: Expression): GetJsonObject =
+    copy(json = newLeft, path = newRight)
 }
 
 // scalastyle:off line.size.limit line.contains.tab
@@ -498,6 +502,9 @@ case class JsonTuple(children: Seq[Expression])
         generator.copyCurrentStructure(parser)
     }
   }
+
+  override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): JsonTuple =
+    copy(children = newChildren)
 }
 
 /**
@@ -609,6 +616,9 @@ case class JsonToStructs(
   }
 
   override def prettyName: String = "from_json"
+
+  override protected def withNewChildInternal(newChild: Expression): JsonToStructs =
+    copy(child = newChild)
 }
 
 /**
@@ -731,6 +741,9 @@ case class StructsToJson(
   override def inputTypes: Seq[AbstractDataType] = TypeCollection(ArrayType, StructType) :: Nil
 
   override def prettyName: String = "to_json"
+
+  override protected def withNewChildInternal(newChild: Expression): StructsToJson =
+    copy(child = newChild)
 }
 
 /**
@@ -805,6 +818,9 @@ case class SchemaOfJson(
   }
 
   override def prettyName: String = "schema_of_json"
+
+  override protected def withNewChildInternal(newChild: Expression): SchemaOfJson =
+    copy(child = newChild)
 }
 
 /**
@@ -874,6 +890,9 @@ case class LengthOfJsonArray(child: Expression) extends UnaryExpression
     }
     length
   }
+
+  override protected def withNewChildInternal(newChild: Expression): LengthOfJsonArray =
+    copy(child = newChild)
 }
 
 /**
@@ -943,4 +962,7 @@ case class JsonObjectKeys(child: Expression) extends UnaryExpression with Codege
     }
     new GenericArrayData(arrayBufferOfKeys.toArray)
   }
+
+  override protected def withNewChildInternal(newChild: Expression): JsonObjectKeys =
+    copy(child = newChild)
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
index 3b58f3d..516eeb9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
@@ -187,7 +187,9 @@ case class Pi() extends LeafMathExpression(math.Pi, "PI")
   """,
   since = "1.4.0",
   group = "math_funcs")
-case class Acos(child: Expression) extends UnaryMathExpression(math.acos, "ACOS")
+case class Acos(child: Expression) extends UnaryMathExpression(math.acos, "ACOS") {
+  override protected def withNewChildInternal(newChild: Expression): Acos = copy(child = newChild)
+}
 
 @ExpressionDescription(
   usage = """
@@ -203,7 +205,9 @@ case class Acos(child: Expression) extends UnaryMathExpression(math.acos, "ACOS"
   """,
   since = "1.4.0",
   group = "math_funcs")
-case class Asin(child: Expression) extends UnaryMathExpression(math.asin, "ASIN")
+case class Asin(child: Expression) extends UnaryMathExpression(math.asin, "ASIN") {
+  override protected def withNewChildInternal(newChild: Expression): Asin = copy(child = newChild)
+}
 
 @ExpressionDescription(
   usage = """
@@ -217,7 +221,9 @@ case class Asin(child: Expression) extends UnaryMathExpression(math.asin, "ASIN"
   """,
   since = "1.4.0",
   group = "math_funcs")
-case class Atan(child: Expression) extends UnaryMathExpression(math.atan, "ATAN")
+case class Atan(child: Expression) extends UnaryMathExpression(math.atan, "ATAN") {
+  override protected def withNewChildInternal(newChild: Expression): Atan = copy(child = newChild)
+}
 
 @ExpressionDescription(
   usage = "_FUNC_(expr) - Returns the cube root of `expr`.",
@@ -228,7 +234,9 @@ case class Atan(child: Expression) extends UnaryMathExpression(math.atan, "ATAN"
   """,
   since = "1.4.0",
   group = "math_funcs")
-case class Cbrt(child: Expression) extends UnaryMathExpression(math.cbrt, "CBRT")
+case class Cbrt(child: Expression) extends UnaryMathExpression(math.cbrt, "CBRT") {
+  override protected def withNewChildInternal(newChild: Expression): Cbrt = copy(child = newChild)
+}
 
 @ExpressionDescription(
   usage = "_FUNC_(expr) - Returns the smallest integer not smaller than `expr`.",
@@ -267,6 +275,8 @@ case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL"
       case _ => defineCodeGen(ctx, ev, c => s"(long)(java.lang.Math.${funcName}($c))")
     }
   }
+
+  override protected def withNewChildInternal(newChild: Expression): Ceil = copy(child = newChild)
 }
 
 @ExpressionDescription(
@@ -285,7 +295,9 @@ case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL"
   """,
   since = "1.4.0",
   group = "math_funcs")
-case class Cos(child: Expression) extends UnaryMathExpression(math.cos, "COS")
+case class Cos(child: Expression) extends UnaryMathExpression(math.cos, "COS") {
+  override protected def withNewChildInternal(newChild: Expression): Cos = copy(child = newChild)
+}
 
 @ExpressionDescription(
   usage = """
@@ -303,7 +315,9 @@ case class Cos(child: Expression) extends UnaryMathExpression(math.cos, "COS")
   """,
   since = "1.4.0",
   group = "math_funcs")
-case class Cosh(child: Expression) extends UnaryMathExpression(math.cosh, "COSH")
+case class Cosh(child: Expression) extends UnaryMathExpression(math.cosh, "COSH") {
+  override protected def withNewChildInternal(newChild: Expression): Cosh = copy(child = newChild)
+}
 
 @ExpressionDescription(
   usage = """
@@ -324,6 +338,7 @@ case class Acosh(child: Expression)
     defineCodeGen(ctx, ev,
       c => s"java.lang.StrictMath.log($c + java.lang.Math.sqrt($c * $c - 1.0))")
   }
+  override protected def withNewChildInternal(newChild: Expression): Acosh = copy(child = newChild)
 }
 
 /**
@@ -372,6 +387,10 @@ case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expre
        """
     )
   }
+
+  override protected def withNewChildrenInternal(
+      newFirst: Expression, newSecond: Expression, newThird: Expression): Expression =
+    copy(numExpr = newFirst, fromBaseExpr = newSecond, toBaseExpr = newThird)
 }
 
 @ExpressionDescription(
@@ -387,6 +406,7 @@ case class Exp(child: Expression) extends UnaryMathExpression(StrictMath.exp, "E
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     defineCodeGen(ctx, ev, c => s"java.lang.StrictMath.exp($c)")
   }
+  override protected def withNewChildInternal(newChild: Expression): Exp = copy(child = newChild)
 }
 
 @ExpressionDescription(
@@ -402,6 +422,7 @@ case class Expm1(child: Expression) extends UnaryMathExpression(StrictMath.expm1
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     defineCodeGen(ctx, ev, c => s"java.lang.StrictMath.expm1($c)")
   }
+  override protected def withNewChildInternal(newChild: Expression): Expm1 = copy(child = newChild)
 }
 
 @ExpressionDescription(
@@ -441,6 +462,8 @@ case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLO
       case _ => defineCodeGen(ctx, ev, c => s"(long)(java.lang.Math.${funcName}($c))")
     }
   }
+
+  override protected def withNewChildInternal(newChild: Expression): Floor = copy(child = newChild)
 }
 
 object Factorial {
@@ -514,6 +537,9 @@ case class Factorial(child: Expression)
       """
     })
   }
+
+  override protected def withNewChildInternal(newChild: Expression): Factorial =
+    copy(child = newChild)
 }
 
 @ExpressionDescription(
@@ -527,6 +553,7 @@ case class Factorial(child: Expression)
   group = "math_funcs")
 case class Log(child: Expression) extends UnaryLogExpression(StrictMath.log, "LOG") {
   override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("ln")
+  override protected def withNewChildInternal(newChild: Expression): Log = copy(child = newChild)
 }
 
 @ExpressionDescription(
@@ -551,6 +578,7 @@ case class Log2(child: Expression)
       """
     )
   }
+  override protected def withNewChildInternal(newChild: Expression): Log2 = copy(child = newChild)
 }
 
 @ExpressionDescription(
@@ -562,7 +590,9 @@ case class Log2(child: Expression)
   """,
   since = "1.4.0",
   group = "math_funcs")
-case class Log10(child: Expression) extends UnaryLogExpression(StrictMath.log10, "LOG10")
+case class Log10(child: Expression) extends UnaryLogExpression(StrictMath.log10, "LOG10") {
+  override protected def withNewChildInternal(newChild: Expression): Log10 = copy(child = newChild)
+}
 
 @ExpressionDescription(
   usage = "_FUNC_(expr) - Returns log(1 + `expr`).",
@@ -575,6 +605,7 @@ case class Log10(child: Expression) extends UnaryLogExpression(StrictMath.log10,
   group = "math_funcs")
 case class Log1p(child: Expression) extends UnaryLogExpression(StrictMath.log1p, "LOG1P") {
   protected override val yAsymptote: Double = -1.0
+  override protected def withNewChildInternal(newChild: Expression): Log1p = copy(child = newChild)
 }
 
 // scalastyle:off line.size.limit
@@ -591,6 +622,7 @@ case class Log1p(child: Expression) extends UnaryLogExpression(StrictMath.log1p,
 case class Rint(child: Expression) extends UnaryMathExpression(math.rint, "ROUND") {
   override def funcName: String = "rint"
   override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("rint")
+  override protected def withNewChildInternal(newChild: Expression): Rint = copy(child = newChild)
 }
 
 @ExpressionDescription(
@@ -602,7 +634,9 @@ case class Rint(child: Expression) extends UnaryMathExpression(math.rint, "ROUND
   """,
   since = "1.4.0",
   group = "math_funcs")
-case class Signum(child: Expression) extends UnaryMathExpression(math.signum, "SIGNUM")
+case class Signum(child: Expression) extends UnaryMathExpression(math.signum, "SIGNUM") {
+  override protected def withNewChildInternal(newChild: Expression): Signum = copy(child = newChild)
+}
 
 @ExpressionDescription(
   usage = "_FUNC_(expr) - Returns the sine of `expr`, as if computed by `java.lang.Math._FUNC_`.",
@@ -617,7 +651,9 @@ case class Signum(child: Expression) extends UnaryMathExpression(math.signum, "S
   """,
   since = "1.4.0",
   group = "math_funcs")
-case class Sin(child: Expression) extends UnaryMathExpression(math.sin, "SIN")
+case class Sin(child: Expression) extends UnaryMathExpression(math.sin, "SIN") {
+  override protected def withNewChildInternal(newChild: Expression): Sin = copy(child = newChild)
+}
 
 @ExpressionDescription(
   usage = """
@@ -634,7 +670,9 @@ case class Sin(child: Expression) extends UnaryMathExpression(math.sin, "SIN")
   """,
   since = "1.4.0",
   group = "math_funcs")
-case class Sinh(child: Expression) extends UnaryMathExpression(math.sinh, "SINH")
+case class Sinh(child: Expression) extends UnaryMathExpression(math.sinh, "SINH") {
+  override protected def withNewChildInternal(newChild: Expression): Sinh = copy(child = newChild)
+}
 
 @ExpressionDescription(
   usage = """
@@ -656,6 +694,7 @@ case class Asinh(child: Expression)
       s"$c == Double.NEGATIVE_INFINITY ? Double.NEGATIVE_INFINITY : " +
       s"java.lang.StrictMath.log($c + java.lang.Math.sqrt($c * $c + 1.0))")
   }
+  override protected def withNewChildInternal(newChild: Expression): Asinh = copy(child = newChild)
 }
 
 @ExpressionDescription(
@@ -667,7 +706,9 @@ case class Asinh(child: Expression)
   """,
   since = "1.1.1",
   group = "math_funcs")
-case class Sqrt(child: Expression) extends UnaryMathExpression(math.sqrt, "SQRT")
+case class Sqrt(child: Expression) extends UnaryMathExpression(math.sqrt, "SQRT") {
+  override protected def withNewChildInternal(newChild: Expression): Sqrt = copy(child = newChild)
+}
 
 @ExpressionDescription(
   usage = """
@@ -684,7 +725,9 @@ case class Sqrt(child: Expression) extends UnaryMathExpression(math.sqrt, "SQRT"
   """,
   since = "1.4.0",
   group = "math_funcs")
-case class Tan(child: Expression) extends UnaryMathExpression(math.tan, "TAN")
+case class Tan(child: Expression) extends UnaryMathExpression(math.tan, "TAN") {
+  override protected def withNewChildInternal(newChild: Expression): Tan = copy(child = newChild)
+}
 
 @ExpressionDescription(
   usage = """
@@ -706,6 +749,7 @@ case class Cot(child: Expression)
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     defineCodeGen(ctx, ev, c => s"${ev.value} = 1 / java.lang.Math.tan($c);")
   }
+  override protected def withNewChildInternal(newChild: Expression): Cot = copy(child = newChild)
 }
 
 @ExpressionDescription(
@@ -724,7 +768,9 @@ case class Cot(child: Expression)
   """,
   since = "1.4.0",
   group = "math_funcs")
-case class Tanh(child: Expression) extends UnaryMathExpression(math.tanh, "TANH")
+case class Tanh(child: Expression) extends UnaryMathExpression(math.tanh, "TANH") {
+  override protected def withNewChildInternal(newChild: Expression): Tanh = copy(child = newChild)
+}
 
 @ExpressionDescription(
   usage = """
@@ -747,6 +793,7 @@ case class Atanh(child: Expression)
     defineCodeGen(ctx, ev,
       c => s"0.5 * (java.lang.StrictMath.log1p($c) - java.lang.StrictMath.log1p(- $c))")
   }
+  override protected def withNewChildInternal(newChild: Expression): Atanh = copy(child = newChild)
 }
 
 @ExpressionDescription(
@@ -764,6 +811,8 @@ case class Atanh(child: Expression)
   group = "math_funcs")
 case class ToDegrees(child: Expression) extends UnaryMathExpression(math.toDegrees, "DEGREES") {
   override def funcName: String = "toDegrees"
+  override protected def withNewChildInternal(newChild: Expression): ToDegrees =
+    copy(child = newChild)
 }
 
 @ExpressionDescription(
@@ -781,6 +830,8 @@ case class ToDegrees(child: Expression) extends UnaryMathExpression(math.toDegre
   group = "math_funcs")
 case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadians, "RADIANS") {
   override def funcName: String = "toRadians"
+  override protected def withNewChildInternal(newChild: Expression): ToRadians =
+    copy(child = newChild)
 }
 
 // scalastyle:off line.size.limit
@@ -811,6 +862,8 @@ case class Bin(child: Expression)
     defineCodeGen(ctx, ev, (c) =>
       s"UTF8String.fromString(java.lang.Long.toBinaryString($c))")
   }
+
+  override protected def withNewChildInternal(newChild: Expression): Bin = copy(child = newChild)
 }
 
 object Hex {
@@ -923,6 +976,8 @@ case class Hex(child: Expression)
       })
     })
   }
+
+  override protected def withNewChildInternal(newChild: Expression): Hex = copy(child = newChild)
 }
 
 /**
@@ -958,6 +1013,8 @@ case class Unhex(child: Expression)
        """
     })
   }
+
+  override protected def withNewChildInternal(newChild: Expression): Unhex = copy(child = newChild)
 }
 
 
@@ -996,6 +1053,9 @@ case class Atan2(left: Expression, right: Expression)
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.atan2($c1 + 0.0, $c2 + 0.0)")
   }
+
+  override protected def withNewChildrenInternal(
+    newLeft: Expression, newRight: Expression): Expression = copy(left = newLeft, right = newRight)
 }
 
 @ExpressionDescription(
@@ -1012,6 +1072,8 @@ case class Pow(left: Expression, right: Expression)
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.StrictMath.pow($c1, $c2)")
   }
+  override protected def withNewChildrenInternal(
+    newLeft: Expression, newRight: Expression): Expression = copy(left = newLeft, right = newRight)
 }
 
 
@@ -1048,6 +1110,9 @@ case class ShiftLeft(left: Expression, right: Expression)
   override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     defineCodeGen(ctx, ev, (left, right) => s"$left << $right")
   }
+
+  override protected def withNewChildrenInternal(
+    newLeft: Expression, newRight: Expression): ShiftLeft = copy(left = newLeft, right = newRight)
 }
 
 
@@ -1084,6 +1149,9 @@ case class ShiftRight(left: Expression, right: Expression)
   override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     defineCodeGen(ctx, ev, (left, right) => s"$left >> $right")
   }
+
+  override protected def withNewChildrenInternal(
+    newLeft: Expression, newRight: Expression): ShiftRight = copy(left = newLeft, right = newRight)
 }
 
 
@@ -1120,6 +1188,10 @@ case class ShiftRightUnsigned(left: Expression, right: Expression)
   override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     defineCodeGen(ctx, ev, (left, right) => s"$left >>> $right")
   }
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression, newRight: Expression): ShiftRightUnsigned =
+    copy(left = newLeft, right = newRight)
 }
 
 @ExpressionDescription(
@@ -1132,7 +1204,10 @@ case class ShiftRightUnsigned(left: Expression, right: Expression)
   since = "1.4.0",
   group = "math_funcs")
 case class Hypot(left: Expression, right: Expression)
-  extends BinaryMathExpression(math.hypot, "HYPOT")
+  extends BinaryMathExpression(math.hypot, "HYPOT") {
+  override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Hypot =
+    copy(left = newLeft, right = newRight)
+}
 
 
 /**
@@ -1190,6 +1265,9 @@ case class Logarithm(left: Expression, right: Expression)
         """)
     }
   }
+
+  override protected def withNewChildrenInternal(
+    newLeft: Expression, newRight: Expression): Logarithm = copy(left = newLeft, right = newRight)
 }
 
 /**
@@ -1387,6 +1465,8 @@ case class Round(child: Expression, scale: Expression)
   extends RoundBase(child, scale, BigDecimal.RoundingMode.HALF_UP, "ROUND_HALF_UP")
     with Serializable with ImplicitCastInputTypes {
   def this(child: Expression) = this(child, Literal(0))
+  override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Round =
+    copy(child = newLeft, scale = newRight)
 }
 
 /**
@@ -1409,6 +1489,8 @@ case class BRound(child: Expression, scale: Expression)
   extends RoundBase(child, scale, BigDecimal.RoundingMode.HALF_EVEN, "ROUND_HALF_EVEN")
     with Serializable with ImplicitCastInputTypes {
   def this(child: Expression) = this(child, Literal(0))
+  override protected def withNewChildrenInternal(
+    newLeft: Expression, newRight: Expression): BRound = copy(child = newLeft, scale = newRight)
 }
 
 object WidthBucket {
@@ -1511,4 +1593,8 @@ case class WidthBucket(
   override def second: Expression = minValue
   override def third: Expression = maxValue
   override def fourth: Expression = numBucket
+
+  override protected def withNewChildrenInternal(
+      first: Expression, second: Expression, third: Expression, fourth: Expression): WidthBucket =
+    copy(value = first, minValue = second, maxValue = third, numBucket = fourth)
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
index 6b3b949..9e854cf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
@@ -51,6 +51,9 @@ case class PrintToStderr(child: Expression) extends UnaryExpression {
          | ${ev.value} = $c;
        """.stripMargin)
   }
+
+  override protected def withNewChildInternal(newChild: Expression): PrintToStderr =
+    copy(child = newChild)
 }
 
 /**
@@ -100,6 +103,9 @@ case class RaiseError(child: Expression, dataType: DataType)
       value = JavaCode.defaultLiteral(dataType)
     )
   }
+
+  override protected def withNewChildInternal(newChild: Expression): RaiseError =
+    copy(child = newChild)
 }
 
 object RaiseError {
@@ -133,6 +139,9 @@ case class AssertTrue(left: Expression, right: Expression, child: Expression)
 
   override def flatArguments: Iterator[Any] = Iterator(left, right)
   override def exprsReplaced: Seq[Expression] = Seq(left, right)
+
+  override protected def withNewChildInternal(newChild: Expression): AssertTrue =
+    copy(child = newChild)
 }
 
 object AssertTrue {
@@ -268,4 +277,6 @@ case class TypeOf(child: Expression) extends UnaryExpression {
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     defineCodeGen(ctx, ev, _ => s"""UTF8String.fromString(${child.dataType.catalogString})""")
   }
+
+  override protected def withNewChildInternal(newChild: Expression): TypeOf = copy(child = newChild)
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index e73b024..b73a189 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -226,6 +226,9 @@ case class Alias(child: Expression, name: String)(
       if (qualifier.nonEmpty) qualifier.map(quoteIfNeeded).mkString(".") + "." else ""
     s"${child.sql} AS $qualifierPrefix${quoteIfNeeded(name)}"
   }
+
+  override protected def withNewChildInternal(newChild: Expression): Alias =
+    copy(child = newChild)(exprId, qualifier, explicitMetadata, nonInheritableMetadataKeys)
 }
 
 /**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
index d508129..2c2df6b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
@@ -120,6 +120,9 @@ case class Coalesce(children: Seq[Expression]) extends ComplexTypeMergingExpress
          |} while (false);
        """.stripMargin)
   }
+
+  override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Coalesce =
+    copy(children = newChildren)
 }
 
 
@@ -141,6 +144,8 @@ case class IfNull(left: Expression, right: Expression, child: Expression)
 
   override def flatArguments: Iterator[Any] = Iterator(left, right)
   override def exprsReplaced: Seq[Expression] = Seq(left, right)
+
+  override protected def withNewChildInternal(newChild: Expression): IfNull = copy(child = newChild)
 }
 
 
@@ -162,6 +167,8 @@ case class NullIf(left: Expression, right: Expression, child: Expression)
 
   override def flatArguments: Iterator[Any] = Iterator(left, right)
   override def exprsReplaced: Seq[Expression] = Seq(left, right)
+
+  override protected def withNewChildInternal(newChild: Expression): NullIf = copy(child = newChild)
 }
 
 
@@ -182,6 +189,8 @@ case class Nvl(left: Expression, right: Expression, child: Expression) extends R
 
   override def flatArguments: Iterator[Any] = Iterator(left, right)
   override def exprsReplaced: Seq[Expression] = Seq(left, right)
+
+  override protected def withNewChildInternal(newChild: Expression): Nvl = copy(child = newChild)
 }
 
 
@@ -205,6 +214,8 @@ case class Nvl2(expr1: Expression, expr2: Expression, expr3: Expression, child:
 
   override def flatArguments: Iterator[Any] = Iterator(expr1, expr2, expr3)
   override def exprsReplaced: Seq[Expression] = Seq(expr1, expr2, expr3)
+
+  override protected def withNewChildInternal(newChild: Expression): Nvl2 = copy(child = newChild)
 }
 
 
@@ -249,6 +260,8 @@ case class IsNaN(child: Expression) extends UnaryExpression
           ${ev.value} = !${eval.isNull} && Double.isNaN(${eval.value});""", isNull = FalseLiteral)
     }
   }
+
+  override protected def withNewChildInternal(newChild: Expression): IsNaN = copy(child = newChild)
 }
 
 /**
@@ -311,6 +324,9 @@ case class NaNvl(left: Expression, right: Expression)
           }""")
     }
   }
+
+  override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): NaNvl =
+    copy(left = newLeft, right = newRight)
 }
 
 
@@ -339,6 +355,8 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate {
   }
 
   override def sql: String = s"(${child.sql} IS NULL)"
+
+  override protected def withNewChildInternal(newChild: Expression): IsNull = copy(child = newChild)
 }
 
 
@@ -374,6 +392,9 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate {
   }
 
   override def sql: String = s"(${child.sql} IS NOT NULL)"
+
+  override protected def withNewChildInternal(newChild: Expression): IsNotNull =
+    copy(child = newChild)
 }
 
 
@@ -466,4 +487,7 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate
          |${CodeGenerator.JAVA_BOOLEAN} ${ev.value} = $nonnull >= $n;
        """.stripMargin, isNull = FalseLiteral)
   }
+
+  override protected def withNewChildrenInternal(
+    newChildren: IndexedSeq[Expression]): AtLeastNNonNulls = copy(children = newChildren)
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 5be5216..5ae0cef 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.objects
 import java.lang.reflect.{Method, Modifier}
 
 import scala.collection.JavaConverters._
-import scala.collection.mutable.{Builder, IndexedSeq, WrappedArray}
+import scala.collection.mutable.{Builder, WrappedArray}
 import scala.reflect.ClassTag
 import scala.util.{Properties, Try}
 
@@ -279,6 +279,9 @@ case class StaticInvoke(
      """
     ev.copy(code = code)
   }
+
+  override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
+    copy(arguments = newChildren)
 }
 
 /**
@@ -400,6 +403,9 @@ case class Invoke(
   }
 
   override def toString: String = s"$targetObject.$functionName"
+
+  override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Invoke =
+    copy(targetObject = newChildren.head, arguments = newChildren.tail)
 }
 
 object NewInstance {
@@ -506,6 +512,9 @@ case class NewInstance(
   }
 
   override def toString: String = s"newInstance($cls)"
+
+  override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): NewInstance =
+    copy(arguments = newChildren)
 }
 
 /**
@@ -543,6 +552,9 @@ case class UnwrapOption(
     """
     ev.copy(code = code)
   }
+
+  override protected def withNewChildInternal(newChild: Expression): UnwrapOption =
+    copy(child = newChild)
 }
 
 /**
@@ -573,6 +585,9 @@ case class WrapOption(child: Expression, optType: DataType)
     """
     ev.copy(code = code, isNull = FalseLiteral)
   }
+
+  override protected def withNewChildInternal(newChild: Expression): WrapOption =
+    copy(child = newChild)
 }
 
 object LambdaVariable {
@@ -659,6 +674,9 @@ case class UnresolvedMapObjects(
   override def dataType: DataType = customCollectionCls.map(ObjectType.apply).getOrElse {
     throw QueryExecutionErrors.customCollectionClsNotResolvedError
   }
+
+  override protected def withNewChildInternal(newChild: Expression): UnresolvedMapObjects =
+    copy(child = newChild)
 }
 
 object MapObjects {
@@ -1025,6 +1043,13 @@ case class MapObjects private(
     """
     ev.copy(code = code, isNull = genInputData.isNull)
   }
+
+  override protected def withNewChildrenInternal(
+      newFirst: Expression, newSecond: Expression, newThird: Expression): Expression =
+    copy(
+      loopVar = newFirst.asInstanceOf[LambdaVariable],
+      lambdaFunction = newSecond,
+      inputData = newThird)
 }
 
 /**
@@ -1044,6 +1069,9 @@ case class UnresolvedCatalystToExternalMap(
   override lazy val resolved = false
 
   override def dataType: DataType = ObjectType(collClass)
+
+  override protected def withNewChildInternal(
+    newChild: Expression): UnresolvedCatalystToExternalMap = copy(child = newChild)
 }
 
 object CatalystToExternalMap {
@@ -1214,6 +1242,15 @@ case class CatalystToExternalMap private(
     """
     ev.copy(code = code, isNull = genInputData.isNull)
   }
+
+  override protected def withNewChildrenInternal(
+      newChildren: IndexedSeq[Expression]): CatalystToExternalMap =
+    copy(
+      keyLoopVar = newChildren(0).asInstanceOf[LambdaVariable],
+      keyLambdaFunction = newChildren(1),
+      valueLoopVar = newChildren(2).asInstanceOf[LambdaVariable],
+      valueLambdaFunction = newChildren(3),
+      inputData = newChildren(4))
 }
 
 object ExternalMapToCatalyst {
@@ -1437,6 +1474,15 @@ case class ExternalMapToCatalyst private(
       """
     ev.copy(code = code, isNull = inputMap.isNull)
   }
+
+  override protected def withNewChildrenInternal(
+      newChildren: IndexedSeq[Expression]): ExternalMapToCatalyst =
+    copy(
+      keyLoopVar = newChildren(0).asInstanceOf[LambdaVariable],
+      keyConverter = newChildren(1),
+      valueLoopVar = newChildren(2).asInstanceOf[LambdaVariable],
+      valueConverter = newChildren(3),
+      inputData = newChildren(4))
 }
 
 /**
@@ -1487,6 +1533,9 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType)
        """.stripMargin
     ev.copy(code = code, isNull = FalseLiteral)
   }
+
+  override protected def withNewChildrenInternal(
+    newChildren: IndexedSeq[Expression]): CreateExternalRow = copy(children = newChildren)
 }
 
 /**
@@ -1516,6 +1565,9 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean)
   }
 
   override def dataType: DataType = BinaryType
+
+  override protected def withNewChildInternal(newChild: Expression): EncodeUsingSerializer =
+    copy(child = newChild)
 }
 
 /**
@@ -1548,6 +1600,9 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B
   }
 
   override def dataType: DataType = ObjectType(tag.runtimeClass)
+
+  override protected def withNewChildInternal(newChild: Expression): DecodeUsingSerializer[T] =
+    copy(child = newChild)
 }
 
 /**
@@ -1629,6 +1684,10 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp
        """.stripMargin
     ev.copy(code = code, isNull = instanceGen.isNull, value = instanceGen.value)
   }
+
+  override protected def withNewChildrenInternal(
+      newChildren: IndexedSeq[Expression]): InitializeJavaBean =
+    super.legacyWithNewChildren(newChildren).asInstanceOf[InitializeJavaBean]
 }
 
 /**
@@ -1676,6 +1735,9 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String] = Nil)
      """
     ev.copy(code = code, isNull = FalseLiteral, value = childGen.value)
   }
+
+  override protected def withNewChildInternal(newChild: Expression): AssertNotNull =
+    copy(child = newChild)
 }
 
 /**
@@ -1727,6 +1789,9 @@ case class GetExternalRowField(
      """
     ev.copy(code = code, isNull = FalseLiteral)
   }
+
+  override protected def withNewChildInternal(newChild: Expression): GetExternalRowField =
+    copy(child = newChild)
 }
 
 /**
@@ -1801,4 +1866,7 @@ case class ValidateExternalType(child: Expression, expected: DataType)
     """
     ev.copy(code = code, isNull = input.isNull)
   }
+
+  override protected def withNewChildInternal(newChild: Expression): ValidateExternalType =
+    copy(child = newChild)
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 33eb120..d9d0643 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -322,6 +322,8 @@ case class Not(child: Expression)
   }
 
   override def sql: String = s"(NOT ${child.sql})"
+
+  override protected def withNewChildInternal(newChild: Expression): Not = copy(child = newChild)
 }
 
 /**
@@ -379,6 +381,9 @@ case class InSubquery(values: Seq[Expression], query: ListQuery)
   override def nullable: Boolean = children.exists(_.nullable)
   override def toString: String = s"$value IN ($query)"
   override def sql: String = s"(${value.sql} IN (${query.sql}))"
+
+  override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): InSubquery =
+    copy(values = newChildren.dropRight(1), query = newChildren.last.asInstanceOf[ListQuery])
 }
 
 
@@ -520,6 +525,9 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
     val listSQL = list.map(_.sql).mkString(", ")
     s"($valueSQL IN ($listSQL))"
   }
+
+  override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): In =
+    copy(value = newChildren.head, list = newChildren.tail)
 }
 
 /**
@@ -625,6 +633,8 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with
       .mkString(", ")
     s"($valueSQL IN ($listSQL))"
   }
+
+  override protected def withNewChildInternal(newChild: Expression): InSet = copy(child = newChild)
 }
 
 @ExpressionDescription(
@@ -708,6 +718,9 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with
       """)
     }
   }
+
+  override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): And =
+    copy(left = newLeft, right = newRight)
 }
 
 @ExpressionDescription(
@@ -792,6 +805,9 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P
       """)
     }
   }
+
+  override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Or =
+    copy(left = newLeft, right = newRight)
 }
 
 
@@ -877,6 +893,9 @@ case class EqualTo(left: Expression, right: Expression)
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     defineCodeGen(ctx, ev, (c1, c2) => ctx.genEqual(left.dataType, c1, c2))
   }
+
+  override protected def withNewChildrenInternal(
+    newLeft: Expression, newRight: Expression): EqualTo = copy(left = newLeft, right = newRight)
 }
 
 // TODO: although map type is not orderable, technically map type should be able to be used
@@ -938,6 +957,10 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp
         boolean ${ev.value} = (${eval1.isNull} && ${eval2.isNull}) ||
            (!${eval1.isNull} && !${eval2.isNull} && $equalCode);""", isNull = FalseLiteral)
   }
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression, newRight: Expression): EqualNullSafe =
+    copy(left = newLeft, right = newRight)
 }
 
 @ExpressionDescription(
@@ -970,6 +993,9 @@ case class LessThan(left: Expression, right: Expression)
   override def symbol: String = "<"
 
   protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lt(input1, input2)
+
+  override protected def withNewChildrenInternal(
+    newLeft: Expression, newRight: Expression): Expression = copy(left = newLeft, right = newRight)
 }
 
 @ExpressionDescription(
@@ -1002,6 +1028,9 @@ case class LessThanOrEqual(left: Expression, right: Expression)
   override def symbol: String = "<="
 
   protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lteq(input1, input2)
+
+  override protected def withNewChildrenInternal(
+    newLeft: Expression, newRight: Expression): Expression = copy(left = newLeft, right = newRight)
 }
 
 @ExpressionDescription(
@@ -1034,6 +1063,9 @@ case class GreaterThan(left: Expression, right: Expression)
   override def symbol: String = ">"
 
   protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gt(input1, input2)
+
+  override protected def withNewChildrenInternal(
+    newLeft: Expression, newRight: Expression): Expression = copy(left = newLeft, right = newRight)
 }
 
 @ExpressionDescription(
@@ -1066,6 +1098,10 @@ case class GreaterThanOrEqual(left: Expression, right: Expression)
   override def symbol: String = ">="
 
   protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gteq(input1, input2)
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression, newRight: Expression): GreaterThanOrEqual =
+    copy(left = newLeft, right = newRight)
 }
 
 /**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
index 0a4c6e2..d470cad 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
@@ -111,6 +111,8 @@ case class Rand(child: Expression, hideSeed: Boolean = false) extends RDG {
   override def sql: String = {
     s"rand(${if (hideSeed) "" else child.sql})"
   }
+
+  override protected def withNewChildInternal(newChild: Expression): Rand = copy(child = newChild)
 }
 
 object Rand {
@@ -162,6 +164,8 @@ case class Randn(child: Expression, hideSeed: Boolean = false) extends RDG {
   override def sql: String = {
     s"randn(${if (hideSeed) "" else child.sql})"
   }
+
+  override protected def withNewChildInternal(newChild: Expression): Randn = copy(child = newChild)
 }
 
 object Randn {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
index 9fdab35..13d00fa 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
@@ -180,6 +180,9 @@ case class Like(left: Expression, right: Expression, escapeChar: Char)
       })
     }
   }
+
+  override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Like =
+    copy(left = newLeft, right = newRight)
 }
 
 sealed abstract class MultiLikeBase
@@ -268,10 +271,14 @@ sealed abstract class LikeAllBase extends MultiLikeBase {
 
 case class LikeAll(child: Expression, patterns: Seq[UTF8String]) extends LikeAllBase {
   override def isNotSpecified: Boolean = false
+  override protected def withNewChildInternal(newChild: Expression): LikeAll =
+    copy(child = newChild)
 }
 
 case class NotLikeAll(child: Expression, patterns: Seq[UTF8String]) extends LikeAllBase {
   override def isNotSpecified: Boolean = true
+  override protected def withNewChildInternal(newChild: Expression): NotLikeAll =
+    copy(child = newChild)
 }
 
 /**
@@ -324,10 +331,14 @@ sealed abstract class LikeAnyBase extends MultiLikeBase {
 
 case class LikeAny(child: Expression, patterns: Seq[UTF8String]) extends LikeAnyBase {
   override def isNotSpecified: Boolean = false
+  override protected def withNewChildInternal(newChild: Expression): LikeAny =
+    copy(child = newChild)
 }
 
 case class NotLikeAny(child: Expression, patterns: Seq[UTF8String]) extends LikeAnyBase {
   override def isNotSpecified: Boolean = true
+  override protected def withNewChildInternal(newChild: Expression): NotLikeAny =
+    copy(child = newChild)
 }
 
 // scalastyle:off line.contains.tab
@@ -409,6 +420,9 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress
       })
     }
   }
+
+  override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): RLike =
+    copy(left = newLeft, right = newRight)
 }
 
 
@@ -467,6 +481,10 @@ case class StringSplit(str: Expression, regex: Expression, limit: Expression)
   }
 
   override def prettyName: String = "split"
+
+  override protected def withNewChildrenInternal(
+      newFirst: Expression, newSecond: Expression, newThird: Expression): StringSplit =
+    copy(str = newFirst, regex = newSecond, limit = newThird)
 }
 
 
@@ -622,6 +640,10 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio
   override def second: Expression = regexp
   override def third: Expression = rep
   override def fourth: Expression = pos
+
+  override protected def withNewChildrenInternal(
+      first: Expression, second: Expression, third: Expression, fourth: Expression): RegExpReplace =
+    copy(subject = first, regexp = second, rep = third, pos = fourth)
 }
 
 object RegExpReplace {
@@ -765,6 +787,10 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio
       }"""
     })
   }
+
+  override protected def withNewChildrenInternal(
+      newFirst: Expression, newSecond: Expression, newThird: Expression): RegExpExtract =
+    copy(subject = newFirst, regexp = newSecond, idx = newThird)
 }
 
 /**
@@ -868,4 +894,8 @@ case class RegExpExtractAll(subject: Expression, regexp: Expression, idx: Expres
          """
     })
   }
+
+  override protected def withNewChildrenInternal(
+      newFirst: Expression, newSecond: Expression, newThird: Expression): RegExpExtractAll =
+    copy(subject = newFirst, regexp = newSecond, idx = newThird)
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index 714f1d6..3d5f812 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -227,6 +227,9 @@ case class ConcatWs(children: Seq[Expression])
       """)
     }
   }
+
+  override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): ConcatWs =
+    copy(children = newChildren)
 }
 
 /**
@@ -366,6 +369,9 @@ case class Elt(
          |final boolean ${ev.isNull} = ${ev.value} == null;
        """.stripMargin)
   }
+
+  override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Elt =
+    copy(children = newChildren)
 }
 
 
@@ -403,6 +409,8 @@ case class Upper(child: Expression)
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     defineCodeGen(ctx, ev, c => s"($c).toUpperCase()")
   }
+
+  override protected def withNewChildInternal(newChild: Expression): Upper = copy(child = newChild)
 }
 
 /**
@@ -430,6 +438,8 @@ case class Lower(child: Expression)
 
   override def prettyName: String =
     getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("lower")
+
+  override protected def withNewChildInternal(newChild: Expression): Lower = copy(child = newChild)
 }
 
 /** A base trait for functions that compare two strings, returning a boolean. */
@@ -454,6 +464,8 @@ case class Contains(left: Expression, right: Expression) extends StringPredicate
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     defineCodeGen(ctx, ev, (c1, c2) => s"($c1).contains($c2)")
   }
+  override protected def withNewChildrenInternal(
+    newLeft: Expression, newRight: Expression): Contains = copy(left = newLeft, right = newRight)
 }
 
 /**
@@ -464,6 +476,8 @@ case class StartsWith(left: Expression, right: Expression) extends StringPredica
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     defineCodeGen(ctx, ev, (c1, c2) => s"($c1).startsWith($c2)")
   }
+  override protected def withNewChildrenInternal(
+    newLeft: Expression, newRight: Expression): StartsWith = copy(left = newLeft, right = newRight)
 }
 
 /**
@@ -474,6 +488,8 @@ case class EndsWith(left: Expression, right: Expression) extends StringPredicate
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     defineCodeGen(ctx, ev, (c1, c2) => s"($c1).endsWith($c2)")
   }
+  override protected def withNewChildrenInternal(
+    newLeft: Expression, newRight: Expression): EndsWith = copy(left = newLeft, right = newRight)
 }
 
 /**
@@ -522,6 +538,10 @@ case class StringReplace(srcExpr: Expression, searchExpr: Expression, replaceExp
   override def third: Expression = replaceExpr
 
   override def prettyName: String = "replace"
+
+  override protected def withNewChildrenInternal(
+      newFirst: Expression, newSecond: Expression, newThird: Expression): StringReplace =
+    copy(srcExpr = newFirst, searchExpr = newSecond, replaceExpr = newThird)
 }
 
 object Overlay {
@@ -634,6 +654,10 @@ case class Overlay(input: Expression, replace: Expression, pos: Expression, len:
   override def second: Expression = replace
   override def third: Expression = pos
   override def fourth: Expression = len
+
+  override protected def withNewChildrenInternal(
+      first: Expression, second: Expression, third: Expression, fourth: Expression): Overlay =
+    copy(input = first, replace = second, pos = third, len = fourth)
 }
 
 object StringTranslate {
@@ -731,6 +755,10 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac
   override def second: Expression = matchingExpr
   override def third: Expression = replaceExpr
   override def prettyName: String = "translate"
+
+  override protected def withNewChildrenInternal(
+      newFirst: Expression, newSecond: Expression, newThird: Expression): StringTranslate =
+    copy(srcExpr = newFirst, matchingExpr = newSecond, replaceExpr = newThird)
 }
 
 /**
@@ -769,6 +797,9 @@ case class FindInSet(left: Expression, right: Expression) extends BinaryExpressi
   override def dataType: DataType = IntegerType
 
   override def prettyName: String = "find_in_set"
+
+  override protected def withNewChildrenInternal(
+    newLeft: Expression, newRight: Expression): FindInSet = copy(left = newLeft, right = newRight)
 }
 
 trait String2TrimExpression extends Expression with ImplicitCastInputTypes {
@@ -926,6 +957,11 @@ case class StringTrim(srcStr: Expression, trimStr: Option[Expression] = None)
     srcString.trim(trimString)
 
   override val trimMethod: String = "trim"
+
+  override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
+    copy(
+      srcStr = newChildren.head,
+      trimStr = if (trimStr.isDefined) Some(newChildren.last) else None)
 }
 
 /**
@@ -974,6 +1010,9 @@ case class StringTrimBoth(srcStr: Expression, trimStr: Option[Expression], child
   override def flatArguments: Iterator[Any] = Iterator(srcStr, trimStr)
 
   override def prettyName: String = "btrim"
+
+  override protected def withNewChildInternal(newChild: Expression): StringTrimBoth =
+    copy(child = newChild)
 }
 
 object StringTrimLeft {
@@ -1027,6 +1066,12 @@ case class StringTrimLeft(srcStr: Expression, trimStr: Option[Expression] = None
     srcString.trimLeft(trimString)
 
   override val trimMethod: String = "trimLeft"
+
+  override protected def withNewChildrenInternal(
+      newChildren: IndexedSeq[Expression]): StringTrimLeft =
+    copy(
+      srcStr = newChildren.head,
+      trimStr = if (trimStr.isDefined) Some(newChildren.last) else None)
 }
 
 object StringTrimRight {
@@ -1082,6 +1127,12 @@ case class StringTrimRight(srcStr: Expression, trimStr: Option[Expression] = Non
     srcString.trimRight(trimString)
 
   override val trimMethod: String = "trimRight"
+
+  override protected def withNewChildrenInternal(
+      newChildren: IndexedSeq[Expression]): StringTrimRight =
+    copy(
+      srcStr = newChildren.head,
+      trimStr = if (trimStr.isDefined) Some(newChildren.last) else None)
 }
 
 /**
@@ -1120,6 +1171,9 @@ case class StringInstr(str: Expression, substr: Expression)
     defineCodeGen(ctx, ev, (l, r) =>
       s"($l).indexOf($r, 0) + 1")
   }
+
+  override protected def withNewChildrenInternal(
+    newLeft: Expression, newRight: Expression): StringInstr = copy(str = newLeft, substr = newRight)
 }
 
 /**
@@ -1164,6 +1218,10 @@ case class SubstringIndex(strExpr: Expression, delimExpr: Expression, countExpr:
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     defineCodeGen(ctx, ev, (str, delim, count) => s"$str.subStringIndex($delim, $count)")
   }
+
+  override protected def withNewChildrenInternal(
+      newFirst: Expression, newSecond: Expression, newThird: Expression): SubstringIndex =
+    copy(strExpr = newFirst, delimExpr = newSecond, countExpr = newThird)
 }
 
 /**
@@ -1258,6 +1316,11 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression)
 
   override def prettyName: String =
     getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("locate")
+
+  override protected def withNewChildrenInternal(
+      newFirst: Expression, newSecond: Expression, newThird: Expression): StringLocate =
+    copy(substr = newFirst, str = newSecond, start = newThird)
+
 }
 
 /**
@@ -1302,6 +1365,10 @@ case class StringLPad(str: Expression, len: Expression, pad: Expression = Litera
   }
 
   override def prettyName: String = "lpad"
+
+  override protected def withNewChildrenInternal(
+      newFirst: Expression, newSecond: Expression, newThird: Expression): StringLPad =
+    copy(str = newFirst, len = newSecond, pad = newThird)
 }
 
 /**
@@ -1347,6 +1414,10 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression = Litera
   }
 
   override def prettyName: String = "rpad"
+
+  override protected def withNewChildrenInternal(
+      newFirst: Expression, newSecond: Expression, newThird: Expression): StringRPad =
+    copy(str = newFirst, len = newSecond, pad = newThird)
 }
 
 object ParseUrl {
@@ -1519,6 +1590,9 @@ case class ParseUrl(children: Seq[Expression], failOnError: Boolean = SQLConf.ge
       }
     }
   }
+
+  override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): ParseUrl =
+    copy(children = newChildren)
 }
 
 /**
@@ -1606,6 +1680,9 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC
 
   override def prettyName: String = getTagValue(
     FunctionRegistry.FUNC_ALIAS).getOrElse("format_string")
+
+  override protected def withNewChildrenInternal(
+    newChildren: IndexedSeq[Expression]): FormatString = FormatString(newChildren: _*)
 }
 
 /**
@@ -1638,6 +1715,9 @@ case class InitCap(child: Expression)
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     defineCodeGen(ctx, ev, str => s"$str.toLowerCase().toTitleCase()")
   }
+
+  override protected def withNewChildInternal(newChild: Expression): InitCap =
+    copy(child = newChild)
 }
 
 /**
@@ -1669,6 +1749,9 @@ case class StringRepeat(str: Expression, times: Expression)
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     defineCodeGen(ctx, ev, (l, r) => s"($l).repeat($r)")
   }
+
+  override protected def withNewChildrenInternal(
+    newLeft: Expression, newRight: Expression): StringRepeat = copy(str = newLeft, times = newRight)
 }
 
 /**
@@ -1700,6 +1783,9 @@ case class StringSpace(child: Expression)
   }
 
   override def prettyName: String = "space"
+
+  override protected def withNewChildInternal(newChild: Expression): StringSpace =
+    copy(child = newChild)
 }
 
 /**
@@ -1767,6 +1853,11 @@ case class Substring(str: Expression, pos: Expression, len: Expression)
       }
     })
   }
+
+  override protected def withNewChildrenInternal(
+      newFirst: Expression, newSecond: Expression, newThird: Expression): Substring =
+    copy(str = newFirst, pos = newSecond, len = newThird)
+
 }
 
 /**
@@ -1791,6 +1882,8 @@ case class Right(str: Expression, len: Expression, child: Expression) extends Ru
 
   override def flatArguments: Iterator[Any] = Iterator(str, len)
   override def exprsReplaced: Seq[Expression] = Seq(str, len)
+
+  override protected def withNewChildInternal(newChild: Expression): Right = copy(child = newChild)
 }
 
 /**
@@ -1814,6 +1907,7 @@ case class Left(str: Expression, len: Expression, child: Expression) extends Run
 
   override def flatArguments: Iterator[Any] = Iterator(str, len)
   override def exprsReplaced: Seq[Expression] = Seq(str, len)
+  override protected def withNewChildInternal(newChild: Expression): Left = copy(child = newChild)
 }
 
 /**
@@ -1851,6 +1945,8 @@ case class Length(child: Expression)
       case BinaryType => defineCodeGen(ctx, ev, c => s"($c).length")
     }
   }
+
+  override protected def withNewChildInternal(newChild: Expression): Length = copy(child = newChild)
 }
 
 /**
@@ -1883,6 +1979,9 @@ case class BitLength(child: Expression)
   }
 
   override def prettyName: String = "bit_length"
+
+  override protected def withNewChildInternal(newChild: Expression): BitLength =
+    copy(child = newChild)
 }
 
 /**
@@ -1916,6 +2015,9 @@ case class OctetLength(child: Expression)
   }
 
   override def prettyName: String = "octet_length"
+
+  override protected def withNewChildInternal(newChild: Expression): OctetLength =
+    copy(child = newChild)
 }
 
 /**
@@ -1943,6 +2045,9 @@ case class Levenshtein(left: Expression, right: Expression) extends BinaryExpres
     nullSafeCodeGen(ctx, ev, (left, right) =>
       s"${ev.value} = $left.levenshteinDistance($right);")
   }
+
+  override protected def withNewChildrenInternal(
+    newLeft: Expression, newRight: Expression): Levenshtein = copy(left = newLeft, right = newRight)
 }
 
 /**
@@ -1969,6 +2074,9 @@ case class SoundEx(child: Expression)
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     defineCodeGen(ctx, ev, c => s"$c.soundex()")
   }
+
+  override protected def withNewChildInternal(newChild: Expression): SoundEx =
+    copy(child = newChild)
 }
 
 /**
@@ -2012,6 +2120,8 @@ case class Ascii(child: Expression)
         }
        """})
   }
+
+  override protected def withNewChildInternal(newChild: Expression): Ascii = copy(child = newChild)
 }
 
 /**
@@ -2060,6 +2170,8 @@ case class Chr(child: Expression)
       """
     })
   }
+
+  override protected def withNewChildInternal(newChild: Expression): Chr = copy(child = newChild)
 }
 
 /**
@@ -2090,6 +2202,8 @@ case class Base64(child: Expression)
             ${classOf[CommonsBase64].getName}.encodeBase64($child));
        """})
   }
+
+  override protected def withNewChildInternal(newChild: Expression): Base64 = copy(child = newChild)
 }
 
 /**
@@ -2119,6 +2233,9 @@ case class UnBase64(child: Expression)
          ${ev.value} = ${classOf[CommonsBase64].getName}.decodeBase64($child.toString());
        """})
   }
+
+  override protected def withNewChildInternal(newChild: Expression): UnBase64 =
+    copy(child = newChild)
 }
 
 object Decode {
@@ -2178,6 +2295,8 @@ case class Decode(params: Seq[Expression], child: Expression) extends RuntimeRep
 
   override def flatArguments: Iterator[Any] = Iterator(params)
   override def exprsReplaced: Seq[Expression] = params
+
+  override protected def withNewChildInternal(newChild: Expression): Decode = copy(child = newChild)
 }
 
 /**
@@ -2219,6 +2338,10 @@ case class StringDecode(bin: Expression, charset: Expression)
         }
       """)
   }
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression, newRight: Expression): StringDecode =
+    copy(bin = newLeft, charset = newRight)
 }
 
 /**
@@ -2259,6 +2382,9 @@ case class Encode(value: Expression, charset: Expression)
           org.apache.spark.unsafe.Platform.throwException(e);
         }""")
   }
+
+  override protected def withNewChildrenInternal(
+    newLeft: Expression, newRight: Expression): Encode = copy(value = newLeft, charset = newRight)
 }
 
 /**
@@ -2439,6 +2565,9 @@ case class FormatNumber(x: Expression, d: Expression)
   }
 
   override def prettyName: String = "format_number"
+
+  override protected def withNewChildrenInternal(
+    newLeft: Expression, newRight: Expression): FormatNumber = copy(x = newLeft, d = newRight)
 }
 
 /**
@@ -2509,4 +2638,9 @@ case class Sentences(
     }
     new GenericArrayData(result.toSeq)
   }
+
+  override protected def withNewChildrenInternal(
+      newFirst: Expression, newSecond: Expression, newThird: Expression): Sentences =
+    copy(str = newFirst, language = newSecond, country = newThird)
+
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
index ff88567..ea6e427 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
@@ -238,6 +238,9 @@ case class ScalarSubquery(
       children.map(_.canonicalized),
       ExprId(0))
   }
+
+  override protected def withNewChildrenInternal(
+    newChildren: IndexedSeq[Expression]): ScalarSubquery = copy(children = newChildren)
 }
 
 object ScalarSubquery {
@@ -283,6 +286,9 @@ case class ListQuery(
       ExprId(0),
       childOutputs.map(_.canonicalized.asInstanceOf[Attribute]))
   }
+
+  override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): ListQuery =
+    copy(children = newChildren)
 }
 
 /**
@@ -325,4 +331,7 @@ case class Exists(
       children.map(_.canonicalized),
       ExprId(0))
   }
+
+  override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Exists =
+    copy(children = newChildren)
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
index fa027d1..ff486bf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
@@ -47,6 +47,13 @@ case class WindowSpecDefinition(
 
   override def children: Seq[Expression] = partitionSpec ++ orderSpec :+ frameSpecification
 
+  override protected def withNewChildrenInternal(
+      newChildren: IndexedSeq[Expression]): WindowSpecDefinition =
+    copy(
+      partitionSpec = newChildren.take(partitionSpec.size),
+      orderSpec = newChildren.drop(partitionSpec.size).dropRight(1).asInstanceOf[Seq[SortOrder]],
+      frameSpecification = newChildren.last.asInstanceOf[WindowFrame])
+
   override lazy val resolved: Boolean =
     childrenResolved && checkInputDataTypes().isSuccess &&
       frameSpecification.isInstanceOf[SpecifiedWindowFrame]
@@ -266,6 +273,10 @@ case class SpecifiedWindowFrame(
       case _ => true
     }
   }
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression, newRight: Expression): SpecifiedWindowFrame =
+    copy(lower = newLeft, upper = newRight)
 }
 
 case class UnresolvedWindowExpression(
@@ -275,6 +286,9 @@ case class UnresolvedWindowExpression(
   override def dataType: DataType = throw new UnresolvedException("dataType")
   override def nullable: Boolean = throw new UnresolvedException("nullable")
   override lazy val resolved = false
+
+  override protected def withNewChildInternal(newChild: Expression): UnresolvedWindowExpression =
+    copy(child = newChild)
 }
 
 case class WindowExpression(
@@ -290,6 +304,10 @@ case class WindowExpression(
 
   override def toString: String = s"$windowFunction $windowSpec"
   override def sql: String = windowFunction.sql + " OVER " + windowSpec.sql
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression, newRight: Expression): WindowExpression =
+    copy(windowFunction = newLeft, windowSpec = newRight.asInstanceOf[WindowSpecDefinition])
 }
 
 /**
@@ -458,6 +476,10 @@ case class Lead(
   override def first: Expression = input
   override def second: Expression = offset
   override def third: Expression = default
+
+  override protected def withNewChildrenInternal(
+      newFirst: Expression, newSecond: Expression, newThird: Expression): Lead =
+    copy(input = newFirst, offset = newSecond, default = newThird)
 }
 
 /**
@@ -513,6 +535,10 @@ case class Lag(
   override def first: Expression = input
   override def second: Expression = inputOffset
   override def third: Expression = default
+
+  override protected def withNewChildrenInternal(
+      newFirst: Expression, newSecond: Expression, newThird: Expression): Lag =
+    copy(input = newFirst, inputOffset = newSecond, default = newThird)
 }
 
 abstract class AggregateWindowFunction extends DeclarativeAggregate with WindowFunction {
@@ -698,6 +724,10 @@ case class NthValue(input: Expression, offset: Expression, ignoreNulls: Boolean)
   override def prettyName: String = "nth_value"
   override def sql: String =
     s"$prettyName(${input.sql}, ${offset.sql})${if (ignoreNulls) " ignore nulls" else ""}"
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression, newRight: Expression): NthValue =
+    copy(input = newLeft, offset = newRight)
 }
 
 /**
@@ -800,6 +830,9 @@ case class NTile(buckets: Expression) extends RowNumberLike with SizeBasedWindow
   )
 
   override val evaluateExpression = bucket
+
+  override protected def withNewChildInternal(
+    newChild: Expression): NTile = copy(buckets = newChild)
 }
 
 /**
@@ -884,6 +917,8 @@ abstract class RankLike extends AggregateWindowFunction {
 case class Rank(children: Seq[Expression]) extends RankLike {
   def this() = this(Nil)
   override def withOrder(order: Seq[Expression]): Rank = Rank(order)
+  override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Rank =
+    copy(children = newChildren)
 }
 
 /**
@@ -925,6 +960,8 @@ case class DenseRank(children: Seq[Expression]) extends RankLike {
   override val aggBufferAttributes = rank +: orderAttrs
   override val initialValues = zero +: orderInit
   override def prettyName: String = "dense_rank"
+  override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): DenseRank =
+    copy(children = newChildren)
 }
 
 /**
@@ -966,4 +1003,6 @@ case class PercentRank(children: Seq[Expression]) extends RankLike with SizeBase
   override val evaluateExpression =
     If(n > one, (rank - one).cast(DoubleType) / (n - one).cast(DoubleType), 0.0d)
   override def prettyName: String = "percent_rank"
+  override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): PercentRank =
+    copy(children = newChildren)
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala
index b8fc830..336dc7a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala
@@ -75,6 +75,9 @@ case class XPathBoolean(xml: Expression, path: Expression) extends XPathExtract
   override def nullSafeEval(xml: Any, path: Any): Any = {
     xpathUtil.evalBoolean(xml.asInstanceOf[UTF8String].toString, pathString)
   }
+
+  override protected def withNewChildrenInternal(
+    newLeft: Expression, newRight: Expression): XPathBoolean = copy(xml = newLeft, path = newRight)
 }
 
 // scalastyle:off line.size.limit
@@ -96,6 +99,9 @@ case class XPathShort(xml: Expression, path: Expression) extends XPathExtract {
     val ret = xpathUtil.evalNumber(xml.asInstanceOf[UTF8String].toString, pathString)
     if (ret eq null) null else ret.shortValue()
   }
+
+  override protected def withNewChildrenInternal(
+    newLeft: Expression, newRight: Expression): XPathShort = copy(xml = newLeft, path = newRight)
 }
 
 // scalastyle:off line.size.limit
@@ -117,6 +123,9 @@ case class XPathInt(xml: Expression, path: Expression) extends XPathExtract {
     val ret = xpathUtil.evalNumber(xml.asInstanceOf[UTF8String].toString, pathString)
     if (ret eq null) null else ret.intValue()
   }
+
+  override protected def withNewChildrenInternal(
+    newLeft: Expression, newRight: Expression): Expression = copy(xml = newLeft, path = newRight)
 }
 
 // scalastyle:off line.size.limit
@@ -138,6 +147,9 @@ case class XPathLong(xml: Expression, path: Expression) extends XPathExtract {
     val ret = xpathUtil.evalNumber(xml.asInstanceOf[UTF8String].toString, pathString)
     if (ret eq null) null else ret.longValue()
   }
+
+  override protected def withNewChildrenInternal(
+    newLeft: Expression, newRight: Expression): XPathLong = copy(xml = newLeft, path = newRight)
 }
 
 // scalastyle:off line.size.limit
@@ -159,6 +171,9 @@ case class XPathFloat(xml: Expression, path: Expression) extends XPathExtract {
     val ret = xpathUtil.evalNumber(xml.asInstanceOf[UTF8String].toString, pathString)
     if (ret eq null) null else ret.floatValue()
   }
+
+  override protected def withNewChildrenInternal(
+    newLeft: Expression, newRight: Expression): XPathFloat = copy(xml = newLeft, path = newRight)
 }
 
 // scalastyle:off line.size.limit
@@ -181,6 +196,9 @@ case class XPathDouble(xml: Expression, path: Expression) extends XPathExtract {
     val ret = xpathUtil.evalNumber(xml.asInstanceOf[UTF8String].toString, pathString)
     if (ret eq null) null else ret.doubleValue()
   }
+
+  override protected def withNewChildrenInternal(
+    newLeft: Expression, newRight: Expression): XPathDouble = copy(xml = newLeft, path = newRight)
 }
 
 // scalastyle:off line.size.limit
@@ -202,6 +220,9 @@ case class XPathString(xml: Expression, path: Expression) extends XPathExtract {
     val ret = xpathUtil.evalString(xml.asInstanceOf[UTF8String].toString, pathString)
     UTF8String.fromString(ret)
   }
+
+  override protected def withNewChildrenInternal(
+    newLeft: Expression, newRight: Expression): Expression = copy(xml = newLeft, path = newRight)
 }
 
 // scalastyle:off line.size.limit
@@ -233,4 +254,7 @@ case class XPathList(xml: Expression, path: Expression) extends XPathExtract {
       null
     }
   }
+
+  override protected def withNewChildrenInternal(
+    newLeft: Expression, newRight: Expression): XPathList = copy(xml = newLeft, path = newRight)
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala
index 828f768..2a288ff 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala
@@ -107,6 +107,9 @@ case class OrderedJoin(
     joinType: JoinType,
     condition: Option[Expression]) extends BinaryNode {
   override def output: Seq[Attribute] = left.output ++ right.output
+  override protected def withNewChildrenInternal(
+      newLeft: LogicalPlan, newRight: LogicalPlan): OrderedJoin =
+    copy(left = newLeft, right = newRight)
 }
 
 /**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala
index ac8766c..a6444b1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala
@@ -211,4 +211,7 @@ case class NormalizeNaNAndZero(child: Expression) extends UnaryExpression with E
 
     nullSafeCodeGen(ctx, ev, codeToNormalize)
   }
+
+  override protected def withNewChildInternal(newChild: Expression): NormalizeNaNAndZero =
+    copy(child = newChild)
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala
index b6bf7cd..bf3f93d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala
@@ -61,4 +61,7 @@ case class EventTimeWatermark(
       a
     }
   }
+
+  override protected def withNewChildInternal(newChild: LogicalPlan): EventTimeWatermark =
+    copy(child = newChild)
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala
index 30bff88..6299976 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala
@@ -35,6 +35,9 @@ case class ScriptTransformation(
     ioschema: ScriptInputOutputSchema) extends UnaryNode {
   @transient
   override lazy val references: AttributeSet = AttributeSet(input.flatMap(_.references))
+
+  override protected def withNewChildInternal(newChild: LogicalPlan): ScriptTransformation =
+    copy(child = newChild)
 }
 
 /**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index 962ce93..ba54be7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -41,6 +41,8 @@ import org.apache.spark.util.random.RandomSampler
 case class ReturnAnswer(child: LogicalPlan) extends UnaryNode {
   override def maxRows: Option[Long] = child.maxRows
   override def output: Seq[Attribute] = child.output
+  override protected def withNewChildInternal(newChild: LogicalPlan): ReturnAnswer =
+    copy(child = newChild)
 }
 
 /**
@@ -52,6 +54,8 @@ case class ReturnAnswer(child: LogicalPlan) extends UnaryNode {
  */
 case class Subquery(child: LogicalPlan, correlated: Boolean) extends OrderPreservingUnaryNode {
   override def output: Seq[Attribute] = child.output
+  override protected def withNewChildInternal(newChild: LogicalPlan): Subquery =
+    copy(child = newChild)
 }
 
 object Subquery {
@@ -78,6 +82,9 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan)
 
   override lazy val validConstraints: ExpressionSet =
     getAllValidConstraints(projectList)
+
+  override protected def withNewChildInternal(newChild: LogicalPlan): Project =
+    copy(child = newChild)
 }
 
 /**
@@ -136,6 +143,9 @@ case class Generate(
   }
 
   def output: Seq[Attribute] = requiredChildOutput ++ qualifiedGeneratorOutput
+
+  override protected def withNewChildInternal(newChild: LogicalPlan): Generate =
+    copy(child = newChild)
 }
 
 case class Filter(condition: Expression, child: LogicalPlan)
@@ -149,6 +159,9 @@ case class Filter(condition: Expression, child: LogicalPlan)
       .filterNot(SubqueryExpression.hasCorrelatedSubquery)
     child.constraints.union(ExpressionSet(predicates))
   }
+
+  override protected def withNewChildInternal(newChild: LogicalPlan): Filter =
+    copy(child = newChild)
 }
 
 abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
@@ -201,6 +214,9 @@ case class Intersect(
       Some(children.flatMap(_.maxRows).min)
     }
   }
+
+  override protected def withNewChildrenInternal(
+    newLeft: LogicalPlan, newRight: LogicalPlan): Intersect = copy(left = newLeft, right = newRight)
 }
 
 case class Except(
@@ -214,6 +230,9 @@ case class Except(
   override def metadataOutput: Seq[Attribute] = Nil
 
   override protected lazy val validConstraints: ExpressionSet = leftConstraints
+
+  override protected def withNewChildrenInternal(
+    newLeft: LogicalPlan, newRight: LogicalPlan): Except = copy(left = newLeft, right = newRight)
 }
 
 /** Factory for constructing new `Union` nodes. */
@@ -326,6 +345,9 @@ case class Union(
       .map(child => rewriteConstraints(children.head.output, child.output, child.constraints))
       .reduce(merge(_, _))
   }
+
+  override protected def withNewChildrenInternal(newChildren: IndexedSeq[LogicalPlan]): Union =
+    copy(children = newChildren)
 }
 
 case class Join(
@@ -436,6 +458,9 @@ case class Join(
       || e.asInstanceOf[JoinHint].leftHint.isDefined
       || e.asInstanceOf[JoinHint].rightHint.isDefined)
   }
+
+  override protected def withNewChildrenInternal(
+    newLeft: LogicalPlan, newRight: LogicalPlan): Join = copy(left = newLeft, right = newRight)
 }
 
 /**
@@ -461,6 +486,9 @@ case class InsertIntoDir(
   override def output: Seq[Attribute] = Seq.empty
   override def metadataOutput: Seq[Attribute] = Nil
   override lazy val resolved: Boolean = false
+
+  override protected def withNewChildInternal(newChild: LogicalPlan): InsertIntoDir =
+    copy(child = newChild)
 }
 
 /**
@@ -515,6 +543,9 @@ case class View(
       case _ => false
     }
   }
+
+  override protected def withNewChildInternal(newChild: LogicalPlan): View =
+    copy(child = newChild)
 }
 
 object View {
@@ -548,12 +579,16 @@ case class With(child: LogicalPlan, cteRelations: Seq[(String, SubqueryAlias)])
   }
 
   override def innerChildren: Seq[LogicalPlan] = cteRelations.map(_._2)
+
+  override protected def withNewChildInternal(newChild: LogicalPlan): With = copy(child = newChild)
 }
 
 case class WithWindowDefinition(
     windowDefinitions: Map[String, WindowSpecDefinition],
     child: LogicalPlan) extends UnaryNode {
   override def output: Seq[Attribute] = child.output
+  override protected def withNewChildInternal(newChild: LogicalPlan): WithWindowDefinition =
+    copy(child = newChild)
 }
 
 /**
@@ -569,6 +604,7 @@ case class Sort(
   override def output: Seq[Attribute] = child.output
   override def maxRows: Option[Long] = child.maxRows
   override def outputOrdering: Seq[SortOrder] = order
+  override protected def withNewChildInternal(newChild: LogicalPlan): Sort = copy(child = newChild)
 }
 
 /** Factory for constructing new `Range` nodes. */
@@ -739,6 +775,9 @@ case class Aggregate(
     val nonAgg = aggregateExpressions.filter(_.find(_.isInstanceOf[AggregateExpression]).isEmpty)
     getAllValidConstraints(nonAgg)
   }
+
+  override protected def withNewChildInternal(newChild: LogicalPlan): Aggregate =
+    copy(child = newChild)
 }
 
 case class Window(
@@ -753,6 +792,9 @@ case class Window(
   override def producedAttributes: AttributeSet = windowOutputSet
 
   def windowOutputSet: AttributeSet = AttributeSet(windowExpressions.map(_.toAttribute))
+
+  override protected def withNewChildInternal(newChild: LogicalPlan): Window =
+    copy(child = newChild)
 }
 
 object Expand {
@@ -869,6 +911,9 @@ case class Expand(
   // This operator can reuse attributes (for example making them null when doing a roll up) so
   // the constraints of the child may no longer be valid.
   override protected lazy val validConstraints: ExpressionSet = ExpressionSet()
+
+  override protected def withNewChildInternal(newChild: LogicalPlan): Expand =
+    copy(child = newChild)
 }
 
 /**
@@ -901,6 +946,8 @@ case class Pivot(
     groupByExprsOpt.getOrElse(Seq.empty).map(_.toAttribute) ++ pivotAgg
   }
   override def metadataOutput: Seq[Attribute] = Nil
+
+  override protected def withNewChildInternal(newChild: LogicalPlan): Pivot = copy(child = newChild)
 }
 
 /**
@@ -950,6 +997,9 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends OrderP
       case _ => None
     }
   }
+
+  override protected def withNewChildInternal(newChild: LogicalPlan): GlobalLimit =
+    copy(child = newChild)
 }
 
 /**
@@ -967,6 +1017,9 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends OrderPr
       case _ => None
     }
   }
+
+  override protected def withNewChildInternal(newChild: LogicalPlan): LocalLimit =
+    copy(child = newChild)
 }
 
 /**
@@ -987,6 +1040,8 @@ case class Tail(limitExpr: Expression, child: LogicalPlan) extends OrderPreservi
       case _ => None
     }
   }
+
+  override protected def withNewChildInternal(newChild: LogicalPlan): Tail = copy(child = newChild)
 }
 
 /**
@@ -1013,6 +1068,9 @@ case class SubqueryAlias(
   }
 
   override def doCanonicalize(): LogicalPlan = child.canonicalized
+
+  override protected def withNewChildInternal(newChild: LogicalPlan): SubqueryAlias =
+    copy(child = newChild)
 }
 
 object SubqueryAlias {
@@ -1066,6 +1124,9 @@ case class Sample(
 
   override def maxRows: Option[Long] = child.maxRows
   override def output: Seq[Attribute] = child.output
+
+  override protected def withNewChildInternal(newChild: LogicalPlan): Sample =
+    copy(child = newChild)
 }
 
 /**
@@ -1074,6 +1135,8 @@ case class Sample(
 case class Distinct(child: LogicalPlan) extends UnaryNode {
   override def maxRows: Option[Long] = child.maxRows
   override def output: Seq[Attribute] = child.output
+  override protected def withNewChildInternal(newChild: LogicalPlan): Distinct =
+    copy(child = newChild)
 }
 
 /**
@@ -1104,6 +1167,8 @@ case class Repartition(numPartitions: Int, shuffle: Boolean, child: LogicalPlan)
       case _ => RoundRobinPartitioning(numPartitions)
     }
   }
+  override protected def withNewChildInternal(newChild: LogicalPlan): Repartition =
+    copy(child = newChild)
 }
 
 /**
@@ -1145,6 +1210,9 @@ case class RepartitionByExpression(
   }
 
   override def shuffle: Boolean = true
+
+  override protected def withNewChildInternal(newChild: LogicalPlan): RepartitionByExpression =
+    copy(child = newChild)
 }
 
 object RepartitionByExpression {
@@ -1178,6 +1246,8 @@ case class Deduplicate(
     child: LogicalPlan) extends UnaryNode {
   override def maxRows: Option[Long] = child.maxRows
   override def output: Seq[Attribute] = child.output
+  override protected def withNewChildInternal(newChild: LogicalPlan): Deduplicate =
+    copy(child = newChild)
 }
 
 /**
@@ -1206,4 +1276,7 @@ case class CollectMetrics(
   }
 
   override def output: Seq[Attribute] = child.output
+
+  override protected def withNewChildInternal(newChild: LogicalPlan): CollectMetrics =
+    copy(child = newChild)
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala
index 4b5e278..5bda94c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala
@@ -31,6 +31,9 @@ case class UnresolvedHint(name: String, parameters: Seq[Any], child: LogicalPlan
 
   override lazy val resolved: Boolean = false
   override def output: Seq[Attribute] = child.output
+
+  override protected def withNewChildInternal(newChild: LogicalPlan): UnresolvedHint =
+    copy(child = newChild)
 }
 
 /**
@@ -43,6 +46,9 @@ case class ResolvedHint(child: LogicalPlan, hints: HintInfo = HintInfo())
   override def output: Seq[Attribute] = child.output
 
   override def doCanonicalize(): LogicalPlan = child.canonicalized
+
+  override protected def withNewChildInternal(newChild: LogicalPlan): ResolvedHint =
+    copy(child = newChild)
 }
 
 /**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
index d383532..6d61a86 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
@@ -79,7 +79,10 @@ trait ObjectConsumer extends UnaryNode {
 case class DeserializeToObject(
     deserializer: Expression,
     outputObjAttr: Attribute,
-    child: LogicalPlan) extends UnaryNode with ObjectProducer
+    child: LogicalPlan) extends UnaryNode with ObjectProducer {
+  override protected def withNewChildInternal(newChild: LogicalPlan): DeserializeToObject =
+    copy(child = newChild)
+}
 
 /**
  * Takes the input object from child and turns it into unsafe row using the given serializer
@@ -90,6 +93,9 @@ case class SerializeFromObject(
     child: LogicalPlan) extends ObjectConsumer {
 
   override def output: Seq[Attribute] = serializer.map(_.toAttribute)
+
+  override protected def withNewChildInternal(newChild: LogicalPlan): SerializeFromObject =
+    copy(child = newChild)
 }
 
 object MapPartitions {
@@ -111,7 +117,10 @@ object MapPartitions {
 case class MapPartitions(
     func: Iterator[Any] => Iterator[Any],
     outputObjAttr: Attribute,
-    child: LogicalPlan) extends ObjectConsumer with ObjectProducer
+    child: LogicalPlan) extends ObjectConsumer with ObjectProducer {
+  override protected def withNewChildInternal(newChild: LogicalPlan): MapPartitions =
+    copy(child = newChild)
+}
 
 object MapPartitionsInR {
   def apply(
@@ -159,6 +168,9 @@ case class MapPartitionsInR(
 
   override protected def stringArgs: Iterator[Any] = Iterator(inputSchema, outputSchema,
     outputObjAttr, child)
+
+  override protected def withNewChildInternal(newChild: LogicalPlan): MapPartitionsInR =
+    copy(child = newChild)
 }
 
 /**
@@ -182,6 +194,9 @@ case class MapPartitionsInRWithArrow(
     inputSchema, StructType.fromAttributes(output), child)
 
   override val producedAttributes = AttributeSet(output)
+
+  override protected def withNewChildInternal(newChild: LogicalPlan): MapPartitionsInRWithArrow =
+    copy(child = newChild)
 }
 
 object MapElements {
@@ -207,7 +222,10 @@ case class MapElements(
     argumentClass: Class[_],
     argumentSchema: StructType,
     outputObjAttr: Attribute,
-    child: LogicalPlan) extends ObjectConsumer with ObjectProducer
+    child: LogicalPlan) extends ObjectConsumer with ObjectProducer {
+  override protected def withNewChildInternal(newChild: LogicalPlan): MapElements =
+    copy(child = newChild)
+}
 
 object TypedFilter {
   def apply[T : Encoder](func: AnyRef, child: LogicalPlan): TypedFilter = {
@@ -251,6 +269,9 @@ case class TypedFilter(
     val funcObj = Literal.create(func, ObjectType(funcMethod._1))
     Invoke(funcObj, funcMethod._2, BooleanType, input :: Nil)
   }
+
+  override protected def withNewChildInternal(newChild: LogicalPlan): TypedFilter =
+    copy(child = newChild)
 }
 
 object FunctionUtils {
@@ -334,6 +355,9 @@ case class AppendColumns(
   override def output: Seq[Attribute] = child.output ++ newColumns
 
   def newColumns: Seq[Attribute] = serializer.map(_.toAttribute)
+
+  override protected def withNewChildInternal(newChild: LogicalPlan): AppendColumns =
+    copy(child = newChild)
 }
 
 /**
@@ -346,6 +370,9 @@ case class AppendColumnsWithObject(
     child: LogicalPlan) extends ObjectConsumer {
 
   override def output: Seq[Attribute] = (childSerializer ++ newColumnsSerializer).map(_.toAttribute)
+
+  override protected def withNewChildInternal(newChild: LogicalPlan): AppendColumnsWithObject =
+    copy(child = newChild)
 }
 
 /** Factory for constructing new `MapGroups` nodes. */
@@ -382,7 +409,10 @@ case class MapGroups(
     groupingAttributes: Seq[Attribute],
     dataAttributes: Seq[Attribute],
     outputObjAttr: Attribute,
-    child: LogicalPlan) extends UnaryNode with ObjectProducer
+    child: LogicalPlan) extends UnaryNode with ObjectProducer {
+  override protected def withNewChildInternal(newChild: LogicalPlan): MapGroups =
+    copy(child = newChild)
+}
 
 /** Internal class representing State */
 trait LogicalGroupState[S]
@@ -453,6 +483,9 @@ case class FlatMapGroupsWithState(
   if (isMapGroupsWithState) {
     assert(outputMode == OutputMode.Update)
   }
+
+  override protected def withNewChildInternal(newChild: LogicalPlan): FlatMapGroupsWithState =
+    copy(child = newChild)
 }
 
 /** Factory for constructing new `FlatMapGroupsInR` nodes. */
@@ -513,6 +546,9 @@ case class FlatMapGroupsInR(
   override protected def stringArgs: Iterator[Any] = Iterator(inputSchema, outputSchema,
     keyDeserializer, valueDeserializer, groupingAttributes, dataAttributes, outputObjAttr,
     child)
+
+  override protected def withNewChildInternal(newChild: LogicalPlan): FlatMapGroupsInR =
+    copy(child = newChild)
 }
 
 /**
@@ -537,6 +573,9 @@ case class FlatMapGroupsInRWithArrow(
     inputSchema, StructType.fromAttributes(output), keyDeserializer, groupingAttributes, child)
 
   override val producedAttributes = AttributeSet(output)
+
+  override protected def withNewChildInternal(newChild: LogicalPlan): FlatMapGroupsInRWithArrow =
+    copy(child = newChild)
 }
 
 /** Factory for constructing new `CoGroup` nodes. */
@@ -584,4 +623,7 @@ case class CoGroup(
     rightAttr: Seq[Attribute],
     outputObjAttr: Attribute,
     left: LogicalPlan,
-    right: LogicalPlan) extends BinaryNode with ObjectProducer
+    right: LogicalPlan) extends BinaryNode with ObjectProducer {
+  override protected def withNewChildrenInternal(
+      newLeft: LogicalPlan, newRight: LogicalPlan): CoGroup = copy(left = newLeft, right = newRight)
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
index 62f2d59..ba8352cf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
@@ -37,6 +37,9 @@ case class FlatMapGroupsInPandas(
    * from the input.
    */
   override val producedAttributes = AttributeSet(output)
+
+  override protected def withNewChildInternal(newChild: LogicalPlan): FlatMapGroupsInPandas =
+    copy(child = newChild)
 }
 
 /**
@@ -49,6 +52,9 @@ case class MapInPandas(
     child: LogicalPlan) extends UnaryNode {
 
   override val producedAttributes = AttributeSet(output)
+
+  override protected def withNewChildInternal(newChild: LogicalPlan): MapInPandas =
+    copy(child = newChild)
 }
 
 /**
@@ -70,6 +76,10 @@ case class FlatMapCoGroupsInPandas(
   def leftAttributes: Seq[Attribute] = left.output.take(leftGroupingLen)
 
   def rightAttributes: Seq[Attribute] = right.output.take(rightGroupingLen)
+
+  override protected def withNewChildrenInternal(
+      newLeft: LogicalPlan, newRight: LogicalPlan): FlatMapCoGroupsInPandas =
+    copy(left = newLeft, right = newRight)
 }
 
 trait BaseEvalPython extends UnaryNode {
@@ -89,7 +99,10 @@ trait BaseEvalPython extends UnaryNode {
 case class BatchEvalPython(
     udfs: Seq[PythonUDF],
     resultAttrs: Seq[Attribute],
-    child: LogicalPlan) extends BaseEvalPython
+    child: LogicalPlan) extends BaseEvalPython {
+  override protected def withNewChildInternal(newChild: LogicalPlan): BatchEvalPython =
+    copy(child = newChild)
+}
 
 /**
  * A logical plan that evaluates a [[PythonUDF]] with Apache Arrow.
@@ -98,4 +111,7 @@ case class ArrowEvalPython(
     udfs: Seq[PythonUDF],
     resultAttrs: Seq[Attribute],
     child: LogicalPlan,
-    evalType: Int) extends BaseEvalPython
+    evalType: Int) extends BaseEvalPython {
+  override protected def withNewChildInternal(newChild: LogicalPlan): ArrowEvalPython =
+    copy(child = newChild)
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala
index d600c15..44550ae 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala
@@ -167,6 +167,8 @@ case class CreateTableAsSelectStatement(
     ifNotExists: Boolean) extends UnaryParsedStatement {
 
   override def child: LogicalPlan = asSelect
+  override protected def withNewChildInternal(newChild: LogicalPlan): CreateTableAsSelectStatement =
+    copy(asSelect = newChild)
 }
 
 /**
@@ -181,7 +183,10 @@ case class CreateViewStatement(
     child: LogicalPlan,
     allowExisting: Boolean,
     replace: Boolean,
-    viewType: ViewType) extends UnaryParsedStatement
+    viewType: ViewType) extends UnaryParsedStatement {
+  override protected def withNewChildInternal(newChild: LogicalPlan): CreateViewStatement =
+    copy(child = newChild)
+}
 
 /**
  * A REPLACE TABLE command, as parsed from SQL.
@@ -220,6 +225,8 @@ case class ReplaceTableAsSelectStatement(
     orCreate: Boolean) extends UnaryParsedStatement {
 
   override def child: LogicalPlan = asSelect
+  override protected def withNewChildInternal(
+    newChild: LogicalPlan): ReplaceTableAsSelectStatement = copy(asSelect = newChild)
 }
 
 
@@ -300,6 +307,8 @@ case class InsertIntoStatement(
     "IF NOT EXISTS is only valid with static partitions")
 
   override def child: LogicalPlan = query
+  override protected def withNewChildInternal(newChild: LogicalPlan): InsertIntoStatement =
+    copy(query = newChild)
 }
 
 /**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
index 889509d..8b7f2db 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
@@ -77,6 +77,8 @@ case class AppendData(
     write: Option[Write] = None) extends V2WriteCommand {
   override def withNewQuery(newQuery: LogicalPlan): AppendData = copy(query = newQuery)
   override def withNewTable(newTable: NamedRelation): AppendData = copy(table = newTable)
+  override protected def withNewChildInternal(newChild: LogicalPlan): AppendData =
+    copy(query = newChild)
 }
 
 object AppendData {
@@ -115,6 +117,9 @@ case class OverwriteByExpression(
   override def withNewTable(newTable: NamedRelation): OverwriteByExpression = {
     copy(table = newTable)
   }
+
+  override protected def withNewChildInternal(newChild: LogicalPlan): OverwriteByExpression =
+    copy(query = newChild)
 }
 
 object OverwriteByExpression {
@@ -150,6 +155,9 @@ case class OverwritePartitionsDynamic(
   override def withNewTable(newTable: NamedRelation): OverwritePartitionsDynamic = {
     copy(table = newTable)
   }
+
+  override protected def withNewChildInternal(newChild: LogicalPlan): OverwritePartitionsDynamic =
+    copy(query = newChild)
 }
 
 object OverwritePartitionsDynamic {
@@ -222,6 +230,9 @@ case class CreateTableAsSelect(
   override def withPartitioning(rewritten: Seq[Transform]): V2CreateTablePlan = {
     this.copy(partitioning = rewritten)
   }
+
+  override protected def withNewChildInternal(newChild: LogicalPlan): CreateTableAsSelect =
+    copy(query = newChild)
 }
 
 /**
@@ -272,6 +283,9 @@ case class ReplaceTableAsSelect(
   override def withPartitioning(rewritten: Seq[Transform]): V2CreateTablePlan = {
     this.copy(partitioning = rewritten)
   }
+
+  override protected def withNewChildInternal(newChild: LogicalPlan): ReplaceTableAsSelect =
+    copy(query = newChild)
 }
 
 /**
@@ -291,6 +305,8 @@ case class DropNamespace(
     ifExists: Boolean,
     cascade: Boolean) extends UnaryCommand {
   override def child: LogicalPlan = namespace
+  override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan =
+    copy(namespace = newChild)
 }
 
 /**
@@ -301,6 +317,8 @@ case class DescribeNamespace(
     extended: Boolean,
     override val output: Seq[Attribute] = DescribeNamespace.getOutputAttrs) extends UnaryCommand {
   override def child: LogicalPlan = namespace
+  override protected def withNewChildInternal(newChild: LogicalPlan): DescribeNamespace =
+    copy(namespace = newChild)
 }
 
 object DescribeNamespace {
@@ -319,6 +337,8 @@ case class SetNamespaceProperties(
     namespace: LogicalPlan,
     properties: Map[String, String]) extends UnaryCommand {
   override def child: LogicalPlan = namespace
+  override protected def withNewChildInternal(newChild: LogicalPlan): SetNamespaceProperties =
+    copy(namespace = newChild)
 }
 
 /**
@@ -328,6 +348,8 @@ case class SetNamespaceLocation(
     namespace: LogicalPlan,
     location: String) extends UnaryCommand {
   override def child: LogicalPlan = namespace
+  override protected def withNewChildInternal(newChild: LogicalPlan): SetNamespaceLocation =
+    copy(namespace = newChild)
 }
 
 /**
@@ -338,6 +360,8 @@ case class ShowNamespaces(
     pattern: Option[String],
     override val output: Seq[Attribute] = ShowNamespaces.getOutputAttrs) extends UnaryCommand {
   override def child: LogicalPlan = namespace
+  override protected def withNewChildInternal(newChild: LogicalPlan): ShowNamespaces =
+    copy(namespace = newChild)
 }
 
 object ShowNamespaces {
@@ -355,6 +379,8 @@ case class DescribeRelation(
     isExtended: Boolean,
     override val output: Seq[Attribute] = DescribeRelation.getOutputAttrs) extends UnaryCommand {
   override def child: LogicalPlan = relation
+  override protected def withNewChildInternal(newChild: LogicalPlan): DescribeRelation =
+    copy(relation = newChild)
 }
 
 object DescribeRelation {
@@ -370,6 +396,8 @@ case class DescribeColumn(
     isExtended: Boolean,
     override val output: Seq[Attribute] = DescribeColumn.getOutputAttrs) extends UnaryCommand {
   override def child: LogicalPlan = relation
+  override protected def withNewChildInternal(newChild: LogicalPlan): DescribeColumn =
+    copy(relation = newChild)
 }
 
 object DescribeColumn {
@@ -383,6 +411,8 @@ case class DeleteFromTable(
     table: LogicalPlan,
     condition: Option[Expression]) extends UnaryCommand with SupportsSubquery {
   override def child: LogicalPlan = table
+  override protected def withNewChildInternal(newChild: LogicalPlan): DeleteFromTable =
+    copy(table = newChild)
 }
 
 /**
@@ -393,6 +423,8 @@ case class UpdateTable(
     assignments: Seq[Assignment],
     condition: Option[Expression]) extends UnaryCommand with SupportsSubquery {
   override def child: LogicalPlan = table
+  override protected def withNewChildInternal(newChild: LogicalPlan): UpdateTable =
+    copy(table = newChild)
 }
 
 /**
@@ -407,6 +439,9 @@ case class MergeIntoTable(
   def duplicateResolved: Boolean = targetTable.outputSet.intersect(sourceTable.outputSet).isEmpty
   override def left: LogicalPlan = targetTable
   override def right: LogicalPlan = sourceTable
+  override protected def withNewChildrenInternal(
+      newLeft: LogicalPlan, newRight: LogicalPlan): MergeIntoTable =
+    copy(targetTable = newLeft, sourceTable = newRight)
 }
 
 sealed abstract class MergeAction extends Expression with Unevaluable {
@@ -416,28 +451,49 @@ sealed abstract class MergeAction extends Expression with Unevaluable {
   override def children: Seq[Expression] = condition.toSeq
 }
 
-case class DeleteAction(condition: Option[Expression]) extends MergeAction
+case class DeleteAction(condition: Option[Expression]) extends MergeAction {
+  override protected def withNewChildrenInternal(
+      newChildren: IndexedSeq[Expression]): DeleteAction =
+    copy(condition = if (condition.isDefined) Some(newChildren(0)) else None)
+}
 
 case class UpdateAction(
     condition: Option[Expression],
     assignments: Seq[Assignment]) extends MergeAction {
   override def children: Seq[Expression] = condition.toSeq ++ assignments
+
+  override protected def withNewChildrenInternal(
+      newChildren: IndexedSeq[Expression]): UpdateAction =
+    copy(
+      condition = if (condition.isDefined) Some(newChildren.head) else None,
+      assignments = newChildren.tail.asInstanceOf[Seq[Assignment]])
 }
 
 case class UpdateStarAction(condition: Option[Expression]) extends MergeAction {
   override def children: Seq[Expression] = condition.toSeq
   override lazy val resolved = false
+  override protected def withNewChildrenInternal(
+    newChildren: IndexedSeq[Expression]): UpdateStarAction =
+  copy(condition = if (condition.isDefined) Some(newChildren(0)) else None)
 }
 
 case class InsertAction(
     condition: Option[Expression],
     assignments: Seq[Assignment]) extends MergeAction {
   override def children: Seq[Expression] = condition.toSeq ++ assignments
+  override protected def withNewChildrenInternal(
+      newChildren: IndexedSeq[Expression]): InsertAction =
+    copy(
+      condition = if (condition.isDefined) Some(newChildren.head) else None,
+      assignments = newChildren.tail.asInstanceOf[Seq[Assignment]])
 }
 
 case class InsertStarAction(condition: Option[Expression]) extends MergeAction {
   override def children: Seq[Expression] = condition.toSeq
   override lazy val resolved = false
+  override protected def withNewChildrenInternal(
+      newChildren: IndexedSeq[Expression]): InsertStarAction =
+    copy(condition = if (condition.isDefined) Some(newChildren(0)) else None)
 }
 
 case class Assignment(key: Expression, value: Expression) extends Expression
@@ -446,6 +502,8 @@ case class Assignment(key: Expression, value: Expression) extends Expression
   override def dataType: DataType = throw new UnresolvedException("nullable")
   override def left: Expression = key
   override def right: Expression = value
+  override protected def withNewChildrenInternal(
+    newLeft: Expression, newRight: Expression): Assignment = copy(key = newLeft, value = newRight)
 }
 
 /**
@@ -462,7 +520,10 @@ case class Assignment(key: Expression, value: Expression) extends Expression
 case class DropTable(
     child: LogicalPlan,
     ifExists: Boolean,
-    purge: Boolean) extends UnaryCommand
+    purge: Boolean) extends UnaryCommand {
+  override protected def withNewChildInternal(newChild: LogicalPlan): DropTable =
+    copy(child = newChild)
+}
 
 /**
  * The logical plan for no-op command handling non-existing table.
@@ -509,7 +570,10 @@ case class AlterTable(
 case class RenameTable(
     child: LogicalPlan,
     newName: Seq[String],
-    isView: Boolean) extends UnaryCommand
+    isView: Boolean) extends UnaryCommand {
+  override protected def withNewChildInternal(newChild: LogicalPlan): RenameTable =
+    copy(child = newChild)
+}
 
 /**
  * The logical plan of the SHOW TABLES command.
@@ -519,6 +583,8 @@ case class ShowTables(
     pattern: Option[String],
     override val output: Seq[Attribute] = ShowTables.getOutputAttrs) extends UnaryCommand {
   override def child: LogicalPlan = namespace
+  override protected def withNewChildInternal(newChild: LogicalPlan): ShowTables =
+    copy(namespace = newChild)
 }
 
 object ShowTables {
@@ -537,6 +603,8 @@ case class ShowTableExtended(
     partitionSpec: Option[PartitionSpec],
     override val output: Seq[Attribute] = ShowTableExtended.getOutputAttrs) extends UnaryCommand {
   override def child: LogicalPlan = namespace
+  override protected def withNewChildInternal(newChild: LogicalPlan): ShowTableExtended =
+    copy(namespace = newChild)
 }
 
 object ShowTableExtended {
@@ -558,6 +626,8 @@ case class ShowViews(
     pattern: Option[String],
     override val output: Seq[Attribute] = ShowViews.getOutputAttrs) extends UnaryCommand {
   override def child: LogicalPlan = namespace
+  override protected def withNewChildInternal(newChild: LogicalPlan): ShowViews =
+    copy(namespace = newChild)
 }
 
 object ShowViews {
@@ -578,7 +648,10 @@ case class SetCatalogAndNamespace(
 /**
  * The logical plan of the REFRESH TABLE command.
  */
-case class RefreshTable(child: LogicalPlan) extends UnaryCommand
+case class RefreshTable(child: LogicalPlan) extends UnaryCommand {
+  override protected def withNewChildInternal(newChild: LogicalPlan): RefreshTable =
+    copy(child = newChild)
+}
 
 /**
  * The logical plan of the SHOW CURRENT NAMESPACE command.
@@ -597,6 +670,8 @@ case class ShowTableProperties(
     propertyKey: Option[String],
     override val output: Seq[Attribute] = ShowTableProperties.getOutputAttrs) extends UnaryCommand {
   override def child: LogicalPlan = table
+  override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan =
+    copy(table = newChild)
 }
 
 object ShowTableProperties {
@@ -615,7 +690,10 @@ object ShowTableProperties {
  * where the `text` is the new comment written as a string literal; or `NULL` to drop the comment.
  *
  */
-case class CommentOnNamespace(child: LogicalPlan, comment: String) extends UnaryCommand
+case class CommentOnNamespace(child: LogicalPlan, comment: String) extends UnaryCommand {
+  override protected def withNewChildInternal(newChild: LogicalPlan): CommentOnNamespace =
+    copy(child = newChild)
+}
 
 /**
  * The logical plan that defines or changes the comment of an TABLE for v2 catalogs.
@@ -627,17 +705,26 @@ case class CommentOnNamespace(child: LogicalPlan, comment: String) extends Unary
  * where the `text` is the new comment written as a string literal; or `NULL` to drop the comment.
  *
  */
-case class CommentOnTable(child: LogicalPlan, comment: String) extends UnaryCommand
+case class CommentOnTable(child: LogicalPlan, comment: String) extends UnaryCommand {
+  override protected def withNewChildInternal(newChild: LogicalPlan): CommentOnTable =
+    copy(child = newChild)
+}
 
 /**
  * The logical plan of the REFRESH FUNCTION command.
  */
-case class RefreshFunction(child: LogicalPlan) extends UnaryCommand
+case class RefreshFunction(child: LogicalPlan) extends UnaryCommand {
+  override protected def withNewChildInternal(newChild: LogicalPlan): RefreshFunction =
+    copy(child = newChild)
+}
 
 /**
  * The logical plan of the DESCRIBE FUNCTION command.
  */
-case class DescribeFunction(child: LogicalPlan, isExtended: Boolean) extends UnaryCommand
+case class DescribeFunction(child: LogicalPlan, isExtended: Boolean) extends UnaryCommand {
+  override protected def withNewChildInternal(newChild: LogicalPlan): DescribeFunction =
+    copy(child = newChild)
+}
 
 /**
  * The logical plan of the DROP FUNCTION command.
@@ -645,7 +732,10 @@ case class DescribeFunction(child: LogicalPlan, isExtended: Boolean) extends Una
 case class DropFunction(
     child: LogicalPlan,
     ifExists: Boolean,
-    isTemp: Boolean) extends UnaryCommand
+    isTemp: Boolean) extends UnaryCommand {
+  override protected def withNewChildInternal(newChild: LogicalPlan): DropFunction =
+    copy(child = newChild)
+}
 
 /**
  * The logical plan of the SHOW FUNCTIONS command.
@@ -657,6 +747,9 @@ case class ShowFunctions(
     pattern: Option[String],
     override val output: Seq[Attribute] = ShowFunctions.getOutputAttrs) extends Command {
   override def children: Seq[LogicalPlan] = child.toSeq
+  override protected def withNewChildrenInternal(
+      newChildren: IndexedSeq[LogicalPlan]): ShowFunctions =
+    copy(child = if (child.isDefined) Some(newChildren.head) else None)
 }
 
 object ShowFunctions {
@@ -671,7 +764,10 @@ object ShowFunctions {
 case class AnalyzeTable(
     child: LogicalPlan,
     partitionSpec: Map[String, Option[String]],
-    noScan: Boolean) extends UnaryCommand
+    noScan: Boolean) extends UnaryCommand {
+  override protected def withNewChildInternal(newChild: LogicalPlan): AnalyzeTable =
+    copy(child = newChild)
+}
 
 /**
  * The logical plan of the ANALYZE TABLES command.
@@ -680,6 +776,8 @@ case class AnalyzeTables(
     namespace: LogicalPlan,
     noScan: Boolean) extends UnaryCommand {
   override def child: LogicalPlan = namespace
+  override protected def withNewChildInternal(newChild: LogicalPlan): AnalyzeTables =
+    copy(namespace = newChild)
 }
 
 /**
@@ -691,6 +789,9 @@ case class AnalyzeColumn(
     allColumns: Boolean) extends UnaryCommand {
   require(columnNames.isDefined ^ allColumns, "Parameter `columnNames` or `allColumns` are " +
     "mutually exclusive. Only one of them should be specified.")
+
+  override protected def withNewChildInternal(newChild: LogicalPlan): AnalyzeColumn =
+    copy(child = newChild)
 }
 
 /**
@@ -705,7 +806,10 @@ case class AnalyzeColumn(
 case class AddPartitions(
     table: LogicalPlan,
     parts: Seq[PartitionSpec],
-    ifNotExists: Boolean) extends V2PartitionCommand
+    ifNotExists: Boolean) extends V2PartitionCommand {
+  override protected def withNewChildInternal(newChild: LogicalPlan): AddPartitions =
+    copy(table = newChild)
+}
 
 /**
  * The logical plan of the ALTER TABLE DROP PARTITION command.
@@ -723,7 +827,10 @@ case class DropPartitions(
     table: LogicalPlan,
     parts: Seq[PartitionSpec],
     ifExists: Boolean,
-    purge: Boolean) extends V2PartitionCommand
+    purge: Boolean) extends V2PartitionCommand {
+  override protected def withNewChildInternal(newChild: LogicalPlan): DropPartitions =
+    copy(table = newChild)
+}
 
 /**
  * The logical plan of the ALTER TABLE ... RENAME TO PARTITION command.
@@ -731,12 +838,18 @@ case class DropPartitions(
 case class RenamePartitions(
     table: LogicalPlan,
     from: PartitionSpec,
-    to: PartitionSpec) extends V2PartitionCommand
+    to: PartitionSpec) extends V2PartitionCommand {
+  override protected def withNewChildInternal(newChild: LogicalPlan): RenamePartitions =
+    copy(table = newChild)
+}
 
 /**
  * The logical plan of the ALTER TABLE ... RECOVER PARTITIONS command.
  */
-case class RecoverPartitions(child: LogicalPlan) extends UnaryCommand
+case class RecoverPartitions(child: LogicalPlan) extends UnaryCommand {
+  override protected def withNewChildInternal(newChild: LogicalPlan): RecoverPartitions =
+    copy(child = newChild)
+}
 
 /**
  * The logical plan of the LOAD DATA INTO TABLE command.
@@ -746,7 +859,10 @@ case class LoadData(
     path: String,
     isLocal: Boolean,
     isOverwrite: Boolean,
-    partition: Option[TablePartitionSpec]) extends UnaryCommand
+    partition: Option[TablePartitionSpec]) extends UnaryCommand {
+  override protected def withNewChildInternal(newChild: LogicalPlan): LoadData =
+    copy(child = newChild)
+}
 
 /**
  * The logical plan of the SHOW CREATE TABLE command.
@@ -754,7 +870,10 @@ case class LoadData(
 case class ShowCreateTable(
     child: LogicalPlan,
     asSerde: Boolean = false,
-    override val output: Seq[Attribute] = ShowCreateTable.getoutputAttrs) extends UnaryCommand
+    override val output: Seq[Attribute] = ShowCreateTable.getoutputAttrs) extends UnaryCommand {
+  override protected def withNewChildInternal(newChild: LogicalPlan): ShowCreateTable =
+    copy(child = newChild)
+}
 
 object ShowCreateTable {
   def getoutputAttrs: Seq[Attribute] = {
@@ -768,7 +887,10 @@ object ShowCreateTable {
 case class ShowColumns(
     child: LogicalPlan,
     namespace: Option[Seq[String]],
-    override val output: Seq[Attribute] = ShowColumns.getOutputAttrs) extends UnaryCommand
+    override val output: Seq[Attribute] = ShowColumns.getOutputAttrs) extends UnaryCommand {
+  override protected def withNewChildInternal(newChild: LogicalPlan): ShowColumns =
+    copy(child = newChild)
+}
 
 object ShowColumns {
   def getOutputAttrs: Seq[Attribute] = {
@@ -781,6 +903,8 @@ object ShowColumns {
  */
 case class TruncateTable(table: LogicalPlan) extends UnaryCommand {
   override def child: LogicalPlan = table
+  override protected def withNewChildInternal(newChild: LogicalPlan): TruncateTable =
+    copy(table = newChild)
 }
 
 /**
@@ -790,6 +914,8 @@ case class TruncatePartition(
     table: LogicalPlan,
     partitionSpec: PartitionSpec) extends V2PartitionCommand {
   override def allowPartialPartitionSpec: Boolean = true
+  override protected def withNewChildInternal(newChild: LogicalPlan): TruncatePartition =
+    copy(table = newChild)
 }
 
 /**
@@ -801,6 +927,8 @@ case class ShowPartitions(
     override val output: Seq[Attribute] = ShowPartitions.getOutputAttrs)
   extends V2PartitionCommand {
   override def allowPartialPartitionSpec: Boolean = true
+  override protected def withNewChildInternal(newChild: LogicalPlan): ShowPartitions =
+    copy(table = newChild)
 }
 
 object ShowPartitions {
@@ -814,7 +942,10 @@ object ShowPartitions {
  */
 case class DropView(
     child: LogicalPlan,
-    ifExists: Boolean) extends UnaryCommand
+    ifExists: Boolean) extends UnaryCommand {
+  override protected def withNewChildInternal(newChild: LogicalPlan): DropView =
+    copy(child = newChild)
+}
 
 /**
  * The logical plan of the MSCK REPAIR TABLE command.
@@ -822,7 +953,10 @@ case class DropView(
 case class RepairTable(
     child: LogicalPlan,
     enableAddPartitions: Boolean,
-    enableDropPartitions: Boolean) extends UnaryCommand
+    enableDropPartitions: Boolean) extends UnaryCommand {
+  override protected def withNewChildInternal(newChild: LogicalPlan): RepairTable =
+    copy(child = newChild)
+}
 
 /**
  * The logical plan of the ALTER VIEW ... AS command.
@@ -833,6 +967,9 @@ case class AlterViewAs(
     query: LogicalPlan) extends BinaryCommand {
   override def left: LogicalPlan = child
   override def right: LogicalPlan = query
+  override protected def withNewChildrenInternal(
+      newLeft: LogicalPlan, newRight: LogicalPlan): LogicalPlan =
+    copy(child = newLeft, query = newRight)
 }
 
 /**
@@ -840,7 +977,10 @@ case class AlterViewAs(
  */
 case class SetViewProperties(
     child: LogicalPlan,
-    properties: Map[String, String]) extends UnaryCommand
+    properties: Map[String, String]) extends UnaryCommand {
+  override protected def withNewChildInternal(newChild: LogicalPlan): SetViewProperties =
+    copy(child = newChild)
+}
 
 /**
  * The logical plan of the ALTER VIEW ... UNSET TBLPROPERTIES command.
@@ -848,7 +988,10 @@ case class SetViewProperties(
 case class UnsetViewProperties(
     child: LogicalPlan,
     propertyKeys: Seq[String],
-    ifExists: Boolean) extends UnaryCommand
+    ifExists: Boolean) extends UnaryCommand {
+  override protected def withNewChildInternal(newChild: LogicalPlan): UnsetViewProperties =
+    copy(child = newChild)
+}
 
 /**
  * The logical plan of the ALTER TABLE ... SET [SERDE|SERDEPROPERTIES] command.
@@ -857,7 +1000,10 @@ case class SetTableSerDeProperties(
     child: LogicalPlan,
     serdeClassName: Option[String],
     serdeProperties: Option[Map[String, String]],
-    partitionSpec: Option[TablePartitionSpec]) extends UnaryCommand
+    partitionSpec: Option[TablePartitionSpec]) extends UnaryCommand {
+  override protected def withNewChildInternal(newChild: LogicalPlan): SetTableSerDeProperties =
+    copy(child = newChild)
+}
 
 /**
  * The logical plan of the CACHE TABLE command.
@@ -894,6 +1040,8 @@ case class SetTableLocation(
     partitionSpec: Option[TablePartitionSpec],
     location: String) extends UnaryCommand {
   override def child: LogicalPlan = table
+  override protected def withNewChildInternal(newChild: LogicalPlan): SetTableLocation =
+    copy(table = newChild)
 }
 
 /**
@@ -903,6 +1051,8 @@ case class SetTableProperties(
     table: LogicalPlan,
     properties: Map[String, String]) extends UnaryCommand {
   override def child: LogicalPlan = table
+  override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan =
+    copy(table = newChild)
 }
 
 /**
@@ -913,4 +1063,6 @@ case class UnsetTableProperties(
     propertyKeys: Seq[String],
     ifExists: Boolean) extends UnaryCommand {
   override def child: LogicalPlan = table
+  override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan =
+    copy(table = newChild)
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
index c4002aa..0f8c788 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -235,6 +235,9 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
    * than numPartitions) based on hashing expressions.
    */
   def partitionIdExpression: Expression = Pmod(new Murmur3Hash(expressions), Literal(numPartitions))
+
+  override protected def withNewChildrenInternal(
+    newChildren: IndexedSeq[Expression]): HashPartitioning = copy(expressions = newChildren)
 }
 
 /**
@@ -284,6 +287,10 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int)
       }
     }
   }
+
+  override protected def withNewChildrenInternal(
+      newChildren: IndexedSeq[Expression]): RangePartitioning =
+    copy(ordering = newChildren.asInstanceOf[Seq[SortOrder]])
 }
 
 /**
@@ -326,6 +333,10 @@ case class PartitioningCollection(partitionings: Seq[Partitioning])
   override def toString: String = {
     partitionings.map(_.toString).mkString("(", " or ", ")")
   }
+
+  override protected def withNewChildrenInternal(
+      newChildren: IndexedSeq[Expression]): PartitioningCollection =
+    super.legacyWithNewChildren(newChildren).asInstanceOf[PartitioningCollection]
 }
 
 /**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/WriteToStream.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/WriteToStream.scala
index 990ae30..2a29137 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/WriteToStream.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/WriteToStream.scala
@@ -39,5 +39,7 @@ case class WriteToStream(
 
   override def child: LogicalPlan = inputQuery
 
+  override protected def withNewChildInternal(newChild: LogicalPlan): WriteToStream =
+    copy(inputQuery = newChild)
 }
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/WriteToStreamStatement.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/WriteToStreamStatement.scala
index 34a4c13..407c70a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/WriteToStreamStatement.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/WriteToStreamStatement.scala
@@ -57,5 +57,8 @@ case class WriteToStreamStatement(
   override def output: Seq[Attribute] = Nil
 
   override def child: LogicalPlan = inputQuery
+
+  override protected def withNewChildInternal(newChild: LogicalPlan): WriteToStreamStatement =
+    copy(inputQuery = newChild)
 }
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index 8fc6238..3fab95c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -246,11 +246,50 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
     arr
   }
 
+  private def childrenFastEquals(
+      originalChildren: IndexedSeq[BaseType], newChildren: IndexedSeq[BaseType]): Boolean = {
+    val size = originalChildren.size
+    var i = 0
+    while (i < size) {
+      if (!originalChildren(i).fastEquals(newChildren(i))) return false
+      i += 1
+    }
+    true
+  }
+
+  // This is a temporary solution, we will change the type of children to IndexedSeq in a
+  // followup PR
+  private def asIndexedSeq(seq: Seq[BaseType]): IndexedSeq[BaseType] = {
+    if (seq.isInstanceOf[IndexedSeq[BaseType]]) {
+      seq.asInstanceOf[IndexedSeq[BaseType]]
+    } else {
+      seq.toIndexedSeq
+    }
+  }
+
+  final def withNewChildren(newChildren: Seq[BaseType]): BaseType = {
+    val childrenIndexedSeq = asIndexedSeq(children)
+    val newChildrenIndexedSeq = asIndexedSeq(newChildren)
+    assert(newChildrenIndexedSeq.size == childrenIndexedSeq.size, "Incorrect number of children")
+    if (childrenIndexedSeq.isEmpty ||
+        childrenFastEquals(newChildrenIndexedSeq, childrenIndexedSeq)) {
+      this
+    } else {
+      CurrentOrigin.withOrigin(origin) {
+        val res = withNewChildrenInternal(newChildrenIndexedSeq)
+        res.copyTagsFrom(this)
+        res
+      }
+    }
+  }
+
+  protected def withNewChildrenInternal(newChildren: IndexedSeq[BaseType]): BaseType
+
   /**
    * Returns a copy of this node with the children replaced.
    * TODO: Validate somewhere (in debug mode?) that children are ordered correctly.
    */
-  def withNewChildren(newChildren: Seq[BaseType]): BaseType = {
+  protected final def legacyWithNewChildren(newChildren: Seq[BaseType]): BaseType = {
     assert(newChildren.size == children.size, "Incorrect number of children")
     var changed = false
     val remainingNewChildren = newChildren.toBuffer
@@ -355,7 +394,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
    */
   def mapChildren(f: BaseType => BaseType): BaseType = {
     if (containsChild.nonEmpty) {
-      mapChildren(f, forceCopy = false)
+      withNewChildren(children.map(f))
     } else {
       this
     }
@@ -844,24 +883,96 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
 
 trait LeafLike[T <: TreeNode[T]] { self: TreeNode[T] =>
   override final def children: Seq[T] = Nil
+  override final def mapChildren(f: T => T): T = this.asInstanceOf[T]
+  override final def withNewChildrenInternal(newChildren: IndexedSeq[T]): T = this.asInstanceOf[T]
 }
 
 trait UnaryLike[T <: TreeNode[T]] { self: TreeNode[T] =>
   def child: T
-  @transient override final lazy val children: Seq[T] = child :: Nil
+  @transient override final lazy val children: Seq[T] = IndexedSeq(child)
+
+  override final def mapChildren(f: T => T): T = {
+    val newChild = f(child)
+    if (newChild fastEquals child) {
+      this.asInstanceOf[T]
+    } else {
+      CurrentOrigin.withOrigin(origin) {
+        val res = withNewChildInternal(newChild)
+        res.copyTagsFrom(this.asInstanceOf[T])
+        res
+      }
+    }
+  }
+
+  override final def withNewChildrenInternal(newChildren: IndexedSeq[T]): T = {
+    assert(newChildren.size == 1, "Incorrect number of children")
+    withNewChildInternal(newChildren.head)
+  }
+
+  protected def withNewChildInternal(newChild: T): T
 }
 
 trait BinaryLike[T <: TreeNode[T]] { self: TreeNode[T] =>
   def left: T
   def right: T
-  @transient override final lazy val children: Seq[T] = left :: right :: Nil
+  @transient override final lazy val children: Seq[T] = IndexedSeq(left, right)
+
+  override final def mapChildren(f: T => T): T = {
+    var newLeft = f(left)
+    newLeft = if (newLeft fastEquals left) left else newLeft
+    var newRight = f(right)
+    newRight = if (newRight fastEquals right) right else newRight
+
+    if (newLeft.eq(left) && newRight.eq(right)) {
+      this.asInstanceOf[T]
+    } else {
+      CurrentOrigin.withOrigin(origin) {
+        val res = withNewChildrenInternal(newLeft, newRight)
+        res.copyTagsFrom(this.asInstanceOf[T])
+        res
+      }
+    }
+  }
+
+  override final def withNewChildrenInternal(newChildren: IndexedSeq[T]): T = {
+    assert(newChildren.size == 2, "Incorrect number of children")
+    withNewChildrenInternal(newChildren(0), newChildren(1))
+  }
+
+  protected def withNewChildrenInternal(newLeft: T, newRight: T): T
 }
 
 trait TernaryLike[T <: TreeNode[T]] { self: TreeNode[T] =>
   def first: T
   def second: T
   def third: T
-  @transient override final lazy val children: Seq[T] = first :: second :: third :: Nil
+  @transient override final lazy val children: Seq[T] = IndexedSeq(first, second, third)
+
+  override final def mapChildren(f: T => T): T = {
+    var newFirst = f(first)
+    newFirst = if (newFirst fastEquals first) first else newFirst
+    var newSecond = f(second)
+    newSecond = if (newSecond fastEquals second) second else newSecond
+    var newThird = f(third)
+    newThird = if (newThird fastEquals third) third else newThird
+
+    if (newFirst.eq(first) && newSecond.eq(second) && newThird.eq(third)) {
+      this.asInstanceOf[T]
+    } else {
+      CurrentOrigin.withOrigin(origin) {
+        val res = withNewChildrenInternal(newFirst, newSecond, newThird)
+        res.copyTagsFrom(this.asInstanceOf[T])
+        res
+      }
+    }
+  }
+
+  override final def withNewChildrenInternal(newChildren: IndexedSeq[T]): T = {
+    assert(newChildren.size == 3, "Incorrect number of children")
+    withNewChildrenInternal(newChildren(0), newChildren(1), newChildren(2))
+  }
+
+  protected def withNewChildrenInternal(newFirst: T, newSecond: T, newThird: T): T
 }
 
 trait QuaternaryLike[T <: TreeNode[T]] { self: TreeNode[T] =>
@@ -869,5 +980,33 @@ trait QuaternaryLike[T <: TreeNode[T]] { self: TreeNode[T] =>
   def second: T
   def third: T
   def fourth: T
-  @transient override final lazy val children: Seq[T] = first :: second :: third :: fourth :: Nil
+  @transient override final lazy val children: Seq[T] = IndexedSeq(first, second, third, fourth)
+
+  override final def mapChildren(f: T => T): T = {
+    var newFirst = f(first)
+    newFirst = if (newFirst fastEquals first) first else newFirst
+    var newSecond = f(second)
+    newSecond = if (newSecond fastEquals second) second else newSecond
+    var newThird = f(third)
+    newThird = if (newThird fastEquals third) third else newThird
+    var newFourth = f(fourth)
+    newFourth = if (newFourth fastEquals fourth) fourth else newFourth
+
+    if (newFirst.eq(first) && newSecond.eq(second) && newThird.eq(third) && newFourth.eq(fourth)) {
+      this.asInstanceOf[T]
+    } else {
+      CurrentOrigin.withOrigin(origin) {
+        val res = withNewChildrenInternal(newFirst, newSecond, newThird, newFourth)
+        res.copyTagsFrom(this.asInstanceOf[T])
+        res
+      }
+    }
+  }
+
+  override final def withNewChildrenInternal(newChildren: IndexedSeq[T]): T = {
+    assert(newChildren.size == 4, "Incorrect number of children")
+    withNewChildrenInternal(newChildren(0), newChildren(1), newChildren(2), newChildren(3))
+  }
+
+  protected def withNewChildrenInternal(newFirst: T, newSecond: T, newThird: T, newFourth: T): T
 }
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
index a9d9acd..aecbf24 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
@@ -88,6 +88,8 @@ case class TestFunction(
   extends Expression with ImplicitCastInputTypes with Unevaluable {
   override def nullable: Boolean = true
   override def dataType: DataType = StringType
+  override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
+    copy(children = newChildren)
 }
 
 case class UnresolvedTestPlan() extends LeafNode {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
index a6145c5..9058e3e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
@@ -1623,12 +1623,16 @@ object TypeCoercionSuite {
     extends UnaryExpression with ExpectsInputTypes with Unevaluable {
     override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
     override def dataType: DataType = NullType
+    override protected def withNewChildInternal(newChild: Expression): AnyTypeUnaryExpression =
+      copy(child = newChild)
   }
 
   case class NumericTypeUnaryExpression(child: Expression)
     extends UnaryExpression with ExpectsInputTypes with Unevaluable {
     override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
     override def dataType: DataType = NullType
+    override protected def withNewChildInternal(newChild: Expression): NumericTypeUnaryExpression =
+      copy(child = newChild)
   }
 
   case class AnyTypeBinaryOperator(left: Expression, right: Expression)
@@ -1636,6 +1640,9 @@ object TypeCoercionSuite {
     override def dataType: DataType = NullType
     override def inputType: AbstractDataType = AnyDataType
     override def symbol: String = "anytype"
+    override protected def withNewChildrenInternal(
+        newLeft: Expression, newRight: Expression): AnyTypeBinaryOperator =
+      copy(left = newLeft, right = newRight)
   }
 
   case class NumericTypeBinaryOperator(left: Expression, right: Expression)
@@ -1643,5 +1650,8 @@ object TypeCoercionSuite {
     override def dataType: DataType = NullType
     override def inputType: AbstractDataType = NumericType
     override def symbol: String = "numerictype"
+    override protected def withNewChildrenInternal(
+        newLeft: Expression, newRight: Expression): NumericTypeBinaryOperator =
+      copy(left = newLeft, right = newRight)
   }
 }
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala
index 71993e1..dc62841 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala
@@ -998,6 +998,8 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper {
   case class StreamingPlanWrapper(child: LogicalPlan) extends UnaryNode {
     override def output: Seq[Attribute] = child.output
     override def isStreaming: Boolean = true
+    override protected def withNewChildInternal(newChild: LogicalPlan): StreamingPlanWrapper =
+      copy(child = newChild)
   }
 
   case class TestStreamingRelation(output: Seq[Attribute]) extends LeafNode {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala
index 65671d2..9bfe69b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala
@@ -314,4 +314,6 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
 case class CodegenFallbackExpression(child: Expression)
   extends UnaryExpression with CodegenFallback {
   override def dataType: DataType = child.dataType
+  override protected def withNewChildInternal(newChild: Expression): CodegenFallbackExpression =
+    copy(child = newChild)
 }
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala
index 43579d4..02b6eed 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala
@@ -104,4 +104,7 @@ case class ExprReuseOutput(child: Expression) extends UnaryExpression {
     row.update(0, child.eval(input))
     row
   }
+
+  override protected def withNewChildInternal(newChild: Expression): ExprReuseOutput =
+    copy(child = newChild)
 }
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala
index 8445239..3784f40 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala
@@ -66,6 +66,9 @@ class LogicalPlanSuite extends SparkFunSuite {
 
     case class TestBinaryRelation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
       override def output: Seq[Attribute] = left.output ++ right.output
+      override protected def withNewChildrenInternal(
+          newLeft: LogicalPlan, newRight: LogicalPlan): LogicalPlan =
+        copy(left = newLeft, right = newRight)
     }
 
     require(relation.isStreaming === false)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanIntegritySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanIntegritySuite.scala
index 6f342b8..009e2a7 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanIntegritySuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanIntegritySuite.scala
@@ -28,6 +28,8 @@ class LogicalPlanIntegritySuite extends PlanTest {
 
   case class OutputTestPlan(child: LogicalPlan, output: Seq[Attribute]) extends UnaryNode {
     override val analyzed = true
+    override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan =
+      copy(child = newChild)
   }
 
   test("Checks if the same `ExprId` refers to a semantically-equal attribute in a plan output") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
index 4ad8475..0d31677 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
@@ -47,6 +47,8 @@ case class Dummy(optKey: Option[Expression]) extends Expression with CodegenFall
   override def dataType: NullType = NullType
   override lazy val resolved = true
   override def eval(input: InternalRow): Any = null.asInstanceOf[Any]
+  override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
+    copy(optKey = if (optKey.isDefined) Some(newChildren(0)) else None)
 }
 
 case class ComplexPlan(exprs: Seq[Seq[Expression]])
@@ -59,6 +61,8 @@ case class ExpressionInMap(map: Map[String, Expression]) extends Unevaluable {
   override def nullable: Boolean = true
   override def dataType: NullType = NullType
   override lazy val resolved = true
+  override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
+    super.legacyWithNewChildren(newChildren)
 }
 
 case class SeqTupleExpression(sons: Seq[(Expression, Expression)],
@@ -67,6 +71,9 @@ case class SeqTupleExpression(sons: Seq[(Expression, Expression)],
   override def nullable: Boolean = true
   override def dataType: NullType = NullType
   override lazy val resolved = true
+
+  override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
+    super.legacyWithNewChildren(newChildren)
 }
 
 case class JsonTestTreeNode(arg: Any) extends LeafNode {
@@ -738,7 +745,10 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper {
   }
 
   object MalformedClassObject extends Serializable {
-    case class MalformedNameExpression(child: Expression) extends TaggingExpression
+    case class MalformedNameExpression(child: Expression) extends TaggingExpression {
+      override protected def withNewChildInternal(newChild: Expression): Expression =
+        copy(child = newChild)
+    }
   }
 
   test("SPARK-32999: TreeNode.nodeName should not throw malformed class name error") {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala
index b0bbb52..500425e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala
@@ -78,6 +78,9 @@ case class CollectMetricsExec(
       }
     }
   }
+
+  override protected def withNewChildInternal(newChild: SparkPlan): CollectMetricsExec =
+    copy(child = newChild)
 }
 
 object CollectMetricsExec {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala
index 8d54279..6bdd93e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala
@@ -201,6 +201,9 @@ case class ColumnarToRowExec(child: SparkPlan) extends ColumnarToRowTransition w
   override def inputRDDs(): Seq[RDD[InternalRow]] = {
     Seq(child.executeColumnar().asInstanceOf[RDD[InternalRow]]) // Hack because of type erasure
   }
+
+  override protected def withNewChildInternal(newChild: SparkPlan): ColumnarToRowExec =
+    copy(child = newChild)
 }
 
 /**
@@ -486,6 +489,9 @@ case class RowToColumnarExec(child: SparkPlan) extends RowToColumnarTransition {
       }
     }
   }
+
+  override protected def withNewChildInternal(newChild: SparkPlan): RowToColumnarExec =
+    copy(child = newChild)
 }
 
 /**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala
index 6f5bf15..3fd6531 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala
@@ -203,4 +203,7 @@ case class ExpandExec(
        |}
      """.stripMargin
   }
+
+  override protected def withNewChildInternal(newChild: SparkPlan): ExpandExec =
+    copy(child = newChild)
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
index 0d5ec2d..6c79294 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
@@ -325,4 +325,7 @@ case class GenerateExec(
     if (condition) Seq(code)
     else Seq.empty
   }
+
+  override protected def withNewChildInternal(newChild: SparkPlan): GenerateExec =
+    copy(child = newChild)
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
index 6b6ca53..984a45c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
@@ -202,4 +202,7 @@ case class SortExec(
     }
     super.cleanupResources()
   }
+
+  override protected def withNewChildInternal(newChild: SparkPlan): SortExec =
+    copy(child = newChild)
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkScriptTransformationExec.scala
index 75c9166..7f36289 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkScriptTransformationExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkScriptTransformationExec.scala
@@ -72,6 +72,9 @@ case class SparkScriptTransformationExec(
 
     outputIterator
   }
+
+  override protected def withNewChildInternal(newChild: SparkPlan): SparkScriptTransformationExec =
+    copy(child = newChild)
 }
 
 case class SparkScriptTransformationWriterThread(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SubqueryAdaptiveBroadcastExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SubqueryAdaptiveBroadcastExec.scala
index cece4309..a735d91 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SubqueryAdaptiveBroadcastExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SubqueryAdaptiveBroadcastExec.scala
@@ -39,4 +39,7 @@ case class SubqueryAdaptiveBroadcastExec(
     throw new UnsupportedOperationException(
       "SubqueryAdaptiveBroadcastExec does not support the execute() code path.")
   }
+
+  override protected def withNewChildInternal(newChild: SparkPlan): SubqueryAdaptiveBroadcastExec =
+    copy(child = newChild)
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SubqueryBroadcastExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SubqueryBroadcastExec.scala
index 70ba135..47cb70d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SubqueryBroadcastExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SubqueryBroadcastExec.scala
@@ -113,6 +113,9 @@ case class SubqueryBroadcastExec(
   }
 
   override def stringArgs: Iterator[Any] = super.stringArgs ++ Iterator(s"[id=#$id]")
+
+  override protected def withNewChildInternal(newChild: SparkPlan): SubqueryBroadcastExec =
+    copy(child = newChild)
 }
 
 object SubqueryBroadcastExec {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
index 9c50dc9..85bc98d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
@@ -554,6 +554,9 @@ case class InputAdapter(child: SparkPlan) extends UnaryExecNode with InputRDDCod
   }
 
   override def needCopyResult: Boolean = false
+
+  override protected def withNewChildInternal(newChild: SparkPlan): InputAdapter =
+    copy(child = newChild)
 }
 
 object WholeStageCodegenExec {
@@ -829,6 +832,9 @@ case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int)
   override def limitNotReachedChecks: Seq[String] = Nil
 
   override protected def otherCopyArgs: Seq[AnyRef] = Seq(codegenStageId.asInstanceOf[Integer])
+
+  override protected def withNewChildInternal(newChild: SparkPlan): WholeStageCodegenExec =
+    copy(child = newChild)(codegenStageId)
 }
 
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CustomShuffleReaderExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CustomShuffleReaderExec.scala
index 4639ccc..f2eefbc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CustomShuffleReaderExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CustomShuffleReaderExec.scala
@@ -195,4 +195,7 @@ case class CustomShuffleReaderExec private(
   override protected def doExecuteColumnar(): RDD[ColumnarBatch] = {
     shuffleRDD.asInstanceOf[RDD[ColumnarBatch]]
   }
+
+  override protected def withNewChildInternal(newChild: SparkPlan): CustomShuffleReaderExec =
+    copy(child = newChild)
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
index 7d45638..6e23a28 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
@@ -1108,6 +1108,9 @@ case class HashAggregateExec(
           s"$allAggregateExpressions $resultExpressions fallbackStartsAt=$fallbackStartsAt"
     }
   }
+
+  override protected def withNewChildInternal(newChild: SparkPlan): HashAggregateExec =
+    copy(child = newChild)
 }
 
 object HashAggregateExec {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala
index e5f59e0..559f545 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala
@@ -138,6 +138,9 @@ case class ObjectHashAggregateExec(
       s"ObjectHashAggregate(keys=$keyString, functions=$functionString)"
     }
   }
+
+  override protected def withNewChildInternal(newChild: SparkPlan): ObjectHashAggregateExec =
+    copy(child = newChild)
 }
 
 object ObjectHashAggregateExec {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala
index 2400cee..4fb0f44 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala
@@ -101,4 +101,7 @@ case class SortAggregateExec(
       s"SortAggregate(key=$keyString, functions=$functionString)"
     }
   }
+
+  override protected def withNewChildInternal(newChild: SparkPlan): SortAggregateExec =
+    copy(child = newChild)
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
index ea44c60..d958790 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
@@ -203,6 +203,10 @@ case class SimpleTypedAggregateExpression(
       schema: StructType): TypedAggregateExpression = {
     copy(inputDeserializer = Some(deser), inputClass = Some(cls), inputSchema = Some(schema))
   }
+
+  override protected def withNewChildrenInternal(
+      newChildren: IndexedSeq[Expression]): SimpleTypedAggregateExpression =
+    super.legacyWithNewChildren(newChildren).asInstanceOf[SimpleTypedAggregateExpression]
 }
 
 case class ComplexTypedAggregateExpression(
@@ -285,4 +289,8 @@ case class ComplexTypedAggregateExpression(
       schema: StructType): TypedAggregateExpression = {
     copy(inputDeserializer = Some(deser), inputClass = Some(cls), inputSchema = Some(schema))
   }
+
+  override protected def withNewChildrenInternal(
+      newChildren: IndexedSeq[Expression]): ComplexTypedAggregateExpression =
+    super.legacyWithNewChildren(newChildren).asInstanceOf[ComplexTypedAggregateExpression]
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
index e6851a9..1aae76e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
@@ -454,6 +454,9 @@ case class ScalaUDAF(
   override def nodeName: String = name
 
   override def name: String = udafName.getOrElse(udaf.getClass.getSimpleName)
+
+  override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): ScalaUDAF =
+    copy(children = newChildren)
 }
 
 case class ScalaAggregator[IN, BUF, OUT](
@@ -520,6 +523,10 @@ case class ScalaAggregator[IN, BUF, OUT](
   override def nodeName: String = name
 
   override def name: String = aggregatorName.getOrElse(agg.getClass.getSimpleName)
+
+  override protected def withNewChildrenInternal(
+      newChildren: IndexedSeq[Expression]): ScalaAggregator[IN, BUF, OUT] =
+    copy(children = newChildren)
 }
 
 /**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
index abd3360..b537040 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
@@ -107,6 +107,9 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan)
        |${ExplainUtils.generateFieldString("Input", child.output)}
        |""".stripMargin
   }
+
+  override protected def withNewChildInternal(newChild: SparkPlan): ProjectExec =
+    copy(child = newChild)
 }
 
 trait GeneratePredicateHelper extends PredicateHelper {
@@ -286,6 +289,9 @@ case class FilterExec(condition: Expression, child: SparkPlan)
        |Condition : ${condition}
        |""".stripMargin
   }
+
+  override protected def withNewChildInternal(newChild: SparkPlan): FilterExec =
+    copy(child = newChild)
 }
 
 /**
@@ -392,6 +398,9 @@ case class SampleExec(
        """.stripMargin.trim
     }
   }
+
+  override protected def withNewChildInternal(newChild: SparkPlan): SampleExec =
+    copy(child = newChild)
 }
 
 
@@ -687,6 +696,9 @@ case class UnionExec(children: Seq[SparkPlan]) extends SparkPlan {
 
   protected override def doExecute(): RDD[InternalRow] =
     sparkContext.union(children.map(_.execute()))
+
+  override protected def withNewChildrenInternal(newChildren: IndexedSeq[SparkPlan]): UnionExec =
+    copy(children = newChildren)
 }
 
 /**
@@ -720,6 +732,9 @@ case class CoalesceExec(numPartitions: Int, child: SparkPlan) extends UnaryExecN
       child.execute().coalesce(numPartitions, shuffle = false)
     }
   }
+
+  override protected def withNewChildInternal(newChild: SparkPlan): CoalesceExec =
+    copy(child = newChild)
 }
 
 object CoalesceExec {
@@ -849,6 +864,9 @@ case class SubqueryExec(name: String, child: SparkPlan, maxNumRows: Option[Int]
   }
 
   override def stringArgs: Iterator[Any] = Iterator(name, child) ++ Iterator(s"[id=#$id]")
+
+  override protected def withNewChildInternal(newChild: SparkPlan): SubqueryExec =
+    copy(child = newChild)
 }
 
 object SubqueryExec {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala
index 641bd26..e3c2e90 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala
@@ -34,7 +34,7 @@ import org.apache.spark.sql.types._
 case class AnalyzeColumnCommand(
     tableIdent: TableIdentifier,
     columnNames: Option[Seq[String]],
-    allColumns: Boolean) extends RunnableCommand {
+    allColumns: Boolean) extends LeafRunnableCommand {
 
   override def run(sparkSession: SparkSession): Seq[Row] = {
     require(columnNames.isDefined ^ allColumns, "Parameter `columnNames` or `allColumns` are " +
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala
index 51d4c5f..5b3cb74 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala
@@ -43,7 +43,7 @@ import org.apache.spark.sql.util.PartitioningUtils
 case class AnalyzePartitionCommand(
     tableIdent: TableIdentifier,
     partitionSpec: Map[String, Option[String]],
-    noscan: Boolean = true) extends RunnableCommand {
+    noscan: Boolean = true) extends LeafRunnableCommand {
 
   private def getPartitionSpec(table: CatalogTable): Option[TablePartitionSpec] = {
     val normalizedPartitionSpec =
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala
index d114ca0..157554e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier
  */
 case class AnalyzeTableCommand(
     tableIdent: TableIdentifier,
-    noScan: Boolean = true) extends RunnableCommand {
+    noScan: Boolean = true) extends LeafRunnableCommand {
 
   override def run(sparkSession: SparkSession): Seq[Row] = {
     CommandUtils.analyzeTable(sparkSession, tableIdent, noScan)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTablesCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTablesCommand.scala
index ef07019..c9b22a7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTablesCommand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTablesCommand.scala
@@ -27,7 +27,7 @@ import org.apache.spark.sql.{Row, SparkSession}
  */
 case class AnalyzeTablesCommand(
     databaseName: Option[String],
-    noScan: Boolean) extends RunnableCommand {
+    noScan: Boolean) extends LeafRunnableCommand {
 
   override def run(sparkSession: SparkSession): Seq[Row] = {
     val catalog = sparkSession.sessionState.catalog
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/InsertIntoDataSourceDirCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/InsertIntoDataSourceDirCommand.scala
index d065bc0..be680a7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/InsertIntoDataSourceDirCommand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/InsertIntoDataSourceDirCommand.scala
@@ -42,7 +42,7 @@ case class InsertIntoDataSourceDirCommand(
     storage: CatalogStorageFormat,
     provider: String,
     query: LogicalPlan,
-    overwrite: Boolean) extends RunnableCommand {
+    overwrite: Boolean) extends LeafRunnableCommand {
 
   override def innerChildren: Seq[LogicalPlan] = query :: Nil
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala
index 7d92e6e..0ebc927 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala
@@ -34,7 +34,8 @@ import org.apache.spark.sql.types.{StringType, StructField, StructType}
  *   set;
  * }}}
  */
-case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableCommand with Logging {
+case class SetCommand(kv: Option[(String, Option[String])])
+  extends LeafRunnableCommand with Logging {
 
   private def keyValueOutput: Seq[Attribute] = {
     val schema = StructType(
@@ -169,7 +170,7 @@ object SetCommand {
  *   reset spark.sql.session.timeZone;
  * }}}
  */
-case class ResetCommand(config: Option[String]) extends RunnableCommand with IgnoreCachedData {
+case class ResetCommand(config: Option[String]) extends LeafRunnableCommand with IgnoreCachedData {
 
   override def run(sparkSession: SparkSession): Seq[Row] = {
     val globalInitialConfigs = sparkSession.sharedState.conf
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala
index 2f72af7..de5dbdd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.plans.logical.IgnoreCachedData
 /**
  * Clear all cached data from the in-memory cache.
  */
-case object ClearCacheCommand extends RunnableCommand with IgnoreCachedData {
+case object ClearCacheCommand extends LeafRunnableCommand with IgnoreCachedData {
 
   override def run(sparkSession: SparkSession): Seq[Row] = {
     sparkSession.catalog.clearCache()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala
index 8bc3ced..7f4f816 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala
@@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
 import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
 import org.apache.spark.sql.catalyst.plans.QueryPlan
 import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan}
+import org.apache.spark.sql.catalyst.trees.LeafLike
 import org.apache.spark.sql.connector.ExternalCommandRunner
 import org.apache.spark.sql.execution.{ExplainMode, LeafExecNode, SparkPlan, UnaryExecNode}
 import org.apache.spark.sql.execution.metric.SQLMetric
@@ -48,6 +49,8 @@ trait RunnableCommand extends Command {
   def run(sparkSession: SparkSession): Seq[Row]
 }
 
+trait LeafRunnableCommand extends RunnableCommand with LeafLike[LogicalPlan]
+
 /**
  * A physical operator that executes the run method of a `RunnableCommand` and
  * saves the result to prevent multiple executions.
@@ -132,6 +135,9 @@ case class DataWritingCommandExec(cmd: DataWritingCommand, child: SparkPlan)
   protected override def doExecute(): RDD[InternalRow] = {
     sqlContext.sparkContext.parallelize(sideEffectResult, 1)
   }
+
+  override protected def withNewChildInternal(newChild: SparkPlan): DataWritingCommandExec =
+    copy(child = newChild)
 }
 
 /**
@@ -150,7 +156,7 @@ case class DataWritingCommandExec(cmd: DataWritingCommand, child: SparkPlan)
 case class ExplainCommand(
     logicalPlan: LogicalPlan,
     mode: ExplainMode)
-  extends RunnableCommand {
+  extends LeafRunnableCommand {
 
   override val output: Seq[Attribute] =
     Seq(AttributeReference("plan", StringType, nullable = true)())
@@ -167,7 +173,7 @@ case class ExplainCommand(
 /** An explain command for users to see how a streaming batch is executed. */
 case class StreamingExplainCommand(
     queryExecution: IncrementalExecution,
-    extended: Boolean) extends RunnableCommand {
+    extended: Boolean) extends LeafRunnableCommand {
 
   override val output: Seq[Attribute] =
     Seq(AttributeReference("plan", StringType, nullable = true)())
@@ -193,7 +199,7 @@ case class StreamingExplainCommand(
 case class ExternalCommandExecutor(
     runner: ExternalCommandRunner,
     command: String,
-    options: Map[String, String]) extends RunnableCommand {
+    options: Map[String, String]) extends LeafRunnableCommand {
 
   override def output: Seq[Attribute] =
     Seq(AttributeReference("command_output", StringType)())
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala
index bb54457..bb3869d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala
@@ -42,7 +42,7 @@ import org.apache.spark.sql.types.StructType
  * }}}
  */
 case class CreateDataSourceTableCommand(table: CatalogTable, ignoreIfExists: Boolean)
-  extends RunnableCommand {
+  extends LeafRunnableCommand {
 
   override def run(sparkSession: SparkSession): Seq[Row] = {
     assert(table.tableType != CatalogTableType.VIEW)
@@ -227,4 +227,7 @@ case class CreateDataSourceTableAsSelectCommand(
         throw ex
     }
   }
+
+  override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan =
+    copy(query = newChild)
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
index 7330f5b..c7456cd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
@@ -69,7 +69,7 @@ case class CreateDatabaseCommand(
     path: Option[String],
     comment: Option[String],
     props: Map[String, String])
-  extends RunnableCommand {
+  extends LeafRunnableCommand {
 
   override def run(sparkSession: SparkSession): Seq[Row] = {
     val catalog = sparkSession.sessionState.catalog
@@ -105,7 +105,7 @@ case class DropDatabaseCommand(
     databaseName: String,
     ifExists: Boolean,
     cascade: Boolean)
-  extends RunnableCommand {
+  extends LeafRunnableCommand {
 
   override def run(sparkSession: SparkSession): Seq[Row] = {
     sparkSession.sessionState.catalog.dropDatabase(databaseName, ifExists, cascade)
@@ -125,7 +125,7 @@ case class DropDatabaseCommand(
 case class AlterDatabasePropertiesCommand(
     databaseName: String,
     props: Map[String, String])
-  extends RunnableCommand {
+  extends LeafRunnableCommand {
 
   override def run(sparkSession: SparkSession): Seq[Row] = {
     val catalog = sparkSession.sessionState.catalog
@@ -146,7 +146,7 @@ case class AlterDatabasePropertiesCommand(
  * }}}
  */
 case class AlterDatabaseSetLocationCommand(databaseName: String, location: String)
-  extends RunnableCommand {
+  extends LeafRunnableCommand {
 
   override def run(sparkSession: SparkSession): Seq[Row] = {
     val catalog = sparkSession.sessionState.catalog
@@ -171,7 +171,7 @@ case class DescribeDatabaseCommand(
     databaseName: String,
     extended: Boolean,
     override val output: Seq[Attribute])
-  extends RunnableCommand {
+  extends LeafRunnableCommand {
 
   override def run(sparkSession: SparkSession): Seq[Row] = {
     val dbMetadata: CatalogDatabase =
@@ -211,7 +211,7 @@ case class DropTableCommand(
     tableName: TableIdentifier,
     ifExists: Boolean,
     isView: Boolean,
-    purge: Boolean) extends RunnableCommand {
+    purge: Boolean) extends LeafRunnableCommand {
 
   override def run(sparkSession: SparkSession): Seq[Row] = {
     val catalog = sparkSession.sessionState.catalog
@@ -264,7 +264,7 @@ case class AlterTableSetPropertiesCommand(
     tableName: TableIdentifier,
     properties: Map[String, String],
     isView: Boolean)
-  extends RunnableCommand {
+  extends LeafRunnableCommand {
 
   override def run(sparkSession: SparkSession): Seq[Row] = {
     val catalog = sparkSession.sessionState.catalog
@@ -295,7 +295,7 @@ case class AlterTableUnsetPropertiesCommand(
     propKeys: Seq[String],
     ifExists: Boolean,
     isView: Boolean)
-  extends RunnableCommand {
+  extends LeafRunnableCommand {
 
   override def run(sparkSession: SparkSession): Seq[Row] = {
     val catalog = sparkSession.sessionState.catalog
@@ -333,7 +333,7 @@ case class AlterTableUnsetPropertiesCommand(
 case class AlterTableChangeColumnCommand(
     tableName: TableIdentifier,
     columnName: String,
-    newColumn: StructField) extends RunnableCommand {
+    newColumn: StructField) extends LeafRunnableCommand {
 
   // TODO: support change column name/dataType/metadata/position.
   override def run(sparkSession: SparkSession): Seq[Row] = {
@@ -402,7 +402,7 @@ case class AlterTableSerDePropertiesCommand(
     serdeClassName: Option[String],
     serdeProperties: Option[Map[String, String]],
     partSpec: Option[TablePartitionSpec])
-  extends RunnableCommand {
+  extends LeafRunnableCommand {
 
   // should never happen if we parsed things correctly
   require(serdeClassName.isDefined || serdeProperties.isDefined,
@@ -454,7 +454,7 @@ case class AlterTableAddPartitionCommand(
     tableName: TableIdentifier,
     partitionSpecsAndLocs: Seq[(TablePartitionSpec, Option[String])],
     ifNotExists: Boolean)
-  extends RunnableCommand {
+  extends LeafRunnableCommand {
 
   override def run(sparkSession: SparkSession): Seq[Row] = {
     val catalog = sparkSession.sessionState.catalog
@@ -509,7 +509,7 @@ case class AlterTableRenamePartitionCommand(
     tableName: TableIdentifier,
     oldPartition: TablePartitionSpec,
     newPartition: TablePartitionSpec)
-  extends RunnableCommand {
+  extends LeafRunnableCommand {
 
   override def run(sparkSession: SparkSession): Seq[Row] = {
     val catalog = sparkSession.sessionState.catalog
@@ -556,7 +556,7 @@ case class AlterTableDropPartitionCommand(
     ifExists: Boolean,
     purge: Boolean,
     retainData: Boolean)
-  extends RunnableCommand {
+  extends LeafRunnableCommand {
 
   override def run(sparkSession: SparkSession): Seq[Row] = {
     val catalog = sparkSession.sessionState.catalog
@@ -600,7 +600,7 @@ case class RepairTableCommand(
     tableName: TableIdentifier,
     enableAddPartitions: Boolean,
     enableDropPartitions: Boolean,
-    cmd: String = "MSCK REPAIR TABLE") extends RunnableCommand {
+    cmd: String = "MSCK REPAIR TABLE") extends LeafRunnableCommand {
 
   // These are list of statistics that can be collected quickly without requiring a scan of the data
   // see https://github.com/apache/hive/blob/master/
@@ -833,7 +833,7 @@ case class AlterTableSetLocationCommand(
     tableName: TableIdentifier,
     partitionSpec: Option[TablePartitionSpec],
     location: String)
-  extends RunnableCommand {
+  extends LeafRunnableCommand {
 
   override def run(sparkSession: SparkSession): Seq[Row] = {
     val catalog = sparkSession.sessionState.catalog
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala
index af5ba48..0eda90a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala
@@ -55,7 +55,7 @@ case class CreateFunctionCommand(
     isTemp: Boolean,
     ignoreIfExists: Boolean,
     replace: Boolean)
-  extends RunnableCommand {
+  extends LeafRunnableCommand {
 
   if (ignoreIfExists && replace) {
     throw new AnalysisException("CREATE FUNCTION with both IF NOT EXISTS and REPLACE" +
@@ -112,7 +112,7 @@ case class CreateFunctionCommand(
  */
 case class DescribeFunctionCommand(
     functionName: FunctionIdentifier,
-    isExtended: Boolean) extends RunnableCommand {
+    isExtended: Boolean) extends LeafRunnableCommand {
 
   override val output: Seq[Attribute] = {
     val schema = StructType(StructField("function_desc", StringType, nullable = false) :: Nil)
@@ -177,7 +177,7 @@ case class DropFunctionCommand(
     functionName: String,
     ifExists: Boolean,
     isTemp: Boolean)
-  extends RunnableCommand {
+  extends LeafRunnableCommand {
 
   override def run(sparkSession: SparkSession): Seq[Row] = {
     val catalog = sparkSession.sessionState.catalog
@@ -216,7 +216,7 @@ case class ShowFunctionsCommand(
     pattern: Option[String],
     showUserFunctions: Boolean,
     showSystemFunctions: Boolean,
-    override val output: Seq[Attribute]) extends RunnableCommand {
+    override val output: Seq[Attribute]) extends LeafRunnableCommand {
 
   override def run(sparkSession: SparkSession): Seq[Row] = {
     val dbName = db.getOrElse(sparkSession.sessionState.catalog.getCurrentDatabase)
@@ -255,7 +255,7 @@ case class ShowFunctionsCommand(
 case class RefreshFunctionCommand(
     databaseName: Option[String],
     functionName: String)
-  extends RunnableCommand {
+  extends LeafRunnableCommand {
 
   override def run(sparkSession: SparkSession): Seq[Row] = {
     val catalog = sparkSession.sessionState.catalog
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/resources.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/resources.scala
index 691837f..af053f7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/resources.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/resources.scala
@@ -29,7 +29,7 @@ import org.apache.spark.sql.types.StringType
 /**
  * Adds a jar to the current session so it can be used (for UDFs or serdes).
  */
-case class AddJarCommand(path: String) extends RunnableCommand {
+case class AddJarCommand(path: String) extends LeafRunnableCommand {
   override def run(sparkSession: SparkSession): Seq[Row] = {
     sparkSession.sessionState.resourceLoader.addJar(path)
     Seq.empty[Row]
@@ -39,7 +39,7 @@ case class AddJarCommand(path: String) extends RunnableCommand {
 /**
  * Adds a file to the current session so it can be used.
  */
-case class AddFileCommand(path: String) extends RunnableCommand {
+case class AddFileCommand(path: String) extends LeafRunnableCommand {
   override def run(sparkSession: SparkSession): Seq[Row] = {
     val recursive = !sparkSession.sessionState.conf.addSingleFileInAddFile
     sparkSession.sparkContext.addFile(path, recursive)
@@ -50,7 +50,7 @@ case class AddFileCommand(path: String) extends RunnableCommand {
 /**
  * Adds an archive to the current session so it can be used.
  */
-case class AddArchiveCommand(path: String) extends RunnableCommand {
+case class AddArchiveCommand(path: String) extends LeafRunnableCommand {
   override def run(sparkSession: SparkSession): Seq[Row] = {
     sparkSession.sparkContext.addArchive(path)
     Seq.empty[Row]
@@ -61,7 +61,7 @@ case class AddArchiveCommand(path: String) extends RunnableCommand {
  * Returns a list of file paths that are added to resources.
  * If file paths are provided, return the ones that are added to resources.
  */
-case class ListFilesCommand(files: Seq[String] = Seq.empty[String]) extends RunnableCommand {
+case class ListFilesCommand(files: Seq[String] = Seq.empty[String]) extends LeafRunnableCommand {
   override val output: Seq[Attribute] = {
     AttributeReference("Results", StringType, nullable = false)() :: Nil
   }
@@ -88,7 +88,7 @@ case class ListFilesCommand(files: Seq[String] = Seq.empty[String]) extends Runn
  * Returns a list of jar files that are added to resources.
  * If jar files are provided, return the ones that are added to resources.
  */
-case class ListJarsCommand(jars: Seq[String] = Seq.empty[String]) extends RunnableCommand {
+case class ListJarsCommand(jars: Seq[String] = Seq.empty[String]) extends LeafRunnableCommand {
   override val output: Seq[Attribute] = {
     AttributeReference("Results", StringType, nullable = false)() :: Nil
   }
@@ -109,7 +109,8 @@ case class ListJarsCommand(jars: Seq[String] = Seq.empty[String]) extends Runnab
  * Returns a list of archive paths that are added to resources.
  * If archive paths are provided, return the ones that are added to resources.
  */
-case class ListArchivesCommand(archives: Seq[String] = Seq.empty[String]) extends RunnableCommand {
+case class ListArchivesCommand(archives: Seq[String] = Seq.empty[String])
+  extends LeafRunnableCommand {
   override val output: Seq[Attribute] = {
     AttributeReference("Results", StringType, nullable = false)() :: Nil
   }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
index 488c628..72168f2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
@@ -82,7 +82,7 @@ case class CreateTableLikeCommand(
     fileFormat: CatalogStorageFormat,
     provider: Option[String],
     properties: Map[String, String] = Map.empty,
-    ifNotExists: Boolean) extends RunnableCommand {
+    ifNotExists: Boolean) extends LeafRunnableCommand {
 
   override def run(sparkSession: SparkSession): Seq[Row] = {
     val catalog = sparkSession.sessionState.catalog
@@ -161,7 +161,7 @@ case class CreateTableLikeCommand(
  */
 case class CreateTableCommand(
     table: CatalogTable,
-    ignoreIfExists: Boolean) extends RunnableCommand {
+    ignoreIfExists: Boolean) extends LeafRunnableCommand {
 
   override def run(sparkSession: SparkSession): Seq[Row] = {
     sparkSession.sessionState.catalog.createTable(table, ignoreIfExists)
@@ -183,7 +183,7 @@ case class AlterTableRenameCommand(
     oldName: TableIdentifier,
     newName: TableIdentifier,
     isView: Boolean)
-  extends RunnableCommand {
+  extends LeafRunnableCommand {
 
   override def run(sparkSession: SparkSession): Seq[Row] = {
     val catalog = sparkSession.sessionState.catalog
@@ -224,7 +224,7 @@ case class AlterTableRenameCommand(
 */
 case class AlterTableAddColumnsCommand(
     table: TableIdentifier,
-    colsToAdd: Seq[StructField]) extends RunnableCommand {
+    colsToAdd: Seq[StructField]) extends LeafRunnableCommand {
   override def run(sparkSession: SparkSession): Seq[Row] = {
     val catalog = sparkSession.sessionState.catalog
     val catalogTable = verifyAlterTableAddColumn(sparkSession.sessionState.conf, catalog, table)
@@ -300,7 +300,7 @@ case class LoadDataCommand(
     path: String,
     isLocal: Boolean,
     isOverwrite: Boolean,
-    partition: Option[TablePartitionSpec]) extends RunnableCommand {
+    partition: Option[TablePartitionSpec]) extends LeafRunnableCommand {
 
   override def run(sparkSession: SparkSession): Seq[Row] = {
     val catalog = sparkSession.sessionState.catalog
@@ -441,7 +441,7 @@ object LoadDataCommand {
  */
 case class TruncateTableCommand(
     tableName: TableIdentifier,
-    partitionSpec: Option[TablePartitionSpec]) extends RunnableCommand {
+    partitionSpec: Option[TablePartitionSpec]) extends LeafRunnableCommand {
 
   override def run(spark: SparkSession): Seq[Row] = {
     val catalog = spark.sessionState.catalog
@@ -580,7 +580,7 @@ case class TruncateTableCommand(
   }
 }
 
-abstract class DescribeCommandBase extends RunnableCommand {
+abstract class DescribeCommandBase extends LeafRunnableCommand {
   protected def describeSchema(
       schema: StructType,
       buffer: ArrayBuffer[Row],
@@ -745,7 +745,7 @@ case class DescribeColumnCommand(
     colNameParts: Seq[String],
     isExtended: Boolean,
     override val output: Seq[Attribute])
-  extends RunnableCommand {
+  extends LeafRunnableCommand {
 
 
   override def run(sparkSession: SparkSession): Seq[Row] = {
@@ -828,7 +828,7 @@ case class ShowTablesCommand(
     tableIdentifierPattern: Option[String],
     override val output: Seq[Attribute],
     isExtended: Boolean = false,
-    partitionSpec: Option[TablePartitionSpec] = None) extends RunnableCommand {
+    partitionSpec: Option[TablePartitionSpec] = None) extends LeafRunnableCommand {
 
   override def run(sparkSession: SparkSession): Seq[Row] = {
     // Since we need to return a Seq of rows, we will call getTables directly
@@ -888,7 +888,7 @@ case class ShowTablesCommand(
 case class ShowTablePropertiesCommand(
     table: TableIdentifier,
     propertyKey: Option[String],
-    override val output: Seq[Attribute]) extends RunnableCommand {
+    override val output: Seq[Attribute]) extends LeafRunnableCommand {
 
   override def run(sparkSession: SparkSession): Seq[Row] = {
     val catalog = sparkSession.sessionState.catalog
@@ -924,7 +924,7 @@ case class ShowTablePropertiesCommand(
 case class ShowColumnsCommand(
     databaseName: Option[String],
     tableName: TableIdentifier,
-    override val output: Seq[Attribute]) extends RunnableCommand {
+    override val output: Seq[Attribute]) extends LeafRunnableCommand {
 
   override def run(sparkSession: SparkSession): Seq[Row] = {
     val catalog = sparkSession.sessionState.catalog
@@ -955,7 +955,7 @@ case class ShowColumnsCommand(
 case class ShowPartitionsCommand(
     tableName: TableIdentifier,
     override val output: Seq[Attribute],
-    spec: Option[TablePartitionSpec]) extends RunnableCommand {
+    spec: Option[TablePartitionSpec]) extends LeafRunnableCommand {
 
   override def run(sparkSession: SparkSession): Seq[Row] = {
     val catalog = sparkSession.sessionState.catalog
@@ -1080,7 +1080,7 @@ trait ShowCreateTableCommandBase {
 case class ShowCreateTableCommand(
     table: TableIdentifier,
     override val output: Seq[Attribute])
-    extends RunnableCommand with ShowCreateTableCommandBase {
+    extends LeafRunnableCommand with ShowCreateTableCommandBase {
 
   override def run(sparkSession: SparkSession): Seq[Row] = {
     val catalog = sparkSession.sessionState.catalog
@@ -1234,7 +1234,7 @@ case class ShowCreateTableCommand(
 case class ShowCreateTableAsSerdeCommand(
     table: TableIdentifier,
     override val output: Seq[Attribute])
-    extends RunnableCommand with ShowCreateTableCommandBase {
+    extends LeafRunnableCommand with ShowCreateTableCommandBase {
 
   override def run(sparkSession: SparkSession): Seq[Row] = {
     val catalog = sparkSession.sessionState.catalog
@@ -1354,7 +1354,7 @@ case class ShowCreateTableAsSerdeCommand(
  * }}}
  */
 case class RefreshTableCommand(tableIdent: TableIdentifier)
-  extends RunnableCommand {
+  extends LeafRunnableCommand {
 
   override def run(sparkSession: SparkSession): Seq[Row] = {
     // Refresh the given table's metadata. If this table is cached as an InMemoryRelation,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala
index b302b26..93ea226 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala
@@ -66,7 +66,7 @@ case class CreateViewCommand(
     allowExisting: Boolean,
     replace: Boolean,
     viewType: ViewType)
-  extends RunnableCommand {
+  extends LeafRunnableCommand {
 
   import ViewHelper._
 
@@ -233,7 +233,7 @@ case class CreateViewCommand(
 case class AlterViewAsCommand(
     name: TableIdentifier,
     originalText: String,
-    query: LogicalPlan) extends RunnableCommand {
+    query: LogicalPlan) extends LeafRunnableCommand {
 
   import ViewHelper._
 
@@ -301,7 +301,7 @@ case class AlterViewAsCommand(
 case class ShowViewsCommand(
     databaseName: String,
     tableIdentifierPattern: Option[String],
-    override val output: Seq[Attribute]) extends RunnableCommand {
+    override val output: Seq[Attribute]) extends LeafRunnableCommand {
 
   override def run(sparkSession: SparkSession): Seq[Row] = {
     val catalog = sparkSession.sessionState.catalog
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
index 5f01955..6300e10 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
@@ -68,6 +68,9 @@ object FileFormatWriter extends Logging {
            |}""".stripMargin
       })
     }
+
+    override protected def withNewChildInternal(newChild: Expression): Empty2Null =
+      copy(child = newChild)
   }
 
   /**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala
index bd9cc0e..789b1d7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.datasources
 import org.apache.spark.sql._
 import org.apache.spark.sql.catalyst.plans.QueryPlan
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.execution.command.RunnableCommand
+import org.apache.spark.sql.execution.command.LeafRunnableCommand
 import org.apache.spark.sql.sources.InsertableRelation
 
 
@@ -31,7 +31,7 @@ case class InsertIntoDataSourceCommand(
     logicalRelation: LogicalRelation,
     query: LogicalPlan,
     overwrite: Boolean)
-  extends RunnableCommand {
+  extends LeafRunnableCommand {
 
   override def innerChildren: Seq[QueryPlan[_]] = Seq(query)
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala
index b29ccb85..267b360 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala
@@ -270,4 +270,7 @@ case class InsertIntoHadoopFsRelationCommand(
       }
     }.toMap
   }
+
+  override protected def withNewChildInternal(
+    newChild: LogicalPlan): InsertIntoHadoopFsRelationCommand = copy(query = newChild)
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala
index 5195bb2..486f73c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.datasources
 import org.apache.spark.sql.{Dataset, Row, SaveMode, SparkSession}
 import org.apache.spark.sql.catalyst.plans.QueryPlan
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.execution.command.RunnableCommand
+import org.apache.spark.sql.execution.command.LeafRunnableCommand
 import org.apache.spark.sql.sources.CreatableRelationProvider
 
 /**
@@ -36,7 +36,7 @@ case class SaveIntoDataSourceCommand(
     query: LogicalPlan,
     dataSource: CreatableRelationProvider,
     options: Map[String, String],
-    mode: SaveMode) extends RunnableCommand {
+    mode: SaveMode) extends LeafRunnableCommand {
 
   override def innerChildren: Seq[QueryPlan[_]] = Seq(query)
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala
index 137e502..221db20 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier
 import org.apache.spark.sql.catalyst.catalog.CatalogTable
 import org.apache.spark.sql.catalyst.expressions.Attribute
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.execution.command.{DDLUtils, RunnableCommand}
+import org.apache.spark.sql.execution.command.{DDLUtils, LeafRunnableCommand}
 import org.apache.spark.sql.execution.command.ViewHelper.createTemporaryViewRelation
 import org.apache.spark.sql.internal.StaticSQLConf
 import org.apache.spark.sql.types._
@@ -52,6 +52,10 @@ case class CreateTable(
   override def children: Seq[LogicalPlan] = query.toSeq
   override def output: Seq[Attribute] = Seq.empty
   override lazy val resolved: Boolean = false
+
+  override protected def withNewChildrenInternal(
+      newChildren: IndexedSeq[LogicalPlan]): LogicalPlan =
+    copy(query = if (query.isDefined) Some(newChildren.head) else None)
 }
 
 /**
@@ -63,7 +67,7 @@ case class CreateTempViewUsing(
     replace: Boolean,
     global: Boolean,
     provider: String,
-    options: Map[String, String]) extends RunnableCommand {
+    options: Map[String, String]) extends LeafRunnableCommand {
 
   if (tableIdent.database.isDefined) {
     throw new AnalysisException(
@@ -123,7 +127,7 @@ case class CreateTempViewUsing(
 }
 
 case class RefreshResource(path: String)
-  extends RunnableCommand {
+  extends LeafRunnableCommand {
 
   override def run(sparkSession: SparkSession): Seq[Row] = {
     sparkSession.catalog.refreshByPath(path)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
index 2ed0e06..764b63d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
@@ -47,6 +47,8 @@ case class WriteToDataSourceV2(batchWrite: BatchWrite, query: LogicalPlan)
   extends UnaryNode {
   override def child: LogicalPlan = query
   override def output: Seq[Attribute] = Nil
+  override protected def withNewChildInternal(newChild: LogicalPlan): WriteToDataSourceV2 =
+    copy(query = newChild)
 }
 
 /**
@@ -82,6 +84,9 @@ case class CreateTableAsSelectExec(
       partitioning.toArray, properties.asJava)
     writeToTable(catalog, table, writeOptions, ident)
   }
+
+  override protected def withNewChildInternal(newChild: SparkPlan): CreateTableAsSelectExec =
+    copy(query = newChild)
 }
 
 /**
@@ -116,6 +121,9 @@ case class AtomicCreateTableAsSelectExec(
       ident, schema, partitioning.toArray, properties.asJava)
     writeToTable(catalog, stagedTable, writeOptions, ident)
   }
+
+  override protected def withNewChildInternal(newChild: SparkPlan): AtomicCreateTableAsSelectExec =
+    copy(query = newChild)
 }
 
 /**
@@ -160,6 +168,9 @@ case class ReplaceTableAsSelectExec(
       ident, schema, partitioning.toArray, properties.asJava)
     writeToTable(catalog, table, writeOptions, ident)
   }
+
+  override protected def withNewChildInternal(newChild: SparkPlan): ReplaceTableAsSelectExec =
+    copy(query = newChild)
 }
 
 /**
@@ -207,6 +218,9 @@ case class AtomicReplaceTableAsSelectExec(
     }
     writeToTable(catalog, staged, writeOptions, ident)
   }
+
+  override protected def withNewChildInternal(newChild: SparkPlan): AtomicReplaceTableAsSelectExec =
+    copy(query = newChild)
 }
 
 /**
@@ -217,7 +231,10 @@ case class AtomicReplaceTableAsSelectExec(
 case class AppendDataExec(
     query: SparkPlan,
     refreshCache: () => Unit,
-    write: Write) extends V2ExistingTableWriteExec
+    write: Write) extends V2ExistingTableWriteExec {
+  override protected def withNewChildInternal(newChild: SparkPlan): AppendDataExec =
+    copy(query = newChild)
+}
 
 /**
  * Physical plan node for overwrite into a v2 table.
@@ -232,7 +249,10 @@ case class AppendDataExec(
 case class OverwriteByExpressionExec(
     query: SparkPlan,
     refreshCache: () => Unit,
-    write: Write) extends V2ExistingTableWriteExec
+    write: Write) extends V2ExistingTableWriteExec {
+  override protected def withNewChildInternal(newChild: SparkPlan): OverwriteByExpressionExec =
+    copy(query = newChild)
+}
 
 /**
  * Physical plan node for dynamic partition overwrite into a v2 table.
@@ -246,7 +266,10 @@ case class OverwriteByExpressionExec(
 case class OverwritePartitionsDynamicExec(
     query: SparkPlan,
     refreshCache: () => Unit,
-    write: Write) extends V2ExistingTableWriteExec
+    write: Write) extends V2ExistingTableWriteExec {
+  override protected def withNewChildInternal(newChild: SparkPlan): OverwritePartitionsDynamicExec =
+    copy(query = newChild)
+}
 
 case class WriteToDataSourceV2Exec(
     batchWrite: BatchWrite,
@@ -255,6 +278,9 @@ case class WriteToDataSourceV2Exec(
   override protected def run(): Seq[InternalRow] = {
     writeWithV2(batchWrite)
   }
+
+  override protected def withNewChildInternal(newChild: SparkPlan): WriteToDataSourceV2Exec =
+    copy(query = newChild)
 }
 
 trait V2ExistingTableWriteExec extends V2TableWriteExec {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
index 3cbebca..6c744e6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
@@ -288,5 +288,8 @@ package object debug {
     }
 
     override def supportsColumnar: Boolean = child.supportsColumnar
+
+    override protected def withNewChildInternal(newChild: SparkPlan): DebugExec =
+      copy(child = newChild)
   }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
index ca640c4..94a8a8f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
@@ -205,6 +205,9 @@ case class BroadcastExchangeExec(
           ex)
     }
   }
+
+  override protected def withNewChildInternal(newChild: SparkPlan): BroadcastExchangeExec =
+    copy(child = newChild)
 }
 
 object BroadcastExchangeExec {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
index 2a7b12f..6ec3767 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
@@ -166,6 +166,9 @@ case class ShuffleExchangeExec(
     }
     cachedShuffleRDD
   }
+
+  override protected def withNewChildInternal(newChild: SparkPlan): ShuffleExchangeExec =
+    copy(child = newChild)
 }
 
 object ShuffleExchangeExec {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
index cec1286..ccbcaa2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
@@ -254,4 +254,8 @@ case class BroadcastHashJoinExec(
       super.codegenAnti(ctx, input)
     }
   }
+
+  override protected def withNewChildrenInternal(
+      newLeft: SparkPlan, newRight: SparkPlan): BroadcastHashJoinExec =
+    copy(left = newLeft, right = newRight)
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
index fa1a57a..acdd346 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
@@ -548,4 +548,8 @@ case class BroadcastNestedLoopJoinExec(
      """.stripMargin
     }
   }
+
+  override protected def withNewChildrenInternal(
+      newLeft: SparkPlan, newRight: SparkPlan): BroadcastNestedLoopJoinExec =
+    copy(left = newLeft, right = newRight)
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
index b6386d0..1b2d373 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
@@ -101,4 +101,8 @@ case class CartesianProductExec(
       }
     }
   }
+
+  override protected def withNewChildrenInternal(
+      newLeft: SparkPlan, newRight: SparkPlan): CartesianProductExec =
+    copy(left = newLeft, right = newRight)
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
index cd57408..8514fc2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
@@ -318,4 +318,8 @@ case class ShuffledHashJoinExec(
       v => s"$v = $thisPlan.buildHashedRelation(inputs[1]);", forceInline = true)
     HashedRelationInfo(relationTerm, keyIsUnique = false, isEmpty = false)
   }
+
+  override protected def withNewChildrenInternal(
+      newLeft: SparkPlan, newRight: SparkPlan): ShuffledHashJoinExec =
+    copy(left = newLeft, right = newRight)
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
index eabbdc8..8e0b717 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
@@ -633,6 +633,10 @@ case class SortMergeJoinExec(
        |$eagerCleanup
      """.stripMargin
   }
+
+  override protected def withNewChildrenInternal(
+      newLeft: SparkPlan, newRight: SparkPlan): SortMergeJoinExec =
+    copy(left = newLeft, right = newRight)
 }
 
 /**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
index e5a2995..5114c07 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
@@ -73,6 +73,9 @@ case class CollectLimitExec(limit: Int, child: SparkPlan) extends LimitExec {
       singlePartitionRDD.mapPartitionsInternal(_.take(limit))
     }
   }
+
+  override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
+    copy(child = newChild)
 }
 
 /**
@@ -95,6 +98,9 @@ case class CollectTailExec(limit: Int, child: SparkPlan) extends LimitExec {
     // job launch, we might just have to mimic the implementation of `CollectLimitExec`.
     sparkContext.parallelize(executeCollect(), numSlices = 1)
   }
+
+  override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
+    copy(child = newChild)
 }
 
 object BaseLimitExec {
@@ -160,7 +166,10 @@ trait BaseLimitExec extends LimitExec with CodegenSupport {
 /**
  * Take the first `limit` elements of each child partition, but do not collect or shuffle them.
  */
-case class LocalLimitExec(limit: Int, child: SparkPlan) extends BaseLimitExec
+case class LocalLimitExec(limit: Int, child: SparkPlan) extends BaseLimitExec {
+  override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
+    copy(child = newChild)
+}
 
 /**
  * Take the first `limit` elements of the child's single output partition.
@@ -168,6 +177,9 @@ case class LocalLimitExec(limit: Int, child: SparkPlan) extends BaseLimitExec
 case class GlobalLimitExec(limit: Int, child: SparkPlan) extends BaseLimitExec {
 
   override def requiredChildDistribution: List[Distribution] = AllTuples :: Nil
+
+  override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
+    copy(child = newChild)
 }
 
 /**
@@ -249,4 +261,7 @@ case class TakeOrderedAndProjectExec(
 
     s"TakeOrderedAndProject(limit=$limit, orderBy=$orderByString, output=$outputString)"
   }
+
+  override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
+    copy(child = newChild)
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
index c08db13..fa46f75 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
@@ -99,6 +99,9 @@ case class DeserializeToObjectExec(
       iter.map(projection)
     }
   }
+
+  override protected def withNewChildInternal(newChild: SparkPlan): DeserializeToObjectExec =
+    copy(child = newChild)
 }
 
 /**
@@ -135,6 +138,9 @@ case class SerializeFromObjectExec(
       iter.map(projection)
     }
   }
+
+  override protected def withNewChildInternal(newChild: SparkPlan): SerializeFromObjectExec =
+    copy(child = newChild)
 }
 
 /**
@@ -195,6 +201,9 @@ case class MapPartitionsExec(
       func(iter.map(getObject)).map(outputObject)
     }
   }
+
+  override protected def withNewChildInternal(newChild: SparkPlan): MapPartitionsExec =
+    copy(child = newChild)
 }
 
 /**
@@ -252,6 +261,9 @@ case class MapPartitionsInRWithArrowExec(
       }.map(outputProject)
     }
   }
+
+  override protected def withNewChildInternal(newChild: SparkPlan): MapPartitionsInRWithArrowExec =
+    copy(child = newChild)
 }
 
 /**
@@ -304,6 +316,9 @@ case class MapElementsExec(
   override def outputOrdering: Seq[SortOrder] = child.outputOrdering
 
   override def outputPartitioning: Partitioning = child.outputPartitioning
+
+  override protected def withNewChildInternal(newChild: SparkPlan): MapElementsExec =
+    copy(child = newChild)
 }
 
 /**
@@ -333,6 +348,9 @@ case class AppendColumnsExec(
       }
     }
   }
+
+  override protected def withNewChildInternal(newChild: SparkPlan): AppendColumnsExec =
+    copy(child = newChild)
 }
 
 /**
@@ -366,6 +384,9 @@ case class AppendColumnsWithObjectExec(
       }
     }
   }
+
+  override protected def withNewChildInternal(newChild: SparkPlan): AppendColumnsWithObjectExec =
+    copy(child = newChild)
 }
 
 /**
@@ -405,6 +426,9 @@ case class MapGroupsExec(
       }
     }
   }
+
+  override protected def withNewChildInternal(newChild: SparkPlan): MapGroupsExec =
+    copy(child = newChild)
 }
 
 object MapGroupsExec {
@@ -495,6 +519,9 @@ case class FlatMapGroupsInRExec(
       }
     }
   }
+
+  override protected def withNewChildInternal(newChild: SparkPlan): FlatMapGroupsInRExec =
+    copy(child = newChild)
 }
 
 /**
@@ -577,6 +604,9 @@ case class FlatMapGroupsInRWithArrowExec(
       }.map(outputProject)
     }
   }
+
+  override protected def withNewChildInternal(newChild: SparkPlan): FlatMapGroupsInRWithArrowExec =
+    copy(child = newChild)
 }
 
 /**
@@ -623,4 +653,7 @@ case class CoGroupExec(
       }
     }
   }
+
+  override protected def withNewChildrenInternal(
+    newLeft: SparkPlan, newRight: SparkPlan): CoGroupExec = copy(left = newLeft, right = newRight)
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala
index dadf1129..5019008e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala
@@ -154,4 +154,7 @@ case class AggregateInPandasExec(
       }
     }}
   }
+
+  override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
+    copy(child = newChild)
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
index 67f075f..096712c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
@@ -94,4 +94,7 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute]
       batch.rowIterator.asScala
     }
   }
+
+  override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
+    copy(child = newChild)
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
index 2ab7262..10f7966 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
@@ -103,4 +103,7 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute]
       }
     }
   }
+
+  override protected def withNewChildInternal(newChild: SparkPlan): BatchEvalPythonExec =
+    copy(child = newChild)
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala
index b079405..e830ea6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala
@@ -103,4 +103,8 @@ case class FlatMapCoGroupsInPandasExec(
       }
     }
   }
+
+  override protected def withNewChildrenInternal(
+      newLeft: SparkPlan, newRight: SparkPlan): FlatMapCoGroupsInPandasExec =
+    copy(left = newLeft, right = newRight)
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
index 5032bc8..3a3a602 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
@@ -94,4 +94,7 @@ case class FlatMapGroupsInPandasExec(
       executePython(data, output, runner)
     }}
   }
+
+  override protected def withNewChildInternal(newChild: SparkPlan): FlatMapGroupsInPandasExec =
+    copy(child = newChild)
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInPandasExec.scala
index 71f51f1..0434710 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInPandasExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInPandasExec.scala
@@ -93,4 +93,7 @@ case class MapInPandasExec(
       }.map(unsafeProj)
     }
   }
+
+  override protected def withNewChildInternal(newChild: SparkPlan): MapInPandasExec =
+    copy(child = newChild)
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala
index 983fe9d..909a026 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala
@@ -401,4 +401,7 @@ case class WindowInPandasExec(
       }
     }
   }
+
+  override protected def withNewChildInternal(newChild: SparkPlan): WindowInPandasExec =
+    copy(child = newChild)
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala
index 20fb06a..7e094fe 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala
@@ -125,4 +125,7 @@ case class EventTimeWatermarkExec(
       a
     }
   }
+
+  override protected def withNewChildInternal(newChild: SparkPlan): EventTimeWatermarkExec =
+    copy(child = newChild)
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
index 747094b..fe788dd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
@@ -246,4 +246,7 @@ case class FlatMapGroupsWithStateExec(
       CompletionIterator[InternalRow, Iterator[InternalRow]](mappedIterator, onIteratorCompletion)
     }
   }
+
+  override protected def withNewChildInternal(newChild: SparkPlan): FlatMapGroupsWithStateExec =
+    copy(child = newChild)
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
index 73d2f82..b2c8141 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
@@ -620,4 +620,8 @@ case class StreamingSymmetricHashJoinExec(
 
     def numUpdatedStateRows: Long = updatedStateRowsCount
   }
+
+  override protected def withNewChildrenInternal(
+      newLeft: SparkPlan, newRight: SparkPlan): StreamingSymmetricHashJoinExec =
+    copy(left = newLeft, right = newRight)
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala
index 1923fc9..ceb52f5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala
@@ -28,4 +28,6 @@ case class WriteToContinuousDataSource(write: StreamingWrite, query: LogicalPlan
   extends UnaryNode {
   override def child: LogicalPlan = query
   override def output: Seq[Attribute] = Nil
+  override protected def withNewChildInternal(
+    newChild: LogicalPlan): WriteToContinuousDataSource = copy(query = newChild)
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala
index f1898ad..1e0caf4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala
@@ -70,4 +70,7 @@ case class WriteToContinuousDataSourceExec(write: StreamingWrite, query: SparkPl
 
     sparkContext.emptyRDD
   }
+
+  override protected def withNewChildInternal(
+    newChild: SparkPlan): WriteToContinuousDataSourceExec = copy(query = newChild)
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/WriteToMicroBatchDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/WriteToMicroBatchDataSource.scala
index 4bacd71..7989b94 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/WriteToMicroBatchDataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/WriteToMicroBatchDataSource.scala
@@ -36,4 +36,7 @@ case class WriteToMicroBatchDataSource(write: StreamingWrite, query: LogicalPlan
   def createPlan(batchId: Long): WriteToDataSourceV2 = {
     WriteToDataSourceV2(new MicroBatchWrite(batchId, write), query)
   }
+
+  override protected def withNewChildInternal(newChild: LogicalPlan): WriteToMicroBatchDataSource =
+    copy(query = newChild)
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
index e52f2a1..b52603e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
@@ -281,6 +281,9 @@ case class StateStoreRestoreExec(
       ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil
     }
   }
+
+  override protected def withNewChildInternal(newChild: SparkPlan): StateStoreRestoreExec =
+    copy(child = newChild)
 }
 
 /**
@@ -436,6 +439,9 @@ case class StateStoreSaveExec(
       eventTimeWatermark.isDefined &&
       newMetadata.batchWatermarkMs > eventTimeWatermark.get
   }
+
+  override protected def withNewChildInternal(newChild: SparkPlan): StateStoreSaveExec =
+    copy(child = newChild)
 }
 
 /** Physical operator for executing streaming Deduplicate. */
@@ -509,6 +515,9 @@ case class StreamingDeduplicateExec(
   override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = {
     eventTimeWatermark.isDefined && newMetadata.batchWatermarkMs > eventTimeWatermark.get
   }
+
+  override protected def withNewChildInternal(newChild: SparkPlan): StreamingDeduplicateExec =
+    copy(child = newChild)
 }
 
 object StreamingDeduplicateExec {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/streamingLimits.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/streamingLimits.scala
index e53e064..51723a2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/streamingLimits.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/streamingLimits.scala
@@ -95,6 +95,9 @@ case class StreamingGlobalLimitExec(
   private def getValueRow(value: Long): UnsafeRow = {
     UnsafeProjection.create(valueSchema)(new GenericInternalRow(Array[Any](value)))
   }
+
+  override protected def withNewChildInternal(newChild: SparkPlan): StreamingGlobalLimitExec =
+    copy(child = newChild)
 }
 
 
@@ -133,4 +136,7 @@ case class StreamingLocalLimitExec(limit: Int, child: SparkPlan)
   override def outputPartitioning: Partitioning = child.outputPartitioning
 
   override def output: Seq[Attribute] = child.output
+
+  override protected def withNewChildInternal(newChild: SparkPlan): StreamingLocalLimitExec =
+    copy(child = newChild)
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
index 9c950fd..15b8501 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
@@ -166,6 +166,9 @@ case class InSubqueryExec(
       exprId = ExprId(0),
       resultBroadcast = null)
   }
+
+  override protected def withNewChildInternal(newChild: Expression): InSubqueryExec =
+    copy(child = newChild)
 }
 
 /**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala
index 6e0e36c..8011c80 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala
@@ -211,4 +211,7 @@ case class WindowExec(
       }
     }
   }
+
+  override protected def withNewChildInternal(newChild: SparkPlan): WindowExec =
+    copy(child = newChild)
 }
diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14a.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14a.sf100/explain.txt
index e4ec487..0c19121 100644
--- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14a.sf100/explain.txt
+++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14a.sf100/explain.txt
@@ -720,7 +720,7 @@ Input [6]: [i_brand_id#104, i_class_id#105, i_category_id#106, sales#116, number
 
 (130) Expand [codegen id : 130]
 Input [6]: [sales#68, number_sales#69, channel#73, i_brand_id#54, i_class_id#55, i_category_id#56]
-Arguments: [List(sales#68, number_sales#69, channel#73, i_brand_id#54, i_class_id#55, i_category_id#56, 0), List(sales#68, number_sales#69, channel#73, i_brand_id#54, i_class_id#55, null, 1), List(sales#68, number_sales#69, channel#73, i_brand_id#54, null, null, 3), List(sales#68, number_sales#69, channel#73, null, null, null, 7), List(sales#68, number_sales#69, null, null, null, null, 15)], [sales#68, number_sales#69, channel#120, i_brand_id#121, i_class_id#122, i_category_id#123, spark [...]
+Arguments: [ArrayBuffer(sales#68, number_sales#69, channel#73, i_brand_id#54, i_class_id#55, i_category_id#56, 0), ArrayBuffer(sales#68, number_sales#69, channel#73, i_brand_id#54, i_class_id#55, null, 1), ArrayBuffer(sales#68, number_sales#69, channel#73, i_brand_id#54, null, null, 3), ArrayBuffer(sales#68, number_sales#69, channel#73, null, null, null, 7), ArrayBuffer(sales#68, number_sales#69, null, null, null, null, 15)], [sales#68, number_sales#69, channel#120, i_brand_id#121, i_cla [...]
 
 (131) HashAggregate [codegen id : 130]
 Input [7]: [sales#68, number_sales#69, channel#120, i_brand_id#121, i_class_id#122, i_category_id#123, spark_grouping_id#124]
diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14a/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14a/explain.txt
index 6f61fc8..ffcbef4 100644
--- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14a/explain.txt
+++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q14a/explain.txt
@@ -625,7 +625,7 @@ Input [6]: [i_brand_id#96, i_class_id#97, i_category_id#98, sales#109, number_sa
 
 (111) Expand [codegen id : 79]
 Input [6]: [sales#63, number_sales#64, channel#68, i_brand_id#46, i_class_id#47, i_category_id#48]
-Arguments: [List(sales#63, number_sales#64, channel#68, i_brand_id#46, i_class_id#47, i_category_id#48, 0), List(sales#63, number_sales#64, channel#68, i_brand_id#46, i_class_id#47, null, 1), List(sales#63, number_sales#64, channel#68, i_brand_id#46, null, null, 3), List(sales#63, number_sales#64, channel#68, null, null, null, 7), List(sales#63, number_sales#64, null, null, null, null, 15)], [sales#63, number_sales#64, channel#113, i_brand_id#114, i_class_id#115, i_category_id#116, spark [...]
+Arguments: [ArrayBuffer(sales#63, number_sales#64, channel#68, i_brand_id#46, i_class_id#47, i_category_id#48, 0), ArrayBuffer(sales#63, number_sales#64, channel#68, i_brand_id#46, i_class_id#47, null, 1), ArrayBuffer(sales#63, number_sales#64, channel#68, i_brand_id#46, null, null, 3), ArrayBuffer(sales#63, number_sales#64, channel#68, null, null, null, 7), ArrayBuffer(sales#63, number_sales#64, null, null, null, null, 15)], [sales#63, number_sales#64, channel#113, i_brand_id#114, i_cla [...]
 
 (112) HashAggregate [codegen id : 79]
 Input [7]: [sales#63, number_sales#64, channel#113, i_brand_id#114, i_class_id#115, i_category_id#116, spark_grouping_id#117]
diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q5.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q5.sf100/explain.txt
index 28a4572..c9a772d 100644
--- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q5.sf100/explain.txt
+++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q5.sf100/explain.txt
@@ -429,7 +429,7 @@ Results [5]: [MakeDecimal(sum(UnscaledValue(sales_price#95))#129,17,2) AS sales#
 
 (77) Expand [codegen id : 23]
 Input [5]: [sales#41, RETURNS#42, profit#43, channel#44, id#45]
-Arguments: [List(sales#41, returns#42, profit#43, channel#44, id#45, 0), List(sales#41, returns#42, profit#43, channel#44, null, 1), List(sales#41, returns#42, profit#43, null, null, 3)], [sales#41, returns#42, profit#43, channel#138, id#139, spark_grouping_id#140]
+Arguments: [ArrayBuffer(sales#41, returns#42, profit#43, channel#44, id#45, 0), ArrayBuffer(sales#41, returns#42, profit#43, channel#44, null, 1), ArrayBuffer(sales#41, returns#42, profit#43, null, null, 3)], [sales#41, returns#42, profit#43, channel#138, id#139, spark_grouping_id#140]
 
 (78) HashAggregate [codegen id : 23]
 Input [6]: [sales#41, returns#42, profit#43, channel#138, id#139, spark_grouping_id#140]
diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q5/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q5/explain.txt
index cb130ce..c01302b 100644
--- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q5/explain.txt
+++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q5/explain.txt
@@ -414,7 +414,7 @@ Results [5]: [MakeDecimal(sum(UnscaledValue(sales_price#95))#128,17,2) AS sales#
 
 (74) Expand [codegen id : 20]
 Input [5]: [sales#41, RETURNS#42, profit#43, channel#44, id#45]
-Arguments: [List(sales#41, returns#42, profit#43, channel#44, id#45, 0), List(sales#41, returns#42, profit#43, channel#44, null, 1), List(sales#41, returns#42, profit#43, null, null, 3)], [sales#41, returns#42, profit#43, channel#137, id#138, spark_grouping_id#139]
+Arguments: [ArrayBuffer(sales#41, returns#42, profit#43, channel#44, id#45, 0), ArrayBuffer(sales#41, returns#42, profit#43, channel#44, null, 1), ArrayBuffer(sales#41, returns#42, profit#43, null, null, 3)], [sales#41, returns#42, profit#43, channel#137, id#138, spark_grouping_id#139]
 
 (75) HashAggregate [codegen id : 20]
 Input [6]: [sales#41, returns#42, profit#43, channel#137, id#138, spark_grouping_id#139]
diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q77.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q77.sf100/explain.txt
index 4b2299c..dc5a7fc 100644
--- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q77.sf100/explain.txt
+++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q77.sf100/explain.txt
@@ -488,7 +488,7 @@ Input [6]: [wp_web_page_sk#77, sales#86, profit#87, wp_web_page_sk#92, returns#1
 
 (85) Expand [codegen id : 23]
 Input [5]: [sales#18, returns#37, profit#38, channel#39, id#40]
-Arguments: [List(sales#18, returns#37, profit#38, channel#39, id#40, 0), List(sales#18, returns#37, profit#38, channel#39, null, 1), List(sales#18, returns#37, profit#38, null, null, 3)], [sales#18, returns#37, profit#38, channel#108, id#109, spark_grouping_id#110]
+Arguments: [ArrayBuffer(sales#18, returns#37, profit#38, channel#39, id#40, 0), ArrayBuffer(sales#18, returns#37, profit#38, channel#39, null, 1), ArrayBuffer(sales#18, returns#37, profit#38, null, null, 3)], [sales#18, returns#37, profit#38, channel#108, id#109, spark_grouping_id#110]
 
 (86) HashAggregate [codegen id : 23]
 Input [6]: [sales#18, returns#37, profit#38, channel#108, id#109, spark_grouping_id#110]
diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q77/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q77/explain.txt
index 618da39..62bd5ab 100644
--- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q77/explain.txt
+++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q77/explain.txt
@@ -488,7 +488,7 @@ Input [6]: [wp_web_page_sk#77, sales#86, profit#87, wp_web_page_sk#93, returns#1
 
 (85) Expand [codegen id : 23]
 Input [5]: [sales#18, returns#37, profit#38, channel#39, id#40]
-Arguments: [List(sales#18, returns#37, profit#38, channel#39, id#40, 0), List(sales#18, returns#37, profit#38, channel#39, null, 1), List(sales#18, returns#37, profit#38, null, null, 3)], [sales#18, returns#37, profit#38, channel#108, id#109, spark_grouping_id#110]
+Arguments: [ArrayBuffer(sales#18, returns#37, profit#38, channel#39, id#40, 0), ArrayBuffer(sales#18, returns#37, profit#38, channel#39, null, 1), ArrayBuffer(sales#18, returns#37, profit#38, null, null, 3)], [sales#18, returns#37, profit#38, channel#108, id#109, spark_grouping_id#110]
 
 (86) HashAggregate [codegen id : 23]
 Input [6]: [sales#18, returns#37, profit#38, channel#108, id#109, spark_grouping_id#110]
diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q80.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q80.sf100/explain.txt
index bdb1a52..040407d 100644
--- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q80.sf100/explain.txt
+++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q80.sf100/explain.txt
@@ -590,7 +590,7 @@ Results [5]: [MakeDecimal(sum(UnscaledValue(ws_ext_sales_price#90))#117,17,2) AS
 
 (107) Expand [codegen id : 31]
 Input [5]: [sales#42, returns#43, profit#44, channel#45, id#46]
-Arguments: [List(sales#42, returns#43, profit#44, channel#45, id#46, 0), List(sales#42, returns#43, profit#44, channel#45, null, 1), List(sales#42, returns#43, profit#44, null, null, 3)], [sales#42, returns#43, profit#44, channel#125, id#126, spark_grouping_id#127]
+Arguments: [ArrayBuffer(sales#42, returns#43, profit#44, channel#45, id#46, 0), ArrayBuffer(sales#42, returns#43, profit#44, channel#45, null, 1), ArrayBuffer(sales#42, returns#43, profit#44, null, null, 3)], [sales#42, returns#43, profit#44, channel#125, id#126, spark_grouping_id#127]
 
 (108) HashAggregate [codegen id : 31]
 Input [6]: [sales#42, returns#43, profit#44, channel#125, id#126, spark_grouping_id#127]
diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q80/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q80/explain.txt
index aa15d27..467127a 100644
--- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q80/explain.txt
+++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q80/explain.txt
@@ -590,7 +590,7 @@ Results [5]: [MakeDecimal(sum(UnscaledValue(ws_ext_sales_price#90))#117,17,2) AS
 
 (107) Expand [codegen id : 31]
 Input [5]: [sales#42, returns#43, profit#44, channel#45, id#46]
-Arguments: [List(sales#42, returns#43, profit#44, channel#45, id#46, 0), List(sales#42, returns#43, profit#44, channel#45, null, 1), List(sales#42, returns#43, profit#44, null, null, 3)], [sales#42, returns#43, profit#44, channel#125, id#126, spark_grouping_id#127]
+Arguments: [ArrayBuffer(sales#42, returns#43, profit#44, channel#45, id#46, 0), ArrayBuffer(sales#42, returns#43, profit#44, channel#45, null, 1), ArrayBuffer(sales#42, returns#43, profit#44, null, null, 3)], [sales#42, returns#43, profit#44, channel#125, id#126, spark_grouping_id#127]
 
 (108) HashAggregate [codegen id : 31]
 Input [6]: [sales#42, returns#43, profit#44, channel#125, id#126, spark_grouping_id#127]
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 6914330..70dc0d0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -3637,5 +3637,7 @@ object DataFrameFunctionsSuite {
     override def dataType: DataType = child.dataType
     override lazy val resolved = true
     override def eval(input: InternalRow): Any = child.eval(input)
+    override protected def withNewChildInternal(newChild: Expression): CodegenFallbackExpr =
+      copy(child = newChild)
   }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala
index 9192370..bec68fa 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala
@@ -21,10 +21,10 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
-import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan}
 import org.apache.spark.sql.test.SharedSparkSession
 
-case class FastOperator(output: Seq[Attribute]) extends SparkPlan {
+case class FastOperator(output: Seq[Attribute]) extends LeafExecNode {
 
   override protected def doExecute(): RDD[InternalRow] = {
     val str = Literal("so fast").value
@@ -35,7 +35,6 @@ case class FastOperator(output: Seq[Attribute]) extends SparkPlan {
   }
 
   override def producedAttributes: AttributeSet = outputSet
-  override def children: Seq[SparkPlan] = Nil
 }
 
 object TestStrategy extends Strategy {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
index 35d2513..d4a6d84 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
@@ -582,6 +582,10 @@ class ColumnarAlias(child: ColumnarExpression, name: String)(
   with ColumnarExpression {
 
   override def columnarEval(batch: ColumnarBatch): Any = child.columnarEval(batch)
+
+  override protected def withNewChildInternal(newChild: Expression): ColumnarAlias =
+    new ColumnarAlias(newChild.asInstanceOf[ColumnarExpression], name)(exprId, qualifier,
+      explicitMetadata, nonInheritableMetadataKeys)
 }
 
 class ColumnarAttributeReference(
@@ -641,6 +645,9 @@ class ColumnarProjectExec(projectList: Seq[NamedExpression], child: SparkPlan)
   }
 
   override def hashCode(): Int = super.hashCode()
+
+  override def withNewChildInternal(newChild: SparkPlan): ColumnarProjectExec =
+    new ColumnarProjectExec(projectList, newChild)
 }
 
 /**
@@ -705,6 +712,12 @@ class BrokenColumnarAdd(
     }
     ret
   }
+
+  override def withNewChildrenInternal(
+      newLeft: Expression, newRight: Expression): BrokenColumnarAdd =
+    new BrokenColumnarAdd(
+      left = newLeft.asInstanceOf[ColumnarExpression],
+      right = newRight.asInstanceOf[ColumnarExpression], failOnError)
 }
 
 class CannotReplaceException(str: String) extends RuntimeException(str) {
@@ -781,6 +794,8 @@ case class MyShuffleExchangeExec(delegate: ShuffleExchangeExec) extends ShuffleE
   override def child: SparkPlan = delegate.child
   override protected def doExecute(): RDD[InternalRow] = delegate.execute()
   override def outputPartitioning: Partitioning = delegate.outputPartitioning
+  override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
+    super.legacyWithNewChildren(Seq(newChild))
 }
 
 /**
@@ -798,6 +813,9 @@ case class MyBroadcastExchangeExec(delegate: BroadcastExchangeExec) extends Broa
   override protected def doExecute(): RDD[InternalRow] = delegate.execute()
   override def doExecuteBroadcast[T](): Broadcast[T] = delegate.executeBroadcast()
   override def outputPartitioning: Partitioning = delegate.outputPartitioning
+
+  override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
+    super.legacyWithNewChildren(Seq(newChild))
 }
 
 class ReplacedRowToColumnarExec(override val child: SparkPlan)
@@ -815,6 +833,9 @@ class ReplacedRowToColumnarExec(override val child: SparkPlan)
   }
 
   override def hashCode(): Int = super.hashCode()
+
+  override def withNewChildInternal(newChild: SparkPlan): ReplacedRowToColumnarExec =
+    new ReplacedRowToColumnarExec(newChild)
 }
 
 case class MyPostRule() extends Rule[SparkPlan] {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala
index abe94c2..986e625 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala
@@ -233,8 +233,7 @@ object TypedImperativeAggregateSuite {
       nullable: Boolean = false,
       mutableAggBufferOffset: Int = 0,
       inputAggBufferOffset: Int = 0)
-    extends TypedImperativeAggregate[MaxValue]
-    with ImplicitCastInputTypes
+    extends TypedImperativeAggregate[MaxValue] with ImplicitCastInputTypes
     with UnaryLike[Expression] {
 
     override def createAggregationBuffer(): MaxValue = {
@@ -297,6 +296,9 @@ object TypedImperativeAggregateSuite {
       val value = stream.readInt()
       new MaxValue(value, isValueSet)
     }
+
+    override protected def withNewChildInternal(newChild: Expression): TypedMax =
+      copy(child = newChild)
   }
 
   private class MaxValue(var value: Int, var isValueSet: Boolean = false)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala
index cef870b..2011d05 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala
@@ -600,6 +600,9 @@ case class ExceptionInjectingOperator(child: SparkPlan) extends UnaryExecNode {
   override def output: Seq[Attribute] = child.output
 
   override def outputPartitioning: Partitioning = child.outputPartitioning
+
+  override protected def withNewChildInternal(newChild: SparkPlan): ExceptionInjectingOperator =
+    copy(child = newChild)
 }
 
 @SQLUserDefinedType(udt = classOf[SimpleTupleUDT])
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ColumnarRulesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ColumnarRulesSuite.scala
index dd27900..df08acd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ColumnarRulesSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ColumnarRulesSuite.scala
@@ -60,4 +60,5 @@ case class LeafOp(override val supportsColumnar: Boolean) extends LeafExecNode {
 case class UnaryOp(child: SparkPlan, override val supportsColumnar: Boolean) extends UnaryExecNode {
   override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException()
   override def output: Seq[Attribute] = child.output
+  override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp = copy(child = newChild)
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
index fb97e15..9776e76 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
@@ -40,6 +40,9 @@ case class ColumnarExchange(child: SparkPlan) extends Exchange {
   override protected def doExecute(): RDD[InternalRow] = throw new RanRowBased
 
   override protected def doExecuteColumnar(): RDD[ColumnarBatch] = throw new RanColumnar
+
+  override protected def withNewChildInternal(newChild: SparkPlan): ColumnarExchange =
+    copy(child = newChild)
 }
 
 class ExchangeSuite extends SparkPlanTest with SharedSparkSession {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index 1724f78..0b30b8c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -1264,4 +1264,6 @@ private case class DummySparkPlan(
   ) extends SparkPlan {
   override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException
   override def output: Seq[Attribute] = Seq.empty
+  override protected def withNewChildrenInternal(newChildren: IndexedSeq[SparkPlan]): SparkPlan =
+    copy(children = newChildren)
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReferenceSort.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReferenceSort.scala
index a31e238..1592949 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReferenceSort.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReferenceSort.scala
@@ -58,4 +58,7 @@ case class ReferenceSort(
   override def outputOrdering: Seq[SortOrder] = sortOrder
 
   override def outputPartitioning: Partitioning = child.outputPartitioning
+
+  override protected def withNewChildInternal(newChild: SparkPlan): ReferenceSort =
+    copy(child = newChild)
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
index b17c935..b3d29df 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
 import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, InsertIntoStatement, LogicalPlan, Project}
 import org.apache.spark.sql.execution.{QueryExecution, QueryExecutionException, WholeStageCodegenExec}
 import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
-import org.apache.spark.sql.execution.command.RunnableCommand
+import org.apache.spark.sql.execution.command.LeafRunnableCommand
 import org.apache.spark.sql.execution.datasources.{CreateTable, InsertIntoHadoopFsRelationCommand}
 import org.apache.spark.sql.execution.datasources.json.JsonFileFormat
 import org.apache.spark.sql.test.SharedSparkSession
@@ -302,7 +302,7 @@ class DataFrameCallbackSuite extends QueryTest
 }
 
 /** A test command that throws `java.lang.Error` during execution. */
-case class ErrorTestCommand(foo: String) extends RunnableCommand {
+case class ErrorTestCommand(foo: String) extends LeafRunnableCommand {
 
   override val output: Seq[Attribute] = Seq(AttributeReference("foo", StringType)())
 
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala
index 283c254..fe5d74f 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala
@@ -130,6 +130,9 @@ case class CreateHiveTableAsSelectCommand(
 
   override def writingCommandClassName: String =
     Utils.getSimpleName(classOf[InsertIntoHiveTable])
+
+  override protected def withNewChildInternal(
+    newChild: LogicalPlan): CreateHiveTableAsSelectCommand = copy(query = newChild)
 }
 
 /**
@@ -177,4 +180,7 @@ case class OptimizedCreateHiveTableAsSelectCommand(
 
   override def writingCommandClassName: String =
     Utils.getSimpleName(classOf[InsertIntoHadoopFsRelationCommand])
+
+  override protected def withNewChildInternal(
+    newChild: LogicalPlan): OptimizedCreateHiveTableAsSelectCommand = copy(query = newChild)
 }
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala
index 2059f5b..27fdb22 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala
@@ -184,6 +184,9 @@ private[hive] case class HiveScriptTransformationExec(
 
     outputIterator
   }
+
+  override protected def withNewChildInternal(newChild: SparkPlan): HiveScriptTransformationExec =
+    copy(child = newChild)
 }
 
 private[hive] case class HiveScriptTransformationWriterThread(
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala
index 7ef637e..09aa1e8 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala
@@ -137,5 +137,8 @@ case class InsertIntoHiveDirCommand(
 
     Seq.empty[Row]
   }
+
+  override protected def withNewChildInternal(
+    newChild: LogicalPlan): InsertIntoHiveDirCommand = copy(query = newChild)
 }
 
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
index bfb24cf..fcd11e6 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
@@ -343,4 +343,7 @@ case class InsertIntoHiveTable(
         isSrcLocal = false)
     }
   }
+
+  override protected def withNewChildInternal(newChild: LogicalPlan): InsertIntoHiveTable =
+    copy(query = newChild)
 }
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
index 7717e6e..7c3d161 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
@@ -110,6 +110,9 @@ private[hive] case class HiveSimpleUDF(
   override def prettyName: String = name
 
   override def sql: String = s"$name(${children.map(_.sql).mkString(", ")})"
+
+  override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
+    copy(children = newChildren)
 }
 
 // Adapter from Catalyst ExpressionResult to Hive DeferredObject
@@ -186,6 +189,9 @@ private[hive] case class HiveGenericUDF(
   override def toString: String = {
     s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
   }
+
+  override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
+    copy(children = newChildren)
 }
 
 /**
@@ -279,6 +285,9 @@ private[hive] case class HiveGenericUDTF(
   }
 
   override def prettyName: String = name
+
+  override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
+    copy(children = newChildren)
 }
 
 /**
@@ -528,6 +537,9 @@ private[hive] case class HiveUDAFFunction(
       buffer
     }
   }
+
+  override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
+    copy(children = newChildren)
 }
 
 case class HiveUDAFBuffer(buf: AggregationBuffer, canDoMerge: Boolean)
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestingTypedCount.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestingTypedCount.scala
index 0ef7b33..ee233fb 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestingTypedCount.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestingTypedCount.scala
@@ -78,6 +78,9 @@ case class TestingTypedCount(
     copy(inputAggBufferOffset = newInputAggBufferOffset)
 
   override val prettyName: String = "typed_count"
+
+  override protected def withNewChildInternal(newChild: Expression): TestingTypedCount =
+    copy(child = newChild)
 }
 
 object TestingTypedCount {

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org