This is an automated email from the ASF dual-hosted git repository.

wesm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 05bfb26  ARROW-1849: [GLib] Add input checks to GArrowRecordBatch
05bfb26 is described below

commit 05bfb2619b19154eead36e4aa067f440fdedf49a
Author: Kouhei Sutou <k...@clear-code.com>
AuthorDate: Fri Nov 24 09:32:36 2017 -0500

    ARROW-1849: [GLib] Add input checks to GArrowRecordBatch
    
    Author: Kouhei Sutou <k...@clear-code.com>
    
    Closes #1351 from kou/glib-add-input-check-to-record-batch and squashes the 
following commits:
    
    c85dde22 [Kouhei Sutou] [GLib] Follow API change in Go example
    3bf1eeb4 [Kouhei Sutou] [GLib] Add index check to column readers for 
GArrowRecordBatch
    94d9f238 [Kouhei Sutou] Always validate on creating new record batch
---
 c_glib/arrow-glib/record-batch.cpp | 54 ++++++++++++++++++++-----
 c_glib/arrow-glib/record-batch.h   |  7 ++--
 c_glib/example/go/read-batch.go    |  4 +-
 c_glib/example/go/read-stream.go   |  4 +-
 c_glib/example/go/write-batch.go   | 10 ++++-
 c_glib/example/go/write-stream.go  | 10 ++++-
 c_glib/test/test-record-batch.rb   | 81 +++++++++++++++++++++++++++-----------
 7 files changed, 128 insertions(+), 42 deletions(-)

diff --git a/c_glib/arrow-glib/record-batch.cpp 
b/c_glib/arrow-glib/record-batch.cpp
index f23a0cf..64f2020 100644
--- a/c_glib/arrow-glib/record-batch.cpp
+++ b/c_glib/arrow-glib/record-batch.cpp
@@ -28,6 +28,23 @@
 
 #include <sstream>
 
+static inline bool
+garrow_record_batch_adjust_index(const std::shared_ptr<arrow::RecordBatch> 
arrow_record_batch,
+                                 gint &i)
+{
+  auto n_columns = arrow_record_batch->num_columns();
+  if (i < 0) {
+    i += n_columns;
+    if (i < 0) {
+      return false;
+    }
+  }
+  if (i >= n_columns) {
+    return false;
+  }
+  return true;
+}
+
 G_BEGIN_DECLS
 
 /**
@@ -135,13 +152,15 @@ garrow_record_batch_class_init(GArrowRecordBatchClass 
*klass)
  * @schema: The schema of the record batch.
  * @n_rows: The number of the rows in the record batch.
  * @columns: (element-type GArrowArray): The columns in the record batch.
+ * @error: (nullable): Return location for a #GError or %NULL.
  *
- * Returns: A newly created #GArrowRecordBatch.
+ * Returns: (nullable): A newly created #GArrowRecordBatch or %NULL on error.
  */
 GArrowRecordBatch *
 garrow_record_batch_new(GArrowSchema *schema,
                         guint32 n_rows,
-                        GList *columns)
+                        GList *columns,
+                        GError **error)
 {
   std::vector<std::shared_ptr<arrow::Array>> arrow_columns;
   for (GList *node = columns; node; node = node->next) {
@@ -152,7 +171,12 @@ garrow_record_batch_new(GArrowSchema *schema,
   auto arrow_record_batch =
     arrow::RecordBatch::Make(garrow_schema_get_raw(schema),
                              n_rows, arrow_columns);
-  return garrow_record_batch_new_raw(&arrow_record_batch);
+  auto status = arrow_record_batch->Validate();
+  if (garrow_error_check(error, status, "[record-batch][new]")) {
+    return garrow_record_batch_new_raw(&arrow_record_batch);
+  } else {
+    return NULL;
+  }
 }
 
 /**
@@ -192,15 +216,21 @@ garrow_record_batch_get_schema(GArrowRecordBatch 
*record_batch)
 /**
  * garrow_record_batch_get_column:
  * @record_batch: A #GArrowRecordBatch.
- * @i: The index of the target column.
+ * @i: The index of the target column. If it's negative, index is
+ *   counted backward from the end of the columns. `-1` means the last
+ *   column.
  *
- * Returns: (transfer full): The i-th column in the record batch.
+ * Returns: (transfer full) (nullable): The i-th column in the record batch
+ *   on success, %NULL on out of index.
  */
 GArrowArray *
 garrow_record_batch_get_column(GArrowRecordBatch *record_batch,
-                               guint i)
+                               gint i)
 {
   const auto arrow_record_batch = garrow_record_batch_get_raw(record_batch);
+  if (!garrow_record_batch_adjust_index(arrow_record_batch, i)) {
+    return NULL;
+  }
   auto arrow_column = arrow_record_batch->column(i);
   return garrow_array_new_raw(&arrow_column);
 }
@@ -230,15 +260,21 @@ garrow_record_batch_get_columns(GArrowRecordBatch 
*record_batch)
 /**
  * garrow_record_batch_get_column_name:
  * @record_batch: A #GArrowRecordBatch.
- * @i: The index of the target column.
+ * @i: The index of the target column. If it's negative, index is
+ *   counted backward from the end of the columns. `-1` means the last
+ *   column.
  *
- * Returns: The name of the i-th column in the record batch.
+ * Returns: (nullable): The name of the i-th column in the record batch
+ *   on success, %NULL on out of index
  */
 const gchar *
 garrow_record_batch_get_column_name(GArrowRecordBatch *record_batch,
-                                    guint i)
+                                    gint i)
 {
   const auto arrow_record_batch = garrow_record_batch_get_raw(record_batch);
+  if (!garrow_record_batch_adjust_index(arrow_record_batch, i)) {
+    return NULL;
+  }
   return arrow_record_batch->column_name(i).c_str();
 }
 
diff --git a/c_glib/arrow-glib/record-batch.h b/c_glib/arrow-glib/record-batch.h
index 021f894..d31edf4 100644
--- a/c_glib/arrow-glib/record-batch.h
+++ b/c_glib/arrow-glib/record-batch.h
@@ -68,17 +68,18 @@ GType garrow_record_batch_get_type(void) G_GNUC_CONST;
 
 GArrowRecordBatch *garrow_record_batch_new(GArrowSchema *schema,
                                            guint32 n_rows,
-                                           GList *columns);
+                                           GList *columns,
+                                           GError **error);
 
 gboolean garrow_record_batch_equal(GArrowRecordBatch *record_batch,
                                    GArrowRecordBatch *other_record_batch);
 
 GArrowSchema *garrow_record_batch_get_schema     (GArrowRecordBatch 
*record_batch);
 GArrowArray  *garrow_record_batch_get_column     (GArrowRecordBatch 
*record_batch,
-                                                  guint i);
+                                                  gint i);
 GList        *garrow_record_batch_get_columns    (GArrowRecordBatch 
*record_batch);
 const gchar  *garrow_record_batch_get_column_name(GArrowRecordBatch 
*record_batch,
-                                                  guint i);
+                                                  gint i);
 guint         garrow_record_batch_get_n_columns  (GArrowRecordBatch 
*record_batch);
 gint64        garrow_record_batch_get_n_rows     (GArrowRecordBatch 
*record_batch);
 GArrowRecordBatch *garrow_record_batch_slice     (GArrowRecordBatch 
*record_batch,
diff --git a/c_glib/example/go/read-batch.go b/c_glib/example/go/read-batch.go
index ef1a7fb..1472939 100644
--- a/c_glib/example/go/read-batch.go
+++ b/c_glib/example/go/read-batch.go
@@ -57,8 +57,8 @@ func PrintColumnValue(column *arrow.Array, i int64) {
 func PrintRecordBatch(recordBatch *arrow.RecordBatch) {
        nColumns := recordBatch.GetNColumns()
        for i := uint32(0); i < nColumns; i++ {
-               column := recordBatch.GetColumn(i)
-               columnName := recordBatch.GetColumnName(i)
+               column := recordBatch.GetColumn(int32(i))
+               columnName := recordBatch.GetColumnName(int32(i))
                fmt.Printf("  %s: [", columnName)
                nRows := recordBatch.GetNRows()
                for j := int64(0); j < nRows; j++ {
diff --git a/c_glib/example/go/read-stream.go b/c_glib/example/go/read-stream.go
index 7bd0764..ed75a96 100644
--- a/c_glib/example/go/read-stream.go
+++ b/c_glib/example/go/read-stream.go
@@ -57,8 +57,8 @@ func PrintColumnValue(column *arrow.Array, i int64) {
 func PrintRecordBatch(recordBatch *arrow.RecordBatch) {
        nColumns := recordBatch.GetNColumns()
        for i := uint32(0); i < nColumns; i++ {
-               column := recordBatch.GetColumn(i)
-               columnName := recordBatch.GetColumnName(i)
+               column := recordBatch.GetColumn(int32(i))
+               columnName := recordBatch.GetColumnName(int32(i))
                fmt.Printf("  %s: [", columnName)
                nRows := recordBatch.GetNRows()
                for j := int64(0); j < nRows; j++ {
diff --git a/c_glib/example/go/write-batch.go b/c_glib/example/go/write-batch.go
index 9dbc3c0..f4d03ed 100644
--- a/c_glib/example/go/write-batch.go
+++ b/c_glib/example/go/write-batch.go
@@ -188,7 +188,10 @@ func main() {
                BuildDoubleArray(),
        }
 
-       recordBatch := arrow.NewRecordBatch(schema, 4, columns)
+       recordBatch, err := arrow.NewRecordBatch(schema, 4, columns)
+       if err != nil {
+               log.Fatalf("Failed to create record batch #1: %v", err)
+       }
        _, err = writer.WriteRecordBatch(recordBatch)
        if err != nil {
                log.Fatalf("Failed to write record batch #1: %v", err)
@@ -198,7 +201,10 @@ func main() {
        for i, column := range columns {
                slicedColumns[i] = column.Slice(1, 3)
        }
-       recordBatch = arrow.NewRecordBatch(schema, 3, slicedColumns)
+       recordBatch, err = arrow.NewRecordBatch(schema, 3, slicedColumns)
+       if err != nil {
+               log.Fatalf("Failed to create record batch #2: %v", err)
+       }
        _, err = writer.WriteRecordBatch(recordBatch)
        if err != nil {
                log.Fatalf("Failed to write record batch #2: %v", err)
diff --git a/c_glib/example/go/write-stream.go 
b/c_glib/example/go/write-stream.go
index 244741e..7225156 100644
--- a/c_glib/example/go/write-stream.go
+++ b/c_glib/example/go/write-stream.go
@@ -188,7 +188,10 @@ func main() {
                BuildDoubleArray(),
        }
 
-       recordBatch := arrow.NewRecordBatch(schema, 4, columns)
+       recordBatch, err := arrow.NewRecordBatch(schema, 4, columns)
+       if err != nil {
+               log.Fatalf("Failed to create record batch #1: %v", err)
+       }
        _, err = writer.WriteRecordBatch(recordBatch)
        if err != nil {
                log.Fatalf("Failed to write record batch #1: %v", err)
@@ -198,7 +201,10 @@ func main() {
        for i, column := range columns {
                slicedColumns[i] = column.Slice(1, 3)
        }
-       recordBatch = arrow.NewRecordBatch(schema, 3, slicedColumns)
+       recordBatch, err = arrow.NewRecordBatch(schema, 3, slicedColumns)
+       if err != nil {
+               log.Fatalf("Failed to create record batch #2: %v", err)
+       }
        writer.WriteRecordBatch(recordBatch)
        _, err = writer.WriteRecordBatch(recordBatch)
        if err != nil {
diff --git a/c_glib/test/test-record-batch.rb b/c_glib/test/test-record-batch.rb
index 9fd34b7..365922f 100644
--- a/c_glib/test/test-record-batch.rb
+++ b/c_glib/test/test-record-batch.rb
@@ -18,32 +18,53 @@
 class TestTable < Test::Unit::TestCase
   include Helper::Buildable
 
-  def test_new
-    fields = [
-      Arrow::Field.new("visible", Arrow::BooleanDataType.new),
-      Arrow::Field.new("valid", Arrow::BooleanDataType.new),
-    ]
-    schema = Arrow::Schema.new(fields)
-    columns = [
-      build_boolean_array([true]),
-      build_boolean_array([false]),
-    ]
-    record_batch = Arrow::RecordBatch.new(schema, 1, columns)
-    assert_equal(1, record_batch.n_rows)
+  sub_test_case(".new") do
+    def test_valid
+      fields = [
+        Arrow::Field.new("visible", Arrow::BooleanDataType.new),
+        Arrow::Field.new("valid", Arrow::BooleanDataType.new),
+      ]
+      schema = Arrow::Schema.new(fields)
+      columns = [
+        build_boolean_array([true]),
+        build_boolean_array([false]),
+      ]
+      record_batch = Arrow::RecordBatch.new(schema, 1, columns)
+      assert_equal(1, record_batch.n_rows)
+    end
+
+    def test_no_columns
+      fields = [
+        Arrow::Field.new("visible", Arrow::BooleanDataType.new),
+      ]
+      schema = Arrow::Schema.new(fields)
+      message = "[record-batch][new]: " +
+        "Invalid: Number of columns did not match schema"
+      assert_raise(Arrow::Error::Invalid.new(message)) do
+        Arrow::RecordBatch.new(schema, 0, [])
+      end
+    end
   end
 
   sub_test_case("instance methods") do
     def setup
+      @visible_field = Arrow::Field.new("visible", Arrow::BooleanDataType.new)
+      @visible_values = [true, false, true, false, true]
+      @valid_field = Arrow::Field.new("valid", Arrow::BooleanDataType.new)
+      @valid_values = [false, true, false, true, false]
+
       fields = [
-        Arrow::Field.new("visible", Arrow::BooleanDataType.new),
-        Arrow::Field.new("valid", Arrow::BooleanDataType.new),
+        @visible_field,
+        @valid_field,
       ]
       schema = Arrow::Schema.new(fields)
       columns = [
-        build_boolean_array([true, false, true, false, true, false]),
-        build_boolean_array([false, true, false, true, false]),
+        build_boolean_array(@visible_values),
+        build_boolean_array(@valid_values),
       ]
-      @record_batch = Arrow::RecordBatch.new(schema, 5, columns)
+      @record_batch = Arrow::RecordBatch.new(schema,
+                                             @visible_values.size,
+                                             columns)
     end
 
     def test_equal
@@ -53,7 +74,7 @@ class TestTable < Test::Unit::TestCase
       ]
       schema = Arrow::Schema.new(fields)
       columns = [
-        build_boolean_array([true, false, true, false, true, false]),
+        build_boolean_array([true, false, true, false, true]),
         build_boolean_array([false, true, false, true, false]),
       ]
       other_record_batch = Arrow::RecordBatch.new(schema, 5, columns)
@@ -66,12 +87,28 @@ class TestTable < Test::Unit::TestCase
                    @record_batch.schema.fields.collect(&:name))
     end
 
-    def test_column
-      assert_equal(5, @record_batch.get_column(1).length)
+    sub_test_case("#column") do
+      def test_positive
+        assert_equal(build_boolean_array(@valid_values),
+                     @record_batch.get_column(1))
+      end
+
+      def test_negative
+        assert_equal(build_boolean_array(@visible_values),
+                     @record_batch.get_column(-2))
+      end
+
+      def test_positive_out_of_index
+        assert_nil(@record_batch.get_column(2))
+      end
+
+      def test_negative_out_of_index
+        assert_nil(@record_batch.get_column(-3))
+      end
     end
 
     def test_columns
-      assert_equal([6, 5],
+      assert_equal([5, 5],
                    @record_batch.columns.collect(&:length))
     end
 
@@ -94,7 +131,7 @@ class TestTable < Test::Unit::TestCase
 
     def test_to_s
       assert_equal(<<-PRETTY_PRINT, @record_batch.to_s)
-visible: [true, false, true, false, true, false]
+visible: [true, false, true, false, true]
 valid: [false, true, false, true, false]
       PRETTY_PRINT
     end

-- 
To stop receiving notification emails like this one, please contact
['"commits@arrow.apache.org" <commits@arrow.apache.org>'].

Reply via email to