[ 
https://issues.apache.org/jira/browse/ARROW-969?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=16361876#comment-16361876
 ] 

ASF GitHub Bot commented on ARROW-969:
--------------------------------------

wesm closed pull request #1574: ARROW-969: [C++] Add add/remove field functions 
for RecordBatch
URL: https://github.com/apache/arrow/pull/1574
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/cpp/src/arrow/record_batch.cc b/cpp/src/arrow/record_batch.cc
index d418cc4a2..f295b864c 100644
--- a/cpp/src/arrow/record_batch.cc
+++ b/cpp/src/arrow/record_batch.cc
@@ -21,15 +21,24 @@
 #include <cstdlib>
 #include <memory>
 #include <sstream>
+#include <string>
 #include <utility>
 
 #include "arrow/array.h"
 #include "arrow/status.h"
 #include "arrow/type.h"
 #include "arrow/util/logging.h"
+#include "arrow/util/stl.h"
 
 namespace arrow {
 
+Status RecordBatch::AddColumn(int i, const std::string& field_name,
+                              const std::shared_ptr<Array>& column,
+                              std::shared_ptr<RecordBatch>* out) const {
+  auto field = ::arrow::field(field_name, column->type());
+  return AddColumn(i, field, column, out);
+}
+
 /// \class SimpleRecordBatch
 /// \brief A basic, non-lazy in-memory record batch
 class SimpleRecordBatch : public RecordBatch {
@@ -78,6 +87,42 @@ class SimpleRecordBatch : public RecordBatch {
 
   std::shared_ptr<ArrayData> column_data(int i) const override { return 
columns_[i]; }
 
+  Status AddColumn(int i, const std::shared_ptr<Field>& field,
+                   const std::shared_ptr<Array>& column,
+                   std::shared_ptr<RecordBatch>* out) const override {
+    DCHECK(field != nullptr);
+    DCHECK(column != nullptr);
+
+    if (!field->type()->Equals(column->type())) {
+      std::stringstream ss;
+      ss << "Column data type " << field->type()->name()
+         << " does not match field data type " << column->type()->name();
+      return Status::Invalid(ss.str());
+    }
+    if (column->length() != num_rows_) {
+      std::stringstream ss;
+      ss << "Added column's length must match record batch's length. Expected 
length "
+         << num_rows_ << " but got length " << column->length();
+      return Status::Invalid(ss.str());
+    }
+
+    std::shared_ptr<Schema> new_schema;
+    RETURN_NOT_OK(schema_->AddField(i, field, &new_schema));
+
+    *out = RecordBatch::Make(new_schema, num_rows_,
+                             internal::AddVectorElement(columns_, i, 
column->data()));
+    return Status::OK();
+  }
+
+  Status RemoveColumn(int i, std::shared_ptr<RecordBatch>* out) const override 
{
+    std::shared_ptr<Schema> new_schema;
+    RETURN_NOT_OK(schema_->RemoveField(i, &new_schema));
+
+    *out = RecordBatch::Make(new_schema, num_rows_,
+                             internal::DeleteVectorElement(columns_, i));
+    return Status::OK();
+  }
+
   std::shared_ptr<RecordBatch> ReplaceSchemaMetadata(
       const std::shared_ptr<const KeyValueMetadata>& metadata) const override {
     auto new_schema = schema_->AddMetadata(metadata);
diff --git a/cpp/src/arrow/record_batch.h b/cpp/src/arrow/record_batch.h
index b2c4c76b3..6fb747c40 100644
--- a/cpp/src/arrow/record_batch.h
+++ b/cpp/src/arrow/record_batch.h
@@ -96,6 +96,35 @@ class ARROW_EXPORT RecordBatch {
   /// \return an internal ArrayData object
   virtual std::shared_ptr<ArrayData> column_data(int i) const = 0;
 
+  /// \brief Add column to the record batch, producing a new RecordBatch
+  ///
+  /// \param[in] i field index, which will be boundschecked
+  /// \param[in] field field to be added
+  /// \param[in] column column to be added
+  /// \param[out] out record batch with column added
+  virtual Status AddColumn(int i, const std::shared_ptr<Field>& field,
+                           const std::shared_ptr<Array>& column,
+                           std::shared_ptr<RecordBatch>* out) const = 0;
+
+  /// \brief Add new nullable column to the record batch, producing a new
+  /// RecordBatch.
+  ///
+  /// For non-nullable columns, use the Field-based version of this method.
+  ///
+  /// \param[in] i field index, which will be boundschecked
+  /// \param[in] field_name name of field to be added
+  /// \param[in] column column to be added
+  /// \param[out] out record batch with column added
+  virtual Status AddColumn(int i, const std::string& field_name,
+                           const std::shared_ptr<Array>& column,
+                           std::shared_ptr<RecordBatch>* out) const;
+
+  /// \brief Remove column from the record batch, producing a new RecordBatch
+  ///
+  /// \param[in] i field index, does boundscheck
+  /// \param[out] out record batch with column removed
+  virtual Status RemoveColumn(int i, std::shared_ptr<RecordBatch>* out) const 
= 0;
+
   virtual std::shared_ptr<RecordBatch> ReplaceSchemaMetadata(
       const std::shared_ptr<const KeyValueMetadata>& metadata) const = 0;
 
diff --git a/cpp/src/arrow/table-test.cc b/cpp/src/arrow/table-test.cc
index 99e4dd5db..af7441682 100644
--- a/cpp/src/arrow/table-test.cc
+++ b/cpp/src/arrow/table-test.cc
@@ -466,6 +466,8 @@ TEST_F(TestTable, AddColumn) {
   // Some negative tests with invalid index
   Status status = table.AddColumn(10, columns_[0], &result);
   ASSERT_TRUE(status.IsInvalid());
+  status = table.AddColumn(4, columns_[0], &result);
+  ASSERT_TRUE(status.IsInvalid());
   status = table.AddColumn(-1, columns_[0], &result);
   ASSERT_TRUE(status.IsInvalid());
 
@@ -588,6 +590,115 @@ TEST_F(TestRecordBatch, Slice) {
   }
 }
 
+TEST_F(TestRecordBatch, AddColumn) {
+  const int length = 10;
+
+  auto field1 = field("f1", int32());
+  auto field2 = field("f2", uint8());
+  auto field3 = field("f3", int16());
+
+  auto schema1 = ::arrow::schema({field1, field2});
+  auto schema2 = ::arrow::schema({field2, field3});
+  auto schema3 = ::arrow::schema({field2});
+
+  auto array1 = MakeRandomArray<Int32Array>(length);
+  auto array2 = MakeRandomArray<UInt8Array>(length);
+  auto array3 = MakeRandomArray<Int16Array>(length);
+
+  auto batch1 = RecordBatch::Make(schema1, length, {array1, array2});
+  auto batch2 = RecordBatch::Make(schema2, length, {array2, array3});
+  auto batch3 = RecordBatch::Make(schema3, length, {array2});
+
+  const RecordBatch& batch = *batch3;
+  std::shared_ptr<RecordBatch> result;
+
+  // Negative tests with invalid index
+  Status status = batch.AddColumn(5, field1, array1, &result);
+  ASSERT_TRUE(status.IsInvalid());
+  status = batch.AddColumn(2, field1, array1, &result);
+  ASSERT_TRUE(status.IsInvalid());
+  status = batch.AddColumn(-1, field1, array1, &result);
+  ASSERT_TRUE(status.IsInvalid());
+
+  // Negative test with wrong length
+  auto longer_col = MakeRandomArray<Int32Array>(length + 1);
+  status = batch.AddColumn(0, field1, longer_col, &result);
+  ASSERT_TRUE(status.IsInvalid());
+
+  // Negative test with mismatch type
+  status = batch.AddColumn(0, field1, array2, &result);
+  ASSERT_TRUE(status.IsInvalid());
+
+  ASSERT_OK(batch.AddColumn(0, field1, array1, &result));
+  ASSERT_TRUE(result->Equals(*batch1));
+
+  ASSERT_OK(batch.AddColumn(1, field3, array3, &result));
+  ASSERT_TRUE(result->Equals(*batch2));
+
+  std::shared_ptr<RecordBatch> result2;
+  ASSERT_OK(batch.AddColumn(1, "f3", array3, &result2));
+  ASSERT_TRUE(result2->Equals(*result));
+
+  ASSERT_TRUE(result2->schema()->field(1)->nullable());
+}
+
+TEST_F(TestRecordBatch, RemoveColumn) {
+  const int length = 10;
+
+  auto field1 = field("f1", int32());
+  auto field2 = field("f2", uint8());
+  auto field3 = field("f3", int16());
+
+  auto schema1 = ::arrow::schema({field1, field2, field3});
+  auto schema2 = ::arrow::schema({field2, field3});
+  auto schema3 = ::arrow::schema({field1, field3});
+  auto schema4 = ::arrow::schema({field1, field2});
+
+  auto array1 = MakeRandomArray<Int32Array>(length);
+  auto array2 = MakeRandomArray<UInt8Array>(length);
+  auto array3 = MakeRandomArray<Int16Array>(length);
+
+  auto batch1 = RecordBatch::Make(schema1, length, {array1, array2, array3});
+  auto batch2 = RecordBatch::Make(schema2, length, {array2, array3});
+  auto batch3 = RecordBatch::Make(schema3, length, {array1, array3});
+  auto batch4 = RecordBatch::Make(schema4, length, {array1, array2});
+
+  const RecordBatch& batch = *batch1;
+  std::shared_ptr<RecordBatch> result;
+
+  // Negative tests with invalid index
+  Status status = batch.RemoveColumn(3, &result);
+  ASSERT_TRUE(status.IsInvalid());
+  status = batch.RemoveColumn(-1, &result);
+  ASSERT_TRUE(status.IsInvalid());
+
+  ASSERT_OK(batch.RemoveColumn(0, &result));
+  ASSERT_TRUE(result->Equals(*batch2));
+
+  ASSERT_OK(batch.RemoveColumn(1, &result));
+  ASSERT_TRUE(result->Equals(*batch3));
+
+  ASSERT_OK(batch.RemoveColumn(2, &result));
+  ASSERT_TRUE(result->Equals(*batch4));
+}
+
+TEST_F(TestRecordBatch, RemoveColumnEmpty) {
+  const int length = 10;
+
+  auto field1 = field("f1", int32());
+  auto schema1 = ::arrow::schema({field1});
+  auto array1 = MakeRandomArray<Int32Array>(length);
+  auto batch1 = RecordBatch::Make(schema1, length, {array1});
+
+  std::shared_ptr<RecordBatch> empty;
+  ASSERT_OK(batch1->RemoveColumn(0, &empty));
+  ASSERT_EQ(batch1->num_rows(), empty->num_rows());
+
+  std::shared_ptr<RecordBatch> added;
+  ASSERT_OK(empty->AddColumn(0, field1, array1, &added));
+  ASSERT_TRUE(added->Equals(*batch1));
+}
+
 class TestTableBatchReader : public TestBase {};
 
 TEST_F(TestTableBatchReader, ReadNext) {
diff --git a/cpp/src/arrow/table.cc b/cpp/src/arrow/table.cc
index 8cfd67fae..62ea32a8d 100644
--- a/cpp/src/arrow/table.cc
+++ b/cpp/src/arrow/table.cc
@@ -234,14 +234,8 @@ class SimpleTable : public Table {
 
   Status AddColumn(int i, const std::shared_ptr<Column>& col,
                    std::shared_ptr<Table>* out) const override {
-    if (i < 0 || i > num_columns() + 1) {
-      return Status::Invalid("Invalid column index.");
-    }
-    if (col == nullptr) {
-      std::stringstream ss;
-      ss << "Column " << i << " was null";
-      return Status::Invalid(ss.str());
-    }
+    DCHECK(col != nullptr);
+
     if (col->length() != num_rows_) {
       std::stringstream ss;
       ss << "Added column's length must match table's length. Expected length "
diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc
index 0a2889f04..836a2aa93 100644
--- a/cpp/src/arrow/type.cc
+++ b/cpp/src/arrow/type.cc
@@ -293,8 +293,9 @@ int64_t Schema::GetFieldIndex(const std::string& name) 
const {
 
 Status Schema::AddField(int i, const std::shared_ptr<Field>& field,
                         std::shared_ptr<Schema>* out) const {
-  DCHECK_GE(i, 0);
-  DCHECK_LE(i, this->num_fields());
+  if (i < 0 || i > this->num_fields()) {
+    return Status::Invalid("Invalid column index to add field.");
+  }
 
   *out =
       std::make_shared<Schema>(internal::AddVectorElement(fields_, i, field), 
metadata_);
@@ -323,8 +324,9 @@ std::shared_ptr<Schema> Schema::RemoveMetadata() const {
 }
 
 Status Schema::RemoveField(int i, std::shared_ptr<Schema>* out) const {
-  DCHECK_GE(i, 0);
-  DCHECK_LT(i, this->num_fields());
+  if (i < 0 || i >= this->num_fields()) {
+    return Status::Invalid("Invalid column index to remove field.");
+  }
 
   *out = std::make_shared<Schema>(internal::DeleteVectorElement(fields_, i), 
metadata_);
   return Status::OK();


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


> [C++/Python] Add add/remove field functions for RecordBatch
> -----------------------------------------------------------
>
>                 Key: ARROW-969
>                 URL: https://issues.apache.org/jira/browse/ARROW-969
>             Project: Apache Arrow
>          Issue Type: New Feature
>          Components: C++, Python
>            Reporter: Wes McKinney
>            Assignee: Panchen Xue
>            Priority: Major
>              Labels: pull-request-available
>             Fix For: 0.9.0
>
>
> Analogous to the Table equivalents



--
This message was sent by Atlassian JIRA
(v7.6.3#76005)

Reply via email to