add InceptionFetcher.
Project: http://git-wip-us.apache.org/repos/asf/incubator-s2graph/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-s2graph/commit/b91054c5 Tree: http://git-wip-us.apache.org/repos/asf/incubator-s2graph/tree/b91054c5 Diff: http://git-wip-us.apache.org/repos/asf/incubator-s2graph/diff/b91054c5 Branch: refs/heads/master Commit: b91054c5b874f2cfc69a2f7d6822191dcc9c1ccf Parents: 5ee1906 Author: DO YUNG YOON <[email protected]> Authored: Sat May 12 10:07:54 2018 +0900 Committer: DO YUNG YOON <[email protected]> Committed: Sat May 12 10:07:54 2018 +0900 ---------------------------------------------------------------------- .../movielens/schema/edge.similar.movie.graphql | 1 + .../core/fetcher/tensorflow/LabelImage.java | 214 +++++++++++++++++++ .../org/apache/s2graph/core/Management.scala | 7 + .../fetcher/tensorflow/InceptionFetcher.scala | 85 ++++++++ .../s2graph/core/fetcher/BaseFetcherTest.scala | 77 +++++++ .../tensorflow/InceptionFetcherTest.scala | 76 +++++++ .../apache/s2graph/s2jobs/JobDescription.scala | 15 +- .../task/custom/process/ALSModelProcess.scala | 15 -- .../task/custom/sink/AnnoyIndexBuildSink.scala | 21 ++ .../custom/process/ALSModelProcessTest.scala | 2 +- 10 files changed, 495 insertions(+), 18 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-s2graph/blob/b91054c5/example/movielens/schema/edge.similar.movie.graphql ---------------------------------------------------------------------- diff --git a/example/movielens/schema/edge.similar.movie.graphql b/example/movielens/schema/edge.similar.movie.graphql index f8ac33b..bb625a0 100644 --- a/example/movielens/schema/edge.similar.movie.graphql +++ b/example/movielens/schema/edge.similar.movie.graphql @@ -45,6 +45,7 @@ mutation{ name:"_PK" propNames:["score"] } + options: "{\n \"fetcher\": {\n \"class\": \"org.apache.s2graph.core.fetcher.annoy.AnnoyModelFetcher\",\n \"annoyIndexFilePath\": \"/tmp/annoy_result\"\n }\n}" ) { isSuccess message http://git-wip-us.apache.org/repos/asf/incubator-s2graph/blob/b91054c5/s2core/src/main/java/org/apache/s2graph/core/fetcher/tensorflow/LabelImage.java ---------------------------------------------------------------------- diff --git a/s2core/src/main/java/org/apache/s2graph/core/fetcher/tensorflow/LabelImage.java b/s2core/src/main/java/org/apache/s2graph/core/fetcher/tensorflow/LabelImage.java new file mode 100644 index 0000000..1125cea --- /dev/null +++ b/s2core/src/main/java/org/apache/s2graph/core/fetcher/tensorflow/LabelImage.java @@ -0,0 +1,214 @@ +package org.apache.s2graph.core.fetcher.tensorflow; + +import org.tensorflow.*; +import org.tensorflow.types.UInt8; + +import java.io.IOException; +import java.io.PrintStream; +import java.nio.charset.Charset; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.Arrays; +import java.util.List; + +/** Sample use of the TensorFlow Java API to label images using a pre-trained model. */ +public class LabelImage { + private static void printUsage(PrintStream s) { + final String url = + "https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip"; + s.println( + "Java program that uses a pre-trained Inception model (http://arxiv.org/abs/1512.00567)"); + s.println("to label JPEG images."); + s.println("TensorFlow version: " + TensorFlow.version()); + s.println(); + s.println("Usage: label_image <model dir> <image file>"); + s.println(); + s.println("Where:"); + s.println("<model dir> is a directory containing the unzipped contents of the inception model"); + s.println(" (from " + url + ")"); + s.println("<image file> is the path to a JPEG image file"); + } + + public static void main(String[] args) { + if (args.length != 2) { + printUsage(System.err); + System.exit(1); + } + String modelDir = args[0]; + String imageFile = args[1]; + + byte[] graphDef = readAllBytesOrExit(Paths.get(modelDir, "tensorflow_inception_graph.pb")); + List<String> labels = + readAllLinesOrExit(Paths.get(modelDir, "imagenet_comp_graph_label_strings.txt")); + byte[] imageBytes = readAllBytesOrExit(Paths.get(imageFile)); + + try (Tensor<Float> image = constructAndExecuteGraphToNormalizeImage(imageBytes)) { + float[] labelProbabilities = executeInceptionGraph(graphDef, image); + int bestLabelIdx = maxIndex(labelProbabilities); + System.out.println( + String.format("BEST MATCH: %s (%.2f%% likely)", + labels.get(bestLabelIdx), + labelProbabilities[bestLabelIdx] * 100f)); + } + } + + static Tensor<Float> constructAndExecuteGraphToNormalizeImage(byte[] imageBytes) { + try (Graph g = new Graph()) { + GraphBuilder b = new GraphBuilder(g); + // Some constants specific to the pre-trained model at: + // https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip + // + // - The model was trained with images scaled to 224x224 pixels. + // - The colors, represented as R, G, B in 1-byte each were converted to + // float using (value - Mean)/Scale. + final int H = 224; + final int W = 224; + final float mean = 117f; + final float scale = 1f; + + // Since the graph is being constructed once per execution here, we can use a constant for the + // input image. If the graph were to be re-used for multiple input images, a placeholder would + // have been more appropriate. + final Output<String> input = b.constant("input", imageBytes); + final Output<Float> output = + b.div( + b.sub( + b.resizeBilinear( + b.expandDims( + b.cast(b.decodeJpeg(input, 3), Float.class), + b.constant("make_batch", 0)), + b.constant("size", new int[] {H, W})), + b.constant("mean", mean)), + b.constant("scale", scale)); + try (Session s = new Session(g)) { + return s.runner().fetch(output.op().name()).run().get(0).expect(Float.class); + } + } + } + + static float[] executeInceptionGraph(byte[] graphDef, Tensor<Float> image) { + try (Graph g = new Graph()) { + g.importGraphDef(graphDef); + try (Session s = new Session(g); + Tensor<Float> result = + s.runner().feed("input", image).fetch("output").run().get(0).expect(Float.class)) { + final long[] rshape = result.shape(); + if (result.numDimensions() != 2 || rshape[0] != 1) { + throw new RuntimeException( + String.format( + "Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s", + Arrays.toString(rshape))); + } + int nlabels = (int) rshape[1]; + return result.copyTo(new float[1][nlabels])[0]; + } + } + } + + static int maxIndex(float[] probabilities) { + int best = 0; + for (int i = 1; i < probabilities.length; ++i) { + if (probabilities[i] > probabilities[best]) { + best = i; + } + } + return best; + } + + static byte[] readAllBytesOrExit(Path path) { + try { + return Files.readAllBytes(path); + } catch (IOException e) { + System.err.println("Failed to read [" + path + "]: " + e.getMessage()); + System.exit(1); + } + return null; + } + + static List<String> readAllLinesOrExit(Path path) { + try { + return Files.readAllLines(path, Charset.forName("UTF-8")); + } catch (IOException e) { + System.err.println("Failed to read [" + path + "]: " + e.getMessage()); + System.exit(0); + } + return null; + } + + // In the fullness of time, equivalents of the methods of this class should be auto-generated from + // the OpDefs linked into libtensorflow_jni.so. That would match what is done in other languages + // like Python, C++ and Go. + static class GraphBuilder { + GraphBuilder(Graph g) { + this.g = g; + } + + Output<Float> div(Output<Float> x, Output<Float> y) { + return binaryOp("Div", x, y); + } + + <T> Output<T> sub(Output<T> x, Output<T> y) { + return binaryOp("Sub", x, y); + } + + <T> Output<Float> resizeBilinear(Output<T> images, Output<Integer> size) { + return binaryOp3("ResizeBilinear", images, size); + } + + <T> Output<T> expandDims(Output<T> input, Output<Integer> dim) { + return binaryOp3("ExpandDims", input, dim); + } + + <T, U> Output<U> cast(Output<T> value, Class<U> type) { + DataType dtype = DataType.fromClass(type); + return g.opBuilder("Cast", "Cast") + .addInput(value) + .setAttr("DstT", dtype) + .build() + .<U>output(0); + } + + Output<UInt8> decodeJpeg(Output<String> contents, long channels) { + return g.opBuilder("DecodeJpeg", "DecodeJpeg") + .addInput(contents) + .setAttr("channels", channels) + .build() + .<UInt8>output(0); + } + + <T> Output<T> constant(String name, Object value, Class<T> type) { + try (Tensor<T> t = Tensor.<T>create(value, type)) { + return g.opBuilder("Const", name) + .setAttr("dtype", DataType.fromClass(type)) + .setAttr("value", t) + .build() + .<T>output(0); + } + } + Output<String> constant(String name, byte[] value) { + return this.constant(name, value, String.class); + } + + Output<Integer> constant(String name, int value) { + return this.constant(name, value, Integer.class); + } + + Output<Integer> constant(String name, int[] value) { + return this.constant(name, value, Integer.class); + } + + Output<Float> constant(String name, float value) { + return this.constant(name, value, Float.class); + } + + private <T> Output<T> binaryOp(String type, Output<T> in1, Output<T> in2) { + return g.opBuilder(type, type).addInput(in1).addInput(in2).build().<T>output(0); + } + + private <T, U, V> Output<T> binaryOp3(String type, Output<U> in1, Output<V> in2) { + return g.opBuilder(type, type).addInput(in1).addInput(in2).build().<T>output(0); + } + private Graph g; + } +} http://git-wip-us.apache.org/repos/asf/incubator-s2graph/blob/b91054c5/s2core/src/main/scala/org/apache/s2graph/core/Management.scala ---------------------------------------------------------------------- diff --git a/s2core/src/main/scala/org/apache/s2graph/core/Management.scala b/s2core/src/main/scala/org/apache/s2graph/core/Management.scala index f64e058..c41e890 100644 --- a/s2core/src/main/scala/org/apache/s2graph/core/Management.scala +++ b/s2core/src/main/scala/org/apache/s2graph/core/Management.scala @@ -429,6 +429,10 @@ class Management(graph: S2GraphLike) { ColumnMeta.findOrInsert(serviceColumn.id.get, propName, dataType, defaultValue, storeInGlobalIndex = storeInGlobalIndex, useCache = false) } + + updateVertexMutator(serviceColumn, None) + updateVertexFetcher(serviceColumn, None) + serviceColumn } } @@ -505,6 +509,9 @@ class Management(graph: S2GraphLike) { )) storage.createTable(config, newLabel.hbaseTableName) + updateEdgeFetcher(newLabel, None) + updateEdgeFetcher(newLabel, None) + newLabel } } http://git-wip-us.apache.org/repos/asf/incubator-s2graph/blob/b91054c5/s2core/src/main/scala/org/apache/s2graph/core/fetcher/tensorflow/InceptionFetcher.scala ---------------------------------------------------------------------- diff --git a/s2core/src/main/scala/org/apache/s2graph/core/fetcher/tensorflow/InceptionFetcher.scala b/s2core/src/main/scala/org/apache/s2graph/core/fetcher/tensorflow/InceptionFetcher.scala new file mode 100644 index 0000000..b35dd4a --- /dev/null +++ b/s2core/src/main/scala/org/apache/s2graph/core/fetcher/tensorflow/InceptionFetcher.scala @@ -0,0 +1,85 @@ +package org.apache.s2graph.core.fetcher.tensorflow + +import java.net.URL +import java.nio.file.Paths + +import com.typesafe.config.Config +import org.apache.commons.io.IOUtils +import org.apache.s2graph.core._ +import org.apache.s2graph.core.types.VertexId + +import scala.concurrent.{ExecutionContext, Future} + + +object InceptionFetcher { + val ModelPath = "modelPath" + + def getImageBytes(urlText: String): Array[Byte] = { + val url = new URL(urlText) + + IOUtils.toByteArray(url) + } + + def predict(graphDef: Array[Byte], + labels: Seq[String])(imageBytes: Array[Byte], topK: Int = 10): Seq[(String, Float)] = { + try { + val image = LabelImage.constructAndExecuteGraphToNormalizeImage(imageBytes) + try { + val labelProbabilities = LabelImage.executeInceptionGraph(graphDef, image) + val topKIndices = labelProbabilities.zipWithIndex.sortBy(_._1).reverse + .take(Math.min(labelProbabilities.length, topK)).map(_._2) + + val ls = topKIndices.map { idx => (labels(idx), labelProbabilities(idx)) } + + ls + } catch { + case e: Throwable => Nil + } finally if (image != null) image.close() + } + } +} + +class InceptionFetcher(graph: S2GraphLike) extends EdgeFetcher { + + import InceptionFetcher._ + + import scala.collection.JavaConverters._ + val builder = graph.elementBuilder + + var graphDef: Array[Byte] = _ + var labels: Seq[String] = _ + + override def init(config: Config)(implicit ec: ExecutionContext): Unit = { + val modelPath = config.getString(ModelPath) + graphDef = LabelImage.readAllBytesOrExit(Paths.get(modelPath, "tensorflow_inception_graph.pb")) + labels = LabelImage.readAllLinesOrExit(Paths.get(modelPath, "imagenet_comp_graph_label_strings.txt")).asScala + } + + override def close(): Unit = {} + + override def fetches(queryRequests: Seq[QueryRequest], + prevStepEdges: Map[VertexId, Seq[EdgeWithScore]])(implicit ec: ExecutionContext): Future[Seq[StepResult]] = { + val stepResultLs = queryRequests.map { queryRequest => + val vertex = queryRequest.vertex + val queryParam = queryRequest.queryParam + val urlText = vertex.innerId.toIdString() + + val edgeWithScores = predict(graphDef, labels)(getImageBytes(urlText), queryParam.limit).map { case (label, score) => + val tgtVertexId = builder.newVertexId(queryParam.label.service, + queryParam.label.tgtColumnWithDir(queryParam.labelWithDir.dir), label) + + val props: Map[String, Any] = if (queryParam.label.metaPropsInvMap.contains("score")) Map("score" -> score) else Map.empty + val edge = graph.toEdge(vertex.innerId.value, tgtVertexId.innerId.value, queryParam.labelName, queryParam.direction, props = props) + + EdgeWithScore(edge, score, queryParam.label) + } + + StepResult(edgeWithScores, Nil, Nil) + } + + Future.successful(stepResultLs) + } + + override def fetchEdgesAll()(implicit ec: ExecutionContext): Future[Seq[S2EdgeLike]] = + Future.successful(Nil) +} http://git-wip-us.apache.org/repos/asf/incubator-s2graph/blob/b91054c5/s2core/src/test/scala/org/apache/s2graph/core/fetcher/BaseFetcherTest.scala ---------------------------------------------------------------------- diff --git a/s2core/src/test/scala/org/apache/s2graph/core/fetcher/BaseFetcherTest.scala b/s2core/src/test/scala/org/apache/s2graph/core/fetcher/BaseFetcherTest.scala new file mode 100644 index 0000000..a614d53 --- /dev/null +++ b/s2core/src/test/scala/org/apache/s2graph/core/fetcher/BaseFetcherTest.scala @@ -0,0 +1,77 @@ +package org.apache.s2graph.core.fetcher + +import com.typesafe.config.{Config, ConfigFactory} +import org.apache.s2graph.core.Management.JsonModel.{Index, Prop} +import org.apache.s2graph.core.rest.RequestParser +import org.apache.s2graph.core._ +import org.apache.s2graph.core.schema.{Label, Service, ServiceColumn} +import org.scalatest._ + +import scala.concurrent.{Await, ExecutionContext} +import scala.concurrent.duration.Duration + +trait BaseFetcherTest extends FunSuite with Matchers with BeforeAndAfterAll { + var graph: S2Graph = _ + var parser: RequestParser = _ + var management: Management = _ + var config: Config = _ + + override def beforeAll = { + config = ConfigFactory.load() + graph = new S2Graph(config)(ExecutionContext.Implicits.global) + management = new Management(graph) + parser = new RequestParser(graph) + } + + override def afterAll(): Unit = { + graph.shutdown() + } + + def queryEdgeFetcher(service: Service, + serviceColumn: ServiceColumn, + label: Label, + srcVertices: Seq[String]): StepResult = { + + val vertices = srcVertices.map(graph.elementBuilder.toVertex(service.serviceName, serviceColumn.columnName, _)) + + val queryParam = QueryParam(labelName = label.label, limit = 10) + + val query = Query.toQuery(srcVertices = vertices, queryParams = Seq(queryParam)) + Await.result(graph.getEdges(query), Duration("60 seconds")) + } + + def initEdgeFetcher(serviceName: String, + columnName: String, + labelName: String, + options: Option[String]): (Service, ServiceColumn, Label) = { + val service = management.createService(serviceName, "localhost", "s2graph_htable", -1, None).get + val serviceColumn = + management.createServiceColumn(serviceName, columnName, "string", Nil) + + Label.findByName(labelName, useCache = false).foreach { label => Label.delete(label.id.get) } + + val label = management.createLabel( + labelName, + service.serviceName, + serviceColumn.columnName, + serviceColumn.columnType, + service.serviceName, + serviceColumn.columnName, + serviceColumn.columnType, + service.serviceName, + Seq.empty[Index], + Seq(Prop(name = "score", defaultValue = "0.0", dataType = "double")), + isDirected = true, + consistencyLevel = "strong", + hTableName = None, + hTableTTL = None, + schemaVersion = "v3", + compressionAlgorithm = "gz", + options = options + ).get + + management.updateEdgeFetcher(label, options) + + (service, serviceColumn, Label.findById(label.id.get, useCache = false)) + } +} http://git-wip-us.apache.org/repos/asf/incubator-s2graph/blob/b91054c5/s2core/src/test/scala/org/apache/s2graph/core/fetcher/tensorflow/InceptionFetcherTest.scala ---------------------------------------------------------------------- diff --git a/s2core/src/test/scala/org/apache/s2graph/core/fetcher/tensorflow/InceptionFetcherTest.scala b/s2core/src/test/scala/org/apache/s2graph/core/fetcher/tensorflow/InceptionFetcherTest.scala new file mode 100644 index 0000000..31557a7 --- /dev/null +++ b/s2core/src/test/scala/org/apache/s2graph/core/fetcher/tensorflow/InceptionFetcherTest.scala @@ -0,0 +1,76 @@ +package org.apache.s2graph.core.fetcher.tensorflow + +import java.io.File + +import org.apache.commons.io.FileUtils +import org.apache.s2graph.core.fetcher.BaseFetcherTest +import play.api.libs.json.Json + +class InceptionFetcherTest extends BaseFetcherTest { + val runDownloadModel: Boolean = false + val runCleanup: Boolean = false + + def cleanup(downloadPath: String, dir: String) = { + synchronized { + FileUtils.deleteQuietly(new File(downloadPath)) + FileUtils.deleteDirectory(new File(dir)) + } + } + def downloadModel(dir: String) = { + import sys.process._ + synchronized { + FileUtils.forceMkdir(new File(dir)) + + val url = "https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip" + val wget = s"wget $url" + wget ! + val unzip = s"unzip inception5h.zip -d $dir" + unzip ! + } + } + + test("test get bytes for image url") { + val downloadPath = "inception5h.zip" + val modelPath = "inception" + try { + if (runDownloadModel) downloadModel(modelPath) + + val serviceName = "s2graph" + val columnName = "user" + val labelName = "image_net" + val options = + s""" + |{ + | "fetcher": { + | "className": "org.apache.s2graph.core.fetcher.tensorflow.InceptionV3Fetcher", + | "modelPath": "$modelPath" + | } + |} + """.stripMargin + val (service, column, label) = initEdgeFetcher(serviceName, columnName, labelName, Option(options)) + + val srcVertices = Seq( + "http://www.gstatic.com/webp/gallery/1.jpg", + "http://www.gstatic.com/webp/gallery/2.jpg", + "http://www.gstatic.com/webp/gallery/3.jpg" + ) + val stepResult = queryEdgeFetcher(service, column, label, srcVertices) + + stepResult.edgeWithScores.groupBy(_.edge.srcVertex).foreach { case (srcVertex, ls) => + val url = srcVertex.innerIdVal.toString + val scores = ls.map { es => + val edge = es.edge + val label = edge.tgtVertex.innerIdVal.toString + val score = edge.property[Double]("score").value() + + Json.obj("label" -> label, "score" -> score) + } + val jsArr = Json.toJson(scores) + val json = Json.obj("url" -> url, "scores" -> jsArr) + println(Json.prettyPrint(json)) + } + } finally { + if (runCleanup) cleanup(downloadPath, modelPath) + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-s2graph/blob/b91054c5/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/JobDescription.scala ---------------------------------------------------------------------- diff --git a/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/JobDescription.scala b/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/JobDescription.scala index dc32bc5..9a529aa 100644 --- a/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/JobDescription.scala +++ b/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/JobDescription.scala @@ -21,7 +21,6 @@ package org.apache.s2graph.s2jobs import play.api.libs.json.{JsValue, Json} import org.apache.s2graph.s2jobs.task._ -import org.apache.s2graph.s2jobs.task.custom.process.AnnoyIndexBuildSink case class JobDescription( name:String, @@ -83,7 +82,19 @@ object JobDescription extends Logger { case "file" => new FileSink(jobName, conf) case "es" => new ESSink(jobName, conf) case "s2graph" => new S2GraphSink(jobName, conf) - case "annoy" => new AnnoyIndexBuildSink(jobName, conf) + case "custom" => + val customClassOpt = conf.options.get("class") + customClassOpt match { + case Some(customClass:String) => + logger.debug(s"custom class for sink init.. $customClass") + + Class.forName(customClass) + .getConstructor(classOf[String], classOf[TaskConf]) + .newInstance(jobName, conf) + .asInstanceOf[task.Sink] + + case None => throw new IllegalArgumentException(s"sink custom class name is not exist.. ${conf}") + } case _ => throw new IllegalArgumentException(s"unsupported sink type : ${conf.`type`}") } http://git-wip-us.apache.org/repos/asf/incubator-s2graph/blob/b91054c5/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/task/custom/process/ALSModelProcess.scala ---------------------------------------------------------------------- diff --git a/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/task/custom/process/ALSModelProcess.scala b/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/task/custom/process/ALSModelProcess.scala index dfbefbf..3c60481 100644 --- a/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/task/custom/process/ALSModelProcess.scala +++ b/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/task/custom/process/ALSModelProcess.scala @@ -122,18 +122,3 @@ class ALSModelProcess(conf: TaskConf) extends org.apache.s2graph.s2jobs.task.Pro } override def mandatoryOptions: Set[String] = Set.empty } - -class AnnoyIndexBuildSink(queryName: String, conf: TaskConf) extends Sink(queryName, conf) { - override val FORMAT: String = "parquet" - - override def mandatoryOptions: Set[String] = Set("path", "itemFactors") - - override def write(inputDF: DataFrame): Unit = { - val df = repartition(preprocess(inputDF), inputDF.sparkSession.sparkContext.defaultParallelism) - - if (inputDF.isStreaming) throw new IllegalStateException("AnnoyIndexBuildSink can not be run as streaming.") - else { - ALSModelProcess.buildAnnoyIndex(conf, inputDF) - } - } -} http://git-wip-us.apache.org/repos/asf/incubator-s2graph/blob/b91054c5/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/task/custom/sink/AnnoyIndexBuildSink.scala ---------------------------------------------------------------------- diff --git a/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/task/custom/sink/AnnoyIndexBuildSink.scala b/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/task/custom/sink/AnnoyIndexBuildSink.scala new file mode 100644 index 0000000..595b95a --- /dev/null +++ b/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/task/custom/sink/AnnoyIndexBuildSink.scala @@ -0,0 +1,21 @@ +package org.apache.s2graph.s2jobs.task.custom.sink + +import org.apache.s2graph.s2jobs.task.{Sink, TaskConf} +import org.apache.s2graph.s2jobs.task.custom.process.ALSModelProcess +import org.apache.spark.sql.DataFrame + + +class AnnoyIndexBuildSink(queryName: String, conf: TaskConf) extends Sink(queryName, conf) { + override val FORMAT: String = "parquet" + + override def mandatoryOptions: Set[String] = Set("path", "itemFactors") + + override def write(inputDF: DataFrame): Unit = { + val df = repartition(preprocess(inputDF), inputDF.sparkSession.sparkContext.defaultParallelism) + + if (inputDF.isStreaming) throw new IllegalStateException("AnnoyIndexBuildSink can not be run as streaming.") + else { + ALSModelProcess.buildAnnoyIndex(conf, inputDF) + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-s2graph/blob/b91054c5/s2jobs/src/test/scala/org/apache/s2graph/s2jobs/task/custom/process/ALSModelProcessTest.scala ---------------------------------------------------------------------- diff --git a/s2jobs/src/test/scala/org/apache/s2graph/s2jobs/task/custom/process/ALSModelProcessTest.scala b/s2jobs/src/test/scala/org/apache/s2graph/s2jobs/task/custom/process/ALSModelProcessTest.scala index 3f12e8c..5f36f4b 100644 --- a/s2jobs/src/test/scala/org/apache/s2graph/s2jobs/task/custom/process/ALSModelProcessTest.scala +++ b/s2jobs/src/test/scala/org/apache/s2graph/s2jobs/task/custom/process/ALSModelProcessTest.scala @@ -2,7 +2,6 @@ package org.apache.s2graph.s2jobs.task.custom.process import java.io.File -import com.holdenkarau.spark.testing.DataFrameSuiteBase import org.apache.commons.io.FileUtils import org.apache.s2graph.core.Management.JsonModel.{Index, Prop} import org.apache.s2graph.core.fetcher.annoy.AnnoyModelFetcher @@ -10,6 +9,7 @@ import org.apache.s2graph.core.{Query, QueryParam, ResourceManager} import org.apache.s2graph.core.schema.Label import org.apache.s2graph.s2jobs.BaseSparkTest import org.apache.s2graph.s2jobs.task.TaskConf +import org.apache.s2graph.s2jobs.task.custom.sink.AnnoyIndexBuildSink import scala.concurrent.{Await, ExecutionContext} import scala.concurrent.duration.Duration
