Repository: spark Updated Branches: refs/heads/master 23a9448c0 -> 633aaae0a
[SPARK-6530] [ML] Add chi-square selector for ml package See JIRA [here](https://issues.apache.org/jira/browse/SPARK-6530). Author: Xusen Yin <yinxu...@gmail.com> Closes #5742 from yinxusen/SPARK-6530. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/633aaae0 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/633aaae0 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/633aaae0 Branch: refs/heads/master Commit: 633aaae0a1e31e9ba634423840e350b22342c6b5 Parents: 23a9448 Author: Xusen Yin <yinxu...@gmail.com> Authored: Fri Oct 2 10:25:58 2015 -0700 Committer: Joseph K. Bradley <jos...@databricks.com> Committed: Fri Oct 2 10:25:58 2015 -0700 ---------------------------------------------------------------------- .../apache/spark/ml/feature/ChiSqSelector.scala | 150 +++++++++++++++++++ .../spark/mllib/feature/ChiSqSelector.scala | 2 + .../spark/ml/feature/ChiSqSelectorSuite.scala | 61 ++++++++ 3 files changed, 213 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/633aaae0/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala new file mode 100644 index 0000000..5e4061f --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml._ +import org.apache.spark.ml.attribute.{AttributeGroup, _} +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util.SchemaUtils +import org.apache.spark.mllib.feature +import org.apache.spark.mllib.linalg.{Vector, VectorUDT} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.sql._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{DoubleType, StructField, StructType} + +/** + * Params for [[ChiSqSelector]] and [[ChiSqSelectorModel]]. + */ +private[feature] trait ChiSqSelectorParams extends Params + with HasFeaturesCol with HasOutputCol with HasLabelCol { + + /** + * Number of features that selector will select (ordered by statistic value descending). If the + * number of features is < numTopFeatures, then this will select all features. The default value + * of numTopFeatures is 50. + * @group param + */ + final val numTopFeatures = new IntParam(this, "numTopFeatures", + "Number of features that selector will select, ordered by statistics value descending. If the" + + " number of features is < numTopFeatures, then this will select all features.", + ParamValidators.gtEq(1)) + setDefault(numTopFeatures -> 50) + + /** @group getParam */ + def getNumTopFeatures: Int = $(numTopFeatures) +} + +/** + * :: Experimental :: + * Chi-Squared feature selection, which selects categorical features to use for predicting a + * categorical label. + */ +@Experimental +final class ChiSqSelector(override val uid: String) + extends Estimator[ChiSqSelectorModel] with ChiSqSelectorParams { + + def this() = this(Identifiable.randomUID("chiSqSelector")) + + /** @group setParam */ + def setNumTopFeatures(value: Int): this.type = set(numTopFeatures, value) + + /** @group setParam */ + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + /** @group setParam */ + def setLabelCol(value: String): this.type = set(labelCol, value) + + override def fit(dataset: DataFrame): ChiSqSelectorModel = { + transformSchema(dataset.schema, logging = true) + val input = dataset.select($(labelCol), $(featuresCol)).map { + case Row(label: Double, features: Vector) => + LabeledPoint(label, features) + } + val chiSqSelector = new feature.ChiSqSelector($(numTopFeatures)).fit(input) + copyValues(new ChiSqSelectorModel(uid, chiSqSelector).setParent(this)) + } + + override def transformSchema(schema: StructType): StructType = { + SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) + SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) + SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) + } + + override def copy(extra: ParamMap): ChiSqSelector = defaultCopy(extra) +} + +/** + * :: Experimental :: + * Model fitted by [[ChiSqSelector]]. + */ +@Experimental +final class ChiSqSelectorModel private[ml] ( + override val uid: String, + private val chiSqSelector: feature.ChiSqSelectorModel) + extends Model[ChiSqSelectorModel] with ChiSqSelectorParams { + + /** @group setParam */ + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + /** @group setParam */ + def setLabelCol(value: String): this.type = set(labelCol, value) + + override def transform(dataset: DataFrame): DataFrame = { + val transformedSchema = transformSchema(dataset.schema, logging = true) + val newField = transformedSchema.last + val selector = udf { chiSqSelector.transform _ } + dataset.withColumn($(outputCol), selector(col($(featuresCol))), newField.metadata) + } + + override def transformSchema(schema: StructType): StructType = { + SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) + val newField = prepOutputField(schema) + val outputFields = schema.fields :+ newField + StructType(outputFields) + } + + /** + * Prepare the output column field, including per-feature metadata. + */ + private def prepOutputField(schema: StructType): StructField = { + val selector = chiSqSelector.selectedFeatures.toSet + val origAttrGroup = AttributeGroup.fromStructField(schema($(featuresCol))) + val featureAttributes: Array[Attribute] = if (origAttrGroup.attributes.nonEmpty) { + origAttrGroup.attributes.get.zipWithIndex.filter(x => selector.contains(x._2)).map(_._1) + } else { + Array.fill[Attribute](selector.size)(NominalAttribute.defaultAttr) + } + val newAttributeGroup = new AttributeGroup($(outputCol), featureAttributes) + newAttributeGroup.toStructField() + } + + override def copy(extra: ParamMap): ChiSqSelectorModel = { + val copied = new ChiSqSelectorModel(uid, chiSqSelector) + copyValues(copied, extra).setParent(parent) + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/633aaae0/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala index 4743cfd..b1524cf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -109,6 +109,8 @@ class ChiSqSelectorModel @Since("1.3.0") ( * Creates a ChiSquared feature selector. * @param numTopFeatures number of features that selector will select * (ordered by statistic value descending) + * Note that if the number of features is < numTopFeatures, then this will + * select all features. */ @Since("1.3.0") @Experimental http://git-wip-us.apache.org/repos/asf/spark/blob/633aaae0/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala new file mode 100644 index 0000000..e5a4296 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.{Row, SQLContext} + +class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext { + test("Test Chi-Square selector") { + val sqlContext = SQLContext.getOrCreate(sc) + import sqlContext.implicits._ + + val data = Seq( + LabeledPoint(0.0, Vectors.sparse(3, Array((0, 8.0), (1, 7.0)))), + LabeledPoint(1.0, Vectors.sparse(3, Array((1, 9.0), (2, 6.0)))), + LabeledPoint(1.0, Vectors.dense(Array(0.0, 9.0, 8.0))), + LabeledPoint(2.0, Vectors.dense(Array(8.0, 9.0, 5.0))) + ) + + val preFilteredData = Seq( + Vectors.dense(0.0), + Vectors.dense(6.0), + Vectors.dense(8.0), + Vectors.dense(5.0) + ) + + val df = sc.parallelize(data.zip(preFilteredData)) + .map(x => (x._1.label, x._1.features, x._2)) + .toDF("label", "data", "preFilteredData") + + val model = new ChiSqSelector() + .setNumTopFeatures(1) + .setFeaturesCol("data") + .setLabelCol("label") + .setOutputCol("filtered") + + model.fit(df).transform(df).select("filtered", "preFilteredData").collect().foreach { + case Row(vec1: Vector, vec2: Vector) => + assert(vec1 ~== vec2 absTol 1e-1) + } + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org