This is an automated email from the ASF dual-hosted git repository.
arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 384a7071a4 [SYSTEMDS-3580] Add word embedding encoder
384a7071a4 is described below
commit 384a7071a41bf6ff5cf5c7410ebb351effb2e47b
Author: e-strauss <[email protected]>
AuthorDate: Fri Jun 9 12:15:36 2023 +0200
[SYSTEMDS-3580] Add word embedding encoder
This patch extends the transformapply API to accept the pre-trained word
embeddings along with the dictionary as inputs. The new word embedding
column encoder is placed after recode and replace the recoded indices
with the embedding vectors. This addition removes the requirement of
a matrix multiplication to produce the embedding matrix.
The current implementation is slower than the baseline (w/ MatMult).
The future commits will introduce a new dense block to deduplicate
the large embeddings and multi-threading.
Closes #1839
---
.../ParameterizedBuiltinFunctionExpression.java | 8 +-
.../apache/sysds/runtime/data/DenseBlockFP64.java | 2 +-
.../cp/ParameterizedBuiltinCPInstruction.java | 13 +-
.../apache/sysds/runtime/transform/TfUtils.java | 2 +-
.../runtime/transform/encode/ColumnEncoder.java | 10 +-
.../transform/encode/ColumnEncoderComposite.java | 6 +
.../encode/ColumnEncoderWordEmbedding.java | 111 +++++++++
.../runtime/transform/encode/EncoderFactory.java | 33 ++-
.../transform/encode/MultiColumnEncoder.java | 34 ++-
.../sysds/utils/stats/TransformStatistics.java | 11 +-
.../TransformFrameEncodeWordEmbedding2Test.java | 258 +++++++++++++++++++++
.../TransformFrameEncodeWordEmbeddings2.dml | 36 +++
...ansformFrameEncodeWordEmbeddings2MultiCols1.dml | 43 ++++
...ansformFrameEncodeWordEmbeddings2MultiCols2.dml | 44 ++++
14 files changed, 593 insertions(+), 18 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
index 1d30d13fea..1906ee818e 100644
---
a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
+++
b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
@@ -49,6 +49,7 @@ public class ParameterizedBuiltinFunctionExpression extends
DataIdentifier
public static final String TF_FN_PARAM_DATA = "target";
public static final String TF_FN_PARAM_MTD2 = "meta";
public static final String TF_FN_PARAM_SPEC = "spec";
+ public static final String TF_FN_PARAM_EMBD = "embedding";
public static final String LINEAGE_TRACE = "lineage";
public static final String TF_FN_PARAM_MTD = "transformPath"; //NOTE
MB: for backwards compatibility
@@ -617,11 +618,14 @@ public class ParameterizedBuiltinFunctionExpression
extends DataIdentifier
//validate data / metadata (recode maps)
checkDataType(false, "transformapply", TF_FN_PARAM_DATA,
DataType.FRAME, conditional);
checkDataType(false, "transformapply", TF_FN_PARAM_MTD2,
DataType.FRAME, conditional);
-
+
//validate specification
checkDataValueType(false, "transformapply", TF_FN_PARAM_SPEC,
DataType.SCALAR, ValueType.STRING, conditional);
validateTransformSpec(TF_FN_PARAM_SPEC, conditional);
-
+
+ //validate additional argument for word_embeddings tranform
+ checkDataType(true, "transformapply", TF_FN_PARAM_EMBD,
DataType.MATRIX, conditional);
+
//set output dimensions
output.setDataType(DataType.MATRIX);
output.setValueType(ValueType.FP64);
diff --git a/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64.java
b/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64.java
index 44f8846ea9..719ad3a9cd 100644
--- a/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64.java
+++ b/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64.java
@@ -178,7 +178,7 @@ public class DenseBlockFP64 extends DenseBlockDRB
System.arraycopy(v, 0, _data, pos(r), _odims[0]);
return this;
}
-
+
@Override
public DenseBlock set(int[] ix, double v) {
_data[pos(ix)] = v;
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
index 9dfbdbec7f..18a199e930 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
@@ -54,7 +54,11 @@ import
org.apache.sysds.runtime.transform.tokenize.TokenizerFactory;
import org.apache.sysds.runtime.util.AutoDiff;
import org.apache.sysds.runtime.util.DataConverter;
-import java.util.*;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.LinkedHashMap;
+import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
@@ -310,11 +314,12 @@ public class ParameterizedBuiltinCPInstruction extends
ComputationCPInstruction
// acquire locks
FrameBlock data =
ec.getFrameInput(params.get("target"));
FrameBlock meta = ec.getFrameInput(params.get("meta"));
+ MatrixBlock embeddings = params.get("embedding") !=
null ? ec.getMatrixInput(params.get("embedding")) : null;
String[] colNames = data.getColumnNames();
// compute transformapply
MultiColumnEncoder encoder = EncoderFactory
- .createEncoder(params.get("spec"), colNames,
data.getNumColumns(), meta);
+ .createEncoder(params.get("spec"), colNames,
data.getNumColumns(), meta, embeddings);
MatrixBlock mbout = encoder.apply(data,
OptimizerUtils.getTransformNumThreads());
// release locks
@@ -346,7 +351,7 @@ public class ParameterizedBuiltinCPInstruction extends
ComputationCPInstruction
// compute transformapply
MultiColumnEncoder encoder = EncoderFactory
- .createEncoder(params.get("spec"), colNames,
meta.getNumColumns(), null);
+ .createEncoder(params.get("spec"), colNames,
meta.getNumColumns(), null, null);
MatrixBlock mbout = encoder.getColMapping(meta);
// release locks
@@ -532,6 +537,8 @@ public class ParameterizedBuiltinCPInstruction extends
ComputationCPInstruction
CPOperand target = new CPOperand(params.get("target"),
ValueType.FP64, DataType.FRAME);
CPOperand meta = getLiteral("meta", ValueType.UNKNOWN,
DataType.FRAME);
CPOperand spec = getStringLiteral("spec");
+ //FIXME: Taking only spec file name as a literal leads
to wrong reuse
+ //TODO: Add Embedding to the lineage item
return Pair.of(output.getName(),
new LineageItem(getOpcode(),
LineageItemUtils.getLineage(ec, target, meta, spec)));
}
diff --git a/src/main/java/org/apache/sysds/runtime/transform/TfUtils.java
b/src/main/java/org/apache/sysds/runtime/transform/TfUtils.java
index ec4758a819..b264004b61 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/TfUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/TfUtils.java
@@ -47,7 +47,7 @@ public class TfUtils implements Serializable
//transform methods
public enum TfMethod {
- IMPUTE, RECODE, HASH, BIN, DUMMYCODE, UDF, OMIT;
+ IMPUTE, RECODE, HASH, BIN, DUMMYCODE, UDF, OMIT, WORD_EMBEDDING;
@Override
public String toString() {
return name().toLowerCase();
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
index 610e0cc414..3020553e71 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
@@ -65,8 +65,13 @@ public abstract class ColumnEncoder implements Encoder,
Comparable<ColumnEncoder
protected int _nBuildPartitions = 0;
protected int _nApplyPartitions = 0;
+ //Override in ColumnEncoderWordEmbedding
+ public void initEmbeddings(MatrixBlock embeddings){
+ return;
+ }
+
protected enum TransformType{
- BIN, RECODE, DUMMYCODE, FEATURE_HASH, PASS_THROUGH, UDF, N_A
+ BIN, RECODE, DUMMYCODE, FEATURE_HASH, PASS_THROUGH, UDF,
WORD_EMBEDDING, N_A
}
protected ColumnEncoder(int colID) {
@@ -106,6 +111,9 @@ public abstract class ColumnEncoder implements Encoder,
Comparable<ColumnEncoder
case DUMMYCODE:
TransformStatistics.incDummyCodeApplyTime(t);
break;
+ case WORD_EMBEDDING:
+
TransformStatistics.incWordEmbeddingApplyTime(t);
+ break;
case FEATURE_HASH:
TransformStatistics.incFeatureHashingApplyTime(t);
break;
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
index 6f18263a26..fd69d5bf26 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
@@ -319,6 +319,12 @@ public class ColumnEncoderComposite extends ColumnEncoder {
columnEncoder.initMetaData(out);
}
+ //pass down init to actual encoders, only ColumnEncoderWordEmbedding
has actually implemented the init method
+ public void initEmbeddings(MatrixBlock embeddings){
+ for(ColumnEncoder columnEncoder : _columnEncoders)
+ columnEncoder.initEmbeddings(embeddings);
+ }
+
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderWordEmbedding.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderWordEmbedding.java
new file mode 100644
index 0000000000..03584cf5ee
--- /dev/null
+++
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderWordEmbedding.java
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.transform.encode;
+
+import org.apache.commons.lang.NotImplementedException;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
+import org.apache.sysds.runtime.frame.data.FrameBlock;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+
+import static org.apache.sysds.runtime.util.UtilFunctions.getEndIndex;
+
+public class ColumnEncoderWordEmbedding extends ColumnEncoder {
+ private MatrixBlock wordEmbeddings;
+
+ //domain size is equal to the number columns of the embedding column
(equal to length of an embedding vector)
+ @Override
+ public int getDomainSize(){
+ return wordEmbeddings.getNumColumns();
+ }
+ protected ColumnEncoderWordEmbedding(int colID) {
+ super(colID);
+ }
+
+ @Override
+ protected double getCode(CacheBlock<?> in, int row) {
+ throw new NotImplementedException();
+ }
+
+ @Override
+ protected double[] getCodeCol(CacheBlock<?> in, int startInd, int blkSize)
{
+ throw new NotImplementedException();
+ }
+
+ //previous recode replaced strings with indices of the corresponding
matrix row index
+ //now, the indices are replaced with actual word embedding vectors
+ //current limitation: in case the transform is done on multiple cols, the
same embedding
+ //matrix is used for both transform
+ @Override
+ public void applyDense(CacheBlock<?> in, MatrixBlock out, int outputCol,
int rowStart, int blk){
+ if (!(in instanceof MatrixBlock)){
+ throw new DMLRuntimeException("ColumnEncoderWordEmbedding called
with: " + in.getClass().getSimpleName() +
+ " and not MatrixBlock");
+ }
+ int rowEnd = getEndIndex(in.getNumRows(), rowStart, blk);
+ //map each recoded index to the corresponding embedding vector
+ for(int i=rowStart; i<rowEnd; i++){
+ double embeddingIndex = in.getDouble(i, outputCol);
+ //fill row with zeroes
+ if(Double.isNaN(embeddingIndex)){
+ for (int j = outputCol; j < outputCol + getDomainSize(); j++)
+ out.quickSetValue(i, j, 0.0);
+ }
+ //array copy
+ else{
+ for (int j = outputCol; j < outputCol + getDomainSize(); j++){
+ out.quickSetValue(i, j, wordEmbeddings.quickGetValue((int)
embeddingIndex - 1,j - outputCol ));
+ }
+ }
+ }
+ }
+
+
+ @Override
+ protected TransformType getTransformType() {
+ return TransformType.WORD_EMBEDDING;
+ }
+
+ @Override
+ public void build(CacheBlock<?> in) {
+ throw new NotImplementedException();
+ }
+
+ @Override
+ public void allocateMetaData(FrameBlock meta) {
+ throw new NotImplementedException();
+ }
+
+ @Override
+ public FrameBlock getMetaData(FrameBlock out) {
+ throw new NotImplementedException();
+ }
+
+ @Override
+ public void initMetaData(FrameBlock meta) {
+ return;
+ }
+
+ //save embeddings matrix reference for apply step
+ @Override
+ public void initEmbeddings(MatrixBlock embeddings){
+ this.wordEmbeddings = embeddings;
+ }
+}
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
index 075b6fbdd4..313258831a 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
@@ -36,6 +36,7 @@ import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.frame.data.FrameBlock;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.transform.TfUtils.TfMethod;
import org.apache.sysds.runtime.transform.encode.ColumnEncoder.EncoderType;
import org.apache.sysds.runtime.transform.meta.TfMetaUtils;
@@ -68,7 +69,21 @@ public interface EncoderFactory {
}
public static MultiColumnEncoder createEncoder(String spec, String[]
colnames, ValueType[] schema, FrameBlock meta,
- int minCol, int maxCol) {
+ int minCol, int maxCol){
+ return createEncoder(spec, colnames, schema, meta, null,
minCol, maxCol);
+ }
+
+ public static MultiColumnEncoder createEncoder(String spec, String[]
colnames, int clen, FrameBlock meta, MatrixBlock embeddings) {
+ return createEncoder(spec, colnames,
UtilFunctions.nCopies(clen, ValueType.STRING), meta, embeddings);
+ }
+
+ public static MultiColumnEncoder createEncoder(String spec, String[]
colnames, ValueType[] schema,
+
FrameBlock meta, MatrixBlock embeddings) {
+ return createEncoder(spec, colnames, schema, meta, embeddings,
-1, -1);
+ }
+
+ public static MultiColumnEncoder createEncoder(String spec, String[]
colnames, ValueType[] schema, FrameBlock meta,
+ MatrixBlock embeddings, int minCol, int maxCol) {
MultiColumnEncoder encoder;
int clen = schema.length;
@@ -88,9 +103,18 @@ public interface EncoderFactory {
List<Integer> dcIDs = Arrays.asList(ArrayUtils
.toObject(TfMetaUtils.parseJsonIDList(jSpec,
colnames, TfMethod.DUMMYCODE.toString(), minCol, maxCol)));
List<Integer> binIDs =
TfMetaUtils.parseBinningColIDs(jSpec, colnames, minCol, maxCol);
+ List<Integer> weIDs = Arrays.asList(ArrayUtils
+
.toObject(TfMetaUtils.parseJsonIDList(jSpec, colnames,
TfMethod.WORD_EMBEDDING.toString(), minCol, maxCol)));
+
+ //check if user passed an embeddings matrix
+ if(!weIDs.isEmpty() && embeddings == null)
+ throw new DMLRuntimeException("Missing argument
Embeddings Matrix for transform [" + TfMethod.WORD_EMBEDDING + "]");
+
// NOTE: any dummycode column requires recode as
preparation, unless the dummycode
// column follows binning or feature hashing
rcIDs = unionDistinct(rcIDs, except(except(dcIDs,
binIDs), haIDs));
+ // NOTE: Word Embeddings requires recode as preparation
+ rcIDs = unionDistinct(rcIDs, weIDs);
// Error out if the first level encoders have overlaps
if (intersect(rcIDs, binIDs, haIDs))
throw new DMLRuntimeException("More than one
encoders (recode, binning, hashing) on one column is not allowed");
@@ -114,7 +138,9 @@ public interface EncoderFactory {
if(!ptIDs.isEmpty())
for(Integer id : ptIDs)
addEncoderToMap(new
ColumnEncoderPassThrough(id), colEncoders);
-
+ if(!weIDs.isEmpty())
+ for(Integer id : weIDs)
+ addEncoderToMap(new
ColumnEncoderWordEmbedding(id), colEncoders);
if(!binIDs.isEmpty())
for(Object o : (JSONArray)
jSpec.get(TfMethod.BIN.toString())) {
JSONObject colspec = (JSONObject) o;
@@ -185,6 +211,9 @@ public interface EncoderFactory {
}
encoder.initMetaData(meta);
}
+ //initialize embeddings matrix block in the encoders in
case word embedding transform is used
+ if(!weIDs.isEmpty())
+ encoder.initEmbeddings(embeddings);
}
catch(Exception ex) {
throw new DMLRuntimeException(ex);
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
index 6838cdd1e2..59c22f5640 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
@@ -314,10 +314,12 @@ public class MultiColumnEncoder implements Encoder {
public MatrixBlock apply(CacheBlock<?> in, int k) {
// domain sizes are not updated if called from transformapply
boolean hasUDF = _columnEncoders.stream().anyMatch(e ->
e.hasEncoder(ColumnEncoderUDF.class));
+ boolean hasWE = _columnEncoders.stream().anyMatch(e ->
e.hasEncoder(ColumnEncoderWordEmbedding.class));
for(ColumnEncoderComposite columnEncoder : _columnEncoders)
columnEncoder.updateAllDCEncoders();
int numCols = getNumOutCols();
- long estNNz = (long) in.getNumRows() * (hasUDF ? numCols :
(long) in.getNumColumns());
+ long estNNz = (long) in.getNumRows() * (hasUDF ? numCols :
hasWE ? getEstNNzRow() : (long) in.getNumColumns());
+ // FIXME: estimate nnz for multiple encoders including
dummycode and embedding
boolean sparse =
MatrixBlock.evalSparseFormatInMemory(in.getNumRows(), numCols, estNNz) &&
!hasUDF;
MatrixBlock out = new MatrixBlock(in.getNumRows(), numCols,
sparse, estNNz);
return apply(in, out, 0, k);
@@ -353,8 +355,7 @@ public class MultiColumnEncoder implements Encoder {
int offset = outputCol;
for(ColumnEncoderComposite columnEncoder :
_columnEncoders) {
columnEncoder.apply(in, out,
columnEncoder._colID - 1 + offset);
- if
(columnEncoder.hasEncoder(ColumnEncoderDummycode.class))
- offset +=
columnEncoder.getEncoder(ColumnEncoderDummycode.class)._domainSize - 1;
+ offset = getOffset(offset, columnEncoder);
}
}
// Recomputing NNZ since we access the block directly
@@ -373,12 +374,19 @@ public class MultiColumnEncoder implements Encoder {
int offset = outputCol;
for(ColumnEncoderComposite e : _columnEncoders) {
tasks.addAll(e.getApplyTasks(in, out, e._colID - 1 +
offset));
- if(e.hasEncoder(ColumnEncoderDummycode.class))
- offset +=
e.getEncoder(ColumnEncoderDummycode.class)._domainSize - 1;
+ offset = getOffset(offset, e);
}
return tasks;
}
+ private int getOffset(int offset, ColumnEncoderComposite e) {
+ if(e.hasEncoder(ColumnEncoderDummycode.class))
+ offset +=
e.getEncoder(ColumnEncoderDummycode.class)._domainSize - 1;
+ if(e.hasEncoder(ColumnEncoderWordEmbedding.class))
+ offset +=
e.getEncoder(ColumnEncoderWordEmbedding.class).getDomainSize() - 1;
+ return offset;
+ }
+
private void applyMT(CacheBlock<?> in, MatrixBlock out, int outputCol,
int k) {
DependencyThreadPool pool = new DependencyThreadPool(k);
try {
@@ -386,8 +394,7 @@ public class MultiColumnEncoder implements Encoder {
int offset = outputCol;
for (ColumnEncoderComposite e :
_columnEncoders) {
pool.submitAllAndWait(e.getApplyTasks(in, out, e._colID - 1 + offset));
- if
(e.hasEncoder(ColumnEncoderDummycode.class))
- offset +=
e.getEncoder(ColumnEncoderDummycode.class)._domainSize - 1;
+ offset = getOffset(offset, e);
}
} else
pool.submitAllAndWait(getApplyTasks(in, out,
outputCol));
@@ -696,6 +703,12 @@ public class MultiColumnEncoder implements Encoder {
_legacyMVImpute.initMetaData(meta);
}
+ //pass down init to composite encoders
+ public void initEmbeddings(MatrixBlock embeddings) {
+ for(ColumnEncoder columnEncoder : _columnEncoders)
+ columnEncoder.initEmbeddings(embeddings);
+ }
+
@Override
public void prepareBuildPartial() {
for(Encoder encoder : _columnEncoders)
@@ -855,6 +868,13 @@ public class MultiColumnEncoder implements Encoder {
return getEncoderTypes(-1);
}
+ public int getEstNNzRow(){
+ int nnz = 0;
+ for(int i = 0; i < _columnEncoders.size(); i++)
+ nnz += _columnEncoders.get(i).getDomainSize();
+ return nnz;
+ }
+
public int getNumOutCols() {
int sum = 0;
for(int i = 0; i < _columnEncoders.size(); i++)
diff --git
a/src/main/java/org/apache/sysds/utils/stats/TransformStatistics.java
b/src/main/java/org/apache/sysds/utils/stats/TransformStatistics.java
index b7779e4ee1..9ace729462 100644
--- a/src/main/java/org/apache/sysds/utils/stats/TransformStatistics.java
+++ b/src/main/java/org/apache/sysds/utils/stats/TransformStatistics.java
@@ -32,6 +32,8 @@ public class TransformStatistics {
//private static final LongAdder applyTime = new LongAdder();
private static final LongAdder recodeApplyTime = new LongAdder();
private static final LongAdder dummyCodeApplyTime = new LongAdder();
+
+ private static final LongAdder wordEmbeddingApplyTime = new LongAdder();
private static final LongAdder passThroughApplyTime = new LongAdder();
private static final LongAdder featureHashingApplyTime = new
LongAdder();
private static final LongAdder binningApplyTime = new LongAdder();
@@ -55,6 +57,10 @@ public class TransformStatistics {
dummyCodeApplyTime.add(t);
}
+ public static void incWordEmbeddingApplyTime(long t){
+ wordEmbeddingApplyTime.add(t);
+ }
+
public static void incBinningApplyTime(long t) {
binningApplyTime.add(t);
}
@@ -112,7 +118,7 @@ public class TransformStatistics {
return dummyCodeApplyTime.longValue() +
binningApplyTime.longValue() +
featureHashingApplyTime.longValue() +
passThroughApplyTime.longValue() +
recodeApplyTime.longValue() +
UDFApplyTime.longValue() +
- omitApplyTime.longValue() +
imputeApplyTime.longValue();
+ omitApplyTime.longValue() +
imputeApplyTime.longValue() + wordEmbeddingApplyTime.longValue();
}
public static void reset() {
@@ -163,6 +169,9 @@ public class TransformStatistics {
if(dummyCodeApplyTime.longValue() > 0)
sb.append("\tDummyCode apply
time:\t").append(String.format("%.3f",
dummyCodeApplyTime.longValue()*1e-9)).append(" sec.\n");
+ if(wordEmbeddingApplyTime.longValue() > 0)
+ sb.append("\tWordEmbedding apply
time:\t").append(String.format("%.3f",
+
wordEmbeddingApplyTime.longValue()*1e-9)).append(" sec.\n");
if(featureHashingApplyTime.longValue() > 0)
sb.append("\tHashing apply
time:\t").append(String.format("%.3f",
featureHashingApplyTime.longValue()*1e-9)).append(" sec.\n");
diff --git
a/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbedding2Test.java
b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbedding2Test.java
new file mode 100644
index 0000000000..8ab52d9f64
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbedding2Test.java
@@ -0,0 +1,258 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.transform;
+
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.lops.Lop;
+import org.apache.sysds.runtime.matrix.data.MatrixValue;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Ignore;
+import org.junit.Test;
+
+import java.io.BufferedWriter;
+import java.io.FileWriter;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Date;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+public class TransformFrameEncodeWordEmbedding2Test extends AutomatedTestBase
+{
+ private final static String TEST_NAME1 =
"TransformFrameEncodeWordEmbeddings2";
+ private final static String TEST_NAME2 =
"TransformFrameEncodeWordEmbeddings2MultiCols1";
+ private final static String TEST_NAME3 =
"TransformFrameEncodeWordEmbeddings2MultiCols2";
+
+ private final static String TEST_DIR = "functions/transform/";
+ private final static String TEST_CLASS_DIR = TEST_DIR +
TransformFrameEncodeWordEmbeddingTest.class.getSimpleName() + "/";
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_DIR,
TEST_NAME1));
+ addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_DIR,
TEST_NAME2));
+ addTestConfiguration(TEST_NAME3, new TestConfiguration(TEST_DIR,
TEST_NAME3));
+ }
+
+ @Test
+ public void testTransformToWordEmbeddings() {
+ runTransformTest(TEST_NAME1, ExecMode.SINGLE_NODE);
+ }
+
+ @Test
+ @Ignore
+ public void testNonRandomTransformToWordEmbeddings2Cols() {
+ runTransformTest(TEST_NAME2, ExecMode.SINGLE_NODE);
+ }
+
+ @Test
+ @Ignore
+ public void testRandomTransformToWordEmbeddings4Cols() {
+ runTransformTestMultiCols(TEST_NAME3, ExecMode.SINGLE_NODE);
+ }
+
+ private void runTransformTest(String testname, ExecMode rt)
+ {
+ //set runtime platform
+ ExecMode rtold = setExecMode(rt);
+ try
+ {
+ int rows = 100;
+ int cols = 100;
+ getAndLoadTestConfiguration(testname);
+ fullDMLScriptName = getScript();
+
+ // Generate random embeddings for the distinct tokens
+ double[][] a = createRandomMatrix("embeddings", rows, cols, 0, 10,
1, new Date().getTime());
+
+ // Generate random distinct tokens
+ List<String> strings = generateRandomStrings(rows, 10);
+
+ // Generate the dictionary by assigning unique ID to each distinct
token
+ Map<String,Integer> map = writeDictToCsvFile(strings,
baseDirectory + INPUT_DIR + "dict");
+
+ // Create the dataset by repeating and shuffling the distinct
tokens
+ List<String> stringsColumn = shuffleAndMultiplyStrings(strings,
320);
+ writeStringsToCsvFile(stringsColumn, baseDirectory + INPUT_DIR +
"data");
+
+ //run script
+ programArgs = new String[]{"-stats","-args", input("embeddings"),
input("data"), input("dict"), output("result")};
+ runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
+
+ // Manually derive the expected result
+ double[][] res_expected = manuallyDeriveWordEmbeddings(cols, a,
map, stringsColumn);
+
+ // Compare results
+ HashMap<MatrixValue.CellIndex, Double> res_actual =
readDMLMatrixFromOutputDir("result");
+ double[][] resultActualDouble =
TestUtils.convertHashMapToDoubleArray(res_actual);
+ //System.out.println("Actual Result [" + resultActualDouble.length
+ "x" + resultActualDouble[0].length + "]:");
+ //print2DimDoubleArray(resultActualDouble);
+ //System.out.println("\nExpected Result [" + res_expected.length +
"x" + res_expected[0].length + "]:");
+ //print2DimDoubleArray(res_expected);
+ TestUtils.compareMatrices(resultActualDouble, res_expected, 1e-6);
+ }
+ catch(Exception ex) {
+ throw new RuntimeException(ex);
+
+ }
+ finally {
+ resetExecMode(rtold);
+ }
+ }
+
+ private void print2DimDoubleArray(double[][] resultActualDouble) {
+ Arrays.stream(resultActualDouble).forEach(
+ e -> System.out.println(Arrays.stream(e).mapToObj(d ->
String.format("%06.1f", d))
+ .reduce("", (sub, elem) -> sub + " " + elem)));
+ }
+
+ private void runTransformTestMultiCols(String testname, ExecMode rt)
+ {
+ //set runtime platform
+ ExecMode rtold = setExecMode(rt);
+ try
+ {
+ int rows = 100;
+ int cols = 100;
+ getAndLoadTestConfiguration(testname);
+ fullDMLScriptName = getScript();
+
+ // Generate random embeddings for the distinct tokens
+ double[][] a = createRandomMatrix("embeddings", rows, cols, 0, 10,
1, new Date().getTime());
+
+ // Generate random distinct tokens
+ List<String> strings = generateRandomStrings(rows, 10);
+
+ // Generate the dictionary by assigning unique ID to each distinct
token
+ Map<String,Integer> map = writeDictToCsvFile(strings,
baseDirectory + INPUT_DIR + "dict");
+
+ // Create the dataset by repeating and shuffling the distinct
tokens
+ List<String> stringsColumn = shuffleAndMultiplyStrings(strings,
10);
+ writeStringsToCsvFile(stringsColumn, baseDirectory + INPUT_DIR +
"data");
+
+ //run script
+ programArgs = new String[]{"-stats","-args", input("embeddings"),
input("data"), input("dict"), output("result"), output("result2")};
+ runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
+
+ // Manually derive the expected result
+ double[][] res_expected = manuallyDeriveWordEmbeddings(cols, a,
map, stringsColumn);
+
+ // Compare results
+ HashMap<MatrixValue.CellIndex, Double> res_actual =
readDMLMatrixFromOutputDir("result");
+ HashMap<MatrixValue.CellIndex, Double> res_actual2 =
readDMLMatrixFromOutputDir("result2");
+ double[][] resultActualDouble =
TestUtils.convertHashMapToDoubleArray(res_actual);
+ double[][] resultActualDouble2 =
TestUtils.convertHashMapToDoubleArray(res_actual2);
+ //System.out.println("Actual Result1 [" +
resultActualDouble.length + "x" + resultActualDouble[0].length + "]:");
+ ///print2DimDoubleArray(resultActualDouble);
+ //System.out.println("\nActual Result2 [" +
resultActualDouble.length + "x" + resultActualDouble[0].length + "]:");
+ //print2DimDoubleArray(resultActualDouble2);
+ //System.out.println("\nExpected Result [" + res_expected.length +
"x" + res_expected[0].length + "]:");
+ //print2DimDoubleArray(res_expected);
+ TestUtils.compareMatrices(resultActualDouble, res_expected, 1e-6);
+ TestUtils.compareMatrices(resultActualDouble, resultActualDouble2,
1e-6);
+ }
+ catch(Exception ex) {
+ throw new RuntimeException(ex);
+
+ }
+ finally {
+ resetExecMode(rtold);
+ }
+ }
+
+ private double[][] manuallyDeriveWordEmbeddings(int cols, double[][] a,
Map<String, Integer> map, List<String> stringsColumn) {
+ // Manually derive the expected result
+ double[][] res_expected = new double[stringsColumn.size()][cols];
+ for (int i = 0; i < stringsColumn.size(); i++) {
+ int rowMapped = map.get(stringsColumn.get(i));
+ System.arraycopy(a[rowMapped], 0, res_expected[i], 0, cols);
+ }
+ return res_expected;
+ }
+
+ private double[][] generateWordEmbeddings(int rows, int cols) {
+ double[][] a = new double[rows][cols];
+ for (int i = 0; i < a.length; i++) {
+ for (int j = 0; j < a[i].length; j++) {
+ a[i][j] = cols *i + j;
+ }
+
+ }
+ return a;
+ }
+
+ public static List<String> shuffleAndMultiplyStrings(List<String> strings,
int multiply){
+ List<String> out = new ArrayList<>();
+ Random random = new Random();
+ for (int i = 0; i < strings.size()*multiply; i++) {
+ out.add(strings.get(random.nextInt(strings.size())));
+ }
+ return out;
+ }
+
+ public static List<String> generateRandomStrings(int numStrings, int
stringLength) {
+ List<String> randomStrings = new ArrayList<>();
+ Random random = new Random();
+ String characters =
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789";
+ for (int i = 0; i < numStrings; i++) {
+ randomStrings.add(generateRandomString(random, stringLength,
characters));
+ }
+ return randomStrings;
+ }
+
+ public static String generateRandomString(Random random, int stringLength,
String characters){
+ StringBuilder randomString = new StringBuilder();
+ for (int j = 0; j < stringLength; j++) {
+ int randomIndex = random.nextInt(characters.length());
+ randomString.append(characters.charAt(randomIndex));
+ }
+ return randomString.toString();
+ }
+
+ public static void writeStringsToCsvFile(List<String> strings, String
fileName) {
+ try (BufferedWriter bw = new BufferedWriter(new FileWriter(fileName)))
{
+ for (String line : strings) {
+ bw.write(line);
+ bw.newLine();
+ }
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+ }
+
+ public static Map<String,Integer> writeDictToCsvFile(List<String> strings,
String fileName) {
+ try (BufferedWriter bw = new BufferedWriter(new FileWriter(fileName)))
{
+ Map<String,Integer> map = new HashMap<>();
+ for (int i = 0; i < strings.size(); i++) {
+ map.put(strings.get(i), i);
+ bw.write(strings.get(i) + Lop.DATATYPE_PREFIX + (i+1) + "\n");
+ }
+ return map;
+ } catch (IOException e) {
+ e.printStackTrace();
+ return null;
+ }
+ }
+}
diff --git
a/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddings2.dml
b/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddings2.dml
new file mode 100644
index 0000000000..29a4bfab74
--- /dev/null
+++
b/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddings2.dml
@@ -0,0 +1,36 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# Read the pre-trained word embeddings
+E = read($1, rows=100, cols=100, format="text");
+# Read the token sequence (1K) w/ 100 distinct tokens
+Data = read($2, data_type="frame", format="csv");
+# Read the recode map for the distinct tokens
+Meta = read($3, data_type="frame", format="csv");
+
+jspec = "{ids: true, word_embedding: [1]}";
+Data_enc = transformapply(target=Data, spec=jspec, meta=Meta, embedding=E);
+
+write(Data_enc, $4, format="text");
+
+
+
+
diff --git
a/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddings2MultiCols1.dml
b/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddings2MultiCols1.dml
new file mode 100644
index 0000000000..00484697d6
--- /dev/null
+++
b/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddings2MultiCols1.dml
@@ -0,0 +1,43 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# Read the pre-trained word embeddings
+E = read($1, rows=100, cols=100, format="text");
+# Read the token sequence (1K) w/ 100 distinct tokens
+Data = read($2, data_type="frame", format="csv");
+# Read the recode map for the distinct tokens
+Meta = read($3, data_type="frame", format="csv");
+
+DataExtension = as.frame(matrix(1, rows=length(Data), cols=1))
+Data = cbind(Data, DataExtension)
+Data = cbind(DataExtension, Data)
+Meta = cbind(Meta, Meta)
+
+jspec = "{ids: true, word_embedding: [2]}";
+#jspec = "{ids: true, dummycode: [2]}";
+Data_enc = transformapply(target=Data, spec=jspec, meta=Meta, embedding=E);
+
+Data_enc = Data_enc[,2:101]
+write(Data_enc, $4, format="text");
+
+
+
+
diff --git
a/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddings2MultiCols2.dml
b/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddings2MultiCols2.dml
new file mode 100644
index 0000000000..fd742520e7
--- /dev/null
+++
b/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddings2MultiCols2.dml
@@ -0,0 +1,44 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# Read the pre-trained word embeddings
+E = read($1, rows=100, cols=100, format="text");
+# Read the token sequence (1K) w/ 100 distinct tokens
+Data = read($2, data_type="frame", format="csv");
+# Read the recode map for the distinct tokens
+Meta = read($3, data_type="frame", format="csv");
+
+DataExtension = as.frame(matrix(1, rows=length(Data), cols=1))
+Data = cbind(Data, DataExtension)
+Data = cbind(Data, Data)
+Meta = cbind(Meta, Meta)
+Meta = cbind(Meta, Meta)
+
+jspec = "{ids: true, word_embedding: [1,3]}";
+Data_enc = transformapply(target=Data, spec=jspec, meta=Meta, embedding=E);
+
+Data_enc1 = Data_enc[,1:100]
+Data_enc2 = Data_enc[,102:201]
+write(Data_enc1, $4, format="text");
+write(Data_enc2, $5, format="text");
+
+
+