This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 706ee13  [SYSTEMDS-2960] Fix codegen binary template for element-wise 
ops
706ee13 is described below

commit 706ee13c4d0f2a5bbb02781b589c27cf1c77030b
Author: ywcb00 <[email protected]>
AuthorDate: Fri May 14 15:25:05 2021 +0200

    [SYSTEMDS-2960] Fix codegen binary template for element-wise ops
    
    Closes #1245.
---
 src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java | 7 ++++++-
 1 file changed, 6 insertions(+), 1 deletion(-)

diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java 
b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java
index 6d53e1e..2e6bcd5 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java
@@ -76,6 +76,10 @@ public class CNodeBinary extends CNode {
                        return ssComm || vsComm || vvComm;
                }
                
+               public boolean isElementwise() {
+                       return this != DOT_PRODUCT && this != VECT_MATRIXMULT 
&& this != VECT_OUTERMULT_ADD;
+               }
+               
                public boolean isVectorPrimitive() {
                        return isVectorScalarPrimitive() 
                                || isVectorVectorPrimitive()
@@ -184,7 +188,8 @@ public class CNodeBinary extends CNode {
                        //replace start position of main input
                        tmp = tmp.replace("%POS"+(j+1)+"%", (_inputs.get(j) 
instanceof CNodeData 
                                        && 
_inputs.get(j).getDataType().isMatrix()) ? (!varj.startsWith("b")) ? varj+"i" : 
-                                       (TemplateUtils.isMatrix(_inputs.get(j)) 
&& _type!=BinType.VECT_MATRIXMULT) ? 
+                                       
((TemplateUtils.isMatrix(_inputs.get(j)) || (_type.isElementwise()
+                                               && 
TemplateUtils.isColVector(_inputs.get(j)))) && _type!=BinType.VECT_MATRIXMULT) ?
                                        varj + ".pos(rix)" : "0" : "0");
                }
                //replace length information (e.g., after matrix mult)

Reply via email to