pwrliang commented on code in PR #465:
URL: https://github.com/apache/sedona-db/pull/465#discussion_r2705662532


##########
c/sedona-libgpuspatial/src/lib.rs:
##########
@@ -77,197 +140,411 @@ impl GpuSpatialContext {
         #[cfg(gpu_available)]
         {
             Ok(Self {
-                joiner: None,
-                context: None,
+                rt_engine: None,
+                index: None,
+                refiner: None,
                 initialized: false,
             })
         }
     }
 
-    pub fn init(&mut self) -> Result<()> {
+    pub fn init(&mut self, concurrency: u32, device_id: i32) -> Result<()> {
         #[cfg(not(gpu_available))]
         {
+            let _ = (concurrency, device_id);
             Err(GpuSpatialError::GpuNotAvailable)
         }
 
         #[cfg(gpu_available)]
         {
-            let mut joiner = GpuSpatialJoinerWrapper::new();
-
             // Get PTX path from OUT_DIR
             let out_path = std::path::PathBuf::from(env!("OUT_DIR"));
             let ptx_root = out_path.join("share/gpuspatial/shaders");
             let ptx_root_str = ptx_root
                 .to_str()
                 .ok_or_else(|| GpuSpatialError::Init("Invalid PTX 
path".to_string()))?;
 
-            // Initialize with concurrency of 1 for now
-            joiner.init(1, ptx_root_str)?;
+            let rt_engine = GpuSpatialRTEngineWrapper::try_new(device_id, 
ptx_root_str)?;
 
-            // Create context
-            let mut ctx = GpuSpatialJoinerContext {
-                last_error: std::ptr::null(),
-                private_data: std::ptr::null_mut(),
-                build_indices: std::ptr::null_mut(),
-                stream_indices: std::ptr::null_mut(),
-            };
-            joiner.create_context(&mut ctx);
+            self.rt_engine = Some(Arc::new(Mutex::new(rt_engine)));
+
+            let index = GpuSpatialIndexFloat2DWrapper::try_new(
+                self.rt_engine.as_ref().unwrap(),
+                concurrency,
+            )?;
+
+            self.index = Some(index);
+
+            let refiner =
+                
GpuSpatialRefinerWrapper::try_new(self.rt_engine.as_ref().unwrap(), 
concurrency)?;
+            self.refiner = Some(refiner);
 
-            self.joiner = Some(joiner);
-            self.context = Some(ctx);
             self.initialized = true;
             Ok(())
         }
     }
 
-    #[cfg(gpu_available)]
-    pub fn get_joiner_mut(&mut self) -> Option<&mut GpuSpatialJoinerWrapper> {
-        self.joiner.as_mut()
-    }
+    pub fn is_gpu_available() -> bool {
+        #[cfg(not(gpu_available))]
+        {
+            false
+        }
+        #[cfg(gpu_available)]
+        {
+            let nvml = match Nvml::init() {
+                Ok(instance) => instance,
+                Err(_) => return false,
+            };
 
-    #[cfg(gpu_available)]
-    pub fn get_context_mut(&mut self) -> Option<&mut GpuSpatialJoinerContext> {
-        self.context.as_mut()
+            // Check if the device count is greater than zero
+            match nvml.device_count() {
+                Ok(count) => count > 0,
+                Err(_) => false,
+            }
+        }
     }
 
     pub fn is_initialized(&self) -> bool {
         self.initialized
     }
 
-    /// Perform spatial join between two geometry arrays
-    pub fn spatial_join(
-        &mut self,
-        left_geom: arrow_array::ArrayRef,
-        right_geom: arrow_array::ArrayRef,
-        predicate: SpatialPredicate,
-    ) -> Result<(Vec<u32>, Vec<u32>)> {
+    /// Clear previous build data
+    pub fn clear(&mut self) -> Result<()> {
         #[cfg(not(gpu_available))]
         {
-            let _ = (left_geom, right_geom, predicate);
             Err(GpuSpatialError::GpuNotAvailable)
         }
-
         #[cfg(gpu_available)]
         {
             if !self.initialized {
-                return Err(GpuSpatialError::Init("Context not 
initialized".into()));
+                return Err(GpuSpatialError::Init("GpuSpatial not 
initialized".into()));
             }
 
-            let joiner = self
-                .joiner
+            let index = self
+                .index
                 .as_mut()
-                .ok_or_else(|| GpuSpatialError::Init("GPU joiner not 
available".into()))?;
+                .ok_or_else(|| GpuSpatialError::Init("GPU index is not 
available".into()))?;
 
             // Clear previous build data
-            joiner.clear();
-
-            // Push build data (left side)
-            log::info!(
-                "DEBUG: Pushing {} geometries to GPU (build side)",
-                left_geom.len()
-            );
-            log::info!("DEBUG: Left array data type: {:?}", 
left_geom.data_type());
-            if let Some(binary_arr) = left_geom
-                .as_any()
-                .downcast_ref::<arrow_array::BinaryArray>()
-            {
-                log::info!("DEBUG: Left binary array has {} values", 
binary_arr.len());
-                if binary_arr.len() > 0 {
-                    let first_wkb = binary_arr.value(0);
-                    log::info!(
-                        "DEBUG: First left WKB length: {}, first bytes: {:?}",
-                        first_wkb.len(),
-                        &first_wkb[..8.min(first_wkb.len())]
-                    );
-                }
-            }
+            index.clear();
+            Ok(())
+        }
+    }
+
+    pub fn push_build(&mut self, rects: &[Rect<f32>]) -> Result<()> {
+        #[cfg(not(gpu_available))]
+        {
+            let _ = rects;
+            Err(GpuSpatialError::GpuNotAvailable)
+        }
+        #[cfg(gpu_available)]
+        {
+            let index = self
+                .index
+                .as_mut()
+                .ok_or_else(|| GpuSpatialError::Init("GPU index not 
available".into()))?;
 
-            joiner.push_build(&left_geom, 0, left_geom.len() as i64)?;
-            joiner.finish_building()?;
+            unsafe { index.push_build(rects.as_ptr() as *const f32, 
rects.len() as u32) }
+        }
+    }
+
+    pub fn finish_building(&mut self) -> Result<()> {
+        #[cfg(not(gpu_available))]
+        return Err(GpuSpatialError::GpuNotAvailable);
 
-            // Recreate context after building (required by libgpuspatial)
-            let mut new_context = 
libgpuspatial_glue_bindgen::GpuSpatialJoinerContext {
+        #[cfg(gpu_available)]
+        self.index
+            .as_mut()
+            .ok_or_else(|| GpuSpatialError::Init("GPU index not 
available".into()))?
+            .finish_building()
+    }
+
+    pub fn probe(&self, rects: &[Rect<f32>]) -> Result<(Vec<u32>, Vec<u32>)> {
+        #[cfg(not(gpu_available))]
+        {
+            let _ = rects;
+            Err(GpuSpatialError::GpuNotAvailable)
+        }
+
+        #[cfg(gpu_available)]
+        {
+            let index = self
+                .index
+                .as_ref()
+                .ok_or_else(|| GpuSpatialError::Init("GPU index not 
available".into()))?;
+            // Create context
+            let mut ctx = GpuSpatialIndexContext {
                 last_error: std::ptr::null(),
-                private_data: std::ptr::null_mut(),
                 build_indices: std::ptr::null_mut(),
-                stream_indices: std::ptr::null_mut(),
+                probe_indices: std::ptr::null_mut(),
             };
-            joiner.create_context(&mut new_context);
-            self.context = Some(new_context);
-            let context = self.context.as_mut().unwrap();
-            // Push stream data (right side) and perform join
-            let gpu_predicate = predicate.into();
-            joiner.push_stream(
-                context,
-                &right_geom,
-                0,
-                right_geom.len() as i64,
-                gpu_predicate,
-                0, // array_index_offset
-            )?;
+            index.create_context(&mut ctx);
 
-            // Get results
-            let build_indices = 
joiner.get_build_indices_buffer(context).to_vec();
-            let stream_indices = 
joiner.get_stream_indices_buffer(context).to_vec();
+            // Push stream data (probe side) and perform join
+            unsafe {
+                index.probe(&mut ctx, rects.as_ptr() as *const f32, 
rects.len() as u32)?;
+            }
 
-            Ok((build_indices, stream_indices))
+            // Get results
+            let build_indices = index.get_build_indices_buffer(&mut 
ctx).to_vec();
+            let probe_indices = index.get_probe_indices_buffer(&mut 
ctx).to_vec();
+            index.destroy_context(&mut ctx);
+            Ok((build_indices, probe_indices))

Review Comment:
   Fixed.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to