You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hbase.apache.org by te...@apache.org on 2016/03/10 15:44:35 UTC

hbase git commit: HBASE-15336 Support Dataframe writer to the spark connector (Zhan Zhang)

Repository: hbase
Updated Branches:
  refs/heads/master d14b6c381 -> f6945c463


HBASE-15336 Support Dataframe writer to the spark connector (Zhan Zhang)


Project: http://git-wip-us.apache.org/repos/asf/hbase/repo
Commit: http://git-wip-us.apache.org/repos/asf/hbase/commit/f6945c46
Tree: http://git-wip-us.apache.org/repos/asf/hbase/tree/f6945c46
Diff: http://git-wip-us.apache.org/repos/asf/hbase/diff/f6945c46

Branch: refs/heads/master
Commit: f6945c4631e7697976fd8c2272f8152905c6f875
Parents: d14b6c3
Author: tedyu <yu...@gmail.com>
Authored: Thu Mar 10 06:44:29 2016 -0800
Committer: tedyu <yu...@gmail.com>
Committed: Thu Mar 10 06:44:29 2016 -0800

----------------------------------------------------------------------
 .../hadoop/hbase/spark/DefaultSource.scala      | 121 +++++++++++++++++--
 .../hadoop/hbase/spark/datasources/Utils.scala  |  44 +++++++
 .../datasources/hbase/HBaseTableCatalog.scala   |  12 +-
 .../hadoop/hbase/spark/DefaultSourceSuite.scala |  77 ++++++++++++
 4 files changed, 242 insertions(+), 12 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/hbase/blob/f6945c46/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/DefaultSource.scala
----------------------------------------------------------------------
diff --git a/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/DefaultSource.scala b/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/DefaultSource.scala
index 97a8e9e..6a6bc1a 100644
--- a/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/DefaultSource.scala
+++ b/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/DefaultSource.scala
@@ -21,17 +21,20 @@ import java.util
 import java.util.concurrent.ConcurrentLinkedQueue
 
 import org.apache.hadoop.hbase.client._
+import org.apache.hadoop.hbase.io.ImmutableBytesWritable
+import org.apache.hadoop.hbase.mapred.TableOutputFormat
+import org.apache.hadoop.hbase.spark.datasources.Utils
 import org.apache.hadoop.hbase.spark.datasources.HBaseSparkConf
 import org.apache.hadoop.hbase.spark.datasources.HBaseTableScanRDD
 import org.apache.hadoop.hbase.spark.datasources.SerializableConfiguration
 import org.apache.hadoop.hbase.types._
 import org.apache.hadoop.hbase.util.{Bytes, PositionedByteRange, SimplePositionedMutableByteRange}
-import org.apache.hadoop.hbase.{HBaseConfiguration, TableName}
+import org.apache.hadoop.hbase.{HColumnDescriptor, HTableDescriptor, HBaseConfiguration, TableName}
+import org.apache.hadoop.mapred.JobConf
 import org.apache.spark.Logging
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.datasources.hbase.{Field, HBaseTableCatalog}
-import org.apache.spark.sql.types.{DataType => SparkDataType}
-import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.sql.{DataFrame, SaveMode, Row, SQLContext}
 import org.apache.spark.sql.sources._
 import org.apache.spark.sql.types._
 
@@ -48,10 +51,11 @@ import scala.collection.mutable
  * - Type conversions of basic SQL types.  All conversions will be
  *   Through the HBase Bytes object commands.
  */
-class DefaultSource extends RelationProvider with Logging {
+class DefaultSource extends RelationProvider  with CreatableRelationProvider with Logging {
   /**
    * Is given input from SparkSQL to construct a BaseRelation
-   * @param sqlContext SparkSQL context
+    *
+    * @param sqlContext SparkSQL context
    * @param parameters Parameters given to us from SparkSQL
    * @return           A BaseRelation Object
    */
@@ -60,18 +64,31 @@ class DefaultSource extends RelationProvider with Logging {
   BaseRelation = {
     new HBaseRelation(parameters, None)(sqlContext)
   }
+
+
+  override def createRelation(
+      sqlContext: SQLContext,
+      mode: SaveMode,
+      parameters: Map[String, String],
+      data: DataFrame): BaseRelation = {
+    val relation = HBaseRelation(parameters, Some(data.schema))(sqlContext)
+    relation.createTable()
+    relation.insert(data, false)
+    relation
+  }
 }
 
 /**
  * Implementation of Spark BaseRelation that will build up our scan logic
  * , do the scan pruning, filter push down, and value conversions
- * @param sqlContext              SparkSQL context
+  *
+  * @param sqlContext              SparkSQL context
  */
 case class HBaseRelation (
     @transient parameters: Map[String, String],
     userSpecifiedSchema: Option[StructType]
   )(@transient val sqlContext: SQLContext)
-  extends BaseRelation with PrunedFilteredScan with Logging {
+  extends BaseRelation with PrunedFilteredScan  with InsertableRelation  with Logging {
   val catalog = HBaseTableCatalog(parameters)
   def tableName = catalog.name
   val configResources = parameters.getOrElse(HBaseSparkConf.HBASE_CONFIG_RESOURCES_LOCATIONS, "")
@@ -116,6 +133,90 @@ case class HBaseRelation (
    */
   override val schema: StructType = userSpecifiedSchema.getOrElse(catalog.toDataType)
 
+
+
+  def createTable() {
+    val numReg = parameters.get(HBaseTableCatalog.newTable).map(x => x.toInt).getOrElse(0)
+    val startKey =  Bytes.toBytes(
+      parameters.get(HBaseTableCatalog.regionStart)
+        .getOrElse(HBaseTableCatalog.defaultRegionStart))
+    val endKey = Bytes.toBytes(
+      parameters.get(HBaseTableCatalog.regionEnd)
+        .getOrElse(HBaseTableCatalog.defaultRegionEnd))
+    if (numReg > 3) {
+      val tName = TableName.valueOf(catalog.name)
+      val cfs = catalog.getColumnFamilies
+      val connection = ConnectionFactory.createConnection(hbaseConf)
+      // Initialize hBase table if necessary
+      val admin = connection.getAdmin()
+      try {
+        if (!admin.isTableAvailable(tName)) {
+          val tableDesc = new HTableDescriptor(tName)
+          cfs.foreach { x =>
+            val cf = new HColumnDescriptor(x.getBytes())
+            logDebug(s"add family $x to ${catalog.name}")
+            tableDesc.addFamily(cf)
+          }
+          val splitKeys = Bytes.split(startKey, endKey, numReg);
+          admin.createTable(tableDesc, splitKeys)
+
+        }
+      }finally {
+        admin.close()
+        connection.close()
+      }
+    } else {
+      logInfo(
+        s"""${HBaseTableCatalog.newTable}
+           |is not defined or no larger than 3, skip the create table""".stripMargin)
+    }
+  }
+
+  /**
+    *
+    * @param data
+    * @param overwrite
+    */
+  override def insert(data: DataFrame, overwrite: Boolean): Unit = {
+    val jobConfig: JobConf = new JobConf(hbaseConf, this.getClass)
+    jobConfig.setOutputFormat(classOf[TableOutputFormat])
+    jobConfig.set(TableOutputFormat.OUTPUT_TABLE, catalog.name)
+    var count = 0
+    val rkFields = catalog.getRowKey
+    val rkIdxedFields = rkFields.map{ case x =>
+      (schema.fieldIndex(x.colName), x)
+    }
+    val colsIdxedFields = schema
+      .fieldNames
+      .partition( x => rkFields.map(_.colName).contains(x))
+      ._2.map(x => (schema.fieldIndex(x), catalog.getField(x)))
+    val rdd = data.rdd
+    def convertToPut(row: Row) = {
+      // construct bytes for row key
+      val rowBytes = rkIdxedFields.map { case (x, y) =>
+        Utils.toBytes(row(x), y)
+      }
+      val rLen = rowBytes.foldLeft(0) { case (x, y) =>
+        x + y.length
+      }
+      val rBytes = new Array[Byte](rLen)
+      var offset = 0
+      rowBytes.foreach { x =>
+        System.arraycopy(x, 0, rBytes, offset, x.length)
+        offset += x.length
+      }
+      val put = new Put(rBytes)
+
+      colsIdxedFields.foreach { case (x, y) =>
+        val b = Utils.toBytes(row(x), y)
+        put.addColumn(Bytes.toBytes(y.cf), Bytes.toBytes(y.col), b)
+      }
+      count += 1
+      (new ImmutableBytesWritable, put)
+    }
+    rdd.map(convertToPut(_)).saveAsHadoopDataset(jobConfig)
+  }
+
   /**
    * Here we are building the functionality to populate the resulting RDD[Row]
    * Here is where we will do the following:
@@ -356,7 +457,8 @@ class ScanRange(var upperBound:Array[Byte], var isUpperBoundEqualTo:Boolean,
 
   /**
    * Function to merge another scan object through a AND operation
-   * @param other Other scan object
+    *
+    * @param other Other scan object
    */
   def mergeIntersect(other:ScanRange): Unit = {
     val upperBoundCompare = compareRange(upperBound, other.upperBound)
@@ -376,7 +478,8 @@ class ScanRange(var upperBound:Array[Byte], var isUpperBoundEqualTo:Boolean,
 
   /**
    * Function to merge another scan object through a OR operation
-   * @param other Other scan object
+    *
+    * @param other Other scan object
    */
   def mergeUnion(other:ScanRange): Unit = {
 

http://git-wip-us.apache.org/repos/asf/hbase/blob/f6945c46/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/datasources/Utils.scala
----------------------------------------------------------------------
diff --git a/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/datasources/Utils.scala b/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/datasources/Utils.scala
new file mode 100644
index 0000000..090e81a
--- /dev/null
+++ b/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/datasources/Utils.scala
@@ -0,0 +1,44 @@
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.hbase.spark.datasources
+
+import org.apache.hadoop.hbase.util.Bytes
+import org.apache.hadoop.hbase.util.Bytes
+import org.apache.spark.sql.datasources.hbase.Field
+import org.apache.spark.unsafe.types.UTF8String
+
+object Utils {
+  // convert input to data type
+  def toBytes(input: Any, field: Field): Array[Byte] = {
+    input match {
+      case data: Boolean => Bytes.toBytes(data)
+      case data: Byte => Array(data)
+      case data: Array[Byte] => data
+      case data: Double => Bytes.toBytes(data)
+      case data: Float => Bytes.toBytes(data)
+      case data: Int => Bytes.toBytes(data)
+      case data: Long => Bytes.toBytes(data)
+      case data: Short => Bytes.toBytes(data)
+      case data: UTF8String => data.getBytes
+      case data: String => Bytes.toBytes(data)
+      // TODO: add more data type support
+      case _ => throw new Exception(s"unsupported data type ${field.dt}")
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/hbase/blob/f6945c46/hbase-spark/src/main/scala/org/apache/spark/sql/datasources/hbase/HBaseTableCatalog.scala
----------------------------------------------------------------------
diff --git a/hbase-spark/src/main/scala/org/apache/spark/sql/datasources/hbase/HBaseTableCatalog.scala b/hbase-spark/src/main/scala/org/apache/spark/sql/datasources/hbase/HBaseTableCatalog.scala
index 103fb90..45fa60f 100644
--- a/hbase-spark/src/main/scala/org/apache/spark/sql/datasources/hbase/HBaseTableCatalog.scala
+++ b/hbase-spark/src/main/scala/org/apache/spark/sql/datasources/hbase/HBaseTableCatalog.scala
@@ -121,7 +121,7 @@ case class HBaseTableCatalog(
      name: String,
      row: RowKey,
      sMap: SchemaMap,
-     numReg: Int) extends Logging {
+     @transient params: Map[String, String]) extends Logging {
   def toDataType = StructType(sMap.toFields)
   def getField(name: String) = sMap.getField(name)
   def getRowKey: Seq[Field] = row.fields
@@ -130,6 +130,8 @@ case class HBaseTableCatalog(
     sMap.fields.map(_.cf).filter(_ != HBaseTableCatalog.rowKey)
   }
 
+  def get(key: String) = params.get(key)
+
   // Setup the start and length for each dimension of row key at runtime.
   def dynSetupRowKey(rowKey: HBaseType) {
     logDebug(s"length: ${rowKey.length}")
@@ -179,8 +181,13 @@ case class HBaseTableCatalog(
 }
 
 object HBaseTableCatalog {
+  // If defined and larger than 3, a new table will be created with the nubmer of region specified.
   val newTable = "newtable"
   // The json string specifying hbase catalog information
+  val regionStart = "regionStart"
+  val defaultRegionStart = "aaaaaaa"
+  val regionEnd = "regionEnd"
+  val defaultRegionEnd = "zzzzzzz"
   val tableCatalog = "catalog"
   // The row key with format key1:key2 specifying table row key
   val rowKey = "rowkey"
@@ -232,9 +239,8 @@ object HBaseTableCatalog {
         sAvro, sd, len)
       schemaMap.+=((name, f))
     }
-    val numReg = parameters.get(newTable).map(x => x.toInt).getOrElse(0)
     val rKey = RowKey(map.get(rowKey).get.asInstanceOf[String])
-    HBaseTableCatalog(nSpace, tName, rKey, SchemaMap(schemaMap), numReg)
+    HBaseTableCatalog(nSpace, tName, rKey, SchemaMap(schemaMap), parameters)
   }
 
   val TABLE_KEY: String = "hbase.table"

http://git-wip-us.apache.org/repos/asf/hbase/blob/f6945c46/hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/DefaultSourceSuite.scala
----------------------------------------------------------------------
diff --git a/hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/DefaultSourceSuite.scala b/hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/DefaultSourceSuite.scala
index 2987ec6..a2aa3c6 100644
--- a/hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/DefaultSourceSuite.scala
+++ b/hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/DefaultSourceSuite.scala
@@ -26,6 +26,26 @@ import org.apache.spark.sql.{DataFrame, SQLContext}
 import org.apache.spark.{SparkConf, SparkContext, Logging}
 import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite}
 
+case class HBaseRecord(
+    col0: String,
+    col1: String,
+    col2: Double,
+    col3: Float,
+    col4: Int,
+    col5: Long)
+
+object HBaseRecord {
+  def apply(i: Int, t: String): HBaseRecord = {
+    val s = s"""row${"%03d".format(i)}"""
+    HBaseRecord(s,
+      s,
+      i.toDouble,
+      i.toFloat,
+      i,
+      i.toLong)
+  }
+}
+
 class DefaultSourceSuite extends FunSuite with
 BeforeAndAfterEach with BeforeAndAfterAll with Logging {
   @transient var sc: SparkContext = null
@@ -63,6 +83,7 @@ BeforeAndAfterEach with BeforeAndAfterAll with Logging {
     sparkConf.set(HBaseSparkConf.BLOCK_CACHE_ENABLE, "true")
     sparkConf.set(HBaseSparkConf.BATCH_NUM, "100")
     sparkConf.set(HBaseSparkConf.CACHE_SIZE, "100")
+
     sc  = new SparkContext("local", "test", sparkConf)
 
     val connection = ConnectionFactory.createConnection(TEST_UTIL.getConfiguration)
@@ -759,4 +780,60 @@ BeforeAndAfterEach with BeforeAndAfterAll with Logging {
 
     assert(executionRules.dynamicLogicExpression == null)
   }
+
+  def writeCatalog = s"""{
+                    |"table":{"namespace":"default", "name":"table1"},
+                    |"rowkey":"key",
+                    |"columns":{
+                    |"col0":{"cf":"rowkey", "col":"key", "type":"string"},
+                    |"col1":{"cf":"cf1", "col":"col1", "type":"string"},
+                    |"col2":{"cf":"cf2", "col":"col2", "type":"double"},
+                    |"col3":{"cf":"cf3", "col":"col3", "type":"float"},
+                    |"col4":{"cf":"cf4", "col":"col4", "type":"int"},
+                    |"col5":{"cf":"cf5", "col":"col5", "type":"bigint"}}
+                    |}
+                    |}""".stripMargin
+
+  def withCatalog(cat: String): DataFrame = {
+    sqlContext
+      .read
+      .options(Map(HBaseTableCatalog.tableCatalog->cat))
+      .format("org.apache.hadoop.hbase.spark")
+      .load()
+  }
+
+  test("populate table") {
+    val sql = sqlContext
+    import sql.implicits._
+    val data = (0 to 255).map { i =>
+      HBaseRecord(i, "extra")
+    }
+    sc.parallelize(data).toDF.write.options(
+      Map(HBaseTableCatalog.tableCatalog -> writeCatalog, HBaseTableCatalog.newTable -> "5"))
+      .format("org.apache.hadoop.hbase.spark")
+      .save()
+  }
+
+  test("empty column") {
+    val df = withCatalog(writeCatalog)
+    df.registerTempTable("table0")
+    val c = sqlContext.sql("select count(1) from table0").rdd.collect()(0)(0).asInstanceOf[Long]
+    assert(c == 256)
+  }
+
+  test("full query") {
+    val df = withCatalog(writeCatalog)
+    df.show
+    assert(df.count() == 256)
+  }
+
+  test("filtered query0") {
+    val sql = sqlContext
+    import sql.implicits._
+    val df = withCatalog(writeCatalog)
+    val s = df.filter($"col0" <= "row005")
+      .select("col0", "col1")
+    s.show
+    assert(s.count() == 6)
+  }
 }