Github user jaxony commented on a diff in the pull request:
https://github.com/apache/incubator-hivemall/pull/167#discussion_r226523395
--- Diff: core/src/main/java/hivemall/mf/CofactorizationUDTF.java ---
@@ -0,0 +1,574 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.mf;
+
+import hivemall.UDTFWithOptions;
+import hivemall.common.ConversionState;
+import hivemall.fm.Feature;
+import hivemall.fm.StringFeature;
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.io.FileUtils;
+import hivemall.utils.io.NioStatefulSegment;
+import hivemall.utils.lang.NumberUtils;
+import hivemall.utils.lang.Primitives;
+import hivemall.utils.lang.SizeOf;
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.Options;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.serde2.objectinspector.*;
+import
org.apache.hadoop.hive.serde2.objectinspector.primitive.BooleanObjectInspector;
+import
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import
org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector;
+import org.apache.hadoop.mapred.Counters;
+import org.apache.hadoop.mapred.Reporter;
+
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+import java.io.File;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.List;
+
+import static hivemall.utils.lang.Primitives.FALSE_BYTE;
+import static hivemall.utils.lang.Primitives.TRUE_BYTE;
+
+public class CofactorizationUDTF extends UDTFWithOptions {
+ private static final Log LOG =
LogFactory.getLog(CofactorizationUDTF.class);
+
+ // Option variables
+ // The number of latent factors
+ protected int factor;
+ // The scaling hyperparameter for zero entries in the rank matrix
+ protected float scale_zero;
+ // The scaling hyperparameter for non-zero entries in the rank matrix
+ protected float scale_nonzero;
+ // The preferred size of the miniBatch for training
+ protected int batchSize;
+ // The initial mean rating
+ protected float globalBias;
+ // Whether update (and return) the mean rating or not
+ protected boolean updateGlobalBias;
+ // The number of iterations
+ protected int maxIters;
+ // Whether to use bias clause
+ protected boolean useBiasClause;
+ // Whether to use normalization
+ protected boolean useL2Norm;
+ // regularization hyperparameters
+ protected float lambdaTheta;
+ protected float lambdaBeta;
+ protected float lambdaGamma;
+
+ // Initialization strategy of rank matrix
+ protected CofactorModel.RankInitScheme rankInit;
+
+ // Model itself
+ protected CofactorModel model;
+ protected int numItems;
+
+ // Variable managing status of learning
+
+ // The number of processed training examples
+ protected long count;
+
+ protected ConversionState cvState;
+ private ConversionState validationState;
+
+ // Input OIs and Context
+ protected StringObjectInspector contextOI;
+ protected ListObjectInspector featuresOI;
+ protected BooleanObjectInspector isItemOI;
+ protected ListObjectInspector sppmiOI;
+
+ // Used for iterations
+ protected NioStatefulSegment fileIO;
+ protected ByteBuffer inputBuf;
+ private long lastWritePos;
+
+ private Feature contextProbe;
+ private Feature[] featuresProbe;
+ private Feature[] sppmiProbe;
+ private boolean isItemProbe;
+ private long numValidations;
+ private long numTraining;
+
+
+ static class MiniBatch {
+ protected int maxSize;
+ private List<TrainingSample> users;
+ private List<TrainingSample> items;
+
+ protected MiniBatch(int maxSize) {
+ this.maxSize = maxSize;
+ }
+
+ protected void add(TrainingSample sample) {
+ if (size() == this.maxSize) {
+ return;
--- End diff --
I think an explicit mini batch size is not required anymore. The upper
limit is the size of the `inputBuf`. Is this correct?
---