gigasquid closed pull request #12630: [MXNET-12627] Fixed param coercion of 
clojure executor/forward
URL: https://github.com/apache/incubator-mxnet/pull/12630
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/executor.clj 
b/contrib/clojure-package/src/org/apache/clojure_mxnet/executor.clj
index 4f4155e2d80..b9883f77d56 100644
--- a/contrib/clojure-package/src/org/apache/clojure_mxnet/executor.clj
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/executor.clj
@@ -34,7 +34,8 @@
    (do (.forward executor)
        executor))
   ([executor is-train kwargs]
-   (do (.forward executor is-train (util/nil-or-coerce-param kwargs 
#{"scala.collection.immutable.Map"})))))
+   (do (.forward executor is-train (util/map->scala-tuple-seq kwargs))
+       executor)))
 
 (defn backward
   "* Do backward pass to get the gradient of arguments.
diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj 
b/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj
index 8f2bb3bfae9..6f22b0eb3a0 100644
--- a/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj
@@ -204,3 +204,24 @@
     (throw (ex-info error-msg
                     (s/explain-data spec value)))))
 
+(defn map->scala-tuple-seq
+  "* Convert a map to a scala-Seq of scala-Tubple.
+   * Should also work if a seq of seq of 2 things passed.
+   * Otherwise passed through unchanged."
+  [map-or-tuple-seq]
+  (letfn [(key->name [k]
+            (if (or (keyword? k) (string? k) (symbol? k))
+              (string/replace (name k) "-" "_")
+              k))
+          (->tuple [kvp-or-tuple]
+            (if (coll? kvp-or-tuple)
+              (let [[k v] kvp-or-tuple]
+                ($/tuple (key->name k) v))
+              ;; pass-through
+              kvp-or-tuple))]
+    (if (coll? map-or-tuple-seq)
+      (->> map-or-tuple-seq
+           (map ->tuple)
+           (apply $/immutable-list))
+      ;; pass-through
+      map-or-tuple-seq)))
diff --git 
a/contrib/clojure-package/test/org/apache/clojure_mxnet/executor_test.clj 
b/contrib/clojure-package/test/org/apache/clojure_mxnet/executor_test.clj
index b2a87d41e34..fb73f009156 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/executor_test.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/executor_test.clj
@@ -74,3 +74,23 @@
     (is (every? #(= 4.0 %) (->> (executor/outputs exec)
                                 (map ndarray/->vec)
                                 first)))))
+
+(deftest test-forward
+  (let [a (sym/variable "a")
+        b (sym/variable "b")
+        c (sym/+ a b)
+        ex (sym/bind c {:a (ndarray/* (ndarray/ones [1 2]) 2)
+                        :b (ndarray/* (ndarray/ones [1 2]) 3)})]
+    ;; test forward with binded values
+    (executor/forward ex)
+    (is (= [5.0 5.0] (-> ex executor/outputs first ndarray/->vec)))
+    ;; test forward with new a (b is still [3.0 3.0]
+    (executor/forward ex false {:a (ndarray/* (ndarray/ones [1 2]) 4)})
+    (is (= [7.0 7.0] (-> ex executor/outputs first ndarray/->vec)))
+    ;; test forward with new b (a is still [4.0 4.0]
+    (executor/forward ex false {:b (ndarray/* (ndarray/ones [1 2]) 5)})
+    (is (= [9.0 9.0] (-> ex executor/outputs first ndarray/->vec)))
+    ;; test forward with new a & b
+    (executor/forward ex false {:a (ndarray/* (ndarray/ones [1 2]) 6)
+                                :b (ndarray/* (ndarray/ones [1 2]) 7)})
+    (is (= [13.0 13.0] (-> ex executor/outputs first ndarray/->vec)))))
diff --git 
a/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj 
b/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj
index de3480827ba..ee7710317e4 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj
@@ -190,4 +190,19 @@
         data2 [1 1 1 1 9 9 9 9]
         data3 [1 1 1 2]]
     (is (not (test-util/approx= 1e-9 data1 data2)))
-    (is (test-util/approx= 2 data1 data3))))
\ No newline at end of file
+    (is (test-util/approx= 2 data1 data3))))
+
+(deftest test-map->scala-tuple-seq
+  ;; convert as much, and pass-through the rest
+  (is (nil? (util/map->scala-tuple-seq nil)))
+  (is (= "List()"
+         (str (util/map->scala-tuple-seq {}))
+         (str (util/map->scala-tuple-seq []))
+         (str (util/map->scala-tuple-seq '()))))
+  (is (= "List(a, b)" (str (util/map->scala-tuple-seq ["a" "b"]))))
+  (is (= "List((a,b), (c,d), (e,f), (a_b,g), (c_d,h), (e_f,i))"
+         (str (util/map->scala-tuple-seq {:a "b", 'c "d", "e" "f"
+                                          :a-b "g", 'c-d "h", "e-f" "i"}))))
+  (let [nda (util/map->scala-tuple-seq {:a-b (ndarray/ones [1 2])})]
+    (is (= "a_b" (._1 (.head nda))))
+    (is (= [1.0 1.0] (ndarray/->vec (._2 (.head nda)))))))


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to