zhengruifeng commented on a change in pull request #32124:
URL: https://github.com/apache/spark/pull/32124#discussion_r613812199
##########
File path: R/pkg/tests/fulltests/test_mllib_classification.R
##########
@@ -38,14 +38,14 @@ test_that("spark.svmLinear", {
expect_true(class(summary$coefficients[, 1]) == "numeric")
coefs <- summary$coefficients[, "Estimate"]
- expected_coefs <- c(-0.06004978, -0.1563083, -0.460648, 0.2276626, 1.055085)
Review comment:
about this change, I checked R's solution:
C(cost) = 2 / numInstance(100) / reg(0.01) = 2
```R
> library(e1071)
> training <- iris[iris$Species %in% c("versicolor", "virginica"), ]
> model <- svm(Species ~ ., data=training, type='C', kernel='linear',
cost=2, scale=F, tolerance=1e-4)
> w <- -t(model$coefs) %*% model$SV
> w
Sepal.Length Sepal.Width Petal.Length Petal.Width
[1,] -0.6345262 -1.044127 2.220617 2.72316
> model$rho
[1] -8.350759
> prediction <- predict(model, training)
> head(training, 10)
Sepal.Length Sepal.Width Petal.Length Petal.Width Species
51 7.0 3.2 4.7 1.4 versicolor
52 6.4 3.2 4.5 1.5 versicolor
53 6.9 3.1 4.9 1.5 versicolor
54 5.5 2.3 4.0 1.3 versicolor
55 6.5 2.8 4.6 1.5 versicolor
56 5.7 2.8 4.5 1.3 versicolor
57 6.3 3.3 4.7 1.6 versicolor
58 4.9 2.4 3.3 1.0 versicolor
59 6.6 2.9 4.6 1.3 versicolor
60 5.2 2.7 3.9 1.4 versicolor
> head(prediction, 10)
51 52 53 54 55 56 57
versicolor versicolor versicolor versicolor versicolor versicolor versicolor
58 59 60
versicolor versicolor versicolor
```
existing impl:
```
> df <- suppressWarnings(createDataFrame(iris))
> training <- df[df$Species %in% c("versicolor", "virginica"), ]
> model <- spark.svmLinear(training, Species ~ ., regParam = 0.01, maxIter
= 10)
> summary <- summary(model)
> coefs <- summary$coefficients[, "Estimate"]
> coefs
(Intercept) Sepal_Length Sepal_Width Petal_Length Petal_Width
-0.06004978 -0.15630830 -0.46064800 0.22766256 1.05508538
> prediction <- predict(model, training)
> head(prediction, 10)
Sepal_Length Sepal_Width Petal_Length Petal_Width Species
1 7.0 3.2 4.7 1.4 versicolor
2 6.4 3.2 4.5 1.5 versicolor
3 6.9 3.1 4.9 1.5 versicolor
4 5.5 2.3 4.0 1.3 versicolor
5 6.5 2.8 4.6 1.5 versicolor
6 5.7 2.8 4.5 1.3 versicolor
7 6.3 3.3 4.7 1.6 versicolor
8 4.9 2.4 3.3 1.0 versicolor
9 6.6 2.9 4.6 1.3 versicolor
10 5.2 2.7 3.9 1.4 versicolor
rawPrediction prediction
1 <environment: 0x56500bd01ea8> versicolor
2 <environment: 0x56500bd08228> virginica
3 <environment: 0x56500bd0a890> virginica
4 <environment: 0x56500bd10d28> virginica
5 <environment: 0x56500bd19188> virginica
6 <environment: 0x56500bd1f690> virginica
7 <environment: 0x56500bd25a48> virginica
8 <environment: 0x56500bd2a078> versicolor
9 <environment: 0x56500bd305b8> versicolor
10 <environment: 0x56500bd36ac0> virginica
```
we can see that 7 of 10 predictions are wrong, and the coef (used in the
testsuite) is far from R's coef;
this PR:
```
> df <- suppressWarnings(createDataFrame(iris))
> training <- df[df$Species %in% c("versicolor", "virginica"), ]
> model <- spark.svmLinear(training, Species ~ ., regParam = 0.01, maxIter
= 10)
> summary <- summary(model)
> coefs <- summary$coefficients[, "Estimate"]
> coefs
(Intercept) Sepal_Length Sepal_Width Petal_Length Petal_Width
-6.8823988 -0.6154984 -1.5135447 1.9694126 3.3736856
> prediction <- predict(model, training)
> head(prediction, 10)
Sepal_Length Sepal_Width Petal_Length Petal_Width Species
1 7.0 3.2 4.7 1.4 versicolor
2 6.4 3.2 4.5 1.5 versicolor
3 6.9 3.1 4.9 1.5 versicolor
4 5.5 2.3 4.0 1.3 versicolor
5 6.5 2.8 4.6 1.5 versicolor
6 5.7 2.8 4.5 1.3 versicolor
7 6.3 3.3 4.7 1.6 versicolor
8 4.9 2.4 3.3 1.0 versicolor
9 6.6 2.9 4.6 1.3 versicolor
10 5.2 2.7 3.9 1.4 versicolor
rawPrediction prediction
1 <environment: 0x561312c6a2b8> versicolor
2 <environment: 0x561312c726e0> versicolor
3 <environment: 0x561312c7ab08> versicolor
4 <environment: 0x561312c81048> versicolor
5 <environment: 0x561312c85678> versicolor
6 <environment: 0x561312c8daa0> versicolor
7 <environment: 0x561312c95f00> versicolor
8 <environment: 0x561312c9e360> versicolor
9 <environment: 0x561312ca0a70> versicolor
10 <environment: 0x561312ca6f78> versicolor
```
the coef of this PR is much close to R, and the first 10 predicitons are all
correct
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]