pitrou commented on code in PR #40070:
URL: https://github.com/apache/arrow/pull/40070#discussion_r1502930081


##########
python/pyarrow/table.pxi:
##########
@@ -2681,6 +2681,28 @@ cdef class RecordBatch(_Tabular):
 
         return pyarrow_wrap_batch(c_batch)
 
+    def cast(self, Schema target_schema, safe=None, options=None):
+        """
+        Cast batch values to another schema.
+
+        Parameters
+        ----------
+        target_schema : Schema
+            Schema to cast to, the names and order of fields must match.
+        safe : bool, default True
+            Check for overflows or other unsafe conversions.
+        options : CastOptions, default None
+            Additional checks pass by CastOptions
+
+        Returns
+        -------
+        RecordBatch
+        """
+        # Wrap the more general Table cast implementation
+        tbl = Table.from_batches([self])
+        casted_tbl = tbl.cast(target_schema, safe=safe, options=options)
+        return list(casted_tbl.to_batches())[0]

Review Comment:
   Hmm, we should ensure that we don't ignore any extraneous batch. Something 
like?
   ```suggestion
           casted_batch, = casted_tbl.to_batches()
           return casted_batch
   ```



##########
python/pyarrow/src/arrow/python/ipc.cc:
##########
@@ -63,5 +64,70 @@ Result<std::shared_ptr<RecordBatchReader>> 
PyRecordBatchReader::Make(
   return reader;
 }
 
+CastingRecordBatchReader::CastingRecordBatchReader() {}

Review Comment:
   Nit
   ```suggestion
   CastingRecordBatchReader::CastingRecordBatchReader() = default;
   ```



##########
python/pyarrow/src/arrow/python/ipc.cc:
##########
@@ -63,5 +64,70 @@ Result<std::shared_ptr<RecordBatchReader>> 
PyRecordBatchReader::Make(
   return reader;
 }
 
+CastingRecordBatchReader::CastingRecordBatchReader() {}
+
+Status CastingRecordBatchReader::Init(std::shared_ptr<RecordBatchReader> 
parent,
+                                      std::shared_ptr<Schema> schema) {
+  std::shared_ptr<Schema> src = parent->schema();
+
+  // The check for names has already been done in Python where it's easier to
+  // generate a nice error message.
+  int num_fields = schema->num_fields();
+  if (src->num_fields() != num_fields) {
+    return Status::Invalid("Number of fields not equal");
+  }
+
+  // Ensure all columns can be cast before succeeding
+  for (int i = 0; i < num_fields; i++) {
+    if (!compute::CanCast(*src->field(i)->type(), *schema->field(i)->type())) {
+      return Status::NotImplemented("Field ", i, " cannot be cast from ",

Review Comment:
   `Status::TypeError` sounds better IMHO. `NotImplemented` implies that the 
corresponding cast should be implemented some day.



##########
python/pyarrow/tests/test_ipc.py:
##########
@@ -51,16 +51,16 @@ def get_source(self):
 
     def write_batches(self, num_batches=5, as_table=False):
         nrows = 5
-        schema = pa.schema([('one', pa.float64()), ('two', pa.utf8())])
+        schema = pa.schema([("one", pa.float64()), ("two", pa.utf8())])

Review Comment:
   Any reason for all these style changes? These don't seem related. Did you 
apply a formatting tool by mistake?



##########
python/pyarrow/table.pxi:
##########
@@ -2995,7 +3017,7 @@ cdef class RecordBatch(_Tabular):
         ----------
         requested_schema : PyCapsule | None
             A PyCapsule containing a C ArrowSchema representation of a 
requested
-            schema. PyArrow will attempt to cast the batch to this schema.
+            schema. PyArrow will attempt to cast each batch to this schema.

Review Comment:
   Why this change?



##########
python/pyarrow/ipc.pxi:
##########
@@ -772,6 +772,38 @@ cdef class RecordBatchReader(_Weakrefable):
     def __exit__(self, exc_type, exc_val, exc_tb):
         self.close()
 
+    def cast(self, target_schema):
+        """
+        Wrap this reader with one that casts each batch lazily as it is pulled.
+        Currently only a safe cast to target_schema is implemented.
+
+        Parameters
+        ----------
+        target_schema : Schema
+            Schema to cast to, the names and order of fields must match.
+
+        Returns
+        -------
+        RecordBatchReader
+        """
+        cdef:
+            shared_ptr[CSchema] c_schema
+            shared_ptr[CRecordBatchReader] c_reader
+            RecordBatchReader out
+
+        if self.schema.names != target_schema.names:
+            raise ValueError("Target schema's field names are not matching "

Review Comment:
   Nit, but you can use f-strings now rather than explicit `format` calls.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to