This is an automated email from the ASF dual-hosted git repository.

arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new f93a6aa97f [SYSTEMDS-3555] Add unit tests for word embeddings
f93a6aa97f is described below

commit f93a6aa97f806ea7ba4a94f163e1c7537795809c
Author: Elias-Strauss <[email protected]>
AuthorDate: Fri Apr 28 14:16:41 2023 +0200

    [SYSTEMDS-3555] Add unit tests for word embeddings
    
    This patch adds a unit test for applying pre-trained
    word embeddings on string tokens.
    
    Closes #1815 #1817
---
 .../TransformFrameEncodeWordEmbeddingTest.java     | 152 +++++++++++++++++++++
 .../TransformFrameEncodeWordEmbeddings.dml         |  36 +++++
 2 files changed, 188 insertions(+)

diff --git 
a/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbeddingTest.java
 
b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbeddingTest.java
new file mode 100644
index 0000000000..9972f6b1c6
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbeddingTest.java
@@ -0,0 +1,152 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.transform;
+
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.lops.Lop;
+import org.apache.sysds.runtime.matrix.data.MatrixValue;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+import java.io.BufferedWriter;
+import java.io.FileWriter;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Date;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+public class TransformFrameEncodeWordEmbeddingTest extends AutomatedTestBase
+{
+    private final static String TEST_NAME1 = 
"TransformFrameEncodeWordEmbeddings";
+    private final static String TEST_DIR = "functions/transform/";
+    private final static String TEST_CLASS_DIR = TEST_DIR + 
TransformFrameEncodeWordEmbeddingTest.class.getSimpleName() + "/";
+
+    @Override
+    public void setUp() {
+        TestUtils.clearAssertionInformation();
+        addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, 
TEST_NAME1));
+    }
+
+    @Test
+    public void testTransformToWordEmbeddings() {
+        runTransformTest(TEST_NAME1, ExecMode.SINGLE_NODE);
+    }
+
+    private void runTransformTest(String testname, ExecMode rt)
+    {
+        //set runtime platform
+        ExecMode rtold = setExecMode(rt);
+        try
+        {
+            int rows = 100;
+            int cols = 100;
+            getAndLoadTestConfiguration(testname);
+            fullDMLScriptName = getScript();
+
+            // Generate random embeddings for the distinct tokens
+            double[][] a = createRandomMatrix("embeddings", rows, cols, 0, 10, 
1, new Date().getTime());
+            // Generate random distinct tokens
+            List<String> strings = generateRandomStrings(rows, 10);
+            // Generate the dictionary by assigning unique ID to each distinct 
token
+            Map<String,Integer> map = writeDictToCsvFile(strings, 
baseDirectory + INPUT_DIR + "dict");
+            // Create the dataset by repeating and shuffling the distinct 
tokens
+            List<String> stringsColumn = shuffleAndMultiplyStrings(strings, 
10);
+            writeStringsToCsvFile(stringsColumn, baseDirectory + INPUT_DIR + 
"data");
+
+            programArgs = new String[]{"-stats","-args", input("embeddings"), 
input("data"), input("dict"), output("result")};
+            runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
+
+            // Manually derive the expected result
+            double[][] res_expected = new double[stringsColumn.size()][cols];
+            for (int i = 0; i < stringsColumn.size(); i++) {
+                int rowMapped = map.get(stringsColumn.get(i));
+                System.arraycopy(a[rowMapped], 0, res_expected[i], 0, cols);
+            }
+
+            // Compare results
+            HashMap<MatrixValue.CellIndex, Double> res_actual = 
readDMLMatrixFromOutputDir("result");
+            
TestUtils.compareMatrices(TestUtils.convertHashMapToDoubleArray(res_actual), 
res_expected, 1e-6);
+        }
+        catch(Exception ex) {
+            throw new RuntimeException(ex);
+        }
+        finally {
+            resetExecMode(rtold);
+        }
+    }
+
+    public static List<String> shuffleAndMultiplyStrings(List<String> strings, 
int multiply){
+        List<String> out = new ArrayList<>();
+        Random random = new Random();
+        for (int i = 0; i < strings.size()*multiply; i++) {
+            out.add(strings.get(random.nextInt(strings.size())));
+        }
+        return out;
+    }
+
+    public static List<String> generateRandomStrings(int numStrings, int 
stringLength) {
+        List<String> randomStrings = new ArrayList<>();
+        Random random = new Random();
+        String characters = 
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789";
+        for (int i = 0; i < numStrings; i++) {
+            randomStrings.add(generateRandomString(random, stringLength, 
characters));
+        }
+        return randomStrings;
+    }
+
+    public static String generateRandomString(Random random, int stringLength, 
String characters){
+        StringBuilder randomString = new StringBuilder();
+        for (int j = 0; j < stringLength; j++) {
+            int randomIndex = random.nextInt(characters.length());
+            randomString.append(characters.charAt(randomIndex));
+        }
+        return randomString.toString();
+    }
+
+    public static void writeStringsToCsvFile(List<String> strings, String 
fileName) {
+        try (BufferedWriter bw = new BufferedWriter(new FileWriter(fileName))) 
{
+            for (String line : strings) {
+                bw.write(line);
+                bw.newLine();
+            }
+        } catch (IOException e) {
+            e.printStackTrace();
+        }
+    }
+
+    public static Map<String,Integer> writeDictToCsvFile(List<String> strings, 
String fileName) {
+        try (BufferedWriter bw = new BufferedWriter(new FileWriter(fileName))) 
{
+            Map<String,Integer> map = new HashMap<>();
+            for (int i = 0; i < strings.size(); i++) {
+                map.put(strings.get(i), i);
+                bw.write(strings.get(i) + Lop.DATATYPE_PREFIX + (i+1) + "\n");
+            }
+            return map;
+        } catch (IOException e) {
+            e.printStackTrace();
+            return null;
+        }
+    }
+}
diff --git 
a/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddings.dml 
b/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddings.dml
new file mode 100644
index 0000000000..1aa1fb0fed
--- /dev/null
+++ 
b/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddings.dml
@@ -0,0 +1,36 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# Read the pre-trained word embeddings
+E = read($1, rows=100, cols=100, format="text");
+# Read the token sequence (1K) w/ 100 distinct tokens
+Data = read($2, data_type="frame", format="csv");
+# Read the recode map for the distinct tokens
+Meta = read($3, data_type="frame", format="csv"); 
+
+jspec = "{ids: true, dummycode: [1]}";
+Data_enc = transformapply(target=Data, spec=jspec, meta=Meta);
+
+# Apply the embeddings on all tokens (1K x 100)
+R = Data_enc %*% E;
+
+write(R, $4, format="text");
+

Reply via email to