From 945cd9835bc3b7c779b9c50c122e6a3d3eff10bd Mon Sep 17 00:00:00 2001
From: Jez Ng <me@jezng.com>
Date: Sun, 5 May 2013 23:25:13 -0400
Subject: [PATCH] Implement array-fold and array-fold-all.

---
 module/ice-9/boot-9.scm      | 37 +++++++++++++++++++++++++++++++++++++
 test-suite/tests/arrays.test | 33 +++++++++++++++++++++++++++++++++
 2 files changed, 70 insertions(+)

diff --git a/module/ice-9/boot-9.scm b/module/ice-9/boot-9.scm
index f1fd041..309e9df 100644
--- a/module/ice-9/boot-9.scm
+++ b/module/ice-9/boot-9.scm
@@ -1125,6 +1125,43 @@ VALUE."
   (map (lambda (ind) (if (number? ind) (list 0 (+ -1 ind)) ind))
        (array-dimensions a)))
 
+(define (array-fold-all proc init . arrays)
+  (apply array-for-each
+         (lambda elements
+           (set! init (apply proc (append! elements `(,init)))))
+         arrays)
+  init)
+
+(define (array-fold proc init . arrays)
+  (define shape (array-shape (car arrays)))
+  (define type (array-type (car arrays)))
+  (unless (> (length shape) 1)
+    (error "array-fold must be called on arrays with of at least rank 2"))
+  (for-each (lambda (a)
+              (unless (equal? shape (array-shape a))
+                (error "array shape mismatch"))
+              (unless (equal? type (array-type a))
+                (error "array type mismatch")))
+            (cdr arrays))
+  (let ((result (apply make-typed-array
+                       type
+                       (apply:nconc2last
+                        `(0 ,(list-head shape (- (length shape) 1)))))))
+    (array-index-map!
+     result
+     (lambda indices
+       (apply array-fold-all
+              proc
+              init
+              (map
+               (lambda (a)
+                 (make-shared-array
+                  a
+                  (lambda (i) (append indices `(,i)))
+                  (car (last-pair shape))))
+               arrays))))
+    result))
+
 
 
 ;;; {Keywords}
diff --git a/test-suite/tests/arrays.test b/test-suite/tests/arrays.test
index 0b3d57c..a6610b6 100644
--- a/test-suite/tests/arrays.test
+++ b/test-suite/tests/arrays.test
@@ -678,3 +678,36 @@
                      #u32(2 3)))
     (pass-if (equal? (array-ref (array-row array 1) 0)
                      2))))
+
+;;;
+;;; array-fold
+;;;
+
+(define exception:shape-mismatch
+  (cons 'misc-error "array shape mismatch"))
+
+(define exception:type-mismatch
+  (cons 'misc-error "array type mismatch"))
+
+(define exception:insufficient-rank
+  (cons 'misc-error "array-fold must be called on arrays with of at least rank 2"))
+
+(let ((a1 #2u32((0 1) (2 3)))
+      (a2 #2u32((4 5) (6 7)))
+      (a3 #2u32((8 9)))
+      (a4 #2u16((10 11)))
+      (a5 #u16(12 13 14)))
+  (with-test-prefix "array-fold and array-fold-all"
+    (pass-if (equal? (array-fold + 0 a1)
+                     #u32(1 5)))
+    (pass-if (equal? (array-fold + 0 a2)
+                     #u32(9 13)))
+    (pass-if (equal? (array-fold + 0 a1 a2)
+                     #u32(10 18)))
+    (pass-if (equal? (array-fold-all + 0 a1 a2) 28))
+    (pass-if-exception "wrong shape" exception:shape-mismatch
+      (array-fold + 0 a1 a3))
+    (pass-if-exception "wrong type" exception:type-mismatch
+      (array-fold + 0 a3 a4))
+    (pass-if-exception "insufficient rank" exception:insufficient-rank
+      (array-fold + 0 a5))))
-- 
1.8.2.2

