zhengruifeng commented on a change in pull request #28349:
URL: https://github.com/apache/spark/pull/28349#discussion_r415255855
##########
File path:
mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
##########
@@ -154,31 +156,54 @@ class LinearSVC @Since("2.2.0") (
def setAggregationDepth(value: Int): this.type = set(aggregationDepth, value)
setDefault(aggregationDepth -> 2)
+ /**
+ * Set block size for stacking input data in matrices.
+ * Default is 1.
+ *
+ * @group expertSetParam
+ */
+ @Since("3.0.0")
+ def setBlockSize(value: Int): this.type = set(blockSize, value)
+ setDefault(blockSize -> 1)
+
@Since("2.2.0")
override def copy(extra: ParamMap): LinearSVC = defaultCopy(extra)
override protected def train(dataset: Dataset[_]): LinearSVCModel =
instrumented { instr =>
- val handlePersistence = dataset.storageLevel == StorageLevel.NONE
-
- val instances = extractInstances(dataset)
- if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
-
instr.logPipelineStage(this)
instr.logDataset(dataset)
instr.logParams(this, labelCol, weightCol, featuresCol, predictionCol,
rawPredictionCol,
- regParam, maxIter, fitIntercept, tol, standardization, threshold,
aggregationDepth)
+ regParam, maxIter, fitIntercept, tol, standardization, threshold,
aggregationDepth, blockSize)
- val (summarizer, labelSummarizer) =
+ val instances = extractInstances(dataset).setName("training instances")
+
+ val (summarizer, labelSummarizer) = if ($(blockSize) == 1) {
+ if (dataset.storageLevel == StorageLevel.NONE) {
+ instances.persist(StorageLevel.MEMORY_AND_DISK)
+ }
Summarizer.getClassificationSummarizers(instances, $(aggregationDepth))
- instr.logNumExamples(summarizer.count)
- instr.logNamedValue("lowestLabelWeight",
labelSummarizer.histogram.min.toString)
- instr.logNamedValue("highestLabelWeight",
labelSummarizer.histogram.max.toString)
- instr.logSumOfWeights(summarizer.weightSum)
+ } else {
+ // instances will be standardized and converted to blocks, so no need to
cache instances.
+ Summarizer.getClassificationSummarizers(instances, $(aggregationDepth),
+ Seq("mean", "std", "count", "numNonZeros"))
+ }
val histogram = labelSummarizer.histogram
val numInvalid = labelSummarizer.countInvalid
val numFeatures = summarizer.mean.size
- val numFeaturesPlusIntercept = if (getFitIntercept) numFeatures + 1 else
numFeatures
+
+ instr.logNumExamples(summarizer.count)
+ instr.logNamedValue("lowestLabelWeight",
labelSummarizer.histogram.min.toString)
+ instr.logNamedValue("highestLabelWeight",
labelSummarizer.histogram.max.toString)
+ instr.logSumOfWeights(summarizer.weightSum)
+ if ($(blockSize) > 1) {
Review comment:
I think it is up to the end user to choose whether high-level blas is
used and which BLAS lib is used.
Here computes the sparsity of dataset, if input it too sparse, log a warning.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]