http://git-wip-us.apache.org/repos/asf/systemml/blob/57438103/src/main/java/org/apache/sysml/runtime/instructions/mr/CtableInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/mr/CtableInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/mr/CtableInstruction.java new file mode 100644 index 0000000..213f401 --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/instructions/mr/CtableInstruction.java @@ -0,0 +1,264 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.instructions.mr; + +import java.util.HashMap; + +import org.apache.sysml.lops.Ctable; +import org.apache.sysml.lops.Ctable.OperationTypes; +import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.runtime.instructions.InstructionUtils; +import org.apache.sysml.runtime.matrix.data.CTableMap; +import org.apache.sysml.runtime.matrix.data.MatrixBlock; +import org.apache.sysml.runtime.matrix.data.MatrixValue; +import org.apache.sysml.runtime.matrix.data.OperationsOnMatrixValues; +import org.apache.sysml.runtime.matrix.mapred.CachedValueMap; +import org.apache.sysml.runtime.matrix.mapred.IndexedMatrixValue; + +public class CtableInstruction extends MRInstruction { + private OperationTypes _op; + + public byte input1; + public byte input2; + public byte input3; + public double scalar_input2; + public double scalar_input3; + private long _outputDim1, _outputDim2; + + private CtableInstruction(MRType type, OperationTypes op, byte in1, double scalar_in2, double scalar_in3, byte out, + long outputDim1, long outputDim2, String istr) { + super(type, null, out); + _op = op; + input1 = in1; + scalar_input2 = scalar_in2; + scalar_input3 = scalar_in3; + _outputDim1 = outputDim1; + _outputDim2 = outputDim2; + instString = istr; + } + + private CtableInstruction(MRType type, OperationTypes op, byte in1, byte in2, double scalar_in3, byte out, long outputDim1, + long outputDim2, String istr) { + super(type, null, out); + _op = op; + input1 = in1; + input2 = in2; + scalar_input3 = scalar_in3; + _outputDim1 = outputDim1; + _outputDim2 = outputDim2; + instString = istr; + } + + private CtableInstruction(MRType type, OperationTypes op, byte in1, double scalar_in2, byte in3, byte out, long outputDim1, + long outputDim2, String istr) { + super(type, null, out); + _op = op; + input1 = in1; + scalar_input2 = scalar_in2; + input3 = in3; + _outputDim1 = outputDim1; + _outputDim2 = outputDim2; + instString = istr; + } + + protected CtableInstruction(MRType type, OperationTypes op, byte in1, byte in2, byte in3, byte out, long outputDim1, + long outputDim2, String istr) { + super(type, null, out); + _op = op; + input1 = in1; + input2 = in2; + input3 = in3; + _outputDim1 = outputDim1; + _outputDim2 = outputDim2; + instString = istr; + } + + public long getOutputDim1() { + return _outputDim1; + } + + public long getOutputDim2() { + return _outputDim2; + } + + public boolean knownOutputDims() { + return (_outputDim1 >0 && _outputDim2>0); + } + + public static CtableInstruction parseInstruction ( String str ) + throws DMLRuntimeException + { + // example instruction string + // - ctabletransform:::0:DOUBLE:::1:DOUBLE:::2:DOUBLE:::3:DOUBLE + // - ctabletransformscalarweight:::0:DOUBLE:::1:DOUBLE:::1.0:DOUBLE:::3:DOUBLE + // - ctabletransformhistogram:::0:DOUBLE:::1.0:DOUBLE:::1.0:DOUBLE:::3:DOUBLE + // - ctabletransformweightedhistogram:::0:DOUBLE:::1:INT:::1:DOUBLE:::2:DOUBLE + + //check number of fields + InstructionUtils.checkNumFields ( str, 6 ); + + //common setup + byte in1, in2, in3, out; + String[] parts = InstructionUtils.getInstructionParts ( str ); + String opcode = parts[0]; + in1 = Byte.parseByte(parts[1]); + long outputDim1 = (long) Double.parseDouble(parts[4]); + long outputDim2 = (long) Double.parseDouble(parts[5]); + out = Byte.parseByte(parts[6]); + + OperationTypes op = Ctable.getOperationType(opcode); + + switch( op ) + { + case CTABLE_TRANSFORM: { + in2 = Byte.parseByte(parts[2]); + in3 = Byte.parseByte(parts[3]); + return new CtableInstruction(MRType.Ctable, op, in1, in2, in3, out, outputDim1, outputDim2, str); + } + case CTABLE_TRANSFORM_SCALAR_WEIGHT: { + in2 = Byte.parseByte(parts[2]); + double scalar_in3 = Double.parseDouble(parts[3]); + return new CtableInstruction(MRType.Ctable, op, in1, in2, scalar_in3, out, outputDim1, outputDim2, str); + } + case CTABLE_EXPAND_SCALAR_WEIGHT: { + double scalar_in2 = Double.parseDouble(parts[2]); + double type = Double.parseDouble(parts[3]); //used as type (1 left, 0 right) + return new CtableInstruction(MRType.Ctable, op, in1, scalar_in2, type, out, outputDim1, outputDim2, str); + } + case CTABLE_TRANSFORM_HISTOGRAM: { + double scalar_in2 = Double.parseDouble(parts[2]); + double scalar_in3 = Double.parseDouble(parts[3]); + return new CtableInstruction(MRType.Ctable, op, in1, scalar_in2, scalar_in3, out, outputDim1, outputDim2, str); + } + case CTABLE_TRANSFORM_WEIGHTED_HISTOGRAM: { + double scalar_in2 = Double.parseDouble(parts[2]); + in3 = Byte.parseByte(parts[3]); + return new CtableInstruction(MRType.Ctable, op, in1, scalar_in2, in3, out, outputDim1, outputDim2, str); + } + default: + throw new DMLRuntimeException("Unrecognized opcode in Ternary Instruction: " + op); + } + } + + public void processInstruction(Class<? extends MatrixValue> valueClass, CachedValueMap cachedValues, + IndexedMatrixValue zeroInput, HashMap<Byte, CTableMap> resultMaps, HashMap<Byte, MatrixBlock> resultBlocks, + int blockRowFactor, int blockColFactor) + throws DMLRuntimeException + { + + IndexedMatrixValue in1, in2, in3 = null; + in1 = cachedValues.getFirst(input1); + + CTableMap ctableResult = null; + MatrixBlock ctableResultBlock = null; + + if ( knownOutputDims() ) { + if ( resultBlocks != null ) { + ctableResultBlock = resultBlocks.get(output); + if ( ctableResultBlock == null ) { + // From MR, output of ctable is set to be sparse since it is built from a single input block. + ctableResultBlock = new MatrixBlock((int)_outputDim1, (int)_outputDim2, true); + resultBlocks.put(output, ctableResultBlock); + } + } + else { + throw new DMLRuntimeException("Unexpected error in processing table instruction."); + } + } + else { + //prepare aggregation maps + ctableResult=resultMaps.get(output); + if(ctableResult==null) + { + ctableResult = new CTableMap(); + resultMaps.put(output, ctableResult); + } + } + + //get inputs and process instruction + switch( _op ) + { + case CTABLE_TRANSFORM: { + in2 = cachedValues.getFirst(input2); + in3 = cachedValues.getFirst(input3); + if(in1==null || in2==null || in3 == null ) + return; + OperationsOnMatrixValues.performCtable(in1.getIndexes(), in1.getValue(), in2.getIndexes(), in2.getValue(), + in3.getIndexes(), in3.getValue(), ctableResult, ctableResultBlock, optr); + break; + } + case CTABLE_TRANSFORM_SCALAR_WEIGHT: { + // 3rd input is a scalar + in2 = cachedValues.getFirst(input2); + if(in1==null || in2==null ) + return; + OperationsOnMatrixValues.performCtable(in1.getIndexes(), in1.getValue(), in2.getIndexes(), in2.getValue(), + scalar_input3, ctableResult, ctableResultBlock, optr); + break; + } + case CTABLE_EXPAND_SCALAR_WEIGHT: { + // 2nd and 3rd input is a scalar + if(in1==null ) + return; + OperationsOnMatrixValues.performCtable(in1.getIndexes(), in1.getValue(), scalar_input2, (scalar_input3==1), + blockRowFactor, ctableResult, ctableResultBlock, optr); + break; + } + case CTABLE_TRANSFORM_HISTOGRAM: { + // 2nd and 3rd inputs are scalars + if(in1==null ) + return; + OperationsOnMatrixValues.performCtable(in1.getIndexes(), in1.getValue(), scalar_input2, scalar_input3, ctableResult, ctableResultBlock, optr); + break; + } + case CTABLE_TRANSFORM_WEIGHTED_HISTOGRAM: { + // 2nd and 3rd inputs are scalars + in3 = cachedValues.getFirst(input3); + if(in1==null || in3==null) + return; + OperationsOnMatrixValues.performCtable(in1.getIndexes(), in1.getValue(), scalar_input2, + in3.getIndexes(), in3.getValue(), ctableResult, ctableResultBlock, optr); + break; + } + default: + throw new DMLRuntimeException("Unrecognized opcode in Tertiary Instruction: " + instString); + } + } + + @Override + public void processInstruction(Class<? extends MatrixValue> valueClass, + CachedValueMap cachedValues, IndexedMatrixValue tempValue, IndexedMatrixValue zeroInput, + int blockRowFactor, int blockColFactor) + throws DMLRuntimeException + { + throw new DMLRuntimeException("This function should not be called!"); + } + + @Override + public byte[] getAllIndexes() throws DMLRuntimeException { + return new byte[]{input1, input2, input3, output}; + } + + @Override + public byte[] getInputIndexes() throws DMLRuntimeException { + return new byte[]{input1, input2, input3}; + } + +}
http://git-wip-us.apache.org/repos/asf/systemml/blob/57438103/src/main/java/org/apache/sysml/runtime/instructions/mr/MRInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/mr/MRInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/mr/MRInstruction.java index b2e5283..355ccb4 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/mr/MRInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/mr/MRInstruction.java @@ -32,7 +32,7 @@ public abstract class MRInstruction extends Instruction { public enum MRType { Append, Aggregate, Binary, Binary2, AggregateBinary, AggregateUnary, Rand, Seq, CSVReblock, CSVWrite, Reblock, Reorg, Replicate, Unary, CombineBinary, CombineUnary, CombineTernary, - PickByCount, Partition, Ternary, Quaternary, CM_N_COV, MapGroupedAggregate, GroupedAggregate, RightIndex, + PickByCount, Partition, Ctable, Quaternary, CM_N_COV, MapGroupedAggregate, GroupedAggregate, RightIndex, ZeroOut, MMTSJ, PMMJ, MatrixReshape, ParameterizedBuiltin, Sort, MapMultChain, CumsumAggregate, CumsumSplit, CumsumOffset, BinUaggChain, UaggOuterChain, RemoveEmpty } http://git-wip-us.apache.org/repos/asf/systemml/blob/57438103/src/main/java/org/apache/sysml/runtime/instructions/mr/QuaternaryInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/mr/QuaternaryInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/mr/QuaternaryInstruction.java index 8efddb2..98504b0 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/mr/QuaternaryInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/mr/QuaternaryInstruction.java @@ -293,7 +293,7 @@ public class QuaternaryInstruction extends MRInstruction implements IDistributed if( imv==null ) continue; MatrixIndexes inIx = imv.getIndexes(); - MatrixValue inVal = imv.getValue(); + MatrixBlock inVal = (MatrixBlock) imv.getValue(); //allocate space for the output value IndexedMatrixValue iout = null; @@ -305,8 +305,8 @@ public class QuaternaryInstruction extends MRInstruction implements IDistributed MatrixIndexes outIx = iout.getIndexes(); MatrixValue outVal = iout.getValue(); - //Step 2: get remaining inputs: Wij, Ui, Vj - MatrixValue Xij = inVal; + //Step 2: get remaining inputs: Wij, Ui, Vj + MatrixBlock Xij = inVal; //get Wij if existing (null of WeightsType.NONE or WSigmoid any type) IndexedMatrixValue iWij = (_input4 != -1) ? cachedValues.getFirst(_input4) : null; @@ -331,7 +331,7 @@ public class QuaternaryInstruction extends MRInstruction implements IDistributed } //Step 3: process instruction - Xij.quaternaryOperations(qop, Ui, Vj, Wij, outVal); + Xij.quaternaryOperations(qop, (MatrixBlock)Ui, (MatrixBlock)Vj, (MatrixBlock)Wij, (MatrixBlock)outVal); //set output indexes http://git-wip-us.apache.org/repos/asf/systemml/blob/57438103/src/main/java/org/apache/sysml/runtime/instructions/mr/TernaryInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/mr/TernaryInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/mr/TernaryInstruction.java deleted file mode 100644 index 14772f2..0000000 --- a/src/main/java/org/apache/sysml/runtime/instructions/mr/TernaryInstruction.java +++ /dev/null @@ -1,264 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.sysml.runtime.instructions.mr; - -import java.util.HashMap; - -import org.apache.sysml.lops.Ternary; -import org.apache.sysml.lops.Ternary.OperationTypes; -import org.apache.sysml.runtime.DMLRuntimeException; -import org.apache.sysml.runtime.instructions.InstructionUtils; -import org.apache.sysml.runtime.matrix.data.CTableMap; -import org.apache.sysml.runtime.matrix.data.MatrixBlock; -import org.apache.sysml.runtime.matrix.data.MatrixValue; -import org.apache.sysml.runtime.matrix.data.OperationsOnMatrixValues; -import org.apache.sysml.runtime.matrix.mapred.CachedValueMap; -import org.apache.sysml.runtime.matrix.mapred.IndexedMatrixValue; - -public class TernaryInstruction extends MRInstruction { - private OperationTypes _op; - - public byte input1; - public byte input2; - public byte input3; - public double scalar_input2; - public double scalar_input3; - private long _outputDim1, _outputDim2; - - private TernaryInstruction(MRType type, OperationTypes op, byte in1, double scalar_in2, double scalar_in3, byte out, - long outputDim1, long outputDim2, String istr) { - super(type, null, out); - _op = op; - input1 = in1; - scalar_input2 = scalar_in2; - scalar_input3 = scalar_in3; - _outputDim1 = outputDim1; - _outputDim2 = outputDim2; - instString = istr; - } - - private TernaryInstruction(MRType type, OperationTypes op, byte in1, byte in2, double scalar_in3, byte out, long outputDim1, - long outputDim2, String istr) { - super(type, null, out); - _op = op; - input1 = in1; - input2 = in2; - scalar_input3 = scalar_in3; - _outputDim1 = outputDim1; - _outputDim2 = outputDim2; - instString = istr; - } - - private TernaryInstruction(MRType type, OperationTypes op, byte in1, double scalar_in2, byte in3, byte out, long outputDim1, - long outputDim2, String istr) { - super(type, null, out); - _op = op; - input1 = in1; - scalar_input2 = scalar_in2; - input3 = in3; - _outputDim1 = outputDim1; - _outputDim2 = outputDim2; - instString = istr; - } - - protected TernaryInstruction(MRType type, OperationTypes op, byte in1, byte in2, byte in3, byte out, long outputDim1, - long outputDim2, String istr) { - super(type, null, out); - _op = op; - input1 = in1; - input2 = in2; - input3 = in3; - _outputDim1 = outputDim1; - _outputDim2 = outputDim2; - instString = istr; - } - - public long getOutputDim1() { - return _outputDim1; - } - - public long getOutputDim2() { - return _outputDim2; - } - - public boolean knownOutputDims() { - return (_outputDim1 >0 && _outputDim2>0); - } - - public static TernaryInstruction parseInstruction ( String str ) - throws DMLRuntimeException - { - // example instruction string - // - ctabletransform:::0:DOUBLE:::1:DOUBLE:::2:DOUBLE:::3:DOUBLE - // - ctabletransformscalarweight:::0:DOUBLE:::1:DOUBLE:::1.0:DOUBLE:::3:DOUBLE - // - ctabletransformhistogram:::0:DOUBLE:::1.0:DOUBLE:::1.0:DOUBLE:::3:DOUBLE - // - ctabletransformweightedhistogram:::0:DOUBLE:::1:INT:::1:DOUBLE:::2:DOUBLE - - //check number of fields - InstructionUtils.checkNumFields ( str, 6 ); - - //common setup - byte in1, in2, in3, out; - String[] parts = InstructionUtils.getInstructionParts ( str ); - String opcode = parts[0]; - in1 = Byte.parseByte(parts[1]); - long outputDim1 = (long) Double.parseDouble(parts[4]); - long outputDim2 = (long) Double.parseDouble(parts[5]); - out = Byte.parseByte(parts[6]); - - OperationTypes op = Ternary.getOperationType(opcode); - - switch( op ) - { - case CTABLE_TRANSFORM: { - in2 = Byte.parseByte(parts[2]); - in3 = Byte.parseByte(parts[3]); - return new TernaryInstruction(MRType.Ternary, op, in1, in2, in3, out, outputDim1, outputDim2, str); - } - case CTABLE_TRANSFORM_SCALAR_WEIGHT: { - in2 = Byte.parseByte(parts[2]); - double scalar_in3 = Double.parseDouble(parts[3]); - return new TernaryInstruction(MRType.Ternary, op, in1, in2, scalar_in3, out, outputDim1, outputDim2, str); - } - case CTABLE_EXPAND_SCALAR_WEIGHT: { - double scalar_in2 = Double.parseDouble(parts[2]); - double type = Double.parseDouble(parts[3]); //used as type (1 left, 0 right) - return new TernaryInstruction(MRType.Ternary, op, in1, scalar_in2, type, out, outputDim1, outputDim2, str); - } - case CTABLE_TRANSFORM_HISTOGRAM: { - double scalar_in2 = Double.parseDouble(parts[2]); - double scalar_in3 = Double.parseDouble(parts[3]); - return new TernaryInstruction(MRType.Ternary, op, in1, scalar_in2, scalar_in3, out, outputDim1, outputDim2, str); - } - case CTABLE_TRANSFORM_WEIGHTED_HISTOGRAM: { - double scalar_in2 = Double.parseDouble(parts[2]); - in3 = Byte.parseByte(parts[3]); - return new TernaryInstruction(MRType.Ternary, op, in1, scalar_in2, in3, out, outputDim1, outputDim2, str); - } - default: - throw new DMLRuntimeException("Unrecognized opcode in Ternary Instruction: " + op); - } - } - - public void processInstruction(Class<? extends MatrixValue> valueClass, CachedValueMap cachedValues, - IndexedMatrixValue zeroInput, HashMap<Byte, CTableMap> resultMaps, HashMap<Byte, MatrixBlock> resultBlocks, - int blockRowFactor, int blockColFactor) - throws DMLRuntimeException - { - - IndexedMatrixValue in1, in2, in3 = null; - in1 = cachedValues.getFirst(input1); - - CTableMap ctableResult = null; - MatrixBlock ctableResultBlock = null; - - if ( knownOutputDims() ) { - if ( resultBlocks != null ) { - ctableResultBlock = resultBlocks.get(output); - if ( ctableResultBlock == null ) { - // From MR, output of ctable is set to be sparse since it is built from a single input block. - ctableResultBlock = new MatrixBlock((int)_outputDim1, (int)_outputDim2, true); - resultBlocks.put(output, ctableResultBlock); - } - } - else { - throw new DMLRuntimeException("Unexpected error in processing table instruction."); - } - } - else { - //prepare aggregation maps - ctableResult=resultMaps.get(output); - if(ctableResult==null) - { - ctableResult = new CTableMap(); - resultMaps.put(output, ctableResult); - } - } - - //get inputs and process instruction - switch( _op ) - { - case CTABLE_TRANSFORM: { - in2 = cachedValues.getFirst(input2); - in3 = cachedValues.getFirst(input3); - if(in1==null || in2==null || in3 == null ) - return; - OperationsOnMatrixValues.performTernary(in1.getIndexes(), in1.getValue(), in2.getIndexes(), in2.getValue(), - in3.getIndexes(), in3.getValue(), ctableResult, ctableResultBlock, optr); - break; - } - case CTABLE_TRANSFORM_SCALAR_WEIGHT: { - // 3rd input is a scalar - in2 = cachedValues.getFirst(input2); - if(in1==null || in2==null ) - return; - OperationsOnMatrixValues.performTernary(in1.getIndexes(), in1.getValue(), in2.getIndexes(), in2.getValue(), - scalar_input3, ctableResult, ctableResultBlock, optr); - break; - } - case CTABLE_EXPAND_SCALAR_WEIGHT: { - // 2nd and 3rd input is a scalar - if(in1==null ) - return; - OperationsOnMatrixValues.performTernary(in1.getIndexes(), in1.getValue(), scalar_input2, (scalar_input3==1), - blockRowFactor, ctableResult, ctableResultBlock, optr); - break; - } - case CTABLE_TRANSFORM_HISTOGRAM: { - // 2nd and 3rd inputs are scalars - if(in1==null ) - return; - OperationsOnMatrixValues.performTernary(in1.getIndexes(), in1.getValue(), scalar_input2, scalar_input3, ctableResult, ctableResultBlock, optr); - break; - } - case CTABLE_TRANSFORM_WEIGHTED_HISTOGRAM: { - // 2nd and 3rd inputs are scalars - in3 = cachedValues.getFirst(input3); - if(in1==null || in3==null) - return; - OperationsOnMatrixValues.performTernary(in1.getIndexes(), in1.getValue(), scalar_input2, - in3.getIndexes(), in3.getValue(), ctableResult, ctableResultBlock, optr); - break; - } - default: - throw new DMLRuntimeException("Unrecognized opcode in Tertiary Instruction: " + instString); - } - } - - @Override - public void processInstruction(Class<? extends MatrixValue> valueClass, - CachedValueMap cachedValues, IndexedMatrixValue tempValue, IndexedMatrixValue zeroInput, - int blockRowFactor, int blockColFactor) - throws DMLRuntimeException - { - throw new DMLRuntimeException("This function should not be called!"); - } - - @Override - public byte[] getAllIndexes() throws DMLRuntimeException { - return new byte[]{input1, input2, input3, output}; - } - - @Override - public byte[] getInputIndexes() throws DMLRuntimeException { - return new byte[]{input1, input2, input3}; - } - -} http://git-wip-us.apache.org/repos/asf/systemml/blob/57438103/src/main/java/org/apache/sysml/runtime/instructions/spark/CtableSPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/CtableSPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/CtableSPInstruction.java new file mode 100644 index 0000000..001e7b6 --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/CtableSPInstruction.java @@ -0,0 +1,510 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.instructions.spark; + +import java.util.ArrayList; +import java.util.Iterator; + +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.PairFlatMapFunction; +import org.apache.spark.api.java.function.PairFunction; + +import scala.Tuple2; + +import org.apache.sysml.lops.Ctable; +import org.apache.sysml.parser.Expression.ValueType; +import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; +import org.apache.sysml.runtime.functionobjects.CTable; +import org.apache.sysml.runtime.instructions.Instruction; +import org.apache.sysml.runtime.instructions.InstructionUtils; +import org.apache.sysml.runtime.instructions.cp.CPOperand; +import org.apache.sysml.runtime.instructions.spark.utils.RDDAggregateUtils; +import org.apache.sysml.runtime.instructions.spark.utils.SparkUtils; +import org.apache.sysml.runtime.matrix.MatrixCharacteristics; +import org.apache.sysml.runtime.matrix.data.CTableMap; +import org.apache.sysml.runtime.matrix.data.MatrixBlock; +import org.apache.sysml.runtime.matrix.data.MatrixCell; +import org.apache.sysml.runtime.matrix.data.MatrixIndexes; +import org.apache.sysml.runtime.matrix.data.OperationsOnMatrixValues; +import org.apache.sysml.runtime.matrix.data.Pair; +import org.apache.sysml.runtime.matrix.mapred.IndexedMatrixValue; +import org.apache.sysml.runtime.matrix.operators.Operator; +import org.apache.sysml.runtime.matrix.operators.SimpleOperator; +import org.apache.sysml.runtime.util.LongLongDoubleHashMap.ADoubleEntry; +import org.apache.sysml.runtime.util.UtilFunctions; + +public class CtableSPInstruction extends ComputationSPInstruction { + private String _outDim1; + private String _outDim2; + private boolean _dim1Literal; + private boolean _dim2Literal; + private boolean _isExpand; + private boolean _ignoreZeros; + + private CtableSPInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, + String outputDim1, boolean dim1Literal, String outputDim2, boolean dim2Literal, boolean isExpand, + boolean ignoreZeros, String opcode, String istr) { + super(SPType.Ctable, op, in1, in2, in3, out, opcode, istr); + _outDim1 = outputDim1; + _dim1Literal = dim1Literal; + _outDim2 = outputDim2; + _dim2Literal = dim2Literal; + _isExpand = isExpand; + _ignoreZeros = ignoreZeros; + } + + public static CtableSPInstruction parseInstruction(String inst) + throws DMLRuntimeException + { + String[] parts = InstructionUtils.getInstructionPartsWithValueType(inst); + InstructionUtils.checkNumFields ( parts, 7 ); + + String opcode = parts[0]; + + //handle opcode + if ( !(opcode.equalsIgnoreCase("ctable") || opcode.equalsIgnoreCase("ctableexpand")) ) { + throw new DMLRuntimeException("Unexpected opcode in TertiarySPInstruction: " + inst); + } + boolean isExpand = opcode.equalsIgnoreCase("ctableexpand"); + + //handle operands + CPOperand in1 = new CPOperand(parts[1]); + CPOperand in2 = new CPOperand(parts[2]); + CPOperand in3 = new CPOperand(parts[3]); + + //handle known dimension information + String[] dim1Fields = parts[4].split(Instruction.LITERAL_PREFIX); + String[] dim2Fields = parts[5].split(Instruction.LITERAL_PREFIX); + + CPOperand out = new CPOperand(parts[6]); + boolean ignoreZeros = Boolean.parseBoolean(parts[7]); + + // ctable does not require any operator, so we simply pass-in a dummy operator with null functionobject + return new CtableSPInstruction(new SimpleOperator(null), in1, in2, in3, out, dim1Fields[0], Boolean.parseBoolean(dim1Fields[1]), dim2Fields[0], Boolean.parseBoolean(dim2Fields[1]), isExpand, ignoreZeros, opcode, inst); + } + + + @Override + public void processInstruction(ExecutionContext ec) + throws DMLRuntimeException + { + SparkExecutionContext sec = (SparkExecutionContext)ec; + + //get input rdd handle + JavaPairRDD<MatrixIndexes,MatrixBlock> in1 = sec.getBinaryBlockRDDHandleForVariable( input1.getName() ); + JavaPairRDD<MatrixIndexes,MatrixBlock> in2 = null; + JavaPairRDD<MatrixIndexes,MatrixBlock> in3 = null; + double scalar_input2 = -1, scalar_input3 = -1; + + Ctable.OperationTypes ctableOp = Ctable.findCtableOperationByInputDataTypes( + input1.getDataType(), input2.getDataType(), input3.getDataType()); + ctableOp = _isExpand ? Ctable.OperationTypes.CTABLE_EXPAND_SCALAR_WEIGHT : ctableOp; + + MatrixCharacteristics mc1 = sec.getMatrixCharacteristics(input1.getName()); + MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(output.getName()); + + // First get the block sizes and then set them as -1 to allow for binary cell reblock + int brlen = mc1.getRowsPerBlock(); + int bclen = mc1.getColsPerBlock(); + + JavaPairRDD<MatrixIndexes, ArrayList<MatrixBlock>> inputMBs = null; + JavaPairRDD<MatrixIndexes, CTableMap> ctables = null; + JavaPairRDD<MatrixIndexes, Double> bincellsNoFilter = null; + boolean setLineage2 = false; + boolean setLineage3 = false; + switch(ctableOp) { + case CTABLE_TRANSFORM: //(VECTOR) + // F=ctable(A,B,W) + in2 = sec.getBinaryBlockRDDHandleForVariable( input2.getName() ); + in3 = sec.getBinaryBlockRDDHandleForVariable( input3.getName() ); + setLineage2 = true; + setLineage3 = true; + + inputMBs = in1.cogroup(in2).cogroup(in3) + .mapToPair(new MapThreeMBIterableIntoAL()); + + ctables = inputMBs.mapToPair(new PerformCTableMapSideOperation(ctableOp, scalar_input2, + scalar_input3, this.instString, (SimpleOperator)_optr, _ignoreZeros)); + break; + + + case CTABLE_EXPAND_SCALAR_WEIGHT: //(VECTOR) + // F = ctable(seq,A) or F = ctable(seq,B,1) + scalar_input3 = sec.getScalarInput(input3.getName(), input3.getValueType(), input3.isLiteral()).getDoubleValue(); + if(scalar_input3 == 1) { + in2 = sec.getBinaryBlockRDDHandleForVariable( input2.getName() ); + setLineage2 = true; + bincellsNoFilter = in2.flatMapToPair(new ExpandScalarCtableOperation(brlen)); + break; + } + case CTABLE_TRANSFORM_SCALAR_WEIGHT: //(VECTOR/MATRIX) + // F = ctable(A,B) or F = ctable(A,B,1) + in2 = sec.getBinaryBlockRDDHandleForVariable( input2.getName() ); + setLineage2 = true; + + scalar_input3 = sec.getScalarInput(input3.getName(), input3.getValueType(), input3.isLiteral()).getDoubleValue(); + inputMBs = in1.cogroup(in2).mapToPair(new MapTwoMBIterableIntoAL()); + + ctables = inputMBs.mapToPair(new PerformCTableMapSideOperation(ctableOp, scalar_input2, + scalar_input3, this.instString, (SimpleOperator)_optr, _ignoreZeros)); + break; + + case CTABLE_TRANSFORM_HISTOGRAM: //(VECTOR) + // F=ctable(A,1) or F = ctable(A,1,1) + scalar_input2 = sec.getScalarInput(input2.getName(), input2.getValueType(), input2.isLiteral()).getDoubleValue(); + scalar_input3 = sec.getScalarInput(input3.getName(), input3.getValueType(), input3.isLiteral()).getDoubleValue(); + inputMBs = in1.mapToPair(new MapMBIntoAL()); + + ctables = inputMBs.mapToPair(new PerformCTableMapSideOperation(ctableOp, scalar_input2, + scalar_input3, this.instString, (SimpleOperator)_optr, _ignoreZeros)); + break; + + case CTABLE_TRANSFORM_WEIGHTED_HISTOGRAM: //(VECTOR) + // F=ctable(A,1,W) + in3 = sec.getBinaryBlockRDDHandleForVariable( input3.getName() ); + setLineage3 = true; + + scalar_input2 = sec.getScalarInput(input2.getName(), input2.getValueType(), input2.isLiteral()).getDoubleValue(); + inputMBs = in1.cogroup(in3).mapToPair(new MapTwoMBIterableIntoAL()); + + ctables = inputMBs.mapToPair(new PerformCTableMapSideOperation(ctableOp, scalar_input2, + scalar_input3, this.instString, (SimpleOperator)_optr, _ignoreZeros)); + break; + + default: + throw new DMLRuntimeException("Encountered an invalid ctable operation ("+ctableOp+") while executing instruction: " + this.toString()); + } + + // Now perform aggregation on ctables to get binaryCells + if(bincellsNoFilter == null && ctables != null) { + bincellsNoFilter = + ctables.values() + .flatMapToPair(new ExtractBinaryCellsFromCTable()); + bincellsNoFilter = RDDAggregateUtils.sumCellsByKeyStable(bincellsNoFilter); + } + else if(!(bincellsNoFilter != null && ctables == null)) { + throw new DMLRuntimeException("Incorrect ctable operation"); + } + + // handle known/unknown dimensions + long outputDim1 = (_dim1Literal ? (long) Double.parseDouble(_outDim1) : (sec.getScalarInput(_outDim1, ValueType.DOUBLE, false)).getLongValue()); + long outputDim2 = (_dim2Literal ? (long) Double.parseDouble(_outDim2) : (sec.getScalarInput(_outDim2, ValueType.DOUBLE, false)).getLongValue()); + MatrixCharacteristics mcBinaryCells = null; + boolean findDimensions = (outputDim1 == -1 && outputDim2 == -1); + + if( !findDimensions ) { + if((outputDim1 == -1 && outputDim2 != -1) || (outputDim1 != -1 && outputDim2 == -1)) + throw new DMLRuntimeException("Incorrect output dimensions passed to TernarySPInstruction:" + outputDim1 + " " + outputDim2); + else + mcBinaryCells = new MatrixCharacteristics(outputDim1, outputDim2, brlen, bclen); + + // filtering according to given dimensions + bincellsNoFilter = bincellsNoFilter + .filter(new FilterCells(mcBinaryCells.getRows(), mcBinaryCells.getCols())); + } + + // convert double values to matrix cell + JavaPairRDD<MatrixIndexes, MatrixCell> binaryCells = bincellsNoFilter + .mapToPair(new ConvertToBinaryCell()); + + // find dimensions if necessary (w/ cache for reblock) + if( findDimensions ) { + binaryCells = SparkUtils.cacheBinaryCellRDD(binaryCells); + mcBinaryCells = SparkUtils.computeMatrixCharacteristics(binaryCells); + } + + //store output rdd handle + sec.setRDDHandleForVariable(output.getName(), binaryCells); + mcOut.set(mcBinaryCells); + // Since we are outputing binary cells, we set block sizes = -1 + mcOut.setRowsPerBlock(-1); mcOut.setColsPerBlock(-1); + sec.addLineageRDD(output.getName(), input1.getName()); + if(setLineage2) + sec.addLineageRDD(output.getName(), input2.getName()); + if(setLineage3) + sec.addLineageRDD(output.getName(), input3.getName()); + } + + private static class ExpandScalarCtableOperation implements PairFlatMapFunction<Tuple2<MatrixIndexes,MatrixBlock>, MatrixIndexes, Double> + { + private static final long serialVersionUID = -12552669148928288L; + + private int _brlen; + + public ExpandScalarCtableOperation(int brlen) { + _brlen = brlen; + } + + @Override + public Iterator<Tuple2<MatrixIndexes, Double>> call(Tuple2<MatrixIndexes, MatrixBlock> arg0) + throws Exception + { + MatrixIndexes ix = arg0._1(); + MatrixBlock mb = arg0._2(); //col-vector + + //create an output cell per matrix block row (aligned w/ original source position) + ArrayList<Tuple2<MatrixIndexes, Double>> retVal = new ArrayList<>(); + CTable ctab = CTable.getCTableFnObject(); + for( int i=0; i<mb.getNumRows(); i++ ) + { + //compute global target indexes (via ctable obj for error handling consistency) + long row = UtilFunctions.computeCellIndex(ix.getRowIndex(), _brlen, i); + double v2 = mb.quickGetValue(i, 0); + Pair<MatrixIndexes,Double> p = ctab.execute(row, v2, 1.0); + + //indirect construction over pair to avoid tuple2 dependency in general ctable obj + if( p.getKey().getRowIndex() >= 1 ) //filter rejected entries + retVal.add(new Tuple2<>(p.getKey(), p.getValue())); + } + + return retVal.iterator(); + } + } + + private static class MapTwoMBIterableIntoAL implements PairFunction<Tuple2<MatrixIndexes,Tuple2<Iterable<MatrixBlock>,Iterable<MatrixBlock>>>, MatrixIndexes, ArrayList<MatrixBlock>> { + + private static final long serialVersionUID = 271459913267735850L; + + private static MatrixBlock extractBlock(Iterable<MatrixBlock> blks, MatrixBlock retVal) throws Exception { + for(MatrixBlock blk1 : blks) { + if(retVal != null) { + throw new Exception("ERROR: More than 1 matrixblock found for one of the inputs at a given index"); + } + retVal = blk1; + } + if(retVal == null) { + throw new Exception("ERROR: No matrixblock found for one of the inputs at a given index"); + } + return retVal; + } + + @Override + public Tuple2<MatrixIndexes, ArrayList<MatrixBlock>> call( + Tuple2<MatrixIndexes, Tuple2<Iterable<MatrixBlock>, Iterable<MatrixBlock>>> kv) + throws Exception { + MatrixBlock in1 = null; MatrixBlock in2 = null; + in1 = extractBlock(kv._2._1, in1); + in2 = extractBlock(kv._2._2, in2); + // Now return unflatten AL + ArrayList<MatrixBlock> inputs = new ArrayList<>(); + inputs.add(in1); inputs.add(in2); + return new Tuple2<>(kv._1, inputs); + } + + } + + private static class MapThreeMBIterableIntoAL implements PairFunction<Tuple2<MatrixIndexes,Tuple2<Iterable<Tuple2<Iterable<MatrixBlock>,Iterable<MatrixBlock>>>,Iterable<MatrixBlock>>>, MatrixIndexes, ArrayList<MatrixBlock>> { + + private static final long serialVersionUID = -4873754507037646974L; + + private static MatrixBlock extractBlock(Iterable<MatrixBlock> blks, MatrixBlock retVal) throws Exception { + for(MatrixBlock blk1 : blks) { + if(retVal != null) { + throw new Exception("ERROR: More than 1 matrixblock found for one of the inputs at a given index"); + } + retVal = blk1; + } + if(retVal == null) { + throw new Exception("ERROR: No matrixblock found for one of the inputs at a given index"); + } + return retVal; + } + + @Override + public Tuple2<MatrixIndexes, ArrayList<MatrixBlock>> call( + Tuple2<MatrixIndexes, Tuple2<Iterable<Tuple2<Iterable<MatrixBlock>, Iterable<MatrixBlock>>>, Iterable<MatrixBlock>>> kv) + throws Exception { + MatrixBlock in1 = null; MatrixBlock in2 = null; MatrixBlock in3 = null; + + for(Tuple2<Iterable<MatrixBlock>, Iterable<MatrixBlock>> blks : kv._2._1) { + in1 = extractBlock(blks._1, in1); + in2 = extractBlock(blks._2, in2); + } + in3 = extractBlock(kv._2._2, in3); + + // Now return unflatten AL + ArrayList<MatrixBlock> inputs = new ArrayList<>(); + inputs.add(in1); inputs.add(in2); inputs.add(in3); + return new Tuple2<>(kv._1, inputs); + } + + } + + private static class PerformCTableMapSideOperation implements PairFunction<Tuple2<MatrixIndexes,ArrayList<MatrixBlock>>, MatrixIndexes, CTableMap> { + + private static final long serialVersionUID = 5348127596473232337L; + + Ctable.OperationTypes ctableOp; + double scalar_input2; double scalar_input3; + String instString; + Operator optr; + boolean ignoreZeros; + + public PerformCTableMapSideOperation(Ctable.OperationTypes ctableOp, double scalar_input2, double scalar_input3, String instString, Operator optr, boolean ignoreZeros) { + this.ctableOp = ctableOp; + this.scalar_input2 = scalar_input2; + this.scalar_input3 = scalar_input3; + this.instString = instString; + this.optr = optr; + this.ignoreZeros = ignoreZeros; + } + + private static void expectedALSize(int length, ArrayList<MatrixBlock> al) throws Exception { + if(al.size() != length) { + throw new Exception("Expected arraylist of size:" + length + ", but found " + al.size()); + } + } + + @Override + public Tuple2<MatrixIndexes, CTableMap> call( + Tuple2<MatrixIndexes, ArrayList<MatrixBlock>> kv) throws Exception { + CTableMap ctableResult = new CTableMap(); + MatrixBlock ctableResultBlock = null; + + IndexedMatrixValue in1, in2, in3 = null; + in1 = new IndexedMatrixValue(kv._1, kv._2.get(0)); + MatrixBlock matBlock1 = kv._2.get(0); + + switch( ctableOp ) + { + case CTABLE_TRANSFORM: { + in2 = new IndexedMatrixValue(kv._1, kv._2.get(1)); + in3 = new IndexedMatrixValue(kv._1, kv._2.get(2)); + expectedALSize(3, kv._2); + + if(in1==null || in2==null || in3 == null ) + break; + else + OperationsOnMatrixValues.performCtable(in1.getIndexes(), in1.getValue(), in2.getIndexes(), + in2.getValue(), in3.getIndexes(), in3.getValue(), ctableResult, ctableResultBlock, optr); + break; + } + case CTABLE_TRANSFORM_SCALAR_WEIGHT: + case CTABLE_EXPAND_SCALAR_WEIGHT: + { + // 3rd input is a scalar + in2 = new IndexedMatrixValue(kv._1, kv._2.get(1)); + expectedALSize(2, kv._2); + if(in1==null || in2==null ) + break; + else + matBlock1.ctableOperations((SimpleOperator)optr, kv._2.get(1), scalar_input3, ignoreZeros, ctableResult, ctableResultBlock); + break; + } + case CTABLE_TRANSFORM_HISTOGRAM: { + expectedALSize(1, kv._2); + OperationsOnMatrixValues.performCtable(in1.getIndexes(), in1.getValue(), scalar_input2, + scalar_input3, ctableResult, ctableResultBlock, optr); + break; + } + case CTABLE_TRANSFORM_WEIGHTED_HISTOGRAM: { + // 2nd and 3rd inputs are scalars + expectedALSize(2, kv._2); + in3 = new IndexedMatrixValue(kv._1, kv._2.get(1)); // Note: kv._2.get(1), not kv._2.get(2) + + if(in1==null || in3==null) + break; + else + OperationsOnMatrixValues.performCtable(in1.getIndexes(), in1.getValue(), scalar_input2, + in3.getIndexes(), in3.getValue(), ctableResult, ctableResultBlock, optr); + break; + } + default: + throw new DMLRuntimeException("Unrecognized opcode in Tertiary Instruction: " + instString); + } + return new Tuple2<>(kv._1, ctableResult); + } + + } + + private static class MapMBIntoAL implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, ArrayList<MatrixBlock>> { + + private static final long serialVersionUID = 2068398913653350125L; + + @Override + public Tuple2<MatrixIndexes, ArrayList<MatrixBlock>> call( + Tuple2<MatrixIndexes, MatrixBlock> kv) throws Exception { + ArrayList<MatrixBlock> retVal = new ArrayList<>(); + retVal.add(kv._2); + return new Tuple2<>(kv._1, retVal); + } + + } + + private static class ExtractBinaryCellsFromCTable implements PairFlatMapFunction<CTableMap, MatrixIndexes, Double> { + + private static final long serialVersionUID = -5933677686766674444L; + + @Override + public Iterator<Tuple2<MatrixIndexes, Double>> call(CTableMap ctableMap) + throws Exception { + ArrayList<Tuple2<MatrixIndexes, Double>> retVal = new ArrayList<>(); + Iterator<ADoubleEntry> iter = ctableMap.getIterator(); + while( iter.hasNext() ) { + ADoubleEntry ijv = iter.next(); + long i = ijv.getKey1(); + long j = ijv.getKey2(); + double v = ijv.value; + retVal.add(new Tuple2<>(new MatrixIndexes(i, j), v)); + } + return retVal.iterator(); + } + + } + + private static class ConvertToBinaryCell implements PairFunction<Tuple2<MatrixIndexes,Double>, MatrixIndexes, MatrixCell> { + + private static final long serialVersionUID = 7481186480851982800L; + + @Override + public Tuple2<MatrixIndexes, MatrixCell> call( + Tuple2<MatrixIndexes, Double> kv) throws Exception { + + MatrixCell cell = new MatrixCell(kv._2().doubleValue()); + return new Tuple2<>(kv._1(), cell); + } + + } + + private static class FilterCells implements Function<Tuple2<MatrixIndexes,Double>, Boolean> { + private static final long serialVersionUID = 108448577697623247L; + + long rlen; long clen; + public FilterCells(long rlen, long clen) { + this.rlen = rlen; + this.clen = clen; + } + + @Override + public Boolean call(Tuple2<MatrixIndexes, Double> kv) throws Exception { + if(kv._1.getRowIndex() <= 0 || kv._1.getColumnIndex() <= 0) { + throw new Exception("Incorrect cell values in TernarySPInstruction:" + kv._1); + } + if(kv._1.getRowIndex() <= rlen && kv._1.getColumnIndex() <= clen) { + return true; + } + return false; + } + + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/57438103/src/main/java/org/apache/sysml/runtime/instructions/spark/SPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/SPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/SPInstruction.java index d1c12ef..ef7158c 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/SPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/SPInstruction.java @@ -36,7 +36,7 @@ public abstract class SPInstruction extends Instruction { Builtin, Unary, BuiltinNary, MultiReturnBuiltin, Checkpoint, Compression, Cast, CentralMoment, Covariance, QSort, QPick, ParameterizedBuiltin, MAppend, RAppend, GAppend, GAlignedAppend, Rand, - MatrixReshape, Ternary, Quaternary, CumsumAggregate, CumsumOffset, BinUaggChain, UaggOuterChain, + MatrixReshape, Ctable, Quaternary, CumsumAggregate, CumsumOffset, BinUaggChain, UaggOuterChain, Write, SpoofFused, Convolution } http://git-wip-us.apache.org/repos/asf/systemml/blob/57438103/src/main/java/org/apache/sysml/runtime/instructions/spark/TernarySPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/TernarySPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/TernarySPInstruction.java deleted file mode 100644 index 7bc4cb5..0000000 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/TernarySPInstruction.java +++ /dev/null @@ -1,510 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.sysml.runtime.instructions.spark; - -import java.util.ArrayList; -import java.util.Iterator; - -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.PairFlatMapFunction; -import org.apache.spark.api.java.function.PairFunction; - -import scala.Tuple2; - -import org.apache.sysml.lops.Ternary; -import org.apache.sysml.parser.Expression.ValueType; -import org.apache.sysml.runtime.DMLRuntimeException; -import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; -import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; -import org.apache.sysml.runtime.functionobjects.CTable; -import org.apache.sysml.runtime.instructions.Instruction; -import org.apache.sysml.runtime.instructions.InstructionUtils; -import org.apache.sysml.runtime.instructions.cp.CPOperand; -import org.apache.sysml.runtime.instructions.spark.utils.RDDAggregateUtils; -import org.apache.sysml.runtime.instructions.spark.utils.SparkUtils; -import org.apache.sysml.runtime.matrix.MatrixCharacteristics; -import org.apache.sysml.runtime.matrix.data.CTableMap; -import org.apache.sysml.runtime.matrix.data.MatrixBlock; -import org.apache.sysml.runtime.matrix.data.MatrixCell; -import org.apache.sysml.runtime.matrix.data.MatrixIndexes; -import org.apache.sysml.runtime.matrix.data.OperationsOnMatrixValues; -import org.apache.sysml.runtime.matrix.data.Pair; -import org.apache.sysml.runtime.matrix.mapred.IndexedMatrixValue; -import org.apache.sysml.runtime.matrix.operators.Operator; -import org.apache.sysml.runtime.matrix.operators.SimpleOperator; -import org.apache.sysml.runtime.util.LongLongDoubleHashMap.ADoubleEntry; -import org.apache.sysml.runtime.util.UtilFunctions; - -public class TernarySPInstruction extends ComputationSPInstruction { - private String _outDim1; - private String _outDim2; - private boolean _dim1Literal; - private boolean _dim2Literal; - private boolean _isExpand; - private boolean _ignoreZeros; - - private TernarySPInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, - String outputDim1, boolean dim1Literal, String outputDim2, boolean dim2Literal, boolean isExpand, - boolean ignoreZeros, String opcode, String istr) { - super(SPType.Ternary, op, in1, in2, in3, out, opcode, istr); - _outDim1 = outputDim1; - _dim1Literal = dim1Literal; - _outDim2 = outputDim2; - _dim2Literal = dim2Literal; - _isExpand = isExpand; - _ignoreZeros = ignoreZeros; - } - - public static TernarySPInstruction parseInstruction(String inst) - throws DMLRuntimeException - { - String[] parts = InstructionUtils.getInstructionPartsWithValueType(inst); - InstructionUtils.checkNumFields ( parts, 7 ); - - String opcode = parts[0]; - - //handle opcode - if ( !(opcode.equalsIgnoreCase("ctable") || opcode.equalsIgnoreCase("ctableexpand")) ) { - throw new DMLRuntimeException("Unexpected opcode in TertiarySPInstruction: " + inst); - } - boolean isExpand = opcode.equalsIgnoreCase("ctableexpand"); - - //handle operands - CPOperand in1 = new CPOperand(parts[1]); - CPOperand in2 = new CPOperand(parts[2]); - CPOperand in3 = new CPOperand(parts[3]); - - //handle known dimension information - String[] dim1Fields = parts[4].split(Instruction.LITERAL_PREFIX); - String[] dim2Fields = parts[5].split(Instruction.LITERAL_PREFIX); - - CPOperand out = new CPOperand(parts[6]); - boolean ignoreZeros = Boolean.parseBoolean(parts[7]); - - // ctable does not require any operator, so we simply pass-in a dummy operator with null functionobject - return new TernarySPInstruction(new SimpleOperator(null), in1, in2, in3, out, dim1Fields[0], Boolean.parseBoolean(dim1Fields[1]), dim2Fields[0], Boolean.parseBoolean(dim2Fields[1]), isExpand, ignoreZeros, opcode, inst); - } - - - @Override - public void processInstruction(ExecutionContext ec) - throws DMLRuntimeException - { - SparkExecutionContext sec = (SparkExecutionContext)ec; - - //get input rdd handle - JavaPairRDD<MatrixIndexes,MatrixBlock> in1 = sec.getBinaryBlockRDDHandleForVariable( input1.getName() ); - JavaPairRDD<MatrixIndexes,MatrixBlock> in2 = null; - JavaPairRDD<MatrixIndexes,MatrixBlock> in3 = null; - double scalar_input2 = -1, scalar_input3 = -1; - - Ternary.OperationTypes ctableOp = Ternary.findCtableOperationByInputDataTypes( - input1.getDataType(), input2.getDataType(), input3.getDataType()); - ctableOp = _isExpand ? Ternary.OperationTypes.CTABLE_EXPAND_SCALAR_WEIGHT : ctableOp; - - MatrixCharacteristics mc1 = sec.getMatrixCharacteristics(input1.getName()); - MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(output.getName()); - - // First get the block sizes and then set them as -1 to allow for binary cell reblock - int brlen = mc1.getRowsPerBlock(); - int bclen = mc1.getColsPerBlock(); - - JavaPairRDD<MatrixIndexes, ArrayList<MatrixBlock>> inputMBs = null; - JavaPairRDD<MatrixIndexes, CTableMap> ctables = null; - JavaPairRDD<MatrixIndexes, Double> bincellsNoFilter = null; - boolean setLineage2 = false; - boolean setLineage3 = false; - switch(ctableOp) { - case CTABLE_TRANSFORM: //(VECTOR) - // F=ctable(A,B,W) - in2 = sec.getBinaryBlockRDDHandleForVariable( input2.getName() ); - in3 = sec.getBinaryBlockRDDHandleForVariable( input3.getName() ); - setLineage2 = true; - setLineage3 = true; - - inputMBs = in1.cogroup(in2).cogroup(in3) - .mapToPair(new MapThreeMBIterableIntoAL()); - - ctables = inputMBs.mapToPair(new PerformCTableMapSideOperation(ctableOp, scalar_input2, - scalar_input3, this.instString, (SimpleOperator)_optr, _ignoreZeros)); - break; - - - case CTABLE_EXPAND_SCALAR_WEIGHT: //(VECTOR) - // F = ctable(seq,A) or F = ctable(seq,B,1) - scalar_input3 = sec.getScalarInput(input3.getName(), input3.getValueType(), input3.isLiteral()).getDoubleValue(); - if(scalar_input3 == 1) { - in2 = sec.getBinaryBlockRDDHandleForVariable( input2.getName() ); - setLineage2 = true; - bincellsNoFilter = in2.flatMapToPair(new ExpandScalarCtableOperation(brlen)); - break; - } - case CTABLE_TRANSFORM_SCALAR_WEIGHT: //(VECTOR/MATRIX) - // F = ctable(A,B) or F = ctable(A,B,1) - in2 = sec.getBinaryBlockRDDHandleForVariable( input2.getName() ); - setLineage2 = true; - - scalar_input3 = sec.getScalarInput(input3.getName(), input3.getValueType(), input3.isLiteral()).getDoubleValue(); - inputMBs = in1.cogroup(in2).mapToPair(new MapTwoMBIterableIntoAL()); - - ctables = inputMBs.mapToPair(new PerformCTableMapSideOperation(ctableOp, scalar_input2, - scalar_input3, this.instString, (SimpleOperator)_optr, _ignoreZeros)); - break; - - case CTABLE_TRANSFORM_HISTOGRAM: //(VECTOR) - // F=ctable(A,1) or F = ctable(A,1,1) - scalar_input2 = sec.getScalarInput(input2.getName(), input2.getValueType(), input2.isLiteral()).getDoubleValue(); - scalar_input3 = sec.getScalarInput(input3.getName(), input3.getValueType(), input3.isLiteral()).getDoubleValue(); - inputMBs = in1.mapToPair(new MapMBIntoAL()); - - ctables = inputMBs.mapToPair(new PerformCTableMapSideOperation(ctableOp, scalar_input2, - scalar_input3, this.instString, (SimpleOperator)_optr, _ignoreZeros)); - break; - - case CTABLE_TRANSFORM_WEIGHTED_HISTOGRAM: //(VECTOR) - // F=ctable(A,1,W) - in3 = sec.getBinaryBlockRDDHandleForVariable( input3.getName() ); - setLineage3 = true; - - scalar_input2 = sec.getScalarInput(input2.getName(), input2.getValueType(), input2.isLiteral()).getDoubleValue(); - inputMBs = in1.cogroup(in3).mapToPair(new MapTwoMBIterableIntoAL()); - - ctables = inputMBs.mapToPair(new PerformCTableMapSideOperation(ctableOp, scalar_input2, - scalar_input3, this.instString, (SimpleOperator)_optr, _ignoreZeros)); - break; - - default: - throw new DMLRuntimeException("Encountered an invalid ctable operation ("+ctableOp+") while executing instruction: " + this.toString()); - } - - // Now perform aggregation on ctables to get binaryCells - if(bincellsNoFilter == null && ctables != null) { - bincellsNoFilter = - ctables.values() - .flatMapToPair(new ExtractBinaryCellsFromCTable()); - bincellsNoFilter = RDDAggregateUtils.sumCellsByKeyStable(bincellsNoFilter); - } - else if(!(bincellsNoFilter != null && ctables == null)) { - throw new DMLRuntimeException("Incorrect ctable operation"); - } - - // handle known/unknown dimensions - long outputDim1 = (_dim1Literal ? (long) Double.parseDouble(_outDim1) : (sec.getScalarInput(_outDim1, ValueType.DOUBLE, false)).getLongValue()); - long outputDim2 = (_dim2Literal ? (long) Double.parseDouble(_outDim2) : (sec.getScalarInput(_outDim2, ValueType.DOUBLE, false)).getLongValue()); - MatrixCharacteristics mcBinaryCells = null; - boolean findDimensions = (outputDim1 == -1 && outputDim2 == -1); - - if( !findDimensions ) { - if((outputDim1 == -1 && outputDim2 != -1) || (outputDim1 != -1 && outputDim2 == -1)) - throw new DMLRuntimeException("Incorrect output dimensions passed to TernarySPInstruction:" + outputDim1 + " " + outputDim2); - else - mcBinaryCells = new MatrixCharacteristics(outputDim1, outputDim2, brlen, bclen); - - // filtering according to given dimensions - bincellsNoFilter = bincellsNoFilter - .filter(new FilterCells(mcBinaryCells.getRows(), mcBinaryCells.getCols())); - } - - // convert double values to matrix cell - JavaPairRDD<MatrixIndexes, MatrixCell> binaryCells = bincellsNoFilter - .mapToPair(new ConvertToBinaryCell()); - - // find dimensions if necessary (w/ cache for reblock) - if( findDimensions ) { - binaryCells = SparkUtils.cacheBinaryCellRDD(binaryCells); - mcBinaryCells = SparkUtils.computeMatrixCharacteristics(binaryCells); - } - - //store output rdd handle - sec.setRDDHandleForVariable(output.getName(), binaryCells); - mcOut.set(mcBinaryCells); - // Since we are outputing binary cells, we set block sizes = -1 - mcOut.setRowsPerBlock(-1); mcOut.setColsPerBlock(-1); - sec.addLineageRDD(output.getName(), input1.getName()); - if(setLineage2) - sec.addLineageRDD(output.getName(), input2.getName()); - if(setLineage3) - sec.addLineageRDD(output.getName(), input3.getName()); - } - - private static class ExpandScalarCtableOperation implements PairFlatMapFunction<Tuple2<MatrixIndexes,MatrixBlock>, MatrixIndexes, Double> - { - private static final long serialVersionUID = -12552669148928288L; - - private int _brlen; - - public ExpandScalarCtableOperation(int brlen) { - _brlen = brlen; - } - - @Override - public Iterator<Tuple2<MatrixIndexes, Double>> call(Tuple2<MatrixIndexes, MatrixBlock> arg0) - throws Exception - { - MatrixIndexes ix = arg0._1(); - MatrixBlock mb = arg0._2(); //col-vector - - //create an output cell per matrix block row (aligned w/ original source position) - ArrayList<Tuple2<MatrixIndexes, Double>> retVal = new ArrayList<>(); - CTable ctab = CTable.getCTableFnObject(); - for( int i=0; i<mb.getNumRows(); i++ ) - { - //compute global target indexes (via ctable obj for error handling consistency) - long row = UtilFunctions.computeCellIndex(ix.getRowIndex(), _brlen, i); - double v2 = mb.quickGetValue(i, 0); - Pair<MatrixIndexes,Double> p = ctab.execute(row, v2, 1.0); - - //indirect construction over pair to avoid tuple2 dependency in general ctable obj - if( p.getKey().getRowIndex() >= 1 ) //filter rejected entries - retVal.add(new Tuple2<>(p.getKey(), p.getValue())); - } - - return retVal.iterator(); - } - } - - private static class MapTwoMBIterableIntoAL implements PairFunction<Tuple2<MatrixIndexes,Tuple2<Iterable<MatrixBlock>,Iterable<MatrixBlock>>>, MatrixIndexes, ArrayList<MatrixBlock>> { - - private static final long serialVersionUID = 271459913267735850L; - - private static MatrixBlock extractBlock(Iterable<MatrixBlock> blks, MatrixBlock retVal) throws Exception { - for(MatrixBlock blk1 : blks) { - if(retVal != null) { - throw new Exception("ERROR: More than 1 matrixblock found for one of the inputs at a given index"); - } - retVal = blk1; - } - if(retVal == null) { - throw new Exception("ERROR: No matrixblock found for one of the inputs at a given index"); - } - return retVal; - } - - @Override - public Tuple2<MatrixIndexes, ArrayList<MatrixBlock>> call( - Tuple2<MatrixIndexes, Tuple2<Iterable<MatrixBlock>, Iterable<MatrixBlock>>> kv) - throws Exception { - MatrixBlock in1 = null; MatrixBlock in2 = null; - in1 = extractBlock(kv._2._1, in1); - in2 = extractBlock(kv._2._2, in2); - // Now return unflatten AL - ArrayList<MatrixBlock> inputs = new ArrayList<>(); - inputs.add(in1); inputs.add(in2); - return new Tuple2<>(kv._1, inputs); - } - - } - - private static class MapThreeMBIterableIntoAL implements PairFunction<Tuple2<MatrixIndexes,Tuple2<Iterable<Tuple2<Iterable<MatrixBlock>,Iterable<MatrixBlock>>>,Iterable<MatrixBlock>>>, MatrixIndexes, ArrayList<MatrixBlock>> { - - private static final long serialVersionUID = -4873754507037646974L; - - private static MatrixBlock extractBlock(Iterable<MatrixBlock> blks, MatrixBlock retVal) throws Exception { - for(MatrixBlock blk1 : blks) { - if(retVal != null) { - throw new Exception("ERROR: More than 1 matrixblock found for one of the inputs at a given index"); - } - retVal = blk1; - } - if(retVal == null) { - throw new Exception("ERROR: No matrixblock found for one of the inputs at a given index"); - } - return retVal; - } - - @Override - public Tuple2<MatrixIndexes, ArrayList<MatrixBlock>> call( - Tuple2<MatrixIndexes, Tuple2<Iterable<Tuple2<Iterable<MatrixBlock>, Iterable<MatrixBlock>>>, Iterable<MatrixBlock>>> kv) - throws Exception { - MatrixBlock in1 = null; MatrixBlock in2 = null; MatrixBlock in3 = null; - - for(Tuple2<Iterable<MatrixBlock>, Iterable<MatrixBlock>> blks : kv._2._1) { - in1 = extractBlock(blks._1, in1); - in2 = extractBlock(blks._2, in2); - } - in3 = extractBlock(kv._2._2, in3); - - // Now return unflatten AL - ArrayList<MatrixBlock> inputs = new ArrayList<>(); - inputs.add(in1); inputs.add(in2); inputs.add(in3); - return new Tuple2<>(kv._1, inputs); - } - - } - - private static class PerformCTableMapSideOperation implements PairFunction<Tuple2<MatrixIndexes,ArrayList<MatrixBlock>>, MatrixIndexes, CTableMap> { - - private static final long serialVersionUID = 5348127596473232337L; - - Ternary.OperationTypes ctableOp; - double scalar_input2; double scalar_input3; - String instString; - Operator optr; - boolean ignoreZeros; - - public PerformCTableMapSideOperation(Ternary.OperationTypes ctableOp, double scalar_input2, double scalar_input3, String instString, Operator optr, boolean ignoreZeros) { - this.ctableOp = ctableOp; - this.scalar_input2 = scalar_input2; - this.scalar_input3 = scalar_input3; - this.instString = instString; - this.optr = optr; - this.ignoreZeros = ignoreZeros; - } - - private static void expectedALSize(int length, ArrayList<MatrixBlock> al) throws Exception { - if(al.size() != length) { - throw new Exception("Expected arraylist of size:" + length + ", but found " + al.size()); - } - } - - @Override - public Tuple2<MatrixIndexes, CTableMap> call( - Tuple2<MatrixIndexes, ArrayList<MatrixBlock>> kv) throws Exception { - CTableMap ctableResult = new CTableMap(); - MatrixBlock ctableResultBlock = null; - - IndexedMatrixValue in1, in2, in3 = null; - in1 = new IndexedMatrixValue(kv._1, kv._2.get(0)); - MatrixBlock matBlock1 = kv._2.get(0); - - switch( ctableOp ) - { - case CTABLE_TRANSFORM: { - in2 = new IndexedMatrixValue(kv._1, kv._2.get(1)); - in3 = new IndexedMatrixValue(kv._1, kv._2.get(2)); - expectedALSize(3, kv._2); - - if(in1==null || in2==null || in3 == null ) - break; - else - OperationsOnMatrixValues.performTernary(in1.getIndexes(), in1.getValue(), in2.getIndexes(), in2.getValue(), - in3.getIndexes(), in3.getValue(), ctableResult, ctableResultBlock, optr); - break; - } - case CTABLE_TRANSFORM_SCALAR_WEIGHT: - case CTABLE_EXPAND_SCALAR_WEIGHT: - { - // 3rd input is a scalar - in2 = new IndexedMatrixValue(kv._1, kv._2.get(1)); - expectedALSize(2, kv._2); - if(in1==null || in2==null ) - break; - else - matBlock1.ternaryOperations((SimpleOperator)optr, kv._2.get(1), scalar_input3, ignoreZeros, ctableResult, ctableResultBlock); - break; - } - case CTABLE_TRANSFORM_HISTOGRAM: { - expectedALSize(1, kv._2); - OperationsOnMatrixValues.performTernary(in1.getIndexes(), in1.getValue(), scalar_input2, - scalar_input3, ctableResult, ctableResultBlock, optr); - break; - } - case CTABLE_TRANSFORM_WEIGHTED_HISTOGRAM: { - // 2nd and 3rd inputs are scalars - expectedALSize(2, kv._2); - in3 = new IndexedMatrixValue(kv._1, kv._2.get(1)); // Note: kv._2.get(1), not kv._2.get(2) - - if(in1==null || in3==null) - break; - else - OperationsOnMatrixValues.performTernary(in1.getIndexes(), in1.getValue(), scalar_input2, - in3.getIndexes(), in3.getValue(), ctableResult, ctableResultBlock, optr); - break; - } - default: - throw new DMLRuntimeException("Unrecognized opcode in Tertiary Instruction: " + instString); - } - return new Tuple2<>(kv._1, ctableResult); - } - - } - - private static class MapMBIntoAL implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, ArrayList<MatrixBlock>> { - - private static final long serialVersionUID = 2068398913653350125L; - - @Override - public Tuple2<MatrixIndexes, ArrayList<MatrixBlock>> call( - Tuple2<MatrixIndexes, MatrixBlock> kv) throws Exception { - ArrayList<MatrixBlock> retVal = new ArrayList<>(); - retVal.add(kv._2); - return new Tuple2<>(kv._1, retVal); - } - - } - - private static class ExtractBinaryCellsFromCTable implements PairFlatMapFunction<CTableMap, MatrixIndexes, Double> { - - private static final long serialVersionUID = -5933677686766674444L; - - @Override - public Iterator<Tuple2<MatrixIndexes, Double>> call(CTableMap ctableMap) - throws Exception { - ArrayList<Tuple2<MatrixIndexes, Double>> retVal = new ArrayList<>(); - Iterator<ADoubleEntry> iter = ctableMap.getIterator(); - while( iter.hasNext() ) { - ADoubleEntry ijv = iter.next(); - long i = ijv.getKey1(); - long j = ijv.getKey2(); - double v = ijv.value; - retVal.add(new Tuple2<>(new MatrixIndexes(i, j), v)); - } - return retVal.iterator(); - } - - } - - private static class ConvertToBinaryCell implements PairFunction<Tuple2<MatrixIndexes,Double>, MatrixIndexes, MatrixCell> { - - private static final long serialVersionUID = 7481186480851982800L; - - @Override - public Tuple2<MatrixIndexes, MatrixCell> call( - Tuple2<MatrixIndexes, Double> kv) throws Exception { - - MatrixCell cell = new MatrixCell(kv._2().doubleValue()); - return new Tuple2<>(kv._1(), cell); - } - - } - - private static class FilterCells implements Function<Tuple2<MatrixIndexes,Double>, Boolean> { - private static final long serialVersionUID = 108448577697623247L; - - long rlen; long clen; - public FilterCells(long rlen, long clen) { - this.rlen = rlen; - this.clen = clen; - } - - @Override - public Boolean call(Tuple2<MatrixIndexes, Double> kv) throws Exception { - if(kv._1.getRowIndex() <= 0 || kv._1.getColumnIndex() <= 0) { - throw new Exception("Incorrect cell values in TernarySPInstruction:" + kv._1); - } - if(kv._1.getRowIndex() <= rlen && kv._1.getColumnIndex() <= clen) { - return true; - } - return false; - } - - } -} http://git-wip-us.apache.org/repos/asf/systemml/blob/57438103/src/main/java/org/apache/sysml/runtime/instructions/spark/ZipmmSPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/ZipmmSPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/ZipmmSPInstruction.java index 491934d..d31dce5 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/ZipmmSPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/ZipmmSPInstruction.java @@ -129,7 +129,7 @@ public class ZipmmSPInstruction extends BinarySPInstruction { MatrixBlock tmp = (MatrixBlock)in2.reorgOperations(_rop, new MatrixBlock(), 0, 0, 0); //core matrix multiplication (for t(y)%*%X or t(X)%*%y) - return (MatrixBlock)tmp.aggregateBinaryOperations(tmp, in1, new MatrixBlock(), _abop); + return tmp.aggregateBinaryOperations(tmp, in1, new MatrixBlock(), _abop); } } } http://git-wip-us.apache.org/repos/asf/systemml/blob/57438103/src/main/java/org/apache/sysml/runtime/matrix/MatrixCharacteristics.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/MatrixCharacteristics.java b/src/main/java/org/apache/sysml/runtime/matrix/MatrixCharacteristics.java index a0e81ed..b447b23 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/MatrixCharacteristics.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/MatrixCharacteristics.java @@ -56,7 +56,7 @@ import org.apache.sysml.runtime.instructions.mr.ReorgInstruction; import org.apache.sysml.runtime.instructions.mr.ReplicateInstruction; import org.apache.sysml.runtime.instructions.mr.ScalarInstruction; import org.apache.sysml.runtime.instructions.mr.SeqInstruction; -import org.apache.sysml.runtime.instructions.mr.TernaryInstruction; +import org.apache.sysml.runtime.instructions.mr.CtableInstruction; import org.apache.sysml.runtime.instructions.mr.UaggOuterChainInstruction; import org.apache.sysml.runtime.instructions.mr.UnaryInstruction; import org.apache.sysml.runtime.instructions.mr.UnaryMRInstructionBase; @@ -377,7 +377,7 @@ public class MatrixCharacteristics implements Serializable } } else if (ins instanceof CombineTernaryInstruction ) { - TernaryInstruction realIns=(TernaryInstruction)ins; + CtableInstruction realIns=(CtableInstruction)ins; dimOut.set(dims.get(realIns.input1)); } else if (ins instanceof CombineUnaryInstruction ) { @@ -393,8 +393,8 @@ public class MatrixCharacteristics implements Serializable MatrixCharacteristics dimIn = dims.get(realIns.input); realIns.computeOutputCharacteristics(dimIn, dimOut); } - else if (ins instanceof TernaryInstruction) { - TernaryInstruction realIns = (TernaryInstruction)ins; + else if (ins instanceof CtableInstruction) { + CtableInstruction realIns = (CtableInstruction)ins; MatrixCharacteristics in_dim=dims.get(realIns.input1); dimOut.set(realIns.getOutputDim1(), realIns.getOutputDim2(), in_dim.numRowsPerBlock, in_dim.numColumnsPerBlock); } http://git-wip-us.apache.org/repos/asf/systemml/blob/57438103/src/main/java/org/apache/sysml/runtime/matrix/data/CM_N_COVCell.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/CM_N_COVCell.java b/src/main/java/org/apache/sysml/runtime/matrix/data/CM_N_COVCell.java index b74ff10..a01d1bb 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/CM_N_COVCell.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/CM_N_COVCell.java @@ -30,12 +30,10 @@ import org.apache.hadoop.io.WritableComparable; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.instructions.cp.CM_COV_Object; import org.apache.sysml.runtime.matrix.mapred.IndexedMatrixValue; -import org.apache.sysml.runtime.matrix.operators.AggregateBinaryOperator; import org.apache.sysml.runtime.matrix.operators.AggregateOperator; import org.apache.sysml.runtime.matrix.operators.AggregateUnaryOperator; import org.apache.sysml.runtime.matrix.operators.BinaryOperator; import org.apache.sysml.runtime.matrix.operators.Operator; -import org.apache.sysml.runtime.matrix.operators.QuaternaryOperator; import org.apache.sysml.runtime.matrix.operators.ReorgOperator; import org.apache.sysml.runtime.matrix.operators.ScalarOperator; import org.apache.sysml.runtime.matrix.operators.UnaryOperator; @@ -50,42 +48,35 @@ public class CM_N_COVCell extends MatrixValue implements WritableComparable public String toString() { return cm.toString(); } - - @Override - public MatrixValue aggregateBinaryOperations(MatrixValue m1Value, - MatrixValue m2Value, MatrixValue result, AggregateBinaryOperator op) - throws DMLRuntimeException { - throw new DMLRuntimeException("operation not supported fro WeightedCell"); - } - + @Override public MatrixValue aggregateUnaryOperations(AggregateUnaryOperator op, MatrixValue result, int brlen, int bclen, MatrixIndexes indexesIn) throws DMLRuntimeException { - throw new DMLRuntimeException("operation not supported fro WeightedCell"); + throw new DMLRuntimeException("operation not supported for CM_N_COVCell"); } @Override public MatrixValue binaryOperations(BinaryOperator op, MatrixValue thatValue, MatrixValue result) throws DMLRuntimeException { - throw new DMLRuntimeException("operation not supported fro WeightedCell"); + throw new DMLRuntimeException("operation not supported for CM_N_COVCell"); } @Override public void binaryOperationsInPlace(BinaryOperator op, MatrixValue thatValue) throws DMLRuntimeException { - throw new DMLRuntimeException("operation not supported fro WeightedCell"); + throw new DMLRuntimeException("operation not supported for CM_N_COVCell"); } @Override public void copy(MatrixValue that, boolean sp) { - throw new RuntimeException("operation not supported fro WeightedCell"); + throw new RuntimeException("operation not supported for CM_N_COVCell"); } @Override public void copy(MatrixValue that) { - throw new RuntimeException("operation not supported fro WeightedCell"); + throw new RuntimeException("operation not supported for CM_N_COVCell"); } @Override @@ -105,21 +96,21 @@ public class CM_N_COVCell extends MatrixValue implements WritableComparable @Override public double getValue(int r, int c) { - throw new RuntimeException("operation not supported fro WeightedCell"); + throw new RuntimeException("operation not supported for CM_N_COVCell"); } @Override public void incrementalAggregate(AggregateOperator aggOp, MatrixValue correction, MatrixValue newWithCorrection) throws DMLRuntimeException { - throw new RuntimeException("operation not supported fro WeightedCell"); + throw new RuntimeException("operation not supported for CM_N_COVCell"); } @Override public void incrementalAggregate(AggregateOperator aggOp, MatrixValue newWithCorrection) throws DMLRuntimeException { - throw new RuntimeException("operation not supported fro WeightedCell"); + throw new RuntimeException("operation not supported for CM_N_COVCell"); } @Override @@ -136,7 +127,7 @@ public class CM_N_COVCell extends MatrixValue implements WritableComparable public MatrixValue reorgOperations(ReorgOperator op, MatrixValue result, int startRow, int startColumn, int length) throws DMLRuntimeException { - throw new RuntimeException("operation not supported fro WeightedCell"); + throw new RuntimeException("operation not supported for CM_N_COVCell"); } @Override @@ -157,18 +148,18 @@ public class CM_N_COVCell extends MatrixValue implements WritableComparable @Override public MatrixValue scalarOperations(ScalarOperator op, MatrixValue result) throws DMLRuntimeException { - throw new DMLRuntimeException("operation not supported fro WeightedCell"); + throw new DMLRuntimeException("operation not supported for CM_N_COVCell"); } @Override public void setValue(int r, int c, double v) { - throw new RuntimeException("operation not supported fro WeightedCell"); + throw new RuntimeException("operation not supported for CM_N_COVCell"); } @Override public MatrixValue unaryOperations(UnaryOperator op, MatrixValue result) throws DMLRuntimeException { - throw new DMLRuntimeException("operation not supported fro WeightedCell"); + throw new DMLRuntimeException("operation not supported for CM_N_COVCell"); } @Override @@ -241,65 +232,56 @@ public class CM_N_COVCell extends MatrixValue implements WritableComparable @Override public MatrixValue zeroOutOperations(MatrixValue result, IndexRange range, boolean complementary) throws DMLRuntimeException { - throw new DMLRuntimeException("operation not supported fro WeightedCell"); + throw new DMLRuntimeException("operation not supported for CM_N_COVCell"); } @Override - public void ternaryOperations(Operator op, MatrixValue that, + public void ctableOperations(Operator op, MatrixValue that, MatrixValue that2, CTableMap resultMap, MatrixBlock resultBlock) throws DMLRuntimeException { - throw new DMLRuntimeException("operation not supported fro WeightedCell"); - + throw new DMLRuntimeException("operation not supported for CM_N_COVCell"); } @Override - public void ternaryOperations(Operator op, MatrixValue that, + public void ctableOperations(Operator op, MatrixValue that, double scalarThat2, boolean ignoreZeros, CTableMap resultMap, MatrixBlock resultBlock) throws DMLRuntimeException { - throw new DMLRuntimeException("operation not supported fro WeightedCell"); + throw new DMLRuntimeException("operation not supported for CM_N_COVCell"); } @Override - public void ternaryOperations(Operator op, double scalarThat, + public void ctableOperations(Operator op, double scalarThat, double scalarThat2, CTableMap resultMap, MatrixBlock resultBlock) throws DMLRuntimeException { - throw new DMLRuntimeException("operation not supported fro WeightedCell"); + throw new DMLRuntimeException("operation not supported for CM_N_COVCell"); } @Override - public void ternaryOperations(Operator op, MatrixIndexes ix1, double scalarThat, boolean left, int brlen, + public void ctableOperations(Operator op, MatrixIndexes ix1, double scalarThat, boolean left, int brlen, CTableMap resultMap, MatrixBlock resultBlock) throws DMLRuntimeException { - throw new DMLRuntimeException("operation not supported fro WeightedCell"); + throw new DMLRuntimeException("operation not supported for CM_N_COVCell"); } @Override - public void ternaryOperations(Operator op, double scalarThat, + public void ctableOperations(Operator op, double scalarThat, MatrixValue that2, CTableMap resultMap, MatrixBlock resultBlock) throws DMLRuntimeException { - throw new DMLRuntimeException("operation not supported fro WeightedCell"); - } - - @Override - public MatrixValue quaternaryOperations(QuaternaryOperator qop, MatrixValue um, MatrixValue vm, MatrixValue wm, MatrixValue out) - throws DMLRuntimeException - { - throw new DMLRuntimeException("operation not supported fro WeightedCell"); + throw new DMLRuntimeException("operation not supported for CM_N_COVCell"); } - @Override public void sliceOperations(ArrayList<IndexedMatrixValue> outlist, IndexRange range, int rowCut, int colCut, int blockRowFactor, int blockColFactor, int boundaryRlen, int boundaryClen) throws DMLRuntimeException { - throw new DMLRuntimeException("operation not supported fro WeightedCell"); + throw new DMLRuntimeException("operation not supported for CM_N_COVCell"); } @Override public MatrixValue replaceOperations(MatrixValue result, double pattern, double replacement) throws DMLRuntimeException { - throw new DMLRuntimeException("operation not supported fro WeightedCell"); + throw new DMLRuntimeException("operation not supported for CM_N_COVCell"); } @Override @@ -307,21 +289,13 @@ public class CM_N_COVCell extends MatrixValue implements WritableComparable MatrixValue result, int blockingFactorRow, int blockingFactorCol, MatrixIndexes indexesIn, boolean inCP) throws DMLRuntimeException { - throw new DMLRuntimeException("operation not supported fro WeightedCell"); - } - - @Override - public MatrixValue aggregateBinaryOperations(MatrixIndexes m1Index, - MatrixValue m1Value, MatrixIndexes m2Index, MatrixValue m2Value, - MatrixValue result, AggregateBinaryOperator op) - throws DMLRuntimeException { - throw new DMLRuntimeException("operation not supported fro WeightedCell"); + throw new DMLRuntimeException("operation not supported for CM_N_COVCell"); } @Override public void appendOperations(MatrixValue valueIn2, ArrayList<IndexedMatrixValue> outlist, int blockRowFactor, int blockColFactor, boolean cbind, boolean m2IsLast, int nextNCol) throws DMLRuntimeException { - throw new DMLRuntimeException("operation not supported fro WeightedCell"); + throw new DMLRuntimeException("operation not supported for CM_N_COVCell"); } } http://git-wip-us.apache.org/repos/asf/systemml/blob/57438103/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java index 08f5c85..098427f 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java @@ -4625,21 +4625,16 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab return sum_wt; } - @Override - public MatrixValue aggregateBinaryOperations(MatrixIndexes m1Index, MatrixValue m1Value, MatrixIndexes m2Index, MatrixValue m2Value, - MatrixValue result, AggregateBinaryOperator op ) throws DMLRuntimeException + public MatrixBlock aggregateBinaryOperations(MatrixIndexes m1Index, MatrixBlock m1, MatrixIndexes m2Index, MatrixBlock m2, + MatrixBlock ret, AggregateBinaryOperator op ) throws DMLRuntimeException { - return aggregateBinaryOperations(m1Value, m2Value, result, op); + return aggregateBinaryOperations(m1, m2, ret, op); } - @Override - public MatrixValue aggregateBinaryOperations(MatrixValue m1Value, MatrixValue m2Value, MatrixValue result, AggregateBinaryOperator op) + public MatrixBlock aggregateBinaryOperations(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, AggregateBinaryOperator op) throws DMLRuntimeException { //check input types, dimensions, configuration - MatrixBlock m1 = checkType(m1Value); - MatrixBlock m2 = checkType(m2Value); - MatrixBlock ret = checkType(result); if( m1.clen != m2.rlen ) { throw new RuntimeException("Dimensions do not match for matrix multiplication ("+m1.clen+"!="+m2.rlen+")."); } @@ -4923,7 +4918,7 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab * (i3,j3,w) from input3 (that2) */ @Override - public void ternaryOperations(Operator op, double scalarThat, + public void ctableOperations(Operator op, double scalarThat, MatrixValue that2Val, CTableMap resultMap, MatrixBlock resultBlock) throws DMLRuntimeException { @@ -4954,7 +4949,7 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab * (w) from scalar_input3 (scalarThat2) */ @Override - public void ternaryOperations(Operator op, double scalarThat, + public void ctableOperations(Operator op, double scalarThat, double scalarThat2, CTableMap resultMap, MatrixBlock resultBlock) throws DMLRuntimeException { @@ -4982,7 +4977,7 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab * */ @Override - public void ternaryOperations(Operator op, MatrixIndexes ix1, double scalarThat, + public void ctableOperations(Operator op, MatrixIndexes ix1, double scalarThat, boolean left, int brlen, CTableMap resultMap, MatrixBlock resultBlock) throws DMLRuntimeException { @@ -5018,7 +5013,7 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab * we can also use a sparse-safe implementation */ @Override - public void ternaryOperations(Operator op, MatrixValue thatVal, double scalarThat2, boolean ignoreZeros, + public void ctableOperations(Operator op, MatrixValue thatVal, double scalarThat2, boolean ignoreZeros, CTableMap resultMap, MatrixBlock resultBlock) throws DMLRuntimeException { @@ -5082,7 +5077,7 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab * @param resultBlock result matrix block * @throws DMLRuntimeException if DMLRuntimeException occurs */ - public void ternaryOperations(Operator op, MatrixValue thatMatrix, double thatScalar, MatrixBlock resultBlock) + public void ctableOperations(Operator op, MatrixValue thatMatrix, double thatScalar, MatrixBlock resultBlock) throws DMLRuntimeException { MatrixBlock that = checkType(thatMatrix); @@ -5118,14 +5113,14 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab * @param resultMap table map * @throws DMLRuntimeException if DMLRuntimeException occurs */ - public void ternaryOperations(Operator op, MatrixValue thatVal, MatrixValue that2Val, CTableMap resultMap) + public void ctableOperations(Operator op, MatrixValue thatVal, MatrixValue that2Val, CTableMap resultMap) throws DMLRuntimeException { - ternaryOperations(op, thatVal, that2Val, resultMap, null); + ctableOperations(op, thatVal, that2Val, resultMap, null); } @Override - public void ternaryOperations(Operator op, MatrixValue thatVal, MatrixValue that2Val, CTableMap resultMap, MatrixBlock resultBlock) + public void ctableOperations(Operator op, MatrixValue thatVal, MatrixValue that2Val, CTableMap resultMap, MatrixBlock resultBlock) throws DMLRuntimeException { MatrixBlock that = checkType(thatVal); @@ -5159,26 +5154,23 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab } } - @Override - public MatrixValue quaternaryOperations(QuaternaryOperator qop, MatrixValue um, MatrixValue vm, MatrixValue wm, MatrixValue out) + public MatrixBlock quaternaryOperations(QuaternaryOperator qop, MatrixBlock um, MatrixBlock vm, MatrixBlock wm, MatrixBlock out) throws DMLRuntimeException { return quaternaryOperations(qop, um, vm, wm, out, 1); } - public MatrixValue quaternaryOperations(QuaternaryOperator qop, MatrixValue um, MatrixValue vm, MatrixValue wm, MatrixValue out, int k) + public MatrixBlock quaternaryOperations(QuaternaryOperator qop, MatrixBlock U, MatrixBlock V, MatrixBlock wm, MatrixBlock out, int k) throws DMLRuntimeException { //check input dimensions - if( getNumRows() != um.getNumRows() ) - throw new DMLRuntimeException("Dimension mismatch rows on quaternary operation: "+getNumRows()+"!="+um.getNumRows()); - if( getNumColumns() != vm.getNumRows() ) - throw new DMLRuntimeException("Dimension mismatch columns quaternary operation: "+getNumColumns()+"!="+vm.getNumRows()); + if( getNumRows() != U.getNumRows() ) + throw new DMLRuntimeException("Dimension mismatch rows on quaternary operation: "+getNumRows()+"!="+U.getNumRows()); + if( getNumColumns() != V.getNumRows() ) + throw new DMLRuntimeException("Dimension mismatch columns quaternary operation: "+getNumColumns()+"!="+V.getNumRows()); //check input data types MatrixBlock X = this; - MatrixBlock U = checkType(um); - MatrixBlock V = checkType(vm); MatrixBlock R = checkType(out); //prepare intermediates and output @@ -5230,7 +5222,7 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab if( k > 1 ) LibMatrixMult.matrixMultWuMM(X, U, V, R, qop.wtype5, qop.fn, k); else - LibMatrixMult.matrixMultWuMM(X, U, V, R, qop.wtype5, qop.fn); + LibMatrixMult.matrixMultWuMM(X, U, V, R, qop.wtype5, qop.fn); } return R;
