Github user jaxony commented on a diff in the pull request:

    https://github.com/apache/incubator-hivemall/pull/167#discussion_r226523395
  
    --- Diff: core/src/main/java/hivemall/mf/CofactorizationUDTF.java ---
    @@ -0,0 +1,574 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one
    + * or more contributor license agreements.  See the NOTICE file
    + * distributed with this work for additional information
    + * regarding copyright ownership.  The ASF licenses this file
    + * to you under the Apache License, Version 2.0 (the
    + * "License"); you may not use this file except in compliance
    + * with the License.  You may obtain a copy of the License at
    + *
    + *   http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing,
    + * software distributed under the License is distributed on an
    + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
    + * KIND, either express or implied.  See the License for the
    + * specific language governing permissions and limitations
    + * under the License.
    + */
    +package hivemall.mf;
    +
    +import hivemall.UDTFWithOptions;
    +import hivemall.common.ConversionState;
    +import hivemall.fm.Feature;
    +import hivemall.fm.StringFeature;
    +import hivemall.utils.hadoop.HiveUtils;
    +import hivemall.utils.io.FileUtils;
    +import hivemall.utils.io.NioStatefulSegment;
    +import hivemall.utils.lang.NumberUtils;
    +import hivemall.utils.lang.Primitives;
    +import hivemall.utils.lang.SizeOf;
    +import org.apache.commons.cli.CommandLine;
    +import org.apache.commons.cli.Options;
    +import org.apache.commons.logging.Log;
    +import org.apache.commons.logging.LogFactory;
    +import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
    +import org.apache.hadoop.hive.ql.metadata.HiveException;
    +import org.apache.hadoop.hive.serde2.objectinspector.*;
    +import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.BooleanObjectInspector;
    +import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
    +import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector;
    +import org.apache.hadoop.mapred.Counters;
    +import org.apache.hadoop.mapred.Reporter;
    +
    +import javax.annotation.Nonnull;
    +import javax.annotation.Nullable;
    +import java.io.File;
    +import java.io.IOException;
    +import java.nio.ByteBuffer;
    +import java.util.ArrayList;
    +import java.util.List;
    +
    +import static hivemall.utils.lang.Primitives.FALSE_BYTE;
    +import static hivemall.utils.lang.Primitives.TRUE_BYTE;
    +
    +public class CofactorizationUDTF extends UDTFWithOptions {
    +    private static final Log LOG = 
LogFactory.getLog(CofactorizationUDTF.class);
    +
    +    // Option variables
    +    // The number of latent factors
    +    protected int factor;
    +    // The scaling hyperparameter for zero entries in the rank matrix
    +    protected float scale_zero;
    +    // The scaling hyperparameter for non-zero entries in the rank matrix
    +    protected float scale_nonzero;
    +    // The preferred size of the miniBatch for training
    +    protected int batchSize;
    +    // The initial mean rating
    +    protected float globalBias;
    +    // Whether update (and return) the mean rating or not
    +    protected boolean updateGlobalBias;
    +    // The number of iterations
    +    protected int maxIters;
    +    // Whether to use bias clause
    +    protected boolean useBiasClause;
    +    // Whether to use normalization
    +    protected boolean useL2Norm;
    +    // regularization hyperparameters
    +    protected float lambdaTheta;
    +    protected float lambdaBeta;
    +    protected float lambdaGamma;
    +
    +    // Initialization strategy of rank matrix
    +    protected CofactorModel.RankInitScheme rankInit;
    +
    +    // Model itself
    +    protected CofactorModel model;
    +    protected int numItems;
    +
    +    // Variable managing status of learning
    +
    +    // The number of processed training examples
    +    protected long count;
    +
    +    protected ConversionState cvState;
    +    private ConversionState validationState;
    +
    +    // Input OIs and Context
    +    protected StringObjectInspector contextOI;
    +    protected ListObjectInspector featuresOI;
    +    protected BooleanObjectInspector isItemOI;
    +    protected ListObjectInspector sppmiOI;
    +
    +    // Used for iterations
    +    protected NioStatefulSegment fileIO;
    +    protected ByteBuffer inputBuf;
    +    private long lastWritePos;
    +
    +    private Feature contextProbe;
    +    private Feature[] featuresProbe;
    +    private Feature[] sppmiProbe;
    +    private boolean isItemProbe;
    +    private long numValidations;
    +    private long numTraining;
    +
    +
    +    static class MiniBatch {
    +        protected int maxSize;
    +        private List<TrainingSample> users;
    +        private List<TrainingSample> items;
    +
    +        protected MiniBatch(int maxSize) {
    +            this.maxSize = maxSize;
    +        }
    +
    +        protected void add(TrainingSample sample) {
    +            if (size() == this.maxSize) {
    +                return;
    --- End diff --
    
    I think an explicit mini batch size is not required anymore. The upper 
limit is the size of the `inputBuf`. Is this correct?


---

Reply via email to