[ 
https://issues.apache.org/jira/browse/SPARK-48463?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=17852751#comment-17852751
 ] 

someshwar kale commented on SPARK-48463:
----------------------------------------

As a temporary fix you may consider renaming the columns by adding a 
transformer as below-
{code:java}
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.param.{ParamMap, Params}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, 
Identifiable}
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.types.{StructField, StructType}

class RenameColumn(val uid: String) extends Transformer with Params
  with HasInputCol with HasOutputCol with DefaultParamsWritable  {
  def this() = this(Identifiable.randomUID("RenameColumn"))

  /** @group setParam */
  def setInputCol(value: String): this.type = set(inputCol, value)

  /** @group setParam */
  def setOutputCol(value: String): this.type = set(outputCol, value)

  def validateAndTransformSchema(schema: StructType): StructType = {
    val col = schema(getInputCol)
    schema.add(StructField(getOutputCol, col.dataType, col.nullable, 
col.metadata))
  }

  def transformSchema(schema: StructType): StructType = 
validateAndTransformSchema(schema)

  def copy(extra: ParamMap): RenameColumn = defaultCopy(extra)

  override def transform(dataset: Dataset[_]): DataFrame = {
    transformSchema(dataset.schema, logging = true)
    dataset.toDF().withColumnRenamed(getInputCol, getOutputCol)
  }
}

object RenameColumn extends DefaultParamsReadable[RenameColumn] {
  override def load(path: String): RenameColumn = super.load(path)
} {code}
and use the above transformer in the pipeline as below- 
{code:java}
val structureData = Seq(
  Row(Row(10, 12), 1000),
  Row(Row(12, 14), 4300),
  Row( Row(37, 891), 1400),
  Row(Row(8902, 12), 4000),
  Row(Row(12, 89), 1000)
)

val structureSchema = new StructType()
  .add("location", new StructType()
    .add("longitude", IntegerType)
    .add("latitude", IntegerType))
  .add("salary", IntegerType)
val df = spark.createDataFrame(spark.sparkContext.parallelize(structureData), 
structureSchema)

def flattenSchema(schema: StructType, prefix: String = null, prefixSelect: 
String = null):
Array[Column] = {
  schema.fields.flatMap(f => {
    val colName = if (prefix == null) f.name else (prefix + "." + f.name)
    val colnameSelect = if (prefix == null) f.name else (prefixSelect + "." + 
f.name)

    f.dataType match {
      case st: StructType => flattenSchema(st, colName, colnameSelect)
      case _ =>
        Array(col(colName).as(colnameSelect))
    }
  })
}

val flattenColumns = flattenSchema(df.schema)
val flattenedDf = df.select(flattenColumns: _*)

flattenedDf.printSchema
flattenedDf.show()

val renameColumn = new 
RenameColumn().setInputCol("location.longitude").setOutputCol("location_longitude")
val si = new 
StringIndexer().setInputCol("location_longitude").setOutputCol("longitutdee")
val pipeline = new Pipeline().setStages(Array(renameColumn, si))
pipeline.fit(flattenedDf).transform(flattenedDf).show()

/**
 * +------------------+-----------------+------+-----------+
 * |location_longitude|location.latitude|salary|longitutdee|
 * +------------------+-----------------+------+-----------+
 * |                10|               12|  1000|        1.0|
 * |                12|               14|  4300|        0.0|
 * |                37|              891|  1400|        2.0|
 * |              8902|               12|  4000|        3.0|
 * |                12|               89|  1000|        0.0|
 * +------------------+-----------------+------+-----------+
 */ {code}

> MLLib function unable to handle nested data
> -------------------------------------------
>
>                 Key: SPARK-48463
>                 URL: https://issues.apache.org/jira/browse/SPARK-48463
>             Project: Spark
>          Issue Type: Bug
>          Components: ML, MLlib
>    Affects Versions: 3.5.1
>            Reporter: Chhavi Bansal
>            Priority: Major
>              Labels: ML, MLPipelines, mllib, nested
>
> I am trying to use feature transformer on nested data after flattening, but 
> it fails.
>  
> {code:java}
> val structureData = Seq(
>   Row(Row(10, 12), 1000),
>   Row(Row(12, 14), 4300),
>   Row( Row(37, 891), 1400),
>   Row(Row(8902, 12), 4000),
>   Row(Row(12, 89), 1000)
> )
> val structureSchema = new StructType()
>   .add("location", new StructType()
>     .add("longitude", IntegerType)
>     .add("latitude", IntegerType))
>   .add("salary", IntegerType) 
> val df = spark.createDataFrame(spark.sparkContext.parallelize(structureData), 
> structureSchema) 
> def flattenSchema(schema: StructType, prefix: String = null, prefixSelect: 
> String = null):
> Array[Column] = {
>   schema.fields.flatMap(f => {
>     val colName = if (prefix == null) f.name else (prefix + "." + f.name)
>     val colnameSelect = if (prefix == null) f.name else (prefixSelect + "." + 
> f.name)
>     f.dataType match {
>       case st: StructType => flattenSchema(st, colName, colnameSelect)
>       case _ =>
>         Array(col(colName).as(colnameSelect))
>     }
>   })
> }
> val flattenColumns = flattenSchema(df.schema)
> val flattenedDf = df.select(flattenColumns: _*){code}
> Now using the string indexer on the DOT notation.
>  
> {code:java}
> val si = new 
> StringIndexer().setInputCol("location.longitude").setOutputCol("longitutdee")
> val pipeline = new Pipeline().setStages(Array(si))
> pipeline.fit(flattenedDf).transform(flattenedDf).show() {code}
> The above code fails 
> {code:java}
> xception in thread "main" org.apache.spark.sql.AnalysisException: Cannot 
> resolve column name "location.longitude" among (location.longitude, 
> location.latitude, salary); did you mean to quote the `location.longitude` 
> column?
>     at 
> org.apache.spark.sql.errors.QueryCompilationErrors$.cannotResolveColumnNameAmongFieldsError(QueryCompilationErrors.scala:2261)
>     at 
> org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$resolveException(Dataset.scala:258)
>     at org.apache.spark.sql.Dataset.$anonfun$resolve$1(Dataset.scala:250)
> ..... {code}
> This points to the same failure as when we try to select dot notation columns 
> in a spark dataframe, which is solved using BACKTICKS *`column.name`.* 
> [https://stackoverflow.com/a/51430335/11688337]
>  
> *so next*
> I use the back ticks while defining stringIndexer
> {code:java}
> val si = new 
> StringIndexer().setInputCol("`location.longitude`").setOutputCol("longitutdee")
>  {code}
> In this case *it again fails* (with a diff reason) in the stringIndexer code 
> itself
> {code:java}
> Exception in thread "main" org.apache.spark.SparkException: Input column 
> `location.longitude` does not exist.
>     at 
> org.apache.spark.ml.feature.StringIndexerBase.$anonfun$validateAndTransformSchema$2(StringIndexer.scala:128)
>     at 
> scala.collection.TraversableLike.$anonfun$flatMap$1(TraversableLike.scala:244)
>     at 
> scala.collection.IndexedSeqOptimized.foreach(IndexedSeqOptimized.scala:36)
>     at 
> scala.collection.IndexedSeqOptimized.foreach$(IndexedSeqOptimized.scala:33) 
> {code}
>  
> This blocks me to use feature transformation functions on nested columns. 
> Any help in solving this problem will be highly appreciated.



--
This message was sent by Atlassian Jira
(v8.20.10#820010)

---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org
For additional commands, e-mail: issues-h...@spark.apache.org

Reply via email to