[
https://issues.apache.org/jira/browse/SPARK-34045?page=com.atlassian.jira.plugin.system.issuetabpanels:all-tabpanel
]
zhengruifeng reassigned SPARK-34045:
------------------------------------
Assignee: zhengruifeng
> OneVsRestModel.transform should not call setter of submodels
> ------------------------------------------------------------
>
> Key: SPARK-34045
> URL: https://issues.apache.org/jira/browse/SPARK-34045
> Project: Spark
> Issue Type: Improvement
> Components: ML
> Affects Versions: 3.2.0
> Reporter: zhengruifeng
> Assignee: zhengruifeng
> Priority: Minor
>
> featuresCol of submodels maybe changed in transform:
> {code:java}
> scala> val df =
> spark.read.format("libsvm").load("/d0/Dev/Opensource/spark/data/mllib/sample_multiclass_classification_data.txt")
> 21/01/08 09:52:01 WARN LibSVMFileFormat: 'numFeatures' option not specified,
> determining the number of features by going though the input. If you know the
> number in advance, please specify it via 'numFeatures' option to avoid the
> extra scan.
> df: org.apache.spark.sql.DataFrame = [label: double, features: vector]
> scala> val lr = new
> LogisticRegression().setMaxIter(1).setTol(1E-6).setFitIntercept(true)
> lr: org.apache.spark.ml.classification.LogisticRegression =
> logreg_3003cb3321a1
> scala> val ovr = new OneVsRest().setClassifier(lr)
> ovr: org.apache.spark.ml.classification.OneVsRest = oneVsRest_b2ec3ec45dbf
> scala> val ovrm = ovr.fit(df)
> 21/01/08 09:52:05 WARN BLAS: Failed to load implementation from:
> com.github.fommil.netlib.NativeSystemBLAS
> 21/01/08 09:52:05 WARN BLAS: Failed to load implementation from:
> com.github.fommil.netlib.NativeRefBLAS
> ovrm: org.apache.spark.ml.classification.OneVsRestModel = OneVsRestModel:
> uid=oneVsRest_b2ec3ec45dbf, classifier=logreg_3003cb3321a1, numClasses=3,
> numFeatures=4
> scala> val df2 = df.withColumnRenamed("features", "features2")
> df2: org.apache.spark.sql.DataFrame = [label: double, features2: vector]
> scala> ovrm.setFeaturesCol("features2")
> res0: ovrm.type = OneVsRestModel: uid=oneVsRest_b2ec3ec45dbf,
> classifier=logreg_3003cb3321a1, numClasses=3, numFeatures=4
> scala> ovrm.models.map(_.getFeaturesCol)
> res1: Array[String] = Array(features, features, features)
> scala> ovrm.transform(df2)
> res2: org.apache.spark.sql.DataFrame = [label: double, features2: vector ...
> 2 more fields]
> scala> ovrm.models.map(_.getFeaturesCol)
> res3: Array[String] = Array(features2, features2, features2)
> {code}
--
This message was sent by Atlassian Jira
(v8.3.4#803005)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]