Github user mengxr commented on a diff in the pull request:
https://github.com/apache/spark/pull/21493#discussion_r192909750
--- Diff:
mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala
---
@@ -182,66 +137,59 @@ class PowerIterationClustering private[clustering] (
/** @group setParam */
@Since("2.4.0")
- def setIdCol(value: String): this.type = set(idCol, value)
+ def setSrcCol(value: String): this.type = set(srcCol, value)
/** @group setParam */
@Since("2.4.0")
- def setNeighborsCol(value: String): this.type = set(neighborsCol, value)
+ def setDstCol(value: String): this.type = set(dstCol, value)
/** @group setParam */
@Since("2.4.0")
- def setSimilaritiesCol(value: String): this.type = set(similaritiesCol,
value)
+ def setWeightCol(value: String): this.type = set(weightCol, value)
+ /**
+ * @param dataset A dataset with columns src, dst, weight representing
the affinity matrix,
+ * which is the matrix A in the PIC paper. Suppose the
src column value is i,
+ * the dst column value is j, the weight column value is
similarity s,,ij,,
+ * must be nonnegative. This is a symmetric matrix and
hence s,,ij,, = s,,ji,,.
+ * For any (i, j) with nonzero similarity, there should
be either
+ * (i, j, s,,ij,,) or (j, i, s,,ji,,) in the input. Rows
with i = j are ignored,
+ * because we assume s,,ij,, = 0.0.
+ * @return A dataset that contains columns of vertex id and the
corresponding cluster for the id.
+ * The schema of it will be:
+ * - id: Long
+ * - cluster: Int
+ */
@Since("2.4.0")
- override def transform(dataset: Dataset[_]): DataFrame = {
- transformSchema(dataset.schema, logging = true)
+ def assignClusters(dataset: Dataset[_]): DataFrame = {
+ val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) {
+ lit(1.0)
+ } else {
+ col($(weightCol)).cast(DoubleType)
+ }
- val sparkSession = dataset.sparkSession
- val idColValue = $(idCol)
- val rdd: RDD[(Long, Long, Double)] =
- dataset.select(
- col($(idCol)).cast(LongType),
- col($(neighborsCol)).cast(ArrayType(LongType, containsNull =
false)),
- col($(similaritiesCol)).cast(ArrayType(DoubleType, containsNull =
false))
- ).rdd.flatMap {
- case Row(id: Long, nbrs: Seq[_], sims: Seq[_]) =>
- require(nbrs.size == sims.size, s"The length of the neighbor ID
list must be " +
- s"equal to the the length of the neighbor similarity list.
Row for ID " +
- s"$idColValue=$id has neighbor ID list of length
${nbrs.length} but similarity list " +
- s"of length ${sims.length}.")
-
nbrs.asInstanceOf[Seq[Long]].zip(sims.asInstanceOf[Seq[Double]]).map {
- case (nbr, similarity) => (id, nbr, similarity)
- }
- }
+ SchemaUtils.checkColumnTypes(dataset.schema, $(srcCol),
Seq(IntegerType, LongType))
+ SchemaUtils.checkColumnTypes(dataset.schema, $(dstCol),
Seq(IntegerType, LongType))
+ val rdd: RDD[(Long, Long, Double)] = dataset.select(
+ col($(srcCol)).cast(LongType),
+ col($(dstCol)).cast(LongType),
+ w).rdd.map {
+ case Row(src: Long, dst: Long, weight: Double) => (src, dst, weight)
+ }
val algorithm = new MLlibPowerIterationClustering()
.setK($(k))
.setInitializationMode($(initMode))
.setMaxIterations($(maxIter))
val model = algorithm.run(rdd)
- val predictionsRDD: RDD[Row] = model.assignments.map { assignment =>
+ val assignmentsRDD: RDD[Row] = model.assignments.map { assignment =>
--- End diff --
`model.assignments.toDF` should work.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]