Changeset: 65dca0ad1418 for MonetDB URL: https://dev.monetdb.org/hg/MonetDB/rev/65dca0ad1418 Modified Files: monetdb5/modules/kernel/vss.c Branch: nested Log Message:
add bulk implementation for other vector distance functions
diffs (truncated from 517 to 300 lines):
diff --git a/monetdb5/modules/kernel/vss.c b/monetdb5/modules/kernel/vss.c
--- a/monetdb5/modules/kernel/vss.c
+++ b/monetdb5/modules/kernel/vss.c
@@ -52,11 +52,11 @@
*/
#define DEFINE_METRIC_L1(NAME, T, R) \
static R \
-metric_l1_##NAME(MalStkPtr stk, InstrPtr pci, size_t dim) \
+metric_l1_##NAME(MalStkPtr stk, InstrPtr pci, size_t dim)
\
{ \
R dist = 0; \
/* SIMD optimization hints */ \
- GCC_Pragma("GCC ivdep")
\
+ GCC_Pragma("GCC ivdep") \
for (size_t i=0; i < dim; i++) { \
T ai = *getArgReference_##T(stk, pci, 1 + i); \
T bi = *getArgReference_##T(stk, pci, 1 + dim + i); \
@@ -194,6 +194,7 @@ METRIC##_distance_##TNAME(Client ctx, Ma
return MAL_SUCCEED;
\
}
+
// Inner (Dot) Product float (32-bit)
DEFINE_SCALAR_HANDLER(ip, f32, flt, dbl)
// Inner (Dot) Product double (64-bit)
@@ -214,127 +215,320 @@ DEFINE_SCALAR_HANDLER(cos, f32, flt, dbl
// cos double (64-bit)
DEFINE_SCALAR_HANDLER(cos, f64, dbl, dbl)
-
+#if 0
#define PATIAL_L2sq(R, BL, BR, CNT, T) \
do { \
const T *a = (const T*) Tloc(BL, 0); \
const T *b = (const T*) Tloc(BR, 0); \
for(BUN p = 0; p<cnt; p++) { \
- dbl d = a[p] - b[p]; \
- R[p] += d*d; \
+ dbl d = a[p] - b[p]; \
+ R[p] += d*d;
\
+ } \
+} while (0)
+
+#define PATIAL_L2sq_const(R, BL, b, CNT, T) \
+do { \
+ const T *a = (const T*) Tloc(BL, 0); \
+ for(BUN p = 0; p<cnt; p++) { \
+ dbl d = a[p] - b; \
+ R[p] += d*d;
\
+ } \
+} while (0)
+#endif
+
+/**
+ * @brief Partial Euclidean distance accumulation.
+ * @param RES Tloc of result BAT
+ * @param RV2 not used
+ * @param COL Tloc of input BAT
+ * @param QVAL Query value for dimension
+ * @param CNT count/size of RES and COL
+ * @param T result type
+ */
+#define PATIAL_l2sq(RES, RV2, COL, QVAL, CNT, T) \
+do { \
+ /* SIMD optimization hints */ \
+ GCC_Pragma("GCC ivdep") \
+ for(BUN p = 0; p < (CNT); p++) { \
+ T d = (T)COL[p] - (T)QVAL;
\
+ (RES)[p] += d*d;
\
} \
} while (0)
-#define PATIAL_L2sq_const(R, BL, b, CNT, T) \
+
+/**
+ * @brief Partial accumulations for cos distance.
+ * @param RDP dot product buffer
+ * @param RV2 vector values buffer
+ * @param COL Tloc of input BAT
+ * @param QVAL Query value for dimension
+ * @param CNT count/size of COL and RES
+ * @param T result type
+ */
+#define PATIAL_cos(RDP, RV2, COL, QVAL, CNT, T) \
do { \
- const T *a = (const T*) Tloc(BL, 0); \
- for(BUN p = 0; p<cnt; p++) { \
- dbl d = a[p] - b; \
- R[p] += d*d; \
+ T q_dim = (T)(QVAL); \
+ /* SIMD optimization */
\
+ GCC_Pragma("GCC ivdep") \
+ for(BUN p = 0; p < (CNT); p++) { \
+ /* Accumulate Dot Product */ \
+ (RDP)[p] += (T)(COL)[p] * q_dim; \
+ /* Accumulate vector norms */ \
+ (RV2)[p] += (T)(COL)[p] * (T)(COL)[p]; \
+ } \
+} while (0)
+
+/**
+ * @brief Partial Inner Product accumulation.
+ * @param RES Tloc of result BAT
+ * @param RV2 not used
+ * @param COL Tloc of input BAT
+ * @param QVAL Query value for dimension
+ * @param CNT count/size of RES and COL
+ * @param T result type
+ */
+#define PATIAL_ip(RES, RV2, COL, QVAL, CNT, T) \
+do { \
+ T q_dim = (T)(QVAL); \
+ /* SIMD optimization hints */ \
+ GCC_Pragma("GCC ivdep") \
+ for(BUN p = 0; p < (CNT); p++) { \
+ (RES)[p] += (T)COL[p] * q_dim;
\
+ } \
+} while (0)
+
+/**
+ * @brief Partial L1/Manhattan distance.
+ * @param RES Tloc of result BAT
+ * @param RV2 not used
+ * @param COL Tloc of input BAT
+ * @param QVAL Query value for dimension
+ * @param CNT count/size of RES and COL
+ * @param T result type
+ */
+#define PATIAL_l1(RES, RV2, COL, QVAL, CNT, T) \
+do { \
+ /* SIMD optimization hints */ \
+ GCC_Pragma("GCC ivdep") \
+ for(BUN p = 0; p < (CNT); p++) { \
+ T diff = (T)COL[p] - (T)QVAL;
\
+ (RES)[p] += (diff < 0) ? -diff : diff; \
} \
} while (0)
-static char*
-BATl2sq_distance(Client ctx, MalBlkPtr mb, MalStkPtr stk, InstrPtr pci)
-{
- (void) ctx;
- (void) mb;
- //lng T0 = GDKusec();
- size_t ncols = (size_t) (pci->argc - pci->retc);
- // check for even num bats
- if ((ncols & 1) != 0)
- throw(MAL, "batvss.l2sq_distance", SQLSTATE(HY097)
DIMENSION_MISMATCH_ERR);
- bat *ret = getArgReference_bat(stk, pci, 0);
- // 1st dimmension bat
- BAT *b = BATdescriptor(*getArgReference_bat(stk, pci, 1));
- size_t cnt = 0;
- if (b)
- cnt = BATcount(b);
- BBPreclaim(b);
-
- BAT *bn = COLnew(b->hseqbase, TYPE_dbl, cnt, TRANSIENT);
- if (bn == NULL)
- throw(MAL, "batvss.l2sq_distance", MAL_MALLOC_FAIL);
-
- dbl *dest = (dbl *) Tloc(bn, 0);
- memset(dest, 0, sizeof(dbl) * cnt);
- size_t dims = ncols / 2;
+/**
+ * @brief Generic Macro to generate BAT handlers.
+ * @param METRIC The prefix for the function name (e.g. l2sq, cos, ip).
+ * @param TNAME The suffix for the function name (e.g., f32, f64).
+ * @param T The scalar type for input vectors (e.g., float, double).
+ * @param R The result/accumulation type (to prevent overflow).
+ */
+#define DEFINE_BAT_HANDLER(METRIC, TNAME, T, R)
\
+static char*
\
+BAT##METRIC##_distance_##TNAME(Client ctx, MalBlkPtr mb, MalStkPtr stk,
InstrPtr pci) \
+{
\
+ (void) ctx;
\
+ (void) mb;
\
+ allocator *ta = MT_thread_getallocator();
\
+ allocator_state ta_state = ma_open(ta);
\
+ size_t ncols = (size_t) (pci->argc - pci->retc);
\
+ /*check for even input*/
\
+ if ((ncols & 1) != 0)
\
+ throw(MAL, "batvss." #METRIC "_distance", SQLSTATE(HY097)
DIMENSION_MISMATCH_ERR); \
+ size_t dims = ncols / 2;
\
+ /*find out args input order, e.g. does qry vector come first or not*/
\
+ size_t cnt = 0, qry_idx, bat_idx;
\
+ int _tpe = getArgType(mb, pci, 1);
\
+ BAT *b = NULL;
\
+ if (isaBatType(_tpe)) {
\
+ bat_idx = 1;
\
+ qry_idx = 1 + dims;
\
+ b = BATdescriptor(*getArgReference_bat(stk, pci, bat_idx));
\
+ if (!b)
\
+ throw(MAL, "batvss." #METRIC "_distance",
RUNTIME_OBJECT_MISSING); \
+ cnt = BATcount(b);
\
+ BBPreclaim(b);
\
+ } else {
\
+ bat_idx = 1 + dims;
\
+ qry_idx = 1;
\
+ b = BATdescriptor(*getArgReference_bat(stk, pci, bat_idx));
\
+ if (!b)
\
+ throw(MAL, "batvss." #METRIC "_distance",
RUNTIME_OBJECT_MISSING); \
+ cnt = BATcount(b);
\
+ BBPreclaim(b);
\
+ }
\
+ bat *ret = getArgReference_bat(stk, pci, 0);
\
+ BAT *bn = COLnew(b->hseqbase, TYPE_##R, cnt, TRANSIENT);
\
+ if (bn == NULL)
\
+ throw(MAL, "batvss." #METRIC "_distance", MAL_MALLOC_FAIL);
\
+ R *dest = (R*) Tloc(bn, 0);
\
+ memset(dest, 0, sizeof(R) * cnt);
\
+ R *v2 = NULL;
\
+ R sum_q2 = 0;
\
+ if (strcmp(#METRIC, "cos") == 0) {
\
+ v2 = ma_alloc(ta, sizeof(R) * cnt);
\
+ if (v2 == NULL)
\
+ throw(MAL, "batvss." #METRIC "_distance",
MAL_MALLOC_FAIL); \
+ }
\
+ for (size_t i = 0; i < dims; i++) {
\
+ BAT *b = BATdescriptor(*getArgReference_bat(stk, pci, (bat_idx
+ i))); \
+ if (!b) {
\
+ if (bn) BBPunfix(bn->batCacheid);
\
+ ma_close(&ta_state);
\
+ throw(MAL, "batvss." #METRIC "_distance",
RUNTIME_OBJECT_MISSING); \
+ }
\
+ const R qv = (R) *getArgReference_##T(stk, pci, (qry_idx + i));
\
+ sum_q2 += qv *qv;
\
+ const T *b_vals = (const T*) Tloc(b, 0);
\
+ PATIAL_##METRIC(dest, v2, b_vals, qv, cnt, R);
\
+ BBPreclaim(b);
\
+ }
\
+ /* Final transformation for IP */
\
+ if (strcmp(#METRIC, "ip") == 0) {
\
+ for (BUN p = 0; p < cnt; p++) {
\
+ dest[p] = (R)1.0 - dest[p];
\
+ }
\
+ }
\
+ /* Final transformation for Cosine Distance */
\
+ if (strcmp(#METRIC, "cos") == 0) {
\
+ R q_norm = SELECT_SQRT(R, sum_q2);
\
+ for (BUN p = 0; p < cnt; p++) {
\
+ R v_norm = SELECT_SQRT(R, v2[p]);
\
+ R result_if_zero[2][2];
\
+ R similarity = dest[p] / (v_norm * q_norm);
\
+ if (similarity > 1.0)
\
+ similarity = 1.0;
\
+ else if (similarity < -1.0)
\
+ similarity = -1.0;
\
+ result_if_zero[0][0] = (R)1.0 - similarity;
\
+ result_if_zero[0][1] = (R)1.0;
\
+ result_if_zero[1][0] = (R)1.0;
\
+ result_if_zero[1][1] = (R)0.0;
\
+ dest[p] = result_if_zero[v_norm == 0][q_norm == 0];
\
+ }
\
+ }
\
+ /* Finalize BAT metadata */
\
+ BATsetcount(bn, cnt);
\
+ bn->tnonil = true;
\
+ bn->tkey = false;
\
+ bn->tsorted = false;
\
+ bn->trevsorted = false;
\
+ *ret = bn->batCacheid;
\
+ BBPkeepref(bn);
\
+ ma_close(&ta_state);
\
+ return MAL_SUCCEED;
\
+}
- int qtype = getArgType(mb, pci, 1 + dims);
- if (!isaBatType(qtype)) {
- if (qtype == TYPE_dbl) {
- for (size_t i = 0; i < dims; i++) {
- BAT *bl =
BATdescriptor(*getArgReference_bat(stk, pci, (i+1)));
- const dbl b = *getArgReference_dbl(stk, pci,
(i+1) + dims);
- int err = !bl;
- if (!err)
- PATIAL_L2sq_const(dest, bl, b, cnt,
dbl);
- BBPreclaim(bl);
- if (err) {
- if (bn) BBPunfix(bn->batCacheid);
- throw(MAL, "batvss.l2sq_distance",
RUNTIME_OBJECT_MISSING);
- }
- }
- } else {
- for (size_t i = 0; i < dims; i++) {
- BAT *bl =
BATdescriptor(*getArgReference_bat(stk, pci, (i+1)));
- const flt b = *getArgReference_flt(stk, pci,
(i+1) + dims);
- int err = !bl;
- if (!err)
- PATIAL_L2sq_const(dest, bl, b, cnt,
flt);
- BBPreclaim(bl);
- if (err) {
- if (bn) BBPunfix(bn->batCacheid);
- throw(MAL, "batvss.l2sq_distance",
RUNTIME_OBJECT_MISSING);
- }
- }
- }
- } else {
- if (b->ttype == TYPE_dbl) {
- for (size_t i = 0; i < dims; i++) {
- BAT *bl =
BATdescriptor(*getArgReference_bat(stk, pci, (i+1)));
- BAT *br =
BATdescriptor(*getArgReference_bat(stk, pci, (i+1) + dims));
_______________________________________________
checkin-list mailing list -- [email protected]
To unsubscribe send an email to [email protected]
