Since we are on spark 2.2, I backported/fixed it. Here is the diff file comparing against https://github.com/apache/spark/blob/73fe1d8087cfc2d59ac5b9af48b4cf5f5b86f920/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSizeHint.scala
24c24 < import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} --- > import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators, > IntParam} 44c44,46 < val size = new Param[Int](this, "size", "Size of vectors in column.", {s: Int => s >= 0}) --- > val size: IntParam = > new IntParam(this, "size", "Size of vectors in column.", > ParamValidators.gt(0)) > 57c59 < @Since("2.3.0") --- > /* @Since("2.3.0") 64c66 < ParamValidators.inArray(VectorSizeHint.supportedHandleInvalids)) --- > ParamValidators.inArray(VectorSizeHint.supportedHandleInvalids))*/ 134c136 < override def copy(extra: ParamMap): VectorAssembler = defaultCopy(extra) --- > override def copy(extra: ParamMap): VectorSizeHint = defaultCopy(extra) The first 2 changes are required to make it save the model with VectorSizeHint info 3rd one is required as the overridden method is final in spark 2.2 4th one was wrong code as it was giving ClassCastException Here is the working code after using this new transformer import java.util.Arrays; import java.util.List; import org.apache.spark.ml.Pipeline; import org.apache.spark.ml.PipelineModel; import org.apache.spark.ml.PipelineStage; import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator; import org.apache.spark.ml.feature.CountVectorizer; import org.apache.spark.ml.feature.CountVectorizerModel; import org.apache.spark.ml.feature.IndexToString; import org.apache.spark.ml.feature.StringIndexer; import org.apache.spark.ml.feature.StringIndexerModel; import org.apache.spark.ml.feature.Tokenizer; import org.apache.spark.ml.feature.VectorAssembler; import org.apache.spark.ml.feature.VectorSizeHint; import org.apache.spark.ml.param.ParamMap; import org.apache.spark.ml.tuning.ParamGridBuilder; import org.apache.spark.ml.tuning.TrainValidationSplit; import org.apache.spark.ml.tuning.TrainValidationSplitModel; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.streaming.StreamingQuery; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; /** * A simple text classification pipeline that recognizes "spark" from input text. */ public class StreamingIssueCountVectorizerSplitFailed { public static void main(String[] args) throws Exception { SparkSession sparkSession = SparkSession.builder().appName("StreamingIssueCountVectorizer") .master("local[2]") .getOrCreate(); List<Row> _trainData = Arrays.asList( RowFactory.create("sunny fantastic day", 1, "Positive"), RowFactory.create("fantastic morning match", 1, "Positive"), RowFactory.create("good morning", 1, "Positive"), RowFactory.create("boring evening", 5, "Negative"), RowFactory.create("tragic evening event", 5, "Negative"), RowFactory.create("today is bad ", 5, "Negative") ); List<Row> _testData = Arrays.asList( RowFactory.create("sunny morning", 1), RowFactory.create("bad evening", 5) ); StructType schema = new StructType(new StructField[]{ new StructField("tweet", DataTypes.StringType, false, Metadata.empty()), new StructField("time", DataTypes.IntegerType, false, Metadata.empty()), new StructField("sentiment", DataTypes.StringType, true, Metadata.empty()) }); StructType testSchema = new StructType(new StructField[]{ new StructField("tweet", DataTypes.StringType, false, Metadata.empty()), new StructField("time", DataTypes.IntegerType, false, Metadata.empty()) }); Dataset<Row> trainData = sparkSession.createDataFrame(_trainData, schema); Dataset<Row> testData = sparkSession.createDataFrame(_testData, testSchema); StringIndexerModel labelIndexerModel = new StringIndexer() .setInputCol("sentiment") .setOutputCol("label") .setHandleInvalid("skip") .fit(trainData); Tokenizer tokenizer = new Tokenizer() .setInputCol("tweet") .setOutputCol("words"); CountVectorizer countVectorizer = new CountVectorizer() .setInputCol(tokenizer.getOutputCol()) .setOutputCol("wordfeatures") .setVocabSize(3) .setMinDF(2) .setMinTF(2) .setBinary(true); VectorSizeHint wordfeatures = new VectorSizeHint(); wordfeatures.setInputCol("wordfeatures"); wordfeatures.setSize(3); VectorAssembler vectorAssembler = new VectorAssembler() .setInputCols(new String[]{"wordfeatures", "time"}). setOutputCol("features"); Dataset<Row> words = tokenizer.transform(trainData); CountVectorizerModel countVectorizerModel = countVectorizer.fit(words); LogisticRegression lr = new LogisticRegression() .setMaxIter(10) .setRegParam(0.001); IndexToString labelConverter = new IndexToString() .setInputCol("prediction") .setOutputCol("predicted") .setLabels(labelIndexerModel.labels()); countVectorizerModel.setMinTF(1); Pipeline pipeline = new Pipeline() .setStages( new PipelineStage[]{labelIndexerModel, tokenizer, countVectorizerModel, wordfeatures, vectorAssembler, lr, labelConverter}); ParamMap[] paramGrid = new ParamGridBuilder() .addGrid(lr.regParam(), new double[]{0.1, 0.01}) .addGrid(lr.fitIntercept()) .addGrid(lr.elasticNetParam(), new double[]{0.0, 0.5, 1.0}) .build(); MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator(); evaluator.setLabelCol("label"); evaluator.setPredictionCol("prediction"); TrainValidationSplit trainValidationSplit = new TrainValidationSplit() .setEstimator(pipeline) .setEvaluator(evaluator) .setEstimatorParamMaps(paramGrid) .setTrainRatio(0.7); // Fit the pipeline to training documents. TrainValidationSplitModel trainValidationSplitModel = trainValidationSplit.fit(trainData); trainValidationSplitModel.write().overwrite().save("/tmp/CountSplit.model"); TrainValidationSplitModel _loadedModel = TrainValidationSplitModel .load("/tmp/CountSplit.model"); PipelineModel loadedModel = (PipelineModel) (_loadedModel).bestModel(); //Test on non-streaming data Dataset<Row> predicted = loadedModel.transform(testData); predicted.show(); List<Row> _rows = predicted.select("tweet", "predicted").collectAsList(); for (Row r : _rows) { System.out.println("[" + r.get(0) + "], prediction=" + r.get(1)); } //Test on streaming data Dataset<Row> lines = sparkSession.readStream().option("sep", ",") .schema(testSchema).option("header", "true").option("inferSchema", "true") .format("com.databricks.spark.csv") .load("file:///home/davis/Documents/Bugs/StreamingTwitter1"); StreamingQuery query = loadedModel.transform(lines).writeStream() .outputMode("append") .format("console") .start(); query.awaitTermination(); } } -- Sent from: http://apache-spark-developers-list.1001551.n3.nabble.com/ --------------------------------------------------------------------- To unsubscribe e-mail: dev-unsubscr...@spark.apache.org