This is an automated email from the ASF dual-hosted git repository.

cmeier pushed a commit to branch clojure-bert-qa-example
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit 665ed7d3fec1fb1560b1985dd6ab2c092cf0d87c
Author: gigasquid <[email protected]>
AuthorDate: Fri Apr 12 16:34:08 2019 -0400

    wip
---
 .../clojure-package/examples/bert-qa/.gitignore    |  12 +++
 contrib/clojure-package/examples/bert-qa/README.md |  22 +++++
 .../examples/bert-qa/get_bert_data.sh              |  29 ++++++
 .../clojure-package/examples/bert-qa/project.clj   |  10 ++
 .../examples/bert-qa/src/bert_qa/core.clj          | 107 +++++++++++++++++++++
 .../examples/bert-qa/src/java}/BertDataParser.java |  10 +-
 .../examples/bert-qa/src/java}/BertQA.java         |   7 +-
 .../examples/bert-qa/test/bert_qa/core_test.clj    |   7 ++
 .../javaapi/infer/bert/BertDataParser.java         |  10 +-
 .../mxnetexamples/javaapi/infer/bert/BertQA.java   |   1 +
 10 files changed, 202 insertions(+), 13 deletions(-)

diff --git a/contrib/clojure-package/examples/bert-qa/.gitignore 
b/contrib/clojure-package/examples/bert-qa/.gitignore
new file mode 100644
index 0000000..d18f225
--- /dev/null
+++ b/contrib/clojure-package/examples/bert-qa/.gitignore
@@ -0,0 +1,12 @@
+/target
+/classes
+/checkouts
+profiles.clj
+pom.xml
+pom.xml.asc
+*.jar
+*.class
+/.lein-*
+/.nrepl-port
+.hgignore
+.hg/
diff --git a/contrib/clojure-package/examples/bert-qa/README.md 
b/contrib/clojure-package/examples/bert-qa/README.md
new file mode 100644
index 0000000..fc21bdd
--- /dev/null
+++ b/contrib/clojure-package/examples/bert-qa/README.md
@@ -0,0 +1,22 @@
+# bert-qa
+
+A Clojure library designed to ... well, that part is up to you.
+
+## Usage
+
+FIXME
+
+## License
+
+Copyright © 2019 FIXME
+
+This program and the accompanying materials are made available under the
+terms of the Eclipse Public License 2.0 which is available at
+http://www.eclipse.org/legal/epl-2.0.
+
+This Source Code may also be made available under the following Secondary
+Licenses when the conditions for such availability set forth in the Eclipse
+Public License, v. 2.0 are satisfied: GNU General Public License as published 
by
+the Free Software Foundation, either version 2 of the License, or (at your
+option) any later version, with the GNU Classpath Exception which is available
+at https://www.gnu.org/software/classpath/license.html.
diff --git a/contrib/clojure-package/examples/bert-qa/get_bert_data.sh 
b/contrib/clojure-package/examples/bert-qa/get_bert_data.sh
new file mode 100755
index 0000000..603194a
--- /dev/null
+++ b/contrib/clojure-package/examples/bert-qa/get_bert_data.sh
@@ -0,0 +1,29 @@
+#!/bin/bash
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+set -e
+
+data_path=model
+
+if [ ! -d "$data_path" ]; then
+  mkdir -p "$data_path"
+  curl 
https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/BertQA/vocab.json
 -o $data_path/vocab.json
+  curl 
https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/BertQA/static_bert_qa-0002.params
 -o $data_path/static_bert_qa-0002.params
+  curl 
https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/BertQA/static_bert_qa-symbol.json
 -o $data_path/static_bert_qa-symbol.json
+fi
diff --git a/contrib/clojure-package/examples/bert-qa/project.clj 
b/contrib/clojure-package/examples/bert-qa/project.clj
new file mode 100644
index 0000000..328d040
--- /dev/null
+++ b/contrib/clojure-package/examples/bert-qa/project.clj
@@ -0,0 +1,10 @@
+(defproject bert-qa "0.1.0-SNAPSHOT"
+  :description "BERT QA Example"
+  :plugins [[lein-cljfmt "0.5.7"]]
+  :dependencies [[org.clojure/clojure "1.9.0"]
+                 [org.apache.mxnet.contrib.clojure/clojure-mxnet 
"1.5.0-SNAPSHOT"]
+                 [com.google.code.gson/gson "2.8.5"]
+                 [cheshire "5.8.1"]]
+  :pedantic? :skip
+  :java-source-paths ["src/java"]
+  :repl-options {:init-ns bert-qa.core})
diff --git a/contrib/clojure-package/examples/bert-qa/src/bert_qa/core.clj 
b/contrib/clojure-package/examples/bert-qa/src/bert_qa/core.clj
new file mode 100644
index 0000000..941e6aa
--- /dev/null
+++ b/contrib/clojure-package/examples/bert-qa/src/bert_qa/core.clj
@@ -0,0 +1,107 @@
+(ns bert-qa.core
+  (:require [clojure.string :as string]
+            [clojure.reflect :as r]
+            [cheshire.core :as json]
+            [clojure.java.io :as io]
+            [clojure.set :as set]
+            [org.apache.clojure-mxnet.dtype :as dtype]
+            [org.apache.clojure-mxnet.context :as context]
+            [org.apache.clojure-mxnet.layout :as layout]
+            [org.apache.clojure-mxnet.ndarray :as ndarray])
+  (:import (bert BertDataParser)
+           (bert BertQA)))
+
+(def model-path-prefix "model/static_bert_qa")
+;; epoch number of the model
+(def epoch 2)
+;; the vocabulary used in the model
+(def model-vocab "model/vocab.json")
+;; the input question
+(def input-q "When did BBC Japan start broadcasting?")
+;;; the input answer
+(def input-a (str  "BBC Japan was a general entertainment Channel.\n"
+                  " Which operated between December 2004 and April 2006.\n"
+                  "It ceased operations after its Japanese distributor 
folded."))
+;; the maximum length of the sequence
+(def seq-length 384)
+
+;;; data helpers
+
+(defn break-out-punctuation [s str-match]
+  (->> (string/split (str s "<punc>") (re-pattern (str "\\" str-match)))
+       (map #(string/replace % "<punc>" str-match))))
+
+(defn break-out-punctuations [s]
+  (if-let [target-char (first (re-seq #"[.,?!]" s))]
+    (break-out-punctuation s target-char)
+    [s]))
+
+(defn tokenizer [s]
+  (->> (string/split s #"\s+")
+       (mapcat break-out-punctuations)
+       (into [])))
+
+(defn pad [tokens pad-item num]
+  (if (>= (count tokens) num)
+    tokens
+    (into tokens (repeat (- num (count tokens)) pad-item))))
+
+(defn get-vocab []
+  (let [vocab (json/parse-stream (clojure.java.io/reader "model/vocab.json"))]
+    {:idx2token (get vocab "idx_to_token")
+     :token2idx (get vocab "token_to_idx")}))
+
+(defn tokens->idxs [token2idx tokens]
+  (mapv #(get token2idx % "[UNK]")  tokens))
+
+(defn idxs->tokens [idx2token idxs]
+  (mapv #(get idx2token %) idxs))
+
+
+(defn infer []
+  (let [ctx (context/default-context)
+        ;;; pre-processing tokenize sentence
+        token-q (tokenizer (string/lower-case input-q))
+        token-a (tokenizer (string/lower-case input-a))
+        valid-length (+ (count token-q) (count token-a))
+        _ (println "Valid length " valid-length)
+        ;;; generate token types [0000...1111...0000]
+        qa-embedded (into (pad [] 0 (count token-q))
+                          (pad [] 1 (count token-a)))
+        token-types (pad qa-embedded 0 seq-length)
+        ;;; make BERT pre-processing standard
+        token-a (conj token-a "[SEP]")
+        token-q (into [] (concat ["[CLS]"] token-q ["[SEP]"] token-a))
+        tokens (pad token-q "[PAD]" seq-length)
+        _ (println "Pre-processed tokens " token-q)
+        ;;; pre-processing - token to index translation
+        {:keys [idx2token token2idx]} (get-vocab)
+        indexes (tokens->idxs token2idx tokens)
+        ;;; preparing the input data
+        input-batch [(ndarray/array indexes [1 seq-length] {:context ctx})
+                     (ndarray/array token-types [1 seq-length] {:context ctx})
+                     (ndarray/array [valid-length] [1] {:context ctx})]
+
+
+        ]
+    input-batch)
+  
+  )
+
+
+(comment
+
+  (repeat 3 0)
+  (def x )
+  (keys x)
+  (get x "idx_to_token")
+  (def bert-parser (new BertDataParser))
+  (.parseJSON bert-parser model-vocab)
+  (.token2idx bert-parser (java.util.ArrayList. ["and" "where"]))
+  [1998 2073]
+  (.idx2token bert-parser (java.util.ArrayList. (map int [1998 2073])))
+  ["and" "where"]
+  (def bert-qa (new BertQA))
+
+  (BertQA/main (into-array ["--model-path-prefix" "model/static_bert_qa"]) )
+  (r/reflect bert-qa))
diff --git 
a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertDataParser.java
 b/contrib/clojure-package/examples/bert-qa/src/java/BertDataParser.java
similarity index 93%
copy from 
scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertDataParser.java
copy to contrib/clojure-package/examples/bert-qa/src/java/BertDataParser.java
index 440670a..a0a821a 100644
--- 
a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertDataParser.java
+++ b/contrib/clojure-package/examples/bert-qa/src/java/BertDataParser.java
@@ -41,7 +41,7 @@ public class BertDataParser {
      * @param jsonFile the filePath of the vocab.json
      * @throws Exception
      */
-    void parseJSON(String jsonFile) throws Exception {
+    public void parseJSON(String jsonFile) throws Exception {
         Gson gson = new Gson();
         token2idx = new HashMap<>();
         idx2token = new LinkedList<>();
@@ -62,7 +62,7 @@ public class BertDataParser {
      * @param input The input string
      * @return List of tokens
      */
-    List<String> tokenizer(String input) {
+    public List<String> tokenizer(String input) {
         String[] step1 = input.split("\\s+");
         List<String> finalResult = new LinkedList<>();
         for (String item : step1) {
@@ -85,7 +85,7 @@ public class BertDataParser {
      * @param num total length after padding
      * @return List of padded tokens
      */
-    <E> List<E> pad(List<E> tokens, E padItem, int num) {
+    public <E> List<E> pad(List<E> tokens, E padItem, int num) {
         if (tokens.size() >= num) return tokens;
         List<E> padded = new LinkedList<>(tokens);
         for (int i = 0; i < num - tokens.size(); i++) {
@@ -99,7 +99,7 @@ public class BertDataParser {
      * @param tokens input tokens
      * @return List of indexes
      */
-    List<Integer> token2idx(List<String> tokens) {
+    public List<Integer> token2idx(List<String> tokens) {
         List<Integer> indexes = new ArrayList<>();
         for (String token : tokens) {
             if (token2idx.containsKey(token)) {
@@ -116,7 +116,7 @@ public class BertDataParser {
      * @param indexes List of indexes
      * @return List of tokens
      */
-    List<String> idx2token(List<Integer> indexes) {
+    public List<String> idx2token(List<Integer> indexes) {
         List<String> tokens = new ArrayList<>();
         for (int index : indexes) {
             tokens.add(idx2token.get(index));
diff --git 
a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertQA.java
 b/contrib/clojure-package/examples/bert-qa/src/java/BertQA.java
similarity index 97%
copy from 
scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertQA.java
copy to contrib/clojure-package/examples/bert-qa/src/java/BertQA.java
index b40a4e9..7308bca 100644
--- 
a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertQA.java
+++ b/contrib/clojure-package/examples/bert-qa/src/java/BertQA.java
@@ -15,7 +15,7 @@
  * limitations under the License.
  */
 
-package org.apache.mxnetexamples.javaapi.infer.bert;
+package bert;
 
 import org.apache.mxnet.infer.javaapi.Predictor;
 import org.apache.mxnet.javaapi.*;
@@ -33,11 +33,11 @@ import java.util.*;
  */
 public class BertQA {
     @Option(name = "--model-path-prefix", usage = "input model directory and 
prefix of the model")
-    private String modelPathPrefix = "/model/static_bert_qa";
+    private String modelPathPrefix = "model/static_bert_qa";
     @Option(name = "--model-epoch", usage = "Epoch number of the model")
     private int epoch = 2;
     @Option(name = "--model-vocab", usage = "the vocabulary used in the model")
-    private String modelVocab = "/model/vocab.json";
+    private String modelVocab = "model/vocab.json";
     @Option(name = "--input-question", usage = "the input question")
     private String inputQ = "When did BBC Japan start broadcasting?";
     @Option(name = "--input-answer", usage = "the input answer")
@@ -126,6 +126,7 @@ public class BertQA {
                 new NDArray(new float[] { validLength },
                         new Shape(new int[]{1}), context)
         );
+       logger.info("Carin inputbatch: " + 
Arrays.toString(inputBatch.toArray()));
         // Build the model
         List<Context> contexts = new ArrayList<>();
         contexts.add(context);
diff --git 
a/contrib/clojure-package/examples/bert-qa/test/bert_qa/core_test.clj 
b/contrib/clojure-package/examples/bert-qa/test/bert_qa/core_test.clj
new file mode 100644
index 0000000..9345881
--- /dev/null
+++ b/contrib/clojure-package/examples/bert-qa/test/bert_qa/core_test.clj
@@ -0,0 +1,7 @@
+(ns bert-qa.core-test
+  (:require [clojure.test :refer :all]
+            [bert-qa.core :refer :all]))
+
+(deftest a-test
+  (testing "FIXME, I fail."
+    (is (= 0 1))))
diff --git 
a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertDataParser.java
 
b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertDataParser.java
index 440670a..a0a821a 100644
--- 
a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertDataParser.java
+++ 
b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertDataParser.java
@@ -41,7 +41,7 @@ public class BertDataParser {
      * @param jsonFile the filePath of the vocab.json
      * @throws Exception
      */
-    void parseJSON(String jsonFile) throws Exception {
+    public void parseJSON(String jsonFile) throws Exception {
         Gson gson = new Gson();
         token2idx = new HashMap<>();
         idx2token = new LinkedList<>();
@@ -62,7 +62,7 @@ public class BertDataParser {
      * @param input The input string
      * @return List of tokens
      */
-    List<String> tokenizer(String input) {
+    public List<String> tokenizer(String input) {
         String[] step1 = input.split("\\s+");
         List<String> finalResult = new LinkedList<>();
         for (String item : step1) {
@@ -85,7 +85,7 @@ public class BertDataParser {
      * @param num total length after padding
      * @return List of padded tokens
      */
-    <E> List<E> pad(List<E> tokens, E padItem, int num) {
+    public <E> List<E> pad(List<E> tokens, E padItem, int num) {
         if (tokens.size() >= num) return tokens;
         List<E> padded = new LinkedList<>(tokens);
         for (int i = 0; i < num - tokens.size(); i++) {
@@ -99,7 +99,7 @@ public class BertDataParser {
      * @param tokens input tokens
      * @return List of indexes
      */
-    List<Integer> token2idx(List<String> tokens) {
+    public List<Integer> token2idx(List<String> tokens) {
         List<Integer> indexes = new ArrayList<>();
         for (String token : tokens) {
             if (token2idx.containsKey(token)) {
@@ -116,7 +116,7 @@ public class BertDataParser {
      * @param indexes List of indexes
      * @return List of tokens
      */
-    List<String> idx2token(List<Integer> indexes) {
+    public List<String> idx2token(List<Integer> indexes) {
         List<String> tokens = new ArrayList<>();
         for (int index : indexes) {
             tokens.add(idx2token.get(index));
diff --git 
a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertQA.java
 
b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertQA.java
index b40a4e9..7a7b852 100644
--- 
a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertQA.java
+++ 
b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertQA.java
@@ -126,6 +126,7 @@ public class BertQA {
                 new NDArray(new float[] { validLength },
                         new Shape(new int[]{1}), context)
         );
+       logger.info("Carin inputbatch: " + 
Arrays.toString(inputBatch.toArray()));
         // Build the model
         List<Context> contexts = new ArrayList<>();
         contexts.add(context);

Reply via email to