You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@impala.apache.org by kw...@apache.org on 2017/12/12 23:51:51 UTC
[2/4] impala git commit: IMPALA-5310: Part 2: Add SAMPLED_NDV()
function.
http://git-wip-us.apache.org/repos/asf/impala/blob/0936e329/fe/src/main/java/org/apache/impala/catalog/BuiltinsDb.java
----------------------------------------------------------------------
diff --git a/fe/src/main/java/org/apache/impala/catalog/BuiltinsDb.java b/fe/src/main/java/org/apache/impala/catalog/BuiltinsDb.java
index 07699d3..f316410 100644
--- a/fe/src/main/java/org/apache/impala/catalog/BuiltinsDb.java
+++ b/fe/src/main/java/org/apache/impala/catalog/BuiltinsDb.java
@@ -22,7 +22,6 @@ import java.util.Collections;
import java.util.Map;
import org.apache.hadoop.hive.metastore.api.Database;
-
import org.apache.impala.analysis.ArithmeticExpr;
import org.apache.impala.analysis.BinaryPredicate;
import org.apache.impala.analysis.CaseExpr;
@@ -32,6 +31,7 @@ import org.apache.impala.analysis.InPredicate;
import org.apache.impala.analysis.IsNullPredicate;
import org.apache.impala.analysis.LikePredicate;
import org.apache.impala.builtins.ScalarBuiltins;
+
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
@@ -304,6 +304,30 @@ public class BuiltinsDb extends Db {
"9HllUpdateIN10impala_udf10DecimalValEEEvPNS2_15FunctionContextERKT_PNS2_9StringValE")
.build();
+ private static final Map<Type, String> SAMPLED_NDV_UPDATE_SYMBOL =
+ ImmutableMap.<Type, String>builder()
+ .put(Type.BOOLEAN,
+ "16SampledNdvUpdateIN10impala_udf10BooleanValEEEvPNS2_15FunctionContextERKT_RKNS2_9DoubleValEPNS2_9StringValE")
+ .put(Type.TINYINT,
+ "16SampledNdvUpdateIN10impala_udf10TinyIntValEEEvPNS2_15FunctionContextERKT_RKNS2_9DoubleValEPNS2_9StringValE")
+ .put(Type.SMALLINT,
+ "16SampledNdvUpdateIN10impala_udf11SmallIntValEEEvPNS2_15FunctionContextERKT_RKNS2_9DoubleValEPNS2_9StringValE")
+ .put(Type.INT,
+ "16SampledNdvUpdateIN10impala_udf6IntValEEEvPNS2_15FunctionContextERKT_RKNS2_9DoubleValEPNS2_9StringValE")
+ .put(Type.BIGINT,
+ "16SampledNdvUpdateIN10impala_udf9BigIntValEEEvPNS2_15FunctionContextERKT_RKNS2_9DoubleValEPNS2_9StringValE")
+ .put(Type.FLOAT,
+ "16SampledNdvUpdateIN10impala_udf8FloatValEEEvPNS2_15FunctionContextERKT_RKNS2_9DoubleValEPNS2_9StringValE")
+ .put(Type.DOUBLE,
+ "16SampledNdvUpdateIN10impala_udf9DoubleValEEEvPNS2_15FunctionContextERKT_RKS3_PNS2_9StringValE")
+ .put(Type.STRING,
+ "16SampledNdvUpdateIN10impala_udf9StringValEEEvPNS2_15FunctionContextERKT_RKNS2_9DoubleValEPS3_")
+ .put(Type.TIMESTAMP,
+ "16SampledNdvUpdateIN10impala_udf12TimestampValEEEvPNS2_15FunctionContextERKT_RKNS2_9DoubleValEPNS2_9StringValE")
+ .put(Type.DECIMAL,
+ "16SampledNdvUpdateIN10impala_udf10DecimalValEEEvPNS2_15FunctionContextERKT_RKNS2_9DoubleValEPNS2_9StringValE")
+ .build();
+
private static final Map<Type, String> PC_UPDATE_SYMBOL =
ImmutableMap.<Type, String>builder()
.put(Type.BOOLEAN,
@@ -788,6 +812,19 @@ public class BuiltinsDb extends Db {
"_Z20IncrementNdvFinalizePN10impala_udf15FunctionContextERKNS_9StringValE",
true, false, true));
+ // SAMPLED_NDV.
+ // Size needs to be kept in sync with SampledNdvState in the BE.
+ int NUM_HLL_BUCKETS = 32;
+ int size = 16 + NUM_HLL_BUCKETS * (8 + HLL_INTERMEDIATE_SIZE);
+ Type sampledIntermediateType = ScalarType.createFixedUdaIntermediateType(size);
+ db.addBuiltin(AggregateFunction.createBuiltin(db, "sampled_ndv",
+ Lists.newArrayList(t, Type.DOUBLE), Type.BIGINT, sampledIntermediateType,
+ prefix + "14SampledNdvInitEPN10impala_udf15FunctionContextEPNS1_9StringValE",
+ prefix + SAMPLED_NDV_UPDATE_SYMBOL.get(t),
+ prefix + "15SampledNdvMergeEPN10impala_udf15FunctionContextERKNS1_9StringValEPS4_",
+ null,
+ prefix + "18SampledNdvFinalizeEPN10impala_udf15FunctionContextERKNS1_9StringValE",
+ true, false, true));
Type pcIntermediateType =
ScalarType.createFixedUdaIntermediateType(PC_INTERMEDIATE_SIZE);
http://git-wip-us.apache.org/repos/asf/impala/blob/0936e329/fe/src/main/java/org/apache/impala/catalog/HdfsTable.java
----------------------------------------------------------------------
diff --git a/fe/src/main/java/org/apache/impala/catalog/HdfsTable.java b/fe/src/main/java/org/apache/impala/catalog/HdfsTable.java
index c2417f6..4de25fe 100644
--- a/fe/src/main/java/org/apache/impala/catalog/HdfsTable.java
+++ b/fe/src/main/java/org/apache/impala/catalog/HdfsTable.java
@@ -2132,7 +2132,6 @@ public class HdfsTable extends Table {
parts[selectedIdx] = parts[numFilesRemaining - 1];
--numFilesRemaining;
}
-
return result;
}
}
http://git-wip-us.apache.org/repos/asf/impala/blob/0936e329/fe/src/test/java/org/apache/impala/analysis/AnalyzeStmtsTest.java
----------------------------------------------------------------------
diff --git a/fe/src/test/java/org/apache/impala/analysis/AnalyzeStmtsTest.java b/fe/src/test/java/org/apache/impala/analysis/AnalyzeStmtsTest.java
index 54aa098..133b6e2 100644
--- a/fe/src/test/java/org/apache/impala/analysis/AnalyzeStmtsTest.java
+++ b/fe/src/test/java/org/apache/impala/analysis/AnalyzeStmtsTest.java
@@ -18,19 +18,23 @@
package org.apache.impala.analysis;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import java.lang.reflect.Field;
import java.util.List;
+import org.apache.impala.catalog.Column;
import org.apache.impala.catalog.PrimitiveType;
import org.apache.impala.catalog.ScalarType;
+import org.apache.impala.catalog.Table;
import org.apache.impala.catalog.Type;
import org.apache.impala.common.AnalysisException;
import org.apache.impala.common.RuntimeEnv;
import org.junit.Assert;
import org.junit.Test;
+import com.google.common.base.Joiner;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
@@ -2051,6 +2055,55 @@ public class AnalyzeStmtsTest extends AnalyzerTest {
}
@Test
+ public void TestSampledNdv() throws AnalysisException {
+ Table allScalarTypes = addAllScalarTypesTestTable();
+ String tblName = allScalarTypes.getFullName();
+
+ // Positive tests: Test all scalar types and valid sampling percents.
+ double validSamplePercs[] = new double[] { 0.0, 0.1, 0.2, 0.5, 0.8, 1.0 };
+ for (double perc: validSamplePercs) {
+ List<String> allAggFnCalls = Lists.newArrayList();
+ for (Column col: allScalarTypes.getColumns()) {
+ String aggFnCall = String.format("sampled_ndv(%s, %s)", col.getName(), perc);
+ allAggFnCalls.add(aggFnCall);
+ String stmtSql = String.format("select %s from %s", aggFnCall, tblName);
+ SelectStmt stmt = (SelectStmt) AnalyzesOk(stmtSql);
+ // Verify that the resolved function signature matches as expected.
+ Type[] args = stmt.getAggInfo().getAggregateExprs().get(0).getFn().getArgs();
+ assertEquals(args.length, 2);
+ assertTrue(col.getType().matchesType(args[0]) ||
+ col.getType().isStringType() && args[0].equals(Type.STRING));
+ assertEquals(Type.DOUBLE, args[1]);
+ }
+ // Test several calls in the same query block.
+ AnalyzesOk(String.format(
+ "select %s from %s", Joiner.on(",").join(allAggFnCalls), tblName));
+ }
+
+ // Negative tests: Incorrect number of args.
+ AnalysisError(
+ String.format("select sampled_ndv() from %s", tblName),
+ "No matching function with signature: sampled_ndv().");
+ AnalysisError(
+ String.format("select sampled_ndv(int_col) from %s", tblName),
+ "No matching function with signature: sampled_ndv(INT).");
+ AnalysisError(
+ String.format("select sampled_ndv(int_col, 0.1, 10) from %s", tblName),
+ "No matching function with signature: sampled_ndv(INT, DECIMAL(1,1), TINYINT).");
+
+ // Negative tests: Invalid sampling percent.
+ String invalidSamplePercs[] = new String[] {
+ "int_col", "double_col", "100 / 10", "-0.1", "1.1", "100", "50", "-50", "NULL"
+ };
+ for (String invalidPerc: invalidSamplePercs) {
+ AnalysisError(
+ String.format("select sampled_ndv(int_col, %s) from %s", invalidPerc, tblName),
+ "Second parameter of SAMPLED_NDV() must be a numeric literal in [0,1]: " +
+ invalidPerc);
+ }
+ }
+
+ @Test
public void TestGroupConcat() throws AnalysisException {
// Test valid and invalid parameters
AnalyzesOk("select group_concat(distinct name) from functional.testtbl");
http://git-wip-us.apache.org/repos/asf/impala/blob/0936e329/fe/src/test/java/org/apache/impala/common/FrontendTestBase.java
----------------------------------------------------------------------
diff --git a/fe/src/test/java/org/apache/impala/common/FrontendTestBase.java b/fe/src/test/java/org/apache/impala/common/FrontendTestBase.java
index c014dff..6718cb4 100644
--- a/fe/src/test/java/org/apache/impala/common/FrontendTestBase.java
+++ b/fe/src/test/java/org/apache/impala/common/FrontendTestBase.java
@@ -58,8 +58,8 @@ import org.apache.impala.thrift.TFunctionBinaryType;
import org.apache.impala.thrift.TQueryCtx;
import org.apache.impala.thrift.TQueryOptions;
import org.junit.After;
-import org.junit.Assert;
import org.junit.AfterClass;
+import org.junit.Assert;
import org.junit.BeforeClass;
import com.google.common.base.Joiner;
@@ -233,6 +233,16 @@ public class FrontendTestBase {
return dummyView;
}
+ protected Table addAllScalarTypesTestTable() {
+ addTestDb("allscalartypesdb", "");
+ return addTestTable("create table allscalartypes (" +
+ "bool_col boolean, tinyint_col tinyint, smallint_col smallint, int_col int, " +
+ "bigint_col bigint, float_col float, double_col double, dec1 decimal(9,0), " +
+ "d2 decimal(10, 0), d3 decimal(20, 10), d4 decimal(38, 38), d5 decimal(10, 5), " +
+ "timestamp_col timestamp, string_col string, varchar_col varchar(50), " +
+ "char_col char (30))");
+ }
+
protected void clearTestTables() {
for (Table testTable: testTables_) {
testTable.getDb().removeTable(testTable.getName());
http://git-wip-us.apache.org/repos/asf/impala/blob/0936e329/tests/query_test/test_aggregation.py
----------------------------------------------------------------------
diff --git a/tests/query_test/test_aggregation.py b/tests/query_test/test_aggregation.py
index 9e0be6d..233c33a 100644
--- a/tests/query_test/test_aggregation.py
+++ b/tests/query_test/test_aggregation.py
@@ -275,6 +275,75 @@ class TestAggregationQueries(ImpalaTestSuite):
vector.get_value('exec_option')['batch_size'] = 1
self.run_test_case('QueryTest/parquet-stats-agg', vector, unique_database)
+ def test_sampled_ndv(self, vector, unique_database):
+ """The SAMPLED_NDV() function is inherently non-deterministic and cannot be
+ reasonably made deterministic with existing options so we test it separately.
+ The goal of this test is to ensure that SAMPLED_NDV() works on all data types
+ and returns approximately sensible estimates. It is not the goal of this test
+ to ensure tight error bounds on the NDV estimates. SAMPLED_NDV() is expected
+ be inaccurate on small data sets like the ones we use in this test."""
+ if (vector.get_value('table_format').file_format != 'text' or
+ vector.get_value('table_format').compression_codec != 'none'):
+ # No need to run this test on all file formats
+ pytest.skip()
+
+ # NDV() is used a baseline to compare SAMPLED_NDV(). Both NDV() and SAMPLED_NDV()
+ # are based on HyperLogLog so NDV() is roughly the best that SAMPLED_NDV() can do.
+ # Expectations: All columns except 'id' and 'timestmap_col' have low NDVs and are
+ # expected to be reasonably accurate with SAMPLED_NDV(). For the two high-NDV columns
+ # SAMPLED_NDV() is expected to have high variance and error.
+ ndv_stmt = """
+ select ndv(bool_col), ndv(tinyint_col),
+ ndv(smallint_col), ndv(int_col),
+ ndv(bigint_col), ndv(float_col),
+ ndv(double_col), ndv(string_col),
+ ndv(cast(double_col as decimal(3, 0))),
+ ndv(cast(double_col as decimal(10, 5))),
+ ndv(cast(double_col as decimal(20, 10))),
+ ndv(cast(double_col as decimal(38, 35))),
+ ndv(cast(string_col as varchar(20))),
+ ndv(cast(string_col as char(10))),
+ ndv(timestamp_col), ndv(id)
+ from functional_parquet.alltypesagg"""
+ ndv_result = self.execute_query(ndv_stmt)
+ ndv_vals = ndv_result.data[0].split('\t')
+
+ for sample_perc in [0.1, 0.2, 0.5, 1.0]:
+ sampled_ndv_stmt = """
+ select sampled_ndv(bool_col, {0}), sampled_ndv(tinyint_col, {0}),
+ sampled_ndv(smallint_col, {0}), sampled_ndv(int_col, {0}),
+ sampled_ndv(bigint_col, {0}), sampled_ndv(float_col, {0}),
+ sampled_ndv(double_col, {0}), sampled_ndv(string_col, {0}),
+ sampled_ndv(cast(double_col as decimal(3, 0)), {0}),
+ sampled_ndv(cast(double_col as decimal(10, 5)), {0}),
+ sampled_ndv(cast(double_col as decimal(20, 10)), {0}),
+ sampled_ndv(cast(double_col as decimal(38, 35)), {0}),
+ sampled_ndv(cast(string_col as varchar(20)), {0}),
+ sampled_ndv(cast(string_col as char(10)), {0}),
+ sampled_ndv(timestamp_col, {0}), sampled_ndv(id, {0})
+ from functional_parquet.alltypesagg""".format(sample_perc)
+ sampled_ndv_result = self.execute_query(sampled_ndv_stmt)
+ sampled_ndv_vals = sampled_ndv_result.data[0].split('\t')
+
+ assert len(sampled_ndv_vals) == len(ndv_vals)
+ # Low NDV columns. We expect a reasonaby accurate estimate regardless of the
+ # sampling percent.
+ for i in xrange(0, 14):
+ self.__appx_equals(int(sampled_ndv_vals[i]), int(ndv_vals[i]), 0.1)
+ # High NDV columns. We expect the estimate to have high variance and error.
+ # Since we give NDV() and SAMPLED_NDV() the same input data, i.e., we are not
+ # actually sampling for SAMPLED_NDV(), we expect the result of SAMPLED_NDV() to
+ # be bigger than NDV() proportional to the sampling percent.
+ # For example, the column 'id' is a PK so we expect the result of SAMPLED_NDV()
+ # with a sampling percent of 0.1 to be approximately 10x of the NDV().
+ for i in xrange(14, 16):
+ self.__appx_equals(int(sampled_ndv_vals[i]) * sample_perc, int(ndv_vals[i]), 2.0)
+
+ def __appx_equals(self, a, b, diff_perc):
+ """Returns True if 'a' and 'b' are within 'diff_perc' percent of each other,
+ False otherwise. 'diff_perc' must be a float in [0,1]."""
+ assert abs(a - b) / float(max(a, b)) <= diff_perc
+
class TestWideAggregationQueries(ImpalaTestSuite):
"""Test that aggregations with many grouping columns work"""
@classmethod