http://git-wip-us.apache.org/repos/asf/impala/blob/0936e329/fe/src/main/java/org/apache/impala/catalog/BuiltinsDb.java ---------------------------------------------------------------------- diff --git a/fe/src/main/java/org/apache/impala/catalog/BuiltinsDb.java b/fe/src/main/java/org/apache/impala/catalog/BuiltinsDb.java index 07699d3..f316410 100644 --- a/fe/src/main/java/org/apache/impala/catalog/BuiltinsDb.java +++ b/fe/src/main/java/org/apache/impala/catalog/BuiltinsDb.java @@ -22,7 +22,6 @@ import java.util.Collections; import java.util.Map; import org.apache.hadoop.hive.metastore.api.Database; - import org.apache.impala.analysis.ArithmeticExpr; import org.apache.impala.analysis.BinaryPredicate; import org.apache.impala.analysis.CaseExpr; @@ -32,6 +31,7 @@ import org.apache.impala.analysis.InPredicate; import org.apache.impala.analysis.IsNullPredicate; import org.apache.impala.analysis.LikePredicate; import org.apache.impala.builtins.ScalarBuiltins; + import com.google.common.collect.ImmutableMap; import com.google.common.collect.Lists; @@ -304,6 +304,30 @@ public class BuiltinsDb extends Db { "9HllUpdateIN10impala_udf10DecimalValEEEvPNS2_15FunctionContextERKT_PNS2_9StringValE") .build(); + private static final Map<Type, String> SAMPLED_NDV_UPDATE_SYMBOL = + ImmutableMap.<Type, String>builder() + .put(Type.BOOLEAN, + "16SampledNdvUpdateIN10impala_udf10BooleanValEEEvPNS2_15FunctionContextERKT_RKNS2_9DoubleValEPNS2_9StringValE") + .put(Type.TINYINT, + "16SampledNdvUpdateIN10impala_udf10TinyIntValEEEvPNS2_15FunctionContextERKT_RKNS2_9DoubleValEPNS2_9StringValE") + .put(Type.SMALLINT, + "16SampledNdvUpdateIN10impala_udf11SmallIntValEEEvPNS2_15FunctionContextERKT_RKNS2_9DoubleValEPNS2_9StringValE") + .put(Type.INT, + "16SampledNdvUpdateIN10impala_udf6IntValEEEvPNS2_15FunctionContextERKT_RKNS2_9DoubleValEPNS2_9StringValE") + .put(Type.BIGINT, + "16SampledNdvUpdateIN10impala_udf9BigIntValEEEvPNS2_15FunctionContextERKT_RKNS2_9DoubleValEPNS2_9StringValE") + .put(Type.FLOAT, + "16SampledNdvUpdateIN10impala_udf8FloatValEEEvPNS2_15FunctionContextERKT_RKNS2_9DoubleValEPNS2_9StringValE") + .put(Type.DOUBLE, + "16SampledNdvUpdateIN10impala_udf9DoubleValEEEvPNS2_15FunctionContextERKT_RKS3_PNS2_9StringValE") + .put(Type.STRING, + "16SampledNdvUpdateIN10impala_udf9StringValEEEvPNS2_15FunctionContextERKT_RKNS2_9DoubleValEPS3_") + .put(Type.TIMESTAMP, + "16SampledNdvUpdateIN10impala_udf12TimestampValEEEvPNS2_15FunctionContextERKT_RKNS2_9DoubleValEPNS2_9StringValE") + .put(Type.DECIMAL, + "16SampledNdvUpdateIN10impala_udf10DecimalValEEEvPNS2_15FunctionContextERKT_RKNS2_9DoubleValEPNS2_9StringValE") + .build(); + private static final Map<Type, String> PC_UPDATE_SYMBOL = ImmutableMap.<Type, String>builder() .put(Type.BOOLEAN, @@ -788,6 +812,19 @@ public class BuiltinsDb extends Db { "_Z20IncrementNdvFinalizePN10impala_udf15FunctionContextERKNS_9StringValE", true, false, true)); + // SAMPLED_NDV. + // Size needs to be kept in sync with SampledNdvState in the BE. + int NUM_HLL_BUCKETS = 32; + int size = 16 + NUM_HLL_BUCKETS * (8 + HLL_INTERMEDIATE_SIZE); + Type sampledIntermediateType = ScalarType.createFixedUdaIntermediateType(size); + db.addBuiltin(AggregateFunction.createBuiltin(db, "sampled_ndv", + Lists.newArrayList(t, Type.DOUBLE), Type.BIGINT, sampledIntermediateType, + prefix + "14SampledNdvInitEPN10impala_udf15FunctionContextEPNS1_9StringValE", + prefix + SAMPLED_NDV_UPDATE_SYMBOL.get(t), + prefix + "15SampledNdvMergeEPN10impala_udf15FunctionContextERKNS1_9StringValEPS4_", + null, + prefix + "18SampledNdvFinalizeEPN10impala_udf15FunctionContextERKNS1_9StringValE", + true, false, true)); Type pcIntermediateType = ScalarType.createFixedUdaIntermediateType(PC_INTERMEDIATE_SIZE);
http://git-wip-us.apache.org/repos/asf/impala/blob/0936e329/fe/src/main/java/org/apache/impala/catalog/HdfsTable.java ---------------------------------------------------------------------- diff --git a/fe/src/main/java/org/apache/impala/catalog/HdfsTable.java b/fe/src/main/java/org/apache/impala/catalog/HdfsTable.java index c2417f6..4de25fe 100644 --- a/fe/src/main/java/org/apache/impala/catalog/HdfsTable.java +++ b/fe/src/main/java/org/apache/impala/catalog/HdfsTable.java @@ -2132,7 +2132,6 @@ public class HdfsTable extends Table { parts[selectedIdx] = parts[numFilesRemaining - 1]; --numFilesRemaining; } - return result; } } http://git-wip-us.apache.org/repos/asf/impala/blob/0936e329/fe/src/test/java/org/apache/impala/analysis/AnalyzeStmtsTest.java ---------------------------------------------------------------------- diff --git a/fe/src/test/java/org/apache/impala/analysis/AnalyzeStmtsTest.java b/fe/src/test/java/org/apache/impala/analysis/AnalyzeStmtsTest.java index 54aa098..133b6e2 100644 --- a/fe/src/test/java/org/apache/impala/analysis/AnalyzeStmtsTest.java +++ b/fe/src/test/java/org/apache/impala/analysis/AnalyzeStmtsTest.java @@ -18,19 +18,23 @@ package org.apache.impala.analysis; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import java.lang.reflect.Field; import java.util.List; +import org.apache.impala.catalog.Column; import org.apache.impala.catalog.PrimitiveType; import org.apache.impala.catalog.ScalarType; +import org.apache.impala.catalog.Table; import org.apache.impala.catalog.Type; import org.apache.impala.common.AnalysisException; import org.apache.impala.common.RuntimeEnv; import org.junit.Assert; import org.junit.Test; +import com.google.common.base.Joiner; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; @@ -2051,6 +2055,55 @@ public class AnalyzeStmtsTest extends AnalyzerTest { } @Test + public void TestSampledNdv() throws AnalysisException { + Table allScalarTypes = addAllScalarTypesTestTable(); + String tblName = allScalarTypes.getFullName(); + + // Positive tests: Test all scalar types and valid sampling percents. + double validSamplePercs[] = new double[] { 0.0, 0.1, 0.2, 0.5, 0.8, 1.0 }; + for (double perc: validSamplePercs) { + List<String> allAggFnCalls = Lists.newArrayList(); + for (Column col: allScalarTypes.getColumns()) { + String aggFnCall = String.format("sampled_ndv(%s, %s)", col.getName(), perc); + allAggFnCalls.add(aggFnCall); + String stmtSql = String.format("select %s from %s", aggFnCall, tblName); + SelectStmt stmt = (SelectStmt) AnalyzesOk(stmtSql); + // Verify that the resolved function signature matches as expected. + Type[] args = stmt.getAggInfo().getAggregateExprs().get(0).getFn().getArgs(); + assertEquals(args.length, 2); + assertTrue(col.getType().matchesType(args[0]) || + col.getType().isStringType() && args[0].equals(Type.STRING)); + assertEquals(Type.DOUBLE, args[1]); + } + // Test several calls in the same query block. + AnalyzesOk(String.format( + "select %s from %s", Joiner.on(",").join(allAggFnCalls), tblName)); + } + + // Negative tests: Incorrect number of args. + AnalysisError( + String.format("select sampled_ndv() from %s", tblName), + "No matching function with signature: sampled_ndv()."); + AnalysisError( + String.format("select sampled_ndv(int_col) from %s", tblName), + "No matching function with signature: sampled_ndv(INT)."); + AnalysisError( + String.format("select sampled_ndv(int_col, 0.1, 10) from %s", tblName), + "No matching function with signature: sampled_ndv(INT, DECIMAL(1,1), TINYINT)."); + + // Negative tests: Invalid sampling percent. + String invalidSamplePercs[] = new String[] { + "int_col", "double_col", "100 / 10", "-0.1", "1.1", "100", "50", "-50", "NULL" + }; + for (String invalidPerc: invalidSamplePercs) { + AnalysisError( + String.format("select sampled_ndv(int_col, %s) from %s", invalidPerc, tblName), + "Second parameter of SAMPLED_NDV() must be a numeric literal in [0,1]: " + + invalidPerc); + } + } + + @Test public void TestGroupConcat() throws AnalysisException { // Test valid and invalid parameters AnalyzesOk("select group_concat(distinct name) from functional.testtbl"); http://git-wip-us.apache.org/repos/asf/impala/blob/0936e329/fe/src/test/java/org/apache/impala/common/FrontendTestBase.java ---------------------------------------------------------------------- diff --git a/fe/src/test/java/org/apache/impala/common/FrontendTestBase.java b/fe/src/test/java/org/apache/impala/common/FrontendTestBase.java index c014dff..6718cb4 100644 --- a/fe/src/test/java/org/apache/impala/common/FrontendTestBase.java +++ b/fe/src/test/java/org/apache/impala/common/FrontendTestBase.java @@ -58,8 +58,8 @@ import org.apache.impala.thrift.TFunctionBinaryType; import org.apache.impala.thrift.TQueryCtx; import org.apache.impala.thrift.TQueryOptions; import org.junit.After; -import org.junit.Assert; import org.junit.AfterClass; +import org.junit.Assert; import org.junit.BeforeClass; import com.google.common.base.Joiner; @@ -233,6 +233,16 @@ public class FrontendTestBase { return dummyView; } + protected Table addAllScalarTypesTestTable() { + addTestDb("allscalartypesdb", ""); + return addTestTable("create table allscalartypes (" + + "bool_col boolean, tinyint_col tinyint, smallint_col smallint, int_col int, " + + "bigint_col bigint, float_col float, double_col double, dec1 decimal(9,0), " + + "d2 decimal(10, 0), d3 decimal(20, 10), d4 decimal(38, 38), d5 decimal(10, 5), " + + "timestamp_col timestamp, string_col string, varchar_col varchar(50), " + + "char_col char (30))"); + } + protected void clearTestTables() { for (Table testTable: testTables_) { testTable.getDb().removeTable(testTable.getName()); http://git-wip-us.apache.org/repos/asf/impala/blob/0936e329/tests/query_test/test_aggregation.py ---------------------------------------------------------------------- diff --git a/tests/query_test/test_aggregation.py b/tests/query_test/test_aggregation.py index 9e0be6d..233c33a 100644 --- a/tests/query_test/test_aggregation.py +++ b/tests/query_test/test_aggregation.py @@ -275,6 +275,75 @@ class TestAggregationQueries(ImpalaTestSuite): vector.get_value('exec_option')['batch_size'] = 1 self.run_test_case('QueryTest/parquet-stats-agg', vector, unique_database) + def test_sampled_ndv(self, vector, unique_database): + """The SAMPLED_NDV() function is inherently non-deterministic and cannot be + reasonably made deterministic with existing options so we test it separately. + The goal of this test is to ensure that SAMPLED_NDV() works on all data types + and returns approximately sensible estimates. It is not the goal of this test + to ensure tight error bounds on the NDV estimates. SAMPLED_NDV() is expected + be inaccurate on small data sets like the ones we use in this test.""" + if (vector.get_value('table_format').file_format != 'text' or + vector.get_value('table_format').compression_codec != 'none'): + # No need to run this test on all file formats + pytest.skip() + + # NDV() is used a baseline to compare SAMPLED_NDV(). Both NDV() and SAMPLED_NDV() + # are based on HyperLogLog so NDV() is roughly the best that SAMPLED_NDV() can do. + # Expectations: All columns except 'id' and 'timestmap_col' have low NDVs and are + # expected to be reasonably accurate with SAMPLED_NDV(). For the two high-NDV columns + # SAMPLED_NDV() is expected to have high variance and error. + ndv_stmt = """ + select ndv(bool_col), ndv(tinyint_col), + ndv(smallint_col), ndv(int_col), + ndv(bigint_col), ndv(float_col), + ndv(double_col), ndv(string_col), + ndv(cast(double_col as decimal(3, 0))), + ndv(cast(double_col as decimal(10, 5))), + ndv(cast(double_col as decimal(20, 10))), + ndv(cast(double_col as decimal(38, 35))), + ndv(cast(string_col as varchar(20))), + ndv(cast(string_col as char(10))), + ndv(timestamp_col), ndv(id) + from functional_parquet.alltypesagg""" + ndv_result = self.execute_query(ndv_stmt) + ndv_vals = ndv_result.data[0].split('\t') + + for sample_perc in [0.1, 0.2, 0.5, 1.0]: + sampled_ndv_stmt = """ + select sampled_ndv(bool_col, {0}), sampled_ndv(tinyint_col, {0}), + sampled_ndv(smallint_col, {0}), sampled_ndv(int_col, {0}), + sampled_ndv(bigint_col, {0}), sampled_ndv(float_col, {0}), + sampled_ndv(double_col, {0}), sampled_ndv(string_col, {0}), + sampled_ndv(cast(double_col as decimal(3, 0)), {0}), + sampled_ndv(cast(double_col as decimal(10, 5)), {0}), + sampled_ndv(cast(double_col as decimal(20, 10)), {0}), + sampled_ndv(cast(double_col as decimal(38, 35)), {0}), + sampled_ndv(cast(string_col as varchar(20)), {0}), + sampled_ndv(cast(string_col as char(10)), {0}), + sampled_ndv(timestamp_col, {0}), sampled_ndv(id, {0}) + from functional_parquet.alltypesagg""".format(sample_perc) + sampled_ndv_result = self.execute_query(sampled_ndv_stmt) + sampled_ndv_vals = sampled_ndv_result.data[0].split('\t') + + assert len(sampled_ndv_vals) == len(ndv_vals) + # Low NDV columns. We expect a reasonaby accurate estimate regardless of the + # sampling percent. + for i in xrange(0, 14): + self.__appx_equals(int(sampled_ndv_vals[i]), int(ndv_vals[i]), 0.1) + # High NDV columns. We expect the estimate to have high variance and error. + # Since we give NDV() and SAMPLED_NDV() the same input data, i.e., we are not + # actually sampling for SAMPLED_NDV(), we expect the result of SAMPLED_NDV() to + # be bigger than NDV() proportional to the sampling percent. + # For example, the column 'id' is a PK so we expect the result of SAMPLED_NDV() + # with a sampling percent of 0.1 to be approximately 10x of the NDV(). + for i in xrange(14, 16): + self.__appx_equals(int(sampled_ndv_vals[i]) * sample_perc, int(ndv_vals[i]), 2.0) + + def __appx_equals(self, a, b, diff_perc): + """Returns True if 'a' and 'b' are within 'diff_perc' percent of each other, + False otherwise. 'diff_perc' must be a float in [0,1].""" + assert abs(a - b) / float(max(a, b)) <= diff_perc + class TestWideAggregationQueries(ImpalaTestSuite): """Test that aggregations with many grouping columns work""" @classmethod
