You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@livy.apache.org by va...@apache.org on 2018/11/30 23:57:36 UTC
[3/5] incubator-livy git commit: [LIVY-502] Remove dependency on
hive-exec
http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/47d3ee6b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/LivyCLIService.scala
----------------------------------------------------------------------
diff --git a/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/LivyCLIService.scala b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/LivyCLIService.scala
index 5289354..725bdc8 100644
--- a/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/LivyCLIService.scala
+++ b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/LivyCLIService.scala
@@ -22,40 +22,36 @@ import java.util
import java.util.concurrent.{CancellationException, ExecutionException, TimeoutException, TimeUnit}
import javax.security.auth.login.LoginException
-import scala.collection.JavaConverters._
-
-import org.apache.hadoop.hive.common.log.ProgressMonitor
-import org.apache.hadoop.hive.conf.HiveConf
-import org.apache.hadoop.hive.conf.HiveConf.ConfVars
-import org.apache.hadoop.hive.ql.parse.ParseUtils
-import org.apache.hadoop.hive.shims.Utils
-import org.apache.hadoop.security.UserGroupInformation
-import org.apache.hive.service.{CompositeService, ServiceException}
-import org.apache.hive.service.auth.HiveAuthFactory
+import org.apache.hadoop.security.{SecurityUtil, UserGroupInformation}
+import org.apache.hive.service.ServiceException
import org.apache.hive.service.cli._
-import org.apache.hive.service.cli.operation.Operation
import org.apache.hive.service.rpc.thrift.{TOperationHandle, TProtocolVersion}
-import org.apache.livy.{LIVY_VERSION, Logging}
+import org.apache.livy.{LIVY_VERSION, LivyConf, Logging}
+import org.apache.livy.thriftserver.auth.AuthFactory
+import org.apache.livy.thriftserver.operation.{Operation, OperationStatus}
+import org.apache.livy.thriftserver.serde.ThriftResultSet
+import org.apache.livy.thriftserver.types.Schema
class LivyCLIService(server: LivyThriftServer)
- extends CompositeService(classOf[LivyCLIService].getName) with ICLIService with Logging {
+ extends ThriftService(classOf[LivyCLIService].getName) with Logging {
import LivyCLIService._
private var sessionManager: LivyThriftSessionManager = _
private var defaultFetchRows: Int = _
private var serviceUGI: UserGroupInformation = _
private var httpUGI: UserGroupInformation = _
+ private var maxTimeout: Long = _
- override def init(hiveConf: HiveConf): Unit = {
- sessionManager = new LivyThriftSessionManager(server, hiveConf)
+ override def init(livyConf: LivyConf): Unit = {
+ sessionManager = new LivyThriftSessionManager(server, livyConf)
addService(sessionManager)
- defaultFetchRows =
- hiveConf.getIntVar(ConfVars.HIVE_SERVER2_THRIFT_RESULTSET_DEFAULT_FETCH_SIZE)
+ defaultFetchRows = livyConf.getInt(LivyConf.THRIFT_RESULTSET_DEFAULT_FETCH_SIZE)
+ maxTimeout = livyConf.getTimeAsMs(LivyConf.THRIFT_LONG_POLLING_TIMEOUT)
// If the hadoop cluster is secure, do a kerberos login for the service from the keytab
if (UserGroupInformation.isSecurityEnabled) {
try {
- serviceUGI = Utils.getUGI
+ serviceUGI = UserGroupInformation.getCurrentUser
} catch {
case e: IOException =>
throw new ServiceException("Unable to login to kerberos with given principal/keytab", e)
@@ -63,19 +59,20 @@ class LivyCLIService(server: LivyThriftServer)
throw new ServiceException("Unable to login to kerberos with given principal/keytab", e)
}
// Also try creating a UGI object for the SPNego principal
- val principal = hiveConf.getVar(ConfVars.HIVE_SERVER2_SPNEGO_PRINCIPAL)
- val keyTabFile = hiveConf.getVar(ConfVars.HIVE_SERVER2_SPNEGO_KEYTAB)
+ val principal = livyConf.get(LivyConf.AUTH_KERBEROS_PRINCIPAL)
+ val keyTabFile = livyConf.get(LivyConf.AUTH_KERBEROS_KEYTAB)
if (principal.isEmpty || keyTabFile.isEmpty) {
info(s"SPNego httpUGI not created, SPNegoPrincipal: $principal, ketabFile: $keyTabFile")
} else try {
- httpUGI = HiveAuthFactory.loginFromSpnegoKeytabAndReturnUGI(hiveConf)
+ httpUGI = UserGroupInformation.loginUserFromKeytabAndReturnUGI(
+ SecurityUtil.getServerPrincipal(principal, "0.0.0.0"), keyTabFile)
info("SPNego httpUGI successfully created.")
} catch {
case e: IOException =>
warn("SPNego httpUGI creation failed: ", e)
}
}
- super.init(hiveConf)
+ super.init(livyConf)
}
def getServiceUGI: UserGroupInformation = this.serviceUGI
@@ -85,7 +82,7 @@ class LivyCLIService(server: LivyThriftServer)
def getSessionManager: LivyThriftSessionManager = sessionManager
@throws[HiveSQLException]
- override def getInfo(sessionHandle: SessionHandle, getInfoType: GetInfoType): GetInfoValue = {
+ def getInfo(sessionHandle: SessionHandle, getInfoType: GetInfoType): GetInfoValue = {
getInfoType match {
case GetInfoType.CLI_SERVER_NAME => new GetInfoValue("Livy JDBC")
case GetInfoType.CLI_DBMS_NAME => new GetInfoValue("Livy JDBC")
@@ -95,7 +92,7 @@ class LivyCLIService(server: LivyThriftServer)
case GetInfoType.CLI_MAX_SCHEMA_NAME_LEN => new GetInfoValue(128)
case GetInfoType.CLI_MAX_TABLE_NAME_LEN => new GetInfoValue(128)
case GetInfoType.CLI_ODBC_KEYWORDS =>
- new GetInfoValue(ParseUtils.getKeywords(LivyCLIService.ODBC_KEYWORDS))
+ new GetInfoValue(getKeywords)
case _ => throw new HiveSQLException(s"Unrecognized GetInfoType value: $getInfoType")
}
}
@@ -128,7 +125,7 @@ class LivyCLIService(server: LivyThriftServer)
}
@throws[HiveSQLException]
- override def openSession(
+ def openSession(
username: String,
password: String,
configuration: util.Map[String, String]): SessionHandle = {
@@ -139,7 +136,7 @@ class LivyCLIService(server: LivyThriftServer)
}
@throws[HiveSQLException]
- override def openSessionWithImpersonation(
+ def openSessionWithImpersonation(
username: String,
password: String,
configuration: util.Map[String, String], delegationToken: String): SessionHandle = {
@@ -150,13 +147,13 @@ class LivyCLIService(server: LivyThriftServer)
}
@throws[HiveSQLException]
- override def closeSession(sessionHandle: SessionHandle): Unit = {
+ def closeSession(sessionHandle: SessionHandle): Unit = {
sessionManager.closeSession(sessionHandle)
debug(sessionHandle + ": closeSession()")
}
@throws[HiveSQLException]
- override def executeStatement(
+ def executeStatement(
sessionHandle: SessionHandle,
statement: String,
confOverlay: util.Map[String, String]): OperationHandle = {
@@ -167,7 +164,7 @@ class LivyCLIService(server: LivyThriftServer)
* Execute statement on the server with a timeout. This is a blocking call.
*/
@throws[HiveSQLException]
- override def executeStatement(
+ def executeStatement(
sessionHandle: SessionHandle,
statement: String,
confOverlay: util.Map[String, String],
@@ -179,7 +176,7 @@ class LivyCLIService(server: LivyThriftServer)
}
@throws[HiveSQLException]
- override def executeStatementAsync(
+ def executeStatementAsync(
sessionHandle: SessionHandle,
statement: String,
confOverlay: util.Map[String, String]): OperationHandle = {
@@ -190,7 +187,7 @@ class LivyCLIService(server: LivyThriftServer)
* Execute statement asynchronously on the server with a timeout. This is a non-blocking call
*/
@throws[HiveSQLException]
- override def executeStatementAsync(
+ def executeStatementAsync(
sessionHandle: SessionHandle,
statement: String,
confOverlay: util.Map[String, String],
@@ -202,19 +199,19 @@ class LivyCLIService(server: LivyThriftServer)
}
@throws[HiveSQLException]
- override def getTypeInfo(sessionHandle: SessionHandle): OperationHandle = {
+ def getTypeInfo(sessionHandle: SessionHandle): OperationHandle = {
debug(sessionHandle + ": getTypeInfo()")
sessionManager.operationManager.getTypeInfo(sessionHandle)
}
@throws[HiveSQLException]
- override def getCatalogs(sessionHandle: SessionHandle): OperationHandle = {
+ def getCatalogs(sessionHandle: SessionHandle): OperationHandle = {
debug(sessionHandle + ": getCatalogs()")
sessionManager.operationManager.getCatalogs(sessionHandle)
}
@throws[HiveSQLException]
- override def getSchemas(
+ def getSchemas(
sessionHandle: SessionHandle,
catalogName: String,
schemaName: String): OperationHandle = {
@@ -223,7 +220,7 @@ class LivyCLIService(server: LivyThriftServer)
}
@throws[HiveSQLException]
- override def getTables(
+ def getTables(
sessionHandle: SessionHandle,
catalogName: String,
schemaName: String,
@@ -234,13 +231,13 @@ class LivyCLIService(server: LivyThriftServer)
}
@throws[HiveSQLException]
- override def getTableTypes(sessionHandle: SessionHandle): OperationHandle = {
+ def getTableTypes(sessionHandle: SessionHandle): OperationHandle = {
debug(sessionHandle + ": getTableTypes()")
sessionManager.operationManager.getTableTypes(sessionHandle)
}
@throws[HiveSQLException]
- override def getColumns(
+ def getColumns(
sessionHandle: SessionHandle,
catalogName: String,
schemaName: String,
@@ -251,7 +248,7 @@ class LivyCLIService(server: LivyThriftServer)
}
@throws[HiveSQLException]
- override def getFunctions(
+ def getFunctions(
sessionHandle: SessionHandle,
catalogName: String,
schemaName: String,
@@ -261,7 +258,7 @@ class LivyCLIService(server: LivyThriftServer)
}
@throws[HiveSQLException]
- override def getPrimaryKeys(
+ def getPrimaryKeys(
sessionHandle: SessionHandle,
catalog: String,
schema: String,
@@ -271,7 +268,7 @@ class LivyCLIService(server: LivyThriftServer)
}
@throws[HiveSQLException]
- override def getCrossReference(
+ def getCrossReference(
sessionHandle: SessionHandle,
primaryCatalog: String,
primarySchema: String,
@@ -284,7 +281,7 @@ class LivyCLIService(server: LivyThriftServer)
}
@throws[HiveSQLException]
- override def getOperationStatus(
+ def getOperationStatus(
opHandle: OperationHandle,
getProgressUpdate: Boolean): OperationStatus = {
val operation: Operation = sessionManager.operationManager.getOperation(opHandle)
@@ -294,10 +291,6 @@ class LivyCLIService(server: LivyThriftServer)
* However, if the background operation is complete, we return immediately.
*/
if (operation.shouldRunAsync) {
- val maxTimeout: Long = HiveConf.getTimeVar(
- getHiveConf,
- HiveConf.ConfVars.HIVE_SERVER2_LONG_POLLING_TIMEOUT,
- TimeUnit.MILLISECONDS)
val elapsed: Long = System.currentTimeMillis - operation.getBeginTime
// A step function to increase the polling timeout by 500 ms every 10 sec,
// starting from 500 ms up to HIVE_SERVER2_LONG_POLLING_TIMEOUT
@@ -321,76 +314,75 @@ class LivyCLIService(server: LivyThriftServer)
// In this case, the call might return sooner than long polling timeout
}
}
- val opStatus: OperationStatus = operation.getStatus
+ val opStatus = operation.getStatus
debug(opHandle + ": getOperationStatus()")
- opStatus.setJobProgressUpdate(new JobProgressUpdate(ProgressMonitor.NULL))
opStatus
}
@throws[HiveSQLException]
- override def cancelOperation(opHandle: OperationHandle): Unit = {
+ def cancelOperation(opHandle: OperationHandle): Unit = {
sessionManager.operationManager.cancelOperation(opHandle)
debug(opHandle + ": cancelOperation()")
}
@throws[HiveSQLException]
- override def closeOperation(opHandle: OperationHandle): Unit = {
+ def closeOperation(opHandle: OperationHandle): Unit = {
sessionManager.operationManager.closeOperation(opHandle)
debug(opHandle + ": closeOperation")
}
@throws[HiveSQLException]
- override def getResultSetMetadata(opHandle: OperationHandle): TableSchema = {
+ def getResultSetMetadata(opHandle: OperationHandle): Schema = {
debug(opHandle + ": getResultSetMetadata()")
sessionManager.operationManager.getOperation(opHandle).getResultSetSchema
}
@throws[HiveSQLException]
- override def fetchResults(opHandle: OperationHandle): RowSet = {
+ def fetchResults(opHandle: OperationHandle): ThriftResultSet = {
fetchResults(
opHandle, Operation.DEFAULT_FETCH_ORIENTATION, defaultFetchRows, FetchType.QUERY_OUTPUT)
}
@throws[HiveSQLException]
- override def fetchResults(
+ def fetchResults(
opHandle: OperationHandle,
orientation: FetchOrientation,
maxRows: Long,
- fetchType: FetchType): RowSet = {
+ fetchType: FetchType): ThriftResultSet = {
debug(opHandle + ": fetchResults()")
sessionManager.operationManager.fetchResults(opHandle, orientation, maxRows, fetchType)
}
@throws[HiveSQLException]
- override def getDelegationToken(
+ def getDelegationToken(
sessionHandle: SessionHandle,
- authFactory: HiveAuthFactory,
+ authFactory: AuthFactory,
owner: String,
renewer: String): String = {
throw new HiveSQLException("Operation not yet supported.")
}
@throws[HiveSQLException]
- override def setApplicationName(sh: SessionHandle, value: String): Unit = {
+ def setApplicationName(sh: SessionHandle, value: String): Unit = {
throw new HiveSQLException("Operation not yet supported.")
}
- override def cancelDelegationToken(
+ def cancelDelegationToken(
sessionHandle: SessionHandle,
- authFactory: HiveAuthFactory,
+ authFactory: AuthFactory,
tokenStr: String): Unit = {
throw new HiveSQLException("Operation not yet supported.")
}
- override def renewDelegationToken(
+ def renewDelegationToken(
sessionHandle: SessionHandle,
- authFactory: HiveAuthFactory,
+ authFactory: AuthFactory,
tokenStr: String): Unit = {
throw new HiveSQLException("Operation not yet supported.")
}
@throws[HiveSQLException]
- override def getQueryId(opHandle: TOperationHandle): String = {
+ def getQueryId(opHandle: TOperationHandle): String = {
throw new HiveSQLException("Operation not yet supported.")
}
}
@@ -428,5 +420,36 @@ object LivyCLIService {
"TIMESTAMP", "TIMEZONE_HOUR", "TIMEZONE_MINUTE", "TO", "TRAILING", "TRANSACTION", "TRANSLATE",
"TRANSLATION", "TRIM", "TRUE", "UNION", "UNIQUE", "UNKNOWN", "UPDATE", "UPPER", "USAGE",
"USER", "USING", "VALUE", "VALUES", "VARCHAR", "VARYING", "VIEW", "WHEN", "WHENEVER", "WHERE",
- "WITH", "WORK", "WRITE", "YEAR", "ZONE").asJava
+ "WITH", "WORK", "WRITE", "YEAR", "ZONE")
+
+ // scalastyle:off line.size.limit
+ // https://github.com/apache/spark/blob/515708d5f33d5acdb4206c626192d1838f8e691f/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4#L744-L1013
+ // scalastyle:on line.size.limit
+ private val SPARK_KEYWORDS = Seq("SELECT", "FROM", "ADD", "AS", "ALL", "ANY", "DISTINCT",
+ "WHERE", "GROUP", "BY", "GROUPING", "SETS", "CUBE", "ROLLUP", "ORDER", "HAVING", "LIMIT",
+ "AT", "OR", "AND", "IN", "NOT", "NO", "EXISTS", "BETWEEN", "LIKE", "RLIKE", "REGEXP", "IS",
+ "NULL", "TRUE", "FALSE", "NULLS", "ASC", "DESC", "FOR", "INTERVAL", "CASE", "WHEN", "THEN",
+ "ELSE", "END", "JOIN", "CROSS", "OUTER", "INNER", "LEFT", "SEMI", "RIGHT", "FULL", "NATURAL",
+ "ON", "PIVOT", "LATERAL", "WINDOW", "OVER", "PARTITION", "RANGE", "ROWS", "UNBOUNDED",
+ "PRECEDING", "FOLLOWING", "CURRENT", "FIRST", "AFTER", "LAST", "ROW", "WITH", "VALUES",
+ "CREATE", "TABLE", "DIRECTORY", "VIEW", "REPLACE", "INSERT", "DELETE", "INTO", "DESCRIBE",
+ "EXPLAIN", "FORMAT", "LOGICAL", "CODEGEN", "COST", "CAST", "SHOW", "TABLES", "COLUMNS",
+ "COLUMN", "USE", "PARTITIONS", "FUNCTIONS", "DROP", "UNION", "EXCEPT", "MINUS", "INTERSECT",
+ "TO", "TABLESAMPLE", "STRATIFY", "ALTER", "RENAME", "ARRAY", "MAP", "STRUCT", "COMMENT", "SET",
+ "RESET", "DATA", "START", "TRANSACTION", "COMMIT", "ROLLBACK", "MACRO", "IGNORE", "BOTH",
+ "LEADING", "TRAILING", "IF", "POSITION", "EXTRACT", "DIV", "PERCENT", "BUCKET", "OUT", "OF",
+ "SORT", "CLUSTER", "DISTRIBUTE", "OVERWRITE", "TRANSFORM", "REDUCE", "USING", "SERDE",
+ "SERDEPROPERTIES", "RECORDREADER", "RECORDWRITER", "DELIMITED", "FIELDS", "TERMINATED",
+ "COLLECTION", "ITEMS", "KEYS", "ESCAPED", "LINES", "SEPARATED", "FUNCTION", "EXTENDED",
+ "REFRESH", "CLEAR", "CACHE", "UNCACHE", "LAZY", "FORMATTED", "GLOBAL", "TEMPORARY", "TEMP",
+ "OPTIONS", "UNSET", "TBLPROPERTIES", "DBPROPERTIES", "BUCKETS", "SKEWED", "STORED",
+ "DIRECTORIES", "LOCATION", "EXCHANGE", "ARCHIVE", "UNARCHIVE", "FILEFORMAT", "TOUCH",
+ "COMPACT", "CONCATENATE", "CHANGE", "CASCADE", "RESTRICT", "CLUSTERED", "SORTED", "PURGE",
+ "INPUTFORMAT", "OUTPUTFORMAT", "DATABASE", "SCHEMA", "DATABASES", "SCHEMAS", "DFS", "TRUNCATE",
+ "ANALYZE", "COMPUTE", "LIST", "STATISTICS", "PARTITIONED", "EXTERNAL", "DEFINED", "REVOKE",
+ "GRANT", "LOCK", "UNLOCK", "MSCK", "REPAIR", "RECOVER", "EXPORT", "IMPORT", "LOAD", "ROLE",
+ "ROLES", "COMPACTIONS", "PRINCIPALS", "TRANSACTIONS", "INDEX", "INDEXES", "LOCKS", "OPTION",
+ "ANTI", "LOCAL", "INPATH")
+
+ def getKeywords: String = SPARK_KEYWORDS.filter(ODBC_KEYWORDS.contains).mkString(",")
}
http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/47d3ee6b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/LivyExecuteStatementOperation.scala
----------------------------------------------------------------------
diff --git a/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/LivyExecuteStatementOperation.scala b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/LivyExecuteStatementOperation.scala
index c2c4716..a067788 100644
--- a/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/LivyExecuteStatementOperation.scala
+++ b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/LivyExecuteStatementOperation.scala
@@ -18,31 +18,27 @@
package org.apache.livy.thriftserver
import java.security.PrivilegedExceptionAction
-import java.util
-import java.util.{Map => JMap}
import java.util.concurrent.{ConcurrentLinkedQueue, RejectedExecutionException}
import scala.collection.mutable
-import scala.collection.JavaConverters._
import scala.util.control.NonFatal
-import org.apache.hadoop.hive.serde2.thrift.{ColumnBuffer => ThriftColumnBuffer}
-import org.apache.hadoop.hive.shims.Utils
+import org.apache.hadoop.security.UserGroupInformation
import org.apache.hive.service.cli._
-import org.apache.hive.service.cli.operation.Operation
import org.apache.livy.Logging
import org.apache.livy.thriftserver.SessionStates._
+import org.apache.livy.thriftserver.operation.Operation
import org.apache.livy.thriftserver.rpc.RpcClient
-import org.apache.livy.thriftserver.types.DataTypeUtils._
+import org.apache.livy.thriftserver.serde.ThriftResultSet
+import org.apache.livy.thriftserver.types.{BasicDataType, DataTypeUtils, Field, Schema}
class LivyExecuteStatementOperation(
sessionHandle: SessionHandle,
statement: String,
- confOverlay: JMap[String, String],
runInBackground: Boolean = true,
sessionManager: LivyThriftSessionManager)
- extends Operation(sessionHandle, confOverlay, OperationType.EXECUTE_STATEMENT)
+ extends Operation(sessionHandle, OperationType.EXECUTE_STATEMENT)
with Logging {
/**
@@ -62,31 +58,23 @@ class LivyExecuteStatementOperation(
}
private var rowOffset = 0L
- private def statementId: String = getHandle.getHandleIdentifier.toString
+ private def statementId: String = opHandle.getHandleIdentifier.toString
private def rpcClientValid: Boolean =
sessionManager.livySessionState(sessionHandle) == CREATION_SUCCESS && rpcClient.isValid
- override def getNextRowSet(order: FetchOrientation, maxRowsL: Long): RowSet = {
- validateDefaultFetchOrientation(order)
- assertState(util.Arrays.asList(OperationState.FINISHED))
+ override def getNextRowSet(order: FetchOrientation, maxRowsL: Long): ThriftResultSet = {
+ validateFetchOrientation(order)
+ assertState(Seq(OperationState.FINISHED))
setHasResultSet(true)
// maxRowsL here typically maps to java.sql.Statement.getFetchSize, which is an int
val maxRows = maxRowsL.toInt
val resultSet = rpcClient.fetchResult(sessionHandle, statementId, maxRows).get()
-
- val thriftColumns = resultSet.getColumns().map { col =>
- new ThriftColumnBuffer(toHiveThriftType(col.getType()), col.getNulls(), col.getValues())
- }
- val result = new ColumnBasedSet(
- toHiveTableSchema(resultSet.getSchema()).toTypeDescriptors,
- thriftColumns.toList.asJava,
- rowOffset)
- if (resultSet.getColumns() != null && resultSet.getColumns().length > 0) {
- rowOffset += resultSet.getColumns()(0).size()
- }
- result
+ val livyColumnResultSet = ThriftResultSet(resultSet)
+ livyColumnResultSet.setRowOffset(rowOffset)
+ rowOffset += livyColumnResultSet.numRows
+ livyColumnResultSet
}
override def runInternal(): Unit = {
@@ -96,7 +84,7 @@ class LivyExecuteStatementOperation(
if (!runInBackground) {
execute()
} else {
- val livyServiceUGI = Utils.getUGI
+ val livyServiceUGI = UserGroupInformation.getCurrentUser
// Runnable impl to call runInternal asynchronously,
// from a different thread
@@ -153,7 +141,7 @@ class LivyExecuteStatementOperation(
rpcClient.executeSql(sessionHandle, statementId, statement).get()
} catch {
case e: Throwable =>
- val currentState = getStatus.getState
+ val currentState = getStatus.state
info(s"Error executing query, currentState $currentState, ", e)
setState(OperationState.ERROR)
throw new HiveSQLException(e)
@@ -171,14 +159,17 @@ class LivyExecuteStatementOperation(
cleanup(state)
}
- def getResultSetSchema: TableSchema = {
- val tableSchema = toHiveTableSchema(
+ override def shouldRunAsync: Boolean = runInBackground
+
+ override def getResultSetSchema: Schema = {
+ val tableSchema = DataTypeUtils.schemaFromSparkJson(
rpcClient.fetchResultSchema(sessionHandle, statementId).get())
// Workaround for operations returning an empty schema (eg. CREATE, INSERT, ...)
- if (tableSchema.getSize == 0) {
- tableSchema.addStringColumn("Result", "")
+ if (!tableSchema.fields.isEmpty) {
+ tableSchema
+ } else {
+ Schema(Field("Result", BasicDataType("string"), ""))
}
- tableSchema
}
private def cleanup(state: OperationState) {
http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/47d3ee6b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/LivyOperationManager.scala
----------------------------------------------------------------------
diff --git a/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/LivyOperationManager.scala b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/LivyOperationManager.scala
index c71171a..e6d48ff 100644
--- a/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/LivyOperationManager.scala
+++ b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/LivyOperationManager.scala
@@ -23,12 +23,12 @@ import java.util.concurrent.ConcurrentHashMap
import scala.collection.mutable
-import org.apache.hadoop.hive.conf.HiveConf.ConfVars
-import org.apache.hadoop.hive.metastore.api.{FieldSchema, Schema}
import org.apache.hive.service.cli._
-import org.apache.hive.service.cli.operation.{GetCatalogsOperation, GetTableTypesOperation, GetTypeInfoOperation, Operation}
-import org.apache.livy.Logging
+import org.apache.livy.{LivyConf, Logging}
+import org.apache.livy.thriftserver.operation._
+import org.apache.livy.thriftserver.serde.ThriftResultSet
+import org.apache.livy.thriftserver.session.DataType
class LivyOperationManager(val livyThriftSessionManager: LivyThriftSessionManager)
extends Logging {
@@ -37,12 +37,15 @@ class LivyOperationManager(val livyThriftSessionManager: LivyThriftSessionManage
private val sessionToOperationHandles =
new mutable.HashMap[SessionHandle, mutable.Set[OperationHandle]]()
+ private val operationTimeout =
+ livyThriftSessionManager.livyConf.getTimeAsMs(LivyConf.THRIFT_IDLE_OPERATION_TIMEOUT)
+
private def addOperation(operation: Operation, sessionHandle: SessionHandle): Unit = {
- handleToOperation.put(operation.getHandle, operation)
+ handleToOperation.put(operation.opHandle, operation)
sessionToOperationHandles.synchronized {
val set = sessionToOperationHandles.getOrElseUpdate(sessionHandle,
new mutable.HashSet[OperationHandle])
- set += operation.getHandle
+ set += operation.opHandle
}
}
@@ -52,7 +55,7 @@ class LivyOperationManager(val livyThriftSessionManager: LivyThriftSessionManage
if (operation == null) {
throw new HiveSQLException(s"Operation does not exist: $operationHandle")
}
- val sessionHandle = operation.getSessionHandle
+ val sessionHandle = operation.sessionHandle
sessionToOperationHandles.synchronized {
sessionToOperationHandles(sessionHandle) -= operationHandle
if (sessionToOperationHandles(sessionHandle).isEmpty) {
@@ -74,7 +77,7 @@ class LivyOperationManager(val livyThriftSessionManager: LivyThriftSessionManage
opHandles.flatMap { handle =>
// Some operations may have finished and been removed since we got them.
Option(handleToOperation.get(handle))
- }.filter(_.isTimedOut(currentTime))
+ }.filter(_.isTimedOut(currentTime, operationTimeout))
}
@throws[HiveSQLException]
@@ -95,7 +98,6 @@ class LivyOperationManager(val livyThriftSessionManager: LivyThriftSessionManage
val op = new LivyExecuteStatementOperation(
sessionHandle,
statement,
- confOverlay,
runAsync,
livyThriftSessionManager)
addOperation(op, sessionHandle)
@@ -107,16 +109,13 @@ class LivyOperationManager(val livyThriftSessionManager: LivyThriftSessionManage
def getOperationLogRowSet(
opHandle: OperationHandle,
orientation: FetchOrientation,
- maxRows: Long): RowSet = {
- val tableSchema = new TableSchema(LivyOperationManager.LOG_SCHEMA)
- val session = livyThriftSessionManager.getSessionInfo(getOperation(opHandle).getSessionHandle)
- val logs = RowSetFactory.create(tableSchema, session.protocolVersion, false)
-
- if (!livyThriftSessionManager.getHiveConf.getBoolVar(
- ConfVars.HIVE_SERVER2_LOGGING_OPERATION_ENABLED)) {
- warn("Try to get operation log when " +
- ConfVars.HIVE_SERVER2_LOGGING_OPERATION_ENABLED.varname +
- " is false, no log will be returned. ")
+ maxRows: Long): ThriftResultSet = {
+ val session = livyThriftSessionManager.getSessionInfo(getOperation(opHandle).sessionHandle)
+ val logs = ThriftResultSet(LivyOperationManager.LOG_SCHEMA, session.protocolVersion)
+
+ if (!livyThriftSessionManager.livyConf.getBoolean(LivyConf.THRIFT_LOG_OPERATION_ENABLED)) {
+ warn(s"Try to get operation log when ${LivyConf.THRIFT_LOG_OPERATION_ENABLED.key} is " +
+ "false, no log will be returned.")
} else {
// Get the operation log. This is implemented only for LivyExecuteStatementOperation
val operation = getOperation(opHandle)
@@ -149,7 +148,7 @@ class LivyOperationManager(val livyThriftSessionManager: LivyThriftSessionManage
var opHandle: OperationHandle = null
try {
val operation = operationCreator
- opHandle = operation.getHandle
+ opHandle = operation.opHandle
operation.run()
opHandle
} catch {
@@ -194,9 +193,9 @@ class LivyOperationManager(val livyThriftSessionManager: LivyThriftSessionManage
@throws[HiveSQLException]
def cancelOperation(opHandle: OperationHandle, errMsg: String): Unit = {
val operation = getOperation(opHandle)
- val opState = operation.getStatus.getState
+ val opState = operation.getStatus.state
if (opState.isTerminal) {
- // Cancel should be a no-op
+ // Cancel should be a no-op either case
debug(s"$opHandle: Operation is already aborted in state - $opState")
} else {
debug(s"$opHandle: Attempting to cancel from state - $opState")
@@ -223,7 +222,7 @@ class LivyOperationManager(val livyThriftSessionManager: LivyThriftSessionManage
opHandle: OperationHandle,
orientation: FetchOrientation,
maxRows: Long,
- fetchType: FetchType): RowSet = {
+ fetchType: FetchType): ThriftResultSet = {
if (fetchType == FetchType.QUERY_OUTPUT) {
getOperation(opHandle).getNextRowSet(orientation, maxRows)
} else {
@@ -233,12 +232,5 @@ class LivyOperationManager(val livyThriftSessionManager: LivyThriftSessionManage
}
object LivyOperationManager {
- val LOG_SCHEMA: Schema = {
- val schema = new Schema
- val fieldSchema = new FieldSchema
- fieldSchema.setName("operation_log")
- fieldSchema.setType("string")
- schema.addToFieldSchemas(fieldSchema)
- schema
- }
+ val LOG_SCHEMA: Array[DataType] = Array(DataType.STRING)
}
http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/47d3ee6b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/LivyThriftServer.scala
----------------------------------------------------------------------
diff --git a/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/LivyThriftServer.scala b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/LivyThriftServer.scala
index c670217..daf7b82 100644
--- a/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/LivyThriftServer.scala
+++ b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/LivyThriftServer.scala
@@ -19,17 +19,14 @@ package org.apache.livy.thriftserver
import java.security.PrivilegedExceptionAction
-import scala.collection.JavaConverters._
-
-import org.apache.hadoop.hive.conf.HiveConf
import org.apache.hadoop.security.UserGroupInformation
-import org.apache.hive.service.server.HiveServer2
import org.apache.livy.{LivyConf, Logging}
import org.apache.livy.server.AccessManager
import org.apache.livy.server.interactive.InteractiveSession
import org.apache.livy.server.recovery.SessionStore
import org.apache.livy.sessions.InteractiveSessionManager
+import org.apache.livy.thriftserver.cli.{ThriftBinaryCLIService, ThriftHttpCLIService}
/**
* The main entry point for the Livy thrift server leveraging HiveServer2. Starts up a
@@ -41,21 +38,6 @@ object LivyThriftServer extends Logging {
private[thriftserver] var thriftServerThread: Thread = _
private var thriftServer: LivyThriftServer = _
- private def hiveConf(livyConf: LivyConf): HiveConf = {
- val conf = new HiveConf()
- // Remove all configs coming from hive-site.xml which may be in the classpath for the Spark
- // applications to run.
- conf.getAllProperties.asScala.filter(_._1.startsWith("hive.")).foreach { case (key, _) =>
- conf.unset(key)
- }
- livyConf.asScala.foreach {
- case nameAndValue if nameAndValue.getKey.startsWith("livy.hive") =>
- conf.set(nameAndValue.getKey.stripPrefix("livy."), nameAndValue.getValue)
- case _ => // Ignore
- }
- conf
- }
-
def start(
livyConf: LivyConf,
livySessionManager: InteractiveSessionManager,
@@ -97,7 +79,7 @@ object LivyThriftServer extends Logging {
}
private def doStart(livyConf: LivyConf): Unit = {
- thriftServer.init(hiveConf(livyConf))
+ thriftServer.init(livyConf)
thriftServer.start()
}
@@ -114,6 +96,11 @@ object LivyThriftServer extends Logging {
thriftServer.stop()
thriftServer = null
}
+
+ def isHTTPTransportMode(livyConf: LivyConf): Boolean = {
+ val transportMode = livyConf.get(LivyConf.THRIFT_TRANSPORT_MODE)
+ transportMode != null && transportMode.equalsIgnoreCase("http")
+ }
}
@@ -121,17 +108,38 @@ class LivyThriftServer(
private[thriftserver] val livyConf: LivyConf,
private[thriftserver] val livySessionManager: InteractiveSessionManager,
private[thriftserver] val sessionStore: SessionStore,
- private[thriftserver] val accessManager: AccessManager) extends HiveServer2 {
- override def init(hiveConf: HiveConf): Unit = {
- this.cliService = new LivyCLIService(this)
- super.init(hiveConf)
+ private[thriftserver] val accessManager: AccessManager)
+ extends ThriftService(classOf[LivyThriftServer].getName) with Logging {
+
+ val cliService = new LivyCLIService(this)
+
+ override def init(livyConf: LivyConf): Unit = {
+ addService(cliService)
+ val server = this
+ val oomHook = new Runnable() {
+ override def run(): Unit = {
+ server.stop()
+ }
+ }
+ val thriftCLIService = if (LivyThriftServer.isHTTPTransportMode(livyConf)) {
+ new ThriftHttpCLIService(cliService, oomHook)
+ } else {
+ new ThriftBinaryCLIService(cliService, oomHook)
+ }
+ addService(thriftCLIService)
+ super.init(livyConf)
}
- private[thriftserver] def getSessionManager(): LivyThriftSessionManager = {
- this.cliService.asInstanceOf[LivyCLIService].getSessionManager
+ private[thriftserver] def getSessionManager = {
+ cliService.getSessionManager
}
def isAllowedToUse(user: String, session: InteractiveSession): Boolean = {
session.owner == user || accessManager.checkModifyPermissions(user)
}
+
+ override def stop(): Unit = {
+ info("Shutting down LivyThriftServer")
+ super.stop()
+ }
}
http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/47d3ee6b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/LivyThriftSessionManager.scala
----------------------------------------------------------------------
diff --git a/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/LivyThriftSessionManager.scala b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/LivyThriftSessionManager.scala
index 344a990..5be5536 100644
--- a/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/LivyThriftSessionManager.scala
+++ b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/LivyThriftSessionManager.scala
@@ -31,10 +31,7 @@ import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.duration.Duration
import scala.util.{Failure, Success, Try}
-import org.apache.hadoop.hive.conf.HiveConf
-import org.apache.hadoop.hive.conf.HiveConf.ConfVars
-import org.apache.hadoop.hive.shims.Utils
-import org.apache.hive.service.CompositeService
+import org.apache.hadoop.security.UserGroupInformation
import org.apache.hive.service.cli.{HiveSQLException, SessionHandle}
import org.apache.hive.service.rpc.thrift.TProtocolVersion
import org.apache.hive.service.server.ThreadFactoryWithGarbageCleanup
@@ -47,8 +44,9 @@ import org.apache.livy.thriftserver.SessionStates._
import org.apache.livy.thriftserver.rpc.RpcClient
import org.apache.livy.utils.LivySparkUtils
-class LivyThriftSessionManager(val server: LivyThriftServer, hiveConf: HiveConf)
- extends CompositeService(classOf[LivyThriftSessionManager].getName) with Logging {
+
+class LivyThriftSessionManager(val server: LivyThriftServer, val livyConf: LivyConf)
+ extends ThriftService(classOf[LivyThriftSessionManager].getName) with Logging {
private[thriftserver] val operationManager = new LivyOperationManager(this)
private val sessionHandleToLivySession =
@@ -79,17 +77,13 @@ class LivyThriftSessionManager(val server: LivyThriftServer, hiveConf: HiveConf)
}
// Configs from Hive
- private val userLimit = hiveConf.getIntVar(ConfVars.HIVE_SERVER2_LIMIT_CONNECTIONS_PER_USER)
- private val ipAddressLimit =
- hiveConf.getIntVar(ConfVars.HIVE_SERVER2_LIMIT_CONNECTIONS_PER_IPADDRESS)
+ private val userLimit = livyConf.getInt(LivyConf.THRIFT_LIMIT_CONNECTIONS_PER_USER)
+ private val ipAddressLimit = livyConf.getInt(LivyConf.THRIFT_LIMIT_CONNECTIONS_PER_IPADDRESS)
private val userIpAddressLimit =
- hiveConf.getIntVar(ConfVars.HIVE_SERVER2_LIMIT_CONNECTIONS_PER_USER_IPADDRESS)
- private val checkInterval = HiveConf.getTimeVar(
- hiveConf, ConfVars.HIVE_SERVER2_SESSION_CHECK_INTERVAL, TimeUnit.MILLISECONDS)
- private val sessionTimeout = HiveConf.getTimeVar(
- hiveConf, ConfVars.HIVE_SERVER2_IDLE_SESSION_TIMEOUT, TimeUnit.MILLISECONDS)
- private val checkOperation = HiveConf.getBoolVar(
- hiveConf, ConfVars.HIVE_SERVER2_IDLE_SESSION_CHECK_OPERATION)
+ livyConf.getInt(LivyConf.THRIFT_LIMIT_CONNECTIONS_PER_USER_IPADDRESS)
+ private val checkInterval = livyConf.getTimeAsMs(LivyConf.THRIFT_SESSION_CHECK_INTERVAL)
+ private val sessionTimeout = livyConf.getTimeAsMs(LivyConf.THRIFT_IDLE_SESSION_TIMEOUT)
+ private val checkOperation = livyConf.getBoolean(LivyConf.THRIFT_IDLE_SESSION_CHECK_OPERATION)
private var backgroundOperationPool: ThreadPoolExecutor = _
@@ -243,7 +237,7 @@ class LivyThriftSessionManager(val server: LivyThriftServer, hiveConf: HiveConf)
newSession
}
val futureLivySession = Future {
- val livyServiceUGI = Utils.getUGI
+ val livyServiceUGI = UserGroupInformation.getCurrentUser
var livySession: InteractiveSession = null
try {
livyServiceUGI.doAs(new PrivilegedExceptionAction[InteractiveSession] {
@@ -294,21 +288,20 @@ class LivyThriftSessionManager(val server: LivyThriftServer, hiveConf: HiveConf)
}
// Taken from Hive
- override def init(hiveConf: HiveConf): Unit = {
- createBackgroundOperationPool(hiveConf)
+ override def init(livyConf: LivyConf): Unit = {
+ createBackgroundOperationPool(livyConf)
info("Connections limit are user: {} ipaddress: {} user-ipaddress: {}",
userLimit, ipAddressLimit, userIpAddressLimit)
- super.init(hiveConf)
+ super.init(livyConf)
}
// Taken from Hive
- private def createBackgroundOperationPool(hiveConf: HiveConf): Unit = {
- val poolSize = hiveConf.getIntVar(ConfVars.HIVE_SERVER2_ASYNC_EXEC_THREADS)
+ private def createBackgroundOperationPool(livyConf: LivyConf): Unit = {
+ val poolSize = livyConf.getInt(LivyConf.THRIFT_ASYNC_EXEC_THREADS)
info("HiveServer2: Background operation thread pool size: " + poolSize)
- val poolQueueSize = hiveConf.getIntVar(ConfVars.HIVE_SERVER2_ASYNC_EXEC_WAIT_QUEUE_SIZE)
+ val poolQueueSize = livyConf.getInt(LivyConf.THRIFT_ASYNC_EXEC_WAIT_QUEUE_SIZE)
info("HiveServer2: Background operation thread wait queue size: " + poolQueueSize)
- val keepAliveTime = HiveConf.getTimeVar(
- hiveConf, ConfVars.HIVE_SERVER2_ASYNC_EXEC_KEEPALIVE_TIME, TimeUnit.SECONDS)
+ val keepAliveTime = livyConf.getTimeAsMs(LivyConf.THRIFT_ASYNC_EXEC_KEEPALIVE_TIME) / 1000
info(s"HiveServer2: Background operation thread keepalive time: $keepAliveTime seconds")
// Create a thread pool with #poolSize threads
// Threads terminate when they are idle for more than the keepAliveTime
@@ -364,11 +357,11 @@ class LivyThriftSessionManager(val server: LivyThriftServer, hiveConf: HiveConf)
if (operations.nonEmpty) {
operations.foreach { op =>
try {
- warn(s"Operation ${op.getHandle} is timed-out and will be closed")
- operationManager.closeOperation(op.getHandle)
+ warn(s"Operation ${op.opHandle} is timed-out and will be closed")
+ operationManager.closeOperation(op.opHandle)
} catch {
case e: Exception =>
- warn("Exception is thrown closing timed-out operation: " + op.getHandle, e)
+ warn("Exception is thrown closing timed-out operation: " + op.opHandle, e)
}
}
}
@@ -405,14 +398,13 @@ class LivyThriftSessionManager(val server: LivyThriftServer, hiveConf: HiveConf)
shutdownTimeoutChecker()
if (backgroundOperationPool != null) {
backgroundOperationPool.shutdown()
- val timeout =
- hiveConf.getTimeVar(ConfVars.HIVE_SERVER2_ASYNC_EXEC_SHUTDOWN_TIMEOUT, TimeUnit.SECONDS)
+ val timeout = livyConf.getTimeAsMs(LivyConf.THRIFT_ASYNC_EXEC_SHUTDOWN_TIMEOUT) / 1000
try {
backgroundOperationPool.awaitTermination(timeout, TimeUnit.SECONDS)
} catch {
case e: InterruptedException =>
- warn("HIVE_SERVER2_ASYNC_EXEC_SHUTDOWN_TIMEOUT = " + timeout +
- " seconds has been exceeded. RUNNING background operations will be shut down", e)
+ warn(s"THRIFT_ASYNC_EXEC_SHUTDOWN_TIMEOUT = $timeout seconds has been exceeded. " +
+ "RUNNING background operations will be shut down", e)
}
backgroundOperationPool = null
}
http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/47d3ee6b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/ThriftServerFactoryImpl.scala
----------------------------------------------------------------------
diff --git a/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/ThriftServerFactoryImpl.scala b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/ThriftServerFactoryImpl.scala
index 7669580..16f903a 100644
--- a/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/ThriftServerFactoryImpl.scala
+++ b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/ThriftServerFactoryImpl.scala
@@ -38,6 +38,11 @@ class ThriftServerFactoryImpl extends ThriftServerFactory {
LivyThriftServer.start(livyConf, livySessionManager, sessionStore, accessManager)
}
+ override def stop(): Unit = {
+ assert(LivyThriftServer.getInstance.isDefined)
+ LivyThriftServer.getInstance.foreach(_.stop())
+ }
+
override def getServlet(basePath: String): Servlet = new ThriftJsonServlet(basePath)
override def getServletMappings: Seq[String] = Seq("/thriftserver/*")
http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/47d3ee6b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/ThriftService.scala
----------------------------------------------------------------------
diff --git a/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/ThriftService.scala b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/ThriftService.scala
new file mode 100644
index 0000000..467592a
--- /dev/null
+++ b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/ThriftService.scala
@@ -0,0 +1,159 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.livy.thriftserver
+
+import scala.collection.mutable
+
+import org.apache.hive.service.ServiceException
+
+import org.apache.livy.{LivyConf, Logging}
+
+/**
+ * Service states
+ */
+object STATE extends Enumeration {
+ type STATE = Value
+ val
+
+ /** Constructed but not initialized */
+ NOTINITED,
+
+ /** Initialized but not started or stopped */
+ INITED,
+
+ /** started and not stopped */
+ STARTED,
+
+ /** stopped. No further state transitions are permitted */
+ STOPPED = Value
+}
+
+class ThriftService(val name: String) extends Logging {
+ private val serviceList = new mutable.ListBuffer[ThriftService]
+
+ /**
+ * Service state: initially {@link STATE#NOTINITED}.
+ */
+ private var state = STATE.NOTINITED
+
+ /**
+ * Service start time. Will be zero until the service is started.
+ */
+ private var startTime = 0L
+
+ def getServices: Seq[ThriftService] = serviceList.toList
+
+ protected def addService(service: ThriftService): Unit = {
+ serviceList += service
+ }
+
+ protected def removeService(service: ThriftService): Unit = serviceList -= service
+
+ def init(conf: LivyConf): Unit = {
+ serviceList.foreach(_.init(conf))
+ ensureCurrentState(STATE.NOTINITED)
+ changeState(STATE.INITED)
+ info(s"Service:$getName is inited.")
+ }
+
+ def start(): Unit = {
+ var i = 0
+ try {
+ val n = serviceList.size
+ while (i < n) {
+ val service = serviceList(i)
+ service.start()
+ i += 1
+ }
+ startTime = System.currentTimeMillis
+ ensureCurrentState(STATE.INITED)
+ changeState(STATE.STARTED)
+ info(s"Service:$getName is started.")
+ } catch {
+ case e: Throwable =>
+ error("Error starting services " + getName, e)
+ // Note that the state of the failed service is still INITED and not
+ // STARTED. Even though the last service is not started completely, still
+ // call stop() on all services including failed service to make sure cleanup
+ // happens.
+ stop(i)
+ throw new ServiceException("Failed to Start " + getName, e)
+ }
+ }
+
+ def stop(): Unit = {
+ if (this.getServiceState == STATE.STOPPED) {
+ // The base composite-service is already stopped, don't do anything again.
+ return
+ }
+ if (serviceList.nonEmpty) stop(serviceList.size - 1)
+ if ((state == STATE.STOPPED) || (state == STATE.INITED) || (state == STATE.NOTINITED)) {
+ // already stopped, or else it was never
+ // started (eg another service failing canceled startup)
+ return
+ }
+ ensureCurrentState(STATE.STARTED)
+ changeState(STATE.STOPPED)
+ info(s"Service:$getName is stopped.")
+ }
+
+ private def stop(numOfServicesStarted: Int): Unit = {
+ // stop in reverse order of start
+ var i = numOfServicesStarted
+ while (i >= 0) {
+ val service = serviceList(i)
+ try {
+ service.stop()
+ } catch {
+ case t: Throwable => info("Error stopping " + service.getName, t)
+ }
+ i -= 1
+ }
+ }
+
+ def getServiceState: STATE.Value = state
+
+ def getName: String = name
+
+ def getStartTime: Long = startTime
+
+ /**
+ * Verify that that a service is in a given state.
+ *
+ * @param currentState
+ * the desired state
+ * @throws IllegalStateException
+ * if the service state is different from
+ * the desired state
+ */
+ private def ensureCurrentState(currentState: STATE.Value): Unit = {
+ if (state != currentState) {
+ throw new IllegalStateException(
+ s"For this operation, the current service state must be $currentState instead of $state")
+ }
+ }
+
+ /**
+ * Change to a new state.
+ *
+ * @param newState new service state
+ */
+ private def changeState(newState: STATE.Value): Unit = {
+ state = newState
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/47d3ee6b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/auth/AuthBridgeServer.scala
----------------------------------------------------------------------
diff --git a/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/auth/AuthBridgeServer.scala b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/auth/AuthBridgeServer.scala
new file mode 100644
index 0000000..d2091e6
--- /dev/null
+++ b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/auth/AuthBridgeServer.scala
@@ -0,0 +1,299 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.livy.thriftserver.auth
+
+import java.io.IOException
+import java.net.InetAddress
+import java.security.{PrivilegedAction, PrivilegedExceptionAction}
+import java.util
+import javax.security.auth.callback.{Callback, CallbackHandler, NameCallback, PasswordCallback, UnsupportedCallbackException}
+import javax.security.sasl.{AuthorizeCallback, RealmCallback, SaslServer}
+
+import org.apache.commons.codec.binary.Base64
+import org.apache.hadoop.fs.FileSystem
+import org.apache.hadoop.security.{SaslRpcServer, UserGroupInformation}
+import org.apache.hadoop.security.SaslRpcServer.AuthMethod
+import org.apache.hadoop.security.UserGroupInformation.AuthenticationMethod
+import org.apache.hadoop.security.token.SecretManager.InvalidToken
+import org.apache.thrift.{TException, TProcessor}
+import org.apache.thrift.protocol.TProtocol
+import org.apache.thrift.transport.{TSaslServerTransport, TSocket, TTransport, TTransportException, TTransportFactory}
+
+import org.apache.livy.Logging
+
+/**
+ * The class is taken from Hive's `HadoopThriftAuthBridge.Server`. It bridges Thrift's SASL
+ * transports to Hadoop's SASL callback handlers and authentication classes.
+ */
+class AuthBridgeServer(private val secretManager: LivyDelegationTokenSecretManager) {
+ private val ugi = try {
+ UserGroupInformation.getCurrentUser
+ } catch {
+ case ioe: IOException => throw new TTransportException(ioe)
+ }
+
+ /**
+ * Create a TTransportFactory that, upon connection of a client socket,
+ * negotiates a Kerberized SASL transport. The resulting TTransportFactory
+ * can be passed as both the input and output transport factory when
+ * instantiating a TThreadPoolServer, for example.
+ *
+ * @param saslProps Map of SASL properties
+ */
+ @throws[TTransportException]
+ def createTransportFactory(saslProps: util.Map[String, String]): TTransportFactory = {
+ val transFactory: TSaslServerTransport.Factory = createSaslServerTransportFactory(saslProps)
+ new TUGIAssumingTransportFactory(transFactory, ugi)
+ }
+
+ /**
+ * Create a TSaslServerTransport.Factory that, upon connection of a client
+ * socket, negotiates a Kerberized SASL transport.
+ *
+ * @param saslProps Map of SASL properties
+ */
+ @throws[TTransportException]
+ def createSaslServerTransportFactory(
+ saslProps: util.Map[String, String]): TSaslServerTransport.Factory = {
+ // Parse out the kerberos principal, host, realm.
+ val kerberosName: String = ugi.getUserName
+ val names: Array[String] = SaslRpcServer.splitKerberosName(kerberosName)
+ if (names.length != 3) {
+ throw new TTransportException(s"Kerberos principal should have 3 parts: $kerberosName")
+ }
+ val transFactory: TSaslServerTransport.Factory = new TSaslServerTransport.Factory
+ transFactory.addServerDefinition(AuthMethod.KERBEROS.getMechanismName,
+ names(0), names(1), // two parts of kerberos principal
+ saslProps,
+ new SaslRpcServer.SaslGssCallbackHandler)
+ transFactory.addServerDefinition(AuthMethod.TOKEN.getMechanismName,
+ null,
+ SaslRpcServer.SASL_DEFAULT_REALM,
+ saslProps,
+ new SaslDigestCallbackHandler(secretManager))
+ transFactory
+ }
+
+ /**
+ * Wrap a TTransportFactory in such a way that, before processing any RPC, it
+ * assumes the UserGroupInformation of the user authenticated by
+ * the SASL transport.
+ */
+ def wrapTransportFactory(transFactory: TTransportFactory): TTransportFactory = {
+ new TUGIAssumingTransportFactory(transFactory, ugi)
+ }
+
+ /**
+ * Wrap a TProcessor in such a way that, before processing any RPC, it
+ * assumes the UserGroupInformation of the user authenticated by
+ * the SASL transport.
+ */
+ def wrapProcessor(processor: TProcessor): TProcessor = {
+ new TUGIAssumingProcessor(processor, secretManager, true)
+ }
+
+ /**
+ * Wrap a TProcessor to capture the client information like connecting userid, ip etc
+ */
+ def wrapNonAssumingProcessor(processor: TProcessor): TProcessor = {
+ new TUGIAssumingProcessor(processor, secretManager, false)
+ }
+
+ def getRemoteAddress: InetAddress = AuthBridgeServer.remoteAddress.get
+
+ def getRemoteUser: String = AuthBridgeServer.remoteUser.get
+
+ def getUserAuthMechanism: String = AuthBridgeServer.userAuthMechanism.get
+
+}
+
+/**
+ * A TransportFactory that wraps another one, but assumes a specified UGI
+ * before calling through.
+ *
+ * This is used on the server side to assume the server's Principal when accepting
+ * clients.
+ *
+ * This class is derived from Hive's one.
+ */
+private[auth] class TUGIAssumingTransportFactory(
+ val wrapped: TTransportFactory,
+ val ugi: UserGroupInformation) extends TTransportFactory {
+ assert(wrapped != null)
+ assert(ugi != null)
+
+ override def getTransport(trans: TTransport): TTransport = {
+ ugi.doAs(new PrivilegedAction[TTransport]() {
+ override def run: TTransport = wrapped.getTransport(trans)
+ })
+ }
+}
+
+/**
+ * CallbackHandler for SASL DIGEST-MD5 mechanism.
+ *
+ * This code is pretty much completely based on Hadoop's SaslRpcServer.SaslDigestCallbackHandler -
+ * the only reason we could not use that Hadoop class as-is was because it needs a
+ * Server.Connection.
+ */
+sealed class SaslDigestCallbackHandler(
+ val secretManager: LivyDelegationTokenSecretManager) extends CallbackHandler with Logging {
+ @throws[InvalidToken]
+ private def getPassword(tokenId: LivyDelegationTokenIdentifier): Array[Char] = {
+ encodePassword(secretManager.retrievePassword(tokenId))
+ }
+
+ private def encodePassword(password: Array[Byte]): Array[Char] = {
+ new String(Base64.encodeBase64(password)).toCharArray
+ }
+
+ @throws[InvalidToken]
+ @throws[UnsupportedCallbackException]
+ override def handle(callbacks: Array[Callback]): Unit = {
+ var nc: NameCallback = null
+ var pc: PasswordCallback = null
+ callbacks.foreach {
+ case ac: AuthorizeCallback =>
+ val authid: String = ac.getAuthenticationID
+ val authzid: String = ac.getAuthorizationID
+ if (authid == authzid) {
+ ac.setAuthorized(true)
+ } else {
+ ac.setAuthorized(false)
+ }
+ if (ac.isAuthorized) {
+ if (logger.isDebugEnabled) {
+ val username = SaslRpcServer.getIdentifier(authzid, secretManager).getUser.getUserName
+ debug(s"SASL server DIGEST-MD5 callback: setting canonicalized client ID: $username")
+ }
+ ac.setAuthorizedID(authzid)
+ }
+ case c: NameCallback => nc = c
+ case c: PasswordCallback => pc = c
+ case _: RealmCallback => // Do nothing.
+ case other =>
+ throw new UnsupportedCallbackException(other, "Unrecognized SASL DIGEST-MD5 Callback")
+ }
+ if (pc != null) {
+ val tokenIdentifier = SaslRpcServer.getIdentifier(nc.getDefaultName, secretManager)
+ val password: Array[Char] = getPassword(tokenIdentifier)
+ if (logger.isDebugEnabled) {
+ debug("SASL server DIGEST-MD5 callback: setting password for client: " +
+ tokenIdentifier.getUser)
+ }
+ pc.setPassword(password)
+ }
+ }
+}
+
+/**
+ * Processor that pulls the SaslServer object out of the transport, and assumes the remote user's
+ * UGI before calling through to the original processor.
+ *
+ * This is used on the server side to set the UGI for each specific call.
+ *
+ * This class is derived from Hive's one.
+ */
+sealed class TUGIAssumingProcessor(
+ val wrapped: TProcessor,
+ val secretManager: LivyDelegationTokenSecretManager,
+ var useProxy: Boolean) extends TProcessor with Logging {
+
+ @throws[TException]
+ override def process(inProt: TProtocol, outProt: TProtocol): Boolean = {
+ val trans = inProt.getTransport
+ if (!trans.isInstanceOf[TSaslServerTransport]) {
+ throw new TException(s"Unexpected non-SASL transport ${trans.getClass}")
+ }
+ val saslTrans: TSaslServerTransport = trans.asInstanceOf[TSaslServerTransport]
+ val saslServer: SaslServer = saslTrans.getSaslServer
+ val authId: String = saslServer.getAuthorizationID
+ debug(s"AUTH ID ======> $authId")
+ var endUser = authId
+ val socket = saslTrans.getUnderlyingTransport.asInstanceOf[TSocket].getSocket
+ AuthBridgeServer.remoteAddress.set(socket.getInetAddress)
+ val mechanismName: String = saslServer.getMechanismName
+ AuthBridgeServer.userAuthMechanism.set(mechanismName)
+ if (AuthMethod.PLAIN.getMechanismName.equalsIgnoreCase(mechanismName)) {
+ AuthBridgeServer.remoteUser.set(endUser)
+ return wrapped.process(inProt, outProt)
+ }
+ AuthBridgeServer.authenticationMethod.set(UserGroupInformation.AuthenticationMethod.KERBEROS)
+ if (AuthMethod.TOKEN.getMechanismName.equalsIgnoreCase(mechanismName)) {
+ try {
+ val tokenId = SaslRpcServer.getIdentifier(authId, secretManager)
+ endUser = tokenId.getUser.getUserName
+ AuthBridgeServer.authenticationMethod.set(UserGroupInformation.AuthenticationMethod.TOKEN)
+ } catch {
+ case e: InvalidToken => throw new TException(e.getMessage)
+ }
+ }
+ var clientUgi: UserGroupInformation = null
+ try {
+ if (useProxy) {
+ clientUgi = UserGroupInformation.createProxyUser(
+ endUser, UserGroupInformation.getLoginUser)
+ AuthBridgeServer.remoteUser.set(clientUgi.getShortUserName)
+ debug(s"Set remoteUser : ${AuthBridgeServer.remoteUser.get}")
+ clientUgi.doAs(new PrivilegedExceptionAction[Boolean]() {
+ override def run: Boolean = try {
+ wrapped.process(inProt, outProt)
+ } catch {
+ case te: TException => throw new RuntimeException(te)
+ }
+ })
+ } else {
+ // use the short user name for the request
+ val endUserUgi: UserGroupInformation = UserGroupInformation.createRemoteUser(endUser)
+ AuthBridgeServer.remoteUser.set(endUserUgi.getShortUserName)
+ debug(s"Set remoteUser: ${AuthBridgeServer.remoteUser.get}, from endUser :" + endUser)
+ wrapped.process(inProt, outProt)
+ }
+ } catch {
+ case rte: RuntimeException if rte.getCause.isInstanceOf[TException] => throw rte.getCause
+ case rte: RuntimeException => throw rte
+ case ie: InterruptedException => throw new RuntimeException(ie) // unexpected!
+ case ioe: IOException => throw new RuntimeException(ioe)
+ } finally {
+ if (clientUgi != null) {
+ try {
+ FileSystem.closeAllForUGI(clientUgi)
+ } catch {
+ case exception: IOException =>
+ error(s"Could not clean up file-system handles for UGI: $clientUgi", exception)
+ }
+ }
+ }
+ }
+}
+
+object AuthBridgeServer extends Logging {
+ private[auth] val remoteAddress: ThreadLocal[InetAddress] = new ThreadLocal[InetAddress]() {
+ override protected def initialValue: InetAddress = null
+ }
+ private[auth] val authenticationMethod: ThreadLocal[AuthenticationMethod] =
+ new ThreadLocal[AuthenticationMethod]() {
+ override protected def initialValue: AuthenticationMethod = AuthenticationMethod.TOKEN
+ }
+ private[auth] val remoteUser: ThreadLocal[String] = new ThreadLocal[String]() {
+ override protected def initialValue: String = null
+ }
+ private[auth] val userAuthMechanism: ThreadLocal[String] = new ThreadLocal[String]() {
+ override protected def initialValue: String = AuthMethod.KERBEROS.getMechanismName
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/47d3ee6b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/auth/AuthFactory.scala
----------------------------------------------------------------------
diff --git a/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/auth/AuthFactory.scala b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/auth/AuthFactory.scala
new file mode 100644
index 0000000..6ac61d2
--- /dev/null
+++ b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/auth/AuthFactory.scala
@@ -0,0 +1,196 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.livy.thriftserver.auth
+
+import java.io.IOException
+import java.util
+import javax.security.auth.callback._
+import javax.security.auth.login.LoginException
+import javax.security.sasl.{AuthorizeCallback, Sasl}
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.CommonConfigurationKeysPublic.HADOOP_SECURITY_AUTHENTICATION
+import org.apache.hadoop.security.SaslRpcServer.AuthMethod
+import org.apache.hive.service.auth.{SaslQOP, TSetIpAddressProcessor}
+import org.apache.hive.service.auth.AuthenticationProviderFactory.AuthMethods
+import org.apache.hive.service.auth.HiveAuthConstants.AuthTypes
+import org.apache.hive.service.cli.HiveSQLException
+import org.apache.hive.service.rpc.thrift.TCLIService
+import org.apache.hive.service.rpc.thrift.TCLIService.Iface
+import org.apache.thrift.{TProcessor, TProcessorFactory}
+import org.apache.thrift.transport.{TTransport, TTransportException, TTransportFactory}
+
+import org.apache.livy.{LivyConf, Logging}
+import org.apache.livy.thriftserver.cli.ThriftCLIService
+
+/**
+ * This class is a porting of the parts we use from `HiveAuthFactory` by Hive.
+ */
+class AuthFactory(val conf: LivyConf) extends Logging {
+
+ private val authTypeStr = conf.get(LivyConf.THRIFT_AUTHENTICATION)
+ // ShimLoader.getHadoopShims().isSecurityEnabled() will only check that
+ // hadoopAuth is not simple, it does not guarantee it is kerberos
+ private val hadoopAuth = new Configuration().get(HADOOP_SECURITY_AUTHENTICATION)
+
+ private val secretManager = if (isSASLWithKerberizedHadoop) {
+ val sm = new LivyDelegationTokenSecretManager(conf)
+ try {
+ sm.startThreads()
+ } catch {
+ case e: IOException =>
+ throw new TTransportException("Failed to start token manager", e)
+ }
+ Some(sm)
+ } else {
+ None
+ }
+
+ private val saslServer: Option[AuthBridgeServer] = secretManager.map { sm =>
+ new AuthBridgeServer(sm)
+ }
+
+ def getSaslProperties: util.Map[String, String] = {
+ val saslProps = new util.HashMap[String, String]
+ val saslQOP = SaslQOP.fromString(conf.get(LivyConf.THRIFT_SASL_QOP))
+ saslProps.put(Sasl.QOP, saslQOP.toString)
+ saslProps.put(Sasl.SERVER_AUTH, "true")
+ saslProps
+ }
+
+ @throws[LoginException]
+ def getAuthTransFactory: TTransportFactory = {
+ val isAuthKerberos = authTypeStr.equalsIgnoreCase(AuthTypes.KERBEROS.getAuthName)
+ val isAuthNoSASL = authTypeStr.equalsIgnoreCase(AuthTypes.NOSASL.getAuthName)
+ // TODO: add LDAP and PAM when supported
+ val isAuthOther = authTypeStr.equalsIgnoreCase(AuthTypes.NONE.getAuthName) ||
+ authTypeStr.equalsIgnoreCase(AuthTypes.CUSTOM.getAuthName)
+
+ saslServer.map { server =>
+ val serverTransportFactory = try {
+ server.createSaslServerTransportFactory(getSaslProperties)
+ } catch {
+ case e: TTransportException =>
+ throw new LoginException(e.getMessage)
+ }
+ if (isAuthOther) {
+ PlainSaslServer.addPlainServerDefinition(serverTransportFactory, authTypeStr, conf)
+ } else if (!isAuthKerberos) {
+ throw new LoginException(s"Unsupported authentication type $authTypeStr")
+ }
+ server.wrapTransportFactory(serverTransportFactory)
+ }.getOrElse {
+ if (isAuthOther) {
+ PlainSaslServer.getPlainTransportFactory(authTypeStr, conf)
+ } else if (isAuthNoSASL) {
+ new TTransportFactory
+ } else {
+ throw new LoginException(s"Unsupported authentication type $authTypeStr")
+ }
+ }
+ }
+
+ /**
+ * Returns the thrift processor factory for binary mode
+ */
+ @throws[LoginException]
+ def getAuthProcFactory(service: ThriftCLIService): TProcessorFactory = {
+ if (saslServer.isDefined) {
+ new CLIServiceProcessorFactory(service, saslServer.get)
+ } else {
+ new SQLPlainProcessorFactory(service)
+ }
+ }
+
+ def getRemoteUser: String = saslServer.map(_.getRemoteUser).orNull
+
+ def getIpAddress: String =
+ saslServer.flatMap(s => Option(s.getRemoteAddress)).map(_.getHostAddress).orNull
+
+ def getUserAuthMechanism: String = saslServer.map(_.getUserAuthMechanism).orNull
+
+ def isSASLWithKerberizedHadoop: Boolean = {
+ "kerberos".equalsIgnoreCase(hadoopAuth) &&
+ !authTypeStr.equalsIgnoreCase(AuthTypes.NOSASL.getAuthName)
+ }
+
+ def isSASLKerberosUser: Boolean = {
+ AuthMethod.KERBEROS.getMechanismName == getUserAuthMechanism ||
+ AuthMethod.TOKEN.getMechanismName == getUserAuthMechanism
+ }
+
+ @throws[HiveSQLException]
+ def verifyDelegationToken(delegationToken: String): String = {
+ if (secretManager.isEmpty) {
+ throw new HiveSQLException(
+ "Delegation token only supported over kerberos authentication", "08S01")
+ }
+ try {
+ secretManager.get.verifyDelegationToken(delegationToken)
+ } catch {
+ case e: IOException =>
+ val msg = s"Error verifying delegation token $delegationToken"
+ error(msg, e)
+ throw new HiveSQLException(msg, "08S01", e)
+ }
+ }
+}
+
+class SQLPlainProcessorFactory(val service: Iface) extends TProcessorFactory(null) {
+
+ override def getProcessor(trans: TTransport): TProcessor = {
+ new TSetIpAddressProcessor[Iface](service)
+ }
+}
+
+class CLIServiceProcessorFactory(val service: Iface, val saslServer: AuthBridgeServer)
+ extends TProcessorFactory(null) {
+
+ override def getProcessor(trans: TTransport): TProcessor = {
+ val sqlProcessor = new TCLIService.Processor[Iface](service)
+ saslServer.wrapNonAssumingProcessor(sqlProcessor)
+ }
+}
+
+/**
+ * This is copied from Hive because its constructor is not accessible.
+ */
+class PlainServerCallbackHandler(authMethodStr: String, livyConf: LivyConf)
+ extends CallbackHandler {
+
+ private val authMethod: AuthMethods = AuthMethods.getValidAuthMethod(authMethodStr)
+
+ override def handle(callbacks: Array[Callback]): Unit = {
+ var username: String = null
+ var password: String = null
+ var ac: AuthorizeCallback = null
+
+ callbacks.foreach {
+ case nc: NameCallback => username = nc.getName
+ case pc: PasswordCallback => password = new String(pc.getPassword)
+ case c: AuthorizeCallback => ac = c
+ case other => throw new UnsupportedCallbackException(other)
+ }
+ val provider =
+ AuthenticationProvider.getAuthenticationProvider(authMethod.getAuthMethod, livyConf)
+ provider.Authenticate(username, password)
+ if (ac != null) {
+ ac.setAuthorized(true)
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/47d3ee6b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/auth/AuthenticationProvider.scala
----------------------------------------------------------------------
diff --git a/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/auth/AuthenticationProvider.scala b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/auth/AuthenticationProvider.scala
new file mode 100644
index 0000000..9464af5
--- /dev/null
+++ b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/auth/AuthenticationProvider.scala
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.livy.thriftserver.auth
+
+import java.lang.reflect.InvocationTargetException
+import javax.security.sasl.AuthenticationException
+
+import org.apache.hive.service.auth.PasswdAuthenticationProvider
+
+import org.apache.livy.LivyConf
+
+object AuthenticationProvider {
+ // TODO: support LDAP and PAM
+ val AUTH_METHODS = Seq("NONE", "CUSTOM")
+
+ @throws[AuthenticationException]
+ def getAuthenticationProvider(method: String, conf: LivyConf): PasswdAuthenticationProvider = {
+ method match {
+ case "NONE" => new NoneAuthenticationProvider
+ case "CUSTOM" => new CustomAuthenticationProvider(conf)
+ case _ => throw new AuthenticationException("Unsupported authentication method")
+ }
+ }
+}
+
+/**
+ * An implementation of [[PasswdAuthenticationProvider]] doing nothing.
+ */
+class NoneAuthenticationProvider extends PasswdAuthenticationProvider {
+ override def Authenticate(user: String, password: String): Unit = {
+ // Do nothing.
+ }
+}
+
+/**
+ * An implementation of [[PasswdAuthenticationProvider]] delegating the class configured in
+ * [[LivyConf.THRIFT_CUSTOM_AUTHENTICATION_CLASS]] to authenticate a user.
+ */
+class CustomAuthenticationProvider(conf: LivyConf) extends PasswdAuthenticationProvider {
+ private val customClass: Class[_ <: PasswdAuthenticationProvider] = {
+ Class.forName(conf.get(LivyConf.THRIFT_CUSTOM_AUTHENTICATION_CLASS))
+ .asSubclass(classOf[PasswdAuthenticationProvider])
+ }
+ val provider: PasswdAuthenticationProvider = {
+ // Try first a constructor with the LivyConf as parameter, then a constructor with no parameter
+ // of none of them is available this fails with an exception.
+ try {
+ customClass.getConstructor(classOf[LivyConf]).newInstance(conf)
+ } catch {
+ case _: NoSuchMethodException | _: InstantiationException | _: IllegalAccessException |
+ _: InvocationTargetException =>
+ customClass.getConstructor().newInstance()
+ }
+ }
+
+ override def Authenticate(user: String, password: String): Unit = {
+ provider.Authenticate(user, password)
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/47d3ee6b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/auth/LivyDelegationTokenSecretManager.scala
----------------------------------------------------------------------
diff --git a/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/auth/LivyDelegationTokenSecretManager.scala b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/auth/LivyDelegationTokenSecretManager.scala
new file mode 100644
index 0000000..e34306e
--- /dev/null
+++ b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/auth/LivyDelegationTokenSecretManager.scala
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.livy.thriftserver.auth
+
+import java.io.{ByteArrayInputStream, DataInputStream, IOException}
+
+import org.apache.hadoop.io.Text
+import org.apache.hadoop.security.token.Token
+import org.apache.hadoop.security.token.delegation.{AbstractDelegationTokenIdentifier, AbstractDelegationTokenSecretManager}
+
+import org.apache.livy.LivyConf
+
+/**
+ * A secret manager. It is taken from analogous implementation in the MapReduce client.
+ */
+class LivyDelegationTokenSecretManager(val livyConf: LivyConf)
+ extends AbstractDelegationTokenSecretManager[LivyDelegationTokenIdentifier](
+ livyConf.getTimeAsMs(LivyConf.THRIFT_DELEGATION_KEY_UPDATE_INTERVAL),
+ livyConf.getTimeAsMs(LivyConf.THRIFT_DELEGATION_TOKEN_MAX_LIFETIME),
+ livyConf.getTimeAsMs(LivyConf.THRIFT_DELEGATION_TOKEN_RENEW_INTERVAL),
+ livyConf.getTimeAsMs(LivyConf.THRIFT_DELEGATION_TOKEN_GC_INTERVAL)) {
+
+ override def createIdentifier: LivyDelegationTokenIdentifier = new LivyDelegationTokenIdentifier
+
+ /**
+ * Verify token string
+ */
+ @throws[IOException]
+ def verifyDelegationToken(tokenStrForm: String): String = {
+ val t = new Token[LivyDelegationTokenIdentifier]
+ t.decodeFromUrlString(tokenStrForm)
+ val id = getTokenIdentifier(t)
+ verifyToken(id, t.getPassword)
+ id.getUser.getShortUserName
+ }
+
+ @throws[IOException]
+ protected def getTokenIdentifier(
+ token: Token[LivyDelegationTokenIdentifier]): LivyDelegationTokenIdentifier = {
+ // turn bytes back into identifier for cache lookup
+ val buf = new ByteArrayInputStream(token.getIdentifier)
+ val in = new DataInputStream(buf)
+ val id = createIdentifier
+ id.readFields(in)
+ id
+ }
+}
+
+/**
+ * A delegation token identifier.
+ *
+ * @param owner the effective username of the token owner
+ * @param renewer the username of the renewer
+ * @param realUser the real username of the token owne
+ */
+class LivyDelegationTokenIdentifier(owner: Text, renewer: Text, realUser: Text)
+ extends AbstractDelegationTokenIdentifier(owner, renewer, realUser) {
+
+ def this() = this(new Text(), new Text(), new Text())
+
+ override def getKind: Text = LivyDelegationTokenIdentifier.LIVY_DELEGATION_KIND
+}
+
+object LivyDelegationTokenIdentifier {
+ val LIVY_DELEGATION_KIND = new Text("LIVY_DELEGATION_TOKEN")
+}
http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/47d3ee6b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/auth/PlainSaslServer.scala
----------------------------------------------------------------------
diff --git a/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/auth/PlainSaslServer.scala b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/auth/PlainSaslServer.scala
new file mode 100644
index 0000000..ea8bd51
--- /dev/null
+++ b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/auth/PlainSaslServer.scala
@@ -0,0 +1,184 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.livy.thriftserver.auth
+
+
+import java.io.IOException
+import java.security.{Provider, Security}
+import java.util
+import javax.security.auth.callback.Callback
+import javax.security.auth.callback.CallbackHandler
+import javax.security.auth.callback.NameCallback
+import javax.security.auth.callback.PasswordCallback
+import javax.security.auth.callback.UnsupportedCallbackException
+import javax.security.auth.login.LoginException
+import javax.security.sasl._
+
+import org.apache.hive.service.auth.AuthenticationProviderFactory.AuthMethods
+import org.apache.thrift.transport.TSaslServerTransport
+
+import org.apache.livy.LivyConf
+
+
+/**
+ * Sun JDK only provides a PLAIN client and no server. This class implements the Plain SASL server
+ * conforming to RFC #4616 (http://www.ietf.org/rfc/rfc4616.txt).
+ */
+class PlainSaslServer private[auth] (
+ val handler: CallbackHandler,
+ val authMethodStr: String) extends SaslServer {
+
+ AuthMethods.getValidAuthMethod(authMethodStr)
+
+ private var user: String = null
+ override def getMechanismName: String = PlainSaslServer.PLAIN_METHOD
+
+ @throws[SaslException]
+ override def evaluateResponse(response: Array[Byte]): Array[Byte] = {
+ try {
+ // parse the response
+ // message = [authzid] UTF8NUL authcid UTF8NUL passwd'
+ val tokenList: util.Deque[String] = new util.ArrayDeque[String]
+ var messageToken = new StringBuilder
+ for (b <- response) {
+ if (b == 0) {
+ tokenList.addLast(messageToken.toString)
+ messageToken = new StringBuilder
+ } else {
+ messageToken.append(b.toChar)
+ }
+ }
+ tokenList.addLast(messageToken.toString)
+ // validate response
+ if (tokenList.size < 2 || tokenList.size > 3) {
+ throw new SaslException("Invalid message format")
+ }
+ val passwd: String = tokenList.removeLast()
+ user = tokenList.removeLast()
+ // optional authzid
+ var authzId: String = null
+ if (tokenList.isEmpty) {
+ authzId = user
+ } else {
+ authzId = tokenList.removeLast()
+ }
+ if (user == null || user.isEmpty) {
+ throw new SaslException("No user name provided")
+ }
+ if (passwd == null || passwd.isEmpty) {
+ throw new SaslException("No password name provided")
+ }
+ val nameCallback = new NameCallback("User")
+ nameCallback.setName(user)
+ val pcCallback = new PasswordCallback("Password", false)
+ pcCallback.setPassword(passwd.toCharArray)
+ val acCallback = new AuthorizeCallback(user, authzId)
+ val cbList = Array[Callback](nameCallback, pcCallback, acCallback)
+ handler.handle(cbList)
+ if (!acCallback.isAuthorized) {
+ throw new SaslException("Authentication failed")
+ }
+ } catch {
+ case eL: IllegalStateException =>
+ throw new SaslException("Invalid message format", eL)
+ case eI: IOException =>
+ throw new SaslException("Error validating the login", eI)
+ case eU: UnsupportedCallbackException =>
+ throw new SaslException("Error validating the login", eU)
+ }
+ null
+ }
+
+ override def isComplete: Boolean = user != null
+
+ override def getAuthorizationID: String = user
+
+ override def unwrap(incoming: Array[Byte], offset: Int, len: Int): Array[Byte] = {
+ throw new UnsupportedOperationException
+ }
+
+ override def wrap(outgoing: Array[Byte], offset: Int, len: Int): Array[Byte] = {
+ throw new UnsupportedOperationException
+ }
+
+ override def getNegotiatedProperty(propName: String): Object = null
+
+ override def dispose(): Unit = {}
+}
+
+object PlainSaslServer {
+ val PLAIN_METHOD = "PLAIN"
+
+ Security.addProvider(new SaslPlainProvider)
+
+ def getPlainTransportFactory(
+ authTypeStr: String,
+ conf: LivyConf): TSaslServerTransport.Factory = {
+ val saslFactory = new TSaslServerTransport.Factory()
+ addPlainServerDefinition(saslFactory, authTypeStr, conf)
+ saslFactory
+ }
+
+ def addPlainServerDefinition(
+ saslFactory: TSaslServerTransport.Factory,
+ authTypeStr: String,
+ conf: LivyConf): Unit = {
+ try {
+ saslFactory.addServerDefinition("PLAIN",
+ authTypeStr,
+ null,
+ new util.HashMap[String, String](),
+ new PlainServerCallbackHandler(authTypeStr, conf))
+ } catch {
+ case e: AuthenticationException =>
+ throw new LoginException(s"Error setting callback handler $e")
+ }
+ }
+}
+
+class SaslPlainServerFactory extends SaslServerFactory {
+ override def createSaslServer(
+ mechanism: String,
+ protocol: String,
+ serverName: String,
+ props: util.Map[String, _],
+ cbh: CallbackHandler): PlainSaslServer = {
+ if (PlainSaslServer.PLAIN_METHOD == mechanism) {
+ try {
+ new PlainSaslServer(cbh, protocol)
+ } catch {
+ case _: SaslException =>
+ /* This is to fulfill the contract of the interface which states that an exception shall
+ be thrown when a SaslServer cannot be created due to an error but null should be
+ returned when a Server can't be created due to the parameters supplied. And the only
+ thing PlainSaslServer can fail on is a non-supported authentication mechanism.
+ That's why we return null instead of throwing the Exception */
+ null
+ }
+ } else {
+ null
+ }
+ }
+
+ override def getMechanismNames(props: util.Map[String, _]): Array[String] = {
+ Array[String](PlainSaslServer.PLAIN_METHOD)
+ }
+}
+
+class SaslPlainProvider extends Provider("LivySaslPlain", 1.0, "Livy Plain SASL provider") {
+ put("SaslServerFactory.PLAIN", classOf[SaslPlainServerFactory].getName)
+}
http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/47d3ee6b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/cli/ThreadPoolExecutorWithOomHook.scala
----------------------------------------------------------------------
diff --git a/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/cli/ThreadPoolExecutorWithOomHook.scala b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/cli/ThreadPoolExecutorWithOomHook.scala
new file mode 100644
index 0000000..0a06baa
--- /dev/null
+++ b/thriftserver/server/src/main/scala/org/apache/livy/thriftserver/cli/ThreadPoolExecutorWithOomHook.scala
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.livy.thriftserver.cli
+
+import java.util.concurrent._
+
+/**
+ * This class is taken from Hive, because it is package private so it cannot be accessed.
+ * If it will become public we can remove this from here.
+ */
+class ThreadPoolExecutorWithOomHook(
+ corePoolSize: Int,
+ maximumPoolSize: Int,
+ keepAliveTime: Long,
+ unit: TimeUnit,
+ workQueue: BlockingQueue[Runnable],
+ threadFactory: ThreadFactory,
+ val oomHook: Runnable)
+ extends ThreadPoolExecutor(
+ corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue, threadFactory) {
+
+ override protected def afterExecute(r: Runnable, t: Throwable): Unit = {
+ super.afterExecute(r, t)
+ if (t == null && r.isInstanceOf[Future[_]] ) {
+ try {
+ val future: Future[_] = r.asInstanceOf[Future[_]]
+ if (future.isDone) {
+ future.get
+ }
+ } catch {
+ case _: InterruptedException => Thread.currentThread.interrupt()
+ case _: OutOfMemoryError => oomHook.run()
+ case _: Throwable => // Do nothing
+ }
+ } else if (t.isInstanceOf[OutOfMemoryError]) {
+ oomHook.run()
+ }
+ }
+}