This is an automated email from the ASF dual-hosted git repository.

jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git


The following commit(s) were added to refs/heads/master by this push:
     new b1ceb1e8c [SEDONA-661] add local outlier factor implementation. (#1623)
b1ceb1e8c is described below

commit b1ceb1e8c5e5fdc1c4e1bdd9ec929aa1393045de
Author: James Willis <[email protected]>
AuthorDate: Tue Oct 15 20:49:18 2024 -0700

    [SEDONA-661] add local outlier factor implementation. (#1623)
    
    * add local outlier factor implementation.
    
    * LOF docs
    
    * precommit changes
    
    * precommit formatting changes
    
    ---------
    
    Co-authored-by: jameswillis <[email protected]>
---
 docs/api/stats/sql.md                              |  22 ++-
 docs/tutorial/sql.md                               |  67 +++++++++-
 python/sedona/stats/outlier_detection/__init__.py  |  18 +++
 .../outlier_detection/local_outlier_factor.py      |  60 +++++++++
 python/tests/stats/test_local_outlier_factor.py    | 107 +++++++++++++++
 .../outlierDetection/LocalOutlierFactor.scala      | 148 +++++++++++++++++++++
 .../org/apache/sedona/sql/TestBaseScala.scala      |   2 +
 .../outlierDetection/LocalOutlierFactorTest.scala  |  59 ++++++++
 8 files changed, 479 insertions(+), 4 deletions(-)

diff --git a/docs/api/stats/sql.md b/docs/api/stats/sql.md
index fe7c0e90e..290691710 100644
--- a/docs/api/stats/sql.md
+++ b/docs/api/stats/sql.md
@@ -9,7 +9,7 @@ complete set of geospatial analysis tools.
 
 ## Using DBSCAN
 
-The DBSCAN function is provided at `org.apache.sedona.stats.DBSCAN.dbscan` in 
scala/java and `sedona.stats.dbscan.dbscan` in python.
+The DBSCAN function is provided at 
`org.apache.sedona.stats.clustering.DBSCAN.dbscan` in scala/java and 
`sedona.stats.clustering.dbscan.dbscan` in python.
 
 The function annotates a dataframe with a cluster label for each data record 
using the DBSCAN algorithm.
 The dataframe should contain at least one `GeometryType` column. Rows must be 
unique. If one
@@ -29,3 +29,23 @@ names in parentheses are python variable names
 - useSpheroid (use_spheroid) - whether to use a cartesian or spheroidal 
distance calculation. Default is false
 
 The output is the input DataFrame with the cluster label added to each row. 
Outlier will have a cluster value of -1 if included.
+
+## Using Local Outlier Factor (LOF)
+
+The LOF function is provided at 
`org.apache.sedona.stats.outlierDetection.LocalOutlierFactor.localOutlierFactor`
 in scala/java and 
`sedona.stats.outlier_detection.local_outlier_factor.local_outlier_factor` in 
python.
+
+The function annotates a dataframe with a column containing the local outlier 
factor for each data record.
+The dataframe should contain at least one `GeometryType` column. Rows must be 
unique. If one
+geometry column is present it will be used automatically. If two are present, 
the one named
+'geometry' will be used. If more than one are present and neither is named 
'geometry', the
+column name must be provided.
+
+### Parameters
+
+names in parentheses are python variable names
+
+- dataframe - dataframe containing the point geometries
+- k - number of nearest neighbors that will be considered for the LOF 
calculation
+- geometry - name of the geometry column
+- handleTies (handle_ties) - whether to handle ties in the k-distance 
calculation. Default is false
+- useSpheroid (use_spheroid) - whether to use a cartesian or spheroidal 
distance calculation. Default is false
diff --git a/docs/tutorial/sql.md b/docs/tutorial/sql.md
index 75a013f85..7754338a8 100644
--- a/docs/tutorial/sql.md
+++ b/docs/tutorial/sql.md
@@ -842,7 +842,7 @@ The first parameter is the dataframe, the next two are the 
epsilon and min_point
 === "Scala"
 
        ```scala
-       import org.apache.sedona.stats.DBSCAN.dbscan
+       import org.apache.sedona.stats.clustering.DBSCAN.dbscan
 
        dbscan(df, 0.1, 5).show()
        ```
@@ -850,7 +850,7 @@ The first parameter is the dataframe, the next two are the 
epsilon and min_point
 === "Java"
 
        ```java
-       import org.apache.sedona.stats.DBSCAN;
+       import org.apache.sedona.stats.clustering.DBSCAN;
 
        DBSCAN.dbscan(df, 0.1, 5).show();
        ```
@@ -858,7 +858,7 @@ The first parameter is the dataframe, the next two are the 
epsilon and min_point
 === "Python"
 
        ```python
-       from sedona.stats.dbscan import dbscan
+       from sedona.stats.clustering.dbscan import dbscan
 
        dbscan(df, 0.1, 5).show()
        ```
@@ -885,6 +885,67 @@ The output will look like this:
 +----------------+---+------+-------+
 ```
 
+## Calculate the Local Outlier Factor (LOF)
+
+Sedona provides an implementation of the [Local Outlier 
Factor](https://en.wikipedia.org/wiki/Local_outlier_factor) algorithm to 
identify anomalous data.
+
+The algorithm is available as a Scala and Python function called on a spatial 
dataframe. The returned dataframe has an additional column added containing the 
local outlier factor.
+
+The first parameter is the dataframe, the next is the number of nearest 
neighbors to consider use in calculating the score.
+
+=== "Scala"
+
+       ```scala
+       import 
org.apache.sedona.stats.outlierDetection.LocalOutlierFactor.localOutlierFactor
+
+    localOutlierFactor(df, 20).show()
+       ```
+
+=== "Java"
+
+       ```java
+       import org.apache.sedona.stats.outlierDetection.LocalOutlierFactor;
+
+       LocalOutlierFactor.localOutlierFactor(df, 20).show();
+       ```
+
+=== "Python"
+
+       ```python
+       from sedona.stats.outlier_detection.local_outlier_factor import 
local_outlier_factor
+
+       local_outlier_factor(df, 20).show()
+       ```
+
+The output will look like this:
+
+```
++--------------------+------------------+
+|            geometry|               lof|
++--------------------+------------------+
+|POINT (-2.0231305...| 0.952098153363662|
+|POINT (-2.0346944...|0.9975325496668104|
+|POINT (-2.2040074...|1.0825843906411081|
+|POINT (1.61573501...|1.7367129352162634|
+|POINT (-2.1176324...|1.5714144683150393|
+|POINT (-2.2349759...|0.9167275845938276|
+|POINT (1.65470192...| 1.046231536764447|
+|POINT (0.62624112...|1.1988700676990034|
+|POINT (2.01746261...|1.1060219481067417|
+|POINT (-2.0483857...|1.0775553430145446|
+|POINT (2.43969463...|1.1129132178576646|
+|POINT (-2.2425480...| 1.104108012697006|
+|POINT (-2.7859235...|  2.86371824574529|
+|POINT (-1.9738858...|1.0398822680356794|
+|POINT (2.00153403...| 0.927409656346015|
+|POINT (2.06422812...|0.9222203762264445|
+|POINT (-1.7533819...|1.0273650471626696|
+|POINT (-2.2030766...| 0.964744555830738|
+|POINT (-1.8509857...|1.0375927869698574|
+|POINT (2.10849080...|1.0753419197322656|
++--------------------+------------------+
+```
+
 ## Run spatial queries
 
 After creating a Geometry type column, you are able to run spatial queries.
diff --git a/python/sedona/stats/outlier_detection/__init__.py 
b/python/sedona/stats/outlier_detection/__init__.py
new file mode 100644
index 000000000..4dd25a3ff
--- /dev/null
+++ b/python/sedona/stats/outlier_detection/__init__.py
@@ -0,0 +1,18 @@
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing,
+#  software distributed under the License is distributed on an
+#  "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+#  KIND, either express or implied.  See the License for the
+#  specific language governing permissions and limitations
+#  under the License.
+
+"""Algorithms for detecting outliers in spatial datasets."""
diff --git a/python/sedona/stats/outlier_detection/local_outlier_factor.py 
b/python/sedona/stats/outlier_detection/local_outlier_factor.py
new file mode 100644
index 000000000..3050d216b
--- /dev/null
+++ b/python/sedona/stats/outlier_detection/local_outlier_factor.py
@@ -0,0 +1,60 @@
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing,
+#  software distributed under the License is distributed on an
+#  "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+#  KIND, either express or implied.  See the License for the
+#  specific language governing permissions and limitations
+#  under the License.
+
+"""Functions related to calculating the local outlier factor of a dataset."""
+from typing import Optional
+
+from pyspark.sql import DataFrame, SparkSession
+
+ID_COLUMN_NAME = "__id"
+CONTENTS_COLUMN_NAME = "__contents"
+
+
+def local_outlier_factor(
+    dataframe: DataFrame,
+    k: int = 20,
+    geometry: Optional[str] = None,
+    handle_ties: bool = False,
+    use_spheroid=False,
+):
+    """Annotates a dataframe with a column containing the local outlier factor 
for each data record.
+
+    The dataframe should contain at least one GeometryType column. Rows must 
be unique. If one geometry column is
+    present it will be used automatically. If two are present, the one named 
'geometry' will be used. If more than one
+    are present and neither is named 'geometry', the column name must be 
provided.
+
+    Args:
+        dataframe: apache sedona idDataframe containing the point geometries
+        k: number of nearest neighbors that will be considered for the LOF 
calculation
+        geometry: name of the geometry column
+        handle_ties: whether to handle ties in the k-distance calculation. 
Default is false
+        use_spheroid: whether to use a cartesian or spheroidal distance 
calculation. Default is false
+
+    Returns:
+        A PySpark DataFrame containing the lof for each row
+    """
+    sedona = SparkSession.getActiveSession()
+
+    result_df = 
sedona._jvm.org.apache.sedona.stats.outlierDetection.LocalOutlierFactor.localOutlierFactor(
+        dataframe._jdf,
+        k,
+        geometry,
+        handle_ties,
+        use_spheroid,
+    )
+
+    return DataFrame(result_df, sedona)
diff --git a/python/tests/stats/test_local_outlier_factor.py 
b/python/tests/stats/test_local_outlier_factor.py
new file mode 100644
index 000000000..52ec860a0
--- /dev/null
+++ b/python/tests/stats/test_local_outlier_factor.py
@@ -0,0 +1,107 @@
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing,
+#  software distributed under the License is distributed on an
+#  "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+#  KIND, either express or implied.  See the License for the
+#  specific language governing permissions and limitations
+#  under the License.
+
+import numpy as np
+import pyspark.sql.functions as f
+import pytest
+from pyspark.sql import DataFrame
+from pyspark.sql.types import DoubleType, IntegerType, StructField, StructType
+from sklearn.neighbors import LocalOutlierFactor
+from tests.test_base import TestBase
+
+from sedona.sql.st_constructors import ST_MakePoint
+from sedona.sql.st_functions import ST_X, ST_Y
+from sedona.stats.outlier_detection.local_outlier_factor import 
local_outlier_factor
+
+
+class TestLOF(TestBase):
+    def get_small_data(self) -> DataFrame:
+        schema = StructType(
+            [
+                StructField("id", IntegerType(), True),
+                StructField("x", DoubleType(), True),
+                StructField("y", DoubleType(), True),
+            ]
+        )
+        return self.spark.createDataFrame(
+            [
+                (1, 1.0, 2.0),
+                (2, 2.0, 2.0),
+                (3, 3.0, 3.0),
+            ],
+            schema,
+        ).select("id", ST_MakePoint("x", "y").alias("geometry"))
+
+    def get_medium_data(self):
+        np.random.seed(42)
+
+        X_inliers = 0.3 * np.random.randn(100, 2)
+        X_inliers = np.r_[X_inliers + 2, X_inliers - 2]
+        X_outliers = np.random.uniform(low=-4, high=4, size=(20, 2))
+        return np.r_[X_inliers, X_outliers]
+
+    def get_medium_dataframe(self, data):
+        schema = StructType(
+            [StructField("x", DoubleType(), True), StructField("y", 
DoubleType(), True)]
+        )
+
+        return (
+            self.spark.createDataFrame(data, schema)
+            .select(ST_MakePoint("x", "y").alias("geometry"))
+            .withColumn("anotherColumn", f.rand())
+        )
+
+    def compare_results(self, actual, expected, k):
+        assert len(actual) == len(expected)
+        missing = set(expected.keys()) - set(actual.keys())
+        assert len(missing) == 0
+        big_diff = {
+            k: (v, expected[k], abs(1 - v / expected[k]))
+            for k, v in actual.items()
+            if abs(1 - v / expected[k]) > 0.0000000001
+        }
+        assert len(big_diff) == 0
+
+    @pytest.mark.parametrize("k", [5, 21, 3])
+    def test_lof_matches_sklearn(self, k):
+        data = self.get_medium_data()
+        actual = {
+            tuple(x[0]): x[1]
+            for x in 
local_outlier_factor(self.get_medium_dataframe(data.tolist()), k)
+            .select(f.array(ST_X("geometry"), ST_Y("geometry")), "lof")
+            .collect()
+        }
+        clf = LocalOutlierFactor(n_neighbors=k, contamination="auto")
+        clf.fit_predict(data)
+        expected = dict(
+            zip(
+                [tuple(x) for x in data],
+                [float(-x) for x in clf.negative_outlier_factor_],
+            )
+        )
+        self.compare_results(actual, expected, k)
+
+    # TODO uncomment when KNN join supports empty dfs
+    # def test_handle_empty_dataframe(self):
+    #     empty_df = self.spark.createDataFrame([], 
self.get_small_data().schema)
+    #     result_df = local_outlier_factor(empty_df, 2)
+    #
+    #     assert 0 == result_df.count()
+
+    def test_raise_error_for_invalid_k_value(self):
+        with pytest.raises(Exception):
+            local_outlier_factor(self.get_small_data(), -1)
diff --git 
a/spark/common/src/main/scala/org/apache/sedona/stats/outlierDetection/LocalOutlierFactor.scala
 
b/spark/common/src/main/scala/org/apache/sedona/stats/outlierDetection/LocalOutlierFactor.scala
new file mode 100644
index 000000000..b98919de2
--- /dev/null
+++ 
b/spark/common/src/main/scala/org/apache/sedona/stats/outlierDetection/LocalOutlierFactor.scala
@@ -0,0 +1,148 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sedona.stats.outlierDetection
+
+import org.apache.sedona.stats.Util.getGeometryColumnName
+import org.apache.spark.sql.sedona_sql.expressions.st_functions.{ST_Distance, 
ST_DistanceSpheroid}
+import org.apache.spark.sql.{Column, DataFrame, SparkSession, functions => f}
+
+object LocalOutlierFactor {
+
+  private val ID_COLUMN_NAME = "__id"
+  private val CONTENTS_COLUMN_NAME = "__contents"
+
+  /**
+   * Annotates a dataframe with a column containing the local outlier factor 
for each data record.
+   * The dataframe should contain at least one GeometryType column. Rows must 
be unique. If one
+   * geometry column is present it will be used automatically. If two are 
present, the one named
+   * 'geometry' will be used. If more than one are present and neither is 
named 'geometry', the
+   * column name must be provided.
+   *
+   * @param dataframe
+   *   dataframe containing the point geometries
+   * @param k
+   *   number of nearest neighbors that will be considered for the LOF 
calculation
+   * @param geometry
+   *   name of the geometry column
+   * @param handleTies
+   *   whether to handle ties in the k-distance calculation. Default is false
+   * @param useSpheroid
+   *   whether to use a cartesian or spheroidal distance calculation. Default 
is false
+   *
+   * @return
+   *   A DataFrame containing the lof for each row
+   */
+  def localOutlierFactor(
+      dataframe: DataFrame,
+      k: Int = 20,
+      geometry: String = null,
+      handleTies: Boolean = false,
+      useSpheroid: Boolean = false): DataFrame = {
+
+    if (k < 1)
+      throw new IllegalArgumentException("k must be a positive integer")
+
+    val prior: String = if (handleTies) {
+      val prior =
+        SparkSession.getActiveSession.get.conf
+          .get("spark.sedona.join.knn.includeTieBreakers", "false")
+      
SparkSession.getActiveSession.get.conf.set("spark.sedona.join.knn.includeTieBreakers",
 true)
+      prior
+    } else "false" // else case to make compiler happy
+
+    val distanceFunction: (Column, Column) => Column =
+      if (useSpheroid) ST_DistanceSpheroid else ST_Distance
+    val useSpheroidString = if (useSpheroid) "True" else "False" // for the 
SQL expression
+
+    val geometryColumn = if (geometry == null) 
getGeometryColumnName(dataframe) else geometry
+
+    val KNNFunction = "ST_KNN"
+
+    // Store original contents, prep necessary columns
+    val formattedDataframe = dataframe
+      .withColumn(CONTENTS_COLUMN_NAME, f.struct("*"))
+      .withColumn(ID_COLUMN_NAME, 
f.sha2(f.to_json(f.col(CONTENTS_COLUMN_NAME)), 256))
+      .withColumnRenamed(geometryColumn, "geometry")
+
+    val kDistanceDf = formattedDataframe
+      .alias("l")
+      .join(
+        formattedDataframe.alias("r"),
+        // k + 1 because we are not counting the row matching to itself
+        f.expr(f"$KNNFunction(l.geometry, r.geometry, $k + 1, 
$useSpheroidString)") && f.col(
+          f"l.$ID_COLUMN_NAME") =!= f.col(f"r.$ID_COLUMN_NAME"))
+      .groupBy(f"l.$ID_COLUMN_NAME")
+      .agg(
+        f.first("l.geometry").alias("geometry"),
+        f.first(f"l.$CONTENTS_COLUMN_NAME").alias(CONTENTS_COLUMN_NAME),
+        f.max(distanceFunction(f.col("l.geometry"), 
f.col("r.geometry"))).alias("k_distance"),
+        f.collect_list(f"r.$ID_COLUMN_NAME").alias("neighbors"))
+      .checkpoint()
+
+    val lrdDf = kDistanceDf
+      .alias("A")
+      .select(
+        f.col(ID_COLUMN_NAME).alias("a_id"),
+        f.col(CONTENTS_COLUMN_NAME),
+        f.col("geometry").alias("a_geometry"),
+        f.explode(f.col("neighbors")).alias("n_id"))
+      .join(
+        kDistanceDf.select(
+          f.col(ID_COLUMN_NAME).alias("b_id"),
+          f.col("geometry").alias("b_geometry"),
+          f.col("k_distance").alias("b_k_distance")),
+        f.expr("n_id = b_id"))
+      .select(
+        f.col("a_id"),
+        f.col("b_id"),
+        f.col(CONTENTS_COLUMN_NAME),
+        f.array_max(
+          f.array(
+            f.col("b_k_distance"),
+            distanceFunction(f.col("a_geometry"), f.col("b_geometry"))))
+          .alias("rd"))
+      .groupBy("a_id")
+      .agg(
+        // + 1e-10 to avoid division by zero, matches sklearn impl
+        (f.lit(1.0) / (f.mean("rd") + 1e-10)).alias("lrd"),
+        f.collect_list(f.col("b_id")).alias("neighbors"),
+        f.first(CONTENTS_COLUMN_NAME).alias(CONTENTS_COLUMN_NAME))
+
+    val ret = lrdDf
+      .select(
+        f.col("a_id"),
+        f.col("lrd").alias("a_lrd"),
+        f.col(CONTENTS_COLUMN_NAME),
+        f.explode(f.col("neighbors")).alias("n_id"))
+      .join(
+        lrdDf.select(f.col("a_id").alias("b_id"), f.col("lrd").alias("b_lrd")),
+        f.expr("n_id = b_id"))
+      .groupBy("a_id")
+      .agg(
+        f.first(CONTENTS_COLUMN_NAME).alias(CONTENTS_COLUMN_NAME),
+        (f.sum("b_lrd") / (f.count("b_lrd") * f.first("a_lrd"))).alias("lof"))
+      .select(f.col(f"$CONTENTS_COLUMN_NAME.*"), f.col("lof"))
+
+    if (handleTies)
+      SparkSession.getActiveSession.get.conf
+        .set("spark.sedona.join.knn.includeTieBreakers", prior)
+    ret
+  }
+
+}
diff --git 
a/spark/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala 
b/spark/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala
index b6c073574..dc9c841eb 100644
--- a/spark/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala
+++ b/spark/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala
@@ -32,6 +32,7 @@ import org.locationtech.jts.geom._
 import org.scalatest.{BeforeAndAfterAll, FunSpec}
 
 import java.io.File
+import java.nio.file.Files
 
 trait TestBaseScala extends FunSpec with BeforeAndAfterAll {
   Logger.getRootLogger.setLevel(Level.WARN)
@@ -97,6 +98,7 @@ trait TestBaseScala extends FunSpec with BeforeAndAfterAll {
   override def beforeAll(): Unit = {
     super.beforeAll()
     SedonaContext.create(sparkSession)
+    sc.setCheckpointDir(Files.createTempDirectory("checkpoints").toString)
   }
 
   override def afterAll(): Unit = {
diff --git 
a/spark/common/src/test/scala/org/apache/sedona/stats/outlierDetection/LocalOutlierFactorTest.scala
 
b/spark/common/src/test/scala/org/apache/sedona/stats/outlierDetection/LocalOutlierFactorTest.scala
new file mode 100644
index 000000000..c401f599b
--- /dev/null
+++ 
b/spark/common/src/test/scala/org/apache/sedona/stats/outlierDetection/LocalOutlierFactorTest.scala
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sedona.stats.outlierDetection
+
+import org.apache.sedona.sql.TestBaseScala
+import org.apache.spark.sql.sedona_sql.expressions.st_constructors.ST_MakePoint
+import org.apache.spark.sql.{DataFrame, functions => f}
+
+class LocalOutlierFactorTest extends TestBaseScala {
+
+  case class Point(id: Int, x: Double, y: Double, expected_lof: Double)
+
+  def get_data(): DataFrame = {
+    // expected pulled from sklearn.neighbors.LocalOutlierFactor
+    sparkSession
+      .createDataFrame(
+        Seq(
+          Point(0, 2.0, 2.0, 0.8747092756023607),
+          Point(1, 2.0, 3.0, 0.9460678118688717),
+          Point(2, 3.0, 3.0, 1.0797104443580348),
+          Point(3, 3.0, 2.0, 1.0517766952923475),
+          Point(4, 3.0, 1.0, 1.0797104443580348),
+          Point(5, 2.0, 1.0, 0.9460678118688719),
+          Point(6, 1.0, 1.0, 1.0797104443580348),
+          Point(7, 1.0, 2.0, 1.0517766952923475),
+          Point(8, 1.0, 3.0, 1.0797104443580348),
+          Point(9, 0.0, 2.0, 1.0517766952923475),
+          Point(10, 4.0, 2.0, 1.0517766952923475)))
+      .withColumn("geometry", ST_MakePoint("x", "y"))
+      .drop("x", "y")
+  }
+
+  describe("LocalOutlierFactor") {
+    it("returns correct results") {
+      val resultDf = LocalOutlierFactor.localOutlierFactor(get_data(), 4)
+      assert(resultDf.count() == 11)
+      assert(
+        resultDf
+          .filter(f.abs(f.col("expected_lof") - f.col("lof")) < .00000001)
+          .count() == 11)
+    }
+  }
+}

Reply via email to