You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kyuubi.apache.org by ul...@apache.org on 2023/01/13 03:44:38 UTC

[kyuubi] branch master updated: [KYUUBI #3936] Parse trino http request header

This is an automated email from the ASF dual-hosted git repository.

ulyssesyou pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/kyuubi.git


The following commit(s) were added to refs/heads/master by this push:
     new 7f6ba8d99 [KYUUBI #3936] Parse trino http request header
7f6ba8d99 is described below

commit 7f6ba8d99c79019f387278cea68f75b6b902c4e0
Author: odone <od...@gmail.com>
AuthorDate: Fri Jan 13 11:44:25 2023 +0800

    [KYUUBI #3936] Parse trino http request header
    
    close #3936
    ### _Why are the changes needed?_
    
    ### _How was this patch tested?_
    - [x] Add some test cases that check the changes thoroughly including negative and positive cases if possible
    
    - [ ] Add screenshots for manual tests if appropriate
    
    - [x] [Run test](https://kyuubi.apache.org/docs/latest/develop_tools/testing.html#running-tests) locally before make a pull request
    
    Closes #4121 from iodone/kyuubi-3936.
    
    Closes #3936
    
    e6cf32c9 [odone] parse trino http request header
    
    Authored-by: odone <od...@gmail.com>
    Signed-off-by: ulysses-you <ul...@apache.org>
---
 .../kyuubi/server/trino/api/TrinoContext.scala     | 169 +++++++++++++++++++++
 .../server/trino/api/v1/StatementResource.scala    |   4 +
 .../org/apache/kyuubi/TrinoClientTestHelper.scala  |  80 ++++++++++
 .../server/trino/api/TrinoContextSuite.scala       |  70 +++++++++
 4 files changed, 323 insertions(+)

diff --git a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/TrinoContext.scala b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/TrinoContext.scala
new file mode 100644
index 000000000..8f3131f61
--- /dev/null
+++ b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/TrinoContext.scala
@@ -0,0 +1,169 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.server.trino.api
+
+import java.io.UnsupportedEncodingException
+import java.net.{URLDecoder, URLEncoder}
+import javax.ws.rs.core.{HttpHeaders, Response}
+
+import scala.collection.JavaConverters._
+
+import io.trino.client.ProtocolHeaders.TRINO_HEADERS
+import io.trino.client.QueryResults
+
+/**
+ * The description and functionality of trino request
+ * and response's context
+ *
+ * @param user Specifies the session user, must be supplied with every query
+ * @param timeZone  The timezone for query processing
+ * @param clientCapabilities Exclusive for trino server
+ * @param source This supplies the name of the software that submitted the query,
+ *               e.g. `trino-jdbc` or `trino-cli` by default
+ * @param catalog The catalog context for query processing, will be set response
+ * @param schema The schema context for query processing
+ * @param language The language to use when processing the query and formatting results,
+ *               formatted as a Java Locale string, e.g., en-US for US English
+ * @param traceToken Trace token for correlating requests across systems
+ * @param clientInfo Extra information about the client
+ * @param clientTags Client tags for selecting resource groups. Example: abc,xyz
+ * @param preparedStatement `preparedStatement` are kv pairs, where the names
+ *                          are names of previously prepared SQL statements,
+ *                          and the values are keys that identify the
+ *                          executable form of the named prepared statements
+ */
+case class TrinoContext(
+    user: String,
+    timeZone: Option[String] = None,
+    clientCapabilities: Option[String] = None,
+    source: Option[String] = None,
+    catalog: Option[String] = None,
+    schema: Option[String] = None,
+    language: Option[String] = None,
+    traceToken: Option[String] = None,
+    clientInfo: Option[String] = None,
+    clientTags: Set[String] = Set.empty,
+    session: Map[String, String] = Map.empty,
+    preparedStatement: Map[String, String] = Map.empty) {}
+
+object TrinoContext {
+
+  def apply(headers: HttpHeaders): TrinoContext = {
+    apply(headers.getRequestHeaders.asScala.toMap.map {
+      case (k, v) => (k, v.asScala.toList)
+    })
+  }
+
+  def apply(headers: Map[String, List[String]]): TrinoContext = {
+    val requestCtx = TrinoContext("")
+    val kvPattern = """(.+)=(.+)""".r
+    headers.foldLeft(requestCtx) { case (context, (k, v)) =>
+      k match {
+        case k if TRINO_HEADERS.requestUser.equalsIgnoreCase(k) && v.nonEmpty =>
+          context.copy(user = v.head)
+        case k if TRINO_HEADERS.requestTimeZone.equalsIgnoreCase(k) =>
+          context.copy(timeZone = v.headOption)
+        case k if TRINO_HEADERS.requestClientCapabilities.equalsIgnoreCase(k) =>
+          context.copy(clientCapabilities = v.headOption)
+        case k if TRINO_HEADERS.requestSource.equalsIgnoreCase(k) =>
+          context.copy(source = v.headOption)
+        case k if TRINO_HEADERS.requestCatalog.equalsIgnoreCase(k) =>
+          context.copy(catalog = v.headOption)
+        case k if TRINO_HEADERS.requestSchema.equalsIgnoreCase(k) =>
+          context.copy(schema = v.headOption)
+        case k if TRINO_HEADERS.requestLanguage.equalsIgnoreCase(k) =>
+          context.copy(language = v.headOption)
+        case k if TRINO_HEADERS.requestTraceToken.equalsIgnoreCase(k) =>
+          context.copy(traceToken = v.headOption)
+        case k if TRINO_HEADERS.requestClientInfo.equalsIgnoreCase(k) =>
+          context.copy(clientInfo = v.headOption)
+        case k if TRINO_HEADERS.requestClientTags.equalsIgnoreCase(k) && v.nonEmpty =>
+          context.copy(clientTags = v.head.split(",").toSet)
+        case k if TRINO_HEADERS.requestSession.equalsIgnoreCase(k) =>
+          val session = v.collect {
+            case kvPattern(key, value) => (key, urlDecode(value))
+          }.toMap
+          context.copy(session = session)
+        case k if TRINO_HEADERS.requestPreparedStatement.equalsIgnoreCase(k) =>
+          val preparedStatement = v.collect {
+            case kvPattern(key, value) => (key, urlDecode(value))
+          }.toMap
+          context.copy(preparedStatement = preparedStatement)
+
+        case k
+            if TRINO_HEADERS.requestTransactionId.equalsIgnoreCase(k)
+              && v.headOption.exists(_ != "NONE") =>
+          throw new UnsupportedOperationException(s"$k is not currently supported")
+        case k if TRINO_HEADERS.requestPath.equalsIgnoreCase(k) =>
+          throw new UnsupportedOperationException(s"$k is not currently supported")
+        case k if TRINO_HEADERS.requestRole.equalsIgnoreCase(k) =>
+          throw new UnsupportedOperationException(s"$k is not currently supported")
+        case k if TRINO_HEADERS.requestResourceEstimate.equalsIgnoreCase(k) =>
+          throw new UnsupportedOperationException(s"$k is not currently supported")
+        case k if TRINO_HEADERS.requestExtraCredential.equalsIgnoreCase(k) =>
+          throw new UnsupportedOperationException(s"$k is not currently supported")
+        case k if TRINO_HEADERS.requestRole.equalsIgnoreCase(k) =>
+          throw new UnsupportedOperationException(s"$k is not currently supported")
+        case _ =>
+          context
+      }
+    }
+  }
+
+  // TODO: Building response with TrinoContext and other information
+  def buildTrinoResponse(qr: QueryResults, trinoContext: TrinoContext): Response = {
+    val responseBuilder = Response.ok(qr)
+
+    trinoContext.catalog.foreach(
+      responseBuilder.header(TRINO_HEADERS.responseSetCatalog, _))
+    trinoContext.schema.foreach(
+      responseBuilder.header(TRINO_HEADERS.responseSetSchema, _))
+
+    trinoContext.session.foreach {
+      case (k, v) =>
+        responseBuilder.header(TRINO_HEADERS.responseSetSession, s"${k}=${urlEncode(v)}")
+    }
+    trinoContext.preparedStatement.foreach {
+      case (k, v) =>
+        responseBuilder.header(TRINO_HEADERS.responseAddedPrepare, s"${k}=${urlEncode(v)}")
+    }
+
+    List("responseDeallocatedPrepare").foreach { v =>
+      responseBuilder.header(TRINO_HEADERS.responseDeallocatedPrepare, urlEncode(v))
+    }
+
+    responseBuilder.header(TRINO_HEADERS.responseClearSession, s"responseClearSession")
+    responseBuilder.header(TRINO_HEADERS.responseClearTransactionId, "false")
+    responseBuilder.build()
+  }
+
+  def urlEncode(value: String): String =
+    try URLEncoder.encode(value, "UTF-8")
+    catch {
+      case e: UnsupportedEncodingException =>
+        throw new AssertionError(e)
+    }
+
+  def urlDecode(value: String): String =
+    try URLDecoder.decode(value, "UTF-8")
+    catch {
+      case e: UnsupportedEncodingException =>
+        throw new AssertionError(e)
+    }
+
+}
diff --git a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/v1/StatementResource.scala b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/v1/StatementResource.scala
index 04e3408a3..3d149b5f3 100644
--- a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/v1/StatementResource.scala
+++ b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/v1/StatementResource.scala
@@ -67,6 +67,7 @@ private[v1] class StatementResource extends ApiRequestContext with Logging {
   @GET
   @Path("/queued/{queryId}/{slug}/{token}")
   def getQueuedStatementStatus(
+      @Context headers: HttpHeaders,
       @PathParam("queryId") queryId: String,
       @PathParam("slug") slug: String,
       @PathParam("token") token: Long): QueryResults = {
@@ -82,6 +83,7 @@ private[v1] class StatementResource extends ApiRequestContext with Logging {
   @GET
   @Path("/executing/{queryId}/{slug}/{token}")
   def getExecutingStatementStatus(
+      @Context headers: HttpHeaders,
       @PathParam("queryId") queryId: String,
       @PathParam("slug") slug: String,
       @PathParam("token") token: Long): QueryResults = {
@@ -97,6 +99,7 @@ private[v1] class StatementResource extends ApiRequestContext with Logging {
   @DELETE
   @Path("/queued/{queryId}/{slug}/{token}")
   def cancelQueuedStatement(
+      @Context headers: HttpHeaders,
       @PathParam("queryId") queryId: String,
       @PathParam("slug") slug: String,
       @PathParam("token") token: Long): QueryResults = {
@@ -112,6 +115,7 @@ private[v1] class StatementResource extends ApiRequestContext with Logging {
   @DELETE
   @Path("/executing/{queryId}/{slug}/{token}")
   def cancelExecutingStatementStatus(
+      @Context headers: HttpHeaders,
       @PathParam("queryId") queryId: String,
       @PathParam("slug") slug: String,
       @PathParam("token") token: Long): QueryResults = {
diff --git a/kyuubi-server/src/test/scala/org/apache/kyuubi/TrinoClientTestHelper.scala b/kyuubi-server/src/test/scala/org/apache/kyuubi/TrinoClientTestHelper.scala
new file mode 100644
index 000000000..c0b3949f4
--- /dev/null
+++ b/kyuubi-server/src/test/scala/org/apache/kyuubi/TrinoClientTestHelper.scala
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi
+
+import java.net.URI
+import java.time.ZoneId
+import java.util.{Locale, Optional}
+import java.util.concurrent.TimeUnit
+
+import scala.collection.JavaConverters._
+
+import io.airlift.units.Duration
+import io.trino.client.{ClientSelectedRole, ClientSession, StatementClient, StatementClientFactory}
+import okhttp3.OkHttpClient
+
+trait TrinoClientTestHelper extends RestFrontendTestHelper {
+
+  override def afterAll(): Unit = {
+    super.afterAll()
+  }
+
+  private val httpClient = new OkHttpClient.Builder().build()
+
+  protected val clientSession = createClientSession(baseUri: URI)
+
+  def getTrinoStatementClient(sql: String): StatementClient = {
+    StatementClientFactory.newStatementClient(httpClient, clientSession, sql)
+  }
+
+  def createClientSession(connectUrl: URI): ClientSession = {
+    new ClientSession(
+      connectUrl,
+      "kyuubi_test",
+      Optional.of("test_user"),
+      "kyuubi",
+      Optional.of("test_token_tracing"),
+      Set[String]().asJava,
+      "test_client_info",
+      "test_catalog",
+      "test_schema",
+      "test_path",
+      ZoneId.systemDefault(),
+      Locale.getDefault,
+      Map[String, String](
+        "test_resource_key0" -> "test_resource_value0",
+        "test_resource_key1" -> "test_resource_value1").asJava,
+      Map[String, String](
+        "test_property_key0" -> "test_property_value0",
+        "test_property_key1" -> "test_propert_value1").asJava,
+      Map[String, String](
+        "test_statement_key0" -> "select 1",
+        "test_statement_key1" -> "select 2").asJava,
+      Map[String, ClientSelectedRole](
+        "test_role_key0" -> ClientSelectedRole.valueOf("ROLE"),
+        "test_role_key2" -> ClientSelectedRole.valueOf("ALL")).asJava,
+      Map[String, String](
+        "test_credentials_key0" -> "test_credentials_value0",
+        "test_credentials_key1" -> "test_credentials_value1").asJava,
+      "test_transaction_id",
+      new Duration(2, TimeUnit.MINUTES),
+      true)
+
+  }
+
+}
diff --git a/kyuubi-server/src/test/scala/org/apache/kyuubi/server/trino/api/TrinoContextSuite.scala b/kyuubi-server/src/test/scala/org/apache/kyuubi/server/trino/api/TrinoContextSuite.scala
new file mode 100644
index 000000000..67a502288
--- /dev/null
+++ b/kyuubi-server/src/test/scala/org/apache/kyuubi/server/trino/api/TrinoContextSuite.scala
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.server.trino.api
+
+import java.time.ZoneId
+
+import io.trino.client.ProtocolHeaders.TRINO_HEADERS
+
+import org.apache.kyuubi.KyuubiFunSuite
+
+class TrinoContextSuite extends KyuubiFunSuite {
+  import TrinoContext._
+
+  test("create trino request context with header") {
+    val testHeader0 = Map(
+      TRINO_HEADERS.requestUser -> List("requestUser"),
+      TRINO_HEADERS.requestTimeZone -> List(ZoneId.systemDefault().getId),
+      TRINO_HEADERS.requestTransactionId -> List("NONE"),
+      TRINO_HEADERS.requestClientCapabilities -> List("requestClientCapabilities"),
+      TRINO_HEADERS.requestSource -> List("requestSource"),
+      TRINO_HEADERS.requestCatalog -> List("requestCatalog"),
+      TRINO_HEADERS.requestSchema -> List("requestSchema"),
+      TRINO_HEADERS.requestLanguage -> List("requestLanguage"),
+      TRINO_HEADERS.requestTraceToken -> List("requestTraceToken"),
+      TRINO_HEADERS.requestClientInfo -> List("requestClientInfo"),
+      TRINO_HEADERS.requestClientTags -> List(
+        "requestClientTag1,requestClientTag2,requestClientTag2"),
+      TRINO_HEADERS.requestSession -> List(
+        "",
+        s"key0=${urlEncode("value0")}",
+        s"key1=${urlEncode("value1")}",
+        "badcase"),
+      TRINO_HEADERS.requestPreparedStatement -> List(
+        "badcase",
+        s"key0=${urlEncode("select 1")}",
+        s"key1=${urlEncode("select 2")}",
+        ""))
+    val expectedTrinoContext = new TrinoContext(
+      user = "requestUser",
+      timeZone = Some(ZoneId.systemDefault().getId),
+      clientCapabilities = Some("requestClientCapabilities"),
+      source = Some("requestSource"),
+      catalog = Some("requestCatalog"),
+      schema = Some("requestSchema"),
+      language = Some("requestLanguage"),
+      traceToken = Some("requestTraceToken"),
+      clientInfo = Some("requestClientInfo"),
+      clientTags = Set("requestClientTag1", "requestClientTag2"),
+      session = Map("key0" -> "value0", "key1" -> "value1"),
+      preparedStatement = Map("key0" -> "select 1", "key1" -> "select 2"))
+    val actual = TrinoContext(testHeader0)
+    assert(actual == expectedTrinoContext)
+  }
+
+}