http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/integration/src/main/java/org/apache/mahout/cf/taste/impl/similarity/jdbc/SQL92JDBCInMemoryItemSimilarity.java ---------------------------------------------------------------------- diff --git a/integration/src/main/java/org/apache/mahout/cf/taste/impl/similarity/jdbc/SQL92JDBCInMemoryItemSimilarity.java b/integration/src/main/java/org/apache/mahout/cf/taste/impl/similarity/jdbc/SQL92JDBCInMemoryItemSimilarity.java deleted file mode 100644 index b311a5e..0000000 --- a/integration/src/main/java/org/apache/mahout/cf/taste/impl/similarity/jdbc/SQL92JDBCInMemoryItemSimilarity.java +++ /dev/null @@ -1,51 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.cf.taste.impl.similarity.jdbc; - -import org.apache.mahout.cf.taste.common.TasteException; -import org.apache.mahout.cf.taste.impl.common.jdbc.AbstractJDBCComponent; - -import javax.sql.DataSource; - -public class SQL92JDBCInMemoryItemSimilarity extends AbstractJDBCInMemoryItemSimilarity { - - static final String DEFAULT_GET_ALL_ITEMSIMILARITIES_SQL = - "SELECT " + AbstractJDBCItemSimilarity.DEFAULT_ITEM_A_ID_COLUMN + ", " - + AbstractJDBCItemSimilarity.DEFAULT_ITEM_B_ID_COLUMN + ", " - + AbstractJDBCItemSimilarity.DEFAULT_SIMILARITY_COLUMN + " FROM " - + AbstractJDBCItemSimilarity.DEFAULT_SIMILARITY_TABLE; - - - public SQL92JDBCInMemoryItemSimilarity() throws TasteException { - this(AbstractJDBCComponent.lookupDataSource(AbstractJDBCComponent.DEFAULT_DATASOURCE_NAME), - DEFAULT_GET_ALL_ITEMSIMILARITIES_SQL); - } - - public SQL92JDBCInMemoryItemSimilarity(String dataSourceName) throws TasteException { - this(AbstractJDBCComponent.lookupDataSource(dataSourceName), DEFAULT_GET_ALL_ITEMSIMILARITIES_SQL); - } - - public SQL92JDBCInMemoryItemSimilarity(DataSource dataSource) { - this(dataSource, DEFAULT_GET_ALL_ITEMSIMILARITIES_SQL); - } - - public SQL92JDBCInMemoryItemSimilarity(DataSource dataSource, String getAllItemSimilaritiesSQL) { - super(dataSource, getAllItemSimilaritiesSQL); - } - -}
http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/integration/src/main/java/org/apache/mahout/cf/taste/impl/similarity/jdbc/SQL92JDBCItemSimilarity.java ---------------------------------------------------------------------- diff --git a/integration/src/main/java/org/apache/mahout/cf/taste/impl/similarity/jdbc/SQL92JDBCItemSimilarity.java b/integration/src/main/java/org/apache/mahout/cf/taste/impl/similarity/jdbc/SQL92JDBCItemSimilarity.java deleted file mode 100644 index f449561..0000000 --- a/integration/src/main/java/org/apache/mahout/cf/taste/impl/similarity/jdbc/SQL92JDBCItemSimilarity.java +++ /dev/null @@ -1,57 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.cf.taste.impl.similarity.jdbc; - -import org.apache.mahout.cf.taste.common.TasteException; - -import javax.sql.DataSource; - -public class SQL92JDBCItemSimilarity extends AbstractJDBCItemSimilarity { - - public SQL92JDBCItemSimilarity() throws TasteException { - this(DEFAULT_DATASOURCE_NAME); - } - - public SQL92JDBCItemSimilarity(String dataSourceName) throws TasteException { - this(lookupDataSource(dataSourceName)); - } - - public SQL92JDBCItemSimilarity(DataSource dataSource) { - this(dataSource, - DEFAULT_SIMILARITY_TABLE, - DEFAULT_ITEM_A_ID_COLUMN, - DEFAULT_ITEM_B_ID_COLUMN, - DEFAULT_SIMILARITY_COLUMN); - } - - public SQL92JDBCItemSimilarity(DataSource dataSource, - String similarityTable, - String itemAIDColumn, - String itemBIDColumn, - String similarityColumn) { - super(dataSource, - similarityTable, - itemAIDColumn, - itemBIDColumn, similarityColumn, - "SELECT " + similarityColumn + " FROM " + similarityTable + " WHERE " - + itemAIDColumn + "=? AND " + itemBIDColumn + "=?", - "SELECT " + itemAIDColumn + ", " + itemBIDColumn + " FROM " + similarityTable + " WHERE " - + itemAIDColumn + "=? OR " + itemBIDColumn + "=?"); - } - -} http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/integration/src/main/java/org/apache/mahout/cf/taste/web/RecommenderServlet.java ---------------------------------------------------------------------- diff --git a/integration/src/main/java/org/apache/mahout/cf/taste/web/RecommenderServlet.java b/integration/src/main/java/org/apache/mahout/cf/taste/web/RecommenderServlet.java deleted file mode 100644 index a5a89c6..0000000 --- a/integration/src/main/java/org/apache/mahout/cf/taste/web/RecommenderServlet.java +++ /dev/null @@ -1,215 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.cf.taste.web; - -import org.apache.mahout.cf.taste.common.TasteException; -import org.apache.mahout.cf.taste.model.DataModel; -import org.apache.mahout.cf.taste.model.Preference; -import org.apache.mahout.cf.taste.model.PreferenceArray; -import org.apache.mahout.cf.taste.recommender.RecommendedItem; -import org.apache.mahout.cf.taste.recommender.Recommender; - -import javax.servlet.ServletConfig; -import javax.servlet.ServletException; -import javax.servlet.http.HttpServlet; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; -import java.io.IOException; -import java.io.PrintWriter; -import java.util.List; - -/** - * <p>A servlet which returns recommendations, as its name implies. The servlet accepts GET and POST - * HTTP requests, and looks for two parameters:</p> - * - * <ul> - * <li><em>userID</em>: the user ID for which to produce recommendations</li> - * <li><em>howMany</em>: the number of recommendations to produce</li> - * <li><em>debug</em>: (optional) output a lot of information that is useful in debugging. - * Defaults to false, of course.</li> - * </ul> - * - * <p>The response is text, and contains a list of the IDs of recommended items, in descending - * order of relevance, one per line.</p> - * - * <p>For example, you can get 10 recommendations for user 123 from the following URL (assuming - * you are running taste in a web application running locally on port 8080):<br/> - * {@code http://localhost:8080/taste/RecommenderServlet?userID=123&howMany=10}</p> - * - * <p>This servlet requires one {@code init-param} in {@code web.xml}: it must find - * a parameter named "recommender-class" which is the name of a class that implements - * {@link Recommender} and has a no-arg constructor. The servlet will instantiate and use - * this {@link Recommender} to produce recommendations.</p> - */ -public final class RecommenderServlet extends HttpServlet { - - private static final int NUM_TOP_PREFERENCES = 20; - private static final int DEFAULT_HOW_MANY = 20; - - private Recommender recommender; - - @Override - public void init(ServletConfig config) throws ServletException { - super.init(config); - String recommenderClassName = config.getInitParameter("recommender-class"); - if (recommenderClassName == null) { - throw new ServletException("Servlet init-param \"recommender-class\" is not defined"); - } - RecommenderSingleton.initializeIfNeeded(recommenderClassName); - recommender = RecommenderSingleton.getInstance().getRecommender(); - } - - @Override - public void doGet(HttpServletRequest request, - HttpServletResponse response) throws ServletException { - - String userIDString = request.getParameter("userID"); - if (userIDString == null) { - throw new ServletException("userID was not specified"); - } - long userID = Long.parseLong(userIDString); - String howManyString = request.getParameter("howMany"); - int howMany = howManyString == null ? DEFAULT_HOW_MANY : Integer.parseInt(howManyString); - boolean debug = Boolean.parseBoolean(request.getParameter("debug")); - String format = request.getParameter("format"); - if (format == null) { - format = "text"; - } - - try { - List<RecommendedItem> items = recommender.recommend(userID, howMany); - if ("text".equals(format)) { - writePlainText(response, userID, debug, items); - } else if ("xml".equals(format)) { - writeXML(response, items); - } else if ("json".equals(format)) { - writeJSON(response, items); - } else { - throw new ServletException("Bad format parameter: " + format); - } - } catch (TasteException | IOException te) { - throw new ServletException(te); - } - - } - - private static void writeXML(HttpServletResponse response, Iterable<RecommendedItem> items) throws IOException { - response.setContentType("application/xml"); - response.setCharacterEncoding("UTF-8"); - response.setHeader("Cache-Control", "no-cache"); - PrintWriter writer = response.getWriter(); - writer.print("<?xml version=\"1.0\" encoding=\"UTF-8\"?><recommendedItems>"); - for (RecommendedItem recommendedItem : items) { - writer.print("<item><value>"); - writer.print(recommendedItem.getValue()); - writer.print("</value><id>"); - writer.print(recommendedItem.getItemID()); - writer.print("</id></item>"); - } - writer.println("</recommendedItems>"); - } - - private static void writeJSON(HttpServletResponse response, Iterable<RecommendedItem> items) throws IOException { - response.setContentType("application/json"); - response.setCharacterEncoding("UTF-8"); - response.setHeader("Cache-Control", "no-cache"); - PrintWriter writer = response.getWriter(); - writer.print("{\"recommendedItems\":{\"item\":["); - boolean first = true; - for (RecommendedItem recommendedItem : items) { - if (first) { - first = false; - } else { - writer.print(','); - } - writer.print("{\"value\":\""); - writer.print(recommendedItem.getValue()); - writer.print("\",\"id\":\""); - writer.print(recommendedItem.getItemID()); - writer.print("\"}"); - } - writer.println("]}}"); - } - - private void writePlainText(HttpServletResponse response, - long userID, - boolean debug, - Iterable<RecommendedItem> items) throws IOException, TasteException { - response.setContentType("text/plain"); - response.setCharacterEncoding("UTF-8"); - response.setHeader("Cache-Control", "no-cache"); - PrintWriter writer = response.getWriter(); - if (debug) { - writeDebugRecommendations(userID, items, writer); - } else { - writeRecommendations(items, writer); - } - } - - private static void writeRecommendations(Iterable<RecommendedItem> items, PrintWriter writer) { - for (RecommendedItem recommendedItem : items) { - writer.print(recommendedItem.getValue()); - writer.print('\t'); - writer.println(recommendedItem.getItemID()); - } - } - - private void writeDebugRecommendations(long userID, Iterable<RecommendedItem> items, PrintWriter writer) - throws TasteException { - DataModel dataModel = recommender.getDataModel(); - writer.print("User:"); - writer.println(userID); - writer.print("Recommender: "); - writer.println(recommender); - writer.println(); - writer.print("Top "); - writer.print(NUM_TOP_PREFERENCES); - writer.println(" Preferences:"); - PreferenceArray rawPrefs = dataModel.getPreferencesFromUser(userID); - int length = rawPrefs.length(); - PreferenceArray sortedPrefs = rawPrefs.clone(); - sortedPrefs.sortByValueReversed(); - // Cap this at NUM_TOP_PREFERENCES just to be brief - int max = Math.min(NUM_TOP_PREFERENCES, length); - for (int i = 0; i < max; i++) { - Preference pref = sortedPrefs.get(i); - writer.print(pref.getValue()); - writer.print('\t'); - writer.println(pref.getItemID()); - } - writer.println(); - writer.println("Recommendations:"); - for (RecommendedItem recommendedItem : items) { - writer.print(recommendedItem.getValue()); - writer.print('\t'); - writer.println(recommendedItem.getItemID()); - } - } - - @Override - public void doPost(HttpServletRequest request, - HttpServletResponse response) throws ServletException { - doGet(request, response); - } - - @Override - public String toString() { - return "RecommenderServlet[recommender:" + recommender + ']'; - } - -} http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/integration/src/main/java/org/apache/mahout/cf/taste/web/RecommenderSingleton.java ---------------------------------------------------------------------- diff --git a/integration/src/main/java/org/apache/mahout/cf/taste/web/RecommenderSingleton.java b/integration/src/main/java/org/apache/mahout/cf/taste/web/RecommenderSingleton.java deleted file mode 100644 index 265d7c0..0000000 --- a/integration/src/main/java/org/apache/mahout/cf/taste/web/RecommenderSingleton.java +++ /dev/null @@ -1,57 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.cf.taste.web; - -import org.apache.mahout.cf.taste.recommender.Recommender; -import org.apache.mahout.common.ClassUtils; - -/** - * <p>A singleton which holds an instance of a {@link Recommender}. This is used to share - * a {@link Recommender} between {@link RecommenderServlet} and {@code RecommenderService.jws}.</p> - */ -public final class RecommenderSingleton { - - private final Recommender recommender; - - private static RecommenderSingleton instance; - - public static synchronized RecommenderSingleton getInstance() { - if (instance == null) { - throw new IllegalStateException("Not initialized"); - } - return instance; - } - - public static synchronized void initializeIfNeeded(String recommenderClassName) { - if (instance == null) { - instance = new RecommenderSingleton(recommenderClassName); - } - } - - private RecommenderSingleton(String recommenderClassName) { - if (recommenderClassName == null) { - throw new IllegalArgumentException("Recommender class name is null"); - } - recommender = ClassUtils.instantiateAs(recommenderClassName, Recommender.class); - } - - public Recommender getRecommender() { - return recommender; - } - -} http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/integration/src/main/java/org/apache/mahout/cf/taste/web/RecommenderWrapper.java ---------------------------------------------------------------------- diff --git a/integration/src/main/java/org/apache/mahout/cf/taste/web/RecommenderWrapper.java b/integration/src/main/java/org/apache/mahout/cf/taste/web/RecommenderWrapper.java deleted file mode 100644 index e927098..0000000 --- a/integration/src/main/java/org/apache/mahout/cf/taste/web/RecommenderWrapper.java +++ /dev/null @@ -1,126 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.cf.taste.web; - -import com.google.common.io.Files; -import com.google.common.io.InputSupplier; -import com.google.common.io.Resources; -import org.apache.mahout.cf.taste.common.Refreshable; -import org.apache.mahout.cf.taste.common.TasteException; -import org.apache.mahout.cf.taste.model.DataModel; -import org.apache.mahout.cf.taste.recommender.IDRescorer; -import org.apache.mahout.cf.taste.recommender.RecommendedItem; -import org.apache.mahout.cf.taste.recommender.Recommender; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.io.File; -import java.io.IOException; -import java.io.InputStream; -import java.net.URL; -import java.util.Collection; -import java.util.List; - -/** - * Users of the packaging and deployment mechanism in this module need - * to produce a {@link Recommender} implementation with a no-arg constructor, - * which will internally build the desired {@link Recommender} and delegate - * to it. This wrapper simplifies that process. Simply extend this class and - * implement {@link #buildRecommender()}. - */ -public abstract class RecommenderWrapper implements Recommender { - - private static final Logger log = LoggerFactory.getLogger(RecommenderWrapper.class); - - private final Recommender delegate; - - protected RecommenderWrapper() throws TasteException, IOException { - this.delegate = buildRecommender(); - } - - /** - * @return the {@link Recommender} which should be used to produce recommendations - * by this wrapper implementation - */ - protected abstract Recommender buildRecommender() throws IOException, TasteException; - - @Override - public List<RecommendedItem> recommend(long userID, int howMany) throws TasteException { - return delegate.recommend(userID, howMany); - } - - @Override - public List<RecommendedItem> recommend(long userID, int howMany, IDRescorer rescorer) throws TasteException { - return delegate.recommend(userID, howMany, rescorer); - } - - @Override - public float estimatePreference(long userID, long itemID) throws TasteException { - return delegate.estimatePreference(userID, itemID); - } - - @Override - public void setPreference(long userID, long itemID, float value) throws TasteException { - delegate.setPreference(userID, itemID, value); - } - - @Override - public void removePreference(long userID, long itemID) throws TasteException { - delegate.removePreference(userID, itemID); - } - - @Override - public DataModel getDataModel() { - return delegate.getDataModel(); - } - - @Override - public void refresh(Collection<Refreshable> alreadyRefreshed) { - delegate.refresh(alreadyRefreshed); - } - - /** - * Reads the given resource into a temporary file. This is intended to be used - * to read data files which are stored as a resource available on the classpath, - * such as in a JAR file. However for convenience the resource name will also - * be interpreted as a relative path to a local file, if no such resource is - * found. This facilitates testing. - * - * @param resourceName name of resource in classpath, or relative path to file - * @return temporary {@link File} with resource data - * @throws IOException if an error occurs while reading or writing data - */ - public static File readResourceToTempFile(String resourceName) throws IOException { - String absoluteResource = resourceName.startsWith("/") ? resourceName : '/' + resourceName; - log.info("Loading resource {}", absoluteResource); - InputSupplier<? extends InputStream> inSupplier; - try { - URL resourceURL = Resources.getResource(RecommenderWrapper.class, absoluteResource); - inSupplier = Resources.newInputStreamSupplier(resourceURL); - } catch (IllegalArgumentException iae) { - File resourceFile = new File(resourceName); - log.info("Falling back to load file {}", resourceFile.getAbsolutePath()); - inSupplier = Files.newInputStreamSupplier(resourceFile); - } - File tempFile = File.createTempFile("taste", null); - tempFile.deleteOnExit(); - Files.copy(inSupplier, tempFile); - return tempFile; - } - -} http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/integration/src/main/java/org/apache/mahout/classifier/ConfusionMatrixDumper.java ---------------------------------------------------------------------- diff --git a/integration/src/main/java/org/apache/mahout/classifier/ConfusionMatrixDumper.java b/integration/src/main/java/org/apache/mahout/classifier/ConfusionMatrixDumper.java deleted file mode 100644 index 03a3000..0000000 --- a/integration/src/main/java/org/apache/mahout/classifier/ConfusionMatrixDumper.java +++ /dev/null @@ -1,425 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.classifier; - -import com.google.common.collect.Lists; -import org.apache.commons.io.Charsets; -import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.fs.FileSystem; -import org.apache.hadoop.fs.Path; -import org.apache.hadoop.io.SequenceFile; -import org.apache.hadoop.io.Text; -import org.apache.hadoop.util.ToolRunner; -import org.apache.mahout.common.AbstractJob; -import org.apache.mahout.common.commandline.DefaultOptionCreator; -import org.apache.mahout.math.Matrix; -import org.apache.mahout.math.MatrixWritable; - -import java.io.File; -import java.io.FileOutputStream; -import java.io.IOException; -import java.io.OutputStream; -import java.io.PrintStream; -import java.util.Iterator; -import java.util.List; -import java.util.Map; - -/** - * Export a ConfusionMatrix in various text formats: ToString version Grayscale HTML table Summary HTML table - * Table of counts all with optional HTML wrappers - * - * Input format: Hadoop SequenceFile with Text key and MatrixWritable value, 1 pair - * - * Intended to consume ConfusionMatrix SequenceFile output by Bayes TestClassifier class - */ -public final class ConfusionMatrixDumper extends AbstractJob { - - private static final String TAB_SEPARATOR = "|"; - - // HTML wrapper - default CSS - private static final String HEADER = "<html>" - + "<head>\n" - + "<title>TITLE</title>\n" - + "</head>" - + "<body>\n" - + "<style type='text/css'> \n" - + "table\n" - + "{\n" - + "border:3px solid black; text-align:left;\n" - + "}\n" - + "th.normalHeader\n" - + "{\n" - + "border:1px solid black;border-collapse:collapse;text-align:center;" - + "background-color:white\n" - + "}\n" - + "th.tallHeader\n" - + "{\n" - + "border:1px solid black;border-collapse:collapse;text-align:center;" - + "background-color:white; height:6em\n" - + "}\n" - + "tr.label\n" - + "{\n" - + "border:1px solid black;border-collapse:collapse;text-align:center;" - + "background-color:white\n" - + "}\n" - + "tr.row\n" - + "{\n" - + "border:1px solid gray;text-align:center;background-color:snow\n" - + "}\n" - + "td\n" - + "{\n" - + "min-width:2em\n" - + "}\n" - + "td.cell\n" - + "{\n" - + "border:1px solid black;text-align:right;background-color:snow\n" - + "}\n" - + "td.empty\n" - + "{\n" - + "border:0px;text-align:right;background-color:snow\n" - + "}\n" - + "td.white\n" - + "{\n" - + "border:0px solid black;text-align:right;background-color:white\n" - + "}\n" - + "td.black\n" - + "{\n" - + "border:0px solid red;text-align:right;background-color:black\n" - + "}\n" - + "td.gray1\n" - + "{\n" - + "border:0px solid green;text-align:right; background-color:LightGray\n" - + "}\n" + "td.gray2\n" + "{\n" - + "border:0px solid blue;text-align:right;background-color:gray\n" - + "}\n" + "td.gray3\n" + "{\n" - + "border:0px solid red;text-align:right;background-color:DarkGray\n" - + "}\n" + "th" + "{\n" + " text-align: center;\n" - + " vertical-align: bottom;\n" - + " padding-bottom: 3px;\n" + " padding-left: 5px;\n" - + " padding-right: 5px;\n" + "}\n" + " .verticalText\n" - + " {\n" + " text-align: center;\n" - + " vertical-align: middle;\n" + " width: 20px;\n" - + " margin: 0px;\n" + " padding: 0px;\n" - + " padding-left: 3px;\n" + " padding-right: 3px;\n" - + " padding-top: 10px;\n" + " white-space: nowrap;\n" - + " -webkit-transform: rotate(-90deg); \n" - + " -moz-transform: rotate(-90deg); \n" + " };\n" - + "</style>\n"; - private static final String FOOTER = "</html></body>"; - - // CSS style names. - private static final String CSS_TABLE = "table"; - private static final String CSS_LABEL = "label"; - private static final String CSS_TALL_HEADER = "tall"; - private static final String CSS_VERTICAL = "verticalText"; - private static final String CSS_CELL = "cell"; - private static final String CSS_EMPTY = "empty"; - private static final String[] CSS_GRAY_CELLS = {"white", "gray1", "gray2", "gray3", "black"}; - - private ConfusionMatrixDumper() {} - - public static void main(String[] args) throws Exception { - ToolRunner.run(new ConfusionMatrixDumper(), args); - } - - @Override - public int run(String[] args) throws IOException { - addInputOption(); - addOption("output", "o", "Output path", null); // AbstractJob output feature requires param - addOption(DefaultOptionCreator.overwriteOption().create()); - addFlag("html", null, "Create complete HTML page"); - addFlag("text", null, "Dump simple text"); - Map<String,List<String>> parsedArgs = parseArguments(args); - if (parsedArgs == null) { - return -1; - } - - Path inputPath = getInputPath(); - String outputFile = hasOption("output") ? getOption("output") : null; - boolean text = parsedArgs.containsKey("--text"); - boolean wrapHtml = parsedArgs.containsKey("--html"); - PrintStream out = getPrintStream(outputFile); - if (text) { - exportText(inputPath, out); - } else { - exportTable(inputPath, out, wrapHtml); - } - out.flush(); - if (out != System.out) { - out.close(); - } - return 0; - } - - private static void exportText(Path inputPath, PrintStream out) throws IOException { - MatrixWritable mw = new MatrixWritable(); - Text key = new Text(); - readSeqFile(inputPath, key, mw); - Matrix m = mw.get(); - ConfusionMatrix cm = new ConfusionMatrix(m); - out.println(String.format("%-40s", "Label") + TAB_SEPARATOR + String.format("%-10s", "Total") - + TAB_SEPARATOR + String.format("%-10s", "Correct") + TAB_SEPARATOR - + String.format("%-6s", "%") + TAB_SEPARATOR); - out.println(String.format("%-70s", "-").replace(' ', '-')); - List<String> labels = stripDefault(cm); - for (String label : labels) { - int correct = cm.getCorrect(label); - double accuracy = cm.getAccuracy(label); - int count = getCount(cm, label); - out.println(String.format("%-40s", label) + TAB_SEPARATOR + String.format("%-10s", count) - + TAB_SEPARATOR + String.format("%-10s", correct) + TAB_SEPARATOR - + String.format("%-6s", (int) Math.round(accuracy)) + TAB_SEPARATOR); - } - out.println(String.format("%-70s", "-").replace(' ', '-')); - out.println(cm.toString()); - } - - private static void exportTable(Path inputPath, PrintStream out, boolean wrapHtml) throws IOException { - MatrixWritable mw = new MatrixWritable(); - Text key = new Text(); - readSeqFile(inputPath, key, mw); - String fileName = inputPath.getName(); - fileName = fileName.substring(fileName.lastIndexOf('/') + 1, fileName.length()); - Matrix m = mw.get(); - ConfusionMatrix cm = new ConfusionMatrix(m); - if (wrapHtml) { - printHeader(out, fileName); - } - out.println("<p/>"); - printSummaryTable(cm, out); - out.println("<p/>"); - printGrayTable(cm, out); - out.println("<p/>"); - printCountsTable(cm, out); - out.println("<p/>"); - printTextInBox(cm, out); - out.println("<p/>"); - if (wrapHtml) { - printFooter(out); - } - } - - private static List<String> stripDefault(ConfusionMatrix cm) { - List<String> stripped = Lists.newArrayList(cm.getLabels().iterator()); - String defaultLabel = cm.getDefaultLabel(); - int unclassified = cm.getTotal(defaultLabel); - if (unclassified > 0) { - return stripped; - } - stripped.remove(defaultLabel); - return stripped; - } - - // TODO: test - this should work with HDFS files - private static void readSeqFile(Path path, Text key, MatrixWritable m) throws IOException { - Configuration conf = new Configuration(); - FileSystem fs = FileSystem.get(conf); - SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, conf); - reader.next(key, m); - } - - // TODO: test - this might not work with HDFS files? - // after all, it does no seeks - private static PrintStream getPrintStream(String outputFilename) throws IOException { - if (outputFilename != null) { - File outputFile = new File(outputFilename); - if (outputFile.exists()) { - outputFile.delete(); - } - outputFile.createNewFile(); - OutputStream os = new FileOutputStream(outputFile); - return new PrintStream(os, false, Charsets.UTF_8.displayName()); - } else { - return System.out; - } - } - - private static int getLabelTotal(ConfusionMatrix cm, String rowLabel) { - Iterator<String> iter = cm.getLabels().iterator(); - int count = 0; - while (iter.hasNext()) { - count += cm.getCount(rowLabel, iter.next()); - } - return count; - } - - // HTML generator code - - private static void printTextInBox(ConfusionMatrix cm, PrintStream out) { - out.println("<div style='width:90%;overflow:scroll;'>"); - out.println("<pre>"); - out.println(cm.toString()); - out.println("</pre>"); - out.println("</div>"); - } - - public static void printSummaryTable(ConfusionMatrix cm, PrintStream out) { - format("<table class='%s'>\n", out, CSS_TABLE); - format("<tr class='%s'>", out, CSS_LABEL); - out.println("<td>Label</td><td>Total</td><td>Correct</td><td>%</td>"); - out.println("</tr>"); - List<String> labels = stripDefault(cm); - for (String label : labels) { - printSummaryRow(cm, out, label); - } - out.println("</table>"); - } - - private static void printSummaryRow(ConfusionMatrix cm, PrintStream out, String label) { - format("<tr class='%s'>", out, CSS_CELL); - int correct = cm.getCorrect(label); - double accuracy = cm.getAccuracy(label); - int count = getCount(cm, label); - format("<td class='%s'>%s</td><td>%d</td><td>%d</td><td>%d</td>", out, CSS_CELL, label, count, correct, - (int) Math.round(accuracy)); - out.println("</tr>"); - } - - private static int getCount(ConfusionMatrix cm, String label) { - int count = 0; - for (String s : cm.getLabels()) { - count += cm.getCount(label, s); - } - return count; - } - - public static void printGrayTable(ConfusionMatrix cm, PrintStream out) { - format("<table class='%s'>\n", out, CSS_TABLE); - printCountsHeader(cm, out, true); - printGrayRows(cm, out); - out.println("</table>"); - } - - /** - * Print each value in a four-value grayscale based on count/max. Gives a mostly white matrix with grays in - * misclassified, and black in diagonal. TODO: Using the sqrt(count/max) as the rating is more stringent - */ - private static void printGrayRows(ConfusionMatrix cm, PrintStream out) { - List<String> labels = stripDefault(cm); - for (String label : labels) { - printGrayRow(cm, out, labels, label); - } - } - - private static void printGrayRow(ConfusionMatrix cm, - PrintStream out, - Iterable<String> labels, - String rowLabel) { - format("<tr class='%s'>", out, CSS_LABEL); - format("<td>%s</td>", out, rowLabel); - int total = getLabelTotal(cm, rowLabel); - for (String columnLabel : labels) { - printGrayCell(cm, out, total, rowLabel, columnLabel); - } - out.println("</tr>"); - } - - // assign white/light/medium/dark to 0,1/4,1/2,3/4 of total number of inputs - // assign black to count = total, meaning complete success - // alternative rating is to use sqrt(total) instead of total - this is more drastic - private static void printGrayCell(ConfusionMatrix cm, - PrintStream out, - int total, - String rowLabel, - String columnLabel) { - - int count = cm.getCount(rowLabel, columnLabel); - if (count == 0) { - out.format("<td class='%s'/>", CSS_EMPTY); - } else { - // 0 is white, full is black, everything else gray - int rating = (int) ((count / (double) total) * 4); - String css = CSS_GRAY_CELLS[rating]; - format("<td class='%s' title='%s'>%s</td>", out, css, columnLabel, count); - } - } - - public static void printCountsTable(ConfusionMatrix cm, PrintStream out) { - format("<table class='%s'>\n", out, CSS_TABLE); - printCountsHeader(cm, out, false); - printCountsRows(cm, out); - out.println("</table>"); - } - - private static void printCountsRows(ConfusionMatrix cm, PrintStream out) { - List<String> labels = stripDefault(cm); - for (String label : labels) { - printCountsRow(cm, out, labels, label); - } - } - - private static void printCountsRow(ConfusionMatrix cm, - PrintStream out, - Iterable<String> labels, - String rowLabel) { - out.println("<tr>"); - format("<td class='%s'>%s</td>", out, CSS_LABEL, rowLabel); - for (String columnLabel : labels) { - printCountsCell(cm, out, rowLabel, columnLabel); - } - out.println("</tr>"); - } - - private static void printCountsCell(ConfusionMatrix cm, PrintStream out, String rowLabel, String columnLabel) { - int count = cm.getCount(rowLabel, columnLabel); - String s = count == 0 ? "" : Integer.toString(count); - format("<td class='%s' title='%s'>%s</td>", out, CSS_CELL, columnLabel, s); - } - - private static void printCountsHeader(ConfusionMatrix cm, PrintStream out, boolean vertical) { - List<String> labels = stripDefault(cm); - int longest = getLongestHeader(labels); - if (vertical) { - // do vertical - rotation is a bitch - out.format("<tr class='%s' style='height:%dem'><th> </th>%n", CSS_TALL_HEADER, longest / 2); - for (String label : labels) { - out.format("<th><div class='%s'>%s</div></th>", CSS_VERTICAL, label); - } - out.println("</tr>"); - } else { - // header - empty cell in upper left - out.format("<tr class='%s'><td class='%s'></td>%n", CSS_TABLE, CSS_LABEL); - for (String label : labels) { - out.format("<td>%s</td>", label); - } - out.format("</tr>"); - } - } - - private static int getLongestHeader(Iterable<String> labels) { - int max = 0; - for (String label : labels) { - max = Math.max(label.length(), max); - } - return max; - } - - private static void format(String format, PrintStream out, Object... args) { - String format2 = String.format(format, args); - out.println(format2); - } - - public static void printHeader(PrintStream out, CharSequence title) { - out.println(HEADER.replace("TITLE", title)); - } - - public static void printFooter(PrintStream out) { - out.println(FOOTER); - } - -} http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/integration/src/main/java/org/apache/mahout/clustering/cdbw/CDbwEvaluator.java ---------------------------------------------------------------------- diff --git a/integration/src/main/java/org/apache/mahout/clustering/cdbw/CDbwEvaluator.java b/integration/src/main/java/org/apache/mahout/clustering/cdbw/CDbwEvaluator.java deleted file mode 100644 index 545c1ff..0000000 --- a/integration/src/main/java/org/apache/mahout/clustering/cdbw/CDbwEvaluator.java +++ /dev/null @@ -1,387 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.clustering.cdbw; - -import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.fs.Path; -import org.apache.mahout.clustering.Cluster; -import org.apache.mahout.clustering.GaussianAccumulator; -import org.apache.mahout.clustering.OnlineGaussianAccumulator; -import org.apache.mahout.clustering.evaluation.RepresentativePointsDriver; -import org.apache.mahout.clustering.evaluation.RepresentativePointsMapper; -import org.apache.mahout.clustering.iterator.ClusterWritable; -import org.apache.mahout.common.ClassUtils; -import org.apache.mahout.common.distance.DistanceMeasure; -import org.apache.mahout.common.iterator.sequencefile.PathFilters; -import org.apache.mahout.common.iterator.sequencefile.PathType; -import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterable; -import org.apache.mahout.math.RandomAccessSparseVector; -import org.apache.mahout.math.Vector; -import org.apache.mahout.math.Vector.Element; -import org.apache.mahout.math.VectorWritable; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.TreeMap; - -/** - * This class calculates the CDbw metric as defined in - * http://www.db-net.aueb.gr/index.php/corporate/content/download/227/833/file/HV_poster2002.pdf - */ -public final class CDbwEvaluator { - - private static final Logger log = LoggerFactory.getLogger(CDbwEvaluator.class); - - private final Map<Integer,List<VectorWritable>> representativePoints; - private final Map<Integer,Double> stDevs = new HashMap<>(); - private final List<Cluster> clusters; - private final DistanceMeasure measure; - private Double interClusterDensity = null; - // these are symmetric so we only compute half of them - private Map<Integer,Map<Integer,Double>> minimumDistances = null; - // these are symmetric too - private Map<Integer,Map<Integer,Double>> interClusterDensities = null; - // these are symmetric too - private Map<Integer,Map<Integer,int[]>> closestRepPointIndices = null; - - /** - * For testing only - * - * @param representativePoints - * a Map<Integer,List<VectorWritable>> of representative points keyed by clusterId - * @param clusters - * a Map<Integer,Cluster> of the clusters keyed by clusterId - * @param measure - * an appropriate DistanceMeasure - */ - public CDbwEvaluator(Map<Integer,List<VectorWritable>> representativePoints, List<Cluster> clusters, - DistanceMeasure measure) { - this.representativePoints = representativePoints; - this.clusters = clusters; - this.measure = measure; - for (Integer cId : representativePoints.keySet()) { - computeStd(cId); - } - } - - /** - * Initialize a new instance from job information - * - * @param conf - * a Configuration with appropriate parameters - * @param clustersIn - * a String path to the input clusters directory - */ - public CDbwEvaluator(Configuration conf, Path clustersIn) { - measure = ClassUtils - .instantiateAs(conf.get(RepresentativePointsDriver.DISTANCE_MEASURE_KEY), DistanceMeasure.class); - representativePoints = RepresentativePointsMapper.getRepresentativePoints(conf); - clusters = loadClusters(conf, clustersIn); - for (Integer cId : representativePoints.keySet()) { - computeStd(cId); - } - } - - /** - * Load the clusters from their sequence files - * - * @param clustersIn - * a String pathname to the directory containing input cluster files - * @return a List<Cluster> of the clusters - */ - private static List<Cluster> loadClusters(Configuration conf, Path clustersIn) { - List<Cluster> clusters = new ArrayList<>(); - for (ClusterWritable clusterWritable : new SequenceFileDirValueIterable<ClusterWritable>(clustersIn, PathType.LIST, - PathFilters.logsCRCFilter(), conf)) { - Cluster cluster = clusterWritable.getValue(); - clusters.add(cluster); - } - return clusters; - } - - /** - * Compute the standard deviation of the representative points for the given cluster. Store these in stDevs, indexed - * by cI - * - * @param cI - * a int clusterId. - */ - private void computeStd(int cI) { - List<VectorWritable> repPts = representativePoints.get(cI); - GaussianAccumulator accumulator = new OnlineGaussianAccumulator(); - for (VectorWritable vw : repPts) { - accumulator.observe(vw.get(), 1.0); - } - accumulator.compute(); - double d = accumulator.getAverageStd(); - stDevs.put(cI, d); - } - - /** - * Compute the density of points near the midpoint between the two closest points of the clusters (eqn 2) used for - * inter-cluster density calculation - * - * @param uIJ - * the Vector midpoint between the closest representative points of the clusters - * @param cI - * the int clusterId of the i-th cluster - * @param cJ - * the int clusterId of the j-th cluster - * @param avgStd - * the double average standard deviation of the two clusters - * @return a double - */ - private double density(Vector uIJ, int cI, int cJ, double avgStd) { - List<VectorWritable> repI = representativePoints.get(cI); - List<VectorWritable> repJ = representativePoints.get(cJ); - double sum = 0.0; - // count the number of representative points of the clusters which are within the - // average std of the two clusters from the midpoint uIJ (eqn 3) - for (VectorWritable vwI : repI) { - if (uIJ != null && measure.distance(uIJ, vwI.get()) <= avgStd) { - sum++; - } - } - for (VectorWritable vwJ : repJ) { - if (uIJ != null && measure.distance(uIJ, vwJ.get()) <= avgStd) { - sum++; - } - } - int nI = repI.size(); - int nJ = repJ.size(); - return sum / (nI + nJ); - } - - /** - * Compute the CDbw validity metric (eqn 8). The goal of this metric is to reward clusterings which have a high - * intraClusterDensity and also a high cluster separation. - * - * @return a double - */ - public double getCDbw() { - return intraClusterDensity() * separation(); - } - - /** - * The average density within clusters is defined as the percentage of representative points that reside in the - * neighborhood of the clusters' centers. The goal is the density within clusters to be significantly high. (eqn 5) - * - * @return a double - */ - public double intraClusterDensity() { - double avgDensity = 0; - int count = 0; - for (Element elem : intraClusterDensities().nonZeroes()) { - double value = elem.get(); - if (!Double.isNaN(value)) { - avgDensity += value; - count++; - } - } - return avgDensity / count; - } - - /** - * This function evaluates the density of points in the regions between each clusters (eqn 1). The goal is the density - * in the area between clusters to be significant low. - * - * @return a Map<Integer,Map<Integer,Double>> of the inter-cluster densities - */ - public Map<Integer,Map<Integer,Double>> interClusterDensities() { - if (interClusterDensities != null) { - return interClusterDensities; - } - interClusterDensities = new TreeMap<>(); - // find the closest representative points between the clusters - for (int i = 0; i < clusters.size(); i++) { - int cI = clusters.get(i).getId(); - Map<Integer,Double> map = new TreeMap<>(); - interClusterDensities.put(cI, map); - for (int j = i + 1; j < clusters.size(); j++) { - int cJ = clusters.get(j).getId(); - double minDistance = minimumDistance(cI, cJ); // the distance between the closest representative points - Vector uIJ = midpointVector(cI, cJ); // the midpoint between the closest representative points - double stdSum = stDevs.get(cI) + stDevs.get(cJ); - double density = density(uIJ, cI, cJ, stdSum / 2); - double interDensity = minDistance * density / stdSum; - map.put(cJ, interDensity); - if (log.isDebugEnabled()) { - log.debug("minDistance[{},{}]={}", cI, cJ, minDistance); - log.debug("interDensity[{},{}]={}", cI, cJ, density); - log.debug("density[{},{}]={}", cI, cJ, interDensity); - } - } - } - return interClusterDensities; - } - - /** - * Calculate the separation of clusters (eqn 4) taking into account both the distances between the clusters' closest - * points and the Inter-cluster density. The goal is the distances between clusters to be high while the - * representative point density in the areas between them are low. - * - * @return a double - */ - public double separation() { - double minDistanceSum = 0; - Map<Integer,Map<Integer,Double>> distances = minimumDistances(); - for (Map<Integer,Double> map : distances.values()) { - for (Double dist : map.values()) { - if (!Double.isInfinite(dist)) { - minDistanceSum += dist * 2; // account for other half of calculated triangular minimumDistances matrix - } - } - } - return minDistanceSum / (1.0 + interClusterDensity()); - } - - /** - * This function evaluates the average density of points in the regions between clusters (eqn 1). The goal is the - * density in the area between clusters to be significant low. - * - * @return a double - */ - public double interClusterDensity() { - if (interClusterDensity != null) { - return interClusterDensity; - } - double sum = 0.0; - int count = 0; - Map<Integer,Map<Integer,Double>> distances = interClusterDensities(); - for (Map<Integer,Double> row : distances.values()) { - for (Double density : row.values()) { - if (!Double.isNaN(density)) { - sum += density; - count++; - } - } - } - log.debug("interClusterDensity={}", sum); - interClusterDensity = sum / count; - return interClusterDensity; - } - - /** - * The average density within clusters is defined as the percentage of representative points that reside in the - * neighborhood of the clusters' centers. The goal is the density within clusters to be significantly high. (eqn 5) - * - * @return a Vector of the intra-densities of each clusterId - */ - public Vector intraClusterDensities() { - Vector densities = new RandomAccessSparseVector(Integer.MAX_VALUE); - // compute the average standard deviation of the clusters - double stdev = 0.0; - for (Integer cI : representativePoints.keySet()) { - stdev += stDevs.get(cI); - } - int c = representativePoints.size(); - stdev /= c; - for (Cluster cluster : clusters) { - Integer cI = cluster.getId(); - List<VectorWritable> repPtsI = representativePoints.get(cI); - int r = repPtsI.size(); - double sumJ = 0.0; - // compute the term density (eqn 6) - for (VectorWritable pt : repPtsI) { - // compute f(x, vIJ) (eqn 7) - Vector repJ = pt.get(); - double densityIJ = measure.distance(cluster.getCenter(), repJ) <= stdev ? 1.0 : 0.0; - // accumulate sumJ - sumJ += densityIJ / stdev; - } - densities.set(cI, sumJ / r); - } - return densities; - } - - /** - * Calculate and cache the distances between the clusters' closest representative points. Also cache the indices of - * the closest representative points used for later use - * - * @return a Map<Integer,Vector> of the closest distances, keyed by clusterId - */ - private Map<Integer,Map<Integer,Double>> minimumDistances() { - if (minimumDistances != null) { - return minimumDistances; - } - minimumDistances = new TreeMap<>(); - closestRepPointIndices = new TreeMap<>(); - for (int i = 0; i < clusters.size(); i++) { - Integer cI = clusters.get(i).getId(); - Map<Integer,Double> map = new TreeMap<>(); - Map<Integer,int[]> treeMap = new TreeMap<>(); - closestRepPointIndices.put(cI, treeMap); - minimumDistances.put(cI, map); - List<VectorWritable> closRepI = representativePoints.get(cI); - for (int j = i + 1; j < clusters.size(); j++) { - // find min{d(closRepI, closRepJ)} - Integer cJ = clusters.get(j).getId(); - List<VectorWritable> closRepJ = representativePoints.get(cJ); - double minDistance = Double.MAX_VALUE; - int[] midPointIndices = null; - for (int xI = 0; xI < closRepI.size(); xI++) { - VectorWritable aRepI = closRepI.get(xI); - for (int xJ = 0; xJ < closRepJ.size(); xJ++) { - VectorWritable aRepJ = closRepJ.get(xJ); - double distance = measure.distance(aRepI.get(), aRepJ.get()); - if (distance < minDistance) { - minDistance = distance; - midPointIndices = new int[] {xI, xJ}; - } - } - } - map.put(cJ, minDistance); - treeMap.put(cJ, midPointIndices); - } - } - return minimumDistances; - } - - private double minimumDistance(int cI, int cJ) { - Map<Integer,Double> distances = minimumDistances().get(cI); - if (distances != null) { - return distances.get(cJ); - } else { - return minimumDistances().get(cJ).get(cI); - } - } - - private Vector midpointVector(int cI, int cJ) { - Map<Integer,Double> distances = minimumDistances().get(cI); - if (distances != null) { - int[] ks = closestRepPointIndices.get(cI).get(cJ); - if (ks == null) { - return null; - } - return representativePoints.get(cI).get(ks[0]).get().plus(representativePoints.get(cJ).get(ks[1]).get()) - .divide(2); - } else { - int[] ks = closestRepPointIndices.get(cJ).get(cI); - if (ks == null) { - return null; - } - return representativePoints.get(cJ).get(ks[1]).get().plus(representativePoints.get(cI).get(ks[0]).get()) - .divide(2); - } - - } -} http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/integration/src/main/java/org/apache/mahout/clustering/conversion/InputDriver.java ---------------------------------------------------------------------- diff --git a/integration/src/main/java/org/apache/mahout/clustering/conversion/InputDriver.java b/integration/src/main/java/org/apache/mahout/clustering/conversion/InputDriver.java deleted file mode 100644 index 6a2b376..0000000 --- a/integration/src/main/java/org/apache/mahout/clustering/conversion/InputDriver.java +++ /dev/null @@ -1,114 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.clustering.conversion; - -import java.io.IOException; - -import org.apache.commons.cli2.CommandLine; -import org.apache.commons.cli2.Group; -import org.apache.commons.cli2.Option; -import org.apache.commons.cli2.OptionException; -import org.apache.commons.cli2.builder.ArgumentBuilder; -import org.apache.commons.cli2.builder.DefaultOptionBuilder; -import org.apache.commons.cli2.builder.GroupBuilder; -import org.apache.commons.cli2.commandline.Parser; -import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.fs.Path; -import org.apache.hadoop.io.Text; -import org.apache.hadoop.mapreduce.Job; -import org.apache.hadoop.mapreduce.lib.input.FileInputFormat; -import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; -import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; -import org.apache.mahout.common.CommandLineUtil; -import org.apache.mahout.common.commandline.DefaultOptionCreator; -import org.apache.mahout.math.VectorWritable; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * This class converts text files containing space-delimited floating point numbers into - * Mahout sequence files of VectorWritable suitable for input to the clustering jobs in - * particular, and any Mahout job requiring this input in general. - * - */ -public final class InputDriver { - - private static final Logger log = LoggerFactory.getLogger(InputDriver.class); - - private InputDriver() { - } - - public static void main(String[] args) throws IOException, InterruptedException, ClassNotFoundException { - DefaultOptionBuilder obuilder = new DefaultOptionBuilder(); - ArgumentBuilder abuilder = new ArgumentBuilder(); - GroupBuilder gbuilder = new GroupBuilder(); - - Option inputOpt = DefaultOptionCreator.inputOption().withRequired(false).create(); - Option outputOpt = DefaultOptionCreator.outputOption().withRequired(false).create(); - Option vectorOpt = obuilder.withLongName("vector").withRequired(false).withArgument( - abuilder.withName("v").withMinimum(1).withMaximum(1).create()).withDescription( - "The vector implementation to use.").withShortName("v").create(); - - Option helpOpt = DefaultOptionCreator.helpOption(); - - Group group = gbuilder.withName("Options").withOption(inputOpt).withOption(outputOpt).withOption( - vectorOpt).withOption(helpOpt).create(); - - try { - Parser parser = new Parser(); - parser.setGroup(group); - CommandLine cmdLine = parser.parse(args); - if (cmdLine.hasOption(helpOpt)) { - CommandLineUtil.printHelp(group); - return; - } - - Path input = new Path(cmdLine.getValue(inputOpt, "testdata").toString()); - Path output = new Path(cmdLine.getValue(outputOpt, "output").toString()); - String vectorClassName = cmdLine.getValue(vectorOpt, - "org.apache.mahout.math.RandomAccessSparseVector").toString(); - runJob(input, output, vectorClassName); - } catch (OptionException e) { - log.error("Exception parsing command line: ", e); - CommandLineUtil.printHelp(group); - } - } - - public static void runJob(Path input, Path output, String vectorClassName) - throws IOException, InterruptedException, ClassNotFoundException { - Configuration conf = new Configuration(); - conf.set("vector.implementation.class.name", vectorClassName); - Job job = new Job(conf, "Input Driver running over input: " + input); - - job.setOutputKeyClass(Text.class); - job.setOutputValueClass(VectorWritable.class); - job.setOutputFormatClass(SequenceFileOutputFormat.class); - job.setMapperClass(InputMapper.class); - job.setNumReduceTasks(0); - job.setJarByClass(InputDriver.class); - - FileInputFormat.addInputPath(job, input); - FileOutputFormat.setOutputPath(job, output); - - boolean succeeded = job.waitForCompletion(true); - if (!succeeded) { - throw new IllegalStateException("Job failed!"); - } - } - -} http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/integration/src/main/java/org/apache/mahout/clustering/conversion/InputMapper.java ---------------------------------------------------------------------- diff --git a/integration/src/main/java/org/apache/mahout/clustering/conversion/InputMapper.java b/integration/src/main/java/org/apache/mahout/clustering/conversion/InputMapper.java deleted file mode 100644 index e4c72c6..0000000 --- a/integration/src/main/java/org/apache/mahout/clustering/conversion/InputMapper.java +++ /dev/null @@ -1,81 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.clustering.conversion; - -import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.io.LongWritable; -import org.apache.hadoop.io.Text; -import org.apache.hadoop.mapreduce.Mapper; -import org.apache.mahout.math.Vector; -import org.apache.mahout.math.VectorWritable; - -import java.io.IOException; -import java.lang.reflect.Constructor; -import java.lang.reflect.InvocationTargetException; -import java.util.ArrayList; -import java.util.Collection; -import java.util.regex.Pattern; - -public class InputMapper extends Mapper<LongWritable, Text, Text, VectorWritable> { - - private static final Pattern SPACE = Pattern.compile(" "); - - private Constructor<?> constructor; - - @Override - protected void map(LongWritable key, Text values, Context context) throws IOException, InterruptedException { - - String[] numbers = SPACE.split(values.toString()); - // sometimes there are multiple separator spaces - Collection<Double> doubles = new ArrayList<>(); - for (String value : numbers) { - if (!value.isEmpty()) { - doubles.add(Double.valueOf(value)); - } - } - // ignore empty lines in data file - if (!doubles.isEmpty()) { - try { - Vector result = (Vector) constructor.newInstance(doubles.size()); - int index = 0; - for (Double d : doubles) { - result.set(index++, d); - } - VectorWritable vectorWritable = new VectorWritable(result); - context.write(new Text(String.valueOf(index)), vectorWritable); - - } catch (InstantiationException | IllegalAccessException | InvocationTargetException e) { - throw new IllegalStateException(e); - } - } - } - - @Override - protected void setup(Context context) throws IOException, InterruptedException { - super.setup(context); - Configuration conf = context.getConfiguration(); - String vectorImplClassName = conf.get("vector.implementation.class.name"); - try { - Class<? extends Vector> outputClass = conf.getClassByName(vectorImplClassName).asSubclass(Vector.class); - constructor = outputClass.getConstructor(int.class); - } catch (NoSuchMethodException | ClassNotFoundException e) { - throw new IllegalStateException(e); - } - } - -} http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/integration/src/main/java/org/apache/mahout/clustering/evaluation/ClusterEvaluator.java ---------------------------------------------------------------------- diff --git a/integration/src/main/java/org/apache/mahout/clustering/evaluation/ClusterEvaluator.java b/integration/src/main/java/org/apache/mahout/clustering/evaluation/ClusterEvaluator.java deleted file mode 100644 index 757f38c..0000000 --- a/integration/src/main/java/org/apache/mahout/clustering/evaluation/ClusterEvaluator.java +++ /dev/null @@ -1,196 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.clustering.evaluation; - -import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.fs.Path; -import org.apache.mahout.clustering.Cluster; -import org.apache.mahout.clustering.iterator.ClusterWritable; -import org.apache.mahout.common.ClassUtils; -import org.apache.mahout.common.distance.DistanceMeasure; -import org.apache.mahout.common.iterator.sequencefile.PathFilters; -import org.apache.mahout.common.iterator.sequencefile.PathType; -import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterable; -import org.apache.mahout.math.RandomAccessSparseVector; -import org.apache.mahout.math.Vector; -import org.apache.mahout.math.Vector.Element; -import org.apache.mahout.math.VectorWritable; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.TreeMap; - -public class ClusterEvaluator { - - private static final Logger log = LoggerFactory.getLogger(ClusterEvaluator.class); - - private final Map<Integer,List<VectorWritable>> representativePoints; - - private final List<Cluster> clusters; - - private final DistanceMeasure measure; - - /** - * For testing only - * - * @param representativePoints - * a Map<Integer,List<VectorWritable>> of representative points keyed by clusterId - * @param clusters - * a Map<Integer,Cluster> of the clusters keyed by clusterId - * @param measure - * an appropriate DistanceMeasure - */ - public ClusterEvaluator(Map<Integer,List<VectorWritable>> representativePoints, List<Cluster> clusters, - DistanceMeasure measure) { - this.representativePoints = representativePoints; - this.clusters = clusters; - this.measure = measure; - } - - /** - * Initialize a new instance from job information - * - * @param conf - * a Configuration with appropriate parameters - * @param clustersIn - * a String path to the input clusters directory - */ - public ClusterEvaluator(Configuration conf, Path clustersIn) { - measure = ClassUtils - .instantiateAs(conf.get(RepresentativePointsDriver.DISTANCE_MEASURE_KEY), DistanceMeasure.class); - representativePoints = RepresentativePointsMapper.getRepresentativePoints(conf); - clusters = loadClusters(conf, clustersIn); - } - - /** - * Load the clusters from their sequence files - * - * @param clustersIn - * a String pathname to the directory containing input cluster files - * @return a List<Cluster> of the clusters - */ - private static List<Cluster> loadClusters(Configuration conf, Path clustersIn) { - List<Cluster> clusters = new ArrayList<>(); - for (ClusterWritable clusterWritable : new SequenceFileDirValueIterable<ClusterWritable>(clustersIn, PathType.LIST, - PathFilters.logsCRCFilter(), conf)) { - Cluster cluster = clusterWritable.getValue(); - clusters.add(cluster); - } - return clusters; - } - - /** - * Computes the inter-cluster density as defined in "Mahout In Action" - * - * @return the interClusterDensity - */ - public double interClusterDensity() { - double max = Double.NEGATIVE_INFINITY; - double min = Double.POSITIVE_INFINITY; - double sum = 0; - int count = 0; - Map<Integer,Vector> distances = interClusterDistances(); - for (Vector row : distances.values()) { - for (Element element : row.nonZeroes()) { - double d = element.get(); - min = Math.min(d, min); - max = Math.max(d, max); - sum += d; - count++; - } - } - double density = (sum / count - min) / (max - min); - log.info("Scaled Inter-Cluster Density = {}", density); - return density; - } - - /** - * Computes the inter-cluster distances - * - * @return a Map<Integer, Vector> - */ - public Map<Integer,Vector> interClusterDistances() { - Map<Integer,Vector> distances = new TreeMap<>(); - for (int i = 0; i < clusters.size(); i++) { - Cluster clusterI = clusters.get(i); - RandomAccessSparseVector row = new RandomAccessSparseVector(Integer.MAX_VALUE); - distances.put(clusterI.getId(), row); - for (int j = i + 1; j < clusters.size(); j++) { - Cluster clusterJ = clusters.get(j); - double d = measure.distance(clusterI.getCenter(), clusterJ.getCenter()); - row.set(clusterJ.getId(), d); - } - } - return distances; - } - - /** - * Computes the average intra-cluster density as the average of each cluster's intra-cluster density - * - * @return the average intraClusterDensity - */ - public double intraClusterDensity() { - double avgDensity = 0; - int count = 0; - for (Element elem : intraClusterDensities().nonZeroes()) { - double value = elem.get(); - if (!Double.isNaN(value)) { - avgDensity += value; - count++; - } - } - avgDensity = clusters.isEmpty() ? 0 : avgDensity / count; - log.info("Average Intra-Cluster Density = {}", avgDensity); - return avgDensity; - } - - /** - * Computes the intra-cluster densities for all clusters as the average distance of the representative points from - * each other - * - * @return a Vector of the intraClusterDensity of the representativePoints by clusterId - */ - public Vector intraClusterDensities() { - Vector densities = new RandomAccessSparseVector(Integer.MAX_VALUE); - for (Cluster cluster : clusters) { - int count = 0; - double max = Double.NEGATIVE_INFINITY; - double min = Double.POSITIVE_INFINITY; - double sum = 0; - List<VectorWritable> repPoints = representativePoints.get(cluster.getId()); - for (int i = 0; i < repPoints.size(); i++) { - for (int j = i + 1; j < repPoints.size(); j++) { - Vector v1 = repPoints.get(i).get(); - Vector v2 = repPoints.get(j).get(); - double d = measure.distance(v1, v2); - min = Math.min(d, min); - max = Math.max(d, max); - sum += d; - count++; - } - } - double density = (sum / count - min) / (max - min); - densities.set(cluster.getId(), density); - log.info("Intra-Cluster Density[{}] = {}", cluster.getId(), density); - } - return densities; - } -} http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/integration/src/main/java/org/apache/mahout/clustering/evaluation/RepresentativePointsDriver.java ---------------------------------------------------------------------- diff --git a/integration/src/main/java/org/apache/mahout/clustering/evaluation/RepresentativePointsDriver.java b/integration/src/main/java/org/apache/mahout/clustering/evaluation/RepresentativePointsDriver.java deleted file mode 100644 index 2fe37ef..0000000 --- a/integration/src/main/java/org/apache/mahout/clustering/evaluation/RepresentativePointsDriver.java +++ /dev/null @@ -1,243 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.clustering.evaluation; - -import java.io.IOException; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Map.Entry; -import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.fs.FileStatus; -import org.apache.hadoop.fs.FileSystem; -import org.apache.hadoop.fs.Path; -import org.apache.hadoop.io.IntWritable; -import org.apache.hadoop.io.SequenceFile; -import org.apache.hadoop.mapreduce.Job; -import org.apache.hadoop.mapreduce.lib.input.FileInputFormat; -import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat; -import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; -import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; -import org.apache.hadoop.util.ToolRunner; -import org.apache.mahout.clustering.AbstractCluster; -import org.apache.mahout.clustering.Cluster; -import org.apache.mahout.clustering.classify.WeightedVectorWritable; -import org.apache.mahout.clustering.iterator.ClusterWritable; -import org.apache.mahout.common.AbstractJob; -import org.apache.mahout.common.ClassUtils; -import org.apache.mahout.common.Pair; -import org.apache.mahout.common.commandline.DefaultOptionCreator; -import org.apache.mahout.common.distance.DistanceMeasure; -import org.apache.mahout.common.iterator.sequencefile.PathFilters; -import org.apache.mahout.common.iterator.sequencefile.PathType; -import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable; -import org.apache.mahout.common.iterator.sequencefile.SequenceFileValueIterable; -import org.apache.mahout.math.VectorWritable; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -public final class RepresentativePointsDriver extends AbstractJob { - - public static final String STATE_IN_KEY = "org.apache.mahout.clustering.stateIn"; - - public static final String DISTANCE_MEASURE_KEY = "org.apache.mahout.clustering.measure"; - - private static final Logger log = LoggerFactory.getLogger(RepresentativePointsDriver.class); - - private RepresentativePointsDriver() {} - - public static void main(String[] args) throws Exception { - ToolRunner.run(new Configuration(), new RepresentativePointsDriver(), args); - } - - @Override - public int run(String[] args) throws ClassNotFoundException, IOException, InterruptedException { - addInputOption(); - addOutputOption(); - addOption("clusteredPoints", "cp", "The path to the clustered points", true); - addOption(DefaultOptionCreator.distanceMeasureOption().create()); - addOption(DefaultOptionCreator.maxIterationsOption().create()); - addOption(DefaultOptionCreator.methodOption().create()); - if (parseArguments(args) == null) { - return -1; - } - - Path input = getInputPath(); - Path output = getOutputPath(); - String distanceMeasureClass = getOption(DefaultOptionCreator.DISTANCE_MEASURE_OPTION); - int maxIterations = Integer.parseInt(getOption(DefaultOptionCreator.MAX_ITERATIONS_OPTION)); - boolean runSequential = getOption(DefaultOptionCreator.METHOD_OPTION).equalsIgnoreCase( - DefaultOptionCreator.SEQUENTIAL_METHOD); - DistanceMeasure measure = ClassUtils.instantiateAs(distanceMeasureClass, DistanceMeasure.class); - Path clusteredPoints = new Path(getOption("clusteredPoints")); - run(getConf(), input, clusteredPoints, output, measure, maxIterations, runSequential); - return 0; - } - - /** - * Utility to print out representative points - * - * @param output - * the Path to the directory containing representativePoints-i folders - * @param numIterations - * the int number of iterations to print - */ - public static void printRepresentativePoints(Path output, int numIterations) { - for (int i = 0; i <= numIterations; i++) { - Path out = new Path(output, "representativePoints-" + i); - System.out.println("Representative Points for iteration " + i); - Configuration conf = new Configuration(); - for (Pair<IntWritable,VectorWritable> record : new SequenceFileDirIterable<IntWritable,VectorWritable>(out, - PathType.LIST, PathFilters.logsCRCFilter(), null, true, conf)) { - System.out.println("\tC-" + record.getFirst().get() + ": " - + AbstractCluster.formatVector(record.getSecond().get(), null)); - } - } - } - - public static void run(Configuration conf, Path clustersIn, Path clusteredPointsIn, Path output, - DistanceMeasure measure, int numIterations, boolean runSequential) throws IOException, InterruptedException, - ClassNotFoundException { - Path stateIn = new Path(output, "representativePoints-0"); - writeInitialState(stateIn, clustersIn); - - for (int iteration = 0; iteration < numIterations; iteration++) { - log.info("Representative Points Iteration {}", iteration); - // point the output to a new directory per iteration - Path stateOut = new Path(output, "representativePoints-" + (iteration + 1)); - runIteration(conf, clusteredPointsIn, stateIn, stateOut, measure, runSequential); - // now point the input to the old output directory - stateIn = stateOut; - } - - conf.set(STATE_IN_KEY, stateIn.toString()); - conf.set(DISTANCE_MEASURE_KEY, measure.getClass().getName()); - } - - private static void writeInitialState(Path output, Path clustersIn) throws IOException { - Configuration conf = new Configuration(); - FileSystem fs = FileSystem.get(output.toUri(), conf); - for (FileStatus dir : fs.globStatus(clustersIn)) { - Path inPath = dir.getPath(); - for (FileStatus part : fs.listStatus(inPath, PathFilters.logsCRCFilter())) { - Path inPart = part.getPath(); - Path path = new Path(output, inPart.getName()); - try (SequenceFile.Writer writer = - new SequenceFile.Writer(fs, conf, path, IntWritable.class, VectorWritable.class)){ - for (ClusterWritable clusterWritable : new SequenceFileValueIterable<ClusterWritable>(inPart, true, conf)) { - Cluster cluster = clusterWritable.getValue(); - if (log.isDebugEnabled()) { - log.debug("C-{}: {}", cluster.getId(), AbstractCluster.formatVector(cluster.getCenter(), null)); - } - writer.append(new IntWritable(cluster.getId()), new VectorWritable(cluster.getCenter())); - } - } - } - } - } - - private static void runIteration(Configuration conf, Path clusteredPointsIn, Path stateIn, Path stateOut, - DistanceMeasure measure, boolean runSequential) throws IOException, InterruptedException, ClassNotFoundException { - if (runSequential) { - runIterationSeq(conf, clusteredPointsIn, stateIn, stateOut, measure); - } else { - runIterationMR(conf, clusteredPointsIn, stateIn, stateOut, measure); - } - } - - /** - * Run the job using supplied arguments as a sequential process - * - * @param conf - * the Configuration to use - * @param clusteredPointsIn - * the directory pathname for input points - * @param stateIn - * the directory pathname for input state - * @param stateOut - * the directory pathname for output state - * @param measure - * the DistanceMeasure to use - */ - private static void runIterationSeq(Configuration conf, Path clusteredPointsIn, Path stateIn, Path stateOut, - DistanceMeasure measure) throws IOException { - - Map<Integer,List<VectorWritable>> repPoints = RepresentativePointsMapper.getRepresentativePoints(conf, stateIn); - Map<Integer,WeightedVectorWritable> mostDistantPoints = new HashMap<>(); - FileSystem fs = FileSystem.get(clusteredPointsIn.toUri(), conf); - for (Pair<IntWritable,WeightedVectorWritable> record - : new SequenceFileDirIterable<IntWritable,WeightedVectorWritable>(clusteredPointsIn, PathType.LIST, - PathFilters.logsCRCFilter(), null, true, conf)) { - RepresentativePointsMapper.mapPoint(record.getFirst(), record.getSecond(), measure, repPoints, mostDistantPoints); - } - int part = 0; - try (SequenceFile.Writer writer = new SequenceFile.Writer(fs, conf, new Path(stateOut, "part-m-" + part++), - IntWritable.class, VectorWritable.class)){ - for (Entry<Integer,List<VectorWritable>> entry : repPoints.entrySet()) { - for (VectorWritable vw : entry.getValue()) { - writer.append(new IntWritable(entry.getKey()), vw); - } - } - } - try (SequenceFile.Writer writer = new SequenceFile.Writer(fs, conf, new Path(stateOut, "part-m-" + part++), - IntWritable.class, VectorWritable.class)){ - for (Map.Entry<Integer,WeightedVectorWritable> entry : mostDistantPoints.entrySet()) { - writer.append(new IntWritable(entry.getKey()), new VectorWritable(entry.getValue().getVector())); - } - } - } - - /** - * Run the job using supplied arguments as a Map/Reduce process - * - * @param conf - * the Configuration to use - * @param input - * the directory pathname for input points - * @param stateIn - * the directory pathname for input state - * @param stateOut - * the directory pathname for output state - * @param measure - * the DistanceMeasure to use - */ - private static void runIterationMR(Configuration conf, Path input, Path stateIn, Path stateOut, - DistanceMeasure measure) throws IOException, InterruptedException, ClassNotFoundException { - conf.set(STATE_IN_KEY, stateIn.toString()); - conf.set(DISTANCE_MEASURE_KEY, measure.getClass().getName()); - Job job = new Job(conf, "Representative Points Driver running over input: " + input); - job.setJarByClass(RepresentativePointsDriver.class); - job.setOutputKeyClass(IntWritable.class); - job.setOutputValueClass(VectorWritable.class); - job.setMapOutputKeyClass(IntWritable.class); - job.setMapOutputValueClass(WeightedVectorWritable.class); - - FileInputFormat.setInputPaths(job, input); - FileOutputFormat.setOutputPath(job, stateOut); - - job.setMapperClass(RepresentativePointsMapper.class); - job.setReducerClass(RepresentativePointsReducer.class); - job.setInputFormatClass(SequenceFileInputFormat.class); - job.setOutputFormatClass(SequenceFileOutputFormat.class); - - boolean succeeded = job.waitForCompletion(true); - if (!succeeded) { - throw new IllegalStateException("Job failed!"); - } - } -}
