This is an automated email from the ASF dual-hosted git repository.

kevingurney pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new 798132ce65 GH-37473: [MATLAB] Add support for indexing `RecordBatch` 
columns by `Field` name  (#37475)
798132ce65 is described below

commit 798132ce65ee4d7fb902e737e2e7832bf58faaaa
Author: Kevin Gurney <[email protected]>
AuthorDate: Wed Aug 30 16:13:13 2023 -0400

    GH-37473: [MATLAB] Add support for indexing `RecordBatch` columns by 
`Field` name  (#37475)
    
    ### Rationale for this change
    
    Currently, `arrow.tabular.Schema` supports indexing by `Field` name. 
However, `arrow.tabular.RecordBatch` does not.
    
    This pull request adds the ability to index columns in a `RecordBatch` by 
`Field` name.
    
    ### What changes are included in this PR?
    
    1. Added support for indexing columns in a `RecordBatch` by `Field` name 
via the `column` method.
    
    **Example**
    ```matlab
    >> recordBatch = arrow.tabular.RecordBatch.fromArrays(...
           arrow.array([1, 2, 3]), ...
           arrow.array(["A", "B", "C"]), ...
           arrow.array([true, false, true]), ...
           ColumnNames=["A", "B", "C"] ...
       )
    
    recordBatch =
    
    A:   [
        1,
        2,
        3
      ]
    B:   [
        "A",
        "B",
        "C"
      ]
    C:   [
        true,
        false,
        true
      ]
    
    >> recordBatch.column("B")
    
    ans =
    
    [
      "A",
      "B",
      "C"
    ]
    
    >> recordBatch.column("C")
    
    ans =
    
    [
      true,
      false,
      true
    ]
    ```
    2. Removed comments about vectorizing `field` method of `Schema` and 
`column` method of `RecordBatch`. After further consideration, we believe it 
would make more sense to only allow these methods to accept scalar inputs. We 
could revisit support for vectorization if we overload the parenthesis operator 
(e.g. `recordBatch(rows, columns)`)  in the future to return another 
`RecordBatch`/`Schema` that only includes the selected columns/fields.
    3. Fixed typo in `tSchema.m`.
    
    ### Are these changes tested?
    
    Yes.
    
    1. Added tests for indexing by column name using the `column` method to 
`tRecordBatch.m`.
    
    ### Are there any user-facing changes?
    
    Yes.
    
    1. Users can now index `RecordBatch` columns by name using the syntax 
`column(name)`.
    
    ### Future Directions
    
    1. Consider overloading parentheses-based indexing on `RecordBatch` and 
`Schema`.
    * Closes: #37473
    
    Lead-authored-by: Kevin Gurney <[email protected]>
    Co-authored-by: Sarah Gilmore <[email protected]>
    Signed-off-by: Kevin Gurney <[email protected]>
---
 .../cpp/arrow/matlab/tabular/proxy/record_batch.cc |  29 ++++
 .../cpp/arrow/matlab/tabular/proxy/record_batch.h  |   1 +
 matlab/src/matlab/+arrow/+tabular/RecordBatch.m    |  18 ++-
 matlab/src/matlab/+arrow/+tabular/Schema.m         |   4 -
 matlab/test/arrow/tabular/tRecordBatch.m           | 165 ++++++++++++++++++++-
 matlab/test/arrow/tabular/tSchema.m                |   2 +-
 6 files changed, 206 insertions(+), 13 deletions(-)

diff --git a/matlab/src/cpp/arrow/matlab/tabular/proxy/record_batch.cc 
b/matlab/src/cpp/arrow/matlab/tabular/proxy/record_batch.cc
index 7b81eee21e..b785d2cf08 100644
--- a/matlab/src/cpp/arrow/matlab/tabular/proxy/record_batch.cc
+++ b/matlab/src/cpp/arrow/matlab/tabular/proxy/record_batch.cc
@@ -55,6 +55,7 @@ namespace arrow::matlab::tabular::proxy {
         REGISTER_METHOD(RecordBatch, numColumns);
         REGISTER_METHOD(RecordBatch, columnNames);
         REGISTER_METHOD(RecordBatch, getColumnByIndex);
+        REGISTER_METHOD(RecordBatch, getColumnByName);
         REGISTER_METHOD(RecordBatch, getSchema);
     }
 
@@ -166,6 +167,34 @@ namespace arrow::matlab::tabular::proxy {
         context.outputs[1] = array_type_id_mda;
     }
 
+    void RecordBatch::getColumnByName(libmexclass::proxy::method::Context& 
context) {
+        namespace mda = ::matlab::data;
+        using namespace libmexclass::proxy;
+        mda::ArrayFactory factory;
+
+        mda::StructArray args = context.inputs[0];
+        const mda::StringArray name_mda = args[0]["Name"];
+        const auto name_utf16 = std::u16string(name_mda[0]);
+        MATLAB_ASSIGN_OR_ERROR_WITH_CONTEXT(const auto name, 
arrow::util::UTF16StringToUTF8(name_utf16), context, 
error::UNICODE_CONVERSION_ERROR_ID);
+
+        const std::vector<std::string> names = {name};
+        const auto& schema = record_batch->schema();
+        
MATLAB_ERROR_IF_NOT_OK_WITH_CONTEXT(schema->CanReferenceFieldsByNames(names), 
context, error::ARROW_TABULAR_SCHEMA_AMBIGUOUS_FIELD_NAME);
+
+        const auto array = record_batch->GetColumnByName(name);
+        MATLAB_ASSIGN_OR_ERROR_WITH_CONTEXT(auto array_proxy,
+                                            
arrow::matlab::array::proxy::wrap(array),
+                                            context,
+                                            
error::UNKNOWN_PROXY_FOR_ARRAY_TYPE);
+
+        const auto array_proxy_id = ProxyManager::manageProxy(array_proxy);
+        const auto array_proxy_id_mda = factory.createScalar(array_proxy_id);
+        const auto array_type_id_mda = 
factory.createScalar(static_cast<int32_t>(array->type_id()));
+
+        context.outputs[0] = array_proxy_id_mda;
+        context.outputs[1] = array_type_id_mda;
+    }
+
     void RecordBatch::getSchema(libmexclass::proxy::method::Context& context) {
         namespace mda = ::matlab::data;
         using namespace libmexclass::proxy;
diff --git a/matlab/src/cpp/arrow/matlab/tabular/proxy/record_batch.h 
b/matlab/src/cpp/arrow/matlab/tabular/proxy/record_batch.h
index 3ffb8769a1..febf18dd17 100644
--- a/matlab/src/cpp/arrow/matlab/tabular/proxy/record_batch.h
+++ b/matlab/src/cpp/arrow/matlab/tabular/proxy/record_batch.h
@@ -38,6 +38,7 @@ namespace arrow::matlab::tabular::proxy {
             void numColumns(libmexclass::proxy::method::Context& context);
             void columnNames(libmexclass::proxy::method::Context& context);
             void getColumnByIndex(libmexclass::proxy::method::Context& 
context);
+            void getColumnByName(libmexclass::proxy::method::Context& context);
             void getSchema(libmexclass::proxy::method::Context& context);
 
             std::shared_ptr<arrow::RecordBatch> record_batch;
diff --git a/matlab/src/matlab/+arrow/+tabular/RecordBatch.m 
b/matlab/src/matlab/+arrow/+tabular/RecordBatch.m
index 8e1fc39ab7..b00c99f7a1 100644
--- a/matlab/src/matlab/+arrow/+tabular/RecordBatch.m
+++ b/matlab/src/matlab/+arrow/+tabular/RecordBatch.m
@@ -54,13 +54,17 @@ classdef RecordBatch < matlab.mixin.CustomDisplay & ...
         function arrowArray = column(obj, idx)
             import arrow.internal.validate.*
 
-            idx = index.numeric(idx, "int32");
-            % TODO: Consider vectorizing column() in the future to support
-            % extracting multiple columns at once.
-            validateattributes(idx, "int32", "scalar");
-
-            args = struct(Index=idx);
-            [proxyID, typeID] = obj.Proxy.getColumnByIndex(args);              
  
+            idx = index.numericOrString(idx, "int32");
+
+            if isnumeric(idx)
+                validateattributes(idx, "int32", "scalar");
+                args = struct(Index=idx);
+                [proxyID, typeID] = obj.Proxy.getColumnByIndex(args);
+            else
+                validateattributes(idx, "string", "scalar");
+                args = struct(Name=idx);
+                [proxyID, typeID] = obj.Proxy.getColumnByName(args);
+            end
             
             traits = arrow.type.traits.traits(arrow.type.ID(typeID));
             proxy = libmexclass.proxy.Proxy(Name=traits.ArrayProxyClassName, 
ID=proxyID);
diff --git a/matlab/src/matlab/+arrow/+tabular/Schema.m 
b/matlab/src/matlab/+arrow/+tabular/Schema.m
index b613bd650d..229260a6ff 100644
--- a/matlab/src/matlab/+arrow/+tabular/Schema.m
+++ b/matlab/src/matlab/+arrow/+tabular/Schema.m
@@ -48,14 +48,10 @@ classdef Schema < matlab.mixin.CustomDisplay
             idx = index.numericOrString(idx, "int32");
 
             if isnumeric(idx)
-                % TODO: Consider vectorizing field() to support extracting
-                % multiple fields at once.
                 validateattributes(idx, "int32", "scalar");
                 args = struct(Index=idx);
                 proxyID = obj.Proxy.getFieldByIndex(args);
             else
-                % TODO: Consider vectorizing field() to support extracting
-                % multiple fields at once.
                 validateattributes(idx, "string", "scalar");
                 args = struct(Name=idx);
                 proxyID = obj.Proxy.getFieldByName(args);
diff --git a/matlab/test/arrow/tabular/tRecordBatch.m 
b/matlab/test/arrow/tabular/tRecordBatch.m
index d9c3c98652..6721f2f3ae 100644
--- a/matlab/test/arrow/tabular/tRecordBatch.m
+++ b/matlab/test/arrow/tabular/tRecordBatch.m
@@ -109,7 +109,7 @@ classdef tRecordBatch < matlab.unittest.TestCase
             TOriginal = table(1, 2, 3);
             arrowRecordBatch = arrow.recordBatch(TOriginal);
             fcn = @() arrowRecordBatch.column(datetime(2022, 1, 3));
-            tc.verifyError(fcn, "arrow:badsubscript:NonNumeric");
+            tc.verifyError(fcn, "arrow:badsubscript:UnsupportedIndexType");
         end
 
         function ErrorIfIndexIsNonScalar(tc)
@@ -223,6 +223,169 @@ classdef tRecordBatch < matlab.unittest.TestCase
                 "MATLAB:class:SetProhibited");
         end
 
+        function GetColumnByName(testCase)
+            % Verify that columns can be accessed by name.
+            recordBatch = arrow.tabular.RecordBatch.fromArrays(...
+                arrow.array([1, 2, 3]), ...
+                arrow.array(["A", "B", "C"]), ...
+                arrow.array([true, false, true]), ...
+                ColumnNames=["A", "B", "C"] ...
+            );
+
+            expected = arrow.array([1, 2, 3]);
+            actual = recordBatch.column("A");
+            testCase.verifyEqual(actual, expected);
+
+            expected = arrow.array(["A", "B", "C"]);
+            actual = recordBatch.column("B");
+            testCase.verifyEqual(actual, expected);
+
+            expected = arrow.array([true, false, true]);
+            actual = recordBatch.column("C");
+            testCase.verifyEqual(actual, expected);
+        end
+
+        function GetColumnByNameWithEmptyString(testCase)
+            % Verify that a column whose name is the empty string ("")
+            % can be accessed using the column() method.
+            recordBatch = arrow.tabular.RecordBatch.fromArrays(...
+                arrow.array([1, 2, 3]), ...
+                arrow.array(["A", "B", "C"]), ...
+                arrow.array([true, false, true]), ...
+                ColumnNames=["A", "", "C"] ...
+            );
+
+            expected = arrow.array(["A", "B", "C"]);
+            actual = recordBatch.column("");
+            testCase.verifyEqual(actual, expected)
+        end
+
+        function GetColumnByNameWithWhitespace(testCase)
+            % Verify that a column whose name contains only whitespace
+            % characters can be accessed using the column() method.
+            recordBatch = arrow.tabular.RecordBatch.fromArrays(...
+                arrow.array([1, 2, 3]), ...
+                arrow.array(["A", "B", "C"]), ...
+                arrow.array([true, false, true]), ...
+                ColumnNames=[" ", "  ", "   "] ...
+            );
+
+            expected = arrow.array([1, 2, 3]);
+            actual = recordBatch.column(" ");
+            testCase.verifyEqual(actual, expected);
+
+            expected = arrow.array(["A", "B", "C"]);
+            actual = recordBatch.column("  ");
+            testCase.verifyEqual(actual, expected);
+
+            expected = arrow.array([true, false, true]);
+            actual = recordBatch.column("   ");
+            testCase.verifyEqual(actual, expected);
+        end
+
+        function ErrorIfColumnNameDoesNotExist(testCase)
+            % Verify that an error is thrown when trying to access a column
+            % with a name that is not part of the Schema of the RecordBatch.
+            recordBatch = arrow.tabular.RecordBatch.fromArrays(...
+                arrow.array([1, 2, 3]), ...
+                arrow.array(["A", "B", "C"]), ...
+                arrow.array([true, false, true]), ...
+                ColumnNames=["A", "B", "C"] ...
+            );
+
+            % Matching should be case sensitive.
+            name = "a";
+            testCase.verifyError(@() recordBatch.column(name), 
"arrow:tabular:schema:AmbiguousFieldName");
+
+            name = "aA";
+            testCase.verifyError(@() recordBatch.column(name), 
"arrow:tabular:schema:AmbiguousFieldName");
+
+            name = "D";
+            testCase.verifyError(@() recordBatch.column(name), 
"arrow:tabular:schema:AmbiguousFieldName");
+
+            name = "";
+            testCase.verifyError(@() recordBatch.column(name), 
"arrow:tabular:schema:AmbiguousFieldName");
+
+            name = " ";
+            testCase.verifyError(@() recordBatch.column(name), 
"arrow:tabular:schema:AmbiguousFieldName");
+        end
+
+        function ErrorIfAmbiguousColumnName(testCase)
+            % Verify that an error is thrown when trying to access a column
+            % with a name that is ambiguous / occurs more than once in the
+            % Schema of the RecordBatch.
+            recordBatch = arrow.tabular.RecordBatch.fromArrays(...
+                arrow.array([1, 2, 3]), ...
+                arrow.array(["A", "B", "C"]), ...
+                arrow.array([true, false, true]), ...
+                arrow.array([days(1), days(2), days(3)]), ...
+                ColumnNames=["A", "A", "B", "B"] ...
+            );
+
+            name = "A";
+            testCase.verifyError(@() recordBatch.column(name), 
"arrow:tabular:schema:AmbiguousFieldName");
+
+            name = "B";
+            testCase.verifyError(@() recordBatch.column(name), 
"arrow:tabular:schema:AmbiguousFieldName");
+        end
+
+        function GetColumnByNameWithChar(testCase)
+            % Verify that the column method works when supplied a char
+            % vector as input.
+            recordBatch = arrow.tabular.RecordBatch.fromArrays(...
+                arrow.array([1, 2, 3]), ...
+                arrow.array(["A", "B", "C"]), ...
+                arrow.array([true, false, true]), ...
+                ColumnNames=["", "B", "123"] ...
+            );
+
+            % Should match the first column whose name is the
+            % empty string ("").
+            name = char.empty(0, 0);
+            expected = arrow.array([1, 2, 3]);
+            actual = recordBatch.column(name);
+            testCase.verifyEqual(actual, expected);
+
+            name = char.empty(0, 1);
+            expected = arrow.array([1, 2, 3]);
+            actual = recordBatch.column(name);
+            testCase.verifyEqual(actual, expected);
+
+            name = char.empty(1, 0);
+            expected = arrow.array([1, 2, 3]);
+            actual = recordBatch.column(name);
+            testCase.verifyEqual(actual, expected);
+
+            % Should match the second column whose name is "B".
+            name = 'B';
+            expected = arrow.array(["A", "B", "C"]);
+            actual = recordBatch.column(name);
+            testCase.verifyEqual(actual, expected);
+
+            % Should match the third column whose name is "123".
+            name = '123';
+            expected = arrow.array([true, false, true]);
+            actual = recordBatch.column(name);
+            testCase.verifyEqual(actual, expected);
+        end
+
+        function ErrorIfColumnNameIsNonScalar(testCase)
+            % Verify that an error is thrown if a nonscalar string array is
+            % specified as a column name to the column method.
+            recordBatch = arrow.tabular.RecordBatch.fromArrays(...
+                arrow.array([1, 2, 3]), ...
+                arrow.array(["A", "B", "C"]), ...
+                arrow.array([true, false, true]), ...
+                ColumnNames=["A", "B", "C"] ...
+            );
+
+            name = ["A", "B", "C"];
+            testCase.verifyError(@() recordBatch.column(name), 
"MATLAB:expectedScalar");
+
+            name = ["A";  "B"; "C"];
+            testCase.verifyError(@() recordBatch.column(name), 
"MATLAB:expectedScalar");
+        end
+
     end
 
     methods
diff --git a/matlab/test/arrow/tabular/tSchema.m 
b/matlab/test/arrow/tabular/tSchema.m
index dbeb2db1fd..45329a2f5f 100644
--- a/matlab/test/arrow/tabular/tSchema.m
+++ b/matlab/test/arrow/tabular/tSchema.m
@@ -429,7 +429,7 @@ classdef tSchema < matlab.unittest.TestCase
             testCase.verifyEqual(field.Name, "B");
             testCase.verifyEqual(field.Type.ID, arrow.type.ID.UInt16);
 
-            % Should match the second field whose name is "123".
+            % Should match the third field whose name is "123".
             fieldName = '123';
             field = schema.field(fieldName);
             testCase.verifyEqual(field.Name, "123");

Reply via email to