You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kyuubi.apache.org by ul...@apache.org on 2023/01/13 03:44:38 UTC
[kyuubi] branch master updated: [KYUUBI #3936] Parse trino http request header
This is an automated email from the ASF dual-hosted git repository.
ulyssesyou pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/kyuubi.git
The following commit(s) were added to refs/heads/master by this push:
new 7f6ba8d99 [KYUUBI #3936] Parse trino http request header
7f6ba8d99 is described below
commit 7f6ba8d99c79019f387278cea68f75b6b902c4e0
Author: odone <od...@gmail.com>
AuthorDate: Fri Jan 13 11:44:25 2023 +0800
[KYUUBI #3936] Parse trino http request header
close #3936
### _Why are the changes needed?_
### _How was this patch tested?_
- [x] Add some test cases that check the changes thoroughly including negative and positive cases if possible
- [ ] Add screenshots for manual tests if appropriate
- [x] [Run test](https://kyuubi.apache.org/docs/latest/develop_tools/testing.html#running-tests) locally before make a pull request
Closes #4121 from iodone/kyuubi-3936.
Closes #3936
e6cf32c9 [odone] parse trino http request header
Authored-by: odone <od...@gmail.com>
Signed-off-by: ulysses-you <ul...@apache.org>
---
.../kyuubi/server/trino/api/TrinoContext.scala | 169 +++++++++++++++++++++
.../server/trino/api/v1/StatementResource.scala | 4 +
.../org/apache/kyuubi/TrinoClientTestHelper.scala | 80 ++++++++++
.../server/trino/api/TrinoContextSuite.scala | 70 +++++++++
4 files changed, 323 insertions(+)
diff --git a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/TrinoContext.scala b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/TrinoContext.scala
new file mode 100644
index 000000000..8f3131f61
--- /dev/null
+++ b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/TrinoContext.scala
@@ -0,0 +1,169 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.server.trino.api
+
+import java.io.UnsupportedEncodingException
+import java.net.{URLDecoder, URLEncoder}
+import javax.ws.rs.core.{HttpHeaders, Response}
+
+import scala.collection.JavaConverters._
+
+import io.trino.client.ProtocolHeaders.TRINO_HEADERS
+import io.trino.client.QueryResults
+
+/**
+ * The description and functionality of trino request
+ * and response's context
+ *
+ * @param user Specifies the session user, must be supplied with every query
+ * @param timeZone The timezone for query processing
+ * @param clientCapabilities Exclusive for trino server
+ * @param source This supplies the name of the software that submitted the query,
+ * e.g. `trino-jdbc` or `trino-cli` by default
+ * @param catalog The catalog context for query processing, will be set response
+ * @param schema The schema context for query processing
+ * @param language The language to use when processing the query and formatting results,
+ * formatted as a Java Locale string, e.g., en-US for US English
+ * @param traceToken Trace token for correlating requests across systems
+ * @param clientInfo Extra information about the client
+ * @param clientTags Client tags for selecting resource groups. Example: abc,xyz
+ * @param preparedStatement `preparedStatement` are kv pairs, where the names
+ * are names of previously prepared SQL statements,
+ * and the values are keys that identify the
+ * executable form of the named prepared statements
+ */
+case class TrinoContext(
+ user: String,
+ timeZone: Option[String] = None,
+ clientCapabilities: Option[String] = None,
+ source: Option[String] = None,
+ catalog: Option[String] = None,
+ schema: Option[String] = None,
+ language: Option[String] = None,
+ traceToken: Option[String] = None,
+ clientInfo: Option[String] = None,
+ clientTags: Set[String] = Set.empty,
+ session: Map[String, String] = Map.empty,
+ preparedStatement: Map[String, String] = Map.empty) {}
+
+object TrinoContext {
+
+ def apply(headers: HttpHeaders): TrinoContext = {
+ apply(headers.getRequestHeaders.asScala.toMap.map {
+ case (k, v) => (k, v.asScala.toList)
+ })
+ }
+
+ def apply(headers: Map[String, List[String]]): TrinoContext = {
+ val requestCtx = TrinoContext("")
+ val kvPattern = """(.+)=(.+)""".r
+ headers.foldLeft(requestCtx) { case (context, (k, v)) =>
+ k match {
+ case k if TRINO_HEADERS.requestUser.equalsIgnoreCase(k) && v.nonEmpty =>
+ context.copy(user = v.head)
+ case k if TRINO_HEADERS.requestTimeZone.equalsIgnoreCase(k) =>
+ context.copy(timeZone = v.headOption)
+ case k if TRINO_HEADERS.requestClientCapabilities.equalsIgnoreCase(k) =>
+ context.copy(clientCapabilities = v.headOption)
+ case k if TRINO_HEADERS.requestSource.equalsIgnoreCase(k) =>
+ context.copy(source = v.headOption)
+ case k if TRINO_HEADERS.requestCatalog.equalsIgnoreCase(k) =>
+ context.copy(catalog = v.headOption)
+ case k if TRINO_HEADERS.requestSchema.equalsIgnoreCase(k) =>
+ context.copy(schema = v.headOption)
+ case k if TRINO_HEADERS.requestLanguage.equalsIgnoreCase(k) =>
+ context.copy(language = v.headOption)
+ case k if TRINO_HEADERS.requestTraceToken.equalsIgnoreCase(k) =>
+ context.copy(traceToken = v.headOption)
+ case k if TRINO_HEADERS.requestClientInfo.equalsIgnoreCase(k) =>
+ context.copy(clientInfo = v.headOption)
+ case k if TRINO_HEADERS.requestClientTags.equalsIgnoreCase(k) && v.nonEmpty =>
+ context.copy(clientTags = v.head.split(",").toSet)
+ case k if TRINO_HEADERS.requestSession.equalsIgnoreCase(k) =>
+ val session = v.collect {
+ case kvPattern(key, value) => (key, urlDecode(value))
+ }.toMap
+ context.copy(session = session)
+ case k if TRINO_HEADERS.requestPreparedStatement.equalsIgnoreCase(k) =>
+ val preparedStatement = v.collect {
+ case kvPattern(key, value) => (key, urlDecode(value))
+ }.toMap
+ context.copy(preparedStatement = preparedStatement)
+
+ case k
+ if TRINO_HEADERS.requestTransactionId.equalsIgnoreCase(k)
+ && v.headOption.exists(_ != "NONE") =>
+ throw new UnsupportedOperationException(s"$k is not currently supported")
+ case k if TRINO_HEADERS.requestPath.equalsIgnoreCase(k) =>
+ throw new UnsupportedOperationException(s"$k is not currently supported")
+ case k if TRINO_HEADERS.requestRole.equalsIgnoreCase(k) =>
+ throw new UnsupportedOperationException(s"$k is not currently supported")
+ case k if TRINO_HEADERS.requestResourceEstimate.equalsIgnoreCase(k) =>
+ throw new UnsupportedOperationException(s"$k is not currently supported")
+ case k if TRINO_HEADERS.requestExtraCredential.equalsIgnoreCase(k) =>
+ throw new UnsupportedOperationException(s"$k is not currently supported")
+ case k if TRINO_HEADERS.requestRole.equalsIgnoreCase(k) =>
+ throw new UnsupportedOperationException(s"$k is not currently supported")
+ case _ =>
+ context
+ }
+ }
+ }
+
+ // TODO: Building response with TrinoContext and other information
+ def buildTrinoResponse(qr: QueryResults, trinoContext: TrinoContext): Response = {
+ val responseBuilder = Response.ok(qr)
+
+ trinoContext.catalog.foreach(
+ responseBuilder.header(TRINO_HEADERS.responseSetCatalog, _))
+ trinoContext.schema.foreach(
+ responseBuilder.header(TRINO_HEADERS.responseSetSchema, _))
+
+ trinoContext.session.foreach {
+ case (k, v) =>
+ responseBuilder.header(TRINO_HEADERS.responseSetSession, s"${k}=${urlEncode(v)}")
+ }
+ trinoContext.preparedStatement.foreach {
+ case (k, v) =>
+ responseBuilder.header(TRINO_HEADERS.responseAddedPrepare, s"${k}=${urlEncode(v)}")
+ }
+
+ List("responseDeallocatedPrepare").foreach { v =>
+ responseBuilder.header(TRINO_HEADERS.responseDeallocatedPrepare, urlEncode(v))
+ }
+
+ responseBuilder.header(TRINO_HEADERS.responseClearSession, s"responseClearSession")
+ responseBuilder.header(TRINO_HEADERS.responseClearTransactionId, "false")
+ responseBuilder.build()
+ }
+
+ def urlEncode(value: String): String =
+ try URLEncoder.encode(value, "UTF-8")
+ catch {
+ case e: UnsupportedEncodingException =>
+ throw new AssertionError(e)
+ }
+
+ def urlDecode(value: String): String =
+ try URLDecoder.decode(value, "UTF-8")
+ catch {
+ case e: UnsupportedEncodingException =>
+ throw new AssertionError(e)
+ }
+
+}
diff --git a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/v1/StatementResource.scala b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/v1/StatementResource.scala
index 04e3408a3..3d149b5f3 100644
--- a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/v1/StatementResource.scala
+++ b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/v1/StatementResource.scala
@@ -67,6 +67,7 @@ private[v1] class StatementResource extends ApiRequestContext with Logging {
@GET
@Path("/queued/{queryId}/{slug}/{token}")
def getQueuedStatementStatus(
+ @Context headers: HttpHeaders,
@PathParam("queryId") queryId: String,
@PathParam("slug") slug: String,
@PathParam("token") token: Long): QueryResults = {
@@ -82,6 +83,7 @@ private[v1] class StatementResource extends ApiRequestContext with Logging {
@GET
@Path("/executing/{queryId}/{slug}/{token}")
def getExecutingStatementStatus(
+ @Context headers: HttpHeaders,
@PathParam("queryId") queryId: String,
@PathParam("slug") slug: String,
@PathParam("token") token: Long): QueryResults = {
@@ -97,6 +99,7 @@ private[v1] class StatementResource extends ApiRequestContext with Logging {
@DELETE
@Path("/queued/{queryId}/{slug}/{token}")
def cancelQueuedStatement(
+ @Context headers: HttpHeaders,
@PathParam("queryId") queryId: String,
@PathParam("slug") slug: String,
@PathParam("token") token: Long): QueryResults = {
@@ -112,6 +115,7 @@ private[v1] class StatementResource extends ApiRequestContext with Logging {
@DELETE
@Path("/executing/{queryId}/{slug}/{token}")
def cancelExecutingStatementStatus(
+ @Context headers: HttpHeaders,
@PathParam("queryId") queryId: String,
@PathParam("slug") slug: String,
@PathParam("token") token: Long): QueryResults = {
diff --git a/kyuubi-server/src/test/scala/org/apache/kyuubi/TrinoClientTestHelper.scala b/kyuubi-server/src/test/scala/org/apache/kyuubi/TrinoClientTestHelper.scala
new file mode 100644
index 000000000..c0b3949f4
--- /dev/null
+++ b/kyuubi-server/src/test/scala/org/apache/kyuubi/TrinoClientTestHelper.scala
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi
+
+import java.net.URI
+import java.time.ZoneId
+import java.util.{Locale, Optional}
+import java.util.concurrent.TimeUnit
+
+import scala.collection.JavaConverters._
+
+import io.airlift.units.Duration
+import io.trino.client.{ClientSelectedRole, ClientSession, StatementClient, StatementClientFactory}
+import okhttp3.OkHttpClient
+
+trait TrinoClientTestHelper extends RestFrontendTestHelper {
+
+ override def afterAll(): Unit = {
+ super.afterAll()
+ }
+
+ private val httpClient = new OkHttpClient.Builder().build()
+
+ protected val clientSession = createClientSession(baseUri: URI)
+
+ def getTrinoStatementClient(sql: String): StatementClient = {
+ StatementClientFactory.newStatementClient(httpClient, clientSession, sql)
+ }
+
+ def createClientSession(connectUrl: URI): ClientSession = {
+ new ClientSession(
+ connectUrl,
+ "kyuubi_test",
+ Optional.of("test_user"),
+ "kyuubi",
+ Optional.of("test_token_tracing"),
+ Set[String]().asJava,
+ "test_client_info",
+ "test_catalog",
+ "test_schema",
+ "test_path",
+ ZoneId.systemDefault(),
+ Locale.getDefault,
+ Map[String, String](
+ "test_resource_key0" -> "test_resource_value0",
+ "test_resource_key1" -> "test_resource_value1").asJava,
+ Map[String, String](
+ "test_property_key0" -> "test_property_value0",
+ "test_property_key1" -> "test_propert_value1").asJava,
+ Map[String, String](
+ "test_statement_key0" -> "select 1",
+ "test_statement_key1" -> "select 2").asJava,
+ Map[String, ClientSelectedRole](
+ "test_role_key0" -> ClientSelectedRole.valueOf("ROLE"),
+ "test_role_key2" -> ClientSelectedRole.valueOf("ALL")).asJava,
+ Map[String, String](
+ "test_credentials_key0" -> "test_credentials_value0",
+ "test_credentials_key1" -> "test_credentials_value1").asJava,
+ "test_transaction_id",
+ new Duration(2, TimeUnit.MINUTES),
+ true)
+
+ }
+
+}
diff --git a/kyuubi-server/src/test/scala/org/apache/kyuubi/server/trino/api/TrinoContextSuite.scala b/kyuubi-server/src/test/scala/org/apache/kyuubi/server/trino/api/TrinoContextSuite.scala
new file mode 100644
index 000000000..67a502288
--- /dev/null
+++ b/kyuubi-server/src/test/scala/org/apache/kyuubi/server/trino/api/TrinoContextSuite.scala
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.server.trino.api
+
+import java.time.ZoneId
+
+import io.trino.client.ProtocolHeaders.TRINO_HEADERS
+
+import org.apache.kyuubi.KyuubiFunSuite
+
+class TrinoContextSuite extends KyuubiFunSuite {
+ import TrinoContext._
+
+ test("create trino request context with header") {
+ val testHeader0 = Map(
+ TRINO_HEADERS.requestUser -> List("requestUser"),
+ TRINO_HEADERS.requestTimeZone -> List(ZoneId.systemDefault().getId),
+ TRINO_HEADERS.requestTransactionId -> List("NONE"),
+ TRINO_HEADERS.requestClientCapabilities -> List("requestClientCapabilities"),
+ TRINO_HEADERS.requestSource -> List("requestSource"),
+ TRINO_HEADERS.requestCatalog -> List("requestCatalog"),
+ TRINO_HEADERS.requestSchema -> List("requestSchema"),
+ TRINO_HEADERS.requestLanguage -> List("requestLanguage"),
+ TRINO_HEADERS.requestTraceToken -> List("requestTraceToken"),
+ TRINO_HEADERS.requestClientInfo -> List("requestClientInfo"),
+ TRINO_HEADERS.requestClientTags -> List(
+ "requestClientTag1,requestClientTag2,requestClientTag2"),
+ TRINO_HEADERS.requestSession -> List(
+ "",
+ s"key0=${urlEncode("value0")}",
+ s"key1=${urlEncode("value1")}",
+ "badcase"),
+ TRINO_HEADERS.requestPreparedStatement -> List(
+ "badcase",
+ s"key0=${urlEncode("select 1")}",
+ s"key1=${urlEncode("select 2")}",
+ ""))
+ val expectedTrinoContext = new TrinoContext(
+ user = "requestUser",
+ timeZone = Some(ZoneId.systemDefault().getId),
+ clientCapabilities = Some("requestClientCapabilities"),
+ source = Some("requestSource"),
+ catalog = Some("requestCatalog"),
+ schema = Some("requestSchema"),
+ language = Some("requestLanguage"),
+ traceToken = Some("requestTraceToken"),
+ clientInfo = Some("requestClientInfo"),
+ clientTags = Set("requestClientTag1", "requestClientTag2"),
+ session = Map("key0" -> "value0", "key1" -> "value1"),
+ preparedStatement = Map("key0" -> "select 1", "key1" -> "select 2"))
+ val actual = TrinoContext(testHeader0)
+ assert(actual == expectedTrinoContext)
+ }
+
+}