You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iotdb.apache.org by ja...@apache.org on 2023/02/13 08:23:53 UTC
[iotdb] branch master updated: [IOTDB-5456]Implement COUNT_IF built-in aggregation function
This is an automated email from the ASF dual-hosted git repository.
jackietien pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git
The following commit(s) were added to refs/heads/master by this push:
new 33719b66b3 [IOTDB-5456]Implement COUNT_IF built-in aggregation function
33719b66b3 is described below
commit 33719b66b3977b62166b3322e5c1ec0d4cc34624
Author: Weihao Li <60...@users.noreply.github.com>
AuthorDate: Mon Feb 13 16:23:47 2023 +0800
[IOTDB-5456]Implement COUNT_IF built-in aggregation function
---
docs/UserGuide/Operators-Functions/Aggregation.md | 98 +++++++--
.../UserGuide/Operators-Functions/Aggregation.md | 92 ++++++--
.../iotdb/db/it/aggregation/IoTDBCountIf2IT.java | 45 ++++
.../iotdb/db/it/aggregation/IoTDBCountIf3IT.java | 45 ++++
.../iotdb/db/it/aggregation/IoTDBCountIfIT.java | 241 +++++++++++++++++++++
.../udf/builtin/BuiltinAggregationFunction.java | 47 ++++
...tinScalarFunction.java => BuiltinFunction.java} | 19 +-
.../org/apache/iotdb/db/constant/SqlConstant.java | 1 +
.../db/mpp/aggregation/AccumulatorFactory.java | 56 ++++-
.../db/mpp/aggregation/CountIfAccumulator.java | 156 +++++++++++++
.../SlidingWindowAggregatorFactory.java | 9 +-
.../iotdb/db/mpp/plan/analyze/AnalyzeVisitor.java | 9 +-
.../db/mpp/plan/analyze/ExpressionAnalyzer.java | 52 ++++-
.../mpp/plan/analyze/ExpressionTypeAnalyzer.java | 4 +-
.../plan/expression/multi/FunctionExpression.java | 40 ++--
.../db/mpp/plan/expression/multi/FunctionType.java | 27 +++
.../visitor/ColumnTransformerVisitor.java | 6 +-
.../visitor/IntermediateLayerVisitor.java | 6 +-
.../iotdb/db/mpp/plan/parser/ASTVisitor.java | 67 ++++++
.../db/mpp/plan/planner/LogicalPlanBuilder.java | 19 +-
.../db/mpp/plan/planner/LogicalPlanVisitor.java | 20 +-
.../db/mpp/plan/planner/OperatorTreeGenerator.java | 28 ++-
.../plan/planner/distribution/SourceRewriter.java | 15 +-
.../plan/parameter/AggregationDescriptor.java | 89 +++++---
.../CrossSeriesAggregationDescriptor.java | 110 ++++++++--
.../org/apache/iotdb/db/utils/SchemaUtils.java | 1 +
.../apache/iotdb/db/utils/TypeInferenceUtils.java | 105 +++++++--
.../iotdb/db/mpp/aggregation/AccumulatorTest.java | 71 +++++-
.../operator/AggregationOperatorTest.java | 14 +-
.../AlignedSeriesAggregationScanOperatorTest.java | 112 ++++++++--
.../operator/HorizontallyConcatOperatorTest.java | 8 +-
.../mpp/execution/operator/OperatorMemoryTest.java | 24 +-
.../operator/RawDataAggregationOperatorTest.java | 8 +-
.../SeriesAggregationScanOperatorTest.java | 112 ++++++++--
.../SlidingWindowAggregationOperatorTest.java | 9 +-
.../plan/analyze/AggregationDescriptorTest.java | 53 +----
.../db/mpp/plan/plan/QueryLogicalPlanUtil.java | 18 ++
.../distribution/AggregationDistributionTest.java | 18 ++
.../node/process/GroupByLevelNodeSerdeTest.java | 2 +
.../plan/node/process/GroupByTagNodeSerdeTest.java | 4 +
thrift-commons/src/main/thrift/common.thrift | 3 +-
41 files changed, 1628 insertions(+), 235 deletions(-)
diff --git a/docs/UserGuide/Operators-Functions/Aggregation.md b/docs/UserGuide/Operators-Functions/Aggregation.md
index ca9e03d461..c20947908e 100644
--- a/docs/UserGuide/Operators-Functions/Aggregation.md
+++ b/docs/UserGuide/Operators-Functions/Aggregation.md
@@ -23,24 +23,27 @@
Aggregate functions are many-to-one functions. They perform aggregate calculations on a set of values, resulting in a single aggregated result.
-All aggregate functions except `COUNT()` ignore null values and return null when there are no input rows or all values are null. For example, `SUM()` returns null instead of zero, and `AVG()` does not include null values in the count.
+All aggregate functions except `COUNT()`, `COUNT_IF()` ignore null values and return null when there are no input rows or all values are null. For example, `SUM()` returns null instead of zero, and `AVG()` does not include null values in the count.
The aggregate functions supported by IoTDB are as follows:
-| Function Name | Function Description | Allowed Input Data Types | Output Data Types |
-| ------------- | ------------------------------------------------------------ | ------------------------ | ----------------------------------- |
-| SUM | Summation. | INT32 INT64 FLOAT DOUBLE | DOUBLE |
-| COUNT | Counts the number of data points. | All types | INT |
-| AVG | Average. | INT32 INT64 FLOAT DOUBLE | DOUBLE |
-| EXTREME | Finds the value with the largest absolute value. Returns a positive value if the maximum absolute value of positive and negative values is equal. | INT32 INT64 FLOAT DOUBLE | Consistent with the input data type |
-| MAX_VALUE | Find the maximum value. | INT32 INT64 FLOAT DOUBLE | Consistent with the input data type |
-| MIN_VALUE | Find the minimum value. | INT32 INT64 FLOAT DOUBLE | Consistent with the input data type |
-| FIRST_VALUE | Find the value with the smallest timestamp. | All data types | Consistent with input data type |
-| LAST_VALUE | Find the value with the largest timestamp. | All data types | Consistent with input data type |
-| MAX_TIME | Find the maximum timestamp. | All data Types | Timestamp |
-| MIN_TIME | Find the minimum timestamp. | All data Types | Timestamp |
-
-**Example:** Count Points
+| Function Name | Function Description | Allowed Input Data Types | Output Data Types [...]
+| ------------- |------------------------------------------------------------------------------------------------------------------------------------------------------| ------------------------ |--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- [...]
+| SUM | Summation. | INT32 INT64 FLOAT DOUBLE | DOUBLE [...]
+| COUNT | Counts the number of data points. | All types | INT [...]
+| AVG | Average. | INT32 INT64 FLOAT DOUBLE | DOUBLE [...]
+| EXTREME | Finds the value with the largest absolute value. Returns a positive value if the maximum absolute value of positive and negative values is equal. | INT32 INT64 FLOAT DOUBLE | Consistent with the input data type [...]
+| MAX_VALUE | Find the maximum value. | INT32 INT64 FLOAT DOUBLE | Consistent with the input data type [...]
+| MIN_VALUE | Find the minimum value. | INT32 INT64 FLOAT DOUBLE | Consistent with the input data type [...]
+| FIRST_VALUE | Find the value with the smallest timestamp. | All data types | Consistent with input data type [...]
+| LAST_VALUE | Find the value with the largest timestamp. | All data types | Consistent with input data type [...]
+| MAX_TIME | Find the maximum timestamp. | All data Types | Timestamp [...]
+| MIN_TIME | Find the minimum timestamp. | All data Types | Timestamp [...]
+| COUNT_IF | Find the number of data points that continuously meet a given condition and the number of data points that meet the condition (represented by keep) meet the specified threshold. | BOOLEAN | `[keep >=/>/=/!=/</<=]threshold`:The specified threshold or threshold condition, it is equivalent to `keep >= threshold` if `threshold` is used alone, type of `threshold` is `INT64`<br/> `ignoreNull`:Optional, default value is `true`;If the value is `true`, null values [...]
+
+## COUNT
+
+### example
```sql
select count(status) from root.ln.wf01.wt01;
@@ -55,4 +58,69 @@ Result:
+-------------------------------+
Total line number = 1
It costs 0.016s
+```
+
+## COUNT_IF
+
+### Grammar
+```sql
+count_if(predicate, [keep >=/>/=/!=/</<=]threshold[, 'ignoreNull'='true/false'])
+```
+predicate: legal expression with `BOOLEAN` return type
+
+use of threshold and ignoreNull can see above table
+
+>Note: count_if is not supported to use with SlidingWindow in group by time now
+
+### example
+
+#### raw data
+
+```
++-----------------------------+-------------+-------------+
+| Time|root.db.d1.s1|root.db.d1.s2|
++-----------------------------+-------------+-------------+
+|1970-01-01T08:00:00.001+08:00| 0| 0|
+|1970-01-01T08:00:00.002+08:00| null| 0|
+|1970-01-01T08:00:00.003+08:00| 0| 0|
+|1970-01-01T08:00:00.004+08:00| 0| 0|
+|1970-01-01T08:00:00.005+08:00| 1| 0|
+|1970-01-01T08:00:00.006+08:00| 1| 0|
+|1970-01-01T08:00:00.007+08:00| 1| 0|
+|1970-01-01T08:00:00.008+08:00| 0| 0|
+|1970-01-01T08:00:00.009+08:00| 0| 0|
+|1970-01-01T08:00:00.010+08:00| 0| 0|
++-----------------------------+-------------+-------------+
+```
+
+#### Not use `ignoreNull` attribute (Ignore Null)
+
+SQL:
+```sql
+select count_if(s1=0 & s2=0, 3), count_if(s1=1 & s2=0, 3) from root.db.d1
+```
+
+Result:
+```
++--------------------------------------------------+--------------------------------------------------+
+|count_if(root.db.d1.s1 = 0 & root.db.d1.s2 = 0, 3)|count_if(root.db.d1.s1 = 1 & root.db.d1.s2 = 0, 3)|
++--------------------------------------------------+--------------------------------------------------+
+| 2| 1|
++--------------------------------------------------+--------------------------------------------------
+```
+
+#### Use `ignoreNull` attribute
+
+SQL:
+```sql
+select count_if(s1=0 & s2=0, 3, 'ignoreNull'='false'), count_if(s1=1 & s2=0, 3, 'ignoreNull'='false') from root.db.d1
+```
+
+Result:
+```
++------------------------------------------------------------------------+------------------------------------------------------------------------+
+|count_if(root.db.d1.s1 = 0 & root.db.d1.s2 = 0, 3, "ignoreNull"="false")|count_if(root.db.d1.s1 = 1 & root.db.d1.s2 = 0, 3, "ignoreNull"="false")|
++------------------------------------------------------------------------+------------------------------------------------------------------------+
+| 1| 1|
++------------------------------------------------------------------------+------------------------------------------------------------------------+
```
\ No newline at end of file
diff --git a/docs/zh/UserGuide/Operators-Functions/Aggregation.md b/docs/zh/UserGuide/Operators-Functions/Aggregation.md
index 1cca558015..89182a0012 100644
--- a/docs/zh/UserGuide/Operators-Functions/Aggregation.md
+++ b/docs/zh/UserGuide/Operators-Functions/Aggregation.md
@@ -23,19 +23,85 @@
聚合函数是多对一函数。它们对一组值进行聚合计算,得到单个聚合结果。
-除了 `COUNT()` 之外,其他所有聚合函数都忽略空值,并在没有输入行或所有值为空时返回空值。 例如,`SUM()` 返回 null 而不是零,而 `AVG()` 在计数中不包括 null 值。
+除了 `COUNT()`, `COUNT_IF()`之外,其他所有聚合函数都忽略空值,并在没有输入行或所有值为空时返回空值。 例如,`SUM()` 返回 null 而不是零,而 `AVG()` 在计数中不包括 null 值。
IoTDB 支持的聚合函数如下:
-| 函数名 | 功能描述 | 允许的输入类型 | 输出类型 |
-| ----------- | ------------------------------------------------------------ | ------------------------ | -------------- |
-| SUM | 求和。 | INT32 INT64 FLOAT DOUBLE | DOUBLE |
-| COUNT | 计算数据点数。 | 所有类型 | INT |
-| AVG | 求平均值。 | INT32 INT64 FLOAT DOUBLE | DOUBLE |
-| EXTREME | 求具有最大绝对值的值。如果正值和负值的最大绝对值相等,则返回正值。 | INT32 INT64 FLOAT DOUBLE | 与输入类型一致 |
-| MAX_VALUE | 求最大值。 | INT32 INT64 FLOAT DOUBLE | 与输入类型一致 |
-| MIN_VALUE | 求最小值。 | INT32 INT64 FLOAT DOUBLE | 与输入类型一致 |
-| FIRST_VALUE | 求时间戳最小的值。 | 所有类型 | 与输入类型一致 |
-| LAST_VALUE | 求时间戳最大的值。 | 所有类型 | 与输入类型一致 |
-| MAX_TIME | 求最大时间戳。 | 所有类型 | Timestamp |
-| MIN_TIME | 求最小时间戳。 | 所有类型 | Timestamp |
\ No newline at end of file
+| 函数名 | 功能描述 | 允许的输入类型 | 必要的属性参数 | 输出类型 |
+| ----------- |-----------------------------------------------| ------------------------ |----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------|
+| SUM | 求和。 | INT32 INT64 FLOAT DOUBLE | 无 | DOUBLE |
+| COUNT | 计算数据点数。 | 所有类型 | 无 | INT64 |
+| AVG | 求平均值。 | INT32 INT64 FLOAT DOUBLE | 无 | DOUBLE |
+| EXTREME | 求具有最大绝对值的值。如果正值和负值的最大绝对值相等,则返回正值。 | INT32 INT64 FLOAT DOUBLE | 无 | 与输入类型一致 |
+| MAX_VALUE | 求最大值。 | INT32 INT64 FLOAT DOUBLE | 无 | 与输入类型一致 |
+| MIN_VALUE | 求最小值。 | INT32 INT64 FLOAT DOUBLE | 无 | 与输入类型一致 |
+| FIRST_VALUE | 求时间戳最小的值。 | 所有类型 | 无 | 与输入类型一致 |
+| LAST_VALUE | 求时间戳最大的值。 | 所有类型 | 无 | 与输入类型一致 |
+| MAX_TIME | 求最大时间戳。 | 所有类型 | 无 | Timestamp |
+| MIN_TIME | 求最小时间戳。 | 所有类型 | 无 | Timestamp |
+| COUNT_IF | 求数据点连续满足某一给定条件,且满足条件的数据点个数(用keep表示)满足指定阈值的次数。 | BOOLEAN | `[keep >=/>/=/!=/</<=]threshold`:被指定的阈值或阈值条件,若只使用`threshold`则等价于`keep >= threshold`,`threshold`类型为`INT64`<br/> `ignoreNull`:可选,默认为`true`;为`true`表示忽略null值,即如果中间出现null值,直接忽略,不会打断连续性;为`false`表示不忽略null值,即如果中间出现null值,会打断连续性 | INT64 |
+
+## COUNT_IF
+
+### 语法
+```sql
+count_if(predicate, [keep >=/>/=/!=/</<=]threshold[, 'ignoreNull'='true/false'])
+```
+predicate: 返回类型为`BOOLEAN`的合法表达式
+
+threshold 及 ignoreNull 用法见上表
+
+>注意: count_if 当前暂不支持与 group by time 的 SlidingWindow 一起使用
+
+### 使用示例
+
+#### 原始数据
+
+```
++-----------------------------+-------------+-------------+
+| Time|root.db.d1.s1|root.db.d1.s2|
++-----------------------------+-------------+-------------+
+|1970-01-01T08:00:00.001+08:00| 0| 0|
+|1970-01-01T08:00:00.002+08:00| null| 0|
+|1970-01-01T08:00:00.003+08:00| 0| 0|
+|1970-01-01T08:00:00.004+08:00| 0| 0|
+|1970-01-01T08:00:00.005+08:00| 1| 0|
+|1970-01-01T08:00:00.006+08:00| 1| 0|
+|1970-01-01T08:00:00.007+08:00| 1| 0|
+|1970-01-01T08:00:00.008+08:00| 0| 0|
+|1970-01-01T08:00:00.009+08:00| 0| 0|
+|1970-01-01T08:00:00.010+08:00| 0| 0|
++-----------------------------+-------------+-------------+
+```
+
+#### 不使用ignoreNull参数(忽略null)
+
+SQL:
+```sql
+select count_if(s1=0 & s2=0, 3), count_if(s1=1 & s2=0, 3) from root.db.d1
+```
+
+输出:
+```
++--------------------------------------------------+--------------------------------------------------+
+|count_if(root.db.d1.s1 = 0 & root.db.d1.s2 = 0, 3)|count_if(root.db.d1.s1 = 1 & root.db.d1.s2 = 0, 3)|
++--------------------------------------------------+--------------------------------------------------+
+| 2| 1|
++--------------------------------------------------+--------------------------------------------------
+```
+
+#### 使用ignoreNull参数
+
+SQL:
+```sql
+select count_if(s1=0 & s2=0, 3, 'ignoreNull'='false'), count_if(s1=1 & s2=0, 3, 'ignoreNull'='false') from root.db.d1
+```
+
+输出:
+```
++------------------------------------------------------------------------+------------------------------------------------------------------------+
+|count_if(root.db.d1.s1 = 0 & root.db.d1.s2 = 0, 3, "ignoreNull"="false")|count_if(root.db.d1.s1 = 1 & root.db.d1.s2 = 0, 3, "ignoreNull"="false")|
++------------------------------------------------------------------------+------------------------------------------------------------------------+
+| 1| 1|
++------------------------------------------------------------------------+------------------------------------------------------------------------+
+```
\ No newline at end of file
diff --git a/integration-test/src/test/java/org/apache/iotdb/db/it/aggregation/IoTDBCountIf2IT.java b/integration-test/src/test/java/org/apache/iotdb/db/it/aggregation/IoTDBCountIf2IT.java
new file mode 100644
index 0000000000..fb40a99126
--- /dev/null
+++ b/integration-test/src/test/java/org/apache/iotdb/db/it/aggregation/IoTDBCountIf2IT.java
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.it.aggregation;
+
+import org.apache.iotdb.it.env.EnvFactory;
+import org.apache.iotdb.it.framework.IoTDBTestRunner;
+import org.apache.iotdb.itbase.category.ClusterIT;
+import org.apache.iotdb.itbase.category.LocalStandaloneIT;
+
+import org.junit.BeforeClass;
+import org.junit.experimental.categories.Category;
+import org.junit.runner.RunWith;
+
+import static org.apache.iotdb.db.it.utils.TestUtils.prepareData;
+
+@RunWith(IoTDBTestRunner.class)
+@Category({LocalStandaloneIT.class, ClusterIT.class})
+public class IoTDBCountIf2IT extends IoTDBCountIfIT {
+ // 2 devices 2 regions
+ @BeforeClass
+ public static void setUp() throws Exception {
+ EnvFactory.getEnv().getConfig().getCommonConfig().setDataRegionGroupExtensionPolicy("CUSTOM");
+ EnvFactory.getEnv().getConfig().getCommonConfig().setDefaultDataRegionGroupNumPerDatabase(2);
+ EnvFactory.getEnv().getConfig().getCommonConfig().setPartitionInterval(1000);
+ EnvFactory.getEnv().initClusterEnvironment();
+ prepareData(SQLs);
+ }
+}
diff --git a/integration-test/src/test/java/org/apache/iotdb/db/it/aggregation/IoTDBCountIf3IT.java b/integration-test/src/test/java/org/apache/iotdb/db/it/aggregation/IoTDBCountIf3IT.java
new file mode 100644
index 0000000000..39b00c569d
--- /dev/null
+++ b/integration-test/src/test/java/org/apache/iotdb/db/it/aggregation/IoTDBCountIf3IT.java
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.it.aggregation;
+
+import org.apache.iotdb.it.env.EnvFactory;
+import org.apache.iotdb.it.framework.IoTDBTestRunner;
+import org.apache.iotdb.itbase.category.ClusterIT;
+import org.apache.iotdb.itbase.category.LocalStandaloneIT;
+
+import org.junit.BeforeClass;
+import org.junit.experimental.categories.Category;
+import org.junit.runner.RunWith;
+
+import static org.apache.iotdb.db.it.utils.TestUtils.prepareData;
+
+@RunWith(IoTDBTestRunner.class)
+@Category({LocalStandaloneIT.class, ClusterIT.class})
+public class IoTDBCountIf3IT extends IoTDBCountIfIT {
+ // 2 devices 3 regions
+ @BeforeClass
+ public static void setUp() throws Exception {
+ EnvFactory.getEnv().getConfig().getCommonConfig().setDataRegionGroupExtensionPolicy("CUSTOM");
+ EnvFactory.getEnv().getConfig().getCommonConfig().setDefaultDataRegionGroupNumPerDatabase(3);
+ EnvFactory.getEnv().getConfig().getCommonConfig().setPartitionInterval(1000);
+ EnvFactory.getEnv().initClusterEnvironment();
+ prepareData(SQLs);
+ }
+}
diff --git a/integration-test/src/test/java/org/apache/iotdb/db/it/aggregation/IoTDBCountIfIT.java b/integration-test/src/test/java/org/apache/iotdb/db/it/aggregation/IoTDBCountIfIT.java
new file mode 100644
index 0000000000..1486172151
--- /dev/null
+++ b/integration-test/src/test/java/org/apache/iotdb/db/it/aggregation/IoTDBCountIfIT.java
@@ -0,0 +1,241 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.it.aggregation;
+
+import org.apache.iotdb.it.env.EnvFactory;
+import org.apache.iotdb.it.framework.IoTDBTestRunner;
+import org.apache.iotdb.itbase.category.ClusterIT;
+import org.apache.iotdb.itbase.category.LocalStandaloneIT;
+import org.apache.iotdb.rpc.TSStatusCode;
+
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.junit.experimental.categories.Category;
+import org.junit.runner.RunWith;
+
+import static org.apache.iotdb.db.it.utils.TestUtils.assertTestFail;
+import static org.apache.iotdb.db.it.utils.TestUtils.prepareData;
+import static org.apache.iotdb.db.it.utils.TestUtils.resultSetEqualTest;
+import static org.apache.iotdb.itbase.constant.TestConstant.DEVICE;
+
+@RunWith(IoTDBTestRunner.class)
+@Category({LocalStandaloneIT.class, ClusterIT.class})
+public class IoTDBCountIfIT {
+ // 2 devices 4 regions
+ protected static final String[] SQLs =
+ new String[] {
+ "CREATE DATABASE root.db",
+ "CREATE TIMESERIES root.db.d1.s1 WITH DATATYPE=INT32, ENCODING=PLAIN",
+ "CREATE TIMESERIES root.db.d1.s2 WITH DATATYPE=INT32, ENCODING=PLAIN",
+ "CREATE TIMESERIES root.db.d1.s3 WITH DATATYPE=BOOLEAN, ENCODING=PLAIN",
+ "INSERT INTO root.db.d1(timestamp,s1,s2,s3) values(1, 0, 0, true)",
+ "INSERT INTO root.db.d1(timestamp,s1,s2) values(2, null, 0)",
+ "INSERT INTO root.db.d1(timestamp,s1,s2) values(3, 0, 0)",
+ "INSERT INTO root.db.d1(timestamp,s1,s2) values(4, 0, 0)",
+ "INSERT INTO root.db.d1(timestamp,s1,s2) values(5, 1, 0)",
+ "INSERT INTO root.db.d1(timestamp,s1,s2) values(6, 1, 0)",
+ "INSERT INTO root.db.d1(timestamp,s1,s2) values(7, 1, 0)",
+ "INSERT INTO root.db.d1(timestamp,s1,s2) values(8, 0, 0)",
+ "INSERT INTO root.db.d1(timestamp,s1,s2) values(5000000000, 0, 0)",
+ "INSERT INTO root.db.d1(timestamp,s1,s2) values(5000000001, 0, 0)",
+ "CREATE TIMESERIES root.db.d2.s1 WITH DATATYPE=INT32, ENCODING=PLAIN",
+ "CREATE TIMESERIES root.db.d2.s2 WITH DATATYPE=INT32, ENCODING=PLAIN",
+ "INSERT INTO root.db.d2(timestamp,s1,s2) values(1, 0, 0)",
+ "INSERT INTO root.db.d2(timestamp,s1,s2) values(2, null, 0)",
+ "INSERT INTO root.db.d2(timestamp,s1,s2) values(5000000000, 0, 0)",
+ "INSERT INTO root.db.d2(timestamp,s1,s2) values(5000000001, 0, 0)",
+ "flush"
+ };
+
+ @BeforeClass
+ public static void setUp() throws Exception {
+ EnvFactory.getEnv().getConfig().getCommonConfig().setPartitionInterval(1000);
+ EnvFactory.getEnv().initClusterEnvironment();
+ prepareData(SQLs);
+ }
+
+ @AfterClass
+ public static void tearDown() throws Exception {
+ EnvFactory.getEnv().cleanClusterEnvironment();
+ }
+
+ @Test
+ public void testCountIfIgnoreNull() {
+ // threshold constant
+ String[] expectedHeader =
+ new String[] {
+ "Count_if(root.db.d1.s1 = 0 & root.db.d1.s2 = 0, 3, \"ignoreNull\"=\"true\")",
+ "Count_if(root.db.d1.s1 = 1 & root.db.d1.s2 = 0, 3)"
+ };
+ String[] retArray = new String[] {"2,1,"};
+ resultSetEqualTest(
+ "select Count_if(s1=0 & s2=0, 3, 'ignoreNull'='true'), Count_if(s1=1 & s2=0, 3) from root.db.d1",
+ expectedHeader,
+ retArray);
+
+ // keep >= threshold
+ expectedHeader =
+ new String[] {
+ "Count_if(root.db.d1.s1 = 0 & root.db.d1.s2 = 0, keep >= 3, \"ignoreNull\"=\"true\")",
+ "Count_if(root.db.d1.s1 = 1 & root.db.d1.s2 = 0, keep >= 3)"
+ };
+ resultSetEqualTest(
+ "select Count_if(s1=0 & s2=0, keep>=3, 'ignoreNull'='true'), Count_if(s1=1 & s2=0, keep>=3) from root.db.d1",
+ expectedHeader,
+ retArray);
+
+ // keep < threshold
+ expectedHeader =
+ new String[] {
+ "Count_if(root.db.d1.s1 = 0 & root.db.d1.s2 = 0, keep < 3, \"ignoreNull\"=\"true\")",
+ "Count_if(root.db.d1.s1 = 1 & root.db.d1.s2 = 0, keep < 3)"
+ };
+ retArray = new String[] {"0,0,"};
+ resultSetEqualTest(
+ "select Count_if(s1=0 & s2=0, keep<3, 'ignoreNull'='true'), Count_if(s1=1 & s2=0, keep<3) from root.db.d1",
+ expectedHeader,
+ retArray);
+ }
+
+ @Test
+ public void testCountIfRespectNull() {
+ // threshold constant
+ String[] expectedHeader =
+ new String[] {
+ "Count_if(root.db.d1.s1 = 0 & root.db.d1.s2 = 0, 3, \"ignoreNull\"=\"false\")",
+ "Count_if(root.db.d1.s1 = 1 & root.db.d1.s2 = 0, 3, \"ignoreNull\"=\"false\")"
+ };
+ String[] retArray = new String[] {"1,1,"};
+ resultSetEqualTest(
+ "select Count_if(s1=0 & s2=0, 3, 'ignoreNull'='false'), Count_if(s1=1 & s2=0, 3, 'ignoreNull'='false') from root.db.d1",
+ expectedHeader,
+ retArray);
+
+ // keep >= threshold
+ expectedHeader =
+ new String[] {
+ "Count_if(root.db.d1.s1 = 0 & root.db.d1.s2 = 0, keep >= 3, \"ignoreNull\"=\"false\")",
+ "Count_if(root.db.d1.s1 = 1 & root.db.d1.s2 = 0, keep >= 3, \"ignoreNull\"=\"false\")"
+ };
+ resultSetEqualTest(
+ "select Count_if(s1=0 & s2=0, keep>=3, 'ignoreNull'='false'), Count_if(s1=1 & s2=0, keep>=3, 'ignoreNull'='false') from root.db.d1",
+ expectedHeader,
+ retArray);
+
+ // keep < threshold
+ expectedHeader =
+ new String[] {
+ "Count_if(root.db.d1.s1 = 0 & root.db.d1.s2 = 0, keep < 3, \"ignoreNull\"=\"false\")",
+ "Count_if(root.db.d1.s1 = 1 & root.db.d1.s2 = 0, keep < 3, \"ignoreNull\"=\"false\")"
+ };
+ retArray = new String[] {"2,0,"};
+ resultSetEqualTest(
+ "select Count_if(s1=0 & s2=0, keep<3, 'ignoreNull'='false'), Count_if(s1=1 & s2=0, keep<3, 'ignoreNull'='false') from root.db.d1",
+ expectedHeader,
+ retArray);
+ }
+
+ @Test
+ public void testCountIfAlignByDevice() {
+ String[] expectedHeader =
+ new String[] {
+ DEVICE,
+ "Count_if(s1 = 0 & s2 = 0, 3, \"ignoreNull\"=\"true\")",
+ "Count_if(s1 = 1 & s2 = 0, 3)"
+ };
+ String[] retArray = new String[] {"root.db.d1,2,1,", "root.db.d2,1,0,"};
+ resultSetEqualTest(
+ "select Count_if(s1=0 & s2=0, 3, 'ignoreNull'='true'), Count_if(s1=1 & s2=0, 3) from root.db.* align by device",
+ expectedHeader,
+ retArray);
+ }
+
+ @Test
+ public void testCountIfInHaving() {
+ String[] expectedHeader =
+ new String[] {
+ "Count_if(root.db.d1.s1 = 0 & root.db.d1.s2 = 0, 3, \"ignoreNull\"=\"true\")"
+ };
+ String[] retArray = new String[] {};
+ resultSetEqualTest(
+ "select Count_if(s1=0 & s2=0, 3, 'ignoreNull'='true') from root.db.d1 having Count_if(s1=1 & s2=0, 3) > 1",
+ expectedHeader,
+ retArray);
+
+ retArray = new String[] {"2,"};
+ resultSetEqualTest(
+ "select Count_if(s1=0 & s2=0, 3, 'ignoreNull'='true') from root.db.d1 having Count_if(s1=1 & s2=0, 3) > 0",
+ expectedHeader,
+ retArray);
+
+ // align by device
+ expectedHeader = new String[] {DEVICE, "Count_if(s1 = 0 & s2 = 0, 3, \"ignoreNull\"=\"true\")"};
+ retArray = new String[] {};
+ resultSetEqualTest(
+ "select Count_if(s1=0 & s2=0, 3, 'ignoreNull'='true') from root.db.d1 having Count_if(s1=1 & s2=0, 3) > 1 align by device",
+ expectedHeader,
+ retArray);
+
+ retArray = new String[] {"root.db.d1,2,"};
+ resultSetEqualTest(
+ "select Count_if(s1=0 & s2=0, 3, 'ignoreNull'='true') from root.db.d1 having Count_if(s1=1 & s2=0, 3) > 0 align by device",
+ expectedHeader,
+ retArray);
+ }
+
+ @Test
+ public void testContIfWithoutTransform() {
+ String[] expectedHeader = new String[] {"Count_if(root.db.d1.s3, 1)"};
+ String[] retArray = new String[] {"1,"};
+ resultSetEqualTest("select Count_if(s3, 1) from root.db.d1", expectedHeader, retArray);
+ }
+
+ @Test
+ public void testContIfWithGroupByLevel() {
+ String[] expectedHeader = new String[] {"Count_if(root.db.*.s1 = 0 & root.db.*.s2 = 0, 3)"};
+ String[] retArray = new String[] {"4,"};
+ resultSetEqualTest(
+ "select Count_if(s1=0 & s2=0, 3) from root.db.* group by level = 1",
+ expectedHeader,
+ retArray);
+
+ expectedHeader =
+ new String[] {
+ "Count_if(root.*.d1.s1 = 0 & root.*.d1.s2 = 0, 3)",
+ "Count_if(root.*.d1.s1 = 0 & root.*.d2.s2 = 0, 3)",
+ "Count_if(root.*.d2.s1 = 0 & root.*.d1.s2 = 0, 3)",
+ "Count_if(root.*.d2.s1 = 0 & root.*.d2.s2 = 0, 3)"
+ };
+ retArray = new String[] {"2,0,1,1,"};
+ resultSetEqualTest(
+ "select Count_if(s1=0 & s2=0, 3) from root.db.* group by level = 2",
+ expectedHeader,
+ retArray);
+ }
+
+ @Test
+ public void testContIfWithSlidingWindow() {
+ assertTestFail(
+ "select count_if(s1>1,1) from root.db.d1 group by time([1,10),3ms,2ms)",
+ TSStatusCode.EXECUTE_STATEMENT_ERROR.getStatusCode()
+ + ": COUNT_IF with slidingWindow is not supported now");
+ }
+}
diff --git a/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/BuiltinAggregationFunction.java b/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/BuiltinAggregationFunction.java
index 7ea3d0d5df..e0977a7eb9 100644
--- a/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/BuiltinAggregationFunction.java
+++ b/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/BuiltinAggregationFunction.java
@@ -35,6 +35,7 @@ public enum BuiltinAggregationFunction {
COUNT("count"),
AVG("avg"),
SUM("sum"),
+ COUNT_IF("count_if"),
;
private final String functionName;
@@ -56,4 +57,50 @@ public enum BuiltinAggregationFunction {
public static Set<String> getNativeFunctionNames() {
return NATIVE_FUNCTION_NAMES;
}
+
+ /** @return if the Aggregation can use statistics to optimize */
+ public static boolean canUseStatistics(String name) {
+ final String functionName = name.toLowerCase();
+ switch (functionName) {
+ case "min_time":
+ case "max_time":
+ case "max_value":
+ case "min_value":
+ case "extreme":
+ case "first_value":
+ case "last_value":
+ case "count":
+ case "avg":
+ case "sum":
+ return true;
+ case "count_if":
+ return false;
+ default:
+ throw new IllegalArgumentException("Invalid Aggregation function: " + name);
+ }
+ }
+
+ // TODO Maybe we can merge this method with canUseStatistics(),
+ // new method returns three level push-down: No push-down, DataRegion, SeriesScan
+ /** @return if the Aggregation can split to multi phases */
+ public static boolean canSplitToMultiPhases(String name) {
+ final String functionName = name.toLowerCase();
+ switch (functionName) {
+ case "min_time":
+ case "max_time":
+ case "max_value":
+ case "min_value":
+ case "extreme":
+ case "first_value":
+ case "last_value":
+ case "count":
+ case "avg":
+ case "sum":
+ return true;
+ case "count_if":
+ return false;
+ default:
+ throw new IllegalArgumentException("Invalid Aggregation function: " + name);
+ }
+ }
}
diff --git a/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/BuiltinScalarFunction.java b/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/BuiltinFunction.java
similarity index 77%
rename from node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/BuiltinScalarFunction.java
rename to node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/BuiltinFunction.java
index 24bc734400..1c64b99d15 100644
--- a/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/BuiltinScalarFunction.java
+++ b/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/BuiltinFunction.java
@@ -26,13 +26,13 @@ import java.util.HashSet;
import java.util.Set;
import java.util.stream.Collectors;
-public enum BuiltinScalarFunction {
+public enum BuiltinFunction {
DIFF("diff"),
;
private final String functionName;
- BuiltinScalarFunction(String functionName) {
+ BuiltinFunction(String functionName) {
this.functionName = functionName;
}
@@ -42,8 +42,8 @@ public enum BuiltinScalarFunction {
private static final Set<String> NATIVE_FUNCTION_NAMES =
new HashSet<>(
- Arrays.stream(BuiltinScalarFunction.values())
- .map(BuiltinScalarFunction::getFunctionName)
+ Arrays.stream(BuiltinFunction.values())
+ .map(BuiltinFunction::getFunctionName)
.collect(Collectors.toList()));
/**
@@ -56,4 +56,15 @@ public enum BuiltinScalarFunction {
public static Set<String> getNativeFunctionNames() {
return NATIVE_FUNCTION_NAMES;
}
+
+ // indicate whether the function is 'input one row, output one row'
+ public static boolean isMappable(String name) {
+ final String functionName = name.toLowerCase();
+ switch (functionName) {
+ case "diff":
+ return true;
+ default:
+ throw new IllegalArgumentException("Invalid BuiltInFunction: " + name);
+ }
+ }
}
diff --git a/server/src/main/java/org/apache/iotdb/db/constant/SqlConstant.java b/server/src/main/java/org/apache/iotdb/db/constant/SqlConstant.java
index 1861d98eb7..42404c86c7 100644
--- a/server/src/main/java/org/apache/iotdb/db/constant/SqlConstant.java
+++ b/server/src/main/java/org/apache/iotdb/db/constant/SqlConstant.java
@@ -53,6 +53,7 @@ public class SqlConstant {
public static final String COUNT = "count";
public static final String AVG = "avg";
public static final String SUM = "sum";
+ public static final String COUNT_IF = "count_if";
// names of scalar functions
public static final String DIFF = "diff";
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/aggregation/AccumulatorFactory.java b/server/src/main/java/org/apache/iotdb/db/mpp/aggregation/AccumulatorFactory.java
index 455c58a41b..2689c2e401 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/aggregation/AccumulatorFactory.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/aggregation/AccumulatorFactory.java
@@ -20,16 +20,25 @@
package org.apache.iotdb.db.mpp.aggregation;
import org.apache.iotdb.common.rpc.thrift.TAggregationType;
+import org.apache.iotdb.db.mpp.plan.expression.Expression;
+import org.apache.iotdb.db.mpp.plan.expression.binary.CompareBinaryExpression;
+import org.apache.iotdb.db.mpp.plan.expression.leaf.ConstantOperand;
import org.apache.iotdb.tsfile.file.metadata.enums.TSDataType;
import java.util.ArrayList;
import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
public class AccumulatorFactory {
// TODO: Are we going to create different seriesScanOperator based on order by sequence?
public static Accumulator createAccumulator(
- TAggregationType aggregationType, TSDataType tsDataType, boolean ascending) {
+ TAggregationType aggregationType,
+ TSDataType tsDataType,
+ List<Expression> inputExpressions,
+ Map<String, String> inputAttributes,
+ boolean ascending) {
switch (aggregationType) {
case COUNT:
return new CountAccumulator();
@@ -55,17 +64,58 @@ public class AccumulatorFactory {
return ascending
? new FirstValueAccumulator(tsDataType)
: new FirstValueDescAccumulator(tsDataType);
+ case COUNT_IF:
+ return new CountIfAccumulator(
+ initKeepEvaluator(inputExpressions.get(1)),
+ Boolean.parseBoolean(inputAttributes.getOrDefault("ignoreNull", "true")));
default:
throw new IllegalArgumentException("Invalid Aggregation function: " + aggregationType);
}
}
public static List<Accumulator> createAccumulators(
- List<TAggregationType> aggregationTypes, TSDataType tsDataType, boolean ascending) {
+ List<TAggregationType> aggregationTypes,
+ TSDataType tsDataType,
+ List<Expression> inputExpressions,
+ Map<String, String> inputAttributes,
+ boolean ascending) {
List<Accumulator> accumulators = new ArrayList<>();
for (TAggregationType aggregationType : aggregationTypes) {
- accumulators.add(createAccumulator(aggregationType, tsDataType, ascending));
+ accumulators.add(
+ createAccumulator(
+ aggregationType, tsDataType, inputExpressions, inputAttributes, ascending));
}
return accumulators;
}
+
+ private static Function<Long, Boolean> initKeepEvaluator(Expression keepExpression) {
+ // We have check semantic in FE,
+ // keep expression must be ConstantOperand or CompareBinaryExpression here
+ if (keepExpression instanceof ConstantOperand) {
+ return keep -> keep >= Long.parseLong(keepExpression.toString());
+ } else {
+ long constant =
+ Long.parseLong(
+ ((CompareBinaryExpression) keepExpression)
+ .getRightExpression()
+ .getExpressionString());
+ switch (keepExpression.getExpressionType()) {
+ case LESS_THAN:
+ return keep -> keep < constant;
+ case LESS_EQUAL:
+ return keep -> keep <= constant;
+ case GREATER_THAN:
+ return keep -> keep > constant;
+ case GREATER_EQUAL:
+ return keep -> keep >= constant;
+ case EQUAL_TO:
+ return keep -> keep == constant;
+ case NON_EQUAL:
+ return keep -> keep != constant;
+ default:
+ throw new IllegalArgumentException(
+ "unsupported expression type: " + keepExpression.getExpressionType());
+ }
+ }
+ }
}
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/aggregation/CountIfAccumulator.java b/server/src/main/java/org/apache/iotdb/db/mpp/aggregation/CountIfAccumulator.java
new file mode 100644
index 0000000000..046e7e82c4
--- /dev/null
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/aggregation/CountIfAccumulator.java
@@ -0,0 +1,156 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.mpp.aggregation;
+
+import org.apache.iotdb.db.mpp.execution.operator.window.IWindow;
+import org.apache.iotdb.tsfile.file.metadata.enums.TSDataType;
+import org.apache.iotdb.tsfile.file.metadata.statistics.Statistics;
+import org.apache.iotdb.tsfile.read.common.block.column.Column;
+import org.apache.iotdb.tsfile.read.common.block.column.ColumnBuilder;
+
+import java.util.function.Function;
+
+import static com.google.common.base.Preconditions.checkArgument;
+
+public class CountIfAccumulator implements Accumulator {
+
+ // number of the point segment that satisfies the KEEP expression
+ private long countValue = 0;
+
+ // number of the continues data points satisfy IF expression
+ private long keep;
+
+ private final Function<Long, Boolean> keepEvaluator;
+
+ private final boolean ignoreNull;
+
+ private boolean lastPointIsSatisfy;
+
+ public CountIfAccumulator(Function<Long, Boolean> keepEvaluator, boolean ignoreNull) {
+ this.keepEvaluator = keepEvaluator;
+ this.ignoreNull = ignoreNull;
+ }
+
+ // Column should be like: | ControlColumn | Time | Value |
+ @Override
+ public int addInput(Column[] column, IWindow curWindow, boolean ignoringNull) {
+ int curPositionCount = column[0].getPositionCount();
+ for (int i = 0; i < curPositionCount; i++) {
+ // skip null value in control column
+ // the input parameter 'ignoringNull' effects on ControlColumn
+ if (ignoringNull && column[0].isNull(i)) {
+ continue;
+ }
+ if (!curWindow.satisfy(column[0], i)) {
+ return i;
+ }
+ curWindow.mergeOnePoint(column, i);
+
+ if (column[2].isNull(i)) {
+ // the member variable 'ignoreNull' effects on calculation of ValueColumn
+ if (!this.ignoreNull) {
+ // data point segment was over, judge whether to count
+ if (lastPointIsSatisfy && keepEvaluator.apply(keep)) {
+ countValue++;
+ }
+ keep = 0;
+ lastPointIsSatisfy = false;
+ }
+ } else {
+ if (column[2].getBoolean(i)) {
+ keep++;
+ lastPointIsSatisfy = true;
+ } else {
+ // data point segment was over, judge whether to count
+ if (lastPointIsSatisfy && keepEvaluator.apply(keep)) {
+ countValue++;
+ }
+ keep = 0;
+ lastPointIsSatisfy = false;
+ }
+ }
+ }
+
+ return curPositionCount;
+ }
+
+ @Override
+ public void addIntermediate(Column[] partialResult) {
+ checkArgument(partialResult.length == 1, "partialResult of count_if should be 1");
+ if (partialResult[0].isNull(0)) {
+ return;
+ }
+ countValue += partialResult[0].getLong(0);
+ }
+
+ @Override
+ public void addStatistics(Statistics statistics) {
+ throw new UnsupportedOperationException(getClass().getName());
+ }
+
+ // finalResult should be single column, like: | finalCountValue |
+ @Override
+ public void setFinal(Column finalResult) {
+ if (finalResult.isNull(0)) {
+ return;
+ }
+ countValue = finalResult.getLong(0);
+ }
+
+ @Override
+ public void outputIntermediate(ColumnBuilder[] columnBuilders) {
+ checkArgument(columnBuilders.length == 1, "partialResult of count_if should be 1");
+ // judge whether the last data point segment need to count
+ if (lastPointIsSatisfy && keepEvaluator.apply(keep)) {
+ countValue++;
+ }
+ columnBuilders[0].writeLong(countValue);
+ }
+
+ @Override
+ public void outputFinal(ColumnBuilder columnBuilder) {
+ // judge whether the last data point segment need to count
+ if (lastPointIsSatisfy && keepEvaluator.apply(keep)) {
+ countValue++;
+ }
+ columnBuilder.writeLong(countValue);
+ }
+
+ @Override
+ public void reset() {
+ this.countValue = 0;
+ this.keep = 0;
+ }
+
+ @Override
+ public boolean hasFinalResult() {
+ return false;
+ }
+
+ @Override
+ public TSDataType[] getIntermediateType() {
+ return new TSDataType[] {TSDataType.INT64};
+ }
+
+ @Override
+ public TSDataType getFinalType() {
+ return TSDataType.INT64;
+ }
+}
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/aggregation/slidingwindow/SlidingWindowAggregatorFactory.java b/server/src/main/java/org/apache/iotdb/db/mpp/aggregation/slidingwindow/SlidingWindowAggregatorFactory.java
index 341c81a216..63ea59167c 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/aggregation/slidingwindow/SlidingWindowAggregatorFactory.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/aggregation/slidingwindow/SlidingWindowAggregatorFactory.java
@@ -20,8 +20,10 @@
package org.apache.iotdb.db.mpp.aggregation.slidingwindow;
import org.apache.iotdb.common.rpc.thrift.TAggregationType;
+import org.apache.iotdb.db.exception.sql.SemanticException;
import org.apache.iotdb.db.mpp.aggregation.Accumulator;
import org.apache.iotdb.db.mpp.aggregation.AccumulatorFactory;
+import org.apache.iotdb.db.mpp.plan.expression.Expression;
import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.AggregationStep;
import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.InputLocation;
import org.apache.iotdb.tsfile.file.metadata.enums.TSDataType;
@@ -116,11 +118,14 @@ public class SlidingWindowAggregatorFactory {
public static SlidingWindowAggregator createSlidingWindowAggregator(
TAggregationType aggregationType,
TSDataType dataType,
+ List<Expression> inputExpressions,
+ Map<String, String> inputAttributes,
boolean ascending,
List<InputLocation[]> inputLocationList,
AggregationStep step) {
Accumulator accumulator =
- AccumulatorFactory.createAccumulator(aggregationType, dataType, ascending);
+ AccumulatorFactory.createAccumulator(
+ aggregationType, dataType, inputExpressions, inputAttributes, ascending);
switch (aggregationType) {
case SUM:
case AVG:
@@ -145,6 +150,8 @@ public class SlidingWindowAggregatorFactory {
return !ascending
? new NormalQueueSlidingWindowAggregator(accumulator, inputLocationList, step)
: new EmptyQueueSlidingWindowAggregator(accumulator, inputLocationList, step);
+ case COUNT_IF:
+ throw new SemanticException("COUNT_IF with slidingWindow is not supported now");
default:
throw new IllegalArgumentException("Invalid Aggregation Type: " + aggregationType);
}
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/AnalyzeVisitor.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/AnalyzeVisitor.java
index 20e2eec4db..6b3e7074c6 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/AnalyzeVisitor.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/AnalyzeVisitor.java
@@ -878,7 +878,11 @@ public class AnalyzeVisitor extends StatementVisitor<Analysis, MPPQueryContext>
Set<Expression> aggregationExpressions = deviceToAggregationExpressions.get(deviceName);
Set<Expression> sourceTransformExpressions = new LinkedHashSet<>();
for (Expression expression : aggregationExpressions) {
- sourceTransformExpressions.addAll(expression.getExpressions());
+ // We just process first input Expression of AggregationFunction,
+ // keep other input Expressions as origin
+ // If AggregationFunction need more than one input series,
+ // we need to reconsider the process of it
+ sourceTransformExpressions.add(expression.getExpressions().get(0));
}
if (analysis.hasGroupByParameter()) {
sourceTransformExpressions.add(analysis.getDeviceToGroupByExpression().get(deviceName));
@@ -895,7 +899,8 @@ public class AnalyzeVisitor extends StatementVisitor<Analysis, MPPQueryContext>
Set<Expression> sourceTransformExpressions = new HashSet<>();
if (queryStatement.isAggregationQuery()) {
for (Expression expression : analysis.getAggregationExpressions()) {
- sourceTransformExpressions.addAll(expression.getExpressions());
+ // for AggregationExpression, only the first Expression of input need to transform
+ sourceTransformExpressions.add(expression.getExpressions().get(0));
}
} else {
sourceTransformExpressions = analysis.getSelectExpressions();
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/ExpressionAnalyzer.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/ExpressionAnalyzer.java
index 54acc3431a..20588ac297 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/ExpressionAnalyzer.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/ExpressionAnalyzer.java
@@ -24,7 +24,7 @@ import org.apache.iotdb.commons.exception.IllegalPathException;
import org.apache.iotdb.commons.path.MeasurementPath;
import org.apache.iotdb.commons.path.PartialPath;
import org.apache.iotdb.commons.path.PathPatternTree;
-import org.apache.iotdb.commons.udf.builtin.BuiltinScalarFunction;
+import org.apache.iotdb.commons.udf.builtin.BuiltinFunction;
import org.apache.iotdb.db.constant.SqlConstant;
import org.apache.iotdb.db.exception.sql.SemanticException;
import org.apache.iotdb.db.mpp.common.header.ColumnHeader;
@@ -73,6 +73,7 @@ import static org.apache.iotdb.db.mpp.plan.analyze.ExpressionUtils.reconstructTi
import static org.apache.iotdb.db.mpp.plan.analyze.ExpressionUtils.reconstructTimeSeriesOperands;
import static org.apache.iotdb.db.mpp.plan.analyze.ExpressionUtils.reconstructUnaryExpression;
import static org.apache.iotdb.db.mpp.plan.analyze.ExpressionUtils.reconstructUnaryExpressions;
+import static org.apache.iotdb.db.utils.TypeInferenceUtils.bindTypeForAggregationNonSeriesInputExpressions;
public class ExpressionAnalyzer {
/**
@@ -268,6 +269,18 @@ public class ExpressionAnalyzer {
for (Expression suffixExpression : expression.getExpressions()) {
extendedExpressions.add(
concatExpressionWithSuffixPaths(suffixExpression, prefixPaths, patternTree));
+
+ // We just process first input Expression of AggregationFunction,
+ // keep other input Expressions as origin
+ // If AggregationFunction need more than one input series,
+ // we need to reconsider the process of it
+ if (expression.isBuiltInAggregationFunctionExpression()) {
+ List<Expression> children = expression.getExpressions();
+ for (int i = 1; i < children.size(); i++) {
+ extendedExpressions.add(Collections.singletonList(children.get(i)));
+ }
+ break;
+ }
}
List<List<Expression>> childExpressionsList = new ArrayList<>();
cartesianProduct(extendedExpressions, childExpressionsList, 0, new ArrayList<>());
@@ -456,6 +469,17 @@ public class ExpressionAnalyzer {
return Collections.emptyList();
}
extendedExpressions.add(actualExpressions);
+
+ // We just process first input Expression of AggregationFunction,
+ // keep other input Expressions as origin and bind Type
+ // If AggregationFunction need more than one input series,
+ // we need to reconsider the process of it
+ if (expression.isBuiltInAggregationFunctionExpression()) {
+ List<Expression> children = expression.getExpressions();
+ bindTypeForAggregationNonSeriesInputExpressions(
+ ((FunctionExpression) expression).getFunctionName(), children, extendedExpressions);
+ break;
+ }
}
// Calculate the Cartesian product of extendedExpressions to get the actual expressions after
@@ -528,6 +552,17 @@ public class ExpressionAnalyzer {
for (Expression suffixExpression : predicate.getExpressions()) {
extendedExpressions.add(
removeWildcardInFilter(suffixExpression, prefixPaths, schemaTree, false));
+
+ // We just process first input Expression of AggregationFunction,
+ // keep other input Expressions as origin and bind Type
+ // If AggregationFunction need more than one input series,
+ // we need to reconsider the process of it
+ if (predicate.isBuiltInAggregationFunctionExpression()) {
+ List<Expression> children = predicate.getExpressions();
+ bindTypeForAggregationNonSeriesInputExpressions(
+ ((FunctionExpression) predicate).getFunctionName(), children, extendedExpressions);
+ break;
+ }
}
List<List<Expression>> childExpressionsList = new ArrayList<>();
cartesianProduct(extendedExpressions, childExpressionsList, 0, new ArrayList<>());
@@ -652,6 +687,17 @@ public class ExpressionAnalyzer {
if (concatedExpression != null && !concatedExpression.isEmpty()) {
extendedExpressions.add(concatedExpression);
}
+
+ // We just process first input Expression of AggregationFunction,
+ // keep other input Expressions as origin and bind Type
+ // If AggregationFunction need more than one input series,
+ // we need to reconsider the process of it
+ if (expression.isBuiltInAggregationFunctionExpression()) {
+ List<Expression> children = expression.getExpressions();
+ bindTypeForAggregationNonSeriesInputExpressions(
+ ((FunctionExpression) expression).getFunctionName(), children, extendedExpressions);
+ break;
+ }
}
List<List<Expression>> childExpressionsList = new ArrayList<>();
cartesianProduct(extendedExpressions, childExpressionsList, 0, new ArrayList<>());
@@ -1267,8 +1313,8 @@ public class ExpressionAnalyzer {
} else if (expression instanceof UnaryExpression) {
return isDeviceViewNeedSpecialProcess(((UnaryExpression) expression).getExpression());
} else if (expression instanceof FunctionExpression) {
- if (((FunctionExpression) expression).isBuiltInScalarFunction()
- && BuiltinScalarFunction.DEVICE_VIEW_SPECIAL_PROCESS_FUNCTIONS.contains(
+ if (((FunctionExpression) expression).isBuiltInFunction()
+ && BuiltinFunction.DEVICE_VIEW_SPECIAL_PROCESS_FUNCTIONS.contains(
((FunctionExpression) expression).getFunctionName().toLowerCase())) {
return true;
}
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/ExpressionTypeAnalyzer.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/ExpressionTypeAnalyzer.java
index e7d841766c..0a6ec4cfcc 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/ExpressionTypeAnalyzer.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/ExpressionTypeAnalyzer.java
@@ -256,10 +256,10 @@ public class ExpressionTypeAnalyzer {
functionExpression.getFunctionName(),
expressionTypes.get(NodeRef.of(inputExpressions.get(0)))));
}
- if (functionExpression.isBuiltInScalarFunction()) {
+ if (functionExpression.isBuiltInFunction()) {
return setExpressionType(
functionExpression,
- TypeInferenceUtils.getScalarFunctionDataType(
+ TypeInferenceUtils.getBuiltInFunctionDataType(
functionExpression.getFunctionName(),
expressionTypes.get(NodeRef.of(inputExpressions.get(0)))));
} else {
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/expression/multi/FunctionExpression.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/expression/multi/FunctionExpression.java
index 66b458ef68..a9910fa941 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/expression/multi/FunctionExpression.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/expression/multi/FunctionExpression.java
@@ -22,7 +22,7 @@ package org.apache.iotdb.db.mpp.plan.expression.multi;
import org.apache.iotdb.commons.conf.IoTDBConstant;
import org.apache.iotdb.commons.path.PartialPath;
import org.apache.iotdb.commons.udf.builtin.BuiltinAggregationFunction;
-import org.apache.iotdb.commons.udf.builtin.BuiltinScalarFunction;
+import org.apache.iotdb.commons.udf.builtin.BuiltinFunction;
import org.apache.iotdb.db.mpp.common.NodeRef;
import org.apache.iotdb.db.mpp.plan.expression.Expression;
import org.apache.iotdb.db.mpp.plan.expression.ExpressionType;
@@ -50,13 +50,7 @@ import java.util.stream.Collectors;
public class FunctionExpression extends Expression {
- /**
- * true: aggregation function<br>
- * false: time series generating function
- */
- private Boolean isBuiltInAggregationFunctionExpressionCache;
-
- private Boolean isBuiltInScalarFunctionCache;
+ private FunctionType functionType;
private final String functionName;
private final LinkedHashMap<String, String> functionAttributes;
@@ -104,21 +98,30 @@ public class FunctionExpression extends Expression {
return visitor.visitFunctionExpression(this, context);
}
+ private void initializeFunctionType() {
+ final String functionName = this.functionName.toLowerCase();
+ if (BuiltinAggregationFunction.getNativeFunctionNames().contains(functionName)) {
+ functionType = FunctionType.AGGREGATION_FUNCTION;
+ } else if (BuiltinFunction.getNativeFunctionNames().contains(functionName)) {
+ functionType = FunctionType.BUILT_IN_FUNCTION;
+ } else {
+ functionType = FunctionType.UDF;
+ }
+ }
+
@Override
public boolean isBuiltInAggregationFunctionExpression() {
- if (isBuiltInAggregationFunctionExpressionCache == null) {
- isBuiltInAggregationFunctionExpressionCache =
- BuiltinAggregationFunction.getNativeFunctionNames().contains(functionName.toLowerCase());
+ if (functionType == null) {
+ initializeFunctionType();
}
- return isBuiltInAggregationFunctionExpressionCache;
+ return functionType == FunctionType.AGGREGATION_FUNCTION;
}
- public Boolean isBuiltInScalarFunction() {
- if (isBuiltInScalarFunctionCache == null) {
- isBuiltInScalarFunctionCache =
- BuiltinScalarFunction.getNativeFunctionNames().contains(functionName.toLowerCase());
+ public Boolean isBuiltInFunction() {
+ if (functionType == null) {
+ initializeFunctionType();
}
- return isBuiltInScalarFunctionCache;
+ return functionType == FunctionType.BUILT_IN_FUNCTION;
}
@Override
@@ -208,7 +211,8 @@ public class FunctionExpression extends Expression {
@Override
public boolean isMappable(Map<NodeRef<Expression>, TSDataType> expressionTypes) {
- if (isBuiltInAggregationFunctionExpression() || isBuiltInScalarFunction()) {
+ if (isBuiltInAggregationFunctionExpression()
+ || (isBuiltInFunction() && BuiltinFunction.isMappable(functionName))) {
return true;
}
return new UDTFInformationInferrer(functionName)
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/expression/multi/FunctionType.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/expression/multi/FunctionType.java
new file mode 100644
index 0000000000..8d13d370a9
--- /dev/null
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/expression/multi/FunctionType.java
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.mpp.plan.expression.multi;
+
+/** */
+public enum FunctionType {
+ AGGREGATION_FUNCTION,
+ BUILT_IN_FUNCTION,
+ UDF
+}
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/expression/visitor/ColumnTransformerVisitor.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/expression/visitor/ColumnTransformerVisitor.java
index 3c46c36e5b..c980f79967 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/expression/visitor/ColumnTransformerVisitor.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/expression/visitor/ColumnTransformerVisitor.java
@@ -220,9 +220,9 @@ public class ColumnTransformerVisitor
.getValueColumnIndex());
context.leafList.add(identity);
context.cache.put(functionExpression, identity);
- } else if (functionExpression.isBuiltInScalarFunction()) {
+ } else if (functionExpression.isBuiltInFunction()) {
context.cache.put(
- functionExpression, getBuiltInScalarFunctionTransformer(functionExpression, context));
+ functionExpression, getBuiltInFunctionTransformer(functionExpression, context));
} else {
ColumnTransformer[] inputColumnTransformers =
expressions.stream()
@@ -259,7 +259,7 @@ public class ColumnTransformerVisitor
return res;
}
- private ColumnTransformer getBuiltInScalarFunctionTransformer(
+ private ColumnTransformer getBuiltInFunctionTransformer(
FunctionExpression expression, ColumnTransformerVisitorContext context) {
ColumnTransformer childColumnTransformer =
this.process(expression.getExpressions().get(0), context);
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/expression/visitor/IntermediateLayerVisitor.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/expression/visitor/IntermediateLayerVisitor.java
index 5839fa5217..68681dfac7 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/expression/visitor/IntermediateLayerVisitor.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/expression/visitor/IntermediateLayerVisitor.java
@@ -200,8 +200,8 @@ public class IntermediateLayerVisitor
new TransparentTransformer(
context.rawTimeSeriesInputLayer.constructValuePointReader(
functionExpression.getInputColumnIndex()));
- } else if (functionExpression.isBuiltInScalarFunction()) {
- transformer = getBuiltInScalarFunctionTransformer(functionExpression, context);
+ } else if (functionExpression.isBuiltInFunction()) {
+ transformer = getBuiltInFunctionTransformer(functionExpression, context);
} else {
try {
IntermediateLayer udfInputIntermediateLayer =
@@ -223,7 +223,7 @@ public class IntermediateLayerVisitor
return context.expressionIntermediateLayerMap.get(functionExpression);
}
- private Transformer getBuiltInScalarFunctionTransformer(
+ private Transformer getBuiltInFunctionTransformer(
FunctionExpression expression, IntermediateLayerVisitorContext context) {
LayerPointReader childPointReader =
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/parser/ASTVisitor.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/parser/ASTVisitor.java
index 0dfb9407b4..c8c54811ab 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/parser/ASTVisitor.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/parser/ASTVisitor.java
@@ -192,6 +192,7 @@ import java.net.URI;
import java.net.URISyntaxException;
import java.time.ZoneId;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashSet;
@@ -2318,9 +2319,75 @@ public class ASTVisitor extends IoTDBSqlParserBaseVisitor<Statement> {
"Invalid function expression, all the arguments are constant operands: "
+ functionClause.getText());
}
+
+ // check size of input expressions
+ // type check of input expressions is put in ExpressionTypeAnalyzer
+ if (functionExpression.isBuiltInAggregationFunctionExpression()) {
+ checkAggregationFunctionInput(functionExpression);
+ } else if (functionExpression.isBuiltInFunction()) {
+ checkBuiltInFunctionInput(functionExpression);
+ }
return functionExpression;
}
+ private void checkAggregationFunctionInput(FunctionExpression functionExpression) {
+ final String functionName = functionExpression.getFunctionName().toLowerCase();
+ switch (functionName) {
+ case SqlConstant.MIN_TIME:
+ case SqlConstant.MAX_TIME:
+ case SqlConstant.COUNT:
+ case SqlConstant.MIN_VALUE:
+ case SqlConstant.LAST_VALUE:
+ case SqlConstant.FIRST_VALUE:
+ case SqlConstant.MAX_VALUE:
+ case SqlConstant.EXTREME:
+ case SqlConstant.AVG:
+ case SqlConstant.SUM:
+ checkFunctionExpressionInputSize(
+ functionExpression.getExpressionString(),
+ functionExpression.getExpressions().size(),
+ 1);
+ return;
+ case SqlConstant.COUNT_IF:
+ checkFunctionExpressionInputSize(
+ functionExpression.getExpressionString(),
+ functionExpression.getExpressions().size(),
+ 2);
+ return;
+ default:
+ throw new IllegalArgumentException(
+ "Invalid Aggregation function: " + functionExpression.getFunctionName());
+ }
+ }
+
+ private void checkBuiltInFunctionInput(FunctionExpression functionExpression) {
+ final String functionName = functionExpression.getFunctionName().toLowerCase();
+ switch (functionName) {
+ case SqlConstant.DIFF:
+ checkFunctionExpressionInputSize(
+ functionExpression.getExpressionString(),
+ functionExpression.getExpressions().size(),
+ 1);
+ return;
+ default:
+ throw new IllegalArgumentException(
+ "Invalid BuiltInFunction: " + functionExpression.getFunctionName());
+ }
+ }
+
+ private void checkFunctionExpressionInputSize(
+ String expressionString, int actual, int... expected) {
+ for (int expect : expected) {
+ if (expect == actual) {
+ return;
+ }
+ }
+ throw new SemanticException(
+ String.format(
+ "Error size of input expressions. expression: %s, actual size: %s, expected size: %s.",
+ expressionString, actual, Arrays.toString(expected)));
+ }
+
private Expression parseRegularExpression(ExpressionContext context, boolean inWithoutNull) {
return new RegularExpression(
parseExpression(context.unaryBeforeRegularOrLikeExpression, inWithoutNull),
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/LogicalPlanBuilder.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/LogicalPlanBuilder.java
index c16b3dd419..d3893276a8 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/LogicalPlanBuilder.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/LogicalPlanBuilder.java
@@ -357,7 +357,10 @@ public class LogicalPlanBuilder {
Map<PartialPath, List<AggregationDescriptor>> descendingAggregations) {
AggregationDescriptor aggregationDescriptor =
new AggregationDescriptor(
- sourceExpression.getFunctionName(), curStep, sourceExpression.getExpressions());
+ sourceExpression.getFunctionName(),
+ curStep,
+ sourceExpression.getExpressions(),
+ sourceExpression.getFunctionAttributes());
if (curStep.isOutputPartial()) {
updateTypeProviderByPartialAggregation(aggregationDescriptor, context.getTypeProvider());
}
@@ -692,16 +695,18 @@ public class LogicalPlanBuilder {
GroupByTimeParameter groupByTimeParameter,
Ordering scanOrder) {
List<CrossSeriesAggregationDescriptor> groupByLevelDescriptors = new ArrayList<>();
- for (Expression groupedExpression : groupByLevelExpressions.keySet()) {
+ for (Map.Entry<Expression, Set<Expression>> entry : groupByLevelExpressions.entrySet()) {
groupByLevelDescriptors.add(
new CrossSeriesAggregationDescriptor(
- ((FunctionExpression) groupedExpression).getFunctionName(),
+ ((FunctionExpression) entry.getKey()).getFunctionName(),
curStep,
- groupByLevelExpressions.get(groupedExpression).stream()
+ entry.getValue().stream()
.map(Expression::getExpressions)
.flatMap(List::stream)
.collect(Collectors.toList()),
- groupedExpression.getExpressions().get(0)));
+ entry.getValue().size(),
+ ((FunctionExpression) entry.getKey()).getFunctionAttributes(),
+ entry.getKey().getExpressions().get(0)));
}
updateTypeProvider(groupByLevelExpressions.keySet());
updateTypeProvider(
@@ -746,6 +751,7 @@ public class LogicalPlanBuilder {
functionName,
curStep,
groupedTimeseriesOperands.get(next),
+ ((FunctionExpression) next).getFunctionAttributes(),
next.getExpressions().get(0));
aggregationDescriptors.add(aggregationDescriptor);
} else {
@@ -809,7 +815,8 @@ public class LogicalPlanBuilder {
return new AggregationDescriptor(
((FunctionExpression) expression).getFunctionName(),
curStep,
- expression.getExpressions());
+ expression.getExpressions(),
+ ((FunctionExpression) expression).getFunctionAttributes());
})
.collect(Collectors.toList());
}
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/LogicalPlanVisitor.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/LogicalPlanVisitor.java
index ced7492805..7696aab20a 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/LogicalPlanVisitor.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/LogicalPlanVisitor.java
@@ -19,11 +19,13 @@
package org.apache.iotdb.db.mpp.plan.planner;
import org.apache.iotdb.commons.path.PartialPath;
+import org.apache.iotdb.commons.udf.builtin.BuiltinAggregationFunction;
import org.apache.iotdb.db.metadata.template.Template;
import org.apache.iotdb.db.mpp.common.MPPQueryContext;
import org.apache.iotdb.db.mpp.plan.analyze.Analysis;
import org.apache.iotdb.db.mpp.plan.analyze.ExpressionAnalyzer;
import org.apache.iotdb.db.mpp.plan.expression.Expression;
+import org.apache.iotdb.db.mpp.plan.expression.multi.FunctionExpression;
import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanNode;
import org.apache.iotdb.db.mpp.plan.planner.plan.node.load.LoadTsFileNode;
import org.apache.iotdb.db.mpp.plan.planner.plan.node.metedata.write.ActivateTemplateNode;
@@ -75,6 +77,8 @@ import org.apache.iotdb.db.mpp.plan.statement.metadata.template.ShowPathsUsingTe
import org.apache.iotdb.db.mpp.plan.statement.sys.ShowQueriesStatement;
import org.apache.iotdb.tsfile.utils.Pair;
+import org.apache.commons.lang3.Validate;
+
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
@@ -215,8 +219,9 @@ public class LogicalPlanVisitor extends StatementVisitor<PlanNode, MPPQueryConte
// aggregation query
boolean isRawDataSource =
analysis.hasValueFilter()
+ || analysis.hasGroupByParameter()
|| needTransform(sourceTransformExpressions)
- || analysis.hasGroupByParameter();
+ || cannotUseStatistics(aggregationExpressions);
AggregationStep curStep;
if (isRawDataSource) {
planBuilder =
@@ -311,6 +316,19 @@ public class LogicalPlanVisitor extends StatementVisitor<PlanNode, MPPQueryConte
return false;
}
+ private boolean cannotUseStatistics(Set<Expression> expressions) {
+ for (Expression expression : expressions) {
+ Validate.isTrue(
+ expression instanceof FunctionExpression,
+ String.format("Invalid Aggregation Expression: %s", expression.getExpressionString()));
+ if (!BuiltinAggregationFunction.canUseStatistics(
+ ((FunctionExpression) expression).getFunctionName())) {
+ return true;
+ }
+ }
+ return false;
+ }
+
@Override
public PlanNode visitCreateTimeseries(
CreateTimeSeriesStatement createTimeSeriesStatement, MPPQueryContext context) {
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/OperatorTreeGenerator.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/OperatorTreeGenerator.java
index 3d4f6acda1..5cf08f8d98 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/OperatorTreeGenerator.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/OperatorTreeGenerator.java
@@ -340,7 +340,11 @@ public class OperatorTreeGenerator extends PlanVisitor<Operator, LocalExecutionP
aggregators.add(
new Aggregator(
AccumulatorFactory.createAccumulator(
- o.getAggregationType(), node.getSeriesPath().getSeriesType(), ascending),
+ o.getAggregationType(),
+ node.getSeriesPath().getSeriesType(),
+ o.getInputExpressions(),
+ o.getInputAttributes(),
+ ascending),
o.getStep())));
GroupByTimeParameter groupByTimeParameter = node.getGroupByTimeParameter();
@@ -401,7 +405,11 @@ public class OperatorTreeGenerator extends PlanVisitor<Operator, LocalExecutionP
aggregators.add(
new Aggregator(
AccumulatorFactory.createAccumulator(
- descriptor.getAggregationType(), seriesDataType, ascending),
+ descriptor.getAggregationType(),
+ seriesDataType,
+ descriptor.getInputExpressions(),
+ descriptor.getInputAttributes(),
+ ascending),
descriptor.getStep(),
Collections.singletonList(new InputLocation[] {new InputLocation(0, seriesIndex)})));
}
@@ -1186,7 +1194,11 @@ public class OperatorTreeGenerator extends PlanVisitor<Operator, LocalExecutionP
aggregators.add(
new Aggregator(
AccumulatorFactory.createAccumulator(
- descriptor.getAggregationType(), seriesDataType, ascending),
+ descriptor.getAggregationType(),
+ seriesDataType,
+ descriptor.getInputExpressions(),
+ descriptor.getInputAttributes(),
+ ascending),
descriptor.getStep(),
inputLocationList));
}
@@ -1241,7 +1253,11 @@ public class OperatorTreeGenerator extends PlanVisitor<Operator, LocalExecutionP
aggregators.add(
new Aggregator(
AccumulatorFactory.createAccumulator(
- aggregationDescriptor.getAggregationType(), seriesDataType, ascending),
+ aggregationDescriptor.getAggregationType(),
+ seriesDataType,
+ aggregationDescriptor.getInputExpressions(),
+ aggregationDescriptor.getInputAttributes(),
+ ascending),
aggregationDescriptor.getStep(),
inputLocations));
}
@@ -1297,6 +1313,8 @@ public class OperatorTreeGenerator extends PlanVisitor<Operator, LocalExecutionP
.getTypeProvider()
// get the type of first inputExpression
.getType(descriptor.getInputExpressions().get(0).toString()),
+ descriptor.getInputExpressions(),
+ descriptor.getInputAttributes(),
ascending,
inputLocationList,
descriptor.getStep()));
@@ -1370,6 +1388,8 @@ public class OperatorTreeGenerator extends PlanVisitor<Operator, LocalExecutionP
.getTypeProvider()
// get the type of first inputExpression
.getType(descriptor.getInputExpressions().get(0).toString()),
+ descriptor.getInputExpressions(),
+ descriptor.getInputAttributes(),
ascending),
descriptor.getStep(),
inputLocationList));
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/distribution/SourceRewriter.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/distribution/SourceRewriter.java
index f66446194e..fd8d1a9293 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/distribution/SourceRewriter.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/distribution/SourceRewriter.java
@@ -460,7 +460,8 @@ public class SourceRewriter extends SimplePlanNodeRewriter<DistributionPlanConte
new AggregationDescriptor(
descriptor.getAggregationFuncName(),
AggregationStep.PARTIAL,
- descriptor.getInputExpressions()));
+ descriptor.getInputExpressions(),
+ descriptor.getInputAttributes()));
});
leafAggDescriptorList.forEach(
d ->
@@ -474,7 +475,8 @@ public class SourceRewriter extends SimplePlanNodeRewriter<DistributionPlanConte
new AggregationDescriptor(
descriptor.getAggregationFuncName(),
context.isRoot ? AggregationStep.FINAL : AggregationStep.INTERMEDIATE,
- descriptor.getInputExpressions()));
+ descriptor.getInputExpressions(),
+ descriptor.getInputAttributes()));
});
AggregationNode aggregationNode =
@@ -733,7 +735,8 @@ public class SourceRewriter extends SimplePlanNodeRewriter<DistributionPlanConte
regionCountPerSeries.get(handle.getPartitionPath()) == 1
? AggregationStep.STATIC
: AggregationStep.FINAL,
- descriptor.getInputExpressions())));
+ descriptor.getInputExpressions(),
+ descriptor.getInputAttributes())));
}
SeriesAggregationSourceNode seed = (SeriesAggregationSourceNode) root.getChildren().get(0);
newRoot =
@@ -759,7 +762,8 @@ public class SourceRewriter extends SimplePlanNodeRewriter<DistributionPlanConte
new AggregationDescriptor(
descriptor.getAggregationFuncName(),
AggregationStep.INTERMEDIATE,
- descriptor.getInputExpressions())));
+ descriptor.getInputExpressions(),
+ descriptor.getInputAttributes())));
}
SeriesAggregationSourceNode seed = (SeriesAggregationSourceNode) root.getChildren().get(0);
newRoot =
@@ -1023,7 +1027,8 @@ public class SourceRewriter extends SimplePlanNodeRewriter<DistributionPlanConte
new AggregationDescriptor(
v.getAggregationFuncName(),
AggregationStep.INTERMEDIATE,
- v.getInputExpressions()));
+ v.getInputExpressions(),
+ v.getInputAttributes()));
}));
parentOfGroup.setAggregationDescriptorList(childDescriptors);
if (sourceNodes.size() == 1) {
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/AggregationDescriptor.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/AggregationDescriptor.java
index 685aaba828..2a09ecc0ae 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/AggregationDescriptor.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/AggregationDescriptor.java
@@ -20,6 +20,7 @@
package org.apache.iotdb.db.mpp.plan.planner.plan.parameter;
import org.apache.iotdb.common.rpc.thrift.TAggregationType;
+import org.apache.iotdb.commons.utils.TestOnly;
import org.apache.iotdb.db.constant.SqlConstant;
import org.apache.iotdb.db.mpp.plan.expression.Expression;
import org.apache.iotdb.tsfile.utils.ReadWriteIOUtils;
@@ -29,11 +30,10 @@ import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collections;
-import java.util.HashMap;
+import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
-import java.util.stream.Collectors;
public class AggregationDescriptor {
@@ -56,14 +56,33 @@ public class AggregationDescriptor {
*/
protected List<Expression> inputExpressions;
- private String parametersString;
+ protected final Map<String, String> inputAttributes;
+
+ protected String parametersString;
+ public AggregationDescriptor(
+ String aggregationFuncName,
+ AggregationStep step,
+ List<Expression> inputExpressions,
+ Map<String, String> inputAttributes) {
+ this.aggregationFuncName = aggregationFuncName;
+ this.aggregationType = TAggregationType.valueOf(aggregationFuncName.toUpperCase());
+ this.step = step;
+ this.inputExpressions = inputExpressions;
+ this.inputAttributes = inputAttributes;
+ }
+
+ // Old method, please don't use it any more
+ // keep here for compatibility of old Test before introduce of count_if
+ @TestOnly
+ @Deprecated
public AggregationDescriptor(
String aggregationFuncName, AggregationStep step, List<Expression> inputExpressions) {
this.aggregationFuncName = aggregationFuncName;
this.aggregationType = TAggregationType.valueOf(aggregationFuncName.toUpperCase());
this.step = step;
this.inputExpressions = inputExpressions;
+ this.inputAttributes = Collections.emptyMap();
}
public AggregationDescriptor(AggregationDescriptor other) {
@@ -71,6 +90,7 @@ public class AggregationDescriptor {
this.aggregationType = other.getAggregationType();
this.step = other.getStep();
this.inputExpressions = other.getInputExpressions();
+ this.inputAttributes = other.inputAttributes;
}
public String getAggregationFuncName() {
@@ -88,38 +108,22 @@ public class AggregationDescriptor {
public List<List<String>> getInputColumnNamesList() {
if (step.isInputRaw()) {
- return inputExpressions.stream()
- .map(expression -> Collections.singletonList(expression.getExpressionString()))
- .collect(Collectors.toList());
+ return Collections.singletonList(
+ Collections.singletonList(inputExpressions.get(0).getExpressionString()));
}
- List<List<String>> inputColumnNames = new ArrayList<>();
- for (Expression expression : inputExpressions) {
- inputColumnNames.add(getInputColumnNames(expression));
- }
- return inputColumnNames;
+ return Collections.singletonList(getInputColumnNames());
}
- public List<String> getInputColumnNames(Expression inputExpression) {
+ public List<String> getInputColumnNames() {
List<String> inputAggregationNames = getActualAggregationNames(step.isInputPartial());
List<String> inputColumnNames = new ArrayList<>();
for (String funcName : inputAggregationNames) {
- inputColumnNames.add(funcName + "(" + inputExpression.getExpressionString() + ")");
+ inputColumnNames.add(funcName + "(" + getParametersString() + ")");
}
return inputColumnNames;
}
- public Map<String, Expression> getInputColumnCandidateMap() {
- Map<String, Expression> inputColumnNameToExpressionMap = new HashMap<>();
- for (Expression inputExpression : inputExpressions) {
- List<String> inputColumnNames = getInputColumnNames(inputExpression);
- for (String inputColumnName : inputColumnNames) {
- inputColumnNameToExpressionMap.put(inputColumnName, inputExpression);
- }
- }
- return inputColumnNameToExpressionMap;
- }
-
/** Keep the lower case of function name for partial result, and origin value for others. */
protected List<String> getActualAggregationNames(boolean isPartial) {
List<String> outputAggregationNames = new ArrayList<>();
@@ -155,7 +159,7 @@ public class AggregationDescriptor {
*
* <p>The parameter part -> root.sg.d.s1, sin(root.sg.d.s1)
*/
- public String getParametersString() {
+ protected String getParametersString() {
if (parametersString == null) {
StringBuilder builder = new StringBuilder();
if (!inputExpressions.isEmpty()) {
@@ -164,15 +168,45 @@ public class AggregationDescriptor {
builder.append(", ").append(inputExpressions.get(i).toString());
}
}
+ appendAttributes(builder);
parametersString = builder.toString();
}
return parametersString;
}
+ protected void appendAttributes(StringBuilder builder) {
+ if (!inputAttributes.isEmpty()) {
+ builder.append(", ");
+
+ Iterator<Map.Entry<String, String>> iterator = inputAttributes.entrySet().iterator();
+ Map.Entry<String, String> entry = iterator.next();
+ builder
+ .append("\"")
+ .append(entry.getKey())
+ .append("\"=\"")
+ .append(entry.getValue())
+ .append("\"");
+ while (iterator.hasNext()) {
+ entry = iterator.next();
+ builder
+ .append(", ")
+ .append("\"")
+ .append(entry.getKey())
+ .append("\"=\"")
+ .append(entry.getValue())
+ .append("\"");
+ }
+ }
+ }
+
public List<Expression> getInputExpressions() {
return inputExpressions;
}
+ public Map<String, String> getInputAttributes() {
+ return inputAttributes;
+ }
+
public TAggregationType getAggregationType() {
return aggregationType;
}
@@ -200,6 +234,7 @@ public class AggregationDescriptor {
for (Expression expression : inputExpressions) {
Expression.serialize(expression, byteBuffer);
}
+ ReadWriteIOUtils.write(inputAttributes, byteBuffer);
}
public void serialize(DataOutputStream stream) throws IOException {
@@ -209,6 +244,7 @@ public class AggregationDescriptor {
for (Expression expression : inputExpressions) {
Expression.serialize(expression, stream);
}
+ ReadWriteIOUtils.write(inputAttributes, stream);
}
public static AggregationDescriptor deserialize(ByteBuffer byteBuffer) {
@@ -220,7 +256,8 @@ public class AggregationDescriptor {
inputExpressions.add(Expression.deserialize(byteBuffer));
inputExpressionsSize--;
}
- return new AggregationDescriptor(aggregationFuncName, step, inputExpressions);
+ Map<String, String> inputAttributes = ReadWriteIOUtils.readMap(byteBuffer);
+ return new AggregationDescriptor(aggregationFuncName, step, inputExpressions, inputAttributes);
}
@Override
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/CrossSeriesAggregationDescriptor.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/CrossSeriesAggregationDescriptor.java
index e9eb082169..6dd6351e11 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/CrossSeriesAggregationDescriptor.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/CrossSeriesAggregationDescriptor.java
@@ -20,10 +20,13 @@
package org.apache.iotdb.db.mpp.plan.planner.plan.parameter;
import org.apache.iotdb.db.mpp.plan.expression.Expression;
+import org.apache.iotdb.tsfile.utils.ReadWriteIOUtils;
import java.io.DataOutputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
@@ -32,65 +35,140 @@ public class CrossSeriesAggregationDescriptor extends AggregationDescriptor {
private final Expression outputExpression;
+ /**
+ * Records how many Expressions are in one input, used for calculation of inputColumnNames
+ *
+ * <p>.e.g input[count_if(root.db.d1.s1, 3), count_if(root.db.d2.s1, 3)], expressionNumOfOneInput
+ * = 2
+ */
+ private final int expressionNumOfOneInput;
+
+ public CrossSeriesAggregationDescriptor(
+ String aggregationFuncName,
+ AggregationStep step,
+ List<Expression> inputExpressions,
+ int numberOfInput,
+ Map<String, String> inputAttributes,
+ Expression outputExpression) {
+ super(aggregationFuncName, step, inputExpressions, inputAttributes);
+ this.outputExpression = outputExpression;
+ this.expressionNumOfOneInput = inputExpressions.size() / numberOfInput;
+ }
+
+ /**
+ * Please ensure only one Expression in one input when you use this construction, now only
+ * GroupByTagNode use it
+ */
public CrossSeriesAggregationDescriptor(
String aggregationFuncName,
AggregationStep step,
List<Expression> inputExpressions,
+ Map<String, String> inputAttributes,
Expression outputExpression) {
- super(aggregationFuncName, step, inputExpressions);
+ super(aggregationFuncName, step, inputExpressions, inputAttributes);
this.outputExpression = outputExpression;
+ this.expressionNumOfOneInput = 1;
}
public CrossSeriesAggregationDescriptor(
- AggregationDescriptor aggregationDescriptor, Expression outputExpression) {
+ AggregationDescriptor aggregationDescriptor,
+ Expression outputExpression,
+ int expressionNumOfOneInput) {
super(aggregationDescriptor);
this.outputExpression = outputExpression;
+ this.expressionNumOfOneInput = expressionNumOfOneInput;
}
public Expression getOutputExpression() {
return outputExpression;
}
+ /**
+ * Generates the parameter part of the output column name.
+ *
+ * <p>Example:
+ *
+ * <p>Full output column name -> count_if(root.*.*.s1, 3)
+ *
+ * <p>The parameter part -> root.*.*.s1, 3
+ */
@Override
- public String getParametersString() {
- return outputExpression.getExpressionString();
+ protected String getParametersString() {
+ if (parametersString == null) {
+ StringBuilder builder = new StringBuilder(outputExpression.getExpressionString());
+ for (int i = 1; i < expressionNumOfOneInput; i++) {
+ builder.append(", ").append(inputExpressions.get(i).toString());
+ }
+ appendAttributes(builder);
+ parametersString = builder.toString();
+ }
+ return parametersString;
}
@Override
- public Map<String, Expression> getInputColumnCandidateMap() {
- Map<String, Expression> inputColumnNameToExpressionMap = super.getInputColumnCandidateMap();
- List<String> outputColumnNames = getOutputColumnNames();
- for (String outputColumnName : outputColumnNames) {
- inputColumnNameToExpressionMap.put(outputColumnName, outputExpression);
+ public List<List<String>> getInputColumnNamesList() {
+ if (step.isInputRaw()) {
+ return Collections.singletonList(
+ Collections.singletonList(inputExpressions.get(0).getExpressionString()));
}
- return inputColumnNameToExpressionMap;
+
+ List<List<String>> inputColumnNamesList = new ArrayList<>();
+ Expression[] expressions = new Expression[expressionNumOfOneInput];
+ for (int i = 0; i < inputExpressions.size(); i += expressionNumOfOneInput) {
+ for (int j = 0; j < expressionNumOfOneInput; j++) {
+ expressions[j] = inputExpressions.get(i + j);
+ }
+ inputColumnNamesList.add(getInputColumnNames(expressions));
+ }
+ return inputColumnNamesList;
+ }
+
+ private List<String> getInputColumnNames(Expression[] expressions) {
+ List<String> inputAggregationNames = getActualAggregationNames(step.isInputPartial());
+ List<String> inputColumnNames = new ArrayList<>();
+ for (String funcName : inputAggregationNames) {
+ inputColumnNames.add(funcName + "(" + getInputString(expressions) + ")");
+ }
+ return inputColumnNames;
+ }
+
+ private String getInputString(Expression[] expressions) {
+ StringBuilder builder = new StringBuilder();
+ if (!(expressions.length == 0)) {
+ builder.append(expressions[0].toString());
+ for (int i = 1; i < expressions.length; ++i) {
+ builder.append(", ").append(expressions[i].toString());
+ }
+ }
+ appendAttributes(builder);
+ return builder.toString();
}
@Override
public CrossSeriesAggregationDescriptor deepClone() {
- return new CrossSeriesAggregationDescriptor(
- this.getAggregationFuncName(),
- this.getStep(),
- this.getInputExpressions(),
- this.getOutputExpression());
+ return new CrossSeriesAggregationDescriptor(this, outputExpression, expressionNumOfOneInput);
}
@Override
public void serialize(ByteBuffer byteBuffer) {
super.serialize(byteBuffer);
Expression.serialize(outputExpression, byteBuffer);
+ ReadWriteIOUtils.write(expressionNumOfOneInput, byteBuffer);
}
@Override
public void serialize(DataOutputStream stream) throws IOException {
super.serialize(stream);
Expression.serialize(outputExpression, stream);
+ ReadWriteIOUtils.write(expressionNumOfOneInput, stream);
}
public static CrossSeriesAggregationDescriptor deserialize(ByteBuffer byteBuffer) {
AggregationDescriptor aggregationDescriptor = AggregationDescriptor.deserialize(byteBuffer);
Expression outputExpression = Expression.deserialize(byteBuffer);
- return new CrossSeriesAggregationDescriptor(aggregationDescriptor, outputExpression);
+ int expressionNumOfOneInput = ReadWriteIOUtils.readInt(byteBuffer);
+ return new CrossSeriesAggregationDescriptor(
+ aggregationDescriptor, outputExpression, expressionNumOfOneInput);
}
@Override
diff --git a/server/src/main/java/org/apache/iotdb/db/utils/SchemaUtils.java b/server/src/main/java/org/apache/iotdb/db/utils/SchemaUtils.java
index eb75b4c714..fc8c29ffa0 100644
--- a/server/src/main/java/org/apache/iotdb/db/utils/SchemaUtils.java
+++ b/server/src/main/java/org/apache/iotdb/db/utils/SchemaUtils.java
@@ -187,6 +187,7 @@ public class SchemaUtils {
case COUNT:
case MIN_TIME:
case MAX_TIME:
+ case COUNT_IF:
return Collections.emptyList();
default:
throw new IllegalArgumentException(
diff --git a/server/src/main/java/org/apache/iotdb/db/utils/TypeInferenceUtils.java b/server/src/main/java/org/apache/iotdb/db/utils/TypeInferenceUtils.java
index 4b2da0c72e..c295d519cc 100644
--- a/server/src/main/java/org/apache/iotdb/db/utils/TypeInferenceUtils.java
+++ b/server/src/main/java/org/apache/iotdb/db/utils/TypeInferenceUtils.java
@@ -19,14 +19,23 @@
package org.apache.iotdb.db.utils;
+import org.apache.iotdb.commons.path.MeasurementPath;
import org.apache.iotdb.db.conf.IoTDBDescriptor;
import org.apache.iotdb.db.constant.SqlConstant;
import org.apache.iotdb.db.exception.sql.SemanticException;
+import org.apache.iotdb.db.mpp.plan.analyze.ExpressionUtils;
+import org.apache.iotdb.db.mpp.plan.expression.Expression;
+import org.apache.iotdb.db.mpp.plan.expression.binary.CompareBinaryExpression;
+import org.apache.iotdb.db.mpp.plan.expression.leaf.ConstantOperand;
+import org.apache.iotdb.db.mpp.plan.expression.leaf.TimeSeriesOperand;
import org.apache.iotdb.tsfile.common.constant.TsFileConstant;
import org.apache.iotdb.tsfile.file.metadata.enums.TSDataType;
import org.apache.commons.lang3.StringUtils;
+import java.util.Collections;
+import java.util.List;
+
public class TypeInferenceUtils {
private static final TSDataType booleanStringInferType =
@@ -114,15 +123,13 @@ public class TypeInferenceUtils {
if (aggrFuncName == null) {
throw new IllegalArgumentException("AggregateFunction Name must not be null");
}
- if (!verifyIsAggregationDataTypeMatched(aggrFuncName, dataType)) {
- throw new SemanticException(
- "Aggregate functions [AVG, SUM, EXTREME, MIN_VALUE, MAX_VALUE] only support numeric data types [INT32, INT64, FLOAT, DOUBLE]");
- }
+ verifyIsAggregationDataTypeMatched(aggrFuncName, dataType);
switch (aggrFuncName.toLowerCase()) {
case SqlConstant.MIN_TIME:
case SqlConstant.MAX_TIME:
case SqlConstant.COUNT:
+ case SqlConstant.COUNT_IF:
return TSDataType.INT64;
case SqlConstant.MIN_VALUE:
case SqlConstant.LAST_VALUE:
@@ -138,11 +145,10 @@ public class TypeInferenceUtils {
}
}
- private static boolean verifyIsAggregationDataTypeMatched(
- String aggrFuncName, TSDataType dataType) {
+ private static void verifyIsAggregationDataTypeMatched(String aggrFuncName, TSDataType dataType) {
// input is NullOperand, needn't check
if (dataType == null) {
- return true;
+ return;
}
switch (aggrFuncName.toLowerCase()) {
case SqlConstant.AVG:
@@ -150,23 +156,96 @@ public class TypeInferenceUtils {
case SqlConstant.EXTREME:
case SqlConstant.MIN_VALUE:
case SqlConstant.MAX_VALUE:
- return dataType.isNumeric();
+ if (dataType.isNumeric()) {
+ return;
+ }
+ throw new SemanticException(
+ "Aggregate functions [AVG, SUM, EXTREME, MIN_VALUE, MAX_VALUE] only support numeric data types [INT32, INT64, FLOAT, DOUBLE]");
case SqlConstant.COUNT:
case SqlConstant.MIN_TIME:
case SqlConstant.MAX_TIME:
case SqlConstant.FIRST_VALUE:
case SqlConstant.LAST_VALUE:
- return true;
+ return;
+ case SqlConstant.COUNT_IF:
+ if (dataType != TSDataType.BOOLEAN) {
+ throw new SemanticException(
+ String.format(
+ "Input series of Aggregation function [%s] only supports data type [BOOLEAN]",
+ aggrFuncName));
+ }
+ return;
default:
throw new IllegalArgumentException("Invalid Aggregation function: " + aggrFuncName);
}
}
- public static TSDataType getScalarFunctionDataType(String funcName, TSDataType dataType) {
+ /**
+ * Bind Type for non-series input Expressions of AggregationFunction and check Semantic
+ *
+ * <p>.e.g COUNT_IF(s1>1, keep>2, 'ignoreNull'='false'), we bind type {@link TSDataType#INT64} for
+ * 'keep'
+ */
+ public static void bindTypeForAggregationNonSeriesInputExpressions(
+ String functionName,
+ List<Expression> inputExpressions,
+ List<List<Expression>> outputExpressionLists) {
+ switch (functionName.toLowerCase()) {
+ case SqlConstant.AVG:
+ case SqlConstant.SUM:
+ case SqlConstant.EXTREME:
+ case SqlConstant.MIN_VALUE:
+ case SqlConstant.MAX_VALUE:
+ case SqlConstant.COUNT:
+ case SqlConstant.MIN_TIME:
+ case SqlConstant.MAX_TIME:
+ case SqlConstant.FIRST_VALUE:
+ case SqlConstant.LAST_VALUE:
+ return;
+ case SqlConstant.COUNT_IF:
+ Expression keepExpression = inputExpressions.get(1);
+ if (keepExpression instanceof ConstantOperand) {
+ outputExpressionLists.add(Collections.singletonList(keepExpression));
+ return;
+ } else if (keepExpression instanceof CompareBinaryExpression) {
+ Expression leftExpression =
+ ((CompareBinaryExpression) keepExpression).getLeftExpression();
+ Expression rightExpression =
+ ((CompareBinaryExpression) keepExpression).getRightExpression();
+ if (leftExpression instanceof TimeSeriesOperand
+ && leftExpression.getExpressionString().equalsIgnoreCase("keep")
+ && rightExpression.isConstantOperand()) {
+ outputExpressionLists.add(
+ Collections.singletonList(
+ ExpressionUtils.reconstructBinaryExpression(
+ keepExpression.getExpressionType(),
+ new TimeSeriesOperand(
+ new MeasurementPath(
+ ((TimeSeriesOperand) leftExpression).getPath(), TSDataType.INT64)),
+ rightExpression)));
+ return;
+ } else {
+ throw new SemanticException(
+ String.format(
+ "Please check input keep condition of Aggregation function [%s]",
+ functionName));
+ }
+ } else {
+ throw new SemanticException(
+ String.format(
+ "Keep condition of Aggregation function [%s] need to be constant or compare expression constructed by keep and a long number",
+ functionName));
+ }
+ default:
+ throw new IllegalArgumentException("Invalid Aggregation function: " + functionName);
+ }
+ }
+
+ public static TSDataType getBuiltInFunctionDataType(String funcName, TSDataType dataType) {
if (funcName == null) {
throw new IllegalArgumentException("ScalarFunction Name must not be null");
}
- verifyIsScalarFunctionDataTypeMatched(funcName, dataType);
+ verifyIsBuiltInFunctionDataTypeMatched(funcName, dataType);
switch (funcName.toLowerCase()) {
case SqlConstant.DIFF:
@@ -176,7 +255,7 @@ public class TypeInferenceUtils {
}
}
- private static void verifyIsScalarFunctionDataTypeMatched(String funcName, TSDataType dataType) {
+ private static void verifyIsBuiltInFunctionDataTypeMatched(String funcName, TSDataType dataType) {
// input is NullOperand, needn't check
if (dataType == null) {
return;
@@ -188,7 +267,7 @@ public class TypeInferenceUtils {
}
throw new SemanticException(
String.format(
- "Scalar function [%s] only support numeric data types [INT32, INT64, FLOAT, DOUBLE]",
+ "Input series of Scalar function [%s] only supports numeric data types [INT32, INT64, FLOAT, DOUBLE]",
funcName));
default:
throw new IllegalArgumentException("Invalid Scalar function: " + funcName);
diff --git a/server/src/test/java/org/apache/iotdb/db/mpp/aggregation/AccumulatorTest.java b/server/src/test/java/org/apache/iotdb/db/mpp/aggregation/AccumulatorTest.java
index 63f6ce731e..8e79b3fb3c 100644
--- a/server/src/test/java/org/apache/iotdb/db/mpp/aggregation/AccumulatorTest.java
+++ b/server/src/test/java/org/apache/iotdb/db/mpp/aggregation/AccumulatorTest.java
@@ -38,6 +38,7 @@ import org.junit.Before;
import org.junit.Test;
import java.util.ArrayList;
+import java.util.Collections;
import java.util.List;
public class AccumulatorTest {
@@ -81,7 +82,12 @@ public class AccumulatorTest {
@Test
public void avgAccumulatorTest() {
Accumulator avgAccumulator =
- AccumulatorFactory.createAccumulator(TAggregationType.AVG, TSDataType.DOUBLE, true);
+ AccumulatorFactory.createAccumulator(
+ TAggregationType.AVG,
+ TSDataType.DOUBLE,
+ Collections.emptyList(),
+ Collections.emptyMap(),
+ true);
Assert.assertEquals(TSDataType.INT64, avgAccumulator.getIntermediateType()[0]);
Assert.assertEquals(TSDataType.DOUBLE, avgAccumulator.getIntermediateType()[1]);
Assert.assertEquals(TSDataType.DOUBLE, avgAccumulator.getFinalType());
@@ -122,7 +128,12 @@ public class AccumulatorTest {
@Test
public void countAccumulatorTest() {
Accumulator countAccumulator =
- AccumulatorFactory.createAccumulator(TAggregationType.COUNT, TSDataType.DOUBLE, true);
+ AccumulatorFactory.createAccumulator(
+ TAggregationType.COUNT,
+ TSDataType.DOUBLE,
+ Collections.emptyList(),
+ Collections.emptyMap(),
+ true);
Assert.assertEquals(TSDataType.INT64, countAccumulator.getIntermediateType()[0]);
Assert.assertEquals(TSDataType.INT64, countAccumulator.getFinalType());
// check returning null while no data
@@ -157,7 +168,12 @@ public class AccumulatorTest {
@Test
public void extremeAccumulatorTest() {
Accumulator extremeAccumulator =
- AccumulatorFactory.createAccumulator(TAggregationType.EXTREME, TSDataType.DOUBLE, true);
+ AccumulatorFactory.createAccumulator(
+ TAggregationType.EXTREME,
+ TSDataType.DOUBLE,
+ Collections.emptyList(),
+ Collections.emptyMap(),
+ true);
Assert.assertEquals(TSDataType.DOUBLE, extremeAccumulator.getIntermediateType()[0]);
Assert.assertEquals(TSDataType.DOUBLE, extremeAccumulator.getFinalType());
// check returning null while no data
@@ -192,7 +208,12 @@ public class AccumulatorTest {
@Test
public void firstValueAccumulatorTest() {
Accumulator firstValueAccumulator =
- AccumulatorFactory.createAccumulator(TAggregationType.FIRST_VALUE, TSDataType.DOUBLE, true);
+ AccumulatorFactory.createAccumulator(
+ TAggregationType.FIRST_VALUE,
+ TSDataType.DOUBLE,
+ Collections.emptyList(),
+ Collections.emptyMap(),
+ true);
Assert.assertEquals(TSDataType.DOUBLE, firstValueAccumulator.getIntermediateType()[0]);
Assert.assertEquals(TSDataType.INT64, firstValueAccumulator.getIntermediateType()[1]);
Assert.assertEquals(TSDataType.DOUBLE, firstValueAccumulator.getFinalType());
@@ -233,7 +254,12 @@ public class AccumulatorTest {
@Test
public void lastValueAccumulatorTest() {
Accumulator lastValueAccumulator =
- AccumulatorFactory.createAccumulator(TAggregationType.LAST_VALUE, TSDataType.DOUBLE, true);
+ AccumulatorFactory.createAccumulator(
+ TAggregationType.LAST_VALUE,
+ TSDataType.DOUBLE,
+ Collections.emptyList(),
+ Collections.emptyMap(),
+ true);
Assert.assertEquals(TSDataType.DOUBLE, lastValueAccumulator.getIntermediateType()[0]);
Assert.assertEquals(TSDataType.INT64, lastValueAccumulator.getIntermediateType()[1]);
Assert.assertEquals(TSDataType.DOUBLE, lastValueAccumulator.getFinalType());
@@ -273,7 +299,12 @@ public class AccumulatorTest {
@Test
public void maxTimeAccumulatorTest() {
Accumulator maxTimeAccumulator =
- AccumulatorFactory.createAccumulator(TAggregationType.MAX_TIME, TSDataType.DOUBLE, true);
+ AccumulatorFactory.createAccumulator(
+ TAggregationType.MAX_TIME,
+ TSDataType.DOUBLE,
+ Collections.emptyList(),
+ Collections.emptyMap(),
+ true);
Assert.assertEquals(TSDataType.INT64, maxTimeAccumulator.getIntermediateType()[0]);
Assert.assertEquals(TSDataType.INT64, maxTimeAccumulator.getFinalType());
// check returning null while no data
@@ -308,7 +339,12 @@ public class AccumulatorTest {
@Test
public void minTimeAccumulatorTest() {
Accumulator minTimeAccumulator =
- AccumulatorFactory.createAccumulator(TAggregationType.MIN_TIME, TSDataType.DOUBLE, true);
+ AccumulatorFactory.createAccumulator(
+ TAggregationType.MIN_TIME,
+ TSDataType.DOUBLE,
+ Collections.emptyList(),
+ Collections.emptyMap(),
+ true);
Assert.assertEquals(TSDataType.INT64, minTimeAccumulator.getIntermediateType()[0]);
Assert.assertEquals(TSDataType.INT64, minTimeAccumulator.getFinalType());
// check returning null while no data
@@ -343,7 +379,12 @@ public class AccumulatorTest {
@Test
public void maxValueAccumulatorTest() {
Accumulator extremeAccumulator =
- AccumulatorFactory.createAccumulator(TAggregationType.MAX_VALUE, TSDataType.DOUBLE, true);
+ AccumulatorFactory.createAccumulator(
+ TAggregationType.MAX_VALUE,
+ TSDataType.DOUBLE,
+ Collections.emptyList(),
+ Collections.emptyMap(),
+ true);
Assert.assertEquals(TSDataType.DOUBLE, extremeAccumulator.getIntermediateType()[0]);
Assert.assertEquals(TSDataType.DOUBLE, extremeAccumulator.getFinalType());
// check returning null while no data
@@ -378,7 +419,12 @@ public class AccumulatorTest {
@Test
public void minValueAccumulatorTest() {
Accumulator extremeAccumulator =
- AccumulatorFactory.createAccumulator(TAggregationType.MIN_VALUE, TSDataType.DOUBLE, true);
+ AccumulatorFactory.createAccumulator(
+ TAggregationType.MIN_VALUE,
+ TSDataType.DOUBLE,
+ Collections.emptyList(),
+ Collections.emptyMap(),
+ true);
Assert.assertEquals(TSDataType.DOUBLE, extremeAccumulator.getIntermediateType()[0]);
Assert.assertEquals(TSDataType.DOUBLE, extremeAccumulator.getFinalType());
// check returning null while no data
@@ -413,7 +459,12 @@ public class AccumulatorTest {
@Test
public void sumAccumulatorTest() {
Accumulator sumAccumulator =
- AccumulatorFactory.createAccumulator(TAggregationType.SUM, TSDataType.DOUBLE, true);
+ AccumulatorFactory.createAccumulator(
+ TAggregationType.SUM,
+ TSDataType.DOUBLE,
+ Collections.emptyList(),
+ Collections.emptyMap(),
+ true);
Assert.assertEquals(TSDataType.DOUBLE, sumAccumulator.getIntermediateType()[0]);
Assert.assertEquals(TSDataType.DOUBLE, sumAccumulator.getFinalType());
// check returning null while no data
diff --git a/server/src/test/java/org/apache/iotdb/db/mpp/execution/operator/AggregationOperatorTest.java b/server/src/test/java/org/apache/iotdb/db/mpp/execution/operator/AggregationOperatorTest.java
index a32d3e083a..deebb8f5a6 100644
--- a/server/src/test/java/org/apache/iotdb/db/mpp/execution/operator/AggregationOperatorTest.java
+++ b/server/src/test/java/org/apache/iotdb/db/mpp/execution/operator/AggregationOperatorTest.java
@@ -313,7 +313,12 @@ public class AggregationOperatorTest {
MeasurementPath measurementPath1 =
new MeasurementPath(AGGREGATION_OPERATOR_TEST_SG + ".device0.sensor0", TSDataType.INT32);
List<Aggregator> aggregators = new ArrayList<>();
- AccumulatorFactory.createAccumulators(aggregationTypes, TSDataType.INT32, true)
+ AccumulatorFactory.createAccumulators(
+ aggregationTypes,
+ TSDataType.INT32,
+ Collections.emptyList(),
+ Collections.emptyMap(),
+ true)
.forEach(o -> aggregators.add(new Aggregator(o, AggregationStep.PARTIAL)));
SeriesAggregationScanOperator seriesAggregationScanOperator1 =
new SeriesAggregationScanOperator(
@@ -366,7 +371,12 @@ public class AggregationOperatorTest {
List<Aggregator> finalAggregators = new ArrayList<>();
List<Accumulator> accumulators =
- AccumulatorFactory.createAccumulators(aggregationTypes, TSDataType.INT32, true);
+ AccumulatorFactory.createAccumulators(
+ aggregationTypes,
+ TSDataType.INT32,
+ Collections.emptyList(),
+ Collections.emptyMap(),
+ true);
for (int i = 0; i < accumulators.size(); i++) {
finalAggregators.add(
new Aggregator(accumulators.get(i), AggregationStep.FINAL, inputLocations.get(i)));
diff --git a/server/src/test/java/org/apache/iotdb/db/mpp/execution/operator/AlignedSeriesAggregationScanOperatorTest.java b/server/src/test/java/org/apache/iotdb/db/mpp/execution/operator/AlignedSeriesAggregationScanOperatorTest.java
index a91a296976..4232c4e5df 100644
--- a/server/src/test/java/org/apache/iotdb/db/mpp/execution/operator/AlignedSeriesAggregationScanOperatorTest.java
+++ b/server/src/test/java/org/apache/iotdb/db/mpp/execution/operator/AlignedSeriesAggregationScanOperatorTest.java
@@ -102,7 +102,12 @@ public class AlignedSeriesAggregationScanOperatorTest {
inputLocations.add(new InputLocation[] {new InputLocation(0, i)});
aggregators.add(
new Aggregator(
- AccumulatorFactory.createAccumulator(TAggregationType.COUNT, dataType, true),
+ AccumulatorFactory.createAccumulator(
+ TAggregationType.COUNT,
+ dataType,
+ Collections.emptyList(),
+ Collections.emptyMap(),
+ true),
AggregationStep.SINGLE,
inputLocations));
}
@@ -128,7 +133,12 @@ public class AlignedSeriesAggregationScanOperatorTest {
inputLocations.add(new InputLocation[] {new InputLocation(0, i)});
aggregators.add(
new Aggregator(
- AccumulatorFactory.createAccumulator(TAggregationType.COUNT, dataType, false),
+ AccumulatorFactory.createAccumulator(
+ TAggregationType.COUNT,
+ dataType,
+ Collections.emptyList(),
+ Collections.emptyMap(),
+ false),
AggregationStep.SINGLE,
inputLocations));
}
@@ -157,7 +167,12 @@ public class AlignedSeriesAggregationScanOperatorTest {
inputLocations.add(new InputLocation[] {new InputLocation(0, i)});
aggregators.add(
new Aggregator(
- AccumulatorFactory.createAccumulator(aggregationTypes.get(i), dataType, true),
+ AccumulatorFactory.createAccumulator(
+ aggregationTypes.get(i),
+ dataType,
+ Collections.emptyList(),
+ Collections.emptyMap(),
+ true),
AggregationStep.SINGLE,
inputLocations));
}
@@ -189,7 +204,12 @@ public class AlignedSeriesAggregationScanOperatorTest {
inputLocations.add(new InputLocation[] {new InputLocation(0, i)});
aggregators.add(
new Aggregator(
- AccumulatorFactory.createAccumulator(aggregationTypes.get(i), dataType, true),
+ AccumulatorFactory.createAccumulator(
+ aggregationTypes.get(i),
+ dataType,
+ Collections.emptyList(),
+ Collections.emptyMap(),
+ true),
AggregationStep.SINGLE,
inputLocations));
}
@@ -226,7 +246,12 @@ public class AlignedSeriesAggregationScanOperatorTest {
inputLocations.add(new InputLocation[] {new InputLocation(0, i)});
aggregators.add(
new Aggregator(
- AccumulatorFactory.createAccumulator(aggregationTypes.get(i), dataType, false),
+ AccumulatorFactory.createAccumulator(
+ aggregationTypes.get(i),
+ dataType,
+ Collections.emptyList(),
+ Collections.emptyMap(),
+ false),
AggregationStep.SINGLE,
inputLocations));
}
@@ -255,7 +280,12 @@ public class AlignedSeriesAggregationScanOperatorTest {
inputLocations.add(new InputLocation[] {new InputLocation(0, i)});
aggregators.add(
new Aggregator(
- AccumulatorFactory.createAccumulator(TAggregationType.COUNT, dataType, true),
+ AccumulatorFactory.createAccumulator(
+ TAggregationType.COUNT,
+ dataType,
+ Collections.emptyList(),
+ Collections.emptyMap(),
+ true),
AggregationStep.SINGLE,
inputLocations));
}
@@ -283,7 +313,12 @@ public class AlignedSeriesAggregationScanOperatorTest {
inputLocations.add(new InputLocation[] {new InputLocation(0, i)});
aggregators.add(
new Aggregator(
- AccumulatorFactory.createAccumulator(TAggregationType.COUNT, dataType, true),
+ AccumulatorFactory.createAccumulator(
+ TAggregationType.COUNT,
+ dataType,
+ Collections.emptyList(),
+ Collections.emptyMap(),
+ true),
AggregationStep.SINGLE,
inputLocations));
}
@@ -310,7 +345,12 @@ public class AlignedSeriesAggregationScanOperatorTest {
inputLocations.add(new InputLocation[] {new InputLocation(0, i)});
aggregators.add(
new Aggregator(
- AccumulatorFactory.createAccumulator(TAggregationType.COUNT, dataType, true),
+ AccumulatorFactory.createAccumulator(
+ TAggregationType.COUNT,
+ dataType,
+ Collections.emptyList(),
+ Collections.emptyMap(),
+ true),
AggregationStep.SINGLE,
inputLocations));
}
@@ -343,7 +383,12 @@ public class AlignedSeriesAggregationScanOperatorTest {
inputLocations.add(new InputLocation[] {new InputLocation(0, i)});
aggregators.add(
new Aggregator(
- AccumulatorFactory.createAccumulator(aggregationTypes.get(i), dataType, true),
+ AccumulatorFactory.createAccumulator(
+ aggregationTypes.get(i),
+ dataType,
+ Collections.emptyList(),
+ Collections.emptyMap(),
+ true),
AggregationStep.SINGLE,
inputLocations));
}
@@ -375,7 +420,12 @@ public class AlignedSeriesAggregationScanOperatorTest {
inputLocations.add(new InputLocation[] {new InputLocation(0, i)});
aggregators.add(
new Aggregator(
- AccumulatorFactory.createAccumulator(TAggregationType.COUNT, dataType, true),
+ AccumulatorFactory.createAccumulator(
+ TAggregationType.COUNT,
+ dataType,
+ Collections.emptyList(),
+ Collections.emptyMap(),
+ true),
AggregationStep.SINGLE,
inputLocations));
}
@@ -408,7 +458,12 @@ public class AlignedSeriesAggregationScanOperatorTest {
inputLocations.add(new InputLocation[] {new InputLocation(0, i)});
aggregators.add(
new Aggregator(
- AccumulatorFactory.createAccumulator(TAggregationType.COUNT, dataType, true),
+ AccumulatorFactory.createAccumulator(
+ TAggregationType.COUNT,
+ dataType,
+ Collections.emptyList(),
+ Collections.emptyMap(),
+ true),
AggregationStep.SINGLE,
inputLocations));
}
@@ -448,7 +503,12 @@ public class AlignedSeriesAggregationScanOperatorTest {
List<Aggregator> aggregators = new ArrayList<>();
List<InputLocation[]> inputLocations =
Collections.singletonList(new InputLocation[] {new InputLocation(0, 1)});
- AccumulatorFactory.createAccumulators(aggregationTypes, TSDataType.INT32, true)
+ AccumulatorFactory.createAccumulators(
+ aggregationTypes,
+ TSDataType.INT32,
+ Collections.emptyList(),
+ Collections.emptyMap(),
+ true)
.forEach(o -> aggregators.add(new Aggregator(o, AggregationStep.SINGLE, inputLocations)));
AlignedSeriesAggregationScanOperator seriesAggregationScanOperator =
initAlignedSeriesAggregationScanOperator(aggregators, null, true, groupByTimeParameter);
@@ -486,7 +546,12 @@ public class AlignedSeriesAggregationScanOperatorTest {
List<Aggregator> aggregators = new ArrayList<>();
List<InputLocation[]> inputLocations =
Collections.singletonList(new InputLocation[] {new InputLocation(0, 1)});
- AccumulatorFactory.createAccumulators(aggregationTypes, TSDataType.INT32, false)
+ AccumulatorFactory.createAccumulators(
+ aggregationTypes,
+ TSDataType.INT32,
+ Collections.emptyList(),
+ Collections.emptyMap(),
+ false)
.forEach(o -> aggregators.add(new Aggregator(o, AggregationStep.SINGLE, inputLocations)));
AlignedSeriesAggregationScanOperator seriesAggregationScanOperator =
initAlignedSeriesAggregationScanOperator(aggregators, null, false, groupByTimeParameter);
@@ -514,7 +579,12 @@ public class AlignedSeriesAggregationScanOperatorTest {
List<Aggregator> aggregators = new ArrayList<>();
List<InputLocation[]> inputLocations =
Collections.singletonList(new InputLocation[] {new InputLocation(0, 1)});
- AccumulatorFactory.createAccumulators(aggregationTypes, TSDataType.INT32, true)
+ AccumulatorFactory.createAccumulators(
+ aggregationTypes,
+ TSDataType.INT32,
+ Collections.emptyList(),
+ Collections.emptyMap(),
+ true)
.forEach(o -> aggregators.add(new Aggregator(o, AggregationStep.SINGLE, inputLocations)));
AlignedSeriesAggregationScanOperator seriesAggregationScanOperator =
initAlignedSeriesAggregationScanOperator(aggregators, null, true, groupByTimeParameter);
@@ -540,7 +610,12 @@ public class AlignedSeriesAggregationScanOperatorTest {
List<Aggregator> aggregators = new ArrayList<>();
List<InputLocation[]> inputLocations =
Collections.singletonList(new InputLocation[] {new InputLocation(0, 1)});
- AccumulatorFactory.createAccumulators(aggregationTypes, TSDataType.INT32, true)
+ AccumulatorFactory.createAccumulators(
+ aggregationTypes,
+ TSDataType.INT32,
+ Collections.emptyList(),
+ Collections.emptyMap(),
+ true)
.forEach(o -> aggregators.add(new Aggregator(o, AggregationStep.SINGLE, inputLocations)));
AlignedSeriesAggregationScanOperator seriesAggregationScanOperator =
initAlignedSeriesAggregationScanOperator(aggregators, null, true, groupByTimeParameter);
@@ -576,7 +651,12 @@ public class AlignedSeriesAggregationScanOperatorTest {
List<Aggregator> aggregators = new ArrayList<>();
List<InputLocation[]> inputLocations =
Collections.singletonList(new InputLocation[] {new InputLocation(0, 1)});
- AccumulatorFactory.createAccumulators(aggregationTypes, TSDataType.INT32, true)
+ AccumulatorFactory.createAccumulators(
+ aggregationTypes,
+ TSDataType.INT32,
+ Collections.emptyList(),
+ Collections.emptyMap(),
+ true)
.forEach(o -> aggregators.add(new Aggregator(o, AggregationStep.SINGLE, inputLocations)));
AlignedSeriesAggregationScanOperator seriesAggregationScanOperator =
initAlignedSeriesAggregationScanOperator(aggregators, null, true, groupByTimeParameter);
diff --git a/server/src/test/java/org/apache/iotdb/db/mpp/execution/operator/HorizontallyConcatOperatorTest.java b/server/src/test/java/org/apache/iotdb/db/mpp/execution/operator/HorizontallyConcatOperatorTest.java
index 2168221610..58c6412f20 100644
--- a/server/src/test/java/org/apache/iotdb/db/mpp/execution/operator/HorizontallyConcatOperatorTest.java
+++ b/server/src/test/java/org/apache/iotdb/db/mpp/execution/operator/HorizontallyConcatOperatorTest.java
@@ -52,6 +52,7 @@ import org.junit.Test;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
+import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
@@ -121,7 +122,12 @@ public class HorizontallyConcatOperatorTest {
Arrays.asList(TAggregationType.COUNT, TAggregationType.SUM, TAggregationType.FIRST_VALUE);
GroupByTimeParameter groupByTimeParameter = new GroupByTimeParameter(0, 10, 1, 1, true);
List<Aggregator> aggregators = new ArrayList<>();
- AccumulatorFactory.createAccumulators(aggregationTypes, TSDataType.INT32, true)
+ AccumulatorFactory.createAccumulators(
+ aggregationTypes,
+ TSDataType.INT32,
+ Collections.emptyList(),
+ Collections.emptyMap(),
+ true)
.forEach(o -> aggregators.add(new Aggregator(o, AggregationStep.SINGLE)));
SeriesAggregationScanOperator seriesAggregationScanOperator1 =
new SeriesAggregationScanOperator(
diff --git a/server/src/test/java/org/apache/iotdb/db/mpp/execution/operator/OperatorMemoryTest.java b/server/src/test/java/org/apache/iotdb/db/mpp/execution/operator/OperatorMemoryTest.java
index 7292f5f321..70746c3cc7 100644
--- a/server/src/test/java/org/apache/iotdb/db/mpp/execution/operator/OperatorMemoryTest.java
+++ b/server/src/test/java/org/apache/iotdb/db/mpp/execution/operator/OperatorMemoryTest.java
@@ -1205,7 +1205,11 @@ public class OperatorMemoryTest {
aggregators.add(
new Aggregator(
AccumulatorFactory.createAccumulator(
- o.getAggregationType(), measurementPath.getSeriesType(), true),
+ o.getAggregationType(),
+ measurementPath.getSeriesType(),
+ Collections.emptyList(),
+ Collections.emptyMap(),
+ true),
o.getStep())));
ITimeRangeIterator timeRangeIterator = initTimeRangeIterator(groupByTimeParameter, true, true);
@@ -1255,7 +1259,11 @@ public class OperatorMemoryTest {
aggregators.add(
new Aggregator(
AccumulatorFactory.createAccumulator(
- o.getAggregationType(), measurementPath.getSeriesType(), true),
+ o.getAggregationType(),
+ measurementPath.getSeriesType(),
+ Collections.emptyList(),
+ Collections.emptyMap(),
+ true),
o.getStep())));
GroupByTimeParameter groupByTimeParameter = new GroupByTimeParameter(0, 1000, 10, 10, true);
@@ -1323,7 +1331,11 @@ public class OperatorMemoryTest {
aggregators.add(
new Aggregator(
AccumulatorFactory.createAccumulator(
- o.getAggregationType(), measurementPath.getSeriesType(), true),
+ o.getAggregationType(),
+ measurementPath.getSeriesType(),
+ Collections.emptyList(),
+ Collections.emptyMap(),
+ true),
o.getStep())));
GroupByTimeParameter groupByTimeParameter = new GroupByTimeParameter(0, 1000, 10, 5, true);
@@ -1397,7 +1409,11 @@ public class OperatorMemoryTest {
aggregators.add(
new Aggregator(
AccumulatorFactory.createAccumulator(
- o.getAggregationType(), measurementPath.getSeriesType(), true),
+ o.getAggregationType(),
+ measurementPath.getSeriesType(),
+ Collections.emptyList(),
+ Collections.emptyMap(),
+ true),
o.getStep())));
GroupByTimeParameter groupByTimeParameter = new GroupByTimeParameter(0, 1000, 10, 10, true);
diff --git a/server/src/test/java/org/apache/iotdb/db/mpp/execution/operator/RawDataAggregationOperatorTest.java b/server/src/test/java/org/apache/iotdb/db/mpp/execution/operator/RawDataAggregationOperatorTest.java
index 726132a143..24fe260af0 100644
--- a/server/src/test/java/org/apache/iotdb/db/mpp/execution/operator/RawDataAggregationOperatorTest.java
+++ b/server/src/test/java/org/apache/iotdb/db/mpp/execution/operator/RawDataAggregationOperatorTest.java
@@ -64,6 +64,7 @@ import org.junit.Test;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
+import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
@@ -883,7 +884,12 @@ public class RawDataAggregationOperatorTest {
List<Aggregator> aggregators = new ArrayList<>();
List<Accumulator> accumulators =
- AccumulatorFactory.createAccumulators(aggregationTypes, TSDataType.INT32, true);
+ AccumulatorFactory.createAccumulators(
+ aggregationTypes,
+ TSDataType.INT32,
+ Collections.emptyList(),
+ Collections.emptyMap(),
+ true);
for (int i = 0; i < accumulators.size(); i++) {
aggregators.add(
new Aggregator(accumulators.get(i), AggregationStep.SINGLE, inputLocations.get(i)));
diff --git a/server/src/test/java/org/apache/iotdb/db/mpp/execution/operator/SeriesAggregationScanOperatorTest.java b/server/src/test/java/org/apache/iotdb/db/mpp/execution/operator/SeriesAggregationScanOperatorTest.java
index ee2b325d10..fbf731dae3 100644
--- a/server/src/test/java/org/apache/iotdb/db/mpp/execution/operator/SeriesAggregationScanOperatorTest.java
+++ b/server/src/test/java/org/apache/iotdb/db/mpp/execution/operator/SeriesAggregationScanOperatorTest.java
@@ -92,7 +92,12 @@ public class SeriesAggregationScanOperatorTest {
public void testAggregationWithoutTimeFilter() throws IllegalPathException {
List<TAggregationType> aggregationTypes = Collections.singletonList(TAggregationType.COUNT);
List<Aggregator> aggregators = new ArrayList<>();
- AccumulatorFactory.createAccumulators(aggregationTypes, TSDataType.INT32, true)
+ AccumulatorFactory.createAccumulators(
+ aggregationTypes,
+ TSDataType.INT32,
+ Collections.emptyList(),
+ Collections.emptyMap(),
+ true)
.forEach(o -> aggregators.add(new Aggregator(o, AggregationStep.SINGLE)));
SeriesAggregationScanOperator seriesAggregationScanOperator =
initSeriesAggregationScanOperator(aggregators, null, true, null);
@@ -109,7 +114,12 @@ public class SeriesAggregationScanOperatorTest {
public void testAggregationWithoutTimeFilterOrderByTimeDesc() throws IllegalPathException {
List<TAggregationType> aggregationTypes = Collections.singletonList(TAggregationType.COUNT);
List<Aggregator> aggregators = new ArrayList<>();
- AccumulatorFactory.createAccumulators(aggregationTypes, TSDataType.INT32, true)
+ AccumulatorFactory.createAccumulators(
+ aggregationTypes,
+ TSDataType.INT32,
+ Collections.emptyList(),
+ Collections.emptyMap(),
+ true)
.forEach(o -> aggregators.add(new Aggregator(o, AggregationStep.SINGLE)));
SeriesAggregationScanOperator seriesAggregationScanOperator =
initSeriesAggregationScanOperator(aggregators, null, false, null);
@@ -128,7 +138,12 @@ public class SeriesAggregationScanOperatorTest {
aggregationTypes.add(TAggregationType.COUNT);
aggregationTypes.add(TAggregationType.SUM);
List<Aggregator> aggregators = new ArrayList<>();
- AccumulatorFactory.createAccumulators(aggregationTypes, TSDataType.INT32, true)
+ AccumulatorFactory.createAccumulators(
+ aggregationTypes,
+ TSDataType.INT32,
+ Collections.emptyList(),
+ Collections.emptyMap(),
+ true)
.forEach(o -> aggregators.add(new Aggregator(o, AggregationStep.SINGLE)));
SeriesAggregationScanOperator seriesAggregationScanOperator =
initSeriesAggregationScanOperator(aggregators, null, true, null);
@@ -152,7 +167,12 @@ public class SeriesAggregationScanOperatorTest {
aggregationTypes.add(TAggregationType.MAX_VALUE);
aggregationTypes.add(TAggregationType.MIN_VALUE);
List<Aggregator> aggregators = new ArrayList<>();
- AccumulatorFactory.createAccumulators(aggregationTypes, TSDataType.INT32, true)
+ AccumulatorFactory.createAccumulators(
+ aggregationTypes,
+ TSDataType.INT32,
+ Collections.emptyList(),
+ Collections.emptyMap(),
+ true)
.forEach(o -> aggregators.add(new Aggregator(o, AggregationStep.SINGLE)));
SeriesAggregationScanOperator seriesAggregationScanOperator =
initSeriesAggregationScanOperator(aggregators, null, true, null);
@@ -181,7 +201,12 @@ public class SeriesAggregationScanOperatorTest {
aggregationTypes.add(TAggregationType.MAX_VALUE);
aggregationTypes.add(TAggregationType.MIN_VALUE);
List<Aggregator> aggregators = new ArrayList<>();
- AccumulatorFactory.createAccumulators(aggregationTypes, TSDataType.INT32, false)
+ AccumulatorFactory.createAccumulators(
+ aggregationTypes,
+ TSDataType.INT32,
+ Collections.emptyList(),
+ Collections.emptyMap(),
+ false)
.forEach(o -> aggregators.add(new Aggregator(o, AggregationStep.SINGLE)));
SeriesAggregationScanOperator seriesAggregationScanOperator =
initSeriesAggregationScanOperator(aggregators, null, false, null);
@@ -203,7 +228,12 @@ public class SeriesAggregationScanOperatorTest {
public void testAggregationWithTimeFilter1() throws IllegalPathException {
List<TAggregationType> aggregationTypes = Collections.singletonList(TAggregationType.COUNT);
List<Aggregator> aggregators = new ArrayList<>();
- AccumulatorFactory.createAccumulators(aggregationTypes, TSDataType.INT32, true)
+ AccumulatorFactory.createAccumulators(
+ aggregationTypes,
+ TSDataType.INT32,
+ Collections.emptyList(),
+ Collections.emptyMap(),
+ true)
.forEach(o -> aggregators.add(new Aggregator(o, AggregationStep.SINGLE)));
Filter timeFilter = TimeFilter.gtEq(120);
SeriesAggregationScanOperator seriesAggregationScanOperator =
@@ -222,7 +252,12 @@ public class SeriesAggregationScanOperatorTest {
Filter timeFilter = TimeFilter.ltEq(379);
List<TAggregationType> aggregationTypes = Collections.singletonList(TAggregationType.COUNT);
List<Aggregator> aggregators = new ArrayList<>();
- AccumulatorFactory.createAccumulators(aggregationTypes, TSDataType.INT32, true)
+ AccumulatorFactory.createAccumulators(
+ aggregationTypes,
+ TSDataType.INT32,
+ Collections.emptyList(),
+ Collections.emptyMap(),
+ true)
.forEach(o -> aggregators.add(new Aggregator(o, AggregationStep.SINGLE)));
SeriesAggregationScanOperator seriesAggregationScanOperator =
initSeriesAggregationScanOperator(aggregators, timeFilter, true, null);
@@ -240,7 +275,12 @@ public class SeriesAggregationScanOperatorTest {
Filter timeFilter = new AndFilter(TimeFilter.gtEq(100), TimeFilter.ltEq(399));
List<TAggregationType> aggregationTypes = Collections.singletonList(TAggregationType.COUNT);
List<Aggregator> aggregators = new ArrayList<>();
- AccumulatorFactory.createAccumulators(aggregationTypes, TSDataType.INT32, true)
+ AccumulatorFactory.createAccumulators(
+ aggregationTypes,
+ TSDataType.INT32,
+ Collections.emptyList(),
+ Collections.emptyMap(),
+ true)
.forEach(o -> aggregators.add(new Aggregator(o, AggregationStep.SINGLE)));
SeriesAggregationScanOperator seriesAggregationScanOperator =
initSeriesAggregationScanOperator(aggregators, timeFilter, true, null);
@@ -263,7 +303,12 @@ public class SeriesAggregationScanOperatorTest {
aggregationTypes.add(TAggregationType.MAX_VALUE);
aggregationTypes.add(TAggregationType.MIN_VALUE);
List<Aggregator> aggregators = new ArrayList<>();
- AccumulatorFactory.createAccumulators(aggregationTypes, TSDataType.INT32, true)
+ AccumulatorFactory.createAccumulators(
+ aggregationTypes,
+ TSDataType.INT32,
+ Collections.emptyList(),
+ Collections.emptyMap(),
+ true)
.forEach(o -> aggregators.add(new Aggregator(o, AggregationStep.SINGLE)));
Filter timeFilter = new AndFilter(TimeFilter.gtEq(100), TimeFilter.ltEq(399));
SeriesAggregationScanOperator seriesAggregationScanOperator =
@@ -288,7 +333,12 @@ public class SeriesAggregationScanOperatorTest {
GroupByTimeParameter groupByTimeParameter = new GroupByTimeParameter(0, 399, 100, 100, true);
List<TAggregationType> aggregationTypes = Collections.singletonList(TAggregationType.COUNT);
List<Aggregator> aggregators = new ArrayList<>();
- AccumulatorFactory.createAccumulators(aggregationTypes, TSDataType.INT32, true)
+ AccumulatorFactory.createAccumulators(
+ aggregationTypes,
+ TSDataType.INT32,
+ Collections.emptyList(),
+ Collections.emptyMap(),
+ true)
.forEach(o -> aggregators.add(new Aggregator(o, AggregationStep.SINGLE)));
SeriesAggregationScanOperator seriesAggregationScanOperator =
initSeriesAggregationScanOperator(aggregators, null, true, groupByTimeParameter);
@@ -312,7 +362,12 @@ public class SeriesAggregationScanOperatorTest {
GroupByTimeParameter groupByTimeParameter = new GroupByTimeParameter(0, 399, 100, 100, true);
List<TAggregationType> aggregationTypes = Collections.singletonList(TAggregationType.COUNT);
List<Aggregator> aggregators = new ArrayList<>();
- AccumulatorFactory.createAccumulators(aggregationTypes, TSDataType.INT32, true)
+ AccumulatorFactory.createAccumulators(
+ aggregationTypes,
+ TSDataType.INT32,
+ Collections.emptyList(),
+ Collections.emptyMap(),
+ true)
.forEach(o -> aggregators.add(new Aggregator(o, AggregationStep.SINGLE)));
SeriesAggregationScanOperator seriesAggregationScanOperator =
initSeriesAggregationScanOperator(aggregators, timeFilter, true, groupByTimeParameter);
@@ -345,7 +400,12 @@ public class SeriesAggregationScanOperatorTest {
aggregationTypes.add(TAggregationType.MIN_VALUE);
GroupByTimeParameter groupByTimeParameter = new GroupByTimeParameter(0, 399, 100, 100, true);
List<Aggregator> aggregators = new ArrayList<>();
- AccumulatorFactory.createAccumulators(aggregationTypes, TSDataType.INT32, true)
+ AccumulatorFactory.createAccumulators(
+ aggregationTypes,
+ TSDataType.INT32,
+ Collections.emptyList(),
+ Collections.emptyMap(),
+ true)
.forEach(o -> aggregators.add(new Aggregator(o, AggregationStep.SINGLE)));
SeriesAggregationScanOperator seriesAggregationScanOperator =
initSeriesAggregationScanOperator(aggregators, null, true, groupByTimeParameter);
@@ -381,7 +441,12 @@ public class SeriesAggregationScanOperatorTest {
aggregationTypes.add(TAggregationType.MIN_VALUE);
GroupByTimeParameter groupByTimeParameter = new GroupByTimeParameter(0, 399, 100, 100, true);
List<Aggregator> aggregators = new ArrayList<>();
- AccumulatorFactory.createAccumulators(aggregationTypes, TSDataType.INT32, false)
+ AccumulatorFactory.createAccumulators(
+ aggregationTypes,
+ TSDataType.INT32,
+ Collections.emptyList(),
+ Collections.emptyMap(),
+ false)
.forEach(o -> aggregators.add(new Aggregator(o, AggregationStep.SINGLE)));
SeriesAggregationScanOperator seriesAggregationScanOperator =
initSeriesAggregationScanOperator(aggregators, null, false, groupByTimeParameter);
@@ -407,7 +472,12 @@ public class SeriesAggregationScanOperatorTest {
GroupByTimeParameter groupByTimeParameter = new GroupByTimeParameter(0, 399, 100, 50, true);
List<TAggregationType> aggregationTypes = Collections.singletonList(TAggregationType.COUNT);
List<Aggregator> aggregators = new ArrayList<>();
- AccumulatorFactory.createAccumulators(aggregationTypes, TSDataType.INT32, true)
+ AccumulatorFactory.createAccumulators(
+ aggregationTypes,
+ TSDataType.INT32,
+ Collections.emptyList(),
+ Collections.emptyMap(),
+ true)
.forEach(o -> aggregators.add(new Aggregator(o, AggregationStep.SINGLE)));
SeriesAggregationScanOperator seriesAggregationScanOperator =
initSeriesAggregationScanOperator(aggregators, null, true, groupByTimeParameter);
@@ -431,7 +501,12 @@ public class SeriesAggregationScanOperatorTest {
GroupByTimeParameter groupByTimeParameter = new GroupByTimeParameter(0, 149, 50, 30, true);
List<TAggregationType> aggregationTypes = Collections.singletonList(TAggregationType.COUNT);
List<Aggregator> aggregators = new ArrayList<>();
- AccumulatorFactory.createAccumulators(aggregationTypes, TSDataType.INT32, true)
+ AccumulatorFactory.createAccumulators(
+ aggregationTypes,
+ TSDataType.INT32,
+ Collections.emptyList(),
+ Collections.emptyMap(),
+ true)
.forEach(o -> aggregators.add(new Aggregator(o, AggregationStep.SINGLE)));
SeriesAggregationScanOperator seriesAggregationScanOperator =
initSeriesAggregationScanOperator(aggregators, null, true, groupByTimeParameter);
@@ -465,7 +540,12 @@ public class SeriesAggregationScanOperatorTest {
aggregationTypes.add(TAggregationType.MIN_VALUE);
GroupByTimeParameter groupByTimeParameter = new GroupByTimeParameter(0, 149, 50, 30, true);
List<Aggregator> aggregators = new ArrayList<>();
- AccumulatorFactory.createAccumulators(aggregationTypes, TSDataType.INT32, true)
+ AccumulatorFactory.createAccumulators(
+ aggregationTypes,
+ TSDataType.INT32,
+ Collections.emptyList(),
+ Collections.emptyMap(),
+ true)
.forEach(o -> aggregators.add(new Aggregator(o, AggregationStep.SINGLE)));
SeriesAggregationScanOperator seriesAggregationScanOperator =
initSeriesAggregationScanOperator(aggregators, null, true, groupByTimeParameter);
diff --git a/server/src/test/java/org/apache/iotdb/db/mpp/execution/operator/SlidingWindowAggregationOperatorTest.java b/server/src/test/java/org/apache/iotdb/db/mpp/execution/operator/SlidingWindowAggregationOperatorTest.java
index a291267353..45db8edc84 100644
--- a/server/src/test/java/org/apache/iotdb/db/mpp/execution/operator/SlidingWindowAggregationOperatorTest.java
+++ b/server/src/test/java/org/apache/iotdb/db/mpp/execution/operator/SlidingWindowAggregationOperatorTest.java
@@ -225,7 +225,12 @@ public class SlidingWindowAggregationOperatorTest {
new MeasurementPath(AGGREGATION_OPERATOR_TEST_SG + ".device0.sensor0", TSDataType.INT32);
List<Aggregator> aggregators = new ArrayList<>();
- AccumulatorFactory.createAccumulators(leafAggregationTypes, TSDataType.INT32, ascending)
+ AccumulatorFactory.createAccumulators(
+ leafAggregationTypes,
+ TSDataType.INT32,
+ Collections.emptyList(),
+ Collections.emptyMap(),
+ ascending)
.forEach(
accumulator -> aggregators.add(new Aggregator(accumulator, AggregationStep.PARTIAL)));
@@ -250,6 +255,8 @@ public class SlidingWindowAggregationOperatorTest {
SlidingWindowAggregatorFactory.createSlidingWindowAggregator(
rootAggregationTypes.get(i),
TSDataType.INT32,
+ Collections.emptyList(),
+ Collections.emptyMap(),
ascending,
inputLocations.get(i).stream()
.map(tmpInputLocations -> tmpInputLocations.toArray(new InputLocation[0]))
diff --git a/server/src/test/java/org/apache/iotdb/db/mpp/plan/analyze/AggregationDescriptorTest.java b/server/src/test/java/org/apache/iotdb/db/mpp/plan/analyze/AggregationDescriptorTest.java
index 97e49107c1..bfd6a72a57 100644
--- a/server/src/test/java/org/apache/iotdb/db/mpp/plan/analyze/AggregationDescriptorTest.java
+++ b/server/src/test/java/org/apache/iotdb/db/mpp/plan/analyze/AggregationDescriptorTest.java
@@ -23,7 +23,6 @@ import org.apache.iotdb.common.rpc.thrift.TAggregationType;
import org.apache.iotdb.commons.exception.IllegalPathException;
import org.apache.iotdb.commons.path.MeasurementPath;
import org.apache.iotdb.commons.path.PartialPath;
-import org.apache.iotdb.db.mpp.plan.expression.Expression;
import org.apache.iotdb.db.mpp.plan.expression.leaf.TimeSeriesOperand;
import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.AggregationDescriptor;
import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.AggregationStep;
@@ -98,6 +97,8 @@ public class AggregationDescriptorTest {
Arrays.asList(
new TimeSeriesOperand(pathMap.get("root.sg.d2.s1")),
new TimeSeriesOperand(pathMap.get("root.sg.d1.s1"))),
+ 2,
+ Collections.emptyMap(),
new TimeSeriesOperand(pathMap.get("root.sg.*.s1"))));
groupByLevelDescriptorList.add(
new CrossSeriesAggregationDescriptor(
@@ -106,6 +107,8 @@ public class AggregationDescriptorTest {
Arrays.asList(
new TimeSeriesOperand(pathMap.get("root.sg.d1.s1")),
new TimeSeriesOperand(pathMap.get("root.sg.d2.s1"))),
+ 2,
+ Collections.emptyMap(),
new TimeSeriesOperand(pathMap.get("root.sg.*.s1"))));
groupByLevelDescriptorList.add(
new CrossSeriesAggregationDescriptor(
@@ -114,6 +117,8 @@ public class AggregationDescriptorTest {
Arrays.asList(
new TimeSeriesOperand(pathMap.get("root.sg.d2.s1")),
new TimeSeriesOperand(pathMap.get("root.sg.d1.s1"))),
+ 2,
+ Collections.emptyMap(),
new TimeSeriesOperand(pathMap.get("root.sg.*.s1"))));
groupByLevelDescriptorList.add(
new CrossSeriesAggregationDescriptor(
@@ -122,6 +127,8 @@ public class AggregationDescriptorTest {
Arrays.asList(
new TimeSeriesOperand(pathMap.get("root.sg.d1.s1")),
new TimeSeriesOperand(pathMap.get("root.sg.d2.s1"))),
+ 2,
+ Collections.emptyMap(),
new TimeSeriesOperand(pathMap.get("root.sg.*.s1"))));
}
@@ -197,48 +204,4 @@ public class AggregationDescriptorTest {
.map(CrossSeriesAggregationDescriptor::getInputColumnNamesList)
.collect(Collectors.toList()));
}
-
- @Test
- public void testGroupByLevelInputColumnCandidate() {
- List<Map<String, Expression>> expectedMapList =
- Arrays.asList(
- new HashMap<String, Expression>() {
- {
- put("count(root.sg.d2.s1)", new TimeSeriesOperand(pathMap.get("root.sg.d2.s1")));
- put("count(root.sg.d1.s1)", new TimeSeriesOperand(pathMap.get("root.sg.d1.s1")));
- put("count(root.sg.*.s1)", new TimeSeriesOperand(pathMap.get("root.sg.*.s1")));
- }
- },
- new HashMap<String, Expression>() {
- {
- put("avg(root.sg.*.s1)", new TimeSeriesOperand(pathMap.get("root.sg.*.s1")));
- put("count(root.sg.d2.s1)", new TimeSeriesOperand(pathMap.get("root.sg.d2.s1")));
- put("count(root.sg.d1.s1)", new TimeSeriesOperand(pathMap.get("root.sg.d1.s1")));
- put("sum(root.sg.d2.s1)", new TimeSeriesOperand(pathMap.get("root.sg.d2.s1")));
- put("sum(root.sg.d1.s1)", new TimeSeriesOperand(pathMap.get("root.sg.d1.s1")));
- }
- },
- new HashMap<String, Expression>() {
- {
- put("count(root.sg.d2.s1)", new TimeSeriesOperand(pathMap.get("root.sg.d2.s1")));
- put("count(root.sg.d1.s1)", new TimeSeriesOperand(pathMap.get("root.sg.d1.s1")));
- put("count(root.sg.*.s1)", new TimeSeriesOperand(pathMap.get("root.sg.*.s1")));
- }
- },
- new HashMap<String, Expression>() {
- {
- put("count(root.sg.d2.s1)", new TimeSeriesOperand(pathMap.get("root.sg.d2.s1")));
- put("count(root.sg.d1.s1)", new TimeSeriesOperand(pathMap.get("root.sg.d1.s1")));
- put("count(root.sg.*.s1)", new TimeSeriesOperand(pathMap.get("root.sg.*.s1")));
- put("sum(root.sg.d1.s1)", new TimeSeriesOperand(pathMap.get("root.sg.d1.s1")));
- put("sum(root.sg.d2.s1)", new TimeSeriesOperand(pathMap.get("root.sg.d2.s1")));
- put("sum(root.sg.*.s1)", new TimeSeriesOperand(pathMap.get("root.sg.*.s1")));
- }
- });
- Assert.assertEquals(
- expectedMapList,
- groupByLevelDescriptorList.stream()
- .map(CrossSeriesAggregationDescriptor::getInputColumnCandidateMap)
- .collect(Collectors.toList()));
- }
}
diff --git a/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/QueryLogicalPlanUtil.java b/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/QueryLogicalPlanUtil.java
index d0668f54fe..9b2c86213d 100644
--- a/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/QueryLogicalPlanUtil.java
+++ b/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/QueryLogicalPlanUtil.java
@@ -612,12 +612,16 @@ public class QueryLogicalPlanUtil {
Arrays.asList(
new TimeSeriesOperand(schemaMap.get("root.sg.d2.s1")),
new TimeSeriesOperand(schemaMap.get("root.sg.d1.s1"))),
+ 2,
+ Collections.emptyMap(),
new TimeSeriesOperand(schemaMap.get("root.sg.*.s1"))),
new CrossSeriesAggregationDescriptor(
TAggregationType.COUNT.name().toLowerCase(),
AggregationStep.FINAL,
Collections.singletonList(
new TimeSeriesOperand(schemaMap.get("root.sg.d2.a.s1"))),
+ 1,
+ Collections.emptyMap(),
new TimeSeriesOperand(schemaMap.get("root.sg.*.*.s1"))),
new CrossSeriesAggregationDescriptor(
TAggregationType.MAX_VALUE.name().toLowerCase(),
@@ -625,12 +629,16 @@ public class QueryLogicalPlanUtil {
Arrays.asList(
new TimeSeriesOperand(schemaMap.get("root.sg.d1.s2")),
new TimeSeriesOperand(schemaMap.get("root.sg.d2.s2"))),
+ 2,
+ Collections.emptyMap(),
new TimeSeriesOperand(schemaMap.get("root.sg.*.s2"))),
new CrossSeriesAggregationDescriptor(
TAggregationType.MAX_VALUE.name().toLowerCase(),
AggregationStep.FINAL,
Collections.singletonList(
new TimeSeriesOperand(schemaMap.get("root.sg.d2.a.s2"))),
+ 1,
+ Collections.emptyMap(),
new TimeSeriesOperand(schemaMap.get("root.sg.*.*.s2"))),
new CrossSeriesAggregationDescriptor(
TAggregationType.LAST_VALUE.name().toLowerCase(),
@@ -638,12 +646,16 @@ public class QueryLogicalPlanUtil {
Arrays.asList(
new TimeSeriesOperand(schemaMap.get("root.sg.d2.s1")),
new TimeSeriesOperand(schemaMap.get("root.sg.d1.s1"))),
+ 2,
+ Collections.emptyMap(),
new TimeSeriesOperand(schemaMap.get("root.sg.*.s1"))),
new CrossSeriesAggregationDescriptor(
TAggregationType.LAST_VALUE.name().toLowerCase(),
AggregationStep.FINAL,
Collections.singletonList(
new TimeSeriesOperand(schemaMap.get("root.sg.d2.a.s1"))),
+ 1,
+ Collections.emptyMap(),
new TimeSeriesOperand(schemaMap.get("root.sg.*.*.s1")))),
null,
Ordering.DESC);
@@ -863,6 +875,8 @@ public class QueryLogicalPlanUtil {
Arrays.asList(
new TimeSeriesOperand(schemaMap.get("root.sg.d2.s1")),
new TimeSeriesOperand(schemaMap.get("root.sg.d1.s1"))),
+ 2,
+ Collections.emptyMap(),
new TimeSeriesOperand(schemaMap.get("root.sg.*.s1"))),
new CrossSeriesAggregationDescriptor(
TAggregationType.MAX_VALUE.name().toLowerCase(),
@@ -870,6 +884,8 @@ public class QueryLogicalPlanUtil {
Arrays.asList(
new TimeSeriesOperand(schemaMap.get("root.sg.d1.s2")),
new TimeSeriesOperand(schemaMap.get("root.sg.d2.s2"))),
+ 2,
+ Collections.emptyMap(),
new TimeSeriesOperand(schemaMap.get("root.sg.*.s2"))),
new CrossSeriesAggregationDescriptor(
TAggregationType.LAST_VALUE.name().toLowerCase(),
@@ -877,6 +893,8 @@ public class QueryLogicalPlanUtil {
Arrays.asList(
new TimeSeriesOperand(schemaMap.get("root.sg.d2.s1")),
new TimeSeriesOperand(schemaMap.get("root.sg.d1.s1"))),
+ 2,
+ Collections.emptyMap(),
new TimeSeriesOperand(schemaMap.get("root.sg.*.s1")))),
null,
Ordering.DESC);
diff --git a/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/distribution/AggregationDistributionTest.java b/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/distribution/AggregationDistributionTest.java
index 66da5f0b79..3e6a6db0d0 100644
--- a/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/distribution/AggregationDistributionTest.java
+++ b/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/distribution/AggregationDistributionTest.java
@@ -302,6 +302,8 @@ public class AggregationDistributionTest {
Arrays.asList(
new TimeSeriesOperand(new PartialPath(d1s1Path)),
new TimeSeriesOperand(new PartialPath(d2s1Path))),
+ 2,
+ Collections.emptyMap(),
new TimeSeriesOperand(new PartialPath(groupedPath)))),
null,
Ordering.ASC);
@@ -340,6 +342,8 @@ public class AggregationDistributionTest {
Arrays.asList(
new TimeSeriesOperand(new PartialPath(d3s1Path)),
new TimeSeriesOperand(new PartialPath(d4s1Path))),
+ 2,
+ Collections.emptyMap(),
new TimeSeriesOperand(new PartialPath(groupedPath)))),
null,
Ordering.ASC);
@@ -402,6 +406,8 @@ public class AggregationDistributionTest {
Arrays.asList(
new TimeSeriesOperand(new PartialPath(d3s1Path)),
new TimeSeriesOperand(new PartialPath(d4s1Path))),
+ 2,
+ Collections.emptyMap(),
new TimeSeriesOperand(new PartialPath(groupedPath)))),
null,
Ordering.ASC);
@@ -482,11 +488,15 @@ public class AggregationDistributionTest {
TAggregationType.COUNT.name().toLowerCase(),
AggregationStep.FINAL,
Collections.singletonList(new TimeSeriesOperand(new PartialPath(d1s1Path))),
+ 1,
+ Collections.emptyMap(),
new TimeSeriesOperand(new PartialPath(groupedPathS1))),
new CrossSeriesAggregationDescriptor(
TAggregationType.COUNT.name().toLowerCase(),
AggregationStep.FINAL,
Collections.singletonList(new TimeSeriesOperand(new PartialPath(d1s2Path))),
+ 1,
+ Collections.emptyMap(),
new TimeSeriesOperand(new PartialPath(groupedPathS2)))),
null,
Ordering.ASC);
@@ -544,11 +554,15 @@ public class AggregationDistributionTest {
Arrays.asList(
new TimeSeriesOperand(new PartialPath(d1s1Path)),
new TimeSeriesOperand(new PartialPath(d2s1Path))),
+ 2,
+ Collections.emptyMap(),
new TimeSeriesOperand(new PartialPath(groupedPathS1))),
new CrossSeriesAggregationDescriptor(
TAggregationType.COUNT.name().toLowerCase(),
AggregationStep.FINAL,
Collections.singletonList(new TimeSeriesOperand(new PartialPath(d1s2Path))),
+ 1,
+ Collections.emptyMap(),
new TimeSeriesOperand(new PartialPath(groupedPathS2)))),
null,
Ordering.ASC);
@@ -619,11 +633,15 @@ public class AggregationDistributionTest {
Arrays.asList(
new TimeSeriesOperand(new PartialPath(d1s1Path)),
new TimeSeriesOperand(new PartialPath(d2s1Path))),
+ 2,
+ Collections.emptyMap(),
new TimeSeriesOperand(new PartialPath(groupedPathS1))),
new CrossSeriesAggregationDescriptor(
TAggregationType.COUNT.name().toLowerCase(),
AggregationStep.FINAL,
Collections.singletonList(new TimeSeriesOperand(new PartialPath(d1s2Path))),
+ 2,
+ Collections.emptyMap(),
new TimeSeriesOperand(new PartialPath(groupedPathS2)))),
null,
Ordering.ASC);
diff --git a/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/node/process/GroupByLevelNodeSerdeTest.java b/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/node/process/GroupByLevelNodeSerdeTest.java
index e0fa4f86c0..87af66e497 100644
--- a/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/node/process/GroupByLevelNodeSerdeTest.java
+++ b/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/node/process/GroupByLevelNodeSerdeTest.java
@@ -91,6 +91,8 @@ public class GroupByLevelNodeSerdeTest {
Arrays.asList(
new TimeSeriesOperand(new PartialPath("root.sg.d1.s1")),
new TimeSeriesOperand(new PartialPath("root.sg.d2.s1"))),
+ 2,
+ Collections.emptyMap(),
new TimeSeriesOperand(new PartialPath("root.sg.*.s1")))),
groupByTimeParameter,
Ordering.ASC);
diff --git a/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/node/process/GroupByTagNodeSerdeTest.java b/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/node/process/GroupByTagNodeSerdeTest.java
index edb423f1a3..0f35462af4 100644
--- a/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/node/process/GroupByTagNodeSerdeTest.java
+++ b/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/node/process/GroupByTagNodeSerdeTest.java
@@ -60,6 +60,8 @@ public class GroupByTagNodeSerdeTest {
TAggregationType.MAX_TIME.name().toLowerCase(),
AggregationStep.FINAL,
Collections.singletonList(new TimeSeriesOperand(new PartialPath("root.sg.d1.s1"))),
+ 1,
+ Collections.emptyMap(),
new FunctionExpression(
"max_time",
new LinkedHashMap<>(),
@@ -70,6 +72,8 @@ public class GroupByTagNodeSerdeTest {
TAggregationType.AVG.name().toLowerCase(),
AggregationStep.FINAL,
Collections.singletonList(new TimeSeriesOperand(new PartialPath("root.sg.d1.s1"))),
+ 1,
+ Collections.emptyMap(),
new FunctionExpression(
"avg",
new LinkedHashMap<>(),
diff --git a/thrift-commons/src/main/thrift/common.thrift b/thrift-commons/src/main/thrift/common.thrift
index 6f10259f6c..49ee669ffe 100644
--- a/thrift-commons/src/main/thrift/common.thrift
+++ b/thrift-commons/src/main/thrift/common.thrift
@@ -138,5 +138,6 @@ enum TAggregationType {
MIN_TIME,
MAX_VALUE,
MIN_VALUE,
- EXTREME
+ EXTREME,
+ COUNT_IF
}