Changeset: 65dca0ad1418 for MonetDB
URL: https://dev.monetdb.org/hg/MonetDB/rev/65dca0ad1418
Modified Files:
        monetdb5/modules/kernel/vss.c
Branch: nested
Log Message:

add bulk implementation for other vector distance functions


diffs (truncated from 517 to 300 lines):

diff --git a/monetdb5/modules/kernel/vss.c b/monetdb5/modules/kernel/vss.c
--- a/monetdb5/modules/kernel/vss.c
+++ b/monetdb5/modules/kernel/vss.c
@@ -52,11 +52,11 @@
  */
 #define DEFINE_METRIC_L1(NAME, T, R)                                          \
 static R \
-metric_l1_##NAME(MalStkPtr stk, InstrPtr pci, size_t dim)                   \
+metric_l1_##NAME(MalStkPtr stk, InstrPtr pci, size_t dim)                      
\
 {                                                                             \
     R dist = 0;                                                               \
     /* SIMD optimization hints */                                             \
-    GCC_Pragma("GCC ivdep")                                                    
  \
+    GCC_Pragma("GCC ivdep")                                                   \
     for (size_t i=0; i < dim; i++) {                                          \
         T ai = *getArgReference_##T(stk, pci, 1 + i);                         \
         T bi = *getArgReference_##T(stk, pci, 1 + dim + i);                   \
@@ -194,6 +194,7 @@ METRIC##_distance_##TNAME(Client ctx, Ma
     return MAL_SUCCEED;                                                        
         \
 }
 
+
 // Inner (Dot) Product float (32-bit)
 DEFINE_SCALAR_HANDLER(ip, f32, flt, dbl)
 // Inner (Dot) Product double (64-bit)
@@ -214,127 +215,320 @@ DEFINE_SCALAR_HANDLER(cos, f32, flt, dbl
 // cos double (64-bit)
 DEFINE_SCALAR_HANDLER(cos, f64, dbl, dbl)
 
-
+#if 0
 #define PATIAL_L2sq(R, BL, BR, CNT, T)                  \
 do {                                                    \
     const T *a = (const T*) Tloc(BL, 0);                \
     const T *b = (const T*) Tloc(BR, 0);                \
     for(BUN p = 0; p<cnt; p++) {                        \
-        dbl d = a[p] - b[p];                              \
-        R[p] += d*d;                                 \
+        dbl d = a[p] - b[p];                            \
+        R[p] += d*d;                                                           
        \
+    }                                                   \
+} while (0)
+
+#define PATIAL_L2sq_const(R, BL, b, CNT, T)             \
+do {                                                    \
+    const T *a = (const T*) Tloc(BL, 0);                \
+    for(BUN p = 0; p<cnt; p++) {                        \
+        dbl d = a[p] - b;                               \
+        R[p] += d*d;                                                           
        \
+    }                                                   \
+} while (0)
+#endif
+
+/**
+ * @brief Partial Euclidean distance accumulation.
+ * @param RES  Tloc of result BAT
+ * @param RV2   not used
+ * @param COL  Tloc of input BAT
+ * @param QVAL Query value for dimension
+ * @param CNT  count/size of RES and COL
+ * @param T    result type
+ */
+#define PATIAL_l2sq(RES, RV2, COL, QVAL, CNT, T)               \
+do {                                                    \
+    /* SIMD optimization hints */                       \
+    GCC_Pragma("GCC ivdep")                             \
+    for(BUN p = 0; p < (CNT); p++) {                    \
+        T d = (T)COL[p] - (T)QVAL;                                             
\
+        (RES)[p] += d*d;                                                       
        \
     }                                                   \
 } while (0)
 
-#define PATIAL_L2sq_const(R, BL, b, CNT, T)            \
+
+/**
+ * @brief Partial accumulations for cos distance.
+ * @param RDP  dot product buffer
+ * @param RV2  vector values buffer
+ * @param COL  Tloc of input BAT
+ * @param QVAL  Query value for dimension
+ * @param CNT  count/size of COL and RES
+ * @param T            result type
+ */
+#define PATIAL_cos(RDP, RV2, COL, QVAL, CNT, T)                        \
 do {                                                    \
-    const T *a = (const T*) Tloc(BL, 0);                \
-    for(BUN p = 0; p<cnt; p++) {                        \
-        dbl d = a[p] - b;                                \
-        R[p] += d*d;                                 \
+    T q_dim = (T)(QVAL);                                \
+    /* SIMD optimization */                                                    
        \
+    GCC_Pragma("GCC ivdep")                             \
+    for(BUN p = 0; p < (CNT); p++) {                    \
+        /* Accumulate Dot Product */                                   \
+        (RDP)[p] += (T)(COL)[p] * q_dim;                \
+        /* Accumulate vector norms */                                  \
+               (RV2)[p] += (T)(COL)[p] * (T)(COL)[p];                  \
+    }                                                   \
+} while (0)
+
+/**
+ * @brief Partial Inner Product accumulation.
+ * @param RES   Tloc of result BAT
+ * @param RV2   not used
+ * @param COL   Tloc of input BAT
+ * @param QVAL  Query value for dimension
+ * @param CNT   count/size of RES and COL
+ * @param T            result type
+ */
+#define PATIAL_ip(RES, RV2, COL, QVAL, CNT, T)                 \
+do {                                                    \
+    T q_dim = (T)(QVAL);                                \
+    /* SIMD optimization hints */                       \
+    GCC_Pragma("GCC ivdep")                             \
+    for(BUN p = 0; p < (CNT); p++) {                    \
+               (RES)[p] += (T)COL[p] * q_dim;                                  
\
+    }                                                   \
+} while (0)
+
+/**
+ * @brief Partial L1/Manhattan distance.
+ * @param RES    Tloc of result BAT
+ * @param RV2   not used
+ * @param COL    Tloc of input BAT
+ * @param QVAL   Query value for dimension
+ * @param CNT   count/size of RES and COL
+ * @param T      result type
+ */
+#define PATIAL_l1(RES, RV2, COL, QVAL, CNT, T)          \
+do {                                                    \
+    /* SIMD optimization hints */                       \
+    GCC_Pragma("GCC ivdep")                             \
+    for(BUN p = 0; p < (CNT); p++) {                    \
+               T diff = (T)COL[p] - (T)QVAL;                                   
\
+               (RES)[p] += (diff < 0) ? -diff : diff;                  \
     }                                                   \
 } while (0)
 
-static char*
-BATl2sq_distance(Client ctx, MalBlkPtr mb, MalStkPtr stk, InstrPtr pci)
-{
-       (void) ctx;
-       (void) mb;
-       //lng T0 = GDKusec();
-    size_t ncols = (size_t) (pci->argc - pci->retc);
-    // check for even num bats
-    if ((ncols & 1) != 0)
-        throw(MAL, "batvss.l2sq_distance", SQLSTATE(HY097) 
DIMENSION_MISMATCH_ERR);
-    bat *ret = getArgReference_bat(stk, pci, 0);
-    // 1st dimmension bat
-    BAT *b = BATdescriptor(*getArgReference_bat(stk, pci, 1));
-    size_t cnt = 0;
-    if (b)
-        cnt = BATcount(b);
-    BBPreclaim(b);
-
-    BAT *bn = COLnew(b->hseqbase, TYPE_dbl, cnt, TRANSIENT);
-    if (bn == NULL)
-               throw(MAL, "batvss.l2sq_distance", MAL_MALLOC_FAIL);
-
-       dbl *dest = (dbl *) Tloc(bn, 0);
-       memset(dest, 0, sizeof(dbl) * cnt);
-    size_t dims = ncols / 2;
+/**
+ * @brief Generic Macro to generate BAT handlers.
+ * @param METRIC The prefix for the function name (e.g. l2sq, cos, ip).
+ * @param TNAME The suffix for the function name (e.g., f32, f64).
+ * @param T    The scalar type for input vectors (e.g., float, double).
+ * @param R    The result/accumulation type (to prevent overflow).
+ */
+#define DEFINE_BAT_HANDLER(METRIC, TNAME, T, R)                                
                                                                \
+static char*                                                                   
                                                                                
        \
+BAT##METRIC##_distance_##TNAME(Client ctx, MalBlkPtr mb, MalStkPtr stk, 
InstrPtr pci)       \
+{                                                                              
                                                                                
                        \
+       (void) ctx;                                                             
                                                                                
                \
+       (void) mb;                                                              
                                                                                
                \
+       allocator *ta = MT_thread_getallocator();                               
                                                                \
+       allocator_state ta_state = ma_open(ta);                                 
                                                                \
+    size_t ncols = (size_t) (pci->argc - pci->retc);                           
                                                \
+    /*check for even input*/                                                   
                                                                        \
+    if ((ncols & 1) != 0)                                                      
                                                                                
\
+        throw(MAL, "batvss." #METRIC "_distance", SQLSTATE(HY097) 
DIMENSION_MISMATCH_ERR);     \
+    size_t dims = ncols / 2;                                                   
                                                                        \
+       /*find out args input order, e.g. does qry vector come first or not*/   
                                \
+    size_t cnt = 0, qry_idx, bat_idx;                                          
                                                                \
+       int _tpe = getArgType(mb, pci, 1);                                      
                                                                        \
+       BAT *b = NULL;                                                          
                                                                                
        \
+       if (isaBatType(_tpe)) {                                                 
                                                                                
\
+               bat_idx = 1;                                                    
                                                                                
        \
+               qry_idx = 1 + dims;                                             
                                                                                
        \
+               b = BATdescriptor(*getArgReference_bat(stk, pci, bat_idx));     
                                                \
+               if (!b)                                                         
                                                                                
                \
+                       throw(MAL, "batvss." #METRIC "_distance", 
RUNTIME_OBJECT_MISSING);                              \
+               cnt = BATcount(b);                                              
                                                                                
        \
+               BBPreclaim(b);                                                  
                                                                                
        \
+       } else  {                                                               
                                                                                
                \
+               bat_idx = 1 + dims;                                             
                                                                                
        \
+               qry_idx = 1;                                                    
                                                                                
        \
+               b = BATdescriptor(*getArgReference_bat(stk, pci, bat_idx));     
                                                \
+               if (!b)                                                         
                                                                                
                \
+                       throw(MAL, "batvss." #METRIC "_distance", 
RUNTIME_OBJECT_MISSING);                              \
+               cnt = BATcount(b);                                              
                                                                                
        \
+               BBPreclaim(b);                                                  
                                                                                
        \
+       }                                                                       
                                                                                
                        \
+    bat *ret = getArgReference_bat(stk, pci, 0);                               
                                                        \
+    BAT *bn = COLnew(b->hseqbase, TYPE_##R, cnt, TRANSIENT);                   
                                        \
+    if (bn == NULL)                                                            
                                                                                
        \
+               throw(MAL, "batvss." #METRIC "_distance", MAL_MALLOC_FAIL);     
                                                \
+       R *dest = (R*) Tloc(bn, 0);                                             
                                                                                
\
+       memset(dest, 0, sizeof(R) * cnt);                                       
                                                                        \
+       R *v2 = NULL;                                                           
                                                                                
        \
+       R sum_q2 = 0;                                                           
                                                                                
        \
+    if (strcmp(#METRIC, "cos") == 0) {                                         
                                                                \
+               v2 = ma_alloc(ta, sizeof(R) * cnt);                             
                                                                        \
+               if (v2 == NULL)                                                 
                                                                                
        \
+                       throw(MAL, "batvss." #METRIC "_distance", 
MAL_MALLOC_FAIL);                                             \
+       }                                                                       
                                                                                
                        \
+       for (size_t i = 0; i < dims; i++) {                                     
                                                                        \
+               BAT *b = BATdescriptor(*getArgReference_bat(stk, pci, (bat_idx 
+ i)));                          \
+               if (!b) {                                                       
                                                                                
                \
+                       if (bn) BBPunfix(bn->batCacheid);                       
                                                                        \
+                       ma_close(&ta_state);                                    
                                                                                
\
+                       throw(MAL, "batvss." #METRIC "_distance", 
RUNTIME_OBJECT_MISSING);                              \
+               }                                                               
                                                                                
                        \
+               const R qv = (R) *getArgReference_##T(stk, pci, (qry_idx + i)); 
                                        \
+               sum_q2 += qv *qv;                                               
                                                                                
        \
+               const T *b_vals = (const T*) Tloc(b, 0);                        
                                                                \
+               PATIAL_##METRIC(dest, v2, b_vals, qv, cnt, R);                  
                                                        \
+               BBPreclaim(b);                                                  
                                                                                
        \
+       }                                                                       
                                                                                
                        \
+       /* Final transformation for IP */                                       
                                                                        \
+    if (strcmp(#METRIC, "ip") == 0) {                                          
                                                                \
+        for (BUN p = 0; p < cnt; p++) {                                        
                                                                        \
+            dest[p] = (R)1.0 - dest[p];                                        
                                                                        \
+        }                                                                      
                                                                                
                \
+    }                                                                          
                                                                                
                \
+    /* Final transformation for Cosine Distance */                             
                                                        \
+    if (strcmp(#METRIC, "cos") == 0) {                                         
                                                                \
+               R q_norm = SELECT_SQRT(R, sum_q2);                              
                                                                        \
+               for (BUN p = 0; p < cnt; p++) {                                 
                                                                        \
+                       R v_norm = SELECT_SQRT(R, v2[p]);                       
                                                                        \
+                       R result_if_zero[2][2];                                 
                                                                                
\
+                       R similarity = dest[p] / (v_norm * q_norm);             
                                                                \
+                       if (similarity > 1.0)                                   
                                                                                
\
+                               similarity = 1.0;                               
                                                                                
        \
+                       else if (similarity < -1.0)                             
                                                                                
\
+                               similarity = -1.0;                              
                                                                                
        \
+                       result_if_zero[0][0] = (R)1.0 - similarity;             
                                                                \
+                       result_if_zero[0][1] = (R)1.0;                          
                                                                        \
+                       result_if_zero[1][0] = (R)1.0;                          
                                                                        \
+                       result_if_zero[1][1] = (R)0.0;                          
                                                                        \
+                       dest[p] = result_if_zero[v_norm == 0][q_norm == 0];     
                                                        \
+               }                                                               
                                                                                
                        \
+       }                                                                       
                                                                                
                        \
+    /* Finalize BAT metadata */                                                
                                                                                
\
+    BATsetcount(bn, cnt);                                                      
                                                                                
\
+    bn->tnonil = true;                                                         
                                                                                
\
+    bn->tkey = false;                                                          
                                                                                
\
+    bn->tsorted = false;                                                       
                                                                                
\
+    bn->trevsorted = false;                                                    
                                                                                
\
+    *ret = bn->batCacheid;                                                     
                                                                                
\
+    BBPkeepref(bn);                                                            
                                                                                
        \
+       ma_close(&ta_state);                                                    
                                                                                
\
+    return MAL_SUCCEED;                                                        
                                                                                
        \
+}
 
-       int qtype = getArgType(mb, pci, 1 + dims);
-       if (!isaBatType(qtype)) {
-               if (qtype == TYPE_dbl) {
-                       for (size_t i = 0; i < dims; i++) {
-                               BAT *bl = 
BATdescriptor(*getArgReference_bat(stk, pci, (i+1)));
-                               const dbl b = *getArgReference_dbl(stk, pci, 
(i+1) + dims);
-                               int err = !bl;
-                               if (!err)
-                                       PATIAL_L2sq_const(dest, bl, b, cnt, 
dbl);
-                               BBPreclaim(bl);
-                               if (err) {
-                                       if (bn) BBPunfix(bn->batCacheid);
-                                       throw(MAL, "batvss.l2sq_distance", 
RUNTIME_OBJECT_MISSING);
-                               }
-                       }
-               } else {
-                       for (size_t i = 0; i < dims; i++) {
-                               BAT *bl = 
BATdescriptor(*getArgReference_bat(stk, pci, (i+1)));
-                               const flt b = *getArgReference_flt(stk, pci, 
(i+1) + dims);
-                               int err = !bl;
-                               if (!err)
-                                       PATIAL_L2sq_const(dest, bl, b, cnt, 
flt);
-                               BBPreclaim(bl);
-                               if (err) {
-                                       if (bn) BBPunfix(bn->batCacheid);
-                                       throw(MAL, "batvss.l2sq_distance", 
RUNTIME_OBJECT_MISSING);
-                               }
-                       }
-               }
-       } else {
-               if (b->ttype == TYPE_dbl) {
-                       for (size_t i = 0; i < dims; i++) {
-                               BAT *bl = 
BATdescriptor(*getArgReference_bat(stk, pci, (i+1)));
-                               BAT *br = 
BATdescriptor(*getArgReference_bat(stk, pci, (i+1) + dims));
_______________________________________________
checkin-list mailing list -- [email protected]
To unsubscribe send an email to [email protected]

Reply via email to