rangadi commented on code in PR #40586:
URL: https://github.com/apache/spark/pull/40586#discussion_r1157385141


##########
connector/connect/common/src/main/protobuf/spark/connect/commands.proto:
##########
@@ -177,3 +179,118 @@ message WriteOperationV2 {
   // (Optional) A condition for overwrite saving mode
   Expression overwrite_condition = 8;
 }
+
+// Starts write stream operation as streaming query. Query ID and Run ID of 
the streaming
+// query are returned.
+message WriteStreamOperationStart {
+
+  // (Required) The output of the `input` streaming relation will be written.
+  Relation input = 1;
+
+  // The following fields directly map to API for DataStreamWriter().
+  // Consult API documentation unless explicitly documented here.
+
+  string format = 2;
+  map<string, string> options = 3;
+  repeated string partitioning_column_names = 4;
+
+  oneof trigger {
+    string processing_time_interval = 5;
+    bool available_now = 6;
+    bool one_time = 7;

Review Comment:
   Fixed. 'one_time' matched the name on the server side in streaming. `once` 
is better.



##########
connector/connect/common/src/main/protobuf/spark/connect/commands.proto:
##########
@@ -177,3 +179,118 @@ message WriteOperationV2 {
   // (Optional) A condition for overwrite saving mode
   Expression overwrite_condition = 8;
 }
+
+// Starts write stream operation as streaming query. Query ID and Run ID of 
the streaming
+// query are returned.
+message WriteStreamOperationStart {
+
+  // (Required) The output of the `input` streaming relation will be written.
+  Relation input = 1;
+
+  // The following fields directly map to API for DataStreamWriter().
+  // Consult API documentation unless explicitly documented here.
+
+  string format = 2;
+  map<string, string> options = 3;
+  repeated string partitioning_column_names = 4;
+
+  oneof trigger {
+    string processing_time_interval = 5;
+    bool available_now = 6;
+    bool one_time = 7;
+    string continuous_checkpoint_interval = 8;
+  }
+
+  string output_mode = 9;
+  string query_name = 10;
+
+  // The destination is optional. When set, it can be a path or a table name.
+  oneof sink_destination {
+    string path = 11;
+    string table_name = 12;
+  }
+}
+
+message WriteStreamOperationStartResult {
+
+  // (Required)
+  string query_id = 1;
+
+  // (Required)
+  string run_id = 2;
+
+  // An optional query name.
+  string name = 3;
+
+  // TODO: How do we indicate errors?
+  // TODO: Consider adding StreamingQueryStatusResult here.
+}
+
+// Commands for a streaming query.
+message StreamingQueryCommand {
+
+  // (Required) query id of the streaming query.
+  string query_id = 1;
+  // (Required) run id of the streaming query.
+  string run_id = 2;
+
+  // A running query is identified by both run_id and query_id.
+
+  oneof command_type {
+    // Status of the query. Used to support multiple status related API like 
lastProgress().
+    StatusCommand status = 3;
+    // Stops the query.
+    bool stop = 4;
+    // Waits till all the available data is processed. See 
processAllAvailable() API doc.
+    bool process_all_available = 5;
+    // Returns logical and physical plans.
+    ExplainCommand explain = 6;
+
+    // TODO(SPARK-42960) Add more commands: await_termination(), exception() 
etc.
+  }
+
+  message StatusCommand {
+    // A limit on how many progress reports to return.
+    int32 recent_progress_limit = 1;
+  }
+
+  message ExplainCommand {
+    // TODO: Consider reusing Explain from AnalyzePlanRequest message.
+    //       We can not do this right now since it base.proto imports this 
file.
+    bool extended = 1;
+  }
+
+}
+
+// Response for commands on a streaming query.
+message StreamingQueryCommandResult {
+  // (Required)
+  string query_id = 1;

Review Comment:
   Let me move this to a 'RunningQueryId' that includes both these ids to avoid 
this issue. 



##########
connector/connect/common/src/main/protobuf/spark/connect/commands.proto:
##########
@@ -177,3 +179,118 @@ message WriteOperationV2 {
   // (Optional) A condition for overwrite saving mode
   Expression overwrite_condition = 8;
 }
+
+// Starts write stream operation as streaming query. Query ID and Run ID of 
the streaming
+// query are returned.
+message WriteStreamOperationStart {
+
+  // (Required) The output of the `input` streaming relation will be written.
+  Relation input = 1;
+
+  // The following fields directly map to API for DataStreamWriter().
+  // Consult API documentation unless explicitly documented here.
+
+  string format = 2;
+  map<string, string> options = 3;
+  repeated string partitioning_column_names = 4;
+
+  oneof trigger {
+    string processing_time_interval = 5;
+    bool available_now = 6;
+    bool one_time = 7;
+    string continuous_checkpoint_interval = 8;
+  }
+
+  string output_mode = 9;
+  string query_name = 10;
+
+  // The destination is optional. When set, it can be a path or a table name.
+  oneof sink_destination {
+    string path = 11;
+    string table_name = 12;
+  }
+}
+
+message WriteStreamOperationStartResult {
+
+  // (Required)
+  string query_id = 1;
+
+  // (Required)
+  string run_id = 2;
+
+  // An optional query name.
+  string name = 3;
+
+  // TODO: How do we indicate errors?
+  // TODO: Consider adding StreamingQueryStatusResult here.
+}
+
+// Commands for a streaming query.
+message StreamingQueryCommand {
+
+  // (Required) query id of the streaming query.
+  string query_id = 1;
+  // (Required) run id of the streaming query.
+  string run_id = 2;
+
+  // A running query is identified by both run_id and query_id.
+
+  oneof command_type {
+    // Status of the query. Used to support multiple status related API like 
lastProgress().
+    StatusCommand status = 3;
+    // Stops the query.
+    bool stop = 4;
+    // Waits till all the available data is processed. See 
processAllAvailable() API doc.
+    bool process_all_available = 5;
+    // Returns logical and physical plans.
+    ExplainCommand explain = 6;
+
+    // TODO(SPARK-42960) Add more commands: await_termination(), exception() 
etc.
+  }
+
+  message StatusCommand {
+    // A limit on how many progress reports to return.
+    int32 recent_progress_limit = 1;
+  }
+
+  message ExplainCommand {
+    // TODO: Consider reusing Explain from AnalyzePlanRequest message.
+    //       We can not do this right now since it base.proto imports this 
file.
+    bool extended = 1;
+  }
+
+}
+
+// Response for commands on a streaming query.
+message StreamingQueryCommandResult {
+  // (Required)
+  string query_id = 1;
+
+  oneof result_type {
+    StatusResult status = 2;
+    ExplainResult explain = 3;
+  }
+
+  message StatusResult {
+    // This status includes all the available to status, including progress 
messages.
+
+    // Fields from Scala 'StreamingQueryStatus' struct
+    string status_message = 1;
+    bool is_data_available = 2;
+    bool is_trigger_active = 3;
+
+    bool is_active = 4;
+
+    // Progress reports as an array of json strings.
+    repeated string recent_progress_json = 5;

Review Comment:
   I have split these into separate commands.



##########
python/pyspark/sql/connect/streaming/query.py:
##########
@@ -0,0 +1,181 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import json
+import sys
+from typing import TYPE_CHECKING, Any, cast, Dict, List, Optional
+
+from pyspark.errors import StreamingQueryException
+import pyspark.sql.connect.proto as pb2
+from pyspark.sql.streaming.query import (
+    StreamingQuery as PySparkStreamingQuery,
+)
+
+__all__ = [
+    "StreamingQuery",  # TODO(WIP): "StreamingQueryManager"
+]
+
+if TYPE_CHECKING:
+    from pyspark.sql.connect.session import SparkSession
+
+
+class StreamingQuery:
+    def __init__(
+        self, session: "SparkSession", queryId: str, runId: str, name: 
Optional[str] = None
+    ) -> None:
+        self._session = session
+        self._query_id = queryId
+        self._run_id = runId
+        self._name = name
+
+    @property
+    def id(self) -> str:
+        return self._query_id
+
+    id.__doc__ = PySparkStreamingQuery.id.__doc__
+
+    @property
+    def runId(self) -> str:
+        return self._run_id
+
+    runId.__doc__ = PySparkStreamingQuery.runId.__doc__
+
+    @property
+    def name(self) -> str:
+        return self._name
+
+    name.__doc__ = PySparkStreamingQuery.name.__doc__
+
+    @property
+    def isActive(self) -> bool:
+        return self._fetch_status().is_active
+
+    isActive.__doc__ = PySparkStreamingQuery.isActive.__doc__
+
+    def awaitTermination(self, timeout: Optional[int] = None) -> 
Optional[bool]:
+        raise NotImplementedError()
+
+    awaitTermination.__doc__ = PySparkStreamingQuery.awaitTermination.__doc__
+
+    @property
+    def status(self) -> Dict[str, Any]:
+        proto = self._fetch_status()
+        return {
+            "message": proto.status_message,
+            "isDataAvailable": proto.is_data_available,
+            "isTriggerActive": proto.is_trigger_active,
+        }
+
+    status.__doc__ = PySparkStreamingQuery.status.__doc__
+
+    @property
+    def recentProgress(self) -> List[Dict[str, Any]]:
+        progress = 
list(self._fetch_status(recent_progress_limit=10000).recent_progress_json)
+        # Server only keeps 100, so 10000 limit is high enough.
+        return [json.loads(p) for p in progress]
+
+    recentProgress.__doc__ = PySparkStreamingQuery.recentProgress.__doc__
+
+    @property
+    def lastProgress(self) -> Optional[Dict[str, Any]]:
+        progress = 
list(self._fetch_status(recent_progress_limit=1).recent_progress_json)
+        if len(progress) > 0:
+            return json.loads(progress[-1])
+        else:
+            return None
+
+    lastProgress.__doc__ = PySparkStreamingQuery.lastProgress.__doc__
+
+    def processAllAvailable(self) -> None:
+        cmd = pb2.StreamingQueryCommand()
+        cmd.process_all_available = True
+        self._execute_streaming_query_cmd(cmd)
+
+    processAllAvailable.__doc__ = 
PySparkStreamingQuery.processAllAvailable.__doc__
+
+    def stop(self) -> None:
+        cmd = pb2.StreamingQueryCommand()
+        cmd.stop = True
+        self._execute_streaming_query_cmd(cmd)
+
+    stop.__doc__ = PySparkStreamingQuery.stop.__doc__
+
+    def explain(self, extended: bool = False) -> None:
+        cmd = pb2.StreamingQueryCommand()
+        cmd.explain.extended = extended
+        result = self._execute_streaming_query_cmd(cmd).explain.result
+        print(result)
+
+    explain.__doc__ = PySparkStreamingQuery.explain.__doc__
+
+    def exception(self) -> Optional[StreamingQueryException]:
+        raise NotImplementedError()
+
+    exception.__doc__ = PySparkStreamingQuery.exception.__doc__
+
+    def _fetch_status(
+        self, recent_progress_limit=0
+    ) -> pb2.StreamingQueryCommandResult.StatusResult:
+        cmd = pb2.StreamingQueryCommand()
+        cmd.status.recent_progress_limit = recent_progress_limit
+
+        return self._execute_streaming_query_cmd(cmd).status
+
+    def _execute_streaming_query_cmd(
+        self, cmd: pb2.StreamingQueryCommand
+    ) -> pb2.StreamingQueryCommandResult:
+        cmd.query_id = self._query_id
+        cmd.run_id = self._run_id
+        exec_cmd = pb2.Command()
+        exec_cmd.streaming_query_command.CopyFrom(cmd)
+        (_, properties) = self._session.client.execute_command(exec_cmd)
+        return cast(pb2.StreamingQueryCommandResult, 
properties["streaming_query_command_result"])
+
+
+# TODO(WIP) class StreamingQueryManager:
+
+
+def _test() -> None:
+    import doctest
+    import os
+    from pyspark.sql import SparkSession
+    import pyspark.sql.streaming.query

Review Comment:
   I removed this implementation with a TODO:
   ```  
      # TODO(SPARK-43031): port _test() from legacy query.py.
   ```



##########
python/pyspark/sql/dataframe.py:
##########
@@ -529,14 +529,10 @@ def writeStream(self) -> DataStreamWriter:
         --------
         >>> import tempfile
         >>> df = spark.readStream.format("rate").load()
-        >>> type(df.writeStream)

Review Comment:
   Nice. Done.



##########
python/pyspark/sql/connect/streaming/query.py:
##########
@@ -0,0 +1,181 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import json
+import sys
+from typing import TYPE_CHECKING, Any, cast, Dict, List, Optional
+
+from pyspark.errors import StreamingQueryException
+import pyspark.sql.connect.proto as pb2
+from pyspark.sql.streaming.query import (
+    StreamingQuery as PySparkStreamingQuery,
+)
+
+__all__ = [
+    "StreamingQuery",  # TODO(WIP): "StreamingQueryManager"

Review Comment:
   Done. TODO(SPARK-43032). 



##########
python/pyspark/sql/connect/streaming/query.py:
##########
@@ -0,0 +1,181 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import json
+import sys
+from typing import TYPE_CHECKING, Any, cast, Dict, List, Optional
+
+from pyspark.errors import StreamingQueryException
+import pyspark.sql.connect.proto as pb2
+from pyspark.sql.streaming.query import (
+    StreamingQuery as PySparkStreamingQuery,
+)
+
+__all__ = [
+    "StreamingQuery",  # TODO(WIP): "StreamingQueryManager"
+]
+
+if TYPE_CHECKING:
+    from pyspark.sql.connect.session import SparkSession
+
+
+class StreamingQuery:
+    def __init__(
+        self, session: "SparkSession", queryId: str, runId: str, name: 
Optional[str] = None
+    ) -> None:
+        self._session = session
+        self._query_id = queryId
+        self._run_id = runId
+        self._name = name
+
+    @property
+    def id(self) -> str:
+        return self._query_id
+
+    id.__doc__ = PySparkStreamingQuery.id.__doc__
+
+    @property
+    def runId(self) -> str:
+        return self._run_id
+
+    runId.__doc__ = PySparkStreamingQuery.runId.__doc__
+
+    @property
+    def name(self) -> str:
+        return self._name
+
+    name.__doc__ = PySparkStreamingQuery.name.__doc__
+
+    @property
+    def isActive(self) -> bool:
+        return self._fetch_status().is_active
+
+    isActive.__doc__ = PySparkStreamingQuery.isActive.__doc__
+
+    def awaitTermination(self, timeout: Optional[int] = None) -> 
Optional[bool]:
+        raise NotImplementedError()
+
+    awaitTermination.__doc__ = PySparkStreamingQuery.awaitTermination.__doc__
+
+    @property
+    def status(self) -> Dict[str, Any]:
+        proto = self._fetch_status()
+        return {
+            "message": proto.status_message,
+            "isDataAvailable": proto.is_data_available,
+            "isTriggerActive": proto.is_trigger_active,
+        }
+
+    status.__doc__ = PySparkStreamingQuery.status.__doc__
+
+    @property
+    def recentProgress(self) -> List[Dict[str, Any]]:
+        progress = 
list(self._fetch_status(recent_progress_limit=10000).recent_progress_json)
+        # Server only keeps 100, so 10000 limit is high enough.
+        return [json.loads(p) for p in progress]
+
+    recentProgress.__doc__ = PySparkStreamingQuery.recentProgress.__doc__
+
+    @property
+    def lastProgress(self) -> Optional[Dict[str, Any]]:
+        progress = 
list(self._fetch_status(recent_progress_limit=1).recent_progress_json)
+        if len(progress) > 0:
+            return json.loads(progress[-1])
+        else:
+            return None
+
+    lastProgress.__doc__ = PySparkStreamingQuery.lastProgress.__doc__
+
+    def processAllAvailable(self) -> None:
+        cmd = pb2.StreamingQueryCommand()
+        cmd.process_all_available = True
+        self._execute_streaming_query_cmd(cmd)
+
+    processAllAvailable.__doc__ = 
PySparkStreamingQuery.processAllAvailable.__doc__
+
+    def stop(self) -> None:
+        cmd = pb2.StreamingQueryCommand()
+        cmd.stop = True
+        self._execute_streaming_query_cmd(cmd)
+
+    stop.__doc__ = PySparkStreamingQuery.stop.__doc__
+
+    def explain(self, extended: bool = False) -> None:
+        cmd = pb2.StreamingQueryCommand()
+        cmd.explain.extended = extended
+        result = self._execute_streaming_query_cmd(cmd).explain.result
+        print(result)
+
+    explain.__doc__ = PySparkStreamingQuery.explain.__doc__
+
+    def exception(self) -> Optional[StreamingQueryException]:
+        raise NotImplementedError()
+
+    exception.__doc__ = PySparkStreamingQuery.exception.__doc__
+
+    def _fetch_status(
+        self, recent_progress_limit=0
+    ) -> pb2.StreamingQueryCommandResult.StatusResult:
+        cmd = pb2.StreamingQueryCommand()
+        cmd.status.recent_progress_limit = recent_progress_limit
+
+        return self._execute_streaming_query_cmd(cmd).status
+
+    def _execute_streaming_query_cmd(
+        self, cmd: pb2.StreamingQueryCommand
+    ) -> pb2.StreamingQueryCommandResult:
+        cmd.query_id = self._query_id
+        cmd.run_id = self._run_id
+        exec_cmd = pb2.Command()
+        exec_cmd.streaming_query_command.CopyFrom(cmd)
+        (_, properties) = self._session.client.execute_command(exec_cmd)
+        return cast(pb2.StreamingQueryCommandResult, 
properties["streaming_query_command_result"])
+
+
+# TODO(WIP) class StreamingQueryManager:
+
+
+def _test() -> None:
+    import doctest
+    import os
+    from pyspark.sql import SparkSession
+    import pyspark.sql.streaming.query
+    from py4j.protocol import Py4JError
+
+    os.chdir(os.environ["SPARK_HOME"])
+
+    globs = pyspark.sql.streaming.query.__dict__.copy()

Review Comment:
   Same as above. We will port these in follow up. SPARK-43031
   Removed the code for now. 



##########
python/pyspark/sql/connect/streaming/query.py:
##########
@@ -0,0 +1,181 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import json
+import sys
+from typing import TYPE_CHECKING, Any, cast, Dict, List, Optional
+
+from pyspark.errors import StreamingQueryException
+import pyspark.sql.connect.proto as pb2
+from pyspark.sql.streaming.query import (
+    StreamingQuery as PySparkStreamingQuery,
+)
+
+__all__ = [
+    "StreamingQuery",  # TODO(WIP): "StreamingQueryManager"
+]
+
+if TYPE_CHECKING:
+    from pyspark.sql.connect.session import SparkSession
+
+
+class StreamingQuery:
+    def __init__(
+        self, session: "SparkSession", queryId: str, runId: str, name: 
Optional[str] = None
+    ) -> None:
+        self._session = session
+        self._query_id = queryId
+        self._run_id = runId
+        self._name = name
+
+    @property
+    def id(self) -> str:
+        return self._query_id
+
+    id.__doc__ = PySparkStreamingQuery.id.__doc__
+
+    @property
+    def runId(self) -> str:
+        return self._run_id
+
+    runId.__doc__ = PySparkStreamingQuery.runId.__doc__
+
+    @property
+    def name(self) -> str:
+        return self._name
+
+    name.__doc__ = PySparkStreamingQuery.name.__doc__
+
+    @property
+    def isActive(self) -> bool:
+        return self._fetch_status().is_active
+
+    isActive.__doc__ = PySparkStreamingQuery.isActive.__doc__
+
+    def awaitTermination(self, timeout: Optional[int] = None) -> 
Optional[bool]:
+        raise NotImplementedError()
+
+    awaitTermination.__doc__ = PySparkStreamingQuery.awaitTermination.__doc__
+
+    @property
+    def status(self) -> Dict[str, Any]:
+        proto = self._fetch_status()
+        return {
+            "message": proto.status_message,
+            "isDataAvailable": proto.is_data_available,
+            "isTriggerActive": proto.is_trigger_active,
+        }
+
+    status.__doc__ = PySparkStreamingQuery.status.__doc__
+
+    @property
+    def recentProgress(self) -> List[Dict[str, Any]]:
+        progress = 
list(self._fetch_status(recent_progress_limit=10000).recent_progress_json)
+        # Server only keeps 100, so 10000 limit is high enough.
+        return [json.loads(p) for p in progress]
+
+    recentProgress.__doc__ = PySparkStreamingQuery.recentProgress.__doc__
+
+    @property
+    def lastProgress(self) -> Optional[Dict[str, Any]]:
+        progress = 
list(self._fetch_status(recent_progress_limit=1).recent_progress_json)
+        if len(progress) > 0:
+            return json.loads(progress[-1])
+        else:
+            return None
+
+    lastProgress.__doc__ = PySparkStreamingQuery.lastProgress.__doc__
+
+    def processAllAvailable(self) -> None:
+        cmd = pb2.StreamingQueryCommand()
+        cmd.process_all_available = True
+        self._execute_streaming_query_cmd(cmd)
+
+    processAllAvailable.__doc__ = 
PySparkStreamingQuery.processAllAvailable.__doc__
+
+    def stop(self) -> None:
+        cmd = pb2.StreamingQueryCommand()
+        cmd.stop = True
+        self._execute_streaming_query_cmd(cmd)
+
+    stop.__doc__ = PySparkStreamingQuery.stop.__doc__
+
+    def explain(self, extended: bool = False) -> None:
+        cmd = pb2.StreamingQueryCommand()
+        cmd.explain.extended = extended
+        result = self._execute_streaming_query_cmd(cmd).explain.result
+        print(result)
+
+    explain.__doc__ = PySparkStreamingQuery.explain.__doc__
+
+    def exception(self) -> Optional[StreamingQueryException]:
+        raise NotImplementedError()
+
+    exception.__doc__ = PySparkStreamingQuery.exception.__doc__
+
+    def _fetch_status(
+        self, recent_progress_limit=0
+    ) -> pb2.StreamingQueryCommandResult.StatusResult:
+        cmd = pb2.StreamingQueryCommand()
+        cmd.status.recent_progress_limit = recent_progress_limit
+
+        return self._execute_streaming_query_cmd(cmd).status
+
+    def _execute_streaming_query_cmd(
+        self, cmd: pb2.StreamingQueryCommand
+    ) -> pb2.StreamingQueryCommandResult:
+        cmd.query_id = self._query_id
+        cmd.run_id = self._run_id
+        exec_cmd = pb2.Command()
+        exec_cmd.streaming_query_command.CopyFrom(cmd)
+        (_, properties) = self._session.client.execute_command(exec_cmd)
+        return cast(pb2.StreamingQueryCommandResult, 
properties["streaming_query_command_result"])
+
+
+# TODO(WIP) class StreamingQueryManager:
+
+
+def _test() -> None:
+    import doctest
+    import os
+    from pyspark.sql import SparkSession
+    import pyspark.sql.streaming.query
+    from py4j.protocol import Py4JError
+
+    os.chdir(os.environ["SPARK_HOME"])
+
+    globs = pyspark.sql.streaming.query.__dict__.copy()
+    try:
+        spark = SparkSession._getActiveSessionOrCreate()

Review Comment:
   Noted. Will do it in SPARK-43031



##########
python/pyspark/sql/connect/streaming/query.py:
##########
@@ -0,0 +1,181 @@
+#

Review Comment:
   Thanks. Added TODO(SPARK-43031). @WweiL looking into enabling these tests.



##########
python/pyspark/sql/connect/streaming/query.py:
##########
@@ -0,0 +1,181 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import json
+import sys
+from typing import TYPE_CHECKING, Any, cast, Dict, List, Optional
+
+from pyspark.errors import StreamingQueryException
+import pyspark.sql.connect.proto as pb2
+from pyspark.sql.streaming.query import (
+    StreamingQuery as PySparkStreamingQuery,
+)
+
+__all__ = [
+    "StreamingQuery",  # TODO(WIP): "StreamingQueryManager"
+]
+
+if TYPE_CHECKING:
+    from pyspark.sql.connect.session import SparkSession
+
+
+class StreamingQuery:
+    def __init__(
+        self, session: "SparkSession", queryId: str, runId: str, name: 
Optional[str] = None
+    ) -> None:
+        self._session = session
+        self._query_id = queryId
+        self._run_id = runId
+        self._name = name
+
+    @property
+    def id(self) -> str:
+        return self._query_id
+
+    id.__doc__ = PySparkStreamingQuery.id.__doc__
+
+    @property
+    def runId(self) -> str:
+        return self._run_id
+
+    runId.__doc__ = PySparkStreamingQuery.runId.__doc__
+
+    @property
+    def name(self) -> str:
+        return self._name
+
+    name.__doc__ = PySparkStreamingQuery.name.__doc__
+
+    @property
+    def isActive(self) -> bool:
+        return self._fetch_status().is_active
+
+    isActive.__doc__ = PySparkStreamingQuery.isActive.__doc__
+
+    def awaitTermination(self, timeout: Optional[int] = None) -> 
Optional[bool]:
+        raise NotImplementedError()
+
+    awaitTermination.__doc__ = PySparkStreamingQuery.awaitTermination.__doc__
+
+    @property
+    def status(self) -> Dict[str, Any]:
+        proto = self._fetch_status()
+        return {
+            "message": proto.status_message,
+            "isDataAvailable": proto.is_data_available,
+            "isTriggerActive": proto.is_trigger_active,
+        }
+
+    status.__doc__ = PySparkStreamingQuery.status.__doc__
+
+    @property
+    def recentProgress(self) -> List[Dict[str, Any]]:
+        progress = 
list(self._fetch_status(recent_progress_limit=10000).recent_progress_json)
+        # Server only keeps 100, so 10000 limit is high enough.
+        return [json.loads(p) for p in progress]
+
+    recentProgress.__doc__ = PySparkStreamingQuery.recentProgress.__doc__
+
+    @property
+    def lastProgress(self) -> Optional[Dict[str, Any]]:
+        progress = 
list(self._fetch_status(recent_progress_limit=1).recent_progress_json)
+        if len(progress) > 0:
+            return json.loads(progress[-1])
+        else:
+            return None
+
+    lastProgress.__doc__ = PySparkStreamingQuery.lastProgress.__doc__
+
+    def processAllAvailable(self) -> None:
+        cmd = pb2.StreamingQueryCommand()
+        cmd.process_all_available = True
+        self._execute_streaming_query_cmd(cmd)
+
+    processAllAvailable.__doc__ = 
PySparkStreamingQuery.processAllAvailable.__doc__
+
+    def stop(self) -> None:
+        cmd = pb2.StreamingQueryCommand()
+        cmd.stop = True
+        self._execute_streaming_query_cmd(cmd)
+
+    stop.__doc__ = PySparkStreamingQuery.stop.__doc__
+
+    def explain(self, extended: bool = False) -> None:
+        cmd = pb2.StreamingQueryCommand()
+        cmd.explain.extended = extended
+        result = self._execute_streaming_query_cmd(cmd).explain.result
+        print(result)
+
+    explain.__doc__ = PySparkStreamingQuery.explain.__doc__
+
+    def exception(self) -> Optional[StreamingQueryException]:
+        raise NotImplementedError()
+
+    exception.__doc__ = PySparkStreamingQuery.exception.__doc__
+
+    def _fetch_status(
+        self, recent_progress_limit=0
+    ) -> pb2.StreamingQueryCommandResult.StatusResult:
+        cmd = pb2.StreamingQueryCommand()
+        cmd.status.recent_progress_limit = recent_progress_limit
+
+        return self._execute_streaming_query_cmd(cmd).status
+
+    def _execute_streaming_query_cmd(
+        self, cmd: pb2.StreamingQueryCommand
+    ) -> pb2.StreamingQueryCommandResult:
+        cmd.query_id = self._query_id
+        cmd.run_id = self._run_id
+        exec_cmd = pb2.Command()
+        exec_cmd.streaming_query_command.CopyFrom(cmd)
+        (_, properties) = self._session.client.execute_command(exec_cmd)
+        return cast(pb2.StreamingQueryCommandResult, 
properties["streaming_query_command_result"])
+
+
+# TODO(WIP) class StreamingQueryManager:
+
+
+def _test() -> None:
+    import doctest
+    import os
+    from pyspark.sql import SparkSession
+    import pyspark.sql.streaming.query
+    from py4j.protocol import Py4JError
+
+    os.chdir(os.environ["SPARK_HOME"])
+
+    globs = pyspark.sql.streaming.query.__dict__.copy()
+    try:
+        spark = SparkSession._getActiveSessionOrCreate()
+    except Py4JError:  # noqa: F821
+        spark = SparkSession(sc)  # type: ignore[name-defined] # noqa: F821
+
+    globs["spark"] = spark
+
+    (failure_count, test_count) = doctest.testmod(
+        pyspark.sql.streaming.query,

Review Comment:
   Same as above. will address in SPARK-43031. 



##########
python/pyspark/sql/dataframe.py:
##########
@@ -529,14 +529,10 @@ def writeStream(self) -> DataStreamWriter:
         --------
         >>> import tempfile
         >>> df = spark.readStream.format("rate").load()
-        >>> type(df.writeStream)
-        <class 'pyspark.sql.streaming.readwriter.DataStreamWriter'>
-
         >>> with tempfile.TemporaryDirectory() as d:
         ...     # Create a table with Rate source.
-        ...     df.writeStream.toTable(
+        ...     streaming_query = df.writeStream.toTable(
         ...         "my_table", checkpointLocation=d) # doctest: +ELLIPSIS
-        <pyspark.sql.streaming.query.StreamingQuery object at 0x...>

Review Comment:
   Done.



##########
python/pyspark/sql/connect/streaming/query.py:
##########
@@ -0,0 +1,181 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import json
+import sys
+from typing import TYPE_CHECKING, Any, cast, Dict, List, Optional
+
+from pyspark.errors import StreamingQueryException
+import pyspark.sql.connect.proto as pb2
+from pyspark.sql.streaming.query import (
+    StreamingQuery as PySparkStreamingQuery,
+)
+
+__all__ = [
+    "StreamingQuery",  # TODO(WIP): "StreamingQueryManager"
+]
+
+if TYPE_CHECKING:
+    from pyspark.sql.connect.session import SparkSession
+
+
+class StreamingQuery:
+    def __init__(
+        self, session: "SparkSession", queryId: str, runId: str, name: 
Optional[str] = None
+    ) -> None:
+        self._session = session
+        self._query_id = queryId
+        self._run_id = runId
+        self._name = name
+
+    @property
+    def id(self) -> str:
+        return self._query_id
+
+    id.__doc__ = PySparkStreamingQuery.id.__doc__
+
+    @property
+    def runId(self) -> str:
+        return self._run_id
+
+    runId.__doc__ = PySparkStreamingQuery.runId.__doc__
+
+    @property
+    def name(self) -> str:
+        return self._name
+
+    name.__doc__ = PySparkStreamingQuery.name.__doc__
+
+    @property
+    def isActive(self) -> bool:
+        return self._fetch_status().is_active
+
+    isActive.__doc__ = PySparkStreamingQuery.isActive.__doc__
+
+    def awaitTermination(self, timeout: Optional[int] = None) -> 
Optional[bool]:
+        raise NotImplementedError()
+
+    awaitTermination.__doc__ = PySparkStreamingQuery.awaitTermination.__doc__
+
+    @property
+    def status(self) -> Dict[str, Any]:
+        proto = self._fetch_status()
+        return {
+            "message": proto.status_message,
+            "isDataAvailable": proto.is_data_available,
+            "isTriggerActive": proto.is_trigger_active,
+        }
+
+    status.__doc__ = PySparkStreamingQuery.status.__doc__
+
+    @property
+    def recentProgress(self) -> List[Dict[str, Any]]:
+        progress = 
list(self._fetch_status(recent_progress_limit=10000).recent_progress_json)
+        # Server only keeps 100, so 10000 limit is high enough.
+        return [json.loads(p) for p in progress]
+
+    recentProgress.__doc__ = PySparkStreamingQuery.recentProgress.__doc__
+
+    @property
+    def lastProgress(self) -> Optional[Dict[str, Any]]:
+        progress = 
list(self._fetch_status(recent_progress_limit=1).recent_progress_json)
+        if len(progress) > 0:
+            return json.loads(progress[-1])
+        else:
+            return None
+
+    lastProgress.__doc__ = PySparkStreamingQuery.lastProgress.__doc__
+
+    def processAllAvailable(self) -> None:
+        cmd = pb2.StreamingQueryCommand()
+        cmd.process_all_available = True
+        self._execute_streaming_query_cmd(cmd)
+
+    processAllAvailable.__doc__ = 
PySparkStreamingQuery.processAllAvailable.__doc__
+
+    def stop(self) -> None:
+        cmd = pb2.StreamingQueryCommand()
+        cmd.stop = True
+        self._execute_streaming_query_cmd(cmd)
+
+    stop.__doc__ = PySparkStreamingQuery.stop.__doc__
+
+    def explain(self, extended: bool = False) -> None:
+        cmd = pb2.StreamingQueryCommand()
+        cmd.explain.extended = extended
+        result = self._execute_streaming_query_cmd(cmd).explain.result
+        print(result)
+
+    explain.__doc__ = PySparkStreamingQuery.explain.__doc__
+
+    def exception(self) -> Optional[StreamingQueryException]:
+        raise NotImplementedError()
+
+    exception.__doc__ = PySparkStreamingQuery.exception.__doc__
+
+    def _fetch_status(
+        self, recent_progress_limit=0
+    ) -> pb2.StreamingQueryCommandResult.StatusResult:
+        cmd = pb2.StreamingQueryCommand()
+        cmd.status.recent_progress_limit = recent_progress_limit
+
+        return self._execute_streaming_query_cmd(cmd).status
+
+    def _execute_streaming_query_cmd(
+        self, cmd: pb2.StreamingQueryCommand
+    ) -> pb2.StreamingQueryCommandResult:
+        cmd.query_id = self._query_id
+        cmd.run_id = self._run_id
+        exec_cmd = pb2.Command()
+        exec_cmd.streaming_query_command.CopyFrom(cmd)
+        (_, properties) = self._session.client.execute_command(exec_cmd)
+        return cast(pb2.StreamingQueryCommandResult, 
properties["streaming_query_command_result"])
+
+
+# TODO(WIP) class StreamingQueryManager:
+
+
+def _test() -> None:
+    import doctest
+    import os
+    from pyspark.sql import SparkSession
+    import pyspark.sql.streaming.query
+    from py4j.protocol import Py4JError
+
+    os.chdir(os.environ["SPARK_HOME"])
+
+    globs = pyspark.sql.streaming.query.__dict__.copy()
+    try:
+        spark = SparkSession._getActiveSessionOrCreate()

Review Comment:
   Noted. 
   FYI: @WweiL 



##########
connector/connect/common/src/main/protobuf/spark/connect/commands.proto:
##########
@@ -177,3 +179,118 @@ message WriteOperationV2 {
   // (Optional) A condition for overwrite saving mode
   Expression overwrite_condition = 8;
 }
+
+// Starts write stream operation as streaming query. Query ID and Run ID of 
the streaming
+// query are returned.
+message WriteStreamOperationStart {
+
+  // (Required) The output of the `input` streaming relation will be written.
+  Relation input = 1;
+
+  // The following fields directly map to API for DataStreamWriter().
+  // Consult API documentation unless explicitly documented here.
+
+  string format = 2;
+  map<string, string> options = 3;
+  repeated string partitioning_column_names = 4;
+
+  oneof trigger {
+    string processing_time_interval = 5;
+    bool available_now = 6;
+    bool one_time = 7;
+    string continuous_checkpoint_interval = 8;
+  }
+
+  string output_mode = 9;
+  string query_name = 10;
+
+  // The destination is optional. When set, it can be a path or a table name.
+  oneof sink_destination {
+    string path = 11;
+    string table_name = 12;
+  }
+}
+
+message WriteStreamOperationStartResult {
+
+  // (Required)
+  string query_id = 1;
+
+  // (Required)
+  string run_id = 2;
+
+  // An optional query name.
+  string name = 3;
+
+  // TODO: How do we indicate errors?
+  // TODO: Consider adding StreamingQueryStatusResult here.
+}
+
+// Commands for a streaming query.
+message StreamingQueryCommand {
+
+  // (Required) query id of the streaming query.
+  string query_id = 1;
+  // (Required) run id of the streaming query.
+  string run_id = 2;
+
+  // A running query is identified by both run_id and query_id.
+
+  oneof command_type {
+    // Status of the query. Used to support multiple status related API like 
lastProgress().
+    StatusCommand status = 3;
+    // Stops the query.
+    bool stop = 4;
+    // Waits till all the available data is processed. See 
processAllAvailable() API doc.
+    bool process_all_available = 5;
+    // Returns logical and physical plans.
+    ExplainCommand explain = 6;
+
+    // TODO(SPARK-42960) Add more commands: await_termination(), exception() 
etc.
+  }
+
+  message StatusCommand {
+    // A limit on how many progress reports to return.
+    int32 recent_progress_limit = 1;
+  }
+
+  message ExplainCommand {
+    // TODO: Consider reusing Explain from AnalyzePlanRequest message.
+    //       We can not do this right now since it base.proto imports this 
file.
+    bool extended = 1;
+  }
+
+}
+
+// Response for commands on a streaming query.
+message StreamingQueryCommandResult {
+  // (Required)
+  string query_id = 1;

Review Comment:
   Fixed. Added 'StreamingQueryInstanceId' struct and we use it in multiple 
places.



##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala:
##########
@@ -1969,6 +2017,142 @@ class SparkConnectPlanner(val session: SparkSession) {
     }
   }
 
+  def handleWriteStreamOperationStart(
+      writeOp: WriteStreamOperationStart,
+      sessionId: String,
+      responseObserver: StreamObserver[ExecutePlanResponse]): Unit = {
+    val plan = transformRelation(writeOp.getInput)
+    val dataset = Dataset.ofRows(session, logicalPlan = plan)
+
+    val writer = dataset.writeStream
+
+    if (writeOp.getFormat.nonEmpty) {
+      writer.format(writeOp.getFormat)
+    }
+
+    writer.options(writeOp.getOptionsMap)
+
+    if (writeOp.getPartitioningColumnNamesCount > 0) {
+      
writer.partitionBy(writeOp.getPartitioningColumnNamesList.asScala.toList: _*)
+    }
+
+    writeOp.getTriggerCase match {
+      case TriggerCase.PROCESSING_TIME_INTERVAL =>
+        
writer.trigger(Trigger.ProcessingTime(writeOp.getProcessingTimeInterval))
+      case TriggerCase.AVAILABLE_NOW =>
+        writer.trigger(Trigger.AvailableNow())
+      case TriggerCase.ONE_TIME =>
+        writer.trigger(Trigger.Once())
+      case TriggerCase.CONTINUOUS_CHECKPOINT_INTERVAL =>
+        
writer.trigger(Trigger.Continuous(writeOp.getContinuousCheckpointInterval))
+      case TriggerCase.TRIGGER_NOT_SET =>
+    }
+
+    if (writeOp.getOutputMode.nonEmpty) {
+      writer.outputMode(writeOp.getOutputMode)
+    }
+
+    if (writeOp.getQueryName.nonEmpty) {
+      writer.queryName(writeOp.getQueryName)
+    }
+
+    val query = writeOp.getPath match {
+      case "" if writeOp.hasTableName => writer.toTable(writeOp.getTableName)
+      case "" => writer.start()
+      case path => writer.start(path)
+    }
+
+    val result = WriteStreamOperationStartResult
+      .newBuilder()
+      .setQueryId(query.id.toString)
+      .setRunId(query.runId.toString)
+      .setName(Option(query.name).getOrElse(""))
+      .build()
+
+    responseObserver.onNext(
+      ExecutePlanResponse
+        .newBuilder()
+        .setSessionId(sessionId)
+        .setWriteStreamOperationStartResult(result)
+        .build())
+  }
+
+  def handleStreamingQueryCommand(
+      command: StreamingQueryCommand,
+      sessionId: String,
+      responseObserver: StreamObserver[ExecutePlanResponse]): Unit = {
+
+    val queryId = command.getQueryId
+
+    val respBuilder = StreamingQueryCommandResult
+      .newBuilder()
+      .setQueryId(command.getQueryId)
+
+    val query = Option(session.streams.get(queryId)) match {
+      case Some(query) if query.runId.toString == command.getRunId =>
+        query
+      case Some(query) =>
+        throw new IllegalArgumentException(
+          s"Run id mismatch for query id $queryId. Run id in the request 
${command.getRunId} " +
+            s"does not match one on the server ${query.runId}. The query might 
have restarted.")
+      case None =>
+        throw new IllegalArgumentException(s"Streaming query $queryId is not 
found")
+      // TODO(SPARK-42962): Handle this better. May be cache stopped queries 
for a few minutes.
+    }
+
+    command.getCommandTypeCase match {
+      case StreamingQueryCommand.CommandTypeCase.STATUS =>
+        val recentProgress: Seq[String] = 
command.getStatus.getRecentProgressLimit match {
+          case 0 => Seq.empty
+          case limit if limit < 0 =>
+            query.recentProgress.map(_.json) // All the cached progresses.
+          case limit => query.recentProgress.takeRight(limit).map(_.json) // 
Most recent
+        }
+
+        val queryStatus = query.status
+
+        val statusResult = StreamingQueryCommandResult.StatusResult
+          .newBuilder()
+          .setStatusMessage(queryStatus.message)
+          .setIsDataAvailable(queryStatus.isDataAvailable)
+          .setIsTriggerActive(queryStatus.isTriggerActive)
+          .setIsActive(query.isActive)
+          .addAllRecentProgressJson(recentProgress.asJava)
+          .build()
+
+        respBuilder.setStatus(statusResult)
+
+      case StreamingQueryCommand.CommandTypeCase.STOP =>
+        query.stop()
+
+      case StreamingQueryCommand.CommandTypeCase.PROCESS_ALL_AVAILABLE =>
+        query.processAllAvailable()

Review Comment:
   Updated the comment. RPC will stay active (because client sends heartbeats 
to keep connection alive). The issue could session getting closed. That will be 
handled separately. 



##########
python/pyspark/sql/connect/readwriter.py:
##########
@@ -37,7 +37,7 @@
     from pyspark.sql.connect._typing import ColumnOrName, OptionalPrimitiveType
     from pyspark.sql.connect.session import SparkSession
 
-__all__ = ["DataFrameReader", "DataFrameWriter"]
+__all__ = ["DataFrameReader", "DataFrameWriter", "OptionUtils", "to_str"]

Review Comment:
   Thats nice. Removed the exports. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to