Github user mengxr commented on a diff in the pull request:

    https://github.com/apache/spark/pull/5431#discussion_r28290532
  
    --- Diff: mllib/src/main/scala/org/apache/spark/ml/param/params.scala ---
    @@ -179,52 +179,96 @@ trait Params extends Identifiable with Serializable {
       /**
        * Sets a parameter (by name) in the embedded param map.
        */
    -  private[ml] def set(param: String, value: Any): this.type = {
    +  protected final def set(param: String, value: Any): this.type = {
         set(getParam(param), value)
       }
     
       /**
    -   * Gets the value of a parameter in the embedded param map.
    +   * Optionally returns the user-supplied value of a param.
    +   */
    +  final def get[T](param: Param[T]): Option[T] = {
    +    shouldOwn(param)
    +    paramMap.get(param)
    +  }
    +
    +  /**
    +   * Clears the user-supplied value for the input param.
    +   */
    +  final def clear(param: Param[_]): this.type = {
    +    shouldOwn(param)
    +    paramMap.remove(param)
    +    this
    +  }
    +
    +  /**
    +   * Gets the value of a param in the embedded param map or its default 
value. Throws an exception
    +   * if neither is set.
        */
    -  protected def get[T](param: Param[T]): T = {
    -    require(param.parent.eq(this))
    -    paramMap(param)
    +  final def getOrDefault[T](param: Param[T]): T = {
    +    shouldOwn(param)
    +    get(param).orElse(getDefault(param)).get
       }
     
       /**
    -   * Internal param map.
    +   * Sets a default value. Make sure that the input param is initialized 
before this gets called.
        */
    -  protected val paramMap: ParamMap = ParamMap.empty
    +  protected final def setDefault[T](param: Param[T], value: T): this.type 
= {
    +    shouldOwn(param)
    +    defaultParamMap.put(param, value)
    +    this
    +  }
     
       /**
    -   * Check whether the given schema contains an input column.
    -   * @param colName  Input column name
    -   * @param dataType  Input column DataType
    +   * Sets default values. Make sure that the input params are initialized 
before this gets called.
        */
    -  protected def checkInputColumn(schema: StructType, colName: String, 
dataType: DataType): Unit = {
    -    val actualDataType = schema(colName).dataType
    -    require(actualDataType.equals(dataType), s"Input column $colName must 
be of type $dataType" +
    -      s" but was actually $actualDataType.  Column param description: 
${getParam(colName)}")
    +  protected final def setDefault(paramPairs: ParamPair[_]*): this.type = {
    +    paramPairs.foreach { p =>
    +      setDefault(p.param.asInstanceOf[Param[Any]], p.value)
    +    }
    +    this
       }
     
       /**
    -   * Add an output column to the given schema.
    -   * This fails if the given output column already exists.
    -   * @param schema  Initial schema (not modified)
    -   * @param colName  Output column name.  If this column name is an empy 
String "", this method
    -   *                 returns the initial schema, unchanged.  This allows 
users to disable output
    -   *                 columns.
    -   * @param dataType  Output column DataType
    -   */
    -  protected def addOutputColumn(
    -      schema: StructType,
    -      colName: String,
    -      dataType: DataType): StructType = {
    -    if (colName.length == 0) return schema
    -    val fieldNames = schema.fieldNames
    -    require(!fieldNames.contains(colName), s"Output column $colName 
already exists.")
    -    val outputFields = schema.fields ++ Seq(StructField(colName, dataType, 
nullable = false))
    -    StructType(outputFields)
    +   * Gets the default value of a parameter.
    +   */
    +  final def getDefault[T](param: Param[T]): Option[T] = {
    +    shouldOwn(param)
    +    defaultParamMap.get(param)
    +  }
    +
    +  /**
    +   * Tests whether the input param has a default value set.
    +   */
    +  final def hasDefault[T](param: Param[T]): Boolean = {
    +    shouldOwn(param)
    +    defaultParamMap.contains(param)
    +  }
    +
    +  /**
    +   * Extracts the embedded default param values and user-supplied values, 
and then merges them with
    +   * extra values from input into a flat param map, where the latter value 
is used if there exist
    +   * conflicts.
    --- End diff --
    
    done


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to