You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@iceberg.apache.org by GitBox <gi...@apache.org> on 2022/07/18 22:07:33 UTC

[GitHub] [iceberg] huaxingao commented on a diff in pull request #5302: Add SparkV2Filters

huaxingao commented on code in PR #5302:
URL: https://github.com/apache/iceberg/pull/5302#discussion_r923900082


##########
spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkV2Filters.java:
##########
@@ -0,0 +1,356 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iceberg.spark;
+
+import java.util.Map;
+import java.util.Objects;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+import org.apache.iceberg.expressions.Expression;
+import org.apache.iceberg.expressions.Expressions;
+import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
+import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
+import org.apache.iceberg.util.NaNUtil;
+import org.apache.spark.sql.connector.expressions.LiteralValue;
+import org.apache.spark.sql.connector.expressions.NamedReference;
+import org.apache.spark.sql.connector.expressions.filter.And;
+import org.apache.spark.sql.connector.expressions.filter.Not;
+import org.apache.spark.sql.connector.expressions.filter.Or;
+import org.apache.spark.sql.connector.expressions.filter.Predicate;
+import org.apache.spark.unsafe.types.UTF8String;
+
+import static org.apache.iceberg.expressions.Expressions.and;
+import static org.apache.iceberg.expressions.Expressions.equal;
+import static org.apache.iceberg.expressions.Expressions.greaterThan;
+import static org.apache.iceberg.expressions.Expressions.greaterThanOrEqual;
+import static org.apache.iceberg.expressions.Expressions.in;
+import static org.apache.iceberg.expressions.Expressions.isNaN;
+import static org.apache.iceberg.expressions.Expressions.isNull;
+import static org.apache.iceberg.expressions.Expressions.lessThan;
+import static org.apache.iceberg.expressions.Expressions.lessThanOrEqual;
+import static org.apache.iceberg.expressions.Expressions.not;
+import static org.apache.iceberg.expressions.Expressions.notIn;
+import static org.apache.iceberg.expressions.Expressions.notNull;
+import static org.apache.iceberg.expressions.Expressions.or;
+import static org.apache.iceberg.expressions.Expressions.startsWith;
+
+public class SparkV2Filters {
+
+  private static final Pattern BACKTICKS_PATTERN = Pattern.compile("([`])(.|$)");
+
+  private SparkV2Filters() {
+  }
+
+  private static final Map<String, Expression.Operation> FILTERS = ImmutableMap
+      .<String, Expression.Operation>builder()
+      .put("ALWAYS_TRUE", Expression.Operation.TRUE)
+      .put("ALWAYS_FALSE", Expression.Operation.FALSE)
+      .put("=", Expression.Operation.EQ)
+      .put("<=>", Expression.Operation.EQ)
+      .put(">", Expression.Operation.GT)
+      .put(">=", Expression.Operation.GT_EQ)
+      .put("<", Expression.Operation.LT)
+      .put("<=", Expression.Operation.LT_EQ)
+      .put("IN", Expression.Operation.IN)
+      .put("IS_NULL", Expression.Operation.IS_NULL)
+      .put("IS_NOT_NULL", Expression.Operation.NOT_NULL)
+      .put("AND", Expression.Operation.AND)
+      .put("OR", Expression.Operation.OR)
+      .put("NOT", Expression.Operation.NOT)
+      .put("STARTS_WITH", Expression.Operation.STARTS_WITH)
+      .build();
+
+  public static Expression convert(Predicate[] predicates) {
+    Expression expression = Expressions.alwaysTrue();
+    for (Predicate predicate : predicates) {
+      Expression converted = convert(predicate);
+      Preconditions.checkArgument(converted != null, "Cannot convert predicate to Iceberg: %s", predicate);
+      expression = Expressions.and(expression, converted);
+    }
+    return expression;
+  }
+
+  @SuppressWarnings({"checkstyle:CyclomaticComplexity", "checkstyle:MethodLength"})
+  public static Expression convert(Predicate predicate) {
+    if (checkIfPredicateValid(predicate) == null) {
+      return null;
+    }
+
+    Expression.Operation op = FILTERS.get(predicate.name());
+    if (op != null) {
+      switch (op) {
+        case TRUE:
+          return Expressions.alwaysTrue();
+
+        case FALSE:
+          return Expressions.alwaysFalse();
+
+        case IS_NULL:
+          return isNull(unquote(getAttrName(predicate.children()[0])));
+
+        case NOT_NULL:
+          return notNull(unquote(getAttrName(predicate.children()[0])));
+
+        case LT:
+          if (predicate.children()[1] instanceof LiteralValue) {
+            return lessThan(unquote(getAttrName(predicate.children()[0])),
+                convertUTF8StringIfNecessary(((LiteralValue) predicate.children()[1]).value()));
+          } else {
+            return greaterThan(unquote(getAttrName(predicate.children()[1])),
+                convertUTF8StringIfNecessary(((LiteralValue) predicate.children()[0]).value()));
+          }
+
+        case LT_EQ:
+          if (predicate.children()[1] instanceof LiteralValue) {
+            return lessThanOrEqual(unquote(getAttrName(predicate.children()[0])),
+                convertUTF8StringIfNecessary(((LiteralValue) predicate.children()[1]).value()));
+          } else {
+            return greaterThanOrEqual(unquote(getAttrName(predicate.children()[1])),
+                convertUTF8StringIfNecessary(((LiteralValue) predicate.children()[0]).value()));
+          }
+
+        case GT:
+          if (predicate.children()[1] instanceof LiteralValue) {
+            return greaterThan(unquote(getAttrName(predicate.children()[0])),
+                convertUTF8StringIfNecessary(((LiteralValue) predicate.children()[1]).value()));
+          } else {
+            return lessThan(unquote(getAttrName(predicate.children()[1])),
+                convertUTF8StringIfNecessary(((LiteralValue) predicate.children()[0]).value()));
+          }
+
+        case GT_EQ:
+          if (predicate.children()[1] instanceof LiteralValue) {
+            return greaterThanOrEqual(unquote(getAttrName(predicate.children()[0])),
+                convertUTF8StringIfNecessary(((LiteralValue) predicate.children()[1]).value()));
+          } else {
+            return lessThanOrEqual(unquote(getAttrName(predicate.children()[1])),
+                convertUTF8StringIfNecessary(((LiteralValue) predicate.children()[0]).value()));
+          }
+
+        case EQ: // used for both eq and null-safe-eq
+          Object value = null;
+          String attributeName = "";
+          if (predicate.children()[1] instanceof LiteralValue) {
+            attributeName = getAttrName(predicate.children()[0]);
+            value = convertUTF8StringIfNecessary(((LiteralValue) predicate.children()[1]).value());
+          } else {
+            attributeName = getAttrName(predicate.children()[1]);
+            value = convertUTF8StringIfNecessary(((LiteralValue) predicate.children()[0]).value());
+          }
+          if (predicate.name().equals("=")) {
+            // comparison with null in normal equality is always null. this is probably a mistake.
+            Preconditions.checkNotNull(value,
+                "Expression is always false (eq is not null-safe): %s", predicate);
+            return handleEqual(unquote(attributeName), value);
+          } else if (predicate.name().equals("<=>")) {
+            if (value == null) {
+              return isNull(unquote(attributeName));
+            } else {
+              return handleEqual(unquote(attributeName), value);
+            }
+          }
+          break;
+
+        case IN:
+          Object[] inValues = new Object[predicate.children().length - 1];
+          for (int i = 1; i < predicate.children().length; i++) {
+            inValues[i - 1] = convertUTF8StringIfNecessary(((LiteralValue) predicate.children()[i]).value());
+          }
+          return in(unquote(getAttrName(predicate.children()[0])),
+              Stream.of(inValues)
+                  .filter(Objects::nonNull)
+                  .collect(Collectors.toList()));
+
+        case NOT:
+          Not notFilter = (Not) predicate;
+          Predicate childFilter = notFilter.child();
+          if (childFilter.name().equals("IN")) {
+            // infer an extra notNull predicate for Spark NOT IN filters
+            // as Iceberg expressions don't follow the 3-value SQL boolean logic
+            // col NOT IN (1, 2) in Spark is equivalent to notNull(col) && notIn(col, 1, 2) in Iceberg
+            Object[] notInValues = new Object[childFilter.children().length - 1];
+            for (int i = 1; i < childFilter.children().length; i++) {
+              notInValues[i - 1] = convertUTF8StringIfNecessary(((LiteralValue) childFilter.children()[i]).value());
+            }
+            Expression notIn = notIn(unquote(getAttrName(childFilter.children()[0])),
+                Stream.of(notInValues).collect(Collectors.toList()));
+            return and(notNull(unquote(getAttrName(childFilter.children()[0]))), notIn);
+          } else if (hasNoInFilter(childFilter)) {
+            Expression child = convert(childFilter);
+            if (child != null) {
+              return not(child);
+            }
+          }
+          return null;
+
+        case AND: {
+          And andPredicate = (And) predicate;
+          Expression left = convert(andPredicate.left());
+          Expression right = convert(andPredicate.right());
+          if (left != null && right != null) {
+            return and(left, right);
+          }
+          return null;
+        }
+
+        case OR: {
+          Or orPredicate = (Or) predicate;
+          Expression left = convert(orPredicate.left());
+          Expression right = convert(orPredicate.right());
+          if (left != null && right != null) {
+            return or(left, right);
+          }
+          return null;
+        }
+
+        case STARTS_WITH: {
+          return startsWith(unquote(getAttrName(predicate.children()[0])),
+              convertUTF8StringIfNecessary(((LiteralValue) predicate.children()[1]).value()).toString());
+        }
+      }
+    }
+
+    return null;
+  }
+
+  private static Object convertUTF8StringIfNecessary(Object value) {
+    if (value instanceof UTF8String) {
+      return ((UTF8String) value).toString();
+    }
+    return value;
+  }
+
+  private static Expression handleEqual(String attribute, Object value) {
+    if (NaNUtil.isNaN(value)) {
+      return isNaN(attribute);
+    } else {
+      return equal(attribute, value);
+    }
+  }
+
+  private static String unquote(String attributeName) {
+    Matcher matcher = BACKTICKS_PATTERN.matcher(attributeName);
+    return matcher.replaceAll("$2");
+  }
+
+  private static boolean hasNoInFilter(Predicate predicate) {
+    Expression.Operation op = FILTERS.get(predicate.name());
+
+    if (op != null) {
+      switch (op) {
+        case AND:
+          And andPredicate = (And) predicate;
+          return hasNoInFilter(andPredicate.left()) && hasNoInFilter(andPredicate.right());
+        case OR:
+          Or orPredicate = (Or) predicate;
+          return hasNoInFilter(orPredicate.left()) && hasNoInFilter(orPredicate.right());
+        case NOT:
+          Not notPredicate = (Not) predicate;
+          return hasNoInFilter(notPredicate.child());
+        case IN:
+          return false;
+        default:
+          return true;
+      }
+    }
+
+    return false;
+  }
+
+  private static String getAttrName(org.apache.spark.sql.connector.expressions.Expression expr) {
+    return String.join(".", ((NamedReference) expr).fieldNames());
+  }
+
+  @SuppressWarnings("checkstyle:CyclomaticComplexity")
+  private static Predicate checkIfPredicateValid(Predicate predicate) {

Review Comment:
   I think Spark should have this check to make sure all the V2 Filters are in valid format. Otherwise, this kind of invalid [filter](https://github.com/apache/iceberg/blob/master/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestDelete.java#L433) will get though. 
   
   I used to have this check on Spark side, but somehow it got removed during refactor. I will add this check back. Before this is available, I will have this check here. I will remove once this is in Spark.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@iceberg.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@iceberg.apache.org
For additional commands, e-mail: issues-help@iceberg.apache.org