nswamy closed pull request #11808: remove mod from arity 2 version of 
load-checkpoint in clojure-package
URL: https://github.com/apache/incubator-mxnet/pull/11808
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/module.clj 
b/contrib/clojure-package/src/org/apache/clojure_mxnet/module.clj
index 22ab761547e..ab6d345fe91 100644
--- a/contrib/clojure-package/src/org/apache/clojure_mxnet/module.clj
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/module.clj
@@ -309,7 +309,6 @@
 
 (defn load-checkpoint
   "Create a model from previously saved checkpoint.
-   - mod module
    - opts map of
      -  prefix Path prefix of saved model files. You should have 
prefix-symbol.json,
                  prefix-xxxx.params, and optionally prefix-xxxx.states,
@@ -341,7 +340,7 @@
     (util/->option (when workload-list (util/vec->indexed-seq workload-list)))
     (util/->option (when fixed-param-names (util/vec->set 
fixed-param-names)))))
   ([prefix epoch]
-   (load-checkpoint mod {:prefix prefix :epoch epoch})))
+   (load-checkpoint {:prefix prefix :epoch epoch})))
 
 (defn load-optimizer-states [mod fname]
   (.mod load fname))
@@ -670,4 +669,3 @@
 
   (fit-params {:allow-missing true})
   (fit-params {}))
-
diff --git 
a/contrib/clojure-package/test/org/apache/clojure_mxnet/module_test.clj 
b/contrib/clojure-package/test/org/apache/clojure_mxnet/module_test.clj
index f3d4e75e8c9..0f71b5a850c 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/module_test.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/module_test.clj
@@ -101,13 +101,20 @@
         (m/init-optimizer {:optimizer (optimizer/sgd {:learning-rate 0.1 
:momentum 0.9})})
         (m/update)
         (m/save-checkpoint {:prefix "test" :epoch 0 :save-opt-states true}))
-
     (let [mod2 (m/load-checkpoint {:prefix "test" :epoch 0 
:load-optimizer-states true})]
       (-> mod2
           (m/bind {:data-shapes [{:name "data" :shape [10 10] :layout "NT"}]})
           (m/init-optimizer {:optimizer (optimizer/sgd {:learning-rate 0.1 
:momentum 0.9})}))
-      (is (= (-> mod m/symbol sym/to-json)  (-> mod2 m/symbol sym/to-json)))
-      (is (= (-> mod m/params first) (-> mod2 m/params first))))))
+      (is (= (-> mod m/symbol sym/to-json) (-> mod2 m/symbol sym/to-json)))
+      (is (= (-> mod m/params first) (-> mod2 m/params first))))
+    ;; arity 2 version of above. `load-optimizer-states` is `false` here by 
default,
+    ;; but optimizers states aren't checked here so it's not relevant to the 
test outcome.
+    (let [mod3 (m/load-checkpoint "test" 0)]
+      (-> mod3
+          (m/bind {:data-shapes [{:name "data" :shape [10 10] :layout "NT"}]})
+          (m/init-optimizer {:optimizer (optimizer/sgd {:learning-rate 0.1 
:momentum 0.9})}))
+      (is (= (-> mod m/symbol sym/to-json) (-> mod3 m/symbol sym/to-json)))
+      (is (= (-> mod m/params first) (-> mod3 m/params first))))))
 
 (deftest test-module-save-load-multi-device
   (let [s (sym/variable "data")
@@ -321,4 +328,3 @@
 (comment
 
   (m/data-shapes x))
-


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to