add InceptionFetcher.

Project: http://git-wip-us.apache.org/repos/asf/incubator-s2graph/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-s2graph/commit/b91054c5
Tree: http://git-wip-us.apache.org/repos/asf/incubator-s2graph/tree/b91054c5
Diff: http://git-wip-us.apache.org/repos/asf/incubator-s2graph/diff/b91054c5

Branch: refs/heads/master
Commit: b91054c5b874f2cfc69a2f7d6822191dcc9c1ccf
Parents: 5ee1906
Author: DO YUNG YOON <[email protected]>
Authored: Sat May 12 10:07:54 2018 +0900
Committer: DO YUNG YOON <[email protected]>
Committed: Sat May 12 10:07:54 2018 +0900

----------------------------------------------------------------------
 .../movielens/schema/edge.similar.movie.graphql |   1 +
 .../core/fetcher/tensorflow/LabelImage.java     | 214 +++++++++++++++++++
 .../org/apache/s2graph/core/Management.scala    |   7 +
 .../fetcher/tensorflow/InceptionFetcher.scala   |  85 ++++++++
 .../s2graph/core/fetcher/BaseFetcherTest.scala  |  77 +++++++
 .../tensorflow/InceptionFetcherTest.scala       |  76 +++++++
 .../apache/s2graph/s2jobs/JobDescription.scala  |  15 +-
 .../task/custom/process/ALSModelProcess.scala   |  15 --
 .../task/custom/sink/AnnoyIndexBuildSink.scala  |  21 ++
 .../custom/process/ALSModelProcessTest.scala    |   2 +-
 10 files changed, 495 insertions(+), 18 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-s2graph/blob/b91054c5/example/movielens/schema/edge.similar.movie.graphql
----------------------------------------------------------------------
diff --git a/example/movielens/schema/edge.similar.movie.graphql 
b/example/movielens/schema/edge.similar.movie.graphql
index f8ac33b..bb625a0 100644
--- a/example/movielens/schema/edge.similar.movie.graphql
+++ b/example/movielens/schema/edge.similar.movie.graphql
@@ -45,6 +45,7 @@ mutation{
         name:"_PK"
         propNames:["score"]
       }
+      options: "{\n    \"fetcher\": {\n        \"class\": 
\"org.apache.s2graph.core.fetcher.annoy.AnnoyModelFetcher\",\n        
\"annoyIndexFilePath\": \"/tmp/annoy_result\"\n    }\n}"
     ) {
       isSuccess
       message

http://git-wip-us.apache.org/repos/asf/incubator-s2graph/blob/b91054c5/s2core/src/main/java/org/apache/s2graph/core/fetcher/tensorflow/LabelImage.java
----------------------------------------------------------------------
diff --git 
a/s2core/src/main/java/org/apache/s2graph/core/fetcher/tensorflow/LabelImage.java
 
b/s2core/src/main/java/org/apache/s2graph/core/fetcher/tensorflow/LabelImage.java
new file mode 100644
index 0000000..1125cea
--- /dev/null
+++ 
b/s2core/src/main/java/org/apache/s2graph/core/fetcher/tensorflow/LabelImage.java
@@ -0,0 +1,214 @@
+package org.apache.s2graph.core.fetcher.tensorflow;
+
+import org.tensorflow.*;
+import org.tensorflow.types.UInt8;
+
+import java.io.IOException;
+import java.io.PrintStream;
+import java.nio.charset.Charset;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.Arrays;
+import java.util.List;
+
+/** Sample use of the TensorFlow Java API to label images using a pre-trained 
model. */
+public class LabelImage {
+    private static void printUsage(PrintStream s) {
+        final String url =
+                
"https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip";;
+        s.println(
+                "Java program that uses a pre-trained Inception model 
(http://arxiv.org/abs/1512.00567)");
+        s.println("to label JPEG images.");
+        s.println("TensorFlow version: " + TensorFlow.version());
+        s.println();
+        s.println("Usage: label_image <model dir> <image file>");
+        s.println();
+        s.println("Where:");
+        s.println("<model dir> is a directory containing the unzipped contents 
of the inception model");
+        s.println("            (from " + url + ")");
+        s.println("<image file> is the path to a JPEG image file");
+    }
+
+    public static void main(String[] args) {
+        if (args.length != 2) {
+            printUsage(System.err);
+            System.exit(1);
+        }
+        String modelDir = args[0];
+        String imageFile = args[1];
+
+        byte[] graphDef = readAllBytesOrExit(Paths.get(modelDir, 
"tensorflow_inception_graph.pb"));
+        List<String> labels =
+                readAllLinesOrExit(Paths.get(modelDir, 
"imagenet_comp_graph_label_strings.txt"));
+        byte[] imageBytes = readAllBytesOrExit(Paths.get(imageFile));
+
+        try (Tensor<Float> image = 
constructAndExecuteGraphToNormalizeImage(imageBytes)) {
+            float[] labelProbabilities = executeInceptionGraph(graphDef, 
image);
+            int bestLabelIdx = maxIndex(labelProbabilities);
+            System.out.println(
+                    String.format("BEST MATCH: %s (%.2f%% likely)",
+                            labels.get(bestLabelIdx),
+                            labelProbabilities[bestLabelIdx] * 100f));
+        }
+    }
+
+    static Tensor<Float> constructAndExecuteGraphToNormalizeImage(byte[] 
imageBytes) {
+        try (Graph g = new Graph()) {
+            GraphBuilder b = new GraphBuilder(g);
+            // Some constants specific to the pre-trained model at:
+            // 
https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip
+            //
+            // - The model was trained with images scaled to 224x224 pixels.
+            // - The colors, represented as R, G, B in 1-byte each were 
converted to
+            //   float using (value - Mean)/Scale.
+            final int H = 224;
+            final int W = 224;
+            final float mean = 117f;
+            final float scale = 1f;
+
+            // Since the graph is being constructed once per execution here, 
we can use a constant for the
+            // input image. If the graph were to be re-used for multiple input 
images, a placeholder would
+            // have been more appropriate.
+            final Output<String> input = b.constant("input", imageBytes);
+            final Output<Float> output =
+                    b.div(
+                            b.sub(
+                                    b.resizeBilinear(
+                                            b.expandDims(
+                                                    b.cast(b.decodeJpeg(input, 
3), Float.class),
+                                                    b.constant("make_batch", 
0)),
+                                            b.constant("size", new int[] {H, 
W})),
+                                    b.constant("mean", mean)),
+                            b.constant("scale", scale));
+            try (Session s = new Session(g)) {
+                return 
s.runner().fetch(output.op().name()).run().get(0).expect(Float.class);
+            }
+        }
+    }
+
+    static float[] executeInceptionGraph(byte[] graphDef, Tensor<Float> image) 
{
+        try (Graph g = new Graph()) {
+            g.importGraphDef(graphDef);
+            try (Session s = new Session(g);
+                 Tensor<Float> result =
+                         s.runner().feed("input", 
image).fetch("output").run().get(0).expect(Float.class)) {
+                final long[] rshape = result.shape();
+                if (result.numDimensions() != 2 || rshape[0] != 1) {
+                    throw new RuntimeException(
+                            String.format(
+                                    "Expected model to produce a [1 N] shaped 
tensor where N is the number of labels, instead it produced one with shape %s",
+                                    Arrays.toString(rshape)));
+                }
+                int nlabels = (int) rshape[1];
+                return result.copyTo(new float[1][nlabels])[0];
+            }
+        }
+    }
+
+    static int maxIndex(float[] probabilities) {
+        int best = 0;
+        for (int i = 1; i < probabilities.length; ++i) {
+            if (probabilities[i] > probabilities[best]) {
+                best = i;
+            }
+        }
+        return best;
+    }
+
+    static byte[] readAllBytesOrExit(Path path) {
+        try {
+            return Files.readAllBytes(path);
+        } catch (IOException e) {
+            System.err.println("Failed to read [" + path + "]: " + 
e.getMessage());
+            System.exit(1);
+        }
+        return null;
+    }
+
+    static List<String> readAllLinesOrExit(Path path) {
+        try {
+            return Files.readAllLines(path, Charset.forName("UTF-8"));
+        } catch (IOException e) {
+            System.err.println("Failed to read [" + path + "]: " + 
e.getMessage());
+            System.exit(0);
+        }
+        return null;
+    }
+
+    // In the fullness of time, equivalents of the methods of this class 
should be auto-generated from
+    // the OpDefs linked into libtensorflow_jni.so. That would match what is 
done in other languages
+    // like Python, C++ and Go.
+    static class GraphBuilder {
+        GraphBuilder(Graph g) {
+            this.g = g;
+        }
+
+        Output<Float> div(Output<Float> x, Output<Float> y) {
+            return binaryOp("Div", x, y);
+        }
+
+        <T> Output<T> sub(Output<T> x, Output<T> y) {
+            return binaryOp("Sub", x, y);
+        }
+
+        <T> Output<Float> resizeBilinear(Output<T> images, Output<Integer> 
size) {
+            return binaryOp3("ResizeBilinear", images, size);
+        }
+
+        <T> Output<T> expandDims(Output<T> input, Output<Integer> dim) {
+            return binaryOp3("ExpandDims", input, dim);
+        }
+
+        <T, U> Output<U> cast(Output<T> value, Class<U> type) {
+            DataType dtype = DataType.fromClass(type);
+            return g.opBuilder("Cast", "Cast")
+                    .addInput(value)
+                    .setAttr("DstT", dtype)
+                    .build()
+                    .<U>output(0);
+        }
+
+        Output<UInt8> decodeJpeg(Output<String> contents, long channels) {
+            return g.opBuilder("DecodeJpeg", "DecodeJpeg")
+                    .addInput(contents)
+                    .setAttr("channels", channels)
+                    .build()
+                    .<UInt8>output(0);
+        }
+
+        <T> Output<T> constant(String name, Object value, Class<T> type) {
+            try (Tensor<T> t = Tensor.<T>create(value, type)) {
+                return g.opBuilder("Const", name)
+                        .setAttr("dtype", DataType.fromClass(type))
+                        .setAttr("value", t)
+                        .build()
+                        .<T>output(0);
+            }
+        }
+        Output<String> constant(String name, byte[] value) {
+            return this.constant(name, value, String.class);
+        }
+
+        Output<Integer> constant(String name, int value) {
+            return this.constant(name, value, Integer.class);
+        }
+
+        Output<Integer> constant(String name, int[] value) {
+            return this.constant(name, value, Integer.class);
+        }
+
+        Output<Float> constant(String name, float value) {
+            return this.constant(name, value, Float.class);
+        }
+
+        private <T> Output<T> binaryOp(String type, Output<T> in1, Output<T> 
in2) {
+            return g.opBuilder(type, 
type).addInput(in1).addInput(in2).build().<T>output(0);
+        }
+
+        private <T, U, V> Output<T> binaryOp3(String type, Output<U> in1, 
Output<V> in2) {
+            return g.opBuilder(type, 
type).addInput(in1).addInput(in2).build().<T>output(0);
+        }
+        private Graph g;
+    }
+}

http://git-wip-us.apache.org/repos/asf/incubator-s2graph/blob/b91054c5/s2core/src/main/scala/org/apache/s2graph/core/Management.scala
----------------------------------------------------------------------
diff --git a/s2core/src/main/scala/org/apache/s2graph/core/Management.scala 
b/s2core/src/main/scala/org/apache/s2graph/core/Management.scala
index f64e058..c41e890 100644
--- a/s2core/src/main/scala/org/apache/s2graph/core/Management.scala
+++ b/s2core/src/main/scala/org/apache/s2graph/core/Management.scala
@@ -429,6 +429,10 @@ class Management(graph: S2GraphLike) {
             ColumnMeta.findOrInsert(serviceColumn.id.get, propName, dataType, 
defaultValue,
               storeInGlobalIndex = storeInGlobalIndex, useCache = false)
           }
+
+          updateVertexMutator(serviceColumn, None)
+          updateVertexFetcher(serviceColumn, None)
+
           serviceColumn
       }
     }
@@ -505,6 +509,9 @@ class Management(graph: S2GraphLike) {
       ))
       storage.createTable(config, newLabel.hbaseTableName)
 
+      updateEdgeFetcher(newLabel, None)
+      updateEdgeFetcher(newLabel, None)
+
       newLabel
     }
   }

http://git-wip-us.apache.org/repos/asf/incubator-s2graph/blob/b91054c5/s2core/src/main/scala/org/apache/s2graph/core/fetcher/tensorflow/InceptionFetcher.scala
----------------------------------------------------------------------
diff --git 
a/s2core/src/main/scala/org/apache/s2graph/core/fetcher/tensorflow/InceptionFetcher.scala
 
b/s2core/src/main/scala/org/apache/s2graph/core/fetcher/tensorflow/InceptionFetcher.scala
new file mode 100644
index 0000000..b35dd4a
--- /dev/null
+++ 
b/s2core/src/main/scala/org/apache/s2graph/core/fetcher/tensorflow/InceptionFetcher.scala
@@ -0,0 +1,85 @@
+package org.apache.s2graph.core.fetcher.tensorflow
+
+import java.net.URL
+import java.nio.file.Paths
+
+import com.typesafe.config.Config
+import org.apache.commons.io.IOUtils
+import org.apache.s2graph.core._
+import org.apache.s2graph.core.types.VertexId
+
+import scala.concurrent.{ExecutionContext, Future}
+
+
+object InceptionFetcher {
+  val ModelPath = "modelPath"
+
+  def getImageBytes(urlText: String): Array[Byte] = {
+    val url = new URL(urlText)
+
+    IOUtils.toByteArray(url)
+  }
+
+  def predict(graphDef: Array[Byte],
+              labels: Seq[String])(imageBytes: Array[Byte], topK: Int = 10): 
Seq[(String, Float)] = {
+    try {
+      val image = 
LabelImage.constructAndExecuteGraphToNormalizeImage(imageBytes)
+      try {
+        val labelProbabilities = LabelImage.executeInceptionGraph(graphDef, 
image)
+        val topKIndices = labelProbabilities.zipWithIndex.sortBy(_._1).reverse
+          .take(Math.min(labelProbabilities.length, topK)).map(_._2)
+
+        val ls = topKIndices.map { idx => (labels(idx), 
labelProbabilities(idx)) }
+
+        ls
+      } catch {
+        case e: Throwable => Nil
+      } finally if (image != null) image.close()
+    }
+  }
+}
+
+class InceptionFetcher(graph: S2GraphLike) extends EdgeFetcher {
+
+  import InceptionFetcher._
+
+  import scala.collection.JavaConverters._
+  val builder = graph.elementBuilder
+
+  var graphDef: Array[Byte] = _
+  var labels: Seq[String] = _
+
+  override def init(config: Config)(implicit ec: ExecutionContext): Unit = {
+    val modelPath = config.getString(ModelPath)
+    graphDef = LabelImage.readAllBytesOrExit(Paths.get(modelPath, 
"tensorflow_inception_graph.pb"))
+    labels = LabelImage.readAllLinesOrExit(Paths.get(modelPath, 
"imagenet_comp_graph_label_strings.txt")).asScala
+  }
+
+  override def close(): Unit = {}
+
+  override def fetches(queryRequests: Seq[QueryRequest],
+                       prevStepEdges: Map[VertexId, 
Seq[EdgeWithScore]])(implicit ec: ExecutionContext): Future[Seq[StepResult]] = {
+    val stepResultLs = queryRequests.map { queryRequest =>
+      val vertex = queryRequest.vertex
+      val queryParam = queryRequest.queryParam
+      val urlText = vertex.innerId.toIdString()
+
+      val edgeWithScores = predict(graphDef, labels)(getImageBytes(urlText), 
queryParam.limit).map { case (label, score) =>
+        val tgtVertexId = builder.newVertexId(queryParam.label.service,
+          queryParam.label.tgtColumnWithDir(queryParam.labelWithDir.dir), 
label)
+
+        val props: Map[String, Any] = if 
(queryParam.label.metaPropsInvMap.contains("score")) Map("score" -> score) else 
Map.empty
+        val edge = graph.toEdge(vertex.innerId.value, 
tgtVertexId.innerId.value, queryParam.labelName, queryParam.direction, props = 
props)
+
+        EdgeWithScore(edge, score, queryParam.label)
+      }
+
+      StepResult(edgeWithScores, Nil, Nil)
+    }
+
+    Future.successful(stepResultLs)
+  }
+
+  override def fetchEdgesAll()(implicit ec: ExecutionContext): 
Future[Seq[S2EdgeLike]] =
+    Future.successful(Nil)
+}

http://git-wip-us.apache.org/repos/asf/incubator-s2graph/blob/b91054c5/s2core/src/test/scala/org/apache/s2graph/core/fetcher/BaseFetcherTest.scala
----------------------------------------------------------------------
diff --git 
a/s2core/src/test/scala/org/apache/s2graph/core/fetcher/BaseFetcherTest.scala 
b/s2core/src/test/scala/org/apache/s2graph/core/fetcher/BaseFetcherTest.scala
new file mode 100644
index 0000000..a614d53
--- /dev/null
+++ 
b/s2core/src/test/scala/org/apache/s2graph/core/fetcher/BaseFetcherTest.scala
@@ -0,0 +1,77 @@
+package org.apache.s2graph.core.fetcher
+
+import com.typesafe.config.{Config, ConfigFactory}
+import org.apache.s2graph.core.Management.JsonModel.{Index, Prop}
+import org.apache.s2graph.core.rest.RequestParser
+import org.apache.s2graph.core._
+import org.apache.s2graph.core.schema.{Label, Service, ServiceColumn}
+import org.scalatest._
+
+import scala.concurrent.{Await, ExecutionContext}
+import scala.concurrent.duration.Duration
+
+trait BaseFetcherTest extends FunSuite with Matchers with BeforeAndAfterAll {
+  var graph: S2Graph = _
+  var parser: RequestParser = _
+  var management: Management = _
+  var config: Config = _
+
+  override def beforeAll = {
+    config = ConfigFactory.load()
+    graph = new S2Graph(config)(ExecutionContext.Implicits.global)
+    management = new Management(graph)
+    parser = new RequestParser(graph)
+  }
+
+  override def afterAll(): Unit = {
+    graph.shutdown()
+  }
+
+  def queryEdgeFetcher(service: Service,
+                       serviceColumn: ServiceColumn,
+                       label: Label,
+                       srcVertices: Seq[String]): StepResult = {
+
+    val vertices = 
srcVertices.map(graph.elementBuilder.toVertex(service.serviceName, 
serviceColumn.columnName, _))
+
+    val queryParam = QueryParam(labelName = label.label, limit = 10)
+
+    val query = Query.toQuery(srcVertices = vertices, queryParams = 
Seq(queryParam))
+    Await.result(graph.getEdges(query), Duration("60 seconds"))
+  }
+
+  def initEdgeFetcher(serviceName: String,
+                      columnName: String,
+                      labelName: String,
+                      options: Option[String]): (Service, ServiceColumn, 
Label) = {
+    val service = management.createService(serviceName, "localhost", 
"s2graph_htable", -1, None).get
+    val serviceColumn =
+      management.createServiceColumn(serviceName, columnName, "string", Nil)
+
+    Label.findByName(labelName, useCache = false).foreach { label => 
Label.delete(label.id.get) }
+
+    val label = management.createLabel(
+      labelName,
+      service.serviceName,
+      serviceColumn.columnName,
+      serviceColumn.columnType,
+      service.serviceName,
+      serviceColumn.columnName,
+      serviceColumn.columnType,
+      service.serviceName,
+      Seq.empty[Index],
+      Seq(Prop(name = "score", defaultValue = "0.0", dataType = "double")),
+      isDirected = true,
+      consistencyLevel =  "strong",
+      hTableName = None,
+      hTableTTL = None,
+      schemaVersion = "v3",
+      compressionAlgorithm =  "gz",
+      options = options
+    ).get
+
+    management.updateEdgeFetcher(label, options)
+
+    (service, serviceColumn, Label.findById(label.id.get, useCache = false))
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-s2graph/blob/b91054c5/s2core/src/test/scala/org/apache/s2graph/core/fetcher/tensorflow/InceptionFetcherTest.scala
----------------------------------------------------------------------
diff --git 
a/s2core/src/test/scala/org/apache/s2graph/core/fetcher/tensorflow/InceptionFetcherTest.scala
 
b/s2core/src/test/scala/org/apache/s2graph/core/fetcher/tensorflow/InceptionFetcherTest.scala
new file mode 100644
index 0000000..31557a7
--- /dev/null
+++ 
b/s2core/src/test/scala/org/apache/s2graph/core/fetcher/tensorflow/InceptionFetcherTest.scala
@@ -0,0 +1,76 @@
+package org.apache.s2graph.core.fetcher.tensorflow
+
+import java.io.File
+
+import org.apache.commons.io.FileUtils
+import org.apache.s2graph.core.fetcher.BaseFetcherTest
+import play.api.libs.json.Json
+
+class InceptionFetcherTest extends BaseFetcherTest {
+  val runDownloadModel: Boolean = false
+  val runCleanup: Boolean = false
+
+  def cleanup(downloadPath: String, dir: String) = {
+    synchronized {
+      FileUtils.deleteQuietly(new File(downloadPath))
+      FileUtils.deleteDirectory(new File(dir))
+    }
+  }
+  def downloadModel(dir: String) = {
+    import sys.process._
+    synchronized {
+      FileUtils.forceMkdir(new File(dir))
+
+      val url = 
"https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip";
+      val wget = s"wget $url"
+      wget !
+      val unzip = s"unzip inception5h.zip -d $dir"
+      unzip !
+    }
+  }
+
+  test("test get bytes for image url") {
+    val downloadPath = "inception5h.zip"
+    val modelPath = "inception"
+    try {
+      if (runDownloadModel) downloadModel(modelPath)
+
+      val serviceName = "s2graph"
+      val columnName = "user"
+      val labelName = "image_net"
+      val options =
+        s"""
+           |{
+           |  "fetcher": {
+           |    "className": 
"org.apache.s2graph.core.fetcher.tensorflow.InceptionV3Fetcher",
+           |    "modelPath": "$modelPath"
+           |  }
+           |}
+       """.stripMargin
+      val (service, column, label) = initEdgeFetcher(serviceName, columnName, 
labelName, Option(options))
+
+      val srcVertices = Seq(
+        "http://www.gstatic.com/webp/gallery/1.jpg";,
+        "http://www.gstatic.com/webp/gallery/2.jpg";,
+        "http://www.gstatic.com/webp/gallery/3.jpg";
+      )
+      val stepResult = queryEdgeFetcher(service, column, label, srcVertices)
+
+      stepResult.edgeWithScores.groupBy(_.edge.srcVertex).foreach { case 
(srcVertex, ls) =>
+        val url = srcVertex.innerIdVal.toString
+        val scores = ls.map { es =>
+          val edge = es.edge
+          val label = edge.tgtVertex.innerIdVal.toString
+          val score = edge.property[Double]("score").value()
+
+          Json.obj("label" -> label, "score" -> score)
+        }
+        val jsArr = Json.toJson(scores)
+        val json = Json.obj("url" -> url, "scores" -> jsArr)
+        println(Json.prettyPrint(json))
+      }
+    } finally {
+      if (runCleanup) cleanup(downloadPath, modelPath)
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-s2graph/blob/b91054c5/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/JobDescription.scala
----------------------------------------------------------------------
diff --git 
a/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/JobDescription.scala 
b/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/JobDescription.scala
index dc32bc5..9a529aa 100644
--- a/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/JobDescription.scala
+++ b/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/JobDescription.scala
@@ -21,7 +21,6 @@ package org.apache.s2graph.s2jobs
 
 import play.api.libs.json.{JsValue, Json}
 import org.apache.s2graph.s2jobs.task._
-import org.apache.s2graph.s2jobs.task.custom.process.AnnoyIndexBuildSink
 
 case class JobDescription(
                            name:String,
@@ -83,7 +82,19 @@ object JobDescription extends Logger {
       case "file" => new FileSink(jobName, conf)
       case "es" => new ESSink(jobName, conf)
       case "s2graph" => new S2GraphSink(jobName, conf)
-      case "annoy" => new AnnoyIndexBuildSink(jobName, conf)
+      case "custom" =>
+        val customClassOpt = conf.options.get("class")
+        customClassOpt match {
+          case Some(customClass:String) =>
+            logger.debug(s"custom class for sink init.. $customClass")
+
+            Class.forName(customClass)
+              .getConstructor(classOf[String], classOf[TaskConf])
+              .newInstance(jobName, conf)
+              .asInstanceOf[task.Sink]
+
+          case None => throw new IllegalArgumentException(s"sink custom class 
name is not exist.. ${conf}")
+        }
       case _ => throw new IllegalArgumentException(s"unsupported sink type : 
${conf.`type`}")
     }
 

http://git-wip-us.apache.org/repos/asf/incubator-s2graph/blob/b91054c5/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/task/custom/process/ALSModelProcess.scala
----------------------------------------------------------------------
diff --git 
a/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/task/custom/process/ALSModelProcess.scala
 
b/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/task/custom/process/ALSModelProcess.scala
index dfbefbf..3c60481 100644
--- 
a/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/task/custom/process/ALSModelProcess.scala
+++ 
b/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/task/custom/process/ALSModelProcess.scala
@@ -122,18 +122,3 @@ class ALSModelProcess(conf: TaskConf) extends 
org.apache.s2graph.s2jobs.task.Pro
   }
   override def mandatoryOptions: Set[String] = Set.empty
 }
-
-class AnnoyIndexBuildSink(queryName: String, conf: TaskConf) extends 
Sink(queryName, conf) {
-  override val FORMAT: String = "parquet"
-
-  override def mandatoryOptions: Set[String] = Set("path", "itemFactors")
-
-  override def write(inputDF: DataFrame): Unit = {
-    val df = repartition(preprocess(inputDF), 
inputDF.sparkSession.sparkContext.defaultParallelism)
-
-    if (inputDF.isStreaming) throw new 
IllegalStateException("AnnoyIndexBuildSink can not be run as streaming.")
-    else {
-      ALSModelProcess.buildAnnoyIndex(conf, inputDF)
-    }
-  }
-}

http://git-wip-us.apache.org/repos/asf/incubator-s2graph/blob/b91054c5/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/task/custom/sink/AnnoyIndexBuildSink.scala
----------------------------------------------------------------------
diff --git 
a/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/task/custom/sink/AnnoyIndexBuildSink.scala
 
b/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/task/custom/sink/AnnoyIndexBuildSink.scala
new file mode 100644
index 0000000..595b95a
--- /dev/null
+++ 
b/s2jobs/src/main/scala/org/apache/s2graph/s2jobs/task/custom/sink/AnnoyIndexBuildSink.scala
@@ -0,0 +1,21 @@
+package org.apache.s2graph.s2jobs.task.custom.sink
+
+import org.apache.s2graph.s2jobs.task.{Sink, TaskConf}
+import org.apache.s2graph.s2jobs.task.custom.process.ALSModelProcess
+import org.apache.spark.sql.DataFrame
+
+
+class AnnoyIndexBuildSink(queryName: String, conf: TaskConf) extends 
Sink(queryName, conf) {
+  override val FORMAT: String = "parquet"
+
+  override def mandatoryOptions: Set[String] = Set("path", "itemFactors")
+
+  override def write(inputDF: DataFrame): Unit = {
+    val df = repartition(preprocess(inputDF), 
inputDF.sparkSession.sparkContext.defaultParallelism)
+
+    if (inputDF.isStreaming) throw new 
IllegalStateException("AnnoyIndexBuildSink can not be run as streaming.")
+    else {
+      ALSModelProcess.buildAnnoyIndex(conf, inputDF)
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-s2graph/blob/b91054c5/s2jobs/src/test/scala/org/apache/s2graph/s2jobs/task/custom/process/ALSModelProcessTest.scala
----------------------------------------------------------------------
diff --git 
a/s2jobs/src/test/scala/org/apache/s2graph/s2jobs/task/custom/process/ALSModelProcessTest.scala
 
b/s2jobs/src/test/scala/org/apache/s2graph/s2jobs/task/custom/process/ALSModelProcessTest.scala
index 3f12e8c..5f36f4b 100644
--- 
a/s2jobs/src/test/scala/org/apache/s2graph/s2jobs/task/custom/process/ALSModelProcessTest.scala
+++ 
b/s2jobs/src/test/scala/org/apache/s2graph/s2jobs/task/custom/process/ALSModelProcessTest.scala
@@ -2,7 +2,6 @@ package org.apache.s2graph.s2jobs.task.custom.process
 
 import java.io.File
 
-import com.holdenkarau.spark.testing.DataFrameSuiteBase
 import org.apache.commons.io.FileUtils
 import org.apache.s2graph.core.Management.JsonModel.{Index, Prop}
 import org.apache.s2graph.core.fetcher.annoy.AnnoyModelFetcher
@@ -10,6 +9,7 @@ import org.apache.s2graph.core.{Query, QueryParam, 
ResourceManager}
 import org.apache.s2graph.core.schema.Label
 import org.apache.s2graph.s2jobs.BaseSparkTest
 import org.apache.s2graph.s2jobs.task.TaskConf
+import org.apache.s2graph.s2jobs.task.custom.sink.AnnoyIndexBuildSink
 
 import scala.concurrent.{Await, ExecutionContext}
 import scala.concurrent.duration.Duration

Reply via email to