You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@asterixdb.apache.org by mb...@apache.org on 2023/02/24 18:50:47 UTC

[asterixdb] 03/11: [ASTERIXDB-3101][COMP] Optimize pushing assign ops down

This is an automated email from the ASF dual-hosted git repository.

mblow pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/asterixdb.git

commit 174ee2ad880d513d6bea64493df96c21e01e6b90
Author: Ali Alsuliman <al...@gmail.com>
AuthorDate: Fri Jan 27 01:43:10 2023 -0800

    [ASTERIXDB-3101][COMP] Optimize pushing assign ops down
    
    - user model changes: no
    - storage format changes: no
    - interface changes: yes
    
    Details:
    One of the things that PushFieldAccessRule attempts to do is
    push assign operator down as close as possible to the respective
    data scan operator. The assign operator is pushed recursively
    through the operators below it one by one until the data scan is
    reached. This becomes expensive when there is a large number
    of assigns. In the case where all the operators below the assign
    operator are other assign operators, the assign operator
    could be moved directly above the data-scan skipping all
    the intermediate assign operators.
    
    - add default method to IAlgebraicRewriteRule to allow the rules
      to know if they are about to rewrite a nested plan root.
    
    Optimize ExtractCommonExpressionsRule since the current traversal
    of operators becomes expensive with a large number of operators.
    - optimize ExtractCommonExpressionsRule to work on only roots of
      plans since the implementation descends to children recursively.
      check if the operator was already rewritten after descending to
      the children to allow post order traversal from the root.
    
    Change-Id: I035b72089f973bb08dccf5f9305f8b06da7fc458
    Reviewed-on: https://asterix-gerrit.ics.uci.edu/c/asterixdb/+/17316
    Integration-Tests: Jenkins <je...@fulliautomatix.ics.uci.edu>
    Tested-by: Jenkins <je...@fulliautomatix.ics.uci.edu>
    Reviewed-by: Michael Blow <mb...@apache.org>
    (cherry picked from commit 964ff7be6d2c026704001bd00430ae3a78bc66f6)
    Reviewed-on: https://asterix-gerrit.ics.uci.edu/c/asterixdb/+/17245
    Reviewed-by: Ali Alsuliman <al...@gmail.com>
---
 .../optimizer/rules/PushFieldAccessRule.java       | 233 +++++++++++++--------
 ...entialFirstRuleCheckFixpointRuleController.java |   2 +-
 .../SequentialFixpointRuleController.java          |   2 +-
 .../SequentialOnceRuleController.java              |   2 +-
 .../core/rewriter/base/AbstractRuleController.java |   9 +-
 .../core/rewriter/base/IAlgebraicRewriteRule.java  |   9 +
 .../rules/ExtractCommonExpressionsRule.java        |  28 ++-
 7 files changed, 185 insertions(+), 100 deletions(-)

diff --git a/asterixdb/asterix-algebra/src/main/java/org/apache/asterix/optimizer/rules/PushFieldAccessRule.java b/asterixdb/asterix-algebra/src/main/java/org/apache/asterix/optimizer/rules/PushFieldAccessRule.java
index c82aa33993..371f460113 100644
--- a/asterixdb/asterix-algebra/src/main/java/org/apache/asterix/optimizer/rules/PushFieldAccessRule.java
+++ b/asterixdb/asterix-algebra/src/main/java/org/apache/asterix/optimizer/rules/PushFieldAccessRule.java
@@ -22,6 +22,7 @@ import java.util.ArrayList;
 import java.util.HashSet;
 import java.util.LinkedList;
 import java.util.List;
+import java.util.Set;
 
 import org.apache.asterix.algebra.base.OperatorAnnotation;
 import org.apache.asterix.common.config.DatasetConfig.DatasetType;
@@ -176,11 +177,11 @@ public class PushFieldAccessRule implements IAlgebraicRewriteRule {
         return e1.equals(e2);
     }
 
-    private boolean pushDownFieldAccessRec(Mutable<ILogicalOperator> opRef, IOptimizationContext context,
+    private boolean pushDownFieldAccessRec(Mutable<ILogicalOperator> assignOpRef, IOptimizationContext context,
             String finalAnnot) throws AlgebricksException {
-        AssignOperator assignOp = (AssignOperator) opRef.getValue();
-        Mutable<ILogicalOperator> opRef2 = assignOp.getInputs().get(0);
-        AbstractLogicalOperator inputOp = (AbstractLogicalOperator) opRef2.getValue();
+        AssignOperator assignOp = (AssignOperator) assignOpRef.getValue();
+        Mutable<ILogicalOperator> inputOpRef = assignOp.getInputs().get(0);
+        AbstractLogicalOperator inputOp = (AbstractLogicalOperator) inputOpRef.getValue();
         // If it's not an indexed field, it is pushed so that scan can be rewritten into index search.
         if (inputOp.getOperatorTag() == LogicalOperatorTag.PROJECT
                 || context.checkAndAddToAlreadyCompared(assignOp, inputOp)
@@ -196,24 +197,31 @@ public class PushFieldAccessRule implements IAlgebraicRewriteRule {
             return false;
         }
         if (testAndModifyRedundantOp(assignOp, inputOp)) {
-            pushDownFieldAccessRec(opRef2, context, finalAnnot);
+            pushDownFieldAccessRec(inputOpRef, context, finalAnnot);
             return true;
         }
-        HashSet<LogicalVariable> usedInAccess = new HashSet<>();
+        Set<LogicalVariable> usedInAccess = new HashSet<>();
         VariableUtilities.getUsedVariables(assignOp, usedInAccess);
-
-        HashSet<LogicalVariable> produced2 = new HashSet<>();
+        if (usedInAccess.isEmpty()) {
+            return false;
+        }
+        Set<LogicalVariable> produced = new HashSet<>();
+        ILogicalOperator dataScanOp =
+                getDataScanOp(assignOpRef, assignOp, inputOpRef, inputOp, usedInAccess, produced, context);
+        if (dataScanOp != null) {
+            // in this case, we don't need to keep pushing the assign op through all the assign operators below it since
+            // this is unnecessary. we just need to try replacing field access by the primary key if it refers to one
+            return rewriteFieldAccessToPK(context, finalAnnot, assignOp, dataScanOp);
+        }
+        produced.clear();
         if (inputOp.getOperatorTag() == LogicalOperatorTag.GROUP) {
-            VariableUtilities.getLiveVariables(inputOp, produced2);
+            VariableUtilities.getLiveVariables(inputOp, produced);
         } else {
-            VariableUtilities.getProducedVariables(inputOp, produced2);
+            VariableUtilities.getProducedVariables(inputOp, produced);
         }
         boolean pushItDown = false;
         HashSet<LogicalVariable> inter = new HashSet<>(usedInAccess);
-        if (inter.isEmpty()) { // ground value
-            return false;
-        }
-        inter.retainAll(produced2);
+        inter.retainAll(produced);
         if (inter.isEmpty()) {
             pushItDown = true;
         } else if (inputOp.getOperatorTag() == LogicalOperatorTag.GROUP) {
@@ -254,18 +262,18 @@ public class PushFieldAccessRule implements IAlgebraicRewriteRule {
             if (inputOp.getOperatorTag() == LogicalOperatorTag.NESTEDTUPLESOURCE) {
                 Mutable<ILogicalOperator> childOfSubplan =
                         ((NestedTupleSourceOperator) inputOp).getDataSourceReference().getValue().getInputs().get(0);
-                pushAccessDown(opRef, inputOp, childOfSubplan, context, finalAnnot);
+                pushAccessDown(assignOpRef, inputOp, childOfSubplan, context, finalAnnot);
                 return true;
             }
             if (inputOp.getInputs().size() == 1 && !inputOp.hasNestedPlans()) {
-                pushAccessDown(opRef, inputOp, inputOp.getInputs().get(0), context, finalAnnot);
+                pushAccessDown(assignOpRef, inputOp, inputOp.getInputs().get(0), context, finalAnnot);
                 return true;
             } else {
                 for (Mutable<ILogicalOperator> inp : inputOp.getInputs()) {
                     HashSet<LogicalVariable> v2 = new HashSet<>();
                     VariableUtilities.getLiveVariables(inp.getValue(), v2);
                     if (v2.containsAll(usedInAccess)) {
-                        pushAccessDown(opRef, inputOp, inp, context, finalAnnot);
+                        pushAccessDown(assignOpRef, inputOp, inp, context, finalAnnot);
                         return true;
                     }
                 }
@@ -277,7 +285,7 @@ public class PushFieldAccessRule implements IAlgebraicRewriteRule {
                         HashSet<LogicalVariable> v2 = new HashSet<>();
                         VariableUtilities.getLiveVariables(root.getValue(), v2);
                         if (v2.containsAll(usedInAccess)) {
-                            pushAccessDown(opRef, inputOp, root, context, finalAnnot);
+                            pushAccessDown(assignOpRef, inputOp, root, context, finalAnnot);
                             return true;
                         }
                     }
@@ -286,73 +294,124 @@ public class PushFieldAccessRule implements IAlgebraicRewriteRule {
             return false;
         } else {
             // check if the accessed field is one of the partitioning key fields. If yes, we can equate the 2 variables
-            if (inputOp.getOperatorTag() == LogicalOperatorTag.DATASOURCESCAN) {
-                DataSourceScanOperator scan = (DataSourceScanOperator) inputOp;
-                IDataSource<DataSourceId> dataSource = (IDataSource<DataSourceId>) scan.getDataSource();
-                byte dsType = ((DataSource) dataSource).getDatasourceType();
-                if (dsType != DataSource.Type.INTERNAL_DATASET && dsType != DataSource.Type.EXTERNAL_DATASET) {
-                    return false;
-                }
-                DataSourceId asid = dataSource.getId();
-                MetadataProvider mp = (MetadataProvider) context.getMetadataProvider();
-                Dataset dataset = mp.findDataset(asid.getDataverseName(), asid.getDatasourceName());
-                if (dataset == null) {
-                    throw new CompilationException(ErrorCode.UNKNOWN_DATASET_IN_DATAVERSE, scan.getSourceLocation(),
-                            asid.getDatasourceName(), asid.getDataverseName());
-                }
-                if (dataset.getDatasetType() != DatasetType.INTERNAL) {
-                    setAsFinal(assignOp, context, finalAnnot);
-                    return false;
-                }
+            return rewriteFieldAccessToPK(context, finalAnnot, assignOp, inputOp);
+        }
+    }
 
-                List<LogicalVariable> allVars = scan.getVariables();
-                LogicalVariable dataRecVarInScan = ((DataSource) dataSource).getDataRecordVariable(allVars);
-                LogicalVariable metaRecVarInScan = ((DataSource) dataSource).getMetaVariable(allVars);
+    /**
+     * Tries to rewrite field access to its equivalent PK. For example, a data scan operator of dataset "ds" produces
+     * the following variables: $PK1, $PK2,.., $ds, ($meta_var?). Given field access: $$ds.getField("id") and given that
+     * the field "id" is one of the primary keys of ds, the field access $$ds.getField("id") is replaced by the primary
+     * key variable (one of the $PKs).
+     * @return true if the field access in the assign operator was replaced by the primary key variable.
+     */
+    private boolean rewriteFieldAccessToPK(IOptimizationContext context, String finalAnnot, AssignOperator assignOp,
+            ILogicalOperator inputOp) throws AlgebricksException {
+        if (inputOp.getOperatorTag() == LogicalOperatorTag.DATASOURCESCAN) {
+            DataSourceScanOperator scan = (DataSourceScanOperator) inputOp;
+            IDataSource<DataSourceId> dataSource = (IDataSource<DataSourceId>) scan.getDataSource();
+            byte dsType = ((DataSource) dataSource).getDatasourceType();
+            if (dsType != DataSource.Type.INTERNAL_DATASET && dsType != DataSource.Type.EXTERNAL_DATASET) {
+                return false;
+            }
+            DataSourceId asid = dataSource.getId();
+            MetadataProvider mp = (MetadataProvider) context.getMetadataProvider();
+            Dataset dataset = mp.findDataset(asid.getDataverseName(), asid.getDatasourceName());
+            if (dataset == null) {
+                throw new CompilationException(ErrorCode.UNKNOWN_DATASET_IN_DATAVERSE, scan.getSourceLocation(),
+                        asid.getDatasourceName(), asid.getDataverseName());
+            }
+            if (dataset.getDatasetType() != DatasetType.INTERNAL) {
+                setAsFinal(assignOp, context, finalAnnot);
+                return false;
+            }
 
-                // data part
-                String dataTypeName = dataset.getItemTypeName();
-                IAType dataType = mp.findType(dataset.getItemTypeDataverseName(), dataTypeName);
-                if (dataType.getTypeTag() != ATypeTag.OBJECT) {
-                    return false;
-                }
-                ARecordType dataRecType = (ARecordType) dataType;
-                Pair<ILogicalExpression, List<String>> fieldPathAndVar = getFieldExpression(assignOp, dataRecType);
-                ILogicalExpression targetRecVar = fieldPathAndVar.first;
-                List<String> targetFieldPath = fieldPathAndVar.second;
-                boolean rewrite = false;
-                boolean fieldFromMeta = false;
-                if (sameRecords(targetRecVar, dataRecVarInScan)) {
-                    rewrite = true;
-                } else {
-                    // check meta part
-                    IAType metaType = mp.findMetaType(dataset); // could be null
-                    if (metaType != null && metaType.getTypeTag() == ATypeTag.OBJECT) {
-                        fieldPathAndVar = getFieldExpression(assignOp, (ARecordType) metaType);
-                        targetRecVar = fieldPathAndVar.first;
-                        targetFieldPath = fieldPathAndVar.second;
-                        if (sameRecords(targetRecVar, metaRecVarInScan)) {
-                            rewrite = true;
-                            fieldFromMeta = true;
-                        }
+            List<LogicalVariable> allVars = scan.getVariables();
+            LogicalVariable dataRecVarInScan = ((DataSource) dataSource).getDataRecordVariable(allVars);
+            LogicalVariable metaRecVarInScan = ((DataSource) dataSource).getMetaVariable(allVars);
+
+            // data part
+            String dataTypeName = dataset.getItemTypeName();
+            IAType dataType = mp.findType(dataset.getItemTypeDataverseName(), dataTypeName);
+            if (dataType.getTypeTag() != ATypeTag.OBJECT) {
+                return false;
+            }
+            ARecordType dataRecType = (ARecordType) dataType;
+            Pair<ILogicalExpression, List<String>> fieldPathAndVar = getFieldExpression(assignOp, dataRecType);
+            ILogicalExpression targetRecVar = fieldPathAndVar.first;
+            List<String> targetFieldPath = fieldPathAndVar.second;
+            boolean rewrite = false;
+            boolean fieldFromMeta = false;
+            if (sameRecords(targetRecVar, dataRecVarInScan)) {
+                rewrite = true;
+            } else {
+                // check meta part
+                IAType metaType = mp.findMetaType(dataset); // could be null
+                if (metaType != null && metaType.getTypeTag() == ATypeTag.OBJECT) {
+                    fieldPathAndVar = getFieldExpression(assignOp, (ARecordType) metaType);
+                    targetRecVar = fieldPathAndVar.first;
+                    targetFieldPath = fieldPathAndVar.second;
+                    if (sameRecords(targetRecVar, metaRecVarInScan)) {
+                        rewrite = true;
+                        fieldFromMeta = true;
                     }
                 }
+            }
 
-                if (rewrite) {
-                    int p = DatasetUtil.getPositionOfPartitioningKeyField(dataset, targetFieldPath, fieldFromMeta);
-                    if (p < 0) { // not one of the partitioning fields
-                        setAsFinal(assignOp, context, finalAnnot);
-                        return false;
-                    }
-                    LogicalVariable keyVar = scan.getVariables().get(p);
-                    VariableReferenceExpression keyVarRef = new VariableReferenceExpression(keyVar);
-                    keyVarRef.setSourceLocation(targetRecVar.getSourceLocation());
-                    assignOp.getExpressions().get(0).setValue(keyVarRef);
-                    return true;
+            if (rewrite) {
+                int p = DatasetUtil.getPositionOfPartitioningKeyField(dataset, targetFieldPath, fieldFromMeta);
+                if (p < 0) { // not one of the partitioning fields
+                    setAsFinal(assignOp, context, finalAnnot);
+                    return false;
                 }
+                LogicalVariable keyVar = scan.getVariables().get(p);
+                VariableReferenceExpression keyVarRef = new VariableReferenceExpression(keyVar);
+                keyVarRef.setSourceLocation(targetRecVar.getSourceLocation());
+                assignOp.getExpressions().get(0).setValue(keyVarRef);
+                return true;
             }
-            setAsFinal(assignOp, context, finalAnnot);
-            return false;
         }
+        setAsFinal(assignOp, context, finalAnnot);
+        return false;
+    }
+
+    /**
+     * Looks for a data scan operator where the data scan operator is below only assign operators. Then, if
+     * applicable, the assign operator is moved down and placed above the data-scan.
+     *
+     * @return the data scan operator if it exists below multiple assign operators only and the assign operator is now
+     * above the data-scan.
+     */
+    private ILogicalOperator getDataScanOp(Mutable<ILogicalOperator> assignOpRef, AssignOperator assignOp,
+            Mutable<ILogicalOperator> assignInputRef, ILogicalOperator assignInput, Set<LogicalVariable> usedInAssign,
+            Set<LogicalVariable> producedByInput, IOptimizationContext context) throws AlgebricksException {
+        ILogicalOperator firstInput = assignInput;
+        while (assignInput.getOperatorTag() == LogicalOperatorTag.ASSIGN) {
+            if (isRedundantAssign(assignOp, assignInput)) {
+                return null;
+            }
+            assignInputRef = assignInput.getInputs().get(0);
+            assignInput = assignInputRef.getValue();
+        }
+        if (assignInput.getOperatorTag() != LogicalOperatorTag.DATASOURCESCAN) {
+            return null;
+        }
+        VariableUtilities.getProducedVariables(assignInput, producedByInput);
+        if (!producedByInput.containsAll(usedInAssign)) {
+            return null;
+        }
+        if (firstInput == assignInput) {
+            // the input to the assign operator is already a data-scan
+            return assignInput;
+        }
+        // move the assign op down, place it above the data-scan
+        assignOpRef.setValue(firstInput);
+        List<Mutable<ILogicalOperator>> assignInputs = assignOp.getInputs();
+        assignInputs.get(0).setValue(assignInput);
+        assignInputRef.setValue(assignOp);
+        context.computeAndSetTypeEnvironmentForOperator(assignOp);
+        context.computeAndSetTypeEnvironmentForOperator(firstInput);
+        return assignInput;
     }
 
     /**
@@ -398,12 +457,9 @@ public class PushFieldAccessRule implements IAlgebraicRewriteRule {
     }
 
     private boolean testAndModifyRedundantOp(AssignOperator access, AbstractLogicalOperator op2) {
-        if (op2.getOperatorTag() != LogicalOperatorTag.ASSIGN) {
-            return false;
-        }
-        AssignOperator a2 = (AssignOperator) op2;
-        ILogicalExpression accessExpr0 = getFirstExpr(access);
-        if (accessExpr0.equals(getFirstExpr(a2))) {
+        if (isRedundantAssign(access, op2)) {
+            AssignOperator a2 = (AssignOperator) op2;
+            ILogicalExpression accessExpr0 = getFirstExpr(access);
             VariableReferenceExpression varRef = new VariableReferenceExpression(a2.getVariables().get(0));
             varRef.setSourceLocation(accessExpr0.getSourceLocation());
             access.getExpressions().get(0).setValue(varRef);
@@ -413,6 +469,14 @@ public class PushFieldAccessRule implements IAlgebraicRewriteRule {
         }
     }
 
+    private static boolean isRedundantAssign(AssignOperator assignOp, ILogicalOperator inputOp) {
+        if (inputOp.getOperatorTag() != LogicalOperatorTag.ASSIGN) {
+            return false;
+        }
+        ILogicalExpression assignOpExpr = getFirstExpr(assignOp);
+        return assignOpExpr.equals(getFirstExpr((AssignOperator) inputOp));
+    }
+
     // indirect recursivity with pushDownFieldAccessRec
     private void pushAccessDown(Mutable<ILogicalOperator> fldAccessOpRef, ILogicalOperator op2,
             Mutable<ILogicalOperator> inputOfOp2, IOptimizationContext context, String finalAnnot)
@@ -429,8 +493,7 @@ public class PushFieldAccessRule implements IAlgebraicRewriteRule {
         pushDownFieldAccessRec(inputOfOp2, context, finalAnnot);
     }
 
-    private ILogicalExpression getFirstExpr(AssignOperator assign) {
+    private static ILogicalExpression getFirstExpr(AssignOperator assign) {
         return assign.getExpressions().get(0).getValue();
     }
-
 }
diff --git a/hyracks-fullstack/algebricks/algebricks-compiler/src/main/java/org/apache/hyracks/algebricks/compiler/rewriter/rulecontrollers/SequentialFirstRuleCheckFixpointRuleController.java b/hyracks-fullstack/algebricks/algebricks-compiler/src/main/java/org/apache/hyracks/algebricks/compiler/rewriter/rulecontrollers/SequentialFirstRuleCheckFixpointRuleController.java
index 29c178a238..79ec0fa317 100644
--- a/hyracks-fullstack/algebricks/algebricks-compiler/src/main/java/org/apache/hyracks/algebricks/compiler/rewriter/rulecontrollers/SequentialFirstRuleCheckFixpointRuleController.java
+++ b/hyracks-fullstack/algebricks/algebricks-compiler/src/main/java/org/apache/hyracks/algebricks/compiler/rewriter/rulecontrollers/SequentialFirstRuleCheckFixpointRuleController.java
@@ -72,7 +72,7 @@ public class SequentialFirstRuleCheckFixpointRuleController extends AbstractRule
         do {
             anyChange = false;
             for (int i = 0; i < rules.size(); i++) {
-                boolean ruleFired = rewriteOperatorRef(root, rules.get(i), true, fullDfs);
+                boolean ruleFired = rewriteOperatorRef(root, rules.get(i), true, fullDfs, false);
                 // If the first rule returns false in the first iteration, stops applying the rules at all.
                 if (!firstRuleChecked && i == 0 && !ruleFired) {
                     return ruleFired;
diff --git a/hyracks-fullstack/algebricks/algebricks-compiler/src/main/java/org/apache/hyracks/algebricks/compiler/rewriter/rulecontrollers/SequentialFixpointRuleController.java b/hyracks-fullstack/algebricks/algebricks-compiler/src/main/java/org/apache/hyracks/algebricks/compiler/rewriter/rulecontrollers/SequentialFixpointRuleController.java
index 1fef33e866..bbe281d715 100644
--- a/hyracks-fullstack/algebricks/algebricks-compiler/src/main/java/org/apache/hyracks/algebricks/compiler/rewriter/rulecontrollers/SequentialFixpointRuleController.java
+++ b/hyracks-fullstack/algebricks/algebricks-compiler/src/main/java/org/apache/hyracks/algebricks/compiler/rewriter/rulecontrollers/SequentialFixpointRuleController.java
@@ -49,7 +49,7 @@ public class SequentialFixpointRuleController extends AbstractRuleController {
         do {
             anyChange = false;
             for (IAlgebraicRewriteRule rule : ruleCollection) {
-                boolean ruleFired = rewriteOperatorRef(root, rule, true, fullDfs);
+                boolean ruleFired = rewriteOperatorRef(root, rule, true, fullDfs, false);
                 if (ruleFired) {
                     anyChange = true;
                     anyRuleFired = true;
diff --git a/hyracks-fullstack/algebricks/algebricks-compiler/src/main/java/org/apache/hyracks/algebricks/compiler/rewriter/rulecontrollers/SequentialOnceRuleController.java b/hyracks-fullstack/algebricks/algebricks-compiler/src/main/java/org/apache/hyracks/algebricks/compiler/rewriter/rulecontrollers/SequentialOnceRuleController.java
index bcbc20727c..1090fe1bd1 100644
--- a/hyracks-fullstack/algebricks/algebricks-compiler/src/main/java/org/apache/hyracks/algebricks/compiler/rewriter/rulecontrollers/SequentialOnceRuleController.java
+++ b/hyracks-fullstack/algebricks/algebricks-compiler/src/main/java/org/apache/hyracks/algebricks/compiler/rewriter/rulecontrollers/SequentialOnceRuleController.java
@@ -40,7 +40,7 @@ public class SequentialOnceRuleController extends AbstractRuleController {
             throws AlgebricksException {
         boolean fired = false;
         for (IAlgebraicRewriteRule rule : rules) {
-            if (rewriteOperatorRef(root, rule, enterNestedPlans, true)) {
+            if (rewriteOperatorRef(root, rule, enterNestedPlans, true, false)) {
                 fired = true;
             }
         }
diff --git a/hyracks-fullstack/algebricks/algebricks-core/src/main/java/org/apache/hyracks/algebricks/core/rewriter/base/AbstractRuleController.java b/hyracks-fullstack/algebricks/algebricks-core/src/main/java/org/apache/hyracks/algebricks/core/rewriter/base/AbstractRuleController.java
index 02611063db..9a47b8aeed 100644
--- a/hyracks-fullstack/algebricks/algebricks-core/src/main/java/org/apache/hyracks/algebricks/core/rewriter/base/AbstractRuleController.java
+++ b/hyracks-fullstack/algebricks/algebricks-core/src/main/java/org/apache/hyracks/algebricks/core/rewriter/base/AbstractRuleController.java
@@ -67,14 +67,15 @@ public abstract class AbstractRuleController {
      */
     protected boolean rewriteOperatorRef(Mutable<ILogicalOperator> opRef, IAlgebraicRewriteRule rule)
             throws AlgebricksException {
-        return rewriteOperatorRef(opRef, rule, true, false);
+        return rewriteOperatorRef(opRef, rule, true, false, false);
     }
 
     protected boolean rewriteOperatorRef(Mutable<ILogicalOperator> opRef, IAlgebraicRewriteRule rule,
-            boolean enterNestedPlans, boolean fullDFS) throws AlgebricksException {
+            boolean enterNestedPlans, boolean fullDFS, boolean enteredNestedPlanRoot) throws AlgebricksException {
 
         String preBeforePlan = getPlanString(opRef);
         sanityCheckBeforeRewrite(rule, opRef);
+        rule.enteredNestedPlan(enteredNestedPlanRoot);
         if (rule.rewritePre(opRef, context)) {
             String preAfterPlan = getPlanString(opRef);
             printRuleApplication(rule, "fired", preBeforePlan, preAfterPlan);
@@ -88,7 +89,7 @@ public abstract class AbstractRuleController {
         AbstractLogicalOperator op = (AbstractLogicalOperator) opRef.getValue();
 
         for (Mutable<ILogicalOperator> inp : op.getInputs()) {
-            if (rewriteOperatorRef(inp, rule, enterNestedPlans, fullDFS)) {
+            if (rewriteOperatorRef(inp, rule, enterNestedPlans, fullDFS, false)) {
                 rewritten = true;
                 if (!fullDFS) {
                     break;
@@ -100,7 +101,7 @@ public abstract class AbstractRuleController {
             AbstractOperatorWithNestedPlans o2 = (AbstractOperatorWithNestedPlans) op;
             for (ILogicalPlan p : o2.getNestedPlans()) {
                 for (Mutable<ILogicalOperator> r : p.getRoots()) {
-                    if (rewriteOperatorRef(r, rule, enterNestedPlans, fullDFS)) {
+                    if (rewriteOperatorRef(r, rule, enterNestedPlans, fullDFS, true)) {
                         rewritten = true;
                         if (!fullDFS) {
                             break;
diff --git a/hyracks-fullstack/algebricks/algebricks-core/src/main/java/org/apache/hyracks/algebricks/core/rewriter/base/IAlgebraicRewriteRule.java b/hyracks-fullstack/algebricks/algebricks-core/src/main/java/org/apache/hyracks/algebricks/core/rewriter/base/IAlgebraicRewriteRule.java
index 128c372c00..33bc4a9a8a 100644
--- a/hyracks-fullstack/algebricks/algebricks-core/src/main/java/org/apache/hyracks/algebricks/core/rewriter/base/IAlgebraicRewriteRule.java
+++ b/hyracks-fullstack/algebricks/algebricks-core/src/main/java/org/apache/hyracks/algebricks/core/rewriter/base/IAlgebraicRewriteRule.java
@@ -54,4 +54,13 @@ public interface IAlgebraicRewriteRule {
             throws AlgebricksException {
         return false;
     }
+
+    /**
+     * Called before calling {@link #rewritePre} to designate if the {@code opRef} is a nested plan root.
+     *
+     * @param enteredNestedPlanRoot whether the operator to be rewritten is a nested plan root.
+     */
+    default void enteredNestedPlan(boolean enteredNestedPlanRoot) {
+        // no op
+    }
 }
diff --git a/hyracks-fullstack/algebricks/algebricks-rewriter/src/main/java/org/apache/hyracks/algebricks/rewriter/rules/ExtractCommonExpressionsRule.java b/hyracks-fullstack/algebricks/algebricks-rewriter/src/main/java/org/apache/hyracks/algebricks/rewriter/rules/ExtractCommonExpressionsRule.java
index 942049870b..e2ba5571b7 100644
--- a/hyracks-fullstack/algebricks/algebricks-rewriter/src/main/java/org/apache/hyracks/algebricks/rewriter/rules/ExtractCommonExpressionsRule.java
+++ b/hyracks-fullstack/algebricks/algebricks-rewriter/src/main/java/org/apache/hyracks/algebricks/rewriter/rules/ExtractCommonExpressionsRule.java
@@ -77,17 +77,17 @@ import org.apache.hyracks.api.exceptions.SourceLocation;
  */
 public class ExtractCommonExpressionsRule implements IAlgebraicRewriteRule {
 
-    private final List<ILogicalExpression> originalAssignExprs = new ArrayList<ILogicalExpression>();
+    private final List<ILogicalExpression> originalAssignExprs = new ArrayList<>();
 
     private final CommonExpressionSubstitutionVisitor substVisitor = new CommonExpressionSubstitutionVisitor();
-    private final Map<ILogicalExpression, ExprEquivalenceClass> exprEqClassMap =
-            new HashMap<ILogicalExpression, ExprEquivalenceClass>();
+    private final Map<ILogicalExpression, ExprEquivalenceClass> exprEqClassMap = new HashMap<>();
 
     private final List<LogicalVariable> tmpLiveVars = new ArrayList<>();
     private final List<LogicalVariable> tmpProducedVars = new ArrayList<>();
+    private boolean enteredNestedPlan = false;
 
     // Set of operators for which common subexpression elimination should not be performed.
-    private static final Set<LogicalOperatorTag> ignoreOps = new HashSet<LogicalOperatorTag>(6);
+    private static final Set<LogicalOperatorTag> ignoreOps = new HashSet<>(6);
 
     static {
         ignoreOps.add(LogicalOperatorTag.UNNEST);
@@ -99,6 +99,11 @@ public class ExtractCommonExpressionsRule implements IAlgebraicRewriteRule {
         ignoreOps.add(LogicalOperatorTag.WINDOW); //TODO: can extract from partition/order/frame expressions
     }
 
+    @Override
+    public void enteredNestedPlan(boolean enteredNestedPlanRoot) {
+        this.enteredNestedPlan = enteredNestedPlanRoot;
+    }
+
     @Override
     public boolean rewritePost(Mutable<ILogicalOperator> opRef, IOptimizationContext context)
             throws AlgebricksException {
@@ -108,6 +113,14 @@ public class ExtractCommonExpressionsRule implements IAlgebraicRewriteRule {
     @Override
     public boolean rewritePre(Mutable<ILogicalOperator> opRef, IOptimizationContext context)
             throws AlgebricksException {
+        ILogicalOperator op = opRef.getValue();
+        if (enteredNestedPlan) {
+            enteredNestedPlan = false;
+        } else if (op.getOperatorTag() != LogicalOperatorTag.DISTRIBUTE_RESULT
+                && op.getOperatorTag() != LogicalOperatorTag.SINK
+                && op.getOperatorTag() != LogicalOperatorTag.DELEGATE_OPERATOR) {
+            return false;
+        }
         exprEqClassMap.clear();
         substVisitor.setContext(context);
         boolean modified = removeCommonExpressions(opRef, context);
@@ -155,9 +168,6 @@ public class ExtractCommonExpressionsRule implements IAlgebraicRewriteRule {
     private boolean removeCommonExpressions(Mutable<ILogicalOperator> opRef, IOptimizationContext context)
             throws AlgebricksException {
         AbstractLogicalOperator op = (AbstractLogicalOperator) opRef.getValue();
-        if (context.checkIfInDontApplySet(this, opRef.getValue())) {
-            return false;
-        }
 
         boolean modified = false;
         // Recurse into children.
@@ -166,7 +176,9 @@ public class ExtractCommonExpressionsRule implements IAlgebraicRewriteRule {
                 modified = true;
             }
         }
-
+        if (context.checkIfInDontApplySet(this, opRef.getValue())) {
+            return modified;
+        }
         // TODO: Deal with replicate properly. Currently, we just clear the expr equivalence map,
         // since we want to avoid incorrect expression replacement
         // (the resulting new variables should be assigned live below a replicate/split).