http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/data/DataConverterTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/data/DataConverterTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/data/DataConverterTest.java new file mode 100644 index 0000000..dfae61d --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/data/DataConverterTest.java @@ -0,0 +1,60 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.data; + +import java.util.Random; + +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.common.RandomUtils; +import org.junit.Test; +@Deprecated +public final class DataConverterTest extends MahoutTestCase { + + private static final int ATTRIBUTE_COUNT = 10; + + private static final int INSTANCE_COUNT = 100; + + @Test + public void testConvert() throws Exception { + Random rng = RandomUtils.getRandom(); + + String descriptor = Utils.randomDescriptor(rng, ATTRIBUTE_COUNT); + double[][] source = Utils.randomDoubles(rng, descriptor, false, INSTANCE_COUNT); + String[] sData = Utils.double2String(source); + Dataset dataset = DataLoader.generateDataset(descriptor, false, sData); + Data data = DataLoader.loadData(dataset, sData); + + DataConverter converter = new DataConverter(dataset); + + for (int index = 0; index < data.size(); index++) { + assertEquals(data.get(index), converter.convert(sData[index])); + } + + // regression + source = Utils.randomDoubles(rng, descriptor, true, INSTANCE_COUNT); + sData = Utils.double2String(source); + dataset = DataLoader.generateDataset(descriptor, true, sData); + data = DataLoader.loadData(dataset, sData); + + converter = new DataConverter(dataset); + + for (int index = 0; index < data.size(); index++) { + assertEquals(data.get(index), converter.convert(sData[index])); + } + } +}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/data/DataLoaderTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/data/DataLoaderTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/data/DataLoaderTest.java new file mode 100644 index 0000000..aeb69fc --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/data/DataLoaderTest.java @@ -0,0 +1,350 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.data; + +import java.util.Collection; +import java.util.Random; + +import com.google.common.collect.Lists; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.classifier.df.data.Dataset.Attribute; +import org.junit.Test; +@Deprecated +public final class DataLoaderTest extends MahoutTestCase { + + private Random rng; + + @Override + public void setUp() throws Exception { + super.setUp(); + rng = RandomUtils.getRandom(); + } + + @Test + public void testLoadDataWithDescriptor() throws Exception { + int nbAttributes = 10; + int datasize = 100; + + // prepare the descriptors + String descriptor = Utils.randomDescriptor(rng, nbAttributes); + Attribute[] attrs = DescriptorUtils.parseDescriptor(descriptor); + + // prepare the data + double[][] data = Utils.randomDoubles(rng, descriptor, false, datasize); + Collection<Integer> missings = Lists.newArrayList(); + String[] sData = prepareData(data, attrs, missings); + Dataset dataset = DataLoader.generateDataset(descriptor, false, sData); + Data loaded = DataLoader.loadData(dataset, sData); + + testLoadedData(data, attrs, missings, loaded); + testLoadedDataset(data, attrs, missings, loaded); + + // regression + data = Utils.randomDoubles(rng, descriptor, true, datasize); + missings = Lists.newArrayList(); + sData = prepareData(data, attrs, missings); + dataset = DataLoader.generateDataset(descriptor, true, sData); + loaded = DataLoader.loadData(dataset, sData); + + testLoadedData(data, attrs, missings, loaded); + testLoadedDataset(data, attrs, missings, loaded); + } + + /** + * Test method for + * {@link DataLoader#generateDataset(CharSequence, boolean, String[])} + */ + @Test + public void testGenerateDataset() throws Exception { + int nbAttributes = 10; + int datasize = 100; + + // prepare the descriptors + String descriptor = Utils.randomDescriptor(rng, nbAttributes); + Attribute[] attrs = DescriptorUtils.parseDescriptor(descriptor); + + // prepare the data + double[][] data = Utils.randomDoubles(rng, descriptor, false, datasize); + Collection<Integer> missings = Lists.newArrayList(); + String[] sData = prepareData(data, attrs, missings); + Dataset expected = DataLoader.generateDataset(descriptor, false, sData); + + Dataset dataset = DataLoader.generateDataset(descriptor, false, sData); + + assertEquals(expected, dataset); + + // regression + data = Utils.randomDoubles(rng, descriptor, true, datasize); + missings = Lists.newArrayList(); + sData = prepareData(data, attrs, missings); + expected = DataLoader.generateDataset(descriptor, true, sData); + + dataset = DataLoader.generateDataset(descriptor, true, sData); + + assertEquals(expected, dataset); +} + + /** + * Converts the data to an array of comma-separated strings and adds some + * missing values in all but IGNORED attributes + * + * @param missings indexes of vectors with missing values + */ + private String[] prepareData(double[][] data, Attribute[] attrs, Collection<Integer> missings) { + int nbAttributes = attrs.length; + + String[] sData = new String[data.length]; + + for (int index = 0; index < data.length; index++) { + int missingAttr; + if (rng.nextDouble() < 0.0) { + // add a missing value + missings.add(index); + + // choose a random attribute (not IGNORED) + do { + missingAttr = rng.nextInt(nbAttributes); + } while (attrs[missingAttr].isIgnored()); + } else { + missingAttr = -1; + } + + StringBuilder builder = new StringBuilder(); + + for (int attr = 0; attr < nbAttributes; attr++) { + if (attr == missingAttr) { + // add a missing value here + builder.append('?').append(','); + } else { + builder.append(data[index][attr]).append(','); + } + } + + sData[index] = builder.toString(); + } + + return sData; + } + + /** + * Test if the loaded data matches the source data + * + * @param missings indexes of instance with missing values + */ + static void testLoadedData(double[][] data, Attribute[] attrs, Collection<Integer> missings, Data loaded) { + int nbAttributes = attrs.length; + + // check the vectors + assertEquals("number of instance", data.length - missings.size(), loaded .size()); + + // make sure that the attributes are loaded correctly + int lind = 0; + for (int index = 0; index < data.length; index++) { + if (missings.contains(index)) { + continue; + }// this vector won't be loaded + + double[] vector = data[index]; + Instance instance = loaded.get(lind); + + int aId = 0; + for (int attr = 0; attr < nbAttributes; attr++) { + if (attrs[attr].isIgnored()) { + continue; + } + + if (attrs[attr].isNumerical()) { + assertEquals(vector[attr], instance.get(aId), EPSILON); + aId++; + } else if (attrs[attr].isCategorical()) { + checkCategorical(data, missings, loaded, attr, aId, vector[attr], + instance.get(aId)); + aId++; + } else if (attrs[attr].isLabel()) { + if (loaded.getDataset().isNumerical(aId)) { + assertEquals(vector[attr], instance.get(aId), EPSILON); + } else { + checkCategorical(data, missings, loaded, attr, aId, vector[attr], + instance.get(aId)); + } + aId++; + } + } + + lind++; + } + + } + + /** + * Test if the loaded dataset matches the source data + * + * @param missings indexes of instance with missing values + */ + static void testLoadedDataset(double[][] data, + Attribute[] attrs, + Collection<Integer> missings, + Data loaded) { + int nbAttributes = attrs.length; + + int iId = 0; + for (int index = 0; index < data.length; index++) { + if (missings.contains(index)) { + continue; + } + + Instance instance = loaded.get(iId++); + + int aId = 0; + for (int attr = 0; attr < nbAttributes; attr++) { + if (attrs[attr].isIgnored()) { + continue; + } + + if (attrs[attr].isLabel()) { + if (!loaded.getDataset().isNumerical(aId)) { + double nValue = instance.get(aId); + String oValue = Double.toString(data[index][attr]); + assertEquals(loaded.getDataset().valueOf(aId, oValue), nValue, EPSILON); + } + } else { + assertEquals(attrs[attr].isNumerical(), loaded.getDataset().isNumerical(aId)); + + if (attrs[attr].isCategorical()) { + double nValue = instance.get(aId); + String oValue = Double.toString(data[index][attr]); + assertEquals(loaded.getDataset().valueOf(aId, oValue), nValue, EPSILON); + } + } + aId++; + } + } + + } + + @Test + public void testLoadDataFromFile() throws Exception { + int nbAttributes = 10; + int datasize = 100; + + // prepare the descriptors + String descriptor = Utils.randomDescriptor(rng, nbAttributes); + Attribute[] attrs = DescriptorUtils.parseDescriptor(descriptor); + + // prepare the data + double[][] source = Utils.randomDoubles(rng, descriptor, false, datasize); + Collection<Integer> missings = Lists.newArrayList(); + String[] sData = prepareData(source, attrs, missings); + Dataset dataset = DataLoader.generateDataset(descriptor, false, sData); + + Path dataPath = Utils.writeDataToTestFile(sData); + FileSystem fs = dataPath.getFileSystem(getConfiguration()); + Data loaded = DataLoader.loadData(dataset, fs, dataPath); + + testLoadedData(source, attrs, missings, loaded); + + // regression + source = Utils.randomDoubles(rng, descriptor, true, datasize); + missings = Lists.newArrayList(); + sData = prepareData(source, attrs, missings); + dataset = DataLoader.generateDataset(descriptor, true, sData); + + dataPath = Utils.writeDataToTestFile(sData); + fs = dataPath.getFileSystem(getConfiguration()); + loaded = DataLoader.loadData(dataset, fs, dataPath); + + testLoadedData(source, attrs, missings, loaded); +} + + /** + * Test method for + * {@link DataLoader#generateDataset(CharSequence, boolean, FileSystem, Path)} + */ + @Test + public void testGenerateDatasetFromFile() throws Exception { + int nbAttributes = 10; + int datasize = 100; + + // prepare the descriptors + String descriptor = Utils.randomDescriptor(rng, nbAttributes); + Attribute[] attrs = DescriptorUtils.parseDescriptor(descriptor); + + // prepare the data + double[][] source = Utils.randomDoubles(rng, descriptor, false, datasize); + Collection<Integer> missings = Lists.newArrayList(); + String[] sData = prepareData(source, attrs, missings); + Dataset expected = DataLoader.generateDataset(descriptor, false, sData); + + Path path = Utils.writeDataToTestFile(sData); + FileSystem fs = path.getFileSystem(getConfiguration()); + + Dataset dataset = DataLoader.generateDataset(descriptor, false, fs, path); + + assertEquals(expected, dataset); + + // regression + source = Utils.randomDoubles(rng, descriptor, false, datasize); + missings = Lists.newArrayList(); + sData = prepareData(source, attrs, missings); + expected = DataLoader.generateDataset(descriptor, false, sData); + + path = Utils.writeDataToTestFile(sData); + fs = path.getFileSystem(getConfiguration()); + + dataset = DataLoader.generateDataset(descriptor, false, fs, path); + + assertEquals(expected, dataset); + } + + /** + * each time oValue appears in data for the attribute 'attr', the nValue must + * appear in vectors for the same attribute. + * + * @param attr attribute's index in source + * @param aId attribute's index in loaded + * @param oValue old value in source + * @param nValue new value in loaded + */ + static void checkCategorical(double[][] source, + Collection<Integer> missings, + Data loaded, + int attr, + int aId, + double oValue, + double nValue) { + int lind = 0; + + for (int index = 0; index < source.length; index++) { + if (missings.contains(index)) { + continue; + } + + if (source[index][attr] == oValue) { + assertEquals(nValue, loaded.get(lind).get(aId), EPSILON); + } else { + assertFalse(nValue == loaded.get(lind).get(aId)); + } + + lind++; + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/data/DataTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/data/DataTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/data/DataTest.java new file mode 100644 index 0000000..70ed7f6 --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/data/DataTest.java @@ -0,0 +1,396 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.data; + +import java.util.Arrays; +import java.util.Random; + +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.classifier.df.data.conditions.Condition; +import org.junit.Test; +@Deprecated +public class DataTest extends MahoutTestCase { + + private static final int ATTRIBUTE_COUNT = 10; + + private static final int DATA_SIZE = 100; + + private Random rng; + + private Data classifierData; + + private Data regressionData; + + @Override + public void setUp() throws Exception { + super.setUp(); + rng = RandomUtils.getRandom(); + classifierData = Utils.randomData(rng, ATTRIBUTE_COUNT, false, DATA_SIZE); + regressionData = Utils.randomData(rng, ATTRIBUTE_COUNT, true, DATA_SIZE); + } + + /** + * Test method for + * {@link org.apache.mahout.classifier.df.data.Data#subset(org.apache.mahout.classifier.df.data.conditions.Condition)}. + */ + @Test + public void testSubset() { + int n = 10; + + for (int nloop = 0; nloop < n; nloop++) { + int attr = rng.nextInt(classifierData.getDataset().nbAttributes()); + + double[] values = classifierData.values(attr); + double value = values[rng.nextInt(values.length)]; + + Data eSubset = classifierData.subset(Condition.equals(attr, value)); + Data lSubset = classifierData.subset(Condition.lesser(attr, value)); + Data gSubset = classifierData.subset(Condition.greaterOrEquals(attr, value)); + + for (int index = 0; index < DATA_SIZE; index++) { + Instance instance = classifierData.get(index); + + if (instance.get(attr) < value) { + assertTrue(lSubset.contains(instance)); + assertFalse(eSubset.contains(instance)); + assertFalse(gSubset.contains(instance)); + } else if (instance.get(attr) == value) { + assertFalse(lSubset.contains(instance)); + assertTrue(eSubset.contains(instance)); + assertTrue(gSubset.contains(instance)); + } else { + assertFalse(lSubset.contains(instance)); + assertFalse(eSubset.contains(instance)); + assertTrue(gSubset.contains(instance)); + } + } + + // regression + attr = rng.nextInt(regressionData.getDataset().nbAttributes()); + + values = regressionData.values(attr); + value = values[rng.nextInt(values.length)]; + + eSubset = regressionData.subset(Condition.equals(attr, value)); + lSubset = regressionData.subset(Condition.lesser(attr, value)); + gSubset = regressionData.subset(Condition.greaterOrEquals(attr, value)); + + for (int index = 0; index < DATA_SIZE; index++) { + Instance instance = regressionData.get(index); + + if (instance.get(attr) < value) { + assertTrue(lSubset.contains(instance)); + assertFalse(eSubset.contains(instance)); + assertFalse(gSubset.contains(instance)); + } else if (instance.get(attr) == value) { + assertFalse(lSubset.contains(instance)); + assertTrue(eSubset.contains(instance)); + assertTrue(gSubset.contains(instance)); + } else { + assertFalse(lSubset.contains(instance)); + assertFalse(eSubset.contains(instance)); + assertTrue(gSubset.contains(instance)); + } + } + } + } + + @Test + public void testValues() throws Exception { + for (int attr = 0; attr < classifierData.getDataset().nbAttributes(); attr++) { + double[] values = classifierData.values(attr); + + // each value of the attribute should appear exactly one time in values + for (int index = 0; index < DATA_SIZE; index++) { + assertEquals(1, count(values, classifierData.get(index).get(attr))); + } + } + + for (int attr = 0; attr < regressionData.getDataset().nbAttributes(); attr++) { + double[] values = regressionData.values(attr); + + // each value of the attribute should appear exactly one time in values + for (int index = 0; index < DATA_SIZE; index++) { + assertEquals(1, count(values, regressionData.get(index).get(attr))); + } + } + } + + private static int count(double[] values, double value) { + int count = 0; + for (double v : values) { + if (v == value) { + count++; + } + } + return count; + } + + @Test + public void testIdenticalTrue() throws Exception { + // generate a small data, only to get the dataset + Dataset dataset = Utils.randomData(rng, ATTRIBUTE_COUNT, false, 1).getDataset(); + + // test empty data + Data empty = new Data(dataset); + assertTrue(empty.isIdentical()); + + // test identical data, except for the labels + Data identical = Utils.randomData(rng, ATTRIBUTE_COUNT, false, DATA_SIZE); + Instance model = identical.get(0); + for (int index = 1; index < DATA_SIZE; index++) { + for (int attr = 0; attr < identical.getDataset().nbAttributes(); attr++) { + identical.get(index).set(attr, model.get(attr)); + } + } + + assertTrue(identical.isIdentical()); + } + + @Test + public void testIdenticalFalse() throws Exception { + int n = 10; + + for (int nloop = 0; nloop < n; nloop++) { + Data data = Utils.randomData(rng, ATTRIBUTE_COUNT, false, DATA_SIZE); + + // choose a random instance + int index = rng.nextInt(DATA_SIZE); + Instance instance = data.get(index); + + // change a random attribute + int attr = rng.nextInt(data.getDataset().nbAttributes()); + instance.set(attr, instance.get(attr) + 1); + + assertFalse(data.isIdentical()); + } + } + + @Test + public void testIdenticalLabelTrue() throws Exception { + // generate a small data, only to get a dataset + Dataset dataset = Utils.randomData(rng, ATTRIBUTE_COUNT, false, 1).getDataset(); + + // test empty data + Data empty = new Data(dataset); + assertTrue(empty.identicalLabel()); + + // test identical labels + String descriptor = Utils.randomDescriptor(rng, ATTRIBUTE_COUNT); + double[][] source = Utils.randomDoublesWithSameLabel(rng, descriptor, false, + DATA_SIZE, rng.nextInt()); + String[] sData = Utils.double2String(source); + + dataset = DataLoader.generateDataset(descriptor, false, sData); + Data data = DataLoader.loadData(dataset, sData); + + assertTrue(data.identicalLabel()); + } + + @Test + public void testIdenticalLabelFalse() throws Exception { + int n = 10; + + for (int nloop = 0; nloop < n; nloop++) { + String descriptor = Utils.randomDescriptor(rng, ATTRIBUTE_COUNT); + int label = Utils.findLabel(descriptor); + double[][] source = Utils.randomDoublesWithSameLabel(rng, descriptor, false, + DATA_SIZE, rng.nextInt()); + // choose a random vector and change its label + int index = rng.nextInt(DATA_SIZE); + source[index][label]++; + + String[] sData = Utils.double2String(source); + + Dataset dataset = DataLoader.generateDataset(descriptor, false, sData); + Data data = DataLoader.loadData(dataset, sData); + + assertFalse(data.identicalLabel()); + } + } + + /** + * Test method for + * {@link org.apache.mahout.classifier.df.data.Data#bagging(java.util.Random)}. + */ + @Test + public void testBagging() { + Data bag = classifierData.bagging(rng); + + // the bag should have the same size as the data + assertEquals(classifierData.size(), bag.size()); + + // at least one element from the data should not be in the bag + boolean found = false; + for (int index = 0; index < classifierData.size() && !found; index++) { + found = !bag.contains(classifierData.get(index)); + } + + assertTrue("some instances from data should not be in the bag", found); + + // regression + bag = regressionData.bagging(rng); + + // the bag should have the same size as the data + assertEquals(regressionData.size(), bag.size()); + + // at least one element from the data should not be in the bag + found = false; + for (int index = 0; index < regressionData.size() && !found; index++) { + found = !bag.contains(regressionData.get(index)); + } + + assertTrue("some instances from data should not be in the bag", found); +} + + /** + * Test method for + * {@link org.apache.mahout.classifier.df.data.Data#rsplit(java.util.Random, int)}. + */ + @Test + public void testRsplit() { + + // rsplit should handle empty subsets + Data source = classifierData.clone(); + Data subset = source.rsplit(rng, 0); + assertTrue("subset should be empty", subset.isEmpty()); + assertEquals("source.size is incorrect", DATA_SIZE, source.size()); + + // rsplit should handle full size subsets + source = classifierData.clone(); + subset = source.rsplit(rng, DATA_SIZE); + assertEquals("subset.size is incorrect", DATA_SIZE, subset.size()); + assertTrue("source should be empty", source.isEmpty()); + + // random case + int subsize = rng.nextInt(DATA_SIZE); + source = classifierData.clone(); + subset = source.rsplit(rng, subsize); + assertEquals("subset.size is incorrect", subsize, subset.size()); + assertEquals("source.size is incorrect", DATA_SIZE - subsize, source.size()); + + // regression + // rsplit should handle empty subsets + source = regressionData.clone(); + subset = source.rsplit(rng, 0); + assertTrue("subset should be empty", subset.isEmpty()); + assertEquals("source.size is incorrect", DATA_SIZE, source.size()); + + // rsplit should handle full size subsets + source = regressionData.clone(); + subset = source.rsplit(rng, DATA_SIZE); + assertEquals("subset.size is incorrect", DATA_SIZE, subset.size()); + assertTrue("source should be empty", source.isEmpty()); + + // random case + subsize = rng.nextInt(DATA_SIZE); + source = regressionData.clone(); + subset = source.rsplit(rng, subsize); + assertEquals("subset.size is incorrect", subsize, subset.size()); + assertEquals("source.size is incorrect", DATA_SIZE - subsize, source.size()); +} + + @Test + public void testCountLabel() throws Exception { + Dataset dataset = classifierData.getDataset(); + int[] counts = new int[dataset.nblabels()]; + + int n = 10; + + for (int nloop = 0; nloop < n; nloop++) { + Arrays.fill(counts, 0); + classifierData.countLabels(counts); + + for (int index = 0; index < classifierData.size(); index++) { + counts[(int) dataset.getLabel(classifierData.get(index))]--; + } + + for (int label = 0; label < classifierData.getDataset().nblabels(); label++) { + assertEquals("Wrong label 'equals' count", 0, counts[0]); + } + } + } + + @Test + public void testMajorityLabel() throws Exception { + + // all instances have the same label + String descriptor = Utils.randomDescriptor(rng, ATTRIBUTE_COUNT); + int label = Utils.findLabel(descriptor); + + int label1 = rng.nextInt(); + double[][] source = Utils.randomDoublesWithSameLabel(rng, descriptor, false, 100, + label1); + String[] sData = Utils.double2String(source); + + Dataset dataset = DataLoader.generateDataset(descriptor, false, sData); + Data data = DataLoader.loadData(dataset, sData); + + int code1 = dataset.labelCode(Double.toString(label1)); + + assertEquals(code1, data.majorityLabel(rng)); + + // 51/100 vectors have label2 + int label2 = label1 + 1; + int nblabel2 = 51; + while (nblabel2 > 0) { + double[] vector = source[rng.nextInt(100)]; + if (vector[label] != label2) { + vector[label] = label2; + nblabel2--; + } + } + sData = Utils.double2String(source); + dataset = DataLoader.generateDataset(descriptor, false, sData); + data = DataLoader.loadData(dataset, sData); + int code2 = dataset.labelCode(Double.toString(label2)); + + // label2 should be the majority label + assertEquals(code2, data.majorityLabel(rng)); + + // 50 vectors with label1 and 50 vectors with label2 + do { + double[] vector = source[rng.nextInt(100)]; + if (vector[label] == label2) { + vector[label] = label1; + break; + } + } while (true); + sData = Utils.double2String(source); + + data = DataLoader.loadData(dataset, sData); + code1 = dataset.labelCode(Double.toString(label1)); + code2 = dataset.labelCode(Double.toString(label2)); + + // majorityLabel should return label1 and label2 at random + boolean found1 = false; + boolean found2 = false; + for (int index = 0; index < 10 && (!found1 || !found2); index++) { + int major = data.majorityLabel(rng); + if (major == code1) { + found1 = true; + } + if (major == code2) { + found2 = true; + } + } + assertTrue(found1 && found2); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/data/DatasetTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/data/DatasetTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/data/DatasetTest.java new file mode 100644 index 0000000..e5c9ee7 --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/data/DatasetTest.java @@ -0,0 +1,72 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with this + * work for additional information regarding copyright ownership. The ASF + * licenses this file to You under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package org.apache.mahout.classifier.df.data; + +import org.apache.mahout.common.MahoutTestCase; +import org.junit.Test; +@Deprecated +public final class DatasetTest extends MahoutTestCase { + + @Test + public void jsonEncoding() throws DescriptorException { + Dataset to = DataLoader.generateDataset("N C I L", true, new String[]{"1 foo 2 3", "4 bar 5 6"}); + + // to JSON + //assertEquals(json, to.toJSON()); + assertEquals(3, to.nbAttributes()); + assertEquals(1, to.getIgnored().length); + assertEquals(2, to.getIgnored()[0]); + assertEquals(2, to.getLabelId()); + assertTrue(to.isNumerical(0)); + + // from JSON + Dataset fromJson = Dataset.fromJSON(to.toJSON()); + assertEquals(3, fromJson.nbAttributes()); + assertEquals(1, fromJson.getIgnored().length); + assertEquals(2, fromJson.getIgnored()[0]); + assertTrue(fromJson.isNumerical(0)); + + // read values for a nominal + assertNotEquals(fromJson.valueOf(1, "bar"), fromJson.valueOf(1, "foo")); + } + + @Test + public void jsonEncodingIgnoreFeatures() throws DescriptorException {; + Dataset to = DataLoader.generateDataset("N C I L", false, new String[]{"1 foo 2 Red", "4 bar 5 Blue"}); + + // to JSON + //assertEquals(json, to.toJSON()); + assertEquals(3, to.nbAttributes()); + assertEquals(1, to.getIgnored().length); + assertEquals(2, to.getIgnored()[0]); + assertEquals(2, to.getLabelId()); + assertTrue(to.isNumerical(0)); + assertNotEquals(to.valueOf(1, "bar"), to.valueOf(1, "foo")); + assertNotEquals(to.valueOf(2, "Red"), to.valueOf(2, "Blue")); + + // from JSON + Dataset fromJson = Dataset.fromJSON(to.toJSON()); + assertEquals(3, fromJson.nbAttributes()); + assertEquals(1, fromJson.getIgnored().length); + assertEquals(2, fromJson.getIgnored()[0]); + assertTrue(fromJson.isNumerical(0)); + + // read values for a nominal, one before and one after the ignore feature + assertNotEquals(fromJson.valueOf(1, "bar"), fromJson.valueOf(1, "foo")); + assertNotEquals(fromJson.valueOf(2, "Red"), fromJson.valueOf(2, "Blue")); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/data/DescriptorUtilsTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/data/DescriptorUtilsTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/data/DescriptorUtilsTest.java new file mode 100644 index 0000000..619f067 --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/data/DescriptorUtilsTest.java @@ -0,0 +1,92 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.data; + +import java.util.Random; + +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.classifier.df.data.Dataset.Attribute; +import org.junit.Test; +@Deprecated +public final class DescriptorUtilsTest extends MahoutTestCase { + + /** + * Test method for + * {@link org.apache.mahout.classifier.df.data.DescriptorUtils#parseDescriptor(java.lang.CharSequence)}. + */ + @Test + public void testParseDescriptor() throws Exception { + int n = 10; + int maxnbAttributes = 100; + + Random rng = RandomUtils.getRandom(); + + for (int nloop = 0; nloop < n; nloop++) { + int nbAttributes = rng.nextInt(maxnbAttributes) + 1; + + char[] tokens = Utils.randomTokens(rng, nbAttributes); + Attribute[] attrs = DescriptorUtils.parseDescriptor(Utils.generateDescriptor(tokens)); + + // verify that the attributes matches the token list + assertEquals("attributes size", nbAttributes, attrs.length); + + for (int attr = 0; attr < nbAttributes; attr++) { + switch (tokens[attr]) { + case 'I': + assertTrue(attrs[attr].isIgnored()); + break; + case 'N': + assertTrue(attrs[attr].isNumerical()); + break; + case 'C': + assertTrue(attrs[attr].isCategorical()); + break; + case 'L': + assertTrue(attrs[attr].isLabel()); + break; + } + } + } + } + + @Test + public void testGenerateDescription() throws Exception { + validate("", ""); + validate("I L C C N N N C", "I L C C N N N C"); + validate("I L C C N N N C", "I L 2 C 3 N C"); + validate("I L C C N N N C", " I L 2 C 3 N C "); + + try { + validate("", "I L 2 2 C 2 N C"); + fail("2 consecutive multiplicators"); + } catch (DescriptorException e) { + } + + try { + validate("", "I L 2 C -2 N C"); + fail("negative multiplicator"); + } catch (DescriptorException e) { + } + } + + private static void validate(String descriptor, CharSequence description) throws DescriptorException { + assertEquals(descriptor, DescriptorUtils.generateDescriptor(description)); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/data/Utils.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/data/Utils.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/data/Utils.java new file mode 100644 index 0000000..9b51ec9 --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/data/Utils.java @@ -0,0 +1,284 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.data; + +import java.io.BufferedWriter; +import java.io.File; +import java.io.IOException; +import java.util.Arrays; +import java.util.Random; + +import com.google.common.base.Charsets; +import com.google.common.io.Closeables; +import com.google.common.io.Files; +import org.apache.commons.lang3.ArrayUtils; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.mahout.classifier.df.data.Dataset.Attribute; +import org.apache.mahout.common.MahoutTestCase; + +/** + * Helper methods used by the tests + * + */ +@Deprecated +public final class Utils { + + private Utils() {} + + /** Used when generating random CATEGORICAL values */ + private static final int CATEGORICAL_RANGE = 100; + + /** + * Generates a random list of tokens + * <ul> + * <li>each attribute has 50% chance to be NUMERICAL ('N') or CATEGORICAL + * ('C')</li> + * <li>10% of the attributes are IGNORED ('I')</li> + * <li>one randomly chosen attribute becomes the LABEL ('L')</li> + * </ul> + * + * @param rng Random number generator + * @param nbTokens number of tokens to generate + */ + public static char[] randomTokens(Random rng, int nbTokens) { + char[] result = new char[nbTokens]; + + for (int token = 0; token < nbTokens; token++) { + double rand = rng.nextDouble(); + if (rand < 0.1) { + result[token] = 'I'; // IGNORED + } else if (rand >= 0.5) { + result[token] = 'C'; + } else { + result[token] = 'N'; // NUMERICAL + } // CATEGORICAL + } + + // choose the label + result[rng.nextInt(nbTokens)] = 'L'; + + return result; + } + + /** + * Generates a space-separated String that contains all the tokens + */ + public static String generateDescriptor(char[] tokens) { + StringBuilder builder = new StringBuilder(); + + for (char token : tokens) { + builder.append(token).append(' '); + } + + return builder.toString(); + } + + /** + * Generates a random descriptor as follows:<br> + * <ul> + * <li>each attribute has 50% chance to be NUMERICAL or CATEGORICAL</li> + * <li>10% of the attributes are IGNORED</li> + * <li>one randomly chosen attribute becomes the LABEL</li> + * </ul> + */ + public static String randomDescriptor(Random rng, int nbAttributes) { + return generateDescriptor(randomTokens(rng, nbAttributes)); + } + + /** + * generates random data based on the given descriptor + * + * @param rng Random number generator + * @param descriptor attributes description + * @param number number of data lines to generate + */ + public static double[][] randomDoubles(Random rng, CharSequence descriptor, boolean regression, int number) + throws DescriptorException { + Attribute[] attrs = DescriptorUtils.parseDescriptor(descriptor); + + double[][] data = new double[number][]; + + for (int index = 0; index < number; index++) { + data[index] = randomVector(rng, attrs, regression); + } + + return data; + } + + /** + * Generates random data + * + * @param rng Random number generator + * @param nbAttributes number of attributes + * @param regression true is the label should be numerical + * @param size data size + */ + public static Data randomData(Random rng, int nbAttributes, boolean regression, int size) throws DescriptorException { + String descriptor = randomDescriptor(rng, nbAttributes); + double[][] source = randomDoubles(rng, descriptor, regression, size); + String[] sData = double2String(source); + Dataset dataset = DataLoader.generateDataset(descriptor, regression, sData); + + return DataLoader.loadData(dataset, sData); + } + + /** + * generates a random vector based on the given attributes.<br> + * the attributes' values are generated as follows :<br> + * <ul> + * <li>each IGNORED attribute receives a Double.NaN</li> + * <li>each NUMERICAL attribute receives a random double</li> + * <li>each CATEGORICAL and LABEL attribute receives a random integer in the + * range [0, CATEGORICAL_RANGE[</li> + * </ul> + * + * @param attrs attributes description + */ + private static double[] randomVector(Random rng, Attribute[] attrs, boolean regression) { + double[] vector = new double[attrs.length]; + + for (int attr = 0; attr < attrs.length; attr++) { + if (attrs[attr].isIgnored()) { + vector[attr] = Double.NaN; + } else if (attrs[attr].isNumerical()) { + vector[attr] = rng.nextDouble(); + } else if (attrs[attr].isCategorical()) { + vector[attr] = rng.nextInt(CATEGORICAL_RANGE); + } else { // LABEL + if (regression) { + vector[attr] = rng.nextDouble(); + } else { + vector[attr] = rng.nextInt(CATEGORICAL_RANGE); + } + } + } + + return vector; + } + + /** + * converts a double array to a comma-separated string + * + * @param v double array + * @return comma-separated string + */ + private static String double2String(double[] v) { + StringBuilder builder = new StringBuilder(); + + for (double aV : v) { + builder.append(aV).append(','); + } + + return builder.toString(); + } + + /** + * converts an array of double arrays to an array of comma-separated strings + * + * @param source array of double arrays + * @return array of comma-separated strings + */ + public static String[] double2String(double[][] source) { + String[] output = new String[source.length]; + + for (int index = 0; index < source.length; index++) { + output[index] = double2String(source[index]); + } + + return output; + } + + /** + * Generates random data with same label value + * + * @param number data size + * @param value label value + */ + public static double[][] randomDoublesWithSameLabel(Random rng, + CharSequence descriptor, + boolean regression, + int number, + int value) throws DescriptorException { + int label = findLabel(descriptor); + + double[][] source = randomDoubles(rng, descriptor, regression, number); + + for (int index = 0; index < number; index++) { + source[index][label] = value; + } + + return source; + } + + /** + * finds the label attribute's index + */ + public static int findLabel(CharSequence descriptor) throws DescriptorException { + Attribute[] attrs = DescriptorUtils.parseDescriptor(descriptor); + return ArrayUtils.indexOf(attrs, Attribute.LABEL); + } + + private static void writeDataToFile(String[] sData, Path path) throws IOException { + BufferedWriter output = null; + try { + output = Files.newWriter(new File(path.toString()), Charsets.UTF_8); + for (String line : sData) { + output.write(line); + output.write('\n'); + } + } finally { + Closeables.close(output, false); + } + + } + + public static Path writeDataToTestFile(String[] sData) throws IOException { + Path testData = new Path("testdata/Data"); + MahoutTestCase ca = new MahoutTestCase(); + FileSystem fs = testData.getFileSystem(ca.getConfiguration()); + if (!fs.exists(testData)) { + fs.mkdirs(testData); + } + + Path path = new Path(testData, "DataLoaderTest.data"); + + writeDataToFile(sData, path); + + return path; + } + + /** + * Split the data into numMaps splits + */ + public static String[][] splitData(String[] sData, int numMaps) { + int nbInstances = sData.length; + int partitionSize = nbInstances / numMaps; + + String[][] splits = new String[numMaps][]; + + for (int partition = 0; partition < numMaps; partition++) { + int from = partition * partitionSize; + int to = partition == (numMaps - 1) ? nbInstances : (partition + 1) * partitionSize; + + splits[partition] = Arrays.copyOfRange(sData, from, to); + } + + return splits; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemInputFormatTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemInputFormatTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemInputFormatTest.java new file mode 100644 index 0000000..6a17aa2 --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemInputFormatTest.java @@ -0,0 +1,109 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.mapreduce.inmem; + +import java.util.List; +import java.util.Random; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.mapreduce.InputSplit; +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.classifier.df.mapreduce.Builder; +import org.apache.mahout.classifier.df.mapreduce.inmem.InMemInputFormat.InMemInputSplit; +import org.apache.mahout.classifier.df.mapreduce.inmem.InMemInputFormat.InMemRecordReader; +import org.junit.Test; +@Deprecated +public final class InMemInputFormatTest extends MahoutTestCase { + + @Test + public void testSplits() throws Exception { + int n = 1; + int maxNumSplits = 100; + int maxNbTrees = 1000; + + Random rng = RandomUtils.getRandom(); + + for (int nloop = 0; nloop < n; nloop++) { + int numSplits = rng.nextInt(maxNumSplits) + 1; + int nbTrees = rng.nextInt(maxNbTrees) + 1; + + Configuration conf = getConfiguration(); + Builder.setNbTrees(conf, nbTrees); + + InMemInputFormat inputFormat = new InMemInputFormat(); + List<InputSplit> splits = inputFormat.getSplits(conf, numSplits); + + assertEquals(numSplits, splits.size()); + + int nbTreesPerSplit = nbTrees / numSplits; + int totalTrees = 0; + int expectedId = 0; + + for (int index = 0; index < numSplits; index++) { + assertTrue(splits.get(index) instanceof InMemInputSplit); + + InMemInputSplit split = (InMemInputSplit) splits.get(index); + + assertEquals(expectedId, split.getFirstId()); + + if (index < numSplits - 1) { + assertEquals(nbTreesPerSplit, split.getNbTrees()); + } else { + assertEquals(nbTrees - totalTrees, split.getNbTrees()); + } + + totalTrees += split.getNbTrees(); + expectedId += split.getNbTrees(); + } + } + } + + @Test + public void testRecordReader() throws Exception { + int n = 1; + int maxNumSplits = 100; + int maxNbTrees = 1000; + + Random rng = RandomUtils.getRandom(); + + for (int nloop = 0; nloop < n; nloop++) { + int numSplits = rng.nextInt(maxNumSplits) + 1; + int nbTrees = rng.nextInt(maxNbTrees) + 1; + + Configuration conf = getConfiguration(); + Builder.setNbTrees(conf, nbTrees); + + InMemInputFormat inputFormat = new InMemInputFormat(); + List<InputSplit> splits = inputFormat.getSplits(conf, numSplits); + + for (int index = 0; index < numSplits; index++) { + InMemInputSplit split = (InMemInputSplit) splits.get(index); + InMemRecordReader reader = new InMemRecordReader(split); + + reader.initialize(split, null); + + for (int tree = 0; tree < split.getNbTrees(); tree++) { + // reader.next() should return true until there is no tree left + assertEquals(tree < split.getNbTrees(), reader.nextKeyValue()); + assertEquals(split.getFirstId() + tree, reader.getCurrentKey().get()); + } + } + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemInputSplitTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemInputSplitTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemInputSplitTest.java new file mode 100644 index 0000000..aeea084 --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemInputSplitTest.java @@ -0,0 +1,77 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.mapreduce.inmem; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.DataInput; +import java.io.DataInputStream; +import java.io.DataOutput; +import java.io.DataOutputStream; +import java.io.IOException; +import java.util.Random; + +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.classifier.df.mapreduce.inmem.InMemInputFormat.InMemInputSplit; +import org.junit.Before; +import org.junit.Test; +@Deprecated +public final class InMemInputSplitTest extends MahoutTestCase { + + private Random rng; + private ByteArrayOutputStream byteOutStream; + private DataOutput out; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + rng = RandomUtils.getRandom(); + byteOutStream = new ByteArrayOutputStream(); + out = new DataOutputStream(byteOutStream); + } + + /** + * Make sure that all the fields are processed correctly + */ + @Test + public void testWritable() throws Exception { + InMemInputSplit split = new InMemInputSplit(rng.nextInt(), rng.nextInt(1000), rng.nextLong()); + + split.write(out); + assertEquals(split, readSplit()); + } + + /** + * test the case seed == null + */ + @Test + public void testNullSeed() throws Exception { + InMemInputSplit split = new InMemInputSplit(rng.nextInt(), rng.nextInt(1000), null); + + split.write(out); + assertEquals(split, readSplit()); + } + + private InMemInputSplit readSplit() throws IOException { + ByteArrayInputStream byteInStream = new ByteArrayInputStream(byteOutStream.toByteArray()); + DataInput in = new DataInputStream(byteInStream); + return InMemInputSplit.read(in); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/mapreduce/partial/PartialBuilderTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/mapreduce/partial/PartialBuilderTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/mapreduce/partial/PartialBuilderTest.java new file mode 100644 index 0000000..2821034 --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/mapreduce/partial/PartialBuilderTest.java @@ -0,0 +1,197 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.mapreduce.partial; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collection; +import java.util.Random; + +import com.google.common.collect.Lists; +import com.google.common.io.Closeables; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.SequenceFile; +import org.apache.hadoop.io.SequenceFile.Writer; +import org.apache.hadoop.mapreduce.Job; +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.classifier.df.builder.DefaultTreeBuilder; +import org.apache.mahout.classifier.df.builder.TreeBuilder; +import org.apache.mahout.classifier.df.mapreduce.MapredOutput; +import org.apache.mahout.classifier.df.node.Leaf; +import org.apache.mahout.classifier.df.node.Node; +import org.junit.Test; +@Deprecated +public final class PartialBuilderTest extends MahoutTestCase { + + private static final int NUM_MAPS = 5; + + private static final int NUM_TREES = 32; + + /** instances per partition */ + private static final int NUM_INSTANCES = 20; + + @Test + public void testProcessOutput() throws Exception { + Configuration conf = getConfiguration(); + conf.setInt("mapred.map.tasks", NUM_MAPS); + + Random rng = RandomUtils.getRandom(); + + // prepare the output + TreeID[] keys = new TreeID[NUM_TREES]; + MapredOutput[] values = new MapredOutput[NUM_TREES]; + int[] firstIds = new int[NUM_MAPS]; + randomKeyValues(rng, keys, values, firstIds); + + // store the output in a sequence file + Path base = getTestTempDirPath("testdata"); + FileSystem fs = base.getFileSystem(conf); + + Path outputFile = new Path(base, "PartialBuilderTest.seq"); + Writer writer = SequenceFile.createWriter(fs, conf, outputFile, + TreeID.class, MapredOutput.class); + + try { + for (int index = 0; index < NUM_TREES; index++) { + writer.append(keys[index], values[index]); + } + } finally { + Closeables.close(writer, false); + } + + // load the output and make sure its valid + TreeID[] newKeys = new TreeID[NUM_TREES]; + Node[] newTrees = new Node[NUM_TREES]; + + PartialBuilder.processOutput(new Job(conf), base, newKeys, newTrees); + + // check the forest + for (int tree = 0; tree < NUM_TREES; tree++) { + assertEquals(values[tree].getTree(), newTrees[tree]); + } + + assertTrue("keys not equal", Arrays.deepEquals(keys, newKeys)); + } + + /** + * Make sure that the builder passes the good parameters to the job + * + */ + @Test + public void testConfigure() { + TreeBuilder treeBuilder = new DefaultTreeBuilder(); + Path dataPath = new Path("notUsedDataPath"); + Path datasetPath = new Path("notUsedDatasetPath"); + Long seed = 5L; + + new PartialBuilderChecker(treeBuilder, dataPath, datasetPath, seed); + } + + /** + * Generates random (key, value) pairs. Shuffles the partition's order + * + * @param rng + * @param keys + * @param values + * @param firstIds partitions's first ids in hadoop's order + */ + private static void randomKeyValues(Random rng, TreeID[] keys, MapredOutput[] values, int[] firstIds) { + int index = 0; + int firstId = 0; + Collection<Integer> partitions = Lists.newArrayList(); + + for (int p = 0; p < NUM_MAPS; p++) { + // select a random partition, not yet selected + int partition; + do { + partition = rng.nextInt(NUM_MAPS); + } while (partitions.contains(partition)); + + partitions.add(partition); + + int nbTrees = Step1Mapper.nbTrees(NUM_MAPS, NUM_TREES, partition); + + for (int treeId = 0; treeId < nbTrees; treeId++) { + Node tree = new Leaf(rng.nextInt(100)); + + keys[index] = new TreeID(partition, treeId); + values[index] = new MapredOutput(tree, nextIntArray(rng, NUM_INSTANCES)); + + index++; + } + + firstIds[p] = firstId; + firstId += NUM_INSTANCES; + } + + } + + private static int[] nextIntArray(Random rng, int size) { + int[] array = new int[size]; + for (int index = 0; index < size; index++) { + array[index] = rng.nextInt(101) - 1; + } + + return array; + } + + static class PartialBuilderChecker extends PartialBuilder { + + private final Long seed; + + private final TreeBuilder treeBuilder; + + private final Path datasetPath; + + PartialBuilderChecker(TreeBuilder treeBuilder, Path dataPath, + Path datasetPath, Long seed) { + super(treeBuilder, dataPath, datasetPath, seed); + + this.seed = seed; + this.treeBuilder = treeBuilder; + this.datasetPath = datasetPath; + } + + @Override + protected boolean runJob(Job job) throws IOException { + // no need to run the job, just check if the params are correct + + Configuration conf = job.getConfiguration(); + + assertEquals(seed, getRandomSeed(conf)); + + // PartialBuilder should detect the 'local' mode and overrides the number + // of map tasks + assertEquals(1, conf.getInt("mapred.map.tasks", -1)); + + assertEquals(NUM_TREES, getNbTrees(conf)); + + assertFalse(isOutput(conf)); + + assertEquals(treeBuilder, getTreeBuilder(conf)); + + assertEquals(datasetPath, getDistributedCacheFile(conf, 0)); + + return true; + } + + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/mapreduce/partial/Step1MapperTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/mapreduce/partial/Step1MapperTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/mapreduce/partial/Step1MapperTest.java new file mode 100644 index 0000000..c5aec7f --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/mapreduce/partial/Step1MapperTest.java @@ -0,0 +1,160 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.mapreduce.partial; + +import org.easymock.EasyMock; +import java.util.Random; + +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.classifier.df.builder.TreeBuilder; +import org.apache.mahout.classifier.df.data.Data; +import org.apache.mahout.classifier.df.data.DataLoader; +import org.apache.mahout.classifier.df.data.Dataset; +import org.apache.mahout.classifier.df.data.Utils; +import org.apache.mahout.classifier.df.node.Leaf; +import org.apache.mahout.classifier.df.node.Node; +import org.apache.mahout.common.MahoutTestCase; +import org.easymock.Capture; +import org.easymock.CaptureType; +import org.junit.Test; +@Deprecated +public final class Step1MapperTest extends MahoutTestCase { + + /** + * Make sure that the data used to build the trees is from the mapper's + * partition + * + */ + private static class MockTreeBuilder implements TreeBuilder { + + private Data expected; + + public void setExpected(Data data) { + expected = data; + } + + @Override + public Node build(Random rng, Data data) { + for (int index = 0; index < data.size(); index++) { + assertTrue(expected.contains(data.get(index))); + } + + return new Leaf(Double.NaN); + } + } + + /** + * Special Step1Mapper that can be configured without using a Configuration + * + */ + private static class MockStep1Mapper extends Step1Mapper { + private MockStep1Mapper(TreeBuilder treeBuilder, Dataset dataset, Long seed, + int partition, int numMapTasks, int numTrees) { + configure(false, treeBuilder, dataset); + configure(seed, partition, numMapTasks, numTrees); + } + } + + private static class TreeIDCapture extends Capture<TreeID> { + + private TreeIDCapture() { + super(CaptureType.ALL); + } + + @Override + public void setValue(final TreeID value) { + super.setValue(value.clone()); + } + } + + /** nb attributes per generated data instance */ + static final int NUM_ATTRIBUTES = 4; + + /** nb generated data instances */ + static final int NUM_INSTANCES = 100; + + /** nb trees to build */ + static final int NUM_TREES = 10; + + /** nb mappers to use */ + static final int NUM_MAPPERS = 2; + + @SuppressWarnings({ "rawtypes", "unchecked" }) + @Test + public void testMapper() throws Exception { + Random rng = RandomUtils.getRandom(); + + // prepare the data + String descriptor = Utils.randomDescriptor(rng, NUM_ATTRIBUTES); + double[][] source = Utils.randomDoubles(rng, descriptor, false, NUM_INSTANCES); + String[] sData = Utils.double2String(source); + Dataset dataset = DataLoader.generateDataset(descriptor, false, sData); + String[][] splits = Utils.splitData(sData, NUM_MAPPERS); + + MockTreeBuilder treeBuilder = new MockTreeBuilder(); + + LongWritable key = new LongWritable(); + Text value = new Text(); + + int treeIndex = 0; + + for (int partition = 0; partition < NUM_MAPPERS; partition++) { + String[] split = splits[partition]; + treeBuilder.setExpected(DataLoader.loadData(dataset, split)); + + // expected number of trees that this mapper will build + int mapNbTrees = Step1Mapper.nbTrees(NUM_MAPPERS, NUM_TREES, partition); + + Mapper.Context context = EasyMock.createNiceMock(Mapper.Context.class); + Capture<TreeID> capturedKeys = new TreeIDCapture(); + context.write(EasyMock.capture(capturedKeys), EasyMock.anyObject()); + EasyMock.expectLastCall().anyTimes(); + + EasyMock.replay(context); + + MockStep1Mapper mapper = new MockStep1Mapper(treeBuilder, dataset, null, + partition, NUM_MAPPERS, NUM_TREES); + + // make sure the mapper computed firstTreeId correctly + assertEquals(treeIndex, mapper.getFirstTreeId()); + + for (int index = 0; index < split.length; index++) { + key.set(index); + value.set(split[index]); + mapper.map(key, value, context); + } + + mapper.cleanup(context); + EasyMock.verify(context); + + // make sure the mapper built all its trees + assertEquals(mapNbTrees, capturedKeys.getValues().size()); + + // check the returned keys + for (TreeID k : capturedKeys.getValues()) { + assertEquals(partition, k.partition()); + assertEquals(treeIndex, k.treeId()); + + treeIndex++; + } + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/mapreduce/partial/TreeIDTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/mapreduce/partial/TreeIDTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/mapreduce/partial/TreeIDTest.java new file mode 100644 index 0000000..c4beeaf --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/mapreduce/partial/TreeIDTest.java @@ -0,0 +1,48 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.mapreduce.partial; + +import java.util.Random; + +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.common.RandomUtils; +import org.junit.Test; +@Deprecated +public final class TreeIDTest extends MahoutTestCase { + + @Test + public void testTreeID() { + Random rng = RandomUtils.getRandom(); + + for (int nloop = 0; nloop < 1000000; nloop++) { + int partition = Math.abs(rng.nextInt()); + int treeId = rng.nextInt(TreeID.MAX_TREEID); + + TreeID t1 = new TreeID(partition, treeId); + + assertEquals(partition, t1.partition()); + assertEquals(treeId, t1.treeId()); + + TreeID t2 = new TreeID(); + t2.set(partition, treeId); + + assertEquals(partition, t2.partition()); + assertEquals(treeId, t2.treeId()); + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/node/NodeTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/node/NodeTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/node/NodeTest.java new file mode 100644 index 0000000..1300926 --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/node/NodeTest.java @@ -0,0 +1,108 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.node; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.DataInput; +import java.io.DataInputStream; +import java.io.DataOutput; +import java.io.DataOutputStream; +import java.io.IOException; +import java.util.Random; + +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.common.RandomUtils; +import org.junit.Before; +import org.junit.Test; +@Deprecated +public final class NodeTest extends MahoutTestCase { + + private Random rng; + + private ByteArrayOutputStream byteOutStream; + private DataOutput out; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + rng = RandomUtils.getRandom(); + + byteOutStream = new ByteArrayOutputStream(); + out = new DataOutputStream(byteOutStream); + } + + /** + * Test method for + * {@link org.apache.mahout.classifier.df.node.Node#read(java.io.DataInput)}. + */ + @Test + public void testReadTree() throws Exception { + Node node1 = new CategoricalNode(rng.nextInt(), + new double[] { rng.nextDouble(), rng.nextDouble() }, + new Node[] { new Leaf(rng.nextDouble()), new Leaf(rng.nextDouble()) }); + Node node2 = new NumericalNode(rng.nextInt(), rng.nextDouble(), + new Leaf(rng.nextDouble()), new Leaf(rng.nextDouble())); + + Node root = new CategoricalNode(rng.nextInt(), + new double[] { rng.nextDouble(), rng.nextDouble(), rng.nextDouble() }, + new Node[] { node1, node2, new Leaf(rng.nextDouble()) }); + + // write the node to a DataOutput + root.write(out); + + // read the node back + assertEquals(root, readNode()); + } + + Node readNode() throws IOException { + ByteArrayInputStream byteInStream = new ByteArrayInputStream(byteOutStream.toByteArray()); + DataInput in = new DataInputStream(byteInStream); + return Node.read(in); + } + + @Test + public void testReadLeaf() throws Exception { + + Node leaf = new Leaf(rng.nextDouble()); + leaf.write(out); + assertEquals(leaf, readNode()); + } + + @Test + public void testParseNumerical() throws Exception { + + Node node = new NumericalNode(rng.nextInt(), rng.nextDouble(), new Leaf(rng + .nextInt()), new Leaf(rng.nextDouble())); + node.write(out); + assertEquals(node, readNode()); + } + + @Test + public void testCategoricalNode() throws Exception { + + Node node = new CategoricalNode(rng.nextInt(), new double[]{rng.nextDouble(), + rng.nextDouble(), rng.nextDouble()}, new Node[]{ + new Leaf(rng.nextDouble()), new Leaf(rng.nextDouble()), + new Leaf(rng.nextDouble())}); + + node.write(out); + assertEquals(node, readNode()); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/split/DefaultIgSplitTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/split/DefaultIgSplitTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/split/DefaultIgSplitTest.java new file mode 100644 index 0000000..94d0ad9 --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/split/DefaultIgSplitTest.java @@ -0,0 +1,78 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.split; + +import java.util.Random; + +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.classifier.df.data.Data; +import org.apache.mahout.classifier.df.data.DataLoader; +import org.apache.mahout.classifier.df.data.Dataset; +import org.apache.mahout.classifier.df.data.Utils; +import org.junit.Test; +@Deprecated +public final class DefaultIgSplitTest extends MahoutTestCase { + + private static final int NUM_ATTRIBUTES = 10; + + @Test + public void testEntropy() throws Exception { + Random rng = RandomUtils.getRandom(); + String descriptor = Utils.randomDescriptor(rng, NUM_ATTRIBUTES); + int label = Utils.findLabel(descriptor); + + // all the vectors have the same label (0) + double[][] temp = Utils.randomDoublesWithSameLabel(rng, descriptor, false, 100, 0); + String[] sData = Utils.double2String(temp); + Dataset dataset = DataLoader.generateDataset(descriptor, false, sData); + Data data = DataLoader.loadData(dataset, sData); + DefaultIgSplit iG = new DefaultIgSplit(); + + double expected = 0.0 - 1.0 * Math.log(1.0) / Math.log(2.0); + assertEquals(expected, iG.entropy(data), EPSILON); + + // 50/100 of the vectors have the label (1) + // 50/100 of the vectors have the label (0) + for (int index = 0; index < 50; index++) { + temp[index][label] = 1.0; + } + sData = Utils.double2String(temp); + dataset = DataLoader.generateDataset(descriptor, false, sData); + data = DataLoader.loadData(dataset, sData); + iG = new DefaultIgSplit(); + + expected = 2.0 * -0.5 * Math.log(0.5) / Math.log(2.0); + assertEquals(expected, iG.entropy(data), EPSILON); + + // 15/100 of the vectors have the label (2) + // 35/100 of the vectors have the label (1) + // 50/100 of the vectors have the label (0) + for (int index = 0; index < 15; index++) { + temp[index][label] = 2.0; + } + sData = Utils.double2String(temp); + dataset = DataLoader.generateDataset(descriptor, false, sData); + data = DataLoader.loadData(dataset, sData); + iG = new DefaultIgSplit(); + + expected = -0.15 * Math.log(0.15) / Math.log(2.0) - 0.35 * Math.log(0.35) + / Math.log(2.0) - 0.5 * Math.log(0.5) / Math.log(2.0); + assertEquals(expected, iG.entropy(data), EPSILON); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/split/RegressionSplitTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/split/RegressionSplitTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/split/RegressionSplitTest.java new file mode 100644 index 0000000..9c5893a --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/split/RegressionSplitTest.java @@ -0,0 +1,87 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.split; + +import org.apache.mahout.classifier.df.data.Data; +import org.apache.mahout.classifier.df.data.DataLoader; +import org.apache.mahout.classifier.df.data.Dataset; +import org.apache.mahout.classifier.df.data.DescriptorException; +import org.apache.mahout.classifier.df.data.conditions.Condition; +import org.apache.mahout.common.MahoutTestCase; +import org.junit.Test; +@Deprecated +public final class RegressionSplitTest extends MahoutTestCase { + + private static Data[] generateTrainingData() throws DescriptorException { + // Training data + String[] trainData = new String[20]; + for (int i = 0; i < trainData.length; i++) { + if (i % 3 == 0) { + trainData[i] = "A," + (40 - i) + ',' + (i + 20); + } else if (i % 3 == 1) { + trainData[i] = "B," + (i + 20) + ',' + (40 - i); + } else { + trainData[i] = "C," + (i + 20) + ',' + (i + 20); + } + } + // Dataset + Dataset dataset = DataLoader.generateDataset("C N L", true, trainData); + Data[] datas = new Data[3]; + datas[0] = DataLoader.loadData(dataset, trainData); + + // Training data + trainData = new String[20]; + for (int i = 0; i < trainData.length; i++) { + if (i % 2 == 0) { + trainData[i] = "A," + (50 - i) + ',' + (i + 10); + } else { + trainData[i] = "B," + (i + 10) + ',' + (50 - i); + } + } + datas[1] = DataLoader.loadData(dataset, trainData); + + // Training data + trainData = new String[10]; + for (int i = 0; i < trainData.length; i++) { + trainData[i] = "A," + (40 - i) + ',' + (i + 20); + } + datas[2] = DataLoader.loadData(dataset, trainData); + + return datas; + } + + @Test + public void testComputeSplit() throws DescriptorException { + Data[] datas = generateTrainingData(); + + RegressionSplit igSplit = new RegressionSplit(); + Split split = igSplit.computeSplit(datas[0], 1); + assertEquals(180.0, split.getIg(), EPSILON); + assertEquals(38.0, split.getSplit(), EPSILON); + split = igSplit.computeSplit(datas[0].subset(Condition.lesser(1, 38.0)), 1); + assertEquals(76.5, split.getIg(), EPSILON); + assertEquals(21.5, split.getSplit(), EPSILON); + + split = igSplit.computeSplit(datas[1], 0); + assertEquals(2205.0, split.getIg(), EPSILON); + assertEquals(Double.NaN, split.getSplit(), EPSILON); + split = igSplit.computeSplit(datas[1].subset(Condition.equals(0, 0.0)), 1); + assertEquals(250.0, split.getIg(), EPSILON); + assertEquals(41.0, split.getSplit(), EPSILON); + } +}
