Repository: incubator-hivemall Updated Branches: refs/heads/HIVEMALL-24-2 [created] f70e7c52e
Implement -ffm option in `feature_pairs` Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/591e3b0f Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/591e3b0f Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/591e3b0f Branch: refs/heads/HIVEMALL-24-2 Commit: 591e3b0f255e6523167157ed2e68c9499b4ca2cd Parents: 1cccf66 Author: Takuya Kitazawa <[email protected]> Authored: Fri Mar 3 17:48:12 2017 +0900 Committer: Takuya Kitazawa <[email protected]> Committed: Fri Mar 3 17:48:12 2017 +0900 ---------------------------------------------------------------------- .../ftvec/pairing/FeaturePairsUDTF.java | 134 +++++++++++++++---- 1 file changed, 109 insertions(+), 25 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/591e3b0f/core/src/main/java/hivemall/ftvec/pairing/FeaturePairsUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/ftvec/pairing/FeaturePairsUDTF.java b/core/src/main/java/hivemall/ftvec/pairing/FeaturePairsUDTF.java index 6aebd64..d814b40 100644 --- a/core/src/main/java/hivemall/ftvec/pairing/FeaturePairsUDTF.java +++ b/core/src/main/java/hivemall/ftvec/pairing/FeaturePairsUDTF.java @@ -20,6 +20,7 @@ package hivemall.ftvec.pairing; import hivemall.UDTFWithOptions; import hivemall.model.FeatureValue; +import hivemall.fm.Feature; import hivemall.utils.hadoop.HiveUtils; import hivemall.utils.hashing.HashFunction; import hivemall.utils.lang.Preconditions; @@ -29,6 +30,7 @@ import java.util.List; import javax.annotation.Nonnull; +import hivemall.utils.lang.Primitives; import org.apache.commons.cli.CommandLine; import org.apache.commons.cli.Options; import org.apache.hadoop.hive.ql.exec.Description; @@ -50,6 +52,8 @@ public final class FeaturePairsUDTF extends UDTFWithOptions { private Type _type; private RowProcessor _proc; + private int _numFields; + private int _numFeatures; public FeaturePairsUDTF() {} @@ -60,6 +64,9 @@ public final class FeaturePairsUDTF extends UDTFWithOptions { "Generate feature pairs for Kernel-Expansion Passive Aggressive [default:true]"); opts.addOption("ffm", false, "Generate feature pairs for Field-aware Factorization Machines [default:false]"); + opts.addOption("feature_hashing", true, + "The number of bits for feature hashing in range [18,31] [default:21]"); + opts.addOption("num_fields", true, "The number of fields [default:1024]"); return opts; } @@ -70,13 +77,17 @@ public final class FeaturePairsUDTF extends UDTFWithOptions { String args = HiveUtils.getConstString(argOIs[1]); cl = parseOptions(args); - Preconditions.checkArgument(cl.getOptions().length == 1, UDFArgumentException.class, - "Only one option can be specified: " + cl.getArgList()); + Preconditions.checkArgument(cl.getOptions().length <= 3, UDFArgumentException.class, + "Too many options were specified: " + cl.getArgList()); if (cl.hasOption("kpa")) { this._type = Type.kpa; } else if (cl.hasOption("ffm")) { this._type = Type.ffm; + int featureBits = Primitives.parseInt( + cl.getOptionValue("feature_hashing"), Feature.DEFAULT_FEATURE_BITS); + this._numFeatures = 1 << featureBits; + this._numFields = Primitives.parseInt(cl.getOptionValue("num_fields"), Feature.DEFAULT_NUM_FIELDS); } else { throw new UDFArgumentException("Unsupported option: " + cl.getArgList().get(0)); } @@ -113,8 +124,16 @@ public final class FeaturePairsUDTF extends UDTFWithOptions { break; } case ffm: { - throw new UDFArgumentException("-ffm is not supported yet"); - //break; + this._proc = new FFMProcessor(fvOI); + fieldNames.add("i"); // <ei, jField> index + fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); + fieldNames.add("j"); // <ej, iField> index + fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); + fieldNames.add("xi"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); + fieldNames.add("xj"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); + break; } default: throw new UDFArgumentException("Illegal condition: " + _type); @@ -144,26 +163,7 @@ public final class FeaturePairsUDTF extends UDTFWithOptions { this.fvOI = fvOI; } - void process(@Nonnull Object arg) throws HiveException { - final int size = fvOI.getListLength(arg); - if (size == 0) { - return; - } - - final List<FeatureValue> features = new ArrayList<FeatureValue>(size); - for (int i = 0; i < size; i++) { - Object f = fvOI.getListElement(arg, i); - if (f == null) { - continue; - } - FeatureValue fv = FeatureValue.parse(f, true); - features.add(fv); - } - - process(features); - } - - abstract void process(@Nonnull List<FeatureValue> features) throws HiveException; + abstract void process(@Nonnull Object arg) throws HiveException; } @@ -186,7 +186,22 @@ public final class FeaturePairsUDTF extends UDTFWithOptions { } @Override - void process(@Nonnull List<FeatureValue> features) throws HiveException { + void process(@Nonnull Object arg) throws HiveException { + final int size = fvOI.getListLength(arg); + if (size == 0) { + return; + } + + final List<FeatureValue> features = new ArrayList<FeatureValue>(size); + for (int i = 0; i < size; i++) { + Object f = fvOI.getListElement(arg, i); + if (f == null) { + continue; + } + FeatureValue fv = FeatureValue.parse(f, true); + features.add(fv); + } + forward[0] = f0; f0.set(0); forward[1] = null; @@ -222,6 +237,75 @@ public final class FeaturePairsUDTF extends UDTFWithOptions { } } + final class FFMProcessor extends RowProcessor { + + @Nonnull + private final IntWritable f0, f1; + @Nonnull + private final DoubleWritable f2, f3; + @Nonnull + private final Writable[] forward; + + FFMProcessor(@Nonnull ListObjectInspector fvOI) { + super(fvOI); + this.f0 = new IntWritable(); + this.f1 = new IntWritable(); + this.f2 = new DoubleWritable(); + this.f3 = new DoubleWritable(); + this.forward = new Writable[] {f0, null, null, null}; + } + + @Override + void process(@Nonnull Object arg) throws HiveException { + final int size = fvOI.getListLength(arg); + if (size == 0) { + return; + } + + final Feature[] features = Feature.parseFFMFeatures(arg, fvOI, null, _numFeatures, _numFields); + + // W0 + forward[0] = f0; + f0.set(0); + forward[1] = null; + forward[2] = null; + forward[3] = null; + forward(forward); + + forward[2] = f2; + for (int i = 0, len = features.length; i < len; i++) { + Feature ei = features[i]; + double xi = ei.getValue(); + int iField = ei.getField(); + + // Wi + forward[0] = f0; + f0.set(i); + forward[1] = null; + f2.set(xi); + forward[3] = null; + forward(forward); + + forward[1] = f1; + forward[3] = f3; + for (int j = i + 1; j < len; j++) { + Feature ej = features[j]; + double xj = ej.getValue(); + int jField = ej.getField(); + + int ifj = Feature.toIntFeature(ei, jField, _numFields); + int jfi = Feature.toIntFeature(ej, iField, _numFields); + + // Vifj, Vjfi + f0.set(ifj); + f1.set(jfi); + // `f2` is consistently set to `xi` + f3.set(xj); + forward(forward); + } + } + } + } @Override public void close() throws HiveException {
