zhengruifeng commented on pull request #31693:
URL: https://github.com/apache/spark/pull/31693#issuecomment-796504039
I use the scala case in the ticket:
```
// scalastyle:off println
test("BLR") {
import org.apache.spark.ml.feature.VectorAssembler
val centered = false
val regParam = 1.0e-8
val num_distribution_samplings = 1000
val num_rows_per_sampling = 1000
val theta_1 = 0.3f
val theta_2 = 0.2f
val intercept = -4.0f
val (feature1, feature2, target) = generate_blr_data(theta_1, theta_2,
intercept, centered,
num_distribution_samplings, num_rows_per_sampling)
val num_rows = num_distribution_samplings * num_rows_per_sampling
val const_feature = Array.fill(num_rows)(1.0f)
(0 until num_rows / 10).foreach { i => const_feature(i) = 0.9f }
val data = (0 until num_rows).map { i =>
(feature1(i), feature2(i), const_feature(i), target(i))
}
val spark_df = spark.createDataFrame(data)
.toDF("feature1", "feature2", "const_feature", "label").cache()
val vec = new VectorAssembler()
.setInputCols(Array("feature1", "feature2")).setOutputCol(("features"))
val spark_df1 = vec.transform(spark_df).cache()
val lr = new LogisticRegression().
setMaxIter(100).setRegParam(regParam).setElasticNetParam(0.5).setFitIntercept(true)
val lrModel = lr.fit(spark_df1)
println("Just the blr data")
println("Coefficients: " + lrModel.coefficients)
println("Intercept: " + lrModel.intercept)
val vec2 = new VectorAssembler().setInputCols(Array("feature1",
"feature2", "const_feature")).
setOutputCol(("features"))
val spark_df2 = vec2.transform(spark_df).cache()
val lrModel2 = lr.fit(spark_df2)
println("blr data plus one vector that is filled with 1's and .9's")
println("Coefficients: " + lrModel2.coefficients)
println("Intercept: " + lrModel2.intercept)
println("objective: " + lrModel2.summary.objectiveHistory.mkString(","))
val rawPreds = Seq.range(0, 1000000, 3000).map { i =>
val vec = Vectors.dense(feature1(i), feature2(i), const_feature(i))
lrModel2.predictRaw(vec)(0)
}
println(s"rawPreds: ${rawPreds.mkString(",")}")
val probs = Seq.range(0, 1000000, 3000).map { i =>
val vec = Vectors.dense(feature1(i), feature2(i), const_feature(i))
lrModel2.predictProbability(vec)(0)
}
println(s"probs: ${probs.mkString(",")}")
}
def generate_blr_data(theta_1: Float,
theta_2: Float,
intercept: Float,
centered: Boolean,
num_distribution_samplings: Int,
num_rows_per_sampling: Int): (Array[Float],
Array[Float], Array[Int]) = {
val random = new Random(12345L)
val uniforms = Array.fill(num_distribution_samplings)(random.nextFloat())
val uniforms2 =
Array.fill(num_distribution_samplings)(random.nextFloat())
if (centered) {
uniforms.transform(f => f - 0.5f)
uniforms2.transform(f => 2.0f * f - 1.0f)
} else {
uniforms2.transform(f => f + 1.0f)
}
val h_theta = uniforms.zip(uniforms2)
.map { case (a, b) => intercept + theta_1 * a + theta_2 * b }
val prob = h_theta.map(t => 1.0 / (1.0 + math.exp(-t)))
val array = Array.ofDim[Int](num_distribution_samplings,
num_rows_per_sampling)
array.indices.foreach { i =>
(0 until math.round(num_rows_per_sampling * prob(i)).toInt).foreach {
j =>
array(i)(j) = 1
}
}
val num_rows = num_distribution_samplings * num_rows_per_sampling
val feature_1 = uniforms.map(f =>
Array.fill(num_rows_per_sampling)(f)).flatten
val feature_2 = uniforms2.map(f =>
Array.fill(num_rows_per_sampling)(f)).flatten
val target = array.flatten
return (feature_1, feature_2, target)
}
// scalastyle:on println
```
master (without centering):
solution: Intercept: -3.5498333785156753, Coefficients:
[0.29728813129204773,0.19287733247367822,-0.44180451028644724]
Objective history:
0.12762747240520095,0.12762216631186177,0.12764985979454083,0.12761175327662522,0.12758366054986692,0.1275656630375162,0.12755645900672408,0.127532008088537,0.12750910987285638,0.1275082263627089,0.12750722952815843,0.1275062817766826,0.12750538161482666,0.1275045356612968,0.12750373487279973,0.1275029817907778,0.12750225789238678,0.12750158637218928,0.1275009135177896,0.1275005830775023,0.12749954294254812,0.12749904352248712,0.12749837794728083,0.1274977512022644,0.12749737729191743,0.1274966670966546,0.12749651869450704,0.12749576053861586,0.12749556188951974,0.12749305449824697,0.1274926258332541,0.12749237844278866,0.1274918338503058,0.1274916122125306,0.12749117813089475,0.12749098122913377,0.12749063524924278,0.12749046164580588,0.12749018587090755,0.1274900081622212,0.1274898078942527,0.12748974247471218,0.12748960522165648,0.12748951458190824,0.1274894296030913,0.1274893211438692,0.12748927727579804,0.1274891570656151,0.1274891306335918,0.12748901615923106,0.1274890035995
9753,0.12748889237426544,0.12748886393809275,0.12748883685446644,0.1274884708470191,0.12748845591533808,0.1274884262204896,0.12748841313224754,0.12748838705800047,0.1274883755234384,0.12748835263609506,0.12748834246869473,0.1274883223799556,0.12748831341596667,0.12748829578487086,0.1274882916026448,0.1274882672142916,0.12748826098016602,0.1274882597476527,0.1274882398160488,0.1274882353158831,0.1274882216517839,0.12748821488372022,0.12748820601030036,0.12748819772926567,0.12748819250736595,0.12748818326927866,0.12748818082144145,0.12748817105754642,0.12748816804373495,0.12748813555318347,0.12748813488583727,0.12748813118913596,0.1274881307065699,0.1274881273830832,0.12748810045033077,0.12748809985451237,0.12748809460766752,0.1274880928663959,0.12748808779542983,0.12748808499985892,0.12748808351023963,0.12748808239530365
this PR(with centering):
solution: Intercept: -4.00916620713765, Coefficients:
[0.29886368349072456,0.2009762822382037,0.008400272616188038]
Objective history:
0.12762747240520095,0.12752273497899994,0.12749490885333506,0.12748788084010645,0.12748550735190392,0.12748550700163083
prediction diff:
avg(abs(diff_raw_prediction)) = 0.007523936579265968
max(abs(diff_raw_prediction)) = 0.045657305342184706
avg(abs(diff_prob)) = 2.0648205657827485E-4
max(abs(diff_prob)) = 0.001266340205206995
@mengxr
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]