You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kyuubi.apache.org by ul...@apache.org on 2023/02/07 03:08:59 UTC

[kyuubi] branch master updated: [KYUUBI #3934] Compatiable with Trino rest dto

This is an automated email from the ASF dual-hosted git repository.

ulyssesyou pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/kyuubi.git


The following commit(s) were added to refs/heads/master by this push:
     new 1b92d8067 [KYUUBI #3934] Compatiable with Trino rest dto
1b92d8067 is described below

commit 1b92d80678d6a05ca8a4e9f3ddded6deaf0b3d9e
Author: yehere <86...@qq.com>
AuthorDate: Tue Feb 7 11:08:48 2023 +0800

    [KYUUBI #3934] Compatiable with Trino rest dto
    
    ### _Why are the changes needed?_
    
    close #3934
    
    ### _How was this patch tested?_
    - [x] Add some test cases that check the changes thoroughly including negative and positive cases if possible
    
    - [ ] Add screenshots for manual tests if appropriate
    
    - [x] [Run test](https://kyuubi.readthedocs.io/en/master/develop_tools/testing.html#running-tests) locally before make a pull request
    
    Closes #4182 from yehere/kyuubi-3934.
    
    Closes #3934
    
    ced64e2c [yehere] [KYUUBI #3934] Add more result types support
    3ef85230 [yehere] [KYUUBI #3934] Optimization for code review
    69e1f442 [yehere] [KYUUBI #3934] Merge the test class to TrinoContextSuite
    4f0a0152 [yehere] [KYUUBI #3934] Merge the class to TrinoContext
    7c9473f6 [yehere] [KYUUBI #3934] Format style, with Copyright  Profiles
    2023f3ce [yehere] [KYUUBI #3934] Format and add test case
    a2243b46 [yehere] [KYUUBI #3934] Compatiable with Trino rest dto
    
    Authored-by: yehere <86...@qq.com>
    Signed-off-by: ulyssesyou <ul...@apache.org>
---
 .../kyuubi/server/trino/api/TrinoContext.scala     | 235 +++++++++++++++++++--
 .../server/trino/api/TrinoContextSuite.scala       |  94 ++++++++-
 2 files changed, 310 insertions(+), 19 deletions(-)

diff --git a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/TrinoContext.scala b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/TrinoContext.scala
index 8f3131f61..4a0736ddb 100644
--- a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/TrinoContext.scala
+++ b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/TrinoContext.scala
@@ -18,34 +18,38 @@
 package org.apache.kyuubi.server.trino.api
 
 import java.io.UnsupportedEncodingException
-import java.net.{URLDecoder, URLEncoder}
+import java.net.{URI, URLDecoder, URLEncoder}
+import java.util
 import javax.ws.rs.core.{HttpHeaders, Response}
 
 import scala.collection.JavaConverters._
 
+import io.trino.client.{ClientStandardTypes, ClientTypeSignature, Column, QueryError, QueryResults, StatementStats, Warning}
 import io.trino.client.ProtocolHeaders.TRINO_HEADERS
-import io.trino.client.QueryResults
+import org.apache.hive.service.rpc.thrift.{TGetResultSetMetadataResp, TRowSet, TTypeId}
+
+import org.apache.kyuubi.operation.OperationStatus
 
 /**
  * The description and functionality of trino request
  * and response's context
  *
- * @param user Specifies the session user, must be supplied with every query
- * @param timeZone  The timezone for query processing
+ * @param user               Specifies the session user, must be supplied with every query
+ * @param timeZone           The timezone for query processing
  * @param clientCapabilities Exclusive for trino server
- * @param source This supplies the name of the software that submitted the query,
- *               e.g. `trino-jdbc` or `trino-cli` by default
- * @param catalog The catalog context for query processing, will be set response
- * @param schema The schema context for query processing
- * @param language The language to use when processing the query and formatting results,
- *               formatted as a Java Locale string, e.g., en-US for US English
- * @param traceToken Trace token for correlating requests across systems
- * @param clientInfo Extra information about the client
- * @param clientTags Client tags for selecting resource groups. Example: abc,xyz
- * @param preparedStatement `preparedStatement` are kv pairs, where the names
- *                          are names of previously prepared SQL statements,
- *                          and the values are keys that identify the
- *                          executable form of the named prepared statements
+ * @param source             This supplies the name of the software that submitted the query,
+ *                           e.g. `trino-jdbc` or `trino-cli` by default
+ * @param catalog            The catalog context for query processing, will be set response
+ * @param schema             The schema context for query processing
+ * @param language           The language to use when processing the query and formatting results,
+ *                           formatted as a Java Locale string, e.g., en-US for US English
+ * @param traceToken         Trace token for correlating requests across systems
+ * @param clientInfo         Extra information about the client
+ * @param clientTags         Client tags for selecting resource groups. Example: abc,xyz
+ * @param preparedStatement  `preparedStatement` are kv pairs, where the names
+ *                           are names of previously prepared SQL statements,
+ *                           and the values are keys that identify the
+ *                           executable form of the named prepared statements
  */
 case class TrinoContext(
     user: String,
@@ -63,6 +67,11 @@ case class TrinoContext(
 
 object TrinoContext {
 
+  private val defaultWarning: util.List[Warning] = new util.ArrayList[Warning]()
+  private val GENERIC_INTERNAL_ERROR_CODE = 65536
+  private val GENERIC_INTERNAL_ERROR_NAME = "GENERIC_INTERNAL_ERROR_NAME"
+  private val GENERIC_INTERNAL_ERROR_TYPE = "INTERNAL_ERROR"
+
   def apply(headers: HttpHeaders): TrinoContext = {
     apply(headers.getRequestHeaders.asScala.toMap.map {
       case (k, v) => (k, v.asScala.toList)
@@ -166,4 +175,196 @@ object TrinoContext {
         throw new AssertionError(e)
     }
 
+  def createQueryResults(
+      queryId: String,
+      nextUri: URI,
+      queryHtmlUri: URI,
+      queryStatus: OperationStatus,
+      columns: Option[TGetResultSetMetadataResp] = None,
+      data: Option[TRowSet] = None): QueryResults = {
+
+    val columnList = columns match {
+      case Some(value) => convertTColumn(value)
+      case None => null
+    }
+    val rowList = data match {
+      case Some(value) => convertTRowSet(value)
+      case None => null
+    }
+
+    new QueryResults(
+      queryId,
+      queryHtmlUri,
+      nextUri,
+      nextUri,
+      columnList,
+      rowList,
+      StatementStats.builder.setState(queryStatus.state.name()).setQueued(false)
+        .setElapsedTimeMillis(0).setQueuedTimeMillis(0).build(),
+      toQueryError(queryStatus),
+      defaultWarning,
+      null,
+      0L)
+  }
+
+  def convertTColumn(columns: TGetResultSetMetadataResp): util.List[Column] = {
+    columns.getSchema.getColumns.asScala.map(c => {
+      val tp = c.getTypeDesc.getTypes.get(0).getPrimitiveEntry.getType match {
+        case TTypeId.BOOLEAN_TYPE => ClientStandardTypes.BOOLEAN
+        case TTypeId.TINYINT_TYPE => ClientStandardTypes.TINYINT
+        case TTypeId.SMALLINT_TYPE => ClientStandardTypes.SMALLINT
+        case TTypeId.INT_TYPE => ClientStandardTypes.INTEGER
+        case TTypeId.BIGINT_TYPE => ClientStandardTypes.BIGINT
+        case TTypeId.FLOAT_TYPE => ClientStandardTypes.DOUBLE
+        case TTypeId.DOUBLE_TYPE => ClientStandardTypes.DOUBLE
+        case TTypeId.STRING_TYPE => ClientStandardTypes.VARCHAR
+        case TTypeId.TIMESTAMP_TYPE => ClientStandardTypes.TIMESTAMP
+        case TTypeId.BINARY_TYPE => ClientStandardTypes.VARBINARY
+        case TTypeId.DECIMAL_TYPE => ClientStandardTypes.DECIMAL
+        case TTypeId.DATE_TYPE => ClientStandardTypes.DATE
+        case TTypeId.VARCHAR_TYPE => ClientStandardTypes.VARCHAR
+        case TTypeId.CHAR_TYPE => ClientStandardTypes.CHAR
+        case TTypeId.INTERVAL_YEAR_MONTH_TYPE => ClientStandardTypes.INTERVAL_YEAR_TO_MONTH
+        case TTypeId.INTERVAL_DAY_TIME_TYPE => ClientStandardTypes.TIME_WITH_TIME_ZONE
+        case TTypeId.TIMESTAMPLOCALTZ_TYPE => ClientStandardTypes.TIMESTAMP_WITH_TIME_ZONE
+        case _ => ClientStandardTypes.VARCHAR
+      }
+      new Column(c.getColumnName, tp, new ClientTypeSignature(tp))
+    }).toList.asJava
+  }
+
+  def convertTRowSet(rowSet: TRowSet): util.List[util.List[Object]] = {
+    val dataResult = new util.LinkedList[util.List[Object]]
+
+    if (rowSet.getColumns == null) {
+      return rowSet.getRows.asScala
+        .map(t => t.getColVals.asScala.map(v => v.getFieldValue.asInstanceOf[Object]).asJava)
+        .asJava
+    }
+
+    rowSet.getColumns.asScala.foreach {
+      case tColumn if tColumn.isSetBoolVal =>
+        val nulls = util.BitSet.valueOf(tColumn.getBoolVal.getNulls)
+        if (dataResult.isEmpty) {
+          (1 to tColumn.getBoolVal.getValuesSize).foreach(_ =>
+            dataResult.add(new util.LinkedList[Object]()))
+        }
+
+        tColumn.getBoolVal.getValues.asScala.zipWithIndex.foreach {
+          case (_, rowIdx) if nulls.get(rowIdx) =>
+            dataResult.get(rowIdx).add(null)
+          case (v, rowIdx) =>
+            dataResult.get(rowIdx).add(v)
+        }
+      case tColumn if tColumn.isSetByteVal =>
+        val nulls = util.BitSet.valueOf(tColumn.getByteVal.getNulls)
+        if (dataResult.isEmpty) {
+          (1 to tColumn.getByteVal.getValuesSize).foreach(_ =>
+            dataResult.add(new util.LinkedList[Object]()))
+        }
+
+        tColumn.getByteVal.getValues.asScala.zipWithIndex.foreach {
+          case (_, rowIdx) if nulls.get(rowIdx) =>
+            dataResult.get(rowIdx).add(null)
+          case (v, rowIdx) =>
+            dataResult.get(rowIdx).add(v)
+        }
+      case tColumn if tColumn.isSetI16Val =>
+        val nulls = util.BitSet.valueOf(tColumn.getI16Val.getNulls)
+        if (dataResult.isEmpty) {
+          (1 to tColumn.getI16Val.getValuesSize).foreach(_ =>
+            dataResult.add(new util.LinkedList[Object]()))
+        }
+
+        tColumn.getI16Val.getValues.asScala.zipWithIndex.foreach {
+          case (_, rowIdx) if nulls.get(rowIdx) =>
+            dataResult.get(rowIdx).add(null)
+          case (v, rowIdx) =>
+            dataResult.get(rowIdx).add(v)
+        }
+      case tColumn if tColumn.isSetI32Val =>
+        val nulls = util.BitSet.valueOf(tColumn.getI32Val.getNulls)
+        if (dataResult.isEmpty) {
+          (1 to tColumn.getI32Val.getValuesSize).foreach(_ =>
+            dataResult.add(new util.LinkedList[Object]()))
+        }
+
+        tColumn.getI32Val.getValues.asScala.zipWithIndex.foreach {
+          case (_, rowIdx) if nulls.get(rowIdx) =>
+            dataResult.get(rowIdx).add(null)
+          case (v, rowIdx) =>
+            dataResult.get(rowIdx).add(v)
+        }
+      case tColumn if tColumn.isSetI64Val =>
+        val nulls = util.BitSet.valueOf(tColumn.getI64Val.getNulls)
+        if (dataResult.isEmpty) {
+          (1 to tColumn.getI64Val.getValuesSize).foreach(_ =>
+            dataResult.add(new util.LinkedList[Object]()))
+        }
+
+        tColumn.getI64Val.getValues.asScala.zipWithIndex.foreach {
+          case (_, rowIdx) if nulls.get(rowIdx) =>
+            dataResult.get(rowIdx).add(null)
+          case (v, rowIdx) =>
+            dataResult.get(rowIdx).add(v)
+        }
+      case tColumn if tColumn.isSetDoubleVal =>
+        val nulls = util.BitSet.valueOf(tColumn.getDoubleVal.getNulls)
+        if (dataResult.isEmpty) {
+          (1 to tColumn.getDoubleVal.getValuesSize).foreach(_ =>
+            dataResult.add(new util.LinkedList[Object]()))
+        }
+
+        tColumn.getDoubleVal.getValues.asScala.zipWithIndex.foreach {
+          case (_, rowIdx) if nulls.get(rowIdx) =>
+            dataResult.get(rowIdx).add(null)
+          case (v, rowIdx) =>
+            dataResult.get(rowIdx).add(v)
+        }
+      case tColumn if tColumn.isSetBinaryVal =>
+        val nulls = util.BitSet.valueOf(tColumn.getBinaryVal.getNulls)
+        if (dataResult.isEmpty) {
+          (1 to tColumn.getBinaryVal.getValuesSize).foreach(_ =>
+            dataResult.add(new util.LinkedList[Object]()))
+        }
+
+        tColumn.getBinaryVal.getValues.asScala.zipWithIndex.foreach {
+          case (_, rowIdx) if nulls.get(rowIdx) =>
+            dataResult.get(rowIdx).add(null)
+          case (v, rowIdx) =>
+            dataResult.get(rowIdx).add(v)
+        }
+      case tColumn =>
+        val nulls = util.BitSet.valueOf(tColumn.getStringVal.getNulls)
+        if (dataResult.isEmpty) {
+          (1 to tColumn.getStringVal.getValuesSize).foreach(_ =>
+            dataResult.add(new util.LinkedList[Object]()))
+        }
+
+        tColumn.getStringVal.getValues.asScala.zipWithIndex.foreach {
+          case (_, rowIdx) if nulls.get(rowIdx) =>
+            dataResult.get(rowIdx).add(null)
+          case (v, rowIdx) =>
+            dataResult.get(rowIdx).add(v)
+        }
+    }
+    dataResult
+  }
+
+  def toQueryError(queryStatus: OperationStatus): QueryError = {
+    val exception = queryStatus.exception
+    if (exception.isEmpty) {
+      null
+    } else {
+      new QueryError(
+        exception.get.getMessage,
+        queryStatus.state.name(),
+        GENERIC_INTERNAL_ERROR_CODE,
+        GENERIC_INTERNAL_ERROR_NAME,
+        GENERIC_INTERNAL_ERROR_TYPE,
+        null,
+        null)
+    }
+  }
+
 }
diff --git a/kyuubi-server/src/test/scala/org/apache/kyuubi/server/trino/api/TrinoContextSuite.scala b/kyuubi-server/src/test/scala/org/apache/kyuubi/server/trino/api/TrinoContextSuite.scala
index 67a502288..8d7b2bf2c 100644
--- a/kyuubi-server/src/test/scala/org/apache/kyuubi/server/trino/api/TrinoContextSuite.scala
+++ b/kyuubi-server/src/test/scala/org/apache/kyuubi/server/trino/api/TrinoContextSuite.scala
@@ -17,13 +17,24 @@
 
 package org.apache.kyuubi.server.trino.api
 
+import java.net.URI
 import java.time.ZoneId
+import javax.ws.rs.core.MediaType
+
+import scala.collection.JavaConverters._
 
 import io.trino.client.ProtocolHeaders.TRINO_HEADERS
+import org.apache.hive.service.rpc.thrift.TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V9
+import org.scalatest.concurrent.PatienceConfiguration.Timeout
+import org.scalatest.time.SpanSugar.convertIntToGrainOfTime
+
+import org.apache.kyuubi.{KyuubiFunSuite, RestFrontendTestHelper}
+import org.apache.kyuubi.events.KyuubiOperationEvent
+import org.apache.kyuubi.operation.{FetchOrientation, OperationHandle}
+import org.apache.kyuubi.operation.OperationState.{FINISHED, OperationState}
 
-import org.apache.kyuubi.KyuubiFunSuite
+class TrinoContextSuite extends KyuubiFunSuite with RestFrontendTestHelper {
 
-class TrinoContextSuite extends KyuubiFunSuite {
   import TrinoContext._
 
   test("create trino request context with header") {
@@ -67,4 +78,83 @@ class TrinoContextSuite extends KyuubiFunSuite {
     assert(actual == expectedTrinoContext)
   }
 
+  test("test convert") {
+    val opHandle = getOpHandle("select 1")
+    val opHandleStr = opHandle.identifier.toString
+    checkOpState(opHandleStr, FINISHED)
+
+    val metadataResp = fe.be.getResultSetMetadata(opHandle)
+    val tRowSet = fe.be.fetchResults(opHandle, FetchOrientation.FETCH_NEXT, 1000, false)
+    val status = fe.be.getOperationStatus(opHandle)
+
+    val uri = new URI("sfdsfsdfdsf")
+    val results = TrinoContext
+      .createQueryResults("/xdfd/xdf", uri, uri, status, Option(metadataResp), Option(tRowSet))
+
+    print(results.toString)
+    assert(results.getColumns.get(0).getType.equals("integer"))
+    assert(results.getData.asScala.last.get(0) == 1)
+  }
+
+  test("test convert from table") {
+    initSql("CREATE DATABASE IF NOT EXISTS INIT_DB")
+    initSql(
+      "CREATE TABLE IF NOT EXISTS INIT_DB.test(a int, b double, c String," +
+        "d BOOLEAN,e DATE,f TIMESTAMP,g ARRAY<String>,h DECIMAL," +
+        "i MAP<String,String>) USING PARQUET;")
+    initSql(
+      "INSERT INTO INIT_DB.test VALUES (1,2.2,'3',true,current_date()," +
+        "current_timestamp(),array('1','2'),2.0, map('m','p') )")
+
+    val opHandle = getOpHandle("SELECT * FROM INIT_DB.test")
+    val opHandleStr = opHandle.identifier.toString
+    checkOpState(opHandleStr, FINISHED)
+
+    val metadataResp = fe.be.getResultSetMetadata(opHandle)
+    val tRowSet = fe.be.fetchResults(opHandle, FetchOrientation.FETCH_NEXT, 1000, false)
+    val status = fe.be.getOperationStatus(opHandle)
+
+    val uri = new URI("sfdsfsdfdsf")
+    val results = TrinoContext
+      .createQueryResults("/xdfd/xdf", uri, uri, status, Option(metadataResp), Option(tRowSet))
+
+    print(results.toString)
+    assert(results.getColumns.get(0).getType.equals("integer"))
+    assert(results.getData.asScala.last.get(0) != null)
+  }
+
+  def getOpHandleStr(statement: String = "show tables"): String = {
+    getOpHandle(statement).identifier.toString
+  }
+
+  def getOpHandle(statement: String = "show tables"): OperationHandle = {
+    val sessionHandle = fe.be.openSession(
+      HIVE_CLI_SERVICE_PROTOCOL_V9,
+      "admin",
+      "123456",
+      "localhost",
+      Map("testConfig" -> "testValue"))
+
+    if (statement.nonEmpty) {
+      fe.be.executeStatement(sessionHandle, statement, Map.empty, runAsync = false, 30000)
+    } else {
+      fe.be.getCatalogs(sessionHandle)
+    }
+  }
+
+  private def checkOpState(opHandleStr: String, state: OperationState): Unit = {
+    eventually(Timeout(30.seconds)) {
+      val response = webTarget.path(s"api/v1/operations/$opHandleStr/event")
+        .request(MediaType.APPLICATION_JSON_TYPE).get()
+      assert(response.getStatus === 200)
+      val operationEvent = response.readEntity(classOf[KyuubiOperationEvent])
+      assert(operationEvent.state === state.name())
+    }
+  }
+
+  private def initSql(sql: String): Unit = {
+    val initOpHandle = getOpHandle(sql)
+    val initOpHandleStr = initOpHandle.identifier.toString
+    checkOpState(initOpHandleStr, FINISHED)
+  }
 }