janniklinde commented on code in PR #2495:
URL: https://github.com/apache/systemds/pull/2495#discussion_r3465732288


##########
src/main/java/org/apache/sysds/parser/DMLTranslator.java:
##########
@@ -2762,6 +2762,22 @@ else if ( in.length == 2 )
                case TYPEOF:
                case DET:
                case DETECTSCHEMA:
+               case SET_NAMES:
+                       currBuiltinOp = new BinaryOp(
+                                       target.getName(),
+                                       target.getDataType(),
+                                       target.getValueType(),
+                                       OpOp2.SET_COLNAMES, expr, expr2
+                       );
+                       break;

Review Comment:
   All of the above cases will now be mapped to a `BinaryOp` of type 
`SET_COLNAMES`. This is incorrect and the reason why the tests fail (see for 
example: 
https://github.com/apache/systemds/actions/runs/27905840894/job/82663839671?pr=2495#step:3:3857).



##########
src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java:
##########
@@ -1095,13 +1095,34 @@ else if( getAllExpr().length == 2 ) { //binary
                case TYPEOF:
                case DETECTSCHEMA:
                case COLNAMES:
+               case GET_NAMES:
                        checkNumParameters(1);
                        checkMatrixFrameParam(getFirstExpr());
                        output.setDataType(DataType.FRAME);
                        output.setDimensions(1, id.getDim2());
                        output.setBlocksize (id.getBlocksize());
                        output.setValueType(ValueType.STRING);
                        break;
+               case SET_NAMES:
+                       //check if we use 2 parameters (Frame on which nemas 
are set and vector for names)
+                       checkNumParameters(2);
+
+                       // check if first paramters is a frame
+                       checkMatrixFrameParam(getFirstExpr());
+
+                       // check if second paramters is a vector 1xn Frame
+                       checkMatrixFrameParam(getSecondExpr());
+
+                       //output should be a frame
+                       output.setDataType(DataType.FRAME);
+
+
+                       checkMatrixFrameParam(getFirstExpr());
+                       output.setDataType(DataType.FRAME);
+                       output.setDimensions(id.getDim1(), id.getDim2());
+                       output.setBlocksize (id.getBlocksize());
+                       output.setValueType(ValueType.STRING);

Review Comment:
   `setName` does not have `STRING` as a return type



##########
src/test/java/org/apache/sysds/test/functions/frame/FrameColNamesPropagationTest.java:
##########
@@ -0,0 +1,292 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.frame;
+
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.common.Types.ExecType;
+import org.apache.sysds.common.Types.FileFormat;
+import org.apache.sysds.runtime.frame.data.FrameBlock;
+import org.apache.sysds.runtime.io.FileFormatPropertiesCSV;
+import org.apache.sysds.runtime.io.FrameWriter;
+import org.apache.sysds.runtime.io.FrameWriterFactory;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+
+@RunWith(value = Parameterized.class)
[email protected]
+public class FrameColNamesPropagationTest extends AutomatedTestBase {
+    private final static String TEST_NAME_CBIND = "ColNameCbindPropagation";
+    private final static String TEST_NAME_RBIND = "ColNameRbindPropagation";
+    private final static String TEST_NAME_SLICE = "ColNameSlicePropagation";
+    private final static String TEST_DIR = "functions/frame/";
+    private static final String TEST_CLASS_DIR = TEST_DIR + 
FrameColumnNamesTest.class.getSimpleName() + "/";

Review Comment:
   Wrong class



##########
src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java:
##########
@@ -923,6 +925,9 @@ else if ( 
opcode.equalsIgnoreCase(Opcodes.DROPINVALIDLENGTH.toString()) || opcod
                        return new 
BinaryOperator(Builtin.getBuiltinFnObject("dropInvalidLength"));
                else if ( opcode.equalsIgnoreCase(Opcodes.VALUESWAP.toString()) 
|| opcode.equalsIgnoreCase("mapValueSwap") )
                        return new 
BinaryOperator(Builtin.getBuiltinFnObject("valueSwap"));
+               //TODO: Check what  "|| 
opcode.equalsIgnoreCase("mapValueSwap"))" does
+               else if 
(opcode.equalsIgnoreCase(Opcodes.SET_COLNAMES.toString()) || 
opcode.equalsIgnoreCase("mapValueSwap"))

Review Comment:
   Should not be `mapValueSwap`



##########
src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameFrameCPInstruction.java:
##########
@@ -62,6 +62,28 @@ else if(getOpcode().equals(Opcodes.APPLYSCHEMA.toString())) {
                        final int k = 
((MultiThreadedOperator)_optr).getNumThreads();
                        final FrameBlock out = 
FrameLibApplySchema.applySchema(inBlock1, inBlock2, k);
                        ec.setFrameOutput(output.getName(), out);
+               }
+               else if(getOpcode().equals(Opcodes.SET_COLNAMES.toString())) {
+
+                       FrameBlock in = ec.getFrameInput(input1.getName());
+                       FrameBlock names = ec.getFrameInput(input2.getName());
+
+                       String[] colNames = new String[(int) 
names.getNumColumns()];
+                       for(int i = 0; i < colNames.length; i++){
+                               colNames[i] = names.get(0, i).toString();
+                       }

Review Comment:
   No size/data validation happening



##########
src/main/java/org/apache/sysds/runtime/instructions/cp/UnaryFrameCPInstruction.java:
##########
@@ -52,6 +52,14 @@ else if(getOpcode().equals(Opcodes.COLNAMES.toString())) {
                        ec.releaseFrameInput(input1.getName());
                        ec.setFrameOutput(output.getName(), retBlock);
                }
+               //TODO: Check if new OPcode handling has to be implemented
+               else if(getOpcode().equals(Opcodes.COLNAMES.toString())) {
+                       FrameBlock inBlock = ec.getFrameInput(input1.getName());
+                       FrameBlock retBlock = inBlock.getColumnNamesAsFrame();
+                       ec.releaseFrameInput(input1.getName());
+                       ec.setFrameOutput(output.getName(), retBlock);
+               }
+

Review Comment:
   Duplicate code, not needed.



##########
src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java:
##########
@@ -686,6 +686,8 @@ else if( 
opcode.equalsIgnoreCase(Opcodes.VALUESWAP.toString()))
                        return new 
BinaryOperator(Builtin.getBuiltinFnObject("valueSwap"));
                else if( opcode.equalsIgnoreCase(Opcodes.FREPLICATE.toString()))
                        return new 
BinaryOperator(Builtin.getBuiltinFnObject("freplicate"));
+               else if( 
opcode.equalsIgnoreCase(Opcodes.SET_COLNAMES.toString()))
+                       return new 
BinaryOperator(Builtin.getBuiltinFnObject("set_colnames"));

Review Comment:
   Does this return a proper function object currently?



##########
src/test/java/org/apache/sysds/test/functions/frame/FrameColumnNamesTest.java:
##########
@@ -60,6 +62,9 @@ public static Collection<Object[]> data() {
        @Override
        public void setUp() {
                addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"B"}));
+               addTestConfiguration(TEST_NAME_GET, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_SET, new String[] {"B"}));
+               addTestConfiguration(TEST_NAME_SET, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_GET, new String[] {"B"}));

Review Comment:
   Get/set swapped



##########
src/test/java/org/apache/sysds/test/functions/frame/FrameColumnNamesTest.java:
##########
@@ -72,6 +77,107 @@ public void testDetectSchemaDoubleSpark() {
                runGetColNamesTest(_columnNames, ExecType.SPARK);
        }
 
+       @Test
+       public void testGetNamesCP() {
+               runGetNamesTest(_columnNames,  ExecType.CP);
+       }
+
+       @Test
+       public void testSetNamesCP() {
+               runSetNamesTest(_columnNames,  ExecType.CP);
+       }
+
+       private void runGetNamesTest(String[] columnNames, ExecType et) {
+               Types.ExecMode platformOld = setExecMode(et);
+               boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+               setOutputBuffering(true);
+               try {
+                       getAndLoadTestConfiguration(TEST_NAME);
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + TEST_NAME_GET + ".dml";
+                       programArgs = new String[] {"-args", input("A"), 
String.valueOf(_rows),
+                                       Integer.toString(columnNames.length), 
output("B")};
+
+                       Types.ValueType[] schema = Collections.nCopies(
+                                       columnNames.length, 
Types.ValueType.FP64).toArray(new Types.ValueType[0]);
+                       FrameBlock frame1 = new FrameBlock(schema);
+                       frame1.setColumnNames(columnNames);
+                       FrameWriter writer = 
FrameWriterFactory.createFrameWriter(FileFormat.CSV,
+                                       new FileFormatPropertiesCSV(true, ",", 
false));
+
+                       double[][] A = getRandomMatrix(_rows, schema.length, 
Double.MIN_VALUE, Double.MAX_VALUE, 0.7, 14123);
+                       TestUtils.initFrameData(frame1, A, schema, _rows);
+                       writer.writeFrameToHDFS(frame1, input("A"), _rows, 
schema.length);
+
+                       runTest(true, false, null, -1);
+                       FrameBlock frame2 = readDMLFrameFromHDFS("B", 
FileFormat.BINARY);
+
+                       // verify output schema
+                       for(int i = 0; i < schema.length; i++) {
+                               Assert
+                                               .assertEquals("Wrong result: " 
+ columnNames[i] + ".", columnNames[i], frame2.get(0, i).toString());
+                       }
+               }
+               catch(Exception ex) {
+                       throw new RuntimeException(ex);
+               }
+               finally {
+                       rtplatform = platformOld;
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+               }
+       }
+
+       private void runSetNamesTest(String[] columnNames, ExecType et) {
+               Types.ExecMode platformOld = setExecMode(et);
+               boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+               setOutputBuffering(true);
+               try {
+                       getAndLoadTestConfiguration(TEST_NAME_SET);
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + TEST_NAME_SET + ".dml";
+                       programArgs = new String[] 
{"-args",input("X"),String.valueOf(_rows),Integer.toString(columnNames.length),
+                                       input("N"),output("B")
+                       };
+
+                       Types.ValueType[] schema = Collections.nCopies(
+                                       columnNames.length, 
Types.ValueType.FP64).toArray(new Types.ValueType[0]);
+
+                       FrameBlock frame1 = new FrameBlock(schema);
+                       FrameWriter writer = 
FrameWriterFactory.createFrameWriter(FileFormat.CSV,
+                                       new FileFormatPropertiesCSV(true, ",", 
false));
+
+                       double[][] A = getRandomMatrix(_rows, schema.length, 
Double.MIN_VALUE, Double.MAX_VALUE, 0.7, 14123);
+                       TestUtils.initFrameData(frame1, A, schema, _rows);
+                       writer.writeFrameToHDFS(frame1, input("X"), _rows, 
schema.length);
+
+                       Types.ValueType[] nameSchema = Collections.nCopies(
+                                       columnNames.length, 
Types.ValueType.STRING).toArray(new Types.ValueType[0]);
+
+                       FrameBlock names = new FrameBlock(nameSchema);
+                       names.ensureAllocatedColumns(1);
+                       for(int i = 0; i < columnNames.length; i++)
+                               names.set(0, i, columnNames[i]);
+                       FrameWriter nameWriter = 
FrameWriterFactory.createFrameWriter(FileFormat.CSV,
+                                       new FileFormatPropertiesCSV(false, ",", 
false));
+                       System.out.println("N path = " + input("N"));

Review Comment:
   Please avoid those prints



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to