You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@impala.apache.org by kw...@apache.org on 2017/12/12 23:51:51 UTC

[2/4] impala git commit: IMPALA-5310: Part 2: Add SAMPLED_NDV() function.

http://git-wip-us.apache.org/repos/asf/impala/blob/0936e329/fe/src/main/java/org/apache/impala/catalog/BuiltinsDb.java
----------------------------------------------------------------------
diff --git a/fe/src/main/java/org/apache/impala/catalog/BuiltinsDb.java b/fe/src/main/java/org/apache/impala/catalog/BuiltinsDb.java
index 07699d3..f316410 100644
--- a/fe/src/main/java/org/apache/impala/catalog/BuiltinsDb.java
+++ b/fe/src/main/java/org/apache/impala/catalog/BuiltinsDb.java
@@ -22,7 +22,6 @@ import java.util.Collections;
 import java.util.Map;
 
 import org.apache.hadoop.hive.metastore.api.Database;
-
 import org.apache.impala.analysis.ArithmeticExpr;
 import org.apache.impala.analysis.BinaryPredicate;
 import org.apache.impala.analysis.CaseExpr;
@@ -32,6 +31,7 @@ import org.apache.impala.analysis.InPredicate;
 import org.apache.impala.analysis.IsNullPredicate;
 import org.apache.impala.analysis.LikePredicate;
 import org.apache.impala.builtins.ScalarBuiltins;
+
 import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.Lists;
 
@@ -304,6 +304,30 @@ public class BuiltinsDb extends Db {
             "9HllUpdateIN10impala_udf10DecimalValEEEvPNS2_15FunctionContextERKT_PNS2_9StringValE")
         .build();
 
+  private static final Map<Type, String> SAMPLED_NDV_UPDATE_SYMBOL =
+      ImmutableMap.<Type, String>builder()
+        .put(Type.BOOLEAN,
+            "16SampledNdvUpdateIN10impala_udf10BooleanValEEEvPNS2_15FunctionContextERKT_RKNS2_9DoubleValEPNS2_9StringValE")
+        .put(Type.TINYINT,
+            "16SampledNdvUpdateIN10impala_udf10TinyIntValEEEvPNS2_15FunctionContextERKT_RKNS2_9DoubleValEPNS2_9StringValE")
+        .put(Type.SMALLINT,
+            "16SampledNdvUpdateIN10impala_udf11SmallIntValEEEvPNS2_15FunctionContextERKT_RKNS2_9DoubleValEPNS2_9StringValE")
+        .put(Type.INT,
+            "16SampledNdvUpdateIN10impala_udf6IntValEEEvPNS2_15FunctionContextERKT_RKNS2_9DoubleValEPNS2_9StringValE")
+        .put(Type.BIGINT,
+            "16SampledNdvUpdateIN10impala_udf9BigIntValEEEvPNS2_15FunctionContextERKT_RKNS2_9DoubleValEPNS2_9StringValE")
+        .put(Type.FLOAT,
+            "16SampledNdvUpdateIN10impala_udf8FloatValEEEvPNS2_15FunctionContextERKT_RKNS2_9DoubleValEPNS2_9StringValE")
+        .put(Type.DOUBLE,
+            "16SampledNdvUpdateIN10impala_udf9DoubleValEEEvPNS2_15FunctionContextERKT_RKS3_PNS2_9StringValE")
+        .put(Type.STRING,
+            "16SampledNdvUpdateIN10impala_udf9StringValEEEvPNS2_15FunctionContextERKT_RKNS2_9DoubleValEPS3_")
+        .put(Type.TIMESTAMP,
+            "16SampledNdvUpdateIN10impala_udf12TimestampValEEEvPNS2_15FunctionContextERKT_RKNS2_9DoubleValEPNS2_9StringValE")
+        .put(Type.DECIMAL,
+            "16SampledNdvUpdateIN10impala_udf10DecimalValEEEvPNS2_15FunctionContextERKT_RKNS2_9DoubleValEPNS2_9StringValE")
+        .build();
+
   private static final Map<Type, String> PC_UPDATE_SYMBOL =
       ImmutableMap.<Type, String>builder()
         .put(Type.BOOLEAN,
@@ -788,6 +812,19 @@ public class BuiltinsDb extends Db {
           "_Z20IncrementNdvFinalizePN10impala_udf15FunctionContextERKNS_9StringValE",
           true, false, true));
 
+      // SAMPLED_NDV.
+      // Size needs to be kept in sync with SampledNdvState in the BE.
+      int NUM_HLL_BUCKETS = 32;
+      int size = 16 + NUM_HLL_BUCKETS * (8 + HLL_INTERMEDIATE_SIZE);
+      Type sampledIntermediateType = ScalarType.createFixedUdaIntermediateType(size);
+      db.addBuiltin(AggregateFunction.createBuiltin(db, "sampled_ndv",
+          Lists.newArrayList(t, Type.DOUBLE), Type.BIGINT, sampledIntermediateType,
+          prefix + "14SampledNdvInitEPN10impala_udf15FunctionContextEPNS1_9StringValE",
+          prefix + SAMPLED_NDV_UPDATE_SYMBOL.get(t),
+          prefix + "15SampledNdvMergeEPN10impala_udf15FunctionContextERKNS1_9StringValEPS4_",
+          null,
+          prefix + "18SampledNdvFinalizeEPN10impala_udf15FunctionContextERKNS1_9StringValE",
+          true, false, true));
 
       Type pcIntermediateType =
           ScalarType.createFixedUdaIntermediateType(PC_INTERMEDIATE_SIZE);

http://git-wip-us.apache.org/repos/asf/impala/blob/0936e329/fe/src/main/java/org/apache/impala/catalog/HdfsTable.java
----------------------------------------------------------------------
diff --git a/fe/src/main/java/org/apache/impala/catalog/HdfsTable.java b/fe/src/main/java/org/apache/impala/catalog/HdfsTable.java
index c2417f6..4de25fe 100644
--- a/fe/src/main/java/org/apache/impala/catalog/HdfsTable.java
+++ b/fe/src/main/java/org/apache/impala/catalog/HdfsTable.java
@@ -2132,7 +2132,6 @@ public class HdfsTable extends Table {
       parts[selectedIdx] = parts[numFilesRemaining - 1];
       --numFilesRemaining;
     }
-
     return result;
   }
 }

http://git-wip-us.apache.org/repos/asf/impala/blob/0936e329/fe/src/test/java/org/apache/impala/analysis/AnalyzeStmtsTest.java
----------------------------------------------------------------------
diff --git a/fe/src/test/java/org/apache/impala/analysis/AnalyzeStmtsTest.java b/fe/src/test/java/org/apache/impala/analysis/AnalyzeStmtsTest.java
index 54aa098..133b6e2 100644
--- a/fe/src/test/java/org/apache/impala/analysis/AnalyzeStmtsTest.java
+++ b/fe/src/test/java/org/apache/impala/analysis/AnalyzeStmtsTest.java
@@ -18,19 +18,23 @@
 package org.apache.impala.analysis;
 
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 
 import java.lang.reflect.Field;
 import java.util.List;
 
+import org.apache.impala.catalog.Column;
 import org.apache.impala.catalog.PrimitiveType;
 import org.apache.impala.catalog.ScalarType;
+import org.apache.impala.catalog.Table;
 import org.apache.impala.catalog.Type;
 import org.apache.impala.common.AnalysisException;
 import org.apache.impala.common.RuntimeEnv;
 import org.junit.Assert;
 import org.junit.Test;
 
+import com.google.common.base.Joiner;
 import com.google.common.base.Preconditions;
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.Lists;
@@ -2051,6 +2055,55 @@ public class AnalyzeStmtsTest extends AnalyzerTest {
   }
 
   @Test
+  public void TestSampledNdv() throws AnalysisException {
+    Table allScalarTypes = addAllScalarTypesTestTable();
+    String tblName = allScalarTypes.getFullName();
+
+    // Positive tests: Test all scalar types and valid sampling percents.
+    double validSamplePercs[] = new double[] { 0.0, 0.1, 0.2, 0.5, 0.8, 1.0 };
+    for (double perc: validSamplePercs) {
+      List<String> allAggFnCalls = Lists.newArrayList();
+      for (Column col: allScalarTypes.getColumns()) {
+        String aggFnCall = String.format("sampled_ndv(%s, %s)", col.getName(), perc);
+        allAggFnCalls.add(aggFnCall);
+        String stmtSql = String.format("select %s from %s", aggFnCall, tblName);
+        SelectStmt stmt = (SelectStmt) AnalyzesOk(stmtSql);
+        // Verify that the resolved function signature matches as expected.
+        Type[] args = stmt.getAggInfo().getAggregateExprs().get(0).getFn().getArgs();
+        assertEquals(args.length, 2);
+        assertTrue(col.getType().matchesType(args[0]) ||
+            col.getType().isStringType() && args[0].equals(Type.STRING));
+        assertEquals(Type.DOUBLE, args[1]);
+      }
+      // Test several calls in the same query block.
+      AnalyzesOk(String.format(
+          "select %s from %s", Joiner.on(",").join(allAggFnCalls), tblName));
+    }
+
+    // Negative tests: Incorrect number of args.
+    AnalysisError(
+        String.format("select sampled_ndv() from %s", tblName),
+        "No matching function with signature: sampled_ndv().");
+    AnalysisError(
+        String.format("select sampled_ndv(int_col) from %s", tblName),
+        "No matching function with signature: sampled_ndv(INT).");
+    AnalysisError(
+        String.format("select sampled_ndv(int_col, 0.1, 10) from %s", tblName),
+        "No matching function with signature: sampled_ndv(INT, DECIMAL(1,1), TINYINT).");
+
+    // Negative tests: Invalid sampling percent.
+    String invalidSamplePercs[] = new String[] {
+        "int_col", "double_col", "100 / 10", "-0.1", "1.1", "100", "50", "-50", "NULL"
+    };
+    for (String invalidPerc: invalidSamplePercs) {
+      AnalysisError(
+          String.format("select sampled_ndv(int_col, %s) from %s", invalidPerc, tblName),
+          "Second parameter of SAMPLED_NDV() must be a numeric literal in [0,1]: " +
+          invalidPerc);
+    }
+  }
+
+  @Test
   public void TestGroupConcat() throws AnalysisException {
     // Test valid and invalid parameters
     AnalyzesOk("select group_concat(distinct name) from functional.testtbl");

http://git-wip-us.apache.org/repos/asf/impala/blob/0936e329/fe/src/test/java/org/apache/impala/common/FrontendTestBase.java
----------------------------------------------------------------------
diff --git a/fe/src/test/java/org/apache/impala/common/FrontendTestBase.java b/fe/src/test/java/org/apache/impala/common/FrontendTestBase.java
index c014dff..6718cb4 100644
--- a/fe/src/test/java/org/apache/impala/common/FrontendTestBase.java
+++ b/fe/src/test/java/org/apache/impala/common/FrontendTestBase.java
@@ -58,8 +58,8 @@ import org.apache.impala.thrift.TFunctionBinaryType;
 import org.apache.impala.thrift.TQueryCtx;
 import org.apache.impala.thrift.TQueryOptions;
 import org.junit.After;
-import org.junit.Assert;
 import org.junit.AfterClass;
+import org.junit.Assert;
 import org.junit.BeforeClass;
 
 import com.google.common.base.Joiner;
@@ -233,6 +233,16 @@ public class FrontendTestBase {
     return dummyView;
   }
 
+  protected Table addAllScalarTypesTestTable() {
+    addTestDb("allscalartypesdb", "");
+    return addTestTable("create table allscalartypes (" +
+      "bool_col boolean, tinyint_col tinyint, smallint_col smallint, int_col int, " +
+      "bigint_col bigint, float_col float, double_col double, dec1 decimal(9,0), " +
+      "d2 decimal(10, 0), d3 decimal(20, 10), d4 decimal(38, 38), d5 decimal(10, 5), " +
+      "timestamp_col timestamp, string_col string, varchar_col varchar(50), " +
+      "char_col char (30))");
+  }
+
   protected void clearTestTables() {
     for (Table testTable: testTables_) {
       testTable.getDb().removeTable(testTable.getName());

http://git-wip-us.apache.org/repos/asf/impala/blob/0936e329/tests/query_test/test_aggregation.py
----------------------------------------------------------------------
diff --git a/tests/query_test/test_aggregation.py b/tests/query_test/test_aggregation.py
index 9e0be6d..233c33a 100644
--- a/tests/query_test/test_aggregation.py
+++ b/tests/query_test/test_aggregation.py
@@ -275,6 +275,75 @@ class TestAggregationQueries(ImpalaTestSuite):
     vector.get_value('exec_option')['batch_size'] = 1
     self.run_test_case('QueryTest/parquet-stats-agg', vector, unique_database)
 
+  def test_sampled_ndv(self, vector, unique_database):
+    """The SAMPLED_NDV() function is inherently non-deterministic and cannot be
+    reasonably made deterministic with existing options so we test it separately.
+    The goal of this test is to ensure that SAMPLED_NDV() works on all data types
+    and returns approximately sensible estimates. It is not the goal of this test
+    to ensure tight error bounds on the NDV estimates. SAMPLED_NDV() is expected
+    be inaccurate on small data sets like the ones we use in this test."""
+    if (vector.get_value('table_format').file_format != 'text' or
+        vector.get_value('table_format').compression_codec != 'none'):
+      # No need to run this test on all file formats
+      pytest.skip()
+
+    # NDV() is used a baseline to compare SAMPLED_NDV(). Both NDV() and SAMPLED_NDV()
+    # are based on HyperLogLog so NDV() is roughly the best that SAMPLED_NDV() can do.
+    # Expectations: All columns except 'id' and 'timestmap_col' have low NDVs and are
+    # expected to be reasonably accurate with SAMPLED_NDV(). For the two high-NDV columns
+    # SAMPLED_NDV() is expected to have high variance and error.
+    ndv_stmt = """
+        select ndv(bool_col), ndv(tinyint_col),
+               ndv(smallint_col), ndv(int_col),
+               ndv(bigint_col), ndv(float_col),
+               ndv(double_col), ndv(string_col),
+               ndv(cast(double_col as decimal(3, 0))),
+               ndv(cast(double_col as decimal(10, 5))),
+               ndv(cast(double_col as decimal(20, 10))),
+               ndv(cast(double_col as decimal(38, 35))),
+               ndv(cast(string_col as varchar(20))),
+               ndv(cast(string_col as char(10))),
+               ndv(timestamp_col), ndv(id)
+        from functional_parquet.alltypesagg"""
+    ndv_result = self.execute_query(ndv_stmt)
+    ndv_vals = ndv_result.data[0].split('\t')
+
+    for sample_perc in [0.1, 0.2, 0.5, 1.0]:
+      sampled_ndv_stmt = """
+          select sampled_ndv(bool_col, {0}), sampled_ndv(tinyint_col, {0}),
+                 sampled_ndv(smallint_col, {0}), sampled_ndv(int_col, {0}),
+                 sampled_ndv(bigint_col, {0}), sampled_ndv(float_col, {0}),
+                 sampled_ndv(double_col, {0}), sampled_ndv(string_col, {0}),
+                 sampled_ndv(cast(double_col as decimal(3, 0)), {0}),
+                 sampled_ndv(cast(double_col as decimal(10, 5)), {0}),
+                 sampled_ndv(cast(double_col as decimal(20, 10)), {0}),
+                 sampled_ndv(cast(double_col as decimal(38, 35)), {0}),
+                 sampled_ndv(cast(string_col as varchar(20)), {0}),
+                 sampled_ndv(cast(string_col as char(10)), {0}),
+                 sampled_ndv(timestamp_col, {0}), sampled_ndv(id, {0})
+          from functional_parquet.alltypesagg""".format(sample_perc)
+      sampled_ndv_result = self.execute_query(sampled_ndv_stmt)
+      sampled_ndv_vals = sampled_ndv_result.data[0].split('\t')
+
+      assert len(sampled_ndv_vals) == len(ndv_vals)
+      # Low NDV columns. We expect a reasonaby accurate estimate regardless of the
+      # sampling percent.
+      for i in xrange(0, 14):
+        self.__appx_equals(int(sampled_ndv_vals[i]), int(ndv_vals[i]), 0.1)
+      # High NDV columns. We expect the estimate to have high variance and error.
+      # Since we give NDV() and SAMPLED_NDV() the same input data, i.e., we are not
+      # actually sampling for SAMPLED_NDV(), we expect the result of SAMPLED_NDV() to
+      # be bigger than NDV() proportional to the sampling percent.
+      # For example, the column 'id' is a PK so we expect the result of SAMPLED_NDV()
+      # with a sampling percent of 0.1 to be approximately 10x of the NDV().
+      for i in xrange(14, 16):
+        self.__appx_equals(int(sampled_ndv_vals[i]) * sample_perc, int(ndv_vals[i]), 2.0)
+
+  def __appx_equals(self, a, b, diff_perc):
+    """Returns True if 'a' and 'b' are within 'diff_perc' percent of each other,
+    False otherwise. 'diff_perc' must be a float in [0,1]."""
+    assert abs(a - b) / float(max(a, b)) <= diff_perc
+
 class TestWideAggregationQueries(ImpalaTestSuite):
   """Test that aggregations with many grouping columns work"""
   @classmethod