You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@calcite.apache.org by li...@apache.org on 2022/03/04 13:24:27 UTC

[calcite] 23/41: [CALCITE-4996] In RelJson, add a readExpression method that can convert JSON to a RexNode expression

This is an automated email from the ASF dual-hosted git repository.

liyafan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/calcite.git

commit b4a576828714e3837d589033ece6acc494492bee
Author: Marieke Gueye <ma...@google.com>
AuthorDate: Wed Feb 2 00:52:40 2022 +0000

    [CALCITE-4996] In RelJson, add a readExpression method that can convert JSON to a RexNode expression
    
    Previously `RelJson` was only able to deserialize `RexNode`
    expressions that are part of a `RelNode` such as `Project`,
    `Filter` or `Join`; references to input fields would always
    be converted to a `RexInputRef`, referencing a field of the
    input `RelNode`'s output row type.
    
    But if the expression is not part of a `RelNode`, there is no
    input `RelNode`. So this change adds `interface
    InputTranslator` to specify how references to input fields
    are to be translated.
    
    In RuleMatchVisualizer, make method static, to appease lint.
    
    Close apache/calcite#2709
---
 .../plan/visualizer/RuleMatchVisualizer.java       |   3 +-
 .../apache/calcite/rel/externalize/RelJson.java    | 223 ++++++++++++++++++---
 .../org/apache/calcite/plan/RelWriterTest.java     |  70 +++++++
 3 files changed, 268 insertions(+), 28 deletions(-)

diff --git a/core/src/main/java/org/apache/calcite/plan/visualizer/RuleMatchVisualizer.java b/core/src/main/java/org/apache/calcite/plan/visualizer/RuleMatchVisualizer.java
index 37d88a2..66a20e8 100644
--- a/core/src/main/java/org/apache/calcite/plan/visualizer/RuleMatchVisualizer.java
+++ b/core/src/main/java/org/apache/calcite/plan/visualizer/RuleMatchVisualizer.java
@@ -46,7 +46,6 @@ import java.text.DecimalFormat;
 import java.text.MessageFormat;
 import java.util.ArrayList;
 import java.util.Arrays;
-import java.util.Collection;
 import java.util.Collections;
 import java.util.HashSet;
 import java.util.LinkedHashMap;
@@ -173,7 +172,7 @@ public class RuleMatchVisualizer implements RelOptListener {
    * Get the inputs for a node, unwrapping {@link HepRelVertex} nodes.
    * (Workaround for HepPlanner)
    */
-  private Collection<RelNode> getInputs(final RelNode node) {
+  private static List<RelNode> getInputs(final RelNode node) {
     return node.getInputs().stream().map(n -> {
       if (n instanceof HepRelVertex) {
         return ((HepRelVertex) n).getCurrentRel();
diff --git a/core/src/main/java/org/apache/calcite/rel/externalize/RelJson.java b/core/src/main/java/org/apache/calcite/rel/externalize/RelJson.java
index 320c44f..d39947d 100644
--- a/core/src/main/java/org/apache/calcite/rel/externalize/RelJson.java
+++ b/core/src/main/java/org/apache/calcite/rel/externalize/RelJson.java
@@ -19,6 +19,8 @@ package org.apache.calcite.rel.externalize;
 import org.apache.calcite.avatica.AvaticaUtils;
 import org.apache.calcite.avatica.util.TimeUnit;
 import org.apache.calcite.plan.RelOptCluster;
+import org.apache.calcite.plan.RelOptTable;
+import org.apache.calcite.plan.RelTraitSet;
 import org.apache.calcite.rel.RelCollation;
 import org.apache.calcite.rel.RelCollationImpl;
 import org.apache.calcite.rel.RelCollations;
@@ -89,6 +91,7 @@ import static java.util.Objects.requireNonNull;
 public class RelJson {
   private final Map<String, Constructor> constructorMap = new HashMap<>();
   private final @Nullable JsonBuilder jsonBuilder;
+  private final InputTranslator inputTranslator;
 
   public static final List<String> PACKAGES =
       ImmutableList.of(
@@ -98,8 +101,24 @@ public class RelJson {
           "org.apache.calcite.adapter.jdbc.",
           "org.apache.calcite.adapter.jdbc.JdbcRules$");
 
-  public RelJson(@Nullable JsonBuilder jsonBuilder) {
+  /** Private constructor. */
+  private RelJson(@Nullable JsonBuilder jsonBuilder,
+      InputTranslator inputTranslator) {
     this.jsonBuilder = jsonBuilder;
+    this.inputTranslator = requireNonNull(inputTranslator, "inputTranslator");
+  }
+
+  /** Creates a RelJson. */
+  public RelJson(@Nullable JsonBuilder jsonBuilder) {
+    this(jsonBuilder, RelJson::translateInput);
+  }
+
+  /** Returns a RelJson with a given InputTranslator. */
+  public RelJson withInputTranslator(InputTranslator inputTranslator) {
+    if (inputTranslator == this.inputTranslator) {
+      return this;
+    }
+    return new RelJson(jsonBuilder, inputTranslator);
   }
 
   private JsonBuilder jsonBuilder() {
@@ -185,6 +204,32 @@ public class RelJson {
     return canonicalName;
   }
 
+  /** Default implementation of
+   * {@link InputTranslator#translateInput(RelJson, int, Map, RelInput)}. */
+  private static RexNode translateInput(RelJson relJson, int input,
+      Map<String, @Nullable Object> map, RelInput relInput) {
+    final RelOptCluster cluster = relInput.getCluster();
+    final RexBuilder rexBuilder = cluster.getRexBuilder();
+
+    // Check if it is a local ref.
+    if (map.containsKey("type")) {
+      final RelDataTypeFactory typeFactory = cluster.getTypeFactory();
+      final RelDataType type = relJson.toType(typeFactory, get(map, "type"));
+      return rexBuilder.makeLocalRef(type, input);
+    }
+    int i = input;
+    final List<RelNode> relNodes = relInput.getInputs();
+    for (RelNode inputNode : relNodes) {
+      final RelDataType rowType = inputNode.getRowType();
+      if (i < rowType.getFieldCount()) {
+        final RelDataTypeField field = rowType.getFieldList().get(i);
+        return rexBuilder.makeInputRef(field.getType(), input);
+      }
+      i -= rowType.getFieldCount();
+    }
+    throw new RuntimeException("input field " + input + " is out of range");
+  }
+
   public Object toJson(RelCollationImpl node) {
     final List<Object> list = new ArrayList<>();
     for (RelFieldCollation fieldCollation : node.getFieldCollations()) {
@@ -551,21 +596,21 @@ public class RelJson {
     return map;
   }
 
+  @SuppressWarnings({"rawtypes", "unchecked"})
   @PolyNull RexNode toRex(RelInput relInput, @PolyNull Object o) {
     final RelOptCluster cluster = relInput.getCluster();
     final RexBuilder rexBuilder = cluster.getRexBuilder();
     if (o == null) {
       return null;
     } else if (o instanceof Map) {
-      Map map = (Map) o;
-      final Map<String, @Nullable Object> opMap = (Map) map.get("op");
+      final Map<String, @Nullable Object> map = (Map) o;
       final RelDataTypeFactory typeFactory = cluster.getTypeFactory();
-      if (opMap != null) {
+      if (map.containsKey("op")) {
+        final Map<String, @Nullable Object> opMap = get(map, "op");
         if (map.containsKey("class")) {
-          opMap.put("class", map.get("class"));
+          opMap.put("class", get(map, "class"));
         }
-        @SuppressWarnings("unchecked")
-        final List operands = get((Map<String, Object>) map, "operands");
+        final List operands = get(map, "operands");
         final List<RexNode> rexOperands = toRexList(relInput, operands);
         final Object jsonType = map.get("type");
         final Map window = (Map) map.get("window");
@@ -619,23 +664,7 @@ public class RelJson {
       }
       final Integer input = (Integer) map.get("input");
       if (input != null) {
-        // Check if it is a local ref.
-        if (map.containsKey("type")) {
-          final RelDataType type = toType(typeFactory, map.get("type"));
-          return rexBuilder.makeLocalRef(type, input);
-        }
-
-        List<RelNode> inputNodes = relInput.getInputs();
-        int i = input;
-        for (RelNode inputNode : inputNodes) {
-          final RelDataType rowType = inputNode.getRowType();
-          if (i < rowType.getFieldCount()) {
-            final RelDataTypeField field = rowType.getFieldList().get(i);
-            return rexBuilder.makeInputRef(field.getType(), input);
-          }
-          i -= rowType.getFieldCount();
-        }
-        throw new RuntimeException("input field " + input + " is out of range");
+        return inputTranslator.translateInput(this, input, map, relInput);
       }
       final String field = (String) map.get("field");
       if (field != null) {
@@ -651,16 +680,17 @@ public class RelJson {
       }
       if (map.containsKey("literal")) {
         Object literal = map.get("literal");
-        final RelDataType type = toType(typeFactory, map.get("type"));
         if (literal == null) {
+          final RelDataType type = toType(typeFactory, get(map, "type"));
           return rexBuilder.makeNullLiteral(type);
         }
-        if (type == null) {
+        if (!map.containsKey("type")) {
           // In previous versions, type was not specified for all literals.
           // To keep backwards compatibility, if type is not specified
           // we just interpret the literal
           return toRex(relInput, literal);
         }
+        final RelDataType type = toType(typeFactory, get(map, "type"));
         if (type.getSqlTypeName() == SqlTypeName.SYMBOL) {
           literal = RelEnumTypes.toEnum((String) literal);
         }
@@ -776,4 +806,145 @@ public class RelJson {
     map.put("syntax", operator.getSyntax().toString());
     return map;
   }
+
+  /**
+   * Translates a JSON expression into a RexNode,
+   * using a given {@link InputTranslator} to transform JSON objects that
+   * represent input references into RexNodes.
+   *
+   * @param cluster The optimization environment
+   * @param translator Input translator
+   * @param o JSON object
+   * @return the transformed RexNode
+   */
+  public static RexNode readExpression(RelOptCluster cluster,
+      InputTranslator translator, Map<String, Object> o) {
+    RelInput relInput = new RelInputForCluster(cluster);
+    return new RelJson(null, translator).toRex(relInput, o);
+  }
+
+  /**
+   * Special context from which a relational expression can be initialized,
+   * reading from a serialized form of the relational expression.
+   *
+   * <p>Contains only a cluster and an empty list of inputs;
+   * most methods throw {@link UnsupportedOperationException}.
+   */
+  private static class RelInputForCluster implements RelInput {
+    private final RelOptCluster cluster;
+
+    RelInputForCluster(RelOptCluster cluster) {
+      this.cluster = cluster;
+    }
+    @Override public RelOptCluster getCluster() {
+      return cluster;
+    }
+
+    @Override public RelTraitSet getTraitSet() {
+      throw new UnsupportedOperationException();
+    }
+
+    @Override public RelOptTable getTable(String table) {
+      throw new UnsupportedOperationException();
+    }
+
+    @Override public RelNode getInput() {
+      throw new UnsupportedOperationException();
+    }
+
+    @Override public List<RelNode> getInputs() {
+      return ImmutableList.of();
+    }
+
+    @Override public @Nullable RexNode getExpression(String tag) {
+      throw new UnsupportedOperationException();
+    }
+
+    @Override public ImmutableBitSet getBitSet(String tag) {
+      throw new UnsupportedOperationException();
+    }
+
+    @Override public @Nullable List<ImmutableBitSet> getBitSetList(String tag) {
+      throw new UnsupportedOperationException();
+    }
+
+    @Override public List<AggregateCall> getAggregateCalls(String tag) {
+      throw new UnsupportedOperationException();
+    }
+
+    @Override public @Nullable Object get(String tag) {
+      throw new UnsupportedOperationException();
+    }
+
+    @Override public @Nullable String getString(String tag) {
+      throw new UnsupportedOperationException();
+    }
+
+    @Override public float getFloat(String tag) {
+      throw new UnsupportedOperationException();
+    }
+
+    @Override public <E extends Enum<E>> @Nullable E getEnum(
+        String tag, Class<E> enumClass) {
+      throw new UnsupportedOperationException();
+    }
+
+    @Override public @Nullable List<RexNode> getExpressionList(String tag) {
+      throw new UnsupportedOperationException();
+    }
+
+    @Override public @Nullable List<String> getStringList(String tag) {
+      throw new UnsupportedOperationException();
+    }
+
+    @Override public @Nullable List<Integer> getIntegerList(String tag) {
+      throw new UnsupportedOperationException();
+    }
+
+    @Override public @Nullable List<List<Integer>> getIntegerListList(String tag) {
+      throw new UnsupportedOperationException();
+    }
+
+    @Override public RelDataType getRowType(String tag) {
+      throw new UnsupportedOperationException();
+    }
+
+    @Override public RelDataType getRowType(String expressionsTag, String fieldsTag) {
+      throw new UnsupportedOperationException();
+    }
+
+    @Override public RelCollation getCollation() {
+      throw new UnsupportedOperationException();
+    }
+
+    @Override public RelDistribution getDistribution() {
+      throw new UnsupportedOperationException();
+    }
+
+    @Override public ImmutableList<ImmutableList<RexLiteral>> getTuples(String tag) {
+      throw new UnsupportedOperationException();
+    }
+
+    @Override public boolean getBoolean(String tag, boolean default_) {
+      throw new UnsupportedOperationException();
+    }
+  }
+
+  /**
+   * Translates a JSON object that represents an input reference into a RexNode.
+   */
+  @FunctionalInterface
+  public interface InputTranslator {
+    /**
+     * Transforms an input reference into a RexNode.
+     *
+     * @param relJson RelJson
+     * @param input Ordinal of input field
+     * @param map JSON object representing an input reference
+     * @param relInput Description of input(s)
+     * @return RexNode representing an input reference
+     */
+    RexNode translateInput(RelJson relJson, int input,
+        Map<String, @Nullable Object> map, RelInput relInput);
+  }
 }
diff --git a/core/src/test/java/org/apache/calcite/plan/RelWriterTest.java b/core/src/test/java/org/apache/calcite/plan/RelWriterTest.java
index 7903c18..6b5c553 100644
--- a/core/src/test/java/org/apache/calcite/plan/RelWriterTest.java
+++ b/core/src/test/java/org/apache/calcite/plan/RelWriterTest.java
@@ -24,6 +24,7 @@ import org.apache.calcite.rel.RelCollations;
 import org.apache.calcite.rel.RelDistribution;
 import org.apache.calcite.rel.RelDistributionTraitDef;
 import org.apache.calcite.rel.RelDistributions;
+import org.apache.calcite.rel.RelInput;
 import org.apache.calcite.rel.RelNode;
 import org.apache.calcite.rel.RelShuttleImpl;
 import org.apache.calcite.rel.core.AggregateCall;
@@ -69,11 +70,16 @@ import org.apache.calcite.util.ImmutableBitSet;
 import org.apache.calcite.util.JsonBuilder;
 import org.apache.calcite.util.TestUtil;
 
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.core.type.TypeReference;
+import com.fasterxml.jackson.databind.DeserializationFeature;
+import com.fasterxml.jackson.databind.ObjectMapper;
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableSet;
 import com.google.common.collect.Lists;
 
 import org.checkerframework.checker.nullness.qual.Nullable;
+import org.hamcrest.Matcher;
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.params.ParameterizedTest;
 import org.junit.jupiter.params.provider.MethodSource;
@@ -83,7 +89,9 @@ import java.math.BigDecimal;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
+import java.util.LinkedHashMap;
 import java.util.List;
+import java.util.Map;
 import java.util.stream.Stream;
 
 import static org.apache.calcite.test.Matchers.isLinux;
@@ -706,6 +714,68 @@ class RelWriterTest {
             + "    LogicalTableScan(table=[[hr, emps]])\n"));
   }
 
+  @Test void testJsonToRex() throws JsonProcessingException {
+    // Test simple literal without inputs
+    final String jsonString1 = "{\n"
+        + "  \"literal\": 10,\n"
+        + "  \"type\": {\n"
+        + "    \"type\": \"INTEGER\",\n"
+        + "    \"nullable\": false\n"
+        + "  }\n"
+        + "}\n";
+
+    assertThatReadExpressionResult(jsonString1, is("10"));
+
+    // Test binary operator ('+') with an input and a literal
+    final String jsonString2 = "{ \"op\": \n"
+        + "  { \"name\": \"+\",\n"
+        + "    \"kind\": \"PLUS\",\n"
+        + "    \"syntax\": \"BINARY\"\n"
+        + "  },\n"
+        + "  \"operands\": [\n"
+        + "    {\n"
+        + "      \"input\": 1,\n"
+        + "      \"name\": \"$1\"\n"
+        + "    },\n"
+        + "    {\n"
+        + "      \"literal\": 2,\n"
+        + "      \"type\": { \"type\": \"INTEGER\", \"nullable\": false }\n"
+        + "    }\n"
+        + "  ]\n"
+        + "}";
+    assertThatReadExpressionResult(jsonString2, is("+(1001, 2)"));
+  }
+
+  private void assertThatReadExpressionResult(String json, Matcher<String> matcher) {
+    final FrameworkConfig config = RelBuilderTest.config().build();
+    final RelBuilder builder = RelBuilder.create(config);
+    final RelOptCluster cluster = builder.getCluster();
+    final ObjectMapper mapper = new ObjectMapper();
+    final TypeReference<LinkedHashMap<String, Object>> typeRef =
+        new TypeReference<LinkedHashMap<String, Object>>() {
+    };
+    final Map<String, Object> o;
+    try {
+      o = mapper
+          .configure(DeserializationFeature.USE_BIG_DECIMAL_FOR_FLOATS, true)
+          .readValue(json, typeRef);
+    } catch (JsonProcessingException e) {
+      throw TestUtil.rethrow(e);
+    }
+    RexNode e =
+        RelJson.readExpression(cluster, RelWriterTest::translateInput, o);
+    assertThat(e.toString(), is(matcher));
+  }
+
+  /** Intended as an instance of {@link RelJson.InputTranslator},
+   * translates input {@code input} into an INTEGER literal
+   * "{@code 1000 + input}". */
+  private static RexNode translateInput(RelJson relJson, int input,
+      Map<String, @Nullable Object> map, RelInput relInput) {
+    final RexBuilder rexBuilder = relInput.getCluster().getRexBuilder();
+    return rexBuilder.makeExactLiteral(BigDecimal.valueOf(1000 + input));
+  }
+
   @Test void testTrim() {
     final FrameworkConfig config = RelBuilderTest.config().build();
     final RelBuilder b = RelBuilder.create(config);