branch: externals/vecdb
commit a5ab9d957e2f7fe84d738a563e0b445bb3673ba3
Author: Andrew Hyatt <ahy...@gmail.com>
Commit: Andrew Hyatt <ahy...@gmail.com>

    Get psql tests to pass integration tests
---
 Eldev                     |  3 ++-
 vecdb-integration-test.el | 26 ++++++++++++------
 vecdb-psql.el             | 68 ++++++++++++++++++++++++++---------------------
 vecdb.el                  |  2 +-
 4 files changed, 59 insertions(+), 40 deletions(-)

diff --git a/Eldev b/Eldev
index ad2d86f3ce..43c29f47cf 100644
--- a/Eldev
+++ b/Eldev
@@ -1,4 +1,5 @@
-; -*- mode: emacs-lisp; lexical-binding: t -*-
+                                        ; -*- mode: emacs-lisp; 
lexical-binding: t -*-
 
 (eldev-use-package-archive 'gnu-elpa)
+(eldev-use-package-archive 'nongnu-elpa)
 (eldev-use-plugin 'maintainer)
diff --git a/vecdb-integration-test.el b/vecdb-integration-test.el
index 2517746500..9b2fd1d82b 100644
--- a/vecdb-integration-test.el
+++ b/vecdb-integration-test.el
@@ -47,6 +47,7 @@
 (require 'vecdb)
 (require 'vecdb-chroma)
 (require 'vecdb-qdrant)
+(require 'vecdb-psql)
 (require 'cl-lib) ;; For cl-remove-if-not, cl-every
 
 (declare-function chroma-ext--tmp-project-dir "ext-chroma")
@@ -84,8 +85,8 @@ Skips tests if no providers are configured."
                 (postgres-password (getenv "PSQL_PASSWORD")))
 
             (when postgres-username
-              (make-psql-vecdb-provider
-               :database postgres-db
+              (make-vecdb-psql-provider
+               :dbname postgres-db
                :username postgres-username
                :password postgres-password)))))
 
@@ -104,6 +105,7 @@ itself might globally skip if no providers at all are 
configured)."
   (declare (indent defun))
   (let ((chroma-test-name (intern (format "%s-chroma" base-name)))
         (qdrant-test-name (intern (format "%s-qdrant" base-name)))
+        (psql-test-name (intern(format "%s-psql" base-name)))
         (base-doc (or docstring (format "Test %s for a vector database 
provider." base-name))))
     `(progn
        (ert-deftest ,chroma-test-name ()
@@ -122,7 +124,15 @@ itself might globally skip if no providers at all are 
configured)."
                                              (vecdb-test--get-providers))))
            (if current-provider
                (funcall ,body-function current-provider)
-             (ert-skip (format "Qdrant provider not configured for %s" 
',qdrant-test-name))))))))
+             (ert-skip (format "Qdrant provider not configured for %s" 
',qdrant-test-name)))))
+       (ert-deftest ,psql-test-name ()
+         ,(format "%s (Postgres)" base-doc)
+         (interactive)
+         (let ((current-provider (cl-find-if (lambda (p) (eq (type-of p) 
'vecdb-psql-provider))
+                                             (vecdb-test--get-providers))))
+           (if current-provider
+               (funcall ,body-function current-provider)
+             (ert-skip (format "Postgres provider not configured for %s" 
',psql-test-name))))))))
 
 (defmacro with-test-collection (current-provider collection-var 
collection-name-base options &rest body)
   "Execute BODY with COLLECTION-VAR bound to a new collection.
@@ -134,11 +144,11 @@ The full collection name is generated by appending the 
provider's name.
 The collection is created before BODY and deleted afterwards."
   (declare (indent 1) (debug t))
   (let ((full-collection-name (gensym "full-collection-name-"))
-        (vector-size-val (gensym "vector-size-"))
         (default-vector-size 3))
     `(let* ((,full-collection-name (format "%s-%s" ,collection-name-base 
(vecdb-provider-name ,current-provider)))
-            (,vector-size-val (or (plist-get ,options :vector-size) 
,default-vector-size))
-            (,collection-var (make-vecdb-collection :name 
,full-collection-name :vector-size ,vector-size-val)))
+            (,collection-var (make-vecdb-collection :name ,full-collection-name
+                                                    :vector-size (or 
(plist-get ,options :vector-size) ,default-vector-size)
+                                                    :payload-fields (plist-get 
,options :payload-fields))))
        (unwind-protect
            (progn
              (vecdb-create ,current-provider ,collection-var)
@@ -177,7 +187,7 @@ The collection is created before BODY and deleted 
afterwards."
                  (make-vecdb-item :id 1 :vector [0 1 2] :payload '(:val 1))
                  (make-vecdb-item :id 2 :vector [0 1 2] :payload '(:val 2))
                  (make-vecdb-item :id 3 :vector [0 1 2] :payload '(:val 3)))))
-    (with-test-collection current-provider current-collection collection-name 
`(:vector-size ,vector-size)
+    (with-test-collection current-provider current-collection collection-name 
`(:vector-size ,vector-size :payload-fields ((val . integer)))
                           (vecdb-upsert-items current-provider 
current-collection items t)
                           (dolist (item items)
                             (let ((retrieved-item (vecdb-get-item 
current-provider current-collection (vecdb-item-id item))))
@@ -202,7 +212,7 @@ The collection is created before BODY and deleted 
afterwards."
          (item2 (make-vecdb-item :id 2 :vector [0.4 0.5 0.6] :payload '(:val 
2)))
          (item3 (make-vecdb-item :id 3 :vector [0.7 0.8 0.9] :payload '(:val 
3)))
          (items (list item1 item2 item3)))
-    (with-test-collection current-provider current-collection collection-name 
`(:vector-size ,vector-size)
+    (with-test-collection current-provider current-collection collection-name 
`(:vector-size ,vector-size :payload-fields ((val . integer)))
                           (vecdb-upsert-items current-provider 
current-collection items t)
                           ;; Search for a vector similar to item2
                           (let ((results (vecdb-search-by-vector 
current-provider current-collection [0.41 0.51 0.61] 3)))
diff --git a/vecdb-psql.el b/vecdb-psql.el
index 75dbfc8f15..fe67f76a87 100644
--- a/vecdb-psql.el
+++ b/vecdb-psql.el
@@ -30,8 +30,8 @@
 (require 'map)
 (require 'seq)
 
-(cl-defstruct (vecdb-psql (:include vecdb-provider
-                                    (name "postgres")))
+(cl-defstruct (vecdb-psql-provider (:include vecdb-provider
+                                             (name "postgres")))
   "Provider for the vector database.
 DBNAME is the database name, which must have been created by the user."
   dbname
@@ -44,28 +44,33 @@ DBNAME is the database name, which must have been created 
by the user."
 
 (defun vecdb-psql-get-connection (provider)
   "Get a connection to the database specified by PROVIDER."
-  (let* ((key (vecdb-psql-dbname provider))
+  (let* ((key (vecdb-psql-provider-dbname provider))
          (connection (gethash key vecdb-psql-connection-cache)))
     (unless connection
       (setq connection
             (pg-connect
-             (vecdb-psql-dbname provider)
-             (vecdb-psql-username provider)
-             (vecdb-psql-password provider)))
+             (vecdb-psql-provider-dbname provider)
+             (vecdb-psql-provider-username provider)
+             (vecdb-psql-provider-password provider)))
       (puthash key connection vecdb-psql-connection-cache))
     connection))
 
-(cl-defmethod vecdb-create ((provider vecdb-psql)
+(defun vecdb-psql-table-name (collection-name)
+  "Turn COLLECTION-NAME into a safe table name."
+  (replace-regexp-in-string "[^a-zA-Z0-9_]" "_" (downcase collection-name)))
+
+(cl-defmethod vecdb-create ((provider vecdb-psql-provider)
                             (collection vecdb-collection))
   "Create COLLECTION in database PROVIDER."
   (pg-exec (vecdb-psql-get-connection provider)
            (format "CREATE TABLE IF NOT EXISTS %s (
                      id INTEGER PRIMARY KEY,
-                     vector VECTOR(%d) NOT NULL,
+                     vector VECTOR(%d) NOT NULL%s
                      %s
                    );"
-                   (vecdb-collection-name collection)
+                   (vecdb-psql-table-name (vecdb-collection-name collection))
                    (vecdb-collection-vector-size collection)
+                   (if (vecdb-collection-payload-fields collection) "," "")
                    (mapconcat
                     (lambda (field)
                       (format "%s %s NULL"
@@ -79,25 +84,25 @@ DBNAME is the database name, which must have been created 
by the user."
                     ", ")))
   (pg-exec (vecdb-psql-get-connection provider)
            (format "CREATE INDEX IF NOT EXISTS %s_embedding_hnsw_idx ON %s 
USING hnsw (vector vector_cosine_ops)"
-                   (vecdb-collection-name collection)
-                   (vecdb-collection-name collection)))
+                   (vecdb-psql-table-name (vecdb-collection-name collection))
+                   (vecdb-psql-table-name (vecdb-collection-name collection))))
   (mapc (lambda (field)
           (pg-exec (vecdb-psql-get-connection provider)
                    (format "CREATE INDEX IF NOT EXISTS %s_%s_idx ON %s (%s)"
-                           (vecdb-collection-name collection)
+                           (vecdb-psql-table-name (vecdb-collection-name 
collection))
                            (car field)
-                           (vecdb-collection-name collection)
+                           (vecdb-psql-table-name (vecdb-collection-name 
collection))
                            (car field))))
         (vecdb-collection-payload-fields collection)))
 
-(cl-defmethod vecdb-delete ((provider vecdb-psql)
+(cl-defmethod vecdb-delete ((provider vecdb-psql-provider)
                             (collection vecdb-collection))
   "Delete COLLECTION from database PROVIDER."
   (pg-exec (vecdb-psql-get-connection provider)
            (format "DROP TABLE IF EXISTS %s;"
-                   (vecdb-collection-name collection))))
+                   (vecdb-psql-table-name (vecdb-collection-name 
collection)))))
 
-(cl-defmethod vecdb-exists ((provider vecdb-psql)
+(cl-defmethod vecdb-exists ((provider vecdb-psql-provider)
                             (collection vecdb-collection))
   "Check if the COLLECTION exists in the database specified by PROVIDER."
   (let ((result
@@ -106,7 +111,7 @@ DBNAME is the database name, which must have been created 
by the user."
                             SELECT FROM information_schema.tables
                             WHERE table_name = '%s'
                           );"
-                          (vecdb-collection-name collection)))))
+                          (vecdb-psql-table-name (vecdb-collection-name 
collection))))))
     (and result
          (equal (caar (pg-result result :tuples)) t))))
 
@@ -115,15 +120,15 @@ DBNAME is the database name, which must have been created 
by the user."
   (cl-loop for (k _v) on plist by #'cddr
            collect (substring (symbol-name k) 1)))
 
-(cl-defmethod vecdb-upsert-items ((provider vecdb-psql)
+(cl-defmethod vecdb-upsert-items ((provider vecdb-psql-provider)
                                   (collection vecdb-collection)
                                   data-list &optional _)
   "Upsert items into the COLLECTION in the database PROVIDER.
 All items in DATA-LIST must have the same paylaods."
   (pg-exec (vecdb-psql-get-connection provider)
            (format "INSERT INTO %s (id, vector, %s) VALUES %s
-                    ON CONFLICT (id) DO UPDATE SET vector = EXCLUDED.vector, 
%s;"
-                   (vecdb-collection-name collection)
+                    ON CONFLICT (id) DO UPDATE SET vector = EXCLUDED.vector%s 
%s;"
+                   (vecdb-psql-table-name (vecdb-collection-name collection))
                    ;; We assume every vecdb-item has the same payload structure
                    (mapconcat #'identity (vecdb-psql--plist-keys
                                           (vecdb-item-payload (car data-list)))
@@ -145,6 +150,7 @@ All items in DATA-LIST must have the same paylaods."
                                ", ")))
                     data-list
                     ", ")
+                   (if (vecdb-collection-payload-fields collection) ", " "")
                    (mapconcat
                     (lambda (field)
                       (format "%s = EXCLUDED.%s" (car field) (car field)))
@@ -165,7 +171,7 @@ All items in DATA-LIST must have the same paylaods."
                                                        :test #'equal))
                                      row))))))
 
-(cl-defmethod vecdb-get-item ((provider vecdb-psql)
+(cl-defmethod vecdb-get-item ((provider vecdb-psql-provider)
                               (collection vecdb-collection)
                               id)
   "Get an item from COLLECTION by ID.
@@ -173,19 +179,20 @@ PROVIDER specifies the database that the collection is 
in."
   (let ((result
          (pg-result
           (pg-exec (vecdb-psql-get-connection provider)
-                   (format "SELECT id, vector::vector, %s FROM %s WHERE id = 
%d;"
+                   (format "SELECT id, vector::vector%s %s FROM %s WHERE id = 
%d;"
+                           (if (vecdb-collection-payload-fields collection) ", 
" "")
                            (mapconcat
                             (lambda (field)
-                              (car field))
+                              (format "%s" (car field)))
                             (vecdb-collection-payload-fields collection)
                             ", ")
-                           (vecdb-collection-name collection)
+                           (vecdb-psql-table-name (vecdb-collection-name 
collection))
                            id))
           :tuples)))
     (when result
       (vecdb-psql--full-row-to-item (car result) collection))))
 
-(cl-defmethod vecdb-delete-items ((provider vecdb-psql)
+(cl-defmethod vecdb-delete-items ((provider vecdb-psql-provider)
                                   (collection vecdb-collection)
                                   ids &optional _)
   "Delete items from COLLECTION by IDs.
@@ -193,10 +200,10 @@ PROVIDER is the database that the collection is in."
   (when ids
     (pg-exec (vecdb-psql-get-connection provider)
              (format "DELETE FROM %s WHERE id IN (%s);"
-                     (vecdb-collection-name collection)
+                     (vecdb-psql-table-name (vecdb-collection-name collection))
                      (mapconcat #'number-to-string ids ", ")))))
 
-(cl-defmethod vecdb-search-by-vector ((provider vecdb-psql)
+(cl-defmethod vecdb-search-by-vector ((provider vecdb-psql-provider)
                                       (collection vecdb-collection)
                                       vector
                                       &optional limit)
@@ -209,14 +216,15 @@ PROVIDER is the database that the collection is in."
               (vecdb-psql--full-row-to-item row collection))
             (pg-result
              (pg-exec (vecdb-psql-get-connection provider)
-                      (format "SELECT id, vector::vector, %s FROM %s
+                      (format "SELECT id, vector::vector%s %s FROM %s
                       ORDER BY vector <-> '[%s]'::vector %s;"
+                              (if (vecdb-collection-payload-fields collection) 
", " "")
                               (mapconcat
                                (lambda (field)
-                                 (car field))
+                                 (format "%s" (car field)))
                                (vecdb-collection-payload-fields collection)
                                ", ")
-                              (vecdb-collection-name collection)
+                              (vecdb-psql-table-name (vecdb-collection-name 
collection))
                               (mapconcat
                                (lambda (v)
                                  (format "%s" v))
diff --git a/vecdb.el b/vecdb.el
index 78179c804b..489fac3acb 100644
--- a/vecdb.el
+++ b/vecdb.el
@@ -4,7 +4,7 @@
 
 ;; Author: Andrew Hyatt <ahy...@gmail.com>
 ;; Homepage: https://github.com/ahyatt/vecdb
-;; Package-Requires: ((emacs "29.1") (plz "0.8"))
+;; Package-Requires: ((emacs "29.1") (plz "0.8") pg)
 ;; Package-Version: 0.1
 ;; SPDX-License-Identifier: GPL-3.0-or-later
 ;;

Reply via email to