You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@s2graph.apache.org by st...@apache.org on 2018/05/14 12:29:55 UTC
[11/25] incubator-s2graph git commit: add FastTextFetcher.
add FastTextFetcher.
Project: http://git-wip-us.apache.org/repos/asf/incubator-s2graph/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-s2graph/commit/54c56c36
Tree: http://git-wip-us.apache.org/repos/asf/incubator-s2graph/tree/54c56c36
Diff: http://git-wip-us.apache.org/repos/asf/incubator-s2graph/diff/54c56c36
Branch: refs/heads/master
Commit: 54c56c36a04864c101b196180359eb357f5ca030
Parents: 08a80cd
Author: DO YUNG YOON <st...@apache.org>
Authored: Fri May 4 16:51:11 2018 +0900
Committer: DO YUNG YOON <st...@apache.org>
Committed: Fri May 4 16:51:11 2018 +0900
----------------------------------------------------------------------
project/Common.scala | 2 +
s2core/build.sbt | 2 +-
.../s2graph/core/model/AnnoyModelFetcher.scala | 128 ------------
.../core/model/annoy/AnnoyModelFetcher.scala | 115 +++++++++++
.../s2graph/core/model/fasttext/CopyModel.scala | 122 ++++++++++++
.../s2graph/core/model/fasttext/FastText.scala | 194 +++++++++++++++++++
.../core/model/fasttext/FastTextArgs.scala | 119 ++++++++++++
.../core/model/fasttext/FastTextFetcher.scala | 48 +++++
.../apache/s2graph/core/model/FetcherTest.scala | 3 +-
.../model/fasttext/FastTextFetcherTest.scala | 60 ++++++
.../custom/process/ALSModelProcessTest.scala | 6 +-
11 files changed, 666 insertions(+), 133 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-s2graph/blob/54c56c36/project/Common.scala
----------------------------------------------------------------------
diff --git a/project/Common.scala b/project/Common.scala
index 04279f4..08552a8 100644
--- a/project/Common.scala
+++ b/project/Common.scala
@@ -33,6 +33,8 @@ object Common {
val KafkaVersion = "0.10.2.1"
+ val rocksVersion = "5.11.3"
+
val annoy4sVersion = "0.6.0"
val tensorflowVersion = "1.7.0"
http://git-wip-us.apache.org/repos/asf/incubator-s2graph/blob/54c56c36/s2core/build.sbt
----------------------------------------------------------------------
diff --git a/s2core/build.sbt b/s2core/build.sbt
index bd84c37..cfc32d6 100644
--- a/s2core/build.sbt
+++ b/s2core/build.sbt
@@ -50,7 +50,7 @@ libraryDependencies ++= Seq(
"org.apache.hadoop" % "hadoop-hdfs" % hadoopVersion ,
"org.apache.lucene" % "lucene-core" % "6.6.0",
"org.apache.lucene" % "lucene-queryparser" % "6.6.0",
- "org.rocksdb" % "rocksdbjni" % "5.8.0",
+ "org.rocksdb" % "rocksdbjni" % rocksVersion,
"org.scala-lang.modules" %% "scala-java8-compat" % "0.8.0",
"com.sksamuel.elastic4s" %% "elastic4s-core" % elastic4sVersion excludeLogging(),
"com.sksamuel.elastic4s" %% "elastic4s-http" % elastic4sVersion excludeLogging(),
http://git-wip-us.apache.org/repos/asf/incubator-s2graph/blob/54c56c36/s2core/src/main/scala/org/apache/s2graph/core/model/AnnoyModelFetcher.scala
----------------------------------------------------------------------
diff --git a/s2core/src/main/scala/org/apache/s2graph/core/model/AnnoyModelFetcher.scala b/s2core/src/main/scala/org/apache/s2graph/core/model/AnnoyModelFetcher.scala
deleted file mode 100644
index 2f2a40c..0000000
--- a/s2core/src/main/scala/org/apache/s2graph/core/model/AnnoyModelFetcher.scala
+++ /dev/null
@@ -1,128 +0,0 @@
-package org.apache.s2graph.core.model
-
-import annoy4s.Converters.KeyConverter
-import annoy4s._
-import com.typesafe.config.Config
-import org.apache.s2graph.core._
-import org.apache.s2graph.core.model.AnnoyModelFetcher.IndexFilePathKey
-import org.apache.s2graph.core.types.VertexId
-
-import scala.concurrent.{ExecutionContext, Future}
-
-object AnnoyModelFetcher {
- val IndexFilePathKey = "annoyIndexFilePath"
- val DictFilePathKey = "annoyDictFilePath"
- val DimensionKey = "annoyIndexDimension"
- val IndexTypeKey = "annoyIndexType"
-
- // def loadDictFromLocal(file: File): Map[Int, String] = {
- // val files = if (file.isDirectory) {
- // file.listFiles()
- // } else {
- // Array(file)
- // }
- //
- // files.flatMap { file =>
- // Source.fromFile(file).getLines().zipWithIndex.flatMap { case (line, _idx) =>
- // val tokens = line.stripMargin.split(",")
- // try {
- // val tpl = if (tokens.length < 2) {
- // (tokens.head.toInt, tokens.head)
- // } else {
- // (tokens.head.toInt, tokens.tail.head)
- // }
- // Seq(tpl)
- // } catch {
- // case e: Exception => Nil
- // }
- // }
- // }.toMap
- // }
-
- def buildAnnoy4s[T](indexPath: String)(implicit converter: KeyConverter[T]): Annoy[T] = {
- Annoy.load[T](indexPath)
- }
-
- // def buildIndex(indexPath: String,
- // dictPath: String,
- // dimension: Int,
- // indexType: IndexType): ANNIndexWithDict = {
- // val dict = loadDictFromLocal(new File(dictPath))
- // val index = new ANNIndex(dimension, indexPath, indexType)
- //
- // ANNIndexWithDict(index, dict)
- // }
- //
- // def buildIndex(config: Config): ANNIndexWithDict = {
- // val indexPath = config.getString(IndexFilePathKey)
- // val dictPath = config.getString(DictFilePathKey)
- //
- // val dimension = config.getInt(DimensionKey)
- // val indexType = Try { config.getString(IndexTypeKey) }.toOption.map(IndexType.valueOf).getOrElse(IndexType.ANGULAR)
- //
- // buildIndex(indexPath, dictPath, dimension, indexType)
- // }
-}
-
-//
-//case class ANNIndexWithDict(index: ANNIndex, dict: Map[Int, String]) {
-// val dictRev = dict.map(kv => kv._2 -> kv._1)
-//}
-
-class AnnoyModelFetcher(val graph: S2GraphLike) extends Fetcher {
- val builder = graph.elementBuilder
-
- // var model: ANNIndexWithDict = _
- var model: Annoy[String] = _
-
- override def init(config: Config)(implicit ec: ExecutionContext): Future[Fetcher] = {
- Future {
- model = AnnoyModelFetcher.buildAnnoy4s(config.getString(IndexFilePathKey))
- // AnnoyModelFetcher.buildIndex(config)
-
- this
- }
- }
-
- /** Fetch **/
- override def fetches(queryRequests: Seq[QueryRequest],
- prevStepEdges: Map[VertexId, Seq[EdgeWithScore]])(implicit ec: ExecutionContext): Future[Seq[StepResult]] = {
- val stepResultLs = queryRequests.map { queryRequest =>
- val vertex = queryRequest.vertex
- val queryParam = queryRequest.queryParam
-
- val edgeWithScores = model.query(vertex.innerId.toIdString(), queryParam.limit).getOrElse(Nil).map { case (tgtId, score) =>
- val tgtVertexId = builder.newVertexId(queryParam.label.service,
- queryParam.label.tgtColumnWithDir(queryParam.labelWithDir.dir), tgtId)
-
- val edge = graph.toEdge(vertex.innerId.value, tgtVertexId.innerId.value, queryParam.labelName, queryParam.direction)
-
- EdgeWithScore(edge, score, queryParam.label)
- }
-
- StepResult(edgeWithScores, Nil, Nil)
- //
- // val srcIndexOpt = model.dictRev.get(vertex.innerId.toIdString())
- //
- // srcIndexOpt.map { srcIdx =>
- // val srcVector = model.index.getItemVector(srcIdx)
- // val nns = model.index.getNearest(srcVector, queryParam.limit).asScala
- //
- // val edges = nns.map { tgtIdx =>
- // val tgtVertexId = builder.newVertexId(queryParam.label.service,
- // queryParam.label.tgtColumnWithDir(queryParam.labelWithDir.dir), model.dict(tgtIdx))
- //
- // graph.toEdge(vertex.innerId.value, tgtVertexId.innerId.value, queryParam.labelName, queryParam.direction)
- // }
- // val edgeWithScores = edges.map(e => EdgeWithScore(e, 1.0, queryParam.label))
- // StepResult(edgeWithScores, Nil, Nil)
- // }.getOrElse(StepResult.Empty)
- }
-
- Future.successful(stepResultLs)
- }
-
- override def close(): Unit = {
- // do clean up
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-s2graph/blob/54c56c36/s2core/src/main/scala/org/apache/s2graph/core/model/annoy/AnnoyModelFetcher.scala
----------------------------------------------------------------------
diff --git a/s2core/src/main/scala/org/apache/s2graph/core/model/annoy/AnnoyModelFetcher.scala b/s2core/src/main/scala/org/apache/s2graph/core/model/annoy/AnnoyModelFetcher.scala
new file mode 100644
index 0000000..a4e2aae
--- /dev/null
+++ b/s2core/src/main/scala/org/apache/s2graph/core/model/annoy/AnnoyModelFetcher.scala
@@ -0,0 +1,115 @@
+package org.apache.s2graph.core.model.annoy
+
+import annoy4s.Converters.KeyConverter
+import annoy4s._
+import com.typesafe.config.Config
+import org.apache.s2graph.core._
+import org.apache.s2graph.core.types.VertexId
+
+import scala.concurrent.{ExecutionContext, Future}
+
+object AnnoyModelFetcher {
+ val IndexFilePathKey = "annoyIndexFilePath"
+ val DictFilePathKey = "annoyDictFilePath"
+ val DimensionKey = "annoyIndexDimension"
+ val IndexTypeKey = "annoyIndexType"
+
+ // def loadDictFromLocal(file: File): Map[Int, String] = {
+ // val files = if (file.isDirectory) {
+ // file.listFiles()
+ // } else {
+ // Array(file)
+ // }
+ //
+ // files.flatMap { file =>
+ // Source.fromFile(file).getLines().zipWithIndex.flatMap { case (line, _idx) =>
+ // val tokens = line.stripMargin.split(",")
+ // try {
+ // val tpl = if (tokens.length < 2) {
+ // (tokens.head.toInt, tokens.head)
+ // } else {
+ // (tokens.head.toInt, tokens.tail.head)
+ // }
+ // Seq(tpl)
+ // } catch {
+ // case e: Exception => Nil
+ // }
+ // }
+ // }.toMap
+ // }
+
+ def buildAnnoy4s[T](indexPath: String)(implicit converter: KeyConverter[T]): Annoy[T] = {
+ Annoy.load[T](indexPath)
+ }
+
+ // def buildIndex(indexPath: String,
+ // dictPath: String,
+ // dimension: Int,
+ // indexType: IndexType): ANNIndexWithDict = {
+ // val dict = loadDictFromLocal(new File(dictPath))
+ // val index = new ANNIndex(dimension, indexPath, indexType)
+ //
+ // ANNIndexWithDict(index, dict)
+ // }
+ //
+ // def buildIndex(config: Config): ANNIndexWithDict = {
+ // val indexPath = config.getString(IndexFilePathKey)
+ // val dictPath = config.getString(DictFilePathKey)
+ //
+ // val dimension = config.getInt(DimensionKey)
+ // val indexType = Try { config.getString(IndexTypeKey) }.toOption.map(IndexType.valueOf).getOrElse(IndexType.ANGULAR)
+ //
+ // buildIndex(indexPath, dictPath, dimension, indexType)
+ // }
+}
+
+//
+//case class ANNIndexWithDict(index: ANNIndex, dict: Map[Int, String]) {
+// val dictRev = dict.map(kv => kv._2 -> kv._1)
+//}
+
+class AnnoyModelFetcher(val graph: S2GraphLike) extends Fetcher {
+ import AnnoyModelFetcher._
+
+ val builder = graph.elementBuilder
+
+ // var model: ANNIndexWithDict = _
+ var model: Annoy[String] = _
+
+ override def init(config: Config)(implicit ec: ExecutionContext): Future[Fetcher] = {
+ Future {
+ model = AnnoyModelFetcher.buildAnnoy4s(config.getString(IndexFilePathKey))
+ // AnnoyModelFetcher.buildIndex(config)
+
+ this
+ }
+ }
+
+ /** Fetch **/
+ override def fetches(queryRequests: Seq[QueryRequest],
+ prevStepEdges: Map[VertexId, Seq[EdgeWithScore]])(implicit ec: ExecutionContext): Future[Seq[StepResult]] = {
+ val stepResultLs = queryRequests.map { queryRequest =>
+ val vertex = queryRequest.vertex
+ val queryParam = queryRequest.queryParam
+
+ val edgeWithScores = model.query(vertex.innerId.toIdString(), queryParam.limit).getOrElse(Nil).map { case (tgtId, score) =>
+ val tgtVertexId = builder.newVertexId(queryParam.label.service,
+ queryParam.label.tgtColumnWithDir(queryParam.labelWithDir.dir), tgtId)
+
+ val props: Map[String, Any] = if (queryParam.label.metaPropsInvMap.contains("score")) Map("score" -> score) else Map.empty
+ val edge = graph.toEdge(vertex.innerId.value, tgtVertexId.innerId.value, queryParam.labelName, queryParam.direction, props = props)
+
+ EdgeWithScore(edge, score, queryParam.label)
+ }
+
+ StepResult(edgeWithScores, Nil, Nil)
+ }
+
+ Future.successful(stepResultLs)
+ }
+
+ override def close(): Unit = {
+ // do clean up
+ model.close
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-s2graph/blob/54c56c36/s2core/src/main/scala/org/apache/s2graph/core/model/fasttext/CopyModel.scala
----------------------------------------------------------------------
diff --git a/s2core/src/main/scala/org/apache/s2graph/core/model/fasttext/CopyModel.scala b/s2core/src/main/scala/org/apache/s2graph/core/model/fasttext/CopyModel.scala
new file mode 100644
index 0000000..c3e36c7
--- /dev/null
+++ b/s2core/src/main/scala/org/apache/s2graph/core/model/fasttext/CopyModel.scala
@@ -0,0 +1,122 @@
+package org.apache.s2graph.core.model.fasttext
+
+
+import java.io.{BufferedInputStream, FileInputStream, InputStream}
+import java.nio.{ByteBuffer, ByteOrder}
+import java.util
+
+import org.apache.s2graph.core.model.fasttext.fasttext.FastTextArgs
+import org.rocksdb._
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
+
+object CopyModel {
+
+ def writeArgs(db: RocksDB, handle: ColumnFamilyHandle, args: FastTextArgs): Unit = {
+ val wo = new WriteOptions().setDisableWAL(true).setSync(false)
+ db.put(handle, wo, "args".getBytes("UTF-8"), args.serialize)
+ wo.close()
+ println("done ")
+ }
+
+ def writeVocab(is: InputStream, db: RocksDB,
+ vocabHandle: ColumnFamilyHandle, labelHandle: ColumnFamilyHandle, args: FastTextArgs): Unit = {
+ val wo = new WriteOptions().setDisableWAL(true).setSync(false)
+ val bb = ByteBuffer.allocate(13).order(ByteOrder.LITTLE_ENDIAN)
+ val wb = new ArrayBuffer[Byte]
+ for (wid <- 0 until args.size) {
+ bb.clear()
+ wb.clear()
+ var b = is.read()
+ while (b != 0) {
+ wb += b.toByte
+ b = is.read()
+ }
+ bb.putInt(wid)
+ is.read(bb.array(), 4, 9)
+ db.put(vocabHandle, wo, wb.toArray, bb.array())
+
+ if (bb.get(12) == 1) {
+ val label = wid - args.nwords
+ db.put(labelHandle, ByteBuffer.allocate(4).putInt(label).array(), wb.toArray)
+ }
+
+ if ((wid + 1) % 1000 == 0)
+ print(f"\rprocessing ${100 * (wid + 1) / args.size.toFloat}%.2f%%")
+ }
+ println("\rdone ")
+ wo.close()
+ }
+
+ def writeVectors(is: InputStream, db: RocksDB, handle: ColumnFamilyHandle, args: FastTextArgs): Unit = {
+ require(is.read() == 0, "not implemented")
+ val wo = new WriteOptions().setDisableWAL(true).setSync(false)
+ val bb = ByteBuffer.allocate(16).order(ByteOrder.LITTLE_ENDIAN)
+ val key = ByteBuffer.allocate(8)
+ val value = new Array[Byte](args.dim * 4)
+ is.read(bb.array())
+ val m = bb.getLong
+ val n = bb.getLong
+ require(n * 4 == value.length)
+ var i = 0L
+ while (i < m) {
+ key.clear()
+ key.putLong(i)
+ is.read(value)
+ db.put(handle, wo, key.array(), value)
+ if ((i + 1) % 1000 == 0)
+ print(f"\rprocessing ${100 * (i + 1) / m.toFloat}%.2f%%")
+ i += 1
+ }
+ println("\rdone ")
+ wo.close()
+ }
+
+ def printHelp(): Unit = {
+ println("usage: CopyModel <in> <out>")
+ }
+
+ def copy(in: String, out: String): Unit = {
+ RocksDB.destroyDB(out, new Options)
+
+ val dbOptions = new DBOptions()
+ .setCreateIfMissing(true)
+ .setCreateMissingColumnFamilies(true)
+ .setAllowMmapReads(false)
+ .setMaxOpenFiles(500000)
+ .setDbWriteBufferSize(134217728)
+ .setMaxBackgroundCompactions(20)
+
+ val descriptors = new java.util.LinkedList[ColumnFamilyDescriptor]()
+ descriptors.add(new ColumnFamilyDescriptor(RocksDB.DEFAULT_COLUMN_FAMILY))
+ descriptors.add(new ColumnFamilyDescriptor("vocab".getBytes()))
+ descriptors.add(new ColumnFamilyDescriptor("i".getBytes()))
+ descriptors.add(new ColumnFamilyDescriptor("o".getBytes()))
+ val handles = new util.LinkedList[ColumnFamilyHandle]()
+ val db = RocksDB.open(dbOptions, out, descriptors, handles)
+
+ val is = new BufferedInputStream(new FileInputStream(in))
+ val fastTextArgs = FastTextArgs.fromInputStream(is)
+
+ require(fastTextArgs.magic == FastText.FASTTEXT_FILEFORMAT_MAGIC_INT32)
+ require(fastTextArgs.version == FastText.FASTTEXT_VERSION)
+
+ println("step 1: writing args")
+ writeArgs(db, handles.get(0), fastTextArgs)
+ println("step 2: writing vocab")
+ writeVocab(is, db, handles.get(1), handles.get(0), fastTextArgs)
+ println("step 3: writing input vectors")
+ writeVectors(is, db, handles.get(2), fastTextArgs)
+ println("step 4: writing output vectors")
+ writeVectors(is, db, handles.get(3), fastTextArgs)
+ println("step 5: compactRange")
+ db.compactRange()
+ println("done")
+
+ handles.asScala.foreach(_.close())
+ db.close()
+ is.close()
+ }
+
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/incubator-s2graph/blob/54c56c36/s2core/src/main/scala/org/apache/s2graph/core/model/fasttext/FastText.scala
----------------------------------------------------------------------
diff --git a/s2core/src/main/scala/org/apache/s2graph/core/model/fasttext/FastText.scala b/s2core/src/main/scala/org/apache/s2graph/core/model/fasttext/FastText.scala
new file mode 100644
index 0000000..b5d10a9
--- /dev/null
+++ b/s2core/src/main/scala/org/apache/s2graph/core/model/fasttext/FastText.scala
@@ -0,0 +1,194 @@
+package org.apache.s2graph.core.model.fasttext
+
+import java.nio.{ByteBuffer, ByteOrder}
+import java.util
+
+import org.apache.s2graph.core.model.fasttext.fasttext.FastTextArgs
+import org.rocksdb.{ColumnFamilyDescriptor, ColumnFamilyHandle, DBOptions, RocksDB}
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
+
+case class Line(labels: Array[Int], words: Array[Long])
+
+case class Entry(wid: Int, count: Long, tpe: Byte, subwords: Array[Long])
+
+object FastText {
+ val EOS = "</s>"
+ val BOW = "<"
+ val EOW = ">"
+
+ val FASTTEXT_VERSION = 12 // Version 1b
+ val FASTTEXT_FILEFORMAT_MAGIC_INT32 = 793712314
+
+ val MODEL_CBOW = 1
+ val MODEL_SG = 2
+ val MODEL_SUP = 3
+
+ val LOSS_HS = 1
+ val LOSS_NS = 2
+ val LOSS_SOFTMAX = 3
+
+ val DBPathKey = "dbPath"
+
+ def tokenize(in: String): Array[String] = in.split("\\s+") ++ Array("</s>")
+
+ def getSubwords(word: String, minn: Int, maxn: Int): Array[String] = {
+ val l = math.max(minn, 1)
+ val u = math.min(maxn, word.length)
+ val r = l to u flatMap word.sliding
+ r.filterNot(s => s == BOW || s == EOW).toArray
+ }
+
+ def hash(str: String): Long = {
+ var h = 2166136261L.toInt
+ for (b <- str.getBytes) {
+ h = (h ^ b) * 16777619
+ }
+ h & 0xffffffffL
+ }
+
+}
+
+class FastText(name: String) extends AutoCloseable {
+
+ import FastText._
+
+ private val dbOptions = new DBOptions()
+ private val descriptors = new java.util.LinkedList[ColumnFamilyDescriptor]()
+ descriptors.add(new ColumnFamilyDescriptor(RocksDB.DEFAULT_COLUMN_FAMILY))
+ descriptors.add(new ColumnFamilyDescriptor("vocab".getBytes()))
+ descriptors.add(new ColumnFamilyDescriptor("i".getBytes()))
+ descriptors.add(new ColumnFamilyDescriptor("o".getBytes()))
+ private val handles = new util.LinkedList[ColumnFamilyHandle]()
+ private val db = RocksDB.openReadOnly(dbOptions, name, descriptors, handles)
+
+ private val defaultHandle = handles.get(0)
+ private val vocabHandle = handles.get(1)
+ private val inputVectorHandle = handles.get(2)
+ private val outputVectorHandle = handles.get(3)
+
+ private val args = FastTextArgs.fromByteArray(db.get(defaultHandle, "args".getBytes("UTF-8")))
+ private val wo = loadOutputVectors()
+ private val labels = loadLabels()
+
+ println(args)
+
+ require(args.magic == FASTTEXT_FILEFORMAT_MAGIC_INT32)
+ require(args.version == FASTTEXT_VERSION)
+
+ // only sup/softmax supported
+ // others are the future work.
+ require(args.model == MODEL_SUP)
+ require(args.loss == LOSS_SOFTMAX)
+
+ private def getVector(handle: ColumnFamilyHandle, key: Long): Array[Float] = {
+ val keyBytes = ByteBuffer.allocate(8).putLong(key).array()
+ val bb = ByteBuffer.wrap(db.get(handle, keyBytes)).order(ByteOrder.LITTLE_ENDIAN)
+ Array.fill(args.dim)(bb.getFloat)
+ }
+
+ private def loadOutputVectors(): Array[Array[Float]] =
+ Array.tabulate(args.nlabels)(key => getVector(outputVectorHandle, key.toLong))
+
+ private def loadLabels(): Array[String] = {
+ val result = new Array[String](args.nlabels)
+ val it = db.newIterator(defaultHandle)
+ var i = 0
+ it.seekToFirst()
+ while (it.isValid) {
+ val key = ByteBuffer.wrap(it.key()).getInt()
+ if (key < args.nlabels) {
+ require(i == key)
+ result(i) = new String(it.value(), "UTF-8")
+ i += 1
+ }
+ it.next()
+ }
+ result
+ }
+
+ def getInputVector(key: Long): Array[Float] = getVector(inputVectorHandle, key)
+
+ def getOutputVector(key: Long): Array[Float] = getVector(outputVectorHandle, key)
+
+ def getEntry(word: String): Entry = {
+ val raw = db.get(vocabHandle, word.getBytes("UTF-8"))
+ if (raw == null) {
+ Entry(-1, 0L, 1, Array.emptyLongArray)
+ } else {
+ val bb = ByteBuffer.wrap(raw).order(ByteOrder.LITTLE_ENDIAN)
+ val wid = bb.getInt
+ val count = bb.getLong
+ val tpe = bb.get
+ val subwords = if (word != EOS && tpe == 0) Array(wid.toLong) ++ computeSubwords(BOW + word + EOW) else Array(wid.toLong)
+ Entry(wid, count, tpe, subwords)
+ }
+ }
+
+ def computeSubwords(word: String): Array[Long] =
+ getSubwords(word, args.minn, args.maxn).map { w => args.nwords + (hash(w) % args.bucket.toLong) }
+
+ def getLine(in: String): Line = {
+ val tokens = tokenize(in)
+ val words = new ArrayBuffer[Long]()
+ val labels = new ArrayBuffer[Int]()
+ tokens foreach { token =>
+ val Entry(wid, count, tpe, subwords) = getEntry(token)
+ if (tpe == 0) {
+ // addSubwords
+ if (wid < 0) { // OOV
+ if (token != EOS) {
+ words ++= computeSubwords(BOW + token + EOW)
+ }
+ } else {
+ words ++= subwords
+ }
+ } else if (tpe == 1 && wid > 0) {
+ labels += wid - args.nwords
+ }
+ }
+ Line(labels.toArray, words.toArray)
+ }
+
+ def computeHidden(input: Array[Long]): Array[Float] = {
+ val hidden = new Array[Float](args.dim)
+ for (row <- input.map(getInputVector)) {
+ var i = 0
+ while (i < hidden.length) {
+ hidden(i) += row(i) / input.length
+ i += 1
+ }
+ }
+ hidden
+ }
+
+ def predict(line: Line, k: Int = 1): Array[(String, Float)] = {
+ val hidden = computeHidden(line.words)
+ val output = wo.map { o =>
+ o.zip(hidden).map(a => a._1 * a._2).sum
+ }
+ val max = output.max
+ var i = 0
+ var z = 0.0f
+ while (i < output.length) {
+ output(i) = math.exp((output(i) - max).toDouble).toFloat
+ z += output(i)
+ i += 1
+ }
+ i = 0
+ while (i < output.length) {
+ output(i) /= z
+ i += 1
+ }
+ output.zipWithIndex.sortBy(-_._1).take(k).map { case (prob, i) =>
+ labels(i) -> prob
+ }
+ }
+
+ def close(): Unit = {
+ handles.asScala.foreach(_.close())
+ db.close()
+ }
+
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/incubator-s2graph/blob/54c56c36/s2core/src/main/scala/org/apache/s2graph/core/model/fasttext/FastTextArgs.scala
----------------------------------------------------------------------
diff --git a/s2core/src/main/scala/org/apache/s2graph/core/model/fasttext/FastTextArgs.scala b/s2core/src/main/scala/org/apache/s2graph/core/model/fasttext/FastTextArgs.scala
new file mode 100644
index 0000000..20c25f0
--- /dev/null
+++ b/s2core/src/main/scala/org/apache/s2graph/core/model/fasttext/FastTextArgs.scala
@@ -0,0 +1,119 @@
+package org.apache.s2graph.core.model.fasttext
+
+
+package fasttext
+
+import java.io.{ByteArrayInputStream, FileInputStream, InputStream}
+import java.nio.{ByteBuffer, ByteOrder}
+
+case class FastTextArgs(
+ magic: Int,
+ version: Int,
+ dim: Int,
+ ws: Int,
+ epoch: Int,
+ minCount: Int,
+ neg: Int,
+ wordNgrams: Int,
+ loss: Int,
+ model: Int,
+ bucket: Int,
+ minn: Int,
+ maxn: Int,
+ lrUpdateRate: Int,
+ t: Double,
+ size: Int,
+ nwords: Int,
+ nlabels: Int,
+ ntokens: Long,
+ pruneidxSize: Long) {
+
+ def serialize: Array[Byte] = {
+ val bb = ByteBuffer.allocate(92).order(ByteOrder.LITTLE_ENDIAN)
+ bb.putInt(magic)
+ bb.putInt(version)
+ bb.putInt(dim)
+ bb.putInt(ws)
+ bb.putInt(epoch)
+ bb.putInt(minCount)
+ bb.putInt(neg)
+ bb.putInt(wordNgrams)
+ bb.putInt(loss)
+ bb.putInt(model)
+ bb.putInt(bucket)
+ bb.putInt(minn)
+ bb.putInt(maxn)
+ bb.putInt(lrUpdateRate)
+ bb.putDouble(t)
+ bb.putInt(size)
+ bb.putInt(nwords)
+ bb.putInt(nlabels)
+ bb.putLong(ntokens)
+ bb.putLong(pruneidxSize)
+ bb.array()
+ }
+
+ override def toString: String = {
+ s"""magic: $magic
+ |version: $version
+ |dim: $dim
+ |ws : $ws
+ |epoch: $epoch
+ |minCount: $minCount
+ |neg: $neg
+ |wordNgrams: $wordNgrams
+ |loss: $loss
+ |model: $model
+ |bucket: $bucket
+ |minn: $minn
+ |maxn: $maxn
+ |lrUpdateRate: $lrUpdateRate
+ |t: $t
+ |size: $size
+ |nwords: $nwords
+ |nlabels: $nlabels
+ |ntokens: $ntokens
+ |pruneIdxSize: $pruneidxSize
+ |""".stripMargin
+ }
+
+}
+
+object FastTextArgs {
+
+ private def getInt(implicit inputStream: InputStream, buffer: Array[Byte]): Int = {
+ inputStream.read(buffer, 0, 4)
+ ByteBuffer.wrap(buffer).order(ByteOrder.LITTLE_ENDIAN).getInt
+ }
+
+ private def getLong(implicit inputStream: InputStream, buffer: Array[Byte]): Long = {
+ inputStream.read(buffer, 0, 8)
+ ByteBuffer.wrap(buffer).order(ByteOrder.LITTLE_ENDIAN).getLong
+ }
+
+ private def getDouble(implicit inputStream: InputStream, buffer: Array[Byte]): Double = {
+ inputStream.read(buffer, 0, 8)
+ ByteBuffer.wrap(buffer).order(ByteOrder.LITTLE_ENDIAN).getDouble
+ }
+
+ def fromByteArray(ar: Array[Byte]): FastTextArgs =
+ fromInputStream(new ByteArrayInputStream(ar))
+
+ def fromInputStream(inputStream: InputStream): FastTextArgs = {
+ implicit val is: InputStream = inputStream
+ implicit val bytes: Array[Byte] = new Array[Byte](8)
+ FastTextArgs(
+ getInt, getInt, getInt, getInt, getInt, getInt, getInt, getInt, getInt, getInt,
+ getInt, getInt, getInt, getInt, getDouble, getInt, getInt, getInt, getLong, getLong)
+ }
+
+ def main(args: Array[String]): Unit = {
+ val args0 = FastTextArgs.fromInputStream(new FileInputStream("/Users/emeth.kim/d/g/fastText/dataset/sample.model.bin"))
+ val serialized = args0.serialize
+ val args1 = FastTextArgs.fromByteArray(serialized)
+
+ println(args0)
+ println(args1)
+ }
+
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/incubator-s2graph/blob/54c56c36/s2core/src/main/scala/org/apache/s2graph/core/model/fasttext/FastTextFetcher.scala
----------------------------------------------------------------------
diff --git a/s2core/src/main/scala/org/apache/s2graph/core/model/fasttext/FastTextFetcher.scala b/s2core/src/main/scala/org/apache/s2graph/core/model/fasttext/FastTextFetcher.scala
new file mode 100644
index 0000000..774d784
--- /dev/null
+++ b/s2core/src/main/scala/org/apache/s2graph/core/model/fasttext/FastTextFetcher.scala
@@ -0,0 +1,48 @@
+package org.apache.s2graph.core.model.fasttext
+
+import com.typesafe.config.Config
+import org.apache.s2graph.core._
+import org.apache.s2graph.core.types.VertexId
+
+import scala.concurrent.{ExecutionContext, Future}
+
+
+class FastTextFetcher(val graph: S2GraphLike) extends Fetcher {
+ val builder = graph.elementBuilder
+ var fastText: FastText = _
+
+ override def init(config: Config)(implicit ec: ExecutionContext): Future[Fetcher] = {
+ Future {
+ val dbPath = config.getString(FastText.DBPathKey)
+
+ fastText = new FastText(dbPath)
+
+ this
+ }
+ }
+
+ override def fetches(queryRequests: Seq[QueryRequest],
+ prevStepEdges: Map[VertexId, Seq[EdgeWithScore]])(implicit ec: ExecutionContext): Future[Seq[StepResult]] = {
+ val stepResultLs = queryRequests.map { queryRequest =>
+ val vertex = queryRequest.vertex
+ val queryParam = queryRequest.queryParam
+ val line = fastText.getLine(vertex.innerId.toIdString())
+
+ val edgeWithScores = fastText.predict(line, queryParam.limit).map { case (_label, score) =>
+ val tgtVertexId = builder.newVertexId(queryParam.label.service,
+ queryParam.label.tgtColumnWithDir(queryParam.labelWithDir.dir), _label)
+
+ val props: Map[String, Any] = if (queryParam.label.metaPropsInvMap.contains("score")) Map("score" -> score) else Map.empty
+ val edge = graph.toEdge(vertex.innerId.value, tgtVertexId.innerId.value, queryParam.labelName, queryParam.direction, props = props)
+
+ EdgeWithScore(edge, score, queryParam.label)
+ }
+
+ StepResult(edgeWithScores, Nil, Nil)
+ }
+
+ Future.successful(stepResultLs)
+ }
+
+ override def close(): Unit = if (fastText != null) fastText.close()
+}
http://git-wip-us.apache.org/repos/asf/incubator-s2graph/blob/54c56c36/s2core/src/test/scala/org/apache/s2graph/core/model/FetcherTest.scala
----------------------------------------------------------------------
diff --git a/s2core/src/test/scala/org/apache/s2graph/core/model/FetcherTest.scala b/s2core/src/test/scala/org/apache/s2graph/core/model/FetcherTest.scala
index 54e6763..ca1f3a7 100644
--- a/s2core/src/test/scala/org/apache/s2graph/core/model/FetcherTest.scala
+++ b/s2core/src/test/scala/org/apache/s2graph/core/model/FetcherTest.scala
@@ -6,6 +6,7 @@ import com.typesafe.config.ConfigFactory
import org.apache.commons.io.FileUtils
import org.apache.s2graph.core.Integrate.IntegrateCommon
import org.apache.s2graph.core.Management.JsonModel.{Index, Prop}
+import org.apache.s2graph.core.model.annoy.AnnoyModelFetcher
import org.apache.s2graph.core.schema.Label
import org.apache.s2graph.core.{Query, QueryParam}
@@ -98,7 +99,7 @@ class FetcherTest extends IntegrateCommon{
| }]
| },
| "fetcher": {
- | "${ModelManager.FetcherClassNameKey}": "org.apache.s2graph.core.model.AnnoyModelFetcher",
+ | "${ModelManager.FetcherClassNameKey}": "org.apache.s2graph.core.model.annoy.AnnoyModelFetcher",
| "${AnnoyModelFetcher.IndexFilePathKey}": "${localIndexFilePath}",
| "${AnnoyModelFetcher.DictFilePathKey}": "${localDictFilePath}",
| "${AnnoyModelFetcher.DimensionKey}": 10
http://git-wip-us.apache.org/repos/asf/incubator-s2graph/blob/54c56c36/s2core/src/test/scala/org/apache/s2graph/core/model/fasttext/FastTextFetcherTest.scala
----------------------------------------------------------------------
diff --git a/s2core/src/test/scala/org/apache/s2graph/core/model/fasttext/FastTextFetcherTest.scala b/s2core/src/test/scala/org/apache/s2graph/core/model/fasttext/FastTextFetcherTest.scala
new file mode 100644
index 0000000..f91e0d5
--- /dev/null
+++ b/s2core/src/test/scala/org/apache/s2graph/core/model/fasttext/FastTextFetcherTest.scala
@@ -0,0 +1,60 @@
+package org.apache.s2graph.core.model.fasttext
+
+import com.typesafe.config.ConfigFactory
+import org.apache.s2graph.core.Integrate.IntegrateCommon
+import org.apache.s2graph.core.Management.JsonModel.{Index, Prop}
+import org.apache.s2graph.core.{Query, QueryParam, QueryRequest}
+import org.apache.s2graph.core.schema.Label
+
+import scala.collection.JavaConverters._
+import scala.concurrent.{Await, ExecutionContext}
+import scala.concurrent.duration.Duration
+
+class FastTextFetcherTest extends IntegrateCommon {
+ import TestUtil._
+
+ test("FastTextFetcher init test.") {
+ val modelPath = "/Users/shon/Downloads/emoji-context-by-story-comments-20170901-20180410"
+ val config = ConfigFactory.parseMap(Map(FastText.DBPathKey -> modelPath).asJava)
+ val fetcher = new FastTextFetcher(graph)
+ Await.ready(fetcher.init(config)(ExecutionContext.Implicits.global), Duration("3 minutes"))
+
+ val service = management.createService("s2graph", "localhost", "s2graph_htable", -1, None).get
+ val serviceColumn =
+ management.createServiceColumn("s2graph", "keyword", "string", Seq(Prop("age", "0", "int", true)))
+
+ val labelName = "fasttext_test_label"
+
+ Label.findByName(labelName, useCache = false).foreach { label => Label.delete(label.id.get) }
+
+ val label = management.createLabel(
+ labelName,
+ serviceColumn,
+ serviceColumn,
+ true,
+ service.serviceName,
+ Seq.empty[Index].asJava,
+ Seq.empty[Prop].asJava,
+ "strong",
+ null,
+ -1,
+ "v3",
+ "gz",
+ ""
+ )
+ val vertex = graph.elementBuilder.toVertex(service.serviceName, serviceColumn.columnName, "안녕하세요")
+ val queryParam = QueryParam(labelName = labelName, limit = 5)
+
+ val query = Query.toQuery(srcVertices = Seq(vertex), queryParams = Seq(queryParam))
+ val queryRequests = Seq(
+ QueryRequest(query, 0, vertex, queryParam)
+ )
+ val future = fetcher.fetches(queryRequests, Map.empty)
+ val results = Await.result(future, Duration("10 seconds"))
+ results.foreach { stepResult =>
+ stepResult.edgeWithScores.foreach { es =>
+ println(es.edge.tgtVertex.innerIdVal)
+ }
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-s2graph/blob/54c56c36/s2jobs/src/test/scala/org/apache/s2graph/s2jobs/task/custom/process/ALSModelProcessTest.scala
----------------------------------------------------------------------
diff --git a/s2jobs/src/test/scala/org/apache/s2graph/s2jobs/task/custom/process/ALSModelProcessTest.scala b/s2jobs/src/test/scala/org/apache/s2graph/s2jobs/task/custom/process/ALSModelProcessTest.scala
index a8479fe..4d2623e 100644
--- a/s2jobs/src/test/scala/org/apache/s2graph/s2jobs/task/custom/process/ALSModelProcessTest.scala
+++ b/s2jobs/src/test/scala/org/apache/s2graph/s2jobs/task/custom/process/ALSModelProcessTest.scala
@@ -8,7 +8,7 @@ import org.apache.commons.io.FileUtils
import org.apache.s2graph.core.Integrate.IntegrateCommon
import org.apache.s2graph.core.Management.JsonModel.{Index, Prop}
import org.apache.s2graph.core.{Query, QueryParam}
-import org.apache.s2graph.core.model.{ANNIndexWithDict, AnnoyModelFetcher, HDFSImporter, ModelManager}
+import org.apache.s2graph.core.model.{ANNIndexWithDict, HDFSImporter, ModelManager}
import org.apache.s2graph.core.schema.Label
import org.apache.s2graph.s2jobs.task.TaskConf
@@ -57,7 +57,7 @@ class ALSModelProcessTest extends IntegrateCommon with DataFrameSuiteBase {
// | "${ModelManager.ImporterClassNameKey}": "org.apache.s2graph.core.model.IdentityImporter"
// | },
// | "fetcher": {
-// | "${ModelManager.FetcherClassNameKey}": "org.apache.s2graph.core.model.AnnoyModelFetcher",
+// | "${ModelManager.FetcherClassNameKey}": "org.apache.s2graph.core.model.annoy.AnnoyModelFetcher",
// | "${AnnoyModelFetcher.IndexFilePathKey}": "${remoteIndexFilePath}",
// | "${AnnoyModelFetcher.DictFilePathKey}": "${remoteDictFilePath}",
// | "${AnnoyModelFetcher.DimensionKey}": 10
@@ -107,7 +107,7 @@ class ALSModelProcessTest extends IntegrateCommon with DataFrameSuiteBase {
| "${ModelManager.ImporterClassNameKey}": "org.apache.s2graph.core.model.IdentityImporter"
| },
| "fetcher": {
- | "${ModelManager.FetcherClassNameKey}": "org.apache.s2graph.core.model.AnnoyModelFetcher",
+ | "${ModelManager.FetcherClassNameKey}": "org.apache.s2graph.core.model.annoy.AnnoyModelFetcher",
| "${AnnoyModelFetcher.IndexFilePathKey}": "${indexPath}",
| "${AnnoyModelFetcher.DictFilePathKey}": "${dictPath}",
| "${AnnoyModelFetcher.DimensionKey}": 10