This is an automated email from the ASF dual-hosted git repository. cmeier pushed a commit to branch clojure-bert-sentence-pair-classification in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
commit 5270b89aaa218d84328a1c66ca251c66c45528e9 Author: gigasquid <[email protected]> AuthorDate: Fri Apr 19 17:01:40 2019 -0400 more wip --- .../src/bert_qa/bert_sentence_classification.clj | 141 +++++++++++++++------ 1 file changed, 105 insertions(+), 36 deletions(-) diff --git a/contrib/clojure-package/examples/bert-qa/src/bert_qa/bert_sentence_classification.clj b/contrib/clojure-package/examples/bert-qa/src/bert_qa/bert_sentence_classification.clj index e53d5da..053dade 100644 --- a/contrib/clojure-package/examples/bert-qa/src/bert_qa/bert_sentence_classification.clj +++ b/contrib/clojure-package/examples/bert-qa/src/bert_qa/bert_sentence_classification.clj @@ -10,6 +10,8 @@ [org.apache.clojure-mxnet.ndarray :as ndarray] [org.apache.clojure-mxnet.symbol :as sym] [org.apache.clojure-mxnet.module :as m] + [org.apache.clojure-mxnet.infer :as infer] + [org.apache.clojure-mxnet.optimizer :as optimizer] [clojure.pprint :as pprint] [clojure-csv.core :as csv] [bert-qa.infer :as bert-infer])) @@ -22,18 +24,10 @@ (def model-vocab "model/vocab.json") ;; the input question ;; the maximum length of the sequence -(def seq-length 384) +(def seq-length 128) + -(defn fine-tune-model - "msymbol: the pretrained network symbol - arg-params: the argument parameters of the pretrained model - num-classes: the number of classes for the fine-tune datasets" - [msymbol num-classes] - (as-> msymbol data - (sym/flatten "flatten-finetune" {:data data}) - (sym/fully-connected "fc-finetune" {:data data :num-hidden num-classes}) - (sym/softmax-output "softmax" {:data data}))) ;;; Data preprocessing @@ -58,12 +52,81 @@ #_(fit devs net new-args arg-params))) +(defn pre-processing [ctx idx->token token->idx train-item] + (let [[sentence-a sentence-b label] train-item + ;;; pre-processing tokenize sentence + token-1 (bert-infer/tokenize (string/lower-case sentence-a)) + token-2 (bert-infer/tokenize (string/lower-case sentence-b)) + valid-length (+ (count token-1) (count token-2)) + ;;; generate token types [0000...1111...0000] + qa-embedded (into (bert-infer/pad [] 0 (count token-1)) + (bert-infer/pad [] 1 (count token-2))) + token-types (bert-infer/pad qa-embedded 0 seq-length) + ;;; make BERT pre-processing standard + token-2 (conj token-2 "[SEP]") + token-1 (into [] (concat ["[CLS]"] token-1 ["[SEP]"] token-2)) + tokens (bert-infer/pad token-1 "[PAD]" seq-length) + ;;; pre-processing - token to index translation + indexes (bert-infer/tokens->idxs token->idx tokens)] + {:input-batch [(ndarray/array indexes [1 seq-length] {:context ctx}) + (ndarray/array token-types [1 seq-length] {:context ctx}) + (ndarray/array [valid-length] [1] {:context ctx})] + :label (if (= "0" label) + (ndarray/array [1 0] [2] {:ctx ctx}) + (ndarray/array [0 1] [2] {:ctx ctx})) + :tokens tokens + :train-item train-item})) + +(defn fine-tune-model + "msymbol: the pretrained network symbol + arg-params: the argument parameters of the pretrained model + num-classes: the number of classes for the fine-tune datasets" + [msymbol num-classes] + (as-> msymbol data + (sym/flatten "flatten-finetune" {:data data}) + (sym/fully-connected "fc-finetune" {:data data :num-hidden num-classes}) + (sym/softmax-output "softmax" {:data data}))) + + (comment ;;; load the pre-trained BERT model using the module api (def bert-base (m/load-checkpoint {:prefix model-path-prefix :epoch 0})) ;;; now that we have loaded the BERT model we need to attach an additional layer for classification which is a dense layer with 2 classes - (def model (fine-tune-model (m/symbol bert-base) 2)) + (def model-sym (fine-tune-model (m/symbol bert-base) 2)) + (def arg-params (m/arg-params bert-base)) + (def aux-params (m/aux-params bert-base)) + + (def devs [(context/default-context)]) + (def input-descs [{:name "data0" + :shape [1 seq-length] + :dtype dtype/FLOAT32 + :layout layout/NT} + {:name "data1" + :shape [1 seq-length] + :dtype dtype/FLOAT32 + :layout layout/NT} + {:name "data2" + :shape [1] + :dtype dtype/FLOAT32 + :layout layout/N}]) + (def label-descs [{:name "softmax_label" + :shape [1 2] + :dtype dtype/FLOAT32 + :layout layout/NT}]) + + ;; now create the module + (def mod (-> (m/module model-sym {:contexts devs + :data-names ["data0" "data1" "data2"]}) + (m/bind {:data-shapes input-descs :label-shapes label-descs}) + (m/init-params {:arg-params arg-params :aux-params aux-params + :allow-missing true}) + (m/init-optimizer {:optimizer (optimizer/sgd {:learning-rate 0.01 :momentum 0.9})}))) + + (def base-mod (-> bert-base + (m/bind {:data-shapes input-descs}) + (m/init-params {:arg-params arg-params :aux-params aux-params + :allow-missing true}))) ;;; Data Preprocessing for BERT @@ -106,33 +169,39 @@ (def idx->token (:idx->token vocab)) (def token->idx (:token->idx vocab)) - (defn pre-processing [ctx idx->token token->idx train-item] - (let [[sentence-a sentence-b label] train-item - ;;; pre-processing tokenize sentence - token-1 (bert-infer/tokenize (string/lower-case sentence-a)) - token-2 (bert-infer/tokenize (string/lower-case sentence-b)) - valid-length (+ (count token-1) (count token-2)) - ;;; generate token types [0000...1111...0000] - qa-embedded (into (bert-infer/pad [] 0 (count token-1)) - (bert-infer/pad [] 1 (count token-2))) - token-types (bert-infer/pad qa-embedded 0 seq-length) - ;;; make BERT pre-processing standard - token-2 (conj token-2 "[SEP]") - token-1 (into [] (concat ["[CLS]"] token-1 ["[SEP]"] token-2)) - tokens (bert-infer/pad token-1 "[PAD]" seq-length) - ;;; pre-processing - token to index translation - indexes (bert-infer/tokens->idxs token->idx tokens)] - {:input-batch [(ndarray/array indexes [1 seq-length] {:context ctx}) - (ndarray/array token-types [1 seq-length] {:context ctx}) - (ndarray/array [valid-length] [1] {:context ctx})] - :label (if (= "0" label) - (ndarray/array [1 0] [2] {:ctx ctx}) - (ndarray/array [0 1] [2] {:ctx ctx})) - :tokens tokens - :train-item train-item})) + ;;; our sample item - (pre-processing (context/default-context) idx->token token->idx sample) + (def sample-data (pre-processing (context/default-context) idx->token token->idx sample)) + + + + ;; with a predictor + (defn make-predictor [ctx] + (let [input-descs [{:name "data0" + :shape [1 seq-length] + :dtype dtype/FLOAT32 + :layout layout/NT} + {:name "data1" + :shape [1 seq-length] + :dtype dtype/FLOAT32 + :layout layout/NT} + {:name "data2" + :shape [1] + :dtype dtype/FLOAT32 + :layout layout/N}] + factory (infer/model-factory model-path-prefix input-descs)] + (infer/create-predictor + factory + {:contexts [ctx] + :epoch 0}))) + + (def predictor (make-predictor (context/default-context))) + (def sample-result (first (infer/predict-with-ndarray predictor (:input-batch sample-data)))) + + + + )
