This is an automated email from the ASF dual-hosted git repository.

kedarb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new f44f6cf  Extend Clojure BERT example (#15023)
f44f6cf is described below

commit f44f6cfbe752fd8b8036307cecf6a30a30ad8557
Author: Dave Liepmann <dave.liepm...@gmail.com>
AuthorDate: Sat Jun 22 19:35:45 2019 +0200

    Extend Clojure BERT example (#15023)
    
    * Clojure predictor example: add rich comment
    
    This provides an entry point for folks working on this example in
    their REPL rather than the command line.
    
    * Clojure BERT example: refactor prepare-data fn for purity
    
    * Clojure BERT example: test fitted model on samples
    
    * Clojure BERT example: namespace docstring & comment
    
    * Clojure BERT example: format intro, add references
    
    * Clojure BERT example: minor refactor
    
    * Clojure BERT example: trim sentence pair explorations
    
    * Clojure BERT example: port experiment to iPynb
    
    * Clojure BERT example: fix test
    
    Underlying fn was refactored
    
    * Clojure BERT example: add sentence-pair prediction test
---
 .../examples/bert/fine-tune-bert.ipynb             | 145 +++++++++++++++++++--
 .../bert/src/bert/bert_sentence_classification.clj | 113 ++++++++++++----
 .../bert/bert_sentence_classification_test.clj     |  24 +++-
 .../predictor/src/infer/predictor_example.clj      |   6 +
 4 files changed, 248 insertions(+), 40 deletions(-)

diff --git a/contrib/clojure-package/examples/bert/fine-tune-bert.ipynb 
b/contrib/clojure-package/examples/bert/fine-tune-bert.ipynb
index 425a999..5934477 100644
--- a/contrib/clojure-package/examples/bert/fine-tune-bert.ipynb
+++ b/contrib/clojure-package/examples/bert/fine-tune-bert.ipynb
@@ -10,15 +10,16 @@
     "\n",
     "Pre-trained language representations have been shown to improve many 
downstream NLP tasks such as question answering, and natural language 
inference. To apply pre-trained representations to these tasks, there are two 
strategies:\n",
     "\n",
-    "feature-based approach, which uses the pre-trained representations as 
additional features to the downstream task.\n",
-    "fine-tuning based approach, which trains the downstream tasks by 
fine-tuning pre-trained parameters.\n",
-    "While feature-based approaches such as ELMo [3] (introduced in the 
previous tutorial) are effective in improving many downstream tasks, they 
require task-specific architectures. Devlin, Jacob, et al proposed BERT [1] 
(Bidirectional Encoder Representations from Transformers), which fine-tunes 
deep bidirectional representations on a wide range of tasks with minimal 
task-specific parameters, and obtained state- of-the-art results.\n",
+    " - **feature-based approach**, which uses the pre-trained representations 
as additional features to the downstream task.\n",
+    " - **fine-tuning based approach**, which trains the downstream tasks by 
fine-tuning pre-trained parameters.\n",
+    " \n",
+    "While feature-based approaches such as ELMo [1] are effective in 
improving many downstream tasks, they require task-specific architectures. 
Devlin, Jacob, et al proposed BERT [2] (Bidirectional Encoder Representations 
from Transformers), which fine-tunes deep bidirectional representations on a 
wide range of tasks with minimal task-specific parameters, and obtained state- 
of-the-art results.\n",
     "\n",
     "In this tutorial, we will focus on fine-tuning with the pre-trained BERT 
model to classify semantically equivalent sentence pairs. Specifically, we 
will:\n",
     "\n",
-    "load the state-of-the-art pre-trained BERT model and attach an additional 
layer for classification,\n",
-    "process and transform sentence pair data for the task at hand, and\n",
-    "fine-tune BERT model for sentence classification.\n",
+    " 1. load the state-of-the-art pre-trained BERT model and attach an 
additional layer for classification\n",
+    " 2. process and transform sentence pair data for the task at hand, and 
\n",
+    " 3. fine-tune BERT model for sentence classification.\n",
     "\n"
    ]
   },
@@ -59,6 +60,7 @@
     "            [org.apache.clojure-mxnet.callback :as callback]\n",
     "            [org.apache.clojure-mxnet.context :as context]\n",
     "            [org.apache.clojure-mxnet.dtype :as dtype]\n",
+    "            [org.apache.clojure-mxnet.infer :as infer]\n",
     "            [org.apache.clojure-mxnet.eval-metric :as eval-metric]\n",
     "            [org.apache.clojure-mxnet.io :as mx-io]\n",
     "            [org.apache.clojure-mxnet.layout :as layout]\n",
@@ -89,7 +91,7 @@
     "\n",
     "![bert](https://gluon-nlp.mxnet.io/_images/bert-sentence-pair.png)\n",
     "\n",
-    "where the model takes a pair of sequences and pools the representation of 
the first token in the sequence. Note that the original BERT model was trained 
for masked language model and next sentence prediction tasks, which includes 
layers for language model decoding and classification. These layers will not be 
used for fine-tuning sentence pair classification.\n",
+    "where the model takes a pair of sequences and *pools* the representation 
of the first token in the sequence. Note that the original BERT model was 
trained for masked language model and next sentence prediction tasks, which 
includes layers for language model decoding and classification. These layers 
will not be used for fine-tuning sentence pair classification.\n",
     "\n",
     "Let's load the pre-trained BERT using the module API in MXNet."
    ]
@@ -114,12 +116,15 @@
    ],
    "source": [
     "(def model-path-prefix \"data/static_bert_base_net\")\n",
+    "\n",
     ";; the vocabulary used in the model\n",
     "(def vocab (bert-util/get-vocab))\n",
-    ";; the input question\n",
+    "\n",
     ";; the maximum length of the sequence\n",
     "(def seq-length 128)\n",
     "\n",
+    "(def batch-size 32)\n",
+    "\n",
     "(def bert-base (m/load-checkpoint {:prefix model-path-prefix :epoch 0}))"
    ]
   },
@@ -291,7 +296,7 @@
    "source": [
     "(defn pre-processing\n",
     "  \"Preprocesses the sentences in the format that BERT is expecting\"\n",
-    "  [ctx idx->token token->idx train-item]\n",
+    "  [idx->token token->idx train-item]\n",
     "    (let [[sentence-a sentence-b label] train-item\n",
     "       ;;; pre-processing tokenize sentence\n",
     "          token-1 (bert-util/tokenize (string/lower-case sentence-a))\n",
@@ -319,7 +324,7 @@
     "(def idx->token (:idx->token vocab))\n",
     "(def token->idx (:token->idx vocab))\n",
     "(def dev (context/default-context))\n",
-    "(def processed-datas (mapv #(pre-processing dev idx->token token->idx %) 
data-train-raw))\n",
+    "(def processed-datas (mapv #(pre-processing idx->token token->idx %) 
data-train-raw))\n",
     "(def train-count (count processed-datas))\n",
     "(println \"Train Count is = \" train-count)\n",
     "(println \"[PAD] token id = \" (get token->idx \"[PAD]\"))\n",
@@ -375,8 +380,6 @@
     "                                 (into []))\n",
     "                    :train-num (count processed-datas)})\n",
     "\n",
-    "(def batch-size 32)\n",
-    "\n",
     "(def train-data\n",
     "  (let [{:keys [data0s data1s data2s labels train-num]} prepared-data\n",
     "        data-desc0 (mx-io/data-desc {:name \"data0\"\n",
@@ -480,7 +483,7 @@
     "(def num-epoch 3)\n",
     "\n",
     "(def fine-tune-model (m/module model-sym {:contexts [dev]\n",
-    "                                         :data-names [\"data0\" \"data1\" 
\"data2\"]}))\n",
+    "                                          :data-names [\"data0\" 
\"data1\" \"data2\"]}))\n",
     "\n",
     "(m/fit fine-tune-model {:train-data train-data  :num-epoch num-epoch\n",
     "                        :fit-params (m/fit-params {:allow-missing true\n",
@@ -489,6 +492,122 @@
     "                                                   :optimizer 
(optimizer/adam {:learning-rate 5e-6 :episilon 1e-9})\n",
     "                                                   :batch-end-callback 
(callback/speedometer batch-size 1)})})\n"
    ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Explore results from the fine-tuned model\n",
+    "\n",
+    "Now that our model is fitted, we can use it to infer semantic equivalence 
of arbitrary sentence pairs. Note that for demonstration purpose we skipped the 
warmup learning rate schedule and validation on dev dataset used in the 
original implementation. This means that our model's performance will be 
significantly less than optimal. Please visit 
[here](https://gluon-nlp.mxnet.io/model_zoo/bert/index.html) for the complete 
fine-tuning scripts (using Python and GluonNLP).\n",
+    "\n",
+    "To do inference with our model we need a predictor. It must have a batch 
size of 1 so we can feed the model a single sentence pair."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 14,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "#'bert.bert-sentence-classification/fine-tuned-predictor"
+      ]
+     },
+     "execution_count": 14,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "(def fine-tuned-prefix \"fine-tune-sentence-bert\")\n",
+    "\n",
+    "(m/save-checkpoint fine-tune-model {:prefix fine-tuned-prefix :epoch 
3})\n",
+    "\n",
+    "(def fine-tuned-predictor\n",
+    "    (infer/create-predictor (infer/model-factory fine-tuned-prefix\n",
+    "                                                 [{:name \"data0\" :shape 
[1 seq-length] :dtype dtype/FLOAT32 :layout layout/NT}\n",
+    "                                                  {:name \"data1\" :shape 
[1 seq-length] :dtype dtype/FLOAT32 :layout layout/NT}\n",
+    "                                                  {:name \"data2\" :shape 
[1]            :dtype dtype/FLOAT32 :layout layout/N}])\n",
+    "                            {:epoch 3}))"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Now we can write a function that feeds a sentence pair to the fine-tuned 
model:"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 15,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "#'bert.bert-sentence-classification/predict-equivalence"
+      ]
+     },
+     "execution_count": 15,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "(defn predict-equivalence\n",
+    "    [predictor sentence1 sentence2]\n",
+    "    (let [vocab (bert.util/get-vocab)\n",
+    "          processed-test-data (mapv #(pre-processing (:idx->token 
vocab)\n",
+    "                                                     (:token->idx vocab) 
%)\n",
+    "                                    [[sentence1 sentence2]])\n",
+    "          prediction (infer/predict-with-ndarray predictor\n",
+    "                                                 [(ndarray/array 
(slice-inputs-data processed-test-data 0) [1 seq-length])\n",
+    "                                                  (ndarray/array 
(slice-inputs-data processed-test-data 1) [1 seq-length])\n",
+    "                                                  (ndarray/array 
(slice-inputs-data processed-test-data 2) [1])])]\n",
+    "      (ndarray/->vec (first prediction))))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 22,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "[0.2633881 0.7366119]"
+      ]
+     },
+     "execution_count": 22,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    ";; Modify an existing sentence pair to test:\n",
+    ";; [\"1\"\n",
+    ";;  \"69773\"\n",
+    ";;  \"69792\"\n",
+    ";;  \"Cisco pared spending to compensate for sluggish sales .\"\n",
+    ";;  \"In response to sluggish sales , Cisco pared spending .\"]\n",
+    "(predict-equivalence fine-tuned-predictor\n",
+    "                     \"The company cut spending to compensate for weak 
sales .\"\n",
+    "                     \"In response to poor sales results, the company cut 
spending .\")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## References\n",
+    "\n",
+    "[1] Peters, Matthew E., et al. “Deep contextualized word 
representations.” arXiv preprint arXiv:1802.05365 (2018).\n",
+    "\n",
+    "[2] Devlin, Jacob, et al. “Bert: Pre-training of deep bidirectional 
transformers for language understanding.” arXiv preprint arXiv:1810.04805 
(2018)."
+   ]
   }
  ],
  "metadata": {
diff --git 
a/contrib/clojure-package/examples/bert/src/bert/bert_sentence_classification.clj
 
b/contrib/clojure-package/examples/bert/src/bert/bert_sentence_classification.clj
index 8c056b7..6ec4d58 100644
--- 
a/contrib/clojure-package/examples/bert/src/bert/bert_sentence_classification.clj
+++ 
b/contrib/clojure-package/examples/bert/src/bert/bert_sentence_classification.clj
@@ -16,12 +16,21 @@
 ;;
 
 (ns bert.bert-sentence-classification
+  "Fine-tuning Sentence Pair Classification with BERT
+  This tutorial focuses on fine-tuning with the pre-trained BERT model to 
classify semantically equivalent sentence pairs.
+
+  Specifically, we will:
+    1. load the state-of-the-art pre-trained BERT model
+    2. attach an additional layer for classification
+    3. process and transform sentence pair data for the task at hand
+    4. fine-tune BERT model for sentence classification"
   (:require [bert.util :as bert-util]
             [clojure-csv.core :as csv]
             [clojure.string :as string]
             [org.apache.clojure-mxnet.callback :as callback]
             [org.apache.clojure-mxnet.context :as context]
             [org.apache.clojure-mxnet.dtype :as dtype]
+            [org.apache.clojure-mxnet.infer :as infer]
             [org.apache.clojure-mxnet.io :as mx-io]
             [org.apache.clojure-mxnet.layout :as layout]
             [org.apache.clojure-mxnet.module :as m]
@@ -29,8 +38,25 @@
             [org.apache.clojure-mxnet.optimizer :as optimizer]
             [org.apache.clojure-mxnet.symbol :as sym]))
 
+;; Pre-trained language representations have been shown to improve
+;; many downstream NLP tasks such as question answering, and natural
+;; language inference. To apply pre-trained representations to these
+;; tasks, there are two strategies:
+
+;;  *  feature-based approach, which uses the pre-trained representations as 
additional features to the downstream task.
+;;  * fine-tuning based approach, which trains the downstream tasks by 
fine-tuning pre-trained parameters.
+
+;; While feature-based approaches such as ELMo are effective in
+;; improving many downstream tasks, they require task-specific
+;; architectures. Devlin, Jacob, et al proposed BERT (Bidirectional
+;; Encoder Representations from Transformers), which fine-tunes deep
+;; bidirectional representations on a wide range of tasks with minimal
+;; task-specific parameters, and obtained state-of-the-art results.
+
 (def model-path-prefix "data/static_bert_base_net")
-;; epoch number of the model
+
+(def fine-tuned-prefix "fine-tune-sentence-bert")
+
 ;; the maximum length of the sequence
 (def seq-length 128)
 
@@ -38,20 +64,19 @@
   "Preprocesses the sentences in the format that BERT is expecting"
   [idx->token token->idx train-item]
   (let [[sentence-a sentence-b label] train-item
-       ;;; pre-processing tokenize sentence
+        ;; pre-processing tokenize sentence
         token-1 (bert-util/tokenize (string/lower-case sentence-a))
         token-2 (bert-util/tokenize (string/lower-case sentence-b))
         valid-length (+ (count token-1) (count token-2))
-        ;;; generate token types [0000...1111...0000]
+        ;; generate token types [0000...1111...0000]
         qa-embedded (into (bert-util/pad [] 0 (count token-1))
-
                           (bert-util/pad [] 1 (count token-2)))
         token-types (bert-util/pad qa-embedded 0 seq-length)
-        ;;; make BERT pre-processing standard
+        ;; make BERT pre-processing standard
         token-2 (conj token-2 "[SEP]")
         token-1 (into [] (concat ["[CLS]"] token-1 ["[SEP]"] token-2))
         tokens (bert-util/pad token-1 "[PAD]" seq-length)
-        ;;; pre-processing - token to index translation
+        ;; pre-processing - token to index translation
         indexes (bert-util/tokens->idxs token->idx tokens)]
     {:input-batch [indexes
                    token-types
@@ -83,19 +108,18 @@
 
 (defn get-raw-data []
   (csv/parse-csv (string/replace (slurp "data/dev.tsv") "\"" "")
-               :delimiter \tab
-               :strict true))
+                 :delimiter \tab
+                 :strict true))
 
 (defn prepare-data
-  "This prepares the senetence pairs into NDArrays for use in NDArrayIterator"
-  []
-  (let [raw-file (get-raw-data)
-        vocab (bert-util/get-vocab)
+  "This prepares the sentence pairs into NDArrays for use in NDArrayIterator"
+  [raw-data]
+  (let [vocab (bert-util/get-vocab)
         idx->token (:idx->token vocab)
         token->idx (:token->idx vocab)
-        data-train-raw (->> raw-file
+        data-train-raw (->> raw-data
                             (mapv #(vals (select-keys % [3 4 0])))
-                            (rest) ;;drop header
+                            (rest) ; drop header
                             (into []))
         processed-datas (mapv #(pre-processing idx->token token->idx %) 
data-train-raw)]
     {:data0s (slice-inputs-data processed-datas 0)
@@ -111,7 +135,7 @@
   [dev num-epoch]
   (let [bert-base (m/load-checkpoint {:prefix model-path-prefix :epoch 0})
         model-sym (fine-tune-model (m/symbol bert-base) {:num-classes 2 
:dropout 0.1})
-        {:keys [data0s data1s data2s labels train-num]} (prepare-data)
+        {:keys [data0s data1s data2s labels train-num]} (prepare-data 
(get-raw-data))
         batch-size 32
         data-desc0 (mx-io/data-desc {:name "data0"
                                      :shape [train-num seq-length]
@@ -138,14 +162,16 @@
                                         {:label {label-desc (ndarray/array 
labels [train-num]
                                                                            
{:ctx dev})}
                                          :data-batch-size batch-size})
-        model (m/module model-sym {:contexts [dev]
-                                   :data-names ["data0" "data1" "data2"]})]
-    (m/fit model {:train-data train-data  :num-epoch num-epoch
-                  :fit-params (m/fit-params {:allow-missing true
-                                             :arg-params (m/arg-params 
bert-base)
-                                             :aux-params (m/aux-params 
bert-base)
-                                             :optimizer (optimizer/adam 
{:learning-rate 5e-6 :episilon 1e-9})
-                                             :batch-end-callback 
(callback/speedometer batch-size 1)})})))
+        fitted-model (m/fit (m/module model-sym {:contexts [dev]
+                                                 :data-names ["data0" "data1" 
"data2"]})
+                            {:train-data train-data  :num-epoch num-epoch
+                             :fit-params (m/fit-params {:allow-missing true
+                                                        :arg-params 
(m/arg-params bert-base)
+                                                        :aux-params 
(m/aux-params bert-base)
+                                                        :optimizer 
(optimizer/adam {:learning-rate 5e-6 :epsilon 1e-9})
+                                                        :batch-end-callback 
(callback/speedometer batch-size 1)})})]
+    (m/save-checkpoint fitted-model {:prefix fine-tuned-prefix :epoch 
num-epoch})
+    fitted-model))
 
 (defn -main [& args]
   (let [[dev-arg num-epoch-arg] args
@@ -154,7 +180,46 @@
     (println "Running example with " dev " and " num-epoch " epochs ")
     (train dev num-epoch)))
 
+;; For evaluating the model
+(defn predict-equivalence
+  "Get the fine-tuned model's opinion on whether two sentences are equivalent:"
+  [predictor sentence1 sentence2]
+  (let [vocab (bert.util/get-vocab)
+        processed-test-data (mapv #(pre-processing (:idx->token vocab)
+                                                   (:token->idx vocab) %)
+                                  [[sentence1 sentence2]])
+        prediction (infer/predict-with-ndarray predictor
+                                               [(ndarray/array 
(slice-inputs-data processed-test-data 0) [1 seq-length])
+                                                (ndarray/array 
(slice-inputs-data processed-test-data 1) [1 seq-length])
+                                                (ndarray/array 
(slice-inputs-data processed-test-data 2) [1])])]
+    (ndarray/->vec (first prediction))))
+
 (comment
 
   (train (context/cpu 0) 3)
-  (m/save-checkpoint model {:prefix "fine-tune-sentence-bert" :epoch 3}))
+
+  (m/save-checkpoint model {:prefix fine-tuned-prefix :epoch 3})
+
+  
+  ;;;; Explore results from the fine-tuned model
+
+  ;; We need a predictor with a batch size of 1, so we can feed the
+  ;; model a single sentence pair.
+  (def fine-tuned-predictor
+    (infer/create-predictor (infer/model-factory fine-tuned-prefix
+                                                 [{:name "data0" :shape [1 
seq-length] :dtype dtype/FLOAT32 :layout layout/NT}
+                                                  {:name "data1" :shape [1 
seq-length] :dtype dtype/FLOAT32 :layout layout/NT}
+                                                  {:name "data2" :shape [1]    
        :dtype dtype/FLOAT32 :layout layout/N}])
+                            {:epoch 3}))
+
+  ;; Modify an existing sentence pair to test:
+  ;; ["1"
+  ;;  "69773"
+  ;;  "69792"
+  ;;  "Cisco pared spending to compensate for sluggish sales ."
+  ;;  "In response to sluggish sales , Cisco pared spending ."]
+  (predict-equivalence fine-tuned-predictor
+                       "The company cut spending to compensate for weak sales 
."
+                       "In response to poor sales results, the company cut 
spending .")  
+
+  )
diff --git 
a/contrib/clojure-package/examples/bert/test/bert/bert_sentence_classification_test.clj
 
b/contrib/clojure-package/examples/bert/test/bert/bert_sentence_classification_test.clj
index 355f23e..c26301e 100644
--- 
a/contrib/clojure-package/examples/bert/test/bert/bert_sentence_classification_test.clj
+++ 
b/contrib/clojure-package/examples/bert/test/bert/bert_sentence_classification_test.clj
@@ -26,6 +26,7 @@
             [org.apache.clojure-mxnet.context :as context]
             [org.apache.clojure-mxnet.dtype :as dtype]
             [org.apache.clojure-mxnet.eval-metric :as eval-metric]
+            [org.apache.clojure-mxnet.infer :as infer]
             [org.apache.clojure-mxnet.io :as mx-io]
             [org.apache.clojure-mxnet.layout :as layout]
             [org.apache.clojure-mxnet.ndarray :as ndarray]
@@ -34,6 +35,8 @@
 
 (def model-dir "data/")
 
+(def test-prefix "test-fine-tuning-bert-sentence-pairs")
+
 (when-not (.exists (io/file (str model-dir "static_bert_qa-0002.params")))
   (println "Downloading bert qa data")
   (sh "./get_bert_data.sh"))
@@ -47,7 +50,7 @@
           num-epoch 1
           bert-base (m/load-checkpoint {:prefix model-path-prefix :epoch 0})
           model-sym (fine-tune-model (m/symbol bert-base) {:num-classes 2 
:dropout 0.1})
-          {:keys [data0s data1s data2s labels train-num]} (prepare-data)
+          {:keys [data0s data1s data2s labels train-num]} (prepare-data 
(get-raw-data))
           batch-size 32
           data-desc0 (mx-io/data-desc {:name "data0"
                                        :shape [train-num seq-length]
@@ -82,5 +85,20 @@
                                                :aux-params (m/aux-params 
bert-base)
                                                :optimizer (optimizer/adam 
{:learning-rate 5e-6 :episilon 1e-9})
                                                :batch-end-callback 
(callback/speedometer batch-size 1)})})
-      (is (< 0.5 (-> (m/score model {:eval-data train-data :eval-metric 
(eval-metric/accuracy) })
-                     (last)))))))
+      (m/save-checkpoint model {:prefix test-prefix :epoch num-epoch})
+      (testing "accuracy"
+        (is (< 0.5 (last (m/score model {:eval-data train-data :eval-metric 
(eval-metric/accuracy)})))))
+      (testing "prediction"
+        (let [test-predictor (infer/create-predictor (infer/model-factory 
test-prefix
+                                                                          
[{:name "data0" :shape [1 seq-length] :dtype dtype/FLOAT32 :layout layout/NT}
+                                                                           
{:name "data1" :shape [1 seq-length] :dtype dtype/FLOAT32 :layout layout/NT}
+                                                                           
{:name "data2" :shape [1]            :dtype dtype/FLOAT32 :layout layout/N}])
+                                                     {:epoch num-epoch})
+              prediction (predict-equivalence test-predictor
+                                              "The company cut spending to 
compensate for weak sales ."
+                                              "In response to poor sales 
results, the company cut spending .")]
+          ;; We can't say much about how the model will find this prediction, 
so we test only the prediction's shape.
+          (is (vector? prediction))
+          (is (number? (first prediction)))
+          (is (number? (second prediction)))
+          (is (= 2 (count prediction))))))))
diff --git 
a/contrib/clojure-package/examples/infer/predictor/src/infer/predictor_example.clj
 
b/contrib/clojure-package/examples/infer/predictor/src/infer/predictor_example.clj
index 05eb0ad..41a003a 100644
--- 
a/contrib/clojure-package/examples/infer/predictor/src/infer/predictor_example.clj
+++ 
b/contrib/clojure-package/examples/infer/predictor/src/infer/predictor_example.clj
@@ -99,3 +99,9 @@
       (:help options) (println summary)
       (some? errors) (println (join "\n" errors))
       :else (run-predictor options))))
+
+(comment
+  (run-predictor {:model-path-prefix "models/resnet-18/resnet-18"
+                  :input-image "images/kitten.jpg"})
+
+  )

Reply via email to