petern48 commented on code in PR #241:
URL: https://github.com/apache/sedona-db/pull/241#discussion_r2467958615


##########
python/sedonadb/tests/functions/test_functions.py:
##########
@@ -176,6 +176,115 @@ def test_st_buffer(eng, geom, dist, expected_area):
     )
 
 
[email protected]("eng", [SedonaDB, PostGIS])
[email protected](
+    ("geom", "dist", "buffer_style_parameters", "expected_area"),
+    [
+        (None, None, None, None),
+        ("POINT(100 90)", 50, "'quad_segs=8'", 7803.612880645131),
+        (
+            "LINESTRING(50 50,150 150,150 50)",
+            10,
+            "'endcap=round join=round'",
+            5016.204476944362,
+        ),
+        (
+            "POLYGON((0 0, 0 10, 10 10, 10 0, 0 0))",
+            2,
+            "'join=miter'",
+            196.0,
+        ),
+        (
+            "LINESTRING(0 0, 10 0)",
+            5,
+            "'endcap=square'",
+            200.0,
+        ),
+        (
+            "POINT(0 0)",
+            10,
+            "'quad_segs=4'",
+            306.1467458920718,
+        ),
+        (
+            "POINT(0 0)",
+            10,
+            "'quad_segs=16'",
+            313.654849054594,
+        ),
+        (
+            "LINESTRING(0 0, 100 0, 100 100)",
+            5,
+            "'join=bevel'",
+            2065.536128806451,
+        ),
+        (
+            "LINESTRING(0 0, 50 0)",
+            10,
+            "'endcap=flat'",
+            1000.0,
+        ),
+        (
+            "POLYGON((0 0, 0 20, 20 20, 20 0, 0 0))",
+            -2,
+            "'join=round'",
+            256.0,
+        ),
+        (
+            "POLYGON((0 0, 0 100, 100 100, 100 0, 0 0), (20 20, 20 80, 80 80, 
80 20, 20 20))",
+            5,
+            "'join=round quad_segs=4'",
+            9576.536686473019,
+        ),
+        (
+            "MULTIPOINT((10 10), (30 30))",
+            5,
+            "'quad_segs=8'",
+            156.0722576129026,
+        ),
+        (
+            "GEOMETRYCOLLECTION(POINT(10 10), LINESTRING(50 50, 60 60))",
+            3,
+            "'endcap=round join=round'",
+            141.0388264830308,
+        ),
+        (
+            "POLYGON((0 0, 0 10, 10 10, 10 0, 0 0))",
+            0,
+            "'join=miter'",
+            100.0,
+        ),
+        (
+            "POINT(0 0)",
+            0.1,
+            "'quad_segs=8'",
+            0.031214451522580514,
+        ),
+        (
+            "LINESTRING(0 0, 50 0, 50 50)",
+            10,
+            "'join=miter miter_limit=2'",
+            2312.1445152258043,
+        ),
+        (
+            "LINESTRING(0 0, 0 100)",
+            10,
+            "'side=left'",
+            1000.0,
+        ),

Review Comment:
   This one test case with `side` is too simple that it's not catching the edge 
case behavior.
   
   Here are 4 cases I want to add. You're welcome to modify the exact test 
cases, but I think we should follow the comments.
   
   ```python
           # non-default side and should implicitly use square endcap
           (
               "LINESTRING (50 50, 150 150, 150 50)",
               100,
               "'side=right'",
               <TODO>
           ),
           # non-default side should implicitly use square endcap
           (
               "POLYGON ((50 50, 50 150, 150 150, 150 50, 50 50))",
               20,
               "'side=left'",
               <TODO>,
           ),
           # non-default side and non-default flat endcap should not use square 
endcap
           # Specifying flat here doesn't actually make a difference. I found 
it hard to find a good test case here
           (
               "POLYGON ((50 50, 50 150, 150 150, 150 50, 50 50))",
               20,
               "'side=right endcap=flat'",
               <TODO>,
           ),
           # explicitly specifying default side=both should not set 
endcap=square
           (
               "LINESTRING (50 50, 150 150, 150 50)",
               100,
               "'side=both'",
               <TODO>,
           ),
   ```
   
   Here are the results I got, the current ode currently fails the first 3, 
while it passes the 4th. Currently, around ~2x off from the correct answer on 
the wrong ones.
   1. Sedona: 35847.55494469865. PostGIS: 16285.07633336958
   2. Sedona: 10000.0 PostGIS: 19248.578060903223
   3. Sedona: 10000.0 PostGIS: 3600.0
   4. Sedona: 69888.08929186598  PostGIS: 69888.089291866
   
   After implementing the changes, if you're not careful, you might fail the 
3rd or 4th cases for different reasons. Just make sure you follow the comments 
I left there. It's not hard, but it's easy to miss. This is the exact bug that 
the PR I mentioned above was fixing, so consider checking out the java 
implementation there.



##########
c/sedona-geos/src/st_buffer.rs:
##########
@@ -54,65 +67,191 @@ impl SedonaScalarKernel for STBuffer {
         arg_types: &[SedonaType],
         args: &[ColumnarValue],
     ) -> Result<ColumnarValue> {
-        // Default params
-        let params_builder = BufferParams::builder();
+        invoke_batch_impl(arg_types, args)
+    }
+}
 
-        let params = params_builder
-            .build()
-            .map_err(|e| DataFusionError::External(Box::new(e)))?;
-
-        // Extract the constant scalar value before looping over the input 
geometries
-        let distance: Option<f64>;
-        let arg1 = args[1].cast_to(&DataType::Float64, None)?;
-        if let ColumnarValue::Scalar(scalar_arg) = &arg1 {
-            if scalar_arg.is_null() {
-                distance = None;
-            } else {
-                distance = Some(f64::try_from(scalar_arg.clone())?);
-            }
-        } else {
-            return Err(DataFusionError::Execution(format!(
-                "Invalid distance: {:?}",
-                args[1]
-            )));
-        }
+pub fn st_buffer_style_impl() -> ScalarKernelRef {
+    Arc::new(STBufferStyle {})
+}
+#[derive(Debug)]
+struct STBufferStyle {}
 
-        let executor = GeosExecutor::new(arg_types, args);
-        let mut builder = BinaryBuilder::with_capacity(
-            executor.num_iterations(),
-            WKB_MIN_PROBABLE_BYTES * executor.num_iterations(),
-        );
-        executor.execute_wkb_void(|wkb| {
-            match (wkb, distance) {
-                (Some(wkb), Some(distance)) => {
-                    invoke_scalar(&wkb, distance, &params, &mut builder)?;
-                    builder.append_value([]);
-                }
-                _ => builder.append_null(),
-            }
+impl SedonaScalarKernel for STBufferStyle {
+    fn return_type(&self, args: &[SedonaType]) -> Result<Option<SedonaType>> {
+        let matcher = ArgMatcher::new(
+            vec![
+                ArgMatcher::is_geometry(),
+                ArgMatcher::is_numeric(),
+                ArgMatcher::is_string(),
+            ],
+            WKB_GEOMETRY,
+        );
 
-            Ok(())
-        })?;
+        matcher.match_args(args)
+    }
 
-        executor.finish(Arc::new(builder.finish()))
+    fn invoke_batch(
+        &self,
+        arg_types: &[SedonaType],
+        args: &[ColumnarValue],
+    ) -> Result<ColumnarValue> {
+        invoke_batch_impl(arg_types, args)
     }
 }
 
+fn invoke_batch_impl(arg_types: &[SedonaType], args: &[ColumnarValue]) -> 
Result<ColumnarValue> {
+    let executor = GeosExecutor::new(arg_types, args);
+    let mut builder = BinaryBuilder::with_capacity(
+        executor.num_iterations(),
+        WKB_MIN_PROBABLE_BYTES * executor.num_iterations(),
+    );
+
+    // Extract Args
+    let distance_value = args[1]
+        .cast_to(&DataType::Float64, None)?
+        .to_array(executor.num_iterations())?;
+    let distance_array = as_float64_array(&distance_value)?;
+    let mut distance_iter = distance_array.iter();
+
+    let buffer_style_params = extract_optional_string(args.get(2))?;
+
+    // Build BufferParams based on style parameters
+    let params = parse_buffer_params(buffer_style_params.as_deref())?;
+
+    executor.execute_wkb_void(|wkb| {
+        match (wkb, distance_iter.next().unwrap()) {
+            (Some(wkb), Some(distance)) => {
+                builder.append_value(invoke_scalar(&wkb, distance, &params)?);
+            }
+            _ => builder.append_null(),
+        }
+        Ok(())
+    })?;
+
+    executor.finish(Arc::new(builder.finish()))
+}
+
 fn invoke_scalar(
     geos_geom: &geos::Geometry,
     distance: f64,
     params: &BufferParams,
-    writer: &mut impl std::io::Write,
-) -> Result<()> {
+) -> Result<Vec<u8>> {
     let geometry = geos_geom
         .buffer_with_params(distance, params)
         .map_err(|e| DataFusionError::External(Box::new(e)))?;
     let wkb = geometry
         .to_wkb()
         .map_err(|e| DataFusionError::Execution(format!("Failed to convert to 
wkb: {e}")))?;
 
-    writer.write_all(wkb.as_ref())?;
-    Ok(())
+    Ok(wkb.into())
+}
+
+fn extract_optional_string(arg: Option<&ColumnarValue>) -> 
Result<Option<String>> {
+    let Some(arg) = arg else { return Ok(None) };
+    let casted = arg.cast_to(&DataType::Utf8, None)?;
+    match &casted {
+        ColumnarValue::Scalar(ScalarValue::Utf8(Some(s)) | 
ScalarValue::LargeUtf8(Some(s))) => {
+            Ok(Some(s.clone()))
+        }
+        ColumnarValue::Scalar(scalar) if scalar.is_null() => Ok(None),
+        ColumnarValue::Scalar(_) => Ok(None),
+        _ => Err(DataFusionError::Execution(format!(
+            "Expected scalar bufferStyleParameters, got: {arg:?}",
+        ))),
+    }
+}
+
+fn parse_buffer_params(params_str: Option<&str>) -> Result<BufferParams> {
+    let Some(params_str) = params_str else {
+        return BufferParams::builder()
+            .build()
+            .map_err(|e| DataFusionError::External(Box::new(e)));
+    };
+
+    let mut params_builder = BufferParams::builder();
+
+    for param in params_str.split_whitespace() {
+        let Some((key, value)) = param.split_once('=') else {
+            return Err(DataFusionError::Execution(format!(
+                "Missing value for buffer parameter: {param}",
+            )));
+        };
+
+        if key.eq_ignore_ascii_case("endcap") {
+            params_builder = 
params_builder.end_cap_style(parse_cap_style(value)?);
+        } else if key.eq_ignore_ascii_case("join") {
+            params_builder = 
params_builder.join_style(parse_join_style(value)?);
+        } else if key.eq_ignore_ascii_case("side") {
+            params_builder = 
params_builder.single_sided(is_single_sided(value)?);

Review Comment:
   We're still missing this logic:
   
https://github.com/apache/sedona/blame/d6ea87cf9f2a27adee00ff0fb41ee92002b9a834/common/src/main/java/org/apache/sedona/common/Functions.java#L399-L402
   
   Here's a better docstring to follow because the PostGIS docs hasn't been 
updated yet.
   
   - `side=both|left|right` : Defaults to `both`. Setting `left` or `right` 
enables a single-sided buffer operation on the geometry, with the buffered side 
aligned according to the direction of the line. This functionality is specific 
to LINESTRING geometry and has no impact on POINT or POLYGON geometries. By 
default, square end caps are applied when `left` or `right` are specified.
   



##########
c/sedona-geos/src/st_buffer.rs:
##########
@@ -163,4 +302,203 @@ mod tests {
         let envelope_result = 
envelope_tester.invoke_array(buffer_result).unwrap();
         assert_array_equal(&envelope_result, &expected_envelope);
     }
+
+    #[rstest]
+    fn udf_with_buffer_params(#[values(WKB_GEOMETRY, WKB_VIEW_GEOMETRY)] 
sedona_type: SedonaType) {
+        let udf = SedonaScalarUDF::from_kernel("st_buffer", 
st_buffer_style_impl());
+        let tester = ScalarUdfTester::new(
+            udf.into(),
+            vec![
+                sedona_type.clone(),
+                SedonaType::Arrow(DataType::Float64),
+                SedonaType::Arrow(DataType::Utf8),
+            ],
+        );
+        tester.assert_return_type(WKB_GEOMETRY);
+
+        let envelope_udf = sedona_functions::st_envelope::st_envelope_udf();
+        let envelope_tester = ScalarUdfTester::new(envelope_udf.into(), 
vec![WKB_GEOMETRY]);
+
+        let buffer_result_flat = tester
+            .invoke_scalar_scalar_scalar("LINESTRING (0 0, 10 0)", 2.0, 
"endcap=flat".to_string())
+            .unwrap();
+        let envelope_result = 
envelope_tester.invoke_scalar(buffer_result_flat).unwrap();
+        let expected_envelope = "POLYGON((0 -2, 0 2, 10 2, 10 -2, 0 -2))";
+        tester.assert_scalar_result_equals(envelope_result, expected_envelope);
+
+        let buffer_result_square = tester
+            .invoke_scalar_scalar_scalar("LINESTRING (0 0, 10 0)", 1.0, 
"endcap=square".to_string())
+            .unwrap();
+        let envelope_result = 
envelope_tester.invoke_scalar(buffer_result_square).unwrap();
+        let expected_envelope = "POLYGON((-1 -1, -1 1, 11 1, 11 -1, -1 -1))";
+        tester.assert_scalar_result_equals(envelope_result, expected_envelope);
+    }
+
+    #[rstest]
+    fn udf_with_quad_segs(#[values(WKB_GEOMETRY, WKB_VIEW_GEOMETRY)] 
sedona_type: SedonaType) {
+        let udf = SedonaScalarUDF::from_kernel("st_buffer", 
st_buffer_style_impl());
+        let tester = ScalarUdfTester::new(
+            udf.into(),
+            vec![
+                sedona_type.clone(),
+                SedonaType::Arrow(DataType::Float64),
+                SedonaType::Arrow(DataType::Utf8),
+            ],
+        );
+        tester.assert_return_type(WKB_GEOMETRY);
+
+        let envelope_udf = sedona_functions::st_envelope::st_envelope_udf();
+        let envelope_tester = ScalarUdfTester::new(envelope_udf.into(), 
vec![WKB_GEOMETRY]);
+        let input_wkt = "POINT (5 5)";
+        let buffer_dist = 3.0;
+
+        let buffer_result_default = tester
+            .invoke_scalar_scalar_scalar(input_wkt, buffer_dist, 
"endcap=round".to_string())
+            .unwrap();
+        let envelope_result_default = envelope_tester
+            .invoke_scalar(buffer_result_default)
+            .unwrap();
+
+        let expected_envelope = "POLYGON((2 2, 2 8, 8 8, 8 2, 2 2))";
+        tester.assert_scalar_result_equals(envelope_result_default, 
expected_envelope);
+
+        let buffer_result_low_segs = tester
+            .invoke_scalar_scalar_scalar(
+                input_wkt,
+                buffer_dist,
+                "quad_segs=1 endcap=round".to_string(),
+            )
+            .unwrap();
+        let envelope_result_low_segs = envelope_tester
+            .invoke_scalar(buffer_result_low_segs)
+            .unwrap();
+        tester.assert_scalar_result_equals(envelope_result_low_segs, 
expected_envelope);
+    }
+
+    #[test]
+    fn test_parse_buffer_params_invalid_endcap() {
+        let err = parse_buffer_params(Some("endcap=invalid")).err().unwrap();
+        assert_eq!(
+            err.message(),
+            "Invalid endcap style: 'invalid'. Valid options: round, flat, 
butt, square"
+        );
+    }
+
+    #[test]
+    fn test_parse_buffer_params_invalid_join() {
+        let err = parse_buffer_params(Some("join=invalid")).err().unwrap();
+        assert_eq!(
+            err.message(),
+            "Invalid join style: 'invalid'. Valid options: round, mitre, 
miter, bevel"
+        );
+    }
+
+    #[test]
+    fn test_parse_buffer_params_invalid_side() {
+        let err = parse_buffer_params(Some("side=invalid")).err().unwrap();
+        assert_eq!(
+            err.message(),
+            "Invalid side: 'invalid'. Valid options: both, left, right"
+        );
+    }
+
+    #[test]
+    fn test_parse_buffer_params_invalid_mitre_limit() {
+        let err = parse_buffer_params(Some("mitre_limit=not_a_number"))
+            .err()
+            .unwrap();
+        assert_eq!(
+            err.message(),
+            "Invalid mitre_limit value: 'not_a_number'. Expected a valid 
number"
+        );
+    }
+
+    #[test]
+    fn test_parse_buffer_params_invalid_miter_limit() {
+        let err = parse_buffer_params(Some("miter_limit=abc")).err().unwrap();
+        assert_eq!(
+            err.message(),
+            "Invalid mitre_limit value: 'abc'. Expected a valid number"
+        );
+    }
+
+    #[test]
+    fn test_parse_buffer_params_invalid_quad_segs() {
+        let err = parse_buffer_params(Some("quad_segs=not_an_int"))
+            .err()
+            .unwrap();
+        assert_eq!(
+            err.message(),
+            "Invalid quadrant_segments value: 'not_an_int'. Expected a valid 
number"
+        );
+    }
+
+    #[test]
+    fn test_parse_buffer_params_invalid_quadrant_segments() {
+        let err = parse_buffer_params(Some("quadrant_segments=xyz"))
+            .err()
+            .unwrap();
+        assert_eq!(
+            err.message(),
+            "Invalid quadrant_segments value: 'xyz'. Expected a valid number"
+        );
+    }
+
+    #[test]
+    fn test_parse_buffer_params_multiple_invalid_params() {
+        // Test that the first invalid parameter is caught
+        let err = parse_buffer_params(Some("endcap=wrong join=mitre"))
+            .err()
+            .unwrap();
+        assert_eq!(
+            err.message(),
+            "Invalid endcap style: 'wrong'. Valid options: round, flat, butt, 
square"
+        );
+    }
+
+    #[test]
+    fn test_parse_buffer_params_invalid_mixed_with_valid() {
+        // Test invalid parameter after valid ones
+        let err = parse_buffer_params(Some("endcap=round join=invalid"))
+            .err()
+            .unwrap();
+        assert_eq!(
+            err.message(),
+            "Invalid join style: 'invalid'. Valid options: round, mitre, 
miter, bevel"
+        );
+    }
+
+    #[test]
+    fn test_parse_buffer_params_invalid_param_name() {
+        let err = parse_buffer_params(Some("unknown_param=value"))
+            .err()
+            .unwrap();
+        assert_eq!(
+            err.message(),
+            "Invalid buffer parameter: unknown_param (accept: 'endcap', 
'join', 'mitre_limit', 'miter_limit', 'quad_segs', 'quadrant_segments' and 
'side')"
+        );
+    }
+
+    #[test]
+    fn test_parse_buffer_params_missing_value() {
+        let err = parse_buffer_params(Some("endcap=round bare_param 
join=mitre"))
+            .err()
+            .unwrap();
+        assert_eq!(
+            err.message(),
+            "Missing value for buffer parameter: bare_param"
+        );
+    }
+
+    #[test]
+    fn test_parse_buffer_params_duplicate_params_no_error() {
+        let result = parse_buffer_params(Some("endcap=round endcap=flat"));
+        assert!(result.is_ok());
+    }
+
+    #[test]
+    fn test_parse_buffer_params_quad_segs_out_of_range() {
+        let result = parse_buffer_params(Some("quad_segs=-5"));
+        assert!(result.is_ok());
+    }

Review Comment:
   It would be nice to write tests that directly check the values set for the 
params object (rather than testing the resulting geometries). Particularly, it 
would test the cases I mentioned in the Python tests. Here they are again for 
convenience.
   
   ```
   1. non-default side and should implicitly use square endcap
   2. non-default side should implicitly use square endcap
   3. non-default side and non-default flat endcap should not use square endcap
   4. explicitly specifying default side=both should not set endcap=square
   ```
   
   I know this PR already feels long, so if you're not up for it, feel free to 
create a new issue for you or someone to do later. As long as the correct 
behavior is implemented and we have some python tests for it, I'm good to merge 
it.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to