This is an automated email from the ASF dual-hosted git repository. cmeier pushed a commit to branch clojure-bert-sentence-pair-classification in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
commit 3f1058edba47ee114c13284d566ab5a3389c77f1 Author: gigasquid <[email protected]> AuthorDate: Fri Apr 19 15:13:40 2019 -0400 wip --- .../clojure-package/examples/bert-qa/project.clj | 3 +- .../src/bert_qa/bert_sentence_classification.clj | 138 +++++++++++++++++++++ 2 files changed, 140 insertions(+), 1 deletion(-) diff --git a/contrib/clojure-package/examples/bert-qa/project.clj b/contrib/clojure-package/examples/bert-qa/project.clj index d256d44..7d79dc6 100644 --- a/contrib/clojure-package/examples/bert-qa/project.clj +++ b/contrib/clojure-package/examples/bert-qa/project.clj @@ -21,7 +21,8 @@ :plugins [[lein-cljfmt "0.5.7"]] :dependencies [[org.clojure/clojure "1.9.0"] [org.apache.mxnet.contrib.clojure/clojure-mxnet "1.5.0-SNAPSHOT"] - [cheshire "5.8.1"]] + [cheshire "5.8.1"] + [clojure-csv/clojure-csv "2.0.1"]] :pedantic? :skip :java-source-paths ["src/java"] :main bert-qa.infer diff --git a/contrib/clojure-package/examples/bert-qa/src/bert_qa/bert_sentence_classification.clj b/contrib/clojure-package/examples/bert-qa/src/bert_qa/bert_sentence_classification.clj new file mode 100644 index 0000000..e53d5da --- /dev/null +++ b/contrib/clojure-package/examples/bert-qa/src/bert_qa/bert_sentence_classification.clj @@ -0,0 +1,138 @@ +(ns bert-qa.bert-sentence-classification + (:require [clojure.string :as string] + [clojure.reflect :as r] + [cheshire.core :as json] + [clojure.java.io :as io] + [clojure.set :as set] + [org.apache.clojure-mxnet.dtype :as dtype] + [org.apache.clojure-mxnet.context :as context] + [org.apache.clojure-mxnet.layout :as layout] + [org.apache.clojure-mxnet.ndarray :as ndarray] + [org.apache.clojure-mxnet.symbol :as sym] + [org.apache.clojure-mxnet.module :as m] + [clojure.pprint :as pprint] + [clojure-csv.core :as csv] + [bert-qa.infer :as bert-infer])) + + +(def model-path-prefix "model/bert-base") +;; epoch number of the model +(def epoch 0) +;; the vocabulary used in the model +(def model-vocab "model/vocab.json") +;; the input question +;; the maximum length of the sequence +(def seq-length 384) + + +(defn fine-tune-model + "msymbol: the pretrained network symbol + arg-params: the argument parameters of the pretrained model + num-classes: the number of classes for the fine-tune datasets" + [msymbol num-classes] + (as-> msymbol data + (sym/flatten "flatten-finetune" {:data data}) + (sym/fully-connected "fc-finetune" {:data data :num-hidden num-classes}) + (sym/softmax-output "softmax" {:data data}))) + +;;; Data preprocessing + +#_(defn fit [devs msymbol arg-params aux-params] + (let [mod (-> (m/module msymbol {:contexts devs}) + (m/bind {:data-shapes (mx-io/provide-data-desc train-iter) :label-shapes (mx-io/provide-label-desc val-iter)}) + (m/init-params {:arg-params arg-params :aux-params aux-params + :allow-missing true}))] + (m/fit mod + {:train-data train-iter + :eval-data val-iter + :num-epoch 1 + :fit-params (m/fit-params {:intializer (init/xavier {:rand-type "gaussian" + :factor-type "in" + :magnitude 2}) + :batch-end-callback (callback/speedometer batch-size 10)})}))) + +#_(defn fine-tune! [devs] + (let [{:keys [msymbol arg-params aux-params] :as model} (get-model) + new-model (fine-tune-model (merge model {:num-classes 2}))] + new-model + #_(fit devs net new-args arg-params))) + + +(comment + + ;;; load the pre-trained BERT model using the module api + (def bert-base (m/load-checkpoint {:prefix model-path-prefix :epoch 0})) + ;;; now that we have loaded the BERT model we need to attach an additional layer for classification which is a dense layer with 2 classes + (def model (fine-tune-model (m/symbol bert-base) 2)) + + ;;; Data Preprocessing for BERT + + ;; For demonstration purpose, we use the dev set of the Microsoft Research Paraphrase Corpus dataset. The file is named ‘dev.tsv’. Let’s take a look at the raw dataset. + ;; it contains 5 columns seperated by tabs + (def raw-file (->> (string/split (slurp "dev.tsv") #"\n") + (map #(string/split % #"\t") ))) + (def raw-file (csv/parse-csv (slurp "dev.tsv") :delimiter \tab)) + (take 3 raw-file) + ;; (["Quality" "#1 ID" "#2 ID" "#1 String" "#2 String"] + ;; ["1" + ;; "1355540" + ;; "1355592" + ;; "He said the foodservice pie business doesn 't fit the company 's long-term growth strategy ." + ;; "\" The foodservice pie business does not fit our long-term growth strategy ."] + ;; ["0" + ;; "2029631" + ;; "2029565" + ;; "Magnarelli said Racicot hated the Iraqi regime and looked forward to using his long years of training in the war ." + ;; "His wife said he was \" 100 percent behind George Bush \" and looked forward to using his years of training in the war ."]) + + ;;; for our task we are only interested in the 0 3rd and 4th column + (vals (select-keys (first raw-file) [3 4 0])) + ;=> ("#1 String" "#2 String" "Quality") + (def data-train-raw (->> raw-file + (map #(vals (select-keys % [3 4 0]))) + (rest) ;;drop header + )) + (def sample (first data-train-raw)) + (nth sample 0) ;;;sentence a + ;=> "He said the foodservice pie business doesn 't fit the company 's long-term growth strategy ." + (nth sample 1) ;; sentence b + "\" The foodservice pie business does not fit our long-term growth strategy ." + + (nth sample 2) ; 1 means equivalent, 0 means not equivalent + ;=> "1" + + ;;; Now we need to turn these into ndarrays to make a Data Iterator + (def vocab (bert-infer/get-vocab)) + (def idx->token (:idx->token vocab)) + (def token->idx (:token->idx vocab)) + + (defn pre-processing [ctx idx->token token->idx train-item] + (let [[sentence-a sentence-b label] train-item + ;;; pre-processing tokenize sentence + token-1 (bert-infer/tokenize (string/lower-case sentence-a)) + token-2 (bert-infer/tokenize (string/lower-case sentence-b)) + valid-length (+ (count token-1) (count token-2)) + ;;; generate token types [0000...1111...0000] + qa-embedded (into (bert-infer/pad [] 0 (count token-1)) + (bert-infer/pad [] 1 (count token-2))) + token-types (bert-infer/pad qa-embedded 0 seq-length) + ;;; make BERT pre-processing standard + token-2 (conj token-2 "[SEP]") + token-1 (into [] (concat ["[CLS]"] token-1 ["[SEP]"] token-2)) + tokens (bert-infer/pad token-1 "[PAD]" seq-length) + ;;; pre-processing - token to index translation + indexes (bert-infer/tokens->idxs token->idx tokens)] + {:input-batch [(ndarray/array indexes [1 seq-length] {:context ctx}) + (ndarray/array token-types [1 seq-length] {:context ctx}) + (ndarray/array [valid-length] [1] {:context ctx})] + :label (if (= "0" label) + (ndarray/array [1 0] [2] {:ctx ctx}) + (ndarray/array [0 1] [2] {:ctx ctx})) + :tokens tokens + :train-item train-item})) + + ;;; our sample item + (pre-processing (context/default-context) idx->token token->idx sample) + + + )
