This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 8d61ae6  [SYSTEMDS-2572] Additional mlcontext test for nn-library 
imports
8d61ae6 is described below

commit 8d61ae6f46f0a8ce21f9ad7c3a617023f6983778
Author: Matthias Boehm <[email protected]>
AuthorDate: Fri Jul 24 23:08:27 2020 +0200

    [SYSTEMDS-2572] Additional mlcontext test for nn-library imports
    
    The bug reported in SYSTEMDS-2572 was non-reproducible both in a local
    environment as well as through spark-shell. However, as the mlcontext
    tests did not include a test for sourcing (importing) dml scripts, we
    add the related test script accordingly.
---
 .../sysds/test/functions/mlcontext/MLContextTest.java      | 14 +++++++++++++-
 1 file changed, 13 insertions(+), 1 deletion(-)

diff --git 
a/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextTest.java 
b/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextTest.java
index 3e07b15..697e9e9 100644
--- a/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextTest.java
@@ -1904,5 +1904,17 @@ public class MLContextTest extends MLContextTestBase {
                Assert.assertEquals(true, c);
                Assert.assertEquals("yes it's TRUE", d);
        }
-
+       
+       @Test
+       public void testNNImport() {
+               System.out.println("MLContextTest - NN import");
+               String s =    "source(\"scripts/nn/layers/relu.dml\") as 
relu;\n"
+                                       + "X = rand(rows=100, cols=10, min=-1, 
max=1);\n"
+                                       + "R1 = relu::forward(X);\n"
+                                       + "R2 = max(X, 0);\n"
+                                       + "R = sum(R1==R2);\n";
+               double ret = ml.execute(dml(s).out("R"))
+                       .getScalarObject("R").getDoubleValue();
+               Assert.assertEquals(1000, ret, 1e-20);
+       }
 }

Reply via email to