Author: srowen Date: Fri Jan 22 14:47:03 2010 New Revision: 902105 URL: http://svn.apache.org/viewvc?rev=902105&view=rev Log: MAHOUT-262
Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/math/MultiLabelVectorWritable.java Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/math/SequentialAccessSparseVectorWritable.java Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/math/MultiLabelVectorWritable.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/math/MultiLabelVectorWritable.java?rev=902105&view=auto ============================================================================== --- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/math/MultiLabelVectorWritable.java (added) +++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/math/MultiLabelVectorWritable.java Fri Jan 22 14:47:03 2010 @@ -0,0 +1,81 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; + +/** + * Writable to handle serialization of a vector and a variable list of + * associated label indexes. + */ +public class MultiLabelVectorWritable extends VectorWritable { + + private int[] labels; + + public void setLabels(int[] labels) { + this.labels = labels; + } + + public int[] getLabels() { + return labels; + } + + public MultiLabelVectorWritable() {} + + public MultiLabelVectorWritable(Vector v, int[] labels) { + super(v); + setLabels(labels); + } + + @Override + public void readFields(DataInput in) throws IOException { + int labelSize = in.readInt(); + labels = new int[labelSize]; + for (int i = 0; i < labelSize; i++) { + labels[i] = in.readInt(); + } + super.readFields(in); + } + + @Override + public void write(DataOutput out) throws IOException { + out.writeInt(labels.length); + for (int i = 0; i < labels.length; i++) { + out.writeInt(labels[i]); + } + super.write(out); + } + + public static MultiLabelVectorWritable read(DataInput in) throws IOException { + int labelSize = in.readInt(); + int[] labels = new int[labelSize]; + for (int i = 0; i < labelSize; i++) { + labels[i] = in.readInt(); + } + Vector vector = VectorWritable.readVector(in); + return new MultiLabelVectorWritable(vector, labels); + } + + public static void write(DataOutput out, SequentialAccessSparseVector ssv, + int[] labels) throws IOException { + (new MultiLabelVectorWritable(ssv, labels)).write(out); + } + +} Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/math/SequentialAccessSparseVectorWritable.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/math/SequentialAccessSparseVectorWritable.java?rev=902105&r1=902104&r2=902105&view=diff ============================================================================== --- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/math/SequentialAccessSparseVectorWritable.java (original) +++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/math/SequentialAccessSparseVectorWritable.java Fri Jan 22 14:47:03 2010 @@ -17,13 +17,12 @@ package org.apache.mahout.math; -import org.apache.hadoop.io.Writable; - import java.io.DataInput; import java.io.DataOutput; import java.io.IOException; import java.util.Iterator; +import org.apache.hadoop.io.Writable; public class SequentialAccessSparseVectorWritable extends SequentialAccessSparseVector implements Writable { @@ -45,7 +44,7 @@ Iterator<Element> iter = iterateNonZero(); int count = 0; while (iter.hasNext()) { - Vector.Element element = iter.next(); + Element element = iter.next(); dataOutput.writeInt(element.index()); dataOutput.writeDouble(element.get()); count++; @@ -61,14 +60,12 @@ } else { setName(className); // we have already read the class name in VectorWritable } - int cardinality = dataInput.readInt(); - int size = dataInput.readInt(); - OrderedIntDoubleMapping values = new OrderedIntDoubleMapping(size); - int i = 0; - for (; i < size; i++) { + size = dataInput.readInt(); + int nde = dataInput.readInt(); + OrderedIntDoubleMapping values = new OrderedIntDoubleMapping(nde); + for (int i = 0; i < nde; i++) { values.set(dataInput.readInt(), dataInput.readDouble()); } - assert (i == size); this.values = values; }