You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iceberg.apache.org by bl...@apache.org on 2022/10/25 21:22:05 UTC
[iceberg] branch master updated: Core, Spark: Add Aggregate expressions (#5961)
This is an automated email from the ASF dual-hosted git repository.
blue pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iceberg.git
The following commit(s) were added to refs/heads/master by this push:
new 8271791061 Core, Spark: Add Aggregate expressions (#5961)
8271791061 is described below
commit 8271791061e72dbf554028e477746e89fca9ec02
Author: Huaxin Gao <hu...@apple.com>
AuthorDate: Tue Oct 25 14:21:57 2022 -0700
Core, Spark: Add Aggregate expressions (#5961)
---
.../org/apache/iceberg/expressions/Aggregate.java | 58 +++++++++++
.../org/apache/iceberg/expressions/Binder.java | 10 ++
.../apache/iceberg/expressions/BoundAggregate.java | 47 +++++++++
.../org/apache/iceberg/expressions/Expression.java | 6 +-
.../iceberg/expressions/ExpressionVisitors.java | 14 +++
.../apache/iceberg/expressions/Expressions.java | 16 +++
.../iceberg/expressions/UnboundAggregate.java | 57 ++++++++++
.../iceberg/expressions/TestAggregateBinding.java | 116 +++++++++++++++++++++
.../org/apache/iceberg/spark/SparkAggregates.java | 80 ++++++++++++++
.../java/org/apache/iceberg/spark/SparkUtil.java | 8 ++
.../org/apache/iceberg/spark/SparkV2Filters.java | 39 +++----
.../iceberg/spark/source/TestSparkAggregates.java | 76 ++++++++++++++
12 files changed, 503 insertions(+), 24 deletions(-)
diff --git a/api/src/main/java/org/apache/iceberg/expressions/Aggregate.java b/api/src/main/java/org/apache/iceberg/expressions/Aggregate.java
new file mode 100644
index 0000000000..7db1822e49
--- /dev/null
+++ b/api/src/main/java/org/apache/iceberg/expressions/Aggregate.java
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.expressions;
+
+/**
+ * The aggregate functions that can be pushed and evaluated in Iceberg. Currently only three
+ * aggregate functions Max, Min and Count are supported.
+ */
+public abstract class Aggregate<C extends Term> implements Expression {
+ private final Operation op;
+ private final C term;
+
+ Aggregate(Operation op, C term) {
+ this.op = op;
+ this.term = term;
+ }
+
+ @Override
+ public Operation op() {
+ return op;
+ }
+
+ public C term() {
+ return term;
+ }
+
+ @Override
+ public String toString() {
+ switch (op()) {
+ case COUNT:
+ return "count(" + term() + ")";
+ case COUNT_STAR:
+ return "count(*)";
+ case MAX:
+ return "max(" + term() + ")";
+ case MIN:
+ return "min(" + term() + ")";
+ default:
+ throw new UnsupportedOperationException("Invalid aggregate: " + op());
+ }
+ }
+}
diff --git a/api/src/main/java/org/apache/iceberg/expressions/Binder.java b/api/src/main/java/org/apache/iceberg/expressions/Binder.java
index d2a7b1d09e..3454fa14e0 100644
--- a/api/src/main/java/org/apache/iceberg/expressions/Binder.java
+++ b/api/src/main/java/org/apache/iceberg/expressions/Binder.java
@@ -158,6 +158,16 @@ public class Binder {
public <T> Expression predicate(UnboundPredicate<T> pred) {
return pred.bind(struct, caseSensitive);
}
+
+ @Override
+ public <T> Expression aggregate(UnboundAggregate<T> agg) {
+ return agg.bind(struct, caseSensitive);
+ }
+
+ @Override
+ public <T, C> Expression aggregate(BoundAggregate<T, C> agg) {
+ throw new IllegalStateException("Found already bound aggregate: " + agg);
+ }
}
private static class ReferenceVisitor extends ExpressionVisitor<Set<Integer>> {
diff --git a/api/src/main/java/org/apache/iceberg/expressions/BoundAggregate.java b/api/src/main/java/org/apache/iceberg/expressions/BoundAggregate.java
new file mode 100644
index 0000000000..650271b3b7
--- /dev/null
+++ b/api/src/main/java/org/apache/iceberg/expressions/BoundAggregate.java
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.expressions;
+
+import org.apache.iceberg.StructLike;
+import org.apache.iceberg.types.Type;
+import org.apache.iceberg.types.Types;
+
+public class BoundAggregate<T, C> extends Aggregate<BoundTerm<T>> implements Bound<C> {
+ protected BoundAggregate(Operation op, BoundTerm<T> term) {
+ super(op, term);
+ }
+
+ @Override
+ public C eval(StructLike struct) {
+ throw new UnsupportedOperationException(this.getClass().getName() + " does not implement eval");
+ }
+
+ @Override
+ public BoundReference<?> ref() {
+ return term().ref();
+ }
+
+ public Type type() {
+ if (op() == Operation.COUNT || op() == Operation.COUNT_STAR) {
+ return Types.LongType.get();
+ } else {
+ return term().type();
+ }
+ }
+}
diff --git a/api/src/main/java/org/apache/iceberg/expressions/Expression.java b/api/src/main/java/org/apache/iceberg/expressions/Expression.java
index cd82aa07ad..dc88172c59 100644
--- a/api/src/main/java/org/apache/iceberg/expressions/Expression.java
+++ b/api/src/main/java/org/apache/iceberg/expressions/Expression.java
@@ -43,7 +43,11 @@ public interface Expression extends Serializable {
AND,
OR,
STARTS_WITH,
- NOT_STARTS_WITH;
+ NOT_STARTS_WITH,
+ COUNT,
+ COUNT_STAR,
+ MAX,
+ MIN;
public static Operation fromString(String operationType) {
Preconditions.checkArgument(null != operationType, "Invalid operation type: null");
diff --git a/api/src/main/java/org/apache/iceberg/expressions/ExpressionVisitors.java b/api/src/main/java/org/apache/iceberg/expressions/ExpressionVisitors.java
index 4076d7febc..79ca6a7128 100644
--- a/api/src/main/java/org/apache/iceberg/expressions/ExpressionVisitors.java
+++ b/api/src/main/java/org/apache/iceberg/expressions/ExpressionVisitors.java
@@ -55,6 +55,14 @@ public class ExpressionVisitors {
public <T> R predicate(UnboundPredicate<T> pred) {
return null;
}
+
+ public <T, C> R aggregate(BoundAggregate<T, C> agg) {
+ throw new UnsupportedOperationException("Cannot visit aggregate expression");
+ }
+
+ public <T> R aggregate(UnboundAggregate<T> agg) {
+ throw new UnsupportedOperationException("Cannot visit aggregate expression");
+ }
}
public abstract static class BoundExpressionVisitor<R> extends ExpressionVisitor<R> {
@@ -338,6 +346,12 @@ public class ExpressionVisitors {
} else {
return visitor.predicate((UnboundPredicate<?>) expr);
}
+ } else if (expr instanceof Aggregate) {
+ if (expr instanceof BoundAggregate) {
+ return visitor.aggregate((BoundAggregate<?, ?>) expr);
+ } else {
+ return visitor.aggregate((UnboundAggregate<?>) expr);
+ }
} else {
switch (expr.op()) {
case TRUE:
diff --git a/api/src/main/java/org/apache/iceberg/expressions/Expressions.java b/api/src/main/java/org/apache/iceberg/expressions/Expressions.java
index 7fad8324c4..171da823cc 100644
--- a/api/src/main/java/org/apache/iceberg/expressions/Expressions.java
+++ b/api/src/main/java/org/apache/iceberg/expressions/Expressions.java
@@ -308,4 +308,20 @@ public class Expressions {
public static <T> UnboundTerm<T> transform(String name, Transform<?, T> transform) {
return new UnboundTransform<>(ref(name), transform);
}
+
+ public static UnboundAggregate<String> count(String name) {
+ return new UnboundAggregate<>(Operation.COUNT, ref(name));
+ }
+
+ public static UnboundAggregate<String> countStar() {
+ return new UnboundAggregate<>(Operation.COUNT_STAR, null);
+ }
+
+ public static UnboundAggregate<String> max(String name) {
+ return new UnboundAggregate<>(Operation.MAX, ref(name));
+ }
+
+ public static UnboundAggregate<String> min(String name) {
+ return new UnboundAggregate<>(Operation.MIN, ref(name));
+ }
}
diff --git a/api/src/main/java/org/apache/iceberg/expressions/UnboundAggregate.java b/api/src/main/java/org/apache/iceberg/expressions/UnboundAggregate.java
new file mode 100644
index 0000000000..5e4cce06c7
--- /dev/null
+++ b/api/src/main/java/org/apache/iceberg/expressions/UnboundAggregate.java
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.expressions;
+
+import org.apache.iceberg.exceptions.ValidationException;
+import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
+import org.apache.iceberg.types.Types;
+
+public class UnboundAggregate<T> extends Aggregate<UnboundTerm<T>>
+ implements Unbound<T, Expression> {
+
+ UnboundAggregate(Operation op, UnboundTerm<T> term) {
+ super(op, term);
+ }
+
+ @Override
+ public NamedReference<?> ref() {
+ return term().ref();
+ }
+
+ /**
+ * Bind this UnboundAggregate.
+ *
+ * @param struct The {@link Types.StructType struct type} to resolve references by name.
+ * @param caseSensitive A boolean flag to control whether the bind should enforce case
+ * sensitivity.
+ * @return an {@link Expression}
+ * @throws ValidationException if literals do not match bound references, or if comparison on
+ * expression is invalid
+ */
+ @Override
+ public Expression bind(Types.StructType struct, boolean caseSensitive) {
+ if (op() == Operation.COUNT_STAR) {
+ return new BoundAggregate<>(op(), null);
+ } else {
+ Preconditions.checkArgument(term() != null, "Invalid aggregate term: null");
+ BoundTerm<T> bound = term().bind(struct, caseSensitive);
+ return new BoundAggregate<>(op(), bound);
+ }
+ }
+}
diff --git a/api/src/test/java/org/apache/iceberg/expressions/TestAggregateBinding.java b/api/src/test/java/org/apache/iceberg/expressions/TestAggregateBinding.java
new file mode 100644
index 0000000000..7bbe8ad7da
--- /dev/null
+++ b/api/src/test/java/org/apache/iceberg/expressions/TestAggregateBinding.java
@@ -0,0 +1,116 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.expressions;
+
+import java.util.Arrays;
+import java.util.List;
+import org.apache.iceberg.exceptions.ValidationException;
+import org.apache.iceberg.types.Types;
+import org.apache.iceberg.types.Types.StructType;
+import org.assertj.core.api.Assertions;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class TestAggregateBinding {
+ private static final List<Expression.Operation> AGGREGATES =
+ Arrays.asList(Expression.Operation.COUNT, Expression.Operation.MAX, Expression.Operation.MIN);
+ private static final StructType struct =
+ StructType.of(Types.NestedField.required(10, "x", Types.IntegerType.get()));
+
+ @Test
+ public void testAggregateBinding() {
+ for (Expression.Operation op : AGGREGATES) {
+ UnboundAggregate unbound = null;
+ switch (op) {
+ case COUNT:
+ unbound = Expressions.count("x");
+ break;
+ case MAX:
+ unbound = Expressions.max("x");
+ break;
+ case MIN:
+ unbound = Expressions.min("x");
+ break;
+ default:
+ throw new UnsupportedOperationException("Invalid aggregate: " + op);
+ }
+
+ Expression expr = unbound.bind(struct, true);
+ BoundAggregate bound = assertAndUnwrapAggregate(expr);
+
+ Assert.assertEquals("Should reference correct field ID", 10, bound.ref().fieldId());
+ Assert.assertEquals("Should not change the comparison operation", op, bound.op());
+ }
+ }
+
+ @Test
+ public void testCountStarBinding() {
+ UnboundAggregate unbound = Expressions.countStar();
+ Expression expr = unbound.bind(null, false);
+ BoundAggregate bound = assertAndUnwrapAggregate(expr);
+
+ Assert.assertEquals(
+ "Should not change the comparison operation", Expression.Operation.COUNT_STAR, bound.op());
+ }
+
+ @Test
+ public void testBoundAggregateFails() {
+ Expression unbound = Expressions.count("x");
+ Assertions.assertThatThrownBy(() -> Binder.bind(struct, Binder.bind(struct, unbound)))
+ .isInstanceOf(IllegalStateException.class)
+ .hasMessageContaining("Found already bound aggregate");
+ }
+
+ @Test
+ public void testCaseInsensitiveReference() {
+ Expression expr = Expressions.max("X");
+ Expression boundExpr = Binder.bind(struct, expr, false);
+ BoundAggregate bound = assertAndUnwrapAggregate(boundExpr);
+ Assert.assertEquals("Should reference correct field ID", 10, bound.ref().fieldId());
+ Assert.assertEquals(
+ "Should not change the comparison operation", Expression.Operation.MAX, bound.op());
+ }
+
+ @Test
+ public void testCaseSensitiveReference() {
+ Expression expr = Expressions.max("X");
+ Assertions.assertThatThrownBy(() -> Binder.bind(struct, expr, true))
+ .isInstanceOf(ValidationException.class)
+ .hasMessageContaining("Cannot find field 'X' in struct");
+ }
+
+ @Test
+ public void testMissingField() {
+ UnboundAggregate unbound = Expressions.count("missing");
+ try {
+ unbound.bind(struct, false);
+ Assert.fail("Binding a missing field should fail");
+ } catch (ValidationException e) {
+ Assert.assertTrue(
+ "Validation should complain about missing field",
+ e.getMessage().contains("Cannot find field 'missing' in struct:"));
+ }
+ }
+
+ private static <T, C> BoundAggregate<T, C> assertAndUnwrapAggregate(Expression expr) {
+ Assert.assertTrue(
+ "Expression should be a bound aggregate: " + expr, expr instanceof BoundAggregate);
+ return (BoundAggregate<T, C>) expr;
+ }
+}
diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkAggregates.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkAggregates.java
new file mode 100644
index 0000000000..6741e33fa1
--- /dev/null
+++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkAggregates.java
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.spark;
+
+import java.util.Map;
+import org.apache.iceberg.expressions.Expression;
+import org.apache.iceberg.expressions.Expression.Operation;
+import org.apache.iceberg.expressions.Expressions;
+import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
+import org.apache.spark.sql.connector.expressions.NamedReference;
+import org.apache.spark.sql.connector.expressions.aggregate.AggregateFunc;
+import org.apache.spark.sql.connector.expressions.aggregate.Count;
+import org.apache.spark.sql.connector.expressions.aggregate.CountStar;
+import org.apache.spark.sql.connector.expressions.aggregate.Max;
+import org.apache.spark.sql.connector.expressions.aggregate.Min;
+
+public class SparkAggregates {
+ private SparkAggregates() {}
+
+ private static final Map<Class<? extends AggregateFunc>, Operation> AGGREGATES =
+ ImmutableMap.<Class<? extends AggregateFunc>, Operation>builder()
+ .put(Count.class, Operation.COUNT)
+ .put(CountStar.class, Operation.COUNT_STAR)
+ .put(Max.class, Operation.MAX)
+ .put(Min.class, Operation.MIN)
+ .build();
+
+ public static Expression convert(AggregateFunc aggregate) {
+ Operation op = AGGREGATES.get(aggregate.getClass());
+ if (op != null) {
+ switch (op) {
+ case COUNT:
+ Count countAgg = (Count) aggregate;
+ if (countAgg.isDistinct()) {
+ // manifest file doesn't have count distinct so this can't be converted to push down
+ return null;
+ }
+
+ if (countAgg.column() instanceof NamedReference) {
+ return Expressions.count(SparkUtil.toColumnName((NamedReference) countAgg.column()));
+ } else {
+ return null;
+ }
+ case COUNT_STAR:
+ return Expressions.countStar();
+ case MAX:
+ Max maxAgg = (Max) aggregate;
+ if (maxAgg.column() instanceof NamedReference) {
+ return Expressions.max(SparkUtil.toColumnName((NamedReference) maxAgg.column()));
+ } else {
+ return null;
+ }
+ case MIN:
+ Min minAgg = (Min) aggregate;
+ if (minAgg.column() instanceof NamedReference) {
+ return Expressions.min(SparkUtil.toColumnName((NamedReference) minAgg.column()));
+ } else {
+ return null;
+ }
+ }
+ }
+ return null;
+ }
+}
diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkUtil.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkUtil.java
index 950ed7bc87..2e8312fd97 100644
--- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkUtil.java
+++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkUtil.java
@@ -32,6 +32,7 @@ import org.apache.iceberg.Schema;
import org.apache.iceberg.Table;
import org.apache.iceberg.hadoop.HadoopConfigurable;
import org.apache.iceberg.io.FileIO;
+import org.apache.iceberg.relocated.com.google.common.base.Joiner;
import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
import org.apache.iceberg.transforms.Transform;
@@ -45,6 +46,7 @@ import org.apache.spark.sql.catalyst.expressions.BoundReference;
import org.apache.spark.sql.catalyst.expressions.EqualTo;
import org.apache.spark.sql.catalyst.expressions.Expression;
import org.apache.spark.sql.catalyst.expressions.Literal;
+import org.apache.spark.sql.connector.expressions.NamedReference;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructType;
@@ -73,6 +75,8 @@ public class SparkUtil {
private static final String SPARK_CATALOG_HADOOP_CONF_OVERRIDE_FMT_STR =
SPARK_CATALOG_CONF_PREFIX + ".%s.hadoop.";
+ private static final Joiner DOT = Joiner.on(".");
+
private SparkUtil() {}
public static FileIO serializableFileIO(Table table) {
@@ -287,4 +291,8 @@ public class SparkUtil {
return filterExpressions;
}
+
+ public static String toColumnName(NamedReference ref) {
+ return DOT.join(ref.fieldNames());
+ }
}
diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkV2Filters.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkV2Filters.java
index 2f09e6e9c9..072c14c08b 100644
--- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkV2Filters.java
+++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkV2Filters.java
@@ -40,7 +40,6 @@ import java.util.stream.Collectors;
import org.apache.iceberg.expressions.Expression;
import org.apache.iceberg.expressions.Expression.Operation;
import org.apache.iceberg.expressions.Expressions;
-import org.apache.iceberg.relocated.com.google.common.base.Joiner;
import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
import org.apache.iceberg.util.NaNUtil;
@@ -54,8 +53,6 @@ import org.apache.spark.unsafe.types.UTF8String;
public class SparkV2Filters {
- private static final Joiner DOT = Joiner.on(".");
-
private static final String TRUE = "ALWAYS_TRUE";
private static final String FALSE = "ALWAYS_FALSE";
private static final String EQ = "=";
@@ -105,17 +102,17 @@ public class SparkV2Filters {
return Expressions.alwaysFalse();
case IS_NULL:
- return isRef(child(predicate)) ? isNull(toColumnName(child(predicate))) : null;
+ return isRef(child(predicate)) ? isNull(SparkUtil.toColumnName(child(predicate))) : null;
case NOT_NULL:
- return isRef(child(predicate)) ? notNull(toColumnName(child(predicate))) : null;
+ return isRef(child(predicate)) ? notNull(SparkUtil.toColumnName(child(predicate))) : null;
case LT:
if (isRef(leftChild(predicate)) && isLiteral(rightChild(predicate))) {
- String columnName = toColumnName(leftChild(predicate));
+ String columnName = SparkUtil.toColumnName(leftChild(predicate));
return lessThan(columnName, convertLiteral(rightChild(predicate)));
} else if (isRef(rightChild(predicate)) && isLiteral(leftChild(predicate))) {
- String columnName = toColumnName(rightChild(predicate));
+ String columnName = SparkUtil.toColumnName(rightChild(predicate));
return greaterThan(columnName, convertLiteral(leftChild(predicate)));
} else {
return null;
@@ -123,10 +120,10 @@ public class SparkV2Filters {
case LT_EQ:
if (isRef(leftChild(predicate)) && isLiteral(rightChild(predicate))) {
- String columnName = toColumnName(leftChild(predicate));
+ String columnName = SparkUtil.toColumnName(leftChild(predicate));
return lessThanOrEqual(columnName, convertLiteral(rightChild(predicate)));
} else if (isRef(rightChild(predicate)) && isLiteral(leftChild(predicate))) {
- String columnName = toColumnName(rightChild(predicate));
+ String columnName = SparkUtil.toColumnName(rightChild(predicate));
return greaterThanOrEqual(columnName, convertLiteral(leftChild(predicate)));
} else {
return null;
@@ -134,10 +131,10 @@ public class SparkV2Filters {
case GT:
if (isRef(leftChild(predicate)) && isLiteral(rightChild(predicate))) {
- String columnName = toColumnName(leftChild(predicate));
+ String columnName = SparkUtil.toColumnName(leftChild(predicate));
return greaterThan(columnName, convertLiteral(rightChild(predicate)));
} else if (isRef(rightChild(predicate)) && isLiteral(leftChild(predicate))) {
- String columnName = toColumnName(rightChild(predicate));
+ String columnName = SparkUtil.toColumnName(rightChild(predicate));
return lessThan(columnName, convertLiteral(leftChild(predicate)));
} else {
return null;
@@ -145,10 +142,10 @@ public class SparkV2Filters {
case GT_EQ:
if (isRef(leftChild(predicate)) && isLiteral(rightChild(predicate))) {
- String columnName = toColumnName(leftChild(predicate));
+ String columnName = SparkUtil.toColumnName(leftChild(predicate));
return greaterThanOrEqual(columnName, convertLiteral(rightChild(predicate)));
} else if (isRef(rightChild(predicate)) && isLiteral(leftChild(predicate))) {
- String columnName = toColumnName(rightChild(predicate));
+ String columnName = SparkUtil.toColumnName(rightChild(predicate));
return lessThanOrEqual(columnName, convertLiteral(leftChild(predicate)));
} else {
return null;
@@ -158,10 +155,10 @@ public class SparkV2Filters {
Object value;
String columnName;
if (isRef(leftChild(predicate)) && isLiteral(rightChild(predicate))) {
- columnName = toColumnName(leftChild(predicate));
+ columnName = SparkUtil.toColumnName(leftChild(predicate));
value = convertLiteral(rightChild(predicate));
} else if (isRef(rightChild(predicate)) && isLiteral(leftChild(predicate))) {
- columnName = toColumnName(rightChild(predicate));
+ columnName = SparkUtil.toColumnName(rightChild(predicate));
value = convertLiteral(leftChild(predicate));
} else {
return null;
@@ -183,7 +180,7 @@ public class SparkV2Filters {
case IN:
if (isSupportedInPredicate(predicate)) {
return in(
- toColumnName(childAtIndex(predicate, 0)),
+ SparkUtil.toColumnName(childAtIndex(predicate, 0)),
Arrays.stream(predicate.children())
.skip(1)
.map(val -> convertLiteral(((Literal<?>) val)))
@@ -202,13 +199,13 @@ public class SparkV2Filters {
// col NOT IN (1, 2) in Spark is equal to notNull(col) && notIn(col, 1, 2) in Iceberg
Expression notIn =
notIn(
- toColumnName(childAtIndex(childPredicate, 0)),
+ SparkUtil.toColumnName(childAtIndex(childPredicate, 0)),
Arrays.stream(childPredicate.children())
.skip(1)
.map(val -> convertLiteral(((Literal<?>) val)))
.filter(Objects::nonNull)
.collect(Collectors.toList()));
- return and(notNull(toColumnName(childAtIndex(childPredicate, 0))), notIn);
+ return and(notNull(SparkUtil.toColumnName(childAtIndex(childPredicate, 0))), notIn);
} else if (hasNoInFilter(childPredicate)) {
Expression child = convert(childPredicate);
if (child != null) {
@@ -240,7 +237,7 @@ public class SparkV2Filters {
}
case STARTS_WITH:
- String colName = toColumnName(leftChild(predicate));
+ String colName = SparkUtil.toColumnName(leftChild(predicate));
return startsWith(colName, convertLiteral(rightChild(predicate)).toString());
}
}
@@ -248,10 +245,6 @@ public class SparkV2Filters {
return null;
}
- private static String toColumnName(NamedReference ref) {
- return DOT.join(ref.fieldNames());
- }
-
@SuppressWarnings("unchecked")
private static <T> T child(Predicate predicate) {
org.apache.spark.sql.connector.expressions.Expression[] children = predicate.children();
diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkAggregates.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkAggregates.java
new file mode 100644
index 0000000000..e2d6f744f5
--- /dev/null
+++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkAggregates.java
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.spark.source;
+
+import java.util.Map;
+import org.apache.iceberg.expressions.Expression;
+import org.apache.iceberg.expressions.Expressions;
+import org.apache.iceberg.relocated.com.google.common.collect.Maps;
+import org.apache.iceberg.spark.SparkAggregates;
+import org.apache.spark.sql.connector.expressions.FieldReference;
+import org.apache.spark.sql.connector.expressions.NamedReference;
+import org.apache.spark.sql.connector.expressions.aggregate.Count;
+import org.apache.spark.sql.connector.expressions.aggregate.CountStar;
+import org.apache.spark.sql.connector.expressions.aggregate.Max;
+import org.apache.spark.sql.connector.expressions.aggregate.Min;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class TestSparkAggregates {
+
+ @Test
+ public void testAggregates() {
+ Map<String, String> attrMap = Maps.newHashMap();
+ attrMap.put("id", "id");
+ attrMap.put("`i.d`", "i.d");
+ attrMap.put("`i``d`", "i`d");
+ attrMap.put("`d`.b.`dd```", "d.b.dd`");
+ attrMap.put("a.`aa```.c", "a.aa`.c");
+
+ attrMap.forEach(
+ (quoted, unquoted) -> {
+ NamedReference namedReference = FieldReference.apply(quoted);
+
+ Max max = new Max(namedReference);
+ Expression expectedMax = Expressions.max(unquoted);
+ Expression actualMax = SparkAggregates.convert(max);
+ Assert.assertEquals("Max must match", expectedMax.toString(), actualMax.toString());
+
+ Min min = new Min(namedReference);
+ Expression expectedMin = Expressions.min(unquoted);
+ Expression actualMin = SparkAggregates.convert(min);
+ Assert.assertEquals("Min must match", expectedMin.toString(), actualMin.toString());
+
+ Count count = new Count(namedReference, false);
+ Expression expectedCount = Expressions.count(unquoted);
+ Expression actualCount = SparkAggregates.convert(count);
+ Assert.assertEquals("Count must match", expectedCount.toString(), actualCount.toString());
+
+ Count countDistinct = new Count(namedReference, true);
+ Expression convertedCountDistinct = SparkAggregates.convert(countDistinct);
+ Assert.assertNull("Count Distinct is converted to null", convertedCountDistinct);
+
+ CountStar countStar = new CountStar();
+ Expression expectedCountStar = Expressions.countStar();
+ Expression actualCountStar = SparkAggregates.convert(countStar);
+ Assert.assertEquals(
+ "CountStar must match", expectedCountStar.toString(), actualCountStar.toString());
+ });
+ }
+}