You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@drill.apache.org by bo...@apache.org on 2019/05/09 01:58:52 UTC

[drill] 02/05: DRILL-7148: Use improved join cardinality and ndv estimation with statistics

This is an automated email from the ASF dual-hosted git repository.

boaz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/drill.git

commit c1ca512ea54cc7fe1b7ee581f21aa84a95aab14f
Author: Gautam Parai <gp...@maprtech.com>
AuthorDate: Tue Mar 19 12:04:13 2019 -0700

    DRILL-7148: Use improved join cardinality and ndv estimation with statistics
    
    closes #1744
---
 .../exec/planner/common/DrillJoinRelBase.java      |  34 ++--
 .../drill/exec/planner/common/DrillRelOptUtil.java |  62 +++++++
 .../drill/exec/planner/common/DrillStatsTable.java |   2 +-
 .../planner/cost/DrillRelMdDistinctRowCount.java   | 181 +++++++++++++++++----
 .../exec/planner/cost/DrillRelMdRowCount.java      |  41 +++--
 .../exec/planner/cost/DrillRelMdSelectivity.java   |  40 +++--
 .../exec/planner/physical/PlannerSettings.java     |   6 +
 .../exec/server/options/SystemOptionManager.java   |   1 +
 .../java/org/apache/drill/exec/util/Utilities.java |   2 +-
 .../java-exec/src/main/resources/drill-module.conf |   1 +
 .../org/apache/drill/exec/sql/TestAnalyze.java     |  12 +-
 .../drill/metastore/ColumnStatisticsKind.java      |   2 +-
 12 files changed, 295 insertions(+), 89 deletions(-)

diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/common/DrillJoinRelBase.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/common/DrillJoinRelBase.java
index a6b6f4a..6150bf3 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/common/DrillJoinRelBase.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/common/DrillJoinRelBase.java
@@ -17,6 +17,7 @@
  */
 package org.apache.drill.exec.planner.common;
 
+import java.util.ArrayList;
 import java.util.Collections;
 import java.util.HashSet;
 import java.util.List;
@@ -33,6 +34,7 @@ import org.apache.calcite.rel.metadata.RelMetadataQuery;
 import org.apache.calcite.rel.type.RelDataType;
 import org.apache.calcite.rex.RexNode;
 import org.apache.calcite.util.ImmutableBitSet;
+import org.apache.calcite.util.Pair;
 import org.apache.drill.exec.ExecConstants;
 import org.apache.drill.exec.expr.holders.IntHolder;
 import org.apache.drill.exec.physical.impl.join.JoinUtils;
@@ -102,26 +104,32 @@ public abstract class DrillJoinRelBase extends Join implements DrillJoin {
       return joinRowFactor * this.getLeft().estimateRowCount(mq) * this.getRight().estimateRowCount(mq);
     }
 
-    int[] joinFields = new int[2];
-
     LogicalJoin jr = LogicalJoin.create(this.getLeft(), this.getRight(), this.getCondition(),
             this.getVariablesSet(), this.getJoinType());
 
     if (!DrillRelOptUtil.guessRows(this)         //Statistics present for left and right side of the join
-        && jr.getJoinType() == JoinRelType.INNER
-        && DrillRelOptUtil.analyzeSimpleEquiJoin((Join)jr, joinFields)) {
-      ImmutableBitSet leq = ImmutableBitSet.of(joinFields[0]);
-      ImmutableBitSet req = ImmutableBitSet.of(joinFields[1]);
+        && jr.getJoinType() == JoinRelType.INNER) {
+      List<Pair<Integer, Integer>> joinConditions = DrillRelOptUtil.analyzeSimpleEquiJoin((Join)jr);
+      if (joinConditions.size() > 0) {
+        List<Integer> leftSide =  new ArrayList<>();
+        List<Integer> rightSide = new ArrayList<>();
+        for (Pair<Integer, Integer> condition : joinConditions) {
+          leftSide.add(condition.left);
+          rightSide.add(condition.right);
+        }
+        ImmutableBitSet leq = ImmutableBitSet.of(leftSide);
+        ImmutableBitSet req = ImmutableBitSet.of(rightSide);
 
-      Double ldrc = mq.getDistinctRowCount(this.getLeft(), leq, null);
-      Double rdrc = mq.getDistinctRowCount(this.getRight(), req, null);
+        Double ldrc = mq.getDistinctRowCount(this.getLeft(), leq, null);
+        Double rdrc = mq.getDistinctRowCount(this.getRight(), req, null);
 
-      Double lrc = mq.getRowCount(this.getLeft());
-      Double rrc = mq.getRowCount(this.getRight());
+        Double lrc = mq.getRowCount(this.getLeft());
+        Double rrc = mq.getRowCount(this.getRight());
 
-      if (ldrc != null && rdrc != null && lrc != null && rrc != null) {
-        // Join cardinality = (lrc * rrc) / Math.max(ldrc, rdrc). Avoid overflow by dividing earlier
-        return (lrc / Math.max(ldrc, rdrc)) * rrc;
+        if (ldrc != null && rdrc != null && lrc != null && rrc != null) {
+          // Join cardinality = (lrc * rrc) / Math.max(ldrc, rdrc). Avoid overflow by dividing earlier
+          return (lrc / Math.max(ldrc, rdrc)) * rrc;
+        }
       }
     }
 
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/common/DrillRelOptUtil.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/common/DrillRelOptUtil.java
index 66499d6..3838bf9 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/common/DrillRelOptUtil.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/common/DrillRelOptUtil.java
@@ -667,4 +667,66 @@ public abstract class DrillRelOptUtil {
     }
     return drillTable;
   }
+
+  public static List<Pair<Integer, Integer>> analyzeSimpleEquiJoin(Join join) {
+    List<Pair<Integer, Integer>> joinConditions = new ArrayList<>();
+    try {
+      RexVisitor<Void> visitor =
+          new RexVisitorImpl<Void>(true) {
+            public Void visitCall(RexCall call) {
+              if (call.getKind() == SqlKind.AND || call.getKind() == SqlKind.OR) {
+                super.visitCall(call);
+              } else {
+                if (call.getKind() == SqlKind.EQUALS) {
+                  int leftFieldCount = join.getLeft().getRowType().getFieldCount();
+                  int rightFieldCount = join.getRight().getRowType().getFieldCount();
+                  RexNode leftComparand = call.operands.get(0);
+                  RexNode rightComparand = call.operands.get(1);
+                  RexInputRef leftFieldAccess = (RexInputRef) leftComparand;
+                  RexInputRef rightFieldAccess = (RexInputRef) rightComparand;
+                  if (leftFieldAccess.getIndex() >= leftFieldCount + rightFieldCount ||
+                      rightFieldAccess.getIndex() >= leftFieldCount + rightFieldCount) {
+                    joinConditions.clear();
+                    throw new Util.FoundOne(call);
+                  }
+                  /* Both columns reference same table */
+                  if ((leftFieldAccess.getIndex() >= leftFieldCount &&
+                      rightFieldAccess.getIndex() >= leftFieldCount) ||
+                          (leftFieldAccess.getIndex() < leftFieldCount &&
+                              rightFieldAccess.getIndex() < leftFieldCount)) {
+                    joinConditions.clear();
+                    throw new Util.FoundOne(call);
+                  } else {
+                    if (leftFieldAccess.getIndex() < leftFieldCount) {
+                      joinConditions.add(Pair.of(leftFieldAccess.getIndex(),
+                          rightFieldAccess.getIndex() - leftFieldCount));
+                    } else {
+                      joinConditions.add(Pair.of(rightFieldAccess.getIndex(),
+                          leftFieldAccess.getIndex() - leftFieldCount));
+                    }
+                  }
+                }
+              }
+              return null;
+            }
+          };
+      join.getCondition().accept(visitor);
+    } catch (Util.FoundOne ex) {
+      Util.swallow(ex, null);
+    }
+    return joinConditions;
+  }
+
+  public static List<RexInputRef> findAllRexInputRefs(final RexNode node) {
+    List<RexInputRef> rexRefs = new ArrayList<>();
+    RexVisitor<Void> visitor =
+            new RexVisitorImpl<Void>(true) {
+              public Void visitInputRef(RexInputRef inputRef) {
+                rexRefs.add(inputRef);
+                return super.visitInputRef(inputRef);
+              }
+            };
+    node.accept(visitor);
+    return rexRefs;
+  }
 }
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/common/DrillStatsTable.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/common/DrillStatsTable.java
index 78e87d0..7565abf 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/common/DrillStatsTable.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/common/DrillStatsTable.java
@@ -483,7 +483,7 @@ public class DrillStatsTable {
       Map<StatisticsKind, Object> statisticsValues = new HashMap<>();
       Double ndv = statsProvider.getNdv(fieldName);
       if (ndv != null) {
-        statisticsValues.put(ColumnStatisticsKind.NVD, ndv);
+        statisticsValues.put(ColumnStatisticsKind.NDV, ndv);
       }
       Double nonNullCount = statsProvider.getNNRowCount(fieldName);
       if (nonNullCount != null) {
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/cost/DrillRelMdDistinctRowCount.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/cost/DrillRelMdDistinctRowCount.java
index 37cd55a..8b11a9a 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/cost/DrillRelMdDistinctRowCount.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/cost/DrillRelMdDistinctRowCount.java
@@ -17,14 +17,18 @@
  */
 package org.apache.drill.exec.planner.cost;
 
+import java.io.IOException;
 import java.util.ArrayList;
+import java.util.HashSet;
 import java.util.List;
+import java.util.Set;
 import org.apache.calcite.plan.RelOptUtil;
 import org.apache.calcite.plan.volcano.RelSubset;
 import org.apache.calcite.rel.RelNode;
 import org.apache.calcite.rel.SingleRel;
 import org.apache.calcite.rel.core.Join;
 import org.apache.calcite.rel.core.JoinRelType;
+import org.apache.calcite.rel.core.TableScan;
 import org.apache.calcite.rel.core.Window;
 import org.apache.calcite.rel.metadata.ReflectiveRelMetadataProvider;
 import org.apache.calcite.rel.metadata.RelMdDistinctRowCount;
@@ -33,23 +37,25 @@ import org.apache.calcite.rel.metadata.RelMetadataProvider;
 import org.apache.calcite.rel.metadata.RelMetadataQuery;
 import org.apache.calcite.rel.type.RelDataType;
 import org.apache.calcite.rex.RexBuilder;
+import org.apache.calcite.rex.RexCall;
+import org.apache.calcite.rex.RexInputRef;
 import org.apache.calcite.rex.RexNode;
 import org.apache.calcite.rex.RexUtil;
+import org.apache.calcite.sql.SqlKind;
 import org.apache.calcite.util.BuiltInMethod;
 import org.apache.calcite.util.ImmutableBitSet;
 import org.apache.drill.common.expression.SchemaPath;
 import org.apache.drill.exec.planner.common.DrillJoinRelBase;
 import org.apache.drill.exec.planner.common.DrillRelOptUtil;
-import org.apache.drill.exec.planner.common.DrillScanRelBase;
 import org.apache.drill.exec.planner.logical.DrillScanRel;
 import org.apache.drill.exec.planner.logical.DrillTable;
+import org.apache.drill.exec.planner.physical.PlannerSettings;
+import org.apache.drill.exec.planner.physical.PrelUtil;
+import org.apache.drill.exec.util.Utilities;
 import org.apache.drill.metastore.ColumnStatistics;
 import org.apache.drill.metastore.ColumnStatisticsKind;
 import org.apache.drill.metastore.TableMetadata;
 
-import java.io.IOException;
-import org.apache.drill.exec.util.Utilities;
-
 public class DrillRelMdDistinctRowCount extends RelMdDistinctRowCount{
   private static final DrillRelMdDistinctRowCount INSTANCE =
       new DrillRelMdDistinctRowCount();
@@ -80,10 +86,10 @@ public class DrillRelMdDistinctRowCount extends RelMdDistinctRowCount{
 
   @Override
   public Double getDistinctRowCount(RelNode rel, RelMetadataQuery mq, ImmutableBitSet groupKey, RexNode predicate) {
-    if (rel instanceof DrillScanRelBase) {                  // Applies to both Drill Logical and Physical Rels
+    if (rel instanceof TableScan) {                   // Applies to Calcite/Drill logical and Drill physical rels
       if (!DrillRelOptUtil.guessRows(rel)) {
         DrillTable table = Utilities.getDrillTable(rel.getTable());
-        return getDistinctRowCountInternal(((DrillScanRelBase) rel), mq, table, groupKey, rel.getRowType(), predicate);
+        return getDistinctRowCountInternal(((TableScan) rel), mq, table, groupKey, rel.getRowType(), predicate);
       } else {
         /* If we are not using statistics OR there is no table or metadata (stats) table associated with scan,
          * estimate the distinct row count. Consistent with the estimation of Aggregate row count in
@@ -124,7 +130,7 @@ public class DrillRelMdDistinctRowCount extends RelMdDistinctRowCount{
    * set of columns indicated by groupKey.
    * column").
    */
-  private Double getDistinctRowCountInternal(DrillScanRelBase scan, RelMetadataQuery mq, DrillTable table,
+  private Double getDistinctRowCountInternal(TableScan scan, RelMetadataQuery mq, DrillTable table,
       ImmutableBitSet groupKey, RelDataType type, RexNode predicate) {
     double selectivity, rowCount;
     /* If predicate is present, determine its selectivity to estimate filtered rows.
@@ -136,7 +142,6 @@ public class DrillRelMdDistinctRowCount extends RelMdDistinctRowCount{
     if (groupKey.length() == 0) {
       return selectivity * rowCount;
     }
-
     /* If predicate is present, determine its selectivity to estimate filtered rows. Thereafter,
      * compute the number of distinct rows
      */
@@ -148,29 +153,39 @@ public class DrillRelMdDistinctRowCount extends RelMdDistinctRowCount{
       // Statistics cannot be obtained, use default behaviour
       return scan.estimateRowCount(mq) * 0.1;
     }
-    double s = 1.0;
 
+    double s = 1.0;
+    boolean allCols = true;
     for (int i = 0; i < groupKey.length(); i++) {
       final String colName = type.getFieldNames().get(i);
       // Skip NDV, if not available
       if (!groupKey.get(i)) {
-        continue;
+        allCols = false;
+        break;
       }
-      ColumnStatistics columnStatistics = tableMetadata != null ? tableMetadata.getColumnStatistics(SchemaPath.getSimplePath(colName)) : null;
-      Double ndv = columnStatistics != null ? (Double) columnStatistics.getStatistic(ColumnStatisticsKind.NVD) : null;
+      ColumnStatistics columnStatistics = tableMetadata != null ?
+          tableMetadata.getColumnStatistics(SchemaPath.getSimplePath(colName)) : null;
+      Double ndv = columnStatistics != null ? (Double) columnStatistics.getStatistic(ColumnStatisticsKind.NDV) : null;
       if (ndv == null) {
         continue;
       }
-      s *= 1 - ndv / rowCount;
+      s *= ndv;
+      selectivity = getPredSelectivityContainingInputRef(predicate, i, mq, scan);
+      /* If predicate is on group-by column, scale down the NDV by selectivity. Consider the query
+       * select a, b from t where a = 10 group by a, b. Here, NDV(a) will be scaled down by SEL(a)
+       * whereas NDV(b) will not.
+       */
+      if (selectivity > 0) {
+        s *= selectivity;
+      }
     }
-    if (s > 0 && s < 1.0) {
-      return (1 - s) * selectivity * rowCount;
-    } else if (s == 1.0) {
+    s = Math.min(s, rowCount);
+    if (!allCols) {
       // Could not get any NDV estimate from stats - probably stats not present for GBY cols. So Guess!
       return scan.estimateRowCount(mq) * 0.1;
     } else {
-      /* rowCount maybe less than NDV(different source), sanity check OR NDV not used at all */
-      return selectivity * rowCount;
+    /* rowCount maybe less than NDV(different source), sanity check OR NDV not used at all */
+      return s;
     }
   }
 
@@ -195,7 +210,7 @@ public class DrillRelMdDistinctRowCount extends RelMdDistinctRowCount{
     if (predicate != null) {
       List<RexNode> leftFilters = new ArrayList<>();
       List<RexNode> rightFilters = new ArrayList<>();
-      List<RexNode> joinFilters = new ArrayList<>();
+      List<RexNode> joinFilters = new ArrayList();
       List<RexNode> predList = RelOptUtil.conjunctions(predicate);
       RelOptUtil.classifyFilters(joinRel, predList, joinType, joinType == JoinRelType.INNER,
           !joinType.generatesNullsOnLeft(), !joinType.generatesNullsOnRight(), joinFilters,
@@ -205,28 +220,122 @@ public class DrillRelMdDistinctRowCount extends RelMdDistinctRowCount{
       rightPred = RexUtil.composeConjunction(rexBuilder, rightFilters, true);
     }
 
-    Double leftDistRowCount = null;
-    Double rightDistRowCount = null;
     double distRowCount = 1;
-    ImmutableBitSet lmb = leftMask.build();
-    ImmutableBitSet rmb = rightMask.build();
-    // Get NDV estimates for the left and right side predicates, if applicable
-    if (lmb.length() > 0) {
-      leftDistRowCount = mq.getDistinctRowCount(left, lmb, leftPred);
-      if (leftDistRowCount != null) {
-        distRowCount = leftDistRowCount;
+    int gbyCols = 0;
+    PlannerSettings plannerSettings = PrelUtil.getPlannerSettings(joinRel.getCluster().getPlanner());
+    /*
+     * The NDV for a multi-column GBY key past a join is determined as follows:
+     * GBY(s1, s2, s3) = CNDV(s1)*CNDV(s2)*CNDV(s3)
+     * where CNDV is determined as follows:
+     * A) If sX is present as a join column (sX = tX) CNDV(sX) = MIN(NDV(sX), NDV(tX)) where X =1, 2, 3, etc
+     * B) Otherwise, based on independence assumption CNDV(sX) = NDV(sX)
+     */
+    Set<ImmutableBitSet> joinFiltersSet = new HashSet<>();
+    for (RexNode filter : RelOptUtil.conjunctions(joinRel.getCondition())) {
+      final RelOptUtil.InputFinder inputFinder = RelOptUtil.InputFinder.analyze(filter);
+      joinFiltersSet.add(inputFinder.inputBitSet.build());
+    }
+    for (int idx = 0; idx < groupKey.length(); idx++) {
+      if (groupKey.get(idx)) {
+        // GBY key is present in some filter - now try options A) and B) as described above
+        double ndvSGby = Double.MAX_VALUE;
+        boolean presentInFilter = false;
+        ImmutableBitSet sGby = getSingleGbyKey(groupKey, idx);
+        if (sGby != null) {
+          for (ImmutableBitSet jFilter : joinFiltersSet) {
+            if (jFilter.contains(sGby)) {
+              presentInFilter = true;
+              // Found join condition containing this GBY key. Pick min NDV across all columns in this join
+              for (int fidx : jFilter) {
+                if (fidx < left.getRowType().getFieldCount()) {
+                  ndvSGby = Math.min(ndvSGby, mq.getDistinctRowCount(left, ImmutableBitSet.of(fidx), leftPred));
+                } else {
+                  ndvSGby = Math.min(ndvSGby, mq.getDistinctRowCount(right, ImmutableBitSet.of(fidx-left.getRowType().getFieldCount()), rightPred));
+                }
+              }
+              break;
+            }
+          }
+          // Did not find it in any join condition(s)
+          if (!presentInFilter) {
+            for (int sidx : sGby) {
+              if (sidx < left.getRowType().getFieldCount()) {
+                ndvSGby = mq.getDistinctRowCount(left, ImmutableBitSet.of(sidx), leftPred);
+              } else {
+                ndvSGby = mq.getDistinctRowCount(right, ImmutableBitSet.of(sidx-left.getRowType().getFieldCount()), rightPred);
+              }
+            }
+          }
+          ++gbyCols;
+          // Multiply NDV(s) of different GBY cols to determine the overall NDV
+          distRowCount *= ndvSGby;
+        }
       }
     }
-    if (rmb.length() > 0) {
-      rightDistRowCount = mq.getDistinctRowCount(right, rmb, rightPred);
-      if (rightDistRowCount != null) {
-        distRowCount = rightDistRowCount;
+    if (gbyCols > 1) { // Scale with multi-col NDV factor if more than one GBY cols were found
+      distRowCount *= plannerSettings.getStatisticsMultiColNdvAdjustmentFactor();
+    }
+    double joinRowCount = mq.getRowCount(joinRel);
+    // Cap NDV to join row count
+    distRowCount = Math.min(distRowCount, joinRowCount);
+    return RelMdUtil.numDistinctVals(distRowCount, joinRowCount);
+  }
+
+  private ImmutableBitSet getSingleGbyKey(ImmutableBitSet groupKey, int idx) {
+    if (groupKey.get(idx)) {
+      return ImmutableBitSet.builder().set(idx, idx+1).build();
+    } else {
+      return null;
+    }
+  }
+
+  private double getPredSelectivityContainingInputRef(RexNode predicate, int inputRef,
+      RelMetadataQuery mq, TableScan scan) {
+    if (predicate instanceof RexCall) {
+      if (predicate.getKind() == SqlKind.AND) {
+        double sel, andSel = 1.0;
+        for (RexNode op : ((RexCall) predicate).getOperands()) {
+          sel = getPredSelectivityContainingInputRef(op, inputRef, mq, scan);
+          if (sel > 0) {
+            andSel *= sel;
+          }
+        }
+        return andSel;
+      } else if (predicate.getKind() == SqlKind.OR) {
+        double sel, orSel = 0.0;
+        for (RexNode op : ((RexCall) predicate).getOperands()) {
+          sel = getPredSelectivityContainingInputRef(op, inputRef, mq, scan);
+          if (sel > 0) {
+            orSel += sel;
+          }
+        }
+        return orSel;
+      } else {
+        for (RexNode op : ((RexCall) predicate).getOperands()) {
+          if (op instanceof RexInputRef && inputRef != ((RexInputRef) op).getIndex()) {
+            return -1.0;
+          }
+        }
+        return mq.getSelectivity(scan, predicate);
       }
+    } else {
+      return -1.0;
     }
-    // Use max of NDVs from both sides of the join, if applicable
-    if (leftDistRowCount != null && rightDistRowCount != null) {
-      distRowCount = Math.max(leftDistRowCount, rightDistRowCount);
+  }
+
+  @Override
+  public Double getDistinctRowCount(RelSubset rel, RelMetadataQuery mq,
+      ImmutableBitSet groupKey, RexNode predicate) {
+    if (!DrillRelOptUtil.guessRows(rel)) {
+      final RelNode best = rel.getBest();
+      if (best != null) {
+        return mq.getDistinctRowCount(best, groupKey, predicate);
+      }
+      final RelNode original = rel.getOriginal();
+      if (original != null) {
+        return mq.getDistinctRowCount(original, groupKey, predicate);
+      }
     }
-    return RelMdUtil.numDistinctVals(distRowCount, mq.getRowCount(joinRel));
+    return super.getDistinctRowCount(rel, mq, groupKey, predicate);
   }
 }
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/cost/DrillRelMdRowCount.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/cost/DrillRelMdRowCount.java
index 814a96d..b65c582 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/cost/DrillRelMdRowCount.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/cost/DrillRelMdRowCount.java
@@ -18,7 +18,6 @@
 package org.apache.drill.exec.planner.cost;
 
 import java.io.IOException;
-import org.apache.calcite.rel.RelNode;
 import org.apache.calcite.rel.SingleRel;
 import org.apache.calcite.rel.core.Aggregate;
 import org.apache.calcite.rel.core.Filter;
@@ -35,7 +34,9 @@ import org.apache.calcite.util.BuiltInMethod;
 import org.apache.calcite.util.ImmutableBitSet;
 import org.apache.drill.exec.planner.common.DrillLimitRelBase;
 import org.apache.drill.exec.planner.common.DrillRelOptUtil;
+import org.apache.drill.exec.planner.common.DrillScanRelBase;
 import org.apache.drill.exec.planner.logical.DrillTable;
+import org.apache.drill.exec.planner.physical.AggPrelBase;
 import org.apache.drill.exec.planner.physical.PlannerSettings;
 import org.apache.drill.exec.planner.physical.PrelUtil;
 import org.apache.drill.exec.util.Utilities;
@@ -53,7 +54,21 @@ public class DrillRelMdRowCount extends RelMdRowCount{
 
     if (groupKey.isEmpty()) {
       return 1.0;
-    } else {
+    } else if (!DrillRelOptUtil.guessRows(rel) &&
+          rel instanceof AggPrelBase &&
+          ((AggPrelBase) rel).getOperatorPhase() == AggPrelBase.OperatorPhase.PHASE_1of2) {
+      // Phase 1 Aggregate would return rows in the range [NDV, input_rows]. Hence, use the
+      // existing estimate of 1/10 * input_rows
+        Double rowCount = mq.getRowCount(rel.getInput()) / 10;
+        Double ndv = mq.getDistinctRowCount(rel.getInput(), groupKey, null);
+        // Use max of NDV and input_rows/10
+        if (ndv != null) {
+          rowCount = Math.max(ndv, rowCount);
+        }
+        // Grouping sets multiply
+        rowCount *= rel.getGroupSets().size();
+        return rowCount;
+      } else {
       return super.getRowCount(rel, mq);
     }
   }
@@ -88,20 +103,13 @@ public class DrillRelMdRowCount extends RelMdRowCount{
   }
 
   @Override
-  public Double getRowCount(RelNode rel, RelMetadataQuery mq) {
-    if (rel instanceof TableScan) {
-      return getRowCountInternal((TableScan)rel, mq);
-    }
-    return super.getRowCount(rel, mq);
-  }
-
-  @Override
   public Double getRowCount(Filter rel, RelMetadataQuery mq) {
     // Need capped selectivity estimates. See the Filter getRows() method
     return rel.getRows();
   }
 
-  private Double getRowCountInternal(TableScan rel, RelMetadataQuery mq) {
+  @Override
+  public Double getRowCount(TableScan rel, RelMetadataQuery mq) {
     DrillTable table = Utilities.getDrillTable(rel.getTable());
     PlannerSettings settings = PrelUtil.getSettings(rel.getCluster());
     // If guessing, return selectivity from RelMDRowCount
@@ -112,14 +120,19 @@ public class DrillRelMdRowCount extends RelMdRowCount{
     try {
       if (table != null
           && table.getGroupScan().getTableMetadata() != null
-          && (boolean) TableStatisticsKind.HAS_STATISTICS.getValue(table.getGroupScan().getTableMetadata())
+          && (boolean) TableStatisticsKind.HAS_STATISTICS.getValue(table.getGroupScan().getTableMetadata())) {
           /* For GroupScan rely on accurate count from the scan, if available, instead of
            * statistics since partition pruning/filter pushdown might have occurred.
            * e.g. ParquetGroupScan returns accurate rowcount. The other way would be to
            * iterate over the rowgroups present in the GroupScan to compute the rowcount.
            */
-          && !(table.getGroupScan().getScanStats(settings).getGroupScanProperty().hasExactRowCount())) {
-        return (Double) TableStatisticsKind.EST_ROW_COUNT.getValue(table.getGroupScan().getTableMetadata());
+        if (!table.getGroupScan().getScanStats(settings).getGroupScanProperty().hasExactRowCount()) {
+          return (Double) TableStatisticsKind.EST_ROW_COUNT.getValue(table.getGroupScan().getTableMetadata());
+        } else {
+          if (!(rel instanceof DrillScanRelBase)) {
+            return table.getGroupScan().getScanStats(settings).getRecordCount();
+          }
+        }
       }
     } catch (IOException ex) {
       return super.getRowCount(rel, mq);
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/cost/DrillRelMdSelectivity.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/cost/DrillRelMdSelectivity.java
index d04b19e..533cd1e 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/cost/DrillRelMdSelectivity.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/cost/DrillRelMdSelectivity.java
@@ -19,12 +19,12 @@ package org.apache.drill.exec.planner.cost;
 
 import java.io.IOException;
 import java.util.ArrayList;
+import java.util.HashSet;
 import java.util.List;
 import java.util.EnumSet;
 import java.util.Set;
 import java.util.Map;
 import java.util.HashMap;
-import java.util.HashSet;
 
 import java.util.stream.Collectors;
 
@@ -180,7 +180,26 @@ public class DrillRelMdSelectivity extends RelMdSelectivity {
       double orSel = 0;
       for (RexNode orPred : RelOptUtil.disjunctions(pred)) {
         if (isMultiColumnPredicate(orPred) && !combinedRangePredicates.contains(orPred)) {
-          orSel += RelMdUtil.guessSelectivity(orPred);  //CALCITE guess
+          Set uniqueRefs = new HashSet<>();
+          uniqueRefs.add(DrillRelOptUtil.findAllRexInputRefs(orPred));
+          // If equality predicate involving single column - selectivity is 1.0
+          if (uniqueRefs.size() == 1) {
+            try {
+              RexVisitor<Void> visitor =
+                      new RexVisitorImpl<Void>(true) {
+                        public Void visitCall(RexCall call) {
+                          if (call.getKind() != SqlKind.EQUALS) {
+                            throw new Util.FoundOne(call);
+                          }
+                          return super.visitCall(call);
+                        }
+                      };
+              pred.accept(visitor);
+              orSel += 1.0;
+            } catch (Util.FoundOne e) {
+              orSel += RelMdUtil.guessSelectivity(orPred);  //CALCITE guess
+            }
+          }
         } else if (orPred.isA(SqlKind.EQUALS)) {
           orSel += computeEqualsSelectivity(tableMetadata, orPred, fieldNames);
         } else if (orPred.isA(RANGE_PREDICATE) || combinedRangePredicates.contains(orPred)) {
@@ -276,7 +295,7 @@ public class DrillRelMdSelectivity extends RelMdSelectivity {
     SchemaPath col = getColumn(orPred, fieldNames);
     if (col != null) {
       ColumnStatistics columnStatistics = tableMetadata != null ? tableMetadata.getColumnStatistics(col) : null;
-      Double ndv = columnStatistics != null ? (Double) columnStatistics.getStatistic(ColumnStatisticsKind.NVD) : null;
+      Double ndv = columnStatistics != null ? (Double) columnStatistics.getStatistic(ColumnStatisticsKind.NDV) : null;
       if (ndv != null) {
         return 1.00 / ndv;
       }
@@ -423,19 +442,6 @@ public class DrillRelMdSelectivity extends RelMdSelectivity {
   }
 
   private boolean isMultiColumnPredicate(final RexNode node) {
-    return findAllRexInputRefs(node).size() > 1;
-  }
-
-  private static List<RexInputRef> findAllRexInputRefs(final RexNode node) {
-    List<RexInputRef> rexRefs = new ArrayList<>();
-      RexVisitor<Void> visitor =
-          new RexVisitorImpl<Void>(true) {
-            public Void visitInputRef(RexInputRef inputRef) {
-              rexRefs.add(inputRef);
-              return super.visitInputRef(inputRef);
-            }
-          };
-      node.accept(visitor);
-      return rexRefs;
+    return DrillRelOptUtil.findAllRexInputRefs(node).size() > 1;
   }
 }
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/PlannerSettings.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/PlannerSettings.java
index a5506d4..ee10a15 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/PlannerSettings.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/PlannerSettings.java
@@ -231,6 +231,8 @@ public class PlannerSettings implements Context{
 
   public static final BooleanValidator STATISTICS_USE = new BooleanValidator("planner.statistics.use", null);
 
+  public static final RangeDoubleValidator STATISTICS_MULTICOL_NDV_ADJUST_FACTOR = new RangeDoubleValidator("planner.statistics.multicol_ndv_adjustment_factor", 0.0, 1.0, null);
+
   public OptionManager options = null;
   public FunctionImplementationRegistry functionImplementationRegistry = null;
 
@@ -475,6 +477,10 @@ public class PlannerSettings implements Context{
     return options.getOption(STATISTICS_USE);
   }
 
+  public double getStatisticsMultiColNdvAdjustmentFactor() {
+    return options.getOption(STATISTICS_MULTICOL_NDV_ADJUST_FACTOR);
+  }
+
   @Override
   public <T> T unwrap(Class<T> clazz) {
     if(clazz == PlannerSettings.class){
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/server/options/SystemOptionManager.java b/exec/java-exec/src/main/java/org/apache/drill/exec/server/options/SystemOptionManager.java
index 61fefe7..4c28887 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/server/options/SystemOptionManager.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/server/options/SystemOptionManager.java
@@ -121,6 +121,7 @@ public class SystemOptionManager extends BaseOptionManager implements AutoClosea
       new OptionDefinition(PlannerSettings.ENABLE_UNNEST_LATERAL),
       new OptionDefinition(PlannerSettings.FORCE_2PHASE_AGGR), // for testing
       new OptionDefinition(PlannerSettings.STATISTICS_USE),
+      new OptionDefinition(PlannerSettings.STATISTICS_MULTICOL_NDV_ADJUST_FACTOR),
       new OptionDefinition(ExecConstants.HASHJOIN_NUM_PARTITIONS_VALIDATOR),
       new OptionDefinition(ExecConstants.HASHJOIN_MAX_MEMORY_VALIDATOR, new OptionMetaData(OptionValue.AccessibleScopes.SYSTEM, true, true)),
       new OptionDefinition(ExecConstants.HASHJOIN_NUM_ROWS_IN_BATCH_VALIDATOR),
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/util/Utilities.java b/exec/java-exec/src/main/java/org/apache/drill/exec/util/Utilities.java
index 2446ba7..f440cf1 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/util/Utilities.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/util/Utilities.java
@@ -105,7 +105,7 @@ public class Utilities {
    */
   public static DrillTable getDrillTable(RelOptTable table) {
     DrillTable drillTable = table.unwrap(DrillTable.class);
-    if (drillTable == null) {
+    if (drillTable == null && table.unwrap(DrillTranslatableTable.class) != null) {
       drillTable = table.unwrap(DrillTranslatableTable.class).getDrillTable();
     }
     return drillTable;
diff --git a/exec/java-exec/src/main/resources/drill-module.conf b/exec/java-exec/src/main/resources/drill-module.conf
index 5096680..ea254a5 100644
--- a/exec/java-exec/src/main/resources/drill-module.conf
+++ b/exec/java-exec/src/main/resources/drill-module.conf
@@ -591,6 +591,7 @@ drill.exec.options: {
     planner.producer_consumer_queue_size: 10,
     planner.slice_target: 100000,
     planner.statistics.use: false,
+    planner.statistics.multicol_ndv_adjustment_factor: 1.0,
     planner.store.parquet.rowgroup.filter.pushdown.enabled: true,
     planner.store.parquet.rowgroup.filter.pushdown.threshold: 10000,
     # Max per node should always be configured as zero and
diff --git a/exec/java-exec/src/test/java/org/apache/drill/exec/sql/TestAnalyze.java b/exec/java-exec/src/test/java/org/apache/drill/exec/sql/TestAnalyze.java
index 94583de..1d404e1 100644
--- a/exec/java-exec/src/test/java/org/apache/drill/exec/sql/TestAnalyze.java
+++ b/exec/java-exec/src/test/java/org/apache/drill/exec/sql/TestAnalyze.java
@@ -258,14 +258,14 @@ public class TestAnalyze extends BaseTestQuery {
 
     query = " select emp.employee_id from dfs.tmp.employeeUseStat emp join dfs.tmp.departmentUseStat dept"
         + " on emp.department_id = dept.department_id";
-    String[] expectedPlan4 = {"HashJoin\\(condition.*\\).*rowcount = 1154.9999999999995,.*",
+    String[] expectedPlan4 = {"HashJoin\\(condition.*\\).*rowcount = 1155.0,.*",
             "Scan.*columns=\\[`department_id`, `employee_id`\\].*rowcount = 1155.0.*",
             "Scan.*columns=\\[`department_id`\\].*rowcount = 12.0.*"};
     PlanTestBase.testPlanWithAttributesMatchingPatterns(query, expectedPlan4, new String[]{});
 
     query = " select emp.employee_id from dfs.tmp.employeeUseStat emp join dfs.tmp.departmentUseStat dept"
             + " on emp.department_id = dept.department_id where dept.department_id = 5";
-    String[] expectedPlan5 = {"HashJoin\\(condition.*\\).*rowcount = 96.24999999999997,.*",
+    String[] expectedPlan5 = {"HashJoin\\(condition.*\\).*rowcount = 96.25,.*",
             "Scan.*columns=\\[`department_id`, `employee_id`\\].*rowcount = 1155.0.*",
             "Scan.*columns=\\[`department_id`\\].*rowcount = 12.0.*"};
     PlanTestBase.testPlanWithAttributesMatchingPatterns(query, expectedPlan5, new String[]{});
@@ -290,8 +290,8 @@ public class TestAnalyze extends BaseTestQuery {
     query = " select emp.employee_id from dfs.tmp.employeeUseStat emp join dfs.tmp.departmentUseStat dept"
             + " on emp.department_id = dept.department_id "
             + " group by emp.employee_id";
-    String[] expectedPlan8 = {"HashAgg\\(group=\\[\\{0\\}\\]\\).*rowcount = 730.0992454469839,.*",
-            "HashJoin\\(condition.*\\).*rowcount = 1154.9999999999995,.*",
+    String[] expectedPlan8 = {"HashAgg\\(group=\\[\\{0\\}\\]\\).*rowcount = 115.49475630811243,.*",
+            "HashJoin\\(condition.*\\).*rowcount = 1155.0,.*",
             "Scan.*columns=\\[`department_id`, `employee_id`\\].*rowcount = 1155.0.*",
             "Scan.*columns=\\[`department_id`\\].*rowcount = 12.0.*"};
     PlanTestBase.testPlanWithAttributesMatchingPatterns(query, expectedPlan8, new String[]{});
@@ -301,8 +301,8 @@ public class TestAnalyze extends BaseTestQuery {
             + " on emp.department_id = dept.department_id "
             + " group by emp.employee_id, emp.store_id, dept.department_description "
             + " having dept.department_description = 'FINANCE'";
-    String[] expectedPlan9 = {"HashAgg\\(group=\\[\\{0, 1, 2\\}\\]\\).*rowcount = 92.3487011031316.*",
-            "HashJoin\\(condition.*\\).*rowcount = 96.24999999999997,.*",
+    String[] expectedPlan9 = {"HashAgg\\(group=\\[\\{0, 1, 2\\}\\]\\).*rowcount = 60.84160378724867.*",
+            "HashJoin\\(condition.*\\).*rowcount = 96.25,.*",
             "Scan.*columns=\\[`department_id`, `employee_id`, `store_id`\\].*rowcount = 1155.0.*",
             "Filter\\(condition=\\[=\\(\\$1, 'FINANCE'\\)\\]\\).*rowcount = 1.0,.*",
             "Scan.*columns=\\[`department_id`, `department_description`\\].*rowcount = 12.0.*"};
diff --git a/metastore/metastore-api/src/main/java/org/apache/drill/metastore/ColumnStatisticsKind.java b/metastore/metastore-api/src/main/java/org/apache/drill/metastore/ColumnStatisticsKind.java
index f1c8196..51195f4 100644
--- a/metastore/metastore-api/src/main/java/org/apache/drill/metastore/ColumnStatisticsKind.java
+++ b/metastore/metastore-api/src/main/java/org/apache/drill/metastore/ColumnStatisticsKind.java
@@ -145,7 +145,7 @@ public enum ColumnStatisticsKind implements CollectableColumnStatisticsKind {
   /**
    * Column statistics kind which represents number of distinct values for the specific column.
    */
-  NVD(Statistic.NDV) {
+  NDV(Statistic.NDV) {
     @Override
     public Object mergeStatistics(List<? extends ColumnStatistics> statisticsList) {
       throw new UnsupportedOperationException("Cannot merge statistics for NDV");