mboehm7 commented on code in PR #2190: URL: https://github.com/apache/systemds/pull/2190#discussion_r1929504112
########## src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java: ########## @@ -270,26 +271,26 @@ public String toString() { case VECT_CBIND: return "b(cbind)"; case VECT_BIASADD: return "b(vbias+)"; case VECT_BIASMULT: return "b(vbias*)"; - case MULT: return "b(*)"; - case DIV: return "b(/)"; - case PLUS: return "b(+)"; - case MINUS: return "b(-)"; - case POW: return "b(^)"; - case MODULUS: return "b(%%)"; - case INTDIV: return "b(%/%)"; - case LESS: return "b(<)"; - case LESSEQUAL: return "b(<=)"; - case GREATER: return "b(>)"; - case GREATEREQUAL: return "b(>=)"; - case EQUAL: return "b(==)"; - case NOTEQUAL: return "b(!=)"; + case MULT: return "b(" + Opcodes.MULT.getName() + ")"; + case DIV: return "b(" + Opcodes.DIV.getName() + ")"; + case PLUS: return "b(" + Opcodes.PLUS.getName() + ")"; + case MINUS: return "b(" + Opcodes.MINUS.getName() + ")"; + case POW: return "b(" + Opcodes.POW.getName() + ")"; + case MODULUS: return "b(" + Opcodes.MODULUS.getName() + ")"; + case INTDIV: return "b(" + Opcodes.INTDIV.getName() + ")"; + case LESS: return "b(" + Opcodes.LESS.getName() + ")"; + case LESSEQUAL: return "b(" + Opcodes.LESSEQUAL.getName() + ")"; + case GREATER: return "b(" + Opcodes.GREATER.getName() + ")"; + case GREATEREQUAL: return "b(" + Opcodes.GREATEREQUAL.getName() + ")"; + case EQUAL: return "b(" + Opcodes.EQUAL.getName() + ")"; + case NOTEQUAL: return "b(" + Opcodes.NOTEQUAL.getName() + ")"; case OR: return "b(|)"; case AND: return "b(&)"; Review Comment: why are these operators not covered? ########## src/main/java/org/apache/sysds/common/Opcodes.java: ########## @@ -0,0 +1,328 @@ +package org.apache.sysds.common; + +import org.apache.sysds.lops.*; +import org.apache.sysds.runtime.instructions.cp.CPInstruction.CPType; +import org.apache.sysds.common.Types.OpOp1; +import org.apache.sysds.hops.FunctionOp; + +import java.util.Arrays; +import java.util.EnumSet; +import java.util.HashMap; +import java.util.Map; + +public enum Opcodes { + MMULT("ba+*", CPType.AggregateBinary), + TAKPM("tak+*", CPType.AggregateTernary), + TACKPM("tack+*", CPType.AggregateTernary), + + UAKP("uak+", CPType.AggregateUnary), + UARKP("uark+", CPType.AggregateUnary), + UACKP( "uack+", CPType.AggregateUnary), + UASQKP( "uasqk+", CPType.AggregateUnary), + UARSQKP( "uarsqk+", CPType.AggregateUnary), + UACSQKP( "uacsqk+", CPType.AggregateUnary), + UAMEAN( "uamean", CPType.AggregateUnary), + UARMEAN("uarmean", CPType.AggregateUnary), + UACMEAN("uacmean", CPType.AggregateUnary), + UAVAR("uavar", CPType.AggregateUnary), + UARVAR("uarvar", CPType.AggregateUnary), + UACVAR("uacvar", CPType.AggregateUnary), + UAMAX("uamax", CPType.AggregateUnary), + UARMAX("uarmax", CPType.AggregateUnary), + UARIMAX("uarimax", CPType.AggregateUnary), + UACMAX("uacmax", CPType.AggregateUnary), + UAMIN("uamin", CPType.AggregateUnary), + UARMIN("uarmin", CPType.AggregateUnary), + UARIMIN("uarimin", CPType.AggregateUnary), + UACMIN("uacmin", CPType.AggregateUnary), + UAP("ua+", CPType.AggregateUnary), + UARP("uar+", CPType.AggregateUnary), + UACP("uac+", CPType.AggregateUnary), + UAM("ua*", CPType.AggregateUnary), + UARM("uar*", CPType.AggregateUnary), + UACM("uac*", CPType.AggregateUnary), + UATRACE("uatrace", CPType.AggregateUnary), + UAKTRACE("uaktrace", CPType.AggregateUnary), + + NROW("nrow", CPType.AggregateUnary), + NCOL("ncol", CPType.AggregateUnary), + LENGTH("length", CPType.AggregateUnary), + EXISTS("exists", CPType.AggregateUnary), + LINEAGE("lineage", CPType.AggregateUnary), + UACD("uacd", CPType.AggregateUnary), + UACDR("uacdr", CPType.AggregateUnary), + UACDC("uacdc", CPType.AggregateUnary), + UACDAP("uacdap", CPType.AggregateUnary), + UACDAPR("uacdapr", CPType.AggregateUnary), + UACDAPC("uacdapc", CPType.AggregateUnary), + UNIQUE("unique", CPType.AggregateUnary), + UNIQUER("uniquer", CPType.AggregateUnary), + UNIQUEC("uniquec", CPType.AggregateUnary), + + UAGGOUTERCHAIN("uaggouterchain", CPType.UaggOuterChain), + + // Arithmetic Instruction Opcodes + PLUS("+", CPType.Binary), + MINUS("-", CPType.Binary), + MULT("*", CPType.Binary), + DIV("/", CPType.Binary), + MODULUS("%%", CPType.Binary), + INTDIV("%/%", CPType.Binary), + POW("^", CPType.Binary), + MINUS1_MULT("1-*", CPType.Binary), //special * case + POW2("^2", CPType.Binary), //special ^ case + MULT2("*2", CPType.Binary), //special * case + MINUS_NZ("-nz", CPType.Binary), //special - case + + // Boolean Instruction Opcodes + AND("&&", CPType.Binary), + OR("||", CPType.Binary), + XOR("xor", CPType.Binary), + BITWAND("bitwAnd", CPType.Binary), + BITWOR("bitwOr", CPType.Binary), + BITWXOR("bitwXor", CPType.Binary), + BITWSHIFTL("bitwShiftL", CPType.Binary), + BITWSHIFTR("bitwShiftR", CPType.Binary), + NOT("!", CPType.Unary), + + // Relational Instruction Opcodes + EQUAL("==", CPType.Binary), + NOTEQUAL("!=", CPType.Binary), + LESS("<", CPType.Binary), + GREATER(">", CPType.Binary), + LESSEQUAL("<=", CPType.Binary), + GREATEREQUAL(">=", CPType.Binary), + + // Builtin Instruction Opcodes + LOG("log", CPType.Builtin), + LOGNZ("log_nz", CPType.Builtin), + + SOLVE("solve", CPType.Binary), + MAX("max", CPType.Binary), + MIN("min", CPType.Binary), + DROPINVALIDTYPE("dropInvalidType", CPType.Binary), + DROPINVALIDLENGTH("dropInvalidLength", CPType.Binary), + FREPLICATE("freplicate", CPType.Binary), + VALUESWAP("valueSwap", CPType.Binary), + APPLYSCHEMA("applySchema", CPType.Binary), + MAP("_map", CPType.Ternary), + + NMAX("nmax", CPType.BuiltinNary), + NMIN("nmin", CPType.BuiltinNary), + NP("n+", CPType.BuiltinNary), + NM("n*", CPType.BuiltinNary), + + EXP("exp", CPType.Unary), + ABS("abs", CPType.Unary), + SIN("sin", CPType.Unary), + COS("cos", CPType.Unary), + TAN("tan", CPType.Unary), + SINH("sinh", CPType.Unary), + COSH("cosh", CPType.Unary), + TANH("tanh", CPType.Unary), + ASIN("asin", CPType.Unary), + ACOS("acos", CPType.Unary), + ATAN("atan", CPType.Unary), + SIGN("sign", CPType.Unary), + SQRT("sqrt", CPType.Unary), + PLOGP("plogp", CPType.Unary), + PRINT("print", CPType.Unary), + ASSERT("assert", CPType.Unary), + ROUND("round", CPType.Unary), + CEIL("ceil", CPType.Unary), + FLOOR("floor", CPType.Unary), + UCUMKP("ucumk+", CPType.Unary), + UCUMM("ucum*", CPType.Unary), + UCUMKPM("ucumk+*", CPType.Unary), + UCUMMIN("ucummin", CPType.Unary), + UCUMMAX("ucummax", CPType.Unary), + STOP("stop", CPType.Unary), + INVERSE("inverse", CPType.Unary), + CHOLESKY("cholesky", CPType.Unary), + SPROP("sprop", CPType.Unary), + SIGMOID("sigmoid", CPType.Unary), + TYPEOF("typeOf", CPType.Unary), + DETECTSCHEMA("detectSchema", CPType.Unary), + COLNAMES("colnames", CPType.Unary), + ISNA("isna", CPType.Unary), + ISNAN("isnan", CPType.Unary), + ISINF("isinf", CPType.Unary), + PRINTF("printf", CPType.BuiltinNary), + CBIND("cbind", CPType.BuiltinNary), + RBIND("rbind", CPType.BuiltinNary), + EVAL("eval", CPType.BuiltinNary), + LIST("list", CPType.BuiltinNary), + + //Parametrized builtin functions + AUTODIFF("autoDiff", CPType.ParameterizedBuiltin), + CONTAINS("contains", CPType.ParameterizedBuiltin), + PARAMSERV("paramserv", CPType.ParameterizedBuiltin), + NVLIST("nvlist", CPType.ParameterizedBuiltin), + CDF("cdf", CPType.ParameterizedBuiltin), + INVCDF("invcdf", CPType.ParameterizedBuiltin), + GROUPEDAGG("groupedagg", CPType.ParameterizedBuiltin), + RMEMPTY("rmempty", CPType.ParameterizedBuiltin), + REPLACE("replace", CPType.ParameterizedBuiltin), + LOWERTRI("lowertri", CPType.ParameterizedBuiltin), + UPPERTRI("uppertri", CPType.ParameterizedBuiltin), + REXPAND("rexpand", CPType.ParameterizedBuiltin), + TOSTRING("toString", CPType.ParameterizedBuiltin), + TOKENIZE("tokenize", CPType.ParameterizedBuiltin), + TRANSFORMAPPLY("transformapply", CPType.ParameterizedBuiltin), + TRANSFORMDECODE("transformdecode", CPType.ParameterizedBuiltin), + TRANSFORMCOLMAP("transformcolmap", CPType.ParameterizedBuiltin), + TRANSFORMMETA("transformmeta", CPType.ParameterizedBuiltin), + TRANSFORMENCODE("transformencode", CPType.MultiReturnParameterizedBuiltin), + + //Ternary instruction opcodes + PM("+*", CPType.Ternary), + MINUSMULT("-*", CPType.Ternary), + IFELSE("ifelse", CPType.Ternary), + + //Variable instruction opcodes + ASSIGNVAR("assignvar", CPType.Variable), + CPVAR("cpvar", CPType.Variable), + MVVAR("mvvar", CPType.Variable), + RMVAR("rmvar", CPType.Variable), + RMFILEVAR("rmfilevar", CPType.Variable), + CAST_AS_SCALAR(OpOp1.CAST_AS_SCALAR.toString(), CPType.Variable), + CAST_AS_MATRIX(OpOp1.CAST_AS_MATRIX.toString(), CPType.Variable), + CAST_AS_FRAME_VAR("cast_as_frame", CPType.Variable), + CAST_AS_FRAME(OpOp1.CAST_AS_FRAME.toString(), CPType.Variable), + CAST_AS_LIST(OpOp1.CAST_AS_LIST.toString(), CPType.Variable), + CAST_AS_DOUBLE(OpOp1.CAST_AS_DOUBLE.toString(), CPType.Variable), + CAST_AS_INT(OpOp1.CAST_AS_INT.toString(), CPType.Variable), + CAST_AS_BOOLEAN(OpOp1.CAST_AS_BOOLEAN.toString(), CPType.Variable), + ATTACHFILETOVAR("attachfiletovar", CPType.Variable), + READ("read", CPType.Variable), + WRITE("write", CPType.Variable), + CREATEVAR("createvar", CPType.Variable), + + //Reorg instruction opcodes + TRANSPOSE("r'", CPType.Reorg), + REV("rev", CPType.Reorg), + ROLL("roll", CPType.Reorg), + DIAG("rdiag", CPType.Reorg), + RESHAPE("rshape", CPType.Reshape), + SORT("rsort", CPType.Reorg), + + // Opcodes related to convolutions + RELU_BACKWARD("relu_backward", CPType.Dnn), + RELU_MAXPOOLING("relu_maxpooling", CPType.Dnn), + RELU_MAXPOOLING_BACKWARD("relu_maxpooling_backward", CPType.Dnn), + MAXPOOLING("maxpooling", CPType.Dnn), + MAXPOOLING_BACKWARD("maxpooling_backward", CPType.Dnn), + AVGPOOLING("avgpooling", CPType.Dnn), + AVGPOOLING_BACKWARD("avgpooling_backward", CPType.Dnn), + CONV2D("conv2d", CPType.Dnn), + CONV2D_BIAS_ADD("conv2d_bias_add", CPType.Dnn), + CONV2D_BACKWARD_FILTER("conv2d_backward_filter", CPType.Dnn), + CONV2D_BACKWARD_DATA("conv2d_backward_data", CPType.Dnn), + BIAS_ADD("bias_add", CPType.Dnn), + BIAS_MULTIPLY("bias_multiply", CPType.Dnn), + BATCH_NORM2D("batch_norm2d", CPType.Dnn), + BATCH_NORM2D_BACKWARD("batch_norm2d_backward", CPType.Dnn), + LSTM("lstm", CPType.Dnn), + LSTM_BACKWARD("lstm_backward", CPType.Dnn), + + //Quaternary instruction opcodes + WSLOSS("wsloss", CPType.Quaternary), + WSIGMOID("wsigmoid", CPType.Quaternary), + WDIVMM("wdivmm", CPType.Quaternary), + WCEMM("wcemm", CPType.Quaternary), + WUMM("wumm", CPType.Quaternary), + + //User-defined function Opcodes + FCALL(FunctionOp.OPCODE, CPType.FCall), + + APPEND(Append.OPCODE, CPType.Append), + REMOVE("remove", CPType.Append), + + //data generation opcodes + RANDOM(DataGen.RAND_OPCODE, CPType.Rand), + SEQUENCE(DataGen.SEQ_OPCODE, CPType.Rand), + STRINGINIT(DataGen.SINIT_OPCODE, CPType.StringInit), + SAMPLE(DataGen.SAMPLE_OPCODE, CPType.Rand), + TIME(DataGen.TIME_OPCODE, CPType.Rand), + FRAME(DataGen.FRAME_OPCODE, CPType.Rand), + + CTABLE("ctable", CPType.Ctable), + CTABLEEXPAND("ctableexpand", CPType.Ctable), + + //central moment, covariance, quantiles (sort/pick) + CM("cm", CPType.CentralMoment), + COV("cov", CPType.Covariance), + QSORT("qsort", CPType.QSort), + QPICK("qpick", CPType.QPick), + + RIGHT_INDEX(RightIndex.OPCODE, CPType.MatrixIndexing), + LEFT_INDEX(LeftIndex.OPCODE, CPType.MatrixIndexing), + + TSMM("tsmm", CPType.MMTSJ), + PMM("pmm", CPType.PMMJ), + MMCHAIN("mmchain", CPType.MMChain), + + QR("qr", CPType.MultiReturnBuiltin), + LU("lu", CPType.MultiReturnBuiltin), + EIGEN("eigen", CPType.MultiReturnBuiltin), + FFT("fft", CPType.MultiReturnBuiltin), + IFFT("ifft", CPType.MultiReturnComplexMatrixBuiltin), + FFT_LINEARIZED("fft_linearized", CPType.MultiReturnBuiltin), + IFFT_LINEARIZED("ifft_linearized", CPType.MultiReturnComplexMatrixBuiltin), + STFT("stft", CPType.MultiReturnComplexMatrixBuiltin), + SVD("svd", CPType.MultiReturnBuiltin), + RCM("rcm", CPType.MultiReturnComplexMatrixBuiltin), + + PARTITION("partition", CPType.Partition), + COMPRESS(Compression.OPCODE, CPType.Compression), + DECOMPRESS(DeCompression.OPCODE, CPType.DeCompression), + SPOOF("spoof", CPType.SpoofFused), + PREFETCH("prefetch", CPType.Prefetch), + EVICT("_evict", CPType.EvictLineageCache), + BROADCAST("broadcast", CPType.Broadcast), + TRIGREMOTE("trigremote", CPType.TrigRemote), + LOCAL(Local.OPCODE, CPType.Local), + + SQL("sql", CPType.Sql); + + + // Constructor + Opcodes(String name, CPType type) { + this._name = name; + this._type = type; + } + + // Fields + private final String _name; + private final CPType _type; + + private static final Map<String, Opcodes> _lookupMap = new HashMap<>(); + + // Initialize lookup map + static { + for (Opcodes op : EnumSet.allOf(Opcodes.class)) { + _lookupMap.put(op.getName(), op); + } + } + + // Getters + public String getName() { + return _name; + } + + public CPType getType() { + return _type; + } + + public static CPType getCPTypeByOpcode(String opcode) { + for (Opcodes op : Opcodes.values()) { + if (op.getName().equalsIgnoreCase(opcode.trim())) { + return op.getType(); + } + } + return null; + } + + + Review Comment: avoid such free lines at the end of the enum ########## src/main/java/org/apache/sysds/hops/cost/CostEstimatorStaticRuntime.java: ########## @@ -281,15 +282,15 @@ private static double getNFLOP( String optype, boolean inMR, long d1m, long d1n, //NOTE: all instruction types that are equivalent in CP and MR are only //included in CP to prevent redundancy - CPType cptype = CPInstructionParser.String2CPInstructionType.get(optype); - if( cptype != null ) //for CP Ops and equivalent MR ops + CPType cptype = Opcodes.valueOf(optype).getType(); Review Comment: shouldn't this be getCPTypeByOpcode? ########## src/main/java/org/apache/sysds/common/Opcodes.java: ########## @@ -0,0 +1,328 @@ +package org.apache.sysds.common; + +import org.apache.sysds.lops.*; +import org.apache.sysds.runtime.instructions.cp.CPInstruction.CPType; +import org.apache.sysds.common.Types.OpOp1; +import org.apache.sysds.hops.FunctionOp; + +import java.util.Arrays; +import java.util.EnumSet; +import java.util.HashMap; +import java.util.Map; + +public enum Opcodes { + MMULT("ba+*", CPType.AggregateBinary), + TAKPM("tak+*", CPType.AggregateTernary), + TACKPM("tack+*", CPType.AggregateTernary), + + UAKP("uak+", CPType.AggregateUnary), + UARKP("uark+", CPType.AggregateUnary), + UACKP( "uack+", CPType.AggregateUnary), + UASQKP( "uasqk+", CPType.AggregateUnary), + UARSQKP( "uarsqk+", CPType.AggregateUnary), + UACSQKP( "uacsqk+", CPType.AggregateUnary), + UAMEAN( "uamean", CPType.AggregateUnary), + UARMEAN("uarmean", CPType.AggregateUnary), Review Comment: There seems to be a formatting issue - please ensure consistent tab indentation. ########## src/main/java/org/apache/sysds/common/Opcodes.java: ########## @@ -0,0 +1,328 @@ +package org.apache.sysds.common; + +import org.apache.sysds.lops.*; +import org.apache.sysds.runtime.instructions.cp.CPInstruction.CPType; +import org.apache.sysds.common.Types.OpOp1; +import org.apache.sysds.hops.FunctionOp; + +import java.util.Arrays; +import java.util.EnumSet; +import java.util.HashMap; +import java.util.Map; + +public enum Opcodes { + MMULT("ba+*", CPType.AggregateBinary), + TAKPM("tak+*", CPType.AggregateTernary), + TACKPM("tack+*", CPType.AggregateTernary), + + UAKP("uak+", CPType.AggregateUnary), + UARKP("uark+", CPType.AggregateUnary), + UACKP( "uack+", CPType.AggregateUnary), + UASQKP( "uasqk+", CPType.AggregateUnary), + UARSQKP( "uarsqk+", CPType.AggregateUnary), + UACSQKP( "uacsqk+", CPType.AggregateUnary), + UAMEAN( "uamean", CPType.AggregateUnary), + UARMEAN("uarmean", CPType.AggregateUnary), + UACMEAN("uacmean", CPType.AggregateUnary), + UAVAR("uavar", CPType.AggregateUnary), + UARVAR("uarvar", CPType.AggregateUnary), + UACVAR("uacvar", CPType.AggregateUnary), + UAMAX("uamax", CPType.AggregateUnary), + UARMAX("uarmax", CPType.AggregateUnary), + UARIMAX("uarimax", CPType.AggregateUnary), + UACMAX("uacmax", CPType.AggregateUnary), + UAMIN("uamin", CPType.AggregateUnary), + UARMIN("uarmin", CPType.AggregateUnary), + UARIMIN("uarimin", CPType.AggregateUnary), + UACMIN("uacmin", CPType.AggregateUnary), + UAP("ua+", CPType.AggregateUnary), + UARP("uar+", CPType.AggregateUnary), + UACP("uac+", CPType.AggregateUnary), + UAM("ua*", CPType.AggregateUnary), + UARM("uar*", CPType.AggregateUnary), + UACM("uac*", CPType.AggregateUnary), + UATRACE("uatrace", CPType.AggregateUnary), + UAKTRACE("uaktrace", CPType.AggregateUnary), + + NROW("nrow", CPType.AggregateUnary), + NCOL("ncol", CPType.AggregateUnary), + LENGTH("length", CPType.AggregateUnary), + EXISTS("exists", CPType.AggregateUnary), + LINEAGE("lineage", CPType.AggregateUnary), + UACD("uacd", CPType.AggregateUnary), + UACDR("uacdr", CPType.AggregateUnary), + UACDC("uacdc", CPType.AggregateUnary), + UACDAP("uacdap", CPType.AggregateUnary), + UACDAPR("uacdapr", CPType.AggregateUnary), + UACDAPC("uacdapc", CPType.AggregateUnary), + UNIQUE("unique", CPType.AggregateUnary), + UNIQUER("uniquer", CPType.AggregateUnary), + UNIQUEC("uniquec", CPType.AggregateUnary), + + UAGGOUTERCHAIN("uaggouterchain", CPType.UaggOuterChain), + + // Arithmetic Instruction Opcodes + PLUS("+", CPType.Binary), + MINUS("-", CPType.Binary), + MULT("*", CPType.Binary), + DIV("/", CPType.Binary), + MODULUS("%%", CPType.Binary), + INTDIV("%/%", CPType.Binary), + POW("^", CPType.Binary), + MINUS1_MULT("1-*", CPType.Binary), //special * case + POW2("^2", CPType.Binary), //special ^ case + MULT2("*2", CPType.Binary), //special * case + MINUS_NZ("-nz", CPType.Binary), //special - case + + // Boolean Instruction Opcodes + AND("&&", CPType.Binary), + OR("||", CPType.Binary), + XOR("xor", CPType.Binary), + BITWAND("bitwAnd", CPType.Binary), + BITWOR("bitwOr", CPType.Binary), + BITWXOR("bitwXor", CPType.Binary), + BITWSHIFTL("bitwShiftL", CPType.Binary), + BITWSHIFTR("bitwShiftR", CPType.Binary), + NOT("!", CPType.Unary), + + // Relational Instruction Opcodes + EQUAL("==", CPType.Binary), + NOTEQUAL("!=", CPType.Binary), + LESS("<", CPType.Binary), + GREATER(">", CPType.Binary), + LESSEQUAL("<=", CPType.Binary), + GREATEREQUAL(">=", CPType.Binary), + + // Builtin Instruction Opcodes + LOG("log", CPType.Builtin), + LOGNZ("log_nz", CPType.Builtin), + + SOLVE("solve", CPType.Binary), + MAX("max", CPType.Binary), + MIN("min", CPType.Binary), + DROPINVALIDTYPE("dropInvalidType", CPType.Binary), + DROPINVALIDLENGTH("dropInvalidLength", CPType.Binary), + FREPLICATE("freplicate", CPType.Binary), + VALUESWAP("valueSwap", CPType.Binary), + APPLYSCHEMA("applySchema", CPType.Binary), + MAP("_map", CPType.Ternary), + + NMAX("nmax", CPType.BuiltinNary), + NMIN("nmin", CPType.BuiltinNary), + NP("n+", CPType.BuiltinNary), + NM("n*", CPType.BuiltinNary), + + EXP("exp", CPType.Unary), + ABS("abs", CPType.Unary), + SIN("sin", CPType.Unary), + COS("cos", CPType.Unary), + TAN("tan", CPType.Unary), + SINH("sinh", CPType.Unary), + COSH("cosh", CPType.Unary), + TANH("tanh", CPType.Unary), + ASIN("asin", CPType.Unary), + ACOS("acos", CPType.Unary), + ATAN("atan", CPType.Unary), + SIGN("sign", CPType.Unary), + SQRT("sqrt", CPType.Unary), + PLOGP("plogp", CPType.Unary), + PRINT("print", CPType.Unary), + ASSERT("assert", CPType.Unary), + ROUND("round", CPType.Unary), + CEIL("ceil", CPType.Unary), + FLOOR("floor", CPType.Unary), + UCUMKP("ucumk+", CPType.Unary), + UCUMM("ucum*", CPType.Unary), + UCUMKPM("ucumk+*", CPType.Unary), + UCUMMIN("ucummin", CPType.Unary), + UCUMMAX("ucummax", CPType.Unary), + STOP("stop", CPType.Unary), + INVERSE("inverse", CPType.Unary), + CHOLESKY("cholesky", CPType.Unary), + SPROP("sprop", CPType.Unary), + SIGMOID("sigmoid", CPType.Unary), + TYPEOF("typeOf", CPType.Unary), + DETECTSCHEMA("detectSchema", CPType.Unary), + COLNAMES("colnames", CPType.Unary), + ISNA("isna", CPType.Unary), + ISNAN("isnan", CPType.Unary), + ISINF("isinf", CPType.Unary), + PRINTF("printf", CPType.BuiltinNary), + CBIND("cbind", CPType.BuiltinNary), + RBIND("rbind", CPType.BuiltinNary), + EVAL("eval", CPType.BuiltinNary), + LIST("list", CPType.BuiltinNary), + + //Parametrized builtin functions + AUTODIFF("autoDiff", CPType.ParameterizedBuiltin), + CONTAINS("contains", CPType.ParameterizedBuiltin), + PARAMSERV("paramserv", CPType.ParameterizedBuiltin), + NVLIST("nvlist", CPType.ParameterizedBuiltin), + CDF("cdf", CPType.ParameterizedBuiltin), + INVCDF("invcdf", CPType.ParameterizedBuiltin), + GROUPEDAGG("groupedagg", CPType.ParameterizedBuiltin), + RMEMPTY("rmempty", CPType.ParameterizedBuiltin), + REPLACE("replace", CPType.ParameterizedBuiltin), + LOWERTRI("lowertri", CPType.ParameterizedBuiltin), + UPPERTRI("uppertri", CPType.ParameterizedBuiltin), + REXPAND("rexpand", CPType.ParameterizedBuiltin), + TOSTRING("toString", CPType.ParameterizedBuiltin), + TOKENIZE("tokenize", CPType.ParameterizedBuiltin), + TRANSFORMAPPLY("transformapply", CPType.ParameterizedBuiltin), + TRANSFORMDECODE("transformdecode", CPType.ParameterizedBuiltin), + TRANSFORMCOLMAP("transformcolmap", CPType.ParameterizedBuiltin), + TRANSFORMMETA("transformmeta", CPType.ParameterizedBuiltin), + TRANSFORMENCODE("transformencode", CPType.MultiReturnParameterizedBuiltin), + + //Ternary instruction opcodes + PM("+*", CPType.Ternary), + MINUSMULT("-*", CPType.Ternary), + IFELSE("ifelse", CPType.Ternary), + + //Variable instruction opcodes + ASSIGNVAR("assignvar", CPType.Variable), + CPVAR("cpvar", CPType.Variable), + MVVAR("mvvar", CPType.Variable), + RMVAR("rmvar", CPType.Variable), + RMFILEVAR("rmfilevar", CPType.Variable), + CAST_AS_SCALAR(OpOp1.CAST_AS_SCALAR.toString(), CPType.Variable), + CAST_AS_MATRIX(OpOp1.CAST_AS_MATRIX.toString(), CPType.Variable), + CAST_AS_FRAME_VAR("cast_as_frame", CPType.Variable), + CAST_AS_FRAME(OpOp1.CAST_AS_FRAME.toString(), CPType.Variable), + CAST_AS_LIST(OpOp1.CAST_AS_LIST.toString(), CPType.Variable), + CAST_AS_DOUBLE(OpOp1.CAST_AS_DOUBLE.toString(), CPType.Variable), + CAST_AS_INT(OpOp1.CAST_AS_INT.toString(), CPType.Variable), + CAST_AS_BOOLEAN(OpOp1.CAST_AS_BOOLEAN.toString(), CPType.Variable), + ATTACHFILETOVAR("attachfiletovar", CPType.Variable), + READ("read", CPType.Variable), + WRITE("write", CPType.Variable), + CREATEVAR("createvar", CPType.Variable), + + //Reorg instruction opcodes + TRANSPOSE("r'", CPType.Reorg), + REV("rev", CPType.Reorg), + ROLL("roll", CPType.Reorg), + DIAG("rdiag", CPType.Reorg), + RESHAPE("rshape", CPType.Reshape), + SORT("rsort", CPType.Reorg), + + // Opcodes related to convolutions + RELU_BACKWARD("relu_backward", CPType.Dnn), + RELU_MAXPOOLING("relu_maxpooling", CPType.Dnn), + RELU_MAXPOOLING_BACKWARD("relu_maxpooling_backward", CPType.Dnn), + MAXPOOLING("maxpooling", CPType.Dnn), + MAXPOOLING_BACKWARD("maxpooling_backward", CPType.Dnn), + AVGPOOLING("avgpooling", CPType.Dnn), + AVGPOOLING_BACKWARD("avgpooling_backward", CPType.Dnn), + CONV2D("conv2d", CPType.Dnn), + CONV2D_BIAS_ADD("conv2d_bias_add", CPType.Dnn), + CONV2D_BACKWARD_FILTER("conv2d_backward_filter", CPType.Dnn), + CONV2D_BACKWARD_DATA("conv2d_backward_data", CPType.Dnn), + BIAS_ADD("bias_add", CPType.Dnn), + BIAS_MULTIPLY("bias_multiply", CPType.Dnn), + BATCH_NORM2D("batch_norm2d", CPType.Dnn), + BATCH_NORM2D_BACKWARD("batch_norm2d_backward", CPType.Dnn), + LSTM("lstm", CPType.Dnn), + LSTM_BACKWARD("lstm_backward", CPType.Dnn), + + //Quaternary instruction opcodes + WSLOSS("wsloss", CPType.Quaternary), + WSIGMOID("wsigmoid", CPType.Quaternary), + WDIVMM("wdivmm", CPType.Quaternary), + WCEMM("wcemm", CPType.Quaternary), + WUMM("wumm", CPType.Quaternary), + + //User-defined function Opcodes + FCALL(FunctionOp.OPCODE, CPType.FCall), + + APPEND(Append.OPCODE, CPType.Append), + REMOVE("remove", CPType.Append), + + //data generation opcodes + RANDOM(DataGen.RAND_OPCODE, CPType.Rand), + SEQUENCE(DataGen.SEQ_OPCODE, CPType.Rand), + STRINGINIT(DataGen.SINIT_OPCODE, CPType.StringInit), + SAMPLE(DataGen.SAMPLE_OPCODE, CPType.Rand), + TIME(DataGen.TIME_OPCODE, CPType.Rand), + FRAME(DataGen.FRAME_OPCODE, CPType.Rand), + + CTABLE("ctable", CPType.Ctable), + CTABLEEXPAND("ctableexpand", CPType.Ctable), + + //central moment, covariance, quantiles (sort/pick) + CM("cm", CPType.CentralMoment), + COV("cov", CPType.Covariance), + QSORT("qsort", CPType.QSort), + QPICK("qpick", CPType.QPick), + + RIGHT_INDEX(RightIndex.OPCODE, CPType.MatrixIndexing), + LEFT_INDEX(LeftIndex.OPCODE, CPType.MatrixIndexing), + + TSMM("tsmm", CPType.MMTSJ), + PMM("pmm", CPType.PMMJ), + MMCHAIN("mmchain", CPType.MMChain), + + QR("qr", CPType.MultiReturnBuiltin), + LU("lu", CPType.MultiReturnBuiltin), + EIGEN("eigen", CPType.MultiReturnBuiltin), + FFT("fft", CPType.MultiReturnBuiltin), + IFFT("ifft", CPType.MultiReturnComplexMatrixBuiltin), + FFT_LINEARIZED("fft_linearized", CPType.MultiReturnBuiltin), + IFFT_LINEARIZED("ifft_linearized", CPType.MultiReturnComplexMatrixBuiltin), + STFT("stft", CPType.MultiReturnComplexMatrixBuiltin), + SVD("svd", CPType.MultiReturnBuiltin), + RCM("rcm", CPType.MultiReturnComplexMatrixBuiltin), + + PARTITION("partition", CPType.Partition), + COMPRESS(Compression.OPCODE, CPType.Compression), + DECOMPRESS(DeCompression.OPCODE, CPType.DeCompression), + SPOOF("spoof", CPType.SpoofFused), + PREFETCH("prefetch", CPType.Prefetch), + EVICT("_evict", CPType.EvictLineageCache), + BROADCAST("broadcast", CPType.Broadcast), + TRIGREMOTE("trigremote", CPType.TrigRemote), + LOCAL(Local.OPCODE, CPType.Local), + + SQL("sql", CPType.Sql); + + + // Constructor + Opcodes(String name, CPType type) { + this._name = name; + this._type = type; + } + + // Fields + private final String _name; + private final CPType _type; + + private static final Map<String, Opcodes> _lookupMap = new HashMap<>(); + + // Initialize lookup map + static { + for (Opcodes op : EnumSet.allOf(Opcodes.class)) { + _lookupMap.put(op.getName(), op); + } + } + + // Getters + public String getName() { + return _name; + } Review Comment: instead of this method, please overwrite `toString` which then allows to use `OpCode.XXX` instead of `OpCode.XXX.getName()` ########## src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java: ########## @@ -270,26 +271,26 @@ public String toString() { case VECT_CBIND: return "b(cbind)"; case VECT_BIASADD: return "b(vbias+)"; case VECT_BIASMULT: return "b(vbias*)"; - case MULT: return "b(*)"; - case DIV: return "b(/)"; - case PLUS: return "b(+)"; - case MINUS: return "b(-)"; - case POW: return "b(^)"; - case MODULUS: return "b(%%)"; - case INTDIV: return "b(%/%)"; - case LESS: return "b(<)"; - case LESSEQUAL: return "b(<=)"; - case GREATER: return "b(>)"; - case GREATEREQUAL: return "b(>=)"; - case EQUAL: return "b(==)"; - case NOTEQUAL: return "b(!=)"; + case MULT: return "b(" + Opcodes.MULT.getName() + ")"; + case DIV: return "b(" + Opcodes.DIV.getName() + ")"; + case PLUS: return "b(" + Opcodes.PLUS.getName() + ")"; + case MINUS: return "b(" + Opcodes.MINUS.getName() + ")"; + case POW: return "b(" + Opcodes.POW.getName() + ")"; + case MODULUS: return "b(" + Opcodes.MODULUS.getName() + ")"; + case INTDIV: return "b(" + Opcodes.INTDIV.getName() + ")"; + case LESS: return "b(" + Opcodes.LESS.getName() + ")"; + case LESSEQUAL: return "b(" + Opcodes.LESSEQUAL.getName() + ")"; + case GREATER: return "b(" + Opcodes.GREATER.getName() + ")"; + case GREATEREQUAL: return "b(" + Opcodes.GREATEREQUAL.getName() + ")"; + case EQUAL: return "b(" + Opcodes.EQUAL.getName() + ")"; + case NOTEQUAL: return "b(" + Opcodes.NOTEQUAL.getName() + ")"; case OR: return "b(|)"; case AND: return "b(&)"; - case XOR: return "b(xor)"; - case BITWAND: return "b(bitwAnd)"; + case XOR: return "b(" + Opcodes.XOR.getName() + ")"; Review Comment: XOR can use the default branch ########## src/main/java/org/apache/sysds/common/Opcodes.java: ########## @@ -0,0 +1,328 @@ +package org.apache.sysds.common; Review Comment: please add the license header ########## src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java: ########## @@ -79,288 +69,11 @@ public class CPInstructionParser extends InstructionParser { protected static final Log LOG = LogFactory.getLog(CPInstructionParser.class.getName()); - public static final HashMap<String, CPType> String2CPInstructionType; - static { - String2CPInstructionType = new HashMap<>(); - String2CPInstructionType.put( "ba+*" , CPType.AggregateBinary); - String2CPInstructionType.put( "tak+*" , CPType.AggregateTernary); - String2CPInstructionType.put( "tack+*" , CPType.AggregateTernary); - - String2CPInstructionType.put( "uak+" , CPType.AggregateUnary); - String2CPInstructionType.put( "uark+" , CPType.AggregateUnary); - String2CPInstructionType.put( "uack+" , CPType.AggregateUnary); - String2CPInstructionType.put( "uasqk+" , CPType.AggregateUnary); - String2CPInstructionType.put( "uarsqk+" , CPType.AggregateUnary); - String2CPInstructionType.put( "uacsqk+" , CPType.AggregateUnary); - String2CPInstructionType.put( "uamean" , CPType.AggregateUnary); - String2CPInstructionType.put( "uarmean" , CPType.AggregateUnary); - String2CPInstructionType.put( "uacmean" , CPType.AggregateUnary); - String2CPInstructionType.put( "uavar" , CPType.AggregateUnary); - String2CPInstructionType.put( "uarvar" , CPType.AggregateUnary); - String2CPInstructionType.put( "uacvar" , CPType.AggregateUnary); - String2CPInstructionType.put( "uamax" , CPType.AggregateUnary); - String2CPInstructionType.put( "uarmax" , CPType.AggregateUnary); - String2CPInstructionType.put( "uarimax" , CPType.AggregateUnary); - String2CPInstructionType.put( "uacmax" , CPType.AggregateUnary); - String2CPInstructionType.put( "uamin" , CPType.AggregateUnary); - String2CPInstructionType.put( "uarmin" , CPType.AggregateUnary); - String2CPInstructionType.put( "uarimin" , CPType.AggregateUnary); - String2CPInstructionType.put( "uacmin" , CPType.AggregateUnary); - String2CPInstructionType.put( "ua+" , CPType.AggregateUnary); - String2CPInstructionType.put( "uar+" , CPType.AggregateUnary); - String2CPInstructionType.put( "uac+" , CPType.AggregateUnary); - String2CPInstructionType.put( "ua*" , CPType.AggregateUnary); - String2CPInstructionType.put( "uar*" , CPType.AggregateUnary); - String2CPInstructionType.put( "uac*" , CPType.AggregateUnary); - String2CPInstructionType.put( "uatrace" , CPType.AggregateUnary); - String2CPInstructionType.put( "uaktrace", CPType.AggregateUnary); - String2CPInstructionType.put( "nrow" , CPType.AggregateUnary); - String2CPInstructionType.put( "ncol" , CPType.AggregateUnary); - String2CPInstructionType.put( "length" , CPType.AggregateUnary); - String2CPInstructionType.put( "exists" , CPType.AggregateUnary); - String2CPInstructionType.put( "lineage" , CPType.AggregateUnary); - String2CPInstructionType.put( "uacd" , CPType.AggregateUnary); - String2CPInstructionType.put( "uacdr" , CPType.AggregateUnary); - String2CPInstructionType.put( "uacdc" , CPType.AggregateUnary); - String2CPInstructionType.put( "uacdap" , CPType.AggregateUnary); - String2CPInstructionType.put( "uacdapr" , CPType.AggregateUnary); - String2CPInstructionType.put( "uacdapc" , CPType.AggregateUnary); - String2CPInstructionType.put( "unique" , CPType.AggregateUnary); - String2CPInstructionType.put( "uniquer" , CPType.AggregateUnary); - String2CPInstructionType.put( "uniquec" , CPType.AggregateUnary); - - String2CPInstructionType.put( "uaggouterchain", CPType.UaggOuterChain); - - // Arithmetic Instruction Opcodes - String2CPInstructionType.put( "+" , CPType.Binary); - String2CPInstructionType.put( "-" , CPType.Binary); - String2CPInstructionType.put( "*" , CPType.Binary); - String2CPInstructionType.put( "/" , CPType.Binary); - String2CPInstructionType.put( "%%" , CPType.Binary); - String2CPInstructionType.put( "%/%" , CPType.Binary); - String2CPInstructionType.put( "^" , CPType.Binary); - String2CPInstructionType.put( "1-*" , CPType.Binary); //special * case - String2CPInstructionType.put( "^2" , CPType.Binary); //special ^ case - String2CPInstructionType.put( "*2" , CPType.Binary); //special * case - String2CPInstructionType.put( "-nz" , CPType.Binary); //special - case - - // Boolean Instruction Opcodes - String2CPInstructionType.put( "&&" , CPType.Binary); - String2CPInstructionType.put( "||" , CPType.Binary); - String2CPInstructionType.put( "xor" , CPType.Binary); - String2CPInstructionType.put( "bitwAnd", CPType.Binary); - String2CPInstructionType.put( "bitwOr", CPType.Binary); - String2CPInstructionType.put( "bitwXor", CPType.Binary); - String2CPInstructionType.put( "bitwShiftL", CPType.Binary); - String2CPInstructionType.put( "bitwShiftR", CPType.Binary); - String2CPInstructionType.put( "!" , CPType.Unary); - - // Relational Instruction Opcodes - String2CPInstructionType.put( "==" , CPType.Binary); - String2CPInstructionType.put( "!=" , CPType.Binary); - String2CPInstructionType.put( "<" , CPType.Binary); - String2CPInstructionType.put( ">" , CPType.Binary); - String2CPInstructionType.put( "<=" , CPType.Binary); - String2CPInstructionType.put( ">=" , CPType.Binary); - - // Builtin Instruction Opcodes - String2CPInstructionType.put( "log" , CPType.Builtin); - String2CPInstructionType.put( "log_nz" , CPType.Builtin); - - String2CPInstructionType.put( "solve" , CPType.Binary); - String2CPInstructionType.put( "max" , CPType.Binary); - String2CPInstructionType.put( "min" , CPType.Binary); - String2CPInstructionType.put( "dropInvalidType" , CPType.Binary); - String2CPInstructionType.put( "dropInvalidLength" , CPType.Binary); - String2CPInstructionType.put( "freplicate" , CPType.Binary); - String2CPInstructionType.put( "valueSwap" , CPType.Binary); - String2CPInstructionType.put( "applySchema" , CPType.Binary); - String2CPInstructionType.put( "_map" , CPType.Ternary); // _map represents the operation map - - String2CPInstructionType.put( "nmax", CPType.BuiltinNary); - String2CPInstructionType.put( "nmin", CPType.BuiltinNary); - String2CPInstructionType.put( "n+" , CPType.BuiltinNary); - String2CPInstructionType.put( "n*" , CPType.BuiltinNary); - - String2CPInstructionType.put( "exp" , CPType.Unary); - String2CPInstructionType.put( "abs" , CPType.Unary); - String2CPInstructionType.put( "sin" , CPType.Unary); - String2CPInstructionType.put( "cos" , CPType.Unary); - String2CPInstructionType.put( "tan" , CPType.Unary); - String2CPInstructionType.put( "sinh" , CPType.Unary); - String2CPInstructionType.put( "cosh" , CPType.Unary); - String2CPInstructionType.put( "tanh" , CPType.Unary); - String2CPInstructionType.put( "asin" , CPType.Unary); - String2CPInstructionType.put( "acos" , CPType.Unary); - String2CPInstructionType.put( "atan" , CPType.Unary); - String2CPInstructionType.put( "sign" , CPType.Unary); - String2CPInstructionType.put( "sqrt" , CPType.Unary); - String2CPInstructionType.put( "plogp" , CPType.Unary); - String2CPInstructionType.put( "print" , CPType.Unary); - String2CPInstructionType.put( "assert" , CPType.Unary); - String2CPInstructionType.put( "round" , CPType.Unary); - String2CPInstructionType.put( "ceil" , CPType.Unary); - String2CPInstructionType.put( "floor" , CPType.Unary); - String2CPInstructionType.put( "ucumk+", CPType.Unary); - String2CPInstructionType.put( "ucum*" , CPType.Unary); - String2CPInstructionType.put( "ucumk+*" , CPType.Unary); - String2CPInstructionType.put( "ucummin", CPType.Unary); - String2CPInstructionType.put( "ucummax", CPType.Unary); - String2CPInstructionType.put( "stop" , CPType.Unary); - String2CPInstructionType.put( "inverse", CPType.Unary); - String2CPInstructionType.put( "cholesky",CPType.Unary); - String2CPInstructionType.put( "sprop", CPType.Unary); - String2CPInstructionType.put( "sigmoid", CPType.Unary); - String2CPInstructionType.put( "typeOf", CPType.Unary); - String2CPInstructionType.put( "detectSchema", CPType.Unary); - String2CPInstructionType.put( "colnames", CPType.Unary); - String2CPInstructionType.put( "isna", CPType.Unary); - String2CPInstructionType.put( "isnan", CPType.Unary); - String2CPInstructionType.put( "isinf", CPType.Unary); - String2CPInstructionType.put( "printf", CPType.BuiltinNary); - String2CPInstructionType.put( "cbind", CPType.BuiltinNary); - String2CPInstructionType.put( "rbind", CPType.BuiltinNary); - String2CPInstructionType.put( "eval", CPType.BuiltinNary); - String2CPInstructionType.put( "list", CPType.BuiltinNary); - - // Parameterized Builtin Functions - String2CPInstructionType.put( "autoDiff" , CPType.ParameterizedBuiltin); - String2CPInstructionType.put( "contains", CPType.ParameterizedBuiltin); - String2CPInstructionType.put("paramserv", CPType.ParameterizedBuiltin); - String2CPInstructionType.put( "nvlist", CPType.ParameterizedBuiltin); - String2CPInstructionType.put( "cdf", CPType.ParameterizedBuiltin); - String2CPInstructionType.put( "invcdf", CPType.ParameterizedBuiltin); - String2CPInstructionType.put( "groupedagg", CPType.ParameterizedBuiltin); - String2CPInstructionType.put( "rmempty" , CPType.ParameterizedBuiltin); - String2CPInstructionType.put( "replace", CPType.ParameterizedBuiltin); - String2CPInstructionType.put( "lowertri", CPType.ParameterizedBuiltin); - String2CPInstructionType.put( "uppertri", CPType.ParameterizedBuiltin); - String2CPInstructionType.put( "rexpand", CPType.ParameterizedBuiltin); - String2CPInstructionType.put( "toString", CPType.ParameterizedBuiltin); - String2CPInstructionType.put( "tokenize", CPType.ParameterizedBuiltin); - String2CPInstructionType.put( "transformapply", CPType.ParameterizedBuiltin); - String2CPInstructionType.put( "transformdecode",CPType.ParameterizedBuiltin); - String2CPInstructionType.put( "transformcolmap",CPType.ParameterizedBuiltin); - String2CPInstructionType.put( "transformmeta", CPType.ParameterizedBuiltin); - String2CPInstructionType.put( "transformencode",CPType.MultiReturnParameterizedBuiltin); - - // Ternary Instruction Opcodes - String2CPInstructionType.put( "+*", CPType.Ternary); - String2CPInstructionType.put( "-*", CPType.Ternary); - String2CPInstructionType.put( "ifelse", CPType.Ternary); - - // Variable Instruction Opcodes - String2CPInstructionType.put( "assignvar" , CPType.Variable); - String2CPInstructionType.put( "cpvar" , CPType.Variable); - String2CPInstructionType.put( "mvvar" , CPType.Variable); - String2CPInstructionType.put( "rmvar" , CPType.Variable); - String2CPInstructionType.put( "rmfilevar" , CPType.Variable); - String2CPInstructionType.put( OpOp1.CAST_AS_SCALAR.toString(), CPType.Variable); - String2CPInstructionType.put( OpOp1.CAST_AS_MATRIX.toString(), CPType.Variable); - String2CPInstructionType.put( "cast_as_frame", CPType.Variable); - String2CPInstructionType.put( OpOp1.CAST_AS_FRAME.toString(), CPType.Variable); - String2CPInstructionType.put( OpOp1.CAST_AS_LIST.toString(), CPType.Variable); - String2CPInstructionType.put( OpOp1.CAST_AS_DOUBLE.toString(), CPType.Variable); - String2CPInstructionType.put( OpOp1.CAST_AS_INT.toString(), CPType.Variable); - String2CPInstructionType.put( OpOp1.CAST_AS_BOOLEAN.toString(), CPType.Variable); - String2CPInstructionType.put( "attachfiletovar" , CPType.Variable); - String2CPInstructionType.put( "read" , CPType.Variable); - String2CPInstructionType.put( "write" , CPType.Variable); - String2CPInstructionType.put( "createvar" , CPType.Variable); - - // Reorg Instruction Opcodes (repositioning of existing values) - String2CPInstructionType.put( "r'" , CPType.Reorg); - String2CPInstructionType.put( "rev" , CPType.Reorg); - String2CPInstructionType.put( "roll" , CPType.Reorg); - String2CPInstructionType.put( "rdiag" , CPType.Reorg); - String2CPInstructionType.put( "rshape" , CPType.Reshape); - String2CPInstructionType.put( "rsort" , CPType.Reorg); - - // Opcodes related to convolutions - String2CPInstructionType.put( "relu_backward" , CPType.Dnn); - String2CPInstructionType.put( "relu_maxpooling" , CPType.Dnn); - String2CPInstructionType.put( "relu_maxpooling_backward" , CPType.Dnn); - String2CPInstructionType.put( "maxpooling" , CPType.Dnn); - String2CPInstructionType.put( "maxpooling_backward" , CPType.Dnn); - String2CPInstructionType.put( "avgpooling" , CPType.Dnn); - String2CPInstructionType.put( "avgpooling_backward" , CPType.Dnn); - String2CPInstructionType.put( "conv2d" , CPType.Dnn); - String2CPInstructionType.put( "conv2d_bias_add" , CPType.Dnn); - String2CPInstructionType.put( "conv2d_backward_filter" , CPType.Dnn); - String2CPInstructionType.put( "conv2d_backward_data" , CPType.Dnn); - String2CPInstructionType.put( "bias_add" , CPType.Dnn); - String2CPInstructionType.put( "bias_multiply" , CPType.Dnn); - String2CPInstructionType.put( "batch_norm2d", CPType.Dnn); - String2CPInstructionType.put( "batch_norm2d_backward", CPType.Dnn); - String2CPInstructionType.put( "lstm" , CPType.Dnn); - String2CPInstructionType.put( "lstm_backward" , CPType.Dnn); - - // Quaternary instruction opcodes - String2CPInstructionType.put( "wsloss" , CPType.Quaternary); - String2CPInstructionType.put( "wsigmoid", CPType.Quaternary); - String2CPInstructionType.put( "wdivmm", CPType.Quaternary); - String2CPInstructionType.put( "wcemm", CPType.Quaternary); - String2CPInstructionType.put( "wumm", CPType.Quaternary); - - // User-defined function Opcodes - String2CPInstructionType.put(FunctionOp.OPCODE, CPType.FCall); - - String2CPInstructionType.put(Append.OPCODE, CPType.Append); - String2CPInstructionType.put( "remove", CPType.Append); - - // data generation opcodes - String2CPInstructionType.put( DataGen.RAND_OPCODE , CPType.Rand); - String2CPInstructionType.put( DataGen.SEQ_OPCODE , CPType.Rand); - String2CPInstructionType.put( DataGen.SINIT_OPCODE , CPType.StringInit); - String2CPInstructionType.put( DataGen.SAMPLE_OPCODE , CPType.Rand); - String2CPInstructionType.put( DataGen.TIME_OPCODE , CPType.Rand); - String2CPInstructionType.put( DataGen.FRAME_OPCODE , CPType.Rand); - - String2CPInstructionType.put( "ctable", CPType.Ctable); - String2CPInstructionType.put( "ctableexpand", CPType.Ctable); - - //central moment, covariance, quantiles (sort/pick) - String2CPInstructionType.put( "cm", CPType.CentralMoment); - String2CPInstructionType.put( "cov", CPType.Covariance); - String2CPInstructionType.put( "qsort", CPType.QSort); - String2CPInstructionType.put( "qpick", CPType.QPick); - - - String2CPInstructionType.put( RightIndex.OPCODE, CPType.MatrixIndexing); - String2CPInstructionType.put( LeftIndex.OPCODE, CPType.MatrixIndexing); - - String2CPInstructionType.put( "tsmm", CPType.MMTSJ); - String2CPInstructionType.put( "pmm", CPType.PMMJ); - String2CPInstructionType.put( "mmchain", CPType.MMChain); - - String2CPInstructionType.put( "qr", CPType.MultiReturnBuiltin); - String2CPInstructionType.put( "lu", CPType.MultiReturnBuiltin); - String2CPInstructionType.put( "eigen", CPType.MultiReturnBuiltin); - String2CPInstructionType.put( "fft", CPType.MultiReturnBuiltin); - String2CPInstructionType.put( "ifft", CPType.MultiReturnComplexMatrixBuiltin); - String2CPInstructionType.put( "fft_linearized", CPType.MultiReturnBuiltin); - String2CPInstructionType.put( "ifft_linearized", CPType.MultiReturnComplexMatrixBuiltin); - String2CPInstructionType.put( "stft", CPType.MultiReturnComplexMatrixBuiltin); - String2CPInstructionType.put( "svd", CPType.MultiReturnBuiltin); - String2CPInstructionType.put( "rcm", CPType.MultiReturnComplexMatrixBuiltin); - - String2CPInstructionType.put( "partition", CPType.Partition); - String2CPInstructionType.put( Compression.OPCODE, CPType.Compression); - String2CPInstructionType.put( DeCompression.OPCODE, CPType.DeCompression); - String2CPInstructionType.put( "spoof", CPType.SpoofFused); - String2CPInstructionType.put( "prefetch", CPType.Prefetch); - String2CPInstructionType.put( "_evict", CPType.EvictLineageCache); - String2CPInstructionType.put( "broadcast", CPType.Broadcast); - String2CPInstructionType.put( "trigremote", CPType.TrigRemote); - String2CPInstructionType.put( Local.OPCODE, CPType.Local); - - String2CPInstructionType.put( "sql", CPType.Sql); - } - public static CPInstruction parseSingleInstruction (String str ) { if ( str == null || str.isEmpty() ) return null; - CPType cptype = InstructionUtils.getCPType(str); + CPType cptype = InstructionUtils.getCPType(str); + //CPType cptype = Opcodes.getCPTypeByOpcode(str); Review Comment: avoid such commented code -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: dev-unsubscr...@systemds.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org