This is an automated email from the ASF dual-hosted git repository.

kevingurney pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new 2f3db6580e GH-37477: [MATLAB] Add `AllowNonScalar` name-value pair to 
arrow.internal.validate.index.* validation functions (#37482)
2f3db6580e is described below

commit 2f3db6580ed59215524f30e09469f0451cb4a6c8
Author: sgilmore10 <[email protected]>
AuthorDate: Thu Aug 31 10:38:32 2023 -0400

    GH-37477: [MATLAB] Add `AllowNonScalar` name-value pair to 
arrow.internal.validate.index.* validation functions (#37482)
    
    
    
    ### Rationale for this change
    
    Per https://github.com/apache/arrow/pull/37475#discussion_r1310732596, we 
should consider adding a name-value pair like `AllowNonScalar = true | false` 
to the `arrow.internal.validate.index.*` validation functions since it is 
relatively common to want to explicitly allow (or disallow) non-scalar inputs 
to indexing functions (e.g. the `column` method of `RecordBatch` should only 
support scalar index values).
    
    ### What changes are included in this PR?
    
    1. Modified all functions within the `arrow.internal.valdiate.index` 
package (i.e. `numeric()`, `string()`, and `numericOrString()`)  to accept a 
name-value pair called `AllowNonScalar`. This name-value pair can be set to 
`logical` scalar, and by default it's set to `true`.
    2. Updated the `column()` method in `RecordBatch` to pass 
`AllowNonScalar=false` to `numericOrString()`.
    3. Updated the `field()` method in `RecordBatch` to pass 
`AllowNonScalar=false` to `numericOrString()`.
    
    **NOTE:** While character row vectors (e.g. `'ABC'`) are not scalar, they 
are equivalent to scalar `string` arrays. Therefore, both `string()` and 
`numericOrString()` do not error if given a character row vector as the index 
to validate and `AllowNonScalar=false`.
    
    ### Are these changes tested?
    
    Yes. Added new test cases to `tNumeric.m`, `tString.m` and 
`tNumericOrString.m`
    
    ### Are there any user-facing changes?
    
    No.
    
    * Closes: #37477
    
    Authored-by: Sarah Gilmore <[email protected]>
    Signed-off-by: Kevin Gurney <[email protected]>
---
 .../+arrow/+internal/+validate/+index/numeric.m    | 13 ++++-
 .../+internal/+validate/+index/numericOrString.m   | 13 ++++-
 .../+arrow/+internal/+validate/+index/string.m     | 13 ++++-
 matlab/src/matlab/+arrow/+tabular/RecordBatch.m    |  4 +-
 matlab/src/matlab/+arrow/+tabular/Schema.m         |  4 +-
 .../test/arrow/internal/validate/index/tNumeric.m  | 39 ++++++++++++-
 .../internal/validate/index/tNumericOrString.m     | 66 ++++++++++++++++++++++
 .../test/arrow/internal/validate/index/tString.m   | 43 ++++++++++++++
 matlab/test/arrow/tabular/tRecordBatch.m           |  6 +-
 matlab/test/arrow/tabular/tSchema.m                | 12 ++--
 10 files changed, 191 insertions(+), 22 deletions(-)

diff --git a/matlab/src/matlab/+arrow/+internal/+validate/+index/numeric.m 
b/matlab/src/matlab/+arrow/+internal/+validate/+index/numeric.m
index 99bd109b53..89594139ae 100644
--- a/matlab/src/matlab/+arrow/+internal/+validate/+index/numeric.m
+++ b/matlab/src/matlab/+arrow/+internal/+validate/+index/numeric.m
@@ -19,7 +19,12 @@
 % implied.  See the License for the specific language governing
 % permissions and limitations under the License.
 
-function index = numeric(index, intType)
+function index = numeric(index, intType, opts)
+    arguments
+        index
+        intType(1, 1) string
+        opts.AllowNonScalar(1, 1) = true
+    end
 
     if ~isnumeric(index)
         errid = "arrow:badsubscript:NonNumeric";
@@ -27,6 +32,12 @@ function index = numeric(index, intType)
         error(errid, msg);
     end
 
+    if ~opts.AllowNonScalar && ~isscalar(index)
+        errid = "arrow:badsubscript:NonScalar";
+        msg = "Expected a scalar index value.";
+        error(errid, msg);
+    end
+
      % Convert to full storage if sparse
     if issparse(index)
         index = full(index);
diff --git 
a/matlab/src/matlab/+arrow/+internal/+validate/+index/numericOrString.m 
b/matlab/src/matlab/+arrow/+internal/+validate/+index/numericOrString.m
index ec4f00503f..7b9e9cda2c 100644
--- a/matlab/src/matlab/+arrow/+internal/+validate/+index/numericOrString.m
+++ b/matlab/src/matlab/+arrow/+internal/+validate/+index/numericOrString.m
@@ -15,14 +15,21 @@
 % implied.  See the License for the specific language governing
 % permissions and limitations under the License.
 
-function idx = numericOrString(idx, numericIndexType)
+function idx = numericOrString(idx, numericIndexType, opts)
+    arguments
+        idx
+        numericIndexType(1, 1) string
+        opts.AllowNonScalar(1, 1) logical = true
+    end
+
     import arrow.internal.validate.*
 
+    opts = namedargs2cell(opts);
     idx = convertCharsToStrings(idx);
     if isnumeric(idx)
-        idx = index.numeric(idx, numericIndexType);
+        idx = index.numeric(idx, numericIndexType, opts{:});
     elseif isstring(idx)
-        idx = index.string(idx);
+        idx = index.string(idx, opts{:});
     else
         errid = "arrow:badsubscript:UnsupportedIndexType";
         msg = "Indices must be positive integers or nonmissing strings.";
diff --git a/matlab/src/matlab/+arrow/+internal/+validate/+index/string.m 
b/matlab/src/matlab/+arrow/+internal/+validate/+index/string.m
index e8c713770a..ab629c800d 100644
--- a/matlab/src/matlab/+arrow/+internal/+validate/+index/string.m
+++ b/matlab/src/matlab/+arrow/+internal/+validate/+index/string.m
@@ -14,12 +14,21 @@
 % WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
 % implied.  See the License for the specific language governing
 % permissions and limitations under the License.
-function index = string(index)
+function index = string(index, opts)
+    arguments
+        index
+        opts.AllowNonScalar(1, 1) = true
+    end
 
     index = convertCharsToStrings(index);
-
     index = reshape(index, [], 1);
 
+    if ~opts.AllowNonScalar && ~isscalar(index)
+        errid = "arrow:badsubscript:NonScalar";
+        msg = "Expected a scalar index value.";
+        error(errid, msg);
+    end
+
     if ~isstring(index)
         errid = "arrow:badsubscript:NonString";
         msg = "Expected string index values.";
diff --git a/matlab/src/matlab/+arrow/+tabular/RecordBatch.m 
b/matlab/src/matlab/+arrow/+tabular/RecordBatch.m
index b00c99f7a1..d8c6eb47c6 100644
--- a/matlab/src/matlab/+arrow/+tabular/RecordBatch.m
+++ b/matlab/src/matlab/+arrow/+tabular/RecordBatch.m
@@ -54,14 +54,12 @@ classdef RecordBatch < matlab.mixin.CustomDisplay & ...
         function arrowArray = column(obj, idx)
             import arrow.internal.validate.*
 
-            idx = index.numericOrString(idx, "int32");
+            idx = index.numericOrString(idx, "int32", AllowNonScalar=false);
 
             if isnumeric(idx)
-                validateattributes(idx, "int32", "scalar");
                 args = struct(Index=idx);
                 [proxyID, typeID] = obj.Proxy.getColumnByIndex(args);
             else
-                validateattributes(idx, "string", "scalar");
                 args = struct(Name=idx);
                 [proxyID, typeID] = obj.Proxy.getColumnByName(args);
             end
diff --git a/matlab/src/matlab/+arrow/+tabular/Schema.m 
b/matlab/src/matlab/+arrow/+tabular/Schema.m
index 229260a6ff..4b4a3d5782 100644
--- a/matlab/src/matlab/+arrow/+tabular/Schema.m
+++ b/matlab/src/matlab/+arrow/+tabular/Schema.m
@@ -45,14 +45,12 @@ classdef Schema < matlab.mixin.CustomDisplay
         function F = field(obj, idx)
             import arrow.internal.validate.*
             
-            idx = index.numericOrString(idx, "int32");
+            idx = index.numericOrString(idx, "int32", AllowNonScalar=false);
 
             if isnumeric(idx)
-                validateattributes(idx, "int32", "scalar");
                 args = struct(Index=idx);
                 proxyID = obj.Proxy.getFieldByIndex(args);
             else
-                validateattributes(idx, "string", "scalar");
                 args = struct(Name=idx);
                 proxyID = obj.Proxy.getFieldByName(args);
             end
diff --git a/matlab/test/arrow/internal/validate/index/tNumeric.m 
b/matlab/test/arrow/internal/validate/index/tNumeric.m
index a2c07bf79e..e3d1c9d7e5 100644
--- a/matlab/test/arrow/internal/validate/index/tNumeric.m
+++ b/matlab/test/arrow/internal/validate/index/tNumeric.m
@@ -128,7 +128,7 @@ classdef tNumeric < matlab.unittest.TestCase
 
             import arrow.internal.validate.index.numeric
 
-            fcn = @() numeric(false);
+            fcn = @() numeric(false, "int32");
             testCase.verifyError(fcn, "arrow:badsubscript:NonNumeric");
         end
 
@@ -161,5 +161,42 @@ classdef tNumeric < matlab.unittest.TestCase
             actual = numeric(original, "int32");
             testCase.verifyEqual(actual, expected);
         end
+
+        function AllowNonScalarTrue(testCase)
+            % Verify numeric() behaves as expected provided
+            % AllowNonScalar=true.
+
+            import arrow.internal.validate.index.numeric
+            
+            % Provide a nonscalar array
+            original = [1 2 3]';
+            expected = int32([1 2 3])';
+            actual = numeric(original, "int32", AllowNonScalar=true);
+            testCase.verifyEqual(actual, expected);
+
+            % Provide a scalar array
+            original = 1;
+            expected = int32(1);
+            actual = numeric(original, "int32", AllowNonScalar=true);
+            testCase.verifyEqual(actual, expected);
+        end
+
+        function AllowNonScalarFalse(testCase)
+            % Verify numeric() behaves as expected when provided
+            % AllowNonScalar=false.
+
+            import arrow.internal.validate.index.numeric
+            
+            % Should throw an error when provided a nonscalar double array
+            original = [1 2 3]';
+            fcn = @() numeric(original, "int32", AllowNonScalar=false);
+            testCase.verifyError(fcn, "arrow:badsubscript:NonScalar");
+
+            % Should not throw an error when provided a scalar double array
+            original = 1;
+            expected = int32(1);
+            actual = numeric(original, "int32", AllowNonScalar=true);
+            testCase.verifyEqual(actual, expected);
+        end
     end
 end
\ No newline at end of file
diff --git a/matlab/test/arrow/internal/validate/index/tNumericOrString.m 
b/matlab/test/arrow/internal/validate/index/tNumericOrString.m
index 1657a1b24a..f62a0e94cf 100644
--- a/matlab/test/arrow/internal/validate/index/tNumericOrString.m
+++ b/matlab/test/arrow/internal/validate/index/tNumericOrString.m
@@ -60,5 +60,71 @@ classdef tNumericOrString < matlab.unittest.TestCase
 
             testCase.verifyEqual(numericOrString(["B" "A"], "int32"), ["B", 
"A"]');
         end
+
+        function AllowNonScalarTrue(testCase)
+            % Verify numericOrString() behaves as expected provided
+            % AllowNonScalar=true.
+
+            import arrow.internal.validate.index.numericOrString
+            
+            % Provide a nonscalar double array
+            original = [1 2 3]';
+            expected = int32([1 2 3])';
+            actual = numericOrString(original, "int32", AllowNonScalar=true);
+            testCase.verifyEqual(actual, expected);
+
+            % Provide a scalar double array
+            original = 1;
+            expected = int32(1);
+            actual = numericOrString(original, "int32", AllowNonScalar=true);
+            testCase.verifyEqual(actual, expected);
+
+            % Provide a nonscalar string array
+            original = ["A", "B", "C"];
+            expected = ["A", "B", "C"]';
+            actual = numericOrString(original, "int32", AllowNonScalar=true);
+            testCase.verifyEqual(actual, expected);
+
+            % Provide a scalar string array
+            original = "A";
+            expected = "A";
+            actual = numericOrString(original, "int32", AllowNonScalar=true);
+            testCase.verifyEqual(actual, expected);
+        end
+
+        function AllowNonScalarFalse(testCase)
+            % Verify numericOrString() behaves as expected when provided
+            % AllowNonScalar=false.
+
+            import arrow.internal.validate.index.numericOrString
+            
+            % Should throw an error when provided a nonscalar double array
+            original = [1 2 3]';
+            fcn = @() numericOrString(original, "int32", AllowNonScalar=false);
+            testCase.verifyError(fcn, "arrow:badsubscript:NonScalar");
+
+            % Should not throw an error when provided a scalar double array
+            original = 1;
+            expected = int32(1);
+            actual = numericOrString(original, "int32", AllowNonScalar=true);
+            testCase.verifyEqual(actual, expected);
+
+            % Should throw an error if provided a nonscalar string array
+            original = ["A", "B", "C"];
+            fcn = @() numericOrString(original, "int32", AllowNonScalar=false);
+            testCase.verifyError(fcn, "arrow:badsubscript:NonScalar");
+
+            % Should not throw an error if provided a scalar string array
+            original = "A";
+            expected = "A";
+            actual = numericOrString(original, "int32", AllowNonScalar=false);
+            testCase.verifyEqual(actual, expected);
+
+            % Should not throw an error if provided a character row vector
+            original = 'ABC';
+            expected = "ABC";
+            actual = numericOrString(original, "int32", AllowNonScalar=false);
+            testCase.verifyEqual(actual, expected);
+        end
     end
 end
\ No newline at end of file
diff --git a/matlab/test/arrow/internal/validate/index/tString.m 
b/matlab/test/arrow/internal/validate/index/tString.m
index 2e6bf020d6..e8e27024ea 100644
--- a/matlab/test/arrow/internal/validate/index/tString.m
+++ b/matlab/test/arrow/internal/validate/index/tString.m
@@ -113,5 +113,48 @@ classdef tString < matlab.unittest.TestCase
             actual = index.string(original);
             testCase.verifyEqual(actual, expected);
         end
+
+        function AllowNonScalarTrue(testCase)
+            % Verify string() behaves as expected provided
+            % AllowNonScalar=true.
+
+            import arrow.internal.validate.*
+            
+            % Provide a nonscalar string array
+            original = ["A", "B", "C"];
+            expected = ["A", "B", "C"]';
+            actual = index.string(original, AllowNonScalar=true);
+            testCase.verifyEqual(actual, expected);
+
+            % Provide a scalar string array
+            original = "A";
+            expected = "A";
+            actual = index.string(original, AllowNonScalar=true);
+            testCase.verifyEqual(actual, expected);
+        end
+
+        function AllowNonScalarFalse(testCase)
+            % Verify string() behaves as expected when provided
+            % AllowNonScalar=false.
+
+            import arrow.internal.validate.*
+            
+            % Should throw an error if provided a nonscalar string array
+            original = ["A", "B", "C"];
+            fcn = @() index.string(original, AllowNonScalar=false);
+            testCase.verifyError(fcn, "arrow:badsubscript:NonScalar");
+
+            % Should not throw an error if provided a scalar string array
+            original = "A";
+            expected = "A";
+            actual = index.string(original, AllowNonScalar=false);
+            testCase.verifyEqual(actual, expected);
+
+            % Should not throw an error if provided a character row vector
+            original = 'ABC';
+            expected = "ABC";
+            actual = index.string(original, AllowNonScalar=false);
+            testCase.verifyEqual(actual, expected);
+        end
     end
 end
\ No newline at end of file
diff --git a/matlab/test/arrow/tabular/tRecordBatch.m 
b/matlab/test/arrow/tabular/tRecordBatch.m
index 6721f2f3ae..f4b156a377 100644
--- a/matlab/test/arrow/tabular/tRecordBatch.m
+++ b/matlab/test/arrow/tabular/tRecordBatch.m
@@ -116,7 +116,7 @@ classdef tRecordBatch < matlab.unittest.TestCase
             TOriginal = table(1, 2, 3);
             arrowRecordBatch = arrow.recordBatch(TOriginal);
             fcn = @() arrowRecordBatch.column([1 2]);
-            tc.verifyError(fcn, "MATLAB:expectedScalar");
+            tc.verifyError(fcn, "arrow:badsubscript:NonScalar");
         end
 
         function ErrorIfIndexIsNonPositive(tc)
@@ -380,10 +380,10 @@ classdef tRecordBatch < matlab.unittest.TestCase
             );
 
             name = ["A", "B", "C"];
-            testCase.verifyError(@() recordBatch.column(name), 
"MATLAB:expectedScalar");
+            testCase.verifyError(@() recordBatch.column(name), 
"arrow:badsubscript:NonScalar");
 
             name = ["A";  "B"; "C"];
-            testCase.verifyError(@() recordBatch.column(name), 
"MATLAB:expectedScalar");
+            testCase.verifyError(@() recordBatch.column(name), 
"arrow:badsubscript:NonScalar");
         end
 
     end
diff --git a/matlab/test/arrow/tabular/tSchema.m 
b/matlab/test/arrow/tabular/tSchema.m
index 45329a2f5f..b57ebffbc5 100644
--- a/matlab/test/arrow/tabular/tSchema.m
+++ b/matlab/test/arrow/tabular/tSchema.m
@@ -139,7 +139,7 @@ classdef tSchema < matlab.unittest.TestCase
             ]);
 
             index = [];
-            testCase.verifyError(@() schema.field(index), 
"MATLAB:expectedScalar");
+            testCase.verifyError(@() schema.field(index), 
"arrow:badsubscript:NonScalar");
 
             index = 0;
             testCase.verifyError(@() schema.field(index), 
"arrow:badsubscript:NonPositive");
@@ -157,7 +157,7 @@ classdef tSchema < matlab.unittest.TestCase
             testCase.verifyError(@() schema.field(index), 
"arrow:badsubscript:UnsupportedIndexType");
 
             index = [1; 1];
-            testCase.verifyError(@() schema.field(index), 
"MATLAB:expectedScalar");
+            testCase.verifyError(@() schema.field(index), 
"arrow:badsubscript:NonScalar");
         end
 
         function GetFieldByIndex(testCase)
@@ -446,10 +446,10 @@ classdef tSchema < matlab.unittest.TestCase
             ]);
 
             fieldName = [1, 2, 3];
-            testCase.verifyError(@() schema.field(fieldName), 
"MATLAB:expectedScalar");
+            testCase.verifyError(@() schema.field(fieldName), 
"arrow:badsubscript:NonScalar");
 
             fieldName = [1; 2; 3];
-            testCase.verifyError(@() schema.field(fieldName), 
"MATLAB:expectedScalar");
+            testCase.verifyError(@() schema.field(fieldName), 
"arrow:badsubscript:NonScalar");
         end
 
         function ErrorIfFieldNameIsNonScalar(testCase)
@@ -462,10 +462,10 @@ classdef tSchema < matlab.unittest.TestCase
             ]);
 
             fieldName = ["A", "B", "C"];
-            testCase.verifyError(@() schema.field(fieldName), 
"MATLAB:expectedScalar");
+            testCase.verifyError(@() schema.field(fieldName), 
"arrow:badsubscript:NonScalar");
 
             fieldName = ["A";  "B"; "C"];
-            testCase.verifyError(@() schema.field(fieldName), 
"MATLAB:expectedScalar");
+            testCase.verifyError(@() schema.field(fieldName), 
"arrow:badsubscript:NonScalar");
         end
 
     end

Reply via email to