wingo pushed a commit to branch main
in repository guile.

commit 88e0933450ed9f1cc96858641e0756ffa44c53c6
Author: Andy Wingo <wi...@pobox.com>
AuthorDate: Tue Mar 5 23:02:25 2024 +0100

    Rework make-c-struct, parse-c-struct
    
    * module/system/foreign.scm (bytevector-complex-single-native-ref)
    (bytevector-complex-single-native-set!)
    (bytevector-complex-double-native-ref)
    (bytevector-complex-double-native-set!): Be more static in our
    definitions.
    (compile-time-eval):
    (switch/compile-time-keys): New helpers.
    (align): Make available at compile-time.
    (read-field, read-fields, write-field, write-fields): New helpers.  More
    efficient than the alist.
    (write-c-struct, read-c-struct): Rework in terms of new helpers.
    (parse-c-struct): Just use sizeof to get the size.
---
 module/system/foreign.scm | 232 ++++++++++++++++++++++++++++++----------------
 1 file changed, 152 insertions(+), 80 deletions(-)

diff --git a/module/system/foreign.scm b/module/system/foreign.scm
index 3ddfd204b..438ecd5ed 100644
--- a/module/system/foreign.scm
+++ b/module/system/foreign.scm
@@ -16,10 +16,12 @@
 
 
 (define-module (system foreign)
+  #:use-module (ice-9 match)
   #:use-module (rnrs bytevectors)
   #:use-module (srfi srfi-1)
   #:use-module (srfi srfi-9)
   #:use-module (srfi srfi-9 gnu)
+  #:use-module (system base target)
   #:export (void
             float double
             complex-float complex-double
@@ -77,6 +79,36 @@
 ;;; Structures.
 ;;;
 
+(define-syntax compile-time-eval
+  (lambda (stx)
+    "Evaluate the target-dependent expression EXP at compile-time if we are
+not cross-compiling; otherwise leave it to be evaluated at run-time."
+    (syntax-case stx ()
+      ((_ exp)
+       (if (equal? (target-type) %host-type)
+           #`(quote
+              #,(datum->syntax #'here
+                               (primitive-eval (syntax->datum #'exp))))
+           #'exp)))))
+
+;; Note that in a cross-compiled Guile, the host and the target may have
+;; different values of, say, `long'.  However the explicitly-sized types
+;; int8, float, etc have the same value on all platforms.  sizeof on
+;; these types is also a target-invariant primitive.  alignof is notably
+;; *not* target-invariant.
+
+(define-syntax switch/compile-time-keys
+  (syntax-rules (else)
+    ((_ x (k expr) ... (else alt))
+     (let ((t x))
+       (cond
+        ((eq? t (compile-time-eval k)) expr)
+        ...
+        (else alt))))))
+
+(define-syntax-rule (align off alignment)
+  (1+ (logior (1- off) (1- alignment))))
+
 (define bytevector-pointer-ref
   (case (sizeof '*)
     ((8) (lambda (bv offset)
@@ -93,85 +125,130 @@
            (bytevector-u32-native-set! bv offset (pointer-address ptr))))
     (else (error "what machine is this?"))))
 
-(define (writer-complex set size)
-  (lambda (bv i val)
-    (set bv i (real-part val))
-    (set bv (+ i size) (imag-part val))))
-
-(define (reader-complex ref size)
-  (lambda (bv i)
-    (make-rectangular
-     (ref bv i)
-     (ref bv (+ i size)))))
-
-(define *writers*
-  `((,float . ,bytevector-ieee-single-native-set!)
-    (,double . ,bytevector-ieee-double-native-set!)
-    (,complex-float
-     . ,(writer-complex bytevector-ieee-single-native-set! (sizeof float)))
-    (,complex-double
-     . ,(writer-complex bytevector-ieee-double-native-set! (sizeof double)))
-    (,int8 . ,bytevector-s8-set!)
-    (,uint8 . ,bytevector-u8-set!)
-    (,int16 . ,bytevector-s16-native-set!)
-    (,uint16 . ,bytevector-u16-native-set!)
-    (,int32 . ,bytevector-s32-native-set!)
-    (,uint32 . ,bytevector-u32-native-set!)
-    (,int64 . ,bytevector-s64-native-set!)
-    (,uint64 . ,bytevector-u64-native-set!)
-    (* . ,bytevector-pointer-set!)))
-
-(define *readers*
-  `((,float . ,bytevector-ieee-single-native-ref)
-    (,double . ,bytevector-ieee-double-native-ref)
-    (,complex-float
-     . ,(reader-complex bytevector-ieee-single-native-ref (sizeof float)))
-    (,complex-double
-     . ,(reader-complex bytevector-ieee-double-native-ref (sizeof double)))
-    (,int8 . ,bytevector-s8-ref)
-    (,uint8 . ,bytevector-u8-ref)
-    (,int16 . ,bytevector-s16-native-ref)
-    (,uint16 . ,bytevector-u16-native-ref)
-    (,int32 . ,bytevector-s32-native-ref)
-    (,uint32 . ,bytevector-u32-native-ref)
-    (,int64 . ,bytevector-s64-native-ref)
-    (,uint64 . ,bytevector-u64-native-ref)
-    (* . ,bytevector-pointer-ref)))
-
-(define (align off alignment)
-  (1+ (logior (1- off) (1- alignment))))
+(define-syntax-rule (define-complex-accessors (read write) (%read %write size))
+  (begin
+    (define (read bv offset)
+      (make-rectangular
+       (%read bv offset)
+       (%read bv (+ offset size))))
+    (define (write bv offset val)
+      (%write bv offset (real-part val))
+      (%write bv (+ offset size) (imag-part val)))))
+
+(define-complex-accessors
+  (bytevector-complex-single-native-ref bytevector-complex-single-native-set!)
+  (bytevector-ieee-single-native-ref bytevector-ieee-single-native-set! 4))
+
+(define-complex-accessors
+  (bytevector-complex-double-native-ref bytevector-complex-double-native-set!)
+  (bytevector-ieee-double-native-ref bytevector-ieee-double-native-set! 8))
+
+(define-syntax-rule (read-field %bv %offset %type)
+  (let ((bv %bv)
+        (offset %offset)
+        (type %type))
+    (define-syntax-rule (%read type reader)
+      (let* ((offset (align offset (compile-time-eval (alignof type))))
+             (val (reader bv offset)))
+        (values val
+                (+ offset (compile-time-eval (sizeof type))))))
+    (define-syntax-rule (dispatch-read type (%%type reader) (... ...))
+      (switch/compile-time-keys
+       type
+       (%%type (%read %%type reader))
+       (... ...)
+       (else
+        (let ((offset (align offset (alignof type))))
+          (values (read-c-struct bv offset type)
+                  (+ offset (sizeof type)))))))
+    (dispatch-read
+     type
+     (int8 bytevector-s8-ref)
+     (uint8 bytevector-u8-ref)
+     (int16 bytevector-s16-native-ref)
+     (uint16 bytevector-u16-native-ref)
+     (int32 bytevector-s32-native-ref)
+     (uint32 bytevector-u32-native-ref)
+     (int64 bytevector-s64-native-ref)
+     (uint64 bytevector-u64-native-ref)
+     (float bytevector-ieee-single-native-ref)
+     (double bytevector-ieee-double-native-ref)
+     (complex-float bytevector-complex-single-native-ref)
+     (complex-double bytevector-complex-double-native-ref)
+     ('* bytevector-pointer-ref))))
+
+(define-syntax read-fields
+  (syntax-rules ()
+    ((read-fields () bv offset k) (k offset))
+    ((read-fields ((field type) . rest) bv offset k)
+     (call-with-values (lambda ()
+                         (read-field bv offset (compile-time-eval type)))
+       (lambda (field offset)
+         (read-fields rest bv offset k))))))
+
+(define-syntax-rule (write-field %bv %offset %type %value)
+  (let ((bv %bv)
+        (offset %offset)
+        (type %type)
+        (value %value))
+    (define-syntax-rule (%write type writer)
+      (let ((offset (align offset (compile-time-eval (alignof type)))))
+        (writer bv offset value)
+        (+ offset (compile-time-eval (sizeof type)))))
+    (define-syntax-rule (dispatch-write type (%%type writer) (... ...))
+      (switch/compile-time-keys
+       type
+       (%%type (%write %%type writer))
+       (... ...)
+       (else
+        (let ((offset (align offset (alignof type))))
+          (write-c-struct bv offset type value)
+          (+ offset (sizeof type))))))
+    (dispatch-write
+     type
+     (int8 bytevector-s8-set!)
+     (uint8 bytevector-u8-set!)
+     (int16 bytevector-s16-native-set!)
+     (uint16 bytevector-u16-native-set!)
+     (int32 bytevector-s32-native-set!)
+     (uint32 bytevector-u32-native-set!)
+     (int64 bytevector-s64-native-set!)
+     (uint64 bytevector-u64-native-set!)
+     (float bytevector-ieee-single-native-set!)
+     (double bytevector-ieee-double-native-set!)
+     (complex-float bytevector-complex-single-native-set!)
+     (complex-double bytevector-complex-double-native-set!)
+     ('* bytevector-pointer-set!))))
+
+(define-syntax write-fields
+  (syntax-rules ()
+    ((write-fields () bv offset k) (k offset))
+    ((write-fields ((field type) . rest) bv offset k)
+     (let ((offset (write-field bv offset (compile-time-eval type) field)))
+       (write-fields rest bv offset k)))))
 
+;; Same as write-fields, but with run-time dispatch.
 (define (write-c-struct bv offset types vals)
   (let lp ((offset offset) (types types) (vals vals))
-    (cond
-     ((not (pair? types))
-      (or (null? vals)
-          (error "too many values" vals)))
-     ((not (pair? vals))
-      (error "too few values" types))
-     (else
-      ;; alignof will error-check
-      (let* ((type (car types))
-             (offset (align offset (alignof type))))
-        (if (pair? type)
-            (write-c-struct bv offset (car types) (car vals))
-            ((assv-ref *writers* type) bv offset (car vals)))
-        (lp (+ offset (sizeof type)) (cdr types) (cdr vals)))))))
+    (match types
+      (() (match vals
+            (() #t)
+            (_ (error "too many values" vals))))
+      ((type . types)
+       (match vals
+         ((val . vals)
+          (lp (write-field bv offset type val) types vals))
+         (() (error "too few values" vals)))))))
 
+;; Same as read-fields, but with run-time dispatch.
 (define (read-c-struct bv offset types)
-  (let lp ((offset offset) (types types) (vals '()))
-    (cond
-     ((not (pair? types))
-      (reverse vals))
-     (else
-      ;; alignof will error-check
-      (let* ((type (car types))
-             (offset (align offset (alignof type))))
-        (lp (+ offset (sizeof type)) (cdr types)
-            (cons (if (pair? type)
-                      (read-c-struct bv offset (car types))
-                      ((assv-ref *readers* type) bv offset))
-                  vals)))))))
+  (let lp ((offset offset) (types types))
+    (match types
+      (() '())
+      ((type . types)
+       (call-with-values (lambda () (read-field bv offset type))
+         (lambda (val offset)
+           (cons val (lp offset types))))))))
 
 (define (make-c-struct types vals)
   (let ((bv (make-bytevector (sizeof types) 0)))
@@ -179,12 +256,7 @@
     (bytevector->pointer bv)))
 
 (define (parse-c-struct foreign types)
-  (let ((size (fold (lambda (type total)
-                      (+ (sizeof type)
-                         (align total (alignof type))))
-                    0
-                    types)))
-    (read-c-struct (pointer->bytevector foreign size) 0 types)))
+  (read-c-struct (pointer->bytevector foreign (sizeof types)) 0 types))
 
 
 ;;;

Reply via email to