sandugood commented on code in PR #1513:
URL: 
https://github.com/apache/datafusion-ballista/pull/1513#discussion_r2946186536


##########
python/python/ballista/extension.py:
##########
@@ -229,9 +231,319 @@ def write_parquet(
         df = self._to_internal_df()
         df.write_parquet(str(path), compression.value, compression_level)
 
+    def explain_visual(self, analyze: bool = False) -> 
"ExecutionPlanVisualization":
+        """
+        Generate a visual representation of the execution plan.
+
+        This method creates an SVG visualization of the query execution plan,
+        which can be displayed directly in Jupyter notebooks.
+
+        Args:
+            analyze: If True, includes runtime statistics from actual 
execution.
+
+        Returns:
+            ExecutionPlanVisualization: An object that renders as SVG in 
Jupyter.
+
+        Example:
+            >>> df = ctx.sql("SELECT * FROM orders WHERE amount > 100")
+            >>> df.explain_visual()  # Displays SVG in notebook
+            >>> viz = df.explain_visual(analyze=True)
+            >>> viz.save("plan.svg")  # Save to file
+        """
+        # Get the execution plan as a string representation
+        # Note: explain() prints but doesn't return a string, so we use 
logical_plan()
+        try:
+            plan = self.logical_plan()
+            plan_str = plan.display_indent()
+        except Exception:
+            # Fallback if logical_plan() fails
+            plan_str = "Unable to retrieve execution plan"
+        return ExecutionPlanVisualization(plan_str, analyze=analyze)
+
+    def collect_with_progress(
+        self,
+        callback: Optional[callable] = None,
+        poll_interval: float = 0.5,
+    ):
+        """
+        Collect results with progress indication.
+
+        For long-running queries, this method provides progress updates
+        through a callback function or displays a progress bar in Jupyter.
+
+        Args:
+            callback: Optional function to call with progress updates.
+                     Signature: callback(status: str, progress: float)
+            poll_interval: How often to check progress (seconds).
+
+        Returns:
+            The collected result batches.
+
+        Example:
+            >>> def my_callback(status, progress):
+            ...     print(f"{status}: {progress:.1%}")
+            >>> batches = df.collect_with_progress(callback=my_callback)
+        """
+        import threading
+        import time
+
+        result = [None]
+        error = [None]
+        done = threading.Event()
+
+        def execute():
+            try:
+                result[0] = self.collect()
+            except Exception as e:
+                error[0] = e
+            finally:
+                done.set()
+
+        thread = threading.Thread(target=execute)
+        thread.start()
+
+        # Check if we're in a Jupyter environment
+        try:
+            from IPython.display import display, clear_output
+            from IPython import get_ipython
+
+            in_jupyter = get_ipython() is not None
+        except (ImportError, AttributeError):
+            in_jupyter = False
+
+        start_time = time.time()
+
+        if in_jupyter and callback is None:
+            # Display a simple progress indicator
+            try:
+                while not done.wait(timeout=poll_interval):
+                    elapsed = time.time() - start_time
+                    clear_output(wait=True)
+                    print(f"⏳ Query executing... ({elapsed:.1f}s elapsed)")
+
+                clear_output(wait=True)
+                elapsed = time.time() - start_time
+                print(f"✓ Query completed in {elapsed:.1f}s")
+            except Exception:
+                pass  # Ignore display errors
+        elif callback is not None:
+            while not done.wait(timeout=poll_interval):
+                elapsed = time.time() - start_time
+                callback(f"Executing ({elapsed:.1f}s)", -1.0)  # -1 means 
indeterminate
+
+            elapsed = time.time() - start_time
+            callback(f"Completed in {elapsed:.1f}s", 1.0)
+        else:
+            done.wait()
+
+        thread.join()
+
+        if error[0] is not None:
+            raise error[0]
+
+        return result[0]
+
+
+class ExecutionPlanVisualization:
+    """
+    A wrapper for execution plan visualizations that can render as SVG in 
Jupyter.
+
+    This class takes the text representation of an execution plan and converts
+    it to a Graphviz DOT format, which is then rendered as SVG.
+    """
+
+    def __init__(self, plan_str: str, analyze: bool = False):
+        self.plan_str = plan_str
+        self.analyze = analyze
+        self._svg_cache: Optional[str] = None
+
+    def _parse_plan_to_dot(self) -> str:
+        """Convert the plan string to DOT format for Graphviz."""
+        lines = self.plan_str.strip().split("\n")
+
+        dot_lines = [
+            "digraph ExecutionPlan {",
+            '    rankdir=TB;',
+            '    node [shape=box, style="rounded,filled", 
fontname="Helvetica"];',
+            '    edge [fontname="Helvetica"];',
+            "",
+        ]
+
+        nodes = []
+        edges = []
+        node_id = 0
+        stack = []  # (indent_level, node_id)
+
+        for line in lines:
+            if not line.strip():
+                continue
+
+            # Calculate indent level
+            indent = len(line) - len(line.lstrip())
+            content = line.strip()
+
+            # Skip non-plan lines
+            if content.startswith("physical_plan") or 
content.startswith("logical_plan"):
+                continue
+
+            # Create a node for this plan element
+            current_id = node_id
+            node_id += 1
+
+            # Determine node color based on operation type
+            color = "#E3F2FD"  # Default light blue
+            if "Scan" in content or "TableScan" in content:
+                color = "#E8F5E9"  # Light green for scans
+            elif "Filter" in content:
+                color = "#FFF3E0"  # Light orange for filters
+            elif "Aggregate" in content or "HashAggregate" in content:
+                color = "#F3E5F5"  # Light purple for aggregations
+            elif "Join" in content:
+                color = "#FFEBEE"  # Light red for joins
+            elif "Sort" in content:
+                color = "#E0F7FA"  # Light cyan for sorts
+            elif "Projection" in content:
+                color = "#FFF8E1"  # Light amber for projections
+
+            # Escape special characters for DOT format
+            label = content.replace('"', '\\"').replace("\n", "\\n")
+            if len(label) > 60:
+                # Wrap long labels
+                label = label[:57] + "..."
+
+            nodes.append(f'    node{current_id} [label="{label}", 
fillcolor="{color}"];')
+
+            # Connect to parent based on indentation
+            while stack and stack[-1][0] >= indent:
+                stack.pop()
+
+            if stack:
+                parent_id = stack[-1][1]
+                edges.append(f"    node{parent_id} -> node{current_id};")
+
+            stack.append((indent, current_id))
+
+        dot_lines.extend(nodes)
+        dot_lines.append("")
+        dot_lines.extend(edges)
+        dot_lines.append("}")
+
+        return "\n".join(dot_lines)
+
+    def to_dot(self) -> str:
+        """Get the DOT representation of the execution plan."""
+        return self._parse_plan_to_dot()
+
+    def to_svg(self) -> str:
+        """
+        Convert the execution plan to SVG format.
+
+        Requires graphviz to be installed. If graphviz is not available,
+        returns a simple HTML representation instead.
+        """
+        if self._svg_cache is not None:
+            return self._svg_cache
+
+        dot_source = self._parse_plan_to_dot()
+
+        try:
+            import subprocess
+
+            # Try to use graphviz's dot command
+            process = subprocess.run(
+                ["dot", "-Tsvg"],
+                input=dot_source.encode(),
+                capture_output=True,
+                timeout=30,
+            )
+
+            if process.returncode == 0:
+                self._svg_cache = process.stdout.decode()
+                return self._svg_cache
+        except (subprocess.SubprocessError, FileNotFoundError, 
subprocess.TimeoutExpired):
+            pass
+
+        # Fallback: return a pre-formatted HTML representation
+        escaped_plan = (
+            self.plan_str.replace("&", "&")
+            .replace("<", "&lt;")
+            .replace(">", "&gt;")
+        )
+        self._svg_cache = f"""
+        <div style="font-family: monospace; background: #f5f5f5; padding: 
10px; 
+                    border-radius: 5px; overflow-x: auto;">
+            <div style="color: #666; margin-bottom: 5px;">
+                Execution Plan {'(with statistics)' if self.analyze else ''}
+                <br><small>Install graphviz for visual diagram: brew install 
graphviz</small>
+            </div>
+            <pre style="margin: 0;">{escaped_plan}</pre>
+        </div>
+        """
+        return self._svg_cache
+
+    def save(self, path: str) -> None:
+        """Save the visualization to a file (SVG or DOT format)."""
+        if path.endswith(".dot"):
+            content = self.to_dot()
+        else:
+            content = self.to_svg()
+
+        with open(path, "w") as f:
+            f.write(content)
+
+    def _repr_html_(self) -> str:
+        """HTML representation for Jupyter notebooks."""
+        return self.to_svg()
+
+    def _repr_svg_(self) -> str:
+        """SVG representation for Jupyter notebooks."""
+        svg = self.to_svg()
+        # Only return if it's actual SVG content
+        if svg.strip().startswith("<svg") or svg.strip().startswith("<?xml"):
+            return svg
+        return ""
+
+    def __repr__(self) -> str:
+        """String representation."""
+        return 
f"ExecutionPlanVisualization(analyze={self.analyze})\n{self.plan_str}"
+
 
 class BallistaSessionContext(SessionContext, 
metaclass=RedefiningSessionContextMeta):
+    """
+    A session context for connecting to and querying a Ballista cluster.
+
+    This class extends DataFusion's SessionContext to work with distributed
+    Ballista clusters, automatically routing query execution to the cluster
+    while maintaining API compatibility with local DataFusion usage.
+
+    Example:
+        >>> from ballista import BallistaSessionContext
+        >>> ctx = BallistaSessionContext("df://localhost:50050")
+        >>> df = ctx.sql("SELECT * FROM my_table LIMIT 10")
+        >>> df.show()
+
+    For Jupyter notebook users:
+        >>> %load_ext ballista.jupyter
+        >>> %ballista connect df://localhost:50050
+        >>> %sql SELECT * FROM my_table
+    """
+
     def __init__(self, address: str, config=None, runtime=None):
         super().__init__(config, runtime)
         self.address = address
         self.session_id = self.session_id()
+
+    def get_tables(self) -> Optional[dict[str, str]]:
+        """Get tables and their respective schemas (in terms of database 
schema)."""
+        try:
+            catalog = self.catalog()
+            schema_names = list(catalog.schema_names())
+            if schema_names is not None:

Review Comment:
   Fixed. Thank you



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to