You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@calcite.apache.org by jh...@apache.org on 2022/02/09 21:35:00 UTC
[calcite] 01/01: [CALCITE-4996] In RelJson, add a readExpression method that can convert JSON to a RexNode expression
This is an automated email from the ASF dual-hosted git repository.
jhyde pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/calcite.git
commit 812e3e98eae518cf85cd1b6b7f055fb96784a423
Author: Marieke Gueye <ma...@google.com>
AuthorDate: Wed Feb 2 00:52:40 2022 +0000
[CALCITE-4996] In RelJson, add a readExpression method that can convert JSON to a RexNode expression
Previously `RelJson` was only able to deserialize `RexNode`
expressions that are part of a `RelNode` such as `Project`,
`Filter` or `Join`; references to input fields would always
be converted to a `RexInputRef`, referencing a field of the
input `RelNode`'s output row type.
But if the expression is not part of a `RelNode`, there is no
input `RelNode`. So this change adds `interface
InputTranslator` to specify how references to input fields
are to be translated.
In RuleMatchVisualizer, make method static, to appease lint.
Close apache/calcite#2709
---
.../plan/visualizer/RuleMatchVisualizer.java | 3 +-
.../apache/calcite/rel/externalize/RelJson.java | 223 ++++++++++++++++++---
.../org/apache/calcite/plan/RelWriterTest.java | 70 +++++++
3 files changed, 268 insertions(+), 28 deletions(-)
diff --git a/core/src/main/java/org/apache/calcite/plan/visualizer/RuleMatchVisualizer.java b/core/src/main/java/org/apache/calcite/plan/visualizer/RuleMatchVisualizer.java
index 37d88a2..66a20e8 100644
--- a/core/src/main/java/org/apache/calcite/plan/visualizer/RuleMatchVisualizer.java
+++ b/core/src/main/java/org/apache/calcite/plan/visualizer/RuleMatchVisualizer.java
@@ -46,7 +46,6 @@ import java.text.DecimalFormat;
import java.text.MessageFormat;
import java.util.ArrayList;
import java.util.Arrays;
-import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedHashMap;
@@ -173,7 +172,7 @@ public class RuleMatchVisualizer implements RelOptListener {
* Get the inputs for a node, unwrapping {@link HepRelVertex} nodes.
* (Workaround for HepPlanner)
*/
- private Collection<RelNode> getInputs(final RelNode node) {
+ private static List<RelNode> getInputs(final RelNode node) {
return node.getInputs().stream().map(n -> {
if (n instanceof HepRelVertex) {
return ((HepRelVertex) n).getCurrentRel();
diff --git a/core/src/main/java/org/apache/calcite/rel/externalize/RelJson.java b/core/src/main/java/org/apache/calcite/rel/externalize/RelJson.java
index 320c44f..d39947d 100644
--- a/core/src/main/java/org/apache/calcite/rel/externalize/RelJson.java
+++ b/core/src/main/java/org/apache/calcite/rel/externalize/RelJson.java
@@ -19,6 +19,8 @@ package org.apache.calcite.rel.externalize;
import org.apache.calcite.avatica.AvaticaUtils;
import org.apache.calcite.avatica.util.TimeUnit;
import org.apache.calcite.plan.RelOptCluster;
+import org.apache.calcite.plan.RelOptTable;
+import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.rel.RelCollation;
import org.apache.calcite.rel.RelCollationImpl;
import org.apache.calcite.rel.RelCollations;
@@ -89,6 +91,7 @@ import static java.util.Objects.requireNonNull;
public class RelJson {
private final Map<String, Constructor> constructorMap = new HashMap<>();
private final @Nullable JsonBuilder jsonBuilder;
+ private final InputTranslator inputTranslator;
public static final List<String> PACKAGES =
ImmutableList.of(
@@ -98,8 +101,24 @@ public class RelJson {
"org.apache.calcite.adapter.jdbc.",
"org.apache.calcite.adapter.jdbc.JdbcRules$");
- public RelJson(@Nullable JsonBuilder jsonBuilder) {
+ /** Private constructor. */
+ private RelJson(@Nullable JsonBuilder jsonBuilder,
+ InputTranslator inputTranslator) {
this.jsonBuilder = jsonBuilder;
+ this.inputTranslator = requireNonNull(inputTranslator, "inputTranslator");
+ }
+
+ /** Creates a RelJson. */
+ public RelJson(@Nullable JsonBuilder jsonBuilder) {
+ this(jsonBuilder, RelJson::translateInput);
+ }
+
+ /** Returns a RelJson with a given InputTranslator. */
+ public RelJson withInputTranslator(InputTranslator inputTranslator) {
+ if (inputTranslator == this.inputTranslator) {
+ return this;
+ }
+ return new RelJson(jsonBuilder, inputTranslator);
}
private JsonBuilder jsonBuilder() {
@@ -185,6 +204,32 @@ public class RelJson {
return canonicalName;
}
+ /** Default implementation of
+ * {@link InputTranslator#translateInput(RelJson, int, Map, RelInput)}. */
+ private static RexNode translateInput(RelJson relJson, int input,
+ Map<String, @Nullable Object> map, RelInput relInput) {
+ final RelOptCluster cluster = relInput.getCluster();
+ final RexBuilder rexBuilder = cluster.getRexBuilder();
+
+ // Check if it is a local ref.
+ if (map.containsKey("type")) {
+ final RelDataTypeFactory typeFactory = cluster.getTypeFactory();
+ final RelDataType type = relJson.toType(typeFactory, get(map, "type"));
+ return rexBuilder.makeLocalRef(type, input);
+ }
+ int i = input;
+ final List<RelNode> relNodes = relInput.getInputs();
+ for (RelNode inputNode : relNodes) {
+ final RelDataType rowType = inputNode.getRowType();
+ if (i < rowType.getFieldCount()) {
+ final RelDataTypeField field = rowType.getFieldList().get(i);
+ return rexBuilder.makeInputRef(field.getType(), input);
+ }
+ i -= rowType.getFieldCount();
+ }
+ throw new RuntimeException("input field " + input + " is out of range");
+ }
+
public Object toJson(RelCollationImpl node) {
final List<Object> list = new ArrayList<>();
for (RelFieldCollation fieldCollation : node.getFieldCollations()) {
@@ -551,21 +596,21 @@ public class RelJson {
return map;
}
+ @SuppressWarnings({"rawtypes", "unchecked"})
@PolyNull RexNode toRex(RelInput relInput, @PolyNull Object o) {
final RelOptCluster cluster = relInput.getCluster();
final RexBuilder rexBuilder = cluster.getRexBuilder();
if (o == null) {
return null;
} else if (o instanceof Map) {
- Map map = (Map) o;
- final Map<String, @Nullable Object> opMap = (Map) map.get("op");
+ final Map<String, @Nullable Object> map = (Map) o;
final RelDataTypeFactory typeFactory = cluster.getTypeFactory();
- if (opMap != null) {
+ if (map.containsKey("op")) {
+ final Map<String, @Nullable Object> opMap = get(map, "op");
if (map.containsKey("class")) {
- opMap.put("class", map.get("class"));
+ opMap.put("class", get(map, "class"));
}
- @SuppressWarnings("unchecked")
- final List operands = get((Map<String, Object>) map, "operands");
+ final List operands = get(map, "operands");
final List<RexNode> rexOperands = toRexList(relInput, operands);
final Object jsonType = map.get("type");
final Map window = (Map) map.get("window");
@@ -619,23 +664,7 @@ public class RelJson {
}
final Integer input = (Integer) map.get("input");
if (input != null) {
- // Check if it is a local ref.
- if (map.containsKey("type")) {
- final RelDataType type = toType(typeFactory, map.get("type"));
- return rexBuilder.makeLocalRef(type, input);
- }
-
- List<RelNode> inputNodes = relInput.getInputs();
- int i = input;
- for (RelNode inputNode : inputNodes) {
- final RelDataType rowType = inputNode.getRowType();
- if (i < rowType.getFieldCount()) {
- final RelDataTypeField field = rowType.getFieldList().get(i);
- return rexBuilder.makeInputRef(field.getType(), input);
- }
- i -= rowType.getFieldCount();
- }
- throw new RuntimeException("input field " + input + " is out of range");
+ return inputTranslator.translateInput(this, input, map, relInput);
}
final String field = (String) map.get("field");
if (field != null) {
@@ -651,16 +680,17 @@ public class RelJson {
}
if (map.containsKey("literal")) {
Object literal = map.get("literal");
- final RelDataType type = toType(typeFactory, map.get("type"));
if (literal == null) {
+ final RelDataType type = toType(typeFactory, get(map, "type"));
return rexBuilder.makeNullLiteral(type);
}
- if (type == null) {
+ if (!map.containsKey("type")) {
// In previous versions, type was not specified for all literals.
// To keep backwards compatibility, if type is not specified
// we just interpret the literal
return toRex(relInput, literal);
}
+ final RelDataType type = toType(typeFactory, get(map, "type"));
if (type.getSqlTypeName() == SqlTypeName.SYMBOL) {
literal = RelEnumTypes.toEnum((String) literal);
}
@@ -776,4 +806,145 @@ public class RelJson {
map.put("syntax", operator.getSyntax().toString());
return map;
}
+
+ /**
+ * Translates a JSON expression into a RexNode,
+ * using a given {@link InputTranslator} to transform JSON objects that
+ * represent input references into RexNodes.
+ *
+ * @param cluster The optimization environment
+ * @param translator Input translator
+ * @param o JSON object
+ * @return the transformed RexNode
+ */
+ public static RexNode readExpression(RelOptCluster cluster,
+ InputTranslator translator, Map<String, Object> o) {
+ RelInput relInput = new RelInputForCluster(cluster);
+ return new RelJson(null, translator).toRex(relInput, o);
+ }
+
+ /**
+ * Special context from which a relational expression can be initialized,
+ * reading from a serialized form of the relational expression.
+ *
+ * <p>Contains only a cluster and an empty list of inputs;
+ * most methods throw {@link UnsupportedOperationException}.
+ */
+ private static class RelInputForCluster implements RelInput {
+ private final RelOptCluster cluster;
+
+ RelInputForCluster(RelOptCluster cluster) {
+ this.cluster = cluster;
+ }
+ @Override public RelOptCluster getCluster() {
+ return cluster;
+ }
+
+ @Override public RelTraitSet getTraitSet() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override public RelOptTable getTable(String table) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override public RelNode getInput() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override public List<RelNode> getInputs() {
+ return ImmutableList.of();
+ }
+
+ @Override public @Nullable RexNode getExpression(String tag) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override public ImmutableBitSet getBitSet(String tag) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override public @Nullable List<ImmutableBitSet> getBitSetList(String tag) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override public List<AggregateCall> getAggregateCalls(String tag) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override public @Nullable Object get(String tag) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override public @Nullable String getString(String tag) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override public float getFloat(String tag) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override public <E extends Enum<E>> @Nullable E getEnum(
+ String tag, Class<E> enumClass) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override public @Nullable List<RexNode> getExpressionList(String tag) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override public @Nullable List<String> getStringList(String tag) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override public @Nullable List<Integer> getIntegerList(String tag) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override public @Nullable List<List<Integer>> getIntegerListList(String tag) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override public RelDataType getRowType(String tag) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override public RelDataType getRowType(String expressionsTag, String fieldsTag) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override public RelCollation getCollation() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override public RelDistribution getDistribution() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override public ImmutableList<ImmutableList<RexLiteral>> getTuples(String tag) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override public boolean getBoolean(String tag, boolean default_) {
+ throw new UnsupportedOperationException();
+ }
+ }
+
+ /**
+ * Translates a JSON object that represents an input reference into a RexNode.
+ */
+ @FunctionalInterface
+ public interface InputTranslator {
+ /**
+ * Transforms an input reference into a RexNode.
+ *
+ * @param relJson RelJson
+ * @param input Ordinal of input field
+ * @param map JSON object representing an input reference
+ * @param relInput Description of input(s)
+ * @return RexNode representing an input reference
+ */
+ RexNode translateInput(RelJson relJson, int input,
+ Map<String, @Nullable Object> map, RelInput relInput);
+ }
}
diff --git a/core/src/test/java/org/apache/calcite/plan/RelWriterTest.java b/core/src/test/java/org/apache/calcite/plan/RelWriterTest.java
index 7903c18..6b5c553 100644
--- a/core/src/test/java/org/apache/calcite/plan/RelWriterTest.java
+++ b/core/src/test/java/org/apache/calcite/plan/RelWriterTest.java
@@ -24,6 +24,7 @@ import org.apache.calcite.rel.RelCollations;
import org.apache.calcite.rel.RelDistribution;
import org.apache.calcite.rel.RelDistributionTraitDef;
import org.apache.calcite.rel.RelDistributions;
+import org.apache.calcite.rel.RelInput;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelShuttleImpl;
import org.apache.calcite.rel.core.AggregateCall;
@@ -69,11 +70,16 @@ import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.JsonBuilder;
import org.apache.calcite.util.TestUtil;
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.core.type.TypeReference;
+import com.fasterxml.jackson.databind.DeserializationFeature;
+import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import org.checkerframework.checker.nullness.qual.Nullable;
+import org.hamcrest.Matcher;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
@@ -83,7 +89,9 @@ import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
+import java.util.LinkedHashMap;
import java.util.List;
+import java.util.Map;
import java.util.stream.Stream;
import static org.apache.calcite.test.Matchers.isLinux;
@@ -706,6 +714,68 @@ class RelWriterTest {
+ " LogicalTableScan(table=[[hr, emps]])\n"));
}
+ @Test void testJsonToRex() throws JsonProcessingException {
+ // Test simple literal without inputs
+ final String jsonString1 = "{\n"
+ + " \"literal\": 10,\n"
+ + " \"type\": {\n"
+ + " \"type\": \"INTEGER\",\n"
+ + " \"nullable\": false\n"
+ + " }\n"
+ + "}\n";
+
+ assertThatReadExpressionResult(jsonString1, is("10"));
+
+ // Test binary operator ('+') with an input and a literal
+ final String jsonString2 = "{ \"op\": \n"
+ + " { \"name\": \"+\",\n"
+ + " \"kind\": \"PLUS\",\n"
+ + " \"syntax\": \"BINARY\"\n"
+ + " },\n"
+ + " \"operands\": [\n"
+ + " {\n"
+ + " \"input\": 1,\n"
+ + " \"name\": \"$1\"\n"
+ + " },\n"
+ + " {\n"
+ + " \"literal\": 2,\n"
+ + " \"type\": { \"type\": \"INTEGER\", \"nullable\": false }\n"
+ + " }\n"
+ + " ]\n"
+ + "}";
+ assertThatReadExpressionResult(jsonString2, is("+(1001, 2)"));
+ }
+
+ private void assertThatReadExpressionResult(String json, Matcher<String> matcher) {
+ final FrameworkConfig config = RelBuilderTest.config().build();
+ final RelBuilder builder = RelBuilder.create(config);
+ final RelOptCluster cluster = builder.getCluster();
+ final ObjectMapper mapper = new ObjectMapper();
+ final TypeReference<LinkedHashMap<String, Object>> typeRef =
+ new TypeReference<LinkedHashMap<String, Object>>() {
+ };
+ final Map<String, Object> o;
+ try {
+ o = mapper
+ .configure(DeserializationFeature.USE_BIG_DECIMAL_FOR_FLOATS, true)
+ .readValue(json, typeRef);
+ } catch (JsonProcessingException e) {
+ throw TestUtil.rethrow(e);
+ }
+ RexNode e =
+ RelJson.readExpression(cluster, RelWriterTest::translateInput, o);
+ assertThat(e.toString(), is(matcher));
+ }
+
+ /** Intended as an instance of {@link RelJson.InputTranslator},
+ * translates input {@code input} into an INTEGER literal
+ * "{@code 1000 + input}". */
+ private static RexNode translateInput(RelJson relJson, int input,
+ Map<String, @Nullable Object> map, RelInput relInput) {
+ final RexBuilder rexBuilder = relInput.getCluster().getRexBuilder();
+ return rexBuilder.makeExactLiteral(BigDecimal.valueOf(1000 + input));
+ }
+
@Test void testTrim() {
final FrameworkConfig config = RelBuilderTest.config().build();
final RelBuilder b = RelBuilder.create(config);