Imbruced commented on code in PR #1589: URL: https://github.com/apache/sedona/pull/1589#discussion_r1775987299
########## spark/common/src/main/scala/org/apache/sedona/stats/clustering/DBSCAN.scala: ########## @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.stats.clustering + +import org.apache.sedona.stats.Util.getGeometryColumnName +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.sedona_sql.expressions.st_functions.{ST_Distance, ST_DistanceSpheroid} +import org.apache.spark.sql.{Column, DataFrame, SparkSession} +import org.graphframes.GraphFrame + +object DBSCAN { + + private val ID_COLUMN = "__id" + + /** + * Annotates a dataframe with a cluster label for each data record using the DBSCAN algorithm. + * The dataframe should contain at least one GeometryType column. Rows must be unique. If one + * geometry column is present it will be used automatically. If two are present, the one named + * 'geometry' will be used. If more than one are present and neither is named 'geometry', the + * column name must be provided. The new column will be named 'cluster'. + * + * @param dataframe + * apache sedona idDataframe containing the point geometries + * @param epsilon + * minimum distance parameter of DBSCAN algorithm + * @param min_pts + * minimum number of points parameter of DBSCAN algorithm + * @param geometry + * name of the geometry column + * @param includeOutliers + * whether to include outliers in the output. Default is false + * @param useSpheroid + * whether to use a cartesian or spheroidal distance calculation. Default is false + * @return + * The input DataFrame with the cluster label added to each row. Outlier will have a cluster + * value of -1 if included. + */ + def dbscan( + dataframe: DataFrame, + epsilon: Double, + min_pts: Int, + geometry: String = null, + includeOutliers: Boolean = true, + useSpheroid: Boolean = false): DataFrame = { + + // We want to disable broadcast joins because the broadcast reference were using too much driver memory + val spark = SparkSession.getActiveSession.get + + val geometryCol = geometry match { + case null => getGeometryColumnName(dataframe) + case _ => geometry + } + validateInputs(dataframe, epsilon, min_pts, geometryCol) + + val distanceFunction: (Column, Column) => Column = + if (useSpheroid) ST_DistanceSpheroid else ST_Distance + + val hasIdColumn = dataframe.columns.contains("id") + val idDataframe = if (hasIdColumn) { + dataframe + .withColumnRenamed("id", ID_COLUMN) + .withColumn("id", sha2(to_json(struct("*")), 256)) + } else { + dataframe.withColumn("id", sha2(to_json(struct("*")), 256)) Review Comment: duplicated records should be equal and aggregated or not ? I am wondering if here `monotonically_increasing_id ` function is enough ? I think that might by costly `sha2(to_json(struct("*"))` also as key for join might not be super efficient compared to bigints ? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
