You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by GitBox <gi...@apache.org> on 2021/03/17 13:01:42 UTC

[GitHub] [spark] maropu commented on a change in pull request #31791: [SPARK-34678][SQL] Add table function registry

maropu commented on a change in pull request #31791:
URL: https://github.com/apache/spark/pull/31791#discussion_r595992593



##########
File path: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
##########
@@ -707,3 +721,150 @@ object FunctionRegistry {
     (name, (info, outerBuilder))
   }
 }
+
+trait TableFunctionRegistry extends FunctionRegistryBase[LogicalPlan] {
+
+  /** Create a copy of this registry with identical functions as this registry. */
+  override def clone(): TableFunctionRegistry = throw new CloneNotSupportedException()
+}
+
+class SimpleTableFunctionRegistry
+    extends SimpleFunctionRegistryBase[LogicalPlan]
+    with TableFunctionRegistry {
+
+  override def clone(): SimpleTableFunctionRegistry = synchronized {
+    val registry = new SimpleTableFunctionRegistry
+    functionBuilders.iterator.foreach { case (name, (info, builder)) =>
+      registry.registerFunction(name, info, builder)
+    }
+    registry
+  }
+}
+
+object EmptyTableFunctionRegistry
+    extends EmptyFunctionRegistryBase[LogicalPlan]
+    with TableFunctionRegistry {
+
+  override def clone(): TableFunctionRegistry = this
+}
+
+object TableFunctionRegistry {
+
+  type TableFunctionBuilder = Seq[Expression] => LogicalPlan
+
+  /**
+   * A TVF maps argument lists to resolver functions that accept those arguments. Using a map
+   * here allows for function overloading.
+   */
+  private type TVF = Map[ArgumentList, Seq[Any] => LogicalPlan]
+
+  /**
+   * List of argument names and their types, used to declare a function.
+   */
+  private case class ArgumentList(args: (String, DataType)*) {
+    /**
+     * Try to cast the expressions to satisfy the expected types of this argument list. If there
+     * are any types that cannot be casted, then None is returned.
+     */
+    def implicitCast(values: Seq[Expression]): Option[Seq[Expression]] = {
+      if (args.length == values.length) {
+        val casted = values.zip(args).map { case (value, (_, expectedType)) =>
+          TypeCoercion.implicitCast(value, expectedType)
+        }
+        if (casted.forall(_.isDefined)) {
+          return Some(casted.map(_.get))
+        }
+      }
+      None
+    }
+
+    override def toString: String = {
+      args.map { a =>
+        s"${a._1}: ${a._2.typeName}"
+      }.mkString(", ")
+    }
+  }
+
+  /**
+   * TVF builder.
+   */
+  private def tvf(args: (String, DataType)*)(pf: PartialFunction[Seq[Any], LogicalPlan])
+      : (ArgumentList, Seq[Any] => LogicalPlan) = {
+    (ArgumentList(args: _*),
+      pf orElse {
+        case arguments =>
+          // This is caught again by the apply function and rethrow with richer information about
+          // position, etc, for a better error message.
+          throw new AnalysisException(
+            "Invalid arguments for resolved function: " + arguments.mkString(", "))
+      })
+  }
+
+  private def logicalPlan[T <: LogicalPlan : ClassTag](name: String, function: TVF)
+      : (String, (ExpressionInfo, TableFunctionBuilder)) = {
+    val argLists = function.keys.map(_.toString).toSeq.sorted.map(x => s"$name($x)").mkString("\n")
+    val builder = (expressions: Seq[Expression]) => {
+      val argTypes = expressions.map(_.dataType.typeName).mkString(", ")
+      function.flatMap { case (argList, resolver) =>
+        argList.implicitCast(expressions) match {
+          case Some(casted) =>
+            try {
+              Some(resolver(casted.map(_.eval())))
+            } catch {
+              case _: AnalysisException =>
+                throw QueryCompilationErrors.cannotApplyTableValuedFunctionError(
+                  name, argLists, argTypes)
+            }
+          case _ => None
+        }
+      }.headOption.getOrElse {
+        throw QueryCompilationErrors.cannotApplyTableValuedFunctionError(
+          name, argLists, argTypes)
+      }
+    }
+    val clazz = scala.reflect.classTag[T].runtimeClass
+    val info = new ExpressionInfo(clazz.getCanonicalName, null, name, argLists,
+      "", "", "", "", "", "")

Review comment:
       Could you fill the other variables in `ExpressionInfo` like `examples`, `since`, ...?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org