You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@pinot.apache.org by ro...@apache.org on 2022/09/12 17:04:27 UTC

[pinot] branch master updated: [multistage] add calcite function catalog (#9375)

This is an automated email from the ASF dual-hosted git repository.

rongr pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new 987480ba2d [multistage] add calcite function catalog (#9375)
987480ba2d is described below

commit 987480ba2d6c6575ccce06aa10702a73e6a28f9a
Author: Rong Rong <wa...@gmail.com>
AuthorDate: Mon Sep 12 10:04:19 2022 -0700

    [multistage] add calcite function catalog (#9375)
    
    * planner can parse custom function
    * use chained operator table
    also
    * fix typo in partition carrying
    * fix rules in singleton exchange optimization.
    
    Co-authored-by: Rong Rong <ro...@startree.ai>
---
 .../pinot/common/function/FunctionRegistry.java    | 27 ++++++++++++++++++++++
 .../apache/calcite/jdbc/CalciteSchemaBuilder.java  | 14 ++++++++++-
 .../org/apache/pinot/query/QueryEnvironment.java   |  8 ++++++-
 .../apache/pinot/query/catalog/PinotCatalog.java   |  5 ++--
 .../apache/pinot/query/context/PlannerContext.java |  3 +--
 .../query/parser/CalciteRexExpressionParser.java   |  5 ++--
 .../pinot/query/planner/logical/RexExpression.java |  1 +
 .../pinot/query/planner/logical/StagePlanner.java  |  4 ++--
 .../apache/pinot/query/QueryCompilationTest.java   |  2 +-
 .../pinot/query/QueryEnvironmentTestBase.java      |  2 ++
 .../pinot/query/runtime/QueryRunnerTest.java       | 16 +++++++++++--
 11 files changed, 73 insertions(+), 14 deletions(-)

diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java b/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java
index 7633e2391c..4d786b2ed5 100644
--- a/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java
+++ b/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java
@@ -21,10 +21,15 @@ package org.apache.pinot.common.function;
 import com.google.common.base.Preconditions;
 import java.lang.reflect.Method;
 import java.lang.reflect.Modifier;
+import java.util.Collection;
 import java.util.HashMap;
+import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import javax.annotation.Nullable;
+import org.apache.calcite.schema.Function;
+import org.apache.calcite.schema.impl.ScalarFunctionImpl;
+import org.apache.calcite.util.NameMultimap;
 import org.apache.commons.lang3.StringUtils;
 import org.apache.pinot.spi.annotations.ScalarFunction;
 import org.apache.pinot.spi.utils.PinotReflectionUtils;
@@ -41,7 +46,12 @@ public class FunctionRegistry {
   }
 
   private static final Logger LOGGER = LoggerFactory.getLogger(FunctionRegistry.class);
+
+  // TODO: consolidate the following 2
+  // This FUNCTION_INFO_MAP is used by Pinot server to look up function by # of arguments
   private static final Map<String, Map<Integer, FunctionInfo>> FUNCTION_INFO_MAP = new HashMap<>();
+  // This FUNCTION_MAP is used by Calcite function catalog tolook up function by function signature.
+  private static final NameMultimap<Function> FUNCTION_MAP = new NameMultimap<>();
 
   /**
    * Registers the scalar functions via reflection.
@@ -86,6 +96,11 @@ public class FunctionRegistry {
    */
   public static void registerFunction(Method method, boolean nullableParameters) {
     registerFunction(method.getName(), method, nullableParameters);
+
+    // Calcite ScalarFunctionImpl doesn't allow customized named functions. TODO: fix me.
+    if (method.getAnnotation(Deprecated.class) == null) {
+      FUNCTION_MAP.put(method.getName(), ScalarFunctionImpl.create(method));
+    }
   }
 
   /**
@@ -99,6 +114,18 @@ public class FunctionRegistry {
         "Function: %s with %s parameters is already registered", functionName, method.getParameterCount());
   }
 
+  public static Map<String, List<Function>> getRegisteredCalciteFunctionMap() {
+    return FUNCTION_MAP.map();
+  }
+
+  public static Collection<Function> getRegisteredCalciteFunctions(String name) {
+    return FUNCTION_MAP.map().get(name);
+  }
+
+  public static Set<String> getRegisteredCalciteFunctionNames() {
+    return FUNCTION_MAP.map().keySet();
+  }
+
   /**
    * Returns {@code true} if the given function name is registered, {@code false} otherwise.
    */
diff --git a/pinot-query-planner/src/main/java/org/apache/calcite/jdbc/CalciteSchemaBuilder.java b/pinot-query-planner/src/main/java/org/apache/calcite/jdbc/CalciteSchemaBuilder.java
index ce3d1c99f9..edb2d74bf0 100644
--- a/pinot-query-planner/src/main/java/org/apache/calcite/jdbc/CalciteSchemaBuilder.java
+++ b/pinot-query-planner/src/main/java/org/apache/calcite/jdbc/CalciteSchemaBuilder.java
@@ -18,7 +18,12 @@
  */
 package org.apache.calcite.jdbc;
 
+import java.util.List;
+import java.util.Map;
+import org.apache.calcite.schema.Function;
 import org.apache.calcite.schema.Schema;
+import org.apache.calcite.schema.SchemaPlus;
+import org.apache.pinot.common.function.FunctionRegistry;
 
 
 /**
@@ -47,6 +52,13 @@ public class CalciteSchemaBuilder {
    * @return calcite schema with given schema as the root
    */
   public static CalciteSchema asRootSchema(Schema root) {
-    return new SimpleCalciteSchema(null, root, "");
+    CalciteSchema rootSchema = CalciteSchema.createRootSchema(false, false, "", root);
+    SchemaPlus schemaPlus = rootSchema.plus();
+    for (Map.Entry<String, List<Function>> e : FunctionRegistry.getRegisteredCalciteFunctionMap().entrySet()) {
+      for (Function f : e.getValue()) {
+        schemaPlus.add(e.getKey(), f);
+      }
+    }
+    return rootSchema;
   }
 }
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java
index c1797381ba..1f84101e53 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java
@@ -19,6 +19,7 @@
 package org.apache.pinot.query;
 
 import com.google.common.annotations.VisibleForTesting;
+import java.util.Arrays;
 import java.util.Collection;
 import java.util.Properties;
 import org.apache.calcite.config.CalciteConnectionConfigImpl;
@@ -42,6 +43,8 @@ import org.apache.calcite.sql.SqlExplainFormat;
 import org.apache.calcite.sql.SqlExplainLevel;
 import org.apache.calcite.sql.SqlKind;
 import org.apache.calcite.sql.SqlNode;
+import org.apache.calcite.sql.fun.SqlStdOperatorTable;
+import org.apache.calcite.sql.util.ChainedSqlOperatorTable;
 import org.apache.calcite.sql2rel.SqlToRelConverter;
 import org.apache.calcite.sql2rel.StandardConvertletTable;
 import org.apache.calcite.tools.FrameworkConfig;
@@ -80,7 +83,6 @@ public class QueryEnvironment {
     _typeFactory = typeFactory;
     _rootSchema = rootSchema;
     _workerManager = workerManager;
-    _config = Frameworks.newConfigBuilder().traitDefs().build();
 
     // catalog
     Properties catalogReaderConfigProperties = new Properties();
@@ -88,6 +90,10 @@ public class QueryEnvironment {
     _catalogReader = new CalciteCatalogReader(_rootSchema, _rootSchema.path(null), _typeFactory,
         new CalciteConnectionConfigImpl(catalogReaderConfigProperties));
 
+    _config = Frameworks.newConfigBuilder().traitDefs()
+        .operatorTable(new ChainedSqlOperatorTable(Arrays.asList(SqlStdOperatorTable.instance(), _catalogReader)))
+        .defaultSchema(_rootSchema.plus()).build();
+
     // optimizer rules
     _logicalRuleSet = PinotQueryRuleSets.LOGICAL_OPT_RULES;
 
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/catalog/PinotCatalog.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/catalog/PinotCatalog.java
index 34673a274b..3d4ae4ac69 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/catalog/PinotCatalog.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/catalog/PinotCatalog.java
@@ -31,6 +31,7 @@ import org.apache.calcite.schema.SchemaVersion;
 import org.apache.calcite.schema.Schemas;
 import org.apache.calcite.schema.Table;
 import org.apache.pinot.common.config.provider.TableCache;
+import org.apache.pinot.common.function.FunctionRegistry;
 import org.apache.pinot.spi.utils.builder.TableNameBuilder;
 
 import static java.util.Objects.requireNonNull;
@@ -86,12 +87,12 @@ public class PinotCatalog implements Schema {
 
   @Override
   public Collection<Function> getFunctions(String name) {
-    return Collections.emptyList();
+    return FunctionRegistry.getRegisteredCalciteFunctions(name);
   }
 
   @Override
   public Set<String> getFunctionNames() {
-    return Collections.emptySet();
+    return FunctionRegistry.getRegisteredCalciteFunctionNames();
   }
 
   @Override
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/context/PlannerContext.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/context/PlannerContext.java
index 859b5e3d60..564fc17468 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/context/PlannerContext.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/context/PlannerContext.java
@@ -25,7 +25,6 @@ import org.apache.calcite.plan.hep.HepProgram;
 import org.apache.calcite.prepare.PlannerImpl;
 import org.apache.calcite.prepare.Prepare;
 import org.apache.calcite.rel.type.RelDataTypeFactory;
-import org.apache.calcite.sql.fun.SqlStdOperatorTable;
 import org.apache.calcite.sql.validate.SqlValidator;
 import org.apache.calcite.tools.FrameworkConfig;
 import org.apache.pinot.query.planner.logical.LogicalPlanner;
@@ -50,7 +49,7 @@ public class PlannerContext implements AutoCloseable {
   public PlannerContext(FrameworkConfig config, Prepare.CatalogReader catalogReader, RelDataTypeFactory typeFactory,
       HepProgram hepProgram) {
     _planner = new PlannerImpl(config);
-    _validator = new Validator(SqlStdOperatorTable.instance(), catalogReader, typeFactory);
+    _validator = new Validator(config.getOperatorTable(), catalogReader, typeFactory);
     _relOptPlanner = new LogicalPlanner(hepProgram, Contexts.EMPTY_CONTEXT);
   }
 
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteRexExpressionParser.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteRexExpressionParser.java
index 804825cff1..f9021154a9 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteRexExpressionParser.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteRexExpressionParser.java
@@ -168,10 +168,9 @@ public class CalciteRexExpressionParser {
         return compileAndExpression(rexCall, pinotQuery);
       case OR:
         return compileOrExpression(rexCall, pinotQuery);
-      case COUNT:
-      case OTHER:
       case OTHER_FUNCTION:
-      case DOT:
+        functionName = rexCall.getFunctionName();
+        break;
       default:
         functionName = functionKind.name();
         break;
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpression.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpression.java
index 778907e40a..d9b04ec06f 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpression.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpression.java
@@ -96,6 +96,7 @@ public interface RexExpression {
         return FieldSpec.DataType.FLOAT;
       case DOUBLE:
         return FieldSpec.DataType.DOUBLE;
+      case CHAR:
       case VARCHAR:
         return FieldSpec.DataType.STRING;
       case BOOLEAN:
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/StagePlanner.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/StagePlanner.java
index ac30996efa..6212f68fa0 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/StagePlanner.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/StagePlanner.java
@@ -215,10 +215,10 @@ public class StagePlanner {
         int leftIndex = leftJoinKeySelector.getColumnIndices().get(i);
         int rightIndex = rightJoinKeySelector.getColumnIndices().get(i);
         if (leftPartitionKeys.contains(leftIndex)) {
-          newPartitionKeys.add(i);
+          newPartitionKeys.add(leftIndex);
         }
         if (rightPartitionKeys.contains(rightIndex)) {
-          newPartitionKeys.add(leftDataSchemaSize + i);
+          newPartitionKeys.add(leftDataSchemaSize + rightIndex);
         }
       }
       node.setPartitionKeys(newPartitionKeys);
diff --git a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java
index 94c62c856e..bace47a3ed 100644
--- a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java
+++ b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java
@@ -141,7 +141,7 @@ public class QueryCompilationTest extends QueryEnvironmentTestBase {
   @Test
   public void testQueryProjectFilterPushDownForJoin() {
     String query = "SELECT a.col1, a.ts, b.col2, b.col3 FROM a JOIN b ON a.col1 = b.col2 "
-        + "WHERE a.col3 >= 0 AND a.col2 IN  ('a', 'b') AND b.col3 < 0";
+        + "WHERE a.col3 >= 0 AND a.col2 IN ('b') AND b.col3 < 0";
     QueryPlan queryPlan = _queryEnvironment.planQuery(query);
     List<StageNode> intermediateStageRoots =
         queryPlan.getStageMetadataMap().entrySet().stream().filter(e -> e.getValue().getScannedTables().size() == 0)
diff --git a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryEnvironmentTestBase.java b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryEnvironmentTestBase.java
index 65ea646e9c..d2939b1b7f 100644
--- a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryEnvironmentTestBase.java
+++ b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryEnvironmentTestBase.java
@@ -63,6 +63,8 @@ public class QueryEnvironmentTestBase {
         new Object[]{"SELECT a.col1, COUNT(*), SUM(a.col3) FROM a WHERE a.col3 >= 0 AND a.col2 = 'a' GROUP BY a.col1 "
             + "HAVING COUNT(*) > 10 AND MAX(a.col3) >= 0 AND MIN(a.col3) < 20 AND SUM(a.col3) <= 10 "
             + "AND AVG(a.col3) = 5"},
+        new Object[]{"SELECT dateTrunc('DAY', ts) FROM a LIMIT 10"},
+        new Object[]{"SELECT dateTrunc('DAY', a.ts + b.ts) FROM a JOIN b on a.col1 = b.col1 AND a.col2 = b.col2"},
     };
   }
 }
diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/QueryRunnerTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/QueryRunnerTest.java
index 670710590d..22d21106b7 100644
--- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/QueryRunnerTest.java
+++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/QueryRunnerTest.java
@@ -92,7 +92,7 @@ public class QueryRunnerTest extends QueryRunnerTestBase {
         // Because:
         //   - MOD(a.col3, 2) will have 6 (42)s equal to 0 and 9 (1)s equals to 1
         //   - MOD(b.col3, 3) will have 2 (42)s equal to 0 and 3 (1)s equals to 1;
-        // final results are 6 * 2 + 9 * 3 = 27 rows
+        // final results are 6 * 2 + 9 * 3 = 39 rows
         new Object[]{"SELECT a.col1, a.col3, b.col3 FROM a JOIN b ON MOD(a.col3, 2) = MOD(b.col3, 3)", 39},
 
         // Specifically table A has 15 rows (10 on server1 and 5 on server2) and table B has 5 rows (all on server1),
@@ -141,9 +141,14 @@ public class QueryRunnerTest extends QueryRunnerTestBase {
             + " WHERE a.col3 >= 0 GROUP BY a.col1, a.col2", 5},
 
         // GROUP BY after JOIN
-        // only 3 GROUP BY key exist because b.col2 cycles between "foo", "bar", "alice".
+        //   - optimizable transport for GROUP BY key after JOIN, using SINGLETON exchange
+        //     only 3 GROUP BY key exist because b.col2 cycles between "foo", "bar", "alice".
         new Object[]{"SELECT a.col1, SUM(b.col3), COUNT(*), SUM(2) FROM a JOIN b ON a.col1 = b.col2 "
             + " WHERE a.col3 >= 0 GROUP BY a.col1", 3},
+        //   - non-optimizable transport for GROUP BY key after JOIN, using HASH exchange
+        //     only 2 GROUP BY key exist for b.col3.
+        new Object[]{"SELECT b.col3, SUM(a.col3) FROM a JOIN b"
+            + " on a.col1 = b.col1 AND a.col2 = b.col2 GROUP BY b.col3", 2},
 
         // Sub-query
         new Object[]{"SELECT b.col1, b.col3, i.maxVal FROM b JOIN "
@@ -162,6 +167,13 @@ public class QueryRunnerTest extends QueryRunnerTestBase {
 
         // Order-by
         new Object[]{"SELECT a.col1, a.col3, b.col3 FROM a JOIN b ON a.col1 = b.col1 ORDER BY a.col3, b.col3 DESC", 15},
+
+        // test customized function
+        //   - on leaf stage
+        new Object[]{"SELECT dateTrunc('DAY', ts) FROM a LIMIT 10", 15},
+        //   - on intermediate stage
+        new Object[]{"SELECT dateTrunc('DAY', round(a.ts, b.ts)) FROM a JOIN b "
+            + "ON a.col1 = b.col1 AND a.col2 = b.col2", 15},
     };
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@pinot.apache.org
For additional commands, e-mail: commits-help@pinot.apache.org