This is an automated email from the ASF dual-hosted git repository.

corgy pushed a commit to branch dev
in repository https://gitbox.apache.org/repos/asf/seatunnel.git


The following commit(s) were added to refs/heads/dev by this push:
     new c1d2172ac6 [Fix][Transform-V2] Reduce embedding precision from double 
to float (#9635)
c1d2172ac6 is described below

commit c1d2172ac63998ee22929c7f96aee9d796bf3a23
Author: xiaochen <[email protected]>
AuthorDate: Thu Jul 31 21:40:50 2025 +0800

    [Fix][Transform-V2] Reduce embedding precision from double to float (#9635)
---
 docs/en/transform-v2/embedding.md                  |   9 ++
 docs/zh/transform-v2/embedding.md                  |   9 ++
 .../nlpmodel/embedding/remote/AbstractModel.java   |  14 +--
 .../embedding/remote/amazon/BedrockModel.java      |  26 ++---
 .../embedding/remote/custom/CustomModel.java       |   6 +-
 .../embedding/remote/doubao/DoubaoModel.java       |  10 +-
 .../embedding/remote/openai/OpenAIModel.java       |  10 +-
 .../embedding/remote/qianfan/QianfanModel.java     |  10 +-
 .../embedding/remote/zhipu/ZhipuModel.java         |  10 +-
 .../transform/embedding/EmbeddingVectorTest.java   | 125 +++++++++++++++++++++
 10 files changed, 186 insertions(+), 43 deletions(-)

diff --git a/docs/en/transform-v2/embedding.md 
b/docs/en/transform-v2/embedding.md
index 660a291684..819285cd90 100644
--- a/docs/en/transform-v2/embedding.md
+++ b/docs/en/transform-v2/embedding.md
@@ -8,6 +8,8 @@ The `Embedding` transform plugin leverages embedding models to 
convert text data
 transformation can be applied to various fields. The plugin supports multiple 
model providers and can be integrated with
 different API endpoints.
 
+> **Important Note:** The current embedding precision only supports float32 
format.
+
 ## Options
 
 | Name                           | Type   | Required | Default Value | 
Description                                                                     
                                                                                
        |
@@ -27,6 +29,13 @@ different API endpoints.
 | custom_request_headers         | map    | no       |               | Custom 
headers for the request to the model.                                           
                                                                                
 |
 | custom_request_body            | map    | no       |               | Custom 
body for the request. Supports placeholders like `${model}`, `${input}`.        
                                                                                
 |
 
+## Precision Support
+
+**Important:** The current version of the Embedding plugin only supports 
**float32** precision for vector data.
+
+- All generated embedding vectors will be stored in float32 format
+- If your model or API returns other precision formats (such as float64), the 
plugin will automatically convert them to float32
+
 ### model_provider
 
 The providers for generating embeddings include common options such as 
`AMAZON`, `DOUBAO`, `QIANFAN`, and `OPENAI`. Additionally,
diff --git a/docs/zh/transform-v2/embedding.md 
b/docs/zh/transform-v2/embedding.md
index 228e5e9656..ccbdd66821 100644
--- a/docs/zh/transform-v2/embedding.md
+++ b/docs/zh/transform-v2/embedding.md
@@ -6,6 +6,8 @@
 
 `Embedding` 转换插件利用 embedding 
模型将文本数据转换为向量化表示。此转换可以应用于各种字段。该插件支持多种模型提供商,并且可以与不同的API集成。
 
+> **重要提示:** 当前 embedding 精确度仅支持 float32 
+
 ## 配置选项
 
 | 名称                             | 类型     | 是否必填 | 默认值    | 描述                 
                                              |
@@ -25,6 +27,13 @@
 | custom_request_headers         | map    | 否    |        | 发送到模型的请求的自定义头信息。   
                                              |
 | custom_request_body            | map    | 否    |        | 请求体的自定义配置。支持占位符如 
`${model}`、`${input}`。                          |
 
+## 精度支持
+
+**重要:** 当前版本的 Embedding 插件仅支持 **float32** 精度的向量数据。
+
+- 所有生成的 embedding 向量将以 float32 格式存储
+- 如果您的模型或API返回其他精度格式(如 float64),插件会自动转换为 float32
+
 ### embedding_model_provider
 
 用于生成 embedding 的模型提供商。常见选项包括 `AMAZON`、 `DOUBAO`、`QIANFAN`、`OPENAI` 等,同时可选择 
`CUSTOM` 实现自定义 embedding
diff --git 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/AbstractModel.java
 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/AbstractModel.java
index a53fa4684c..0803dfd7ad 100644
--- 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/AbstractModel.java
+++ 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/AbstractModel.java
@@ -42,23 +42,23 @@ public abstract class AbstractModel implements Model {
     public List<ByteBuffer> vectorization(Object[] fields) throws IOException {
         List<ByteBuffer> result = new ArrayList<>();
 
-        List<List<Double>> vectors = batchProcess(fields, 
singleVectorizedInputNumber);
-        for (List<Double> vector : vectors) {
-            result.add(BufferUtils.toByteBuffer(vector.toArray(new 
Double[0])));
+        List<List<Float>> vectors = batchProcess(fields, 
singleVectorizedInputNumber);
+        for (List<Float> vector : vectors) {
+            result.add(BufferUtils.toByteBuffer(vector.toArray(new Float[0])));
         }
         return result;
     }
 
-    protected abstract List<List<Double>> vector(Object[] fields) throws 
IOException;
+    protected abstract List<List<Float>> vector(Object[] fields) throws 
IOException;
 
-    public List<List<Double>> batchProcess(Object[] array, int batchSize) 
throws IOException {
-        List<List<Double>> merged = new ArrayList<>();
+    public List<List<Float>> batchProcess(Object[] array, int batchSize) 
throws IOException {
+        List<List<Float>> merged = new ArrayList<>();
         if (array == null || array.length == 0) {
             return merged;
         }
         for (int i = 0; i < array.length; i += batchSize) {
             Object[] batch = ArrayUtils.subarray(array, i, i + batchSize);
-            List<List<Double>> vector = vector(batch);
+            List<List<Float>> vector = vector(batch);
             merged.addAll(vector);
         }
         if (array.length != merged.size()) {
diff --git 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/amazon/BedrockModel.java
 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/amazon/BedrockModel.java
index 514614e98a..35ab49df22 100644
--- 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/amazon/BedrockModel.java
+++ 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/amazon/BedrockModel.java
@@ -186,7 +186,7 @@ public class BedrockModel extends AbstractModel {
     }
 
     @Override
-    protected List<List<Double>> vector(Object[] fields) throws IOException {
+    protected List<List<Float>> vector(Object[] fields) throws IOException {
         if (fields == null || fields.length == 0) {
             return new ArrayList<>();
         }
@@ -247,26 +247,26 @@ public class BedrockModel extends AbstractModel {
         return requestBody;
     }
 
-    private List<List<Double>> parseSingleResponse(String responseBody) throws 
IOException {
+    private List<List<Float>> parseSingleResponse(String responseBody) throws 
IOException {
         try {
             JsonNode responseJson = OBJECT_MAPPER.readTree(responseBody);
-            List<List<Double>> result = new ArrayList<>();
+            List<List<Float>> result = new ArrayList<>();
 
             if (modelId.startsWith("amazon.titan")) {
                 JsonNode embedding = responseJson.get("embedding");
                 if (embedding != null && embedding.isArray()) {
-                    List<Double> vector = new ArrayList<>();
+                    List<Float> vector = new ArrayList<>();
                     for (JsonNode value : embedding) {
-                        vector.add(value.asDouble());
+                        vector.add(value.floatValue());
                     }
                     result.add(vector);
                 }
             } else if (modelId.startsWith("cohere.")) {
                 JsonNode embeddings = responseJson.get("embeddings");
                 if (embeddings != null && embeddings.isArray() && 
!embeddings.isEmpty()) {
-                    List<Double> vector = new ArrayList<>();
+                    List<Float> vector = new ArrayList<>();
                     for (JsonNode value : embeddings.get(0)) {
-                        vector.add(value.asDouble());
+                        vector.add(value.floatValue());
                     }
                     result.add(vector);
                 }
@@ -278,26 +278,26 @@ public class BedrockModel extends AbstractModel {
         }
     }
 
-    private List<List<Double>> parseBatchResponse(String responseBody) throws 
IOException {
+    private List<List<Float>> parseBatchResponse(String responseBody) throws 
IOException {
         try {
             JsonNode responseJson = OBJECT_MAPPER.readTree(responseBody);
-            List<List<Double>> result = new ArrayList<>();
+            List<List<Float>> result = new ArrayList<>();
             JsonNode embeddings = responseJson.get("embeddings");
             if (embeddings != null && embeddings.isArray()) {
                 if (modelId.startsWith("amazon.titan")) {
                     for (JsonNode embedding : embeddings) {
-                        List<Double> vector = new ArrayList<>();
+                        List<Float> vector = new ArrayList<>();
                         for (JsonNode value : embedding) {
-                            vector.add(value.asDouble());
+                            vector.add(value.floatValue());
                         }
                         result.add(vector);
                     }
 
                 } else if (modelId.startsWith("cohere.")) {
                     for (JsonNode embedding : embeddings) {
-                        List<Double> vector = new ArrayList<>();
+                        List<Float> vector = new ArrayList<>();
                         for (JsonNode value : embedding) {
-                            vector.add(value.asDouble());
+                            vector.add(value.floatValue());
                         }
                         result.add(vector);
                     }
diff --git 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/custom/CustomModel.java
 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/custom/CustomModel.java
index ea39f15462..179315f956 100644
--- 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/custom/CustomModel.java
+++ 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/custom/CustomModel.java
@@ -67,7 +67,7 @@ public class CustomModel extends AbstractModel {
     }
 
     @Override
-    protected List<List<Double>> vector(Object[] fields) throws IOException {
+    protected List<List<Float>> vector(Object[] fields) throws IOException {
         return vectorGeneration(fields);
     }
 
@@ -76,7 +76,7 @@ public class CustomModel extends AbstractModel {
         return vectorGeneration(new Object[] {DIMENSION_EXAMPLE}).size();
     }
 
-    private List<List<Double>> vectorGeneration(Object[] fields) throws 
IOException {
+    private List<List<Float>> vectorGeneration(Object[] fields) throws 
IOException {
         HttpPost post = new HttpPost(apiPath);
         // Construct a request with custom parameters
         for (Map.Entry<String, String> entry : header.entrySet()) {
@@ -96,7 +96,7 @@ public class CustomModel extends AbstractModel {
         }
 
         return OBJECT_MAPPER.convertValue(
-                parseResponse(responseStr), new 
TypeReference<List<List<Double>>>() {});
+                parseResponse(responseStr), new 
TypeReference<List<List<Float>>>() {});
     }
 
     @VisibleForTesting
diff --git 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/doubao/DoubaoModel.java
 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/doubao/DoubaoModel.java
index 1591cd587e..f2b1e348c7 100644
--- 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/doubao/DoubaoModel.java
+++ 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/doubao/DoubaoModel.java
@@ -54,7 +54,7 @@ public class DoubaoModel extends AbstractModel {
     }
 
     @Override
-    protected List<List<Double>> vector(Object[] fields) throws IOException {
+    protected List<List<Float>> vector(Object[] fields) throws IOException {
         return vectorGeneration(fields);
     }
 
@@ -63,7 +63,7 @@ public class DoubaoModel extends AbstractModel {
         return vectorGeneration(new Object[] {DIMENSION_EXAMPLE}).size();
     }
 
-    private List<List<Double>> vectorGeneration(Object[] fields) throws 
IOException {
+    private List<List<Float>> vectorGeneration(Object[] fields) throws 
IOException {
         HttpPost post = new HttpPost(apiPath);
         post.setHeader("Authorization", "Bearer " + apiKey);
         post.setHeader("Content-Type", "application/json");
@@ -82,14 +82,14 @@ public class DoubaoModel extends AbstractModel {
         }
 
         JsonNode data = OBJECT_MAPPER.readTree(responseStr).get("data");
-        List<List<Double>> embeddings = new ArrayList<>();
+        List<List<Float>> embeddings = new ArrayList<>();
 
         if (data.isArray()) {
             for (JsonNode node : data) {
                 JsonNode embeddingNode = node.get("embedding");
-                List<Double> embedding =
+                List<Float> embedding =
                         OBJECT_MAPPER.readValue(
-                                embeddingNode.traverse(), new 
TypeReference<List<Double>>() {});
+                                embeddingNode.traverse(), new 
TypeReference<List<Float>>() {});
                 embeddings.add(embedding);
             }
         }
diff --git 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/openai/OpenAIModel.java
 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/openai/OpenAIModel.java
index 467d6cb406..2a45cc829f 100644
--- 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/openai/OpenAIModel.java
+++ 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/openai/OpenAIModel.java
@@ -53,7 +53,7 @@ public class OpenAIModel extends AbstractModel {
     }
 
     @Override
-    protected List<List<Double>> vector(Object[] fields) throws IOException {
+    protected List<List<Float>> vector(Object[] fields) throws IOException {
         if (fields.length > 1) {
             throw new IllegalArgumentException("OpenAI model only supports 
single input");
         }
@@ -65,7 +65,7 @@ public class OpenAIModel extends AbstractModel {
         return vectorGeneration(new Object[] {DIMENSION_EXAMPLE}).size();
     }
 
-    private List<List<Double>> vectorGeneration(Object[] fields) throws 
IOException {
+    private List<List<Float>> vectorGeneration(Object[] fields) throws 
IOException {
         HttpPost post = new HttpPost(apiPath);
         post.setHeader("Authorization", "Bearer " + apiKey);
         post.setHeader("Content-Type", "application/json");
@@ -84,14 +84,14 @@ public class OpenAIModel extends AbstractModel {
         }
 
         JsonNode data = OBJECT_MAPPER.readTree(responseStr).get("data");
-        List<List<Double>> embeddings = new ArrayList<>();
+        List<List<Float>> embeddings = new ArrayList<>();
 
         if (data.isArray()) {
             for (JsonNode node : data) {
                 JsonNode embeddingNode = node.get("embedding");
-                List<Double> embedding =
+                List<Float> embedding =
                         OBJECT_MAPPER.readValue(
-                                embeddingNode.traverse(), new 
TypeReference<List<Double>>() {});
+                                embeddingNode.traverse(), new 
TypeReference<List<Float>>() {});
                 embeddings.add(embedding);
             }
         }
diff --git 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/qianfan/QianfanModel.java
 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/qianfan/QianfanModel.java
index 67c1a8147a..f85619eb3e 100644
--- 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/qianfan/QianfanModel.java
+++ 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/qianfan/QianfanModel.java
@@ -100,7 +100,7 @@ public class QianfanModel extends AbstractModel {
     }
 
     @Override
-    public List<List<Double>> vector(Object[] fields) throws IOException {
+    public List<List<Float>> vector(Object[] fields) throws IOException {
         return vectorGeneration(fields);
     }
 
@@ -109,7 +109,7 @@ public class QianfanModel extends AbstractModel {
         return vectorGeneration(new Object[] 
{DIMENSION_EXAMPLE}).get(0).size();
     }
 
-    private List<List<Double>> vectorGeneration(Object[] fields) throws 
IOException {
+    private List<List<Float>> vectorGeneration(Object[] fields) throws 
IOException {
         String formattedApiPath =
                 String.format(
                         (apiPath.endsWith("/") ? apiPath : apiPath + "/") + 
"%s?access_token=%s",
@@ -143,14 +143,14 @@ public class QianfanModel extends AbstractModel {
                     "Failed to get vector from qianfan, response: " + 
result.get("error_msg"));
         }
 
-        List<List<Double>> embeddings = new ArrayList<>();
+        List<List<Float>> embeddings = new ArrayList<>();
         JsonNode data = result.get("data");
         if (data.isArray()) {
             for (JsonNode node : data) {
-                List<Double> embedding =
+                List<Float> embedding =
                         OBJECT_MAPPER.readValue(
                                 node.get("embedding").traverse(),
-                                new TypeReference<List<Double>>() {});
+                                new TypeReference<List<Float>>() {});
                 embeddings.add(embedding);
             }
         }
diff --git 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/zhipu/ZhipuModel.java
 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/zhipu/ZhipuModel.java
index df72261bb5..a36535821e 100644
--- 
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/zhipu/ZhipuModel.java
+++ 
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/zhipu/ZhipuModel.java
@@ -66,7 +66,7 @@ public class ZhipuModel extends AbstractModel {
     }
 
     @Override
-    public List<List<Double>> vector(Object[] fields) throws IOException {
+    public List<List<Float>> vector(Object[] fields) throws IOException {
         return vectorGeneration(fields);
     }
 
@@ -75,7 +75,7 @@ public class ZhipuModel extends AbstractModel {
         return dimension;
     }
 
-    private List<List<Double>> vectorGeneration(Object[] fields) throws 
IOException {
+    private List<List<Float>> vectorGeneration(Object[] fields) throws 
IOException {
 
         if (fields == null || fields.length > MAX_INPUT_SIZE) {
             throw new IOException(
@@ -98,14 +98,14 @@ public class ZhipuModel extends AbstractModel {
             throw new IOException("Failed to get vector from zhipu, response: 
" + responseStr);
         }
         JsonNode data = OBJECT_MAPPER.readTree(responseStr).get("data");
-        List<List<Double>> embeddings = new ArrayList<>();
+        List<List<Float>> embeddings = new ArrayList<>();
 
         if (data.isArray()) {
             for (JsonNode node : data) {
                 JsonNode embeddingNode = node.get("embedding");
-                List<Double> embedding =
+                List<Float> embedding =
                         OBJECT_MAPPER.readValue(
-                                embeddingNode.traverse(), new 
TypeReference<List<Double>>() {});
+                                embeddingNode.traverse(), new 
TypeReference<List<Float>>() {});
                 embeddings.add(embedding);
             }
         }
diff --git 
a/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/EmbeddingVectorTest.java
 
b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/EmbeddingVectorTest.java
new file mode 100644
index 0000000000..0813de6588
--- /dev/null
+++ 
b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/EmbeddingVectorTest.java
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.seatunnel.transform.embedding;
+
+import 
org.apache.seatunnel.shade.com.fasterxml.jackson.core.type.TypeReference;
+import org.apache.seatunnel.shade.com.fasterxml.jackson.databind.JsonNode;
+import 
org.apache.seatunnel.shade.com.fasterxml.jackson.databind.node.ArrayNode;
+import 
org.apache.seatunnel.shade.com.fasterxml.jackson.databind.node.ObjectNode;
+
+import org.apache.seatunnel.common.utils.BufferUtils;
+import org.apache.seatunnel.transform.nlpmodel.embedding.remote.AbstractModel;
+
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Test;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.List;
+
+public class EmbeddingVectorTest {
+    private static class MockApiModel extends AbstractModel {
+
+        public MockApiModel() {
+            super(1);
+        }
+
+        @Override
+        protected List<List<Float>> vector(Object[] fields) throws IOException 
{
+            String mockApiResponse = createMockApiResponse(fields);
+            return parseApiResponse(mockApiResponse);
+        }
+
+        private String createMockApiResponse(Object[] fields) {
+            ObjectNode response = OBJECT_MAPPER.createObjectNode();
+            response.put("object", "list");
+            response.put("model", "text-embedding-3-small");
+
+            ArrayNode dataArray = OBJECT_MAPPER.createArrayNode();
+
+            for (int i = 0; i < fields.length; i++) {
+                ObjectNode embeddingObj = OBJECT_MAPPER.createObjectNode();
+                embeddingObj.put("object", "embedding");
+                embeddingObj.put("index", i);
+                ArrayNode embeddingArray = OBJECT_MAPPER.createArrayNode();
+                embeddingArray.add(-0.006929283495992422);
+                embeddingArray.add(-0.005336422007530928);
+                embeddingArray.add(-4.547132266452536e-05);
+                embeddingArray.add(-0.024047505110502243);
+
+                embeddingObj.set("embedding", embeddingArray);
+                dataArray.add(embeddingObj);
+            }
+
+            response.set("data", dataArray);
+
+            ObjectNode usage = OBJECT_MAPPER.createObjectNode();
+            usage.put("prompt_tokens", 5);
+            usage.put("total_tokens", 5);
+            response.set("usage", usage);
+
+            return response.toString();
+        }
+
+        private List<List<Float>> parseApiResponse(String responseStr) throws 
IOException {
+            JsonNode responseJson = OBJECT_MAPPER.readTree(responseStr);
+            JsonNode data = responseJson.get("data");
+            List<List<Float>> embeddings = new ArrayList<>();
+
+            if (data.isArray()) {
+                for (JsonNode node : data) {
+                    JsonNode embeddingNode = node.get("embedding");
+                    List<Float> embedding =
+                            OBJECT_MAPPER.readValue(
+                                    embeddingNode.traverse(), new 
TypeReference<List<Float>>() {});
+                    embeddings.add(embedding);
+                }
+            }
+            return embeddings;
+        }
+
+        @Override
+        public Integer dimension() throws IOException {
+            return 4;
+        }
+
+        @Override
+        public void close() throws IOException {}
+    }
+
+    /**
+     * Currently, when the embedding model returns a type of double, it gets 
converted to float,
+     * resulting in a loss of precision.
+     */
+    @Test
+    public void testVectorPrecision() throws IOException {
+        MockApiModel model = new MockApiModel();
+        Object[] inputFields = {"test input"};
+        List<ByteBuffer> result = model.vectorization(inputFields);
+        ByteBuffer buffer = result.get(0);
+        Float[] embedding = BufferUtils.toFloatArray(buffer);
+        Assertions.assertEquals(4, embedding.length);
+        Assertions.assertEquals(-0.0069292835f, embedding[0]);
+        Assertions.assertEquals(-0.005336422f, embedding[1]);
+        Assertions.assertEquals(-4.5471323E-5f, embedding[2]);
+        Assertions.assertEquals(-0.024047505f, embedding[3]);
+
+        model.close();
+    }
+}

Reply via email to