Github user jaxony commented on a diff in the pull request:
https://github.com/apache/incubator-hivemall/pull/167#discussion_r226817004
--- Diff: core/src/main/java/hivemall/mf/CofactorizationUDTF.java ---
@@ -0,0 +1,584 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.mf;
+
+import hivemall.UDTFWithOptions;
+import hivemall.annotations.VisibleForTesting;
+import hivemall.common.ConversionState;
+import hivemall.fm.Feature;
+import hivemall.fm.StringFeature;
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.io.FileUtils;
+import hivemall.utils.io.NIOUtils;
+import hivemall.utils.io.NioStatefulSegment;
+import hivemall.utils.lang.NumberUtils;
+import hivemall.utils.lang.Primitives;
+import hivemall.utils.lang.SizeOf;
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.Options;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.serde2.objectinspector.*;
+import
org.apache.hadoop.hive.serde2.objectinspector.primitive.BooleanObjectInspector;
+import
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import
org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector;
+import org.apache.hadoop.mapred.Counters;
+import org.apache.hadoop.mapred.Reporter;
+
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+import java.io.File;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.List;
+
+import static hivemall.utils.lang.Primitives.FALSE_BYTE;
+import static hivemall.utils.lang.Primitives.TRUE_BYTE;
+
+public class CofactorizationUDTF extends UDTFWithOptions {
+ private static final Log LOG =
LogFactory.getLog(CofactorizationUDTF.class);
+
+ // Option variables
+ // The number of latent factors
+ private int factor;
+ // The scaling hyperparameter for zero entries in the rank matrix
+ private float scale_zero;
+ // The scaling hyperparameter for non-zero entries in the rank matrix
+ private float scale_nonzero;
+ // The preferred size of the miniBatch for training
+ private int batchSize;
+ // The initial mean rating
+ private float globalBias;
+ // Whether update (and return) the mean rating or not
+ private boolean updateGlobalBias;
+ // The number of iterations
+ private int maxIters;
+ // Whether to use bias clause
+ private boolean useBiasClause;
+ // Whether to use normalization
+ private boolean useL2Norm;
+ // regularization hyperparameters
+ private float lambdaTheta;
+ private float lambdaBeta;
+ private float lambdaGamma;
+
+ // Initialization strategy of rank matrix
+ private CofactorModel.RankInitScheme rankInit;
+
+ // Model itself
+ private CofactorModel model;
+ private int numItems;
+
+ // Variable managing status of learning
+
+ // The number of processed training examples
+ private long count;
+
+ private ConversionState cvState;
+ private ConversionState validationState;
+
+ // Input OIs and Context
+ private StringObjectInspector contextOI;
+ @VisibleForTesting
+ protected ListObjectInspector featuresOI;
+ private BooleanObjectInspector isItemOI;
+ @VisibleForTesting
+ protected ListObjectInspector sppmiOI;
+
+ // Used for iterations
+ @VisibleForTesting
+ protected NioStatefulSegment fileIO;
+ private ByteBuffer inputBuf;
+ private long lastWritePos;
+
+ private String contextProbe;
+ private Feature[] featuresProbe;
+ private Feature[] sppmiProbe;
+ private boolean isItemProbe;
+ private long numValidations;
+ private long numTraining;
+
+
+ static class MiniBatch {
+ private List<TrainingSample> users;
+ private List<TrainingSample> items;
+
+ protected MiniBatch() {
+ users = new ArrayList<>();
+ items = new ArrayList<>();
+ }
+
+ protected void add(TrainingSample sample) {
+ if (sample.isItem()) {
+ items.add(sample);
+ } else {
+ users.add(sample);
+ }
+ }
+
+ protected void clear() {
+ users.clear();
+ items.clear();
+ }
+
+ protected int size() {
+ return items.size() + users.size();
+ }
+
+ protected List<TrainingSample> getItems() {
+ return items;
+ }
+
+ protected List<TrainingSample> getUsers() {
+ return users;
+ }
+ }
+
+ static class TrainingSample {
+ protected String context;
+ protected Feature[] features;
+ protected Feature[] sppmi;
+
+ protected TrainingSample(String context, Feature[] features,
Feature[] sppmi) {
+ this.context = context;
+ this.features = features;
+ this.sppmi = sppmi;
+ }
+
+ protected boolean isItem() {
+ return sppmi != null;
+ }
+ }
+
+ @Override
+ protected Options getOptions() {
+ Options opts = new Options();
+ opts.addOption("k", "factor", true, "The number of latent factor
[default: 10] "
+ + " Note this is alias for `factors` option.");
+ opts.addOption("f", "factors", true, "The number of latent factor
[default: 10]");
+ opts.addOption("r", "lambda", true, "The regularization factor
[default: 0.03]");
+ opts.addOption("c0", "scale_zero", true,
+ "The scaling hyperparameter for zero entries in the rank
matrix [default: 0.1]");
+ opts.addOption("c1", "scale_nonzero", true,
+ "The scaling hyperparameter for non-zero entries in the
rank matrix [default: 1.0]");
+ opts.addOption("b", "batch_size", true, "The miniBatch size for
training [default: 1024]");
+ opts.addOption("n", "num_items", false, "Number of items");
+ opts.addOption("gb", "global_bias", true, "The global bias
[default: 0.0]");
+ opts.addOption("update_gb", "update_gb", false,
+ "Whether update (and return) the global bias or not");
+ opts.addOption("rankinit", true,
+ "Initialization strategy of rank matrix [random, gaussian]
(default: gaussian)");
+ opts.addOption("maxval", "max_init_value", true,
+ "The maximum initial value in the rank matrix [default:
1.0]");
+ opts.addOption("min_init_stddev", true,
+ "The minimum standard deviation of initial rank matrix
[default: 0.01]");
+ opts.addOption("iters", "iterations", true, "The number of
iterations [default: 1]");
+ opts.addOption("iter", true,
+ "The number of iterations [default: 1] Alias for
`-iterations`");
+ opts.addOption("disable_cv", "disable_cvtest", false,
+ "Whether to disable convergence check [default: enabled]");
+ opts.addOption("cv_rate", "convergence_rate", true,
+ "Threshold to determine convergence [default: 0.005]");
+ opts.addOption("disable_bias", "no_bias", false, "Turn off bias
clause");
+ // feature representation
+ opts.addOption("int_feature", "feature_as_integer", false,
+ "Parse a feature as integer [default: OFF]");
+ // normalization
+ opts.addOption("disable_norm", "disable_l2norm", false, "Disable
instance-wise L2 normalization");
+ return opts;
+ }
+
+ @Override
+ protected CommandLine processOptions(ObjectInspector[] argOIs) throws
UDFArgumentException {
+ CommandLine cl = null;
+ String rankInitOpt = "gaussian";
+ float maxInitValue = 1.f;
+ double initStdDev = 0.1d;
+ boolean conversionCheck = true;
+ double convergenceRate = 0.005d;
+
+ if (argOIs.length >= 5) {
+ String rawArgs = HiveUtils.getConstString(argOIs[4]);
+ cl = parseOptions(rawArgs);
+ if (cl.hasOption("factors")) {
+ this.factor =
Primitives.parseInt(cl.getOptionValue("factors"), 10);
+ } else {
+ this.factor =
Primitives.parseInt(cl.getOptionValue("factor"), 10);
+ }
+ this.lambdaTheta =
Primitives.parseFloat(cl.getOptionValue("lambda_theta"), 1e-5f);
+ this.lambdaBeta =
Primitives.parseFloat(cl.getOptionValue("lambda_beta"), 1e-5f);
+ this.lambdaGamma =
Primitives.parseFloat(cl.getOptionValue("lambda_gamma"), 1e+0f);
+ this.scale_zero =
Primitives.parseFloat(cl.getOptionValue("scale_zero"), 0.1f);
+ this.scale_nonzero =
Primitives.parseFloat(cl.getOptionValue("scale_nonzero"), 1.0f);
+ this.batchSize =
Primitives.parseInt(cl.getOptionValue("batch_size"), 1024);
+ if (cl.hasOption("num_items")) {
+ this.numItems =
Primitives.parseInt(cl.getOptionValue("num_items"), 1024);
+ } else {
+ throw new UDFArgumentException("-num_items must be
specified");
+ }
+ this.globalBias =
Primitives.parseFloat(cl.getOptionValue("gb"), 0.f);
+ this.updateGlobalBias = cl.hasOption("update_gb");
+ rankInitOpt = cl.getOptionValue("rankinit");
+ maxInitValue =
Primitives.parseFloat(cl.getOptionValue("max_init_value"), 1.f);
+ initStdDev =
Primitives.parseDouble(cl.getOptionValue("min_init_stddev"), 0.01d);
+ if (cl.hasOption("iter")) {
+ this.maxIters =
Primitives.parseInt(cl.getOptionValue("iter"), 1);
+ } else {
+ this.maxIters =
Primitives.parseInt(cl.getOptionValue("maxIters"), 1);
+ }
+ if (maxIters < 1) {
+ throw new UDFArgumentException(
+ "'-maxIters' must be greater than or equal to 1: "
+ maxIters);
+ }
+ conversionCheck = !cl.hasOption("disable_cvtest");
+ convergenceRate =
Primitives.parseDouble(cl.getOptionValue("cv_rate"), convergenceRate);
+ boolean noBias = cl.hasOption("no_bias");
+ this.useBiasClause = !noBias;
+ if (noBias && updateGlobalBias) {
+ throw new UDFArgumentException(
+ "Cannot set both `update_gb` and `no_bias`
option");
+ }
+ this.useL2Norm = !cl.hasOption("disable_l2norm");
+ }
+ this.rankInit = CofactorModel.RankInitScheme.resolve(rankInitOpt);
+ rankInit.setMaxInitValue(maxInitValue);
+ initStdDev = Math.max(initStdDev, 1.0d / factor);
+ rankInit.setInitStdDev(initStdDev);
+ this.cvState = new ConversionState(conversionCheck,
convergenceRate);
+ return cl;
+ }
+
+ @Override
+ public StructObjectInspector initialize(ObjectInspector[] argOIs)
throws UDFArgumentException {
+ if (argOIs.length < 3) {
+ throw new UDFArgumentException(
+ "_FUNC_ takes 5 arguments: string context,
array<string> features, boolean is_item, array<string> sppmi [, CONSTANT STRING
options]");
+ }
+ this.contextOI = HiveUtils.asStringOI(argOIs[0]);
+ this.featuresOI = HiveUtils.asListOI(argOIs[1]);
+
HiveUtils.validateFeatureOI(featuresOI.getListElementObjectInspector());
+ this.isItemOI = HiveUtils.asBooleanOI(argOIs[2]);
+ this.sppmiOI = HiveUtils.asListOI(argOIs[3]);
+
HiveUtils.validateFeatureOI(sppmiOI.getListElementObjectInspector());
+
+ processOptions(argOIs);
+
+ this.model = new CofactorModel(factor, rankInit, scale_zero,
scale_nonzero, lambdaTheta, lambdaBeta, lambdaGamma);
+ this.count = 0L;
+ this.lastWritePos = 0L;
+
+ List<String> fieldNames = new ArrayList<String>();
+ List<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
+ fieldNames.add("idx");
+
fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+ fieldNames.add("Pu");
+ fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(
+
PrimitiveObjectInspectorFactory.writableFloatObjectInspector));
+ fieldNames.add("Qi");
+ fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(
+
PrimitiveObjectInspectorFactory.writableFloatObjectInspector));
+ return
ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
+ }
+
+ @Override
+ public void process(Object[] args) throws HiveException {
+ if (args.length != 4) {
+ throw new HiveException("should have 4 args, but have " +
args.length);
+ }
+
+ String context = contextOI.getPrimitiveJavaObject(args[0]);
+ Feature[] features = parseFeatures(args[1], featuresOI,
featuresProbe);
+ if (features == null) {
+ throw new HiveException("features must not be null");
+ }
+
+ Boolean isItem = isItemOI.get(args[2]);
+ Feature[] sppmi = null;
+ if (isItem) {
+ sppmi = parseFeatures(args[3], sppmiOI, sppmiProbe);
+ }
+
+ model.recordContext(context, isItem);
+
+ this.contextProbe = context;
+ this.featuresProbe = features;
+ this.isItemProbe = isItem;
+ this.sppmiProbe = sppmi;
+
+ recordTrain(context, features, sppmi);
+ }
+
+ @Nullable
+ @VisibleForTesting
+ protected static Feature[] parseFeatures(@Nullable final Object arg,
ListObjectInspector listOI, Feature[] probe) throws HiveException {
+ if (arg == null) {
+ return null;
+ }
+ Feature[] rawFeatures = Feature.parseFeatures(arg, listOI, probe,
false);
+ return createNnzFeatureArray(rawFeatures);
+ }
+
+ @VisibleForTesting
+ protected static Feature[] createNnzFeatureArray(@Nonnull Feature[] x)
{
+ int nnz = countNnzFeatures(x);
+ Feature[] nnzFeatures = new Feature[nnz];
+ int i = 0;
+ for (Feature f : x) {
+ if (f.getValue() != 0.d) {
+ nnzFeatures[i++] = f;
+ }
+ }
+ return nnzFeatures;
+ }
+
+ private static int countNnzFeatures(@Nonnull Feature[] x) {
+ int nnz = 0;
+ for (Feature f : x) {
+ if (f.getValue() != 0.d) {
+ nnz++;
+ }
+ }
+ return nnz;
+ }
+
+ private Double trainMiniBatch(MiniBatch miniBatch) throws
HiveException {
+ model.updateWithUsers(miniBatch.getUsers());
+ model.updateWithItems(miniBatch.getItems());
+ return model.calculateLoss(miniBatch.getUsers(),
miniBatch.getItems());
+ }
+
+ private void recordTrain(final String context, final Feature[]
features, final Feature[] sppmi)
+ throws HiveException {
+ numTraining++;
+ ByteBuffer inputBuf = this.inputBuf;
+ NioStatefulSegment dst = this.fileIO;
+ if (inputBuf == null) {
+ final File file = createTempFile();
+ this.inputBuf = inputBuf = ByteBuffer.allocateDirect(1024 *
1024); // 1 MiB
+ this.fileIO = dst = new NioStatefulSegment(file, false);
+ }
+
+ writeRecordToBuffer(inputBuf, dst, context, features, sppmi);
+ }
+
+ private static void writeRecordToBuffer(@Nonnull final ByteBuffer
inputBuf, @Nonnull final NioStatefulSegment dst, @Nonnull final String context,
+ @Nonnull final Feature[]
features, @Nullable final Feature[] sppmi) throws HiveException {
+ int recordBytes = calculateRecordBytes(context, features, sppmi);
+ int requiredBytes = SizeOf.INT + recordBytes;
+ int remain = inputBuf.remaining();
+
+ if (remain < requiredBytes) {
+ writeBuffer(inputBuf, dst);
+ }
+
+ inputBuf.putInt(recordBytes);
+ NIOUtils.putString(context, inputBuf);
+ writeFeaturesToBuffer(features, inputBuf);
+ if (sppmi != null) {
+ inputBuf.put(TRUE_BYTE);
+ writeFeaturesToBuffer(sppmi, inputBuf);
+ } else {
+ inputBuf.put(FALSE_BYTE);
+ }
+ }
+
+ private static int calculateRecordBytes(String context, Feature[]
features, Feature[] sppmi) {
+ int contextBytes = SizeOf.INT + SizeOf.CHAR * context.length();
+ int featuresBytes = SizeOf.INT + Feature.requiredBytes(features);
+ int isItemBytes = SizeOf.BYTE;
+ int sppmiBytes = sppmi != null ? SizeOf.INT +
Feature.requiredBytes(sppmi) : 0;
+ return contextBytes + featuresBytes + isItemBytes + sppmiBytes;
+ }
+
+ private static File createTempFile() throws UDFArgumentException {
+ final File file;
+ try {
+ file = File.createTempFile("hivemall_cofactor", ".sgmt");
+ file.deleteOnExit();
+ if (!file.canWrite()) {
+ throw new UDFArgumentException(
+ "Cannot write a temporary file: " +
file.getAbsolutePath());
+ }
+ LOG.info("Record training examples to a file: " +
file.getAbsolutePath());
+ } catch (IOException ioe) {
+ throw new UDFArgumentException(ioe);
+ } catch (Throwable e) {
+ throw new UDFArgumentException(e);
+ }
+ return file;
+ }
+
+ private static void writeFeaturesToBuffer(Feature[] features,
ByteBuffer buffer) {
+ buffer.putInt(features.length);
+ for (Feature f : features) {
+ f.writeTo(buffer);
+ }
+ }
+
+ private static void writeBuffer(@Nonnull ByteBuffer srcBuf, @Nonnull
NioStatefulSegment dst)
+ throws HiveException {
+ srcBuf.flip();
+ try {
+ dst.write(srcBuf);
+ } catch (IOException e) {
+ throw new HiveException("Exception causes while writing a
buffer to file", e);
+ }
+ srcBuf.clear();
+ }
+
+ @Override
+ public void close() throws HiveException {
+ try {
+ boolean lossIncreasedLastIter = false;
+
+ final Reporter reporter = getReporter();
+ final Counters.Counter iterCounter = (reporter == null) ? null
+ : reporter.getCounter("hivemall.mf.Cofactor$Counter",
"iteration");
+
+ prepareForRead();
+
+ if (LOG.isInfoEnabled()) {
+ File tmpFile = fileIO.getFile();
+ LOG.info("Wrote " + numTraining
+ + " records to a temporary file for iterative
training: "
+ + tmpFile.getAbsolutePath() + " (" +
FileUtils.prettyFileSize(tmpFile)
+ + ")");
+ }
+
+ for (int iteration = 0; iteration < maxIters; iteration++) {
+ // train the model on a full batch (i.e., all the data)
using mini-batch updates
+// validationState.next();
+ cvState.next();
+ reportProgress(reporter);
+ setCounterValue(iterCounter, iteration);
+ runTrainingIteration();
+
+ LOG.info("Performed " + cvState.getCurrentIteration() + "
iterations of "
+ + NumberUtils.formatNumber(maxIters));
+// + " training examples on a secondary storage
(thus "
+// + NumberUtils.formatNumber(_t) + " training
updates in total), used "
+// + _numValidations + " validation examples");
+ }
+ } finally {
+ // delete the temporary file and release resources
+ try {
+ fileIO.close(true);
+ } catch (IOException e) {
+ throw new HiveException(
+ "Failed to close a file: " +
fileIO.getFile().getAbsolutePath(), e);
+ }
+ this.inputBuf = null;
+ this.fileIO = null;
+ }
+ }
+
+ @VisibleForTesting
+ protected void prepareForRead() throws HiveException {
+ // write training examples in buffer to a temporary file
+ if (inputBuf.remaining() > 0) {
+ writeBuffer(inputBuf, fileIO);
+ }
+ try {
+ fileIO.flush();
+ } catch (IOException e) {
+ throw new HiveException(
+ "Failed to flush a file: " +
fileIO.getFile().getAbsolutePath(), e);
+ }
+ fileIO.resetPosition();
+ }
+
+ private void runTrainingIteration() throws HiveException {
+ fileIO.resetPosition();
+ MiniBatch miniBatch = new MiniBatch();
+ // read minibatch from disk into memory
+ while (readMiniBatchFromFile(miniBatch)) {
+ Double trainLoss = trainMiniBatch(miniBatch);
+ if (trainLoss != null) {
+ cvState.incrLoss(trainLoss);
+ }
+ miniBatch.clear();
+ }
+
+ }
+
+ @Nonnull
+ private static Feature instantiateFeature(@Nonnull final ByteBuffer
input) {
+ return new StringFeature(input);
+ }
+
+ @VisibleForTesting
+ protected boolean readMiniBatchFromFile(MiniBatch miniBatch) throws
HiveException {
+ inputBuf.clear();
--- End diff --
Fix this bug: remove the line
---