You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@druid.apache.org by cw...@apache.org on 2021/11/08 08:33:29 UTC
[druid] branch master updated: complex typed expressions (#11853)
This is an automated email from the ASF dual-hosted git repository.
cwylie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/druid.git
The following commit(s) were added to refs/heads/master by this push:
new 7237dc8 complex typed expressions (#11853)
7237dc8 is described below
commit 7237dc837cf335558fa7981c4297d2d3b40a730e
Author: Clint Wylie <cw...@apache.org>
AuthorDate: Mon Nov 8 00:33:06 2021 -0800
complex typed expressions (#11853)
* complex typed expressions
* add built-in hll collector expressions to get coverage on druid-processing, more types, more better
* rampage!!!
* more javadoc
* adjustments
* oops
* lol
* remove unused dependency
* contradiction?
* more test
---
core/pom.xml | 4 -
.../org/apache/druid/math/expr/antlr/Expr.g4 | 3 +
.../org/apache/druid/math/expr/ApplyFunction.java | 214 +++--
.../org/apache/druid/math/expr/ConstantExpr.java | 197 ++---
.../main/java/org/apache/druid/math/expr/Expr.java | 6 +-
.../java/org/apache/druid/math/expr/ExprEval.java | 957 +++++++--------------
.../apache/druid/math/expr/ExprListenerImpl.java | 58 +-
.../org/apache/druid/math/expr/ExprMacroTable.java | 37 +-
.../java/org/apache/druid/math/expr/ExprType.java | 37 +-
.../druid/math/expr/ExpressionProcessing.java | 68 ++
.../math/expr/ExpressionProcessingConfig.java | 25 +-
.../math/expr/ExpressionProcessingModule.java | 23 +-
.../org/apache/druid/math/expr/ExpressionType.java | 10 +
.../java/org/apache/druid/math/expr/Function.java | 221 ++---
.../org/apache/druid/math/expr/IdentifierExpr.java | 2 +-
.../org/apache/druid/math/expr/InputBindings.java | 106 ++-
.../java/org/apache/druid/math/expr/Parser.java | 13 +-
.../druid/math/expr/SettableObjectBinding.java | 14 +
.../druid/segment/column/ObjectByteStrategy.java | 40 +-
.../org/apache/druid/segment/column/Types.java | 615 +++++++++++++
.../apache/druid/math/expr/ApplyFunctionTest.java | 12 +-
.../org/apache/druid/math/expr/ExprEvalTest.java | 143 ++-
.../java/org/apache/druid/math/expr/ExprTest.java | 33 +-
.../org/apache/druid/math/expr/FunctionTest.java | 81 +-
.../org/apache/druid/math/expr/ParserTest.java | 140 ++-
.../org/apache/druid/segment/column/TypesTest.java | 443 ++++++++++
.../druid/testing/InitializedNullHandlingTest.java | 2 +
.../druid/guice/BloomFilterExtensionModule.java | 6 +-
.../bloom/BloomFilterAggregatorFactory.java | 9 +-
.../query/aggregation/bloom/BloomFilterSerde.java | 45 +-
.../query/expressions/BloomFilterExprMacro.java | 138 ---
.../query/expressions/BloomFilterExpressions.java | 366 ++++++++
.../filter/sql/BloomFilterOperatorConversion.java | 6 +-
.../expressions/BloomFilterExpressionsTest.java | 226 +++++
.../query/filter/sql/BloomDimFilterSqlTest.java | 4 +-
.../druid/query/expressions/SleepExprTest.java | 21 +-
.../org/apache/druid/indexer/InputRowSerde.java | 3 +-
pom.xml | 5 -
.../org/apache/druid/guice/GuiceInjectors.java | 2 +
.../aggregation/ExpressionLambdaAggregator.java | 93 +-
.../ExpressionLambdaAggregatorFactory.java | 42 +-
.../ExpressionLambdaAggregatorInputBindings.java | 11 +
.../ExpressionLambdaBufferAggregator.java | 19 +-
.../apache/druid/query/expression/ExprUtils.java | 22 +-
.../query/expression/HyperUniqueExpressions.java | 333 +++++++
.../query/expression/TimestampCeilExprMacro.java | 3 +-
.../query/expression/TimestampFloorExprMacro.java | 3 +-
.../query/expression/TimestampShiftExprMacro.java | 7 +-
.../druid/query/expression/TrimExprMacro.java | 3 +-
.../epinephelinae/RowBasedGrouperHelper.java | 2 +-
.../timeseries/TimeseriesQueryQueryToolChest.java | 2 +-
.../segment/RowBasedColumnSelectorFactory.java | 14 +-
.../org/apache/druid/segment/RowBasedCursor.java | 2 +-
.../apache/druid/segment/data/ObjectStrategy.java | 26 +-
.../druid/segment/filter/ExpressionFilter.java | 16 +-
.../segment/incremental/IncrementalIndex.java | 17 +-
.../druid/segment/join/JoinConditionAnalysis.java | 6 +-
.../apache/druid/segment/serde/ComplexMetrics.java | 2 +
.../segment/transform/ExpressionTransform.java | 13 +-
.../druid/segment/transform/Transformer.java | 2 +-
.../druid/segment/virtual/ExpressionSelectors.java | 84 +-
.../segment/virtual/ExpressionVectorSelectors.java | 8 +-
.../RowBasedExpressionColumnValueSelector.java | 4 +-
.../druid/segment/virtual/SingleInputBindings.java | 15 +
...gInputCachingExpressionColumnValueSelector.java | 3 +-
...gInputCachingExpressionColumnValueSelector.java | 4 +-
...erredEvaluationExpressionDimensionSelector.java | 3 +-
.../ExpressionLambdaAggregatorFactoryTest.java | 87 +-
.../ExpressionLambdaAggregatorTest.java | 108 +++
.../expression/CaseInsensitiveExprMacroTest.java | 4 +-
.../query/expression/ContainsExprMacroTest.java | 4 +-
.../expression/HyperUniqueExpressionsTest.java | 256 ++++++
.../expression/IPv4AddressMatchExprMacroTest.java | 3 +-
.../expression/IPv4AddressParseExprMacroTest.java | 3 +-
.../IPv4AddressStringifyExprMacroTest.java | 3 +-
.../expression/RegexpExtractExprMacroTest.java | 4 +-
.../query/expression/RegexpLikeExprMacroTest.java | 4 +-
.../druid/query/expression/TestExprMacroTable.java | 6 +-
.../expression/TimestampExtractExprMacroTest.java | 11 +-
.../query/expression/TimestampShiftMacroTest.java | 25 +-
.../apache/druid/query/filter/InDimFilterTest.java | 2 +-
.../query/groupby/GroupByQueryRunnerTest.java | 85 ++
.../druid/query/topn/TopNQueryRunnerTest.java | 65 ++
.../java/org/apache/druid/segment/TestHelper.java | 2 +-
.../druid/segment/filter/BaseFilterTest.java | 2 +-
.../virtual/ExpressionVirtualColumnTest.java | 8 +-
.../ListFilteredVirtualColumnSelectorTest.java | 2 +-
.../org/apache/druid/guice/ExpressionModule.java | 5 +
.../builtin/ArrayContainsOperatorConversion.java | 3 +-
.../builtin/ArrayOverlapOperatorConversion.java | 3 +-
.../MultiValueStringOperatorConversions.java | 4 +-
.../sql/calcite/planner/DruidRexExecutor.java | 11 +-
.../apache/druid/sql/calcite/rel/QueryMaker.java | 2 +
.../druid/sql/calcite/CalciteArraysQueryTest.java | 49 +-
.../calcite/expression/ExpressionTestHelper.java | 2 +-
.../druid/sql/calcite/util/CalciteTestBase.java | 2 +
96 files changed, 4351 insertions(+), 1508 deletions(-)
diff --git a/core/pom.xml b/core/pom.xml
index f56e0d2..5cda7a4 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -188,10 +188,6 @@
<artifactId>fastutil-core</artifactId>
</dependency>
<dependency>
- <groupId>it.unimi.dsi</groupId>
- <artifactId>fastutil-extra</artifactId>
- </dependency>
- <dependency>
<groupId>io.netty</groupId>
<artifactId>netty-buffer</artifactId>
</dependency>
diff --git a/core/src/main/antlr4/org/apache/druid/math/expr/antlr/Expr.g4 b/core/src/main/antlr4/org/apache/druid/math/expr/antlr/Expr.g4
index 1665302..1d52a13 100644
--- a/core/src/main/antlr4/org/apache/druid/math/expr/antlr/Expr.g4
+++ b/core/src/main/antlr4/org/apache/druid/math/expr/antlr/Expr.g4
@@ -36,6 +36,7 @@ expr : NULL # null
| '<LONG>' '[' (numericElement (',' numericElement)*)? ']' # explicitLongArray
| '<DOUBLE>'? '[' (numericElement (',' numericElement)*)? ']' # doubleArray
| '<STRING>' '[' (literalElement (',' literalElement)*)? ']' # explicitStringArray
+ | ARRAY_TYPE '[' (literalElement (',' literalElement)*)? ']' # explicitArray
;
lambda : (IDENTIFIER | '(' ')' | '(' IDENTIFIER (',' IDENTIFIER)* ')') '->' expr
@@ -52,6 +53,8 @@ numericElement : (LONG | DOUBLE | NULL);
literalElement : (STRING | LONG | DOUBLE | NULL);
+ARRAY_TYPE : 'ARRAY<' ( 'LONG' | 'DOUBLE' | 'STRING' | ('COMPLEX<' IDENTIFIER '>')| ARRAY_TYPE ) '>';
+
NULL : 'null';
LONG : [0-9]+;
EXP: [eE] [-]? LONG;
diff --git a/core/src/main/java/org/apache/druid/math/expr/ApplyFunction.java b/core/src/main/java/org/apache/druid/math/expr/ApplyFunction.java
index 08e5e91f..0bd6010 100644
--- a/core/src/main/java/org/apache/druid/math/expr/ApplyFunction.java
+++ b/core/src/main/java/org/apache/druid/math/expr/ApplyFunction.java
@@ -25,7 +25,6 @@ import it.unimi.dsi.fastutil.objects.Object2IntArrayMap;
import it.unimi.dsi.fastutil.objects.Object2IntMap;
import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap;
import org.apache.druid.java.util.common.IAE;
-import org.apache.druid.java.util.common.RE;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.UOE;
import org.apache.druid.math.expr.vector.ExprVectorProcessor;
@@ -138,54 +137,16 @@ public interface ApplyFunction
/**
* Evaluate {@link LambdaExpr} against every index position of an {@link IndexableMapLambdaObjectBinding}
*/
- ExprEval applyMap(LambdaExpr expr, IndexableMapLambdaObjectBinding bindings)
+ ExprEval applyMap(@Nullable ExpressionType arrayType, LambdaExpr expr, IndexableMapLambdaObjectBinding bindings)
{
final int length = bindings.getLength();
- String[] stringsOut = null;
- Long[] longsOut = null;
- Double[] doublesOut = null;
-
- ExpressionType elementType = null;
+ Object[] out = new Object[length];
for (int i = 0; i < length; i++) {
ExprEval evaluated = expr.eval(bindings.withIndex(i));
- if (elementType == null) {
- elementType = evaluated.type();
- switch (elementType.getType()) {
- case STRING:
- stringsOut = new String[length];
- break;
- case LONG:
- longsOut = new Long[length];
- break;
- case DOUBLE:
- doublesOut = new Double[length];
- break;
- default:
- throw new RE("Unhandled map function output type [%s]", elementType);
- }
- }
-
- Function.ArrayConstructorFunction.setArrayOutputElement(
- stringsOut,
- longsOut,
- doublesOut,
- elementType,
- i,
- evaluated
- );
- }
-
- switch (elementType.getType()) {
- case STRING:
- return ExprEval.ofStringArray(stringsOut);
- case LONG:
- return ExprEval.ofLongArray(longsOut);
- case DOUBLE:
- return ExprEval.ofDoubleArray(doublesOut);
- default:
- throw new RE("Unhandled map function output type [%s]", elementType);
+ arrayType = Function.ArrayConstructorFunction.setArrayOutput(arrayType, out, i, evaluated);
}
+ return ExprEval.ofArray(arrayType, out);
}
}
@@ -216,8 +177,9 @@ public interface ApplyFunction
return arrayEval;
}
- MapLambdaBinding lambdaBinding = new MapLambdaBinding(array, lambdaExpr, bindings);
- return applyMap(lambdaExpr, lambdaBinding);
+ MapLambdaBinding lambdaBinding = new MapLambdaBinding(arrayEval.elementType(), array, lambdaExpr, bindings);
+ ExpressionType lambdaType = lambdaExpr.getOutputType(lambdaBinding);
+ return applyMap(lambdaType == null ? null : ExpressionTypeFactory.getInstance().ofArray(lambdaType), lambdaExpr, lambdaBinding);
}
@Override
@@ -261,6 +223,7 @@ public interface ApplyFunction
List<List<Object>> arrayInputs = new ArrayList<>();
boolean hadNull = false;
boolean hadEmpty = false;
+ ExpressionType elementType = null;
for (Expr expr : argsExpr) {
ExprEval arrayEval = expr.eval(bindings);
Object[] array = arrayEval.asArray();
@@ -268,6 +231,7 @@ public interface ApplyFunction
hadNull = true;
continue;
}
+ elementType = arrayEval.elementType();
if (array.length == 0) {
hadEmpty = true;
continue;
@@ -282,8 +246,9 @@ public interface ApplyFunction
}
List<List<Object>> product = CartesianList.create(arrayInputs);
- CartesianMapLambdaBinding lambdaBinding = new CartesianMapLambdaBinding(product, lambdaExpr, bindings);
- return applyMap(lambdaExpr, lambdaBinding);
+ CartesianMapLambdaBinding lambdaBinding = new CartesianMapLambdaBinding(elementType, product, lambdaExpr, bindings);
+ ExpressionType lambdaType = lambdaExpr.getOutputType(lambdaBinding);
+ return applyMap(ExpressionType.asArrayType(lambdaType), lambdaExpr, lambdaBinding);
}
@Override
@@ -324,7 +289,7 @@ public interface ApplyFunction
if (accumulator instanceof Boolean) {
return ExprEval.ofLongBoolean((boolean) accumulator);
}
- return ExprEval.bestEffortOf(accumulator);
+ return ExprEval.ofType(bindings.getAccumulatorType(), accumulator);
}
@Override
@@ -372,7 +337,14 @@ public interface ApplyFunction
}
Object accumulator = accEval.value();
- FoldLambdaBinding lambdaBinding = new FoldLambdaBinding(array, accumulator, lambdaExpr, bindings);
+ FoldLambdaBinding lambdaBinding = new FoldLambdaBinding(
+ arrayEval.elementType(),
+ array,
+ accEval.type(),
+ accumulator,
+ lambdaExpr,
+ bindings
+ );
return applyFold(lambdaExpr, accumulator, lambdaBinding);
}
@@ -415,6 +387,7 @@ public interface ApplyFunction
List<List<Object>> arrayInputs = new ArrayList<>();
boolean hadNull = false;
boolean hadEmpty = false;
+ ExpressionType arrayElementType = null;
for (int i = 0; i < argsExpr.size() - 1; i++) {
Expr expr = argsExpr.get(i);
ExprEval arrayEval = expr.eval(bindings);
@@ -423,6 +396,7 @@ public interface ApplyFunction
hadNull = true;
continue;
}
+ arrayElementType = arrayEval.elementType();
if (array.length == 0) {
hadEmpty = true;
continue;
@@ -444,7 +418,7 @@ public interface ApplyFunction
Object accumulator = accEval.value();
CartesianFoldLambdaBinding lambdaBindings =
- new CartesianFoldLambdaBinding(product, accumulator, lambdaExpr, bindings);
+ new CartesianFoldLambdaBinding(arrayElementType, product, accEval.type(), accumulator, lambdaExpr, bindings);
return applyFold(lambdaExpr, accumulator, lambdaBindings);
}
@@ -495,23 +469,9 @@ public interface ApplyFunction
return ExprEval.of(null);
}
- SettableLambdaBinding lambdaBinding = new SettableLambdaBinding(lambdaExpr, bindings);
- switch (arrayEval.elementType().getType()) {
- case STRING:
- String[] filteredString =
- this.filter(arrayEval.asStringArray(), lambdaExpr, lambdaBinding).toArray(String[]::new);
- return ExprEval.ofStringArray(filteredString);
- case LONG:
- Long[] filteredLong =
- this.filter(arrayEval.asLongArray(), lambdaExpr, lambdaBinding).toArray(Long[]::new);
- return ExprEval.ofLongArray(filteredLong);
- case DOUBLE:
- Double[] filteredDouble =
- this.filter(arrayEval.asDoubleArray(), lambdaExpr, lambdaBinding).toArray(Double[]::new);
- return ExprEval.ofDoubleArray(filteredDouble);
- default:
- throw new RE("Unhandled filter function input type [%s]", arrayEval.type());
- }
+ SettableLambdaBinding lambdaBinding = new SettableLambdaBinding(arrayEval.elementType(), lambdaExpr, bindings);
+ Object[] filtered = filter(arrayEval.asArray(), lambdaExpr, lambdaBinding).toArray();
+ return ExprEval.ofArray(arrayEval.asArrayType(), filtered);
}
@Override
@@ -565,7 +525,7 @@ public interface ApplyFunction
return ExprEval.ofLongBoolean(false);
}
- SettableLambdaBinding lambdaBinding = new SettableLambdaBinding(lambdaExpr, bindings);
+ SettableLambdaBinding lambdaBinding = new SettableLambdaBinding(arrayEval.elementType(), lambdaExpr, bindings);
return match(array, lambdaExpr, lambdaBinding);
}
@@ -654,14 +614,16 @@ public interface ApplyFunction
{
private final Expr.ObjectBinding bindings;
private final Map<String, Object> lambdaBindings;
+ private final ExpressionType elementType;
- SettableLambdaBinding(LambdaExpr expr, Expr.ObjectBinding bindings)
+ SettableLambdaBinding(ExpressionType elementType, LambdaExpr expr, Expr.ObjectBinding bindings)
{
+ this.elementType = elementType;
this.lambdaBindings = new HashMap<>();
for (String lambdaIdentifier : expr.getIdentifiers()) {
lambdaBindings.put(lambdaIdentifier, null);
}
- this.bindings = bindings != null ? bindings : Collections.emptyMap()::get;
+ this.bindings = bindings != null ? bindings : InputBindings.nilBindings();
}
@Nullable
@@ -679,6 +641,16 @@ public interface ApplyFunction
this.lambdaBindings.put(key, value);
return this;
}
+
+ @Nullable
+ @Override
+ public ExpressionType getType(String name)
+ {
+ if (lambdaBindings.containsKey(name)) {
+ return elementType;
+ }
+ return bindings.getType(name);
+ }
}
/**
@@ -707,17 +679,19 @@ public interface ApplyFunction
class MapLambdaBinding implements IndexableMapLambdaObjectBinding
{
private final Expr.ObjectBinding bindings;
+ private final ExpressionType arrayElementType;
@Nullable
private final String lambdaIdentifier;
private final Object[] arrayValues;
private int index = 0;
private final boolean scoped;
- MapLambdaBinding(Object[] arrayValues, LambdaExpr expr, Expr.ObjectBinding bindings)
+ MapLambdaBinding(ExpressionType elementType, Object[] arrayValues, LambdaExpr expr, Expr.ObjectBinding bindings)
{
this.lambdaIdentifier = expr.getIdentifier();
+ this.arrayElementType = elementType;
this.arrayValues = arrayValues;
- this.bindings = bindings != null ? bindings : Collections.emptyMap()::get;
+ this.bindings = bindings != null ? bindings : InputBindings.nilBindings();
this.scoped = lambdaIdentifier != null;
}
@@ -743,6 +717,16 @@ public interface ApplyFunction
this.index = index;
return this;
}
+
+ @Nullable
+ @Override
+ public ExpressionType getType(String name)
+ {
+ if (scoped && name.equals(lambdaIdentifier)) {
+ return arrayElementType;
+ }
+ return bindings.getType(name);
+ }
}
/**
@@ -753,14 +737,16 @@ public interface ApplyFunction
class CartesianMapLambdaBinding implements IndexableMapLambdaObjectBinding
{
private final Expr.ObjectBinding bindings;
+ private final ExpressionType arrayElementType;
private final Object2IntMap<String> lambdaIdentifiers;
private final List<List<Object>> lambdaInputs;
private final boolean scoped;
private int index = 0;
- CartesianMapLambdaBinding(List<List<Object>> inputs, LambdaExpr expr, Expr.ObjectBinding bindings)
+ CartesianMapLambdaBinding(ExpressionType arrayElementType, List<List<Object>> inputs, LambdaExpr expr, Expr.ObjectBinding bindings)
{
this.lambdaInputs = inputs;
+ this.arrayElementType = arrayElementType;
List<String> ids = expr.getIdentifiers();
this.scoped = ids.size() > 0;
this.lambdaIdentifiers = new Object2IntArrayMap<>(ids.size());
@@ -768,7 +754,7 @@ public interface ApplyFunction
lambdaIdentifiers.put(ids.get(i), i);
}
- this.bindings = bindings != null ? bindings : Collections.emptyMap()::get;
+ this.bindings = bindings != null ? bindings : InputBindings.nilBindings();
}
@Nullable
@@ -793,6 +779,16 @@ public interface ApplyFunction
this.index = index;
return this;
}
+
+ @Nullable
+ @Override
+ public ExpressionType getType(String name)
+ {
+ if (scoped && lambdaIdentifiers.containsKey(name)) {
+ return arrayElementType;
+ }
+ return bindings.getType(name);
+ }
}
/**
@@ -803,6 +799,8 @@ public interface ApplyFunction
*/
interface IndexableFoldLambdaBinding extends Expr.ObjectBinding
{
+ ExpressionType getAccumulatorType();
+
/**
* Total number of bindings in this binding
*/
@@ -821,20 +819,31 @@ public interface ApplyFunction
class FoldLambdaBinding implements IndexableFoldLambdaBinding
{
private final Expr.ObjectBinding bindings;
+ private final ExpressionType arrayElementType;
+ private final ExpressionType accumulatorType;
private final String elementIdentifier;
private final Object[] arrayValues;
private final String accumulatorIdentifier;
private Object accumulatorValue;
private int index;
- FoldLambdaBinding(Object[] arrayValues, Object initialAccumulator, LambdaExpr expr, Expr.ObjectBinding bindings)
+ FoldLambdaBinding(
+ ExpressionType arrayElementType,
+ Object[] arrayValues,
+ ExpressionType accumulatorType,
+ Object initialAccumulator,
+ LambdaExpr expr,
+ Expr.ObjectBinding bindings
+ )
{
List<String> ids = expr.getIdentifiers();
this.elementIdentifier = ids.get(0);
+ this.arrayElementType = arrayElementType;
+ this.accumulatorType = accumulatorType;
this.accumulatorIdentifier = ids.get(1);
this.arrayValues = arrayValues;
this.accumulatorValue = initialAccumulator;
- this.bindings = bindings != null ? bindings : Collections.emptyMap()::get;
+ this.bindings = bindings != null ? bindings : InputBindings.nilBindings();
}
@Nullable
@@ -850,6 +859,12 @@ public interface ApplyFunction
}
@Override
+ public ExpressionType getAccumulatorType()
+ {
+ return accumulatorType;
+ }
+
+ @Override
public int getLength()
{
return arrayValues.length;
@@ -862,6 +877,18 @@ public interface ApplyFunction
this.accumulatorValue = acc;
return this;
}
+
+ @Nullable
+ @Override
+ public ExpressionType getType(String name)
+ {
+ if (name.equals(elementIdentifier)) {
+ return arrayElementType;
+ } else if (name.equals(accumulatorIdentifier)) {
+ return accumulatorType;
+ }
+ return bindings.getType(name);
+ }
}
/**
@@ -871,14 +898,25 @@ public interface ApplyFunction
class CartesianFoldLambdaBinding implements IndexableFoldLambdaBinding
{
private final Expr.ObjectBinding bindings;
+ private final ExpressionType arrayElementType;
+ private final ExpressionType accumulatorType;
private final Object2IntMap<String> lambdaIdentifiers;
private final List<List<Object>> lambdaInputs;
private final String accumulatorIdentifier;
private Object accumulatorValue;
private int index = 0;
- CartesianFoldLambdaBinding(List<List<Object>> inputs, Object accumulatorValue, LambdaExpr expr, Expr.ObjectBinding bindings)
- {
+ CartesianFoldLambdaBinding(
+ @Nullable ExpressionType arrayElementType,
+ List<List<Object>> inputs,
+ ExpressionType accumulatorType,
+ Object accumulatorValue,
+ LambdaExpr expr,
+ Expr.ObjectBinding bindings
+ )
+ {
+ this.arrayElementType = arrayElementType;
+ this.accumulatorType = accumulatorType;
this.lambdaInputs = inputs;
List<String> ids = expr.getIdentifiers();
this.lambdaIdentifiers = new Object2IntArrayMap<>(ids.size());
@@ -886,7 +924,7 @@ public interface ApplyFunction
lambdaIdentifiers.put(ids.get(i), i);
}
this.accumulatorIdentifier = ids.get(ids.size() - 1);
- this.bindings = bindings != null ? bindings : Collections.emptyMap()::get;
+ this.bindings = bindings != null ? bindings : InputBindings.nilBindings();
this.accumulatorValue = accumulatorValue;
}
@@ -903,6 +941,12 @@ public interface ApplyFunction
}
@Override
+ public ExpressionType getAccumulatorType()
+ {
+ return accumulatorType;
+ }
+
+ @Override
public int getLength()
{
return lambdaInputs.size();
@@ -915,6 +959,18 @@ public interface ApplyFunction
this.accumulatorValue = acc;
return this;
}
+
+ @Nullable
+ @Override
+ public ExpressionType getType(String name)
+ {
+ if (lambdaIdentifiers.containsKey(name)) {
+ return arrayElementType;
+ } else if (accumulatorIdentifier.equals(name)) {
+ return accumulatorType;
+ }
+ return bindings.getType(name);
+ }
}
/**
diff --git a/core/src/main/java/org/apache/druid/math/expr/ConstantExpr.java b/core/src/main/java/org/apache/druid/math/expr/ConstantExpr.java
index 52ea878..61770e9 100644
--- a/core/src/main/java/org/apache/druid/math/expr/ConstantExpr.java
+++ b/core/src/main/java/org/apache/druid/math/expr/ConstantExpr.java
@@ -22,9 +22,12 @@ package org.apache.druid.math.expr;
import com.google.common.base.Preconditions;
import org.apache.commons.lang.StringEscapeUtils;
import org.apache.druid.common.config.NullHandling;
+import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.math.expr.vector.ExprVectorProcessor;
import org.apache.druid.math.expr.vector.VectorProcessors;
+import org.apache.druid.segment.column.ObjectByteStrategy;
+import org.apache.druid.segment.column.Types;
import javax.annotation.Nullable;
import java.util.Arrays;
@@ -82,7 +85,7 @@ abstract class ConstantExpr<T> implements Expr
@Override
public BindingAnalysis analyzeInputs()
{
- return new BindingAnalysis();
+ return BindingAnalysis.EMTPY;
}
@Override
@@ -181,60 +184,6 @@ class NullLongExpr extends ConstantExpr<Long>
}
}
-class LongArrayExpr extends ConstantExpr<Long[]>
-{
- LongArrayExpr(@Nullable Long[] value)
- {
- super(ExpressionType.LONG_ARRAY, value);
- }
-
- @Override
- public String toString()
- {
- return Arrays.toString(value);
- }
-
- @Override
- public ExprEval eval(ObjectBinding bindings)
- {
- return ExprEval.ofLongArray(value);
- }
-
- @Override
- public boolean canVectorize(InputBindingInspector inspector)
- {
- return false;
- }
-
- @Override
- public String stringify()
- {
- if (value.length == 0) {
- return "<LONG>[]";
- }
- return StringUtils.format("<LONG>%s", toString());
- }
-
- @Override
- public boolean equals(Object o)
- {
- if (this == o) {
- return true;
- }
- if (o == null || getClass() != o.getClass()) {
- return false;
- }
- LongArrayExpr that = (LongArrayExpr) o;
- return Arrays.equals(value, that.value);
- }
-
- @Override
- public int hashCode()
- {
- return Arrays.hashCode(value);
- }
-}
-
class DoubleExpr extends ConstantExpr<Double>
{
DoubleExpr(Double value)
@@ -318,38 +267,36 @@ class NullDoubleExpr extends ConstantExpr<Double>
}
}
-class DoubleArrayExpr extends ConstantExpr<Double[]>
+class StringExpr extends ConstantExpr<String>
{
- DoubleArrayExpr(@Nullable Double[] value)
+ StringExpr(@Nullable String value)
{
- super(ExpressionType.DOUBLE_ARRAY, value);
+ super(ExpressionType.STRING, NullHandling.emptyToNullIfNeeded(value));
}
@Override
public String toString()
{
- return Arrays.toString(value);
+ return value;
}
@Override
public ExprEval eval(ObjectBinding bindings)
{
- return ExprEval.ofDoubleArray(value);
+ return ExprEval.of(value);
}
@Override
- public boolean canVectorize(InputBindingInspector inspector)
+ public <T> ExprVectorProcessor<T> buildVectorized(VectorInputBindingInspector inspector)
{
- return false;
+ return VectorProcessors.constantString(value, inspector.getMaxVectorSize());
}
@Override
public String stringify()
{
- if (value.length == 0) {
- return "<DOUBLE>[]";
- }
- return StringUtils.format("<DOUBLE>%s", toString());
+ // escape as javascript string since string literals are wrapped in single quotes
+ return value == null ? NULL_LITERAL : StringUtils.format("'%s'", StringEscapeUtils.escapeJavaScript(value));
}
@Override
@@ -361,47 +308,79 @@ class DoubleArrayExpr extends ConstantExpr<Double[]>
if (o == null || getClass() != o.getClass()) {
return false;
}
- DoubleArrayExpr that = (DoubleArrayExpr) o;
- return Arrays.equals(value, that.value);
+ StringExpr that = (StringExpr) o;
+ return Objects.equals(value, that.value);
}
@Override
public int hashCode()
{
- return Arrays.hashCode(value);
+ return Objects.hash(value);
}
}
-class StringExpr extends ConstantExpr<String>
+class ArrayExpr extends ConstantExpr<Object[]>
{
- StringExpr(@Nullable String value)
- {
- super(ExpressionType.STRING, NullHandling.emptyToNullIfNeeded(value));
- }
-
- @Override
- public String toString()
+ public ArrayExpr(ExpressionType outputType, @Nullable Object[] value)
{
- return value;
+ super(outputType, value);
+ Preconditions.checkArgument(outputType.isArray());
+ ExpressionType.checkNestedArrayAllowed(outputType);
}
@Override
public ExprEval eval(ObjectBinding bindings)
{
- return ExprEval.of(value);
+ return ExprEval.ofArray(outputType, value);
}
@Override
- public <T> ExprVectorProcessor<T> buildVectorized(VectorInputBindingInspector inspector)
+ public boolean canVectorize(InputBindingInspector inspector)
{
- return VectorProcessors.constantString(value, inspector.getMaxVectorSize());
+ return false;
}
@Override
public String stringify()
{
- // escape as javascript string since string literals are wrapped in single quotes
- return value == null ? NULL_LITERAL : StringUtils.format("'%s'", StringEscapeUtils.escapeJavaScript(value));
+ if (value == null) {
+ return NULL_LITERAL;
+ }
+ if (value.length == 0) {
+ return outputType.asTypeString() + "[]";
+ }
+ if (outputType.getElementType().is(ExprType.STRING)) {
+ return StringUtils.format(
+ "%s[%s]",
+ outputType.asTypeString(),
+ ARG_JOINER.join(
+ Arrays.stream(value)
+ .map(s -> s == null
+ ? NULL_LITERAL
+ // escape as javascript string since string literals are wrapped in single quotes
+ : StringUtils.format("'%s'", StringEscapeUtils.escapeJavaScript((String) s))
+ )
+ .iterator()
+ )
+ );
+ } else if (outputType.getElementType().isNumeric()) {
+ return outputType.asTypeString() + Arrays.toString(value);
+ } else if (outputType.getElementType().is(ExprType.COMPLEX)) {
+ Object[] stringified = new Object[value.length];
+ for (int i = 0; i < value.length; i++) {
+ stringified[i] = new ComplexExpr((ExpressionType) outputType.getElementType(), value[i]).stringify();
+ }
+ // use array function to rebuild since we can't stringify complex types directly
+ return StringUtils.format("array(%s)", Arrays.toString(stringified));
+ } else if (outputType.getElementType().isArray()) {
+ // use array function to rebuild since the parser can't yet recognize nested arrays e.g. [['foo', 'bar'],['baz']]
+ Object[] stringified = new Object[value.length];
+ for (int i = 0; i < value.length; i++) {
+ stringified[i] = new ArrayExpr((ExpressionType) outputType.getElementType(), (Object[]) value[i]).stringify();
+ }
+ return StringUtils.format("array(%s)", Arrays.toString(stringified));
+ }
+ throw new IAE("cannot stringify array type %s", outputType);
}
@Override
@@ -413,22 +392,14 @@ class StringExpr extends ConstantExpr<String>
if (o == null || getClass() != o.getClass()) {
return false;
}
- StringExpr that = (StringExpr) o;
- return Objects.equals(value, that.value);
+ ArrayExpr that = (ArrayExpr) o;
+ return outputType.equals(that.outputType) && Arrays.equals(value, that.value);
}
@Override
public int hashCode()
{
- return Objects.hash(value);
- }
-}
-
-class StringArrayExpr extends ConstantExpr<String[]>
-{
- StringArrayExpr(@Nullable String[] value)
- {
- super(ExpressionType.STRING_ARRAY, value);
+ return Objects.hash(outputType, Arrays.hashCode(value));
}
@Override
@@ -436,11 +407,19 @@ class StringArrayExpr extends ConstantExpr<String[]>
{
return Arrays.toString(value);
}
+}
+
+class ComplexExpr extends ConstantExpr<Object>
+{
+ protected ComplexExpr(ExpressionType outputType, @Nullable Object value)
+ {
+ super(outputType, value);
+ }
@Override
public ExprEval eval(ObjectBinding bindings)
{
- return ExprEval.ofStringArray(value);
+ return ExprEval.ofComplex(outputType, value);
}
@Override
@@ -452,21 +431,17 @@ class StringArrayExpr extends ConstantExpr<String[]>
@Override
public String stringify()
{
- if (value.length == 0) {
- return "<STRING>[]";
+ if (value == null) {
+ return StringUtils.format("complex_decode_base64('%s', %s)", outputType.getComplexTypeName(), NULL_LITERAL);
+ }
+ ObjectByteStrategy strategy = Types.getStrategy(outputType.getComplexTypeName());
+ if (strategy == null) {
+ throw new IAE("Cannot stringify type[%s]", outputType.asTypeString());
}
-
return StringUtils.format(
- "<STRING>[%s]",
- ARG_JOINER.join(
- Arrays.stream(value)
- .map(s -> s == null
- ? NULL_LITERAL
- // escape as javascript string since string literals are wrapped in single quotes
- : StringUtils.format("'%s'", StringEscapeUtils.escapeJavaScript(s))
- )
- .iterator()
- )
+ "complex_decode_base64('%s', '%s')",
+ outputType.getComplexTypeName(),
+ StringUtils.encodeBase64String(strategy.toBytes(value))
);
}
@@ -479,13 +454,13 @@ class StringArrayExpr extends ConstantExpr<String[]>
if (o == null || getClass() != o.getClass()) {
return false;
}
- StringArrayExpr that = (StringArrayExpr) o;
- return Arrays.equals(value, that.value);
+ ComplexExpr that = (ComplexExpr) o;
+ return outputType.equals(that.outputType) && Objects.equals(value, that.value);
}
@Override
public int hashCode()
{
- return Arrays.hashCode(value);
+ return Objects.hash(outputType, value);
}
}
diff --git a/core/src/main/java/org/apache/druid/math/expr/Expr.java b/core/src/main/java/org/apache/druid/math/expr/Expr.java
index 3f3ff14..43471b3 100644
--- a/core/src/main/java/org/apache/druid/math/expr/Expr.java
+++ b/core/src/main/java/org/apache/druid/math/expr/Expr.java
@@ -284,7 +284,7 @@ public interface Expr extends Cacheable
/**
* Mechanism to supply values to back {@link IdentifierExpr} during expression evaluation
*/
- interface ObjectBinding
+ interface ObjectBinding extends InputBindingInspector
{
/**
* Get value binding for string identifier of {@link IdentifierExpr}
@@ -364,13 +364,15 @@ public interface Expr extends Cacheable
@SuppressWarnings("JavadocReference")
class BindingAnalysis
{
+ public static final BindingAnalysis EMTPY = new BindingAnalysis();
+
private final ImmutableSet<IdentifierExpr> freeVariables;
private final ImmutableSet<IdentifierExpr> scalarVariables;
private final ImmutableSet<IdentifierExpr> arrayVariables;
private final boolean hasInputArrays;
private final boolean isOutputArray;
- BindingAnalysis()
+ public BindingAnalysis()
{
this(ImmutableSet.of(), ImmutableSet.of(), ImmutableSet.of(), false, false);
}
diff --git a/core/src/main/java/org/apache/druid/math/expr/ExprEval.java b/core/src/main/java/org/apache/druid/math/expr/ExprEval.java
index 4ab5b10..04394b2 100644
--- a/core/src/main/java/org/apache/druid/math/expr/ExprEval.java
+++ b/core/src/main/java/org/apache/druid/math/expr/ExprEval.java
@@ -19,145 +19,65 @@
package org.apache.druid.math.expr;
+import com.google.common.base.Preconditions;
import com.google.common.primitives.Doubles;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.common.guava.GuavaUtils;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.ISE;
+import org.apache.druid.java.util.common.NonnullPair;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.UOE;
+import org.apache.druid.segment.column.ObjectByteStrategy;
+import org.apache.druid.segment.column.Types;
import javax.annotation.Nullable;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.List;
-import java.util.Objects;
/**
* Generic result holder for evaluated {@link Expr} containing the value and {@link ExprType} of the value to allow
*/
public abstract class ExprEval<T>
{
- private static final int NULL_LENGTH = -1;
-
/**
* Deserialize an expression stored in a bytebuffer, e.g. for an agg.
*
* This should be refactored to be consolidated with some of the standard type handling of aggregators probably
*/
- public static ExprEval deserialize(ByteBuffer buffer, int position)
+ public static ExprEval deserialize(ByteBuffer buffer, int offset, ExpressionType type)
{
- final ExprType type = ExprType.fromByte(buffer.get(position));
- return deserialize(buffer, position + 1, type);
- }
-
- /**
- * Deserialize an expression stored in a bytebuffer, e.g. for an agg.
- *
- * This should be refactored to be consolidated with some of the standard type handling of aggregators probably
- */
- public static ExprEval deserialize(ByteBuffer buffer, int offset, ExprType type)
- {
- // | expression bytes |
- switch (type) {
+ switch (type.getType()) {
case LONG:
- // | expression type (byte) | is null (byte) | long bytes |
- if (buffer.get(offset++) == NullHandling.IS_NOT_NULL_BYTE) {
- return of(buffer.getLong(offset));
+ if (Types.isNullableNull(buffer, offset)) {
+ return ofLong(null);
}
- return ofLong(null);
+ return of(Types.readNullableLong(buffer, offset));
case DOUBLE:
- // | expression type (byte) | is null (byte) | double bytes |
- if (buffer.get(offset++) == NullHandling.IS_NOT_NULL_BYTE) {
- return of(buffer.getDouble(offset));
+ if (Types.isNullableNull(buffer, offset)) {
+ return ofDouble(null);
}
- return ofDouble(null);
+ return of(Types.readNullableDouble(buffer, offset));
case STRING:
- // | expression type (byte) | string length (int) | string bytes |
- final int length = buffer.getInt(offset);
- if (length < 0) {
+ if (Types.isNullableNull(buffer, offset)) {
return of(null);
}
- final byte[] stringBytes = new byte[length];
- final int oldPosition = buffer.position();
- buffer.position(offset + Integer.BYTES);
- buffer.get(stringBytes, 0, length);
- buffer.position(oldPosition);
+ final byte[] stringBytes = Types.readNullableVariableBlob(buffer, offset);
return of(StringUtils.fromUtf8(stringBytes));
case ARRAY:
- final ExprType elementType = ExprType.fromByte(buffer.get(offset++));
- switch (elementType) {
+ switch (type.getElementType().getType()) {
case LONG:
- // | expression type (byte) | array element type (byte) | array length (int) | array bytes |
- final int longArrayLength = buffer.getInt(offset);
- offset += Integer.BYTES;
- if (longArrayLength < 0) {
- return ofLongArray(null);
- }
- final Long[] longs = new Long[longArrayLength];
- for (int i = 0; i < longArrayLength; i++) {
- final byte isNull = buffer.get(offset);
- offset += Byte.BYTES;
- if (isNull == NullHandling.IS_NOT_NULL_BYTE) {
- // | is null (byte) | long bytes |
- longs[i] = buffer.getLong(offset);
- offset += Long.BYTES;
- } else {
- // | is null (byte) |
- longs[i] = null;
- }
- }
- return ofLongArray(longs);
+ return ofLongArray(Types.readNullableLongArray(buffer, offset));
case DOUBLE:
- // | expression type (byte) | array element type (byte) | array length (int) | array bytes |
- final int doubleArrayLength = buffer.getInt(offset);
- offset += Integer.BYTES;
- if (doubleArrayLength < 0) {
- return ofDoubleArray(null);
- }
- final Double[] doubles = new Double[doubleArrayLength];
- for (int i = 0; i < doubleArrayLength; i++) {
- final byte isNull = buffer.get(offset);
- offset += Byte.BYTES;
- if (isNull == NullHandling.IS_NOT_NULL_BYTE) {
- // | is null (byte) | double bytes |
- doubles[i] = buffer.getDouble(offset);
- offset += Double.BYTES;
- } else {
- // | is null (byte) |
- doubles[i] = null;
- }
- }
- return ofDoubleArray(doubles);
+ return ofDoubleArray(Types.readNullableDoubleArray(buffer, offset));
case STRING:
- // | expression type (byte) | array element type (byte) | array length (int) | array bytes |
- final int stringArrayLength = buffer.getInt(offset);
- offset += Integer.BYTES;
- if (stringArrayLength < 0) {
- return ofStringArray(null);
- }
- final String[] stringArray = new String[stringArrayLength];
- for (int i = 0; i < stringArrayLength; i++) {
- final int stringElementLength = buffer.getInt(offset);
- offset += Integer.BYTES;
- if (stringElementLength < 0) {
- // | string length (int) |
- stringArray[i] = null;
- } else {
- // | string length (int) | string bytes |
- final byte[] stringElementBytes = new byte[stringElementLength];
- final int oldPosition2 = buffer.position();
- buffer.position(offset);
- buffer.get(stringElementBytes, 0, stringElementLength);
- buffer.position(oldPosition2);
- stringArray[i] = StringUtils.fromUtf8(stringElementBytes);
- offset += stringElementLength;
- }
- }
- return ofStringArray(stringArray);
+ return ofStringArray(Types.readNullableStringArray(buffer, offset));
default:
- throw new UOE("Cannot deserialize expression array of type %s", elementType);
+ throw new UOE("Cannot deserialize expression array of type %s", type);
}
+ case COMPLEX:
+ return ofComplex(type, Types.readNullableComplexType(buffer, offset, type));
default:
throw new UOE("Cannot deserialize expression type %s", type);
}
@@ -173,206 +93,53 @@ public abstract class ExprEval<T>
public static void serialize(ByteBuffer buffer, int position, ExprEval<?> eval, int maxSizeBytes)
{
int offset = position;
- buffer.put(offset++, eval.type().getType().getId());
switch (eval.type().getType()) {
case LONG:
if (eval.isNumericNull()) {
- buffer.put(offset, NullHandling.IS_NULL_BYTE);
+ Types.writeNull(buffer, offset);
} else {
- buffer.put(offset++, NullHandling.IS_NOT_NULL_BYTE);
- buffer.putLong(offset, eval.asLong());
+ Types.writeNullableLong(buffer, offset, eval.asLong());
}
break;
case DOUBLE:
if (eval.isNumericNull()) {
- buffer.put(offset, NullHandling.IS_NULL_BYTE);
+ Types.writeNull(buffer, offset);
} else {
- buffer.put(offset++, NullHandling.IS_NOT_NULL_BYTE);
- buffer.putDouble(offset, eval.asDouble());
+ Types.writeNullableDouble(buffer, offset, eval.asDouble());
}
break;
case STRING:
final byte[] stringBytes = StringUtils.toUtf8Nullable(eval.asString());
if (stringBytes != null) {
- // | expression type (byte) | string length (int) | string bytes |
- checkMaxBytes(eval.type(), 1 + Integer.BYTES + stringBytes.length, maxSizeBytes);
- buffer.putInt(offset, stringBytes.length);
- offset += Integer.BYTES;
- final int oldPosition = buffer.position();
- buffer.position(offset);
- buffer.put(stringBytes, 0, stringBytes.length);
- buffer.position(oldPosition);
+ Types.writeNullableVariableBlob(buffer, offset, stringBytes, eval.type(), maxSizeBytes);
} else {
- checkMaxBytes(eval.type(), 1 + Integer.BYTES, maxSizeBytes);
- buffer.putInt(offset, NULL_LENGTH);
+ Types.writeNull(buffer, offset);
}
break;
case ARRAY:
- // | expression type (byte) | array type (byte) | array length (int) | array bytes |
- buffer.put(offset++, eval.type().getElementType().getType().getId());
switch (eval.type().getElementType().getType()) {
case LONG:
Long[] longs = eval.asLongArray();
- if (longs == null) {
- // | expression type (byte) | array type (byte) | array length (int) |
- checkMaxBytes(eval.type(), 2 + Integer.BYTES, maxSizeBytes);
- buffer.putInt(offset, NULL_LENGTH);
- } else {
- // | expression type (byte) | array type (byte) | array length (int) | array bytes |
- final int sizeBytes = 2 + Integer.BYTES + (Long.BYTES * longs.length);
- checkMaxBytes(eval.type(), sizeBytes, maxSizeBytes);
- buffer.putInt(offset, longs.length);
- offset += Integer.BYTES;
- for (Long aLong : longs) {
- if (aLong != null) {
- buffer.put(offset, NullHandling.IS_NOT_NULL_BYTE);
- offset++;
- buffer.putLong(offset, aLong);
- offset += Long.BYTES;
- } else {
- buffer.put(offset++, NullHandling.IS_NULL_BYTE);
- }
- }
- }
+ Types.writeNullableLongArray(buffer, offset, longs, maxSizeBytes);
break;
case DOUBLE:
Double[] doubles = eval.asDoubleArray();
- if (doubles == null) {
- // | expression type (byte) | array type (byte) | array length (int) |
- checkMaxBytes(eval.type(), 2 + Integer.BYTES, maxSizeBytes);
- buffer.putInt(offset, NULL_LENGTH);
- } else {
- // | expression type (byte) | array type (byte) | array length (int) | array bytes |
- final int sizeBytes = 2 + Integer.BYTES + (Double.BYTES * doubles.length);
- checkMaxBytes(eval.type(), sizeBytes, maxSizeBytes);
- buffer.putInt(offset, doubles.length);
- offset += Integer.BYTES;
-
- for (Double aDouble : doubles) {
- if (aDouble != null) {
- buffer.put(offset, NullHandling.IS_NOT_NULL_BYTE);
- offset++;
- buffer.putDouble(offset, aDouble);
- offset += Long.BYTES;
- } else {
- buffer.put(offset++, NullHandling.IS_NULL_BYTE);
- }
- }
- }
+ Types.writeNullableDoubleArray(buffer, offset, doubles, maxSizeBytes);
break;
case STRING:
String[] strings = eval.asStringArray();
- if (strings == null) {
- // | expression type (byte) | array type (byte) | array length (int) |
- checkMaxBytes(eval.type(), 2 + Integer.BYTES, maxSizeBytes);
- buffer.putInt(offset, NULL_LENGTH);
- } else {
- // | expression type (byte) | array type (byte) | array length (int) | array bytes |
- buffer.putInt(offset, strings.length);
- offset += Integer.BYTES;
- int sizeBytes = 2 + Integer.BYTES;
- for (String string : strings) {
- if (string == null) {
- // | string length (int) |
- sizeBytes += Integer.BYTES;
- checkMaxBytes(eval.type(), sizeBytes, maxSizeBytes);
- buffer.putInt(offset, NULL_LENGTH);
- offset += Integer.BYTES;
- } else {
- // | string length (int) | string bytes |
- final byte[] stringElementBytes = StringUtils.toUtf8(string);
- sizeBytes += Integer.BYTES + stringElementBytes.length;
- checkMaxBytes(eval.type(), sizeBytes, maxSizeBytes);
- buffer.putInt(offset, stringElementBytes.length);
- offset += Integer.BYTES;
- final int oldPosition = buffer.position();
- buffer.position(offset);
- buffer.put(stringElementBytes, 0, stringElementBytes.length);
- buffer.position(oldPosition);
- offset += stringElementBytes.length;
- }
- }
- }
+ Types.writeNullableStringArray(buffer, offset, strings, maxSizeBytes);
break;
default:
throw new UOE("Cannot serialize expression array type %s", eval.type());
}
break;
- default:
- throw new UOE("Cannot serialize expression type %s", eval.type());
- }
- }
-
- public static void checkMaxBytes(ExpressionType type, int sizeBytes, int maxSizeBytes)
- {
- if (sizeBytes > maxSizeBytes) {
- throw new ISE("Unable to serialize [%s], size [%s] is larger than max [%s]", type, sizeBytes, maxSizeBytes);
- }
- }
-
- /**
- * Used to estimate the size in bytes to {@link #serialize} the {@link ExprEval} value, checking against a maximum
- * size and failing with an {@link ISE} if the estimate is over the maximum.
- */
- public static void estimateAndCheckMaxBytes(ExprEval eval, int maxSizeBytes)
- {
- final int estimated;
- switch (eval.type().getType()) {
- case STRING:
- String stringValue = eval.asString();
- estimated = 1 + Integer.BYTES + (stringValue == null ? 0 : StringUtils.estimatedBinaryLengthAsUTF8(stringValue));
- break;
- case LONG:
- case DOUBLE:
- estimated = 1 + (NullHandling.sqlCompatible() ? 1 + Long.BYTES : Long.BYTES);
- break;
- case ARRAY:
- switch (eval.type().getElementType().getType()) {
- case STRING:
- String[] stringArray = eval.asStringArray();
- if (stringArray == null) {
- estimated = 2 + Integer.BYTES;
- } else {
- final int elementsSize = Arrays.stream(stringArray)
- .filter(Objects::nonNull)
- .mapToInt(StringUtils::estimatedBinaryLengthAsUTF8)
- .sum();
- // since each value is variably sized, there is an integer per element
- estimated = 2 + Integer.BYTES + (Integer.BYTES * stringArray.length) + elementsSize;
- }
- break;
- case LONG:
- Long[] longArray = eval.asLongArray();
- if (longArray == null) {
- estimated = 2 + Integer.BYTES;
- } else {
- final int elementsSize = Arrays.stream(longArray)
- .filter(Objects::nonNull)
- .mapToInt(x -> Long.BYTES)
- .sum();
- estimated = 2 + Integer.BYTES + (NullHandling.sqlCompatible() ? longArray.length : 0) + elementsSize;
- }
- break;
- case DOUBLE:
- Double[] doubleArray = eval.asDoubleArray();
- if (doubleArray == null) {
- estimated = 2 + Integer.BYTES;
- } else {
- final int elementsSize = Arrays.stream(doubleArray)
- .filter(Objects::nonNull)
- .mapToInt(x -> Long.BYTES)
- .sum();
- estimated = 2 + Integer.BYTES + (NullHandling.sqlCompatible() ? doubleArray.length : 0) + elementsSize;
- }
- break;
- default:
- throw new ISE("Unsupported array type: %s", eval.type());
- }
+ case COMPLEX:
+ Types.writeNullableComplexType(buffer, offset, eval.type(), eval.value(), maxSizeBytes);
break;
default:
- throw new ISE("Unsupported type: %s", eval.type());
+ throw new UOE("Cannot serialize expression type %s", eval.type());
}
- checkMaxBytes(eval.type(), estimated, maxSizeBytes);
}
/**
@@ -382,7 +149,7 @@ public abstract class ExprEval<T>
* If homogenizeMultiValueStrings is true, null and [] will be converted to [null], otherwise they will retain
*/
@Nullable
- public static Object coerceListToArray(@Nullable List<?> val, boolean homogenizeMultiValueStrings)
+ public static NonnullPair<ExpressionType, Object[]> coerceListToArray(@Nullable List<?> val, boolean homogenizeMultiValueStrings)
{
// if value is not null and has at least 1 element, conversion is unambigous regardless of the selector
if (val != null && val.size() > 0) {
@@ -395,24 +162,58 @@ public abstract class ExprEval<T>
}
if (coercedType == Long.class || coercedType == Integer.class) {
- return val.stream().map(x -> x != null ? ((Number) x).longValue() : null).toArray(Long[]::new);
+ return new NonnullPair<>(
+ ExpressionType.LONG_ARRAY,
+ val.stream().map(x -> x != null ? ((Number) x).longValue() : null).toArray()
+ );
}
if (coercedType == Float.class || coercedType == Double.class) {
- return val.stream().map(x -> x != null ? ((Number) x).doubleValue() : null).toArray(Double[]::new);
+ return new NonnullPair<>(
+ ExpressionType.DOUBLE_ARRAY,
+ val.stream().map(x -> x != null ? ((Number) x).doubleValue() : null).toArray()
+ );
}
// default to string
- return val.stream().map(x -> x != null ? x.toString() : null).toArray(String[]::new);
+ return new NonnullPair<>(
+ ExpressionType.STRING_ARRAY,
+ val.stream().map(x -> x != null ? x.toString() : null).toArray()
+ );
}
if (homogenizeMultiValueStrings) {
- return new String[]{null};
+ return new NonnullPair<>(ExpressionType.STRING_ARRAY, new Object[]{null});
} else {
if (val != null) {
- return new String[0];
+ return new NonnullPair<>(ExpressionType.STRING_ARRAY, new Object[0]);
}
return null;
}
}
+ @Nullable
+ public static ExpressionType findArrayType(@Nullable Object[] val)
+ {
+ // if value is not null and has at least 1 element, conversion is unambigous regardless of the selector
+ if (val != null && val.length > 0) {
+ Class<?> coercedType = null;
+
+ for (Object elem : val) {
+ if (elem != null) {
+ coercedType = convertType(coercedType, elem.getClass());
+ }
+ }
+
+ if (coercedType == Long.class || coercedType == Integer.class) {
+ return ExpressionType.LONG_ARRAY;
+ }
+ if (coercedType == Float.class || coercedType == Double.class) {
+ return ExpressionType.DOUBLE_ARRAY;
+ }
+ // default to string
+ return ExpressionType.STRING_ARRAY;
+ }
+ return null;
+ }
+
/**
* Find the common type to use between 2 types, useful for choosing the appropriate type for an array given a set
* of objects with unknown type, following rules similar to Java, our own native Expr, and SQL implicit type
@@ -500,25 +301,32 @@ public abstract class ExprEval<T>
public static ExprEval ofLongArray(@Nullable Long[] longValue)
{
if (longValue == null) {
- return LongArrayExprEval.OF_NULL;
+ return ArrayExprEval.OF_NULL_LONG;
}
- return new LongArrayExprEval(longValue);
+ return new ArrayExprEval(ExpressionType.LONG_ARRAY, longValue);
}
public static ExprEval ofDoubleArray(@Nullable Double[] doubleValue)
{
if (doubleValue == null) {
- return DoubleArrayExprEval.OF_NULL;
+ return ArrayExprEval.OF_NULL_DOUBLE;
}
- return new DoubleArrayExprEval(doubleValue);
+ return new ArrayExprEval(ExpressionType.DOUBLE_ARRAY, doubleValue);
}
public static ExprEval ofStringArray(@Nullable String[] stringValue)
{
if (stringValue == null) {
- return StringArrayExprEval.OF_NULL;
+ return ArrayExprEval.OF_NULL_STRING;
}
- return new StringArrayExprEval(stringValue);
+ return new ArrayExprEval(ExpressionType.STRING_ARRAY, stringValue);
+ }
+
+
+ public static ExprEval ofArray(ExpressionType outputType, Object[] value)
+ {
+ Preconditions.checkArgument(outputType.isArray());
+ return new ArrayExprEval(outputType, value);
}
/**
@@ -546,6 +354,11 @@ public abstract class ExprEval<T>
return ExprEval.of(Evals.asLong(value));
}
+ public static ExprEval ofComplex(ExpressionType outputType, @Nullable Object value)
+ {
+ return new ComplexExprEval(outputType, value);
+ }
+
/**
* Examine java type to find most appropriate expression type
*/
@@ -554,6 +367,9 @@ public abstract class ExprEval<T>
if (val instanceof ExprEval) {
return (ExprEval) val;
}
+ if (val instanceof String) {
+ return new StringExprEval((String) val);
+ }
if (val instanceof Number) {
if (val instanceof Float || val instanceof Double) {
return new DoubleExprEval((Number) val);
@@ -561,26 +377,93 @@ public abstract class ExprEval<T>
return new LongExprEval((Number) val);
}
if (val instanceof Long[]) {
- return new LongArrayExprEval((Long[]) val);
+ return new ArrayExprEval(ExpressionType.LONG_ARRAY, (Long[]) val);
}
if (val instanceof Double[]) {
- return new DoubleArrayExprEval((Double[]) val);
+ return new ArrayExprEval(ExpressionType.DOUBLE_ARRAY, (Double[]) val);
}
if (val instanceof Float[]) {
- return new DoubleArrayExprEval(Arrays.stream((Float[]) val).map(Float::doubleValue).toArray(Double[]::new));
+ return new ArrayExprEval(ExpressionType.DOUBLE_ARRAY, Arrays.stream((Float[]) val).map(Float::doubleValue).toArray());
}
if (val instanceof String[]) {
- return new StringArrayExprEval((String[]) val);
+ return new ArrayExprEval(ExpressionType.STRING_ARRAY, (String[]) val);
+ }
+ if (val instanceof Object[]) {
+ ExpressionType arrayType = findArrayType((Object[]) val);
+ if (arrayType != null) {
+ return new ArrayExprEval(arrayType, (Object[]) val);
+ }
+ // default to string if array is empty
+ return new ArrayExprEval(ExpressionType.STRING_ARRAY, (Object[]) val);
}
if (val instanceof List) {
// do not convert empty lists to arrays with a single null element here, because that should have been done
// by the selectors preparing their ObjectBindings if necessary. If we get to this point it was legitimately
// empty
- return bestEffortOf(coerceListToArray((List<?>) val, false));
+ NonnullPair<ExpressionType, Object[]> coerced = coerceListToArray((List<?>) val, false);
+ if (coerced == null) {
+ return bestEffortOf(null);
+ }
+ return ofArray(coerced.lhs, coerced.rhs);
}
- return new StringExprEval(val == null ? null : String.valueOf(val));
+ if (val != null) {
+ // is this cool?
+ return new ComplexExprEval(ExpressionType.UNKNOWN_COMPLEX, val);
+ }
+
+ return new StringExprEval(null);
+ }
+
+ public static ExprEval ofType(@Nullable ExpressionType type, @Nullable Object value)
+ {
+ if (type == null) {
+ return bestEffortOf(value);
+ }
+ switch (type.getType()) {
+ case STRING:
+ // not all who claim to be "STRING" are always a String, prepare ourselves...
+ if (value instanceof String[]) {
+ return new ArrayExprEval(ExpressionType.STRING_ARRAY, (String[]) value);
+ }
+ if (value instanceof Object[]) {
+ return new ArrayExprEval(ExpressionType.STRING_ARRAY, (Object[]) value);
+ }
+ if (value instanceof List) {
+ return bestEffortOf(value);
+ }
+ return of((String) value);
+ case LONG:
+ return ofLong((Number) value);
+ case DOUBLE:
+ return ofDouble((Number) value);
+ case COMPLEX:
+ byte[] bytes = null;
+ if (value instanceof String) {
+ bytes = StringUtils.decodeBase64String((String) value);
+ } else if (value instanceof byte[]) {
+ bytes = (byte[]) value;
+ }
+
+ if (bytes != null) {
+ ObjectByteStrategy<?> strategy = Types.getStrategy(type.getComplexTypeName());
+ assert strategy != null;
+ ByteBuffer bb = ByteBuffer.wrap(bytes);
+ return ofComplex(type, strategy.fromByteBuffer(bb, bytes.length));
+ }
+
+ return ofComplex(type, value);
+ case ARRAY:
+ if (value instanceof Object[]) {
+ return ofArray(type, (Object[]) value);
+ }
+ // in a better world, we might get an object that matches the type signature for arrays and could do a switch
+ // statement here, but this is not that world yet, and things that are array typed might also be non-arrays,
+ // e.g. we might get a String instead of String[], so just fallback to bestEffortOf
+ return bestEffortOf(value);
+ }
+ throw new IAE("Cannot create type [%s]", type);
}
@Nullable
@@ -620,6 +503,11 @@ public abstract class ExprEval<T>
return type().isArray() ? (ExpressionType) type().getElementType() : type();
}
+ public ExpressionType asArrayType()
+ {
+ return type().isArray() ? type() : ExpressionTypeFactory.getInstance().ofArray(type());
+ }
+
@Nullable
public T value()
{
@@ -1097,11 +985,26 @@ public abstract class ExprEval<T>
}
}
- abstract static class ArrayExprEval<T> extends ExprEval<T[]>
+ static class ArrayExprEval extends ExprEval<Object[]>
{
- private ArrayExprEval(@Nullable T[] value)
+ public static final ExprEval OF_NULL_LONG = new ArrayExprEval(ExpressionType.LONG_ARRAY, null);
+ public static final ExprEval OF_NULL_DOUBLE = new ArrayExprEval(ExpressionType.DOUBLE_ARRAY, null);
+ public static final ExprEval OF_NULL_STRING = new ArrayExprEval(ExpressionType.STRING_ARRAY, null);
+
+ private final ExpressionType arrayType;
+
+ private ArrayExprEval(ExpressionType arrayType, @Nullable Object[] value)
{
super(value);
+ this.arrayType = arrayType;
+ Preconditions.checkArgument(arrayType.isArray());
+ ExpressionType.checkNestedArrayAllowed(arrayType);
+ }
+
+ @Override
+ public ExpressionType type()
+ {
+ return arrayType;
}
@Override
@@ -1129,6 +1032,10 @@ public abstract class ExprEval<T>
public boolean isNumericNull()
{
if (isScalar()) {
+ if (arrayType.getElementType().is(ExprType.STRING)) {
+ Number n = computeNumber((String) getScalarValue());
+ return n == null;
+ }
return getScalarValue() == null;
}
@@ -1144,150 +1051,156 @@ public abstract class ExprEval<T>
@Override
public int asInt()
{
- return 0;
- }
-
- @Override
- public long asLong()
- {
- return 0;
- }
-
- @Override
- public double asDouble()
- {
- return 0;
- }
-
- @Override
- public boolean asBoolean()
- {
- return false;
- }
-
- @Nullable
- @Override
- public T[] asArray()
- {
- return value;
- }
-
- @Nullable
- public T getIndex(int index)
- {
- return value == null ? null : value[index];
- }
-
- protected boolean isScalar()
- {
- return value != null && value.length == 1;
- }
-
- @Nullable
- protected T getScalarValue()
- {
- assert value != null && value.length == 1;
- return value[0];
- }
- }
-
- private static class LongArrayExprEval extends ArrayExprEval<Long>
- {
- private static final LongArrayExprEval OF_NULL = new LongArrayExprEval(null);
-
- private LongArrayExprEval(@Nullable Long[] value)
- {
- super(value);
- }
-
- @Override
- public ExpressionType type()
- {
- return ExpressionType.LONG_ARRAY;
- }
-
- @Override
- public int asInt()
- {
if (isScalar()) {
- Number scalar = getScalarValue();
+ Number scalar = null;
+ if (arrayType.getElementType().isNumeric()) {
+ scalar = (Number) getScalarValue();
+ } else if (arrayType.getElementType().is(ExprType.STRING)) {
+ scalar = computeNumber((String) getScalarValue());
+ }
if (scalar == null) {
assert NullHandling.replaceWithDefault();
return 0;
}
return scalar.intValue();
}
- return super.asInt();
+ return 0;
}
@Override
public long asLong()
{
if (isScalar()) {
- Number scalar = getScalarValue();
+ Number scalar = null;
+ if (arrayType.getElementType().isNumeric()) {
+ scalar = (Number) getScalarValue();
+ } else if (arrayType.getElementType().is(ExprType.STRING)) {
+ scalar = computeNumber((String) getScalarValue());
+ }
if (scalar == null) {
assert NullHandling.replaceWithDefault();
return 0;
}
return scalar.longValue();
}
- return super.asLong();
+ return 0L;
}
@Override
public double asDouble()
{
if (isScalar()) {
- Number scalar = getScalarValue();
+ Number scalar = null;
+ if (arrayType.getElementType().isNumeric()) {
+ scalar = (Number) getScalarValue();
+ } else if (arrayType.getElementType().is(ExprType.STRING)) {
+ scalar = computeNumber((String) getScalarValue());
+ }
if (scalar == null) {
assert NullHandling.replaceWithDefault();
- return 0;
+ return 0.0;
}
return scalar.doubleValue();
}
- return super.asDouble();
+ return 0.0;
}
@Override
public boolean asBoolean()
{
if (isScalar()) {
- Number scalarValue = getScalarValue();
- if (scalarValue == null) {
- assert NullHandling.replaceWithDefault();
- return false;
+ if (arrayType.getElementType().isNumeric()) {
+ Number scalarValue = (Number) getScalarValue();
+ if (scalarValue == null) {
+ assert NullHandling.replaceWithDefault();
+ return false;
+ }
+ return Evals.asBoolean(scalarValue.longValue());
+ }
+ if (arrayType.getElementType().is(ExprType.STRING)) {
+ return Evals.asBoolean((String) getScalarValue());
}
- return Evals.asBoolean(scalarValue.longValue());
}
- return super.asBoolean();
+ return false;
+ }
+
+ @Nullable
+ @Override
+ public Object[] asArray()
+ {
+ return value;
}
@Nullable
@Override
public String[] asStringArray()
{
- return value == null ? null : Arrays.stream(value).map(x -> x != null ? x.toString() : null).toArray(String[]::new);
+ if (value != null) {
+ if (arrayType.getElementType().is(ExprType.STRING)) {
+ return Arrays.stream(value).map(v -> (String) v).toArray(String[]::new);
+ } else if (arrayType.getElementType().isNumeric()) {
+ return Arrays.stream(value).map(x -> x != null ? x.toString() : null).toArray(String[]::new);
+ }
+ }
+ return null;
}
@Nullable
@Override
public Long[] asLongArray()
{
- return value;
+ if (arrayType.getElementType().is(ExprType.LONG)) {
+ return Arrays.stream(value).map(v -> (Long) v).toArray(Long[]::new);
+ } else if (arrayType.getElementType().is(ExprType.DOUBLE)) {
+ return value == null ? null : Arrays.stream(value).map(v -> ((Double) v).longValue()).toArray(Long[]::new);
+ } else if (arrayType.getElementType().is(ExprType.STRING)) {
+ return Arrays.stream(value).map(v -> {
+ if (v == null) {
+ return null;
+ }
+ Long lv = GuavaUtils.tryParseLong((String) v);
+ if (lv == null) {
+ Double d = Doubles.tryParse((String) v);
+ if (d != null) {
+ lv = d.longValue();
+ }
+ }
+ return lv;
+ }).toArray(Long[]::new);
+ }
+ return null;
}
@Nullable
@Override
public Double[] asDoubleArray()
{
- return value == null ? null : Arrays.stream(value).map(Long::doubleValue).toArray(Double[]::new);
+ if (arrayType.getElementType().is(ExprType.DOUBLE)) {
+ return Arrays.stream(value).map(v -> (Double) v).toArray(Double[]::new);
+ } else if (arrayType.getElementType().is(ExprType.LONG)) {
+ return value == null ? null : Arrays.stream(value).map(v -> ((Long) v).doubleValue()).toArray(Double[]::new);
+ } else if (arrayType.getElementType().is(ExprType.STRING)) {
+ if (value == null) {
+ return null;
+ }
+ return Arrays.stream(value).map(val -> {
+ if (val == null) {
+ return null;
+ }
+ return Doubles.tryParse((String) val);
+ }).toArray(Double[]::new);
+ }
+ return new Double[0];
}
@Override
public ExprEval castTo(ExpressionType castTo)
{
if (value == null) {
- return StringExprEval.OF_NULL;
+ if (castTo.isArray()) {
+ return new ArrayExprEval(castTo, null);
+ }
+ return ExprEval.ofType(castTo, null);
}
switch (castTo.getType()) {
case STRING:
@@ -1306,15 +1219,12 @@ public abstract class ExprEval<T>
}
break;
case ARRAY:
- switch (castTo.getElementType().getType()) {
- case LONG:
- return this;
- case DOUBLE:
- return ExprEval.ofDoubleArray(asDoubleArray());
- case STRING:
- return ExprEval.ofStringArray(asStringArray());
+ ExpressionType elementType = (ExpressionType) castTo.getElementType();
+ Object[] cast = new Object[value.length];
+ for (int i = 0; i < value.length; i++) {
+ cast[i] = ExprEval.ofType(elementType(), value[i]).castTo(elementType).value();
}
- break;
+ return ExprEval.ofArray(castTo, cast);
}
throw new IAE("invalid type " + castTo);
@@ -1323,293 +1233,107 @@ public abstract class ExprEval<T>
@Override
public Expr toExpr()
{
- return new LongArrayExpr(value);
- }
- }
-
- private static class DoubleArrayExprEval extends ArrayExprEval<Double>
- {
- private static final DoubleArrayExprEval OF_NULL = new DoubleArrayExprEval(null);
-
- private DoubleArrayExprEval(@Nullable Double[] value)
- {
- super(value);
- }
-
- @Override
- public ExpressionType type()
- {
- return ExpressionType.DOUBLE_ARRAY;
- }
-
- @Override
- public int asInt()
- {
- if (isScalar()) {
- Number scalar = getScalarValue();
- if (scalar == null) {
- assert NullHandling.replaceWithDefault();
- return 0;
- }
- return scalar.intValue();
- }
- return super.asInt();
- }
-
- @Override
- public long asLong()
- {
- if (isScalar()) {
- Number scalar = getScalarValue();
- if (scalar == null) {
- assert NullHandling.replaceWithDefault();
- return 0;
- }
- return scalar.longValue();
- }
- return super.asLong();
- }
-
- @Override
- public double asDouble()
- {
- if (isScalar()) {
- Number scalar = getScalarValue();
- if (scalar == null) {
- assert NullHandling.replaceWithDefault();
- return 0;
- }
- return scalar.doubleValue();
- }
- return super.asDouble();
- }
-
- @Override
- public boolean asBoolean()
- {
- if (isScalar()) {
- Number scalarValue = getScalarValue();
- if (scalarValue == null) {
- assert NullHandling.replaceWithDefault();
- return false;
- }
- return Evals.asBoolean(scalarValue.longValue());
- }
- return super.asBoolean();
+ return new ArrayExpr(arrayType, value);
}
@Nullable
- @Override
- public String[] asStringArray()
+ public Object getIndex(int index)
{
- return value == null
- ? null
- : Arrays.stream(value).map(x -> x != null ? x.toString() : null).toArray(String[]::new);
+ return value == null ? null : value[index];
}
- @Nullable
- @Override
- public Long[] asLongArray()
+ protected boolean isScalar()
{
- return value == null ? null : Arrays.stream(value).map(Double::longValue).toArray(Long[]::new);
+ return value != null && value.length == 1;
}
@Nullable
- @Override
- public Double[] asDoubleArray()
+ protected Object getScalarValue()
{
- return value;
- }
-
- @Override
- public ExprEval castTo(ExpressionType castTo)
- {
- if (value == null) {
- return StringExprEval.OF_NULL;
- }
- switch (castTo.getType()) {
- case STRING:
- if (value.length == 1) {
- return ExprEval.of(asString());
- }
- break;
- case LONG:
- if (value.length == 1) {
- return isNumericNull() ? ExprEval.ofLong(null) : ExprEval.ofLong(asLong());
- }
- break;
- case DOUBLE:
- if (value.length == 1) {
- return isNumericNull() ? ExprEval.ofDouble(null) : ExprEval.ofDouble(asDouble());
- }
- break;
- case ARRAY:
- switch (castTo.getElementType().getType()) {
- case LONG:
- return ExprEval.ofLongArray(asLongArray());
- case DOUBLE:
- return this;
- case STRING:
- return ExprEval.ofStringArray(asStringArray());
- }
- }
-
- throw new IAE("invalid type " + castTo);
- }
-
- @Override
- public Expr toExpr()
- {
- return new DoubleArrayExpr(value);
+ assert value != null && value.length == 1;
+ return value[0];
}
}
- private static class StringArrayExprEval extends ArrayExprEval<String>
+ private static class ComplexExprEval extends ExprEval<Object>
{
- private static final StringArrayExprEval OF_NULL = new StringArrayExprEval(null);
-
- private boolean longValueValid = false;
- private boolean doubleValueValid = false;
- @Nullable
- private Long[] longValues;
- @Nullable
- private Double[] doubleValues;
- @Nullable
- private Number computedNumericScalar;
- private boolean isScalarNumberValid;
+ private final ExpressionType expressionType;
- private StringArrayExprEval(@Nullable String[] value)
+ private ComplexExprEval(ExpressionType expressionType, @Nullable Object value)
{
super(value);
+ this.expressionType = expressionType;
}
@Override
public ExpressionType type()
{
- return ExpressionType.STRING_ARRAY;
+ return expressionType;
}
@Override
public boolean isNumericNull()
{
- if (isScalar()) {
- computeScalarNumericIfNeeded();
- return computedNumericScalar == null;
- }
- return true;
+ return false;
}
@Override
public int asInt()
{
- if (isScalar()) {
- computeScalarNumericIfNeeded();
- if (computedNumericScalar == null) {
- assert NullHandling.replaceWithDefault();
- return 0;
- }
- return computedNumericScalar.intValue();
- }
- return super.asInt();
+ return 0;
}
@Override
public long asLong()
{
- if (isScalar()) {
- computeScalarNumericIfNeeded();
- if (computedNumericScalar == null) {
- assert NullHandling.replaceWithDefault();
- return 0L;
- }
- return computedNumericScalar.longValue();
- }
- return super.asLong();
+ return 0;
}
@Override
public double asDouble()
{
- if (isScalar()) {
- computeScalarNumericIfNeeded();
- if (computedNumericScalar == null) {
- assert NullHandling.replaceWithDefault();
- return 0.0;
- }
- return computedNumericScalar.doubleValue();
- }
- return super.asDouble();
+ return 0;
}
@Override
public boolean asBoolean()
{
- if (isScalar()) {
- return Evals.asBoolean(getScalarValue());
- }
- return super.asBoolean();
+ return false;
+ }
+
+ @Nullable
+ @Override
+ public Object[] asArray()
+ {
+ return new Object[0];
}
@Nullable
@Override
public String[] asStringArray()
{
- return value;
+ return new String[0];
}
@Nullable
@Override
public Long[] asLongArray()
{
- if (!longValueValid) {
- longValues = computeLongs();
- longValueValid = true;
- }
- return longValues;
+ return new Long[0];
}
@Nullable
@Override
public Double[] asDoubleArray()
{
- if (!doubleValueValid) {
- doubleValues = computeDoubles();
- doubleValueValid = true;
- }
- return doubleValues;
+ return new Double[0];
}
@Override
public ExprEval castTo(ExpressionType castTo)
{
- if (value == null) {
- return StringExprEval.OF_NULL;
- }
- switch (castTo.getType()) {
- case STRING:
- if (value.length == 1) {
- return ExprEval.of(asString());
- }
- break;
- case LONG:
- if (value.length == 1) {
- return isNumericNull() ? ExprEval.ofLong(null) : ExprEval.ofLong(asLong());
- }
- break;
- case DOUBLE:
- if (value.length == 1) {
- return isNumericNull() ? ExprEval.ofDouble(null) : ExprEval.ofDouble(asDouble());
- }
- break;
- case ARRAY:
- switch (castTo.getElementType().getType()) {
- case STRING:
- return this;
- case LONG:
- return ExprEval.ofLongArray(asLongArray());
- case DOUBLE:
- return ExprEval.ofDoubleArray(asDoubleArray());
- }
+ if (expressionType.equals(castTo)) {
+ return this;
}
throw new IAE("invalid type " + castTo);
}
@@ -1617,54 +1341,7 @@ public abstract class ExprEval<T>
@Override
public Expr toExpr()
{
- return new StringArrayExpr(value);
- }
-
- @Nullable
- private Long[] computeLongs()
- {
- if (value == null) {
- return null;
- }
- return Arrays.stream(value).map(value -> {
- if (value == null) {
- return null;
- }
- Long lv = GuavaUtils.tryParseLong(value);
- if (lv == null) {
- Double d = Doubles.tryParse(value);
- if (d != null) {
- lv = d.longValue();
- }
- }
- return lv;
- }).toArray(Long[]::new);
- }
-
- @Nullable
- private Double[] computeDoubles()
- {
- if (value == null) {
- return null;
- }
- return Arrays.stream(value).map(val -> {
- if (val == null) {
- return null;
- }
- return Doubles.tryParse(val);
- }).toArray(Double[]::new);
- }
-
-
- /**
- * must not be called unless array has a single element
- */
- private void computeScalarNumericIfNeeded()
- {
- if (!isScalarNumberValid) {
- computedNumericScalar = computeNumber(getScalarValue());
- isScalarNumberValid = true;
- }
+ return new ComplexExpr(expressionType, value);
}
}
}
diff --git a/core/src/main/java/org/apache/druid/math/expr/ExprListenerImpl.java b/core/src/main/java/org/apache/druid/math/expr/ExprListenerImpl.java
index 617499e..db02562 100644
--- a/core/src/main/java/org/apache/druid/math/expr/ExprListenerImpl.java
+++ b/core/src/main/java/org/apache/druid/math/expr/ExprListenerImpl.java
@@ -390,7 +390,7 @@ public class ExprListenerImpl extends ExprBaseListener
@Override
public void exitDoubleArray(ExprParser.DoubleArrayContext ctx)
{
- Double[] values = new Double[ctx.numericElement().size()];
+ Object[] values = new Object[ctx.numericElement().size()];
for (int i = 0; i < values.length; i++) {
if (ctx.numericElement(i).NULL() != null) {
values[i] = null;
@@ -402,13 +402,51 @@ public class ExprListenerImpl extends ExprBaseListener
throw new RE("Failed to parse array element %s as a double", ctx.numericElement(i).getText());
}
}
- nodes.put(ctx, new DoubleArrayExpr(values));
+ nodes.put(ctx, new ArrayExpr(ExpressionType.DOUBLE_ARRAY, values));
+ }
+
+ @Override
+ public void exitExplicitArray(ExprParser.ExplicitArrayContext ctx)
+ {
+ ExpressionType type = ExpressionType.fromString(ctx.ARRAY_TYPE().getText());
+ if (type == null) {
+ throw new RE("Failed to convert array type %s to expression type", ctx.ARRAY_TYPE().getText());
+ }
+ Object[] values = new Object[ctx.literalElement().size()];
+ for (int i = 0; i < values.length; i++) {
+ if (ctx.literalElement(i).NULL() != null) {
+ values[i] = null;
+ } else {
+ final ExprParser.LiteralElementContext elementContext = ctx.literalElement(i);
+ // if value is a string, escape quoting
+ final String toParse;
+ if (elementContext.STRING() != null) {
+ toParse = escapeStringLiteral(elementContext.STRING().getText());
+ } else {
+ toParse = elementContext.getText();
+ }
+ switch (type.getElementType().getType()) {
+ case LONG:
+ values[i] = Numbers.parseLongObject(toParse);
+ break;
+ case DOUBLE:
+ values[i] = Numbers.parseDoubleObject(toParse);
+ break;
+ case STRING:
+ values[i] = toParse;
+ break;
+ default:
+ throw new RE("Failed to parse array element %s as a %s", toParse, type.getElementType().asTypeString());
+ }
+ }
+ }
+ nodes.put(ctx, new ArrayExpr(type, values));
}
@Override
public void exitLongArray(ExprParser.LongArrayContext ctx)
{
- Long[] values = new Long[ctx.longElement().size()];
+ Object[] values = new Object[ctx.longElement().size()];
for (int i = 0; i < values.length; i++) {
if (ctx.longElement(i).NULL() != null) {
values[i] = null;
@@ -418,13 +456,13 @@ public class ExprListenerImpl extends ExprBaseListener
throw new RE("Failed to parse array element %s as a long", ctx.longElement(i).getText());
}
}
- nodes.put(ctx, new LongArrayExpr(values));
+ nodes.put(ctx, new ArrayExpr(ExpressionType.LONG_ARRAY, values));
}
@Override
public void exitExplicitLongArray(ExprParser.ExplicitLongArrayContext ctx)
{
- Long[] values = new Long[ctx.numericElement().size()];
+ Object[] values = new Object[ctx.numericElement().size()];
for (int i = 0; i < values.length; i++) {
if (ctx.numericElement(i).NULL() != null) {
values[i] = null;
@@ -436,13 +474,13 @@ public class ExprListenerImpl extends ExprBaseListener
throw new RE("Failed to parse array element %s as a long", ctx.numericElement(i).getText());
}
}
- nodes.put(ctx, new LongArrayExpr(values));
+ nodes.put(ctx, new ArrayExpr(ExpressionType.LONG_ARRAY, values));
}
@Override
public void exitStringArray(ExprParser.StringArrayContext ctx)
{
- String[] values = new String[ctx.stringElement().size()];
+ Object[] values = new Object[ctx.stringElement().size()];
for (int i = 0; i < values.length; i++) {
if (ctx.stringElement(i).NULL() != null) {
values[i] = null;
@@ -452,13 +490,13 @@ public class ExprListenerImpl extends ExprBaseListener
throw new RE("Failed to parse array: element %s is not a string", ctx.stringElement(i).getText());
}
}
- nodes.put(ctx, new StringArrayExpr(values));
+ nodes.put(ctx, new ArrayExpr(ExpressionType.STRING_ARRAY, values));
}
@Override
public void exitExplicitStringArray(ExprParser.ExplicitStringArrayContext ctx)
{
- String[] values = new String[ctx.literalElement().size()];
+ Object[] values = new Object[ctx.literalElement().size()];
for (int i = 0; i < values.length; i++) {
if (ctx.literalElement(i).NULL() != null) {
values[i] = null;
@@ -472,7 +510,7 @@ public class ExprListenerImpl extends ExprBaseListener
throw new RE("Failed to parse array element %s as a string", ctx.literalElement(i).getText());
}
}
- nodes.put(ctx, new StringArrayExpr(values));
+ nodes.put(ctx, new ArrayExpr(ExpressionType.STRING_ARRAY, values));
}
/**
diff --git a/core/src/main/java/org/apache/druid/math/expr/ExprMacroTable.java b/core/src/main/java/org/apache/druid/math/expr/ExprMacroTable.java
index 8ff0d3d..94d9729 100644
--- a/core/src/main/java/org/apache/druid/math/expr/ExprMacroTable.java
+++ b/core/src/main/java/org/apache/druid/math/expr/ExprMacroTable.java
@@ -94,9 +94,17 @@ public class ExprMacroTable
}
/**
+ * stub interface to allow {@link Parser#flatten(Expr)} a way to recognize macro functions that exend this
+ */
+ public interface ExprMacroFunctionExpr extends Expr
+ {
+ List<Expr> getArgs();
+ }
+
+ /**
* Base class for single argument {@link ExprMacro} function {@link Expr}
*/
- public abstract static class BaseScalarUnivariateMacroFunctionExpr implements Expr
+ public abstract static class BaseScalarUnivariateMacroFunctionExpr implements ExprMacroFunctionExpr
{
protected final String name;
protected final Expr arg;
@@ -112,6 +120,12 @@ public class ExprMacroTable
}
@Override
+ public List<Expr> getArgs()
+ {
+ return Collections.singletonList(arg);
+ }
+
+ @Override
public BindingAnalysis analyzeInputs()
{
return analyzeInputsSupplier.get();
@@ -147,12 +161,19 @@ public class ExprMacroTable
{
return arg.analyzeInputs().withScalarArguments(ImmutableSet.of(arg));
}
+
+
+ @Override
+ public String toString()
+ {
+ return StringUtils.format("(%s %s)", name, getArgs());
+ }
}
/**
* Base class for multi-argument {@link ExprMacro} function {@link Expr}
*/
- public abstract static class BaseScalarMacroFunctionExpr implements Expr
+ public abstract static class BaseScalarMacroFunctionExpr implements ExprMacroFunctionExpr
{
protected final String name;
protected final List<Expr> args;
@@ -168,6 +189,12 @@ public class ExprMacroTable
}
@Override
+ public List<Expr> getArgs()
+ {
+ return args;
+ }
+
+ @Override
public String stringify()
{
return StringUtils.format(
@@ -213,5 +240,11 @@ public class ExprMacroTable
}
return accumulator.withScalarArguments(argSet);
}
+
+ @Override
+ public String toString()
+ {
+ return StringUtils.format("(%s %s)", name, getArgs());
+ }
}
}
diff --git a/core/src/main/java/org/apache/druid/math/expr/ExprType.java b/core/src/main/java/org/apache/druid/math/expr/ExprType.java
index f0e2576..022c4a4 100644
--- a/core/src/main/java/org/apache/druid/math/expr/ExprType.java
+++ b/core/src/main/java/org/apache/druid/math/expr/ExprType.java
@@ -19,8 +19,6 @@
package org.apache.druid.math.expr;
-import it.unimi.dsi.fastutil.bytes.Byte2ObjectArrayMap;
-import it.unimi.dsi.fastutil.bytes.Byte2ObjectMap;
import org.apache.druid.segment.column.TypeDescriptor;
/**
@@ -28,31 +26,11 @@ import org.apache.druid.segment.column.TypeDescriptor;
*/
public enum ExprType implements TypeDescriptor
{
- DOUBLE((byte) 0x01),
- LONG((byte) 0x02),
- STRING((byte) 0x03),
- ARRAY((byte) 0x04),
- COMPLEX((byte) 0x05);
-
- private static final Byte2ObjectMap<ExprType> TYPE_BYTES = new Byte2ObjectArrayMap<>(ExprType.values().length);
-
- static {
- for (ExprType type : ExprType.values()) {
- TYPE_BYTES.put(type.getId(), type);
- }
- }
-
- final byte id;
-
- ExprType(byte id)
- {
- this.id = id;
- }
-
- public byte getId()
- {
- return id;
- }
+ DOUBLE,
+ LONG,
+ STRING,
+ ARRAY,
+ COMPLEX;
@Override
public boolean isNumeric()
@@ -71,9 +49,4 @@ public enum ExprType implements TypeDescriptor
{
return this == ExprType.ARRAY;
}
-
- public static ExprType fromByte(byte id)
- {
- return TYPE_BYTES.get(id);
- }
}
diff --git a/core/src/main/java/org/apache/druid/math/expr/ExpressionProcessing.java b/core/src/main/java/org/apache/druid/math/expr/ExpressionProcessing.java
new file mode 100644
index 0000000..c1b7d7b
--- /dev/null
+++ b/core/src/main/java/org/apache/druid/math/expr/ExpressionProcessing.java
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.math.expr;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.inject.Inject;
+
+import javax.annotation.Nullable;
+
+/**
+ * Like {@link org.apache.druid.common.config.NullHandling}, except for expressions processing configs
+ */
+public class ExpressionProcessing
+{
+ /**
+ * INSTANCE is injected using static injection to avoid adding JacksonInject annotations all over the code.
+ * @see {@link ExpressionProcessingModule} for details.
+ *
+ * It does not take effect in all unit tests since we don't use Guice Injection. Use {@link #initializeForTests}
+ * when modules are not available.
+ */
+ @Inject
+ private static ExpressionProcessingConfig INSTANCE;
+
+
+ /**
+ * Many unit tests do not setup modules for this value to be injected, this method provides a manual way to initialize
+ * {@link #INSTANCE}
+ * @param allowNestedArrays
+ */
+ @VisibleForTesting
+ public static void initializeForTests(@Nullable Boolean allowNestedArrays)
+ {
+ INSTANCE = new ExpressionProcessingConfig(allowNestedArrays);
+ }
+
+ /**
+ * whether nulls should be replaced with default value.
+ */
+ public static boolean allowNestedArrays()
+ {
+ // this should only be null in a unit test context
+ // in production this will be injected by the expression processing module
+ if (INSTANCE == null) {
+ throw new IllegalStateException(
+ "Expressions module not initialized, call ExpressionProcessing.initializeForTests()"
+ );
+ }
+ return INSTANCE.allowNestedArrays();
+ }
+}
diff --git a/processing/src/main/java/org/apache/druid/segment/virtual/SingleInputBindings.java b/core/src/main/java/org/apache/druid/math/expr/ExpressionProcessingConfig.java
similarity index 53%
copy from processing/src/main/java/org/apache/druid/segment/virtual/SingleInputBindings.java
copy to core/src/main/java/org/apache/druid/math/expr/ExpressionProcessingConfig.java
index 2525e09..8dc5b84 100644
--- a/processing/src/main/java/org/apache/druid/segment/virtual/SingleInputBindings.java
+++ b/core/src/main/java/org/apache/druid/math/expr/ExpressionProcessingConfig.java
@@ -17,25 +17,30 @@
* under the License.
*/
-package org.apache.druid.segment.virtual;
+package org.apache.druid.math.expr;
-import org.apache.druid.math.expr.Expr;
+import com.fasterxml.jackson.annotation.JsonCreator;
+import com.fasterxml.jackson.annotation.JsonProperty;
import javax.annotation.Nullable;
-public class SingleInputBindings implements Expr.ObjectBinding
+public class ExpressionProcessingConfig
{
- @Nullable
- private Object value;
+ public static final String NESTED_ARRAYS_CONFIG_STRING = "druid.expressions.allowNestedArrays";
- @Override
- public Object get(final String name)
+ @JsonProperty("allowNestedArrays")
+ private final boolean allowNestedArrays;
+
+ @JsonCreator
+ public ExpressionProcessingConfig(@JsonProperty("allowNestedArrays") @Nullable Boolean allowNestedArrays)
{
- return value;
+ this.allowNestedArrays = allowNestedArrays == null
+ ? Boolean.valueOf(System.getProperty(NESTED_ARRAYS_CONFIG_STRING, "false"))
+ : allowNestedArrays;
}
- public void set(@Nullable final Object value)
+ public boolean allowNestedArrays()
{
- this.value = value;
+ return allowNestedArrays;
}
}
diff --git a/processing/src/main/java/org/apache/druid/segment/virtual/SingleInputBindings.java b/core/src/main/java/org/apache/druid/math/expr/ExpressionProcessingModule.java
similarity index 67%
copy from processing/src/main/java/org/apache/druid/segment/virtual/SingleInputBindings.java
copy to core/src/main/java/org/apache/druid/math/expr/ExpressionProcessingModule.java
index 2525e09..d6d197f 100644
--- a/processing/src/main/java/org/apache/druid/segment/virtual/SingleInputBindings.java
+++ b/core/src/main/java/org/apache/druid/math/expr/ExpressionProcessingModule.java
@@ -17,25 +17,18 @@
* under the License.
*/
-package org.apache.druid.segment.virtual;
+package org.apache.druid.math.expr;
-import org.apache.druid.math.expr.Expr;
+import com.google.inject.Binder;
+import com.google.inject.Module;
+import org.apache.druid.guice.JsonConfigProvider;
-import javax.annotation.Nullable;
-
-public class SingleInputBindings implements Expr.ObjectBinding
+public class ExpressionProcessingModule implements Module
{
- @Nullable
- private Object value;
-
@Override
- public Object get(final String name)
- {
- return value;
- }
-
- public void set(@Nullable final Object value)
+ public void configure(Binder binder)
{
- this.value = value;
+ JsonConfigProvider.bind(binder, "druid.expressions", ExpressionProcessingConfig.class);
+ binder.requestStaticInjection(ExpressionProcessing.class);
}
}
diff --git a/core/src/main/java/org/apache/druid/math/expr/ExpressionType.java b/core/src/main/java/org/apache/druid/math/expr/ExpressionType.java
index 635aca0..adf4806 100644
--- a/core/src/main/java/org/apache/druid/math/expr/ExpressionType.java
+++ b/core/src/main/java/org/apache/druid/math/expr/ExpressionType.java
@@ -23,6 +23,7 @@ import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import com.fasterxml.jackson.databind.ser.std.ToStringSerializer;
+import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.segment.column.BaseTypeSignature;
import org.apache.druid.segment.column.ColumnType;
@@ -51,6 +52,8 @@ public class ExpressionType extends BaseTypeSignature<ExprType>
new ExpressionType(ExprType.ARRAY, null, LONG);
public static final ExpressionType DOUBLE_ARRAY =
new ExpressionType(ExprType.ARRAY, null, DOUBLE);
+ public static final ExpressionType UNKNOWN_COMPLEX =
+ new ExpressionType(ExprType.COMPLEX, null, null);
@JsonCreator
public ExpressionType(
@@ -205,4 +208,11 @@ public class ExpressionType extends BaseTypeSignature<ExprType>
throw new ISE("Unsupported expression type[%s]", exprType);
}
}
+
+ public static void checkNestedArrayAllowed(ExpressionType outputType)
+ {
+ if (outputType.isArray() && outputType.getElementType().isArray() && !ExpressionProcessing.allowNestedArrays()) {
+ throw new IAE("Cannot create a nested array type [%s], 'druid.expressions.allowNestedArrays' must be set to true", outputType);
+ }
+ }
}
diff --git a/core/src/main/java/org/apache/druid/math/expr/Function.java b/core/src/main/java/org/apache/druid/math/expr/Function.java
index e73749e..2d45567 100644
--- a/core/src/main/java/org/apache/druid/math/expr/Function.java
+++ b/core/src/main/java/org/apache/druid/math/expr/Function.java
@@ -24,7 +24,6 @@ import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.DateTimes;
import org.apache.druid.java.util.common.HumanReadableBytes;
import org.apache.druid.java.util.common.IAE;
-import org.apache.druid.java.util.common.RE;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.UOE;
import org.apache.druid.math.expr.vector.CastToTypeVectorProcessor;
@@ -32,6 +31,8 @@ import org.apache.druid.math.expr.vector.ExprVectorProcessor;
import org.apache.druid.math.expr.vector.VectorMathProcessors;
import org.apache.druid.math.expr.vector.VectorProcessors;
import org.apache.druid.math.expr.vector.VectorStringProcessors;
+import org.apache.druid.segment.column.ObjectByteStrategy;
+import org.apache.druid.segment.column.Types;
import org.joda.time.DateTime;
import org.joda.time.DateTimeZone;
import org.joda.time.format.DateTimeFormat;
@@ -39,6 +40,7 @@ import org.joda.time.format.DateTimeFormat;
import javax.annotation.Nullable;
import java.math.BigDecimal;
import java.math.RoundingMode;
+import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
@@ -501,26 +503,12 @@ public interface Function
@Override
ExprEval doApply(ExprEval arrayExpr, ExprEval scalarExpr)
{
- switch (arrayExpr.elementType().getType()) {
- case STRING:
- return ExprEval.ofStringArray(add(arrayExpr.asStringArray(), scalarExpr.asString()).toArray(String[]::new));
- case LONG:
- return ExprEval.ofLongArray(
- add(
- arrayExpr.asLongArray(),
- scalarExpr.isNumericNull() ? null : scalarExpr.asLong()
- ).toArray(Long[]::new)
- );
- case DOUBLE:
- return ExprEval.ofDoubleArray(
- add(
- arrayExpr.asDoubleArray(),
- scalarExpr.isNumericNull() ? null : scalarExpr.asDouble()
- ).toArray(Double[]::new)
- );
+ if (!scalarExpr.type().equals(arrayExpr.elementType())) {
+ // try to cast
+ ExprEval coerced = scalarExpr.castTo(arrayExpr.elementType());
+ return ExprEval.ofArray(arrayExpr.asArrayType(), add(arrayExpr.asArray(), coerced.value()).toArray());
}
-
- throw new RE("Unable to add to unknown array type %s", arrayExpr.type());
+ return ExprEval.ofArray(arrayExpr.asArrayType(), add(arrayExpr.asArray(), scalarExpr.value()).toArray());
}
abstract <T> Stream<T> add(T[] array, @Nullable T val);
@@ -564,21 +552,13 @@ public interface Function
return lhsExpr;
}
- switch (lhsExpr.elementType().getType()) {
- case STRING:
- return ExprEval.ofStringArray(
- merge(lhsExpr.asStringArray(), rhsExpr.asStringArray()).toArray(String[]::new)
- );
- case LONG:
- return ExprEval.ofLongArray(
- merge(lhsExpr.asLongArray(), rhsExpr.asLongArray()).toArray(Long[]::new)
- );
- case DOUBLE:
- return ExprEval.ofDoubleArray(
- merge(lhsExpr.asDoubleArray(), rhsExpr.asDoubleArray()).toArray(Double[]::new)
- );
+ if (!lhsExpr.asArrayType().equals(rhsExpr.asArrayType())) {
+ // try to cast if they types don't match
+ ExprEval coerced = rhsExpr.castTo(lhsExpr.asArrayType());
+ ExprEval.ofArray(lhsExpr.asArrayType(), merge(lhsExpr.asArray(), coerced.asArray()).toArray());
}
- throw new RE("Unable to concatenate to unknown type %s", lhsExpr.type());
+
+ return ExprEval.ofArray(lhsExpr.asArrayType(), merge(lhsExpr.asArray(), rhsExpr.asArray()).toArray());
}
abstract <T> Stream<T> merge(T[] array1, T[] array2);
@@ -2925,71 +2905,17 @@ public interface Function
// this is copied from 'BaseMapFunction.applyMap', need to find a better way to consolidate, or construct arrays,
// or.. something...
final int length = args.size();
- String[] stringsOut = null;
- Long[] longsOut = null;
- Double[] doublesOut = null;
+ Object[] out = new Object[length];
- ExpressionType elementType = null;
+ ExpressionType arrayType = null;
for (int i = 0; i < length; i++) {
ExprEval<?> evaluated = args.get(i).eval(bindings);
- if (elementType == null) {
- elementType = evaluated.type();
- switch (elementType.getType()) {
- case STRING:
- stringsOut = new String[length];
- break;
- case LONG:
- longsOut = new Long[length];
- break;
- case DOUBLE:
- doublesOut = new Double[length];
- break;
- default:
- throw new RE("Unhandled array constructor element type [%s]", elementType);
- }
- }
-
- setArrayOutputElement(stringsOut, longsOut, doublesOut, elementType, i, evaluated);
+ arrayType = setArrayOutput(arrayType, out, i, evaluated);
}
- // There should be always at least one argument and thus elementType is never null.
- // See validateArguments().
- //noinspection ConstantConditions
- switch (elementType.getType()) {
- case STRING:
- return ExprEval.ofStringArray(stringsOut);
- case LONG:
- return ExprEval.ofLongArray(longsOut);
- case DOUBLE:
- return ExprEval.ofDoubleArray(doublesOut);
- default:
- throw new RE("Unhandled array constructor element type [%s]", elementType);
- }
+ return ExprEval.ofArray(arrayType, out);
}
- static void setArrayOutputElement(
- String[] stringsOut,
- Long[] longsOut,
- Double[] doublesOut,
- ExpressionType elementType,
- int i,
- ExprEval evaluated
- )
- {
- switch (elementType.getType()) {
- case STRING:
- stringsOut[i] = evaluated.asString();
- break;
- case LONG:
- longsOut[i] = evaluated.isNumericNull() ? null : evaluated.asLong();
- break;
- case DOUBLE:
- doublesOut[i] = evaluated.isNumericNull() ? null : evaluated.asDouble();
- break;
- }
- }
-
-
@Override
public Set<Expr> getScalarInputs(List<Expr> args)
{
@@ -3026,6 +2952,29 @@ public interface Function
}
return ExpressionType.asArrayType(type);
}
+
+ /**
+ * Set an array element to the output array, checking for null if the array is numeric. If the type of the evaluated
+ * array element does not match the array element type, this method will attempt to call {@link ExprEval#castTo}
+ * to the array element type, else will set the element as is. If the type of the array is unknown, it will be
+ * detected and defined from the first element. Returns the type of the array, which will be identical to the input
+ * type, unless the input type was null.
+ */
+ static ExpressionType setArrayOutput(@Nullable ExpressionType arrayType, Object[] out, int i, ExprEval evaluated)
+ {
+ if (arrayType == null) {
+ arrayType = ExpressionTypeFactory.getInstance().ofArray(evaluated.type());
+ }
+ ExpressionType.checkNestedArrayAllowed(arrayType);
+ if (arrayType.getElementType().isNumeric() && evaluated.isNumericNull()) {
+ out[i] = null;
+ } else if (!evaluated.asArrayType().equals(arrayType)) {
+ out[i] = evaluated.castTo((ExpressionType) arrayType.getElementType()).value();
+ } else {
+ out[i] = evaluated.value();
+ }
+ return arrayType;
+ }
}
class ArrayLengthFunction implements Function
@@ -3186,7 +3135,7 @@ public interface Function
final int position = scalarExpr.asInt();
if (array.length > position) {
- return ExprEval.bestEffortOf(array[position]);
+ return ExprEval.ofType(arrayExpr.elementType(), array[position]);
}
return ExprEval.of(null);
}
@@ -3214,7 +3163,7 @@ public interface Function
final int position = scalarExpr.asInt() - 1;
if (array.length > position) {
- return ExprEval.bestEffortOf(array[position]);
+ return ExprEval.ofType(arrayExpr.elementType(), array[position]);
}
return ExprEval.of(null);
}
@@ -3521,15 +3470,7 @@ public interface Function
return ExprEval.of(null);
}
- switch (expr.elementType().getType()) {
- case STRING:
- return ExprEval.ofStringArray(Arrays.copyOfRange(expr.asStringArray(), start, end));
- case LONG:
- return ExprEval.ofLongArray(Arrays.copyOfRange(expr.asLongArray(), start, end));
- case DOUBLE:
- return ExprEval.ofDoubleArray(Arrays.copyOfRange(expr.asDoubleArray(), start, end));
- }
- throw new RE("Unable to slice to unknown type %s", expr.type());
+ return ExprEval.ofArray(expr.asArrayType(), Arrays.copyOfRange(expr.asArray(), start, end));
}
}
@@ -3631,4 +3572,78 @@ public interface Function
return HumanReadableBytes.UnitSystem.DECIMAL;
}
}
+
+ class ComplexDecodeBase64Function implements Function
+ {
+ @Override
+ public String name()
+ {
+ return "complex_decode_base64";
+ }
+
+ @Override
+ public ExprEval apply(List<Expr> args, Expr.ObjectBinding bindings)
+ {
+ ExprEval arg0 = args.get(0).eval(bindings);
+ if (!arg0.type().is(ExprType.STRING)) {
+ throw new IAE(
+ "Function[%s] first argument must be constant 'STRING' expression containing a valid complex type name",
+ name()
+ );
+ }
+ ExpressionType complexType = ExpressionTypeFactory.getInstance().ofComplex((String) args.get(0).getLiteralValue());
+ ObjectByteStrategy strategy = Types.getStrategy(complexType.getComplexTypeName());
+ if (strategy == null) {
+ throw new IAE(
+ "Function[%s] first argument must be a valid complex type name, unknown complex type [%s]",
+ name(),
+ complexType.asTypeString()
+ );
+ }
+ ExprEval base64String = args.get(1).eval(bindings);
+ if (!base64String.type().is(ExprType.STRING)) {
+ throw new IAE(
+ "Function[%s] second argument must be a base64 encoded 'STRING' value",
+ name()
+ );
+ }
+ if (base64String.value() == null) {
+ return ExprEval.ofComplex(complexType, null);
+ }
+
+ final byte[] base64 = StringUtils.decodeBase64String(base64String.asString());
+ return ExprEval.ofComplex(complexType, strategy.fromByteBuffer(ByteBuffer.wrap(base64), base64.length));
+ }
+
+ @Override
+ public void validateArguments(List<Expr> args)
+ {
+ if (args.size() != 2) {
+ throw new IAE("Function[%s] needs 2 arguments", name());
+ }
+ if (!args.get(0).isLiteral() || args.get(0).isNullLiteral()) {
+ throw new IAE(
+ "Function[%s] first argument must be constant 'STRING' expression containing a valid complex type name",
+ name()
+ );
+ }
+ }
+
+ @Nullable
+ @Override
+ public ExpressionType getOutputType(
+ Expr.InputBindingInspector inspector,
+ List<Expr> args
+ )
+ {
+ ExpressionType arg0Type = args.get(0).getOutputType(inspector);
+ if (arg0Type == null || !arg0Type.is(ExprType.STRING)) {
+ throw new IAE(
+ "Function[%s] first argument must be constant 'STRING' expression containing a valid complex type name",
+ name()
+ );
+ }
+ return ExpressionTypeFactory.getInstance().ofComplex((String) args.get(0).getLiteralValue());
+ }
+ }
}
diff --git a/core/src/main/java/org/apache/druid/math/expr/IdentifierExpr.java b/core/src/main/java/org/apache/druid/math/expr/IdentifierExpr.java
index 4abfd55..49e9c96 100644
--- a/core/src/main/java/org/apache/druid/math/expr/IdentifierExpr.java
+++ b/core/src/main/java/org/apache/druid/math/expr/IdentifierExpr.java
@@ -122,7 +122,7 @@ class IdentifierExpr implements Expr
@Override
public ExprEval eval(ObjectBinding bindings)
{
- return ExprEval.bestEffortOf(bindings.get(binding));
+ return ExprEval.ofType(bindings.getType(binding), bindings.get(binding));
}
@Override
diff --git a/core/src/main/java/org/apache/druid/math/expr/InputBindings.java b/core/src/main/java/org/apache/druid/math/expr/InputBindings.java
index 5174603..7b2639d 100644
--- a/core/src/main/java/org/apache/druid/math/expr/InputBindings.java
+++ b/core/src/main/java/org/apache/druid/math/expr/InputBindings.java
@@ -20,12 +20,36 @@
package org.apache.druid.math.expr;
import com.google.common.base.Supplier;
+import org.apache.druid.java.util.common.Pair;
import javax.annotation.Nullable;
import java.util.Map;
+import java.util.function.Function;
public class InputBindings
{
+ private static final Expr.ObjectBinding NIL_BINDINGS = new Expr.ObjectBinding()
+ {
+ @Nullable
+ @Override
+ public Object get(String name)
+ {
+ return null;
+ }
+
+ @Nullable
+ @Override
+ public ExpressionType getType(String name)
+ {
+ return null;
+ }
+ };
+
+ public static Expr.ObjectBinding nilBindings()
+ {
+ return NIL_BINDINGS;
+ }
+
/**
* Create an {@link Expr.InputBindingInspector} backed by a map of binding identifiers to their {@link ExprType}
*/
@@ -42,23 +66,95 @@ public class InputBindings
};
}
+ public static Expr.ObjectBinding singleProvider(ExpressionType type, final Function<String, ?> valueFn)
+ {
+ return new Expr.ObjectBinding()
+ {
+ @Nullable
+ @Override
+ public Object get(String name)
+ {
+ return valueFn.apply(name);
+ }
+
+ @Nullable
+ @Override
+ public ExpressionType getType(String name)
+ {
+ return type;
+ }
+ };
+ }
+
+ public static Expr.ObjectBinding forFunction(final Function<String, ?> valueFn)
+ {
+ return new Expr.ObjectBinding()
+ {
+ @Nullable
+ @Override
+ public Object get(String name)
+ {
+ return valueFn.apply(name);
+ }
+
+ @Nullable
+ @Override
+ public ExpressionType getType(String name)
+ {
+ return ExprEval.bestEffortOf(valueFn.apply(name)).type();
+ }
+ };
+ }
+
/**
* Create {@link Expr.ObjectBinding} backed by {@link Map} to provide values for identifiers to evaluate {@link Expr}
*/
public static Expr.ObjectBinding withMap(final Map<String, ?> bindings)
{
- return bindings::get;
+ return new Expr.ObjectBinding()
+ {
+ @Nullable
+ @Override
+ public Object get(String name)
+ {
+ return bindings.get(name);
+ }
+
+ @Nullable
+ @Override
+ public ExpressionType getType(String name)
+ {
+ return ExprEval.bestEffortOf(bindings.get(name)).type();
+ }
+ };
}
/**
* Create {@link Expr.ObjectBinding} backed by map of {@link Supplier} to provide values for identifiers to evaluate
* {@link Expr}
*/
- public static Expr.ObjectBinding withSuppliers(final Map<String, Supplier<Object>> bindings)
+ public static Expr.ObjectBinding withTypedSuppliers(final Map<String, Pair<ExpressionType, Supplier<Object>>> bindings)
{
- return (String name) -> {
- Supplier<Object> supplier = bindings.get(name);
- return supplier == null ? null : supplier.get();
+ return new Expr.ObjectBinding()
+ {
+ @Nullable
+ @Override
+ public Object get(String name)
+ {
+ Pair<ExpressionType, Supplier<Object>> binding = bindings.get(name);
+ return binding == null || binding.rhs == null ? null : binding.rhs.get();
+ }
+
+ @Nullable
+ @Override
+ public ExpressionType getType(String name)
+ {
+ Pair<ExpressionType, Supplier<Object>> binding = bindings.get(name);
+ if (binding == null) {
+ return null;
+ }
+ return binding.lhs;
+ }
};
}
}
diff --git a/core/src/main/java/org/apache/druid/math/expr/Parser.java b/core/src/main/java/org/apache/druid/math/expr/Parser.java
index 0a42bfa..b338043 100644
--- a/core/src/main/java/org/apache/druid/math/expr/Parser.java
+++ b/core/src/main/java/org/apache/druid/math/expr/Parser.java
@@ -153,28 +153,33 @@ public class Parser
if (childExpr instanceof BinaryOpExprBase) {
BinaryOpExprBase binary = (BinaryOpExprBase) childExpr;
if (Evals.isAllConstants(binary.left, binary.right)) {
- return childExpr.eval(null).toExpr();
+ return childExpr.eval(InputBindings.nilBindings()).toExpr();
}
} else if (childExpr instanceof UnaryExpr) {
UnaryExpr unary = (UnaryExpr) childExpr;
if (unary.expr instanceof ConstantExpr) {
- return childExpr.eval(null).toExpr();
+ return childExpr.eval(InputBindings.nilBindings()).toExpr();
}
} else if (childExpr instanceof FunctionExpr) {
FunctionExpr functionExpr = (FunctionExpr) childExpr;
List<Expr> args = functionExpr.args;
if (Evals.isAllConstants(args)) {
- return childExpr.eval(null).toExpr();
+ return childExpr.eval(InputBindings.nilBindings()).toExpr();
}
} else if (childExpr instanceof ApplyFunctionExpr) {
ApplyFunctionExpr applyFunctionExpr = (ApplyFunctionExpr) childExpr;
List<Expr> args = applyFunctionExpr.argsExpr;
if (Evals.isAllConstants(args)) {
if (applyFunctionExpr.analyzeInputs().getFreeVariables().size() == 0) {
- return childExpr.eval(null).toExpr();
+ return childExpr.eval(InputBindings.nilBindings()).toExpr();
}
}
+ } else if (childExpr instanceof ExprMacroTable.ExprMacroFunctionExpr) {
+ ExprMacroTable.ExprMacroFunctionExpr macroFn = (ExprMacroTable.ExprMacroFunctionExpr) childExpr;
+ if (Evals.isAllConstants(macroFn.getArgs())) {
+ return childExpr.eval(InputBindings.nilBindings()).toExpr();
+ }
}
return childExpr;
});
diff --git a/core/src/main/java/org/apache/druid/math/expr/SettableObjectBinding.java b/core/src/main/java/org/apache/druid/math/expr/SettableObjectBinding.java
index 600e1e2..d97a23d 100644
--- a/core/src/main/java/org/apache/druid/math/expr/SettableObjectBinding.java
+++ b/core/src/main/java/org/apache/druid/math/expr/SettableObjectBinding.java
@@ -32,6 +32,7 @@ import java.util.Map;
public class SettableObjectBinding implements Expr.ObjectBinding
{
private final Map<String, Object> bindings;
+ private Expr.InputBindingInspector inspector = InputBindings.nilBindings();
public SettableObjectBinding()
{
@@ -50,12 +51,25 @@ public class SettableObjectBinding implements Expr.ObjectBinding
return bindings.get(name);
}
+ @Nullable
+ @Override
+ public ExpressionType getType(String name)
+ {
+ return inspector.getType(name);
+ }
+
public SettableObjectBinding withBinding(String name, @Nullable Object value)
{
bindings.put(name, value);
return this;
}
+ public SettableObjectBinding withInspector(Expr.InputBindingInspector inspector)
+ {
+ this.inspector = inspector;
+ return this;
+ }
+
@VisibleForTesting
public Map<String, Object> asMap()
{
diff --git a/processing/src/main/java/org/apache/druid/segment/data/ObjectStrategy.java b/core/src/main/java/org/apache/druid/segment/column/ObjectByteStrategy.java
similarity index 66%
copy from processing/src/main/java/org/apache/druid/segment/data/ObjectStrategy.java
copy to core/src/main/java/org/apache/druid/segment/column/ObjectByteStrategy.java
index ed4410b..d47432b 100644
--- a/processing/src/main/java/org/apache/druid/segment/data/ObjectStrategy.java
+++ b/core/src/main/java/org/apache/druid/segment/column/ObjectByteStrategy.java
@@ -17,18 +17,23 @@
* under the License.
*/
-package org.apache.druid.segment.data;
-
-import org.apache.druid.guice.annotations.ExtensionPoint;
-import org.apache.druid.segment.writeout.WriteOutBytes;
+package org.apache.druid.segment.column;
import javax.annotation.Nullable;
-import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Comparator;
-@ExtensionPoint
-public interface ObjectStrategy<T> extends Comparator<T>
+/**
+ * Naming is hard. This is the core interface extracted from another interface called ObjectStrategy that lives in
+ * 'druid-processing'. It provides basic methods for handling converting some type of object to a binary form, reading
+ * the binary form back into an object from a {@link ByteBuffer}, and mechanism to perform comparisons between objects.
+ *
+ * Complex types register one of these in {@link Types#registerStrategy}, which can be retrieved by the complex
+ * type name to convert values to and from binary format, and compare them.
+ *
+ * This could be recombined with 'ObjectStrategy' should these two modules be combined.
+ */
+public interface ObjectByteStrategy<T> extends Comparator<T>
{
Class<? extends T> getClazz();
@@ -50,25 +55,4 @@ public interface ObjectStrategy<T> extends Comparator<T>
@Nullable
byte[] toBytes(@Nullable T val);
-
- /**
- * Reads 4-bytes numBytes from the given buffer, and then delegates to {@link #fromByteBuffer(ByteBuffer, int)}.
- */
- default T fromByteBufferWithSize(ByteBuffer buffer)
- {
- int size = buffer.getInt();
- ByteBuffer bufferToUse = buffer.asReadOnlyBuffer();
- bufferToUse.limit(bufferToUse.position() + size);
- buffer.position(bufferToUse.limit());
-
- return fromByteBuffer(bufferToUse, size);
- }
-
- default void writeTo(T val, WriteOutBytes out) throws IOException
- {
- byte[] bytes = toBytes(val);
- if (bytes != null) {
- out.write(bytes);
- }
- }
}
diff --git a/core/src/main/java/org/apache/druid/segment/column/Types.java b/core/src/main/java/org/apache/druid/segment/column/Types.java
index 44c7b13..d8c1d21 100644
--- a/core/src/main/java/org/apache/druid/segment/column/Types.java
+++ b/core/src/main/java/org/apache/druid/segment/column/Types.java
@@ -20,14 +20,23 @@
package org.apache.druid.segment.column;
import com.google.common.base.Preconditions;
+import org.apache.druid.common.config.NullHandling;
+import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.StringUtils;
import javax.annotation.Nullable;
+import java.nio.ByteBuffer;
+import java.util.concurrent.ConcurrentHashMap;
public class Types
{
private static final String ARRAY_PREFIX = "ARRAY<";
private static final String COMPLEX_PREFIX = "COMPLEX<";
+ private static final int VALUE_OFFSET = Byte.BYTES;
+ private static final int NULLABLE_LONG_SIZE = Byte.BYTES + Long.BYTES;
+ private static final int NULLABLE_DOUBLE_SIZE = Byte.BYTES + Double.BYTES;
+ private static final int NULLABLE_FLOAT_SIZE = Byte.BYTES + Float.BYTES;
+ private static final ConcurrentHashMap<String, ObjectByteStrategy<?>> STRATEGIES = new ConcurrentHashMap<>();
/**
* Create a {@link TypeSignature} given the value of {@link TypeSignature#asTypeString()} and a {@link TypeFactory}
@@ -112,4 +121,610 @@ public class Types
return (typeSignature1 != null && typeSignature1.is(typeDescriptor)) ||
(typeSignature2 != null && typeSignature2.is(typeDescriptor));
}
+
+ /**
+ * Get an {@link ObjectByteStrategy} registered to some {@link TypeSignature#getComplexTypeName()}.
+ */
+ @Nullable
+ public static ObjectByteStrategy<?> getStrategy(String type)
+ {
+ return STRATEGIES.get(type);
+ }
+
+ /**
+ * hmm... this might look familiar... (see ComplexMetrics)
+ *
+ * Register a complex type name -> {@link ObjectByteStrategy} mapping.
+ *
+ * If the specified type name or type id are already used and the supplied {@link ObjectByteStrategy} is not of the
+ * same type as the existing value in the map for said key, an {@link ISE} is thrown.
+ *
+ * @param strategy The {@link ObjectByteStrategy} object to be associated with the 'type' in the map.
+ */
+ public static void registerStrategy(String typeName, ObjectByteStrategy<?> strategy)
+ {
+ Preconditions.checkNotNull(typeName);
+ STRATEGIES.compute(typeName, (key, value) -> {
+ if (value == null) {
+ return strategy;
+ } else {
+ if (!value.getClass().getName().equals(strategy.getClass().getName())) {
+ throw new ISE(
+ "Incompatible strategy for type[%s] already exists. Expected [%s], found [%s].",
+ key,
+ strategy.getClass().getName(),
+ value.getClass().getName()
+ );
+ } else {
+ return value;
+ }
+ }
+ });
+ }
+
+ /**
+ * Clear and set the 'null' byte of a nullable value to {@link NullHandling#IS_NULL_BYTE} to a {@link ByteBuffer} at
+ * the supplied position. This method does not change the buffer position, limit, or mark, because it does not expect
+ * to own the buffer given to it (i.e. buffer aggs)
+ *
+ * Nullable types are stored with a leading byte to indicate if the value is null, followed by the value bytes
+ * (if not null)
+ *
+ * layout: | null (byte) | value |
+ *
+ * @return number of bytes written (always 1)
+ */
+ public static int writeNull(ByteBuffer buffer, int offset)
+ {
+ buffer.put(offset, NullHandling.IS_NULL_BYTE);
+ return 1;
+ }
+
+ /**
+ * Checks if a 'nullable' value's null byte is set to {@link NullHandling#IS_NULL_BYTE}. This method will mask the
+ * value of the null byte to only check if the null bit is set, meaning that the higher bits can be utilized for
+ * flags as necessary (e.g. using high bits to indicate if the value has been set or not for aggregators).
+ *
+ * Note that writing nullable values with the methods of {@link Types} will always clear and set the null byte to
+ * either {@link NullHandling#IS_NULL_BYTE} or {@link NullHandling#IS_NOT_NULL_BYTE}, losing any flag bits.
+ *
+ * layout: | null (byte) | value |
+ */
+ public static boolean isNullableNull(ByteBuffer buffer, int offset)
+ {
+ // use & so that callers can use the high bits of the null byte to pack additional information if necessary
+ return (buffer.get(offset) & NullHandling.IS_NULL_BYTE) == NullHandling.IS_NULL_BYTE;
+ }
+
+ /**
+ * Write a non-null long value to a {@link ByteBuffer} at the supplied offset. The first byte is always cleared and
+ * set to {@link NullHandling#IS_NOT_NULL_BYTE}, the long value is written in the next 8 bytes.
+ *
+ * layout: | null (byte) | long |
+ *
+ * This method does not change the buffer position, limit, or mark, because it does not expect to own the buffer
+ * given to it (i.e. buffer aggs)
+ *
+ * @return number of bytes written (always 9)
+ */
+ public static int writeNullableLong(ByteBuffer buffer, int offset, long value)
+ {
+ buffer.put(offset++, NullHandling.IS_NOT_NULL_BYTE);
+ buffer.putLong(offset, value);
+ return NULLABLE_LONG_SIZE;
+ }
+
+ /**
+ * Reads a non-null long value from a {@link ByteBuffer} at the supplied offset. This method should only be called
+ * if and only if {@link #isNullableNull} for the same offset returns false.
+ *
+ * layout: | null (byte) | long |
+ *
+ * This method does not change the buffer position, limit, or mark, because it does not expect to own the buffer
+ * given to it (i.e. buffer aggs)
+ */
+ public static long readNullableLong(ByteBuffer buffer, int offset)
+ {
+ assert !isNullableNull(buffer, offset);
+ return buffer.getLong(offset + VALUE_OFFSET);
+ }
+
+ /**
+ * Write a non-null double value to a {@link ByteBuffer} at the supplied offset. The first byte is always cleared and
+ * set to {@link NullHandling#IS_NOT_NULL_BYTE}, the double value is written in the next 8 bytes.
+ *
+ * layout: | null (byte) | double |
+ *
+ * This method does not change the buffer position, limit, or mark, because it does not expect to own the buffer
+ * given to it (i.e. buffer aggs)
+ *
+ * @return number of bytes written (always 9)
+ */
+ public static int writeNullableDouble(ByteBuffer buffer, int offset, double value)
+ {
+ buffer.put(offset++, NullHandling.IS_NOT_NULL_BYTE);
+ buffer.putDouble(offset, value);
+ return NULLABLE_DOUBLE_SIZE;
+ }
+
+ /**
+ * Reads a non-null double value from a {@link ByteBuffer} at the supplied offset. This method should only be called
+ * if and only if {@link #isNullableNull} for the same offset returns false.
+ *
+ * layout: | null (byte) | double |
+ *
+ * This method does not change the buffer position, limit, or mark, because it does not expect to own the buffer
+ * given to it (i.e. buffer aggs)
+ */
+ public static double readNullableDouble(ByteBuffer buffer, int offset)
+ {
+ assert !isNullableNull(buffer, offset);
+ return buffer.getDouble(offset + VALUE_OFFSET);
+ }
+
+ /**
+ * Write a non-null float value to a {@link ByteBuffer} at the supplied offset. The first byte is always cleared and
+ * set to {@link NullHandling#IS_NOT_NULL_BYTE}, the float value is written in the next 4 bytes.
+ *
+ * layout: | null (byte) | float |
+ *
+ * This method does not change the buffer position, limit, or mark, because it does not expect to own the buffer
+ * given to it (i.e. buffer aggs)
+ *
+ * @return number of bytes written (always 5)
+ */
+ public static int writeNullableFloat(ByteBuffer buffer, int offset, float value)
+ {
+ buffer.put(offset++, NullHandling.IS_NOT_NULL_BYTE);
+ buffer.putFloat(offset, value);
+ return NULLABLE_FLOAT_SIZE;
+ }
+
+ /**
+ * Reads a non-null float value from a {@link ByteBuffer} at the supplied offset. This method should only be called
+ * if and only if {@link #isNullableNull} for the same offset returns false.
+ *
+ * layout: | null (byte) | float |
+ *
+ * This method does not change the buffer position, limit, or mark, because it does not expect to own the buffer
+ * given to it (i.e. buffer aggs)
+ */
+ public static float readNullableFloat(ByteBuffer buffer, int offset)
+ {
+ assert !isNullableNull(buffer, offset);
+ return buffer.getFloat(offset + VALUE_OFFSET);
+ }
+
+ /**
+ * Write a variably lengthed byte[] value to a {@link ByteBuffer} at the supplied offset. The first byte is set to
+ * {@link NullHandling#IS_NULL_BYTE} or {@link NullHandling#IS_NOT_NULL_BYTE} as appropriate, and if the byte[] value
+ * is not null, the size in bytes is written as an integer in the next 4 bytes, followed by the byte[] value itself.
+ *
+ * layout: | null (byte) | size (int) | byte[] |
+ *
+ * This method checks that no more than the specified maximum number of bytes can be written to the buffer, and the
+ * proper function of this method requires that the buffer contains at least that many bytes free from the starting
+ * offset. See {@link #writeNullableVariableBlob(ByteBuffer, int, byte[])} if you do not need to check the length
+ * of the byte array, or wish to perform the check externally.
+ *
+ * This method does not change the buffer position, limit, or mark, because it does not expect to own the buffer
+ * given to it (i.e. buffer aggs)
+ *
+ * @return number of bytes written (1 if null, or 5 + size of byte[] if not)
+ */
+ public static int writeNullableVariableBlob(
+ ByteBuffer buffer,
+ int offset,
+ @Nullable byte[] value,
+ TypeSignature<?> type,
+ int maxSizeBytes
+ )
+ {
+ if (value == null) {
+ return writeNull(buffer, offset);
+ }
+ // | null (byte) | length (int) | bytes |
+ checkMaxBytes(
+ type,
+ 1 + Integer.BYTES + value.length,
+ maxSizeBytes
+ );
+ return writeNullableVariableBlob(buffer, offset, value);
+ }
+
+ /**
+ * Write a variably lengthed byte[] value to a {@link ByteBuffer} at the supplied offset. The first byte is set to
+ * {@link NullHandling#IS_NULL_BYTE} or {@link NullHandling#IS_NOT_NULL_BYTE} as appropriate, and if the byte[] value
+ * is not null, the size in bytes is written as an integer in the next 4 bytes, followed by the byte[] value itself.
+ *
+ * layout: | null (byte) | size (int) | byte[] |
+ *
+ * This method does not constrain the number of bytes written to the buffer, so either use
+ * {@link #writeNullableVariableBlob(ByteBuffer, int, byte[], TypeSignature, int)} or first check that the size
+ * of the byte array plus 5 bytes is available in the buffer before using this method.
+ *
+ * This method does not change the buffer position, limit, or mark, because it does not expect to own the buffer
+ * given to it (i.e. buffer aggs)
+ *
+ * @return number of bytes written (1 if null, or 5 + size of byte[] if not)
+ */
+ public static int writeNullableVariableBlob(ByteBuffer buffer, int offset, @Nullable byte[] value)
+ {
+ // | null (byte) | length (int) | bytes |
+ final int size;
+ if (value == null) {
+ return writeNull(buffer, offset);
+ }
+ final int oldPosition = buffer.position();
+ buffer.position(offset);
+ buffer.put(NullHandling.IS_NOT_NULL_BYTE);
+ buffer.putInt(value.length);
+ buffer.put(value, 0, value.length);
+ size = buffer.position() - offset;
+ buffer.position(oldPosition);
+ return size;
+ }
+
+ /**
+ * Reads a nullable variably lengthed byte[] value from a {@link ByteBuffer} at the supplied offset. If the null byte
+ * is set to {@link NullHandling#IS_NULL_BYTE}, this method will return null, else it will read the next 4 bytes to
+ * get the byte[] size followed by that many bytes to extract the value.
+ *
+ * layout: | null (byte) | size (int) | byte[] |
+ *
+ * This method does not change the buffer position, limit, or mark, because it does not expect to own the buffer
+ * given to it (i.e. buffer aggs)
+ */
+ @Nullable
+ public static byte[] readNullableVariableBlob(ByteBuffer buffer, int offset)
+ {
+ // | null (byte) | length (int) | bytes |
+ final int length = buffer.getInt(offset + VALUE_OFFSET);
+ final byte[] blob = new byte[length];
+ final int oldPosition = buffer.position();
+ buffer.position(offset + VALUE_OFFSET + Integer.BYTES);
+ buffer.get(blob, 0, length);
+ buffer.position(oldPosition);
+ return blob;
+ }
+
+ /**
+ * Write a variably lengthed Long[] value to a {@link ByteBuffer} at the supplied offset. The first byte is set to
+ * {@link NullHandling#IS_NULL_BYTE} or {@link NullHandling#IS_NOT_NULL_BYTE} as appropriate, and if the Long[] value
+ * is not null, the size in bytes is written as an integer in the next 4 bytes. Elements of the array are each written
+ * out with {@link #writeNull} if null, or {@link #writeNullableLong} if not, taking either 1 or 9 bytes each. If the
+ * total byte size of serializing the array is larger than the max size parameter, this method will explode via a call
+ * to {@link #checkMaxBytes}.
+ *
+ * layout: | null (byte) | size (int) | {| null (byte) | long |, | null (byte) |, ... |null (byte) | long |} |
+ *
+ * This method checks that no more than the specified maximum number of bytes can be written to the buffer, and the
+ * proper function of this method requires that the buffer contains at least that many bytes free from the starting
+ * offset.
+ *
+ * This method does not change the buffer position, limit, or mark, because it does not expect to own the buffer
+ * given to it (i.e. buffer aggs)
+ *
+ * @return number of bytes written (1 if null, or 5 + size of Long[] if not)
+ */
+ public static int writeNullableLongArray(ByteBuffer buffer, int offset, @Nullable Long[] array, int maxSizeBytes)
+ {
+ // | null (byte) | array length (int) | array bytes |
+ if (array == null) {
+ return writeNull(buffer, offset);
+ }
+ int sizeBytes = 1 + Integer.BYTES;
+
+ buffer.put(offset, NullHandling.IS_NOT_NULL_BYTE);
+ buffer.putInt(offset + 1, array.length);
+ for (Long element : array) {
+ if (element != null) {
+ checkMaxBytes(
+ ColumnType.LONG_ARRAY,
+ sizeBytes + 1 + Long.BYTES,
+ maxSizeBytes
+ );
+ sizeBytes += writeNullableLong(buffer, offset + sizeBytes, element);
+ } else {
+ checkMaxBytes(
+ ColumnType.LONG_ARRAY,
+ sizeBytes + 1,
+ maxSizeBytes
+ );
+ sizeBytes += writeNull(buffer, offset + sizeBytes);
+ }
+ }
+ return sizeBytes;
+ }
+
+ /**
+ * Reads a nullable variably lengthed Long[] value from a {@link ByteBuffer} at the supplied offset. If the null byte
+ * is set to {@link NullHandling#IS_NULL_BYTE}, this method will return null, else it will read the size of the array
+ * from the next 4 bytes and then read that many elements with {@link #isNullableNull} and {@link #readNullableLong}.
+ *
+ * layout: | null (byte) | size (int) | {| null (byte) | long |, | null (byte) |, ... |null (byte) | long |} |
+ *
+ * This method does not change the buffer position, limit, or mark, because it does not expect to own the buffer
+ * given to it (i.e. buffer aggs)
+ */
+ @Nullable
+ public static Long[] readNullableLongArray(ByteBuffer buffer, int offset)
+ {
+ // | null (byte) | array length (int) | array bytes |
+ if (isNullableNull(buffer, offset++)) {
+ return null;
+ }
+ final int longArrayLength = buffer.getInt(offset);
+ offset += Integer.BYTES;
+ final Long[] longs = new Long[longArrayLength];
+ for (int i = 0; i < longArrayLength; i++) {
+ if (isNullableNull(buffer, offset)) {
+ longs[i] = null;
+ } else {
+ longs[i] = readNullableLong(buffer, offset);
+ offset += Long.BYTES;
+ }
+ offset++;
+ }
+ return longs;
+ }
+
+ /**
+ * Write a variably lengthed Double[] value to a {@link ByteBuffer} at the supplied offset. The first byte is set to
+ * {@link NullHandling#IS_NULL_BYTE} or {@link NullHandling#IS_NOT_NULL_BYTE} as appropriate, and if the Long[] value
+ * is not null, the size in bytes is written as an integer in the next 4 bytes. Elements of the array are each written
+ * out with {@link #writeNull} if null, or {@link #writeNullableDouble} if not, taking either 1 or 9 bytes each. If
+ * the total byte size of serializing the array is larger than the max size parameter, this method will explode via a
+ * call to {@link #checkMaxBytes}.
+ *
+ * layout: | null (byte) | size (int) | {| null (byte) | double |, | null (byte) |, ... |null (byte) | double |} |
+ *
+ * This method checks that no more than the specified maximum number of bytes can be written to the buffer, and the
+ * proper function of this method requires that the buffer contains at least that many bytes free from the starting
+ * offset.
+ *
+ * This method does not change the buffer position, limit, or mark, because it does not expect to own the buffer
+ * given to it (i.e. buffer aggs)
+ *
+ * @return number of bytes written (1 if null, or 5 + size of Double[] if not)
+ */
+ public static int writeNullableDoubleArray(ByteBuffer buffer, int offset, @Nullable Double[] array, int maxSizeBytes)
+ {
+ // | null (byte) | array length (int) | array bytes |
+ if (array == null) {
+ return writeNull(buffer, offset);
+ }
+ int sizeBytes = 1 + Integer.BYTES;
+ buffer.put(offset, NullHandling.IS_NOT_NULL_BYTE);
+ buffer.putInt(offset + 1, array.length);
+ for (Double element : array) {
+ if (element != null) {
+ checkMaxBytes(
+ ColumnType.DOUBLE_ARRAY,
+ sizeBytes + 1 + Double.BYTES,
+ maxSizeBytes
+ );
+ sizeBytes += writeNullableDouble(buffer, offset + sizeBytes, element);
+ } else {
+ checkMaxBytes(
+ ColumnType.DOUBLE_ARRAY,
+ sizeBytes + 1,
+ maxSizeBytes
+ );
+ sizeBytes += writeNull(buffer, offset + sizeBytes);
+ }
+ }
+ return sizeBytes;
+ }
+
+ /**
+ * Reads a nullable variably lengthed Double[] value from a {@link ByteBuffer} at the supplied offset. If the null
+ * byte is set to {@link NullHandling#IS_NULL_BYTE}, this method will return null, else it will read the size of the
+ * array from the next 4 bytes and then read that many elements with {@link #isNullableNull} and
+ * {@link #readNullableDouble}.
+ *
+ * layout: | null (byte) | size (int) | {| null (byte) | double |, | null (byte) |, ... |null (byte) | double |} |
+ *
+ * This method does not change the buffer position, limit, or mark, because it does not expect to own the buffer
+ * given to it (i.e. buffer aggs)
+ */
+ @Nullable
+ public static Double[] readNullableDoubleArray(ByteBuffer buffer, int offset)
+ {
+ // | null (byte) | array length (int) | array bytes |
+ if (isNullableNull(buffer, offset++)) {
+ return null;
+ }
+ final int doubleArrayLength = buffer.getInt(offset);
+ offset += Integer.BYTES;
+ final Double[] doubles = new Double[doubleArrayLength];
+ for (int i = 0; i < doubleArrayLength; i++) {
+ if (isNullableNull(buffer, offset)) {
+ doubles[i] = null;
+ } else {
+ doubles[i] = readNullableDouble(buffer, offset);
+ offset += Double.BYTES;
+ }
+ offset++;
+ }
+ return doubles;
+ }
+
+ /**
+ * Write a variably lengthed String[] value to a {@link ByteBuffer} at the supplied offset. The first byte is set to
+ * {@link NullHandling#IS_NULL_BYTE} or {@link NullHandling#IS_NOT_NULL_BYTE} as appropriate, and if the String[]
+ * value is not null, the size in bytes is written as an integer in the next 4 bytes. The Strings themselves are
+ * encoded with {@link StringUtils#toUtf8} Elements of the array are each written out with {@link #writeNull} if null,
+ * or {@link #writeNullableVariableBlob} if not, taking either 1 or 5 + the size of the utf8 byte array each. If the
+ * total byte size of serializing the array is larger than the max size parameter, this method will explode via a
+ * call to {@link #checkMaxBytes}.
+ *
+ * layout: | null (byte) | size (int) | {| null (byte) | size (int) | byte[] |, | null (byte) |, ... } |
+ *
+ * This method checks that no more than the specified maximum number of bytes can be written to the buffer, and the
+ * proper function of this method requires that the buffer contains at least that many bytes free from the starting
+ * offset.
+ *
+ * This method does not change the buffer position, limit, or mark, because it does not expect to own the buffer
+ * given to it (i.e. buffer aggs)
+ *
+ * @return number of bytes written (1 if null, or 5 + size of String[] if not)
+ */
+ public static int writeNullableStringArray(ByteBuffer buffer, int offset, @Nullable String[] array, int maxSizeBytes)
+ {
+ // | null (byte) | array length (int) | array bytes |
+ if (array == null) {
+ return writeNull(buffer, offset);
+ }
+ int sizeBytes = 1 + Integer.BYTES;
+ buffer.put(offset, NullHandling.IS_NOT_NULL_BYTE);
+ buffer.putInt(offset + 1, array.length);
+ for (String element : array) {
+ if (element != null) {
+ final byte[] stringElementBytes = StringUtils.toUtf8(element);
+ checkMaxBytes(
+ ColumnType.STRING_ARRAY,
+ sizeBytes + 1 + Integer.BYTES + stringElementBytes.length,
+ maxSizeBytes
+ );
+ sizeBytes += writeNullableVariableBlob(buffer, offset + sizeBytes, stringElementBytes);
+ } else {
+ checkMaxBytes(
+ ColumnType.STRING_ARRAY,
+ sizeBytes + 1,
+ maxSizeBytes
+ );
+ sizeBytes += writeNull(buffer, offset + sizeBytes);
+ }
+ }
+ return sizeBytes;
+ }
+
+ /**
+ * Reads a nullable variably lengthed String[] value from a {@link ByteBuffer} at the supplied offset. If the null
+ * byte is set to {@link NullHandling#IS_NULL_BYTE}, this method will return null, else it will read the size of the
+ * array from the next 4 bytes and then read that many elements with {@link #readNullableVariableBlob} and decode them
+ * with {@link StringUtils#fromUtf8} to convert to string values.
+ *
+ * layout: | null (byte) | size (int) | {| null (byte) | size (int) | byte[] |, | null (byte) |, ... } |
+ *
+ * This method does not change the buffer position, limit, or mark, because it does not expect to own the buffer
+ * given to it (i.e. buffer aggs)
+ */
+ @Nullable
+ public static String[] readNullableStringArray(ByteBuffer buffer, int offset)
+ {
+ // | null (byte) | array length (int) | array bytes |
+ if (isNullableNull(buffer, offset++)) {
+ return null;
+ }
+ final int stringArrayLength = buffer.getInt(offset);
+ offset += Integer.BYTES;
+ final String[] stringArray = new String[stringArrayLength];
+ for (int i = 0; i < stringArrayLength; i++) {
+ if (isNullableNull(buffer, offset)) {
+ stringArray[i] = null;
+ } else {
+ final byte[] stringElementBytes = readNullableVariableBlob(buffer, offset);
+ stringArray[i] = StringUtils.fromUtf8(stringElementBytes);
+ offset += Integer.BYTES + stringElementBytes.length;
+ }
+ offset++;
+ }
+ return stringArray;
+ }
+
+ /**
+ * Write a variably lengthed byte[] value derived from some {@link ObjectByteStrategy} for a complex
+ * {@link TypeSignature} to a {@link ByteBuffer} at the supplied offset. The first byte is set to
+ * {@link NullHandling#IS_NULL_BYTE} or {@link NullHandling#IS_NOT_NULL_BYTE} as appropriate, and if the value
+ * is not null, the size in bytes is written as an integer in the next 4 bytes, followed by the byte[] value itself
+ * from {@link ObjectByteStrategy#toBytes}.
+ *
+ * layout: | null (byte) | size (int) | byte[] |
+ *
+ * Note that the {@link TypeSignature#getComplexTypeName()} MUST have registered an {@link ObjectByteStrategy} with
+ * {@link #registerStrategy} for this method to work, else a null pointer exception will be thrown.
+ *
+ * This method checks that no more than the specified maximum number of bytes can be written to the buffer, and the
+ * proper function of this method requires that the buffer contains at least that many bytes free from the starting
+ * offset.
+ *
+ * This method does not change the buffer position, limit, or mark, because it does not expect to own the buffer
+ * given to it (i.e. buffer aggs)
+ *
+ * @return number of bytes written (1 if null, or 5 + size of byte[] if not)
+ */
+ public static <T> int writeNullableComplexType(
+ ByteBuffer buffer,
+ int offset,
+ TypeSignature<?> type,
+ @Nullable T value,
+ int maxSizeBytes
+ )
+ {
+ final ObjectByteStrategy strategy = Preconditions.checkNotNull(
+ getStrategy(type.getComplexTypeName()),
+ StringUtils.format(
+ "Type %s has not registered an ObjectByteStrategy and cannot be written",
+ type.asTypeString()
+ )
+ );
+ if (value == null) {
+ return writeNull(buffer, offset);
+ }
+ final byte[] complexBytes = strategy.toBytes(value);
+ return writeNullableVariableBlob(buffer, offset, complexBytes, type, maxSizeBytes);
+ }
+
+ /**
+ * Read a possibly null, variably lengthed byte[] value derived from some {@link ObjectByteStrategy} for a complex
+ * {@link TypeSignature} from a {@link ByteBuffer} at the supplied offset. If the first byte is set to
+ * {@link NullHandling#IS_NULL_BYTE}, this method will return null, and if the value is not null, the size in bytes
+ * is read as an integer from the next 4 bytes, followed by the byte[] value itself from
+ * {@link ObjectByteStrategy#fromByteBuffer}.
+ *
+ * layout: | null (byte) | size (int) | byte[] |
+ *
+ * Note that the {@link TypeSignature#getComplexTypeName()} MUST have registered an {@link ObjectByteStrategy} with
+ * {@link #registerStrategy} for this method to work, else a null pointer exception will be thrown.
+ *
+ * This method does not change the buffer position, limit, or mark, because it does not expect to own the buffer
+ * given to it (i.e. buffer aggs)
+ */
+ @Nullable
+ public static Object readNullableComplexType(ByteBuffer buffer, int offset, TypeSignature<?> type)
+ {
+ if (isNullableNull(buffer, offset++)) {
+ return null;
+ }
+ final ObjectByteStrategy strategy = Preconditions.checkNotNull(
+ getStrategy(type.getComplexTypeName()),
+ StringUtils.format(
+ "Type %s has not registered an ObjectByteStrategy and cannot be read",
+ type.asTypeString()
+ )
+ );
+ final int complexLength = buffer.getInt(offset);
+ offset += Integer.BYTES;
+ ByteBuffer dupe = buffer.duplicate();
+ dupe.position(offset);
+ dupe.limit(offset + complexLength);
+ return strategy.fromByteBuffer(dupe, complexLength);
+ }
+
+ /**
+ * Throw an {@link ISE} for consistent error messaging if the size to be written is greater than the max size
+ */
+ public static void checkMaxBytes(TypeSignature<?> type, int sizeBytes, int maxSizeBytes)
+ {
+ if (sizeBytes > maxSizeBytes) {
+ throw new ISE(
+ "Unable to serialize [%s], size [%s] is larger than max [%s]",
+ type.asTypeString(),
+ sizeBytes,
+ maxSizeBytes
+ );
+ }
+ }
}
diff --git a/core/src/test/java/org/apache/druid/math/expr/ApplyFunctionTest.java b/core/src/test/java/org/apache/druid/math/expr/ApplyFunctionTest.java
index a63f0ec..d20fffe 100644
--- a/core/src/test/java/org/apache/druid/math/expr/ApplyFunctionTest.java
+++ b/core/src/test/java/org/apache/druid/math/expr/ApplyFunctionTest.java
@@ -205,25 +205,25 @@ public class ApplyFunctionTest extends InitializedNullHandlingTest
private void assertExpr(final String expression, final Double[] expectedResult)
{
final Expr expr = Parser.parse(expression, ExprMacroTable.nil());
- Double[] result = (Double[]) expr.eval(bindings).value();
+ Object[] result = expr.eval(bindings).asArray();
Assert.assertEquals(expectedResult.length, result.length);
for (int i = 0; i < result.length; i++) {
- Assert.assertEquals(expression, expectedResult[i], result[i], 0.00001); // something is lame somewhere..
+ Assert.assertEquals(expression, expectedResult[i], (Double) result[i], 0.00001); // something is lame somewhere..
}
final Expr exprNoFlatten = Parser.parse(expression, ExprMacroTable.nil(), false);
final Expr roundTrip = Parser.parse(exprNoFlatten.stringify(), ExprMacroTable.nil());
- Double[] resultRoundTrip = (Double[]) roundTrip.eval(bindings).value();
+ Object[] resultRoundTrip = (Object[]) roundTrip.eval(bindings).value();
Assert.assertEquals(expectedResult.length, resultRoundTrip.length);
for (int i = 0; i < resultRoundTrip.length; i++) {
- Assert.assertEquals(expression, expectedResult[i], resultRoundTrip[i], 0.00001);
+ Assert.assertEquals(expression, expectedResult[i], (Double) resultRoundTrip[i], 0.00001);
}
final Expr roundTripFlatten = Parser.parse(expr.stringify(), ExprMacroTable.nil());
- Double[] resultRoundTripFlatten = (Double[]) roundTripFlatten.eval(bindings).value();
+ Object[] resultRoundTripFlatten = (Object[]) roundTripFlatten.eval(bindings).value();
Assert.assertEquals(expectedResult.length, resultRoundTripFlatten.length);
for (int i = 0; i < resultRoundTripFlatten.length; i++) {
- Assert.assertEquals(expression, expectedResult[i], resultRoundTripFlatten[i], 0.00001);
+ Assert.assertEquals(expression, expectedResult[i], (Double) resultRoundTripFlatten[i], 0.00001);
}
Assert.assertEquals(expr.stringify(), roundTrip.stringify());
diff --git a/core/src/test/java/org/apache/druid/math/expr/ExprEvalTest.java b/core/src/test/java/org/apache/druid/math/expr/ExprEvalTest.java
index a486586..9056225 100644
--- a/core/src/test/java/org/apache/druid/math/expr/ExprEvalTest.java
+++ b/core/src/test/java/org/apache/druid/math/expr/ExprEvalTest.java
@@ -20,11 +20,14 @@
package org.apache.druid.math.expr;
import com.google.common.collect.ImmutableList;
-import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.ISE;
+import org.apache.druid.java.util.common.NonnullPair;
import org.apache.druid.java.util.common.StringUtils;
+import org.apache.druid.segment.column.Types;
+import org.apache.druid.segment.column.TypesTest;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Assert;
+import org.junit.BeforeClass;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
@@ -42,6 +45,12 @@ public class ExprEvalTest extends InitializedNullHandlingTest
ByteBuffer buffer = ByteBuffer.allocate(1 << 16);
+ @BeforeClass
+ public static void setup()
+ {
+ Types.registerStrategy(TypesTest.NULLABLE_TEST_PAIR_TYPE.getComplexTypeName(), new TypesTest.PairObjectByteStrategy());
+ }
+
@Test
public void testStringSerde()
{
@@ -109,10 +118,10 @@ public class ExprEvalTest extends InitializedNullHandlingTest
expectedException.expectMessage(StringUtils.format(
"Unable to serialize [%s], size [%s] is larger than max [%s]",
ExpressionType.STRING_ARRAY,
- 28,
+ 15,
10
));
- assertEstimatedBytes(ExprEval.ofStringArray(new String[]{"hello", "hi", "hey"}), 10);
+ assertExpr(0, ExprEval.ofStringArray(new String[]{"hello", "hi", "hey"}), 10);
}
@Test
@@ -130,7 +139,7 @@ public class ExprEvalTest extends InitializedNullHandlingTest
expectedException.expectMessage(StringUtils.format(
"Unable to serialize [%s], size [%s] is larger than max [%s]",
ExpressionType.LONG_ARRAY,
- 30,
+ 14,
10
));
assertExpr(0, ExprEval.ofLongArray(new Long[]{1L, 2L, 3L}), 10);
@@ -143,10 +152,10 @@ public class ExprEvalTest extends InitializedNullHandlingTest
expectedException.expectMessage(StringUtils.format(
"Unable to serialize [%s], size [%s] is larger than max [%s]",
ExpressionType.LONG_ARRAY,
- NullHandling.sqlCompatible() ? 33 : 30,
+ 14,
10
));
- assertEstimatedBytes(ExprEval.ofLongArray(new Long[]{1L, 2L, 3L}), 10);
+ assertExpr(0, ExprEval.ofLongArray(new Long[]{1L, 2L, 3L}), 10);
}
@Test
@@ -164,7 +173,7 @@ public class ExprEvalTest extends InitializedNullHandlingTest
expectedException.expectMessage(StringUtils.format(
"Unable to serialize [%s], size [%s] is larger than max [%s]",
ExpressionType.DOUBLE_ARRAY,
- 30,
+ 14,
10
));
assertExpr(0, ExprEval.ofDoubleArray(new Double[]{1.1, 2.2, 3.3}), 10);
@@ -177,86 +186,144 @@ public class ExprEvalTest extends InitializedNullHandlingTest
expectedException.expectMessage(StringUtils.format(
"Unable to serialize [%s], size [%s] is larger than max [%s]",
ExpressionType.DOUBLE_ARRAY,
- NullHandling.sqlCompatible() ? 33 : 30,
+ 14,
10
));
- assertEstimatedBytes(ExprEval.ofDoubleArray(new Double[]{1.1, 2.2, 3.3}), 10);
+ assertExpr(0, ExprEval.ofDoubleArray(new Double[]{1.1, 2.2, 3.3}), 10);
+ }
+
+ @Test
+ public void testComplexEval()
+ {
+ final ExpressionType complexType = ExpressionType.fromColumnType(TypesTest.NULLABLE_TEST_PAIR_TYPE);
+ assertExpr(0, ExprEval.ofComplex(complexType, new TypesTest.NullableLongPair(1234L, 5678L)));
+ assertExpr(1024, ExprEval.ofComplex(complexType, new TypesTest.NullableLongPair(1234L, 5678L)));
+ }
+
+ @Test
+ public void testComplexEvalTooBig()
+ {
+ final ExpressionType complexType = ExpressionType.fromColumnType(TypesTest.NULLABLE_TEST_PAIR_TYPE);
+ expectedException.expect(ISE.class);
+ expectedException.expectMessage(StringUtils.format(
+ "Unable to serialize [%s], size [%s] is larger than max [%s]",
+ complexType.asTypeString(),
+ 23,
+ 10
+ ));
+ assertExpr(0, ExprEval.ofComplex(complexType, new TypesTest.NullableLongPair(1234L, 5678L)), 10);
}
@Test
public void test_coerceListToArray()
{
Assert.assertNull(ExprEval.coerceListToArray(null, false));
- Assert.assertArrayEquals(new Object[0], (Object[]) ExprEval.coerceListToArray(ImmutableList.of(), false));
- Assert.assertArrayEquals(new String[]{null}, (String[]) ExprEval.coerceListToArray(null, true));
- Assert.assertArrayEquals(new String[]{null}, (String[]) ExprEval.coerceListToArray(ImmutableList.of(), true));
+
+ NonnullPair<ExpressionType, Object[]> coerced = ExprEval.coerceListToArray(ImmutableList.of(), false);
+ Assert.assertEquals(ExpressionType.STRING_ARRAY, coerced.lhs);
+ Assert.assertArrayEquals(new Object[0], coerced.rhs);
+
+ coerced = ExprEval.coerceListToArray(null, true);
+ Assert.assertEquals(ExpressionType.STRING_ARRAY, coerced.lhs);
+ Assert.assertArrayEquals(new Object[]{null}, coerced.rhs);
+
+ coerced = ExprEval.coerceListToArray(ImmutableList.of(), true);
+ Assert.assertEquals(ExpressionType.STRING_ARRAY, coerced.lhs);
+ Assert.assertArrayEquals(new Object[]{null}, coerced.rhs);
List<Long> longList = ImmutableList.of(1L, 2L, 3L);
- Assert.assertArrayEquals(new Long[]{1L, 2L, 3L}, (Long[]) ExprEval.coerceListToArray(longList, false));
+ coerced = ExprEval.coerceListToArray(longList, false);
+ Assert.assertEquals(ExpressionType.LONG_ARRAY, coerced.lhs);
+ Assert.assertArrayEquals(new Object[]{1L, 2L, 3L}, coerced.rhs);
List<Integer> intList = ImmutableList.of(1, 2, 3);
- Assert.assertArrayEquals(new Long[]{1L, 2L, 3L}, (Long[]) ExprEval.coerceListToArray(intList, false));
+ ExprEval.coerceListToArray(intList, false);
+ Assert.assertEquals(ExpressionType.LONG_ARRAY, coerced.lhs);
+ Assert.assertArrayEquals(new Object[]{1L, 2L, 3L}, coerced.rhs);
List<Float> floatList = ImmutableList.of(1.0f, 2.0f, 3.0f);
- Assert.assertArrayEquals(new Double[]{1.0, 2.0, 3.0}, (Double[]) ExprEval.coerceListToArray(floatList, false));
+ coerced = ExprEval.coerceListToArray(floatList, false);
+ Assert.assertEquals(ExpressionType.DOUBLE_ARRAY, coerced.lhs);
+ Assert.assertArrayEquals(new Object[]{1.0, 2.0, 3.0}, coerced.rhs);
List<Double> doubleList = ImmutableList.of(1.0, 2.0, 3.0);
- Assert.assertArrayEquals(new Double[]{1.0, 2.0, 3.0}, (Double[]) ExprEval.coerceListToArray(doubleList, false));
+ coerced = ExprEval.coerceListToArray(doubleList, false);
+ Assert.assertEquals(ExpressionType.DOUBLE_ARRAY, coerced.lhs);
+ Assert.assertArrayEquals(new Object[]{1.0, 2.0, 3.0}, coerced.rhs);
List<String> stringList = ImmutableList.of("a", "b", "c");
- Assert.assertArrayEquals(new String[]{"a", "b", "c"}, (String[]) ExprEval.coerceListToArray(stringList, false));
+ coerced = ExprEval.coerceListToArray(stringList, false);
+ Assert.assertEquals(ExpressionType.STRING_ARRAY, coerced.lhs);
+ Assert.assertArrayEquals(new Object[]{"a", "b", "c"}, coerced.rhs);
List<String> withNulls = new ArrayList<>();
withNulls.add("a");
withNulls.add(null);
withNulls.add("c");
- Assert.assertArrayEquals(new String[]{"a", null, "c"}, (String[]) ExprEval.coerceListToArray(withNulls, false));
+ coerced = ExprEval.coerceListToArray(withNulls, false);
+ Assert.assertEquals(ExpressionType.STRING_ARRAY, coerced.lhs);
+ Assert.assertArrayEquals(new Object[]{"a", null, "c"}, coerced.rhs);
List<Long> withNumberNulls = new ArrayList<>();
withNumberNulls.add(1L);
withNumberNulls.add(null);
withNumberNulls.add(3L);
- Assert.assertArrayEquals(new Long[]{1L, null, 3L}, (Long[]) ExprEval.coerceListToArray(withNumberNulls, false));
+ coerced = ExprEval.coerceListToArray(withNumberNulls, false);
+ Assert.assertEquals(ExpressionType.LONG_ARRAY, coerced.lhs);
+ Assert.assertArrayEquals(new Object[]{1L, null, 3L}, coerced.rhs);
List<Object> withStringMix = ImmutableList.of(1L, "b", 3L);
+ coerced = ExprEval.coerceListToArray(withStringMix, false);
+ Assert.assertEquals(ExpressionType.STRING_ARRAY, coerced.lhs);
Assert.assertArrayEquals(
- new String[]{"1", "b", "3"},
- (String[]) ExprEval.coerceListToArray(withStringMix, false)
+ new Object[]{"1", "b", "3"},
+ coerced.rhs
);
List<Number> withIntsAndLongs = ImmutableList.of(1, 2L, 3);
+ coerced = ExprEval.coerceListToArray(withIntsAndLongs, false);
+ Assert.assertEquals(ExpressionType.LONG_ARRAY, coerced.lhs);
Assert.assertArrayEquals(
- new Long[]{1L, 2L, 3L},
- (Long[]) ExprEval.coerceListToArray(withIntsAndLongs, false)
+ new Object[]{1L, 2L, 3L},
+ coerced.rhs
);
List<Number> withFloatsAndLongs = ImmutableList.of(1, 2L, 3.0f);
+ coerced = ExprEval.coerceListToArray(withFloatsAndLongs, false);
+ Assert.assertEquals(ExpressionType.DOUBLE_ARRAY, coerced.lhs);
Assert.assertArrayEquals(
- new Double[]{1.0, 2.0, 3.0},
- (Double[]) ExprEval.coerceListToArray(withFloatsAndLongs, false)
+ new Object[]{1.0, 2.0, 3.0},
+ coerced.rhs
);
List<Number> withDoublesAndLongs = ImmutableList.of(1, 2L, 3.0);
+ coerced = ExprEval.coerceListToArray(withDoublesAndLongs, false);
+ Assert.assertEquals(ExpressionType.DOUBLE_ARRAY, coerced.lhs);
Assert.assertArrayEquals(
- new Double[]{1.0, 2.0, 3.0},
- (Double[]) ExprEval.coerceListToArray(withDoublesAndLongs, false)
+ new Object[]{1.0, 2.0, 3.0},
+ coerced.rhs
);
List<Number> withFloatsAndDoubles = ImmutableList.of(1L, 2.0f, 3.0);
+ coerced = ExprEval.coerceListToArray(withFloatsAndDoubles, false);
+ Assert.assertEquals(ExpressionType.DOUBLE_ARRAY, coerced.lhs);
Assert.assertArrayEquals(
- new Double[]{1.0, 2.0, 3.0},
- (Double[]) ExprEval.coerceListToArray(withFloatsAndDoubles, false)
+ new Object[]{1.0, 2.0, 3.0},
+ coerced.rhs
);
List<String> withAllNulls = new ArrayList<>();
withAllNulls.add(null);
withAllNulls.add(null);
withAllNulls.add(null);
+ coerced = ExprEval.coerceListToArray(withAllNulls, false);
+ Assert.assertEquals(ExpressionType.STRING_ARRAY, coerced.lhs);
Assert.assertArrayEquals(
- new String[]{null, null, null},
- (String[]) ExprEval.coerceListToArray(withAllNulls, false)
+ new Object[]{null, null, null},
+ coerced.rhs
);
+
}
@Test
@@ -299,27 +366,21 @@ public class ExprEvalTest extends InitializedNullHandlingTest
if (expected.type().isArray()) {
Assert.assertArrayEquals(
expected.asArray(),
- ExprEval.deserialize(buffer, position + 1, ExprType.fromByte(buffer.get(position))).asArray()
+ ExprEval.deserialize(buffer, position, expected.type()).asArray()
);
Assert.assertArrayEquals(
expected.asArray(),
- ExprEval.deserialize(buffer, position).asArray()
+ ExprEval.deserialize(buffer, position, expected.type()).asArray()
);
} else {
Assert.assertEquals(
expected.value(),
- ExprEval.deserialize(buffer, position + 1, ExprType.fromByte(buffer.get(position))).value()
+ ExprEval.deserialize(buffer, position, expected.type()).value()
);
Assert.assertEquals(
expected.value(),
- ExprEval.deserialize(buffer, position).value()
+ ExprEval.deserialize(buffer, position, expected.type()).value()
);
}
- assertEstimatedBytes(expected, maxSizeBytes);
- }
-
- private void assertEstimatedBytes(ExprEval eval, int maxSizeBytes)
- {
- ExprEval.estimateAndCheckMaxBytes(eval, maxSizeBytes);
}
}
diff --git a/core/src/test/java/org/apache/druid/math/expr/ExprTest.java b/core/src/test/java/org/apache/druid/math/expr/ExprTest.java
index 05c1929..e3b2945 100644
--- a/core/src/test/java/org/apache/druid/math/expr/ExprTest.java
+++ b/core/src/test/java/org/apache/druid/math/expr/ExprTest.java
@@ -160,34 +160,27 @@ public class ExprTest
}
@Test
- public void testEqualsContractForStringArrayExpr()
+ public void testEqualsContractForArrayExpr()
{
- EqualsVerifier.forClass(StringArrayExpr.class)
- .withIgnoredFields("outputType")
- .withPrefabValues(Object.class, new String[]{"foo"}, new String[0])
- .withPrefabValues(ExpressionType.class, ExpressionType.STRING_ARRAY, ExpressionType.LONG_ARRAY)
- .usingGetClass()
- .verify();
- }
-
- @Test
- public void testEqualsContractForLongArrayExpr()
- {
- EqualsVerifier.forClass(LongArrayExpr.class)
- .withIgnoredFields("outputType")
- .withPrefabValues(Object.class, new Long[]{1L}, new Long[0])
+ EqualsVerifier.forClass(ArrayExpr.class)
+ .withPrefabValues(Object.class, new Object[]{1L}, new Object[0])
.withPrefabValues(ExpressionType.class, ExpressionType.LONG_ARRAY, ExpressionType.DOUBLE_ARRAY)
+ .withNonnullFields("outputType")
.usingGetClass()
.verify();
}
@Test
- public void testEqualsContractForDoubleArrayExpr()
+ public void testEqualsContractForComplexExpr()
{
- EqualsVerifier.forClass(DoubleArrayExpr.class)
- .withIgnoredFields("outputType")
- .withPrefabValues(Object.class, new Double[]{1.0}, new Double[0])
- .withPrefabValues(ExpressionType.class, ExpressionType.DOUBLE_ARRAY, ExpressionType.STRING_ARRAY)
+ EqualsVerifier.forClass(ComplexExpr.class)
+ .withPrefabValues(Object.class, new Object[]{1L}, new Object[0])
+ .withPrefabValues(
+ ExpressionType.class,
+ ExpressionTypeFactory.getInstance().ofComplex("foo"),
+ ExpressionTypeFactory.getInstance().ofComplex("bar")
+ )
+ .withNonnullFields("outputType")
.usingGetClass()
.verify();
}
diff --git a/core/src/test/java/org/apache/druid/math/expr/FunctionTest.java b/core/src/test/java/org/apache/druid/math/expr/FunctionTest.java
index c05422a..bc7afa3 100644
--- a/core/src/test/java/org/apache/druid/math/expr/FunctionTest.java
+++ b/core/src/test/java/org/apache/druid/math/expr/FunctionTest.java
@@ -25,9 +25,13 @@ import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.Pair;
import org.apache.druid.java.util.common.StringUtils;
+import org.apache.druid.segment.column.ObjectByteStrategy;
+import org.apache.druid.segment.column.Types;
+import org.apache.druid.segment.column.TypesTest;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Assert;
import org.junit.Before;
+import org.junit.BeforeClass;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
@@ -45,6 +49,12 @@ public class FunctionTest extends InitializedNullHandlingTest
private Expr.ObjectBinding bindings;
+ @BeforeClass
+ public static void setupClass()
+ {
+ Types.registerStrategy(TypesTest.NULLABLE_TEST_PAIR_TYPE.getComplexTypeName(), new TypesTest.PairObjectByteStrategy());
+ }
+
@Before
public void setup()
{
@@ -64,7 +74,8 @@ public class FunctionTest extends InitializedNullHandlingTest
.put("of", 0F)
.put("a", new String[] {"foo", "bar", "baz", "foobar"})
.put("b", new Long[] {1L, 2L, 3L, 4L, 5L})
- .put("c", new Double[] {3.1, 4.2, 5.3});
+ .put("c", new Double[] {3.1, 4.2, 5.3})
+ .put("someComplex", new TypesTest.NullableLongPair(1L, 2L));
bindings = InputBindings.withMap(builder.build());
}
@@ -281,7 +292,7 @@ public class FunctionTest extends InitializedNullHandlingTest
public void testArrayAppend()
{
assertArrayExpr("array_append([1, 2, 3], 4)", new Long[]{1L, 2L, 3L, 4L});
- assertArrayExpr("array_append([1, 2, 3], 'bar')", new Long[]{1L, 2L, 3L, null});
+ assertArrayExpr("array_append([1, 2, 3], 'bar')", new Long[]{1L, 2L, 3L, NullHandling.defaultLongValue()});
assertArrayExpr("array_append([], 1)", new String[]{"1"});
assertArrayExpr("array_append(<LONG>[], 1)", new Long[]{1L});
}
@@ -300,11 +311,11 @@ public class FunctionTest extends InitializedNullHandlingTest
public void testArraySetAdd()
{
assertArrayExpr("array_set_add([1, 2, 3], 4)", new Long[]{1L, 2L, 3L, 4L});
- assertArrayExpr("array_set_add([1, 2, 3], 'bar')", new Long[]{null, 1L, 2L, 3L});
+ assertArrayExpr("array_set_add([1, 2, 3], 'bar')", new Long[]{NullHandling.defaultLongValue(), 1L, 2L, 3L});
assertArrayExpr("array_set_add([1, 2, 2], 1)", new Long[]{1L, 2L});
assertArrayExpr("array_set_add([], 1)", new String[]{"1"});
assertArrayExpr("array_set_add(<LONG>[], 1)", new Long[]{1L});
- assertArrayExpr("array_set_add(<LONG>[], null)", new Long[]{null});
+ assertArrayExpr("array_set_add(<LONG>[], null)", new Long[]{NullHandling.defaultLongValue()});
}
@Test
@@ -358,7 +369,7 @@ public class FunctionTest extends InitializedNullHandlingTest
public void testArrayPrepend()
{
assertArrayExpr("array_prepend(4, [1, 2, 3])", new Long[]{4L, 1L, 2L, 3L});
- assertArrayExpr("array_prepend('bar', [1, 2, 3])", new Long[]{null, 1L, 2L, 3L});
+ assertArrayExpr("array_prepend('bar', [1, 2, 3])", new Long[]{NullHandling.defaultLongValue(), 1L, 2L, 3L});
assertArrayExpr("array_prepend(1, [])", new String[]{"1"});
assertArrayExpr("array_prepend(1, <LONG>[])", new Long[]{1L});
assertArrayExpr("array_prepend(1, <DOUBLE>[])", new Double[]{1.0});
@@ -792,6 +803,66 @@ public class FunctionTest extends InitializedNullHandlingTest
assertExpr("repeat(nonexistent, 10)", null);
}
+ @Test
+ public void testComplexDecode()
+ {
+ TypesTest.NullableLongPair expected = new TypesTest.NullableLongPair(1L, 2L);
+ ObjectByteStrategy strategy = Types.getStrategy(TypesTest.NULLABLE_TEST_PAIR_TYPE.getComplexTypeName());
+ assertExpr(
+ StringUtils.format(
+ "complex_decode_base64('%s', '%s')",
+ TypesTest.NULLABLE_TEST_PAIR_TYPE.getComplexTypeName(),
+ StringUtils.encodeBase64String(strategy.toBytes(expected))
+ ),
+ expected
+ );
+ }
+
+ @Test
+ public void testComplexDecodeNull()
+ {
+ assertExpr(
+ StringUtils.format(
+ "complex_decode_base64('%s', null)",
+ TypesTest.NULLABLE_TEST_PAIR_TYPE.getComplexTypeName()
+ ),
+ null
+ );
+ }
+
+ @Test
+ public void testComplexDecodeBaseWrongArgCount()
+ {
+ expectedException.expect(IAE.class);
+ expectedException.expectMessage("Function[complex_decode_base64] needs 2 arguments");
+ assertExpr(
+ "complex_decode_base64(string)",
+ null
+ );
+ }
+
+ @Test
+ public void testComplexDecodeBaseArg0BadType()
+ {
+ expectedException.expect(IAE.class);
+ expectedException.expectMessage("Function[complex_decode_base64] first argument must be constant 'STRING' expression containing a valid complex type name");
+ assertExpr(
+ "complex_decode_base64(1, string)",
+ null
+ );
+ }
+
+ @Test
+ public void testComplexDecodeBaseArg0Unknown()
+ {
+ expectedException.expect(IAE.class);
+ expectedException.expectMessage("Function[complex_decode_base64] first argument must be a valid complex type name, unknown complex type [COMPLEX<unknown>]");
+ assertExpr(
+ "complex_decode_base64('unknown', string)",
+ null
+ );
+ }
+
private void assertExpr(final String expression, @Nullable final Object expectedResult)
{
final Expr expr = Parser.parse(expression, ExprMacroTable.nil());
diff --git a/core/src/test/java/org/apache/druid/math/expr/ParserTest.java b/core/src/test/java/org/apache/druid/math/expr/ParserTest.java
index 82405cb..afa02dd 100644
--- a/core/src/test/java/org/apache/druid/math/expr/ParserTest.java
+++ b/core/src/test/java/org/apache/druid/math/expr/ParserTest.java
@@ -22,9 +22,15 @@ package org.apache.druid.math.expr;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
+import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.RE;
+import org.apache.druid.java.util.common.StringUtils;
+import org.apache.druid.segment.column.ObjectByteStrategy;
+import org.apache.druid.segment.column.Types;
+import org.apache.druid.segment.column.TypesTest;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Assert;
+import org.junit.BeforeClass;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
@@ -43,6 +49,12 @@ public class ParserTest extends InitializedNullHandlingTest
VectorExprSanityTest.SettableVectorInputBinding emptyBinding = new VectorExprSanityTest.SettableVectorInputBinding(8);
+ @BeforeClass
+ public static void setup()
+ {
+ Types.registerStrategy(TypesTest.NULLABLE_TEST_PAIR_TYPE.getComplexTypeName(), new TypesTest.PairObjectByteStrategy());
+ }
+
@Test
public void testSimple()
{
@@ -222,76 +234,154 @@ public class ParserTest extends InitializedNullHandlingTest
@Test
public void testLiteralArraysHomogeneousElements()
{
- validateConstantExpression("[1.0, 2.345]", new Double[]{1.0, 2.345});
- validateConstantExpression("[1, 3]", new Long[]{1L, 3L});
- validateConstantExpression("['hello', 'world']", new String[]{"hello", "world"});
+ validateConstantExpression("[1.0, 2.345]", new Object[]{1.0, 2.345});
+ validateConstantExpression("[1, 3]", new Object[]{1L, 3L});
+ validateConstantExpression("['hello', 'world']", new Object[]{"hello", "world"});
}
@Test
public void testLiteralArraysHomogeneousOrNullElements()
{
- validateConstantExpression("[1.0, null, 2.345]", new Double[]{1.0, null, 2.345});
- validateConstantExpression("[null, 1, 3]", new Long[]{null, 1L, 3L});
- validateConstantExpression("['hello', 'world', null]", new String[]{"hello", "world", null});
+ validateConstantExpression("[1.0, null, 2.345]", new Object[]{1.0, null, 2.345});
+ validateConstantExpression("[null, 1, 3]", new Object[]{null, 1L, 3L});
+ validateConstantExpression("['hello', 'world', null]", new Object[]{"hello", "world", null});
}
@Test
public void testLiteralArraysEmptyAndAllNullImplicitAreString()
{
- validateConstantExpression("[]", new String[0]);
- validateConstantExpression("[null, null, null]", new String[]{null, null, null});
+ validateConstantExpression("[]", new Object[0]);
+ validateConstantExpression("[null, null, null]", new Object[]{null, null, null});
}
@Test
public void testLiteralArraysImplicitTypedNumericMixed()
{
// implicit typed numeric arrays with mixed elements are doubles
- validateConstantExpression("[1, null, 2000.0]", new Double[]{1.0, null, 2000.0});
- validateConstantExpression("[1.0, null, 2000]", new Double[]{1.0, null, 2000.0});
+ validateConstantExpression("[1, null, 2000.0]", new Object[]{1.0, null, 2000.0});
+ validateConstantExpression("[1.0, null, 2000]", new Object[]{1.0, null, 2000.0});
}
@Test
public void testLiteralArraysExplicitTypedEmpties()
{
- validateConstantExpression("<STRING>[]", new String[0]);
- validateConstantExpression("<DOUBLE>[]", new Double[0]);
- validateConstantExpression("<LONG>[]", new Long[0]);
+ // legacy explicit array format
+ validateConstantExpression("<STRING>[]", new Object[0]);
+ validateConstantExpression("<DOUBLE>[]", new Object[0]);
+ validateConstantExpression("<LONG>[]", new Object[0]);
}
@Test
public void testLiteralArraysExplicitAllNull()
{
- validateConstantExpression("<DOUBLE>[null, null, null]", new Double[]{null, null, null});
- validateConstantExpression("<LONG>[null, null, null]", new Long[]{null, null, null});
- validateConstantExpression("<STRING>[null, null, null]", new String[]{null, null, null});
+ // legacy explicit array format
+ validateConstantExpression("<DOUBLE>[null, null, null]", new Object[]{null, null, null});
+ validateConstantExpression("<LONG>[null, null, null]", new Object[]{null, null, null});
+ validateConstantExpression("<STRING>[null, null, null]", new Object[]{null, null, null});
}
@Test
public void testLiteralArraysExplicitTypes()
{
- validateConstantExpression("<DOUBLE>[1.0, null, 2000.0]", new Double[]{1.0, null, 2000.0});
- validateConstantExpression("<LONG>[3, null, 4]", new Long[]{3L, null, 4L});
- validateConstantExpression("<STRING>['foo', 'bar', 'baz']", new String[]{"foo", "bar", "baz"});
+ // legacy explicit array format
+ validateConstantExpression("<DOUBLE>[1.0, null, 2000.0]", new Object[]{1.0, null, 2000.0});
+ validateConstantExpression("<LONG>[3, null, 4]", new Object[]{3L, null, 4L});
+ validateConstantExpression("<STRING>['foo', 'bar', 'baz']", new Object[]{"foo", "bar", "baz"});
}
@Test
public void testLiteralArraysExplicitTypesMixedElements()
{
+ // legacy explicit array format
// explicit typed numeric arrays mixed numeric types should coerce to the correct explicit type
- validateConstantExpression("<DOUBLE>[3, null, 4, 2.345]", new Double[]{3.0, null, 4.0, 2.345});
- validateConstantExpression("<LONG>[1.0, null, 2000.0]", new Long[]{1L, null, 2000L});
+ validateConstantExpression("<DOUBLE>[3, null, 4, 2.345]", new Object[]{3.0, null, 4.0, 2.345});
+ validateConstantExpression("<LONG>[1.0, null, 2000.0]", new Object[]{1L, null, 2000L});
// explicit typed string arrays should accept any literal and convert to string
- validateConstantExpression("<STRING>['1', null, 2000, 1.1]", new String[]{"1", null, "2000", "1.1"});
+ validateConstantExpression("<STRING>['1', null, 2000, 1.1]", new Object[]{"1", null, "2000", "1.1"});
+ }
+
+ @Test
+ public void testLiteralExplicitTypedArrays()
+ {
+ ExpressionProcessing.initializeForTests(true);
+ validateConstantExpression("ARRAY<DOUBLE>[1.0, 2.0, null, 3.0]", new Object[]{1.0, 2.0, null, 3.0});
+ validateConstantExpression("ARRAY<LONG>[1, 2, null, 3]", new Object[]{1L, 2L, null, 3L});
+ validateConstantExpression("ARRAY<STRING>['1', '2', null, '3.0']", new Object[]{"1", "2", null, "3.0"});
+
+ // mixed type tests
+ validateConstantExpression("ARRAY<DOUBLE>[3, null, 4, 2.345]", new Object[]{3.0, null, 4.0, 2.345});
+ validateConstantExpression("ARRAY<LONG>[1.0, null, 2000.0]", new Object[]{1L, null, 2000L});
+
+ // explicit typed string arrays should accept any literal and convert
+ validateConstantExpression("ARRAY<STRING>['1', null, 2000, 1.1]", new Object[]{"1", null, "2000", "1.1"});
+ validateConstantExpression("ARRAY<LONG>['1', null, 2000, 1.1]", new Object[]{1L, null, 2000L, 1L});
+ validateConstantExpression("ARRAY<DOUBLE>['1', null, 2000, 1.1]", new Object[]{1.0, null, 2000.0, 1.1});
+
+ // the gramar isn't cool enough yet to parse populated nested-arrays or complex arrays..., but empty ones can
+ // be defined...
+ validateConstantExpression("ARRAY<COMPLEX<nullableLongPair>>[]", new Object[]{});
+ validateConstantExpression("ARRAY<ARRAY<LONG>>[]", new Object[]{});
+ ExpressionProcessing.initializeForTests(null);
}
@Test
+ public void testConstantComplexAndNestedArrays()
+ {
+ ExpressionProcessing.initializeForTests(true);
+ // they can be built with array builder functions though...
+ validateConstantExpression(
+ "array(['foo', 'bar', 'baz'], ['baz','foo','bar'])",
+ new Object[]{new Object[]{"foo", "bar", "baz"}, new Object[]{"baz", "foo", "bar"}}
+ );
+ // nested arrays cannot be mixed types, the first element choo-choo-chooses for you
+ validateConstantExpression(
+ "array(['foo', 'bar', 'baz'], ARRAY<LONG>[1,2,3])",
+ new Object[]{new Object[]{"foo", "bar", "baz"}, new Object[]{"1", "2", "3"}}
+ );
+
+ // complex types too
+ TypesTest.NullableLongPair l1 = new TypesTest.NullableLongPair(1L, 2L);
+ TypesTest.NullableLongPair l2 = new TypesTest.NullableLongPair(2L, 3L);
+ ObjectByteStrategy byteStrategy = Types.getStrategy(TypesTest.NULLABLE_TEST_PAIR_TYPE.getComplexTypeName());
+ String l1String = StringUtils.format(
+ "complex_decode_base64('%s', '%s')",
+ TypesTest.NULLABLE_TEST_PAIR_TYPE.getComplexTypeName(),
+ StringUtils.encodeBase64String(byteStrategy.toBytes(l1))
+ );
+ String l2String = StringUtils.format(
+ "complex_decode_base64('%s', '%s')",
+ TypesTest.NULLABLE_TEST_PAIR_TYPE.getComplexTypeName(),
+ StringUtils.encodeBase64String(byteStrategy.toBytes(l2))
+ );
+ validateConstantExpression(
+ l1String,
+ l1
+ );
+
+ validateConstantExpression(
+ StringUtils.format("array(%s,%s)", l1String, l2String),
+ new Object[]{l1, l2}
+ );
+ ExpressionProcessing.initializeForTests(null);
+ }
+
+ @Test
+ public void nestedArraysExplodeIfNotEnabled()
+ {
+ expectedException.expect(IAE.class);
+ expectedException.expectMessage("Cannot create a nested array type [ARRAY<ARRAY<LONG>>], 'druid.expressions.allowNestedArrays' must be set to true");
+ validateConstantExpression("ARRAY<ARRAY<LONG>>[]", new Object[]{});
+ }
+
+
+ @Test
public void testLiteralArrayImplicitStringParseException()
{
// implicit typed string array cannot handle literals thate are not null or string
expectedException.expect(RE.class);
expectedException.expectMessage("Failed to parse array: element 2000 is not a string");
- validateConstantExpression("['1', null, 2000, 1.1]", new String[]{"1", null, "2000", "1.1"});
+ validateConstantExpression("['1', null, 2000, 1.1]", new Object[]{"1", null, "2000", "1.1"});
}
@Test
@@ -300,7 +390,7 @@ public class ParserTest extends InitializedNullHandlingTest
// explicit typed long arrays only handle numeric types
expectedException.expect(RE.class);
expectedException.expectMessage("Failed to parse array element '2000' as a long");
- validateConstantExpression("<LONG>[1, null, '2000']", new Long[]{1L, null, 2000L});
+ validateConstantExpression("<LONG>[1, null, '2000']", new Object[]{1L, null, 2000L});
}
@Test
@@ -309,7 +399,7 @@ public class ParserTest extends InitializedNullHandlingTest
// explicit typed double arrays only handle numeric types
expectedException.expect(RE.class);
expectedException.expectMessage("Failed to parse array element '2000.0' as a double");
- validateConstantExpression("<DOUBLE>[1.0, null, '2000.0']", new Double[]{1.0, null, 2000.0});
+ validateConstantExpression("<DOUBLE>[1.0, null, '2000.0']", new Object[]{1.0, null, 2000.0});
}
@Test
diff --git a/core/src/test/java/org/apache/druid/segment/column/TypesTest.java b/core/src/test/java/org/apache/druid/segment/column/TypesTest.java
new file mode 100644
index 0000000..8a1929b
--- /dev/null
+++ b/core/src/test/java/org/apache/druid/segment/column/TypesTest.java
@@ -0,0 +1,443 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.segment.column;
+
+import com.google.common.primitives.Longs;
+import org.apache.druid.java.util.common.ISE;
+import org.apache.druid.java.util.common.Pair;
+import org.apache.druid.java.util.common.StringUtils;
+import org.apache.druid.java.util.common.guava.Comparators;
+import org.junit.Assert;
+import org.junit.BeforeClass;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+
+import javax.annotation.Nullable;
+import java.nio.ByteBuffer;
+import java.util.Arrays;
+
+public class TypesTest
+{
+ ByteBuffer buffer = ByteBuffer.allocate(1 << 16);
+
+ public static ColumnType NULLABLE_TEST_PAIR_TYPE = ColumnType.ofComplex("nullableLongPair");
+
+ @Rule
+ public ExpectedException expectedException = ExpectedException.none();
+
+ @BeforeClass
+ public static void setup()
+ {
+ Types.registerStrategy(NULLABLE_TEST_PAIR_TYPE.getComplexTypeName(), new PairObjectByteStrategy());
+ }
+
+ @Test
+ public void testIs()
+ {
+ Assert.assertTrue(Types.is(ColumnType.LONG, ValueType.LONG));
+ Assert.assertTrue(Types.is(ColumnType.DOUBLE, ValueType.DOUBLE));
+ Assert.assertTrue(Types.is(ColumnType.FLOAT, ValueType.FLOAT));
+ Assert.assertTrue(Types.is(ColumnType.STRING, ValueType.STRING));
+ Assert.assertTrue(Types.is(ColumnType.LONG_ARRAY, ValueType.ARRAY));
+ Assert.assertTrue(Types.is(ColumnType.LONG_ARRAY.getElementType(), ValueType.LONG));
+ Assert.assertTrue(Types.is(ColumnType.DOUBLE_ARRAY, ValueType.ARRAY));
+ Assert.assertTrue(Types.is(ColumnType.DOUBLE_ARRAY.getElementType(), ValueType.DOUBLE));
+ Assert.assertTrue(Types.is(ColumnType.STRING_ARRAY, ValueType.ARRAY));
+ Assert.assertTrue(Types.is(ColumnType.STRING_ARRAY.getElementType(), ValueType.STRING));
+ Assert.assertTrue(Types.is(NULLABLE_TEST_PAIR_TYPE, ValueType.COMPLEX));
+
+ Assert.assertFalse(Types.is(ColumnType.LONG, ValueType.DOUBLE));
+ Assert.assertFalse(Types.is(ColumnType.DOUBLE, ValueType.FLOAT));
+
+ Assert.assertFalse(Types.is(null, ValueType.STRING));
+ Assert.assertTrue(Types.isNullOr(null, ValueType.STRING));
+ }
+
+ @Test
+ public void testNullOrAnyOf()
+ {
+ Assert.assertTrue(Types.isNullOrAnyOf(ColumnType.LONG, ValueType.STRING, ValueType.LONG, ValueType.DOUBLE));
+ Assert.assertFalse(Types.isNullOrAnyOf(ColumnType.DOUBLE, ValueType.STRING, ValueType.LONG, ValueType.FLOAT));
+ Assert.assertTrue(Types.isNullOrAnyOf(null, ValueType.STRING, ValueType.LONG, ValueType.FLOAT));
+ }
+
+ @Test
+ public void testEither()
+ {
+ Assert.assertTrue(Types.either(ColumnType.LONG, ColumnType.DOUBLE, ValueType.DOUBLE));
+ Assert.assertFalse(Types.either(ColumnType.LONG, ColumnType.STRING, ValueType.DOUBLE));
+ }
+
+ @Test
+ public void testRegister()
+ {
+ ObjectByteStrategy<?> strategy = Types.getStrategy(NULLABLE_TEST_PAIR_TYPE.getComplexTypeName());
+ Assert.assertNotNull(strategy);
+ Assert.assertTrue(strategy instanceof PairObjectByteStrategy);
+ }
+
+ @Test
+ public void testRegisterDuplicate()
+ {
+ Types.registerStrategy(NULLABLE_TEST_PAIR_TYPE.getComplexTypeName(), new PairObjectByteStrategy());
+ ObjectByteStrategy<?> strategy = Types.getStrategy(NULLABLE_TEST_PAIR_TYPE.getComplexTypeName());
+ Assert.assertNotNull(strategy);
+ Assert.assertTrue(strategy instanceof PairObjectByteStrategy);
+ }
+
+ @Test
+ public void testConflicting()
+ {
+ expectedException.expect(IllegalStateException.class);
+ expectedException.expectMessage(
+ "Incompatible strategy for type[nullableLongPair] already exists."
+ + " Expected [org.apache.druid.segment.column.TypesTest$1],"
+ + " found [org.apache.druid.segment.column.TypesTest$PairObjectByteStrategy]."
+ );
+
+ Types.registerStrategy(NULLABLE_TEST_PAIR_TYPE.getComplexTypeName(), new ObjectByteStrategy<String>()
+ {
+ @Override
+ public int compare(String o1, String o2)
+ {
+ return 0;
+ }
+
+ @Override
+ public Class<? extends String> getClazz()
+ {
+ return null;
+ }
+
+ @Nullable
+ @Override
+ public String fromByteBuffer(ByteBuffer buffer, int numBytes)
+ {
+ return null;
+ }
+
+ @Nullable
+ @Override
+ public byte[] toBytes(@Nullable String val)
+ {
+ return new byte[0];
+ }
+ });
+ }
+
+ @Test
+ public void testNulls()
+ {
+ int offset = 0;
+ Types.writeNull(buffer, offset);
+ Assert.assertTrue(Types.isNullableNull(buffer, offset));
+
+ // test non-zero offset
+ offset = 128;
+ Types.writeNull(buffer, offset);
+ Assert.assertTrue(Types.isNullableNull(buffer, offset));
+ }
+
+ @Test
+ public void testNonNullNullableLongBinary()
+ {
+ final long someLong = 12345567L;
+ int offset = 0;
+ int bytesWritten = Types.writeNullableLong(buffer, offset, someLong);
+ Assert.assertEquals(1 + Long.BYTES, bytesWritten);
+ Assert.assertFalse(Types.isNullableNull(buffer, offset));
+ Assert.assertEquals(someLong, Types.readNullableLong(buffer, offset));
+
+ // test non-zero offset
+ offset = 1024;
+ bytesWritten = Types.writeNullableLong(buffer, offset, someLong);
+ Assert.assertEquals(1 + Long.BYTES, bytesWritten);
+ Assert.assertFalse(Types.isNullableNull(buffer, offset));
+ Assert.assertEquals(someLong, Types.readNullableLong(buffer, offset));
+ }
+
+ @Test
+ public void testNonNullNullableDoubleBinary()
+ {
+ final double someDouble = 1.234567;
+ int offset = 0;
+ int bytesWritten = Types.writeNullableDouble(buffer, offset, someDouble);
+ Assert.assertEquals(1 + Double.BYTES, bytesWritten);
+ Assert.assertFalse(Types.isNullableNull(buffer, offset));
+ Assert.assertEquals(someDouble, Types.readNullableDouble(buffer, offset), 0);
+
+ // test non-zero offset
+ offset = 1024;
+ bytesWritten = Types.writeNullableDouble(buffer, offset, someDouble);
+ Assert.assertEquals(1 + Double.BYTES, bytesWritten);
+ Assert.assertFalse(Types.isNullableNull(buffer, offset));
+ Assert.assertEquals(someDouble, Types.readNullableDouble(buffer, offset), 0);
+ }
+
+ @Test
+ public void testNonNullNullableFloatBinary()
+ {
+ final float someFloat = 12345567L;
+ int offset = 0;
+ int bytesWritten = Types.writeNullableFloat(buffer, offset, someFloat);
+ Assert.assertEquals(1 + Float.BYTES, bytesWritten);
+ Assert.assertFalse(Types.isNullableNull(buffer, offset));
+ Assert.assertEquals(someFloat, Types.readNullableFloat(buffer, offset), 0);
+
+ // test non-zero offset
+ offset = 1024;
+ bytesWritten = Types.writeNullableFloat(buffer, offset, someFloat);
+ Assert.assertEquals(1 + Float.BYTES, bytesWritten);
+ Assert.assertFalse(Types.isNullableNull(buffer, offset));
+ Assert.assertEquals(someFloat, Types.readNullableFloat(buffer, offset), 0);
+ }
+
+ @Test
+ public void testNullableVariableBlob()
+ {
+ String someString = "hello";
+ byte[] stringBytes = StringUtils.toUtf8(someString);
+ int offset = 0;
+ int bytesWritten = Types.writeNullableVariableBlob(buffer, offset, stringBytes);
+ Assert.assertEquals(1 + Integer.BYTES + stringBytes.length, bytesWritten);
+ Assert.assertFalse(Types.isNullableNull(buffer, offset));
+ Assert.assertArrayEquals(stringBytes, Types.readNullableVariableBlob(buffer, offset));
+
+ // test non-zero offset
+ offset = 1024;
+ bytesWritten = Types.writeNullableVariableBlob(buffer, offset, stringBytes);
+ Assert.assertEquals(1 + Integer.BYTES + stringBytes.length, bytesWritten);
+ Assert.assertFalse(Types.isNullableNull(buffer, offset));
+ Assert.assertArrayEquals(stringBytes, Types.readNullableVariableBlob(buffer, offset));
+
+ // test null
+ bytesWritten = Types.writeNullableVariableBlob(buffer, offset, null);
+ Assert.assertEquals(1, bytesWritten);
+ Assert.assertTrue(Types.isNullableNull(buffer, offset));
+ }
+
+ @Test
+ public void testNullableVariableBlobTooBig()
+ {
+ expectedException.expect(ISE.class);
+ expectedException.expectMessage("Unable to serialize [STRING], size [10] is larger than max [5]");
+ String someString = "hello";
+ byte[] stringBytes = StringUtils.toUtf8(someString);
+ int offset = 0;
+ Types.writeNullableVariableBlob(buffer, offset, stringBytes, ColumnType.STRING, stringBytes.length);
+ }
+
+ @Test
+ public void testArrays()
+ {
+ final Long[] longArray = new Long[]{1L, 1234567L, null, 10L};
+ final Double[] doubleArray = new Double[]{1.23, 4.567, null, 8.9};
+ final String[] stringArray = new String[]{"hello", "world", null, ""};
+
+ int bytesWritten;
+ int offset = 0;
+ bytesWritten = Types.writeNullableLongArray(buffer, offset, longArray, buffer.limit());
+ Assert.assertEquals(33, bytesWritten);
+ Assert.assertFalse(Types.isNullableNull(buffer, offset));
+ Assert.assertArrayEquals(longArray, Types.readNullableLongArray(buffer, offset));
+
+ bytesWritten = Types.writeNullableDoubleArray(buffer, offset, doubleArray, buffer.limit());
+ Assert.assertEquals(33, bytesWritten);
+ Assert.assertFalse(Types.isNullableNull(buffer, offset));
+ Assert.assertArrayEquals(doubleArray, Types.readNullableDoubleArray(buffer, offset));
+
+ bytesWritten = Types.writeNullableStringArray(buffer, offset, stringArray, buffer.limit());
+ Assert.assertEquals(31, bytesWritten);
+ Assert.assertFalse(Types.isNullableNull(buffer, offset));
+ Assert.assertArrayEquals(stringArray, Types.readNullableStringArray(buffer, offset));
+
+ offset = 1024;
+ bytesWritten = Types.writeNullableLongArray(buffer, offset, longArray, buffer.limit());
+ Assert.assertEquals(33, bytesWritten);
+ Assert.assertFalse(Types.isNullableNull(buffer, offset));
+ Assert.assertArrayEquals(longArray, Types.readNullableLongArray(buffer, offset));
+
+ bytesWritten = Types.writeNullableDoubleArray(buffer, offset, doubleArray, buffer.limit());
+ Assert.assertEquals(33, bytesWritten);
+ Assert.assertFalse(Types.isNullableNull(buffer, offset));
+ Assert.assertArrayEquals(doubleArray, Types.readNullableDoubleArray(buffer, offset));
+
+ bytesWritten = Types.writeNullableStringArray(buffer, offset, stringArray, buffer.limit());
+ Assert.assertEquals(31, bytesWritten);
+ Assert.assertFalse(Types.isNullableNull(buffer, offset));
+ Assert.assertArrayEquals(stringArray, Types.readNullableStringArray(buffer, offset));
+
+ // test nulls
+ bytesWritten = Types.writeNullableLongArray(buffer, offset, null, buffer.limit());
+ Assert.assertEquals(1, bytesWritten);
+ Assert.assertTrue(Types.isNullableNull(buffer, offset));
+
+ bytesWritten = Types.writeNullableDoubleArray(buffer, offset, null, buffer.limit());
+ Assert.assertEquals(1, bytesWritten);
+ Assert.assertTrue(Types.isNullableNull(buffer, offset));
+
+ bytesWritten = Types.writeNullableStringArray(buffer, offset, null, buffer.limit());
+ Assert.assertEquals(1, bytesWritten);
+ Assert.assertTrue(Types.isNullableNull(buffer, offset));
+ }
+
+ @Test
+ public void testLongArrayToBig()
+ {
+ expectedException.expect(ISE.class);
+ expectedException.expectMessage("Unable to serialize [ARRAY<LONG>], size [14] is larger than max [10]");
+ final Long[] longArray = new Long[]{1L, 1234567L, null, 10L};
+ Types.writeNullableLongArray(buffer, 0, longArray, 10);
+ }
+
+ @Test
+ public void testDoubleArrayToBig()
+ {
+ expectedException.expect(ISE.class);
+ expectedException.expectMessage("Unable to serialize [ARRAY<DOUBLE>], size [14] is larger than max [10]");
+ final Double[] doubleArray = new Double[]{1.23, 4.567, null, 8.9};
+ Types.writeNullableDoubleArray(buffer, 0, doubleArray, 10);
+ }
+
+ @Test
+ public void testStringArrayToBig()
+ {
+ expectedException.expect(ISE.class);
+ expectedException.expectMessage("Unable to serialize [ARRAY<STRING>], size [15] is larger than max [10]");
+ final String[] stringArray = new String[]{"hello", "world", null, ""};
+ Types.writeNullableStringArray(buffer, 0, stringArray, 10);
+ }
+
+
+ @Test
+ public void testComplex()
+ {
+ NullableLongPair lp1 = new NullableLongPair(null, 1L);
+ NullableLongPair lp2 = new NullableLongPair(1234L, 5678L);
+ NullableLongPair lp3 = new NullableLongPair(1234L, null);
+
+ int bytesWritten;
+ int offset = 0;
+ bytesWritten = Types.writeNullableComplexType(buffer, offset, NULLABLE_TEST_PAIR_TYPE, lp1, buffer.limit());
+ // 1 (not null) + 4 (length) + 1 (null) + 0 (lhs) + 1 (not null) + 8 (rhs)
+ Assert.assertEquals(15, bytesWritten);
+ Assert.assertFalse(Types.isNullableNull(buffer, offset));
+ Assert.assertEquals(lp1, Types.readNullableComplexType(buffer, offset, NULLABLE_TEST_PAIR_TYPE));
+
+ // 1 (not null) + 4 (length) + 1 (not null) + 8 (lhs) + 1 (not null) + 8 (rhs)
+ bytesWritten = Types.writeNullableComplexType(buffer, offset, NULLABLE_TEST_PAIR_TYPE, lp2, buffer.limit());
+ Assert.assertEquals(23, bytesWritten);
+ Assert.assertFalse(Types.isNullableNull(buffer, offset));
+ Assert.assertEquals(lp2, Types.readNullableComplexType(buffer, offset, NULLABLE_TEST_PAIR_TYPE));
+
+ // 1 (not null) + 4 (length) + 1 (not null) + 8 (lhs) + 1 (null) + 0 (rhs)
+ bytesWritten = Types.writeNullableComplexType(buffer, offset, NULLABLE_TEST_PAIR_TYPE, lp3, buffer.limit());
+ Assert.assertEquals(15, bytesWritten);
+ Assert.assertFalse(Types.isNullableNull(buffer, offset));
+ Assert.assertEquals(lp3, Types.readNullableComplexType(buffer, offset, NULLABLE_TEST_PAIR_TYPE));
+ }
+
+ @Test
+ public void testComplexTooBig()
+ {
+ expectedException.expect(ISE.class);
+ expectedException.expectMessage("Unable to serialize [COMPLEX<nullableLongPair>], size [23] is larger than max [10]");
+ Types.writeNullableComplexType(
+ buffer,
+ 0,
+ NULLABLE_TEST_PAIR_TYPE,
+ new NullableLongPair(1234L, 5678L),
+ 10
+ );
+ }
+
+ public static class PairObjectByteStrategy implements ObjectByteStrategy<NullableLongPair>
+ {
+ @Override
+ public Class<? extends NullableLongPair> getClazz()
+ {
+ return NullableLongPair.class;
+ }
+
+ @Nullable
+ @Override
+ public NullableLongPair fromByteBuffer(ByteBuffer buffer, int numBytes)
+ {
+ int position = buffer.position();
+ Long lhs = null;
+ Long rhs = null;
+ if (!Types.isNullableNull(buffer, position)) {
+ lhs = Types.readNullableLong(buffer, position);
+ position += 1 + Long.BYTES;
+ } else {
+ position++;
+ }
+ if (!Types.isNullableNull(buffer, position)) {
+ rhs = Types.readNullableLong(buffer, position);
+ }
+ return new NullableLongPair(lhs, rhs);
+ }
+
+ @Nullable
+ @Override
+ public byte[] toBytes(@Nullable NullableLongPair val)
+ {
+ byte[] bytes = new byte[1 + Long.BYTES + 1 + Long.BYTES];
+ ByteBuffer buffer = ByteBuffer.wrap(bytes);
+ int position = 0;
+ if (val != null) {
+ if (val.lhs != null) {
+ position += Types.writeNullableLong(buffer, position, val.lhs);
+ } else {
+ position += Types.writeNull(buffer, position);
+ }
+ if (val.rhs != null) {
+ position += Types.writeNullableLong(buffer, position, val.rhs);
+ } else {
+ position += Types.writeNull(buffer, position);
+ }
+ return Arrays.copyOfRange(bytes, 0, position);
+ } else {
+ return null;
+ }
+ }
+
+ @Override
+ public int compare(NullableLongPair o1, NullableLongPair o2)
+ {
+ return Comparators.<NullableLongPair>naturalNullsFirst().compare(o1, o2);
+ }
+ }
+
+ public static class NullableLongPair extends Pair<Long, Long> implements Comparable<NullableLongPair>
+ {
+ public NullableLongPair(@Nullable Long lhs, @Nullable Long rhs)
+ {
+ super(lhs, rhs);
+ }
+
+ @Override
+ public int compareTo(NullableLongPair o)
+ {
+ return Comparators.<Long>naturalNullsFirst().thenComparing(Longs::compare).compare(this.lhs, o.lhs);
+ }
+ }
+}
diff --git a/core/src/test/java/org/apache/druid/testing/InitializedNullHandlingTest.java b/core/src/test/java/org/apache/druid/testing/InitializedNullHandlingTest.java
index a4a737b..1ba482d 100644
--- a/core/src/test/java/org/apache/druid/testing/InitializedNullHandlingTest.java
+++ b/core/src/test/java/org/apache/druid/testing/InitializedNullHandlingTest.java
@@ -20,10 +20,12 @@
package org.apache.druid.testing;
import org.apache.druid.common.config.NullHandling;
+import org.apache.druid.math.expr.ExpressionProcessing;
public class InitializedNullHandlingTest
{
static {
NullHandling.initializeForTests();
+ ExpressionProcessing.initializeForTests(null);
}
}
diff --git a/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/guice/BloomFilterExtensionModule.java b/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/guice/BloomFilterExtensionModule.java
index 8bef477..41fa0ac 100644
--- a/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/guice/BloomFilterExtensionModule.java
+++ b/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/guice/BloomFilterExtensionModule.java
@@ -23,7 +23,7 @@ import com.fasterxml.jackson.databind.Module;
import com.google.inject.Binder;
import org.apache.druid.initialization.DruidModule;
import org.apache.druid.query.aggregation.bloom.sql.BloomFilterSqlAggregator;
-import org.apache.druid.query.expressions.BloomFilterExprMacro;
+import org.apache.druid.query.expressions.BloomFilterExpressions;
import org.apache.druid.query.filter.sql.BloomFilterOperatorConversion;
import org.apache.druid.sql.guice.SqlBindings;
@@ -44,6 +44,8 @@ public class BloomFilterExtensionModule implements DruidModule
{
SqlBindings.addOperatorConversion(binder, BloomFilterOperatorConversion.class);
SqlBindings.addAggregator(binder, BloomFilterSqlAggregator.class);
- ExpressionModule.addExprMacro(binder, BloomFilterExprMacro.class);
+ ExpressionModule.addExprMacro(binder, BloomFilterExpressions.CreateExprMacro.class);
+ ExpressionModule.addExprMacro(binder, BloomFilterExpressions.AddExprMacro.class);
+ ExpressionModule.addExprMacro(binder, BloomFilterExpressions.TestExprMacro.class);
}
}
diff --git a/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/query/aggregation/bloom/BloomFilterAggregatorFactory.java b/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/query/aggregation/bloom/BloomFilterAggregatorFactory.java
index 3365624..947fdfc 100644
--- a/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/query/aggregation/bloom/BloomFilterAggregatorFactory.java
+++ b/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/query/aggregation/bloom/BloomFilterAggregatorFactory.java
@@ -52,7 +52,7 @@ public class BloomFilterAggregatorFactory extends AggregatorFactory
public static final ColumnType TYPE = ColumnType.ofComplex(BloomFilterSerializersModule.BLOOM_FILTER_TYPE_NAME);
private static final int DEFAULT_NUM_ENTRIES = 1500;
- private static final Comparator COMPARATOR = Comparator.nullsFirst((o1, o2) -> {
+ public static final Comparator COMPARATOR = Comparator.nullsFirst((o1, o2) -> {
if (o1 instanceof ByteBuffer && o2 instanceof ByteBuffer) {
ByteBuffer buf1 = (ByteBuffer) o1;
ByteBuffer buf2 = (ByteBuffer) o2;
@@ -60,6 +60,13 @@ public class BloomFilterAggregatorFactory extends AggregatorFactory
BloomKFilter.getNumSetBits(buf1, buf1.position()),
BloomKFilter.getNumSetBits(buf2, buf2.position())
);
+ } else if (o1 instanceof BloomKFilter && o2 instanceof BloomKFilter) {
+ BloomKFilter f1 = (BloomKFilter) o1;
+ BloomKFilter f2 = (BloomKFilter) o2;
+ return Integer.compare(
+ f1.getNumSetBits(),
+ f2.getNumSetBits()
+ );
} else {
throw new RE("Unable to compare unexpected types [%s]", o1.getClass().getName());
}
diff --git a/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/query/aggregation/bloom/BloomFilterSerde.java b/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/query/aggregation/bloom/BloomFilterSerde.java
index 227fe70..69494bc 100644
--- a/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/query/aggregation/bloom/BloomFilterSerde.java
+++ b/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/query/aggregation/bloom/BloomFilterSerde.java
@@ -28,6 +28,8 @@ import org.apache.druid.segment.serde.ComplexMetricExtractor;
import org.apache.druid.segment.serde.ComplexMetricSerde;
import org.apache.druid.segment.writeout.SegmentWriteOutMedium;
+import javax.annotation.Nullable;
+import java.io.IOException;
import java.nio.ByteBuffer;
/**
@@ -37,6 +39,8 @@ import java.nio.ByteBuffer;
*/
public class BloomFilterSerde extends ComplexMetricSerde
{
+ private static final BloomFilterObjectStrategy STRATEGY = new BloomFilterObjectStrategy();
+
@Override
public String getTypeName()
{
@@ -64,6 +68,45 @@ public class BloomFilterSerde extends ComplexMetricSerde
@Override
public ObjectStrategy<BloomKFilter> getObjectStrategy()
{
- throw new UnsupportedOperationException("Bloom filter aggregators are query-time only");
+ return STRATEGY;
+ }
+
+ private static class BloomFilterObjectStrategy implements ObjectStrategy<BloomKFilter>
+ {
+ @Override
+ public Class<? extends BloomKFilter> getClazz()
+ {
+ return BloomKFilter.class;
+ }
+
+ @Nullable
+ @Override
+ public BloomKFilter fromByteBuffer(ByteBuffer buffer, int numBytes)
+ {
+ try {
+ return BloomKFilter.deserialize(buffer, buffer.position());
+ }
+ catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ @Nullable
+ @Override
+ public byte[] toBytes(@Nullable BloomKFilter val)
+ {
+ try {
+ return BloomFilterSerializersModule.bloomKFilterToBytes(val);
+ }
+ catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ @Override
+ public int compare(BloomKFilter o1, BloomKFilter o2)
+ {
+ return BloomFilterAggregatorFactory.COMPARATOR.compare(o1, o2);
+ }
}
}
diff --git a/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/query/expressions/BloomFilterExprMacro.java b/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/query/expressions/BloomFilterExprMacro.java
deleted file mode 100644
index 637daff..0000000
--- a/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/query/expressions/BloomFilterExprMacro.java
+++ /dev/null
@@ -1,138 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.druid.query.expressions;
-
-import org.apache.druid.guice.BloomFilterSerializersModule;
-import org.apache.druid.java.util.common.IAE;
-import org.apache.druid.java.util.common.StringUtils;
-import org.apache.druid.math.expr.Expr;
-import org.apache.druid.math.expr.ExprEval;
-import org.apache.druid.math.expr.ExprMacroTable;
-import org.apache.druid.math.expr.ExpressionType;
-import org.apache.druid.query.filter.BloomKFilter;
-
-import javax.annotation.Nonnull;
-import javax.annotation.Nullable;
-import java.io.IOException;
-import java.util.List;
-
-public class BloomFilterExprMacro implements ExprMacroTable.ExprMacro
-{
- public static final String FN_NAME = "bloom_filter_test";
-
- @Override
- public String name()
- {
- return FN_NAME;
- }
-
- @Override
- public Expr apply(List<Expr> args)
- {
- if (args.size() != 2) {
- throw new IAE("Function[%s] must have 2 arguments", name());
- }
-
- final Expr arg = args.get(0);
- final Expr filterExpr = args.get(1);
-
- if (!filterExpr.isLiteral() || filterExpr.getLiteralValue() == null) {
- throw new IAE("Function[%s] second argument must be a base64 serialized bloom filter", name());
- }
-
-
- final String serializedFilter = filterExpr.getLiteralValue().toString();
- final byte[] decoded = StringUtils.decodeBase64String(serializedFilter);
- BloomKFilter filter;
- try {
- filter = BloomFilterSerializersModule.bloomKFilterFromBytes(decoded);
- }
- catch (IOException ioe) {
- throw new RuntimeException("Failed to deserialize bloom filter", ioe);
- }
-
- class BloomExpr extends ExprMacroTable.BaseScalarUnivariateMacroFunctionExpr
- {
- private BloomExpr(Expr arg)
- {
- super(FN_NAME, arg);
- }
-
- @Nonnull
- @Override
- public ExprEval eval(final ObjectBinding bindings)
- {
- ExprEval evaluated = arg.eval(bindings);
-
- boolean matches = false;
- switch (evaluated.type().getType()) {
- case STRING:
- String stringVal = (String) evaluated.value();
- if (stringVal == null) {
- matches = nullMatch();
- } else {
- matches = filter.testString(stringVal);
- }
- break;
- case DOUBLE:
- Double doubleVal = (Double) evaluated.value();
- if (doubleVal == null) {
- matches = nullMatch();
- } else {
- matches = filter.testDouble(doubleVal);
- }
- break;
- case LONG:
- Long longVal = (Long) evaluated.value();
- if (longVal == null) {
- matches = nullMatch();
- } else {
- matches = filter.testLong(longVal);
- }
- break;
- }
-
- return ExprEval.ofLongBoolean(matches);
- }
-
- private boolean nullMatch()
- {
- return filter.testBytes(null, 0, 0);
- }
-
-
- @Override
- public Expr visit(Shuttle shuttle)
- {
- Expr newArg = arg.visit(shuttle);
- return shuttle.visit(new BloomExpr(newArg));
- }
-
- @Nullable
- @Override
- public ExpressionType getOutputType(InputBindingInspector inspector)
- {
- return ExpressionType.LONG;
- }
- }
-
- return new BloomExpr(arg);
- }
-}
diff --git a/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/query/expressions/BloomFilterExpressions.java b/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/query/expressions/BloomFilterExpressions.java
new file mode 100644
index 0000000..f48baa6
--- /dev/null
+++ b/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/query/expressions/BloomFilterExpressions.java
@@ -0,0 +1,366 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.query.expressions;
+
+import org.apache.druid.guice.BloomFilterSerializersModule;
+import org.apache.druid.java.util.common.IAE;
+import org.apache.druid.java.util.common.StringUtils;
+import org.apache.druid.math.expr.Expr;
+import org.apache.druid.math.expr.ExprEval;
+import org.apache.druid.math.expr.ExprMacroTable;
+import org.apache.druid.math.expr.ExprType;
+import org.apache.druid.math.expr.ExpressionType;
+import org.apache.druid.math.expr.InputBindings;
+import org.apache.druid.query.aggregation.bloom.BloomFilterAggregatorFactory;
+import org.apache.druid.query.filter.BloomKFilter;
+
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+import java.io.IOException;
+import java.util.List;
+import java.util.stream.Collectors;
+
+public class BloomFilterExpressions
+{
+ public static final ExpressionType BLOOM_FILTER_TYPE = ExpressionType.fromColumnTypeStrict(
+ BloomFilterAggregatorFactory.TYPE
+ );
+
+ public static class CreateExprMacro implements ExprMacroTable.ExprMacro
+ {
+ public static final String FN_NAME = "bloom_filter";
+
+ @Override
+ public String name()
+ {
+ return FN_NAME;
+ }
+
+ @Override
+ public Expr apply(List<Expr> args)
+ {
+ if (args.size() != 1) {
+ throw new IAE("Function[%s] must have 1 argument", name());
+ }
+
+ final Expr expectedSizeArg = args.get(0);
+
+ if (!expectedSizeArg.isLiteral() || expectedSizeArg.getLiteralValue() == null) {
+ throw new IAE("Function[%s] argument must be an LONG constant", name());
+ }
+
+ class BloomExpr extends ExprMacroTable.BaseScalarUnivariateMacroFunctionExpr
+ {
+ final int expectedSize;
+
+ public BloomExpr(Expr arg)
+ {
+ super(FN_NAME, arg);
+ this.expectedSize = arg.eval(InputBindings.nilBindings()).asInt();
+ }
+
+ @Override
+ public ExprEval eval(ObjectBinding bindings)
+ {
+ return ExprEval.ofComplex(
+ BLOOM_FILTER_TYPE,
+ new BloomKFilter(expectedSize)
+ );
+ }
+
+ @Override
+ public Expr visit(Shuttle shuttle)
+ {
+ return shuttle.visit(this);
+ }
+
+ @Nullable
+ @Override
+ public ExpressionType getOutputType(InputBindingInspector inspector)
+ {
+ return BLOOM_FILTER_TYPE;
+ }
+ }
+
+ return new BloomExpr(expectedSizeArg);
+ }
+ }
+
+ public static class AddExprMacro implements ExprMacroTable.ExprMacro
+ {
+ public static final String FN_NAME = "bloom_filter_add";
+
+ @Override
+ public String name()
+ {
+ return FN_NAME;
+ }
+
+ @Override
+ public Expr apply(List<Expr> args)
+ {
+ if (args.size() != 2) {
+ throw new IAE("Function[%s] must have 2 arguments", name());
+ }
+
+ class BloomExpr extends ExprMacroTable.BaseScalarMacroFunctionExpr
+ {
+ private BloomExpr(List<Expr> args)
+ {
+ super(FN_NAME, args);
+ }
+
+ @Override
+ public ExprEval eval(final ObjectBinding bindings)
+ {
+ ExprEval bloomy = args.get(1).eval(bindings);
+ // be permissive for now, we can count more on this later when we are better at retaining complete complex
+ // type information everywhere
+ if (!bloomy.type().equals(BLOOM_FILTER_TYPE) ||
+ !bloomy.type().is(ExprType.COMPLEX) && bloomy.value() instanceof BloomKFilter) {
+ throw new IAE("Function[%s] must take a bloom filter as the second argument", FN_NAME);
+ }
+ BloomKFilter filter = (BloomKFilter) bloomy.value();
+ assert filter != null;
+ ExprEval input = args.get(0).eval(bindings);
+
+ if (input.value() == null) {
+ filter.addBytes(null, 0, 0);
+ } else {
+ switch (input.type().getType()) {
+ case STRING:
+ filter.addString(input.asString());
+ break;
+ case DOUBLE:
+ filter.addDouble(input.asDouble());
+ break;
+ case LONG:
+ filter.addLong(input.asLong());
+ break;
+ case COMPLEX:
+ if (BLOOM_FILTER_TYPE.equals(input.type()) || (bloomy.type().is(ExprType.COMPLEX) && bloomy.value() instanceof BloomKFilter)) {
+ filter.merge((BloomKFilter) input.value());
+ break;
+ }
+ default:
+ throw new IAE("Function[%s] cannot add [%s] to a bloom filter", FN_NAME, input.type());
+ }
+ }
+
+ return ExprEval.ofComplex(BLOOM_FILTER_TYPE, filter);
+ }
+
+
+ @Override
+ public Expr visit(Shuttle shuttle)
+ {
+ List<Expr> newArgs = args.stream().map(x -> x.visit(shuttle)).collect(Collectors.toList());
+ return shuttle.visit(new BloomExpr(newArgs));
+ }
+
+ @Nullable
+ @Override
+ public ExpressionType getOutputType(InputBindingInspector inspector)
+ {
+ return BLOOM_FILTER_TYPE;
+ }
+ }
+
+ return new BloomExpr(args);
+ }
+ }
+
+ public static class TestExprMacro implements ExprMacroTable.ExprMacro
+ {
+ public static final String FN_NAME = "bloom_filter_test";
+
+ @Override
+ public String name()
+ {
+ return FN_NAME;
+ }
+
+ @Override
+ public Expr apply(List<Expr> args)
+ {
+ if (args.size() != 2) {
+ throw new IAE("Function[%s] must have 2 arguments", name());
+ }
+
+ class BloomExpr extends ExprMacroTable.BaseScalarUnivariateMacroFunctionExpr
+ {
+ private final BloomKFilter filter;
+
+ private BloomExpr(BloomKFilter filter, Expr arg)
+ {
+ super(FN_NAME, arg);
+ this.filter = filter;
+ }
+
+ @Nonnull
+ @Override
+ public ExprEval eval(final ObjectBinding bindings)
+ {
+ ExprEval evaluated = arg.eval(bindings);
+
+ boolean matches = false;
+ switch (evaluated.type().getType()) {
+ case STRING:
+ String stringVal = (String) evaluated.value();
+ if (stringVal == null) {
+ matches = nullMatch();
+ } else {
+ matches = filter.testString(stringVal);
+ }
+ break;
+ case DOUBLE:
+ Double doubleVal = (Double) evaluated.value();
+ if (doubleVal == null) {
+ matches = nullMatch();
+ } else {
+ matches = filter.testDouble(doubleVal);
+ }
+ break;
+ case LONG:
+ Long longVal = (Long) evaluated.value();
+ if (longVal == null) {
+ matches = nullMatch();
+ } else {
+ matches = filter.testLong(longVal);
+ }
+ break;
+ }
+
+ return ExprEval.ofLongBoolean(matches);
+ }
+
+ private boolean nullMatch()
+ {
+ return filter.testBytes(null, 0, 0);
+ }
+
+ @Override
+ public Expr visit(Shuttle shuttle)
+ {
+ Expr newArg = arg.visit(shuttle);
+ return shuttle.visit(new BloomExpr(filter, newArg));
+ }
+
+ @Nullable
+ @Override
+ public ExpressionType getOutputType(InputBindingInspector inspector)
+ {
+ return ExpressionType.LONG;
+ }
+ }
+
+ class DynamicBloomExpr extends ExprMacroTable.BaseScalarMacroFunctionExpr
+ {
+ public DynamicBloomExpr(List<Expr> args)
+ {
+ super(FN_NAME, args);
+ }
+
+ @Override
+ public ExprEval eval(final ObjectBinding bindings)
+ {
+ ExprEval bloomy = args.get(1).eval(bindings);
+ // be permissive for now, we can count more on this later when we are better at retaining complete complex
+ // type information everywhere
+ if (!bloomy.type().equals(BLOOM_FILTER_TYPE) ||
+ !bloomy.type().is(ExprType.COMPLEX) && bloomy.value() instanceof BloomKFilter) {
+ throw new IAE("Function[%s] must take a bloom filter as the second argument", FN_NAME);
+ }
+ BloomKFilter filter = (BloomKFilter) bloomy.value();
+ assert filter != null;
+ ExprEval input = args.get(0).eval(bindings);
+
+ boolean matches = false;
+ switch (input.type().getType()) {
+ case STRING:
+ String stringVal = (String) input.value();
+ if (stringVal == null) {
+ matches = nullMatch(filter);
+ } else {
+ matches = filter.testString(stringVal);
+ }
+ break;
+ case DOUBLE:
+ Double doubleVal = (Double) input.value();
+ if (doubleVal == null) {
+ matches = nullMatch(filter);
+ } else {
+ matches = filter.testDouble(doubleVal);
+ }
+ break;
+ case LONG:
+ Long longVal = (Long) input.value();
+ if (longVal == null) {
+ matches = nullMatch(filter);
+ } else {
+ matches = filter.testLong(longVal);
+ }
+ break;
+ }
+
+ return ExprEval.ofLongBoolean(matches);
+ }
+
+ private boolean nullMatch(BloomKFilter filter)
+ {
+ return filter.testBytes(null, 0, 0);
+ }
+
+ @Override
+ public Expr visit(Shuttle shuttle)
+ {
+ List<Expr> newArgs = args.stream().map(x -> x.visit(shuttle)).collect(Collectors.toList());
+ return shuttle.visit(new DynamicBloomExpr(newArgs));
+ }
+
+ @Nullable
+ @Override
+ public ExpressionType getOutputType(InputBindingInspector inspector)
+ {
+ return ExpressionType.LONG;
+ }
+ }
+
+
+ final Expr arg = args.get(0);
+ final Expr filterExpr = args.get(1);
+
+ if (filterExpr.isLiteral() && filterExpr.getLiteralValue() instanceof String) {
+ final String serializedFilter = (String) filterExpr.getLiteralValue();
+ final byte[] decoded = StringUtils.decodeBase64String(serializedFilter);
+ BloomKFilter filter;
+ try {
+ filter = BloomFilterSerializersModule.bloomKFilterFromBytes(decoded);
+ }
+ catch (IOException ioe) {
+ throw new RuntimeException("Failed to deserialize bloom filter", ioe);
+ }
+ return new BloomExpr(filter, arg);
+ } else {
+ return new DynamicBloomExpr(args);
+ }
+ }
+ }
+}
diff --git a/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/query/filter/sql/BloomFilterOperatorConversion.java b/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/query/filter/sql/BloomFilterOperatorConversion.java
index edd6430..cb5e823 100644
--- a/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/query/filter/sql/BloomFilterOperatorConversion.java
+++ b/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/query/filter/sql/BloomFilterOperatorConversion.java
@@ -28,7 +28,7 @@ import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.druid.guice.BloomFilterSerializersModule;
import org.apache.druid.java.util.common.StringUtils;
-import org.apache.druid.query.expressions.BloomFilterExprMacro;
+import org.apache.druid.query.expressions.BloomFilterExpressions;
import org.apache.druid.query.filter.BloomDimFilter;
import org.apache.druid.query.filter.BloomKFilter;
import org.apache.druid.query.filter.BloomKFilterHolder;
@@ -49,14 +49,14 @@ import java.util.List;
public class BloomFilterOperatorConversion extends DirectOperatorConversion
{
private static final SqlFunction SQL_FUNCTION = OperatorConversions
- .operatorBuilder(StringUtils.toUpperCase(BloomFilterExprMacro.FN_NAME))
+ .operatorBuilder(StringUtils.toUpperCase(BloomFilterExpressions.TestExprMacro.FN_NAME))
.operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.CHARACTER)
.returnTypeInference(ReturnTypes.BOOLEAN_NULLABLE)
.build();
public BloomFilterOperatorConversion()
{
- super(SQL_FUNCTION, BloomFilterExprMacro.FN_NAME);
+ super(SQL_FUNCTION, BloomFilterExpressions.TestExprMacro.FN_NAME);
}
@Override
diff --git a/extensions-core/druid-bloom-filter/src/test/java/org/apache/druid/query/expressions/BloomFilterExpressionsTest.java b/extensions-core/druid-bloom-filter/src/test/java/org/apache/druid/query/expressions/BloomFilterExpressionsTest.java
new file mode 100644
index 0000000..6a6be4e
--- /dev/null
+++ b/extensions-core/druid-bloom-filter/src/test/java/org/apache/druid/query/expressions/BloomFilterExpressionsTest.java
@@ -0,0 +1,226 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.query.expressions;
+
+import com.google.common.base.Supplier;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import org.apache.druid.java.util.common.IAE;
+import org.apache.druid.java.util.common.Pair;
+import org.apache.druid.math.expr.Expr;
+import org.apache.druid.math.expr.ExprEval;
+import org.apache.druid.math.expr.ExprMacroTable;
+import org.apache.druid.math.expr.ExpressionType;
+import org.apache.druid.math.expr.InputBindings;
+import org.apache.druid.math.expr.Parser;
+import org.apache.druid.query.filter.BloomKFilter;
+import org.apache.druid.testing.InitializedNullHandlingTest;
+import org.junit.Assert;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+
+
+public class BloomFilterExpressionsTest extends InitializedNullHandlingTest
+{
+ private static final String SOME_STRING = "foo";
+ private static final long SOME_LONG = 1234L;
+ private static final double SOME_DOUBLE = 1.234;
+ private static final String[] SOME_STRING_ARRAY = new String[]{"hello", "world"};
+ private static final Long[] SOME_LONG_ARRAY = new Long[]{1L, 2L, 3L, 4L};
+ private static final Double[] SOME_DOUBLE_ARRAY = new Double[]{1.2, 3.4};
+
+ BloomFilterExpressions.CreateExprMacro createMacro = new BloomFilterExpressions.CreateExprMacro();
+ BloomFilterExpressions.AddExprMacro addMacro = new BloomFilterExpressions.AddExprMacro();
+ BloomFilterExpressions.TestExprMacro testMacro = new BloomFilterExpressions.TestExprMacro();
+ ExprMacroTable macroTable = new ExprMacroTable(ImmutableList.of(createMacro, addMacro, testMacro));
+
+ Expr.ObjectBinding inputBindings = InputBindings.withTypedSuppliers(
+ new ImmutableMap.Builder<String, Pair<ExpressionType, Supplier<Object>>>()
+ .put("bloomy", new Pair<>(BloomFilterExpressions.BLOOM_FILTER_TYPE, () -> new BloomKFilter(100)))
+ .put("string", new Pair<>(ExpressionType.STRING, () -> SOME_STRING))
+ .put("long", new Pair<>(ExpressionType.LONG, () -> SOME_LONG))
+ .put("double", new Pair<>(ExpressionType.DOUBLE, () -> SOME_DOUBLE))
+ .put("string_array", new Pair<>(ExpressionType.STRING_ARRAY, () -> SOME_STRING_ARRAY))
+ .put("long_array", new Pair<>(ExpressionType.LONG_ARRAY, () -> SOME_LONG_ARRAY))
+ .put("double_array", new Pair<>(ExpressionType.DOUBLE_ARRAY, () -> SOME_DOUBLE_ARRAY))
+ .build()
+ );
+
+ @Rule
+ public ExpectedException expectedException = ExpectedException.none();
+
+ @Test
+ public void testCreate()
+ {
+ Expr expr = Parser.parse("bloom_filter(100)", macroTable);
+ ExprEval eval = expr.eval(inputBindings);
+
+ Assert.assertEquals(BloomFilterExpressions.BLOOM_FILTER_TYPE, eval.type());
+ Assert.assertTrue(eval.value() instanceof BloomKFilter);
+ Assert.assertEquals(1024, ((BloomKFilter) eval.value()).getBitSize());
+ }
+
+ @Test
+ public void testAddString()
+ {
+ Expr expr = Parser.parse("bloom_filter_add('foo', bloomy)", macroTable);
+ ExprEval eval = expr.eval(inputBindings);
+
+ Assert.assertEquals(BloomFilterExpressions.BLOOM_FILTER_TYPE, eval.type());
+ Assert.assertTrue(eval.value() instanceof BloomKFilter);
+ Assert.assertTrue(((BloomKFilter) eval.value()).testString(SOME_STRING));
+
+ expr = Parser.parse("bloom_filter_add(string, bloomy)", macroTable);
+ eval = expr.eval(inputBindings);
+
+ Assert.assertEquals(BloomFilterExpressions.BLOOM_FILTER_TYPE, eval.type());
+ Assert.assertTrue(eval.value() instanceof BloomKFilter);
+ Assert.assertTrue(((BloomKFilter) eval.value()).testString(SOME_STRING));
+ }
+
+ @Test
+ public void testAddLong()
+ {
+ Expr expr = Parser.parse("bloom_filter_add(1234, bloomy)", macroTable);
+ ExprEval eval = expr.eval(inputBindings);
+
+ Assert.assertEquals(BloomFilterExpressions.BLOOM_FILTER_TYPE, eval.type());
+ Assert.assertTrue(eval.value() instanceof BloomKFilter);
+ Assert.assertTrue(((BloomKFilter) eval.value()).testLong(SOME_LONG));
+
+ expr = Parser.parse("bloom_filter_add(long, bloomy)", macroTable);
+ eval = expr.eval(inputBindings);
+
+ Assert.assertEquals(BloomFilterExpressions.BLOOM_FILTER_TYPE, eval.type());
+ Assert.assertTrue(eval.value() instanceof BloomKFilter);
+ Assert.assertTrue(((BloomKFilter) eval.value()).testLong(SOME_LONG));
+ }
+
+ @Test
+ public void testAddDouble()
+ {
+ Expr expr = Parser.parse("bloom_filter_add(1.234, bloomy)", macroTable);
+ ExprEval eval = expr.eval(inputBindings);
+
+ Assert.assertEquals(BloomFilterExpressions.BLOOM_FILTER_TYPE, eval.type());
+ Assert.assertTrue(eval.value() instanceof BloomKFilter);
+ Assert.assertTrue(((BloomKFilter) eval.value()).testDouble(SOME_DOUBLE));
+
+ expr = Parser.parse("bloom_filter_add(double, bloomy)", macroTable);
+ eval = expr.eval(inputBindings);
+
+ Assert.assertEquals(BloomFilterExpressions.BLOOM_FILTER_TYPE, eval.type());
+ Assert.assertTrue(eval.value() instanceof BloomKFilter);
+ Assert.assertTrue(((BloomKFilter) eval.value()).testDouble(SOME_DOUBLE));
+ }
+
+ @Test
+ public void testFilter()
+ {
+ Expr expr = Parser.parse("bloom_filter_test(1.234, bloom_filter_add(1.234, bloomy))", macroTable);
+ ExprEval eval = expr.eval(inputBindings);
+ Assert.assertEquals(ExpressionType.LONG, eval.type());
+ Assert.assertTrue(eval.asBoolean());
+
+ expr = Parser.parse("bloom_filter_test(1234, bloom_filter_add(1234, bloomy))", macroTable);
+ eval = expr.eval(inputBindings);
+ Assert.assertTrue(eval.asBoolean());
+
+ expr = Parser.parse("bloom_filter_test('foo', bloom_filter_add('foo', bloomy))", macroTable);
+ eval = expr.eval(inputBindings);
+ Assert.assertTrue(eval.asBoolean());
+
+ expr = Parser.parse("bloom_filter_test('bar', bloom_filter_add('foo', bloomy))", macroTable);
+ eval = expr.eval(inputBindings);
+ Assert.assertFalse(eval.asBoolean());
+
+ expr = Parser.parse("bloom_filter_test(1234, bloom_filter_add('foo', bloomy))", macroTable);
+ eval = expr.eval(inputBindings);
+ Assert.assertFalse(eval.asBoolean());
+
+ expr = Parser.parse("bloom_filter_test(1.23, bloom_filter_add('foo', bloomy))", macroTable);
+ eval = expr.eval(inputBindings);
+ Assert.assertFalse(eval.asBoolean());
+
+
+ expr = Parser.parse("bloom_filter_test(1234, bloom_filter_add(1234, bloom_filter(100)))", macroTable);
+ eval = expr.eval(inputBindings);
+ Assert.assertTrue(eval.asBoolean());
+
+ expr = Parser.parse("bloom_filter_test(4321, bloom_filter_add(1234, bloom_filter(100)))", macroTable);
+ eval = expr.eval(inputBindings);
+ Assert.assertFalse(eval.asBoolean());
+
+ expr = Parser.parse("bloom_filter_test(4321, bloom_filter_add(bloom_filter_add(1234, bloom_filter(100)), bloom_filter_add(4321, bloom_filter(100))))", macroTable);
+ eval = expr.eval(inputBindings);
+ Assert.assertTrue(eval.asBoolean());
+ }
+
+
+ @Test
+ public void testCreateWrongArgsCount()
+ {
+ expectedException.expect(IAE.class);
+ expectedException.expectMessage("Function[bloom_filter] must have 1 argument");
+ Parser.parse("bloom_filter()", macroTable);
+ }
+
+ @Test
+ public void testAddWrongArgsCount()
+ {
+ expectedException.expect(IAE.class);
+ expectedException.expectMessage("Function[bloom_filter_add] must have 2 arguments");
+ Parser.parse("bloom_filter_add(1)", macroTable);
+ }
+
+ @Test
+ public void testAddWrongArgType()
+ {
+ expectedException.expect(IAE.class);
+ expectedException.expectMessage("Function[bloom_filter_add] must take a bloom filter as the second argument");
+ Parser.parse("bloom_filter_add(1, 2)", macroTable);
+ }
+
+ @Test
+ public void testAddWrongArgType2()
+ {
+ expectedException.expect(IAE.class);
+ expectedException.expectMessage("Function[bloom_filter_add] cannot add [ARRAY<LONG>] to a bloom filter");
+ Expr expr = Parser.parse("bloom_filter_add(ARRAY<LONG>[], bloomy)", macroTable);
+ expr.eval(inputBindings);
+ }
+
+ @Test
+ public void testTestWrongArgsCount()
+ {
+ expectedException.expect(IAE.class);
+ expectedException.expectMessage("Function[bloom_filter_test] must have 2 arguments");
+ Parser.parse("bloom_filter_test(1)", macroTable);
+ }
+
+ @Test
+ public void testTestWrongArgsType()
+ {
+ expectedException.expect(IAE.class);
+ expectedException.expectMessage("Function[bloom_filter_test] must take a bloom filter as the second argument");
+ Parser.parse("bloom_filter_test(1, 2)", macroTable);
+ }
+}
diff --git a/extensions-core/druid-bloom-filter/src/test/java/org/apache/druid/query/filter/sql/BloomDimFilterSqlTest.java b/extensions-core/druid-bloom-filter/src/test/java/org/apache/druid/query/filter/sql/BloomDimFilterSqlTest.java
index cc1b7c8..eba95ce 100644
--- a/extensions-core/druid-bloom-filter/src/test/java/org/apache/druid/query/filter/sql/BloomDimFilterSqlTest.java
+++ b/extensions-core/druid-bloom-filter/src/test/java/org/apache/druid/query/filter/sql/BloomDimFilterSqlTest.java
@@ -31,7 +31,7 @@ import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.query.Druids;
import org.apache.druid.query.aggregation.CountAggregatorFactory;
import org.apache.druid.query.expression.LookupExprMacro;
-import org.apache.druid.query.expressions.BloomFilterExprMacro;
+import org.apache.druid.query.expressions.BloomFilterExpressions;
import org.apache.druid.query.filter.BloomDimFilter;
import org.apache.druid.query.filter.BloomKFilter;
import org.apache.druid.query.filter.BloomKFilterHolder;
@@ -74,7 +74,7 @@ public class BloomDimFilterSqlTest extends BaseCalciteQueryTest
exprMacros.add(CalciteTests.INJECTOR.getInstance(clazz));
}
exprMacros.add(CalciteTests.INJECTOR.getInstance(LookupExprMacro.class));
- exprMacros.add(new BloomFilterExprMacro());
+ exprMacros.add(new BloomFilterExpressions.TestExprMacro());
return new ExprMacroTable(exprMacros);
}
diff --git a/extensions-core/testing-tools/src/test/java/org/apache/druid/query/expressions/SleepExprTest.java b/extensions-core/testing-tools/src/test/java/org/apache/druid/query/expressions/SleepExprTest.java
index 8aa2be7..f63de93 100644
--- a/extensions-core/testing-tools/src/test/java/org/apache/druid/query/expressions/SleepExprTest.java
+++ b/extensions-core/testing-tools/src/test/java/org/apache/druid/query/expressions/SleepExprTest.java
@@ -21,28 +21,17 @@ package org.apache.druid.query.expressions;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.math.expr.Expr;
-import org.apache.druid.math.expr.Expr.ObjectBinding;
import org.apache.druid.math.expr.ExprMacroTable;
+import org.apache.druid.math.expr.InputBindings;
import org.apache.druid.math.expr.Parser;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Assert;
import org.junit.Test;
-import javax.annotation.Nullable;
import java.util.Collections;
public class SleepExprTest extends InitializedNullHandlingTest
{
- private final ObjectBinding bindings = new ObjectBinding()
- {
- @Nullable
- @Override
- public Object get(String name)
- {
- return null;
- }
- };
-
private final ExprMacroTable exprMacroTable = new ExprMacroTable(Collections.singletonList(new SleepExprMacro()));
@Test
@@ -66,7 +55,7 @@ public class SleepExprTest extends InitializedNullHandlingTest
final long detla = 50;
final long before = System.currentTimeMillis();
final Expr expr = Parser.parse(expression, exprMacroTable);
- expr.eval(bindings).value();
+ expr.eval(InputBindings.nilBindings()).value();
final long after = System.currentTimeMillis();
final long elapsed = after - before;
Assert.assertTrue(
@@ -79,14 +68,14 @@ public class SleepExprTest extends InitializedNullHandlingTest
private void assertExpr(final String expression)
{
final Expr expr = Parser.parse(expression, exprMacroTable);
- Assert.assertNull(expression, expr.eval(bindings).value());
+ Assert.assertNull(expression, expr.eval(InputBindings.nilBindings()).value());
final Expr exprNoFlatten = Parser.parse(expression, exprMacroTable, false);
final Expr roundTrip = Parser.parse(exprNoFlatten.stringify(), exprMacroTable);
- Assert.assertNull(expr.stringify(), roundTrip.eval(bindings).value());
+ Assert.assertNull(expr.stringify(), roundTrip.eval(InputBindings.nilBindings()).value());
final Expr roundTripFlatten = Parser.parse(expr.stringify(), exprMacroTable);
- Assert.assertNull(expr.stringify(), roundTripFlatten.eval(bindings).value());
+ Assert.assertNull(expr.stringify(), roundTripFlatten.eval(InputBindings.nilBindings()).value());
Assert.assertEquals(expr.stringify(), roundTrip.stringify());
Assert.assertEquals(expr.stringify(), roundTripFlatten.stringify());
diff --git a/indexing-hadoop/src/main/java/org/apache/druid/indexer/InputRowSerde.java b/indexing-hadoop/src/main/java/org/apache/druid/indexer/InputRowSerde.java
index 846826a..cb2c2bb 100644
--- a/indexing-hadoop/src/main/java/org/apache/druid/indexer/InputRowSerde.java
+++ b/indexing-hadoop/src/main/java/org/apache/druid/indexer/InputRowSerde.java
@@ -39,6 +39,7 @@ import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.segment.DimensionHandlerUtils;
import org.apache.druid.segment.VirtualColumns;
import org.apache.druid.segment.column.ColumnType;
+import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.segment.column.ValueType;
import org.apache.druid.segment.incremental.IncrementalIndex;
import org.apache.druid.segment.serde.ComplexMetricSerde;
@@ -330,7 +331,7 @@ public class InputRowSerde
writeString(k, out);
try (Aggregator agg = aggFactory.factorize(
- IncrementalIndex.makeColumnSelectorFactory(VirtualColumns.EMPTY, aggFactory, supplier, true)
+ IncrementalIndex.makeColumnSelectorFactory(RowSignature::empty, VirtualColumns.EMPTY, aggFactory, supplier, true)
)) {
try {
agg.aggregate();
diff --git a/pom.xml b/pom.xml
index b02322a..134f8b6 100644
--- a/pom.xml
+++ b/pom.xml
@@ -771,11 +771,6 @@
<artifactId>fastutil-core</artifactId>
<version>${fastutil.version}</version>
</dependency>
- <dependency>
- <groupId>it.unimi.dsi</groupId>
- <artifactId>fastutil-extra</artifactId>
- <version>${fastutil.version}</version>
- </dependency>
<dependency>
<groupId>com.opencsv</groupId>
<artifactId>opencsv</artifactId>
diff --git a/processing/src/main/java/org/apache/druid/guice/GuiceInjectors.java b/processing/src/main/java/org/apache/druid/guice/GuiceInjectors.java
index 5a7811d..7ee10df 100644
--- a/processing/src/main/java/org/apache/druid/guice/GuiceInjectors.java
+++ b/processing/src/main/java/org/apache/druid/guice/GuiceInjectors.java
@@ -24,6 +24,7 @@ import com.google.inject.Guice;
import com.google.inject.Injector;
import com.google.inject.Module;
import org.apache.druid.jackson.JacksonModule;
+import org.apache.druid.math.expr.ExpressionProcessingModule;
import java.util.ArrayList;
import java.util.Arrays;
@@ -43,6 +44,7 @@ public class GuiceInjectors
new RuntimeInfoModule(),
new ConfigModule(),
new NullHandlingModule(),
+ new ExpressionProcessingModule(),
binder -> {
binder.bind(DruidSecondaryModule.class);
JsonConfigProvider.bind(binder, "druid.extensions", ExtensionsConfig.class);
diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregator.java b/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregator.java
index 3e56b5d..7fadc7e 100644
--- a/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregator.java
+++ b/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregator.java
@@ -19,10 +19,17 @@
package org.apache.druid.query.aggregation;
+import com.google.common.annotations.VisibleForTesting;
+import org.apache.druid.java.util.common.ISE;
+import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprEval;
+import org.apache.druid.segment.column.ObjectByteStrategy;
+import org.apache.druid.segment.column.Types;
import javax.annotation.Nullable;
+import java.util.Arrays;
+import java.util.Objects;
public class ExpressionLambdaAggregator implements Aggregator
{
@@ -48,7 +55,7 @@ public class ExpressionLambdaAggregator implements Aggregator
public void aggregate()
{
final ExprEval<?> eval = lambda.eval(bindings);
- ExprEval.estimateAndCheckMaxBytes(eval, maxSizeBytes);
+ estimateAndCheckMaxBytes(eval, maxSizeBytes);
bindings.accumulate(eval);
hasValue = true;
}
@@ -89,4 +96,88 @@ public class ExpressionLambdaAggregator implements Aggregator
{
// nothing to close
}
+
+ /**
+ * Tries to mimic the byte serialization of {@link Types} binary methods use to write expression values for the
+ * {@link ExpressionLambdaBufferAggregator} in an attempt to provide consistent size limits when using the heap
+ * based algorithm.
+ */
+ @VisibleForTesting
+ public static void estimateAndCheckMaxBytes(ExprEval eval, int maxSizeBytes)
+ {
+ final int estimated;
+ switch (eval.type().getType()) {
+ case STRING:
+ String stringValue = eval.asString();
+ estimated = Integer.BYTES + (stringValue == null ? 0 : StringUtils.estimatedBinaryLengthAsUTF8(stringValue));
+ break;
+ case LONG:
+ case DOUBLE:
+ estimated = Long.BYTES;
+ break;
+ case ARRAY:
+ switch (eval.type().getElementType().getType()) {
+ case STRING:
+ String[] stringArray = eval.asStringArray();
+ if (stringArray == null) {
+ estimated = Integer.BYTES;
+ } else {
+ final int elementsSize = Arrays.stream(stringArray)
+ .filter(Objects::nonNull)
+ .mapToInt(StringUtils::estimatedBinaryLengthAsUTF8)
+ .sum();
+ // since each value is variably sized, there is a null byte, and an integer length per element
+ estimated = Integer.BYTES + (Integer.BYTES * stringArray.length) + elementsSize;
+ }
+ break;
+ case LONG:
+ Long[] longArray = eval.asLongArray();
+ if (longArray == null) {
+ estimated = Integer.BYTES;
+ } else {
+ final int elementsSize = Arrays.stream(longArray)
+ .filter(Objects::nonNull)
+ .mapToInt(x -> Long.BYTES)
+ .sum();
+ // null byte + length int + byte per element + size per element
+ estimated = Integer.BYTES + longArray.length + elementsSize;
+ }
+ break;
+ case DOUBLE:
+ Double[] doubleArray = eval.asDoubleArray();
+ if (doubleArray == null) {
+ estimated = Integer.BYTES;
+ } else {
+ final int elementsSize = Arrays.stream(doubleArray)
+ .filter(Objects::nonNull)
+ .mapToInt(x -> Long.BYTES)
+ .sum();
+ // null byte + length int + byte per element + size per element
+ estimated = Integer.BYTES + doubleArray.length + elementsSize;
+ }
+ break;
+ default:
+ throw new ISE("Unsupported array type: %s", eval.type());
+ }
+ break;
+ case COMPLEX:
+ final ObjectByteStrategy strategy = Types.getStrategy(eval.type().getComplexTypeName());
+ if (strategy != null) {
+ if (eval.value() != null) {
+ // | null (byte) | length (int) | complex type bytes |
+ final byte[] complexBytes = strategy.toBytes(eval.value());
+ estimated = Integer.BYTES + complexBytes.length;
+ } else {
+ estimated = Integer.BYTES;
+ }
+ } else {
+ throw new ISE("Unsupported type: %s", eval.type());
+ }
+ break;
+ default:
+ throw new ISE("Unsupported type: %s", eval.type());
+ }
+ // +1 for the null byte
+ Types.checkMaxBytes(eval.type(), 1 + estimated, maxSizeBytes);
+ }
}
diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactory.java
index 686288a..4c80a53 100644
--- a/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactory.java
+++ b/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactory.java
@@ -39,12 +39,12 @@ import org.apache.druid.math.expr.InputBindings;
import org.apache.druid.math.expr.Parser;
import org.apache.druid.math.expr.SettableObjectBinding;
import org.apache.druid.query.cache.CacheKeyBuilder;
-import org.apache.druid.query.expression.ExprUtils;
import org.apache.druid.segment.ColumnInspector;
import org.apache.druid.segment.ColumnSelectorFactory;
import org.apache.druid.segment.column.ColumnCapabilities;
import org.apache.druid.segment.column.ColumnCapabilitiesImpl;
import org.apache.druid.segment.column.ColumnType;
+import org.apache.druid.segment.column.Types;
import org.apache.druid.segment.virtual.ExpressionPlan;
import org.apache.druid.segment.virtual.ExpressionPlanner;
import org.apache.druid.segment.virtual.ExpressionSelectors;
@@ -92,12 +92,9 @@ public class ExpressionLambdaAggregatorFactory extends AggregatorFactory
private final Supplier<Expr> finalizeExpression;
private final HumanReadableBytes maxSizeBytes;
- private final Supplier<SettableObjectBinding> compareBindings =
- Suppliers.memoize(() -> new SettableObjectBinding(2));
- private final Supplier<SettableObjectBinding> combineBindings =
- Suppliers.memoize(() -> new SettableObjectBinding(2));
- private final Supplier<SettableObjectBinding> finalizeBindings =
- Suppliers.memoize(() -> new SettableObjectBinding(1));
+ private final Supplier<SettableObjectBinding> compareBindings;
+ private final Supplier<SettableObjectBinding> combineBindings;
+ private final Supplier<SettableObjectBinding> finalizeBindings;
private final Supplier<Expr.InputBindingInspector> finalizeInspector;
@JsonCreator
@@ -145,12 +142,12 @@ public class ExpressionLambdaAggregatorFactory extends AggregatorFactory
this.initialValue = Suppliers.memoize(() -> {
Expr parsed = Parser.parse(initialValue, macroTable);
Preconditions.checkArgument(parsed.isLiteral(), "initial value must be constant");
- return parsed.eval(ExprUtils.nilBindings());
+ return parsed.eval(InputBindings.nilBindings());
});
this.initialCombineValue = Suppliers.memoize(() -> {
Expr parsed = Parser.parse(this.initialCombineValueExpressionString, macroTable);
Preconditions.checkArgument(parsed.isLiteral(), "initial combining value must be constant");
- return parsed.eval(ExprUtils.nilBindings());
+ return parsed.eval(InputBindings.nilBindings());
});
this.foldExpression = Parser.lazyParse(foldExpressionString, macroTable);
this.combineExpression = Parser.lazyParse(combineExpressionString, macroTable);
@@ -160,6 +157,29 @@ public class ExpressionLambdaAggregatorFactory extends AggregatorFactory
ImmutableMap.of(FINALIZE_IDENTIFIER, this.initialCombineValue.get().type())
)
);
+ this.compareBindings = Suppliers.memoize(
+ () -> new SettableObjectBinding(2).withInspector(
+ InputBindings.inspectorFromTypeMap(
+ ImmutableMap.of(
+ COMPARE_O1, this.initialCombineValue.get().type(),
+ COMPARE_O2, this.initialCombineValue.get().type()
+ )
+ )
+ )
+ );
+ this.combineBindings = Suppliers.memoize(
+ () -> new SettableObjectBinding(2).withInspector(
+ InputBindings.inspectorFromTypeMap(
+ ImmutableMap.of(
+ accumulatorId, this.initialCombineValue.get().type(),
+ name, this.initialCombineValue.get().type()
+ )
+ )
+ )
+ );
+ this.finalizeBindings = Suppliers.memoize(
+ () -> new SettableObjectBinding(1).withInspector(finalizeInspector.get())
+ );
this.finalizeExpression = Parser.lazyParse(finalizeExpressionString, macroTable);
this.maxSizeBytes = maxSizeBytes != null ? maxSizeBytes : DEFAULT_MAX_SIZE_BYTES;
Preconditions.checkArgument(this.maxSizeBytes.getBytesInInt() >= MIN_SIZE_BYTES);
@@ -285,11 +305,13 @@ public class ExpressionLambdaAggregatorFactory extends AggregatorFactory
return (o1, o2) ->
compareExpr.eval(compareBindings.get().withBinding(COMPARE_O1, o1).withBinding(COMPARE_O2, o2)).asInt();
}
- switch (initialValue.get().type().getType()) {
+ switch (initialCombineValue.get().type().getType()) {
case LONG:
return LongSumAggregator.COMPARATOR;
case DOUBLE:
return DoubleSumAggregator.COMPARATOR;
+ case COMPLEX:
+ return Types.getStrategy(initialCombineValue.get().type().getComplexTypeName());
default:
return Comparators.naturalNullsFirst();
}
diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorInputBindings.java b/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorInputBindings.java
index 5e4864e..36232ff 100644
--- a/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorInputBindings.java
+++ b/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorInputBindings.java
@@ -21,6 +21,7 @@ package org.apache.druid.query.aggregation;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprEval;
+import org.apache.druid.math.expr.ExpressionType;
import javax.annotation.Nullable;
@@ -57,6 +58,16 @@ public class ExpressionLambdaAggregatorInputBindings implements Expr.ObjectBindi
return inputBindings.get(name);
}
+ @Nullable
+ @Override
+ public ExpressionType getType(String name)
+ {
+ if (accumlatorIdentifier.equals(name)) {
+ return accumulator.type();
+ }
+ return inputBindings.getType(name);
+ }
+
public void accumulate(ExprEval<?> eval)
{
accumulator = eval;
diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaBufferAggregator.java b/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaBufferAggregator.java
index 82b954e..cff4aff 100644
--- a/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaBufferAggregator.java
+++ b/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaBufferAggregator.java
@@ -21,7 +21,7 @@ package org.apache.druid.query.aggregation;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprEval;
-import org.apache.druid.math.expr.ExprType;
+import org.apache.druid.math.expr.ExpressionType;
import javax.annotation.Nullable;
import java.nio.ByteBuffer;
@@ -35,6 +35,7 @@ public class ExpressionLambdaBufferAggregator implements BufferAggregator
private final ExpressionLambdaAggregatorInputBindings bindings;
private final int maxSizeBytes;
private final boolean isNullUnlessAggregated;
+ private final ExpressionType outputType;
public ExpressionLambdaBufferAggregator(
Expr lambda,
@@ -46,6 +47,7 @@ public class ExpressionLambdaBufferAggregator implements BufferAggregator
{
this.lambda = lambda;
this.initialValue = initialValue;
+ this.outputType = initialValue.type();
this.bindings = bindings;
this.isNullUnlessAggregated = isNullUnlessAggregated;
this.maxSizeBytes = maxSizeBytes;
@@ -64,7 +66,7 @@ public class ExpressionLambdaBufferAggregator implements BufferAggregator
@Override
public void aggregate(ByteBuffer buf, int position)
{
- ExprEval<?> acc = ExprEval.deserialize(buf, position + 1, getType(buf, position));
+ ExprEval<?> acc = ExprEval.deserialize(buf, position, outputType);
bindings.setAccumulator(acc);
ExprEval<?> newAcc = lambda.eval(bindings);
ExprEval.serialize(buf, position, newAcc, maxSizeBytes);
@@ -79,25 +81,25 @@ public class ExpressionLambdaBufferAggregator implements BufferAggregator
if (isNullUnlessAggregated && (buf.get(position) & NOT_AGGREGATED_BIT) != 0) {
return null;
}
- return ExprEval.deserialize(buf, position + 1, getType(buf, position)).value();
+ return ExprEval.deserialize(buf, position, outputType).value();
}
@Override
public float getFloat(ByteBuffer buf, int position)
{
- return (float) ExprEval.deserialize(buf, position + 1, getType(buf, position)).asDouble();
+ return (float) ExprEval.deserialize(buf, position, outputType).asDouble();
}
@Override
public double getDouble(ByteBuffer buf, int position)
{
- return ExprEval.deserialize(buf, position + 1, getType(buf, position)).asDouble();
+ return ExprEval.deserialize(buf, position, outputType).asDouble();
}
@Override
public long getLong(ByteBuffer buf, int position)
{
- return ExprEval.deserialize(buf, position + 1, getType(buf, position)).asLong();
+ return ExprEval.deserialize(buf, position, outputType).asLong();
}
@Override
@@ -105,9 +107,4 @@ public class ExpressionLambdaBufferAggregator implements BufferAggregator
{
// nothing to close
}
-
- private static ExprType getType(ByteBuffer buf, int position)
- {
- return ExprType.fromByte((byte) (buf.get(position) & IS_AGGREGATED_MASK));
- }
}
diff --git a/processing/src/main/java/org/apache/druid/query/expression/ExprUtils.java b/processing/src/main/java/org/apache/druid/query/expression/ExprUtils.java
index 184ff53..4aa2317 100644
--- a/processing/src/main/java/org/apache/druid/query/expression/ExprUtils.java
+++ b/processing/src/main/java/org/apache/druid/query/expression/ExprUtils.java
@@ -35,13 +35,6 @@ import javax.annotation.Nullable;
public class ExprUtils
{
- private static final Expr.ObjectBinding NIL_BINDINGS = name -> null;
-
- public static Expr.ObjectBinding nilBindings()
- {
- return NIL_BINDINGS;
- }
-
static DateTimeZone toTimeZone(final Expr timeZoneArg)
{
if (!timeZoneArg.isLiteral()) {
@@ -110,19 +103,6 @@ public class ExprUtils
static boolean isStringLiteral(final Expr expr)
{
return (expr.isLiteral() && expr.getLiteralValue() instanceof String)
- || (NullHandling.replaceWithDefault() && isNullLiteral(expr));
- }
-
- /**
- * True if Expr is a null literal.
- *
- * In non-SQL-compliant null handling mode, this method will return true for either a null literal or an empty string
- * literal (they are treated equivalently and we cannot tell the difference).
- *
- * In SQL-compliant null handling mode, this method will only return true for an actual null literal.
- */
- static boolean isNullLiteral(final Expr expr)
- {
- return expr.isLiteral() && expr.getLiteralValue() == null;
+ || (NullHandling.replaceWithDefault() && expr.isNullLiteral());
}
}
diff --git a/processing/src/main/java/org/apache/druid/query/expression/HyperUniqueExpressions.java b/processing/src/main/java/org/apache/druid/query/expression/HyperUniqueExpressions.java
new file mode 100644
index 0000000..32335ef
--- /dev/null
+++ b/processing/src/main/java/org/apache/druid/query/expression/HyperUniqueExpressions.java
@@ -0,0 +1,333 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.query.expression;
+
+import org.apache.druid.common.config.NullHandling;
+import org.apache.druid.hll.HyperLogLogCollector;
+import org.apache.druid.java.util.common.IAE;
+import org.apache.druid.java.util.common.StringUtils;
+import org.apache.druid.math.expr.Expr;
+import org.apache.druid.math.expr.ExprEval;
+import org.apache.druid.math.expr.ExprMacroTable;
+import org.apache.druid.math.expr.ExprType;
+import org.apache.druid.math.expr.ExpressionType;
+import org.apache.druid.query.aggregation.cardinality.CardinalityAggregator;
+import org.apache.druid.query.aggregation.cardinality.types.StringCardinalityAggregatorColumnSelectorStrategy;
+import org.apache.druid.query.aggregation.hyperloglog.HyperUniquesAggregatorFactory;
+
+import javax.annotation.Nullable;
+import java.util.Collections;
+import java.util.List;
+import java.util.Objects;
+import java.util.stream.Collectors;
+
+public class HyperUniqueExpressions
+{
+ public static final ExpressionType TYPE = ExpressionType.fromColumnType(HyperUniquesAggregatorFactory.TYPE);
+
+ public static class HllCreateExprMacro implements ExprMacroTable.ExprMacro
+ {
+ private static final String NAME = "hyper_unique";
+
+ @Override
+ public String name()
+ {
+ return NAME;
+ }
+
+ @Override
+ public Expr apply(List<Expr> args)
+ {
+ if (args.size() > 0) {
+ throw new IAE("Function[%s] must have no arguments", name());
+ }
+ final HyperLogLogCollector collector = HyperLogLogCollector.makeLatestCollector();
+ class HllExpression implements ExprMacroTable.ExprMacroFunctionExpr
+ {
+ @Override
+ public ExprEval eval(ObjectBinding bindings)
+ {
+ return ExprEval.ofComplex(TYPE, collector);
+ }
+
+ @Override
+ public String stringify()
+ {
+ return StringUtils.format("%s()", NAME);
+ }
+
+ @Override
+ public Expr visit(Shuttle shuttle)
+ {
+ return shuttle.visit(this);
+ }
+
+ @Override
+ public BindingAnalysis analyzeInputs()
+ {
+ return BindingAnalysis.EMTPY;
+ }
+
+ @Nullable
+ @Override
+ public ExpressionType getOutputType(InputBindingInspector inspector)
+ {
+ return TYPE;
+ }
+
+ @Override
+ public List<Expr> getArgs()
+ {
+ return Collections.emptyList();
+ }
+
+ @Override
+ public int hashCode()
+ {
+ return Objects.hashCode(NAME);
+ }
+
+ @Override
+ public boolean equals(Object obj)
+ {
+ if (this == obj) {
+ return true;
+ }
+ if (obj == null || getClass() != obj.getClass()) {
+ return false;
+ }
+ return true;
+ }
+
+ @Override
+ public String toString()
+ {
+ return StringUtils.format("(%s)", NAME);
+ }
+ }
+ return new HllExpression();
+ }
+ }
+
+ public static class HllAddExprMacro implements ExprMacroTable.ExprMacro
+ {
+ private static final String NAME = "hyper_unique_add";
+
+ @Override
+ public String name()
+ {
+ return NAME;
+ }
+
+ @Override
+ public Expr apply(List<Expr> args)
+ {
+ if (args.size() != 2) {
+ throw new IAE("Function[%s] must have 2 arguments", name());
+ }
+
+
+ class HllExpr extends ExprMacroTable.BaseScalarMacroFunctionExpr
+ {
+ public HllExpr(List<Expr> args)
+ {
+ super(NAME, args);
+ }
+
+ @Override
+ public ExprEval eval(ObjectBinding bindings)
+ {
+ ExprEval hllCollector = args.get(1).eval(bindings);
+ ExpressionType hllType = hllCollector.type();
+ // be permissive for now, we can count more on this later when we are better at retaining complete complex
+ // type information everywhere
+ if (!TYPE.equals(hllType) ||
+ !(hllType.is(ExprType.COMPLEX) && hllCollector.value() instanceof HyperLogLogCollector)
+ ) {
+ throw new IAE("Function[%s] must take a hyper-log-log collector as the second argument", NAME);
+ }
+ HyperLogLogCollector collector = (HyperLogLogCollector) hllCollector.value();
+ assert collector != null;
+ ExprEval input = args.get(0).eval(bindings);
+ switch (input.type().getType()) {
+ case STRING:
+ if (input.value() == null) {
+ if (NullHandling.replaceWithDefault()) {
+ collector.add(
+ CardinalityAggregator.HASH_FUNCTION.hashUnencodedChars(
+ StringCardinalityAggregatorColumnSelectorStrategy.CARDINALITY_AGG_NULL_STRING
+ ).asBytes()
+ );
+ }
+ } else {
+ collector.add(CardinalityAggregator.HASH_FUNCTION.hashUnencodedChars(input.asString()).asBytes());
+ }
+ break;
+ case DOUBLE:
+ if (NullHandling.replaceWithDefault() || !input.isNumericNull()) {
+ collector.add(CardinalityAggregator.HASH_FUNCTION.hashLong(Double.doubleToLongBits(input.asDouble())).asBytes());
+ }
+ break;
+ case LONG:
+ if (NullHandling.replaceWithDefault() || !input.isNumericNull()) {
+ collector.add(CardinalityAggregator.HASH_FUNCTION.hashLong(input.asLong()).asBytes());
+ }
+ break;
+ case COMPLEX:
+ if (TYPE.equals(input.type()) || hllType.is(ExprType.COMPLEX) && hllCollector.value() instanceof HyperLogLogCollector) {
+ collector.fold((HyperLogLogCollector) input.value());
+ break;
+ }
+ default:
+ throw new IAE("Function[%s] cannot add [%s] to hyper-log-log collector", NAME, input.type());
+ }
+
+ return ExprEval.ofComplex(TYPE, collector);
+ }
+
+ @Override
+ public Expr visit(Shuttle shuttle)
+ {
+ List<Expr> newArgs = args.stream().map(x -> x.visit(shuttle)).collect(Collectors.toList());
+ return shuttle.visit(new HllExpr(newArgs));
+ }
+
+ @Nullable
+ @Override
+ public ExpressionType getOutputType(InputBindingInspector inspector)
+ {
+ return TYPE;
+ }
+ }
+ return new HllExpr(args);
+ }
+ }
+
+ public static class HllEstimateExprMacro implements ExprMacroTable.ExprMacro
+ {
+ public static final String NAME = "hyper_unique_estimate";
+
+ @Override
+ public String name()
+ {
+ return NAME;
+ }
+
+ @Override
+ public Expr apply(List<Expr> args)
+ {
+ if (args.size() != 1) {
+ throw new IAE("Function[%s] must have 1 argument", name());
+ }
+ class HllExpr extends ExprMacroTable.BaseScalarUnivariateMacroFunctionExpr
+ {
+ public HllExpr(Expr arg)
+ {
+ super(NAME, arg);
+ }
+
+ @Override
+ public ExprEval eval(ObjectBinding bindings)
+ {
+ ExprEval hllCollector = args.get(0).eval(bindings);
+ // be permissive for now, we can count more on this later when we are better at retaining complete complex
+ // type information everywhere
+ if (!TYPE.equals(hllCollector.type()) ||
+ !(hllCollector.type().is(ExprType.COMPLEX) && hllCollector.value() instanceof HyperLogLogCollector)
+ ) {
+ throw new IAE("Function[%s] must take a hyper-log-log collector as input", NAME);
+ }
+ HyperLogLogCollector collector = (HyperLogLogCollector) hllCollector.value();
+ assert collector != null;
+ return ExprEval.ofDouble(collector.estimateCardinality());
+ }
+
+ @Override
+ public Expr visit(Shuttle shuttle)
+ {
+ Expr newArg = arg.visit(shuttle);
+ return shuttle.visit(new HllExpr(newArg));
+ }
+
+ @Nullable
+ @Override
+ public ExpressionType getOutputType(InputBindingInspector inspector)
+ {
+ return ExpressionType.DOUBLE;
+ }
+ }
+ return new HllExpr(args.get(0));
+ }
+ }
+
+ public static class HllRoundEstimateExprMacro implements ExprMacroTable.ExprMacro
+ {
+ public static final String NAME = "hyper_unique_round_estimate";
+
+ @Override
+ public String name()
+ {
+ return NAME;
+ }
+
+ @Override
+ public Expr apply(List<Expr> args)
+ {
+ if (args.size() != 1) {
+ throw new IAE("Function[%s] must have 1 argument", name());
+ }
+
+ class HllExpr extends ExprMacroTable.BaseScalarUnivariateMacroFunctionExpr
+ {
+ public HllExpr(Expr arg)
+ {
+ super(NAME, arg);
+ }
+
+ @Override
+ public ExprEval eval(ObjectBinding bindings)
+ {
+ ExprEval hllCollector = args.get(0).eval(bindings);
+ if (!hllCollector.type().equals(TYPE)) {
+ throw new IAE("Function[%s] must take a hyper-log-log collector as input", NAME);
+ }
+ HyperLogLogCollector collector = (HyperLogLogCollector) hllCollector.value();
+ assert collector != null;
+ return ExprEval.ofLong(collector.estimateCardinalityRound());
+ }
+
+ @Override
+ public Expr visit(Shuttle shuttle)
+ {
+ Expr newArg = arg.visit(shuttle);
+ return shuttle.visit(new HllExpr(newArg));
+ }
+
+ @Nullable
+ @Override
+ public ExpressionType getOutputType(InputBindingInspector inspector)
+ {
+ return ExpressionType.LONG;
+ }
+ }
+ return new HllExpr(args.get(0));
+ }
+ }
+}
diff --git a/processing/src/main/java/org/apache/druid/query/expression/TimestampCeilExprMacro.java b/processing/src/main/java/org/apache/druid/query/expression/TimestampCeilExprMacro.java
index 9e137f2..39ad178 100644
--- a/processing/src/main/java/org/apache/druid/query/expression/TimestampCeilExprMacro.java
+++ b/processing/src/main/java/org/apache/druid/query/expression/TimestampCeilExprMacro.java
@@ -27,6 +27,7 @@ import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprEval;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.math.expr.ExpressionType;
+import org.apache.druid.math.expr.InputBindings;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
@@ -66,7 +67,7 @@ public class TimestampCeilExprMacro implements ExprMacroTable.ExprMacro
TimestampCeilExpr(final List<Expr> args)
{
super(FN_NAME, args);
- this.granularity = getGranularity(args, ExprUtils.nilBindings());
+ this.granularity = getGranularity(args, InputBindings.nilBindings());
}
@Nonnull
diff --git a/processing/src/main/java/org/apache/druid/query/expression/TimestampFloorExprMacro.java b/processing/src/main/java/org/apache/druid/query/expression/TimestampFloorExprMacro.java
index 1b9fa58..20507ad 100644
--- a/processing/src/main/java/org/apache/druid/query/expression/TimestampFloorExprMacro.java
+++ b/processing/src/main/java/org/apache/druid/query/expression/TimestampFloorExprMacro.java
@@ -25,6 +25,7 @@ import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprEval;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.math.expr.ExpressionType;
+import org.apache.druid.math.expr.InputBindings;
import org.apache.druid.math.expr.vector.CastToTypeVectorProcessor;
import org.apache.druid.math.expr.vector.ExprVectorProcessor;
import org.apache.druid.math.expr.vector.LongOutLongInFunctionVectorProcessor;
@@ -76,7 +77,7 @@ public class TimestampFloorExprMacro implements ExprMacroTable.ExprMacro
TimestampFloorExpr(final List<Expr> args)
{
super(FN_NAME, args);
- this.granularity = computeGranularity(args, ExprUtils.nilBindings());
+ this.granularity = computeGranularity(args, InputBindings.nilBindings());
}
/**
diff --git a/processing/src/main/java/org/apache/druid/query/expression/TimestampShiftExprMacro.java b/processing/src/main/java/org/apache/druid/query/expression/TimestampShiftExprMacro.java
index 765ad89..1c60492 100644
--- a/processing/src/main/java/org/apache/druid/query/expression/TimestampShiftExprMacro.java
+++ b/processing/src/main/java/org/apache/druid/query/expression/TimestampShiftExprMacro.java
@@ -25,6 +25,7 @@ import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprEval;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.math.expr.ExpressionType;
+import org.apache.druid.math.expr.InputBindings;
import org.joda.time.Chronology;
import org.joda.time.Period;
import org.joda.time.chrono.ISOChronology;
@@ -90,9 +91,9 @@ public class TimestampShiftExprMacro implements ExprMacroTable.ExprMacro
TimestampShiftExpr(final List<Expr> args)
{
super(FN_NAME, args);
- period = getPeriod(args, ExprUtils.nilBindings());
- chronology = getTimeZone(args, ExprUtils.nilBindings());
- step = getStep(args, ExprUtils.nilBindings());
+ period = getPeriod(args, InputBindings.nilBindings());
+ chronology = getTimeZone(args, InputBindings.nilBindings());
+ step = getStep(args, InputBindings.nilBindings());
}
@Nonnull
diff --git a/processing/src/main/java/org/apache/druid/query/expression/TrimExprMacro.java b/processing/src/main/java/org/apache/druid/query/expression/TrimExprMacro.java
index f07807e..fca9e21 100644
--- a/processing/src/main/java/org/apache/druid/query/expression/TrimExprMacro.java
+++ b/processing/src/main/java/org/apache/druid/query/expression/TrimExprMacro.java
@@ -27,6 +27,7 @@ import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprEval;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.math.expr.ExpressionType;
+import org.apache.druid.math.expr.InputBindings;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
@@ -97,7 +98,7 @@ public abstract class TrimExprMacro implements ExprMacroTable.ExprMacro
} else {
final Expr charsArg = args.get(1);
if (charsArg.isLiteral()) {
- final String charsString = charsArg.eval(ExprUtils.nilBindings()).asString();
+ final String charsString = charsArg.eval(InputBindings.nilBindings()).asString();
final char[] chars = charsString == null ? EMPTY_CHARS : charsString.toCharArray();
return new TrimStaticCharsExpr(mode, args.get(0), chars, charsArg);
} else {
diff --git a/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/RowBasedGrouperHelper.java b/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/RowBasedGrouperHelper.java
index dddc1eb..e585398 100644
--- a/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/RowBasedGrouperHelper.java
+++ b/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/RowBasedGrouperHelper.java
@@ -378,7 +378,7 @@ public class RowBasedGrouperHelper
return RowBasedColumnSelectorFactory.create(
adapter,
supplier::get,
- query.getResultRowSignature(),
+ () -> query.getResultRowSignature(),
false
);
}
diff --git a/processing/src/main/java/org/apache/druid/query/timeseries/TimeseriesQueryQueryToolChest.java b/processing/src/main/java/org/apache/druid/query/timeseries/TimeseriesQueryQueryToolChest.java
index 34254b6..89c317a 100644
--- a/processing/src/main/java/org/apache/druid/query/timeseries/TimeseriesQueryQueryToolChest.java
+++ b/processing/src/main/java/org/apache/druid/query/timeseries/TimeseriesQueryQueryToolChest.java
@@ -230,7 +230,7 @@ public class TimeseriesQueryQueryToolChest extends QueryToolChest<Result<Timeser
RowBasedColumnSelectorFactory.create(
RowAdapters.standardRow(),
() -> new MapBasedRow(null, null),
- RowSignature.empty(),
+ () -> RowSignature.builder().addAggregators(aggregatorSpecs).build(),
false
)
);
diff --git a/processing/src/main/java/org/apache/druid/segment/RowBasedColumnSelectorFactory.java b/processing/src/main/java/org/apache/druid/segment/RowBasedColumnSelectorFactory.java
index ea1a48e..6c5b3ef 100644
--- a/processing/src/main/java/org/apache/druid/segment/RowBasedColumnSelectorFactory.java
+++ b/processing/src/main/java/org/apache/druid/segment/RowBasedColumnSelectorFactory.java
@@ -49,19 +49,19 @@ public class RowBasedColumnSelectorFactory<T> implements ColumnSelectorFactory
{
private final Supplier<T> supplier;
private final RowAdapter<T> adapter;
- private final RowSignature rowSignature;
+ private final Supplier<RowSignature> rowSignatureSupplier;
private final boolean throwParseExceptions;
private RowBasedColumnSelectorFactory(
final Supplier<T> supplier,
final RowAdapter<T> adapter,
- final RowSignature rowSignature,
+ final Supplier<RowSignature> rowSignatureSupplier,
final boolean throwParseExceptions
)
{
this.supplier = supplier;
this.adapter = adapter;
- this.rowSignature = Preconditions.checkNotNull(rowSignature, "rowSignature must be nonnull");
+ this.rowSignatureSupplier = Preconditions.checkNotNull(rowSignatureSupplier, "rowSignature must be nonnull");
this.throwParseExceptions = throwParseExceptions;
}
@@ -70,7 +70,7 @@ public class RowBasedColumnSelectorFactory<T> implements ColumnSelectorFactory
*
* @param adapter adapter for these row objects
* @param supplier supplier of row objects
- * @param signature will be used for reporting available columns and their capabilities. Note that the this
+ * @param signatureSupplier will be used for reporting available columns and their capabilities. Note that the this
* factory will still allow creation of selectors on any named field in the rows, even if
* it doesn't appear in "rowSignature". (It only needs to be accessible via
* {@link RowAdapter#columnFunction}.) As a result, you can achieve an untyped mode by
@@ -81,11 +81,11 @@ public class RowBasedColumnSelectorFactory<T> implements ColumnSelectorFactory
public static <RowType> RowBasedColumnSelectorFactory<RowType> create(
final RowAdapter<RowType> adapter,
final Supplier<RowType> supplier,
- final RowSignature signature,
+ final Supplier<RowSignature> signatureSupplier,
final boolean throwParseExceptions
)
{
- return new RowBasedColumnSelectorFactory<>(supplier, adapter, signature, throwParseExceptions);
+ return new RowBasedColumnSelectorFactory<>(supplier, adapter, signatureSupplier, throwParseExceptions);
}
@Nullable
@@ -452,6 +452,6 @@ public class RowBasedColumnSelectorFactory<T> implements ColumnSelectorFactory
@Override
public ColumnCapabilities getColumnCapabilities(String columnName)
{
- return getColumnCapabilities(rowSignature, columnName);
+ return getColumnCapabilities(rowSignatureSupplier.get(), columnName);
}
}
diff --git a/processing/src/main/java/org/apache/druid/segment/RowBasedCursor.java b/processing/src/main/java/org/apache/druid/segment/RowBasedCursor.java
index d4447a8..863b725 100644
--- a/processing/src/main/java/org/apache/druid/segment/RowBasedCursor.java
+++ b/processing/src/main/java/org/apache/druid/segment/RowBasedCursor.java
@@ -66,7 +66,7 @@ public class RowBasedCursor<RowType> implements Cursor
RowBasedColumnSelectorFactory.create(
rowAdapter,
rowWalker::currentRow,
- rowSignature,
+ () -> rowSignature,
false
)
);
diff --git a/processing/src/main/java/org/apache/druid/segment/data/ObjectStrategy.java b/processing/src/main/java/org/apache/druid/segment/data/ObjectStrategy.java
index ed4410b..ce0b44e 100644
--- a/processing/src/main/java/org/apache/druid/segment/data/ObjectStrategy.java
+++ b/processing/src/main/java/org/apache/druid/segment/data/ObjectStrategy.java
@@ -20,37 +20,15 @@
package org.apache.druid.segment.data;
import org.apache.druid.guice.annotations.ExtensionPoint;
+import org.apache.druid.segment.column.ObjectByteStrategy;
import org.apache.druid.segment.writeout.WriteOutBytes;
-import javax.annotation.Nullable;
import java.io.IOException;
import java.nio.ByteBuffer;
-import java.util.Comparator;
@ExtensionPoint
-public interface ObjectStrategy<T> extends Comparator<T>
+public interface ObjectStrategy<T> extends ObjectByteStrategy<T>
{
- Class<? extends T> getClazz();
-
- /**
- * Convert values from their underlying byte representation.
- *
- * Implementations of this method <i>may</i> change the given buffer's mark, or limit, and position.
- *
- * Implementations of this method <i>may not</i> store the given buffer in a field of the "deserialized" object,
- * need to use {@link ByteBuffer#slice()}, {@link ByteBuffer#asReadOnlyBuffer()} or {@link ByteBuffer#duplicate()} in
- * this case.
- *
- * @param buffer buffer to read value from
- * @param numBytes number of bytes used to store the value, starting at buffer.position()
- * @return an object created from the given byte buffer representation
- */
- @Nullable
- T fromByteBuffer(ByteBuffer buffer, int numBytes);
-
- @Nullable
- byte[] toBytes(@Nullable T val);
-
/**
* Reads 4-bytes numBytes from the given buffer, and then delegates to {@link #fromByteBuffer(ByteBuffer, int)}.
*/
diff --git a/processing/src/main/java/org/apache/druid/segment/filter/ExpressionFilter.java b/processing/src/main/java/org/apache/druid/segment/filter/ExpressionFilter.java
index ae31539..4ec947a 100644
--- a/processing/src/main/java/org/apache/druid/segment/filter/ExpressionFilter.java
+++ b/processing/src/main/java/org/apache/druid/segment/filter/ExpressionFilter.java
@@ -29,8 +29,8 @@ import org.apache.druid.math.expr.Evals;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprEval;
import org.apache.druid.math.expr.ExpressionType;
+import org.apache.druid.math.expr.InputBindings;
import org.apache.druid.query.BitmapResultFactory;
-import org.apache.druid.query.expression.ExprUtils;
import org.apache.druid.query.filter.BitmapIndexSelector;
import org.apache.druid.query.filter.DruidDoublePredicate;
import org.apache.druid.query.filter.DruidFloatPredicate;
@@ -143,7 +143,7 @@ public class ExpressionFilter implements Filter
// or not.
return BooleanVectorValueMatcher.of(
factory.getReadableVectorInspector(),
- theExpr.eval(ExprUtils.nilBindings()).asBoolean()
+ theExpr.eval(InputBindings.nilBindings()).asBoolean()
);
}
@@ -181,14 +181,14 @@ public class ExpressionFilter implements Filter
final ExprEval eval = selector.getObject();
if (eval.type().isArray()) {
- switch (eval.type().getElementType().getType()) {
+ switch (eval.elementType().getType()) {
case LONG:
final Long[] lResult = eval.asLongArray();
if (lResult == null) {
return false;
}
- return Arrays.stream(lResult).anyMatch(Evals::asBoolean);
+ return Arrays.stream(lResult).filter(Objects::nonNull).anyMatch(Evals::asBoolean);
case STRING:
final String[] sResult = eval.asStringArray();
if (sResult == null) {
@@ -202,7 +202,7 @@ public class ExpressionFilter implements Filter
return false;
}
- return Arrays.stream(dResult).anyMatch(Evals::asBoolean);
+ return Arrays.stream(dResult).filter(Objects::nonNull).anyMatch(Evals::asBoolean);
}
}
return eval.asBoolean();
@@ -248,7 +248,7 @@ public class ExpressionFilter implements Filter
{
if (bindingDetails.get().getRequiredBindings().isEmpty()) {
// Constant expression.
- if (expr.get().eval(ExprUtils.nilBindings()).asBoolean()) {
+ if (expr.get().eval(InputBindings.nilBindings()).asBoolean()) {
return bitmapResultFactory.wrapAllTrue(Filters.allTrue(selector));
} else {
return bitmapResultFactory.wrapAllFalse(Filters.allFalse(selector));
@@ -263,12 +263,12 @@ public class ExpressionFilter implements Filter
column,
selector,
bitmapResultFactory,
- value -> expr.get().eval(identifierName -> {
+ value -> expr.get().eval(InputBindings.forFunction(identifierName -> {
// There's only one binding, and it must be the single column, so it can safely be ignored in production.
assert column.equals(identifierName);
// convert null to Empty before passing to expressions if needed.
return NullHandling.nullToEmptyIfNeeded(value);
- }).asBoolean()
+ })).asBoolean()
);
}
}
diff --git a/processing/src/main/java/org/apache/druid/segment/incremental/IncrementalIndex.java b/processing/src/main/java/org/apache/druid/segment/incremental/IncrementalIndex.java
index 0cfadf3..ec4aaaf 100644
--- a/processing/src/main/java/org/apache/druid/segment/incremental/IncrementalIndex.java
+++ b/processing/src/main/java/org/apache/druid/segment/incremental/IncrementalIndex.java
@@ -109,6 +109,7 @@ public abstract class IncrementalIndex extends AbstractIndex implements Iterable
* @return column selector factory
*/
public static ColumnSelectorFactory makeColumnSelectorFactory(
+ final Supplier<RowSignature> rowSignatureSupplier,
final VirtualColumns virtualColumns,
final AggregatorFactory agg,
final Supplier<InputRow> in,
@@ -118,7 +119,7 @@ public abstract class IncrementalIndex extends AbstractIndex implements Iterable
final RowBasedColumnSelectorFactory<InputRow> baseSelectorFactory = RowBasedColumnSelectorFactory.create(
RowAdapters.standardRow(),
in::get,
- RowSignature.empty(),
+ rowSignatureSupplier::get,
true
);
@@ -264,6 +265,8 @@ public abstract class IncrementalIndex extends AbstractIndex implements Iterable
this.deserializeComplexMetrics = deserializeComplexMetrics;
this.timeAndMetricsColumnCapabilities = new HashMap<>();
+ this.metricDescs = Maps.newLinkedHashMap();
+ this.dimensionDescs = Maps.newLinkedHashMap();
this.metadata = new Metadata(
null,
getCombiningAggregators(metrics),
@@ -274,7 +277,6 @@ public abstract class IncrementalIndex extends AbstractIndex implements Iterable
initAggs(metrics, rowSupplier, deserializeComplexMetrics, concurrentEventAdd);
- this.metricDescs = Maps.newLinkedHashMap();
for (AggregatorFactory metric : metrics) {
MetricDesc metricDesc = new MetricDesc(metricDescs.size(), metric);
metricDescs.put(metricDesc.getName(), metricDesc);
@@ -282,7 +284,6 @@ public abstract class IncrementalIndex extends AbstractIndex implements Iterable
}
DimensionsSpec dimensionsSpec = incrementalIndexSchema.getDimensionsSpec();
- this.dimensionDescs = Maps.newLinkedHashMap();
this.dimensionDescsList = new ArrayList<>();
for (DimensionSchema dimSchema : dimensionsSpec.getDimensions()) {
@@ -986,7 +987,15 @@ public abstract class IncrementalIndex extends AbstractIndex implements Iterable
final boolean deserializeComplexMetrics
)
{
- return makeColumnSelectorFactory(virtualColumns, agg, in, deserializeComplexMetrics);
+ Supplier<RowSignature> signatureSupplier = () -> {
+ Map<String, ColumnCapabilities> capabilitiesMap = getColumnCapabilities();
+ RowSignature.Builder bob = RowSignature.builder();
+ for (Map.Entry<String, ColumnCapabilities> capabilitiesEntry : capabilitiesMap.entrySet()) {
+ bob.add(capabilitiesEntry.getKey(), capabilitiesEntry.getValue().toColumnType());
+ }
+ return bob.build();
+ };
+ return makeColumnSelectorFactory(signatureSupplier, virtualColumns, agg, in, deserializeComplexMetrics);
}
protected final Comparator<IncrementalIndexRow> dimsComparator()
diff --git a/processing/src/main/java/org/apache/druid/segment/join/JoinConditionAnalysis.java b/processing/src/main/java/org/apache/druid/segment/join/JoinConditionAnalysis.java
index 53460b7..4a35998 100644
--- a/processing/src/main/java/org/apache/druid/segment/join/JoinConditionAnalysis.java
+++ b/processing/src/main/java/org/apache/druid/segment/join/JoinConditionAnalysis.java
@@ -24,8 +24,8 @@ import org.apache.druid.java.util.common.Pair;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.math.expr.Exprs;
+import org.apache.druid.math.expr.InputBindings;
import org.apache.druid.math.expr.Parser;
-import org.apache.druid.query.expression.ExprUtils;
import java.util.ArrayList;
import java.util.Collections;
@@ -74,12 +74,12 @@ public class JoinConditionAnalysis
this.nonEquiConditions = Collections.unmodifiableList(nonEquiConditions);
// if any nonEquiCondition is an expression and it evaluates to false
isAlwaysFalse = nonEquiConditions.stream()
- .anyMatch(expr -> expr.isLiteral() && !expr.eval(ExprUtils.nilBindings())
+ .anyMatch(expr -> expr.isLiteral() && !expr.eval(InputBindings.nilBindings())
.asBoolean());
// if there are no equiConditions and all nonEquiConditions are literals and the evaluate to true
isAlwaysTrue = equiConditions.isEmpty() && nonEquiConditions.stream()
.allMatch(expr -> expr.isLiteral() && expr.eval(
- ExprUtils.nilBindings()).asBoolean());
+ InputBindings.nilBindings()).asBoolean());
canHashJoin = nonEquiConditions.stream().allMatch(Expr::isLiteral);
rightKeyColumns = getEquiConditions().stream().map(Equality::getRightColumn).collect(Collectors.toSet());
requiredColumns = computeRequiredColumns(rightPrefix, equiConditions, nonEquiConditions);
diff --git a/processing/src/main/java/org/apache/druid/segment/serde/ComplexMetrics.java b/processing/src/main/java/org/apache/druid/segment/serde/ComplexMetrics.java
index 4f40ff6..03cd44b 100644
--- a/processing/src/main/java/org/apache/druid/segment/serde/ComplexMetrics.java
+++ b/processing/src/main/java/org/apache/druid/segment/serde/ComplexMetrics.java
@@ -20,6 +20,7 @@
package org.apache.druid.segment.serde;
import org.apache.druid.java.util.common.ISE;
+import org.apache.druid.segment.column.Types;
import javax.annotation.Nullable;
import java.util.concurrent.ConcurrentHashMap;
@@ -62,6 +63,7 @@ public class ComplexMetrics
value.getClass().getName()
);
} else {
+ Types.registerStrategy(type, serde.getObjectStrategy());
return value;
}
}
diff --git a/processing/src/main/java/org/apache/druid/segment/transform/ExpressionTransform.java b/processing/src/main/java/org/apache/druid/segment/transform/ExpressionTransform.java
index caa8daf..56ac2a7 100644
--- a/processing/src/main/java/org/apache/druid/segment/transform/ExpressionTransform.java
+++ b/processing/src/main/java/org/apache/druid/segment/transform/ExpressionTransform.java
@@ -25,9 +25,12 @@ import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.base.Preconditions;
import com.google.common.base.Suppliers;
import org.apache.druid.data.input.Row;
+import org.apache.druid.java.util.common.NonnullPair;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprEval;
import org.apache.druid.math.expr.ExprMacroTable;
+import org.apache.druid.math.expr.ExpressionType;
+import org.apache.druid.math.expr.InputBindings;
import org.apache.druid.math.expr.Parser;
import org.apache.druid.segment.column.ColumnHolder;
import org.apache.druid.segment.virtual.ExpressionSelectors;
@@ -96,7 +99,9 @@ public class ExpressionTransform implements Transform
@Override
public Object eval(final Row row)
{
- return ExpressionSelectors.coerceEvalToSelectorObject(expr.eval(name -> getValueFromRow(row, name)));
+ return ExpressionSelectors.coerceEvalToSelectorObject(
+ expr.eval(InputBindings.forFunction(name -> getValueFromRow(row, name)))
+ );
}
}
@@ -107,7 +112,11 @@ public class ExpressionTransform implements Transform
} else {
Object raw = row.getRaw(column);
if (raw instanceof List) {
- return ExprEval.coerceListToArray((List) raw, true);
+ NonnullPair<ExpressionType, Object[]> coerced = ExprEval.coerceListToArray((List) raw, true);
+ if (coerced == null) {
+ return null;
+ }
+ return coerced.rhs;
}
return raw;
}
diff --git a/processing/src/main/java/org/apache/druid/segment/transform/Transformer.java b/processing/src/main/java/org/apache/druid/segment/transform/Transformer.java
index 9db7fe6..27bcba5 100644
--- a/processing/src/main/java/org/apache/druid/segment/transform/Transformer.java
+++ b/processing/src/main/java/org/apache/druid/segment/transform/Transformer.java
@@ -59,7 +59,7 @@ public class Transformer
RowBasedColumnSelectorFactory.create(
RowAdapters.standardRow(),
rowSupplierForValueMatcher::get,
- RowSignature.empty(),
+ RowSignature::empty, // sad
false
)
);
diff --git a/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionSelectors.java b/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionSelectors.java
index 889ec66..27b4ad0 100644
--- a/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionSelectors.java
+++ b/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionSelectors.java
@@ -24,11 +24,13 @@ import com.google.common.base.Preconditions;
import com.google.common.base.Supplier;
import com.google.common.collect.Iterables;
import org.apache.druid.common.config.NullHandling;
+import org.apache.druid.java.util.common.NonnullPair;
+import org.apache.druid.java.util.common.Pair;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprEval;
+import org.apache.druid.math.expr.ExpressionType;
import org.apache.druid.math.expr.InputBindings;
import org.apache.druid.query.dimension.DefaultDimensionSpec;
-import org.apache.druid.query.expression.ExprUtils;
import org.apache.druid.query.extraction.ExtractionFn;
import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector;
import org.apache.druid.segment.BaseObjectColumnValueSelector;
@@ -159,7 +161,7 @@ public class ExpressionSelectors
final Expr.ObjectBinding bindings = createBindings(plan.getAnalysis(), columnSelectorFactory);
// Optimization for constant expressions
- if (bindings.equals(ExprUtils.nilBindings())) {
+ if (bindings.equals(InputBindings.nilBindings())) {
return new ConstantExprEvalSelector(plan.getExpression().eval(bindings));
}
@@ -261,11 +263,12 @@ public class ExpressionSelectors
List<String> columns
)
{
- final Map<String, Supplier<Object>> suppliers = new HashMap<>();
+ final Map<String, Pair<ExpressionType, Supplier<Object>>> suppliers = new HashMap<>();
for (String columnName : columns) {
final ColumnCapabilities columnCapabilities = columnSelectorFactory.getColumnCapabilities(columnName);
final boolean multiVal = columnCapabilities != null && columnCapabilities.hasMultipleValues().isTrue();
final Supplier<Object> supplier;
+ final ExpressionType expressionType = ExpressionType.fromColumnType(columnCapabilities);
if (columnCapabilities == null || columnCapabilities.isArray()) {
// Unknown ValueType or array type. Try making an Object selector and see if that gives us anything useful.
@@ -285,30 +288,48 @@ public class ExpressionSelectors
multiVal
);
} else {
- // Unhandleable ValueType (COMPLEX).
- supplier = null;
+ // complex type just pass straight through
+ ColumnValueSelector<?> selector = columnSelectorFactory.makeColumnValueSelector(columnName);
+ if (!(selector instanceof NilColumnValueSelector)) {
+ supplier = selector::getObject;
+ } else {
+ supplier = null;
+ }
}
if (supplier != null) {
- suppliers.put(columnName, supplier);
+ suppliers.put(columnName, new Pair<>(expressionType, supplier));
}
}
if (suppliers.isEmpty()) {
- return ExprUtils.nilBindings();
+ return InputBindings.nilBindings();
} else if (suppliers.size() == 1 && columns.size() == 1) {
// If there's only one column (and it has a supplier), we can skip the Map and just use that supplier when
// asked for something.
final String column = Iterables.getOnlyElement(suppliers.keySet());
- final Supplier<Object> supplier = Iterables.getOnlyElement(suppliers.values());
+ final Pair<ExpressionType, Supplier<Object>> supplier = Iterables.getOnlyElement(suppliers.values());
- return identifierName -> {
- // There's only one binding, and it must be the single column, so it can safely be ignored in production.
- assert column.equals(identifierName);
- return supplier.get();
+ return new Expr.ObjectBinding()
+ {
+ @Nullable
+ @Override
+ public Object get(String name)
+ {
+ // There's only one binding, and it must be the single column, so it can safely be ignored in production.
+ assert column.equals(name);
+ return supplier.rhs.get();
+ }
+
+ @Nullable
+ @Override
+ public ExpressionType getType(String name)
+ {
+ return supplier.lhs;
+ }
};
} else {
- return InputBindings.withSuppliers(suppliers);
+ return InputBindings.withTypedSuppliers(suppliers);
}
}
@@ -350,9 +371,9 @@ public class ExpressionSelectors
} else {
// column selector factories hate you and use [] and [null] interchangeably for nullish data
if (row.size() == 0) {
- return new String[]{null};
+ return new Object[]{null};
}
- final String[] strings = new String[row.size()];
+ final Object[] strings = new Object[row.size()];
// noinspection SSBasedInspection
for (int i = 0; i < row.size(); i++) {
strings[i] = selector.lookupName(row.get(i));
@@ -382,25 +403,31 @@ public class ExpressionSelectors
// Might be Numbers and Strings. Use a selector that double-checks.
return () -> {
final Object val = selector.getObject();
- if (val instanceof Number || val instanceof String || (val != null && val.getClass().isArray())) {
- return val;
- } else if (val instanceof List) {
- return ExprEval.coerceListToArray((List) val, true);
+ if (val instanceof List) {
+ NonnullPair<ExpressionType, Object[]> coerced = ExprEval.coerceListToArray((List) val, true);
+ if (coerced == null) {
+ return null;
+ }
+ return coerced.rhs;
} else {
- return null;
+ return val;
}
};
} else if (clazz.isAssignableFrom(List.class)) {
return () -> {
final Object val = selector.getObject();
if (val != null) {
- return ExprEval.coerceListToArray((List) val, true);
+ NonnullPair<ExpressionType, Object[]> coerced = ExprEval.coerceListToArray((List) val, true);
+ if (coerced == null) {
+ return null;
+ }
+ return coerced.rhs;
}
return null;
};
} else {
- // No numbers or strings.
- return null;
+ // No numbers or strings, just pass it through
+ return selector::getObject;
}
}
@@ -412,16 +439,7 @@ public class ExpressionSelectors
public static Object coerceEvalToSelectorObject(ExprEval eval)
{
if (eval.type().isArray()) {
- switch (eval.type().getElementType().getType()) {
- case STRING:
- return Arrays.stream(eval.asStringArray()).collect(Collectors.toList());
- case DOUBLE:
- return Arrays.stream(eval.asDoubleArray()).collect(Collectors.toList());
- case LONG:
- return Arrays.stream(eval.asLongArray()).collect(Collectors.toList());
- default:
-
- }
+ return Arrays.stream(eval.asArray()).collect(Collectors.toList());
}
return eval.value();
}
diff --git a/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionVectorSelectors.java b/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionVectorSelectors.java
index 0e2e3cd..3ee12ff 100644
--- a/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionVectorSelectors.java
+++ b/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionVectorSelectors.java
@@ -23,9 +23,9 @@ import com.google.common.base.Preconditions;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprType;
import org.apache.druid.math.expr.ExpressionType;
+import org.apache.druid.math.expr.InputBindings;
import org.apache.druid.math.expr.vector.ExprVectorProcessor;
import org.apache.druid.query.dimension.DefaultDimensionSpec;
-import org.apache.druid.query.expression.ExprUtils;
import org.apache.druid.segment.column.ColumnCapabilities;
import org.apache.druid.segment.vector.ConstantVectorSelectors;
import org.apache.druid.segment.vector.SingleValueDimensionVectorSelector;
@@ -52,7 +52,7 @@ public class ExpressionVectorSelectors
// only constant expressions are currently supported, nothing else should get here
if (plan.isConstant()) {
- String constant = plan.getExpression().eval(ExprUtils.nilBindings()).asString();
+ String constant = plan.getExpression().eval(InputBindings.nilBindings()).asString();
return ConstantVectorSelectors.singleValueDimensionVectorSelector(factory.getReadableVectorInspector(), constant);
}
if (plan.is(ExpressionPlan.Trait.SINGLE_INPUT_SCALAR) && (plan.getOutputType() != null && plan.getOutputType().is(ExprType.STRING))) {
@@ -75,7 +75,7 @@ public class ExpressionVectorSelectors
if (plan.isConstant()) {
return ConstantVectorSelectors.vectorValueSelector(
factory.getReadableVectorInspector(),
- (Number) plan.getExpression().eval(ExprUtils.nilBindings()).value()
+ (Number) plan.getExpression().eval(InputBindings.nilBindings()).value()
);
}
final Expr.VectorInputBinding bindings = createVectorBindings(plan.getAnalysis(), factory);
@@ -94,7 +94,7 @@ public class ExpressionVectorSelectors
if (plan.isConstant()) {
return ConstantVectorSelectors.vectorObjectSelector(
factory.getReadableVectorInspector(),
- plan.getExpression().eval(ExprUtils.nilBindings()).value()
+ plan.getExpression().eval(InputBindings.nilBindings()).value()
);
}
diff --git a/processing/src/main/java/org/apache/druid/segment/virtual/RowBasedExpressionColumnValueSelector.java b/processing/src/main/java/org/apache/druid/segment/virtual/RowBasedExpressionColumnValueSelector.java
index 84117a4..9838103 100644
--- a/processing/src/main/java/org/apache/druid/segment/virtual/RowBasedExpressionColumnValueSelector.java
+++ b/processing/src/main/java/org/apache/druid/segment/virtual/RowBasedExpressionColumnValueSelector.java
@@ -34,7 +34,7 @@ import java.util.stream.Collectors;
* Expression column value selector that examines a set of 'unknown' type input bindings on a row by row basis,
* transforming the expression to handle multi-value list typed inputs as they are encountered.
*
- * Currently, string dimensions are the only bindings which might appear as a {@link String} or a {@link String[]}, so
+ * Currently, string dimensions are the only bindings which might appear as a {@link String} or a {@link Object[]}, so
* numbers are eliminated from the set of 'unknown' bindings to check as they are encountered.
*/
public class RowBasedExpressionColumnValueSelector extends ExpressionColumnValueSelector
@@ -94,7 +94,7 @@ public class RowBasedExpressionColumnValueSelector extends ExpressionColumnValue
{
Object binding = bindings.get(x);
if (binding != null) {
- if (binding instanceof String[]) {
+ if (binding instanceof Object[] && ((Object[]) binding).length > 0) {
return true;
} else if (binding instanceof Number) {
ignoredColumns.add(x);
diff --git a/processing/src/main/java/org/apache/druid/segment/virtual/SingleInputBindings.java b/processing/src/main/java/org/apache/druid/segment/virtual/SingleInputBindings.java
index 2525e09..fad855b 100644
--- a/processing/src/main/java/org/apache/druid/segment/virtual/SingleInputBindings.java
+++ b/processing/src/main/java/org/apache/druid/segment/virtual/SingleInputBindings.java
@@ -20,6 +20,7 @@
package org.apache.druid.segment.virtual;
import org.apache.druid.math.expr.Expr;
+import org.apache.druid.math.expr.ExpressionType;
import javax.annotation.Nullable;
@@ -28,6 +29,13 @@ public class SingleInputBindings implements Expr.ObjectBinding
@Nullable
private Object value;
+ private final ExpressionType type;
+
+ public SingleInputBindings(ExpressionType type)
+ {
+ this.type = type;
+ }
+
@Override
public Object get(final String name)
{
@@ -38,4 +46,11 @@ public class SingleInputBindings implements Expr.ObjectBinding
{
this.value = value;
}
+
+ @Nullable
+ @Override
+ public ExpressionType getType(String name)
+ {
+ return type;
+ }
}
diff --git a/processing/src/main/java/org/apache/druid/segment/virtual/SingleLongInputCachingExpressionColumnValueSelector.java b/processing/src/main/java/org/apache/druid/segment/virtual/SingleLongInputCachingExpressionColumnValueSelector.java
index a388b8a..d88c4fb 100644
--- a/processing/src/main/java/org/apache/druid/segment/virtual/SingleLongInputCachingExpressionColumnValueSelector.java
+++ b/processing/src/main/java/org/apache/druid/segment/virtual/SingleLongInputCachingExpressionColumnValueSelector.java
@@ -24,6 +24,7 @@ import it.unimi.dsi.fastutil.longs.Long2ObjectLinkedOpenHashMap;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprEval;
+import org.apache.druid.math.expr.ExpressionType;
import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector;
import org.apache.druid.segment.ColumnValueSelector;
import org.checkerframework.checker.nullness.qual.MonotonicNonNull;
@@ -41,7 +42,7 @@ public class SingleLongInputCachingExpressionColumnValueSelector implements Colu
private final ColumnValueSelector selector;
private final Expr expression;
- private final SingleInputBindings bindings = new SingleInputBindings();
+ private final SingleInputBindings bindings = new SingleInputBindings(ExpressionType.LONG);
@Nullable
private final LruEvalCache lruEvalCache;
diff --git a/processing/src/main/java/org/apache/druid/segment/virtual/SingleStringInputCachingExpressionColumnValueSelector.java b/processing/src/main/java/org/apache/druid/segment/virtual/SingleStringInputCachingExpressionColumnValueSelector.java
index cbd98cf..68a8522 100644
--- a/processing/src/main/java/org/apache/druid/segment/virtual/SingleStringInputCachingExpressionColumnValueSelector.java
+++ b/processing/src/main/java/org/apache/druid/segment/virtual/SingleStringInputCachingExpressionColumnValueSelector.java
@@ -25,6 +25,8 @@ import it.unimi.dsi.fastutil.ints.Int2ObjectLinkedOpenHashMap;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprEval;
+import org.apache.druid.math.expr.ExpressionType;
+import org.apache.druid.math.expr.InputBindings;
import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector;
import org.apache.druid.segment.ColumnValueSelector;
import org.apache.druid.segment.DimensionDictionarySelector;
@@ -63,7 +65,7 @@ public class SingleStringInputCachingExpressionColumnValueSelector implements Co
this.expression = Preconditions.checkNotNull(expression, "expression");
final Supplier<Object> inputSupplier = ExpressionSelectors.supplierFromDimensionSelector(selector, false);
- this.bindings = name -> inputSupplier.get();
+ this.bindings = InputBindings.singleProvider(ExpressionType.STRING, name -> inputSupplier.get());
if (selector.getValueCardinality() == DimensionDictionarySelector.CARDINALITY_UNKNOWN) {
throw new ISE("Selector must have a dictionary");
diff --git a/processing/src/main/java/org/apache/druid/segment/virtual/SingleStringInputDeferredEvaluationExpressionDimensionSelector.java b/processing/src/main/java/org/apache/druid/segment/virtual/SingleStringInputDeferredEvaluationExpressionDimensionSelector.java
index f4f5838..d89033a 100644
--- a/processing/src/main/java/org/apache/druid/segment/virtual/SingleStringInputDeferredEvaluationExpressionDimensionSelector.java
+++ b/processing/src/main/java/org/apache/druid/segment/virtual/SingleStringInputDeferredEvaluationExpressionDimensionSelector.java
@@ -23,6 +23,7 @@ import com.google.common.base.Preconditions;
import com.google.common.base.Predicate;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.math.expr.Expr;
+import org.apache.druid.math.expr.ExpressionType;
import org.apache.druid.query.filter.ValueMatcher;
import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector;
import org.apache.druid.segment.DimensionDictionarySelector;
@@ -48,7 +49,7 @@ public class SingleStringInputDeferredEvaluationExpressionDimensionSelector impl
{
private final DimensionSelector selector;
private final Expr expression;
- private final SingleInputBindings bindings = new SingleInputBindings();
+ private final SingleInputBindings bindings = new SingleInputBindings(ExpressionType.STRING);
public SingleStringInputDeferredEvaluationExpressionDimensionSelector(
final DimensionSelector selector,
diff --git a/processing/src/test/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactoryTest.java b/processing/src/test/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactoryTest.java
index 1c3b4c72..cd7835d 100644
--- a/processing/src/test/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactoryTest.java
+++ b/processing/src/test/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactoryTest.java
@@ -26,6 +26,7 @@ import nl.jqno.equalsverifier.EqualsVerifier;
import org.apache.druid.java.util.common.HumanReadableBytes;
import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.query.Druids;
+import org.apache.druid.query.aggregation.hyperloglog.HyperUniquesAggregatorFactory;
import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator;
import org.apache.druid.query.aggregation.post.FinalizingFieldAccessPostAggregator;
import org.apache.druid.query.expression.TestExprMacroTable;
@@ -189,7 +190,7 @@ public class ExpressionLambdaAggregatorFactoryTest extends InitializedNullHandli
ImmutableSet.of("x"),
null,
"0",
- null,
+ "<LONG>[]",
true,
"array_set_add(__acc, x)",
"array_set_add_all(__acc, expr_agg_name)",
@@ -411,6 +412,52 @@ public class ExpressionLambdaAggregatorFactoryTest extends InitializedNullHandli
}
@Test
+ public void testComplexType()
+ {
+ ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory(
+ "expr_agg_name",
+ ImmutableSet.of("some_column"),
+ null,
+ "hyper_unique()",
+ null,
+ null,
+ "hyper_unique_add(some_column, __acc)",
+ "hyper_unique_add(__acc, expr_agg_name)",
+ null,
+ null,
+ new HumanReadableBytes(2048),
+ TestExprMacroTable.INSTANCE
+ );
+
+ Assert.assertEquals(HyperUniquesAggregatorFactory.TYPE, agg.getType());
+ Assert.assertEquals(HyperUniquesAggregatorFactory.TYPE, agg.getCombiningFactory().getType());
+ Assert.assertEquals(HyperUniquesAggregatorFactory.TYPE, agg.getFinalizedType());
+ }
+
+ @Test
+ public void testComplexTypeFinalized()
+ {
+ ExpressionLambdaAggregatorFactory agg = new ExpressionLambdaAggregatorFactory(
+ "expr_agg_name",
+ ImmutableSet.of("some_column"),
+ null,
+ "hyper_unique()",
+ null,
+ null,
+ "hyper_unique_add(some_column, __acc)",
+ "hyper_unique_add(__acc, expr_agg_name)",
+ null,
+ "hyper_unique_estimate(o)",
+ new HumanReadableBytes(2048),
+ TestExprMacroTable.INSTANCE
+ );
+
+ Assert.assertEquals(HyperUniquesAggregatorFactory.TYPE, agg.getType());
+ Assert.assertEquals(HyperUniquesAggregatorFactory.TYPE, agg.getCombiningFactory().getType());
+ Assert.assertEquals(ColumnType.DOUBLE, agg.getFinalizedType());
+ }
+
+ @Test
public void testResultArraySignature()
{
final TimeseriesQuery query =
@@ -544,6 +591,34 @@ public class ExpressionLambdaAggregatorFactoryTest extends InitializedNullHandli
"fold((x, acc) -> x + acc, o, 0)",
new HumanReadableBytes(2048),
TestExprMacroTable.INSTANCE
+ ),
+ new ExpressionLambdaAggregatorFactory(
+ "complex_expr",
+ ImmutableSet.of("some_column"),
+ null,
+ "hyper_unique()",
+ null,
+ null,
+ "hyper_unique_add(some_column, __acc)",
+ "hyper_unique_add(__acc, expr_agg_name)",
+ null,
+ null,
+ new HumanReadableBytes(2048),
+ TestExprMacroTable.INSTANCE
+ ),
+ new ExpressionLambdaAggregatorFactory(
+ "complex_expr_finalized",
+ ImmutableSet.of("some_column"),
+ null,
+ "hyper_unique()",
+ null,
+ null,
+ "hyper_unique_add(some_column, __acc)",
+ "hyper_unique_add(__acc, expr_agg_name)",
+ null,
+ "hyper_unique_estimate(o)",
+ new HumanReadableBytes(2048),
+ TestExprMacroTable.INSTANCE
)
)
.postAggregators(
@@ -552,7 +627,9 @@ public class ExpressionLambdaAggregatorFactoryTest extends InitializedNullHandli
new FieldAccessPostAggregator("double-array-expr-access", "double_array_expr_finalized"),
new FinalizingFieldAccessPostAggregator("double-array-expr-finalize", "double_array_expr_finalized"),
new FieldAccessPostAggregator("long-array-expr-access", "long_array_expr_finalized"),
- new FinalizingFieldAccessPostAggregator("long-array-expr-finalize", "long_array_expr_finalized")
+ new FinalizingFieldAccessPostAggregator("long-array-expr-finalize", "long_array_expr_finalized"),
+ new FieldAccessPostAggregator("complex-expr-access", "complex_expr_finalized"),
+ new FinalizingFieldAccessPostAggregator("complex-expr-finalize", "complex_expr_finalized")
)
.build();
@@ -576,6 +653,10 @@ public class ExpressionLambdaAggregatorFactoryTest extends InitializedNullHandli
.add("double_array_expr_finalized", null)
// long because fold type equals finalized type, even though merge type is array
.add("long_array_expr_finalized", ColumnType.LONG)
+ .add("complex_expr", HyperUniquesAggregatorFactory.TYPE)
+ // type does not equal finalized type. (combining factory type does equal finalized type,
+ // but this signature doesn't use combining factory)
+ .add("complex_expr_finalized", null)
// fold type is string
.add("string-array-expr-access", ColumnType.STRING)
// finalized type is string
@@ -588,6 +669,8 @@ public class ExpressionLambdaAggregatorFactoryTest extends InitializedNullHandli
.add("long-array-expr-access", ColumnType.LONG)
// finalized type is long
.add("long-array-expr-finalize", ColumnType.LONG)
+ .add("complex-expr-access", HyperUniquesAggregatorFactory.TYPE)
+ .add("complex-expr-finalize", ColumnType.DOUBLE)
.build(),
new TimeseriesQueryQueryToolChest().resultArraySignature(query)
);
diff --git a/processing/src/test/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorTest.java b/processing/src/test/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorTest.java
new file mode 100644
index 0000000..b652656
--- /dev/null
+++ b/processing/src/test/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorTest.java
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.query.aggregation;
+
+import org.apache.druid.java.util.common.ISE;
+import org.apache.druid.math.expr.ExprEval;
+import org.apache.druid.math.expr.ExpressionType;
+import org.apache.druid.testing.InitializedNullHandlingTest;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+
+public class ExpressionLambdaAggregatorTest extends InitializedNullHandlingTest
+{
+ @Rule
+ public ExpectedException expectedException = ExpectedException.none();
+
+ @Test
+ public void testEstimateString()
+ {
+ ExpressionLambdaAggregator.estimateAndCheckMaxBytes(ExprEval.ofType(ExpressionType.STRING, "hello"), 10);
+ }
+
+ @Test
+ public void testEstimateStringTooBig()
+ {
+ expectedException.expect(ISE.class);
+ expectedException.expectMessage("Unable to serialize [STRING], size [12] is larger than max [5]");
+ ExpressionLambdaAggregator.estimateAndCheckMaxBytes(ExprEval.ofType(ExpressionType.STRING, "too big"), 5);
+ }
+
+ @Test
+ public void testEstimateStringArray()
+ {
+ ExpressionLambdaAggregator.estimateAndCheckMaxBytes(
+ ExprEval.ofType(ExpressionType.STRING_ARRAY, new Object[] {"a", "b", "c", "d"}),
+ 30
+ );
+ }
+
+ @Test
+ public void testEstimateStringArrayTooBig()
+ {
+ expectedException.expect(ISE.class);
+ expectedException.expectMessage("Unable to serialize [ARRAY<STRING>], size [25] is larger than max [15]");
+ ExpressionLambdaAggregator.estimateAndCheckMaxBytes(
+ ExprEval.ofType(ExpressionType.STRING_ARRAY, new Object[] {"a", "b", "c", "d"}),
+ 15
+ );
+ }
+
+ @Test
+ public void testEstimateLongArray()
+ {
+ ExpressionLambdaAggregator.estimateAndCheckMaxBytes(
+ ExprEval.ofType(ExpressionType.LONG_ARRAY, new Object[] {1L, 2L, 3L, 4L}),
+ 64
+ );
+ }
+
+ @Test
+ public void testEstimateLongArrayTooBig()
+ {
+ expectedException.expect(ISE.class);
+ expectedException.expectMessage("Unable to serialize [ARRAY<LONG>], size [41] is larger than max [24]");
+ ExpressionLambdaAggregator.estimateAndCheckMaxBytes(
+ ExprEval.ofType(ExpressionType.LONG_ARRAY, new Object[] {1L, 2L, 3L, 4L}),
+ 24
+ );
+ }
+
+ @Test
+ public void testEstimateDoubleArray()
+ {
+ ExpressionLambdaAggregator.estimateAndCheckMaxBytes(
+ ExprEval.ofType(ExpressionType.DOUBLE_ARRAY, new Object[] {1.0, 2.0, 3.0, 4.0}),
+ 64
+ );
+ }
+
+ @Test
+ public void testEstimateDoubleArrayTooBig()
+ {
+ expectedException.expect(ISE.class);
+ expectedException.expectMessage("Unable to serialize [ARRAY<DOUBLE>], size [41] is larger than max [24]");
+ ExpressionLambdaAggregator.estimateAndCheckMaxBytes(
+ ExprEval.ofType(ExpressionType.DOUBLE_ARRAY, new Object[] {1.0, 2.0, 3.0, 4.0}),
+ 24
+ );
+ }
+}
diff --git a/processing/src/test/java/org/apache/druid/query/expression/CaseInsensitiveExprMacroTest.java b/processing/src/test/java/org/apache/druid/query/expression/CaseInsensitiveExprMacroTest.java
index 54a1539..ca03bbe 100644
--- a/processing/src/test/java/org/apache/druid/query/expression/CaseInsensitiveExprMacroTest.java
+++ b/processing/src/test/java/org/apache/druid/query/expression/CaseInsensitiveExprMacroTest.java
@@ -135,7 +135,7 @@ public class CaseInsensitiveExprMacroTest extends MacroTestBase
final ExprEval<?> result = eval(
"icontains_string(a, null)",
- InputBindings.withSuppliers(ImmutableMap.of("a", () -> null))
+ InputBindings.nilBindings()
);
Assert.assertEquals(
ExprEval.ofBoolean(true, ExprType.LONG).value(),
@@ -146,7 +146,7 @@ public class CaseInsensitiveExprMacroTest extends MacroTestBase
@Test
public void testEmptyStringSearchOnNull()
{
- final ExprEval<?> result = eval("icontains_string(a, '')", InputBindings.withSuppliers(ImmutableMap.of("a", () -> null)));
+ final ExprEval<?> result = eval("icontains_string(a, '')", InputBindings.nilBindings());
Assert.assertEquals(
ExprEval.ofBoolean(!NullHandling.sqlCompatible(), ExprType.LONG).value(),
result.value()
diff --git a/processing/src/test/java/org/apache/druid/query/expression/ContainsExprMacroTest.java b/processing/src/test/java/org/apache/druid/query/expression/ContainsExprMacroTest.java
index decd899..883a435 100644
--- a/processing/src/test/java/org/apache/druid/query/expression/ContainsExprMacroTest.java
+++ b/processing/src/test/java/org/apache/druid/query/expression/ContainsExprMacroTest.java
@@ -123,7 +123,7 @@ public class ContainsExprMacroTest extends MacroTestBase
expectException(IllegalArgumentException.class, "Function[contains_string] substring must be a string literal");
}
- final ExprEval<?> result = eval("contains_string(a, null)", InputBindings.withSuppliers(ImmutableMap.of("a", () -> null)));
+ final ExprEval<?> result = eval("contains_string(a, null)", InputBindings.nilBindings());
Assert.assertEquals(
ExprEval.ofBoolean(true, ExprType.LONG).value(),
result.value()
@@ -133,7 +133,7 @@ public class ContainsExprMacroTest extends MacroTestBase
@Test
public void testEmptyStringSearchOnNull()
{
- final ExprEval<?> result = eval("contains_string(a, '')", InputBindings.withSuppliers(ImmutableMap.of("a", () -> null)));
+ final ExprEval<?> result = eval("contains_string(a, '')", InputBindings.nilBindings());
Assert.assertEquals(
ExprEval.ofBoolean(!NullHandling.sqlCompatible(), ExprType.LONG).value(),
result.value()
diff --git a/processing/src/test/java/org/apache/druid/query/expression/HyperUniqueExpressionsTest.java b/processing/src/test/java/org/apache/druid/query/expression/HyperUniqueExpressionsTest.java
new file mode 100644
index 0000000..a20099b
--- /dev/null
+++ b/processing/src/test/java/org/apache/druid/query/expression/HyperUniqueExpressionsTest.java
@@ -0,0 +1,256 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.query.expression;
+
+import com.google.common.base.Supplier;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import org.apache.druid.common.config.NullHandling;
+import org.apache.druid.hll.HyperLogLogCollector;
+import org.apache.druid.java.util.common.IAE;
+import org.apache.druid.java.util.common.Pair;
+import org.apache.druid.math.expr.Expr;
+import org.apache.druid.math.expr.ExprEval;
+import org.apache.druid.math.expr.ExprMacroTable;
+import org.apache.druid.math.expr.ExpressionType;
+import org.apache.druid.math.expr.InputBindings;
+import org.apache.druid.math.expr.Parser;
+import org.apache.druid.testing.InitializedNullHandlingTest;
+import org.junit.Assert;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+
+public class HyperUniqueExpressionsTest extends InitializedNullHandlingTest
+{
+ private static final ExprMacroTable MACRO_TABLE = new ExprMacroTable(
+ ImmutableList.of(
+ new HyperUniqueExpressions.HllCreateExprMacro(),
+ new HyperUniqueExpressions.HllAddExprMacro(),
+ new HyperUniqueExpressions.HllEstimateExprMacro(),
+ new HyperUniqueExpressions.HllRoundEstimateExprMacro()
+ )
+ );
+
+ private static final String SOME_STRING = "foo";
+ private static final long SOME_LONG = 1234L;
+ private static final double SOME_DOUBLE = 1.234;
+
+ Expr.ObjectBinding inputBindings = InputBindings.withTypedSuppliers(
+ new ImmutableMap.Builder<String, Pair<ExpressionType, Supplier<Object>>>()
+ .put("hll", new Pair<>(HyperUniqueExpressions.TYPE, HyperLogLogCollector::makeLatestCollector))
+ .put("string", new Pair<>(ExpressionType.STRING, () -> SOME_STRING))
+ .put("long", new Pair<>(ExpressionType.LONG, () -> SOME_LONG))
+ .put("double", new Pair<>(ExpressionType.DOUBLE, () -> SOME_DOUBLE))
+ .put("nullString", new Pair<>(ExpressionType.STRING, () -> null))
+ .put("nullLong", new Pair<>(ExpressionType.LONG, () -> null))
+ .put("nullDouble", new Pair<>(ExpressionType.DOUBLE, () -> null))
+ .build()
+ );
+
+ @Rule
+ public ExpectedException expectedException = ExpectedException.none();
+
+ @Test
+ public void testCreate()
+ {
+ Expr expr = Parser.parse("hyper_unique()", MACRO_TABLE);
+ ExprEval eval = expr.eval(inputBindings);
+
+ Assert.assertEquals(HyperUniqueExpressions.TYPE, eval.type());
+ Assert.assertTrue(eval.value() instanceof HyperLogLogCollector);
+ Assert.assertEquals(0.0, ((HyperLogLogCollector) eval.value()).estimateCardinality(), 0);
+ }
+
+ @Test
+ public void testString()
+ {
+ Expr expr = Parser.parse("hyper_unique_add('foo', hyper_unique())", MACRO_TABLE);
+ ExprEval eval = expr.eval(inputBindings);
+
+ Assert.assertEquals(HyperUniqueExpressions.TYPE, eval.type());
+ Assert.assertTrue(eval.value() instanceof HyperLogLogCollector);
+ Assert.assertEquals(1.0, ((HyperLogLogCollector) eval.value()).estimateCardinality(), 0.01);
+
+ expr = Parser.parse("hyper_unique_add('bar', hyper_unique_add('foo', hyper_unique()))", MACRO_TABLE);
+ eval = expr.eval(inputBindings);
+
+ Assert.assertEquals(HyperUniqueExpressions.TYPE, eval.type());
+ Assert.assertTrue(eval.value() instanceof HyperLogLogCollector);
+ Assert.assertEquals(2.0, ((HyperLogLogCollector) eval.value()).estimateCardinality(), 0.01);
+
+ expr = Parser.parse("hyper_unique_add(string, hyper_unique())", MACRO_TABLE);
+ eval = expr.eval(inputBindings);
+
+ Assert.assertEquals(HyperUniqueExpressions.TYPE, eval.type());
+ Assert.assertTrue(eval.value() instanceof HyperLogLogCollector);
+ Assert.assertEquals(1.0, ((HyperLogLogCollector) eval.value()).estimateCardinality(), 0.01);
+
+ expr = Parser.parse("hyper_unique_add(nullString, hyper_unique())", MACRO_TABLE);
+ eval = expr.eval(inputBindings);
+
+ Assert.assertEquals(HyperUniqueExpressions.TYPE, eval.type());
+ Assert.assertTrue(eval.value() instanceof HyperLogLogCollector);
+ Assert.assertEquals(NullHandling.replaceWithDefault() ? 1.0 : 0.0, ((HyperLogLogCollector) eval.value()).estimateCardinality(), 0.01);
+ }
+
+ @Test
+ public void testLong()
+ {
+ Expr expr = Parser.parse("hyper_unique_add(1234, hyper_unique())", MACRO_TABLE);
+ ExprEval eval = expr.eval(inputBindings);
+
+ Assert.assertEquals(HyperUniqueExpressions.TYPE, eval.type());
+ Assert.assertTrue(eval.value() instanceof HyperLogLogCollector);
+ Assert.assertEquals(1.0, ((HyperLogLogCollector) eval.value()).estimateCardinality(), 0.01);
+
+ expr = Parser.parse("hyper_unique_add(1234, hyper_unique_add(5678, hyper_unique()))", MACRO_TABLE);
+ eval = expr.eval(inputBindings);
+
+ Assert.assertEquals(HyperUniqueExpressions.TYPE, eval.type());
+ Assert.assertTrue(eval.value() instanceof HyperLogLogCollector);
+ Assert.assertEquals(2.0, ((HyperLogLogCollector) eval.value()).estimateCardinality(), 0.01);
+
+ expr = Parser.parse("hyper_unique_add(long, hyper_unique())", MACRO_TABLE);
+ eval = expr.eval(inputBindings);
+
+ Assert.assertEquals(HyperUniqueExpressions.TYPE, eval.type());
+ Assert.assertTrue(eval.value() instanceof HyperLogLogCollector);
+ Assert.assertEquals(1.0, ((HyperLogLogCollector) eval.value()).estimateCardinality(), 0.01);
+
+ expr = Parser.parse("hyper_unique_add(nullLong, hyper_unique())", MACRO_TABLE);
+ eval = expr.eval(inputBindings);
+
+ Assert.assertEquals(HyperUniqueExpressions.TYPE, eval.type());
+ Assert.assertTrue(eval.value() instanceof HyperLogLogCollector);
+ Assert.assertEquals(NullHandling.replaceWithDefault() ? 1.0 : 0.0, ((HyperLogLogCollector) eval.value()).estimateCardinality(), 0.01);
+ }
+
+ @Test
+ public void testDouble()
+ {
+ Expr expr = Parser.parse("hyper_unique_add(1.234, hyper_unique())", MACRO_TABLE);
+ ExprEval eval = expr.eval(inputBindings);
+
+ Assert.assertEquals(HyperUniqueExpressions.TYPE, eval.type());
+ Assert.assertTrue(eval.value() instanceof HyperLogLogCollector);
+ Assert.assertEquals(1.0, ((HyperLogLogCollector) eval.value()).estimateCardinality(), 0.01);
+
+ expr = Parser.parse("hyper_unique_add(1.234, hyper_unique_add(5.678, hyper_unique()))", MACRO_TABLE);
+ eval = expr.eval(inputBindings);
+
+ Assert.assertEquals(HyperUniqueExpressions.TYPE, eval.type());
+ Assert.assertTrue(eval.value() instanceof HyperLogLogCollector);
+ Assert.assertEquals(2.0, ((HyperLogLogCollector) eval.value()).estimateCardinality(), 0.01);
+
+ expr = Parser.parse("hyper_unique_add(double, hyper_unique())", MACRO_TABLE);
+ eval = expr.eval(inputBindings);
+
+ Assert.assertEquals(HyperUniqueExpressions.TYPE, eval.type());
+ Assert.assertTrue(eval.value() instanceof HyperLogLogCollector);
+ Assert.assertEquals(1.0, ((HyperLogLogCollector) eval.value()).estimateCardinality(), 0.01);
+
+ expr = Parser.parse("hyper_unique_add(nullDouble, hyper_unique())", MACRO_TABLE);
+ eval = expr.eval(inputBindings);
+
+ Assert.assertEquals(HyperUniqueExpressions.TYPE, eval.type());
+ Assert.assertTrue(eval.value() instanceof HyperLogLogCollector);
+ Assert.assertEquals(NullHandling.replaceWithDefault() ? 1.0 : 0.0, ((HyperLogLogCollector) eval.value()).estimateCardinality(), 0.01);
+ }
+
+ @Test
+ public void testEstimate()
+ {
+ Expr expr = Parser.parse("hyper_unique_estimate(hyper_unique_add(1.234, hyper_unique()))", MACRO_TABLE);
+ ExprEval eval = expr.eval(inputBindings);
+
+ Assert.assertEquals(ExpressionType.DOUBLE, eval.type());
+ Assert.assertEquals(1.0, eval.asDouble(), 0.01);
+ }
+
+ @Test
+ public void testEstimateRound()
+ {
+ Expr expr = Parser.parse("hyper_unique_round_estimate(hyper_unique_add(1.234, hyper_unique()))", MACRO_TABLE);
+ ExprEval eval = expr.eval(inputBindings);
+
+ Assert.assertEquals(ExpressionType.LONG, eval.type());
+ Assert.assertEquals(1L, eval.asLong(), 0.01);
+ }
+
+ @Test
+ public void testCreateWrongArgsCount()
+ {
+ expectedException.expect(IAE.class);
+ expectedException.expectMessage("Function[hyper_unique] must have no arguments");
+ Parser.parse("hyper_unique(100)", MACRO_TABLE);
+ }
+
+ @Test
+ public void testAddWrongArgsCount()
+ {
+ expectedException.expect(IAE.class);
+ expectedException.expectMessage("Function[hyper_unique_add] must have 2 arguments");
+ Parser.parse("hyper_unique_add(100, hyper_unique(), 100)", MACRO_TABLE);
+ }
+
+ @Test
+ public void testAddWrongArgType()
+ {
+ expectedException.expect(IAE.class);
+ expectedException.expectMessage("Function[hyper_unique_add] must take a hyper-log-log collector as the second argument");
+ Expr expr = Parser.parse("hyper_unique_add(long, string)", MACRO_TABLE);
+ expr.eval(inputBindings);
+ }
+
+ @Test
+ public void testEstimateWrongArgsCount()
+ {
+ expectedException.expect(IAE.class);
+ expectedException.expectMessage("Function[hyper_unique_estimate] must have 1 argument");
+ Parser.parse("hyper_unique_estimate(hyper_unique(), 100)", MACRO_TABLE);
+ }
+
+ @Test
+ public void testEstimateWrongArgTypes()
+ {
+ expectedException.expect(IAE.class);
+ expectedException.expectMessage("Function[hyper_unique_estimate] must take a hyper-log-log collector as input");
+ Expr expr = Parser.parse("hyper_unique_estimate(100)", MACRO_TABLE);
+ expr.eval(inputBindings);
+ }
+
+ @Test
+ public void testRoundEstimateWrongArgsCount()
+ {
+ expectedException.expect(IAE.class);
+ expectedException.expectMessage("Function[hyper_unique_round_estimate] must have 1 argument");
+ Parser.parse("hyper_unique_round_estimate(hyper_unique(), 100)", MACRO_TABLE);
+ }
+
+ @Test
+ public void testRoundEstimateWrongArgTypes()
+ {
+ expectedException.expect(IAE.class);
+ expectedException.expectMessage("Function[hyper_unique_round_estimate] must take a hyper-log-log collector as input");
+ Expr expr = Parser.parse("hyper_unique_round_estimate(string)", MACRO_TABLE);
+ expr.eval(inputBindings);
+ }
+}
diff --git a/processing/src/test/java/org/apache/druid/query/expression/IPv4AddressMatchExprMacroTest.java b/processing/src/test/java/org/apache/druid/query/expression/IPv4AddressMatchExprMacroTest.java
index aa5bd91..d62d248 100644
--- a/processing/src/test/java/org/apache/druid/query/expression/IPv4AddressMatchExprMacroTest.java
+++ b/processing/src/test/java/org/apache/druid/query/expression/IPv4AddressMatchExprMacroTest.java
@@ -22,6 +22,7 @@ package org.apache.druid.query.expression;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprEval;
import org.apache.druid.math.expr.ExprMacroTable;
+import org.apache.druid.math.expr.InputBindings;
import org.junit.Assert;
import org.junit.Test;
@@ -179,7 +180,7 @@ public class IPv4AddressMatchExprMacroTest extends MacroTestBase
private boolean eval(Expr... args)
{
Expr expr = apply(Arrays.asList(args));
- ExprEval eval = expr.eval(ExprUtils.nilBindings());
+ ExprEval eval = expr.eval(InputBindings.nilBindings());
return eval.asBoolean();
}
diff --git a/processing/src/test/java/org/apache/druid/query/expression/IPv4AddressParseExprMacroTest.java b/processing/src/test/java/org/apache/druid/query/expression/IPv4AddressParseExprMacroTest.java
index 0d70b2c..0bfaa5d 100644
--- a/processing/src/test/java/org/apache/druid/query/expression/IPv4AddressParseExprMacroTest.java
+++ b/processing/src/test/java/org/apache/druid/query/expression/IPv4AddressParseExprMacroTest.java
@@ -22,6 +22,7 @@ package org.apache.druid.query.expression;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprEval;
+import org.apache.druid.math.expr.InputBindings;
import org.junit.Assert;
import org.junit.Test;
@@ -151,7 +152,7 @@ public class IPv4AddressParseExprMacroTest extends MacroTestBase
private Object eval(Expr arg)
{
Expr expr = apply(Collections.singletonList(arg));
- ExprEval eval = expr.eval(ExprUtils.nilBindings());
+ ExprEval eval = expr.eval(InputBindings.nilBindings());
return eval.value();
}
}
diff --git a/processing/src/test/java/org/apache/druid/query/expression/IPv4AddressStringifyExprMacroTest.java b/processing/src/test/java/org/apache/druid/query/expression/IPv4AddressStringifyExprMacroTest.java
index 1b4235b..fc6d893 100644
--- a/processing/src/test/java/org/apache/druid/query/expression/IPv4AddressStringifyExprMacroTest.java
+++ b/processing/src/test/java/org/apache/druid/query/expression/IPv4AddressStringifyExprMacroTest.java
@@ -22,6 +22,7 @@ package org.apache.druid.query.expression;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprEval;
+import org.apache.druid.math.expr.InputBindings;
import org.junit.Assert;
import org.junit.Test;
@@ -147,7 +148,7 @@ public class IPv4AddressStringifyExprMacroTest extends MacroTestBase
private Object eval(Expr arg)
{
Expr expr = apply(Collections.singletonList(arg));
- ExprEval eval = expr.eval(ExprUtils.nilBindings());
+ ExprEval eval = expr.eval(InputBindings.nilBindings());
return eval.value();
}
}
diff --git a/processing/src/test/java/org/apache/druid/query/expression/RegexpExtractExprMacroTest.java b/processing/src/test/java/org/apache/druid/query/expression/RegexpExtractExprMacroTest.java
index 8d7a322..bf432a2 100644
--- a/processing/src/test/java/org/apache/druid/query/expression/RegexpExtractExprMacroTest.java
+++ b/processing/src/test/java/org/apache/druid/query/expression/RegexpExtractExprMacroTest.java
@@ -128,14 +128,14 @@ public class RegexpExtractExprMacroTest extends MacroTestBase
expectException(IllegalArgumentException.class, "Function[regexp_extract] pattern must be a string literal");
}
- final ExprEval<?> result = eval("regexp_extract(a, null)", InputBindings.withSuppliers(ImmutableMap.of("a", () -> null)));
+ final ExprEval<?> result = eval("regexp_extract(a, null)", InputBindings.nilBindings());
Assert.assertNull(result.value());
}
@Test
public void testEmptyStringPatternOnNull()
{
- final ExprEval<?> result = eval("regexp_extract(a, '')", InputBindings.withSuppliers(ImmutableMap.of("a", () -> null)));
+ final ExprEval<?> result = eval("regexp_extract(a, '')", InputBindings.nilBindings());
Assert.assertNull(result.value());
}
}
diff --git a/processing/src/test/java/org/apache/druid/query/expression/RegexpLikeExprMacroTest.java b/processing/src/test/java/org/apache/druid/query/expression/RegexpLikeExprMacroTest.java
index fb6d99f..77eea92 100644
--- a/processing/src/test/java/org/apache/druid/query/expression/RegexpLikeExprMacroTest.java
+++ b/processing/src/test/java/org/apache/druid/query/expression/RegexpLikeExprMacroTest.java
@@ -122,7 +122,7 @@ public class RegexpLikeExprMacroTest extends MacroTestBase
expectException(IllegalArgumentException.class, "Function[regexp_like] pattern must be a string literal");
}
- final ExprEval<?> result = eval("regexp_like(a, null)", InputBindings.withSuppliers(ImmutableMap.of("a", () -> null)));
+ final ExprEval<?> result = eval("regexp_like(a, null)", InputBindings.nilBindings());
Assert.assertEquals(
ExprEval.ofLongBoolean(true).value(),
result.value()
@@ -132,7 +132,7 @@ public class RegexpLikeExprMacroTest extends MacroTestBase
@Test
public void testEmptyStringPatternOnNull()
{
- final ExprEval<?> result = eval("regexp_like(a, '')", InputBindings.withSuppliers(ImmutableMap.of("a", () -> null)));
+ final ExprEval<?> result = eval("regexp_like(a, '')", InputBindings.nilBindings());
Assert.assertEquals(
ExprEval.ofLongBoolean(NullHandling.replaceWithDefault()).value(),
result.value()
diff --git a/processing/src/test/java/org/apache/druid/query/expression/TestExprMacroTable.java b/processing/src/test/java/org/apache/druid/query/expression/TestExprMacroTable.java
index c617099..424faa6 100644
--- a/processing/src/test/java/org/apache/druid/query/expression/TestExprMacroTable.java
+++ b/processing/src/test/java/org/apache/druid/query/expression/TestExprMacroTable.java
@@ -43,7 +43,11 @@ public class TestExprMacroTable extends ExprMacroTable
new TimestampShiftExprMacro(),
new TrimExprMacro.BothTrimExprMacro(),
new TrimExprMacro.LeftTrimExprMacro(),
- new TrimExprMacro.RightTrimExprMacro()
+ new TrimExprMacro.RightTrimExprMacro(),
+ new HyperUniqueExpressions.HllCreateExprMacro(),
+ new HyperUniqueExpressions.HllAddExprMacro(),
+ new HyperUniqueExpressions.HllEstimateExprMacro(),
+ new HyperUniqueExpressions.HllRoundEstimateExprMacro()
)
);
}
diff --git a/processing/src/test/java/org/apache/druid/query/expression/TimestampExtractExprMacroTest.java b/processing/src/test/java/org/apache/druid/query/expression/TimestampExtractExprMacroTest.java
index 60be22e..c5f3e0f 100644
--- a/processing/src/test/java/org/apache/druid/query/expression/TimestampExtractExprMacroTest.java
+++ b/processing/src/test/java/org/apache/druid/query/expression/TimestampExtractExprMacroTest.java
@@ -23,6 +23,7 @@ import com.google.common.collect.ImmutableList;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprEval;
+import org.apache.druid.math.expr.InputBindings;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
@@ -53,7 +54,7 @@ public class TimestampExtractExprMacroTest
ExprEval.of("2001-02-16").toExpr(),
ExprEval.of(TimestampExtractExprMacro.Unit.DECADE.toString()).toExpr()
));
- Assert.assertEquals(200, expression.eval(ExprUtils.nilBindings()).asInt());
+ Assert.assertEquals(200, expression.eval(InputBindings.nilBindings()).asInt());
}
@Test
@@ -64,7 +65,7 @@ public class TimestampExtractExprMacroTest
ExprEval.of("2000-12-16").toExpr(),
ExprEval.of(TimestampExtractExprMacro.Unit.CENTURY.toString()).toExpr()
));
- Assert.assertEquals(20, expression.eval(ExprUtils.nilBindings()).asInt());
+ Assert.assertEquals(20, expression.eval(InputBindings.nilBindings()).asInt());
}
@Test
@@ -75,7 +76,7 @@ public class TimestampExtractExprMacroTest
ExprEval.of("2001-02-16").toExpr(),
ExprEval.of(TimestampExtractExprMacro.Unit.CENTURY.toString()).toExpr()
));
- Assert.assertEquals(21, expression.eval(ExprUtils.nilBindings()).asInt());
+ Assert.assertEquals(21, expression.eval(InputBindings.nilBindings()).asInt());
}
@Test
@@ -86,7 +87,7 @@ public class TimestampExtractExprMacroTest
ExprEval.of("2000-12-16").toExpr(),
ExprEval.of(TimestampExtractExprMacro.Unit.MILLENNIUM.toString()).toExpr()
));
- Assert.assertEquals(2, expression.eval(ExprUtils.nilBindings()).asInt());
+ Assert.assertEquals(2, expression.eval(InputBindings.nilBindings()).asInt());
}
@Test
@@ -97,6 +98,6 @@ public class TimestampExtractExprMacroTest
ExprEval.of("2001-02-16").toExpr(),
ExprEval.of(TimestampExtractExprMacro.Unit.MILLENNIUM.toString()).toExpr()
));
- Assert.assertEquals(3, expression.eval(ExprUtils.nilBindings()).asInt());
+ Assert.assertEquals(3, expression.eval(InputBindings.nilBindings()).asInt());
}
}
diff --git a/processing/src/test/java/org/apache/druid/query/expression/TimestampShiftMacroTest.java b/processing/src/test/java/org/apache/druid/query/expression/TimestampShiftMacroTest.java
index 05945b1..44d9494 100644
--- a/processing/src/test/java/org/apache/druid/query/expression/TimestampShiftMacroTest.java
+++ b/processing/src/test/java/org/apache/druid/query/expression/TimestampShiftMacroTest.java
@@ -26,6 +26,8 @@ import org.apache.druid.java.util.common.IAE;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprEval;
import org.apache.druid.math.expr.ExprMacroTable;
+import org.apache.druid.math.expr.ExpressionType;
+import org.apache.druid.math.expr.InputBindings;
import org.joda.time.DateTime;
import org.joda.time.Days;
import org.joda.time.Minutes;
@@ -102,7 +104,7 @@ public class TimestampShiftMacroTest extends MacroTestBase
Assert.assertEquals(
timestamp.withPeriodAdded(Months.ONE, step).getMillis(),
- expr.eval(ExprUtils.nilBindings()).asLong()
+ expr.eval(InputBindings.nilBindings()).asLong()
);
}
@@ -119,7 +121,7 @@ public class TimestampShiftMacroTest extends MacroTestBase
Assert.assertEquals(
timestamp.withPeriodAdded(Months.ONE, step).getMillis(),
- expr.eval(ExprUtils.nilBindings()).asLong()
+ expr.eval(InputBindings.nilBindings()).asLong()
);
}
@@ -136,7 +138,7 @@ public class TimestampShiftMacroTest extends MacroTestBase
Assert.assertEquals(
timestamp.withPeriodAdded(Months.ONE, step).getMillis(),
- expr.eval(ExprUtils.nilBindings()).asLong()
+ expr.eval(InputBindings.nilBindings()).asLong()
);
}
@@ -152,7 +154,7 @@ public class TimestampShiftMacroTest extends MacroTestBase
Assert.assertEquals(
timestamp.withPeriodAdded(Minutes.ONE, 1).getMillis(),
- expr.eval(ExprUtils.nilBindings()).asLong()
+ expr.eval(InputBindings.nilBindings()).asLong()
);
}
@@ -168,7 +170,7 @@ public class TimestampShiftMacroTest extends MacroTestBase
Assert.assertEquals(
timestamp.withPeriodAdded(Days.ONE, 1).getMillis(),
- expr.eval(ExprUtils.nilBindings()).asLong()
+ expr.eval(InputBindings.nilBindings()).asLong()
);
}
@@ -185,7 +187,7 @@ public class TimestampShiftMacroTest extends MacroTestBase
Assert.assertEquals(
timestamp.toDateTime(DateTimes.inferTzFromString("America/Los_Angeles")).withPeriodAdded(Years.ONE, 1).getMillis(),
- expr.eval(ExprUtils.nilBindings()).asLong()
+ expr.eval(InputBindings.nilBindings()).asLong()
);
}
@@ -208,6 +210,13 @@ public class TimestampShiftMacroTest extends MacroTestBase
{
@Nullable
@Override
+ public ExpressionType getType(String name)
+ {
+ return null;
+ }
+
+ @Nullable
+ @Override
public Object get(String name)
{
if ("step".equals(name)) {
@@ -232,9 +241,9 @@ public class TimestampShiftMacroTest extends MacroTestBase
);
if (NullHandling.replaceWithDefault()) {
- Assert.assertEquals(2678400000L, expr.eval(ExprUtils.nilBindings()).value());
+ Assert.assertEquals(2678400000L, expr.eval(InputBindings.nilBindings()).value());
} else {
- Assert.assertNull(expr.eval(ExprUtils.nilBindings()).value());
+ Assert.assertNull(expr.eval(InputBindings.nilBindings()).value());
}
}
diff --git a/processing/src/test/java/org/apache/druid/query/filter/InDimFilterTest.java b/processing/src/test/java/org/apache/druid/query/filter/InDimFilterTest.java
index 0c933b6..cb5016a 100644
--- a/processing/src/test/java/org/apache/druid/query/filter/InDimFilterTest.java
+++ b/processing/src/test/java/org/apache/druid/query/filter/InDimFilterTest.java
@@ -217,7 +217,7 @@ public class InDimFilterTest extends InitializedNullHandlingTest
final RowBasedColumnSelectorFactory<MapBasedRow> columnSelectorFactory = RowBasedColumnSelectorFactory.create(
RowAdapters.standardRow(),
() -> new MapBasedRow(0, row),
- RowSignature.builder().add("dim", ColumnType.STRING).build(),
+ () -> RowSignature.builder().add("dim", ColumnType.STRING).build(),
true
);
diff --git a/processing/src/test/java/org/apache/druid/query/groupby/GroupByQueryRunnerTest.java b/processing/src/test/java/org/apache/druid/query/groupby/GroupByQueryRunnerTest.java
index 5b3d3f1..f8b1793 100644
--- a/processing/src/test/java/org/apache/druid/query/groupby/GroupByQueryRunnerTest.java
+++ b/processing/src/test/java/org/apache/druid/query/groupby/GroupByQueryRunnerTest.java
@@ -11443,6 +11443,91 @@ public class GroupByQueryRunnerTest extends InitializedNullHandlingTest
}
@Test
+ public void testGroupByWithExpressionAggregatorWithComplex()
+ {
+ cannotVectorize();
+ final GroupByQuery query = makeQueryBuilder()
+ .setDataSource(QueryRunnerTestHelper.DATA_SOURCE)
+ .setQuerySegmentSpec(QueryRunnerTestHelper.FULL_ON_INTERVAL_SPEC)
+ .setDimensions(Collections.emptyList())
+ .setAggregatorSpecs(
+ new CardinalityAggregatorFactory(
+ "car",
+ ImmutableList.of(new DefaultDimensionSpec("quality", "quality")),
+ false
+ ),
+ new ExpressionLambdaAggregatorFactory(
+ "carExpr",
+ ImmutableSet.of("quality"),
+ null,
+ "hyper_unique()",
+ null,
+ null,
+ "hyper_unique_add(quality, __acc)",
+ "hyper_unique_add(carExpr, __acc)",
+ null,
+ "hyper_unique_estimate(o)",
+ null,
+ TestExprMacroTable.INSTANCE
+ )
+ )
+ .setGranularity(QueryRunnerTestHelper.ALL_GRAN)
+ .build();
+
+ List<ResultRow> expectedResults = Collections.singletonList(
+ makeRow(query, "1970-01-01", "car", QueryRunnerTestHelper.UNIQUES_9, "carExpr", QueryRunnerTestHelper.UNIQUES_9)
+ );
+ Iterable<ResultRow> results = GroupByQueryRunnerTestHelper.runQuery(factory, runner, query);
+ TestHelper.assertExpectedObjects(expectedResults, results, "subquery-cardinality");
+ }
+
+ @Test
+ public void testGroupByWithExpressionAggregatorWithComplexOnSubquery()
+ {
+ final GroupByQuery subquery = makeQueryBuilder()
+ .setDataSource(QueryRunnerTestHelper.DATA_SOURCE)
+ .setQuerySegmentSpec(QueryRunnerTestHelper.FULL_ON_INTERVAL_SPEC)
+ .setDimensions(new DefaultDimensionSpec("market", "market"), new DefaultDimensionSpec("quality", "quality"))
+ .setAggregatorSpecs(QueryRunnerTestHelper.ROWS_COUNT, new LongSumAggregatorFactory("index", "index"))
+ .setGranularity(QueryRunnerTestHelper.ALL_GRAN)
+ .build();
+
+ final GroupByQuery query = makeQueryBuilder()
+ .setDataSource(subquery)
+ .setQuerySegmentSpec(QueryRunnerTestHelper.FULL_ON_INTERVAL_SPEC)
+ .setDimensions(Collections.emptyList())
+ .setAggregatorSpecs(
+ new CardinalityAggregatorFactory(
+ "car",
+ ImmutableList.of(new DefaultDimensionSpec("quality", "quality")),
+ false
+ ),
+ new ExpressionLambdaAggregatorFactory(
+ "carExpr",
+ ImmutableSet.of("quality"),
+ null,
+ "hyper_unique()",
+ null,
+ null,
+ "hyper_unique_add(quality, __acc)",
+ null,
+ null,
+ "hyper_unique_estimate(o)",
+ null,
+ TestExprMacroTable.INSTANCE
+ )
+ )
+ .setGranularity(QueryRunnerTestHelper.ALL_GRAN)
+ .build();
+
+ List<ResultRow> expectedResults = Collections.singletonList(
+ makeRow(query, "1970-01-01", "car", QueryRunnerTestHelper.UNIQUES_9, "carExpr", QueryRunnerTestHelper.UNIQUES_9)
+ );
+ Iterable<ResultRow> results = GroupByQueryRunnerTestHelper.runQuery(factory, runner, query);
+ TestHelper.assertExpectedObjects(expectedResults, results, "subquery-cardinality");
+ }
+
+ @Test
public void testGroupByWithExpressionAggregatorWithArrays()
{
// expression agg not yet vectorized
diff --git a/processing/src/test/java/org/apache/druid/query/topn/TopNQueryRunnerTest.java b/processing/src/test/java/org/apache/druid/query/topn/TopNQueryRunnerTest.java
index 23f724a..341fb17 100644
--- a/processing/src/test/java/org/apache/druid/query/topn/TopNQueryRunnerTest.java
+++ b/processing/src/test/java/org/apache/druid/query/topn/TopNQueryRunnerTest.java
@@ -6075,6 +6075,71 @@ public class TopNQueryRunnerTest extends InitializedNullHandlingTest
assertExpectedResults(expectedResults, query);
}
+ @Test
+ public void testExpressionAggregatorComplex()
+ {
+
+ // sorted by array hyperunique expression
+ TopNQuery query = new TopNQueryBuilder()
+ .dataSource(QueryRunnerTestHelper.DATA_SOURCE)
+ .granularity(QueryRunnerTestHelper.ALL_GRAN)
+ .dimension(QueryRunnerTestHelper.MARKET_DIMENSION)
+ .metric("carExpr")
+ .threshold(4)
+ .intervals(QueryRunnerTestHelper.FULL_ON_INTERVAL_SPEC)
+ .aggregators(
+ ImmutableList.of(
+ new CardinalityAggregatorFactory(
+ "car",
+ ImmutableList.of(new DefaultDimensionSpec("quality", "quality")),
+ false
+ ),
+ new ExpressionLambdaAggregatorFactory(
+ "carExpr",
+ ImmutableSet.of("quality"),
+ null,
+ "hyper_unique()",
+ null,
+ null,
+ "hyper_unique_add(quality, __acc)",
+ "hyper_unique_add(carExpr, __acc)",
+ null,
+ "hyper_unique_estimate(o)",
+ null,
+ TestExprMacroTable.INSTANCE
+ )
+ )
+ )
+ .build();
+
+
+ List<Result<TopNResultValue>> expectedResults = Collections.singletonList(
+ new Result<>(
+ DateTimes.of("2011-01-12T00:00:00.000Z"),
+ new TopNResultValue(
+ Arrays.<Map<String, Object>>asList(
+ ImmutableMap.<String, Object>builder()
+ .put(QueryRunnerTestHelper.MARKET_DIMENSION, "spot")
+ .put("car", 9.019833517963864)
+ .put("carExpr", 9.019833517963864)
+ .build(),
+ ImmutableMap.<String, Object>builder()
+ .put(QueryRunnerTestHelper.MARKET_DIMENSION, "total_market")
+ .put("car", 2.000977198748901)
+ .put("carExpr", 2.000977198748901)
+ .build(),
+ ImmutableMap.<String, Object>builder()
+ .put(QueryRunnerTestHelper.MARKET_DIMENSION, "upfront")
+ .put("car", 2.000977198748901)
+ .put("carExpr", 2.000977198748901)
+ .build()
+ )
+ )
+ )
+ );
+ assertExpectedResults(expectedResults, query);
+ }
+
private static Map<String, Object> makeRowWithNulls(
String dimName,
@Nullable Object dimValue,
diff --git a/processing/src/test/java/org/apache/druid/segment/TestHelper.java b/processing/src/test/java/org/apache/druid/segment/TestHelper.java
index 59f739a..40dac73 100644
--- a/processing/src/test/java/org/apache/druid/segment/TestHelper.java
+++ b/processing/src/test/java/org/apache/druid/segment/TestHelper.java
@@ -392,7 +392,7 @@ public class TestHelper
Assert.assertEquals(
message,
(Object[]) expectedValue,
- (Object[]) ExprEval.coerceListToArray((List) actualValue, true)
+ (Object[]) ExprEval.coerceListToArray((List) actualValue, true).rhs
);
} else {
Assert.assertArrayEquals(
diff --git a/processing/src/test/java/org/apache/druid/segment/filter/BaseFilterTest.java b/processing/src/test/java/org/apache/druid/segment/filter/BaseFilterTest.java
index 7e21bcf..3c8b70a 100644
--- a/processing/src/test/java/org/apache/druid/segment/filter/BaseFilterTest.java
+++ b/processing/src/test/java/org/apache/druid/segment/filter/BaseFilterTest.java
@@ -754,7 +754,7 @@ public abstract class BaseFilterTest extends InitializedNullHandlingTest
RowBasedColumnSelectorFactory.create(
RowAdapters.standardRow(),
rowSupplier::get,
- rowSignatureBuilder.build(),
+ rowSignatureBuilder::build,
false
)
)
diff --git a/processing/src/test/java/org/apache/druid/segment/virtual/ExpressionVirtualColumnTest.java b/processing/src/test/java/org/apache/druid/segment/virtual/ExpressionVirtualColumnTest.java
index a5dac67..c7a1c8e 100644
--- a/processing/src/test/java/org/apache/druid/segment/virtual/ExpressionVirtualColumnTest.java
+++ b/processing/src/test/java/org/apache/druid/segment/virtual/ExpressionVirtualColumnTest.java
@@ -203,7 +203,7 @@ public class ExpressionVirtualColumnTest extends InitializedNullHandlingTest
private static final ColumnSelectorFactory COLUMN_SELECTOR_FACTORY = RowBasedColumnSelectorFactory.create(
RowAdapters.standardRow(),
CURRENT_ROW::get,
- RowSignature.empty(),
+ RowSignature::empty,
false
);
@@ -743,7 +743,7 @@ public class ExpressionVirtualColumnTest extends InitializedNullHandlingTest
RowBasedColumnSelectorFactory.create(
RowAdapters.standardRow(),
CURRENT_ROW::get,
- RowSignature.builder().add("x", ColumnType.LONG).build(),
+ RowSignature.builder().add("x", ColumnType.LONG)::build,
false
),
Parser.parse(SCALE_LONG.getExpression(), TestExprMacroTable.INSTANCE)
@@ -766,7 +766,7 @@ public class ExpressionVirtualColumnTest extends InitializedNullHandlingTest
RowBasedColumnSelectorFactory.create(
RowAdapters.standardRow(),
CURRENT_ROW::get,
- RowSignature.builder().add("x", ColumnType.DOUBLE).build(),
+ RowSignature.builder().add("x", ColumnType.DOUBLE)::build,
false
),
Parser.parse(SCALE_FLOAT.getExpression(), TestExprMacroTable.INSTANCE)
@@ -789,7 +789,7 @@ public class ExpressionVirtualColumnTest extends InitializedNullHandlingTest
RowBasedColumnSelectorFactory.create(
RowAdapters.standardRow(),
CURRENT_ROW::get,
- RowSignature.builder().add("x", ColumnType.FLOAT).build(),
+ RowSignature.builder().add("x", ColumnType.FLOAT)::build,
false
),
Parser.parse(SCALE_FLOAT.getExpression(), TestExprMacroTable.INSTANCE)
diff --git a/processing/src/test/java/org/apache/druid/segment/virtual/ListFilteredVirtualColumnSelectorTest.java b/processing/src/test/java/org/apache/druid/segment/virtual/ListFilteredVirtualColumnSelectorTest.java
index 9dbc25d..0438634 100644
--- a/processing/src/test/java/org/apache/druid/segment/virtual/ListFilteredVirtualColumnSelectorTest.java
+++ b/processing/src/test/java/org/apache/druid/segment/virtual/ListFilteredVirtualColumnSelectorTest.java
@@ -276,7 +276,7 @@ public class ListFilteredVirtualColumnSelectorTest extends InitializedNullHandli
RowBasedColumnSelectorFactory.create(
RowAdapters.standardRow(),
() -> new MapBasedRow(0L, ImmutableMap.of(COLUMN_NAME, ImmutableList.of("a", "b", "c", "d"))),
- rowSignature,
+ () -> rowSignature,
false
),
VirtualColumns.create(ImmutableList.of(virtualColumn))
diff --git a/server/src/main/java/org/apache/druid/guice/ExpressionModule.java b/server/src/main/java/org/apache/druid/guice/ExpressionModule.java
index 7a25f92..9cc829f 100644
--- a/server/src/main/java/org/apache/druid/guice/ExpressionModule.java
+++ b/server/src/main/java/org/apache/druid/guice/ExpressionModule.java
@@ -28,6 +28,7 @@ import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.query.expression.CaseInsensitiveContainsExprMacro;
import org.apache.druid.query.expression.ContainsExprMacro;
import org.apache.druid.query.expression.GuiceExprMacroTable;
+import org.apache.druid.query.expression.HyperUniqueExpressions;
import org.apache.druid.query.expression.IPv4AddressMatchExprMacro;
import org.apache.druid.query.expression.IPv4AddressParseExprMacro;
import org.apache.druid.query.expression.IPv4AddressStringifyExprMacro;
@@ -65,6 +66,10 @@ public class ExpressionModule implements DruidModule
.add(TrimExprMacro.BothTrimExprMacro.class)
.add(TrimExprMacro.LeftTrimExprMacro.class)
.add(TrimExprMacro.RightTrimExprMacro.class)
+ .add(HyperUniqueExpressions.HllCreateExprMacro.class)
+ .add(HyperUniqueExpressions.HllAddExprMacro.class)
+ .add(HyperUniqueExpressions.HllEstimateExprMacro.class)
+ .add(HyperUniqueExpressions.HllRoundEstimateExprMacro.class)
.build();
@Override
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayContainsOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayContainsOperatorConversion.java
index 705da9a..9e33d1d 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayContainsOperatorConversion.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayContainsOperatorConversion.java
@@ -27,6 +27,7 @@ import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprEval;
+import org.apache.druid.math.expr.InputBindings;
import org.apache.druid.math.expr.Parser;
import org.apache.druid.query.filter.AndDimFilter;
import org.apache.druid.query.filter.DimFilter;
@@ -102,7 +103,7 @@ public class ArrayContainsOperatorConversion extends BaseExpressionDimFilterOper
if (expr.isLiteral()) {
// Evaluate the expression to get out the array elements.
// We can safely pass a noop ObjectBinding if the expression is literal.
- ExprEval<?> exprEval = expr.eval(name -> null);
+ ExprEval<?> exprEval = expr.eval(InputBindings.nilBindings());
String[] arrayElements = exprEval.asStringArray();
if (arrayElements == null || arrayElements.length == 0) {
// If arrayElements is empty which means rightExpr is an empty array,
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayOverlapOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayOverlapOperatorConversion.java
index cb1ddf5..78edea3 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayOverlapOperatorConversion.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayOverlapOperatorConversion.java
@@ -28,6 +28,7 @@ import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprEval;
+import org.apache.druid.math.expr.InputBindings;
import org.apache.druid.math.expr.Parser;
import org.apache.druid.query.filter.DimFilter;
import org.apache.druid.query.filter.InDimFilter;
@@ -109,7 +110,7 @@ public class ArrayOverlapOperatorConversion extends BaseExpressionDimFilterOpera
if (expr.isLiteral()) {
// Evaluate the expression to take out the array elements.
// We can safely pass null if the expression is literal.
- ExprEval<?> exprEval = expr.eval(name -> null);
+ ExprEval<?> exprEval = expr.eval(InputBindings.nilBindings());
String[] arrayElements = exprEval.asStringArray();
if (arrayElements == null || arrayElements.length == 0) {
// If arrayElements is empty which means complexExpr is an empty array,
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/MultiValueStringOperatorConversions.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/MultiValueStringOperatorConversions.java
index 6b5e5fb..46f9814 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/MultiValueStringOperatorConversions.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/MultiValueStringOperatorConversions.java
@@ -30,8 +30,8 @@ import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.druid.math.expr.Expr;
+import org.apache.druid.math.expr.InputBindings;
import org.apache.druid.math.expr.Parser;
-import org.apache.druid.query.expression.ExprUtils;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.segment.virtual.ListFilteredVirtualColumn;
import org.apache.druid.sql.calcite.expression.AliasedOperatorConversion;
@@ -334,7 +334,7 @@ public class MultiValueStringOperatorConversions
if (!expr.isLiteral()) {
return null;
}
- String[] lit = expr.eval(ExprUtils.nilBindings()).asStringArray();
+ String[] lit = expr.eval(InputBindings.nilBindings()).asStringArray();
if (lit == null || lit.length == 0) {
return null;
}
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidRexExecutor.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidRexExecutor.java
index a27871d..ff61050 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidRexExecutor.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidRexExecutor.java
@@ -29,6 +29,7 @@ import org.apache.druid.java.util.common.IAE;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprEval;
import org.apache.druid.math.expr.ExprType;
+import org.apache.druid.math.expr.InputBindings;
import org.apache.druid.math.expr.Parser;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.expression.DruidExpression;
@@ -74,10 +75,12 @@ public class DruidRexExecutor implements RexExecutor
final Expr expr = Parser.parse(druidExpression.getExpression(), plannerContext.getExprMacroTable());
final ExprEval exprResult = expr.eval(
- name -> {
- // Sanity check. Bindings should not be used for a constant expression.
- throw new UnsupportedOperationException();
- }
+ InputBindings.forFunction(
+ name -> {
+ // Sanity check. Bindings should not be used for a constant expression.
+ throw new UnsupportedOperationException();
+ }
+ )
);
final RexNode literal;
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rel/QueryMaker.java b/sql/src/main/java/org/apache/druid/sql/calcite/rel/QueryMaker.java
index 0a28320..9b625b5 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/rel/QueryMaker.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/rel/QueryMaker.java
@@ -340,6 +340,8 @@ public class QueryMaker
coercedValue = Arrays.asList((Long[]) value);
} else if (value instanceof Double[]) {
coercedValue = Arrays.asList((Double[]) value);
+ } else if (value instanceof Object[]) {
+ coercedValue = Arrays.asList((Object[]) value);
} else {
throw new ISE("Cannot coerce[%s] to %s", value.getClass().getName(), sqlType);
}
diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteArraysQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteArraysQueryTest.java
index 55367fb..d4a9157 100644
--- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteArraysQueryTest.java
+++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteArraysQueryTest.java
@@ -24,8 +24,11 @@ import com.google.common.collect.ImmutableSet;
import junitparams.JUnitParamsRunner;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.HumanReadableBytes;
+import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.granularity.Granularities;
+import org.apache.druid.math.expr.ExpressionProcessing;
import org.apache.druid.query.Druids;
+import org.apache.druid.query.Query;
import org.apache.druid.query.QueryDataSource;
import org.apache.druid.query.TableDataSource;
import org.apache.druid.query.aggregation.CountAggregatorFactory;
@@ -141,21 +144,49 @@ public class CalciteArraysQueryTest extends BaseCalciteQueryTest
}
@Test
- public void testSelectNonConstantArrayExpressionFromTableFailForMultival() throws Exception
+ public void testSelectNonConstantArrayExpressionFromTableForMultival() throws Exception
{
- // without expression output type inference to prevent this, the automatic translation will try to turn this into
+ final String sql = "SELECT ARRAY[CONCAT(dim3, 'word'),'up'] as arr, dim1 FROM foo LIMIT 5";
+ final Query<?> scanQuery = newScanQueryBuilder()
+ .dataSource(CalciteTests.DATASOURCE1)
+ .intervals(querySegmentSpec(Filtration.eternity()))
+ .virtualColumns(expressionVirtualColumn("v0", "array(concat(\"dim3\",'word'),'up')", ColumnType.STRING))
+ .columns("dim1", "v0")
+ .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST)
+ .limit(5)
+ .context(QUERY_CONTEXT_DEFAULT)
+ .build();
+
+
+ ExpressionProcessing.initializeForTests(true);
+ // if nested arrays are allowed, dim3 is a multi-valued string column, so the automatic translation will turn this
+ // expression into
//
// `map((dim3) -> array(concat(dim3,'word'),'up'), dim3)`
//
- // This error message will get better in the future. The error without translation would be:
- //
- // org.apache.druid.java.util.common.RE: Unhandled array constructor element type [ARRAY<STRING>]
+ // this works, but we still translate the output into a string since that is the current output type
+ // in some future this might not auto-convert to a string type (when we support grouping on arrays maybe?)
- expectedException.expect(RuntimeException.class);
- expectedException.expectMessage("Unhandled map function output type [ARRAY<STRING>]");
testQuery(
- "SELECT ARRAY[CONCAT(dim3, 'word'),'up'] as arr, dim1 FROM foo LIMIT 5",
- ImmutableList.of(),
+ sql,
+ ImmutableList.of(scanQuery),
+ ImmutableList.of(
+ new Object[]{"[[\"aword\",\"up\"],[\"bword\",\"up\"]]", ""},
+ new Object[]{"[[\"bword\",\"up\"],[\"cword\",\"up\"]]", "10.1"},
+ new Object[]{"[[\"dword\",\"up\"]]", "2"},
+ new Object[]{"[[\"word\",\"up\"]]", "1"},
+ useDefault ? new Object[]{"[[\"word\",\"up\"]]", "def"} : new Object[]{"[[null,\"up\"]]", "def"}
+ )
+ );
+
+ ExpressionProcessing.initializeForTests(null);
+
+ // if nested arrays are not enabled, this doesn't work
+ expectedException.expect(IAE.class);
+ expectedException.expectMessage("Cannot create a nested array type [ARRAY<ARRAY<STRING>>], 'druid.expressions.allowNestedArrays' must be set to true");
+ testQuery(
+ sql,
+ ImmutableList.of(scanQuery),
ImmutableList.of()
);
}
diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionTestHelper.java b/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionTestHelper.java
index 3c049b4..3101eb0 100644
--- a/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionTestHelper.java
+++ b/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionTestHelper.java
@@ -301,7 +301,7 @@ class ExpressionTestHelper
RowBasedColumnSelectorFactory.create(
RowAdapters.standardRow(),
() -> new MapBasedRow(0L, bindings),
- rowSignature,
+ () -> rowSignature,
false
),
VirtualColumns.create(virtualColumns)
diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/util/CalciteTestBase.java b/sql/src/test/java/org/apache/druid/sql/calcite/util/CalciteTestBase.java
index f8f5307..22c8aec 100644
--- a/sql/src/test/java/org/apache/druid/sql/calcite/util/CalciteTestBase.java
+++ b/sql/src/test/java/org/apache/druid/sql/calcite/util/CalciteTestBase.java
@@ -21,6 +21,7 @@ package org.apache.druid.sql.calcite.util;
import com.google.common.collect.ImmutableList;
import org.apache.druid.common.config.NullHandling;
+import org.apache.druid.math.expr.ExpressionProcessing;
import org.apache.druid.sql.calcite.planner.Calcites;
import org.apache.druid.sql.http.SqlParameter;
import org.junit.BeforeClass;
@@ -36,5 +37,6 @@ public abstract class CalciteTestBase
{
Calcites.setSystemProperties();
NullHandling.initializeForTests();
+ ExpressionProcessing.initializeForTests(null);
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@druid.apache.org
For additional commands, e-mail: commits-help@druid.apache.org