http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/smile/tools/RandomForestEnsembleUDAF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/smile/tools/RandomForestEnsembleUDAF.java b/core/src/main/java/hivemall/smile/tools/RandomForestEnsembleUDAF.java index 7fd841a..40957cb 100644 --- a/core/src/main/java/hivemall/smile/tools/RandomForestEnsembleUDAF.java +++ b/core/src/main/java/hivemall/smile/tools/RandomForestEnsembleUDAF.java @@ -18,127 +18,289 @@ */ package hivemall.smile.tools; -import hivemall.utils.collections.IntArrayList; -import hivemall.utils.lang.Counter; +import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.hadoop.WritableUtils; +import hivemall.utils.lang.Preconditions; +import hivemall.utils.lang.SizeOf; -import java.util.Arrays; +import java.util.ArrayList; import java.util.List; -import java.util.Map; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; import org.apache.hadoop.hive.ql.exec.Description; -import org.apache.hadoop.hive.ql.exec.UDAF; -import org.apache.hadoop.hive.ql.exec.UDAFEvaluator; +import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.parse.SemanticException; +import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AbstractAggregationBuffer; +import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryArray; +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StandardListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StructField; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.IntObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; +import org.apache.hadoop.io.IntWritable; + +@Description( + name = "rf_ensemble", + value = "_FUNC_(int yhat, array<double> proba [, double model_weight=1.0])" + + " - Returns emsebled prediction results in <int label, double probability, array<double> probabilities>") +public final class RandomForestEnsembleUDAF extends AbstractGenericUDAFResolver { + + public RandomForestEnsembleUDAF() { + super(); + } + + @Override + public GenericUDAFEvaluator getEvaluator(@Nonnull TypeInfo[] typeInfo) throws SemanticException { + if (typeInfo.length != 2 && typeInfo.length != 3) { + throw new UDFArgumentLengthException("Expected 2 or 3 arguments but got " + + typeInfo.length); + } + if (!HiveUtils.isIntegerTypeInfo(typeInfo[0])) { + throw new UDFArgumentTypeException(0, "Expected INT for yhat: " + typeInfo[0]); + } + if (!HiveUtils.isFloatingPointListTypeInfo(typeInfo[1])) { + throw new UDFArgumentTypeException(1, "ARRAY<double> is expected for posteriori: " + + typeInfo[1]); + } + if (typeInfo.length == 3) { + if (!HiveUtils.isFloatingPointTypeInfo(typeInfo[2])) { + throw new UDFArgumentTypeException(2, "Expected DOUBLE or FLOAT for model_weight: " + + typeInfo[2]); + } + } + return new RfEvaluator(); + } -@SuppressWarnings("deprecation") -@Description(name = "rf_ensemble", - value = "_FUNC_(int y) - Returns emsebled prediction results of Random Forest classifiers") -public final class RandomForestEnsembleUDAF extends UDAF { - public static class RandomForestPredictUDAFEvaluator implements UDAFEvaluator { + @SuppressWarnings("deprecation") + public static final class RfEvaluator extends GenericUDAFEvaluator { - private Counter<Integer> partial; + private PrimitiveObjectInspector yhatOI; + private ListObjectInspector posterioriOI; + private PrimitiveObjectInspector posterioriElemOI; + @Nullable + private PrimitiveObjectInspector weightOI; - @Override - public void init() { - this.partial = null; + private StructObjectInspector internalMergeOI; + private StructField sizeField, posterioriField; + private IntObjectInspector sizeFieldOI; + private StandardListObjectInspector posterioriFieldOI; + + public RfEvaluator() { + super(); } - public boolean iterate(Integer k) { - if (k == null) { - return true; + @Override + public ObjectInspector init(@Nonnull Mode mode, @Nonnull ObjectInspector[] parameters) + throws HiveException { + super.init(mode, parameters); + // initialize input + if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {// from original data + this.yhatOI = HiveUtils.asIntegerOI(parameters[0]); + this.posterioriOI = HiveUtils.asListOI(parameters[1]); + this.posterioriElemOI = HiveUtils.asDoubleCompatibleOI(posterioriOI.getListElementObjectInspector()); + if (parameters.length == 3) { + this.weightOI = HiveUtils.asDoubleCompatibleOI(parameters[2]); + } + } else {// from partial aggregation + StructObjectInspector soi = (StructObjectInspector) parameters[0]; + this.internalMergeOI = soi; + this.sizeField = soi.getStructFieldRef("size"); + this.posterioriField = soi.getStructFieldRef("posteriori"); + this.sizeFieldOI = PrimitiveObjectInspectorFactory.writableIntObjectInspector; + this.posterioriFieldOI = ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); } - if (partial == null) { - this.partial = new Counter<Integer>(); + + // initialize output + final ObjectInspector outputOI; + if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) {// terminatePartial + List<String> fieldNames = new ArrayList<>(3); + List<ObjectInspector> fieldOIs = new ArrayList<>(3); + fieldNames.add("size"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); + fieldNames.add("posteriori"); + fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector)); + outputOI = ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, + fieldOIs); + } else {// terminate + List<String> fieldNames = new ArrayList<>(3); + List<ObjectInspector> fieldOIs = new ArrayList<>(3); + fieldNames.add("label"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); + fieldNames.add("probability"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); + fieldNames.add("probabilities"); + fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector)); + outputOI = ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, + fieldOIs); } - partial.increment(k); - return true; + return outputOI; } - /* - * https://cwiki.apache.org/confluence/display/Hive/GenericUDAFCaseStudy#GenericUDAFCaseStudy-terminatePartial - */ - public Map<Integer, Integer> terminatePartial() { - if (partial == null) { - return null; + @Override + public RfAggregationBuffer getNewAggregationBuffer() throws HiveException { + RfAggregationBuffer buf = new RfAggregationBuffer(); + reset(buf); + return buf; + } + + @Override + public void reset(AggregationBuffer agg) throws HiveException { + RfAggregationBuffer buf = (RfAggregationBuffer) agg; + buf.reset(); + } + + @Override + public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException { + RfAggregationBuffer buf = (RfAggregationBuffer) agg; + + Preconditions.checkNotNull(parameters[0]); + int yhat = PrimitiveObjectInspectorUtils.getInt(parameters[0], yhatOI); + Preconditions.checkNotNull(parameters[1]); + double[] posteriori = HiveUtils.asDoubleArray(parameters[1], posterioriOI, + posterioriElemOI); + + double weight = 1.0d; + if (parameters.length == 3) { + Preconditions.checkNotNull(parameters[2]); + weight = PrimitiveObjectInspectorUtils.getDouble(parameters[2], weightOI); } - if (partial.size() == 0) { + buf.iterate(yhat, weight, posteriori); + } + + @Override + public Object terminatePartial(AggregationBuffer agg) throws HiveException { + RfAggregationBuffer buf = (RfAggregationBuffer) agg; + if (buf._k == -1) { return null; - } else { - return partial.getMap(); // CAN NOT return Counter here } + + Object[] partial = new Object[2]; + partial[0] = new IntWritable(buf._k); + partial[1] = WritableUtils.toWritableList(buf._posteriori); + return partial; } - public boolean merge(Map<Integer, Integer> o) { - if (o == null) { - return true; + @Override + public void merge(AggregationBuffer agg, Object partial) throws HiveException { + if (partial == null) { + return; } + RfAggregationBuffer buf = (RfAggregationBuffer) agg; - if (partial == null) { - this.partial = new Counter<Integer>(); + Object o1 = internalMergeOI.getStructFieldData(partial, sizeField); + int size = sizeFieldOI.get(o1); + Object posteriori = internalMergeOI.getStructFieldData(partial, posterioriField); + + // -------------------------------------------------------------- + // [workaround] + // java.lang.ClassCastException: org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryArray + // cannot be cast to [Ljava.lang.Object; + if (posteriori instanceof LazyBinaryArray) { + posteriori = ((LazyBinaryArray) posteriori).getList(); } - partial.addAll(o); - return true; + + buf.merge(size, posteriori, posterioriFieldOI); } - public Result terminate() { - if (partial == null) { - return null; - } - if (partial.size() == 0) { + @Override + public Object terminate(AggregationBuffer agg) throws HiveException { + RfAggregationBuffer buf = (RfAggregationBuffer) agg; + if (buf._k == -1) { return null; } - return new Result(partial); + double[] posteriori = buf._posteriori; + int label = smile.math.Math.whichMax(posteriori); + smile.math.Math.unitize1(posteriori); + double proba = posteriori[label]; + + Object[] result = new Object[3]; + result[0] = new IntWritable(label); + result[1] = new DoubleWritable(proba); + result[2] = WritableUtils.toWritableList(posteriori); + return result; } + } - public static final class Result { - @SuppressWarnings("unused") - private Integer label; - @SuppressWarnings("unused") - private Double probability; - @SuppressWarnings("unused") - private List<Double> probabilities; - - Result(Counter<Integer> partial) { - final Map<Integer, Integer> counts = partial.getMap(); - int size = counts.size(); - assert (size > 0) : size; - IntArrayList keyList = new IntArrayList(size); - - long totalCnt = 0L; - Integer maxKey = null; - int maxCnt = Integer.MIN_VALUE; - for (Map.Entry<Integer, Integer> e : counts.entrySet()) { - Integer key = e.getKey(); - keyList.add(key); - int cnt = e.getValue().intValue(); - totalCnt += cnt; - if (cnt >= maxCnt) { - maxCnt = cnt; - maxKey = key; - } + public static final class RfAggregationBuffer extends AbstractAggregationBuffer { + + @Nullable + private double[] _posteriori; + private int _k; + + public RfAggregationBuffer() { + super(); + reset(); + } + + void reset() { + this._posteriori = null; + this._k = -1; + } + + void iterate(final int yhat, final double weight, @Nonnull final double[] posteriori) + throws HiveException { + if (_posteriori == null) { + this._k = posteriori.length; + this._posteriori = new double[_k]; } + if (yhat >= _k) { + throw new HiveException("Predicted class " + yhat + " is out of bounds: " + _k); + } + if (posteriori.length != _k) { + throw new HiveException("Given |posteriori| " + posteriori.length + + " is differs from expected one: " + _k); + } + + _posteriori[yhat] += (posteriori[yhat] * weight); + } - int[] keyArray = keyList.toArray(); - Arrays.sort(keyArray); - int last = keyArray[keyArray.length - 1]; + void merge(int size, @Nonnull Object posterioriObj, + @Nonnull StandardListObjectInspector posterioriOI) throws HiveException { - double totalCnt_d = (double) totalCnt; - final Double[] probabilities = new Double[Math.max(2, last + 1)]; - for (int i = 0, len = probabilities.length; i < len; i++) { - final Integer cnt = counts.get(Integer.valueOf(i)); - if (cnt == null) { - probabilities[i] = Double.valueOf(0d); + if (size != _k) { + if (_k == -1) { + this._k = size; + this._posteriori = new double[size]; } else { - probabilities[i] = Double.valueOf(cnt.intValue() / totalCnt_d); + throw new HiveException("Mismatch in the number of elements: _k=" + _k + + ", size=" + size); } } - this.label = maxKey; - this.probability = Double.valueOf(maxCnt / totalCnt_d); - this.probabilities = Arrays.asList(probabilities); + + final double[] posteriori = _posteriori; + final DoubleObjectInspector doubleOI = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector; + for (int i = 0, len = _k; i < len; i++) { + Object o2 = posterioriOI.getListElement(posterioriObj, i); + posteriori[i] += doubleOI.get(o2); + } } + + @Override + public int estimate() { + if (_k == -1) { + return 0; + } + return SizeOf.INT + _k * SizeOf.DOUBLE; + } + } }
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/smile/tools/TreePredictUDF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/smile/tools/TreePredictUDF.java b/core/src/main/java/hivemall/smile/tools/TreePredictUDF.java index b5a81d4..dc544ae 100644 --- a/core/src/main/java/hivemall/smile/tools/TreePredictUDF.java +++ b/core/src/main/java/hivemall/smile/tools/TreePredictUDF.java @@ -18,31 +18,26 @@ */ package hivemall.smile.tools; -import hivemall.smile.ModelType; +import hivemall.math.vector.DenseVector; +import hivemall.math.vector.SparseVector; +import hivemall.math.vector.Vector; import hivemall.smile.classification.DecisionTree; +import hivemall.smile.classification.PredictionHandler; import hivemall.smile.regression.RegressionTree; -import hivemall.smile.vm.StackMachine; -import hivemall.smile.vm.VMRuntimeException; import hivemall.utils.codec.Base91; -import hivemall.utils.codec.DeflateCodec; import hivemall.utils.hadoop.HiveUtils; -import hivemall.utils.io.IOUtils; +import hivemall.utils.hadoop.WritableUtils; +import hivemall.utils.lang.Preconditions; -import java.io.Closeable; import java.io.IOException; +import java.util.ArrayList; import java.util.Arrays; +import java.util.List; import javax.annotation.Nonnull; import javax.annotation.Nullable; -import javax.script.Bindings; -import javax.script.Compilable; -import javax.script.CompiledScript; -import javax.script.ScriptEngine; -import javax.script.ScriptEngineManager; -import javax.script.ScriptException; import org.apache.hadoop.hive.ql.exec.Description; -import org.apache.hadoop.hive.ql.exec.MapredContext; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.udf.UDFType; @@ -50,73 +45,75 @@ import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; import org.apache.hadoop.hive.serde2.io.DoubleWritable; import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.Text; -import org.apache.hadoop.io.Writable; -import org.apache.hadoop.mapred.JobConf; @Description( name = "tree_predict", - value = "_FUNC_(string modelId, int modelType, string script, array<double> features [, const boolean classification])" + value = "_FUNC_(string modelId, string model, array<double|string> features [, const boolean classification])" + " - Returns a prediction result of a random forest") @UDFType(deterministic = true, stateful = false) public final class TreePredictUDF extends GenericUDF { private boolean classification; - private PrimitiveObjectInspector modelTypeOI; - private StringObjectInspector stringOI; + private StringObjectInspector modelOI; private ListObjectInspector featureListOI; private PrimitiveObjectInspector featureElemOI; + private boolean denseInput; + @Nullable + private Vector featuresProbe; @Nullable private transient Evaluator evaluator; - private boolean support_javascript_eval = true; - - @Override - public void configure(MapredContext context) { - super.configure(context); - - if (context != null) { - JobConf conf = context.getJobConf(); - String tdJarVersion = conf.get("td.jar.version"); - if (tdJarVersion != null) { - this.support_javascript_eval = false; - } - } - } @Override public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { - if (argOIs.length != 4 && argOIs.length != 5) { - throw new UDFArgumentException("_FUNC_ takes 4 or 5 arguments"); + if (argOIs.length != 3 && argOIs.length != 4) { + throw new UDFArgumentException("_FUNC_ takes 3 or 4 arguments"); } - this.modelTypeOI = HiveUtils.asIntegerOI(argOIs[1]); - this.stringOI = HiveUtils.asStringOI(argOIs[2]); - ListObjectInspector listOI = HiveUtils.asListOI(argOIs[3]); + this.modelOI = HiveUtils.asStringOI(argOIs[1]); + ListObjectInspector listOI = HiveUtils.asListOI(argOIs[2]); this.featureListOI = listOI; ObjectInspector elemOI = listOI.getListElementObjectInspector(); - this.featureElemOI = HiveUtils.asDoubleCompatibleOI(elemOI); + if (HiveUtils.isNumberOI(elemOI)) { + this.featureElemOI = HiveUtils.asDoubleCompatibleOI(elemOI); + this.denseInput = true; + } else if (HiveUtils.isStringOI(elemOI)) { + this.featureElemOI = HiveUtils.asStringOI(elemOI); + this.denseInput = false; + } else { + throw new UDFArgumentException( + "_FUNC_ takes array<double> or array<string> for the second argument: " + + listOI.getTypeName()); + } boolean classification = false; - if (argOIs.length == 5) { - classification = HiveUtils.getConstBoolean(argOIs[4]); + if (argOIs.length == 4) { + classification = HiveUtils.getConstBoolean(argOIs[3]); } this.classification = classification; if (classification) { - return PrimitiveObjectInspectorFactory.writableIntObjectInspector; + List<String> fieldNames = new ArrayList<String>(2); + List<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(2); + fieldNames.add("value"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); + fieldNames.add("posteriori"); + fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector)); + return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs); } else { return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector; } } @Override - public Writable evaluate(@Nonnull DeferredObject[] arguments) throws HiveException { + public Object evaluate(@Nonnull DeferredObject[] arguments) throws HiveException { Object arg0 = arguments[0].get(); if (arg0 == null) { throw new HiveException("ModelId was null"); @@ -125,67 +122,96 @@ public final class TreePredictUDF extends GenericUDF { String modelId = arg0.toString(); Object arg1 = arguments[1].get(); - int modelTypeId = PrimitiveObjectInspectorUtils.getInt(arg1, modelTypeOI); - ModelType modelType = ModelType.resolve(modelTypeId); - - Object arg2 = arguments[2].get(); - if (arg2 == null) { + if (arg1 == null) { return null; } - Text script = stringOI.getPrimitiveWritableObject(arg2); + Text model = modelOI.getPrimitiveWritableObject(arg1); - Object arg3 = arguments[3].get(); - if (arg3 == null) { + Object arg2 = arguments[2].get(); + if (arg2 == null) { throw new HiveException("array<double> features was null"); } - double[] features = HiveUtils.asDoubleArray(arg3, featureListOI, featureElemOI); + this.featuresProbe = parseFeatures(arg2, featuresProbe); if (evaluator == null) { - this.evaluator = getEvaluator(modelType, support_javascript_eval); + this.evaluator = classification ? new ClassificationEvaluator() + : new RegressionEvaluator(); } - - Writable result = evaluator.evaluate(modelId, modelType.isCompressed(), script, features, - classification); - return result; + return evaluator.evaluate(modelId, model, featuresProbe); } @Nonnull - private static Evaluator getEvaluator(@Nonnull ModelType type, boolean supportJavascriptEval) + private Vector parseFeatures(@Nonnull final Object argObj, @Nullable Vector probe) throws UDFArgumentException { - final Evaluator evaluator; - switch (type) { - case serialization: - case serialization_compressed: { - evaluator = new JavaSerializationEvaluator(); - break; + if (denseInput) { + final int length = featureListOI.getListLength(argObj); + if (probe == null) { + probe = new DenseVector(length); + } else if (length != probe.size()) { + probe = new DenseVector(length); } - case opscode: - case opscode_compressed: { - evaluator = new StackmachineEvaluator(); - break; + + for (int i = 0; i < length; i++) { + final Object o = featureListOI.getListElement(argObj, i); + if (o == null) { + probe.set(i, 0.d); + } else { + double v = PrimitiveObjectInspectorUtils.getDouble(o, featureElemOI); + probe.set(i, v); + } } - case javascript: - case javascript_compressed: { - if (!supportJavascriptEval) { + } else { + if (probe == null) { + probe = new SparseVector(); + } else { + probe.clear(); + } + + final int length = featureListOI.getListLength(argObj); + for (int i = 0; i < length; i++) { + Object o = featureListOI.getListElement(argObj, i); + if (o == null) { + continue; + } + String col = o.toString(); + + final int pos = col.indexOf(':'); + if (pos == 0) { + throw new UDFArgumentException("Invalid feature value representation: " + col); + } + + final String feature; + final double value; + if (pos > 0) { + feature = col.substring(0, pos); + String s2 = col.substring(pos + 1); + value = Double.parseDouble(s2); + } else { + feature = col; + value = 1.d; + } + + if (feature.indexOf(':') != -1) { + throw new UDFArgumentException("Invaliad feature format `<index>:<value>`: " + + col); + } + + final int colIndex = Integer.parseInt(feature); + if (colIndex < 0) { throw new UDFArgumentException( - "Javascript evaluation is not allowed in Treasure Data env"); + "Col index MUST be greather than or equals to 0: " + colIndex); } - evaluator = new JavascriptEvaluator(); - break; + probe.set(colIndex, value); } - default: - throw new UDFArgumentException("Unexpected model type was detected: " + type); } - return evaluator; + return probe; } @Override public void close() throws IOException { - this.modelTypeOI = null; - this.stringOI = null; + this.modelOI = null; this.featureElemOI = null; this.featureListOI = null; - IOUtils.closeQuietly(evaluator); this.evaluator = null; } @@ -194,224 +220,81 @@ public final class TreePredictUDF extends GenericUDF { return "tree_predict(" + Arrays.toString(children) + ")"; } - public interface Evaluator extends Closeable { + interface Evaluator { - @Nullable - Writable evaluate(@Nonnull String modelId, boolean compressed, @Nonnull final Text script, - @Nonnull final double[] features, final boolean classification) + @Nonnull + Object evaluate(@Nonnull String modelId, @Nonnull Text model, @Nonnull Vector features) throws HiveException; } - static final class JavaSerializationEvaluator implements Evaluator { + static final class ClassificationEvaluator implements Evaluator { + + @Nonnull + private final Object[] result; @Nullable private String prevModelId = null; private DecisionTree.Node cNode = null; - private RegressionTree.Node rNode = null; - - JavaSerializationEvaluator() {} - - @Override - public Writable evaluate(@Nonnull String modelId, boolean compressed, @Nonnull Text script, - double[] features, boolean classification) throws HiveException { - if (classification) { - return evaluateClassification(modelId, compressed, script, features); - } else { - return evaluteRegression(modelId, compressed, script, features); - } - } - private IntWritable evaluateClassification(@Nonnull String modelId, boolean compressed, - @Nonnull Text script, double[] features) throws HiveException { - if (!modelId.equals(prevModelId)) { - this.prevModelId = modelId; - int length = script.getLength(); - byte[] b = script.getBytes(); - b = Base91.decode(b, 0, length); - this.cNode = DecisionTree.deserializeNode(b, b.length, compressed); - } - assert (cNode != null); - int result = cNode.predict(features); - return new IntWritable(result); + ClassificationEvaluator() { + this.result = new Object[2]; } - private DoubleWritable evaluteRegression(@Nonnull String modelId, boolean compressed, - @Nonnull Text script, double[] features) throws HiveException { + @Nonnull + public Object[] evaluate(@Nonnull final String modelId, @Nonnull final Text script, + @Nonnull final Vector features) throws HiveException { if (!modelId.equals(prevModelId)) { this.prevModelId = modelId; int length = script.getLength(); byte[] b = script.getBytes(); b = Base91.decode(b, 0, length); - this.rNode = RegressionTree.deserializeNode(b, b.length, compressed); - } - assert (rNode != null); - double result = rNode.predict(features); - return new DoubleWritable(result); - } - - @Override - public void close() throws IOException {} - - } - - static final class StackmachineEvaluator implements Evaluator { - - private String prevModelId = null; - private StackMachine prevVM = null; - private DeflateCodec codec = null; - - StackmachineEvaluator() {} - - @Override - public Writable evaluate(@Nonnull String modelId, boolean compressed, @Nonnull Text script, - double[] features, boolean classification) throws HiveException { - final String scriptStr; - if (compressed) { - if (codec == null) { - this.codec = new DeflateCodec(false, true); - } - byte[] b = script.getBytes(); - int len = script.getLength(); - b = Base91.decode(b, 0, len); - try { - b = codec.decompress(b); - } catch (IOException e) { - throw new HiveException("decompression failed", e); - } - scriptStr = new String(b); - } else { - scriptStr = script.toString(); + this.cNode = DecisionTree.deserializeNode(b, b.length, true); } - final StackMachine vm; - if (modelId.equals(prevModelId)) { - vm = prevVM; - } else { - vm = new StackMachine(); - try { - vm.compile(scriptStr); - } catch (VMRuntimeException e) { - throw new HiveException("failed to compile StackMachine", e); + Arrays.fill(result, null); + Preconditions.checkNotNull(cNode); + cNode.predict(features, new PredictionHandler() { + public void handle(int output, double[] posteriori) { + result[0] = new IntWritable(output); + result[1] = WritableUtils.toWritableList(posteriori); } - this.prevModelId = modelId; - this.prevVM = vm; - } - - try { - vm.eval(features); - } catch (VMRuntimeException vme) { - throw new HiveException("failed to eval StackMachine", vme); - } catch (Throwable e) { - throw new HiveException("failed to eval StackMachine", e); - } + }); - Double result = vm.getResult(); - if (result == null) { - return null; - } - if (classification) { - return new IntWritable(result.intValue()); - } else { - return new DoubleWritable(result.doubleValue()); - } - } - - @Override - public void close() throws IOException { - IOUtils.closeQuietly(codec); + return result; } } - static final class JavascriptEvaluator implements Evaluator { + static final class RegressionEvaluator implements Evaluator { - private final ScriptEngine scriptEngine; - private final Compilable compilableEngine; + @Nonnull + private final DoubleWritable result; + @Nullable private String prevModelId = null; - private CompiledScript prevCompiled; - - private DeflateCodec codec = null; + private RegressionTree.Node rNode = null; - JavascriptEvaluator() throws UDFArgumentException { - ScriptEngineManager manager = new ScriptEngineManager(); - ScriptEngine engine = manager.getEngineByExtension("js"); - if (!(engine instanceof Compilable)) { - throw new UDFArgumentException("ScriptEngine was not compilable: " - + engine.getFactory().getEngineName() + " version " - + engine.getFactory().getEngineVersion()); - } - this.scriptEngine = engine; - this.compilableEngine = (Compilable) engine; + RegressionEvaluator() { + this.result = new DoubleWritable(); } - @Override - public Writable evaluate(@Nonnull String modelId, boolean compressed, @Nonnull Text script, - double[] features, boolean classification) throws HiveException { - final String scriptStr; - if (compressed) { - if (codec == null) { - this.codec = new DeflateCodec(false, true); - } + @Nonnull + public DoubleWritable evaluate(@Nonnull final String modelId, @Nonnull final Text script, + @Nonnull final Vector features) throws HiveException { + if (!modelId.equals(prevModelId)) { + this.prevModelId = modelId; + int length = script.getLength(); byte[] b = script.getBytes(); - int len = script.getLength(); - b = Base91.decode(b, 0, len); - try { - b = codec.decompress(b); - } catch (IOException e) { - throw new HiveException("decompression failed", e); - } - scriptStr = new String(b); - } else { - scriptStr = script.toString(); - } - - final CompiledScript compiled; - if (modelId.equals(prevModelId)) { - compiled = prevCompiled; - } else { - try { - compiled = compilableEngine.compile(scriptStr); - } catch (ScriptException e) { - throw new HiveException("failed to compile: \n" + script, e); - } - this.prevCompiled = compiled; - } - - final Bindings bindings = scriptEngine.createBindings(); - final Object result; - try { - bindings.put("x", features); - result = compiled.eval(bindings); - } catch (ScriptException se) { - throw new HiveException("failed to evaluate: \n" + script, se); - } catch (Throwable e) { - throw new HiveException("failed to evaluate: \n" + script, e); - } finally { - bindings.clear(); - } - - if (result == null) { - return null; - } - if (!(result instanceof Number)) { - throw new HiveException("Got an unexpected non-number result: " + result); - } - if (classification) { - Number casted = (Number) result; - return new IntWritable(casted.intValue()); - } else { - Number casted = (Number) result; - return new DoubleWritable(casted.doubleValue()); + b = Base91.decode(b, 0, length); + this.rNode = RegressionTree.deserializeNode(b, b.length, true); } - } + Preconditions.checkNotNull(rNode); - @Override - public void close() throws IOException { - IOUtils.closeQuietly(codec); + double value = rNode.predict(features); + result.set(value); + return result; } - } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/smile/utils/SmileExtUtils.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/smile/utils/SmileExtUtils.java b/core/src/main/java/hivemall/smile/utils/SmileExtUtils.java index c0dfc1c..74a3032 100644 --- a/core/src/main/java/hivemall/smile/utils/SmileExtUtils.java +++ b/core/src/main/java/hivemall/smile/utils/SmileExtUtils.java @@ -18,11 +18,23 @@ */ package hivemall.smile.utils; +import hivemall.math.matrix.ColumnMajorMatrix; +import hivemall.math.matrix.Matrix; +import hivemall.math.matrix.MatrixUtils; +import hivemall.math.matrix.ints.ColumnMajorDenseIntMatrix2d; +import hivemall.math.matrix.ints.ColumnMajorIntMatrix; +import hivemall.math.random.PRNG; +import hivemall.math.random.RandomNumberGeneratorFactory; +import hivemall.math.vector.VectorProcedure; import hivemall.smile.classification.DecisionTree.SplitRule; import hivemall.smile.data.Attribute; import hivemall.smile.data.Attribute.AttributeType; import hivemall.smile.data.Attribute.NominalAttribute; import hivemall.smile.data.Attribute.NumericAttribute; +import hivemall.utils.collections.lists.DoubleArrayList; +import hivemall.utils.collections.lists.IntArrayList; +import hivemall.utils.lang.mutable.MutableInt; +import hivemall.utils.math.MathUtils; import java.util.Arrays; @@ -49,13 +61,14 @@ public final class SmileExtUtils { } final String[] opts = opt.split(","); final int size = opts.length; + final NumericAttribute immutableNumAttr = new NumericAttribute(); final Attribute[] attr = new Attribute[size]; for (int i = 0; i < size; i++) { final String type = opts[i]; if ("Q".equals(type)) { - attr[i] = new NumericAttribute(i); + attr[i] = immutableNumAttr; } else if ("C".equals(type)) { - attr[i] = new NominalAttribute(i); + attr[i] = new NominalAttribute(); } else { throw new UDFArgumentException("Unexpected type: " + type); } @@ -64,13 +77,55 @@ public final class SmileExtUtils { } @Nonnull - public static Attribute[] attributeTypes(@Nullable Attribute[] attributes, - @Nonnull final double[][] x) { + public static Attribute[] attributeTypes(@Nullable final Attribute[] attributes, + @Nonnull final Matrix x) { if (attributes == null) { - int p = x[0].length; - attributes = new Attribute[p]; - for (int i = 0; i < p; i++) { - attributes[i] = new NumericAttribute(i); + int p = x.numColumns(); + Attribute[] newAttributes = new Attribute[p]; + Arrays.fill(newAttributes, new NumericAttribute()); + return newAttributes; + } + + if (x.isRowMajorMatrix()) { + final VectorProcedure proc = new VectorProcedure() { + @Override + public void apply(final int j, final double value) { + final Attribute attr = attributes[j]; + if (attr.type == AttributeType.NOMINAL) { + final int x_ij = ((int) value) + 1; + final int prevSize = attr.getSize(); + if (x_ij > prevSize) { + attr.setSize(x_ij); + } + } + } + }; + for (int i = 0, rows = x.numRows(); i < rows; i++) { + x.eachNonNullInRow(i, proc); + } + } else if (x.isColumnMajorMatrix()) { + final MutableInt max_x = new MutableInt(0); + final VectorProcedure proc = new VectorProcedure() { + @Override + public void apply(final int i, final double value) { + final int x_ij = (int) value; + if (x_ij > max_x.getValue()) { + max_x.setValue(x_ij); + } + } + }; + + final int size = attributes.length; + for (int j = 0; j < size; j++) { + final Attribute attr = attributes[j]; + if (attr.type == AttributeType.NOMINAL) { + if (attr.getSize() != -1) { + continue; + } + max_x.setValue(0); + x.eachNonNullInColumn(j, proc); + attr.setSize(max_x.getValue() + 1); + } } } else { int size = attributes.length; @@ -81,8 +136,12 @@ public final class SmileExtUtils { continue; } int max_x = 0; - for (int i = 0; i < x.length; i++) { - int x_ij = (int) x[i][j]; + for (int i = 0, rows = x.numRows(); i < rows; i++) { + final double v = x.get(i, j, Double.NaN); + if (Double.isNaN(v)) { + continue; + } + int x_ij = (int) v; if (x_ij > max_x) { max_x = x_ij; } @@ -97,16 +156,17 @@ public final class SmileExtUtils { @Nonnull public static Attribute[] convertAttributeTypes(@Nonnull final smile.data.Attribute[] original) { final int size = original.length; + final NumericAttribute immutableNumAttr = new NumericAttribute(); final Attribute[] dst = new Attribute[size]; for (int i = 0; i < size; i++) { smile.data.Attribute o = original[i]; switch (o.type) { case NOMINAL: { - dst[i] = new NominalAttribute(i); + dst[i] = new NominalAttribute(); break; } case NUMERIC: { - dst[i] = new NumericAttribute(i); + dst[i] = immutableNumAttr; break; } default: @@ -117,23 +177,52 @@ public final class SmileExtUtils { } @Nonnull - public static int[][] sort(@Nonnull final Attribute[] attributes, @Nonnull final double[][] x) { - final int n = x.length; - final int p = x[0].length; + public static ColumnMajorIntMatrix sort(@Nonnull final Attribute[] attributes, + @Nonnull final Matrix x) { + final int n = x.numRows(); + final int p = x.numColumns(); - final double[] a = new double[n]; final int[][] index = new int[p][]; + if (x.isSparse()) { + int initSize = n / 10; + final DoubleArrayList dlist = new DoubleArrayList(initSize); + final IntArrayList ilist = new IntArrayList(initSize); + final VectorProcedure proc = new VectorProcedure() { + @Override + public void apply(final int i, final double v) { + dlist.add(v); + ilist.add(i); + } + }; - for (int j = 0; j < p; j++) { - if (attributes[j].type == AttributeType.NUMERIC) { - for (int i = 0; i < n; i++) { - a[i] = x[i][j]; + final ColumnMajorMatrix x2 = x.toColumnMajorMatrix(); + for (int j = 0; j < p; j++) { + if (attributes[j].type != AttributeType.NUMERIC) { + continue; + } + x2.eachNonNullInColumn(j, proc); + if (ilist.isEmpty()) { + continue; + } + int[] indexJ = ilist.toArray(); + QuickSort.sort(dlist.array(), indexJ, indexJ.length); + index[j] = indexJ; + dlist.clear(); + ilist.clear(); + } + } else { + final double[] a = new double[n]; + for (int j = 0; j < p; j++) { + if (attributes[j].type == AttributeType.NUMERIC) { + for (int i = 0; i < n; i++) { + a[i] = x.get(i, j); + } + index[j] = QuickSort.sort(a); } - index[j] = QuickSort.sort(a); } } - return index; + return new ColumnMajorDenseIntMatrix2d(index, n); } @Nonnull @@ -169,13 +258,13 @@ public final class SmileExtUtils { } } - public static int computeNumInputVars(final float numVars, final double[][] x) { + public static int computeNumInputVars(final float numVars, @Nonnull final Matrix x) { final int numInputVars; if (numVars <= 0.f) { - int dims = x[0].length; + int dims = x.numColumns(); numInputVars = (int) Math.ceil(Math.sqrt(dims)); } else if (numVars > 0.f && numVars <= 1.f) { - numInputVars = (int) (numVars * x[0].length); + numInputVars = (int) (numVars * x.numColumns()); } else { numInputVars = (int) numVars; } @@ -186,42 +275,75 @@ public final class SmileExtUtils { return Thread.currentThread().getId() * System.nanoTime(); } - public static void shuffle(@Nonnull final int[] x, @Nonnull final smile.math.Random rnd) { + public static void shuffle(@Nonnull final int[] x, @Nonnull final PRNG rnd) { for (int i = x.length; i > 1; i--) { int j = rnd.nextInt(i); swap(x, i - 1, j); } } - public static void shuffle(@Nonnull final double[][] x, final int[] y, @Nonnull long seed) { - if (x.length != y.length) { - throw new IllegalArgumentException("x.length (" + x.length + ") != y.length (" + @Nonnull + public static Matrix shuffle(@Nonnull final Matrix x, @Nonnull final int[] y, long seed) { + final int numRows = x.numRows(); + if (numRows != y.length) { + throw new IllegalArgumentException("x.length (" + numRows + ") != y.length (" + y.length + ')'); } if (seed == -1L) { seed = generateSeed(); } - final smile.math.Random rnd = new smile.math.Random(seed); - for (int i = x.length; i > 1; i--) { - int j = rnd.nextInt(i); - swap(x, i - 1, j); - swap(y, i - 1, j); + + final PRNG rnd = RandomNumberGeneratorFactory.createPRNG(seed); + if (x.swappable()) { + for (int i = numRows; i > 1; i--) { + int j = rnd.nextInt(i); + int k = i - 1; + x.swap(k, j); + swap(y, k, j); + } + return x; + } else { + final int[] indicies = MathUtils.permutation(numRows); + for (int i = numRows; i > 1; i--) { + int j = rnd.nextInt(i); + int k = i - 1; + swap(indicies, k, j); + swap(y, k, j); + } + return MatrixUtils.shuffle(x, indicies); } } - public static void shuffle(@Nonnull final double[][] x, final double[] y, @Nonnull long seed) { - if (x.length != y.length) { - throw new IllegalArgumentException("x.length (" + x.length + ") != y.length (" + @Nonnull + public static Matrix shuffle(@Nonnull final Matrix x, @Nonnull final double[] y, + @Nonnull long seed) { + final int numRows = x.numRows(); + if (numRows != y.length) { + throw new IllegalArgumentException("x.length (" + numRows + ") != y.length (" + y.length + ')'); } if (seed == -1L) { seed = generateSeed(); } - final smile.math.Random rnd = new smile.math.Random(seed); - for (int i = x.length; i > 1; i--) { - int j = rnd.nextInt(i); - swap(x, i - 1, j); - swap(y, i - 1, j); + + final PRNG rnd = RandomNumberGeneratorFactory.createPRNG(seed); + if (x.swappable()) { + for (int i = numRows; i > 1; i--) { + int j = rnd.nextInt(i); + int k = i - 1; + x.swap(k, j); + swap(y, k, j); + } + return x; + } else { + final int[] indicies = MathUtils.permutation(numRows); + for (int i = numRows; i > 1; i--) { + int j = rnd.nextInt(i); + int k = i - 1; + swap(indicies, k, j); + swap(y, k, j); + } + return MatrixUtils.shuffle(x, indicies); } } @@ -243,15 +365,6 @@ public final class SmileExtUtils { x[j] = s; } - /** - * Swap two elements of an array. - */ - private static void swap(final double[][] x, final int i, final int j) { - double[] s = x[i]; - x[i] = x[j]; - x[j] = s; - } - @Nonnull public static int[] bagsToSamples(@Nonnull final int[] bags) { int maxIndex = -1; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/smile/vm/Operation.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/smile/vm/Operation.java b/core/src/main/java/hivemall/smile/vm/Operation.java deleted file mode 100644 index fff617f..0000000 --- a/core/src/main/java/hivemall/smile/vm/Operation.java +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package hivemall.smile.vm; - -import javax.annotation.Nonnull; -import javax.annotation.Nullable; - -public final class Operation { - - final OperationEnum op; - final String operand; - - public Operation(@Nonnull OperationEnum op) { - this(op, null); - } - - public Operation(@Nonnull OperationEnum op, @Nullable String operand) { - this.op = op; - this.operand = operand; - } - - public enum OperationEnum { - ADD, SUB, DIV, MUL, DUP, // reserved - PUSH, POP, GOTO, IFEQ, IFEQ2, IFGE, IFGT, IFLE, IFLT, CALL; // used - - static OperationEnum valueOfLowerCase(String op) { - return OperationEnum.valueOf(op.toUpperCase()); - } - } - - @Override - public String toString() { - return op.toString() + (operand != null ? (" " + operand) : ""); - } - -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/smile/vm/StackMachine.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/smile/vm/StackMachine.java b/core/src/main/java/hivemall/smile/vm/StackMachine.java deleted file mode 100644 index 3bf8b46..0000000 --- a/core/src/main/java/hivemall/smile/vm/StackMachine.java +++ /dev/null @@ -1,300 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package hivemall.smile.vm; - -import hivemall.utils.lang.StringUtils; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Stack; - -import javax.annotation.Nonnull; -import javax.annotation.Nullable; - -public final class StackMachine { - public static final String SEP = "; "; - - @Nonnull - private final List<Operation> code; - @Nonnull - private final Map<String, Double> valuesMap; - @Nonnull - private final Map<String, Integer> jumpMap; - @Nonnull - private final Stack<Double> programStack; - - /** - * Instruction pointer - */ - private int IP; - - /** - * Stack pointer - */ - @SuppressWarnings("unused") - private int SP; - - private int codeLength; - private boolean[] done; - private Double result; - - public StackMachine() { - this.code = new ArrayList<Operation>(); - this.valuesMap = new HashMap<String, Double>(); - this.jumpMap = new HashMap<String, Integer>(); - this.programStack = new Stack<Double>(); - this.SP = 0; - this.result = null; - } - - public void run(@Nonnull String scripts, @Nonnull double[] features) throws VMRuntimeException { - compile(scripts); - eval(features); - } - - public void run(@Nonnull List<String> opslist, @Nonnull double[] features) - throws VMRuntimeException { - compile(opslist); - eval(features); - } - - public void compile(@Nonnull String scripts) throws VMRuntimeException { - List<String> opslist = Arrays.asList(scripts.split(SEP)); - compile(opslist); - } - - public void compile(@Nonnull List<String> opslist) throws VMRuntimeException { - for (String line : opslist) { - String[] ops = line.split(" ", -1); - if (ops.length == 2) { - Operation.OperationEnum o = Operation.OperationEnum.valueOfLowerCase(ops[0]); - code.add(new Operation(o, ops[1])); - } else { - Operation.OperationEnum o = Operation.OperationEnum.valueOfLowerCase(ops[0]); - code.add(new Operation(o)); - } - } - - int size = opslist.size(); - this.codeLength = size - 1; - this.done = new boolean[size]; - } - - public void eval(final double[] features) throws VMRuntimeException { - init(); - bind(features); - execute(0); - } - - private void init() { - valuesMap.clear(); - jumpMap.clear(); - programStack.clear(); - this.SP = 0; - this.result = null; - Arrays.fill(done, false); - } - - private void bind(final double[] features) { - final StringBuilder buf = new StringBuilder(); - for (int i = 0; i < features.length; i++) { - String bindKey = buf.append("x[").append(i).append("]").toString(); - valuesMap.put(bindKey, features[i]); - StringUtils.clear(buf); - } - } - - private void execute(int entryPoint) throws VMRuntimeException { - valuesMap.put("end", -1.0); - jumpMap.put("last", codeLength); - - IP = entryPoint; - - while (IP < code.size()) { - if (done[IP]) { - throw new VMRuntimeException("There is a infinite loop in the Machine code."); - } - done[IP] = true; - Operation currentOperation = code.get(IP); - if (!executeOperation(currentOperation)) { - return; - } - } - } - - @Nullable - public Double getResult() { - return result; - } - - private Double pop() { - SP--; - return programStack.pop(); - } - - private Double push(Double val) { - programStack.push(val); - SP++; - return val; - } - - private boolean executeOperation(Operation currentOperation) throws VMRuntimeException { - if (IP < 0) { - return false; - } - switch (currentOperation.op) { - case GOTO: { - if (StringUtils.isInt(currentOperation.operand)) { - IP = Integer.parseInt(currentOperation.operand); - } else { - IP = jumpMap.get(currentOperation.operand); - } - break; - } - case CALL: { - double candidateIP = valuesMap.get(currentOperation.operand); - if (candidateIP < 0) { - evaluateBuiltinByName(currentOperation.operand); - IP++; - } - break; - } - case IFEQ: { - double a = pop(); - double b = pop(); - if (a == b) { - IP++; - } else { - if (StringUtils.isInt(currentOperation.operand)) { - IP = Integer.parseInt(currentOperation.operand); - } else { - IP = jumpMap.get(currentOperation.operand); - } - } - break; - } - case IFEQ2: {// follow the rule of smile's Math class. - double a = pop(); - double b = pop(); - if (smile.math.Math.equals(a, b)) { - IP++; - } else { - if (StringUtils.isInt(currentOperation.operand)) { - IP = Integer.parseInt(currentOperation.operand); - } else { - IP = jumpMap.get(currentOperation.operand); - } - } - break; - } - case IFGE: { - double lower = pop(); - double upper = pop(); - if (upper >= lower) { - IP++; - } else { - if (StringUtils.isInt(currentOperation.operand)) { - IP = Integer.parseInt(currentOperation.operand); - } else { - IP = jumpMap.get(currentOperation.operand); - } - } - break; - } - case IFGT: { - double lower = pop(); - double upper = pop(); - if (upper > lower) { - IP++; - } else { - if (StringUtils.isInt(currentOperation.operand)) { - IP = Integer.parseInt(currentOperation.operand); - } else { - IP = jumpMap.get(currentOperation.operand); - } - } - break; - } - case IFLE: { - double lower = pop(); - double upper = pop(); - if (upper <= lower) { - IP++; - } else { - if (StringUtils.isInt(currentOperation.operand)) { - IP = Integer.parseInt(currentOperation.operand); - } else { - IP = jumpMap.get(currentOperation.operand); - } - } - break; - } - case IFLT: { - double lower = pop(); - double upper = pop(); - if (upper < lower) { - IP++; - } else { - if (StringUtils.isInt(currentOperation.operand)) { - IP = Integer.parseInt(currentOperation.operand); - } else { - IP = jumpMap.get(currentOperation.operand); - } - } - break; - } - case POP: { - valuesMap.put(currentOperation.operand, pop()); - IP++; - break; - } - case PUSH: { - if (StringUtils.isDouble(currentOperation.operand)) { - push(Double.parseDouble(currentOperation.operand)); - } else { - Double v = valuesMap.get(currentOperation.operand); - if (v == null) { - throw new VMRuntimeException("value is not binded: " - + currentOperation.operand); - } - push(v); - } - IP++; - break; - } - default: - throw new VMRuntimeException("Machine code has wrong opcode :" - + currentOperation.op); - } - return true; - - } - - private void evaluateBuiltinByName(String name) throws VMRuntimeException { - if (name.equals("end")) { - this.result = pop(); - } else { - throw new VMRuntimeException("Machine code has wrong builin function :" + name); - } - } - -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/smile/vm/VMRuntimeException.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/smile/vm/VMRuntimeException.java b/core/src/main/java/hivemall/smile/vm/VMRuntimeException.java deleted file mode 100644 index 7fc89c8..0000000 --- a/core/src/main/java/hivemall/smile/vm/VMRuntimeException.java +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package hivemall.smile.vm; - -public class VMRuntimeException extends Exception { - private static final long serialVersionUID = -7378149197872357802L; - - public VMRuntimeException(String message) { - super(message); - } - - public VMRuntimeException(String message, Throwable cause) { - super(message, cause); - } - -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/tools/mapred/DistributedCacheLookupUDF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/tools/mapred/DistributedCacheLookupUDF.java b/core/src/main/java/hivemall/tools/mapred/DistributedCacheLookupUDF.java index 1f6c324..366b74b 100644 --- a/core/src/main/java/hivemall/tools/mapred/DistributedCacheLookupUDF.java +++ b/core/src/main/java/hivemall/tools/mapred/DistributedCacheLookupUDF.java @@ -19,7 +19,7 @@ package hivemall.tools.mapred; import hivemall.ftvec.ExtractFeatureUDF; -import hivemall.utils.collections.OpenHashMap; +import hivemall.utils.collections.maps.OpenHashMap; import hivemall.utils.hadoop.HadoopUtils; import hivemall.utils.hadoop.HiveUtils; import hivemall.utils.io.IOUtils; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/collections/DoubleArray.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/collections/DoubleArray.java b/core/src/main/java/hivemall/utils/collections/DoubleArray.java deleted file mode 100644 index a7dfa81..0000000 --- a/core/src/main/java/hivemall/utils/collections/DoubleArray.java +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package hivemall.utils.collections; - -import java.io.Serializable; - -import javax.annotation.Nonnull; - -public interface DoubleArray extends Serializable { - - public double get(int key); - - public double get(int key, double valueIfKeyNotFound); - - public void put(int key, double value); - - public int size(); - - public int keyAt(int index); - - @Nonnull - public double[] toArray(); - - @Nonnull - public double[] toArray(boolean copy); - -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/collections/DoubleArray3D.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/collections/DoubleArray3D.java b/core/src/main/java/hivemall/utils/collections/DoubleArray3D.java deleted file mode 100644 index 5716212..0000000 --- a/core/src/main/java/hivemall/utils/collections/DoubleArray3D.java +++ /dev/null @@ -1,147 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package hivemall.utils.collections; - -import hivemall.utils.lang.Primitives; - -import java.nio.ByteBuffer; -import java.nio.DoubleBuffer; - -import javax.annotation.Nonnull; - -public final class DoubleArray3D { - private static final int DEFAULT_SIZE = 100 * 100 * 10; // feature * field * factor - - private final boolean direct; - - @Nonnull - private DoubleBuffer buffer; - private int capacity; - - private int size; - // number of array in each dimension - private int n1, n2, n3; - // pointer to each dimension - private int p1, p2; - - private boolean sanityCheck; - - public DoubleArray3D() { - this(DEFAULT_SIZE, true); - } - - public DoubleArray3D(int initSize, boolean direct) { - this.direct = direct; - this.buffer = allocate(direct, initSize); - this.capacity = initSize; - this.size = -1; - this.sanityCheck = true; - } - - public DoubleArray3D(int dim1, int dim2, int dim3) { - this.direct = true; - this.capacity = -1; - configure(dim1, dim2, dim3); - this.sanityCheck = true; - } - - public void setSanityCheck(boolean enable) { - this.sanityCheck = enable; - } - - public void configure(final int dim1, final int dim2, final int dim3) { - int requiredSize = cardinarity(dim1, dim2, dim3); - if (requiredSize > capacity) { - this.buffer = allocate(direct, requiredSize); - this.capacity = requiredSize; - } - this.size = requiredSize; - this.n1 = dim1; - this.n2 = dim2; - this.n3 = dim3; - this.p1 = n2 * n3; - this.p2 = n3; - } - - public void clear() { - buffer.clear(); - this.size = -1; - } - - public int getSize() { - return size; - } - - int getCapacity() { - return capacity; - } - - public double get(final int i, final int j, final int k) { - int idx = idx(i, j, k); - return buffer.get(idx); - } - - public void set(final int i, final int j, final int k, final double val) { - int idx = idx(i, j, k); - buffer.put(idx, val); - } - - private int idx(final int i, final int j, final int k) { - if (sanityCheck == false) { - return i * p1 + j * p2 + k; - } - - if (size == -1) { - throw new IllegalStateException("Double3DArray#configure() is not called"); - } - if (i >= n1 || i < 0) { - throw new ArrayIndexOutOfBoundsException("Index '" + i - + "' out of bounds for 1st dimension of size " + n1); - } - if (j >= n2 || j < 0) { - throw new ArrayIndexOutOfBoundsException("Index '" + j - + "' out of bounds for 2nd dimension of size " + n2); - } - if (k >= n3 || k < 0) { - throw new ArrayIndexOutOfBoundsException("Index '" + k - + "' out of bounds for 3rd dimension of size " + n3); - } - final int idx = i * p1 + j * p2 + k; - if (idx >= size) { - throw new IndexOutOfBoundsException("Computed internal index '" + idx - + "' exceeds buffer size '" + size + "' where i=" + i + ", j=" + j + ", k=" + k); - } - return idx; - } - - private static int cardinarity(final int dim1, final int dim2, final int dim3) { - if (dim1 <= 0 || dim2 <= 0 || dim3 <= 0) { - throw new IllegalArgumentException("Detected negative dimension size. dim1=" + dim1 - + ", dim2=" + dim2 + ", dim3=" + dim3); - } - return dim1 * dim2 * dim3; - } - - @Nonnull - private static DoubleBuffer allocate(final boolean direct, final int size) { - int bytes = size * Primitives.DOUBLE_BYTES; - ByteBuffer buf = direct ? ByteBuffer.allocateDirect(bytes) : ByteBuffer.allocate(bytes); - return buf.asDoubleBuffer(); - } -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/collections/DoubleArrayList.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/collections/DoubleArrayList.java b/core/src/main/java/hivemall/utils/collections/DoubleArrayList.java deleted file mode 100644 index afdc251..0000000 --- a/core/src/main/java/hivemall/utils/collections/DoubleArrayList.java +++ /dev/null @@ -1,168 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package hivemall.utils.collections; - -import java.io.Closeable; -import java.io.Serializable; - -import javax.annotation.Nonnull; - -public final class DoubleArrayList implements Serializable, Closeable { - private static final long serialVersionUID = -8155789759545975413L; - public static final int DEFAULT_CAPACITY = 12; - - /** array entity */ - private double[] data; - private int used; - - public DoubleArrayList() { - this(DEFAULT_CAPACITY); - } - - public DoubleArrayList(int size) { - this.data = new double[size]; - this.used = 0; - } - - public DoubleArrayList(double[] initValues) { - this.data = initValues; - this.used = initValues.length; - } - - public void add(double value) { - if (used >= data.length) { - expand(used + 1); - } - data[used++] = value; - } - - public void add(double[] values) { - final int needs = used + values.length; - if (needs >= data.length) { - expand(needs); - } - System.arraycopy(values, 0, data, used, values.length); - this.used = needs; - } - - /** - * dynamic expansion. - */ - private void expand(int max) { - while (data.length < max) { - final int len = data.length; - double[] newArray = new double[len * 2]; - System.arraycopy(data, 0, newArray, 0, len); - this.data = newArray; - } - } - - public double remove() { - return data[--used]; - } - - public double remove(int index) { - final double ret; - if (index > used) { - throw new IndexOutOfBoundsException(); - } else if (index == used) { - ret = data[--used]; - } else { // index < used - // removed value - ret = data[index]; - final double[] newarray = new double[--used]; - // prefix - System.arraycopy(data, 0, newarray, 0, index - 1); - // appendix - System.arraycopy(data, index + 1, newarray, index, used - index); - // set fields. - this.data = newarray; - } - return ret; - } - - public void set(int index, double value) { - if (index > used) { - throw new IllegalArgumentException("Index MUST be less than \"size()\"."); - } else if (index == used) { - ++used; - } - data[index] = value; - } - - public double get(int index) { - if (index >= used) - throw new IndexOutOfBoundsException(); - return data[index]; - } - - public double fastGet(int index) { - return data[index]; - } - - public int size() { - return used; - } - - public boolean isEmpty() { - return used == 0; - } - - public void clear() { - used = 0; - } - - @Nonnull - public double[] toArray() { - return toArray(false); - } - - @Nonnull - public double[] toArray(boolean close) { - final double[] newArray = new double[used]; - System.arraycopy(data, 0, newArray, 0, used); - if (close) { - close(); - } - return newArray; - } - - public double[] array() { - return data; - } - - @Override - public String toString() { - final StringBuilder buf = new StringBuilder(); - buf.append('['); - for (int i = 0; i < used; i++) { - if (i != 0) { - buf.append(", "); - } - buf.append(data[i]); - } - buf.append(']'); - return buf.toString(); - } - - @Override - public void close() { - this.data = null; - } -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/collections/FixedIntArray.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/collections/FixedIntArray.java b/core/src/main/java/hivemall/utils/collections/FixedIntArray.java deleted file mode 100644 index 927ee83..0000000 --- a/core/src/main/java/hivemall/utils/collections/FixedIntArray.java +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package hivemall.utils.collections; - -import java.util.Arrays; - -import javax.annotation.Nonnull; - -/** - * A fixed INT array that has keys greater than or equals to 0. - */ -public final class FixedIntArray implements IntArray { - private static final long serialVersionUID = -1450212841013810240L; - - @Nonnull - private final int[] array; - private final int size; - - public FixedIntArray(@Nonnull int size) { - this.array = new int[size]; - this.size = size; - } - - public FixedIntArray(@Nonnull int[] array) { - this.array = array; - this.size = array.length; - } - - @Override - public int get(int index) { - return array[index]; - } - - @Override - public int get(int index, int valueIfKeyNotFound) { - if (index >= size) { - return valueIfKeyNotFound; - } - return array[index]; - } - - @Override - public void put(int index, int value) { - array[index] = value; - } - - @Override - public int size() { - return array.length; - } - - @Override - public int keyAt(int index) { - return index; - } - - @Override - public int[] toArray() { - return toArray(true); - } - - @Override - public int[] toArray(boolean copy) { - if (copy) { - return Arrays.copyOf(array, size); - } else { - return array; - } - } - -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/collections/FloatArrayList.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/collections/FloatArrayList.java b/core/src/main/java/hivemall/utils/collections/FloatArrayList.java deleted file mode 100644 index cfdf504..0000000 --- a/core/src/main/java/hivemall/utils/collections/FloatArrayList.java +++ /dev/null @@ -1,152 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package hivemall.utils.collections; - -import java.io.Serializable; - -public final class FloatArrayList implements Serializable { - private static final long serialVersionUID = 8764828070342317585L; - - public static final int DEFAULT_CAPACITY = 12; - - /** array entity */ - private float[] data; - private int used; - - public FloatArrayList() { - this(DEFAULT_CAPACITY); - } - - public FloatArrayList(int size) { - this.data = new float[size]; - this.used = 0; - } - - public FloatArrayList(float[] initValues) { - this.data = initValues; - this.used = initValues.length; - } - - public void add(float value) { - if (used >= data.length) { - expand(used + 1); - } - data[used++] = value; - } - - public void add(float[] values) { - final int needs = used + values.length; - if (needs >= data.length) { - expand(needs); - } - System.arraycopy(values, 0, data, used, values.length); - this.used = needs; - } - - /** - * dynamic expansion. - */ - private void expand(int max) { - while (data.length < max) { - final int len = data.length; - float[] newArray = new float[len * 2]; - System.arraycopy(data, 0, newArray, 0, len); - this.data = newArray; - } - } - - public float remove() { - return data[--used]; - } - - public float remove(int index) { - final float ret; - if (index > used) { - throw new IndexOutOfBoundsException(); - } else if (index == used) { - ret = data[--used]; - } else { // index < used - // removed value - ret = data[index]; - final float[] newarray = new float[--used]; - // prefix - System.arraycopy(data, 0, newarray, 0, index - 1); - // appendix - System.arraycopy(data, index + 1, newarray, index, used - index); - // set fields. - this.data = newarray; - } - return ret; - } - - public void set(int index, float value) { - if (index > used) { - throw new IllegalArgumentException("Index MUST be less than \"size()\"."); - } else if (index == used) { - ++used; - } - data[index] = value; - } - - public float get(int index) { - if (index >= used) - throw new IndexOutOfBoundsException(); - return data[index]; - } - - public float fastGet(int index) { - return data[index]; - } - - public int size() { - return used; - } - - public boolean isEmpty() { - return used == 0; - } - - public void clear() { - this.used = 0; - } - - public float[] toArray() { - final float[] newArray = new float[used]; - System.arraycopy(data, 0, newArray, 0, used); - return newArray; - } - - public float[] array() { - return data; - } - - @Override - public String toString() { - final StringBuilder buf = new StringBuilder(); - buf.append('['); - for (int i = 0; i < used; i++) { - if (i != 0) { - buf.append(", "); - } - buf.append(data[i]); - } - buf.append(']'); - return buf.toString(); - } -}
