This is an automated email from the ASF dual-hosted git repository.

arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new f6c65f59d7 [SYSTEMDS-3579] More support for deduplicated dense block
f6c65f59d7 is described below

commit f6c65f59d747419b3e5911d237118b026e181308
Author: e-strauss <[email protected]>
AuthorDate: Fri Aug 11 13:14:55 2023 +0200

    [SYSTEMDS-3579] More support for deduplicated dense block
    
    This patch extends supports for dedup dense block. Now we support
    row, col, ful sum, matrix multiplication (dedup %*% dense) and
    serialization for dedup blocks.
    
    Closes #1870
---
 src/main/java/org/apache/sysds/common/Types.java   |   1 +
 .../sysds/runtime/data/DenseBlockFP64DEDUP.java    | 496 +++++++++++---------
 .../sysds/runtime/matrix/data/LibMatrixAgg.java    | 110 ++++-
 .../sysds/runtime/matrix/data/LibMatrixMult.java   |  62 ++-
 .../sysds/runtime/matrix/data/MatrixBlock.java     | 137 +++++-
 .../test/functions/io/binary/SerializeTest.java    |  72 ++-
 .../TransformFrameEncodeWordEmbedding2Test.java    | 511 +++++++++++----------
 .../TransformFrameEncodeWordEmbeddingMMTest.java   | 108 +++++
 ...ransformFrameEncodeWordEmbeddingRowSumTest.java | 231 ++++++++++
 .../TransformFrameEncodeWordEmbeddingsColSum.dml   |  32 ++
 .../TransformFrameEncodeWordEmbeddingsFullSum.dml  |  32 ++
 .../TransformFrameEncodeWordEmbeddingsMM.dml       |  38 ++
 .../TransformFrameEncodeWordEmbeddingsRowSum.dml   |  32 ++
 13 files changed, 1347 insertions(+), 515 deletions(-)

diff --git a/src/main/java/org/apache/sysds/common/Types.java 
b/src/main/java/org/apache/sysds/common/Types.java
index 78ed173bc9..7bf48f5da5 100644
--- a/src/main/java/org/apache/sysds/common/Types.java
+++ b/src/main/java/org/apache/sysds/common/Types.java
@@ -213,6 +213,7 @@ public class Types
                ULTRA_SPARSE_BLOCK,
                SPARSE_BLOCK,
                DENSE_BLOCK,
+               DEDUP_BLOCK,
        }
        
        /**
diff --git 
a/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64DEDUP.java 
b/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64DEDUP.java
index b2273faa75..15433654c3 100644
--- a/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64DEDUP.java
+++ b/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64DEDUP.java
@@ -19,236 +19,280 @@
 
 package org.apache.sysds.runtime.data;
 
-import org.apache.commons.lang3.NotImplementedException;
+import org.apache.commons.lang.NotImplementedException;
 import org.apache.sysds.common.Types;
 import org.apache.sysds.runtime.util.UtilFunctions;
+import org.apache.sysds.utils.MemoryEstimates;
 
 import java.util.Arrays;
 import java.util.HashMap;
 
-public class DenseBlockFP64DEDUP extends DenseBlockDRB{
-       private static final long serialVersionUID = 4124905190428752213L;
-       private double[][] _data;
-
-       protected DenseBlockFP64DEDUP(int[] dims) {
-               super(dims);
-               reset(_rlen, _odims, 0);
-       }
-
-       @Override
-       protected void allocateBlock(int bix, int length) {
-               _data[bix] = new double[length];
-       }
-
-       @Override
-       public void reset(int rlen, int[] odims, double v) {
-               if(rlen >  capacity() / _odims[0])
-                       _data = new double[rlen][];
-               else{
-                       if(v == 0.0){
-                               for(int i = 0; i < rlen; i++)
-                                       _data[i] = null;
-                       }
-                       else {
-                               for(int i = 0; i < rlen; i++){
-                                       if(odims[0] > _odims[0] ||_data[i] == 
null )
-                                               allocateBlock(i, odims[0]);
-                                       Arrays.fill(_data[i], 0, odims[0], v);
-                               }
-                       }
-               }
-               _rlen = rlen;
-               _odims = odims;
-       }
-
-       @Override
-       public void resetNoFill(int rlen, int[] odims) {
-               if(_data == null || rlen > _rlen){
-                       _data = new double[rlen][];
-               }
-               _rlen = rlen;
-               _odims = odims;
-       }
-
-       @Override
-       public boolean isNumeric() {
-               return true;
-       }
-
-       @Override
-       public boolean isNumeric(Types.ValueType vt) {
-               return Types.ValueType.FP64 == vt;
-       }
-
-       @Override
-       public long capacity() {
-               return (_data != null) ? _data.length*_odims[0] : -1;
-       }
-
-       @Override
-       public long countNonZeros(){
-               long nnz = 0;
-               HashMap<double[], Long> cache = new HashMap<>();
-               for (int i = 0; i < _rlen; i++) {
-                       double[] row = this._data[i];
-                       if(row == null)
-                               continue;
-                       Long count = cache.getOrDefault(row, null);
-                       if(count == null){
-                               count = Long.valueOf(countNonZeros(i));
-                               cache.put(row, count);
-                       }
-                       nnz += count;
-               }
-               return nnz;
-       }
-
-       @Override
-       public int countNonZeros(int r) {
-               return _data[r] == null ? 0 : 
UtilFunctions.computeNnz(_data[r], 0, _odims[0]);
-       }
-
-       @Override
-       protected long computeNnz(int bix, int start, int length) {
-               int nnz = 0;
-               int row_start = (int) Math.floor(start / _odims[0]);
-               int col_start = start % _odims[0];
-               for (int i = 0; i < length; i++) {
-                       if(_data[row_start] == null){
-                               i += _odims[0] - 1 - col_start;
-                               col_start = 0;
-                               row_start += 1;
-                               continue;
-                       }
-                       nnz += _data[row_start][col_start] != 0 ? 1 : 0;
-                       col_start += 1;
-                       if(col_start == _odims[0]) {
-                               col_start = 0;
-                               row_start += 1;
-                       }
-               }
-               return nnz;
-       }
-
-       @Override
-       public int pos(int r){
-               return 0;
-       }
-
-       @Override
-       public int pos(int[] ix){
-               int pos = ix[ix.length - 1];
-               for(int i = 1; i < ix.length - 1; i++)
-                       pos += ix[i] * _odims[i];
-               return pos;
-       }
-
-       @Override
-       public double[] values(int r) {
-               return valuesAt(r);
-       }
-
-       @Override
-       public double[] valuesAt(int bix) {
-               return _data[bix] == null ? new double[_odims[0]] : _data[bix];
-       }
-
-       @Override
-       public int index(int r) {
-               return r;
-       }
-
-       @Override
-       public int numBlocks(){
-               return _data.length;
-       }
-
-       @Override
-       public int size(int bix) {
-               return _odims[0];
-       }
-
-       @Override
-       public void incr(int r, int c) {
-               incr(r,c,1.0);
-       }
-
-       @Override
-       public void incr(int r, int c, double delta) {
-               if(_data[r] == null)
-                       allocateBlock(r, _odims[0]);
-               _data[r][c] += delta;
-       }
-
-       @Override
-       protected void fillBlock(int bix, int fromIndex, int toIndex, double v) 
{
-               if(_data[bix] == null)
-                       allocateBlock(bix, _odims[0]);
-               Arrays.fill(_data[bix], fromIndex, toIndex, v);
-       }
-
-       @Override
-       protected void setInternal(int bix, int ix, double v) {
-               set(bix, ix, v);
-       }
-
-       @Override
-       public DenseBlock set(int r, int c, double v) {
-               if(_data[r] == null)
-                       _data[r] = new double[_odims[0]];
-               _data[r][c] = v;
-               return this;
-       }
-
-       @Override
-       public DenseBlock set(int r, double[] v) {
-               if(v.length == _odims[0])
-                       _data[r] = v;
-               else
-                       throw new RuntimeException("set Denseblock called with 
an array length [" + v.length +"], array to overwrite is of length [" + 
_odims[0] + "]");
-               return this;
-       }
-
-       @Override
-       public DenseBlock set(DenseBlock db) {
-               throw new NotImplementedException();
-       }
-
-       @Override
-       public DenseBlock set(int[] ix, double v) {
-               return set(ix[0], pos(ix), v);
-       }
-
-       @Override
-       public DenseBlock set(int[] ix, long v) {
-               return set(ix[0], pos(ix), v);
-       }
-
-       @Override
-       public DenseBlock set(int[] ix, String v) {
-               return set(ix[0], pos(ix), Double.parseDouble(v));
-       }
-
-       @Override
-       public double get(int r, int c) {
-               if(_data[r] == null)
-                       return 0.0;
-               else
-                       return _data[r][c];
-       }
-
-       @Override
-       public double get(int[] ix) {
-               return get(ix[0], pos(ix));
-       }
-
-       @Override
-       public String getString(int[] ix) {
-               return String.valueOf(get(ix[0], pos(ix)));
-       }
-
-       @Override
-       public long getLong(int[] ix) {
-               return UtilFunctions.toLong(get(ix[0], pos(ix)));
-       }
+public class DenseBlockFP64DEDUP extends DenseBlockDRB
+{
+    private double[][] _data;
+    private int _distinct = 0;
+
+    protected DenseBlockFP64DEDUP(int[] dims) {
+        super(dims);
+        reset(_rlen, _odims, 0);
+    }
+
+    public int getNrDistinctRows(){
+        return _distinct;
+    }
+
+    @Override
+    protected void allocateBlock(int bix, int length) {
+        _data[bix] = new double[length];
+    }
+
+    @Override
+    public void reset(int rlen, int[] odims, double v) {
+        if(rlen >  capacity() / _odims[0])
+            _data = new double[rlen][];
+        else {
+            if(v == 0.0) {
+                for(int i = 0; i < rlen; i++)
+                    _data[i] = null;
+            }
+            else {
+                for(int i = 0; i < rlen; i++) {
+                    if(odims[0] > _odims[0] ||_data[i] == null )
+                        allocateBlock(i, odims[0]);
+                    Arrays.fill(_data[i], 0, odims[0], v);
+                }
+            }
+        }
+        _rlen = rlen;
+        _odims = odims;
+    }
+
+    @Override
+    public void resetNoFill(int rlen, int[] odims) {
+        if(_data == null || rlen > _rlen){
+            _data = new double[rlen][];
+        }
+        _rlen = rlen;
+        _odims = odims;
+    }
+
+    @Override
+    public boolean isNumeric() {
+        return true;
+    }
+
+    @Override
+    public boolean isNumeric(Types.ValueType vt) {
+        return Types.ValueType.FP64 == vt;
+    }
+
+    @Override
+    public long capacity() {
+        return (_data != null) ? ((long) _data.length)*_odims[0] : -1;
+    }
+
+    @Override
+    public long countNonZeros() {
+        long nnz = 0;
+        HashMap<double[], Long> cache = new HashMap<double[], Long>();
+        for (int i = 0; i < _rlen; i++) {
+            double[] row = this._data[i];
+            if(row == null)
+                continue;
+            Long count = cache.getOrDefault(row, null);
+            if(count == null){
+                count = Long.valueOf(countNonZeros(i));
+                cache.put(row, count);
+            }
+            nnz += count;
+        }
+        this._distinct = cache.size();
+        return nnz;
+    }
+
+    @Override
+    public int countNonZeros(int r) {
+        return _data[r] == null ? 0 : UtilFunctions.computeNnz(_data[r], 0, 
_odims[0]);
+    }
+
+    @Override
+    public long countNonZeros(int rl, int ru, int ol, int ou) {
+        long nnz = 0;
+        HashMap<double[], Long> cache = new HashMap<double[], Long>();
+        for (int i = rl; i < ru; i++) {
+            double[] row = this._data[i];
+            if(row == null)
+                continue;
+            Long count = cache.getOrDefault(row, null);
+            if(count == null){
+                count = Long.valueOf(UtilFunctions.computeNnz(_data[i], ol, 
ou));
+                cache.put(row, count);
+            }
+            nnz += count;
+        }
+        return nnz;
+    }
+
+    @Override
+    protected long computeNnz(int bix, int start, int length) {
+        int nnz = 0;
+        int row_start = (int) Math.floor(start / _odims[0]);
+        int col_start = start % _odims[0];
+        for (int i = 0; i < length; i++) {
+            if(_data[row_start] == null){
+                i += _odims[0] - 1 - col_start;
+                col_start = 0;
+                row_start += 1;
+                continue;
+            }
+            nnz += _data[row_start][col_start] != 0 ? 1 : 0;
+            col_start += 1;
+            if(col_start == _odims[0]) {
+                col_start = 0;
+                row_start += 1;
+            }
+        }
+        return nnz;
+    }
+
+    @Override
+    public int pos(int r){
+        return 0;
+    }
+
+    @Override
+    public int pos(int r, int c){
+        return c;
+    }
+
+    @Override
+    public int pos(int[] ix){
+        int pos = ix[ix.length - 1];
+        for(int i = 1; i < ix.length - 1; i++)
+            pos += ix[i] * _odims[i];
+        return pos;
+    }
+
+    @Override
+    public int blockSize(int bix) {
+        return 1;
+    }
+    public boolean isContiguous(int rl, int ru){
+        return rl == ru;
+    }
+    @Override
+    public double[] values(int r) {
+        return valuesAt(r);
+    }
+
+    @Override
+    public double[] valuesAt(int bix) {
+        return _data[bix] == null ? new double[_odims[0]] : _data[bix];
+    }
+
+    @Override
+    public int index(int r) {
+        return r;
+    }
+
+    @Override
+    public int numBlocks(){
+        return _data.length;
+    }
+
+    @Override
+    public int size(int bix) {
+        return _odims[0];
+    }
+
+    @Override
+    public void incr(int r, int c) {
+        incr(r,c,1.0);
+    }
+
+    @Override
+    public void incr(int r, int c, double delta) {
+        if(_data[r] == null)
+            allocateBlock(r, _odims[0]);
+        _data[r][c] += delta;
+    }
+
+    @Override
+    protected void fillBlock(int bix, int fromIndex, int toIndex, double v) {
+        if(_data[bix] == null)
+            allocateBlock(bix, _odims[0]);
+        Arrays.fill(_data[bix], fromIndex, toIndex, v);
+    }
+
+    @Override
+    protected void setInternal(int bix, int ix, double v) {
+        set(bix, ix, v);
+    }
+
+    @Override
+    public DenseBlock set(int r, int c, double v) {
+        if(_data[r] == null)
+            _data[r] = new double[_odims[0]];
+        _data[r][c] = v;
+        return this;
+    }
+
+    @Override
+    public DenseBlock set(int r, double[] v) {
+        if(v.length == _odims[0])
+            _data[r] = v;
+        else
+            throw new RuntimeException("set Denseblock called with an array 
length [" + v.length +"], array to overwrite is of length [" + _odims[0] + "]");
+        return this;
+    }
+
+    @Override
+    public DenseBlock set(DenseBlock db) {
+        throw new NotImplementedException();
+    }
+
+    @Override
+    public DenseBlock set(int[] ix, double v) {
+        return set(ix[0], pos(ix), v);
+    }
+
+    @Override
+    public DenseBlock set(int[] ix, long v) {
+        return set(ix[0], pos(ix), v);
+    }
+
+    @Override
+    public DenseBlock set(int[] ix, String v) {
+        return set(ix[0], pos(ix), Double.parseDouble(v));
+    }
+
+    @Override
+    public double get(int r, int c) {
+        if(_data[r] == null)
+            return 0.0;
+        else
+            return _data[r][c];
+    }
+
+    @Override
+    public double get(int[] ix) {
+        return get(ix[0], pos(ix));
+    }
+
+    @Override
+    public String getString(int[] ix) {
+        return String.valueOf(get(ix[0], pos(ix)));
+    }
+
+    @Override
+    public long getLong(int[] ix) {
+        return UtilFunctions.toLong(get(ix[0], pos(ix)));
+    }
+
+    public double estimateMemory(){
+        if( (double)_rlen + this._odims[0] > Long.MAX_VALUE )
+            return Long.MAX_VALUE;
+        return DenseBlock.estimateMemory(_rlen, _odims[0])
+                + MemoryEstimates.doubleArrayCost(_odims[0])*_distinct + 
MemoryEstimates.objectArrayCost(_rlen);
+    }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java
index 61eecf251f..70ee962162 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java
@@ -22,6 +22,7 @@ package org.apache.sysds.runtime.matrix.data;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Comparator;
+import java.util.HashMap;
 import java.util.List;
 import java.util.concurrent.Callable;
 import java.util.concurrent.ExecutorService;
@@ -36,6 +37,7 @@ import 
org.apache.sysds.runtime.codegen.SpoofOperator.SideInputSparseCell;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject.UpdateType;
 import 
org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
 import org.apache.sysds.runtime.data.DenseBlock;
+import org.apache.sysds.runtime.data.DenseBlockFP64DEDUP;
 import org.apache.sysds.runtime.data.DenseBlockFactory;
 import org.apache.sysds.runtime.data.SparseBlock;
 import org.apache.sysds.runtime.data.SparseBlockCSR;
@@ -1828,12 +1830,16 @@ public class LibMatrixAgg {
         * @param ru row upper index
         */
        private static void d_uakp( DenseBlock a, DenseBlock c, int n, 
KahanObject kbuff, KahanPlus kplus, int rl, int ru ) {
-               final int bil = a.index(rl);
-               final int biu = a.index(ru-1);
-               for(int bi=bil; bi<=biu; bi++) {
-                       int lpos = (bi==bil) ? a.pos(rl) : 0;
-                       int len = (bi==biu) ? a.pos(ru-1)-lpos+n : 
a.blockSize(bi)*n;
-                       sum(a.valuesAt(bi), lpos, len, kbuff, kplus);
+               if(a instanceof DenseBlockFP64DEDUP)
+                       uakpDedup(a, c, n, kbuff, kplus, rl, ru);
+               else {
+                       final int bil = a.index(rl);
+                       final int biu = a.index(ru - 1);
+                       for (int bi = bil; bi <= biu; bi++) {
+                               int lpos = (bi == bil) ? a.pos(rl) : 0;
+                               int len = (bi == biu) ? a.pos(ru - 1) - lpos + 
n : a.blockSize(bi) * n;
+                               sum(a.valuesAt(bi), lpos, len, kbuff, kplus);
+                       }
                }
                c.set(kbuff);
        }
@@ -1851,10 +1857,14 @@ public class LibMatrixAgg {
         */
        private static void d_uarkp( DenseBlock a, DenseBlock c, int n, 
KahanObject kbuff, KahanPlus kplus, int rl, int ru ) 
        {
-               for( int i=rl; i<ru; i++ ) {
-                       kbuff.set(0, 0); //reset buffer
-                       sum( a.values(i), a.pos(i), n, kbuff, kplus );
-                       c.set(i, kbuff);
+               if(a instanceof DenseBlockFP64DEDUP)
+                       uarkpDedup(a, c, n, kbuff, kplus, rl, ru);
+               else {
+                       for (int i = rl; i < ru; i++) {
+                               kbuff.set(0, 0); //reset buffer
+                               sum(a.values(i), a.pos(i), n, kbuff, kplus);
+                               c.set(i, kbuff);
+                       }
                }
        }
        
@@ -1870,8 +1880,12 @@ public class LibMatrixAgg {
         * @param ru row upper index
         */
        private static void d_uackp( DenseBlock a, DenseBlock c, int n, 
KahanObject kbuff, KahanPlus kplus, int rl, int ru ) {
-               for( int i=rl; i<ru; i++ )
-                       sumAgg( a.values(i), c, a.pos(i), n, kbuff, kplus );
+               if(a instanceof DenseBlockFP64DEDUP)
+                       uackpDedup(a, c, n, kbuff, kplus, rl, ru);
+               else {
+                       for( int i=rl; i<ru; i++ )
+                               sumAgg( a.values(i), c, a.pos(i), n, kbuff, 
kplus );
+               }
        }
 
        /**
@@ -3462,7 +3476,77 @@ public class LibMatrixAgg {
                        c[ ci+aix[ i+7 ] ] --;
                }
        }
-       
+
+
+       //////////////////////////////////////////////////////
+       // Duplicated dense block related utility functions //
+       /////////////////////////////////////////////////////
+
+
+       private static void uakpDedup (DenseBlock a, DenseBlock c, int n, 
KahanObject kbuff, KahanPlus kplus, int rl, int ru) {
+               HashMap<double[], Integer> counts = new HashMap<>();
+               for(int i = rl; i < ru; i++) {
+                       double[] row = a.values(i);
+                       Integer count = counts.getOrDefault(row, 0);
+                       count += 1;
+                       counts.put(row, count);
+               }
+               counts.forEach((row, count) -> {
+                       for(double r : row) {
+                               kplus.execute3(kbuff, r, count);
+                       }
+               });
+       }
+
+       private static void uarkpDedup( DenseBlock a, DenseBlock c, int n, 
KahanObject kbuff, KahanPlus kplus, int rl, int ru ) {
+               HashMap<double[], double[]> cache = new HashMap<>();
+               for(int i = rl; i < ru; i++) {
+                       double[] row = a.values(i);
+                       int finalI = i;
+                       double[] kbuff_array = cache.computeIfAbsent(row, 
lambda_row -> {
+                               kbuff.set(0, 0);
+                               sum(lambda_row, a.pos(finalI), n, kbuff, kplus);
+                               return new double[] {kbuff._sum, 
kbuff._correction};
+                       });
+                       cache.putIfAbsent(row, kbuff_array);
+                       c.set(i, 0, kbuff_array[0]);
+                       c.set(i, 1, kbuff_array[1]);
+               }
+       }
+
+       private static void uackpDedup( DenseBlock a, DenseBlock c, int n, 
KahanObject kbuff, KahanPlus kplus, int rl, int ru ) {
+               HashMap<double[], Integer> counts = new HashMap<>();
+               for(int i = rl; i < ru; i++) {
+                       double[] row = a.values(i);
+                       Integer count = counts.getOrDefault(row, 0);
+                       count += 1;
+                       counts.put(row, count);
+               }
+               double[] sum = new double[n];
+               double[] corr = new double[n];
+               counts.forEach((row, count) -> {
+                       for(int i = 0; i < row.length; i++) {
+                               kbuff._sum = sum[i];
+                               kbuff._correction = corr[i];
+                               kplus.execute3(kbuff, row[i], count);
+                               sum[i] = kbuff._sum;
+                               corr[i] = kbuff._correction;
+                       }
+               });
+               double[] out_sum = c.values(0);
+               double[] out_corr = c.values(1);
+               int pos0 = c.pos(0), pos1 = c.pos(1);
+               for(int i = 0; i < n; i++) {
+                       double tmp_sum = out_sum[pos0 + i] + sum[i];
+                       if(Math.abs(out_sum[pos0 + i]) > Math.abs(sum[i]))
+                               out_corr[pos1 + i] = ((out_sum[pos0 + i] - 
tmp_sum) + sum[i]) + out_corr[pos1 + i] + corr[i];
+                       else
+                               out_corr[pos1 + i] = ((sum[i] - tmp_sum) + 
out_sum[pos0 + i]) + out_corr[pos1 + i] + corr[i];
+                       out_sum[pos0 + i] = tmp_sum + out_corr[pos1 + i];
+               }
+       }
+
+
        /////////////////////////////////////////////////////////
        // Task Implementations for Multi-Threaded Operations  //
        /////////////////////////////////////////////////////////
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
index ce44e2e343..54ac8d2e22 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
@@ -23,6 +23,7 @@ import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.List;
 import java.util.concurrent.Callable;
+import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Future;
@@ -41,6 +42,7 @@ import org.apache.sysds.lops.WeightedUnaryMM.WUMMType;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import 
org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
 import org.apache.sysds.runtime.data.DenseBlock;
+import org.apache.sysds.runtime.data.DenseBlockFP64DEDUP;
 import org.apache.sysds.runtime.data.DenseBlockFactory;
 import org.apache.sysds.runtime.data.SparseBlock;
 import org.apache.sysds.runtime.data.SparseBlock.Type;
@@ -197,7 +199,10 @@ public class LibMatrixMult
                        ret = new MatrixBlock(m1.rlen, m2.clen, ultraSparse | 
sparse);
                else 
                        ret.reset(m1.rlen, m2.clen, ultraSparse | sparse);
-               ret.allocateBlock();
+               if(m1.denseBlock instanceof DenseBlockFP64DEDUP)
+                       ret.allocateDenseBlock(true, true);
+               else
+                       ret.allocateBlock();
                
                // Detect if we should transpose skinny right side.
                boolean tm2 = !fixedRet && checkPrepMatrixMultRightInput(m1,m2);
@@ -258,9 +263,11 @@ public class LibMatrixMult
                try {
                        ExecutorService pool = CommonThreadPool.get(k);
                        ArrayList<MatrixMultTask> tasks = new ArrayList<>();
-                       ArrayList<Integer> blklens = 
UtilFunctions.getBalancedBlockSizesDefault(num, k, (pm2r || pm2c));
+                       ArrayList<Integer> blklens = 
UtilFunctions.getBalancedBlockSizesDefault(num, k,
+                               (pm2r || pm2c || ret.denseBlock instanceof 
DenseBlockFP64DEDUP));
+                       ConcurrentHashMap<double[], double[]> cache = 
m1.denseBlock instanceof DenseBlockFP64DEDUP ? new ConcurrentHashMap(): null;
                        for(int i = 0, lb = 0; i < blklens.size(); lb += 
blklens.get(i), i++)
-                               tasks.add(new MatrixMultTask(m1, m2, ret, tm2, 
pm2r, pm2c, m1Perm, sparse, lb, lb + blklens.get(i)));
+                               tasks.add(new MatrixMultTask(m1, m2, ret, tm2, 
pm2r, pm2c, m1Perm, sparse, lb, lb + blklens.get(i), cache));
                        // execute tasks
                        List<Future<Object>> taskret = pool.invokeAll(tasks);
                        pool.shutdown();
@@ -1129,21 +1136,44 @@ public class LibMatrixMult
                                cvals[cix+j] = dotProduct(avals, b.values(j), 
aix, b.pos(j), cd);
                }
        }
-       
+
+       public static void matrixMultDenseDenseMMDedup(DenseBlock a, DenseBlock 
b, DenseBlock c, int n, int cd, int rl, int ru, ConcurrentHashMap<double[], 
double[]> cache) {
+               //n = m2.clen;
+               //cd = m1.clen;
+               for (int i = rl; i < ru; i++) {
+                       double[] a_row = a.values(i);
+                       double[] c_row = cache.getOrDefault(a_row, null);
+                       if (c_row == null) {
+                               c_row = new double[n];
+                               for (int j = 0; j < n; j++) {
+                                       c_row[j] = 0.0;
+                                       //the following requires 
b.isContiguous(0,cd)
+                                       double[] b_column = b.values(0);
+                                       for (int k = 0; k < cd; k++) {
+                                               c_row[j] += a_row[k] * 
b_column[b.pos(k, j)];
+                                       }
+                               }
+                               //the following requires
+                               cache.put(a_row, c_row);
+                       }
+                       c.set(i, c_row);
+               }
+       }
+
        //note: public for use by codegen for consistency
        public static void matrixMultDenseDenseMM(DenseBlock a, DenseBlock b, 
DenseBlock c, int n, int cd, int rl, int ru, int cl, int cu) {
                //1) Unrolled inner loop (for better instruction-level 
parallelism)
-               //2) Blocked execution (for less cache trashing in parallel 
exec) 
+               //2) Blocked execution (for less cache trashing in parallel 
exec)
                //3) Asymmetric block sizes (for less misses in inner loop, yet 
blocks in L1/L2)
-               
-               final int blocksizeI = 32; //64//256KB c block (typical L2 size 
per core), 32KB a block 
-               final int blocksizeK = 24; //64//256KB b block (typical L2 size 
per core), used while read 512B of a / read/write 4KB of c 
-               final int blocksizeJ = 1024; //512//4KB (typical main-memory 
page size), for scan 
+
+               final int blocksizeI = 32; //64//256KB c block (typical L2 size 
per core), 32KB a block
+               final int blocksizeK = 24; //64//256KB b block (typical L2 size 
per core), used while read 512B of a / read/write 4KB of c
+               final int blocksizeJ = 1024; //512//4KB (typical main-memory 
page size), for scan
 
                //temporary arrays (nnz a, b index)
                double[] ta = new double[ blocksizeK ];
                int[]  tbi  = new int[ blocksizeK ];
-               
+
                //blocked execution
                for( int bi = rl; bi < ru; bi+=blocksizeI )
                        for( int bk = 0, bimin = Math.min(ru, bi+blocksizeI); 
bk < cd; bk+=blocksizeK ) 
@@ -4135,9 +4165,10 @@ public class LibMatrixMult
                private final boolean _sparse; //sparse output
                private final int _rl;
                private final int _ru;
+               private final ConcurrentHashMap<double[], double[]> _cache;
 
                protected MatrixMultTask( MatrixBlock m1, MatrixBlock m2, 
MatrixBlock ret,
-                       boolean tm2, boolean pm2r, boolean pm2c, boolean 
m1Perm, boolean sparse, int rl, int ru )
+                       boolean tm2, boolean pm2r, boolean pm2c, boolean 
m1Perm, boolean sparse, int rl, int ru, ConcurrentHashMap<double[], double[]> 
cache )
                {
                        _m1 = m1;
                        _m2 = m2;
@@ -4148,7 +4179,8 @@ public class LibMatrixMult
                        _sparse = sparse;
                        _rl = rl;
                        _ru = ru;
-                       
+                       _cache = cache;
+
                        if( pm2r ) { //vector-matrix / matrix-matrix
                                //allocate local result for partial aggregation
                                _ret = new MatrixBlock(ret.rlen, ret.clen, 
false);
@@ -4174,7 +4206,11 @@ public class LibMatrixMult
                        if( _ret.sparse ) //ultra-sparse
                                matrixMultUltraSparse(_m1, _m2, _ret, _m1Perm, 
rl, ru);
                        else if(!_m1.sparse && !_m2.sparse)
-                               matrixMultDenseDense(_m1, _m2, _ret, _tm2, 
_pm2r, rl, ru, cl, cu);
+                               if(_m1.denseBlock instanceof 
DenseBlockFP64DEDUP && _m2.denseBlock.isContiguous(0,_m1.clen) && cl == 0 && cu 
== _m2.clen)
+                                       
matrixMultDenseDenseMMDedup(_m1.denseBlock, _m2.denseBlock, _ret.denseBlock, 
_m2.clen, _m1.clen, rl, ru, _cache);
+                               else
+                                       matrixMultDenseDense(_m1, _m2, _ret, 
_tm2, _pm2r, rl, ru, cl, cu);
+
                        else if(_m1.sparse && _m2.sparse)
                                matrixMultSparseSparse(_m1, _m2, _ret, _pm2r, 
_sparse, rl, ru);
                        else if(_m1.sparse)
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index 01a5216b4b..10c2e16ae5 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -30,6 +30,7 @@ import java.io.ObjectOutputStream;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
+import java.util.HashMap;
 import java.util.Iterator;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Future;
@@ -56,6 +57,7 @@ import 
org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject.UpdateType;
 import org.apache.sysds.runtime.data.DenseBlock;
 import org.apache.sysds.runtime.data.DenseBlockFP64;
+import org.apache.sysds.runtime.data.DenseBlockFP64DEDUP;
 import org.apache.sysds.runtime.data.DenseBlockFactory;
 import org.apache.sysds.runtime.data.SparseBlock;
 import org.apache.sysds.runtime.data.SparseBlockCOO;
@@ -172,7 +174,11 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock<MatrixBlock>,
        public MatrixBlock(int rl, int cl, boolean sp, long estnnz) {
                reset(rl, cl, sp, estnnz, 0);
        }
-       
+
+       public MatrixBlock(int rl, int cl, boolean sp, long estnnz, boolean 
dedup) {
+               reset(rl, cl, sp, estnnz, 0, dedup);
+       }
+
        public MatrixBlock(MatrixBlock that) {
                copy(that);
        }
@@ -298,7 +304,27 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock<MatrixBlock>,
                else
                        resetDense(val);
        }
-       
+
+       public void reset(int rl, int cl, boolean sp, long estnnz, double val, 
boolean dedup) {
+               //check for valid dimensions
+               if( rl < 0 || cl < 0 )
+                       throw new RuntimeException("Invalid block dimensions: 
"+rl+" "+cl);
+
+               //reset basic meta data
+               rlen = rl;
+               clen = cl;
+               sparse = (val == 0) ? sp : false;
+               nonZeros = (val == 0) ? 0 : (long)rl*cl;
+               estimatedNNzsPerRow = (estnnz < 0 || !sparse) ? -1 :
+                               (int)Math.ceil((double)estnnz/(double)rlen);
+
+               //reset sparse/dense blocks
+               if( sparse )
+                       resetSparse();
+               else
+                       resetDense(val, dedup);
+       }
+
        private void resetSparse() {
                if(sparseBlock == null)
                        return;
@@ -315,7 +341,18 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock<MatrixBlock>,
                        denseBlock.set(val);
                }
        }
-       
+
+       private void resetDense(double val, boolean dedup) {
+               //handle to dense block allocation and
+               //reset dense block to given value
+               if( denseBlock != null )
+                       denseBlock.reset(rlen, clen, val);
+               else if( val != 0 ) {
+                       allocateDenseBlock(false, dedup);
+                       denseBlock.set(val);
+               }
+       }
+
        /**
         * NOTE: This method is designed only for dense representation.
         * 
@@ -401,6 +438,10 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock<MatrixBlock>,
                        denseBlock = DenseBlockFactory.createDenseBlock(rlen, 
clen, containsDuplicates);
                        return true;
                }
+               else if( containsDuplicates && !(denseBlock instanceof 
DenseBlockFP64DEDUP)) {
+                       denseBlock = DenseBlockFactory.createDenseBlock(rlen, 
clen, true);
+                       return true;
+               }
                else if( denseBlock.capacity() < limit ){
                        denseBlock.reset(rlen, clen);
                        return true;
@@ -2037,6 +2078,11 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock<MatrixBlock>,
                                        cleanupBlock(false, true); //reuse dense
                                        readDenseBlock(in); //always dense 
in-mem if dense on disk
                                        break;
+                               case DEDUP_BLOCK:
+                                       sparse = false;
+                                       cleanupBlock(false, true); //reuse dense
+                                       readDedupDenseBlock(in); //always dense 
in-mem if dense on disk
+                                       break;
                                case EMPTY_BLOCK:
                                        sparse = true;
                                        cleanupBlock(true, !(sparseBlock 
instanceof SparseBlockCSR));
@@ -2052,6 +2098,33 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock<MatrixBlock>,
                }
        }
 
+       private void readDedupDenseBlock(DataInput in) throws IOException, 
DMLRuntimeException {
+               allocateDenseBlock(true,true);
+               DenseBlock a = getDenseBlock();
+               if(a.getDim(0) != rlen || a.getDim(1) != clen)
+                       a.resetNoFill(rlen, clen); // reset the dimensions of a 
if incorrect.
+               HashMap<Integer, double[]> mapping = new HashMap<>();
+               for( int i=0; i<rlen; i++ ) {
+                       Integer pos = in.readInt();
+                       double[] row = mapping.get(pos);
+                       if( row == null){
+                               row = new double[clen];
+                               mapping.put(pos, row);
+                       }
+                       a.set(i, row);
+               }
+               for (int i = 0; i < mapping.size(); i++) {
+                       double[] row = mapping.get(i);
+                       if (row == null) {
+                               throw new DMLRuntimeException("serialized 
object is corrupt, did not find unique row number [" + i +"] in mappings");
+                       }
+                       for (int j = 0; j < clen; j++) {
+                               row[j] = in.readDouble();
+                       }
+               }
+               nonZeros = a.countNonZeros();
+       }
+
        private void readDenseBlock(DataInput in) throws IOException, 
DMLRuntimeException {
                // allocate dense block resets the block if already allocated.
                allocateDenseBlock(true);
@@ -2186,7 +2259,7 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock<MatrixBlock>,
        }
        
        @Override
-       public void write(DataOutput out) 
+       public void write(DataOutput out)
                throws IOException 
        {
                //determine format
@@ -2218,12 +2291,55 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock<MatrixBlock>,
                                writeDenseToUltraSparse(out);
                        else if( sparseDst )
                                writeDenseToSparse(out);
+                       else if( denseBlock instanceof DenseBlockFP64DEDUP )
+                               writeDedupDenseblock(out);
                        else
                                writeDenseBlock(out);
                }
        }
 
-       private static void writeEmptyBlock(DataOutput out) 
+       private void writeDedupDenseblock(DataOutput out)
+                       throws IOException
+       {
+               out.writeByte( BlockType.DEDUP_BLOCK.ordinal() );
+
+               DenseBlockFP64DEDUP a = (DenseBlockFP64DEDUP) getDenseBlock();
+               if (rlen > a.numBlocks())
+                       throw new DMLRuntimeException("Serialize 
DedupDenseblock: block does not contain enough rows ["+a.numBlocks() +" < " + 
rlen + "]");
+
+               HashMap<double[], Integer> mapping = new HashMap<>((int) 
(a.getNrDistinctRows()*1.1));
+               ArrayList<double[]> unique_rows = new ArrayList<>((int) 
(a.getNrDistinctRows()*1.1));
+
+               for(int i=0; i<rlen; i++) {
+                       double[] avals = a.values(i); //equals 1 row
+                       Integer pos = mapping.get(avals);
+                       if (pos == null) {
+                               pos = mapping.size();
+                               unique_rows.add(avals);
+                               mapping.put(avals, pos);
+                       }
+                       out.writeInt(pos);
+               }
+               if( mapping.size() != unique_rows.size() )
+                       throw new DMLRuntimeException("Serialize 
DedupDenseblock: Map Size != Row Size");
+
+               if( out instanceof MatrixBlockDataOutput) { //fast serialize
+                       MatrixBlockDataOutput mout = (MatrixBlockDataOutput)out;
+                       for (double[] row : unique_rows) {
+                               mout.writeDoubleArray(clen, row);
+                       }
+               }
+               else { //general case (if fast serialize not supported)
+
+                       for (double[] row : unique_rows) {
+                               for (int i = 0; i < clen; i++) {
+                                       out.writeDouble(row[i]);
+                               }
+                       }
+               }
+       }
+
+       private static void writeEmptyBlock(DataOutput out)
                throws IOException
        {
                //empty blocks do not need to materialize row information
@@ -2586,6 +2702,10 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock<MatrixBlock>,
        }
        
        public long estimateSizeInMemory() {
+               if (denseBlock instanceof DenseBlockFP64DEDUP) {
+                       double size = getHeaderSize() + ((DenseBlockFP64DEDUP) 
denseBlock).estimateMemory();
+                       return (long) Math.min(size, Long.MAX_VALUE);
+               }
                return estimateSizeInMemory(rlen, clen, getSparsity());
        }
 
@@ -2781,6 +2901,11 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock<MatrixBlock>,
                //in-memory size given by header if not allocated
                if( !isAllocated() ) 
                        return getHeaderSize();
+               //dedup dense block uses less in-memory than other dense blocks
+               if (denseBlock instanceof DenseBlockFP64DEDUP) {
+                       double size = getHeaderSize() + ((DenseBlockFP64DEDUP) 
denseBlock).estimateMemory();
+                       return (long) Math.min(size, Long.MAX_VALUE);
+               }
                //in-memory size of dense/sparse representation
                return !sparse ? estimateSizeDenseInMemory(rlen, clen) :
                        estimateSizeSparseInMemory(rlen, clen, getSparsity(),
@@ -5155,7 +5280,7 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock<MatrixBlock>,
                        ret.reset(rlen, clen, replacement);
                        return ret;
                }
-               
+
                boolean NaNpattern = Double.isNaN(pattern);
                if( sparse ) //SPARSE
                {
diff --git 
a/src/test/java/org/apache/sysds/test/functions/io/binary/SerializeTest.java 
b/src/test/java/org/apache/sysds/test/functions/io/binary/SerializeTest.java
index c3253388ae..be1e6a1ac2 100644
--- a/src/test/java/org/apache/sysds/test/functions/io/binary/SerializeTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/io/binary/SerializeTest.java
@@ -19,6 +19,10 @@
 
 package org.apache.sysds.test.functions.io.binary;
 
+import com.google.crypto.tink.subtle.Random;
+import org.apache.sysds.runtime.controlprogram.caching.ByteBuffer;
+import org.apache.sysds.runtime.util.FastBufferedDataOutputStream;
+import org.apache.sysds.runtime.util.LocalFileUtils;
 import org.junit.Assert;
 import org.junit.Test;
 import org.apache.sysds.common.Types.FileFormat;
@@ -31,7 +35,11 @@ import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
 
-public class SerializeTest extends AutomatedTestBase 
+import java.io.FileOutputStream;
+import java.util.HashMap;
+import java.util.HashSet;
+
+public class SerializeTest extends AutomatedTestBase
 {
        private final static String TEST_NAME = "SerializeTest";
        private final static String TEST_DIR = "functions/io/binary/";
@@ -61,7 +69,13 @@ public class SerializeTest extends AutomatedTestBase
        { 
                runSerializeTest( rows1, cols1, 1.0 ); 
        }
-       
+       @Test
+       public void testDedupDenseBlock()
+       {
+               runSerializeDedupDenseTest( rows1, cols1 );
+       }
+
+
        @Test
        public void testDenseSparseBlock() 
        { 
@@ -123,4 +137,58 @@ public class SerializeTest extends AutomatedTestBase
                        throw new RuntimeException(ex);
                }
        }
+
+       private void runSerializeDedupDenseTest( int rows, int cols )
+       {
+               try
+               {
+                       //generate actual dataset
+                       double[][] X = getRandomMatrix(rows, cols, -1.0, 1.0, 
1.0, 7);
+                       double[][] X_duplicated = new double[rows*10][];
+                       MatrixBlock mb = new MatrixBlock(rows*10, cols, false, 
0, true);
+                       mb.allocateDenseBlock(true, true);
+                       HashMap<double[], Integer > seen = new HashMap<>();
+                       for (int i = 0; i < rows*10; i++) {
+                               int row = Random.randInt(rows);
+                               Integer tmpPos = seen.get(X[row]);
+                               if(tmpPos == null) {
+                                       tmpPos = seen.size();
+                                       seen.put(X[row], tmpPos);
+                               }
+                               X_duplicated[i] = X[row];
+                               mb.quickSetRow(i, X[row]);
+                       }
+
+                       String fname = SCRIPT_DIR + TEST_DIR + 
"dedupSerializedBlock.out";
+                       LocalFileUtils.writeCacheBlockToLocal(fname, mb);
+                       MatrixBlock mb2 = (MatrixBlock) 
LocalFileUtils.readCacheBlockFromLocal(fname, true);
+
+                       //compare matrices - values
+                       for( int i=0; i<mb.getNumRows(); i++ )
+                               for( int j=0; j<mb.getNumColumns(); j++ )
+                               {
+                                       double val1 = mb.quickGetValue(i, j);
+                                       double val2 = mb2.quickGetValue(i, j);
+                                       Assert.assertEquals(val1, val2, eps);
+                               }
+
+                       //compare matrices - values
+                       HashMap<double[], Integer > seen2 = new HashMap<>();
+                       for( int i=0; i<mb.getNumRows(); i++ ){
+                               double[] row = mb2.getDenseBlock().values(i);
+                               Integer tmpPos = seen2.get(row);
+                               if(tmpPos == null) {
+                                       tmpPos = seen2.size();
+                                       seen2.put(row, tmpPos);
+                               }
+                               Integer posMb1 = 
seen.get(mb.getDenseBlock().values(i));
+                               Assert.assertEquals( (long) tmpPos, (long) 
posMb1);
+                       }
+               }
+               catch(Exception ex)
+               {
+                       ex.printStackTrace();
+                       throw new RuntimeException(ex);
+               }
+       }
 }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbedding2Test.java
 
b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbedding2Test.java
index 4787d35bcf..6fb9f511ea 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbedding2Test.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbedding2Test.java
@@ -41,259 +41,260 @@ import java.util.Random;
 
 public class TransformFrameEncodeWordEmbedding2Test extends AutomatedTestBase
 {
-       private final static String TEST_NAME1 = 
"TransformFrameEncodeWordEmbeddings2";
-       private final static String TEST_NAME2a = 
"TransformFrameEncodeWordEmbeddings2MultiCols1";
-       private final static String TEST_NAME2b = 
"TransformFrameEncodeWordEmbeddings2MultiCols2";
-
-       private final static String TEST_DIR = "functions/transform/";
-
-       @Override
-       public void setUp() {
-               TestUtils.clearAssertionInformation();
-               addTestConfiguration(TEST_NAME1, new 
TestConfiguration(TEST_DIR, TEST_NAME1));
-               addTestConfiguration(TEST_NAME2a, new 
TestConfiguration(TEST_DIR, TEST_NAME2a));
-               addTestConfiguration(TEST_NAME2b, new 
TestConfiguration(TEST_DIR, TEST_NAME2b));
-       }
-
-       @Test
-       public void testTransformToWordEmbeddings() {
-               runTransformTest(TEST_NAME1, ExecMode.SINGLE_NODE);
-       }
-
-       @Test
-       @Ignore
-       public void testNonRandomTransformToWordEmbeddings2Cols() {
-               runTransformTest(TEST_NAME2a, ExecMode.SINGLE_NODE);
-       }
-
-       @Test
-       @Ignore
-       public void testRandomTransformToWordEmbeddings4Cols() {
-               runTransformTestMultiCols(TEST_NAME2b, ExecMode.SINGLE_NODE);
-       }
-
-       @Test
-       @Ignore
-       public void runBenchmark(){
-               runBenchmark(TEST_NAME1, ExecMode.SINGLE_NODE);
-       }
-
-
-       private void runBenchmark(String testname, ExecMode rt)
-       {
-               //set runtime platform
-               ExecMode rtold = setExecMode(rt);
-               try
-               {
-                       int rows = 100;
-                       //int cols = 300;
-                       getAndLoadTestConfiguration(testname);
-                       fullDMLScriptName = getScript();
-
-                       // Generate random embeddings for the distinct tokens
-                       // double[][] a = createRandomMatrix("embeddings", 
rows, cols, 0, 10, 1, new Date().getTime());
-
-                       // Generate random distinct tokens
-                       List<String> strings = generateRandomStrings(rows, 10);
-
-                       // Generate the dictionary by assigning unique ID to 
each distinct token
-                       writeDictToCsvFile(strings, baseDirectory + INPUT_DIR + 
"dict");
-
-                       // Create the dataset by repeating and shuffling the 
distinct tokens
-                       List<String> stringsColumn = 
shuffleAndMultiplyStrings(strings, 320);
-                       writeStringsToCsvFile(stringsColumn, baseDirectory + 
INPUT_DIR + "data");
-
-                       //run script
-                       programArgs = new String[]{"-stats","-args", 
input("embeddings"), input("data"), input("dict"), output("result")};
-                       runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
-               }
-               catch(Exception ex) {
-                       throw new RuntimeException(ex);
-
-               }
-               finally {
-                       resetExecMode(rtold);
-               }
-       }
-
-       private void runTransformTest(String testname, ExecMode rt)
-       {
-               //set runtime platform
-               ExecMode rtold = setExecMode(rt);
-               try
-               {
-                       int rows = 100;
-                       int cols = 300;
-                       getAndLoadTestConfiguration(testname);
-                       fullDMLScriptName = getScript();
-
-                       // Generate random embeddings for the distinct tokens
-                       double[][] a = createRandomMatrix("embeddings", rows, 
cols, 0, 10, 1, new Date().getTime());
-
-                       // Generate random distinct tokens
-                       List<String> strings = generateRandomStrings(rows, 10);
-
-                       // Generate the dictionary by assigning unique ID to 
each distinct token
-                       Map<String,Integer> map = writeDictToCsvFile(strings, 
baseDirectory + INPUT_DIR + "dict");
-
-                       // Create the dataset by repeating and shuffling the 
distinct tokens
-                       List<String> stringsColumn = 
shuffleAndMultiplyStrings(strings, 32);
-                       writeStringsToCsvFile(stringsColumn, baseDirectory + 
INPUT_DIR + "data");
-
-                       //run script
-                       programArgs = new String[]{"-stats","-args", 
input("embeddings"), input("data"), input("dict"), output("result")};
-                       runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
-
-                       // Manually derive the expected result
-                       double[][] res_expected = 
manuallyDeriveWordEmbeddings(cols, a, map, stringsColumn);
-
-                       // Compare results
-                       HashMap<MatrixValue.CellIndex, Double> res_actual = 
readDMLMatrixFromOutputDir("result");
-                       double[][] resultActualDouble = 
TestUtils.convertHashMapToDoubleArray(res_actual);
-                       TestUtils.compareMatrices(resultActualDouble, 
res_expected, 1e-6);
-               }
-               catch(Exception ex) {
-                       throw new RuntimeException(ex);
-
-               }
-               finally {
-                       resetExecMode(rtold);
-               }
-       }
-
-       @SuppressWarnings("unused")
-       private void print2DimDoubleArray(double[][] resultActualDouble) {
-               Arrays.stream(resultActualDouble).forEach(
-                               e -> 
System.out.println(Arrays.stream(e).mapToObj(d -> String.format("%06.1f", d))
-                                               .reduce("", (sub, elem) -> sub 
+ " " + elem)));
-       }
-
-       private void runTransformTestMultiCols(String testname, ExecMode rt)
-       {
-               //set runtime platform
-               ExecMode rtold = setExecMode(rt);
-               try
-               {
-                       int rows = 100;
-                       int cols = 100;
-                       getAndLoadTestConfiguration(testname);
-                       fullDMLScriptName = getScript();
-
-                       // Generate random embeddings for the distinct tokens
-                       double[][] a = createRandomMatrix("embeddings", rows, 
cols, 0, 10, 1, new Date().getTime());
-
-                       // Generate random distinct tokens
-                       List<String> strings = generateRandomStrings(rows, 10);
-
-                       // Generate the dictionary by assigning unique ID to 
each distinct token
-                       Map<String,Integer> map = writeDictToCsvFile(strings, 
baseDirectory + INPUT_DIR + "dict");
-
-                       // Create the dataset by repeating and shuffling the 
distinct tokens
-                       List<String> stringsColumn = 
shuffleAndMultiplyStrings(strings, 10);
-                       writeStringsToCsvFile(stringsColumn, baseDirectory + 
INPUT_DIR + "data");
-
-                       //run script
-                       programArgs = new String[]{"-stats","-args", 
input("embeddings"), input("data"), input("dict"), output("result"), 
output("result2")};
-                       runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
-
-                       // Manually derive the expected result
-                       double[][] res_expected = 
manuallyDeriveWordEmbeddings(cols, a, map, stringsColumn);
-
-                       // Compare results
-                       HashMap<MatrixValue.CellIndex, Double> res_actual = 
readDMLMatrixFromOutputDir("result");
-                       HashMap<MatrixValue.CellIndex, Double> res_actual2 = 
readDMLMatrixFromOutputDir("result2");
-                       double[][] resultActualDouble  = 
TestUtils.convertHashMapToDoubleArray(res_actual);
-                       double[][] resultActualDouble2 = 
TestUtils.convertHashMapToDoubleArray(res_actual2);
-                       //System.out.println("Actual Result1 [" + 
resultActualDouble.length + "x" + resultActualDouble[0].length + "]:");
-                       ///print2DimDoubleArray(resultActualDouble);
-                       //System.out.println("\nActual Result2 [" + 
resultActualDouble.length + "x" + resultActualDouble[0].length + "]:");
-                       //print2DimDoubleArray(resultActualDouble2);
-                       //System.out.println("\nExpected Result [" + 
res_expected.length + "x" + res_expected[0].length + "]:");
-                       //print2DimDoubleArray(res_expected);
-                       TestUtils.compareMatrices(resultActualDouble, 
res_expected, 1e-6);
-                       TestUtils.compareMatrices(resultActualDouble, 
resultActualDouble2, 1e-6);
-               }
-               catch(Exception ex) {
-                       throw new RuntimeException(ex);
-
-               }
-               finally {
-                       resetExecMode(rtold);
-               }
-       }
-
-       private double[][] manuallyDeriveWordEmbeddings(int cols, double[][] a, 
Map<String, Integer> map, List<String> stringsColumn) {
-               // Manually derive the expected result
-               double[][] res_expected = new 
double[stringsColumn.size()][cols];
-               for (int i = 0; i < stringsColumn.size(); i++) {
-                       int rowMapped = map.get(stringsColumn.get(i));
-                       System.arraycopy(a[rowMapped], 0, res_expected[i], 0, 
cols);
-               }
-               return res_expected;
-       }
-
-       @SuppressWarnings("unused")
-       private double[][] generateWordEmbeddings(int rows, int cols) {
-               double[][] a = new double[rows][cols];
-               for (int i = 0; i < a.length; i++) {
-                       for (int j = 0; j < a[i].length; j++) {
-                               a[i][j] = cols *i + j;
-                       }
-
-               }
-               return a;
-       }
-
-       public static List<String> shuffleAndMultiplyStrings(List<String> 
strings, int multiply){
-               List<String> out = new ArrayList<>();
-               Random random = new Random();
-               for (int i = 0; i < strings.size()*multiply; i++) {
-                       out.add(strings.get(random.nextInt(strings.size())));
-               }
-               return out;
-       }
-
-       public static List<String> generateRandomStrings(int numStrings, int 
stringLength) {
-               List<String> randomStrings = new ArrayList<>();
-               Random random = new Random();
-               String characters = 
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789";
-               for (int i = 0; i < numStrings; i++) {
-                       randomStrings.add(generateRandomString(random, 
stringLength, characters));
-               }
-               return randomStrings;
-       }
-
-       public static String generateRandomString(Random random, int 
stringLength, String characters){
-               StringBuilder randomString = new StringBuilder();
-               for (int j = 0; j < stringLength; j++) {
-                       int randomIndex = random.nextInt(characters.length());
-                       randomString.append(characters.charAt(randomIndex));
-               }
-               return randomString.toString();
-       }
-
-       public static void writeStringsToCsvFile(List<String> strings, String 
fileName) {
-               try (BufferedWriter bw = new BufferedWriter(new 
FileWriter(fileName))) {
-                       for (String line : strings) {
-                               bw.write(line);
-                               bw.newLine();
-                       }
-               } catch (IOException e) {
-                       e.printStackTrace();
-               }
-       }
-
-       public static Map<String,Integer> writeDictToCsvFile(List<String> 
strings, String fileName) {
-               try (BufferedWriter bw = new BufferedWriter(new 
FileWriter(fileName))) {
-                       Map<String,Integer> map = new HashMap<>();
-                       for (int i = 0; i < strings.size(); i++) {
-                               map.put(strings.get(i), i);
-                               bw.write(strings.get(i) + Lop.DATATYPE_PREFIX + 
(i+1) + "\n");
-                       }
-                       return map;
-               } catch (IOException e) {
-                       e.printStackTrace();
-                       return null;
-               }
-       }
+    private final static String TEST_NAME1 = 
"TransformFrameEncodeWordEmbeddings2";
+    private final static String TEST_NAME2a = 
"TransformFrameEncodeWordEmbeddings2MultiCols1";
+    private final static String TEST_NAME2b = 
"TransformFrameEncodeWordEmbeddings2MultiCols2";
+
+    private final static String TEST_DIR = "functions/transform/";
+    private final static String TEST_CLASS_DIR = TEST_DIR + 
TransformFrameEncodeWordEmbedding1Test.class.getSimpleName() + "/";
+
+    @Override
+    public void setUp() {
+        TestUtils.clearAssertionInformation();
+        addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_DIR, 
TEST_NAME1));
+        addTestConfiguration(TEST_NAME2a, new TestConfiguration(TEST_DIR, 
TEST_NAME2a));
+        addTestConfiguration(TEST_NAME2b, new TestConfiguration(TEST_DIR, 
TEST_NAME2b));
+    }
+
+    @Test
+    public void testTransformToWordEmbeddings() {
+        runTransformTest(TEST_NAME1, ExecMode.SINGLE_NODE);
+    }
+
+    @Test
+    @Ignore
+    public void testNonRandomTransformToWordEmbeddings2Cols() {
+        runTransformTest(TEST_NAME2a, ExecMode.SINGLE_NODE);
+    }
+
+    @Test
+    @Ignore
+    public void testRandomTransformToWordEmbeddings4Cols() {
+        runTransformTestMultiCols(TEST_NAME2b, ExecMode.SINGLE_NODE);
+    }
+
+    @Test
+    @Ignore
+    public void runBenchmark(){
+        runBenchmark(TEST_NAME1, ExecMode.SINGLE_NODE);
+    }
+
+
+
+
+    private void runBenchmark(String testname, ExecMode rt)
+    {
+        //set runtime platform
+        ExecMode rtold = setExecMode(rt);
+        try
+        {
+            int rows = 100;
+            int cols = 300;
+            getAndLoadTestConfiguration(testname);
+            fullDMLScriptName = getScript();
+
+            // Generate random embeddings for the distinct tokens
+            double[][] a = createRandomMatrix("embeddings", rows, cols, 0, 10, 
1, new Date().getTime());
+
+            // Generate random distinct tokens
+            List<String> strings = generateRandomStrings(rows, 10);
+
+            // Generate the dictionary by assigning unique ID to each distinct 
token
+            Map<String,Integer> map = writeDictToCsvFile(strings, 
baseDirectory + INPUT_DIR + "dict");
+
+            // Create the dataset by repeating and shuffling the distinct 
tokens
+            List<String> stringsColumn = shuffleAndMultiplyStrings(strings, 
320);
+            writeStringsToCsvFile(stringsColumn, baseDirectory + INPUT_DIR + 
"data");
+
+            //run script
+            programArgs = new String[]{"-stats","-args", input("embeddings"), 
input("data"), input("dict"), output("result")};
+            runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
+        }
+        catch(Exception ex) {
+            throw new RuntimeException(ex);
+
+        }
+        finally {
+            resetExecMode(rtold);
+        }
+    }
+
+    private void runTransformTest(String testname, ExecMode rt)
+    {
+        //set runtime platform
+        ExecMode rtold = setExecMode(rt);
+        try
+        {
+            int rows = 100;
+            int cols = 300;
+            getAndLoadTestConfiguration(testname);
+            fullDMLScriptName = getScript();
+
+            // Generate random embeddings for the distinct tokens
+            double[][] a = createRandomMatrix("embeddings", rows, cols, 0, 10, 
1, new Date().getTime());
+
+            // Generate random distinct tokens
+            List<String> strings = generateRandomStrings(rows, 10);
+
+            // Generate the dictionary by assigning unique ID to each distinct 
token
+            Map<String,Integer> map = writeDictToCsvFile(strings, 
baseDirectory + INPUT_DIR + "dict");
+
+            // Create the dataset by repeating and shuffling the distinct 
tokens
+            List<String> stringsColumn = shuffleAndMultiplyStrings(strings, 
320);
+            writeStringsToCsvFile(stringsColumn, baseDirectory + INPUT_DIR + 
"data");
+
+            //run script
+            programArgs = new String[]{"-stats","-args", input("embeddings"), 
input("data"), input("dict"), output("result")};
+            runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
+
+            // Manually derive the expected result
+            double[][] res_expected = manuallyDeriveWordEmbeddings(cols, a, 
map, stringsColumn);
+
+            // Compare results
+            HashMap<MatrixValue.CellIndex, Double> res_actual = 
readDMLMatrixFromOutputDir("result");
+            double[][] resultActualDouble = 
TestUtils.convertHashMapToDoubleArray(res_actual);
+            TestUtils.compareMatrices(resultActualDouble, res_expected, 1e-6);
+        }
+        catch(Exception ex) {
+            throw new RuntimeException(ex);
+
+        }
+        finally {
+            resetExecMode(rtold);
+        }
+    }
+
+    public static void print2DimDoubleArray(double[][] resultActualDouble) {
+        Arrays.stream(resultActualDouble).forEach(
+                e -> System.out.println(Arrays.stream(e).mapToObj(d -> 
String.format("%06.1f", d))
+                        .reduce("", (sub, elem) -> sub + " " + elem)));
+    }
+
+    private void runTransformTestMultiCols(String testname, ExecMode rt)
+    {
+        //set runtime platform
+        ExecMode rtold = setExecMode(rt);
+        try
+        {
+            int rows = 100;
+            int cols = 100;
+            getAndLoadTestConfiguration(testname);
+            fullDMLScriptName = getScript();
+
+            // Generate random embeddings for the distinct tokens
+            double[][] a = createRandomMatrix("embeddings", rows, cols, 0, 10, 
1, new Date().getTime());
+
+            // Generate random distinct tokens
+            List<String> strings = generateRandomStrings(rows, 10);
+
+            // Generate the dictionary by assigning unique ID to each distinct 
token
+            Map<String,Integer> map = writeDictToCsvFile(strings, 
baseDirectory + INPUT_DIR + "dict");
+
+            // Create the dataset by repeating and shuffling the distinct 
tokens
+            List<String> stringsColumn = shuffleAndMultiplyStrings(strings, 
10);
+            writeStringsToCsvFile(stringsColumn, baseDirectory + INPUT_DIR + 
"data");
+
+            //run script
+            programArgs = new String[]{"-stats","-args", input("embeddings"), 
input("data"), input("dict"), output("result"), output("result2")};
+            runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
+
+            // Manually derive the expected result
+            double[][] res_expected = manuallyDeriveWordEmbeddings(cols, a, 
map, stringsColumn);
+
+            // Compare results
+            HashMap<MatrixValue.CellIndex, Double> res_actual = 
readDMLMatrixFromOutputDir("result");
+            HashMap<MatrixValue.CellIndex, Double> res_actual2 = 
readDMLMatrixFromOutputDir("result2");
+            double[][] resultActualDouble  = 
TestUtils.convertHashMapToDoubleArray(res_actual);
+            double[][] resultActualDouble2 = 
TestUtils.convertHashMapToDoubleArray(res_actual2);
+            //System.out.println("Actual Result1 [" + 
resultActualDouble.length + "x" + resultActualDouble[0].length + "]:");
+            print2DimDoubleArray(resultActualDouble);
+            //System.out.println("\nActual Result2 [" + 
resultActualDouble.length + "x" + resultActualDouble[0].length + "]:");
+            //print2DimDoubleArray(resultActualDouble2);
+            //System.out.println("\nExpected Result [" + res_expected.length + 
"x" + res_expected[0].length + "]:");
+            //print2DimDoubleArray(res_expected);
+            TestUtils.compareMatrices(resultActualDouble, res_expected, 1e-6);
+            TestUtils.compareMatrices(resultActualDouble, resultActualDouble2, 
1e-6);
+        }
+        catch(Exception ex) {
+            throw new RuntimeException(ex);
+
+        }
+        finally {
+            resetExecMode(rtold);
+        }
+    }
+
+    public static double[][] manuallyDeriveWordEmbeddings(int cols, double[][] 
a, Map<String, Integer> map, List<String> stringsColumn) {
+        // Manually derive the expected result
+        double[][] res_expected = new double[stringsColumn.size()][cols];
+        for (int i = 0; i < stringsColumn.size(); i++) {
+            int rowMapped = map.get(stringsColumn.get(i));
+            System.arraycopy(a[rowMapped], 0, res_expected[i], 0, cols);
+        }
+        return res_expected;
+    }
+
+    private double[][] generateWordEmbeddings(int rows, int cols) {
+        double[][] a = new double[rows][cols];
+        for (int i = 0; i < a.length; i++) {
+            for (int j = 0; j < a[i].length; j++) {
+                a[i][j] = cols *i + j;
+            }
+
+        }
+        return a;
+    }
+
+    public static List<String> shuffleAndMultiplyStrings(List<String> strings, 
int multiply){
+        List<String> out = new ArrayList<>();
+        Random random = new Random();
+        for (int i = 0; i < strings.size()*multiply; i++) {
+            out.add(strings.get(random.nextInt(strings.size())));
+        }
+        return out;
+    }
+
+    public static List<String> generateRandomStrings(int numStrings, int 
stringLength) {
+        List<String> randomStrings = new ArrayList<>();
+        Random random = new Random();
+        String characters = 
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789";
+        for (int i = 0; i < numStrings; i++) {
+            randomStrings.add(generateRandomString(random, stringLength, 
characters));
+        }
+        return randomStrings;
+    }
+
+    public static String generateRandomString(Random random, int stringLength, 
String characters){
+        StringBuilder randomString = new StringBuilder();
+        for (int j = 0; j < stringLength; j++) {
+            int randomIndex = random.nextInt(characters.length());
+            randomString.append(characters.charAt(randomIndex));
+        }
+        return randomString.toString();
+    }
+
+    public static void writeStringsToCsvFile(List<String> strings, String 
fileName) {
+        try (BufferedWriter bw = new BufferedWriter(new FileWriter(fileName))) 
{
+            for (String line : strings) {
+                bw.write(line);
+                bw.newLine();
+            }
+        } catch (IOException e) {
+            e.printStackTrace();
+        }
+    }
+
+    public static Map<String,Integer> writeDictToCsvFile(List<String> strings, 
String fileName) {
+        try (BufferedWriter bw = new BufferedWriter(new FileWriter(fileName))) 
{
+            Map<String,Integer> map = new HashMap<>();
+            for (int i = 0; i < strings.size(); i++) {
+                map.put(strings.get(i), i);
+                bw.write(strings.get(i) + Lop.DATATYPE_PREFIX + (i+1) + "\n");
+            }
+            return map;
+        } catch (IOException e) {
+            e.printStackTrace();
+            return null;
+        }
+    }
 }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbeddingMMTest.java
 
b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbeddingMMTest.java
new file mode 100644
index 0000000000..3862294ca6
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbeddingMMTest.java
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+
+package org.apache.sysds.test.functions.transform;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.matrix.data.MatrixValue;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+import java.util.Date;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static 
org.apache.sysds.test.functions.transform.TransformFrameEncodeWordEmbedding2Test.*;
+
+public class TransformFrameEncodeWordEmbeddingMMTest extends AutomatedTestBase 
{
+    private final static String TEST_NAME1 = 
"TransformFrameEncodeWordEmbeddingsMM";
+    private final static String TEST_DIR = "functions/transform/";
+
+    @Override
+    public void setUp() {
+        TestUtils.clearAssertionInformation();
+        addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_DIR, 
TEST_NAME1));
+    }
+
+    @Test
+    public void testMultiplication() {
+        runMatrixMultiplicationTest(TEST_NAME1, Types.ExecMode.SINGLE_NODE);
+    }
+
+    private void runMatrixMultiplicationTest(String testname, Types.ExecMode 
rt)
+    {
+        //set runtime platform
+        Types.ExecMode rtold = setExecMode(rt);
+        try
+        {
+            int rows = 100;
+            int cols = 300;
+            getAndLoadTestConfiguration(testname);
+            fullDMLScriptName = getScript();
+
+            // Generate random embeddings for the distinct tokens
+            double[][] a = createRandomMatrix("embeddings", rows, cols, 0, 10, 
1, new Date().getTime());
+            double[][] b = createRandomMatrix("factor", cols, cols, 0, 10, 1, 
new Date().getTime());
+
+            // Generate random distinct tokens
+            List<String> strings = generateRandomStrings(rows, 10);
+
+            // Generate the dictionary by assigning unique ID to each distinct 
token
+            Map<String,Integer> map = writeDictToCsvFile(strings, 
baseDirectory + INPUT_DIR + "dict");
+
+            // Create the dataset by repeating and shuffling the distinct 
tokens
+            int factor = 320;
+            rows *= factor;
+            List<String> stringsColumn = shuffleAndMultiplyStrings(strings, 
factor);
+            writeStringsToCsvFile(stringsColumn, baseDirectory + INPUT_DIR + 
"data");
+
+            //run script
+            programArgs = new String[]{"-stats","-args", input("embeddings"), 
input("data"), input("dict"), input("factor"), output("result")};
+            runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
+
+            // Manually derive the expected result
+            double[][] res_expected = manuallyDeriveWordEmbeddings(cols, a, 
map, stringsColumn);
+            double[][] res_expectedMM = new double[rows][cols];
+            for (int i = 0; i < res_expectedMM.length; i++) {
+                for (int j = 0; j < res_expectedMM[i].length; j++) {
+                    res_expectedMM[i][j] = 0.0;
+                    for (int k = 0; k < res_expected[i].length; k++) {
+                        res_expectedMM[i][j] += res_expected[i][k]*b[k][j];
+                    }
+                }
+            }
+            // Compare results
+            HashMap<MatrixValue.CellIndex, Double> res_actual = 
readDMLMatrixFromOutputDir("result");
+            double[][] resultActualDouble = 
TestUtils.convertHashMapToDoubleArray(res_actual);
+            //print2DimDoubleArray(resultActualDouble);
+            TestUtils.compareMatrices(res_expectedMM, resultActualDouble, 
1e-8);
+        }
+        catch(Exception ex) {
+            throw new RuntimeException(ex);
+
+        }
+        finally {
+            resetExecMode(rtold);
+        }
+    }
+}
diff --git 
a/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbeddingRowSumTest.java
 
b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbeddingRowSumTest.java
new file mode 100644
index 0000000000..b3a09f3ec8
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbeddingRowSumTest.java
@@ -0,0 +1,231 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.transform;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.functionobjects.KahanPlus;
+import org.apache.sysds.runtime.instructions.cp.KahanObject;
+import org.apache.sysds.runtime.matrix.data.MatrixValue;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+import java.util.Date;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static 
org.apache.sysds.runtime.functionobjects.KahanPlus.getKahanPlusFnObject;
+import static 
org.apache.sysds.test.functions.transform.TransformFrameEncodeWordEmbedding2Test.*;
+
+public class TransformFrameEncodeWordEmbeddingRowSumTest extends 
AutomatedTestBase {
+    private final static String TEST_NAME1 = 
"TransformFrameEncodeWordEmbeddingsRowSum";
+    private final static String TEST_NAME2 = 
"TransformFrameEncodeWordEmbeddingsColSum";
+    private final static String TEST_NAME3 = 
"TransformFrameEncodeWordEmbeddingsFullSum";
+    private final static String TEST_DIR = "functions/transform/";
+
+    @Override
+    public void setUp() {
+        TestUtils.clearAssertionInformation();
+        addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_DIR, 
TEST_NAME1));
+        addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_DIR, 
TEST_NAME2));
+        addTestConfiguration(TEST_NAME3, new TestConfiguration(TEST_DIR, 
TEST_NAME3));
+    }
+
+    @Test
+    public void testDedupRowSums() {
+        runDedupRowSumTest(TEST_NAME1, Types.ExecMode.SINGLE_NODE);
+    }
+
+    @Test
+    public void testDedupColSums() {
+        runDedupColSumTest(TEST_NAME2, Types.ExecMode.SINGLE_NODE);
+    }
+
+    @Test
+    public void testDedupFullSums() {
+        runDedupFullSumTest(TEST_NAME3, Types.ExecMode.SINGLE_NODE);
+    }
+
+    private void runDedupFullSumTest(String testname, Types.ExecMode rt)
+    {
+        //set runtime platform
+        Types.ExecMode rtold = setExecMode(rt);
+        try
+        {
+            int rows = 100;
+            int cols = 300;
+            getAndLoadTestConfiguration(testname);
+            fullDMLScriptName = getScript();
+
+            // Generate random embeddings for the distinct tokens
+            double[][] a = createRandomMatrix("embeddings", rows, cols, 0, 10, 
1, new Date().getTime());
+
+            // Generate random distinct tokens
+            List<String> strings = generateRandomStrings(rows, 10);
+
+            // Generate the dictionary by assigning unique ID to each distinct 
token
+            Map<String,Integer> map = writeDictToCsvFile(strings, 
baseDirectory + INPUT_DIR + "dict");
+
+            // Create the dataset by repeating and shuffling the distinct 
tokens
+            List<String> stringsColumn = shuffleAndMultiplyStrings(strings, 
320*6);
+            writeStringsToCsvFile(stringsColumn, baseDirectory + INPUT_DIR + 
"data");
+
+            //run script
+            programArgs = new String[]{"-stats","-args", input("embeddings"), 
input("data"), input("dict"), output("result")};
+            runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
+
+            // Manually derive the expected result
+            double[][] res_expected = manuallyDeriveWordEmbeddings(cols, a, 
map, stringsColumn);
+            double[][] sums_expected = new double[1][1];
+            KahanObject ko = new KahanObject(0,0);
+            KahanPlus kp = getKahanPlusFnObject();
+            for (int i = 0; i < res_expected.length; i++) {
+                for (int j = 0; j < res_expected[i].length; j++) {
+                    kp.execute2(ko,  res_expected[i][j]);
+                }
+            }
+            sums_expected[0][0] = ko._sum;
+            // Compare results
+            HashMap<MatrixValue.CellIndex, Double> res_actual = 
readDMLMatrixFromOutputDir("result");
+            double[][] resultActualDouble = 
TestUtils.convertHashMapToDoubleArray(res_actual);
+            //print2DimDoubleArray(resultActualDouble);
+            TestUtils.compareMatrices(sums_expected, resultActualDouble, 
1e-14);
+        }
+        catch(Exception ex) {
+            throw new RuntimeException(ex);
+
+        }
+        finally {
+            resetExecMode(rtold);
+        }
+    }
+
+    private void runDedupColSumTest(String testname, Types.ExecMode rt)
+    {
+        //set runtime platform
+        Types.ExecMode rtold = setExecMode(rt);
+        try
+        {
+            int rows = 100;
+            int cols = 300;
+            getAndLoadTestConfiguration(testname);
+            fullDMLScriptName = getScript();
+
+            // Generate random embeddings for the distinct tokens
+            double[][] a = createRandomMatrix("embeddings", rows, cols, 0, 10, 
1, new Date().getTime());
+
+            // Generate random distinct tokens
+            List<String> strings = generateRandomStrings(rows, 10);
+
+            // Generate the dictionary by assigning unique ID to each distinct 
token
+            Map<String,Integer> map = writeDictToCsvFile(strings, 
baseDirectory + INPUT_DIR + "dict");
+
+            // Create the dataset by repeating and shuffling the distinct 
tokens
+            List<String> stringsColumn = shuffleAndMultiplyStrings(strings, 
320*6);
+            writeStringsToCsvFile(stringsColumn, baseDirectory + INPUT_DIR + 
"data");
+
+            //run script
+            programArgs = new String[]{"-stats","-args", input("embeddings"), 
input("data"), input("dict"), output("result")};
+            runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
+
+            // Manually derive the expected result
+            double[][] res_expected = manuallyDeriveWordEmbeddings(cols, a, 
map, stringsColumn);
+            double[][] sums_expected = new double[1][res_expected[0].length];
+            KahanObject ko = new KahanObject(0,0);
+            KahanPlus kp = getKahanPlusFnObject();
+            for (int i = 0; i < res_expected[0].length; i++) {
+                ko.set(0,0);
+                for (int j = 0; j < res_expected.length; j++) {
+                    kp.execute2(ko,  res_expected[j][i]);
+                }
+                sums_expected[0][i] = ko._sum;
+            }
+            // Compare results
+            HashMap<MatrixValue.CellIndex, Double> res_actual = 
readDMLMatrixFromOutputDir("result");
+            double[][] resultActualDouble = 
TestUtils.convertHashMapToDoubleArray(res_actual);
+            //print2DimDoubleArray(resultActualDouble);
+            TestUtils.compareMatrices(sums_expected, resultActualDouble, 1e-9);
+        }
+        catch(Exception ex) {
+            throw new RuntimeException(ex);
+
+        }
+        finally {
+            resetExecMode(rtold);
+        }
+    }
+
+    private void runDedupRowSumTest(String testname, Types.ExecMode rt)
+    {
+        //set runtime platform
+        Types.ExecMode rtold = setExecMode(rt);
+        try
+        {
+            int rows = 100;
+            int cols = 300;
+            getAndLoadTestConfiguration(testname);
+            fullDMLScriptName = getScript();
+
+            // Generate random embeddings for the distinct tokens
+            double[][] a = createRandomMatrix("embeddings", rows, cols, 0, 10, 
1, new Date().getTime());
+
+            // Generate random distinct tokens
+            List<String> strings = generateRandomStrings(rows, 10);
+
+            // Generate the dictionary by assigning unique ID to each distinct 
token
+            Map<String,Integer> map = writeDictToCsvFile(strings, 
baseDirectory + INPUT_DIR + "dict");
+
+            // Create the dataset by repeating and shuffling the distinct 
tokens
+            List<String> stringsColumn = shuffleAndMultiplyStrings(strings, 
320);
+            writeStringsToCsvFile(stringsColumn, baseDirectory + INPUT_DIR + 
"data");
+
+            //run script
+            programArgs = new String[]{"-stats","-args", input("embeddings"), 
input("data"), input("dict"), output("result")};
+            runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
+
+            // Manually derive the expected result
+            double[][] res_expected = manuallyDeriveWordEmbeddings(cols, a, 
map, stringsColumn);
+            double[][] sums_expected = new double[res_expected.length][1];
+            KahanObject ko = new KahanObject(0,0);
+            KahanPlus kp = getKahanPlusFnObject();
+            for (int i = 0; i < res_expected.length; i++) {
+                ko.set(0,0);
+                for (int j = 0; j < res_expected[i].length; j++) {
+                    kp.execute2(ko,  res_expected[i][j]);
+                }
+                sums_expected[i][0] = ko._sum;
+            }
+            // Compare results
+            HashMap<MatrixValue.CellIndex, Double> res_actual = 
readDMLMatrixFromOutputDir("result");
+            double[][] resultActualDouble = 
TestUtils.convertHashMapToDoubleArray(res_actual);
+            //print2DimDoubleArray(resultActualDouble);
+            TestUtils.compareMatrices(sums_expected, resultActualDouble, 
1e-15);
+        }
+        catch(Exception ex) {
+            throw new RuntimeException(ex);
+
+        }
+        finally {
+            resetExecMode(rtold);
+        }
+    }
+}
diff --git 
a/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddingsColSum.dml
 
b/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddingsColSum.dml
new file mode 100644
index 0000000000..59e7963183
--- /dev/null
+++ 
b/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddingsColSum.dml
@@ -0,0 +1,32 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# Read the pre-trained word embeddings
+E = read($1, rows=100, cols=300, format="text");
+# Read the token sequence (1K) w/ 100 distinct tokens
+Data = read($2, data_type="frame", format="csv");
+# Read the recode map for the distinct tokens
+Meta = read($3, data_type="frame", format="csv");
+
+jspec = "{ids: true, word_embedding: [1]}";
+Data_enc = transformapply(target=Data, spec=jspec, meta=Meta, embedding=E);
+MatrixRowSum = colSums(Data_enc);
+write(MatrixRowSum, $4, format="text");
\ No newline at end of file
diff --git 
a/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddingsFullSum.dml
 
b/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddingsFullSum.dml
new file mode 100644
index 0000000000..5b5fb81303
--- /dev/null
+++ 
b/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddingsFullSum.dml
@@ -0,0 +1,32 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# Read the pre-trained word embeddings
+E = read($1, rows=100, cols=300, format="text");
+# Read the token sequence (1K) w/ 100 distinct tokens
+Data = read($2, data_type="frame", format="csv");
+# Read the recode map for the distinct tokens
+Meta = read($3, data_type="frame", format="csv");
+
+jspec = "{ids: true, word_embedding: [1]}";
+Data_enc = transformapply(target=Data, spec=jspec, meta=Meta, embedding=E);
+MatrixRowSum = matrix(sum(Data_enc),1,1)
+write(MatrixRowSum, $4, format="text");
\ No newline at end of file
diff --git 
a/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddingsMM.dml 
b/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddingsMM.dml
new file mode 100644
index 0000000000..c439ef50d7
--- /dev/null
+++ 
b/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddingsMM.dml
@@ -0,0 +1,38 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# Read the pre-trained word embeddings
+E = read($1, rows=100, cols=300, format="text");
+# Read the token sequence (1K) w/ 100 distinct tokens
+Data = read($2, data_type="frame", format="csv");
+# Read the recode map for the distinct tokens
+Meta = read($3, data_type="frame", format="csv");
+#Read the matrix that is used for multiplication after transform
+MM = read($4, rows=300, cols=300, format="text");
+
+jspec = "{ids: true, word_embedding: [1]}";
+Data_enc = transformapply(target=Data, spec=jspec, meta=Meta, embedding=E);
+Product = Data_enc %*% MM
+write(Product, $5, format="text");
+
+
+
+
diff --git 
a/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddingsRowSum.dml
 
b/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddingsRowSum.dml
new file mode 100644
index 0000000000..2ce047ba39
--- /dev/null
+++ 
b/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddingsRowSum.dml
@@ -0,0 +1,32 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# Read the pre-trained word embeddings
+E = read($1, rows=100, cols=300, format="text");
+# Read the token sequence (1K) w/ 100 distinct tokens
+Data = read($2, data_type="frame", format="csv");
+# Read the recode map for the distinct tokens
+Meta = read($3, data_type="frame", format="csv");
+
+jspec = "{ids: true, word_embedding: [1]}";
+Data_enc = transformapply(target=Data, spec=jspec, meta=Meta, embedding=E);
+MatrixRowSum = rowSums(Data_enc);
+write(MatrixRowSum, $4, format="text");
\ No newline at end of file

Reply via email to