Github user helenahm commented on a diff in the pull request:

    https://github.com/apache/incubator-hivemall/pull/93#discussion_r125552049
  
    --- Diff: core/src/main/java/hivemall/smile/classification/MaxEntUDTF.java 
---
    @@ -0,0 +1,440 @@
    +package hivemall.smile.classification;
    +
    +import java.io.FileNotFoundException;
    +import java.io.IOException;
    +import java.util.ArrayList;
    +import java.util.Arrays;
    +import java.util.BitSet;
    +import java.util.HashMap;
    +import java.util.List;
    +import java.util.Map;
    +import java.util.concurrent.Callable;
    +import java.util.concurrent.atomic.AtomicInteger;
    +
    +import javax.annotation.Nonnegative;
    +import javax.annotation.Nonnull;
    +import javax.annotation.Nullable;
    +import javax.annotation.concurrent.GuardedBy;
    +
    +import org.apache.commons.cli.CommandLine;
    +import org.apache.commons.cli.Options;
    +import org.apache.commons.logging.Log;
    +import org.apache.commons.logging.LogFactory;
    +import org.apache.hadoop.hive.ql.exec.MapredContext;
    +import org.apache.hadoop.hive.ql.exec.MapredContextAccessor;
    +import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
    +import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
    +import org.apache.hadoop.hive.ql.metadata.HiveException;
    +import org.apache.hadoop.hive.serde2.io.DoubleWritable;
    +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
    +import 
org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
    +import 
org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
    +import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
    +import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
    +import org.apache.hadoop.io.IntWritable;
    +import org.apache.hadoop.io.Text;
    +import org.apache.hadoop.mapred.Reporter;
    +import org.apache.hadoop.mapred.Counters.Counter;
    +
    +import hivemall.UDTFWithOptions;
    +import hivemall.math.matrix.Matrix;
    +import hivemall.math.matrix.MatrixUtils;
    +import hivemall.math.matrix.builders.CSRMatrixBuilder;
    +import hivemall.math.matrix.builders.MatrixBuilder;
    +import hivemall.math.matrix.builders.RowMajorDenseMatrixBuilder;
    +import hivemall.math.matrix.ints.ColumnMajorIntMatrix;
    +import hivemall.math.matrix.ints.DoKIntMatrix;
    +import hivemall.math.matrix.ints.IntMatrix;
    +import hivemall.math.random.PRNG;
    +import hivemall.math.random.RandomNumberGeneratorFactory;
    +import hivemall.math.vector.Vector;
    +import hivemall.math.vector.VectorProcedure;
    +import hivemall.smile.classification.DecisionTree.SplitRule;
    +import hivemall.smile.data.Attribute;
    +import hivemall.smile.tools.MatrixEventStream;
    +import hivemall.smile.tools.SepDelimitedTextGISModelWriter;
    +import hivemall.smile.utils.SmileExtUtils;
    +import hivemall.smile.utils.SmileTaskExecutor;
    +import hivemall.utils.codec.Base91;
    +import hivemall.utils.collections.lists.IntArrayList;
    +import hivemall.utils.hadoop.HiveUtils;
    +import hivemall.utils.hadoop.WritableUtils;
    +import hivemall.utils.lang.Preconditions;
    +import hivemall.utils.lang.Primitives;
    +import hivemall.utils.lang.RandomUtils;
    +
    +import opennlp.maxent.GIS;
    +import opennlp.maxent.io.GISModelWriter;
    +import opennlp.model.AbstractModel;
    +import opennlp.model.Event;
    +import opennlp.model.EventStream;
    +import opennlp.model.OnePassRealValueDataIndexer;
    +
    +@Description(
    +        name = "train_maxent_classifier",
    +        value = "_FUNC_(array<double> features, int label [, const boolean 
classification])"
    +                + " - Returns a maximum entropy model per subset of data.")
    +@UDFType(deterministic = true, stateful = false)
    +public class MaxEntUDTF extends UDTFWithOptions{
    +   private static final Log logger = LogFactory.getLog(MaxEntUDTF.class);
    +   
    +   private ListObjectInspector featureListOI;
    +    private PrimitiveObjectInspector featureElemOI;
    +    private PrimitiveObjectInspector labelOI;
    +
    +    private MatrixBuilder matrixBuilder;
    +    private IntArrayList labels;
    +    
    +   private boolean _real;
    +   private Attribute[] _attributes;
    +   private static boolean _USE_SMOOTHING;
    +   private double _SMOOTHING_OBSERVATION;
    +   
    +   private int _numTrees = 1;
    +    
    +    @Nullable
    +    private Reporter _progressReporter;
    +    @Nullable
    +    private Counter _treeBuildTaskCounter;
    +    
    +    @Override
    +    protected Options getOptions() {
    +        Options opts = new Options();
    +        opts.addOption("real", "quantative_feature_presence_indication", 
true,
    +            "true or false [default: true]");
    +        opts.addOption("smoothing", "smoothimg", true, "Shall smoothing be 
performed [default: false]");
    +        opts.addOption("constant", "smoothing_constant", true, "real 
number [default: 1.0]");
    +        opts.addOption("attrs", "attribute_types", true, "Comma separated 
attribute types "
    +                + "(Q for quantitative variable and C for categorical 
variable. e.g., [Q,C,Q,C])");
    +        return opts;
    +    }
    +    
    +    @Override
    +    protected CommandLine processOptions(ObjectInspector[] argOIs) throws 
UDFArgumentException {
    +           boolean real = true;
    +       boolean USE_SMOOTHING = false;
    +       double SMOOTHING_OBSERVATION = 0.1;
    +       
    +        Attribute[] attrs = null;
    +
    +        CommandLine cl = null;
    +        if (argOIs.length >= 3) {
    +            String rawArgs = HiveUtils.getConstString(argOIs[2]);
    +            cl = parseOptions(rawArgs);
    +
    +            real = 
Primitives.parseBoolean(cl.getOptionValue("quantative_feature_presence_indication"),
 real);
    +            attrs = 
SmileExtUtils.resolveAttributes(cl.getOptionValue("attribute_types"));
    +            USE_SMOOTHING = 
Primitives.parseBoolean(cl.getOptionValue("smoothing"), USE_SMOOTHING);
    +            SMOOTHING_OBSERVATION = 
Primitives.parseDouble(cl.getOptionValue("smoothing_constant"), 
SMOOTHING_OBSERVATION);
    +        }
    +
    +        this._real = real;
    +        this._attributes = attrs;
    +        this._USE_SMOOTHING = USE_SMOOTHING;
    +        this._SMOOTHING_OBSERVATION = SMOOTHING_OBSERVATION;
    +
    +        return cl;
    +    }
    +    
    +    @Override
    +    public StructObjectInspector initialize(ObjectInspector[] argOIs) 
throws UDFArgumentException {
    +        if (argOIs.length < 2 || argOIs.length > 3) {
    +            throw new UDFArgumentException(
    +                "_FUNC_ takes 2 ~ 3 arguments: array<double> features, int 
label [, const string options]: "
    +                        + argOIs.length);
    +        }
    +
    +        ListObjectInspector listOI = HiveUtils.asListOI(argOIs[0]);
    +        ObjectInspector elemOI = listOI.getListElementObjectInspector();
    +        this.featureListOI = listOI;
    +        if (HiveUtils.isNumberOI(elemOI)) {
    +            this.featureElemOI = HiveUtils.asDoubleCompatibleOI(elemOI);
    +            this.matrixBuilder = new CSRMatrixBuilder(8192);
    +        } else {
    +            throw new UDFArgumentException(
    +                "_FUNC_ takes double[] for the first argument: " + 
listOI.getTypeName());
    +        }
    +        this.labelOI = HiveUtils.asIntCompatibleOI(argOIs[1]);
    +
    +        processOptions(argOIs);
    +
    +        this.labels = new IntArrayList(1024);
    +
    +        final ArrayList<String> fieldNames = new ArrayList<String>(6);
    +        final ArrayList<ObjectInspector> fieldOIs = new 
ArrayList<ObjectInspector>(6);
    +
    +        fieldNames.add("model_id");
    +        
fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
    +        fieldNames.add("model_weight");
    +        
fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
    +        fieldNames.add("model");
    +        
fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
    +        fieldNames.add("attributes");
    +        
fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
    +        fieldNames.add("oob_errors");
    +        
fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
    +        fieldNames.add("oob_tests");
    +        
fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
    +
    +        return 
ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
    +    }
    +    
    +    @Override
    +    public void process(Object[] args) throws HiveException {
    +        if (args[0] == null) {
    +            throw new HiveException("array<double> features was null");
    +        }
    +        parseFeatures(args[0], matrixBuilder);
    +        int label = PrimitiveObjectInspectorUtils.getInt(args[1], labelOI);
    +        labels.add(label);
    +    }
    +    
    +    private void parseFeatures(@Nonnull final Object argObj, @Nonnull 
final MatrixBuilder builder) {
    +           final int length = featureListOI.getListLength(argObj);
    +        for (int i = 0; i < length; i++) {
    +            Object o = featureListOI.getListElement(argObj, i);
    +            if (o == null) {
    +                continue;
    +            }
    +            double v = PrimitiveObjectInspectorUtils.getDouble(o, 
featureElemOI);
    +            builder.nextColumn(i, v);
    +        } 
    +        builder.nextRow();
    +    }
    +    
    +    @Override
    +    public void close() throws HiveException {
    +        this._progressReporter = getReporter();
    +        this._treeBuildTaskCounter = (_progressReporter == null) ? null
    +                : 
_progressReporter.getCounter("hivemall.smile.MaxEntClassifier$Counter",
    +                    "finishedGISTask");
    +        reportProgress(_progressReporter);
    +
    +        if (!labels.isEmpty()) {
    +            Matrix x = matrixBuilder.buildMatrix();
    +            this.matrixBuilder = null;
    +            int[] y = labels.toArray();
    +            this.labels = null;
    +
    +            // run training
    +            train(x, y);
    +        }
    +
    +        // clean up
    +        this.featureListOI = null;
    +        this.featureElemOI = null;
    +        this.labelOI = null;
    +    }
    +    
    +    private void checkOptions() throws HiveException {
    +           if (_USE_SMOOTHING == false && _SMOOTHING_OBSERVATION != 0.1) {
    +            throw new HiveException("Instructions received to avoid 
smoothing, but smoothing constant is set [" + _SMOOTHING_OBSERVATION + "]");
    +        }
    +    }
    +    
    +    /**
    +     * @param x features
    +     * @param y label
    +     * @param attrs attribute types
    +     * @param numTrees The number of trees
    +     * @param numVars The number of variables to pick up in each node.
    +     * @param seed The seed number for Random Forest
    +     */
    +    private void train(@Nonnull Matrix x, @Nonnull final int[] y) throws 
HiveException {
    +        final int numExamples = x.numRows();
    +        if (numExamples != y.length) {
    +            throw new HiveException(String.format("The sizes of X and Y 
don't match: %d != %d",
    +                numExamples, y.length));
    +        }
    +        checkOptions();
    +
    +        int[] labels = SmileExtUtils.classLables(y);
    +        Attribute[] attributes = SmileExtUtils.attributeTypes(_attributes, 
x);
    +
    +        if (logger.isInfoEnabled()) {
    +            logger.info("real: " + _real + ", smoothing: " + 
this._USE_SMOOTHING + ", smoothing constant: "
    +                    + _SMOOTHING_OBSERVATION);
    +        }
    +
    +        IntMatrix prediction = new DoKIntMatrix(numExamples, 
labels.length); // placeholder for out-of-bag prediction
    +        AtomicInteger remainingTasks = new AtomicInteger(_numTrees);
    +        List<TrainingTask> tasks = new ArrayList<TrainingTask>();
    +        for (int i = 0; i < _numTrees; i++) {
    +            tasks.add(new TrainingTask(this, i, attributes, x, y, 
prediction, remainingTasks));
    +        }
    +
    +        MapredContext mapredContext = MapredContextAccessor.get();
    +        final SmileTaskExecutor executor = new 
SmileTaskExecutor(mapredContext);
    +        try {
    +            executor.run(tasks);
    +        } catch (Exception ex) {
    +            throw new HiveException(ex);
    +        } finally {
    +            executor.shotdown();
    +        }
    +    }
    +    
    +    /**
    +     * Synchronized because {@link #forward(Object)} should be called from 
a single thread.
    +     * 
    +     * @param accuracy
    +     */
    +    synchronized void forward(final int taskId, @Nonnull final Text model,
    +                   @Nonnull Attribute[] attributes,
    +            @Nonnegative final double accuracy, final int[] y,
    +            @Nonnull final IntMatrix prediction, final boolean lastTask) 
throws HiveException {
    +        int oobErrors = 0;
    +        int oobTests = 0;
    +        if (lastTask) {
    +            // out-of-bag error estimate
    +            for (int i = 0; i < y.length; i++) {
    +                final int pred = MatrixUtils.whichMax(prediction, i);
    +                if (pred != -1 && prediction.get(i, pred) > 0) {
    +                    oobTests++;
    +                    if (pred != y[i]) {
    +                        oobErrors++;
    +                    }
    +                }
    +            }
    +        }
    +        
    +        String attributesString = 
SmileExtUtils.resolveAttributes(attributes);
    +
    +        final Object[] forwardObjs = new Object[6];
    +        String modelId = RandomUtils.getUUID();
    +        forwardObjs[0] = new Text(modelId);
    +        forwardObjs[1] = new DoubleWritable(accuracy);
    +        forwardObjs[2] = model;
    +        forwardObjs[3] = new 
Text(SmileExtUtils.resolveAttributes(attributes));
    +        forwardObjs[4] = new IntWritable(oobErrors);
    +        forwardObjs[5] = new IntWritable(oobTests);
    +        forward(forwardObjs);
    +
    +        reportProgress(_progressReporter);
    +        incrCounter(_treeBuildTaskCounter, 1);
    +
    +        logger.info("Forwarded " + taskId + "-th DecisionTree out of " + 
_numTrees);
    +    }
    +    
    +    /**
    +     * Trains a regression tree.
    +     */
    +    private static final class TrainingTask implements Callable<Integer> {
    +
    +        /**
    +         * Training instances.
    +         */
    +        @Nonnull
    +        private final Matrix _x;
    +        /**
    +         * Training sample labels.
    +         */
    +        @Nonnull
    +        private final int[] _y;
    +        
    +        /**
    +         * Attribute properties.
    +         */
    +        @Nonnull
    +        private final Attribute[] _attributes;
    +
    +        /**
    +         * The out-of-bag predictions.
    +         */
    +        @Nonnull
    +        @GuardedBy("_udtf")
    +        private final IntMatrix _prediction;
    +
    +        @Nonnull
    +        private final MaxEntUDTF _udtf;
    +        private final int _taskId;
    + 
    +        @Nonnull
    +        private final AtomicInteger _remainingTasks;
    +
    +        TrainingTask(@Nonnull MaxEntUDTF udtf, int taskId,
    +                   @Nonnull Attribute[] attributes, @Nonnull Matrix x, 
@Nonnull int[] y, 
    +                @Nonnull IntMatrix prediction, @Nonnull AtomicInteger 
remainingTasks) {
    +            this._udtf = udtf;
    +            this._taskId = taskId;
    +            this._attributes = attributes;
    +            this._x = x;
    +            this._y = y;
    +            this._prediction = prediction;
    +            this._remainingTasks = remainingTasks;
    +        }
    +
    +        @Override
    +        public Integer call() throws HiveException {
    +            final int N = _x.numRows();
    +
    +            EventStream es = new MatrixEventStream(_x, _y, _attributes);
    +            AbstractModel model;
    +                   try {
    +                           model = GIS.trainModel(1000, new 
OnePassRealValueDataIndexer(es,0), _USE_SMOOTHING);
    --- End diff --
    
    Yes. Should be definitely.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---

Reply via email to