Github user jkbradley commented on a diff in the pull request:

    https://github.com/apache/spark/pull/9267#discussion_r48585214
  
    --- Diff: mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala ---
    @@ -53,6 +65,96 @@ class FPGrowthModel[Item: ClassTag] @Since("1.3.0") (
         val associationRules = new AssociationRules(confidence)
         associationRules.run(freqItemsets)
       }
    +
    +  override def save(sc: SparkContext, path: String): Unit = {
    +    FPGrowthModel.SaveLoadV1_0.save(this, path)
    +  }
    +
    +  override protected val formatVersion: String = "1.0"
    +}
    +
    +object FPGrowthModel extends Loader[FPGrowthModel[_]] {
    +
    +  override def load(sc: SparkContext, path: String): FPGrowthModel[_] = {
    +    val inferredItemset = FPGrowthModel.SaveLoadV1_0.inferItemType(sc, 
path)
    +    FPGrowthModel.SaveLoadV1_0.load(sc, path, inferredItemset)
    +  }
    +
    +  private[fpm] object SaveLoadV1_0 {
    +
    +    private val thisFormatVersion = "1.0"
    +
    +    private[fpm] val thisClassName = 
"org.apache.spark.mllib.fpm.FPGrowthModel"
    +
    +    def save[Item: ClassTag: TypeTag](model: FPGrowthModel[Item], path: 
String): Unit = {
    +      val sc = model.freqItemsets.sparkContext
    +      val sqlContext = new SQLContext(sc)
    +
    +      val metadata = compact(render(
    +        ("class" -> thisClassName) ~ ("version" -> thisFormatVersion)))
    +      sc.parallelize(Seq(metadata), 
1).saveAsTextFile(Loader.metadataPath(path))
    +
    +      val itemType = ScalaReflection.schemaFor[Item].dataType
    +      val fields = Array(StructField("items", ArrayType(itemType)),
    +        StructField("freq", LongType))
    +      val schema = StructType(fields)
    +      val rowDataRDD = model.freqItemsets.map { x =>
    +        Row(x.items, x.freq)
    +      }
    +      sqlContext.createDataFrame(rowDataRDD, 
schema).write.parquet(Loader.dataPath(path))
    +    }
    +
    +    def inferItemType(sc: SparkContext, path: String): FreqItemset[_] = {
    +      val sqlContext = new SQLContext(sc)
    +      val freqItemsets = sqlContext.read.parquet(Loader.dataPath(path))
    +      val itemsetType = freqItemsets.schema(0).dataType
    +      val freqType = freqItemsets.schema(1).dataType
    +      require(itemsetType.isInstanceOf[ArrayType],
    +        s"items should be ArrayType, but get $itemsetType")
    +      require(freqType.isInstanceOf[LongType], s"freq should be LongType, 
but get $freqType")
    +      val itemType = itemsetType.asInstanceOf[ArrayType].elementType
    +      val result = itemType match {
    +        case BooleanType => new FreqItemset(Array[Boolean](), 0L)
    +        case BinaryType => new FreqItemset(Array(Array[Byte]()), 0L)
    +        case StringType => new FreqItemset(Array[String](), 0L)
    +        case ByteType => new FreqItemset(Array[Byte](), 0L)
    +        case ShortType => new FreqItemset(Array[Short](), 0L)
    +        case IntegerType => new FreqItemset(Array[Int](), 0L)
    +        case LongType => new FreqItemset(Array[Long](), 0L)
    +        case FloatType => new FreqItemset(Array[Float](), 0L)
    +        case DoubleType => new FreqItemset(Array[Double](), 0L)
    +        case DateType => new FreqItemset(Array[java.sql.Date](), 0L)
    +        case DecimalType.SYSTEM_DEFAULT => new 
FreqItemset(Array[java.math.BigDecimal](), 0L)
    +        case TimestampType => new FreqItemset(Array[java.sql.Timestamp](), 
0L)
    +        case _: ArrayType => new FreqItemset(Array[Seq[_]](), 0L)
    +        case _: MapType => new FreqItemset(Array[Map[_, _]](), 0L)
    +        case _: StructType => new FreqItemset(Array[Row](), 0L)
    +        case other =>
    +          throw new UnsupportedOperationException(s"Schema for type $other 
is not supported")
    +      }
    +      result
    +    }
    +
    +    def load[Item: ClassTag: TypeTag](
    +        sc: SparkContext,
    +        path: String,
    +        inferredItemset: FreqItemset[Item]): FPGrowthModel[Item] = {
    +      implicit val formats = DefaultFormats
    +      val sqlContext = new SQLContext(sc)
    +
    +      val (className, formatVersion, metadata) = loadMetadata(sc, path)
    +      assert(className == thisClassName)
    +      assert(formatVersion == thisFormatVersion)
    +
    +      val freqItemsets = sqlContext.read.parquet(Loader.dataPath(path))
    +      val freqItemsetsRDD = freqItemsets.map { x =>
    +        val items = x.getAs[Seq[Item]](0).toArray
    --- End diff --
    
    Are you able to do ```getSeq[_]``` here?  I'm wondering if we can eliminate 
```inferItemType```.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to