This is an automated email from the ASF dual-hosted git repository. cmeier pushed a commit to branch clojure-bert-qa-example in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
commit 665ed7d3fec1fb1560b1985dd6ab2c092cf0d87c Author: gigasquid <[email protected]> AuthorDate: Fri Apr 12 16:34:08 2019 -0400 wip --- .../clojure-package/examples/bert-qa/.gitignore | 12 +++ contrib/clojure-package/examples/bert-qa/README.md | 22 +++++ .../examples/bert-qa/get_bert_data.sh | 29 ++++++ .../clojure-package/examples/bert-qa/project.clj | 10 ++ .../examples/bert-qa/src/bert_qa/core.clj | 107 +++++++++++++++++++++ .../examples/bert-qa/src/java}/BertDataParser.java | 10 +- .../examples/bert-qa/src/java}/BertQA.java | 7 +- .../examples/bert-qa/test/bert_qa/core_test.clj | 7 ++ .../javaapi/infer/bert/BertDataParser.java | 10 +- .../mxnetexamples/javaapi/infer/bert/BertQA.java | 1 + 10 files changed, 202 insertions(+), 13 deletions(-) diff --git a/contrib/clojure-package/examples/bert-qa/.gitignore b/contrib/clojure-package/examples/bert-qa/.gitignore new file mode 100644 index 0000000..d18f225 --- /dev/null +++ b/contrib/clojure-package/examples/bert-qa/.gitignore @@ -0,0 +1,12 @@ +/target +/classes +/checkouts +profiles.clj +pom.xml +pom.xml.asc +*.jar +*.class +/.lein-* +/.nrepl-port +.hgignore +.hg/ diff --git a/contrib/clojure-package/examples/bert-qa/README.md b/contrib/clojure-package/examples/bert-qa/README.md new file mode 100644 index 0000000..fc21bdd --- /dev/null +++ b/contrib/clojure-package/examples/bert-qa/README.md @@ -0,0 +1,22 @@ +# bert-qa + +A Clojure library designed to ... well, that part is up to you. + +## Usage + +FIXME + +## License + +Copyright © 2019 FIXME + +This program and the accompanying materials are made available under the +terms of the Eclipse Public License 2.0 which is available at +http://www.eclipse.org/legal/epl-2.0. + +This Source Code may also be made available under the following Secondary +Licenses when the conditions for such availability set forth in the Eclipse +Public License, v. 2.0 are satisfied: GNU General Public License as published by +the Free Software Foundation, either version 2 of the License, or (at your +option) any later version, with the GNU Classpath Exception which is available +at https://www.gnu.org/software/classpath/license.html. diff --git a/contrib/clojure-package/examples/bert-qa/get_bert_data.sh b/contrib/clojure-package/examples/bert-qa/get_bert_data.sh new file mode 100755 index 0000000..603194a --- /dev/null +++ b/contrib/clojure-package/examples/bert-qa/get_bert_data.sh @@ -0,0 +1,29 @@ +#!/bin/bash + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -e + +data_path=model + +if [ ! -d "$data_path" ]; then + mkdir -p "$data_path" + curl https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/BertQA/vocab.json -o $data_path/vocab.json + curl https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/BertQA/static_bert_qa-0002.params -o $data_path/static_bert_qa-0002.params + curl https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/BertQA/static_bert_qa-symbol.json -o $data_path/static_bert_qa-symbol.json +fi diff --git a/contrib/clojure-package/examples/bert-qa/project.clj b/contrib/clojure-package/examples/bert-qa/project.clj new file mode 100644 index 0000000..328d040 --- /dev/null +++ b/contrib/clojure-package/examples/bert-qa/project.clj @@ -0,0 +1,10 @@ +(defproject bert-qa "0.1.0-SNAPSHOT" + :description "BERT QA Example" + :plugins [[lein-cljfmt "0.5.7"]] + :dependencies [[org.clojure/clojure "1.9.0"] + [org.apache.mxnet.contrib.clojure/clojure-mxnet "1.5.0-SNAPSHOT"] + [com.google.code.gson/gson "2.8.5"] + [cheshire "5.8.1"]] + :pedantic? :skip + :java-source-paths ["src/java"] + :repl-options {:init-ns bert-qa.core}) diff --git a/contrib/clojure-package/examples/bert-qa/src/bert_qa/core.clj b/contrib/clojure-package/examples/bert-qa/src/bert_qa/core.clj new file mode 100644 index 0000000..941e6aa --- /dev/null +++ b/contrib/clojure-package/examples/bert-qa/src/bert_qa/core.clj @@ -0,0 +1,107 @@ +(ns bert-qa.core + (:require [clojure.string :as string] + [clojure.reflect :as r] + [cheshire.core :as json] + [clojure.java.io :as io] + [clojure.set :as set] + [org.apache.clojure-mxnet.dtype :as dtype] + [org.apache.clojure-mxnet.context :as context] + [org.apache.clojure-mxnet.layout :as layout] + [org.apache.clojure-mxnet.ndarray :as ndarray]) + (:import (bert BertDataParser) + (bert BertQA))) + +(def model-path-prefix "model/static_bert_qa") +;; epoch number of the model +(def epoch 2) +;; the vocabulary used in the model +(def model-vocab "model/vocab.json") +;; the input question +(def input-q "When did BBC Japan start broadcasting?") +;;; the input answer +(def input-a (str "BBC Japan was a general entertainment Channel.\n" + " Which operated between December 2004 and April 2006.\n" + "It ceased operations after its Japanese distributor folded.")) +;; the maximum length of the sequence +(def seq-length 384) + +;;; data helpers + +(defn break-out-punctuation [s str-match] + (->> (string/split (str s "<punc>") (re-pattern (str "\\" str-match))) + (map #(string/replace % "<punc>" str-match)))) + +(defn break-out-punctuations [s] + (if-let [target-char (first (re-seq #"[.,?!]" s))] + (break-out-punctuation s target-char) + [s])) + +(defn tokenizer [s] + (->> (string/split s #"\s+") + (mapcat break-out-punctuations) + (into []))) + +(defn pad [tokens pad-item num] + (if (>= (count tokens) num) + tokens + (into tokens (repeat (- num (count tokens)) pad-item)))) + +(defn get-vocab [] + (let [vocab (json/parse-stream (clojure.java.io/reader "model/vocab.json"))] + {:idx2token (get vocab "idx_to_token") + :token2idx (get vocab "token_to_idx")})) + +(defn tokens->idxs [token2idx tokens] + (mapv #(get token2idx % "[UNK]") tokens)) + +(defn idxs->tokens [idx2token idxs] + (mapv #(get idx2token %) idxs)) + + +(defn infer [] + (let [ctx (context/default-context) + ;;; pre-processing tokenize sentence + token-q (tokenizer (string/lower-case input-q)) + token-a (tokenizer (string/lower-case input-a)) + valid-length (+ (count token-q) (count token-a)) + _ (println "Valid length " valid-length) + ;;; generate token types [0000...1111...0000] + qa-embedded (into (pad [] 0 (count token-q)) + (pad [] 1 (count token-a))) + token-types (pad qa-embedded 0 seq-length) + ;;; make BERT pre-processing standard + token-a (conj token-a "[SEP]") + token-q (into [] (concat ["[CLS]"] token-q ["[SEP]"] token-a)) + tokens (pad token-q "[PAD]" seq-length) + _ (println "Pre-processed tokens " token-q) + ;;; pre-processing - token to index translation + {:keys [idx2token token2idx]} (get-vocab) + indexes (tokens->idxs token2idx tokens) + ;;; preparing the input data + input-batch [(ndarray/array indexes [1 seq-length] {:context ctx}) + (ndarray/array token-types [1 seq-length] {:context ctx}) + (ndarray/array [valid-length] [1] {:context ctx})] + + + ] + input-batch) + + ) + + +(comment + + (repeat 3 0) + (def x ) + (keys x) + (get x "idx_to_token") + (def bert-parser (new BertDataParser)) + (.parseJSON bert-parser model-vocab) + (.token2idx bert-parser (java.util.ArrayList. ["and" "where"])) + [1998 2073] + (.idx2token bert-parser (java.util.ArrayList. (map int [1998 2073]))) + ["and" "where"] + (def bert-qa (new BertQA)) + + (BertQA/main (into-array ["--model-path-prefix" "model/static_bert_qa"]) ) + (r/reflect bert-qa)) diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertDataParser.java b/contrib/clojure-package/examples/bert-qa/src/java/BertDataParser.java similarity index 93% copy from scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertDataParser.java copy to contrib/clojure-package/examples/bert-qa/src/java/BertDataParser.java index 440670a..a0a821a 100644 --- a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertDataParser.java +++ b/contrib/clojure-package/examples/bert-qa/src/java/BertDataParser.java @@ -41,7 +41,7 @@ public class BertDataParser { * @param jsonFile the filePath of the vocab.json * @throws Exception */ - void parseJSON(String jsonFile) throws Exception { + public void parseJSON(String jsonFile) throws Exception { Gson gson = new Gson(); token2idx = new HashMap<>(); idx2token = new LinkedList<>(); @@ -62,7 +62,7 @@ public class BertDataParser { * @param input The input string * @return List of tokens */ - List<String> tokenizer(String input) { + public List<String> tokenizer(String input) { String[] step1 = input.split("\\s+"); List<String> finalResult = new LinkedList<>(); for (String item : step1) { @@ -85,7 +85,7 @@ public class BertDataParser { * @param num total length after padding * @return List of padded tokens */ - <E> List<E> pad(List<E> tokens, E padItem, int num) { + public <E> List<E> pad(List<E> tokens, E padItem, int num) { if (tokens.size() >= num) return tokens; List<E> padded = new LinkedList<>(tokens); for (int i = 0; i < num - tokens.size(); i++) { @@ -99,7 +99,7 @@ public class BertDataParser { * @param tokens input tokens * @return List of indexes */ - List<Integer> token2idx(List<String> tokens) { + public List<Integer> token2idx(List<String> tokens) { List<Integer> indexes = new ArrayList<>(); for (String token : tokens) { if (token2idx.containsKey(token)) { @@ -116,7 +116,7 @@ public class BertDataParser { * @param indexes List of indexes * @return List of tokens */ - List<String> idx2token(List<Integer> indexes) { + public List<String> idx2token(List<Integer> indexes) { List<String> tokens = new ArrayList<>(); for (int index : indexes) { tokens.add(idx2token.get(index)); diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertQA.java b/contrib/clojure-package/examples/bert-qa/src/java/BertQA.java similarity index 97% copy from scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertQA.java copy to contrib/clojure-package/examples/bert-qa/src/java/BertQA.java index b40a4e9..7308bca 100644 --- a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertQA.java +++ b/contrib/clojure-package/examples/bert-qa/src/java/BertQA.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.mxnetexamples.javaapi.infer.bert; +package bert; import org.apache.mxnet.infer.javaapi.Predictor; import org.apache.mxnet.javaapi.*; @@ -33,11 +33,11 @@ import java.util.*; */ public class BertQA { @Option(name = "--model-path-prefix", usage = "input model directory and prefix of the model") - private String modelPathPrefix = "/model/static_bert_qa"; + private String modelPathPrefix = "model/static_bert_qa"; @Option(name = "--model-epoch", usage = "Epoch number of the model") private int epoch = 2; @Option(name = "--model-vocab", usage = "the vocabulary used in the model") - private String modelVocab = "/model/vocab.json"; + private String modelVocab = "model/vocab.json"; @Option(name = "--input-question", usage = "the input question") private String inputQ = "When did BBC Japan start broadcasting?"; @Option(name = "--input-answer", usage = "the input answer") @@ -126,6 +126,7 @@ public class BertQA { new NDArray(new float[] { validLength }, new Shape(new int[]{1}), context) ); + logger.info("Carin inputbatch: " + Arrays.toString(inputBatch.toArray())); // Build the model List<Context> contexts = new ArrayList<>(); contexts.add(context); diff --git a/contrib/clojure-package/examples/bert-qa/test/bert_qa/core_test.clj b/contrib/clojure-package/examples/bert-qa/test/bert_qa/core_test.clj new file mode 100644 index 0000000..9345881 --- /dev/null +++ b/contrib/clojure-package/examples/bert-qa/test/bert_qa/core_test.clj @@ -0,0 +1,7 @@ +(ns bert-qa.core-test + (:require [clojure.test :refer :all] + [bert-qa.core :refer :all])) + +(deftest a-test + (testing "FIXME, I fail." + (is (= 0 1)))) diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertDataParser.java b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertDataParser.java index 440670a..a0a821a 100644 --- a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertDataParser.java +++ b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertDataParser.java @@ -41,7 +41,7 @@ public class BertDataParser { * @param jsonFile the filePath of the vocab.json * @throws Exception */ - void parseJSON(String jsonFile) throws Exception { + public void parseJSON(String jsonFile) throws Exception { Gson gson = new Gson(); token2idx = new HashMap<>(); idx2token = new LinkedList<>(); @@ -62,7 +62,7 @@ public class BertDataParser { * @param input The input string * @return List of tokens */ - List<String> tokenizer(String input) { + public List<String> tokenizer(String input) { String[] step1 = input.split("\\s+"); List<String> finalResult = new LinkedList<>(); for (String item : step1) { @@ -85,7 +85,7 @@ public class BertDataParser { * @param num total length after padding * @return List of padded tokens */ - <E> List<E> pad(List<E> tokens, E padItem, int num) { + public <E> List<E> pad(List<E> tokens, E padItem, int num) { if (tokens.size() >= num) return tokens; List<E> padded = new LinkedList<>(tokens); for (int i = 0; i < num - tokens.size(); i++) { @@ -99,7 +99,7 @@ public class BertDataParser { * @param tokens input tokens * @return List of indexes */ - List<Integer> token2idx(List<String> tokens) { + public List<Integer> token2idx(List<String> tokens) { List<Integer> indexes = new ArrayList<>(); for (String token : tokens) { if (token2idx.containsKey(token)) { @@ -116,7 +116,7 @@ public class BertDataParser { * @param indexes List of indexes * @return List of tokens */ - List<String> idx2token(List<Integer> indexes) { + public List<String> idx2token(List<Integer> indexes) { List<String> tokens = new ArrayList<>(); for (int index : indexes) { tokens.add(idx2token.get(index)); diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertQA.java b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertQA.java index b40a4e9..7a7b852 100644 --- a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertQA.java +++ b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertQA.java @@ -126,6 +126,7 @@ public class BertQA { new NDArray(new float[] { validLength }, new Shape(new int[]{1}), context) ); + logger.info("Carin inputbatch: " + Arrays.toString(inputBatch.toArray())); // Build the model List<Context> contexts = new ArrayList<>(); contexts.add(context);
