Repository: mahout Updated Branches: refs/heads/master d38edb1cd -> bd1f7bdab
MAHOUT-1799:Read null row vectors from file in TextDelimeterReaderWriter driver, this closes apache/mahout#182 Project: http://git-wip-us.apache.org/repos/asf/mahout/repo Commit: http://git-wip-us.apache.org/repos/asf/mahout/commit/bd1f7bda Tree: http://git-wip-us.apache.org/repos/asf/mahout/tree/bd1f7bda Diff: http://git-wip-us.apache.org/repos/asf/mahout/diff/bd1f7bda Branch: refs/heads/master Commit: bd1f7bdabceeaaaffd3e7d9c372d40b3b714afc8 Parents: d38edb1 Author: smarthi <[email protected]> Authored: Sat May 21 20:05:41 2016 -0400 Committer: smarthi <[email protected]> Committed: Sat May 21 20:05:41 2016 -0400 ---------------------------------------------------------------------- .../mahout/math/indexeddataset/Schema.scala | 13 ++--- .../drivers/TextDelimitedReaderWriter.scala | 40 +++++++++------ .../TextDelimitedReaderWriterSuite.scala | 53 ++++++++++++++++++++ 3 files changed, 84 insertions(+), 22 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/mahout/blob/bd1f7bda/math-scala/src/main/scala/org/apache/mahout/math/indexeddataset/Schema.scala ---------------------------------------------------------------------- diff --git a/math-scala/src/main/scala/org/apache/mahout/math/indexeddataset/Schema.scala b/math-scala/src/main/scala/org/apache/mahout/math/indexeddataset/Schema.scala index 3b4a2e9..b7f120b 100644 --- a/math-scala/src/main/scala/org/apache/mahout/math/indexeddataset/Schema.scala +++ b/math-scala/src/main/scala/org/apache/mahout/math/indexeddataset/Schema.scala @@ -46,7 +46,7 @@ class Schema(params: Tuple2[String, Any]*) extends HashMap[String, Any] { * This tells the reader to input elements of the default (rowID<comma, tab, or space>columnID * <comma, tab, or space>here may be other ignored text...) */ -final object DefaultIndexedDatasetElementReadSchema extends Schema( +object DefaultIndexedDatasetElementReadSchema extends Schema( "delim" -> "[,\t ]", //comma, tab or space "filter" -> "", "rowIDColumn" -> 0, @@ -59,7 +59,7 @@ final object DefaultIndexedDatasetElementReadSchema extends Schema( * The default form: * (rowID<tab>columnID1:score1<space>columnID2:score2...) */ -final object DefaultIndexedDatasetWriteSchema extends Schema( +object DefaultIndexedDatasetWriteSchema extends Schema( "rowKeyDelim" -> "\t", "columnIdStrengthDelim" -> ":", "elementDelim" -> " ", @@ -70,10 +70,11 @@ final object DefaultIndexedDatasetWriteSchema extends Schema( * row-wise input. This tells the reader to input text lines of the form: * (rowID<tab>columnID1:score1,columnID2:score2,...) */ -final object DefaultIndexedDatasetReadSchema extends Schema( +object DefaultIndexedDatasetReadSchema extends Schema( "rowKeyDelim" -> "\t", "columnIdStrengthDelim" -> ":", - "elementDelim" -> " ") + "elementDelim" -> " ", + "omitScore" -> false) /** * Default Schema for reading a text delimited [[org.apache.mahout.math.indexeddataset.IndexedDataset]] file where @@ -84,7 +85,7 @@ final object DefaultIndexedDatasetReadSchema extends Schema( * (rowID<tab>columnID1<space>columnID2 ...) where presence indicates a score of 1. This is the default * output format for [[IndexedDatasetWriteBooleanSchema]] */ -final object IndexedDatasetReadBooleanSchema extends Schema( +object IndexedDatasetReadBooleanSchema extends Schema( "rowKeyDelim" -> "\t", "columnIdStrengthDelim" -> ":", "elementDelim" -> " ", @@ -96,7 +97,7 @@ final object IndexedDatasetReadBooleanSchema extends Schema( * [[org.apache.mahout.math.indexeddataset.IndexedDataset]] row of the form * (rowID<tab>columnID1<space>columnID2...) */ -final object IndexedDatasetWriteBooleanSchema extends Schema( +object IndexedDatasetWriteBooleanSchema extends Schema( "rowKeyDelim" -> "\t", "columnIdStrengthDelim" -> ":", "elementDelim" -> " ", http://git-wip-us.apache.org/repos/asf/mahout/blob/bd1f7bda/spark/src/main/scala/org/apache/mahout/drivers/TextDelimitedReaderWriter.scala ---------------------------------------------------------------------- diff --git a/spark/src/main/scala/org/apache/mahout/drivers/TextDelimitedReaderWriter.scala b/spark/src/main/scala/org/apache/mahout/drivers/TextDelimitedReaderWriter.scala index d4f1aea..93d977b 100644 --- a/spark/src/main/scala/org/apache/mahout/drivers/TextDelimitedReaderWriter.scala +++ b/spark/src/main/scala/org/apache/mahout/drivers/TextDelimitedReaderWriter.scala @@ -149,7 +149,8 @@ trait TDIndexedDatasetReader extends Reader[IndexedDatasetSpark]{ // get row and column IDs val interactions = rows.map { row => - row(0) -> row(1)// rowID token -> string of column IDs+strengths + // rowID token -> string of column IDs+strengths or null if empty (all elements zero) + row(0) -> (if (row.length > 1) row(1) else null) } interactions.cache() @@ -159,12 +160,15 @@ trait TDIndexedDatasetReader extends Reader[IndexedDatasetSpark]{ // the columns are in a TD string so separate them and get unique ones val columnIDs = interactions.flatMap { case (_, columns) => columns - val elements = columns.split(elementDelim) - val colIDs = if (!omitScore) - elements.map( elem => elem.split(columnIdStrengthDelim)(0) ) - else - elements - colIDs + if (columns == null) None + else { + val elements = columns.split(elementDelim) + val colIDs = if (!omitScore) + elements.map(elem => elem.split(columnIdStrengthDelim)(0)) + else + elements + colIDs + } }.distinct().collect() // create BiMaps for bi-directional lookup of ID by either Mahout ID or external ID @@ -186,17 +190,21 @@ trait TDIndexedDatasetReader extends Reader[IndexedDatasetSpark]{ interactions.map { case (rowID, columns) => val rowIndex = rowIDDictionary_bcast.value.getOrElse(rowID, -1) - val elements = columns.split(elementDelim) val row = new RandomAccessSparseVector(ncol) - for (element <- elements) { - val id = if (omitScore) element else element.split(columnIdStrengthDelim)(0) - val columnID = columnIDDictionary_bcast.value.getOrElse(id, -1) - val strength = if (omitScore) 1.0d else {// if the input says not to omit but there is no seperator treat - // as omitting and return a strength of 1 - if (element.split(columnIdStrengthDelim).size == 1) 1.0d - else element.split(columnIdStrengthDelim)(1).toDouble + if (columns != null) { + val elements = columns.split(elementDelim) + for (element <- elements) { + val id = if (omitScore) element else element.split(columnIdStrengthDelim)(0) + val columnID = columnIDDictionary_bcast.value.getOrElse(id, -1) + val strength = if (omitScore) 1.0d + else { + // if the input says not to omit but there is no seperator treat + // as omitting and return a strength of 1 + if (element.split(columnIdStrengthDelim).size == 1) 1.0d + else element.split(columnIdStrengthDelim)(1).toDouble + } + row.setQuick(columnID, strength) } - row.setQuick(columnID, strength) } rowIndex -> row } http://git-wip-us.apache.org/repos/asf/mahout/blob/bd1f7bda/spark/src/test/scala/org/apache/mahout/drivers/TextDelimitedReaderWriterSuite.scala ---------------------------------------------------------------------- diff --git a/spark/src/test/scala/org/apache/mahout/drivers/TextDelimitedReaderWriterSuite.scala b/spark/src/test/scala/org/apache/mahout/drivers/TextDelimitedReaderWriterSuite.scala new file mode 100644 index 0000000..5d92cca --- /dev/null +++ b/spark/src/test/scala/org/apache/mahout/drivers/TextDelimitedReaderWriterSuite.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.drivers + +import org.apache.mahout.math.indexeddataset.DefaultIndexedDatasetReadSchema +import org.apache.mahout.sparkbindings._ +import org.apache.mahout.sparkbindings.test.DistributedSparkSuite +import org.scalatest.FunSuite + +import scala.collection.JavaConversions._ + +class TextDelimitedReaderWriterSuite extends FunSuite with DistributedSparkSuite { + test("indexedDatasetDFSRead should read sparse matrix file with null rows") { + val OutFile = TmpDir + "similarity-matrices/part-00000" + + val lines = Array( + "galaxy\tnexus:1.0", + "ipad\tiphone:2.0", + "nexus\tgalaxy:3.0", + "iphone\tipad:4.0", + "surface" + ) + val linesRdd = mahoutCtx.parallelize(lines).saveAsTextFile(OutFile) + + val data = mahoutCtx.indexedDatasetDFSRead(OutFile, DefaultIndexedDatasetReadSchema) + + data.rowIDs.toMap.keySet should equal(Set("galaxy", "ipad", "nexus", "iphone", "surface")) + data.columnIDs.toMap.keySet should equal(Set("nexus", "iphone", "galaxy", "ipad")) + + val a = data.matrix.collect + a.setRowLabelBindings(mapAsJavaMap(data.rowIDs.toMap).asInstanceOf[java.util.Map[java.lang.String, java.lang.Integer]]) + a.setColumnLabelBindings(mapAsJavaMap(data.columnIDs.toMap).asInstanceOf[java.util.Map[java.lang.String, java.lang.Integer]]) + a.get("galaxy", "nexus") should equal(1.0) + a.get("ipad", "iphone") should equal(2.0) + a.get("nexus", "galaxy") should equal(3.0) + a.get("iphone", "ipad") should equal(4.0) + } +}
