add FastTextFetcher.
Project: http://git-wip-us.apache.org/repos/asf/incubator-s2graph/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-s2graph/commit/54c56c36 Tree: http://git-wip-us.apache.org/repos/asf/incubator-s2graph/tree/54c56c36 Diff: http://git-wip-us.apache.org/repos/asf/incubator-s2graph/diff/54c56c36 Branch: refs/heads/master Commit: 54c56c36a04864c101b196180359eb357f5ca030 Parents: 08a80cd Author: DO YUNG YOON <[email protected]> Authored: Fri May 4 16:51:11 2018 +0900 Committer: DO YUNG YOON <[email protected]> Committed: Fri May 4 16:51:11 2018 +0900 ---------------------------------------------------------------------- project/Common.scala | 2 + s2core/build.sbt | 2 +- .../s2graph/core/model/AnnoyModelFetcher.scala | 128 ------------ .../core/model/annoy/AnnoyModelFetcher.scala | 115 +++++++++++ .../s2graph/core/model/fasttext/CopyModel.scala | 122 ++++++++++++ .../s2graph/core/model/fasttext/FastText.scala | 194 +++++++++++++++++++ .../core/model/fasttext/FastTextArgs.scala | 119 ++++++++++++ .../core/model/fasttext/FastTextFetcher.scala | 48 +++++ .../apache/s2graph/core/model/FetcherTest.scala | 3 +- .../model/fasttext/FastTextFetcherTest.scala | 60 ++++++ .../custom/process/ALSModelProcessTest.scala | 6 +- 11 files changed, 666 insertions(+), 133 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-s2graph/blob/54c56c36/project/Common.scala ---------------------------------------------------------------------- diff --git a/project/Common.scala b/project/Common.scala index 04279f4..08552a8 100644 --- a/project/Common.scala +++ b/project/Common.scala @@ -33,6 +33,8 @@ object Common { val KafkaVersion = "0.10.2.1" + val rocksVersion = "5.11.3" + val annoy4sVersion = "0.6.0" val tensorflowVersion = "1.7.0" http://git-wip-us.apache.org/repos/asf/incubator-s2graph/blob/54c56c36/s2core/build.sbt ---------------------------------------------------------------------- diff --git a/s2core/build.sbt b/s2core/build.sbt index bd84c37..cfc32d6 100644 --- a/s2core/build.sbt +++ b/s2core/build.sbt @@ -50,7 +50,7 @@ libraryDependencies ++= Seq( "org.apache.hadoop" % "hadoop-hdfs" % hadoopVersion , "org.apache.lucene" % "lucene-core" % "6.6.0", "org.apache.lucene" % "lucene-queryparser" % "6.6.0", - "org.rocksdb" % "rocksdbjni" % "5.8.0", + "org.rocksdb" % "rocksdbjni" % rocksVersion, "org.scala-lang.modules" %% "scala-java8-compat" % "0.8.0", "com.sksamuel.elastic4s" %% "elastic4s-core" % elastic4sVersion excludeLogging(), "com.sksamuel.elastic4s" %% "elastic4s-http" % elastic4sVersion excludeLogging(), http://git-wip-us.apache.org/repos/asf/incubator-s2graph/blob/54c56c36/s2core/src/main/scala/org/apache/s2graph/core/model/AnnoyModelFetcher.scala ---------------------------------------------------------------------- diff --git a/s2core/src/main/scala/org/apache/s2graph/core/model/AnnoyModelFetcher.scala b/s2core/src/main/scala/org/apache/s2graph/core/model/AnnoyModelFetcher.scala deleted file mode 100644 index 2f2a40c..0000000 --- a/s2core/src/main/scala/org/apache/s2graph/core/model/AnnoyModelFetcher.scala +++ /dev/null @@ -1,128 +0,0 @@ -package org.apache.s2graph.core.model - -import annoy4s.Converters.KeyConverter -import annoy4s._ -import com.typesafe.config.Config -import org.apache.s2graph.core._ -import org.apache.s2graph.core.model.AnnoyModelFetcher.IndexFilePathKey -import org.apache.s2graph.core.types.VertexId - -import scala.concurrent.{ExecutionContext, Future} - -object AnnoyModelFetcher { - val IndexFilePathKey = "annoyIndexFilePath" - val DictFilePathKey = "annoyDictFilePath" - val DimensionKey = "annoyIndexDimension" - val IndexTypeKey = "annoyIndexType" - - // def loadDictFromLocal(file: File): Map[Int, String] = { - // val files = if (file.isDirectory) { - // file.listFiles() - // } else { - // Array(file) - // } - // - // files.flatMap { file => - // Source.fromFile(file).getLines().zipWithIndex.flatMap { case (line, _idx) => - // val tokens = line.stripMargin.split(",") - // try { - // val tpl = if (tokens.length < 2) { - // (tokens.head.toInt, tokens.head) - // } else { - // (tokens.head.toInt, tokens.tail.head) - // } - // Seq(tpl) - // } catch { - // case e: Exception => Nil - // } - // } - // }.toMap - // } - - def buildAnnoy4s[T](indexPath: String)(implicit converter: KeyConverter[T]): Annoy[T] = { - Annoy.load[T](indexPath) - } - - // def buildIndex(indexPath: String, - // dictPath: String, - // dimension: Int, - // indexType: IndexType): ANNIndexWithDict = { - // val dict = loadDictFromLocal(new File(dictPath)) - // val index = new ANNIndex(dimension, indexPath, indexType) - // - // ANNIndexWithDict(index, dict) - // } - // - // def buildIndex(config: Config): ANNIndexWithDict = { - // val indexPath = config.getString(IndexFilePathKey) - // val dictPath = config.getString(DictFilePathKey) - // - // val dimension = config.getInt(DimensionKey) - // val indexType = Try { config.getString(IndexTypeKey) }.toOption.map(IndexType.valueOf).getOrElse(IndexType.ANGULAR) - // - // buildIndex(indexPath, dictPath, dimension, indexType) - // } -} - -// -//case class ANNIndexWithDict(index: ANNIndex, dict: Map[Int, String]) { -// val dictRev = dict.map(kv => kv._2 -> kv._1) -//} - -class AnnoyModelFetcher(val graph: S2GraphLike) extends Fetcher { - val builder = graph.elementBuilder - - // var model: ANNIndexWithDict = _ - var model: Annoy[String] = _ - - override def init(config: Config)(implicit ec: ExecutionContext): Future[Fetcher] = { - Future { - model = AnnoyModelFetcher.buildAnnoy4s(config.getString(IndexFilePathKey)) - // AnnoyModelFetcher.buildIndex(config) - - this - } - } - - /** Fetch **/ - override def fetches(queryRequests: Seq[QueryRequest], - prevStepEdges: Map[VertexId, Seq[EdgeWithScore]])(implicit ec: ExecutionContext): Future[Seq[StepResult]] = { - val stepResultLs = queryRequests.map { queryRequest => - val vertex = queryRequest.vertex - val queryParam = queryRequest.queryParam - - val edgeWithScores = model.query(vertex.innerId.toIdString(), queryParam.limit).getOrElse(Nil).map { case (tgtId, score) => - val tgtVertexId = builder.newVertexId(queryParam.label.service, - queryParam.label.tgtColumnWithDir(queryParam.labelWithDir.dir), tgtId) - - val edge = graph.toEdge(vertex.innerId.value, tgtVertexId.innerId.value, queryParam.labelName, queryParam.direction) - - EdgeWithScore(edge, score, queryParam.label) - } - - StepResult(edgeWithScores, Nil, Nil) - // - // val srcIndexOpt = model.dictRev.get(vertex.innerId.toIdString()) - // - // srcIndexOpt.map { srcIdx => - // val srcVector = model.index.getItemVector(srcIdx) - // val nns = model.index.getNearest(srcVector, queryParam.limit).asScala - // - // val edges = nns.map { tgtIdx => - // val tgtVertexId = builder.newVertexId(queryParam.label.service, - // queryParam.label.tgtColumnWithDir(queryParam.labelWithDir.dir), model.dict(tgtIdx)) - // - // graph.toEdge(vertex.innerId.value, tgtVertexId.innerId.value, queryParam.labelName, queryParam.direction) - // } - // val edgeWithScores = edges.map(e => EdgeWithScore(e, 1.0, queryParam.label)) - // StepResult(edgeWithScores, Nil, Nil) - // }.getOrElse(StepResult.Empty) - } - - Future.successful(stepResultLs) - } - - override def close(): Unit = { - // do clean up - } -} http://git-wip-us.apache.org/repos/asf/incubator-s2graph/blob/54c56c36/s2core/src/main/scala/org/apache/s2graph/core/model/annoy/AnnoyModelFetcher.scala ---------------------------------------------------------------------- diff --git a/s2core/src/main/scala/org/apache/s2graph/core/model/annoy/AnnoyModelFetcher.scala b/s2core/src/main/scala/org/apache/s2graph/core/model/annoy/AnnoyModelFetcher.scala new file mode 100644 index 0000000..a4e2aae --- /dev/null +++ b/s2core/src/main/scala/org/apache/s2graph/core/model/annoy/AnnoyModelFetcher.scala @@ -0,0 +1,115 @@ +package org.apache.s2graph.core.model.annoy + +import annoy4s.Converters.KeyConverter +import annoy4s._ +import com.typesafe.config.Config +import org.apache.s2graph.core._ +import org.apache.s2graph.core.types.VertexId + +import scala.concurrent.{ExecutionContext, Future} + +object AnnoyModelFetcher { + val IndexFilePathKey = "annoyIndexFilePath" + val DictFilePathKey = "annoyDictFilePath" + val DimensionKey = "annoyIndexDimension" + val IndexTypeKey = "annoyIndexType" + + // def loadDictFromLocal(file: File): Map[Int, String] = { + // val files = if (file.isDirectory) { + // file.listFiles() + // } else { + // Array(file) + // } + // + // files.flatMap { file => + // Source.fromFile(file).getLines().zipWithIndex.flatMap { case (line, _idx) => + // val tokens = line.stripMargin.split(",") + // try { + // val tpl = if (tokens.length < 2) { + // (tokens.head.toInt, tokens.head) + // } else { + // (tokens.head.toInt, tokens.tail.head) + // } + // Seq(tpl) + // } catch { + // case e: Exception => Nil + // } + // } + // }.toMap + // } + + def buildAnnoy4s[T](indexPath: String)(implicit converter: KeyConverter[T]): Annoy[T] = { + Annoy.load[T](indexPath) + } + + // def buildIndex(indexPath: String, + // dictPath: String, + // dimension: Int, + // indexType: IndexType): ANNIndexWithDict = { + // val dict = loadDictFromLocal(new File(dictPath)) + // val index = new ANNIndex(dimension, indexPath, indexType) + // + // ANNIndexWithDict(index, dict) + // } + // + // def buildIndex(config: Config): ANNIndexWithDict = { + // val indexPath = config.getString(IndexFilePathKey) + // val dictPath = config.getString(DictFilePathKey) + // + // val dimension = config.getInt(DimensionKey) + // val indexType = Try { config.getString(IndexTypeKey) }.toOption.map(IndexType.valueOf).getOrElse(IndexType.ANGULAR) + // + // buildIndex(indexPath, dictPath, dimension, indexType) + // } +} + +// +//case class ANNIndexWithDict(index: ANNIndex, dict: Map[Int, String]) { +// val dictRev = dict.map(kv => kv._2 -> kv._1) +//} + +class AnnoyModelFetcher(val graph: S2GraphLike) extends Fetcher { + import AnnoyModelFetcher._ + + val builder = graph.elementBuilder + + // var model: ANNIndexWithDict = _ + var model: Annoy[String] = _ + + override def init(config: Config)(implicit ec: ExecutionContext): Future[Fetcher] = { + Future { + model = AnnoyModelFetcher.buildAnnoy4s(config.getString(IndexFilePathKey)) + // AnnoyModelFetcher.buildIndex(config) + + this + } + } + + /** Fetch **/ + override def fetches(queryRequests: Seq[QueryRequest], + prevStepEdges: Map[VertexId, Seq[EdgeWithScore]])(implicit ec: ExecutionContext): Future[Seq[StepResult]] = { + val stepResultLs = queryRequests.map { queryRequest => + val vertex = queryRequest.vertex + val queryParam = queryRequest.queryParam + + val edgeWithScores = model.query(vertex.innerId.toIdString(), queryParam.limit).getOrElse(Nil).map { case (tgtId, score) => + val tgtVertexId = builder.newVertexId(queryParam.label.service, + queryParam.label.tgtColumnWithDir(queryParam.labelWithDir.dir), tgtId) + + val props: Map[String, Any] = if (queryParam.label.metaPropsInvMap.contains("score")) Map("score" -> score) else Map.empty + val edge = graph.toEdge(vertex.innerId.value, tgtVertexId.innerId.value, queryParam.labelName, queryParam.direction, props = props) + + EdgeWithScore(edge, score, queryParam.label) + } + + StepResult(edgeWithScores, Nil, Nil) + } + + Future.successful(stepResultLs) + } + + override def close(): Unit = { + // do clean up + model.close + } +} http://git-wip-us.apache.org/repos/asf/incubator-s2graph/blob/54c56c36/s2core/src/main/scala/org/apache/s2graph/core/model/fasttext/CopyModel.scala ---------------------------------------------------------------------- diff --git a/s2core/src/main/scala/org/apache/s2graph/core/model/fasttext/CopyModel.scala b/s2core/src/main/scala/org/apache/s2graph/core/model/fasttext/CopyModel.scala new file mode 100644 index 0000000..c3e36c7 --- /dev/null +++ b/s2core/src/main/scala/org/apache/s2graph/core/model/fasttext/CopyModel.scala @@ -0,0 +1,122 @@ +package org.apache.s2graph.core.model.fasttext + + +import java.io.{BufferedInputStream, FileInputStream, InputStream} +import java.nio.{ByteBuffer, ByteOrder} +import java.util + +import org.apache.s2graph.core.model.fasttext.fasttext.FastTextArgs +import org.rocksdb._ + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + +object CopyModel { + + def writeArgs(db: RocksDB, handle: ColumnFamilyHandle, args: FastTextArgs): Unit = { + val wo = new WriteOptions().setDisableWAL(true).setSync(false) + db.put(handle, wo, "args".getBytes("UTF-8"), args.serialize) + wo.close() + println("done ") + } + + def writeVocab(is: InputStream, db: RocksDB, + vocabHandle: ColumnFamilyHandle, labelHandle: ColumnFamilyHandle, args: FastTextArgs): Unit = { + val wo = new WriteOptions().setDisableWAL(true).setSync(false) + val bb = ByteBuffer.allocate(13).order(ByteOrder.LITTLE_ENDIAN) + val wb = new ArrayBuffer[Byte] + for (wid <- 0 until args.size) { + bb.clear() + wb.clear() + var b = is.read() + while (b != 0) { + wb += b.toByte + b = is.read() + } + bb.putInt(wid) + is.read(bb.array(), 4, 9) + db.put(vocabHandle, wo, wb.toArray, bb.array()) + + if (bb.get(12) == 1) { + val label = wid - args.nwords + db.put(labelHandle, ByteBuffer.allocate(4).putInt(label).array(), wb.toArray) + } + + if ((wid + 1) % 1000 == 0) + print(f"\rprocessing ${100 * (wid + 1) / args.size.toFloat}%.2f%%") + } + println("\rdone ") + wo.close() + } + + def writeVectors(is: InputStream, db: RocksDB, handle: ColumnFamilyHandle, args: FastTextArgs): Unit = { + require(is.read() == 0, "not implemented") + val wo = new WriteOptions().setDisableWAL(true).setSync(false) + val bb = ByteBuffer.allocate(16).order(ByteOrder.LITTLE_ENDIAN) + val key = ByteBuffer.allocate(8) + val value = new Array[Byte](args.dim * 4) + is.read(bb.array()) + val m = bb.getLong + val n = bb.getLong + require(n * 4 == value.length) + var i = 0L + while (i < m) { + key.clear() + key.putLong(i) + is.read(value) + db.put(handle, wo, key.array(), value) + if ((i + 1) % 1000 == 0) + print(f"\rprocessing ${100 * (i + 1) / m.toFloat}%.2f%%") + i += 1 + } + println("\rdone ") + wo.close() + } + + def printHelp(): Unit = { + println("usage: CopyModel <in> <out>") + } + + def copy(in: String, out: String): Unit = { + RocksDB.destroyDB(out, new Options) + + val dbOptions = new DBOptions() + .setCreateIfMissing(true) + .setCreateMissingColumnFamilies(true) + .setAllowMmapReads(false) + .setMaxOpenFiles(500000) + .setDbWriteBufferSize(134217728) + .setMaxBackgroundCompactions(20) + + val descriptors = new java.util.LinkedList[ColumnFamilyDescriptor]() + descriptors.add(new ColumnFamilyDescriptor(RocksDB.DEFAULT_COLUMN_FAMILY)) + descriptors.add(new ColumnFamilyDescriptor("vocab".getBytes())) + descriptors.add(new ColumnFamilyDescriptor("i".getBytes())) + descriptors.add(new ColumnFamilyDescriptor("o".getBytes())) + val handles = new util.LinkedList[ColumnFamilyHandle]() + val db = RocksDB.open(dbOptions, out, descriptors, handles) + + val is = new BufferedInputStream(new FileInputStream(in)) + val fastTextArgs = FastTextArgs.fromInputStream(is) + + require(fastTextArgs.magic == FastText.FASTTEXT_FILEFORMAT_MAGIC_INT32) + require(fastTextArgs.version == FastText.FASTTEXT_VERSION) + + println("step 1: writing args") + writeArgs(db, handles.get(0), fastTextArgs) + println("step 2: writing vocab") + writeVocab(is, db, handles.get(1), handles.get(0), fastTextArgs) + println("step 3: writing input vectors") + writeVectors(is, db, handles.get(2), fastTextArgs) + println("step 4: writing output vectors") + writeVectors(is, db, handles.get(3), fastTextArgs) + println("step 5: compactRange") + db.compactRange() + println("done") + + handles.asScala.foreach(_.close()) + db.close() + is.close() + } + +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/incubator-s2graph/blob/54c56c36/s2core/src/main/scala/org/apache/s2graph/core/model/fasttext/FastText.scala ---------------------------------------------------------------------- diff --git a/s2core/src/main/scala/org/apache/s2graph/core/model/fasttext/FastText.scala b/s2core/src/main/scala/org/apache/s2graph/core/model/fasttext/FastText.scala new file mode 100644 index 0000000..b5d10a9 --- /dev/null +++ b/s2core/src/main/scala/org/apache/s2graph/core/model/fasttext/FastText.scala @@ -0,0 +1,194 @@ +package org.apache.s2graph.core.model.fasttext + +import java.nio.{ByteBuffer, ByteOrder} +import java.util + +import org.apache.s2graph.core.model.fasttext.fasttext.FastTextArgs +import org.rocksdb.{ColumnFamilyDescriptor, ColumnFamilyHandle, DBOptions, RocksDB} + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + +case class Line(labels: Array[Int], words: Array[Long]) + +case class Entry(wid: Int, count: Long, tpe: Byte, subwords: Array[Long]) + +object FastText { + val EOS = "</s>" + val BOW = "<" + val EOW = ">" + + val FASTTEXT_VERSION = 12 // Version 1b + val FASTTEXT_FILEFORMAT_MAGIC_INT32 = 793712314 + + val MODEL_CBOW = 1 + val MODEL_SG = 2 + val MODEL_SUP = 3 + + val LOSS_HS = 1 + val LOSS_NS = 2 + val LOSS_SOFTMAX = 3 + + val DBPathKey = "dbPath" + + def tokenize(in: String): Array[String] = in.split("\\s+") ++ Array("</s>") + + def getSubwords(word: String, minn: Int, maxn: Int): Array[String] = { + val l = math.max(minn, 1) + val u = math.min(maxn, word.length) + val r = l to u flatMap word.sliding + r.filterNot(s => s == BOW || s == EOW).toArray + } + + def hash(str: String): Long = { + var h = 2166136261L.toInt + for (b <- str.getBytes) { + h = (h ^ b) * 16777619 + } + h & 0xffffffffL + } + +} + +class FastText(name: String) extends AutoCloseable { + + import FastText._ + + private val dbOptions = new DBOptions() + private val descriptors = new java.util.LinkedList[ColumnFamilyDescriptor]() + descriptors.add(new ColumnFamilyDescriptor(RocksDB.DEFAULT_COLUMN_FAMILY)) + descriptors.add(new ColumnFamilyDescriptor("vocab".getBytes())) + descriptors.add(new ColumnFamilyDescriptor("i".getBytes())) + descriptors.add(new ColumnFamilyDescriptor("o".getBytes())) + private val handles = new util.LinkedList[ColumnFamilyHandle]() + private val db = RocksDB.openReadOnly(dbOptions, name, descriptors, handles) + + private val defaultHandle = handles.get(0) + private val vocabHandle = handles.get(1) + private val inputVectorHandle = handles.get(2) + private val outputVectorHandle = handles.get(3) + + private val args = FastTextArgs.fromByteArray(db.get(defaultHandle, "args".getBytes("UTF-8"))) + private val wo = loadOutputVectors() + private val labels = loadLabels() + + println(args) + + require(args.magic == FASTTEXT_FILEFORMAT_MAGIC_INT32) + require(args.version == FASTTEXT_VERSION) + + // only sup/softmax supported + // others are the future work. + require(args.model == MODEL_SUP) + require(args.loss == LOSS_SOFTMAX) + + private def getVector(handle: ColumnFamilyHandle, key: Long): Array[Float] = { + val keyBytes = ByteBuffer.allocate(8).putLong(key).array() + val bb = ByteBuffer.wrap(db.get(handle, keyBytes)).order(ByteOrder.LITTLE_ENDIAN) + Array.fill(args.dim)(bb.getFloat) + } + + private def loadOutputVectors(): Array[Array[Float]] = + Array.tabulate(args.nlabels)(key => getVector(outputVectorHandle, key.toLong)) + + private def loadLabels(): Array[String] = { + val result = new Array[String](args.nlabels) + val it = db.newIterator(defaultHandle) + var i = 0 + it.seekToFirst() + while (it.isValid) { + val key = ByteBuffer.wrap(it.key()).getInt() + if (key < args.nlabels) { + require(i == key) + result(i) = new String(it.value(), "UTF-8") + i += 1 + } + it.next() + } + result + } + + def getInputVector(key: Long): Array[Float] = getVector(inputVectorHandle, key) + + def getOutputVector(key: Long): Array[Float] = getVector(outputVectorHandle, key) + + def getEntry(word: String): Entry = { + val raw = db.get(vocabHandle, word.getBytes("UTF-8")) + if (raw == null) { + Entry(-1, 0L, 1, Array.emptyLongArray) + } else { + val bb = ByteBuffer.wrap(raw).order(ByteOrder.LITTLE_ENDIAN) + val wid = bb.getInt + val count = bb.getLong + val tpe = bb.get + val subwords = if (word != EOS && tpe == 0) Array(wid.toLong) ++ computeSubwords(BOW + word + EOW) else Array(wid.toLong) + Entry(wid, count, tpe, subwords) + } + } + + def computeSubwords(word: String): Array[Long] = + getSubwords(word, args.minn, args.maxn).map { w => args.nwords + (hash(w) % args.bucket.toLong) } + + def getLine(in: String): Line = { + val tokens = tokenize(in) + val words = new ArrayBuffer[Long]() + val labels = new ArrayBuffer[Int]() + tokens foreach { token => + val Entry(wid, count, tpe, subwords) = getEntry(token) + if (tpe == 0) { + // addSubwords + if (wid < 0) { // OOV + if (token != EOS) { + words ++= computeSubwords(BOW + token + EOW) + } + } else { + words ++= subwords + } + } else if (tpe == 1 && wid > 0) { + labels += wid - args.nwords + } + } + Line(labels.toArray, words.toArray) + } + + def computeHidden(input: Array[Long]): Array[Float] = { + val hidden = new Array[Float](args.dim) + for (row <- input.map(getInputVector)) { + var i = 0 + while (i < hidden.length) { + hidden(i) += row(i) / input.length + i += 1 + } + } + hidden + } + + def predict(line: Line, k: Int = 1): Array[(String, Float)] = { + val hidden = computeHidden(line.words) + val output = wo.map { o => + o.zip(hidden).map(a => a._1 * a._2).sum + } + val max = output.max + var i = 0 + var z = 0.0f + while (i < output.length) { + output(i) = math.exp((output(i) - max).toDouble).toFloat + z += output(i) + i += 1 + } + i = 0 + while (i < output.length) { + output(i) /= z + i += 1 + } + output.zipWithIndex.sortBy(-_._1).take(k).map { case (prob, i) => + labels(i) -> prob + } + } + + def close(): Unit = { + handles.asScala.foreach(_.close()) + db.close() + } + +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/incubator-s2graph/blob/54c56c36/s2core/src/main/scala/org/apache/s2graph/core/model/fasttext/FastTextArgs.scala ---------------------------------------------------------------------- diff --git a/s2core/src/main/scala/org/apache/s2graph/core/model/fasttext/FastTextArgs.scala b/s2core/src/main/scala/org/apache/s2graph/core/model/fasttext/FastTextArgs.scala new file mode 100644 index 0000000..20c25f0 --- /dev/null +++ b/s2core/src/main/scala/org/apache/s2graph/core/model/fasttext/FastTextArgs.scala @@ -0,0 +1,119 @@ +package org.apache.s2graph.core.model.fasttext + + +package fasttext + +import java.io.{ByteArrayInputStream, FileInputStream, InputStream} +import java.nio.{ByteBuffer, ByteOrder} + +case class FastTextArgs( + magic: Int, + version: Int, + dim: Int, + ws: Int, + epoch: Int, + minCount: Int, + neg: Int, + wordNgrams: Int, + loss: Int, + model: Int, + bucket: Int, + minn: Int, + maxn: Int, + lrUpdateRate: Int, + t: Double, + size: Int, + nwords: Int, + nlabels: Int, + ntokens: Long, + pruneidxSize: Long) { + + def serialize: Array[Byte] = { + val bb = ByteBuffer.allocate(92).order(ByteOrder.LITTLE_ENDIAN) + bb.putInt(magic) + bb.putInt(version) + bb.putInt(dim) + bb.putInt(ws) + bb.putInt(epoch) + bb.putInt(minCount) + bb.putInt(neg) + bb.putInt(wordNgrams) + bb.putInt(loss) + bb.putInt(model) + bb.putInt(bucket) + bb.putInt(minn) + bb.putInt(maxn) + bb.putInt(lrUpdateRate) + bb.putDouble(t) + bb.putInt(size) + bb.putInt(nwords) + bb.putInt(nlabels) + bb.putLong(ntokens) + bb.putLong(pruneidxSize) + bb.array() + } + + override def toString: String = { + s"""magic: $magic + |version: $version + |dim: $dim + |ws : $ws + |epoch: $epoch + |minCount: $minCount + |neg: $neg + |wordNgrams: $wordNgrams + |loss: $loss + |model: $model + |bucket: $bucket + |minn: $minn + |maxn: $maxn + |lrUpdateRate: $lrUpdateRate + |t: $t + |size: $size + |nwords: $nwords + |nlabels: $nlabels + |ntokens: $ntokens + |pruneIdxSize: $pruneidxSize + |""".stripMargin + } + +} + +object FastTextArgs { + + private def getInt(implicit inputStream: InputStream, buffer: Array[Byte]): Int = { + inputStream.read(buffer, 0, 4) + ByteBuffer.wrap(buffer).order(ByteOrder.LITTLE_ENDIAN).getInt + } + + private def getLong(implicit inputStream: InputStream, buffer: Array[Byte]): Long = { + inputStream.read(buffer, 0, 8) + ByteBuffer.wrap(buffer).order(ByteOrder.LITTLE_ENDIAN).getLong + } + + private def getDouble(implicit inputStream: InputStream, buffer: Array[Byte]): Double = { + inputStream.read(buffer, 0, 8) + ByteBuffer.wrap(buffer).order(ByteOrder.LITTLE_ENDIAN).getDouble + } + + def fromByteArray(ar: Array[Byte]): FastTextArgs = + fromInputStream(new ByteArrayInputStream(ar)) + + def fromInputStream(inputStream: InputStream): FastTextArgs = { + implicit val is: InputStream = inputStream + implicit val bytes: Array[Byte] = new Array[Byte](8) + FastTextArgs( + getInt, getInt, getInt, getInt, getInt, getInt, getInt, getInt, getInt, getInt, + getInt, getInt, getInt, getInt, getDouble, getInt, getInt, getInt, getLong, getLong) + } + + def main(args: Array[String]): Unit = { + val args0 = FastTextArgs.fromInputStream(new FileInputStream("/Users/emeth.kim/d/g/fastText/dataset/sample.model.bin")) + val serialized = args0.serialize + val args1 = FastTextArgs.fromByteArray(serialized) + + println(args0) + println(args1) + } + +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/incubator-s2graph/blob/54c56c36/s2core/src/main/scala/org/apache/s2graph/core/model/fasttext/FastTextFetcher.scala ---------------------------------------------------------------------- diff --git a/s2core/src/main/scala/org/apache/s2graph/core/model/fasttext/FastTextFetcher.scala b/s2core/src/main/scala/org/apache/s2graph/core/model/fasttext/FastTextFetcher.scala new file mode 100644 index 0000000..774d784 --- /dev/null +++ b/s2core/src/main/scala/org/apache/s2graph/core/model/fasttext/FastTextFetcher.scala @@ -0,0 +1,48 @@ +package org.apache.s2graph.core.model.fasttext + +import com.typesafe.config.Config +import org.apache.s2graph.core._ +import org.apache.s2graph.core.types.VertexId + +import scala.concurrent.{ExecutionContext, Future} + + +class FastTextFetcher(val graph: S2GraphLike) extends Fetcher { + val builder = graph.elementBuilder + var fastText: FastText = _ + + override def init(config: Config)(implicit ec: ExecutionContext): Future[Fetcher] = { + Future { + val dbPath = config.getString(FastText.DBPathKey) + + fastText = new FastText(dbPath) + + this + } + } + + override def fetches(queryRequests: Seq[QueryRequest], + prevStepEdges: Map[VertexId, Seq[EdgeWithScore]])(implicit ec: ExecutionContext): Future[Seq[StepResult]] = { + val stepResultLs = queryRequests.map { queryRequest => + val vertex = queryRequest.vertex + val queryParam = queryRequest.queryParam + val line = fastText.getLine(vertex.innerId.toIdString()) + + val edgeWithScores = fastText.predict(line, queryParam.limit).map { case (_label, score) => + val tgtVertexId = builder.newVertexId(queryParam.label.service, + queryParam.label.tgtColumnWithDir(queryParam.labelWithDir.dir), _label) + + val props: Map[String, Any] = if (queryParam.label.metaPropsInvMap.contains("score")) Map("score" -> score) else Map.empty + val edge = graph.toEdge(vertex.innerId.value, tgtVertexId.innerId.value, queryParam.labelName, queryParam.direction, props = props) + + EdgeWithScore(edge, score, queryParam.label) + } + + StepResult(edgeWithScores, Nil, Nil) + } + + Future.successful(stepResultLs) + } + + override def close(): Unit = if (fastText != null) fastText.close() +} http://git-wip-us.apache.org/repos/asf/incubator-s2graph/blob/54c56c36/s2core/src/test/scala/org/apache/s2graph/core/model/FetcherTest.scala ---------------------------------------------------------------------- diff --git a/s2core/src/test/scala/org/apache/s2graph/core/model/FetcherTest.scala b/s2core/src/test/scala/org/apache/s2graph/core/model/FetcherTest.scala index 54e6763..ca1f3a7 100644 --- a/s2core/src/test/scala/org/apache/s2graph/core/model/FetcherTest.scala +++ b/s2core/src/test/scala/org/apache/s2graph/core/model/FetcherTest.scala @@ -6,6 +6,7 @@ import com.typesafe.config.ConfigFactory import org.apache.commons.io.FileUtils import org.apache.s2graph.core.Integrate.IntegrateCommon import org.apache.s2graph.core.Management.JsonModel.{Index, Prop} +import org.apache.s2graph.core.model.annoy.AnnoyModelFetcher import org.apache.s2graph.core.schema.Label import org.apache.s2graph.core.{Query, QueryParam} @@ -98,7 +99,7 @@ class FetcherTest extends IntegrateCommon{ | }] | }, | "fetcher": { - | "${ModelManager.FetcherClassNameKey}": "org.apache.s2graph.core.model.AnnoyModelFetcher", + | "${ModelManager.FetcherClassNameKey}": "org.apache.s2graph.core.model.annoy.AnnoyModelFetcher", | "${AnnoyModelFetcher.IndexFilePathKey}": "${localIndexFilePath}", | "${AnnoyModelFetcher.DictFilePathKey}": "${localDictFilePath}", | "${AnnoyModelFetcher.DimensionKey}": 10 http://git-wip-us.apache.org/repos/asf/incubator-s2graph/blob/54c56c36/s2core/src/test/scala/org/apache/s2graph/core/model/fasttext/FastTextFetcherTest.scala ---------------------------------------------------------------------- diff --git a/s2core/src/test/scala/org/apache/s2graph/core/model/fasttext/FastTextFetcherTest.scala b/s2core/src/test/scala/org/apache/s2graph/core/model/fasttext/FastTextFetcherTest.scala new file mode 100644 index 0000000..f91e0d5 --- /dev/null +++ b/s2core/src/test/scala/org/apache/s2graph/core/model/fasttext/FastTextFetcherTest.scala @@ -0,0 +1,60 @@ +package org.apache.s2graph.core.model.fasttext + +import com.typesafe.config.ConfigFactory +import org.apache.s2graph.core.Integrate.IntegrateCommon +import org.apache.s2graph.core.Management.JsonModel.{Index, Prop} +import org.apache.s2graph.core.{Query, QueryParam, QueryRequest} +import org.apache.s2graph.core.schema.Label + +import scala.collection.JavaConverters._ +import scala.concurrent.{Await, ExecutionContext} +import scala.concurrent.duration.Duration + +class FastTextFetcherTest extends IntegrateCommon { + import TestUtil._ + + test("FastTextFetcher init test.") { + val modelPath = "/Users/shon/Downloads/emoji-context-by-story-comments-20170901-20180410" + val config = ConfigFactory.parseMap(Map(FastText.DBPathKey -> modelPath).asJava) + val fetcher = new FastTextFetcher(graph) + Await.ready(fetcher.init(config)(ExecutionContext.Implicits.global), Duration("3 minutes")) + + val service = management.createService("s2graph", "localhost", "s2graph_htable", -1, None).get + val serviceColumn = + management.createServiceColumn("s2graph", "keyword", "string", Seq(Prop("age", "0", "int", true))) + + val labelName = "fasttext_test_label" + + Label.findByName(labelName, useCache = false).foreach { label => Label.delete(label.id.get) } + + val label = management.createLabel( + labelName, + serviceColumn, + serviceColumn, + true, + service.serviceName, + Seq.empty[Index].asJava, + Seq.empty[Prop].asJava, + "strong", + null, + -1, + "v3", + "gz", + "" + ) + val vertex = graph.elementBuilder.toVertex(service.serviceName, serviceColumn.columnName, "ìë íì¸ì") + val queryParam = QueryParam(labelName = labelName, limit = 5) + + val query = Query.toQuery(srcVertices = Seq(vertex), queryParams = Seq(queryParam)) + val queryRequests = Seq( + QueryRequest(query, 0, vertex, queryParam) + ) + val future = fetcher.fetches(queryRequests, Map.empty) + val results = Await.result(future, Duration("10 seconds")) + results.foreach { stepResult => + stepResult.edgeWithScores.foreach { es => + println(es.edge.tgtVertex.innerIdVal) + } + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-s2graph/blob/54c56c36/s2jobs/src/test/scala/org/apache/s2graph/s2jobs/task/custom/process/ALSModelProcessTest.scala ---------------------------------------------------------------------- diff --git a/s2jobs/src/test/scala/org/apache/s2graph/s2jobs/task/custom/process/ALSModelProcessTest.scala b/s2jobs/src/test/scala/org/apache/s2graph/s2jobs/task/custom/process/ALSModelProcessTest.scala index a8479fe..4d2623e 100644 --- a/s2jobs/src/test/scala/org/apache/s2graph/s2jobs/task/custom/process/ALSModelProcessTest.scala +++ b/s2jobs/src/test/scala/org/apache/s2graph/s2jobs/task/custom/process/ALSModelProcessTest.scala @@ -8,7 +8,7 @@ import org.apache.commons.io.FileUtils import org.apache.s2graph.core.Integrate.IntegrateCommon import org.apache.s2graph.core.Management.JsonModel.{Index, Prop} import org.apache.s2graph.core.{Query, QueryParam} -import org.apache.s2graph.core.model.{ANNIndexWithDict, AnnoyModelFetcher, HDFSImporter, ModelManager} +import org.apache.s2graph.core.model.{ANNIndexWithDict, HDFSImporter, ModelManager} import org.apache.s2graph.core.schema.Label import org.apache.s2graph.s2jobs.task.TaskConf @@ -57,7 +57,7 @@ class ALSModelProcessTest extends IntegrateCommon with DataFrameSuiteBase { // | "${ModelManager.ImporterClassNameKey}": "org.apache.s2graph.core.model.IdentityImporter" // | }, // | "fetcher": { -// | "${ModelManager.FetcherClassNameKey}": "org.apache.s2graph.core.model.AnnoyModelFetcher", +// | "${ModelManager.FetcherClassNameKey}": "org.apache.s2graph.core.model.annoy.AnnoyModelFetcher", // | "${AnnoyModelFetcher.IndexFilePathKey}": "${remoteIndexFilePath}", // | "${AnnoyModelFetcher.DictFilePathKey}": "${remoteDictFilePath}", // | "${AnnoyModelFetcher.DimensionKey}": 10 @@ -107,7 +107,7 @@ class ALSModelProcessTest extends IntegrateCommon with DataFrameSuiteBase { | "${ModelManager.ImporterClassNameKey}": "org.apache.s2graph.core.model.IdentityImporter" | }, | "fetcher": { - | "${ModelManager.FetcherClassNameKey}": "org.apache.s2graph.core.model.AnnoyModelFetcher", + | "${ModelManager.FetcherClassNameKey}": "org.apache.s2graph.core.model.annoy.AnnoyModelFetcher", | "${AnnoyModelFetcher.IndexFilePathKey}": "${indexPath}", | "${AnnoyModelFetcher.DictFilePathKey}": "${dictPath}", | "${AnnoyModelFetcher.DimensionKey}": 10
