HappenLee commented on code in PR #57868:
URL: https://github.com/apache/doris/pull/57868#discussion_r2554200355


##########
be/src/udf/python/python_udf_server.py:
##########
@@ -860,14 +894,487 @@ def has_python_file_recursive(location: str) -> bool:
         return any(path.rglob("*.py"))
 
 
-class UDFFlightServer(flight.FlightServerBase):
-    """Arrow Flight server for executing Python UDFs."""
+class UDAFClassLoader:
+    """
+    Utility class for loading UDAF classes from various sources.
+    
+    This class is responsible for loading UDAF classes from:
+    - Inline code (embedded in SQL)
+    - Module files (imported from filesystem)
+    """
+
+    @staticmethod
+    def load_udaf_class(python_udf_meta: PythonUDFMeta) -> type:
+        """
+        Load the UDAF class from metadata.
+
+        Args:
+            python_udf_meta: UDAF metadata
+
+        Returns:
+            The UDAF class
+
+        Raises:
+            RuntimeError: If inline code execution fails
+            ValueError: If class is not found or invalid
+        """
+        loader = UDFLoaderFactory.get_loader(python_udf_meta)
+
+        # For UDAF, we need the class, not an instance
+        if isinstance(loader, InlineUDFLoader):
+            return UDAFClassLoader.load_from_inline(python_udf_meta)
+        elif isinstance(loader, ModuleUDFLoader):
+            return UDAFClassLoader.load_from_module(python_udf_meta, loader)
+        else:
+            raise ValueError(f"Unsupported loader type: {type(loader)}")
+
+    @staticmethod
+    def load_from_inline(python_udf_meta: PythonUDFMeta) -> type:
+        """
+        Load UDAF class from inline code.
+
+        Args:
+            python_udf_meta: UDAF metadata with inline code
+
+        Returns:
+            The UDAF class
+        """
+        symbol = python_udf_meta.symbol
+        inline_code = python_udf_meta.inline_code.decode("utf-8")
+        env: dict[str, Any] = {}
+
+        try:
+            exec(inline_code, env)  # nosec B102
+        except Exception as e:
+            raise RuntimeError(f"Failed to exec inline code: {e}") from e
+
+        udaf_class = env.get(symbol)
+        if udaf_class is None:
+            raise ValueError(f"UDAF class '{symbol}' not found in inline code")
+
+        if not inspect.isclass(udaf_class):
+            raise ValueError(f"'{symbol}' is not a class (type: 
{type(udaf_class)})")
+
+        UDAFClassLoader.validate_udaf_class(udaf_class)
+        return udaf_class
+
+    @staticmethod
+    def load_from_module(python_udf_meta: PythonUDFMeta, loader: 
ModuleUDFLoader) -> type:
+        """
+        Load UDAF class from module file.
+
+        Args:
+            python_udf_meta: UDAF metadata with module location
+            loader: Module loader instance
+
+        Returns:
+            The UDAF class
+        """
+        symbol = python_udf_meta.symbol
+        location = python_udf_meta.location
+
+        package_name, module_name, class_name = loader.parse_symbol(symbol)
+        udaf_class = loader.load_udf_from_module(
+            location, package_name, module_name, class_name
+        )
+
+        if not inspect.isclass(udaf_class):
+            raise ValueError(f"'{symbol}' is not a class (type: 
{type(udaf_class)})")
+
+        UDAFClassLoader.validate_udaf_class(udaf_class)
+        return udaf_class
+
+    @staticmethod
+    def validate_udaf_class(udaf_class: type):
+        """
+        Validate that the UDAF class follows the required Snowflake pattern.
+
+        Args:
+            udaf_class: The class to validate
+
+        Raises:
+            ValueError: If class doesn't implement required methods or 
properties
+        """
+        required_methods = ["__init__", "accumulate", "merge", "finish"]
+        for method in required_methods:
+            if not hasattr(udaf_class, method):
+                raise ValueError(
+                    f"UDAF class must implement '{method}' method. "
+                    f"Missing in {udaf_class.__name__}"
+                )
+
+        # Check for aggregate_state property
+        if not hasattr(udaf_class, "aggregate_state"):
+            raise ValueError(
+                f"UDAF class must have 'aggregate_state' property. "
+                f"Missing in {udaf_class.__name__}"
+            )
+
+        # Verify it's actually a property
+        try:
+            attr = inspect.getattr_static(udaf_class, "aggregate_state")
+            if not isinstance(attr, property):
+                raise ValueError(
+                    f"'aggregate_state' must be a @property in 
{udaf_class.__name__}"
+                )
+        except AttributeError:
+            raise ValueError(
+                f"UDAF class must have 'aggregate_state' property. "
+                f"Missing in {udaf_class.__name__}"
+            )
+
+
+class UDAFStateManager:
+    """
+    Manages UDAF aggregate states for Python UDAF execution.
+
+    This class maintains a mapping from place_id to UDAF instances,
+    following the Snowflake UDAF pattern:
+    - __init__(): Initialize state
+    - aggregate_state: Property returning serializable state
+    - accumulate(*args): Add input values
+    - merge(other_state): Merge two states
+    - finish(): Return final result
+    """
+
+    def __init__(self):
+        """Initialize the state manager."""
+        self.states: Dict[int, Any] = {}  # place_id -> UDAF instance
+        self.udaf_class = None  # UDAF class to instantiate
+        self.lock = threading.Lock()  # Thread-safe state access
+
+    def set_udaf_class(self, udaf_class: type):
+        """
+        Set the UDAF class to use for creating instances.
+
+        Args:
+            udaf_class: The UDAF class (must follow Snowflake pattern)
+        
+        Note:
+            Validation is performed by UDAFClassLoader before calling this 
method.
+        """
+        self.udaf_class = udaf_class
+        logging.info("UDAF class set to: %s", udaf_class.__name__)
+
+    def create_state(self, place_id: int) -> None:
+        """
+        Create a new UDAF state for the given place_id.
+
+        Args:
+            place_id: Unique identifier for this aggregate state
+        """
+        with self.lock:
+            if place_id in self.states:
+                # Destroy old state before creating new one
+                logging.warning(

Review Comment:
   should the logic happen ? maybe should raise error?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to