This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new a4deb7188a [SYSTEMDS-3500] Fix lineage support / tests for 
contains-value function
a4deb7188a is described below

commit a4deb7188a4ca8b45df9e58efef289e21e06a93d
Author: Matthias Boehm <[email protected]>
AuthorDate: Thu Feb 23 23:05:25 2023 +0100

    [SYSTEMDS-3500] Fix lineage support / tests for contains-value function
    
    This patch fixes missing lineage reconstruction support and one python
    test for the new contains-value function.
---
 .../instructions/cp/ParameterizedBuiltinCPInstruction.java        | 8 +++++++-
 src/main/python/tests/federated/test_federated_mnist.py           | 2 +-
 2 files changed, 8 insertions(+), 2 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
index d3c88fd5ff..a67f8cd20d 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
@@ -474,7 +474,13 @@ public class ParameterizedBuiltinCPInstruction extends 
ComputationCPInstruction
        @Override
        public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
                String opcode = getOpcode();
-               if(opcode.equalsIgnoreCase("groupedagg")) {
+               if(opcode.equalsIgnoreCase("contains")) {
+                       CPOperand target = getTargetOperand();
+                       CPOperand pattern = getFP64Literal("pattern");
+                       return Pair.of(output.getName(),
+                               new LineageItem(getOpcode(), 
LineageItemUtils.getLineage(ec, target, pattern)));
+               }
+               else if(opcode.equalsIgnoreCase("groupedagg")) {
                        CPOperand target = getTargetOperand();
                        CPOperand groups = new 
CPOperand(params.get(Statement.GAGG_GROUPS), ValueType.FP64, DataType.MATRIX);
                        String wt = params.containsKey(Statement.GAGG_WEIGHTS) 
? params.get(Statement.GAGG_WEIGHTS) : String
diff --git a/src/main/python/tests/federated/test_federated_mnist.py 
b/src/main/python/tests/federated/test_federated_mnist.py
index 3b11bd3194..c49f64897c 100644
--- a/src/main/python/tests/federated/test_federated_mnist.py
+++ b/src/main/python/tests/federated/test_federated_mnist.py
@@ -114,7 +114,7 @@ class TestFederatedMnist(unittest.TestCase):
         with self.sds.capture_stats_context():
             [_, _, acc] = multiLogRegPredict(Xt, bias, Yt).compute()
         stats = self.sds.take_stats()
-        for fed_instr in ["fed_isnan", "fed_*", "fed_-", "fed_uark+", 
"fed_r'", "fed_rightIndex"]:
+        for fed_instr in ["fed_contains", "fed_*", "fed_-", "fed_uark+", 
"fed_r'", "fed_rightIndex"]:
             self.assertIn(fed_instr, stats)
         self.assertGreater(acc, 80)
 

Reply via email to