This is an automated email from the ASF dual-hosted git repository.

wesm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 4a8a102  ARROW-969: [C++] Add add/remove field functions for 
RecordBatch
4a8a102 is described below

commit 4a8a102d508fc31570584b3618291103a901e6b0
Author: Panchen Xue <pan.panchen....@gmail.com>
AuthorDate: Tue Feb 13 01:20:29 2018 -0500

    ARROW-969: [C++] Add add/remove field functions for RecordBatch
    
    Add AddColumn and RemoveColumn methods for RecordBatch, as well as test 
cases
    
    Author: Panchen Xue <pan.panchen....@gmail.com>
    Author: Wes McKinney <wes.mckin...@twosigma.com>
    
    Closes #1574 from xuepanchen/ARROW-969 and squashes the following commits:
    
    b082c46b [Wes McKinney] Add variant of RecordBatch::AddColumn that takes a 
field name instead of a Field
    cd3a69d0 [Panchen Xue] Move index boundscheck to Schema module and add more 
test cases for boundscheck
    68e93f49 [Panchen Xue] Add test cases for AddColumn and RemoveColumn
    5ebfdfff [Panchen Xue] Add AddColumn and RemoveColumn methods for 
RecordBatch
---
 cpp/src/arrow/record_batch.cc |  45 +++++++++++++++++
 cpp/src/arrow/record_batch.h  |  29 +++++++++++
 cpp/src/arrow/table-test.cc   | 111 ++++++++++++++++++++++++++++++++++++++++++
 cpp/src/arrow/table.cc        |  10 +---
 cpp/src/arrow/type.cc         |  10 ++--
 5 files changed, 193 insertions(+), 12 deletions(-)

diff --git a/cpp/src/arrow/record_batch.cc b/cpp/src/arrow/record_batch.cc
index d418cc4..f295b86 100644
--- a/cpp/src/arrow/record_batch.cc
+++ b/cpp/src/arrow/record_batch.cc
@@ -21,15 +21,24 @@
 #include <cstdlib>
 #include <memory>
 #include <sstream>
+#include <string>
 #include <utility>
 
 #include "arrow/array.h"
 #include "arrow/status.h"
 #include "arrow/type.h"
 #include "arrow/util/logging.h"
+#include "arrow/util/stl.h"
 
 namespace arrow {
 
+Status RecordBatch::AddColumn(int i, const std::string& field_name,
+                              const std::shared_ptr<Array>& column,
+                              std::shared_ptr<RecordBatch>* out) const {
+  auto field = ::arrow::field(field_name, column->type());
+  return AddColumn(i, field, column, out);
+}
+
 /// \class SimpleRecordBatch
 /// \brief A basic, non-lazy in-memory record batch
 class SimpleRecordBatch : public RecordBatch {
@@ -78,6 +87,42 @@ class SimpleRecordBatch : public RecordBatch {
 
   std::shared_ptr<ArrayData> column_data(int i) const override { return 
columns_[i]; }
 
+  Status AddColumn(int i, const std::shared_ptr<Field>& field,
+                   const std::shared_ptr<Array>& column,
+                   std::shared_ptr<RecordBatch>* out) const override {
+    DCHECK(field != nullptr);
+    DCHECK(column != nullptr);
+
+    if (!field->type()->Equals(column->type())) {
+      std::stringstream ss;
+      ss << "Column data type " << field->type()->name()
+         << " does not match field data type " << column->type()->name();
+      return Status::Invalid(ss.str());
+    }
+    if (column->length() != num_rows_) {
+      std::stringstream ss;
+      ss << "Added column's length must match record batch's length. Expected 
length "
+         << num_rows_ << " but got length " << column->length();
+      return Status::Invalid(ss.str());
+    }
+
+    std::shared_ptr<Schema> new_schema;
+    RETURN_NOT_OK(schema_->AddField(i, field, &new_schema));
+
+    *out = RecordBatch::Make(new_schema, num_rows_,
+                             internal::AddVectorElement(columns_, i, 
column->data()));
+    return Status::OK();
+  }
+
+  Status RemoveColumn(int i, std::shared_ptr<RecordBatch>* out) const override 
{
+    std::shared_ptr<Schema> new_schema;
+    RETURN_NOT_OK(schema_->RemoveField(i, &new_schema));
+
+    *out = RecordBatch::Make(new_schema, num_rows_,
+                             internal::DeleteVectorElement(columns_, i));
+    return Status::OK();
+  }
+
   std::shared_ptr<RecordBatch> ReplaceSchemaMetadata(
       const std::shared_ptr<const KeyValueMetadata>& metadata) const override {
     auto new_schema = schema_->AddMetadata(metadata);
diff --git a/cpp/src/arrow/record_batch.h b/cpp/src/arrow/record_batch.h
index b2c4c76..6fb747c 100644
--- a/cpp/src/arrow/record_batch.h
+++ b/cpp/src/arrow/record_batch.h
@@ -96,6 +96,35 @@ class ARROW_EXPORT RecordBatch {
   /// \return an internal ArrayData object
   virtual std::shared_ptr<ArrayData> column_data(int i) const = 0;
 
+  /// \brief Add column to the record batch, producing a new RecordBatch
+  ///
+  /// \param[in] i field index, which will be boundschecked
+  /// \param[in] field field to be added
+  /// \param[in] column column to be added
+  /// \param[out] out record batch with column added
+  virtual Status AddColumn(int i, const std::shared_ptr<Field>& field,
+                           const std::shared_ptr<Array>& column,
+                           std::shared_ptr<RecordBatch>* out) const = 0;
+
+  /// \brief Add new nullable column to the record batch, producing a new
+  /// RecordBatch.
+  ///
+  /// For non-nullable columns, use the Field-based version of this method.
+  ///
+  /// \param[in] i field index, which will be boundschecked
+  /// \param[in] field_name name of field to be added
+  /// \param[in] column column to be added
+  /// \param[out] out record batch with column added
+  virtual Status AddColumn(int i, const std::string& field_name,
+                           const std::shared_ptr<Array>& column,
+                           std::shared_ptr<RecordBatch>* out) const;
+
+  /// \brief Remove column from the record batch, producing a new RecordBatch
+  ///
+  /// \param[in] i field index, does boundscheck
+  /// \param[out] out record batch with column removed
+  virtual Status RemoveColumn(int i, std::shared_ptr<RecordBatch>* out) const 
= 0;
+
   virtual std::shared_ptr<RecordBatch> ReplaceSchemaMetadata(
       const std::shared_ptr<const KeyValueMetadata>& metadata) const = 0;
 
diff --git a/cpp/src/arrow/table-test.cc b/cpp/src/arrow/table-test.cc
index 99e4dd5..af74416 100644
--- a/cpp/src/arrow/table-test.cc
+++ b/cpp/src/arrow/table-test.cc
@@ -466,6 +466,8 @@ TEST_F(TestTable, AddColumn) {
   // Some negative tests with invalid index
   Status status = table.AddColumn(10, columns_[0], &result);
   ASSERT_TRUE(status.IsInvalid());
+  status = table.AddColumn(4, columns_[0], &result);
+  ASSERT_TRUE(status.IsInvalid());
   status = table.AddColumn(-1, columns_[0], &result);
   ASSERT_TRUE(status.IsInvalid());
 
@@ -588,6 +590,115 @@ TEST_F(TestRecordBatch, Slice) {
   }
 }
 
+TEST_F(TestRecordBatch, AddColumn) {
+  const int length = 10;
+
+  auto field1 = field("f1", int32());
+  auto field2 = field("f2", uint8());
+  auto field3 = field("f3", int16());
+
+  auto schema1 = ::arrow::schema({field1, field2});
+  auto schema2 = ::arrow::schema({field2, field3});
+  auto schema3 = ::arrow::schema({field2});
+
+  auto array1 = MakeRandomArray<Int32Array>(length);
+  auto array2 = MakeRandomArray<UInt8Array>(length);
+  auto array3 = MakeRandomArray<Int16Array>(length);
+
+  auto batch1 = RecordBatch::Make(schema1, length, {array1, array2});
+  auto batch2 = RecordBatch::Make(schema2, length, {array2, array3});
+  auto batch3 = RecordBatch::Make(schema3, length, {array2});
+
+  const RecordBatch& batch = *batch3;
+  std::shared_ptr<RecordBatch> result;
+
+  // Negative tests with invalid index
+  Status status = batch.AddColumn(5, field1, array1, &result);
+  ASSERT_TRUE(status.IsInvalid());
+  status = batch.AddColumn(2, field1, array1, &result);
+  ASSERT_TRUE(status.IsInvalid());
+  status = batch.AddColumn(-1, field1, array1, &result);
+  ASSERT_TRUE(status.IsInvalid());
+
+  // Negative test with wrong length
+  auto longer_col = MakeRandomArray<Int32Array>(length + 1);
+  status = batch.AddColumn(0, field1, longer_col, &result);
+  ASSERT_TRUE(status.IsInvalid());
+
+  // Negative test with mismatch type
+  status = batch.AddColumn(0, field1, array2, &result);
+  ASSERT_TRUE(status.IsInvalid());
+
+  ASSERT_OK(batch.AddColumn(0, field1, array1, &result));
+  ASSERT_TRUE(result->Equals(*batch1));
+
+  ASSERT_OK(batch.AddColumn(1, field3, array3, &result));
+  ASSERT_TRUE(result->Equals(*batch2));
+
+  std::shared_ptr<RecordBatch> result2;
+  ASSERT_OK(batch.AddColumn(1, "f3", array3, &result2));
+  ASSERT_TRUE(result2->Equals(*result));
+
+  ASSERT_TRUE(result2->schema()->field(1)->nullable());
+}
+
+TEST_F(TestRecordBatch, RemoveColumn) {
+  const int length = 10;
+
+  auto field1 = field("f1", int32());
+  auto field2 = field("f2", uint8());
+  auto field3 = field("f3", int16());
+
+  auto schema1 = ::arrow::schema({field1, field2, field3});
+  auto schema2 = ::arrow::schema({field2, field3});
+  auto schema3 = ::arrow::schema({field1, field3});
+  auto schema4 = ::arrow::schema({field1, field2});
+
+  auto array1 = MakeRandomArray<Int32Array>(length);
+  auto array2 = MakeRandomArray<UInt8Array>(length);
+  auto array3 = MakeRandomArray<Int16Array>(length);
+
+  auto batch1 = RecordBatch::Make(schema1, length, {array1, array2, array3});
+  auto batch2 = RecordBatch::Make(schema2, length, {array2, array3});
+  auto batch3 = RecordBatch::Make(schema3, length, {array1, array3});
+  auto batch4 = RecordBatch::Make(schema4, length, {array1, array2});
+
+  const RecordBatch& batch = *batch1;
+  std::shared_ptr<RecordBatch> result;
+
+  // Negative tests with invalid index
+  Status status = batch.RemoveColumn(3, &result);
+  ASSERT_TRUE(status.IsInvalid());
+  status = batch.RemoveColumn(-1, &result);
+  ASSERT_TRUE(status.IsInvalid());
+
+  ASSERT_OK(batch.RemoveColumn(0, &result));
+  ASSERT_TRUE(result->Equals(*batch2));
+
+  ASSERT_OK(batch.RemoveColumn(1, &result));
+  ASSERT_TRUE(result->Equals(*batch3));
+
+  ASSERT_OK(batch.RemoveColumn(2, &result));
+  ASSERT_TRUE(result->Equals(*batch4));
+}
+
+TEST_F(TestRecordBatch, RemoveColumnEmpty) {
+  const int length = 10;
+
+  auto field1 = field("f1", int32());
+  auto schema1 = ::arrow::schema({field1});
+  auto array1 = MakeRandomArray<Int32Array>(length);
+  auto batch1 = RecordBatch::Make(schema1, length, {array1});
+
+  std::shared_ptr<RecordBatch> empty;
+  ASSERT_OK(batch1->RemoveColumn(0, &empty));
+  ASSERT_EQ(batch1->num_rows(), empty->num_rows());
+
+  std::shared_ptr<RecordBatch> added;
+  ASSERT_OK(empty->AddColumn(0, field1, array1, &added));
+  ASSERT_TRUE(added->Equals(*batch1));
+}
+
 class TestTableBatchReader : public TestBase {};
 
 TEST_F(TestTableBatchReader, ReadNext) {
diff --git a/cpp/src/arrow/table.cc b/cpp/src/arrow/table.cc
index 8cfd67f..62ea32a 100644
--- a/cpp/src/arrow/table.cc
+++ b/cpp/src/arrow/table.cc
@@ -234,14 +234,8 @@ class SimpleTable : public Table {
 
   Status AddColumn(int i, const std::shared_ptr<Column>& col,
                    std::shared_ptr<Table>* out) const override {
-    if (i < 0 || i > num_columns() + 1) {
-      return Status::Invalid("Invalid column index.");
-    }
-    if (col == nullptr) {
-      std::stringstream ss;
-      ss << "Column " << i << " was null";
-      return Status::Invalid(ss.str());
-    }
+    DCHECK(col != nullptr);
+
     if (col->length() != num_rows_) {
       std::stringstream ss;
       ss << "Added column's length must match table's length. Expected length "
diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc
index 0a2889f..836a2aa 100644
--- a/cpp/src/arrow/type.cc
+++ b/cpp/src/arrow/type.cc
@@ -293,8 +293,9 @@ int64_t Schema::GetFieldIndex(const std::string& name) 
const {
 
 Status Schema::AddField(int i, const std::shared_ptr<Field>& field,
                         std::shared_ptr<Schema>* out) const {
-  DCHECK_GE(i, 0);
-  DCHECK_LE(i, this->num_fields());
+  if (i < 0 || i > this->num_fields()) {
+    return Status::Invalid("Invalid column index to add field.");
+  }
 
   *out =
       std::make_shared<Schema>(internal::AddVectorElement(fields_, i, field), 
metadata_);
@@ -323,8 +324,9 @@ std::shared_ptr<Schema> Schema::RemoveMetadata() const {
 }
 
 Status Schema::RemoveField(int i, std::shared_ptr<Schema>* out) const {
-  DCHECK_GE(i, 0);
-  DCHECK_LT(i, this->num_fields());
+  if (i < 0 || i >= this->num_fields()) {
+    return Status::Invalid("Invalid column index to remove field.");
+  }
 
   *out = std::make_shared<Schema>(internal::DeleteVectorElement(fields_, i), 
metadata_);
   return Status::OK();

-- 
To stop receiving notification emails like this one, please contact
w...@apache.org.

Reply via email to