This is an automated email from the ASF dual-hosted git repository.

cmeier pushed a commit to branch clojure-bert-sentence-pair-classification
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit 5270b89aaa218d84328a1c66ca251c66c45528e9
Author: gigasquid <[email protected]>
AuthorDate: Fri Apr 19 17:01:40 2019 -0400

    more wip
---
 .../src/bert_qa/bert_sentence_classification.clj   | 141 +++++++++++++++------
 1 file changed, 105 insertions(+), 36 deletions(-)

diff --git 
a/contrib/clojure-package/examples/bert-qa/src/bert_qa/bert_sentence_classification.clj
 
b/contrib/clojure-package/examples/bert-qa/src/bert_qa/bert_sentence_classification.clj
index e53d5da..053dade 100644
--- 
a/contrib/clojure-package/examples/bert-qa/src/bert_qa/bert_sentence_classification.clj
+++ 
b/contrib/clojure-package/examples/bert-qa/src/bert_qa/bert_sentence_classification.clj
@@ -10,6 +10,8 @@
             [org.apache.clojure-mxnet.ndarray :as ndarray]
             [org.apache.clojure-mxnet.symbol :as sym]
             [org.apache.clojure-mxnet.module :as m]
+            [org.apache.clojure-mxnet.infer :as infer]
+            [org.apache.clojure-mxnet.optimizer :as optimizer]
             [clojure.pprint :as pprint]
             [clojure-csv.core :as csv]
             [bert-qa.infer :as bert-infer]))
@@ -22,18 +24,10 @@
 (def model-vocab "model/vocab.json")
 ;; the input question
 ;; the maximum length of the sequence
-(def seq-length 384)
+(def seq-length 128)
+
 
 
-(defn fine-tune-model
-  "msymbol: the pretrained network symbol
-    arg-params: the argument parameters of the pretrained model
-    num-classes: the number of classes for the fine-tune datasets"
-  [msymbol num-classes]
-  (as-> msymbol data
-    (sym/flatten "flatten-finetune" {:data data})
-    (sym/fully-connected "fc-finetune" {:data data :num-hidden num-classes})
-    (sym/softmax-output "softmax" {:data data})))
 
 ;;; Data preprocessing
 
@@ -58,12 +52,81 @@
     #_(fit devs net new-args arg-params)))
 
 
+(defn pre-processing [ctx idx->token token->idx train-item]
+    (let [[sentence-a sentence-b label] train-item
+       ;;; pre-processing tokenize sentence
+          token-1 (bert-infer/tokenize (string/lower-case sentence-a))
+          token-2 (bert-infer/tokenize (string/lower-case sentence-b))
+          valid-length (+ (count token-1) (count token-2))
+        ;;; generate token types [0000...1111...0000]
+          qa-embedded (into (bert-infer/pad [] 0 (count token-1))
+                            (bert-infer/pad [] 1 (count token-2)))
+          token-types (bert-infer/pad qa-embedded 0 seq-length)
+        ;;; make BERT pre-processing standard
+          token-2 (conj token-2 "[SEP]")
+          token-1 (into [] (concat ["[CLS]"] token-1 ["[SEP]"] token-2))
+          tokens (bert-infer/pad token-1 "[PAD]" seq-length)
+        ;;; pre-processing - token to index translation
+          indexes (bert-infer/tokens->idxs token->idx tokens)]
+    {:input-batch [(ndarray/array indexes [1 seq-length] {:context ctx})
+                   (ndarray/array token-types [1 seq-length] {:context ctx})
+                   (ndarray/array [valid-length] [1] {:context ctx})]
+     :label (if (= "0" label)
+              (ndarray/array [1 0] [2] {:ctx ctx})
+              (ndarray/array [0 1] [2] {:ctx ctx}))
+     :tokens tokens
+     :train-item train-item}))
+
+(defn fine-tune-model
+  "msymbol: the pretrained network symbol
+    arg-params: the argument parameters of the pretrained model
+    num-classes: the number of classes for the fine-tune datasets"
+  [msymbol num-classes]
+  (as-> msymbol data
+    (sym/flatten "flatten-finetune" {:data data})
+    (sym/fully-connected "fc-finetune" {:data data :num-hidden num-classes})
+    (sym/softmax-output "softmax" {:data data})))
+
+
 (comment
 
   ;;; load the pre-trained BERT model using the module api
   (def bert-base (m/load-checkpoint {:prefix model-path-prefix :epoch 0}))
   ;;; now that we have loaded the BERT model we need to attach an additional 
layer for classification which is a dense layer with 2 classes
-  (def model (fine-tune-model (m/symbol bert-base) 2))
+  (def model-sym (fine-tune-model (m/symbol bert-base) 2))
+  (def arg-params (m/arg-params bert-base))
+  (def aux-params (m/aux-params bert-base))
+
+  (def devs [(context/default-context)])
+  (def input-descs [{:name "data0"
+                      :shape [1 seq-length]
+                      :dtype dtype/FLOAT32
+                      :layout layout/NT}
+                     {:name "data1"
+                      :shape [1 seq-length]
+                      :dtype dtype/FLOAT32
+                      :layout layout/NT}
+                     {:name "data2"
+                      :shape [1]
+                      :dtype dtype/FLOAT32
+                      :layout layout/N}])
+  (def label-descs [{:name "softmax_label"
+                     :shape [1 2]
+                     :dtype dtype/FLOAT32
+                     :layout layout/NT}])
+
+  ;; now create the module
+  (def mod (-> (m/module model-sym {:contexts devs
+                                    :data-names ["data0" "data1" "data2"]})
+               (m/bind {:data-shapes input-descs :label-shapes label-descs})
+               (m/init-params {:arg-params arg-params :aux-params aux-params
+                               :allow-missing true})
+               (m/init-optimizer {:optimizer (optimizer/sgd {:learning-rate 
0.01 :momentum 0.9})})))
+
+  (def base-mod (-> bert-base
+                    (m/bind {:data-shapes input-descs})
+                    (m/init-params {:arg-params arg-params :aux-params 
aux-params
+                                    :allow-missing true})))
 
   ;;; Data Preprocessing for BERT
 
@@ -106,33 +169,39 @@
   (def idx->token (:idx->token vocab))
   (def token->idx (:token->idx vocab))
 
-  (defn pre-processing [ctx idx->token token->idx train-item]
-    (let [[sentence-a sentence-b label] train-item
-       ;;; pre-processing tokenize sentence
-          token-1 (bert-infer/tokenize (string/lower-case sentence-a))
-          token-2 (bert-infer/tokenize (string/lower-case sentence-b))
-          valid-length (+ (count token-1) (count token-2))
-        ;;; generate token types [0000...1111...0000]
-          qa-embedded (into (bert-infer/pad [] 0 (count token-1))
-                            (bert-infer/pad [] 1 (count token-2)))
-          token-types (bert-infer/pad qa-embedded 0 seq-length)
-        ;;; make BERT pre-processing standard
-          token-2 (conj token-2 "[SEP]")
-          token-1 (into [] (concat ["[CLS]"] token-1 ["[SEP]"] token-2))
-          tokens (bert-infer/pad token-1 "[PAD]" seq-length)
-        ;;; pre-processing - token to index translation
-          indexes (bert-infer/tokens->idxs token->idx tokens)]
-    {:input-batch [(ndarray/array indexes [1 seq-length] {:context ctx})
-                   (ndarray/array token-types [1 seq-length] {:context ctx})
-                   (ndarray/array [valid-length] [1] {:context ctx})]
-     :label (if (= "0" label)
-              (ndarray/array [1 0] [2] {:ctx ctx})
-              (ndarray/array [0 1] [2] {:ctx ctx}))
-     :tokens tokens
-     :train-item train-item}))
+  
 
   ;;; our sample item
-  (pre-processing (context/default-context) idx->token token->idx sample)
+  (def sample-data (pre-processing (context/default-context) idx->token 
token->idx sample))
+
+  
+
+  ;; with a predictor
+  (defn make-predictor [ctx]
+  (let [input-descs [{:name "data0"
+                      :shape [1 seq-length]
+                      :dtype dtype/FLOAT32
+                      :layout layout/NT}
+                     {:name "data1"
+                      :shape [1 seq-length]
+                      :dtype dtype/FLOAT32
+                      :layout layout/NT}
+                     {:name "data2"
+                      :shape [1]
+                      :dtype dtype/FLOAT32
+                      :layout layout/N}]
+        factory (infer/model-factory model-path-prefix input-descs)]
+    (infer/create-predictor
+     factory
+     {:contexts [ctx]
+      :epoch 0})))
+
+  (def predictor (make-predictor (context/default-context)))
+  (def sample-result (first (infer/predict-with-ndarray predictor 
(:input-batch sample-data))))
+  
+
+
+  
  
 
   )

Reply via email to