This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 2c4bf88  [SYSTEMDS-2617] New builtin for obtaining frame column names
2c4bf88 is described below

commit 2c4bf8816a4b87d0a91dd02ff390bae87cef4b3e
Author: Kevin Innerebner <[email protected]>
AuthorDate: Thu Aug 13 18:58:31 2020 +0200

    [SYSTEMDS-2617] New builtin for obtaining frame column names
    
    New builtin colnames(X) for obtaining a single-row frame holding the
    column names by position.
    
    Closes #1020.
---
 .../java/org/apache/sysds/common/Builtins.java     |   3 +-
 src/main/java/org/apache/sysds/common/Types.java   |   3 +-
 src/main/java/org/apache/sysds/hops/UnaryOp.java   |   3 +-
 .../sysds/parser/BuiltinFunctionExpression.java    |   1 +
 .../org/apache/sysds/parser/DMLTranslator.java     |   5 +-
 .../runtime/instructions/CPInstructionParser.java  |   1 +
 .../runtime/instructions/SPInstructionParser.java  |  69 ++++++-------
 .../instructions/cp/UnaryFrameCPInstruction.java   |  14 ++-
 .../spark/UnaryFrameSPInstruction.java             |  33 ++++--
 .../sysds/runtime/matrix/data/FrameBlock.java      |   9 +-
 .../org/apache/sysds/test/AutomatedTestBase.java   |   9 --
 .../test/functions/frame/FrameColumnNamesTest.java | 113 +++++++++++++++++++++
 .../frame/{TypeOf.dml => ColumnNames.dml}          |   8 +-
 src/test/scripts/functions/frame/TypeOf.dml        |   2 +-
 14 files changed, 209 insertions(+), 64 deletions(-)

diff --git a/src/main/java/org/apache/sysds/common/Builtins.java 
b/src/main/java/org/apache/sysds/common/Builtins.java
index 22134ea..cc5b12b 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -64,6 +64,7 @@ public enum Builtins {
        COLMAX("colMaxs", false),
        COLMEAN("colMeans", false),
        COLMIN("colMins", false),
+       COLNAMES("colnames", false),
        COLPROD("colProds", false),
        COLSD("colSds", false),
        COLSUM("colSums", false),
@@ -182,7 +183,7 @@ public enum Builtins {
        TANH("tanh", false),
        TRACE("trace", false),
        TO_ONE_HOT("toOneHot", true),
-       TYPEOF("typeOf", false),
+       TYPEOF("typeof", false),
        COUNT_DISTINCT("countDistinct",false),
        COUNT_DISTINCT_APPROX("countDistinctApprox",false),
        VAR("var", false),
diff --git a/src/main/java/org/apache/sysds/common/Types.java 
b/src/main/java/org/apache/sysds/common/Types.java
index 92027a5..978c644 100644
--- a/src/main/java/org/apache/sysds/common/Types.java
+++ b/src/main/java/org/apache/sysds/common/Types.java
@@ -195,7 +195,7 @@ public class Types
                ABS, ACOS, ASIN, ASSERT, ATAN, CAST_AS_SCALAR, CAST_AS_MATRIX,
                CAST_AS_FRAME, CAST_AS_DOUBLE, CAST_AS_INT, CAST_AS_BOOLEAN,
                CEIL, CHOLESKY, COS, COSH, CUMMAX, CUMMIN, CUMPROD, CUMSUM,
-               CUMSUMPROD, DETECTSCHEMA, EIGEN, EXISTS, EXP, FLOOR, INVERSE,
+               CUMSUMPROD, DETECTSCHEMA, COLNAMES, EIGEN, EXISTS, EXP, FLOOR, 
INVERSE,
                IQM, ISNA, ISNAN, ISINF, LENGTH, LINEAGE, LOG, NCOL, NOT, NROW,
                MEDIAN, PRINT, ROUND, SIN, SINH, SIGN, SOFTMAX, SQRT, STOP, SVD,
                TAN, TANH, TYPEOF,
@@ -231,6 +231,7 @@ public class Types
                                case CUMPROD:         return "ucum*";
                                case CUMSUM:          return "ucumk+";
                                case CUMSUMPROD:      return "ucumk+*";
+                               case COLNAMES:        return "colnames";
                                case DETECTSCHEMA:    return "detectSchema";
                                case MULT2:           return "*2";
                                case NOT:             return "!";
diff --git a/src/main/java/org/apache/sysds/hops/UnaryOp.java 
b/src/main/java/org/apache/sysds/hops/UnaryOp.java
index f9b46a8..6da0e32 100644
--- a/src/main/java/org/apache/sysds/hops/UnaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/UnaryOp.java
@@ -539,7 +539,8 @@ public class UnaryOp extends MultiThreadedHop
                        setDim1(input.getDim1());
                        setDim2(1);
                }
-               else if(_op == OpOp1.TYPEOF || _op == OpOp1.DETECTSCHEMA) {
+               else if(_op == OpOp1.TYPEOF || _op == OpOp1.DETECTSCHEMA || _op 
== OpOp1.COLNAMES) {
+                       //TODO theses three builtins should rather be moved to 
unary aggregates
                        setDim1(1);
                        setDim2(input.getDim2());
                }
diff --git 
a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java 
b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
index d3966e2..7db411c 100644
--- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
@@ -716,6 +716,7 @@ public class BuiltinFunctionExpression extends 
DataIdentifier
                        break;
                case TYPEOF:
                case DETECTSCHEMA:
+               case COLNAMES:
                        checkNumParameters(1);
                        checkMatrixFrameParam(getFirstExpr());
                        output.setDataType(DataType.FRAME);
diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java 
b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
index f84f469..4747bfe 100644
--- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
@@ -2656,8 +2656,9 @@ public class DMLTranslator
                case CHOLESKY:
                case TYPEOF:
                case DETECTSCHEMA:
-                       currBuiltinOp = new UnaryOp(target.getName(), 
target.getDataType(), target.getValueType(),
-                               OpOp1.valueOf(source.getOpCode().name()), expr);
+               case COLNAMES:
+                       currBuiltinOp = new UnaryOp(target.getName(), 
target.getDataType(),
+                               target.getValueType(), 
OpOp1.valueOf(source.getOpCode().name()), expr);
                        break;
                        
                case OUTER:
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java 
b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
index 30ec6bd..d280b16 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
@@ -190,6 +190,7 @@ public class CPInstructionParser extends InstructionParser
                String2CPInstructionType.put( "sigmoid", CPType.Unary);
                String2CPInstructionType.put( "typeOf", CPType.Unary);
                String2CPInstructionType.put( "detectSchema", CPType.Unary);
+               String2CPInstructionType.put( "colnames", CPType.Unary);
                String2CPInstructionType.put( "isna", CPType.Unary);
                String2CPInstructionType.put( "isnan", CPType.Unary);
                String2CPInstructionType.put( "isinf", CPType.Unary);
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java 
b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
index a53acb9..b4104e1 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
@@ -93,10 +93,10 @@ public class SPInstructionParser extends InstructionParser
                String2SPInstructionType = new HashMap<>();
                
                //unary aggregate operators
-               String2SPInstructionType.put( "uak+"    , 
SPType.AggregateUnary);
+               String2SPInstructionType.put( "uak+"    , 
SPType.AggregateUnary);
                String2SPInstructionType.put( "uark+"   , 
SPType.AggregateUnary);
                String2SPInstructionType.put( "uack+"   , 
SPType.AggregateUnary);
-               String2SPInstructionType.put( "uasqk+"  , 
SPType.AggregateUnary);
+               String2SPInstructionType.put( "uasqk+"  , 
SPType.AggregateUnary);
                String2SPInstructionType.put( "uarsqk+" , 
SPType.AggregateUnary);
                String2SPInstructionType.put( "uacsqk+" , 
SPType.AggregateUnary);
                String2SPInstructionType.put( "uamean"  , 
SPType.AggregateUnary);
@@ -107,7 +107,7 @@ public class SPInstructionParser extends InstructionParser
                String2SPInstructionType.put( "uacvar"  , 
SPType.AggregateUnary);
                String2SPInstructionType.put( "uamax"   , 
SPType.AggregateUnary);
                String2SPInstructionType.put( "uarmax"  , 
SPType.AggregateUnary);
-               String2SPInstructionType.put( "uarimax",  
SPType.AggregateUnary);
+               String2SPInstructionType.put( "uarimax" ,  
SPType.AggregateUnary);
                String2SPInstructionType.put( "uacmax"  , 
SPType.AggregateUnary);
                String2SPInstructionType.put( "uamin"   , 
SPType.AggregateUnary);
                String2SPInstructionType.put( "uarmin"  , 
SPType.AggregateUnary);
@@ -127,7 +127,7 @@ public class SPInstructionParser extends InstructionParser
                String2SPInstructionType.put( "mapmmchain" , SPType.MAPMMCHAIN);
                String2SPInstructionType.put( "tsmm"       , SPType.TSMM); 
//single-pass tsmm
                String2SPInstructionType.put( "tsmm2"      , SPType.TSMM2); 
//multi-pass tsmm
-               String2SPInstructionType.put( "cpmm"       , SPType.CPMM);
+               String2SPInstructionType.put( "cpmm"       , SPType.CPMM);
                String2SPInstructionType.put( "rmm"        , SPType.RMM);
                String2SPInstructionType.put( "pmm"        , SPType.PMM);
                String2SPInstructionType.put( "zipmm"      , SPType.ZIPMM);
@@ -141,42 +141,42 @@ public class SPInstructionParser extends InstructionParser
                String2SPInstructionType.put( "tack+*"     , 
SPType.AggregateTernary);
 
                // Neural network operators
-               String2SPInstructionType.put( "conv2d",                 
SPType.Dnn);
+               String2SPInstructionType.put( "conv2d",          SPType.Dnn);
                String2SPInstructionType.put( "conv2d_bias_add", SPType.Dnn);
-               String2SPInstructionType.put( "maxpooling",             
SPType.Dnn);
-               String2SPInstructionType.put( "relu_maxpooling",          
SPType.Dnn);
+               String2SPInstructionType.put( "maxpooling",      SPType.Dnn);
+               String2SPInstructionType.put( "relu_maxpooling", SPType.Dnn);
                
                String2SPInstructionType.put( RightIndex.OPCODE, 
SPType.MatrixIndexing);
-               String2SPInstructionType.put( LeftIndex.OPCODE, 
SPType.MatrixIndexing);
-               String2SPInstructionType.put( "mapLeftIndex" , 
SPType.MatrixIndexing);
+               String2SPInstructionType.put( LeftIndex.OPCODE,  
SPType.MatrixIndexing);
+               String2SPInstructionType.put( "mapLeftIndex",    
SPType.MatrixIndexing);
                
                // Reorg Instruction Opcodes (repositioning of existing values)
-               String2SPInstructionType.put( "r'"         , SPType.Reorg);
-               String2SPInstructionType.put( "rev"        , SPType.Reorg);
-               String2SPInstructionType.put( "rdiag"      , SPType.Reorg);
-               String2SPInstructionType.put( "rshape"     , 
SPType.MatrixReshape);
-               String2SPInstructionType.put( "rsort"      , SPType.Reorg);
+               String2SPInstructionType.put( "r'",       SPType.Reorg);
+               String2SPInstructionType.put( "rev",      SPType.Reorg);
+               String2SPInstructionType.put( "rdiag",    SPType.Reorg);
+               String2SPInstructionType.put( "rshape",   SPType.MatrixReshape);
+               String2SPInstructionType.put( "rsort",    SPType.Reorg);
                
-               String2SPInstructionType.put( "+"    , SPType.Binary);
-               String2SPInstructionType.put( "-"    , SPType.Binary);
-               String2SPInstructionType.put( "*"    , SPType.Binary);
-               String2SPInstructionType.put( "/"    , SPType.Binary);
-               String2SPInstructionType.put( "%%"   , SPType.Binary);
-               String2SPInstructionType.put( "%/%"  , SPType.Binary);
-               String2SPInstructionType.put( "1-*"  , SPType.Binary);
-               String2SPInstructionType.put( "^"    , SPType.Binary);
-               String2SPInstructionType.put( "^2"   , SPType.Binary);
-               String2SPInstructionType.put( "*2"   , SPType.Binary);
-               String2SPInstructionType.put( "map+"    , SPType.Binary);
-               String2SPInstructionType.put( "map-"    , SPType.Binary);
-               String2SPInstructionType.put( "map*"    , SPType.Binary);
-               String2SPInstructionType.put( "map/"    , SPType.Binary);
-               String2SPInstructionType.put( "map%%"   , SPType.Binary);
-               String2SPInstructionType.put( "map%/%"  , SPType.Binary);
-               String2SPInstructionType.put( "map1-*"  , SPType.Binary);
-               String2SPInstructionType.put( "map^"    , SPType.Binary);
-               String2SPInstructionType.put( "map+*"   , SPType.Binary);
-               String2SPInstructionType.put( "map-*"   , SPType.Binary);
+               String2SPInstructionType.put( "+",        SPType.Binary);
+               String2SPInstructionType.put( "-",        SPType.Binary);
+               String2SPInstructionType.put( "*",        SPType.Binary);
+               String2SPInstructionType.put( "/",        SPType.Binary);
+               String2SPInstructionType.put( "%%",       SPType.Binary);
+               String2SPInstructionType.put( "%/%",      SPType.Binary);
+               String2SPInstructionType.put( "1-*",      SPType.Binary);
+               String2SPInstructionType.put( "^",        SPType.Binary);
+               String2SPInstructionType.put( "^2",       SPType.Binary);
+               String2SPInstructionType.put( "*2",       SPType.Binary);
+               String2SPInstructionType.put( "map+",     SPType.Binary);
+               String2SPInstructionType.put( "map-",     SPType.Binary);
+               String2SPInstructionType.put( "map*",     SPType.Binary);
+               String2SPInstructionType.put( "map/",     SPType.Binary);
+               String2SPInstructionType.put( "map%%",    SPType.Binary);
+               String2SPInstructionType.put( "map%/%",   SPType.Binary);
+               String2SPInstructionType.put( "map1-*",   SPType.Binary);
+               String2SPInstructionType.put( "map^",     SPType.Binary);
+               String2SPInstructionType.put( "map+*",    SPType.Binary);
+               String2SPInstructionType.put( "map-*",    SPType.Binary);
                String2SPInstructionType.put( "dropInvalidType", SPType.Binary);
                String2SPInstructionType.put( "mapdropInvalidLength", 
SPType.Binary);
                // Relational Instruction Opcodes
@@ -250,6 +250,7 @@ public class SPInstructionParser extends InstructionParser
                String2SPInstructionType.put( "sprop", SPType.Unary);
                String2SPInstructionType.put( "sigmoid", SPType.Unary);
                String2SPInstructionType.put( "detectSchema", SPType.Unary);
+               String2SPInstructionType.put( "colnames", SPType.Unary);
                String2SPInstructionType.put( "isna", SPType.Unary);
                String2SPInstructionType.put( "isnan", SPType.Unary);
                String2SPInstructionType.put( "isinf", SPType.Unary);
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/UnaryFrameCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/UnaryFrameCPInstruction.java
index 13af891..4cbf93c 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/UnaryFrameCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/UnaryFrameCPInstruction.java
@@ -20,6 +20,7 @@
 package org.apache.sysds.runtime.instructions.cp;
 
 import org.apache.sysds.lops.Lop;
+import org.apache.sysds.runtime.DMLScriptException;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.matrix.data.FrameBlock;
 import org.apache.sysds.runtime.matrix.operators.Operator;
@@ -37,12 +38,19 @@ public class UnaryFrameCPInstruction extends 
UnaryCPInstruction {
                        ec.releaseFrameInput(input1.getName());
                        ec.setFrameOutput(output.getName(), retBlock);
                }
-               else if(getOpcode().equals("detectSchema"))
-               {
+               else if(getOpcode().equals("detectSchema")) {
                        FrameBlock inBlock = ec.getFrameInput(input1.getName());
                        FrameBlock retBlock = 
inBlock.detectSchemaFromRow(Lop.SAMPLE_FRACTION);
                        ec.releaseFrameInput(input1.getName());
                        ec.setFrameOutput(output.getName(), retBlock);
                }
+               else if(getOpcode().equals("colnames")) {
+                       FrameBlock inBlock = ec.getFrameInput(input1.getName());
+                       FrameBlock retBlock = inBlock.getColumnNamesAsFrame();
+                       ec.releaseFrameInput(input1.getName());
+                       ec.setFrameOutput(output.getName(), retBlock);
+               }
+               else
+                       throw new DMLScriptException("Opcode '" + getOpcode() + 
"' is not a valid UnaryFrameCPInstruction");
        }
-}
\ No newline at end of file
+}
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/UnaryFrameSPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/UnaryFrameSPInstruction.java
index 6cd2785..d4bcf42 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/UnaryFrameSPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/UnaryFrameSPInstruction.java
@@ -23,7 +23,9 @@ import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.function.Function2;
 import org.apache.spark.api.java.function.PairFunction;
 import org.apache.sysds.common.Types;
+import org.apache.sysds.common.Types.OpOp1;
 import org.apache.sysds.lops.Lop;
+import org.apache.sysds.runtime.DMLScriptException;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
@@ -37,7 +39,7 @@ public class UnaryFrameSPInstruction extends 
UnarySPInstruction {
                super(SPInstruction.SPType.Unary, op, in, out, opcode, instr);
        }
 
-       public static UnaryFrameSPInstruction parseInstruction (String str ) {
+       public static UnaryFrameSPInstruction parseInstruction(String str) {
                CPOperand in = new CPOperand("", Types.ValueType.UNKNOWN, 
Types.DataType.UNKNOWN);
                CPOperand out = new CPOperand("", Types.ValueType.UNKNOWN, 
Types.DataType.UNKNOWN);
                String opcode = parseUnaryInstruction(str, in, out);
@@ -46,18 +48,36 @@ public class UnaryFrameSPInstruction extends 
UnarySPInstruction {
 
        @Override
        public void processInstruction(ExecutionContext ec) {
-               SparkExecutionContext sec = (SparkExecutionContext)ec;
-               //get input
-               JavaPairRDD<Long, FrameBlock> in = 
sec.getFrameBinaryBlockRDDHandleForVariable(input1.getName() );
-               JavaPairRDD<Long,FrameBlock> out = in.mapToPair(new 
DetectSchemaUsingRows());
+               SparkExecutionContext sec = (SparkExecutionContext) ec;
+               if(getOpcode().equals(OpOp1.DETECTSCHEMA.toString()))
+                       detectSchema(sec);
+               else if(getOpcode().equals(OpOp1.COLNAMES.toString()))
+                       columnNames(sec);
+               else
+                       throw new DMLScriptException("Opcode '" + getOpcode() + 
"' is not a valid UnaryFrameSPInstruction");
+       }
+
+       private void columnNames(SparkExecutionContext sec) {
+               // get input
+               JavaPairRDD<Long, FrameBlock> in = 
sec.getFrameBinaryBlockRDDHandleForVariable(input1.getName());
+               // get the first row block (frames are only blocked rowwise) 
and get its column names
+               FrameBlock outFrame = 
in.lookup(1L).get(0).getColumnNamesAsFrame();
+               sec.setFrameOutput(output.getName(), outFrame);
+       }
+
+       public void detectSchema(SparkExecutionContext sec) {
+               // get input
+               JavaPairRDD<Long, FrameBlock> in = 
sec.getFrameBinaryBlockRDDHandleForVariable(input1.getName());
+               JavaPairRDD<Long, FrameBlock> out = in.mapToPair(new 
DetectSchemaUsingRows());
                FrameBlock outFrame = out.values().reduce(new MergeFrame());
                sec.setFrameOutput(output.getName(), outFrame);
        }
 
        private static class DetectSchemaUsingRows implements 
PairFunction<Tuple2<Long, FrameBlock>, Long, FrameBlock> {
                private static final long serialVersionUID = 
5850400295183766400L;
+
                @Override
-               public Tuple2<Long,FrameBlock> call(Tuple2<Long, FrameBlock> 
arg0) throws Exception {
+               public Tuple2<Long, FrameBlock> call(Tuple2<Long, FrameBlock> 
arg0) throws Exception {
                        FrameBlock resultBlock = new 
FrameBlock(arg0._2.detectSchemaFromRow(Lop.SAMPLE_FRACTION));
                        return new Tuple2<>(1L, resultBlock);
                }
@@ -65,6 +85,7 @@ public class UnaryFrameSPInstruction extends 
UnarySPInstruction {
 
        private static class MergeFrame implements Function2<FrameBlock, 
FrameBlock, FrameBlock> {
                private static final long serialVersionUID = 
942744896521069893L;
+
                @Override
                public FrameBlock call(FrameBlock arg0, FrameBlock arg1) throws 
Exception {
                        return new FrameBlock(FrameBlock.mergeSchema(arg0, 
arg1));
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java
index e473acd..7ae6b53 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java
@@ -176,7 +176,14 @@ public class FrameBlock implements CacheBlock, 
Externalizable
        public String[] getColumnNames() {
                return getColumnNames(true);
        }
-               
+       
+       
+       public FrameBlock getColumnNamesAsFrame() {
+               FrameBlock fb = new FrameBlock(getNumColumns(), 
ValueType.STRING);
+               fb.appendRow(getColumnNames());
+               return fb;
+       }
+       
        /**
         * Returns the column names of the frame block. This method 
         * allocates default column names if required.
diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java 
b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
index 4e55248..7e63127 100644
--- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
+++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
@@ -20,8 +20,6 @@
 package org.apache.sysds.test;
 
 import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 
 import java.io.ByteArrayOutputStream;
@@ -197,13 +195,6 @@ public abstract class AutomatedTestBase {
 
        private boolean isOutAndExpectedDeletionDisabled = false;
 
-       private int iExpectedStdOutState = 0;
-       private int iUnexpectedStdOutState = 0;
-       // private PrintStream originalPrintStreamStd = null;
-
-       private int iExpectedStdErrState = 0;
-       // private PrintStream originalErrStreamStd = null;
-
        private boolean outputBuffering = true;
        
        // Timestamp before test start.
diff --git 
a/src/test/java/org/apache/sysds/test/functions/frame/FrameColumnNamesTest.java 
b/src/test/java/org/apache/sysds/test/functions/frame/FrameColumnNamesTest.java
new file mode 100644
index 0000000..be00f61
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/frame/FrameColumnNamesTest.java
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.frame;
+
+import java.util.Arrays;
+import java.util.Collection;
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.common.Types.FileFormat;
+import org.apache.sysds.lops.LopProperties.ExecType;
+import org.apache.sysds.runtime.io.FileFormatPropertiesCSV;
+import org.apache.sysds.runtime.io.FrameWriter;
+import org.apache.sysds.runtime.io.FrameWriterFactory;
+import org.apache.sysds.runtime.matrix.data.FrameBlock;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import edu.emory.mathcs.backport.java.util.Collections;
+
+@RunWith(value = Parameterized.class)
[email protected]
+public class FrameColumnNamesTest extends AutomatedTestBase {
+       private final static String TEST_NAME = "ColumnNames";
+       private final static String TEST_DIR = "functions/frame/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
FrameColumnNamesTest.class.getSimpleName() + "/";
+
+       private final static int _rows = 10000;
+       @Parameterized.Parameter()
+       public String[] _columnNames;
+
+       @Parameterized.Parameters
+       public static Collection<Object[]> data() {
+               return Arrays.asList(new Object[][] {{new String[] {"A", "B", 
"C"}}, {new String[] {"1", "2", "3"}},
+                       {new String[] {"Hello", "hello", "Hello", "hi", "u", 
"w", "u"}},});
+       }
+
+       @Override
+       public void setUp() {
+               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"B"}));
+       }
+
+       @Test
+       public void testDetectSchemaDoubleCP() {
+               runGetColNamesTest(_columnNames, ExecType.CP);
+       }
+
+       @Test
+       public void testDetectSchemaDoubleSpark() {
+               runGetColNamesTest(_columnNames, ExecType.SPARK);
+       }
+
+       @SuppressWarnings("unchecked")
+       private void runGetColNamesTest(String[] columnNames, ExecType et) {
+               Types.ExecMode platformOld = setExecMode(et);
+               boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+               try {
+                       getAndLoadTestConfiguration(TEST_NAME);
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + TEST_NAME + ".dml";
+                       programArgs = new String[] {"-args", input("A"), 
String.valueOf(_rows),
+                               Integer.toString(columnNames.length), 
output("B")};
+
+                       Types.ValueType[] schema = (Types.ValueType[]) 
Collections
+                               .nCopies(columnNames.length, 
Types.ValueType.FP64).toArray(new Types.ValueType[0]);
+                       FrameBlock frame1 = new FrameBlock(schema);
+                       frame1.setColumnNames(columnNames);
+                       FrameWriter writer = 
FrameWriterFactory.createFrameWriter(FileFormat.CSV,
+                               new FileFormatPropertiesCSV(true, ",", false));
+
+                       double[][] A = getRandomMatrix(_rows, schema.length, 
Double.MIN_VALUE, Double.MAX_VALUE, 0.7, 14123);
+                       TestUtils.initFrameData(frame1, A, schema, _rows);
+                       writer.writeFrameToHDFS(frame1, input("A"), _rows, 
schema.length);
+
+                       runTest(true, false, null, -1);
+                       FrameBlock frame2 = readDMLFrameFromHDFS("B", 
FileFormat.BINARY);
+
+                       // verify output schema
+                       for(int i = 0; i < schema.length; i++) {
+                               Assert
+                                       .assertEquals("Wrong result: " + 
columnNames[i] + ".", columnNames[i], frame2.get(0, i).toString());
+                       }
+               }
+               catch(Exception ex) {
+                       throw new RuntimeException(ex);
+               }
+               finally {
+                       rtplatform = platformOld;
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+               }
+       }
+}
diff --git a/src/test/scripts/functions/frame/TypeOf.dml 
b/src/test/scripts/functions/frame/ColumnNames.dml
similarity index 87%
copy from src/test/scripts/functions/frame/TypeOf.dml
copy to src/test/scripts/functions/frame/ColumnNames.dml
index 7394541..319a03c 100644
--- a/src/test/scripts/functions/frame/TypeOf.dml
+++ b/src/test/scripts/functions/frame/ColumnNames.dml
@@ -19,8 +19,6 @@
 #
 #-------------------------------------------------------------
 
-X = read($1, rows=$2, cols=$3, data_type="frame", format="csv");
-R = typeOf(X);
-print(toString(R))
-write(R, $4, format="binary");
-  
\ No newline at end of file
+X = read($1, rows=$2, cols=$3, data_type="frame", format="csv", header=TRUE);
+R = colnames(X);
+write(R, $4, format="binary");
\ No newline at end of file
diff --git a/src/test/scripts/functions/frame/TypeOf.dml 
b/src/test/scripts/functions/frame/TypeOf.dml
index 7394541..6e8b3bb 100644
--- a/src/test/scripts/functions/frame/TypeOf.dml
+++ b/src/test/scripts/functions/frame/TypeOf.dml
@@ -20,7 +20,7 @@
 #-------------------------------------------------------------
 
 X = read($1, rows=$2, cols=$3, data_type="frame", format="csv");
-R = typeOf(X);
+R = typeof(X);
 print(toString(R))
 write(R, $4, format="binary");
   
\ No newline at end of file

Reply via email to