http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/similarity/jdbc/AbstractJDBCInMemoryItemSimilarity.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/similarity/jdbc/AbstractJDBCInMemoryItemSimilarity.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/similarity/jdbc/AbstractJDBCInMemoryItemSimilarity.java new file mode 100644 index 0000000..3ae9990 --- /dev/null +++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/similarity/jdbc/AbstractJDBCInMemoryItemSimilarity.java @@ -0,0 +1,132 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.similarity.jdbc; + +import org.apache.mahout.cf.taste.common.Refreshable; +import org.apache.mahout.cf.taste.common.TasteException; +import org.apache.mahout.cf.taste.impl.common.jdbc.AbstractJDBCComponent; +import org.apache.mahout.cf.taste.impl.common.jdbc.ResultSetIterator; +import org.apache.mahout.cf.taste.impl.model.jdbc.ConnectionPoolDataSource; +import org.apache.mahout.cf.taste.impl.similarity.GenericItemSimilarity; +import org.apache.mahout.cf.taste.similarity.ItemSimilarity; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.sql.DataSource; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.Collection; +import java.util.Iterator; +import java.util.concurrent.locks.ReentrantLock; + +/** + * loads all similarities from the database into RAM + */ +abstract class AbstractJDBCInMemoryItemSimilarity extends AbstractJDBCComponent implements ItemSimilarity { + + private ItemSimilarity delegate; + + private final DataSource dataSource; + private final String getAllItemSimilaritiesSQL; + private final ReentrantLock reloadLock; + + private static final Logger log = LoggerFactory.getLogger(AbstractJDBCInMemoryItemSimilarity.class); + + AbstractJDBCInMemoryItemSimilarity(DataSource dataSource, String getAllItemSimilaritiesSQL) { + + AbstractJDBCComponent.checkNotNullAndLog("getAllItemSimilaritiesSQL", getAllItemSimilaritiesSQL); + + if (!(dataSource instanceof ConnectionPoolDataSource)) { + log.warn("You are not using ConnectionPoolDataSource. Make sure your DataSource pools connections " + + "to the database itself, or database performance will be severely reduced."); + } + + this.dataSource = dataSource; + this.getAllItemSimilaritiesSQL = getAllItemSimilaritiesSQL; + this.reloadLock = new ReentrantLock(); + + reload(); + } + + @Override + public double itemSimilarity(long itemID1, long itemID2) throws TasteException { + return delegate.itemSimilarity(itemID1, itemID2); + } + + @Override + public double[] itemSimilarities(long itemID1, long[] itemID2s) throws TasteException { + return delegate.itemSimilarities(itemID1, itemID2s); + } + + @Override + public long[] allSimilarItemIDs(long itemID) throws TasteException { + return delegate.allSimilarItemIDs(itemID); + } + + @Override + public void refresh(Collection<Refreshable> alreadyRefreshed) { + log.debug("Reloading..."); + reload(); + } + + protected void reload() { + if (reloadLock.tryLock()) { + try { + delegate = new GenericItemSimilarity(new JDBCSimilaritiesIterable(dataSource, getAllItemSimilaritiesSQL)); + } finally { + reloadLock.unlock(); + } + } + } + + private static final class JDBCSimilaritiesIterable implements Iterable<GenericItemSimilarity.ItemItemSimilarity> { + + private final DataSource dataSource; + private final String getAllItemSimilaritiesSQL; + + private JDBCSimilaritiesIterable(DataSource dataSource, String getAllItemSimilaritiesSQL) { + this.dataSource = dataSource; + this.getAllItemSimilaritiesSQL = getAllItemSimilaritiesSQL; + } + + @Override + public Iterator<GenericItemSimilarity.ItemItemSimilarity> iterator() { + try { + return new JDBCSimilaritiesIterator(dataSource, getAllItemSimilaritiesSQL); + } catch (SQLException sqle) { + throw new IllegalStateException(sqle); + } + } + } + + private static final class JDBCSimilaritiesIterator + extends ResultSetIterator<GenericItemSimilarity.ItemItemSimilarity> { + + private JDBCSimilaritiesIterator(DataSource dataSource, String getAllItemSimilaritiesSQL) throws SQLException { + super(dataSource, getAllItemSimilaritiesSQL); + } + + @Override + protected GenericItemSimilarity.ItemItemSimilarity parseElement(ResultSet resultSet) throws SQLException { + return new GenericItemSimilarity.ItemItemSimilarity(resultSet.getLong(1), + resultSet.getLong(2), + resultSet.getDouble(3)); + } + } + +}
http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/similarity/jdbc/AbstractJDBCItemSimilarity.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/similarity/jdbc/AbstractJDBCItemSimilarity.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/similarity/jdbc/AbstractJDBCItemSimilarity.java new file mode 100644 index 0000000..1b8d109 --- /dev/null +++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/similarity/jdbc/AbstractJDBCItemSimilarity.java @@ -0,0 +1,213 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.similarity.jdbc; + +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.Collection; + +import javax.sql.DataSource; + +import org.apache.mahout.cf.taste.common.Refreshable; +import org.apache.mahout.cf.taste.common.TasteException; +import org.apache.mahout.cf.taste.impl.common.FastIDSet; +import org.apache.mahout.cf.taste.impl.common.jdbc.AbstractJDBCComponent; +import org.apache.mahout.cf.taste.impl.model.jdbc.ConnectionPoolDataSource; +import org.apache.mahout.cf.taste.similarity.ItemSimilarity; +import org.apache.mahout.common.IOUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * An {@link ItemSimilarity} which draws pre-computed item-item similarities from a database table via JDBC. + */ +public abstract class AbstractJDBCItemSimilarity extends AbstractJDBCComponent implements ItemSimilarity { + + private static final Logger log = LoggerFactory.getLogger(AbstractJDBCItemSimilarity.class); + + static final String DEFAULT_SIMILARITY_TABLE = "taste_item_similarity"; + static final String DEFAULT_ITEM_A_ID_COLUMN = "item_id_a"; + static final String DEFAULT_ITEM_B_ID_COLUMN = "item_id_b"; + static final String DEFAULT_SIMILARITY_COLUMN = "similarity"; + + private final DataSource dataSource; + private final String similarityTable; + private final String itemAIDColumn; + private final String itemBIDColumn; + private final String similarityColumn; + private final String getItemItemSimilaritySQL; + private final String getAllSimilarItemIDsSQL; + + protected AbstractJDBCItemSimilarity(DataSource dataSource, + String getItemItemSimilaritySQL, + String getAllSimilarItemIDsSQL) { + this(dataSource, + DEFAULT_SIMILARITY_TABLE, + DEFAULT_ITEM_A_ID_COLUMN, + DEFAULT_ITEM_B_ID_COLUMN, + DEFAULT_SIMILARITY_COLUMN, + getItemItemSimilaritySQL, + getAllSimilarItemIDsSQL); + } + + protected AbstractJDBCItemSimilarity(DataSource dataSource, + String similarityTable, + String itemAIDColumn, + String itemBIDColumn, + String similarityColumn, + String getItemItemSimilaritySQL, + String getAllSimilarItemIDsSQL) { + AbstractJDBCComponent.checkNotNullAndLog("similarityTable", similarityTable); + AbstractJDBCComponent.checkNotNullAndLog("itemAIDColumn", itemAIDColumn); + AbstractJDBCComponent.checkNotNullAndLog("itemBIDColumn", itemBIDColumn); + AbstractJDBCComponent.checkNotNullAndLog("similarityColumn", similarityColumn); + + AbstractJDBCComponent.checkNotNullAndLog("getItemItemSimilaritySQL", getItemItemSimilaritySQL); + AbstractJDBCComponent.checkNotNullAndLog("getAllSimilarItemIDsSQL", getAllSimilarItemIDsSQL); + + if (!(dataSource instanceof ConnectionPoolDataSource)) { + log.warn("You are not using ConnectionPoolDataSource. Make sure your DataSource pools connections " + + "to the database itself, or database performance will be severely reduced."); + } + + this.dataSource = dataSource; + this.similarityTable = similarityTable; + this.itemAIDColumn = itemAIDColumn; + this.itemBIDColumn = itemBIDColumn; + this.similarityColumn = similarityColumn; + this.getItemItemSimilaritySQL = getItemItemSimilaritySQL; + this.getAllSimilarItemIDsSQL = getAllSimilarItemIDsSQL; + } + + protected String getSimilarityTable() { + return similarityTable; + } + + protected String getItemAIDColumn() { + return itemAIDColumn; + } + + protected String getItemBIDColumn() { + return itemBIDColumn; + } + + protected String getSimilarityColumn() { + return similarityColumn; + } + + @Override + public double itemSimilarity(long itemID1, long itemID2) throws TasteException { + if (itemID1 == itemID2) { + return 1.0; + } + Connection conn = null; + PreparedStatement stmt = null; + try { + conn = dataSource.getConnection(); + stmt = conn.prepareStatement(getItemItemSimilaritySQL, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY); + stmt.setFetchDirection(ResultSet.FETCH_FORWARD); + stmt.setFetchSize(getFetchSize()); + return doItemSimilarity(stmt, itemID1, itemID2); + } catch (SQLException sqle) { + log.warn("Exception while retrieving similarity", sqle); + throw new TasteException(sqle); + } finally { + IOUtils.quietClose(null, stmt, conn); + } + } + + @Override + public double[] itemSimilarities(long itemID1, long[] itemID2s) throws TasteException { + double[] result = new double[itemID2s.length]; + Connection conn = null; + PreparedStatement stmt = null; + try { + conn = dataSource.getConnection(); + stmt = conn.prepareStatement(getItemItemSimilaritySQL, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY); + stmt.setFetchDirection(ResultSet.FETCH_FORWARD); + stmt.setFetchSize(getFetchSize()); + for (int i = 0; i < itemID2s.length; i++) { + result[i] = doItemSimilarity(stmt, itemID1, itemID2s[i]); + } + } catch (SQLException sqle) { + log.warn("Exception while retrieving item similarities", sqle); + throw new TasteException(sqle); + } finally { + IOUtils.quietClose(null, stmt, conn); + } + return result; + } + + @Override + public long[] allSimilarItemIDs(long itemID) throws TasteException { + FastIDSet allSimilarItemIDs = new FastIDSet(); + Connection conn = null; + PreparedStatement stmt = null; + ResultSet rs = null; + try { + conn = dataSource.getConnection(); + stmt = conn.prepareStatement(getAllSimilarItemIDsSQL, ResultSet.TYPE_FORWARD_ONLY, + ResultSet.CONCUR_READ_ONLY); + stmt.setFetchDirection(ResultSet.FETCH_FORWARD); + stmt.setFetchSize(getFetchSize()); + stmt.setLong(1, itemID); + stmt.setLong(2, itemID); + rs = stmt.executeQuery(); + while (rs.next()) { + allSimilarItemIDs.add(rs.getLong(1)); + allSimilarItemIDs.add(rs.getLong(2)); + } + } catch (SQLException sqle) { + log.warn("Exception while retrieving all similar itemIDs", sqle); + throw new TasteException(sqle); + } finally { + IOUtils.quietClose(rs, stmt, conn); + } + allSimilarItemIDs.remove(itemID); + return allSimilarItemIDs.toArray(); + } + + @Override + public void refresh(Collection<Refreshable> alreadyRefreshed) { + // do nothing + } + + private double doItemSimilarity(PreparedStatement stmt, long itemID1, long itemID2) throws SQLException { + // Order as smaller - larger + if (itemID1 > itemID2) { + long temp = itemID1; + itemID1 = itemID2; + itemID2 = temp; + } + stmt.setLong(1, itemID1); + stmt.setLong(2, itemID2); + log.debug("Executing SQL query: {}", getItemItemSimilaritySQL); + ResultSet rs = null; + try { + rs = stmt.executeQuery(); + // If not found, perhaps the items exist but have no presence in the table, + // so NaN is appropriate + return rs.next() ? rs.getDouble(1) : Double.NaN; + } finally { + IOUtils.quietClose(rs); + } + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/similarity/jdbc/MySQLJDBCInMemoryItemSimilarity.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/similarity/jdbc/MySQLJDBCInMemoryItemSimilarity.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/similarity/jdbc/MySQLJDBCInMemoryItemSimilarity.java new file mode 100644 index 0000000..cc831d9 --- /dev/null +++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/similarity/jdbc/MySQLJDBCInMemoryItemSimilarity.java @@ -0,0 +1,47 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.similarity.jdbc; + +import org.apache.mahout.cf.taste.common.TasteException; + +import javax.sql.DataSource; + +public class MySQLJDBCInMemoryItemSimilarity extends SQL92JDBCInMemoryItemSimilarity { + + public MySQLJDBCInMemoryItemSimilarity() throws TasteException { + } + + public MySQLJDBCInMemoryItemSimilarity(String dataSourceName) throws TasteException { + super(dataSourceName); + } + + public MySQLJDBCInMemoryItemSimilarity(DataSource dataSource) { + super(dataSource); + } + + public MySQLJDBCInMemoryItemSimilarity(DataSource dataSource, String getAllItemSimilaritiesSQL) { + super(dataSource, getAllItemSimilaritiesSQL); + } + + @Override + protected int getFetchSize() { + // Need to return this for MySQL Connector/J to make it use streaming mode + return Integer.MIN_VALUE; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/similarity/jdbc/MySQLJDBCItemSimilarity.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/similarity/jdbc/MySQLJDBCItemSimilarity.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/similarity/jdbc/MySQLJDBCItemSimilarity.java new file mode 100644 index 0000000..af0742e --- /dev/null +++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/similarity/jdbc/MySQLJDBCItemSimilarity.java @@ -0,0 +1,103 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.similarity.jdbc; + +import javax.sql.DataSource; + +import org.apache.mahout.cf.taste.common.TasteException; + +/** + * <p> + * An {@link org.apache.mahout.cf.taste.similarity.ItemSimilarity} backed by a MySQL database + * and accessed via JDBC. It may work with other JDBC + * databases. By default, this class assumes that there is a {@link DataSource} available under the JNDI name + * "jdbc/taste", which gives access to a database with a "taste_item_similarity" table with the following + * schema: + * </p> + * + * <table> + * <tr> + * <th>item_id_a</th> + * <th>item_id_b</th> + * <th>similarity</th> + * </tr> + * <tr> + * <td>ABC</td> + * <td>DEF</td> + * <td>0.9</td> + * </tr> + * <tr> + * <td>DEF</td> + * <td>EFG</td> + * <td>0.1</td> + * </tr> + * </table> + * + * <p> + * For example, the following command sets up a suitable table in MySQL, complete with primary key and + * indexes: + * </p> + * + * <p> + * + * <pre> + * CREATE TABLE taste_item_similarity ( + * item_id_a BIGINT NOT NULL, + * item_id_b BIGINT NOT NULL, + * similarity FLOAT NOT NULL, + * PRIMARY KEY (item_id_a, item_id_b), + * ) + * </pre> + * + * </p> + * + * <p> + * Note that for each row, item_id_a should be less than item_id_b. It is redundant to store it both ways, + * so the pair is always stored as a pair with the lesser one first. + * + * @see org.apache.mahout.cf.taste.impl.model.jdbc.MySQLJDBCDataModel + */ +public class MySQLJDBCItemSimilarity extends SQL92JDBCItemSimilarity { + + public MySQLJDBCItemSimilarity() throws TasteException { + } + + public MySQLJDBCItemSimilarity(String dataSourceName) throws TasteException { + super(dataSourceName); + } + + public MySQLJDBCItemSimilarity(DataSource dataSource) { + super(dataSource); + } + + public MySQLJDBCItemSimilarity(DataSource dataSource, + String similarityTable, + String itemAIDColumn, + String itemBIDColumn, + String similarityColumn) { + super(dataSource, similarityTable, itemAIDColumn, itemBIDColumn, similarityColumn); + } + + @Override + protected int getFetchSize() { + // Need to return this for MySQL Connector/J to make it use streaming mode + return Integer.MIN_VALUE; + } + +} + http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/similarity/jdbc/SQL92JDBCInMemoryItemSimilarity.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/similarity/jdbc/SQL92JDBCInMemoryItemSimilarity.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/similarity/jdbc/SQL92JDBCInMemoryItemSimilarity.java new file mode 100644 index 0000000..b311a5e --- /dev/null +++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/similarity/jdbc/SQL92JDBCInMemoryItemSimilarity.java @@ -0,0 +1,51 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.similarity.jdbc; + +import org.apache.mahout.cf.taste.common.TasteException; +import org.apache.mahout.cf.taste.impl.common.jdbc.AbstractJDBCComponent; + +import javax.sql.DataSource; + +public class SQL92JDBCInMemoryItemSimilarity extends AbstractJDBCInMemoryItemSimilarity { + + static final String DEFAULT_GET_ALL_ITEMSIMILARITIES_SQL = + "SELECT " + AbstractJDBCItemSimilarity.DEFAULT_ITEM_A_ID_COLUMN + ", " + + AbstractJDBCItemSimilarity.DEFAULT_ITEM_B_ID_COLUMN + ", " + + AbstractJDBCItemSimilarity.DEFAULT_SIMILARITY_COLUMN + " FROM " + + AbstractJDBCItemSimilarity.DEFAULT_SIMILARITY_TABLE; + + + public SQL92JDBCInMemoryItemSimilarity() throws TasteException { + this(AbstractJDBCComponent.lookupDataSource(AbstractJDBCComponent.DEFAULT_DATASOURCE_NAME), + DEFAULT_GET_ALL_ITEMSIMILARITIES_SQL); + } + + public SQL92JDBCInMemoryItemSimilarity(String dataSourceName) throws TasteException { + this(AbstractJDBCComponent.lookupDataSource(dataSourceName), DEFAULT_GET_ALL_ITEMSIMILARITIES_SQL); + } + + public SQL92JDBCInMemoryItemSimilarity(DataSource dataSource) { + this(dataSource, DEFAULT_GET_ALL_ITEMSIMILARITIES_SQL); + } + + public SQL92JDBCInMemoryItemSimilarity(DataSource dataSource, String getAllItemSimilaritiesSQL) { + super(dataSource, getAllItemSimilaritiesSQL); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/similarity/jdbc/SQL92JDBCItemSimilarity.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/similarity/jdbc/SQL92JDBCItemSimilarity.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/similarity/jdbc/SQL92JDBCItemSimilarity.java new file mode 100644 index 0000000..f449561 --- /dev/null +++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/impl/similarity/jdbc/SQL92JDBCItemSimilarity.java @@ -0,0 +1,57 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.similarity.jdbc; + +import org.apache.mahout.cf.taste.common.TasteException; + +import javax.sql.DataSource; + +public class SQL92JDBCItemSimilarity extends AbstractJDBCItemSimilarity { + + public SQL92JDBCItemSimilarity() throws TasteException { + this(DEFAULT_DATASOURCE_NAME); + } + + public SQL92JDBCItemSimilarity(String dataSourceName) throws TasteException { + this(lookupDataSource(dataSourceName)); + } + + public SQL92JDBCItemSimilarity(DataSource dataSource) { + this(dataSource, + DEFAULT_SIMILARITY_TABLE, + DEFAULT_ITEM_A_ID_COLUMN, + DEFAULT_ITEM_B_ID_COLUMN, + DEFAULT_SIMILARITY_COLUMN); + } + + public SQL92JDBCItemSimilarity(DataSource dataSource, + String similarityTable, + String itemAIDColumn, + String itemBIDColumn, + String similarityColumn) { + super(dataSource, + similarityTable, + itemAIDColumn, + itemBIDColumn, similarityColumn, + "SELECT " + similarityColumn + " FROM " + similarityTable + " WHERE " + + itemAIDColumn + "=? AND " + itemBIDColumn + "=?", + "SELECT " + itemAIDColumn + ", " + itemBIDColumn + " FROM " + similarityTable + " WHERE " + + itemAIDColumn + "=? OR " + itemBIDColumn + "=?"); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/web/RecommenderServlet.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/web/RecommenderServlet.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/web/RecommenderServlet.java new file mode 100644 index 0000000..a5a89c6 --- /dev/null +++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/web/RecommenderServlet.java @@ -0,0 +1,215 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.web; + +import org.apache.mahout.cf.taste.common.TasteException; +import org.apache.mahout.cf.taste.model.DataModel; +import org.apache.mahout.cf.taste.model.Preference; +import org.apache.mahout.cf.taste.model.PreferenceArray; +import org.apache.mahout.cf.taste.recommender.RecommendedItem; +import org.apache.mahout.cf.taste.recommender.Recommender; + +import javax.servlet.ServletConfig; +import javax.servlet.ServletException; +import javax.servlet.http.HttpServlet; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import java.io.IOException; +import java.io.PrintWriter; +import java.util.List; + +/** + * <p>A servlet which returns recommendations, as its name implies. The servlet accepts GET and POST + * HTTP requests, and looks for two parameters:</p> + * + * <ul> + * <li><em>userID</em>: the user ID for which to produce recommendations</li> + * <li><em>howMany</em>: the number of recommendations to produce</li> + * <li><em>debug</em>: (optional) output a lot of information that is useful in debugging. + * Defaults to false, of course.</li> + * </ul> + * + * <p>The response is text, and contains a list of the IDs of recommended items, in descending + * order of relevance, one per line.</p> + * + * <p>For example, you can get 10 recommendations for user 123 from the following URL (assuming + * you are running taste in a web application running locally on port 8080):<br/> + * {@code http://localhost:8080/taste/RecommenderServlet?userID=123&howMany=10}</p> + * + * <p>This servlet requires one {@code init-param} in {@code web.xml}: it must find + * a parameter named "recommender-class" which is the name of a class that implements + * {@link Recommender} and has a no-arg constructor. The servlet will instantiate and use + * this {@link Recommender} to produce recommendations.</p> + */ +public final class RecommenderServlet extends HttpServlet { + + private static final int NUM_TOP_PREFERENCES = 20; + private static final int DEFAULT_HOW_MANY = 20; + + private Recommender recommender; + + @Override + public void init(ServletConfig config) throws ServletException { + super.init(config); + String recommenderClassName = config.getInitParameter("recommender-class"); + if (recommenderClassName == null) { + throw new ServletException("Servlet init-param \"recommender-class\" is not defined"); + } + RecommenderSingleton.initializeIfNeeded(recommenderClassName); + recommender = RecommenderSingleton.getInstance().getRecommender(); + } + + @Override + public void doGet(HttpServletRequest request, + HttpServletResponse response) throws ServletException { + + String userIDString = request.getParameter("userID"); + if (userIDString == null) { + throw new ServletException("userID was not specified"); + } + long userID = Long.parseLong(userIDString); + String howManyString = request.getParameter("howMany"); + int howMany = howManyString == null ? DEFAULT_HOW_MANY : Integer.parseInt(howManyString); + boolean debug = Boolean.parseBoolean(request.getParameter("debug")); + String format = request.getParameter("format"); + if (format == null) { + format = "text"; + } + + try { + List<RecommendedItem> items = recommender.recommend(userID, howMany); + if ("text".equals(format)) { + writePlainText(response, userID, debug, items); + } else if ("xml".equals(format)) { + writeXML(response, items); + } else if ("json".equals(format)) { + writeJSON(response, items); + } else { + throw new ServletException("Bad format parameter: " + format); + } + } catch (TasteException | IOException te) { + throw new ServletException(te); + } + + } + + private static void writeXML(HttpServletResponse response, Iterable<RecommendedItem> items) throws IOException { + response.setContentType("application/xml"); + response.setCharacterEncoding("UTF-8"); + response.setHeader("Cache-Control", "no-cache"); + PrintWriter writer = response.getWriter(); + writer.print("<?xml version=\"1.0\" encoding=\"UTF-8\"?><recommendedItems>"); + for (RecommendedItem recommendedItem : items) { + writer.print("<item><value>"); + writer.print(recommendedItem.getValue()); + writer.print("</value><id>"); + writer.print(recommendedItem.getItemID()); + writer.print("</id></item>"); + } + writer.println("</recommendedItems>"); + } + + private static void writeJSON(HttpServletResponse response, Iterable<RecommendedItem> items) throws IOException { + response.setContentType("application/json"); + response.setCharacterEncoding("UTF-8"); + response.setHeader("Cache-Control", "no-cache"); + PrintWriter writer = response.getWriter(); + writer.print("{\"recommendedItems\":{\"item\":["); + boolean first = true; + for (RecommendedItem recommendedItem : items) { + if (first) { + first = false; + } else { + writer.print(','); + } + writer.print("{\"value\":\""); + writer.print(recommendedItem.getValue()); + writer.print("\",\"id\":\""); + writer.print(recommendedItem.getItemID()); + writer.print("\"}"); + } + writer.println("]}}"); + } + + private void writePlainText(HttpServletResponse response, + long userID, + boolean debug, + Iterable<RecommendedItem> items) throws IOException, TasteException { + response.setContentType("text/plain"); + response.setCharacterEncoding("UTF-8"); + response.setHeader("Cache-Control", "no-cache"); + PrintWriter writer = response.getWriter(); + if (debug) { + writeDebugRecommendations(userID, items, writer); + } else { + writeRecommendations(items, writer); + } + } + + private static void writeRecommendations(Iterable<RecommendedItem> items, PrintWriter writer) { + for (RecommendedItem recommendedItem : items) { + writer.print(recommendedItem.getValue()); + writer.print('\t'); + writer.println(recommendedItem.getItemID()); + } + } + + private void writeDebugRecommendations(long userID, Iterable<RecommendedItem> items, PrintWriter writer) + throws TasteException { + DataModel dataModel = recommender.getDataModel(); + writer.print("User:"); + writer.println(userID); + writer.print("Recommender: "); + writer.println(recommender); + writer.println(); + writer.print("Top "); + writer.print(NUM_TOP_PREFERENCES); + writer.println(" Preferences:"); + PreferenceArray rawPrefs = dataModel.getPreferencesFromUser(userID); + int length = rawPrefs.length(); + PreferenceArray sortedPrefs = rawPrefs.clone(); + sortedPrefs.sortByValueReversed(); + // Cap this at NUM_TOP_PREFERENCES just to be brief + int max = Math.min(NUM_TOP_PREFERENCES, length); + for (int i = 0; i < max; i++) { + Preference pref = sortedPrefs.get(i); + writer.print(pref.getValue()); + writer.print('\t'); + writer.println(pref.getItemID()); + } + writer.println(); + writer.println("Recommendations:"); + for (RecommendedItem recommendedItem : items) { + writer.print(recommendedItem.getValue()); + writer.print('\t'); + writer.println(recommendedItem.getItemID()); + } + } + + @Override + public void doPost(HttpServletRequest request, + HttpServletResponse response) throws ServletException { + doGet(request, response); + } + + @Override + public String toString() { + return "RecommenderServlet[recommender:" + recommender + ']'; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/web/RecommenderSingleton.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/web/RecommenderSingleton.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/web/RecommenderSingleton.java new file mode 100644 index 0000000..265d7c0 --- /dev/null +++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/web/RecommenderSingleton.java @@ -0,0 +1,57 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.web; + +import org.apache.mahout.cf.taste.recommender.Recommender; +import org.apache.mahout.common.ClassUtils; + +/** + * <p>A singleton which holds an instance of a {@link Recommender}. This is used to share + * a {@link Recommender} between {@link RecommenderServlet} and {@code RecommenderService.jws}.</p> + */ +public final class RecommenderSingleton { + + private final Recommender recommender; + + private static RecommenderSingleton instance; + + public static synchronized RecommenderSingleton getInstance() { + if (instance == null) { + throw new IllegalStateException("Not initialized"); + } + return instance; + } + + public static synchronized void initializeIfNeeded(String recommenderClassName) { + if (instance == null) { + instance = new RecommenderSingleton(recommenderClassName); + } + } + + private RecommenderSingleton(String recommenderClassName) { + if (recommenderClassName == null) { + throw new IllegalArgumentException("Recommender class name is null"); + } + recommender = ClassUtils.instantiateAs(recommenderClassName, Recommender.class); + } + + public Recommender getRecommender() { + return recommender; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/web/RecommenderWrapper.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/web/RecommenderWrapper.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/web/RecommenderWrapper.java new file mode 100644 index 0000000..e927098 --- /dev/null +++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/cf/taste/web/RecommenderWrapper.java @@ -0,0 +1,126 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.web; + +import com.google.common.io.Files; +import com.google.common.io.InputSupplier; +import com.google.common.io.Resources; +import org.apache.mahout.cf.taste.common.Refreshable; +import org.apache.mahout.cf.taste.common.TasteException; +import org.apache.mahout.cf.taste.model.DataModel; +import org.apache.mahout.cf.taste.recommender.IDRescorer; +import org.apache.mahout.cf.taste.recommender.RecommendedItem; +import org.apache.mahout.cf.taste.recommender.Recommender; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.net.URL; +import java.util.Collection; +import java.util.List; + +/** + * Users of the packaging and deployment mechanism in this module need + * to produce a {@link Recommender} implementation with a no-arg constructor, + * which will internally build the desired {@link Recommender} and delegate + * to it. This wrapper simplifies that process. Simply extend this class and + * implement {@link #buildRecommender()}. + */ +public abstract class RecommenderWrapper implements Recommender { + + private static final Logger log = LoggerFactory.getLogger(RecommenderWrapper.class); + + private final Recommender delegate; + + protected RecommenderWrapper() throws TasteException, IOException { + this.delegate = buildRecommender(); + } + + /** + * @return the {@link Recommender} which should be used to produce recommendations + * by this wrapper implementation + */ + protected abstract Recommender buildRecommender() throws IOException, TasteException; + + @Override + public List<RecommendedItem> recommend(long userID, int howMany) throws TasteException { + return delegate.recommend(userID, howMany); + } + + @Override + public List<RecommendedItem> recommend(long userID, int howMany, IDRescorer rescorer) throws TasteException { + return delegate.recommend(userID, howMany, rescorer); + } + + @Override + public float estimatePreference(long userID, long itemID) throws TasteException { + return delegate.estimatePreference(userID, itemID); + } + + @Override + public void setPreference(long userID, long itemID, float value) throws TasteException { + delegate.setPreference(userID, itemID, value); + } + + @Override + public void removePreference(long userID, long itemID) throws TasteException { + delegate.removePreference(userID, itemID); + } + + @Override + public DataModel getDataModel() { + return delegate.getDataModel(); + } + + @Override + public void refresh(Collection<Refreshable> alreadyRefreshed) { + delegate.refresh(alreadyRefreshed); + } + + /** + * Reads the given resource into a temporary file. This is intended to be used + * to read data files which are stored as a resource available on the classpath, + * such as in a JAR file. However for convenience the resource name will also + * be interpreted as a relative path to a local file, if no such resource is + * found. This facilitates testing. + * + * @param resourceName name of resource in classpath, or relative path to file + * @return temporary {@link File} with resource data + * @throws IOException if an error occurs while reading or writing data + */ + public static File readResourceToTempFile(String resourceName) throws IOException { + String absoluteResource = resourceName.startsWith("/") ? resourceName : '/' + resourceName; + log.info("Loading resource {}", absoluteResource); + InputSupplier<? extends InputStream> inSupplier; + try { + URL resourceURL = Resources.getResource(RecommenderWrapper.class, absoluteResource); + inSupplier = Resources.newInputStreamSupplier(resourceURL); + } catch (IllegalArgumentException iae) { + File resourceFile = new File(resourceName); + log.info("Falling back to load file {}", resourceFile.getAbsolutePath()); + inSupplier = Files.newInputStreamSupplier(resourceFile); + } + File tempFile = File.createTempFile("taste", null); + tempFile.deleteOnExit(); + Files.copy(inSupplier, tempFile); + return tempFile; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/classifier/ConfusionMatrixDumper.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/classifier/ConfusionMatrixDumper.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/classifier/ConfusionMatrixDumper.java new file mode 100644 index 0000000..03a3000 --- /dev/null +++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/classifier/ConfusionMatrixDumper.java @@ -0,0 +1,425 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier; + +import com.google.common.collect.Lists; +import org.apache.commons.io.Charsets; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.SequenceFile; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.util.ToolRunner; +import org.apache.mahout.common.AbstractJob; +import org.apache.mahout.common.commandline.DefaultOptionCreator; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.MatrixWritable; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.OutputStream; +import java.io.PrintStream; +import java.util.Iterator; +import java.util.List; +import java.util.Map; + +/** + * Export a ConfusionMatrix in various text formats: ToString version Grayscale HTML table Summary HTML table + * Table of counts all with optional HTML wrappers + * + * Input format: Hadoop SequenceFile with Text key and MatrixWritable value, 1 pair + * + * Intended to consume ConfusionMatrix SequenceFile output by Bayes TestClassifier class + */ +public final class ConfusionMatrixDumper extends AbstractJob { + + private static final String TAB_SEPARATOR = "|"; + + // HTML wrapper - default CSS + private static final String HEADER = "<html>" + + "<head>\n" + + "<title>TITLE</title>\n" + + "</head>" + + "<body>\n" + + "<style type='text/css'> \n" + + "table\n" + + "{\n" + + "border:3px solid black; text-align:left;\n" + + "}\n" + + "th.normalHeader\n" + + "{\n" + + "border:1px solid black;border-collapse:collapse;text-align:center;" + + "background-color:white\n" + + "}\n" + + "th.tallHeader\n" + + "{\n" + + "border:1px solid black;border-collapse:collapse;text-align:center;" + + "background-color:white; height:6em\n" + + "}\n" + + "tr.label\n" + + "{\n" + + "border:1px solid black;border-collapse:collapse;text-align:center;" + + "background-color:white\n" + + "}\n" + + "tr.row\n" + + "{\n" + + "border:1px solid gray;text-align:center;background-color:snow\n" + + "}\n" + + "td\n" + + "{\n" + + "min-width:2em\n" + + "}\n" + + "td.cell\n" + + "{\n" + + "border:1px solid black;text-align:right;background-color:snow\n" + + "}\n" + + "td.empty\n" + + "{\n" + + "border:0px;text-align:right;background-color:snow\n" + + "}\n" + + "td.white\n" + + "{\n" + + "border:0px solid black;text-align:right;background-color:white\n" + + "}\n" + + "td.black\n" + + "{\n" + + "border:0px solid red;text-align:right;background-color:black\n" + + "}\n" + + "td.gray1\n" + + "{\n" + + "border:0px solid green;text-align:right; background-color:LightGray\n" + + "}\n" + "td.gray2\n" + "{\n" + + "border:0px solid blue;text-align:right;background-color:gray\n" + + "}\n" + "td.gray3\n" + "{\n" + + "border:0px solid red;text-align:right;background-color:DarkGray\n" + + "}\n" + "th" + "{\n" + " text-align: center;\n" + + " vertical-align: bottom;\n" + + " padding-bottom: 3px;\n" + " padding-left: 5px;\n" + + " padding-right: 5px;\n" + "}\n" + " .verticalText\n" + + " {\n" + " text-align: center;\n" + + " vertical-align: middle;\n" + " width: 20px;\n" + + " margin: 0px;\n" + " padding: 0px;\n" + + " padding-left: 3px;\n" + " padding-right: 3px;\n" + + " padding-top: 10px;\n" + " white-space: nowrap;\n" + + " -webkit-transform: rotate(-90deg); \n" + + " -moz-transform: rotate(-90deg); \n" + " };\n" + + "</style>\n"; + private static final String FOOTER = "</html></body>"; + + // CSS style names. + private static final String CSS_TABLE = "table"; + private static final String CSS_LABEL = "label"; + private static final String CSS_TALL_HEADER = "tall"; + private static final String CSS_VERTICAL = "verticalText"; + private static final String CSS_CELL = "cell"; + private static final String CSS_EMPTY = "empty"; + private static final String[] CSS_GRAY_CELLS = {"white", "gray1", "gray2", "gray3", "black"}; + + private ConfusionMatrixDumper() {} + + public static void main(String[] args) throws Exception { + ToolRunner.run(new ConfusionMatrixDumper(), args); + } + + @Override + public int run(String[] args) throws IOException { + addInputOption(); + addOption("output", "o", "Output path", null); // AbstractJob output feature requires param + addOption(DefaultOptionCreator.overwriteOption().create()); + addFlag("html", null, "Create complete HTML page"); + addFlag("text", null, "Dump simple text"); + Map<String,List<String>> parsedArgs = parseArguments(args); + if (parsedArgs == null) { + return -1; + } + + Path inputPath = getInputPath(); + String outputFile = hasOption("output") ? getOption("output") : null; + boolean text = parsedArgs.containsKey("--text"); + boolean wrapHtml = parsedArgs.containsKey("--html"); + PrintStream out = getPrintStream(outputFile); + if (text) { + exportText(inputPath, out); + } else { + exportTable(inputPath, out, wrapHtml); + } + out.flush(); + if (out != System.out) { + out.close(); + } + return 0; + } + + private static void exportText(Path inputPath, PrintStream out) throws IOException { + MatrixWritable mw = new MatrixWritable(); + Text key = new Text(); + readSeqFile(inputPath, key, mw); + Matrix m = mw.get(); + ConfusionMatrix cm = new ConfusionMatrix(m); + out.println(String.format("%-40s", "Label") + TAB_SEPARATOR + String.format("%-10s", "Total") + + TAB_SEPARATOR + String.format("%-10s", "Correct") + TAB_SEPARATOR + + String.format("%-6s", "%") + TAB_SEPARATOR); + out.println(String.format("%-70s", "-").replace(' ', '-')); + List<String> labels = stripDefault(cm); + for (String label : labels) { + int correct = cm.getCorrect(label); + double accuracy = cm.getAccuracy(label); + int count = getCount(cm, label); + out.println(String.format("%-40s", label) + TAB_SEPARATOR + String.format("%-10s", count) + + TAB_SEPARATOR + String.format("%-10s", correct) + TAB_SEPARATOR + + String.format("%-6s", (int) Math.round(accuracy)) + TAB_SEPARATOR); + } + out.println(String.format("%-70s", "-").replace(' ', '-')); + out.println(cm.toString()); + } + + private static void exportTable(Path inputPath, PrintStream out, boolean wrapHtml) throws IOException { + MatrixWritable mw = new MatrixWritable(); + Text key = new Text(); + readSeqFile(inputPath, key, mw); + String fileName = inputPath.getName(); + fileName = fileName.substring(fileName.lastIndexOf('/') + 1, fileName.length()); + Matrix m = mw.get(); + ConfusionMatrix cm = new ConfusionMatrix(m); + if (wrapHtml) { + printHeader(out, fileName); + } + out.println("<p/>"); + printSummaryTable(cm, out); + out.println("<p/>"); + printGrayTable(cm, out); + out.println("<p/>"); + printCountsTable(cm, out); + out.println("<p/>"); + printTextInBox(cm, out); + out.println("<p/>"); + if (wrapHtml) { + printFooter(out); + } + } + + private static List<String> stripDefault(ConfusionMatrix cm) { + List<String> stripped = Lists.newArrayList(cm.getLabels().iterator()); + String defaultLabel = cm.getDefaultLabel(); + int unclassified = cm.getTotal(defaultLabel); + if (unclassified > 0) { + return stripped; + } + stripped.remove(defaultLabel); + return stripped; + } + + // TODO: test - this should work with HDFS files + private static void readSeqFile(Path path, Text key, MatrixWritable m) throws IOException { + Configuration conf = new Configuration(); + FileSystem fs = FileSystem.get(conf); + SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, conf); + reader.next(key, m); + } + + // TODO: test - this might not work with HDFS files? + // after all, it does no seeks + private static PrintStream getPrintStream(String outputFilename) throws IOException { + if (outputFilename != null) { + File outputFile = new File(outputFilename); + if (outputFile.exists()) { + outputFile.delete(); + } + outputFile.createNewFile(); + OutputStream os = new FileOutputStream(outputFile); + return new PrintStream(os, false, Charsets.UTF_8.displayName()); + } else { + return System.out; + } + } + + private static int getLabelTotal(ConfusionMatrix cm, String rowLabel) { + Iterator<String> iter = cm.getLabels().iterator(); + int count = 0; + while (iter.hasNext()) { + count += cm.getCount(rowLabel, iter.next()); + } + return count; + } + + // HTML generator code + + private static void printTextInBox(ConfusionMatrix cm, PrintStream out) { + out.println("<div style='width:90%;overflow:scroll;'>"); + out.println("<pre>"); + out.println(cm.toString()); + out.println("</pre>"); + out.println("</div>"); + } + + public static void printSummaryTable(ConfusionMatrix cm, PrintStream out) { + format("<table class='%s'>\n", out, CSS_TABLE); + format("<tr class='%s'>", out, CSS_LABEL); + out.println("<td>Label</td><td>Total</td><td>Correct</td><td>%</td>"); + out.println("</tr>"); + List<String> labels = stripDefault(cm); + for (String label : labels) { + printSummaryRow(cm, out, label); + } + out.println("</table>"); + } + + private static void printSummaryRow(ConfusionMatrix cm, PrintStream out, String label) { + format("<tr class='%s'>", out, CSS_CELL); + int correct = cm.getCorrect(label); + double accuracy = cm.getAccuracy(label); + int count = getCount(cm, label); + format("<td class='%s'>%s</td><td>%d</td><td>%d</td><td>%d</td>", out, CSS_CELL, label, count, correct, + (int) Math.round(accuracy)); + out.println("</tr>"); + } + + private static int getCount(ConfusionMatrix cm, String label) { + int count = 0; + for (String s : cm.getLabels()) { + count += cm.getCount(label, s); + } + return count; + } + + public static void printGrayTable(ConfusionMatrix cm, PrintStream out) { + format("<table class='%s'>\n", out, CSS_TABLE); + printCountsHeader(cm, out, true); + printGrayRows(cm, out); + out.println("</table>"); + } + + /** + * Print each value in a four-value grayscale based on count/max. Gives a mostly white matrix with grays in + * misclassified, and black in diagonal. TODO: Using the sqrt(count/max) as the rating is more stringent + */ + private static void printGrayRows(ConfusionMatrix cm, PrintStream out) { + List<String> labels = stripDefault(cm); + for (String label : labels) { + printGrayRow(cm, out, labels, label); + } + } + + private static void printGrayRow(ConfusionMatrix cm, + PrintStream out, + Iterable<String> labels, + String rowLabel) { + format("<tr class='%s'>", out, CSS_LABEL); + format("<td>%s</td>", out, rowLabel); + int total = getLabelTotal(cm, rowLabel); + for (String columnLabel : labels) { + printGrayCell(cm, out, total, rowLabel, columnLabel); + } + out.println("</tr>"); + } + + // assign white/light/medium/dark to 0,1/4,1/2,3/4 of total number of inputs + // assign black to count = total, meaning complete success + // alternative rating is to use sqrt(total) instead of total - this is more drastic + private static void printGrayCell(ConfusionMatrix cm, + PrintStream out, + int total, + String rowLabel, + String columnLabel) { + + int count = cm.getCount(rowLabel, columnLabel); + if (count == 0) { + out.format("<td class='%s'/>", CSS_EMPTY); + } else { + // 0 is white, full is black, everything else gray + int rating = (int) ((count / (double) total) * 4); + String css = CSS_GRAY_CELLS[rating]; + format("<td class='%s' title='%s'>%s</td>", out, css, columnLabel, count); + } + } + + public static void printCountsTable(ConfusionMatrix cm, PrintStream out) { + format("<table class='%s'>\n", out, CSS_TABLE); + printCountsHeader(cm, out, false); + printCountsRows(cm, out); + out.println("</table>"); + } + + private static void printCountsRows(ConfusionMatrix cm, PrintStream out) { + List<String> labels = stripDefault(cm); + for (String label : labels) { + printCountsRow(cm, out, labels, label); + } + } + + private static void printCountsRow(ConfusionMatrix cm, + PrintStream out, + Iterable<String> labels, + String rowLabel) { + out.println("<tr>"); + format("<td class='%s'>%s</td>", out, CSS_LABEL, rowLabel); + for (String columnLabel : labels) { + printCountsCell(cm, out, rowLabel, columnLabel); + } + out.println("</tr>"); + } + + private static void printCountsCell(ConfusionMatrix cm, PrintStream out, String rowLabel, String columnLabel) { + int count = cm.getCount(rowLabel, columnLabel); + String s = count == 0 ? "" : Integer.toString(count); + format("<td class='%s' title='%s'>%s</td>", out, CSS_CELL, columnLabel, s); + } + + private static void printCountsHeader(ConfusionMatrix cm, PrintStream out, boolean vertical) { + List<String> labels = stripDefault(cm); + int longest = getLongestHeader(labels); + if (vertical) { + // do vertical - rotation is a bitch + out.format("<tr class='%s' style='height:%dem'><th> </th>%n", CSS_TALL_HEADER, longest / 2); + for (String label : labels) { + out.format("<th><div class='%s'>%s</div></th>", CSS_VERTICAL, label); + } + out.println("</tr>"); + } else { + // header - empty cell in upper left + out.format("<tr class='%s'><td class='%s'></td>%n", CSS_TABLE, CSS_LABEL); + for (String label : labels) { + out.format("<td>%s</td>", label); + } + out.format("</tr>"); + } + } + + private static int getLongestHeader(Iterable<String> labels) { + int max = 0; + for (String label : labels) { + max = Math.max(label.length(), max); + } + return max; + } + + private static void format(String format, PrintStream out, Object... args) { + String format2 = String.format(format, args); + out.println(format2); + } + + public static void printHeader(PrintStream out, CharSequence title) { + out.println(HEADER.replace("TITLE", title)); + } + + public static void printFooter(PrintStream out) { + out.println(FOOTER); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/clustering/cdbw/CDbwEvaluator.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/clustering/cdbw/CDbwEvaluator.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/clustering/cdbw/CDbwEvaluator.java new file mode 100644 index 0000000..545c1ff --- /dev/null +++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/clustering/cdbw/CDbwEvaluator.java @@ -0,0 +1,387 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.cdbw; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.mahout.clustering.Cluster; +import org.apache.mahout.clustering.GaussianAccumulator; +import org.apache.mahout.clustering.OnlineGaussianAccumulator; +import org.apache.mahout.clustering.evaluation.RepresentativePointsDriver; +import org.apache.mahout.clustering.evaluation.RepresentativePointsMapper; +import org.apache.mahout.clustering.iterator.ClusterWritable; +import org.apache.mahout.common.ClassUtils; +import org.apache.mahout.common.distance.DistanceMeasure; +import org.apache.mahout.common.iterator.sequencefile.PathFilters; +import org.apache.mahout.common.iterator.sequencefile.PathType; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterable; +import org.apache.mahout.math.RandomAccessSparseVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.Vector.Element; +import org.apache.mahout.math.VectorWritable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.TreeMap; + +/** + * This class calculates the CDbw metric as defined in + * http://www.db-net.aueb.gr/index.php/corporate/content/download/227/833/file/HV_poster2002.pdf + */ +public final class CDbwEvaluator { + + private static final Logger log = LoggerFactory.getLogger(CDbwEvaluator.class); + + private final Map<Integer,List<VectorWritable>> representativePoints; + private final Map<Integer,Double> stDevs = new HashMap<>(); + private final List<Cluster> clusters; + private final DistanceMeasure measure; + private Double interClusterDensity = null; + // these are symmetric so we only compute half of them + private Map<Integer,Map<Integer,Double>> minimumDistances = null; + // these are symmetric too + private Map<Integer,Map<Integer,Double>> interClusterDensities = null; + // these are symmetric too + private Map<Integer,Map<Integer,int[]>> closestRepPointIndices = null; + + /** + * For testing only + * + * @param representativePoints + * a Map<Integer,List<VectorWritable>> of representative points keyed by clusterId + * @param clusters + * a Map<Integer,Cluster> of the clusters keyed by clusterId + * @param measure + * an appropriate DistanceMeasure + */ + public CDbwEvaluator(Map<Integer,List<VectorWritable>> representativePoints, List<Cluster> clusters, + DistanceMeasure measure) { + this.representativePoints = representativePoints; + this.clusters = clusters; + this.measure = measure; + for (Integer cId : representativePoints.keySet()) { + computeStd(cId); + } + } + + /** + * Initialize a new instance from job information + * + * @param conf + * a Configuration with appropriate parameters + * @param clustersIn + * a String path to the input clusters directory + */ + public CDbwEvaluator(Configuration conf, Path clustersIn) { + measure = ClassUtils + .instantiateAs(conf.get(RepresentativePointsDriver.DISTANCE_MEASURE_KEY), DistanceMeasure.class); + representativePoints = RepresentativePointsMapper.getRepresentativePoints(conf); + clusters = loadClusters(conf, clustersIn); + for (Integer cId : representativePoints.keySet()) { + computeStd(cId); + } + } + + /** + * Load the clusters from their sequence files + * + * @param clustersIn + * a String pathname to the directory containing input cluster files + * @return a List<Cluster> of the clusters + */ + private static List<Cluster> loadClusters(Configuration conf, Path clustersIn) { + List<Cluster> clusters = new ArrayList<>(); + for (ClusterWritable clusterWritable : new SequenceFileDirValueIterable<ClusterWritable>(clustersIn, PathType.LIST, + PathFilters.logsCRCFilter(), conf)) { + Cluster cluster = clusterWritable.getValue(); + clusters.add(cluster); + } + return clusters; + } + + /** + * Compute the standard deviation of the representative points for the given cluster. Store these in stDevs, indexed + * by cI + * + * @param cI + * a int clusterId. + */ + private void computeStd(int cI) { + List<VectorWritable> repPts = representativePoints.get(cI); + GaussianAccumulator accumulator = new OnlineGaussianAccumulator(); + for (VectorWritable vw : repPts) { + accumulator.observe(vw.get(), 1.0); + } + accumulator.compute(); + double d = accumulator.getAverageStd(); + stDevs.put(cI, d); + } + + /** + * Compute the density of points near the midpoint between the two closest points of the clusters (eqn 2) used for + * inter-cluster density calculation + * + * @param uIJ + * the Vector midpoint between the closest representative points of the clusters + * @param cI + * the int clusterId of the i-th cluster + * @param cJ + * the int clusterId of the j-th cluster + * @param avgStd + * the double average standard deviation of the two clusters + * @return a double + */ + private double density(Vector uIJ, int cI, int cJ, double avgStd) { + List<VectorWritable> repI = representativePoints.get(cI); + List<VectorWritable> repJ = representativePoints.get(cJ); + double sum = 0.0; + // count the number of representative points of the clusters which are within the + // average std of the two clusters from the midpoint uIJ (eqn 3) + for (VectorWritable vwI : repI) { + if (uIJ != null && measure.distance(uIJ, vwI.get()) <= avgStd) { + sum++; + } + } + for (VectorWritable vwJ : repJ) { + if (uIJ != null && measure.distance(uIJ, vwJ.get()) <= avgStd) { + sum++; + } + } + int nI = repI.size(); + int nJ = repJ.size(); + return sum / (nI + nJ); + } + + /** + * Compute the CDbw validity metric (eqn 8). The goal of this metric is to reward clusterings which have a high + * intraClusterDensity and also a high cluster separation. + * + * @return a double + */ + public double getCDbw() { + return intraClusterDensity() * separation(); + } + + /** + * The average density within clusters is defined as the percentage of representative points that reside in the + * neighborhood of the clusters' centers. The goal is the density within clusters to be significantly high. (eqn 5) + * + * @return a double + */ + public double intraClusterDensity() { + double avgDensity = 0; + int count = 0; + for (Element elem : intraClusterDensities().nonZeroes()) { + double value = elem.get(); + if (!Double.isNaN(value)) { + avgDensity += value; + count++; + } + } + return avgDensity / count; + } + + /** + * This function evaluates the density of points in the regions between each clusters (eqn 1). The goal is the density + * in the area between clusters to be significant low. + * + * @return a Map<Integer,Map<Integer,Double>> of the inter-cluster densities + */ + public Map<Integer,Map<Integer,Double>> interClusterDensities() { + if (interClusterDensities != null) { + return interClusterDensities; + } + interClusterDensities = new TreeMap<>(); + // find the closest representative points between the clusters + for (int i = 0; i < clusters.size(); i++) { + int cI = clusters.get(i).getId(); + Map<Integer,Double> map = new TreeMap<>(); + interClusterDensities.put(cI, map); + for (int j = i + 1; j < clusters.size(); j++) { + int cJ = clusters.get(j).getId(); + double minDistance = minimumDistance(cI, cJ); // the distance between the closest representative points + Vector uIJ = midpointVector(cI, cJ); // the midpoint between the closest representative points + double stdSum = stDevs.get(cI) + stDevs.get(cJ); + double density = density(uIJ, cI, cJ, stdSum / 2); + double interDensity = minDistance * density / stdSum; + map.put(cJ, interDensity); + if (log.isDebugEnabled()) { + log.debug("minDistance[{},{}]={}", cI, cJ, minDistance); + log.debug("interDensity[{},{}]={}", cI, cJ, density); + log.debug("density[{},{}]={}", cI, cJ, interDensity); + } + } + } + return interClusterDensities; + } + + /** + * Calculate the separation of clusters (eqn 4) taking into account both the distances between the clusters' closest + * points and the Inter-cluster density. The goal is the distances between clusters to be high while the + * representative point density in the areas between them are low. + * + * @return a double + */ + public double separation() { + double minDistanceSum = 0; + Map<Integer,Map<Integer,Double>> distances = minimumDistances(); + for (Map<Integer,Double> map : distances.values()) { + for (Double dist : map.values()) { + if (!Double.isInfinite(dist)) { + minDistanceSum += dist * 2; // account for other half of calculated triangular minimumDistances matrix + } + } + } + return minDistanceSum / (1.0 + interClusterDensity()); + } + + /** + * This function evaluates the average density of points in the regions between clusters (eqn 1). The goal is the + * density in the area between clusters to be significant low. + * + * @return a double + */ + public double interClusterDensity() { + if (interClusterDensity != null) { + return interClusterDensity; + } + double sum = 0.0; + int count = 0; + Map<Integer,Map<Integer,Double>> distances = interClusterDensities(); + for (Map<Integer,Double> row : distances.values()) { + for (Double density : row.values()) { + if (!Double.isNaN(density)) { + sum += density; + count++; + } + } + } + log.debug("interClusterDensity={}", sum); + interClusterDensity = sum / count; + return interClusterDensity; + } + + /** + * The average density within clusters is defined as the percentage of representative points that reside in the + * neighborhood of the clusters' centers. The goal is the density within clusters to be significantly high. (eqn 5) + * + * @return a Vector of the intra-densities of each clusterId + */ + public Vector intraClusterDensities() { + Vector densities = new RandomAccessSparseVector(Integer.MAX_VALUE); + // compute the average standard deviation of the clusters + double stdev = 0.0; + for (Integer cI : representativePoints.keySet()) { + stdev += stDevs.get(cI); + } + int c = representativePoints.size(); + stdev /= c; + for (Cluster cluster : clusters) { + Integer cI = cluster.getId(); + List<VectorWritable> repPtsI = representativePoints.get(cI); + int r = repPtsI.size(); + double sumJ = 0.0; + // compute the term density (eqn 6) + for (VectorWritable pt : repPtsI) { + // compute f(x, vIJ) (eqn 7) + Vector repJ = pt.get(); + double densityIJ = measure.distance(cluster.getCenter(), repJ) <= stdev ? 1.0 : 0.0; + // accumulate sumJ + sumJ += densityIJ / stdev; + } + densities.set(cI, sumJ / r); + } + return densities; + } + + /** + * Calculate and cache the distances between the clusters' closest representative points. Also cache the indices of + * the closest representative points used for later use + * + * @return a Map<Integer,Vector> of the closest distances, keyed by clusterId + */ + private Map<Integer,Map<Integer,Double>> minimumDistances() { + if (minimumDistances != null) { + return minimumDistances; + } + minimumDistances = new TreeMap<>(); + closestRepPointIndices = new TreeMap<>(); + for (int i = 0; i < clusters.size(); i++) { + Integer cI = clusters.get(i).getId(); + Map<Integer,Double> map = new TreeMap<>(); + Map<Integer,int[]> treeMap = new TreeMap<>(); + closestRepPointIndices.put(cI, treeMap); + minimumDistances.put(cI, map); + List<VectorWritable> closRepI = representativePoints.get(cI); + for (int j = i + 1; j < clusters.size(); j++) { + // find min{d(closRepI, closRepJ)} + Integer cJ = clusters.get(j).getId(); + List<VectorWritable> closRepJ = representativePoints.get(cJ); + double minDistance = Double.MAX_VALUE; + int[] midPointIndices = null; + for (int xI = 0; xI < closRepI.size(); xI++) { + VectorWritable aRepI = closRepI.get(xI); + for (int xJ = 0; xJ < closRepJ.size(); xJ++) { + VectorWritable aRepJ = closRepJ.get(xJ); + double distance = measure.distance(aRepI.get(), aRepJ.get()); + if (distance < minDistance) { + minDistance = distance; + midPointIndices = new int[] {xI, xJ}; + } + } + } + map.put(cJ, minDistance); + treeMap.put(cJ, midPointIndices); + } + } + return minimumDistances; + } + + private double minimumDistance(int cI, int cJ) { + Map<Integer,Double> distances = minimumDistances().get(cI); + if (distances != null) { + return distances.get(cJ); + } else { + return minimumDistances().get(cJ).get(cI); + } + } + + private Vector midpointVector(int cI, int cJ) { + Map<Integer,Double> distances = minimumDistances().get(cI); + if (distances != null) { + int[] ks = closestRepPointIndices.get(cI).get(cJ); + if (ks == null) { + return null; + } + return representativePoints.get(cI).get(ks[0]).get().plus(representativePoints.get(cJ).get(ks[1]).get()) + .divide(2); + } else { + int[] ks = closestRepPointIndices.get(cJ).get(cI); + if (ks == null) { + return null; + } + return representativePoints.get(cJ).get(ks[1]).get().plus(representativePoints.get(cI).get(ks[0]).get()) + .divide(2); + } + + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/clustering/conversion/InputDriver.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/clustering/conversion/InputDriver.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/clustering/conversion/InputDriver.java new file mode 100644 index 0000000..6a2b376 --- /dev/null +++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/clustering/conversion/InputDriver.java @@ -0,0 +1,114 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.conversion; + +import java.io.IOException; + +import org.apache.commons.cli2.CommandLine; +import org.apache.commons.cli2.Group; +import org.apache.commons.cli2.Option; +import org.apache.commons.cli2.OptionException; +import org.apache.commons.cli2.builder.ArgumentBuilder; +import org.apache.commons.cli2.builder.DefaultOptionBuilder; +import org.apache.commons.cli2.builder.GroupBuilder; +import org.apache.commons.cli2.commandline.Parser; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapreduce.Job; +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat; +import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; +import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; +import org.apache.mahout.common.CommandLineUtil; +import org.apache.mahout.common.commandline.DefaultOptionCreator; +import org.apache.mahout.math.VectorWritable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * This class converts text files containing space-delimited floating point numbers into + * Mahout sequence files of VectorWritable suitable for input to the clustering jobs in + * particular, and any Mahout job requiring this input in general. + * + */ +public final class InputDriver { + + private static final Logger log = LoggerFactory.getLogger(InputDriver.class); + + private InputDriver() { + } + + public static void main(String[] args) throws IOException, InterruptedException, ClassNotFoundException { + DefaultOptionBuilder obuilder = new DefaultOptionBuilder(); + ArgumentBuilder abuilder = new ArgumentBuilder(); + GroupBuilder gbuilder = new GroupBuilder(); + + Option inputOpt = DefaultOptionCreator.inputOption().withRequired(false).create(); + Option outputOpt = DefaultOptionCreator.outputOption().withRequired(false).create(); + Option vectorOpt = obuilder.withLongName("vector").withRequired(false).withArgument( + abuilder.withName("v").withMinimum(1).withMaximum(1).create()).withDescription( + "The vector implementation to use.").withShortName("v").create(); + + Option helpOpt = DefaultOptionCreator.helpOption(); + + Group group = gbuilder.withName("Options").withOption(inputOpt).withOption(outputOpt).withOption( + vectorOpt).withOption(helpOpt).create(); + + try { + Parser parser = new Parser(); + parser.setGroup(group); + CommandLine cmdLine = parser.parse(args); + if (cmdLine.hasOption(helpOpt)) { + CommandLineUtil.printHelp(group); + return; + } + + Path input = new Path(cmdLine.getValue(inputOpt, "testdata").toString()); + Path output = new Path(cmdLine.getValue(outputOpt, "output").toString()); + String vectorClassName = cmdLine.getValue(vectorOpt, + "org.apache.mahout.math.RandomAccessSparseVector").toString(); + runJob(input, output, vectorClassName); + } catch (OptionException e) { + log.error("Exception parsing command line: ", e); + CommandLineUtil.printHelp(group); + } + } + + public static void runJob(Path input, Path output, String vectorClassName) + throws IOException, InterruptedException, ClassNotFoundException { + Configuration conf = new Configuration(); + conf.set("vector.implementation.class.name", vectorClassName); + Job job = new Job(conf, "Input Driver running over input: " + input); + + job.setOutputKeyClass(Text.class); + job.setOutputValueClass(VectorWritable.class); + job.setOutputFormatClass(SequenceFileOutputFormat.class); + job.setMapperClass(InputMapper.class); + job.setNumReduceTasks(0); + job.setJarByClass(InputDriver.class); + + FileInputFormat.addInputPath(job, input); + FileOutputFormat.setOutputPath(job, output); + + boolean succeeded = job.waitForCompletion(true); + if (!succeeded) { + throw new IllegalStateException("Job failed!"); + } + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/clustering/conversion/InputMapper.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/clustering/conversion/InputMapper.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/clustering/conversion/InputMapper.java new file mode 100644 index 0000000..e4c72c6 --- /dev/null +++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/clustering/conversion/InputMapper.java @@ -0,0 +1,81 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.conversion; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; + +import java.io.IOException; +import java.lang.reflect.Constructor; +import java.lang.reflect.InvocationTargetException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.regex.Pattern; + +public class InputMapper extends Mapper<LongWritable, Text, Text, VectorWritable> { + + private static final Pattern SPACE = Pattern.compile(" "); + + private Constructor<?> constructor; + + @Override + protected void map(LongWritable key, Text values, Context context) throws IOException, InterruptedException { + + String[] numbers = SPACE.split(values.toString()); + // sometimes there are multiple separator spaces + Collection<Double> doubles = new ArrayList<>(); + for (String value : numbers) { + if (!value.isEmpty()) { + doubles.add(Double.valueOf(value)); + } + } + // ignore empty lines in data file + if (!doubles.isEmpty()) { + try { + Vector result = (Vector) constructor.newInstance(doubles.size()); + int index = 0; + for (Double d : doubles) { + result.set(index++, d); + } + VectorWritable vectorWritable = new VectorWritable(result); + context.write(new Text(String.valueOf(index)), vectorWritable); + + } catch (InstantiationException | IllegalAccessException | InvocationTargetException e) { + throw new IllegalStateException(e); + } + } + } + + @Override + protected void setup(Context context) throws IOException, InterruptedException { + super.setup(context); + Configuration conf = context.getConfiguration(); + String vectorImplClassName = conf.get("vector.implementation.class.name"); + try { + Class<? extends Vector> outputClass = conf.getClassByName(vectorImplClassName).asSubclass(Vector.class); + constructor = outputClass.getConstructor(int.class); + } catch (NoSuchMethodException | ClassNotFoundException e) { + throw new IllegalStateException(e); + } + } + +}
