This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new a0987e536a [SYSTEMDS-3343,3366] Fix missing handling of positional
defaults in eval
a0987e536a is described below
commit a0987e536a2be71d16d64ac64e9873206083e49b
Author: Matthias Boehm <[email protected]>
AuthorDate: Tue May 10 20:46:29 2022 +0200
[SYSTEMDS-3343,3366] Fix missing handling of positional defaults in eval
This patch extends the recently added support for adding named defaults
in eval function calls generic functions like gridSearch. We now
extended this functionality for positional default as well, which
broadens the set of functions that can be used in transformencode,
UDF encoders.
---
.../instructions/cp/EvalNaryCPInstruction.java | 28 +++++++++++++++++++++-
.../runtime/transform/encode/ColumnEncoderUDF.java | 6 +++--
.../transform/TransformEncodeUDFTest.java | 19 +++++++--------
.../functions/transform/TransformEncodeUDF2.dml | 2 +-
4 files changed, 41 insertions(+), 14 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
index b7d315c612..9f151e14cf 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
@@ -142,7 +142,9 @@ public class EvalNaryCPInstruction extends
BuiltinNaryCPInstruction {
&& !(fpb.getInputParams().size() == 1 &&
fpb.getInputParams().get(0).getDataType().isList()))
{
ListObject lo = ec.getListObject(boundInputs[0]);
- lo = appendNamedDefaults(lo, fpb.getStatementBlock());
+ lo = lo.isNamedList() ?
+ appendNamedDefaults(lo,
fpb.getStatementBlock()) :
+ appendPositionalDefaults(lo,
fpb.getStatementBlock());
checkValidArguments(lo.getData(), lo.getNames(),
fpb.getInputParamNames());
if( lo.isNamedList() )
lo = reorderNamedListForFunctionCall(lo,
fpb.getInputParamNames());
@@ -305,6 +307,30 @@ public class EvalNaryCPInstruction extends
BuiltinNaryCPInstruction {
return ret;
}
+ private static ListObject appendPositionalDefaults(ListObject params,
StatementBlock sb) {
+ if( sb == null )
+ return params;
+
+ //best effort replacement of scalar literal defaults
+ FunctionStatement fstmt = (FunctionStatement)
sb.getStatement(0);
+ ListObject ret = new ListObject(params);
+ for( int i=ret.getLength(); i<fstmt.getInputParams().size();
i++ ) {
+ String param = fstmt.getInputParamNames()[i];
+ if( !(fstmt.getInputDefaults().get(i) != null
+ &&
fstmt.getInputParams().get(i).getDataType().isScalar()
+ && fstmt.getInputDefaults().get(i) instanceof
ConstIdentifier) )
+ throw new DMLRuntimeException("Unable to append
positional scalar default for '"+param+"'");
+ ValueType vt =
fstmt.getInputParams().get(i).getValueType();
+ Expression expr = fstmt.getInputDefaults().get(i);
+ ScalarObject sobj =
ScalarObjectFactory.createScalarObject(vt, expr.toString());
+ LineageItem litem = !DMLScript.LINEAGE ? null :
+
LineageItemUtils.createScalarLineageItem(ScalarObjectFactory.createLiteralOp(sobj));
+ ret.add(sobj, litem);
+ }
+
+ return ret;
+ }
+
private static void checkValidArguments(List<Data> loData, List<String>
loNames, List<String> fArgNames) {
//check number of parameters
int listSize = (loNames != null) ? loNames.size() :
loData.size();
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderUDF.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderUDF.java
index 15fa568d65..a3f76623f2 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderUDF.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderUDF.java
@@ -33,7 +33,9 @@ import
org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContextFactory;
import org.apache.sysds.runtime.controlprogram.paramserv.ParamservUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.EvalNaryCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.ListObject;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.util.DependencyTask;
@@ -75,7 +77,7 @@ public class ColumnEncoderUDF extends ColumnEncoder {
//create execution context and input
ExecutionContext ec = ExecutionContextFactory.createContext(new
Program(new DMLProgram()));
MatrixBlock col = out.slice(0, in.getNumRows()-1, _colID-1,
_colID-1, new MatrixBlock());
- ec.setVariable("I", ParamservUtils.newMatrixObject(col, true));
+ ec.setVariable("I", new ListObject(new Data[]
{ParamservUtils.newMatrixObject(col, true)}));
ec.setVariable("O", ParamservUtils.newMatrixObject(col, true));
//call UDF function via eval machinery
@@ -83,7 +85,7 @@ public class ColumnEncoderUDF extends ColumnEncoder {
new CPOperand("O", ValueType.FP64, DataType.MATRIX),
new CPOperand[] {
new CPOperand(_fName, ValueType.STRING,
DataType.SCALAR, true),
- new CPOperand("I", ValueType.FP64,
DataType.MATRIX)});
+ new CPOperand("I", ValueType.UNKNOWN,
DataType.LIST)});
fun.processInstruction(ec);
//obtain result and in-place write back
diff --git
a/src/test/java/org/apache/sysds/test/functions/transform/TransformEncodeUDFTest.java
b/src/test/java/org/apache/sysds/test/functions/transform/TransformEncodeUDFTest.java
index e7ccc8d582..1586a51b7d 100644
---
a/src/test/java/org/apache/sysds/test/functions/transform/TransformEncodeUDFTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/transform/TransformEncodeUDFTest.java
@@ -56,16 +56,15 @@ public class TransformEncodeUDFTest extends
AutomatedTestBase
runTransformTest(ExecMode.HYBRID, TEST_NAME1);
}
-// TODO default handling without named lists
-// @Test
-// public void testUDF2Singlenode() {
-// runTransformTest(ExecMode.SINGLE_NODE, TEST_NAME2);
-// }
-//
-// @Test
-// public void testUDF2Hybrid() {
-// runTransformTest(ExecMode.HYBRID, TEST_NAME2);
-// }
+ @Test
+ public void testUDF2Singlenode() {
+ runTransformTest(ExecMode.SINGLE_NODE, TEST_NAME2);
+ }
+
+ @Test
+ public void testUDF2Hybrid() {
+ runTransformTest(ExecMode.HYBRID, TEST_NAME2);
+ }
private void runTransformTest(ExecMode rt, String testname)
{
diff --git a/src/test/scripts/functions/transform/TransformEncodeUDF2.dml
b/src/test/scripts/functions/transform/TransformEncodeUDF2.dml
index 233f25b73c..a62ca2d860 100644
--- a/src/test/scripts/functions/transform/TransformEncodeUDF2.dml
+++ b/src/test/scripts/functions/transform/TransformEncodeUDF2.dml
@@ -34,6 +34,6 @@ jspec2 = "{ids: true, recode: [1, 2, 7], udf: {name: scale,
ids: [1, 2, 3, 4, 5,
while(FALSE){}
-R = sum(R1==R2);
+R = sum(abs(R1-R2)<1e-10);
write(R, $R);