This is an automated email from the ASF dual-hosted git repository.
kevingurney pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new 2f3db6580e GH-37477: [MATLAB] Add `AllowNonScalar` name-value pair to
arrow.internal.validate.index.* validation functions (#37482)
2f3db6580e is described below
commit 2f3db6580ed59215524f30e09469f0451cb4a6c8
Author: sgilmore10 <[email protected]>
AuthorDate: Thu Aug 31 10:38:32 2023 -0400
GH-37477: [MATLAB] Add `AllowNonScalar` name-value pair to
arrow.internal.validate.index.* validation functions (#37482)
### Rationale for this change
Per https://github.com/apache/arrow/pull/37475#discussion_r1310732596, we
should consider adding a name-value pair like `AllowNonScalar = true | false`
to the `arrow.internal.validate.index.*` validation functions since it is
relatively common to want to explicitly allow (or disallow) non-scalar inputs
to indexing functions (e.g. the `column` method of `RecordBatch` should only
support scalar index values).
### What changes are included in this PR?
1. Modified all functions within the `arrow.internal.valdiate.index`
package (i.e. `numeric()`, `string()`, and `numericOrString()`) to accept a
name-value pair called `AllowNonScalar`. This name-value pair can be set to
`logical` scalar, and by default it's set to `true`.
2. Updated the `column()` method in `RecordBatch` to pass
`AllowNonScalar=false` to `numericOrString()`.
3. Updated the `field()` method in `RecordBatch` to pass
`AllowNonScalar=false` to `numericOrString()`.
**NOTE:** While character row vectors (e.g. `'ABC'`) are not scalar, they
are equivalent to scalar `string` arrays. Therefore, both `string()` and
`numericOrString()` do not error if given a character row vector as the index
to validate and `AllowNonScalar=false`.
### Are these changes tested?
Yes. Added new test cases to `tNumeric.m`, `tString.m` and
`tNumericOrString.m`
### Are there any user-facing changes?
No.
* Closes: #37477
Authored-by: Sarah Gilmore <[email protected]>
Signed-off-by: Kevin Gurney <[email protected]>
---
.../+arrow/+internal/+validate/+index/numeric.m | 13 ++++-
.../+internal/+validate/+index/numericOrString.m | 13 ++++-
.../+arrow/+internal/+validate/+index/string.m | 13 ++++-
matlab/src/matlab/+arrow/+tabular/RecordBatch.m | 4 +-
matlab/src/matlab/+arrow/+tabular/Schema.m | 4 +-
.../test/arrow/internal/validate/index/tNumeric.m | 39 ++++++++++++-
.../internal/validate/index/tNumericOrString.m | 66 ++++++++++++++++++++++
.../test/arrow/internal/validate/index/tString.m | 43 ++++++++++++++
matlab/test/arrow/tabular/tRecordBatch.m | 6 +-
matlab/test/arrow/tabular/tSchema.m | 12 ++--
10 files changed, 191 insertions(+), 22 deletions(-)
diff --git a/matlab/src/matlab/+arrow/+internal/+validate/+index/numeric.m
b/matlab/src/matlab/+arrow/+internal/+validate/+index/numeric.m
index 99bd109b53..89594139ae 100644
--- a/matlab/src/matlab/+arrow/+internal/+validate/+index/numeric.m
+++ b/matlab/src/matlab/+arrow/+internal/+validate/+index/numeric.m
@@ -19,7 +19,12 @@
% implied. See the License for the specific language governing
% permissions and limitations under the License.
-function index = numeric(index, intType)
+function index = numeric(index, intType, opts)
+ arguments
+ index
+ intType(1, 1) string
+ opts.AllowNonScalar(1, 1) = true
+ end
if ~isnumeric(index)
errid = "arrow:badsubscript:NonNumeric";
@@ -27,6 +32,12 @@ function index = numeric(index, intType)
error(errid, msg);
end
+ if ~opts.AllowNonScalar && ~isscalar(index)
+ errid = "arrow:badsubscript:NonScalar";
+ msg = "Expected a scalar index value.";
+ error(errid, msg);
+ end
+
% Convert to full storage if sparse
if issparse(index)
index = full(index);
diff --git
a/matlab/src/matlab/+arrow/+internal/+validate/+index/numericOrString.m
b/matlab/src/matlab/+arrow/+internal/+validate/+index/numericOrString.m
index ec4f00503f..7b9e9cda2c 100644
--- a/matlab/src/matlab/+arrow/+internal/+validate/+index/numericOrString.m
+++ b/matlab/src/matlab/+arrow/+internal/+validate/+index/numericOrString.m
@@ -15,14 +15,21 @@
% implied. See the License for the specific language governing
% permissions and limitations under the License.
-function idx = numericOrString(idx, numericIndexType)
+function idx = numericOrString(idx, numericIndexType, opts)
+ arguments
+ idx
+ numericIndexType(1, 1) string
+ opts.AllowNonScalar(1, 1) logical = true
+ end
+
import arrow.internal.validate.*
+ opts = namedargs2cell(opts);
idx = convertCharsToStrings(idx);
if isnumeric(idx)
- idx = index.numeric(idx, numericIndexType);
+ idx = index.numeric(idx, numericIndexType, opts{:});
elseif isstring(idx)
- idx = index.string(idx);
+ idx = index.string(idx, opts{:});
else
errid = "arrow:badsubscript:UnsupportedIndexType";
msg = "Indices must be positive integers or nonmissing strings.";
diff --git a/matlab/src/matlab/+arrow/+internal/+validate/+index/string.m
b/matlab/src/matlab/+arrow/+internal/+validate/+index/string.m
index e8c713770a..ab629c800d 100644
--- a/matlab/src/matlab/+arrow/+internal/+validate/+index/string.m
+++ b/matlab/src/matlab/+arrow/+internal/+validate/+index/string.m
@@ -14,12 +14,21 @@
% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
% implied. See the License for the specific language governing
% permissions and limitations under the License.
-function index = string(index)
+function index = string(index, opts)
+ arguments
+ index
+ opts.AllowNonScalar(1, 1) = true
+ end
index = convertCharsToStrings(index);
-
index = reshape(index, [], 1);
+ if ~opts.AllowNonScalar && ~isscalar(index)
+ errid = "arrow:badsubscript:NonScalar";
+ msg = "Expected a scalar index value.";
+ error(errid, msg);
+ end
+
if ~isstring(index)
errid = "arrow:badsubscript:NonString";
msg = "Expected string index values.";
diff --git a/matlab/src/matlab/+arrow/+tabular/RecordBatch.m
b/matlab/src/matlab/+arrow/+tabular/RecordBatch.m
index b00c99f7a1..d8c6eb47c6 100644
--- a/matlab/src/matlab/+arrow/+tabular/RecordBatch.m
+++ b/matlab/src/matlab/+arrow/+tabular/RecordBatch.m
@@ -54,14 +54,12 @@ classdef RecordBatch < matlab.mixin.CustomDisplay & ...
function arrowArray = column(obj, idx)
import arrow.internal.validate.*
- idx = index.numericOrString(idx, "int32");
+ idx = index.numericOrString(idx, "int32", AllowNonScalar=false);
if isnumeric(idx)
- validateattributes(idx, "int32", "scalar");
args = struct(Index=idx);
[proxyID, typeID] = obj.Proxy.getColumnByIndex(args);
else
- validateattributes(idx, "string", "scalar");
args = struct(Name=idx);
[proxyID, typeID] = obj.Proxy.getColumnByName(args);
end
diff --git a/matlab/src/matlab/+arrow/+tabular/Schema.m
b/matlab/src/matlab/+arrow/+tabular/Schema.m
index 229260a6ff..4b4a3d5782 100644
--- a/matlab/src/matlab/+arrow/+tabular/Schema.m
+++ b/matlab/src/matlab/+arrow/+tabular/Schema.m
@@ -45,14 +45,12 @@ classdef Schema < matlab.mixin.CustomDisplay
function F = field(obj, idx)
import arrow.internal.validate.*
- idx = index.numericOrString(idx, "int32");
+ idx = index.numericOrString(idx, "int32", AllowNonScalar=false);
if isnumeric(idx)
- validateattributes(idx, "int32", "scalar");
args = struct(Index=idx);
proxyID = obj.Proxy.getFieldByIndex(args);
else
- validateattributes(idx, "string", "scalar");
args = struct(Name=idx);
proxyID = obj.Proxy.getFieldByName(args);
end
diff --git a/matlab/test/arrow/internal/validate/index/tNumeric.m
b/matlab/test/arrow/internal/validate/index/tNumeric.m
index a2c07bf79e..e3d1c9d7e5 100644
--- a/matlab/test/arrow/internal/validate/index/tNumeric.m
+++ b/matlab/test/arrow/internal/validate/index/tNumeric.m
@@ -128,7 +128,7 @@ classdef tNumeric < matlab.unittest.TestCase
import arrow.internal.validate.index.numeric
- fcn = @() numeric(false);
+ fcn = @() numeric(false, "int32");
testCase.verifyError(fcn, "arrow:badsubscript:NonNumeric");
end
@@ -161,5 +161,42 @@ classdef tNumeric < matlab.unittest.TestCase
actual = numeric(original, "int32");
testCase.verifyEqual(actual, expected);
end
+
+ function AllowNonScalarTrue(testCase)
+ % Verify numeric() behaves as expected provided
+ % AllowNonScalar=true.
+
+ import arrow.internal.validate.index.numeric
+
+ % Provide a nonscalar array
+ original = [1 2 3]';
+ expected = int32([1 2 3])';
+ actual = numeric(original, "int32", AllowNonScalar=true);
+ testCase.verifyEqual(actual, expected);
+
+ % Provide a scalar array
+ original = 1;
+ expected = int32(1);
+ actual = numeric(original, "int32", AllowNonScalar=true);
+ testCase.verifyEqual(actual, expected);
+ end
+
+ function AllowNonScalarFalse(testCase)
+ % Verify numeric() behaves as expected when provided
+ % AllowNonScalar=false.
+
+ import arrow.internal.validate.index.numeric
+
+ % Should throw an error when provided a nonscalar double array
+ original = [1 2 3]';
+ fcn = @() numeric(original, "int32", AllowNonScalar=false);
+ testCase.verifyError(fcn, "arrow:badsubscript:NonScalar");
+
+ % Should not throw an error when provided a scalar double array
+ original = 1;
+ expected = int32(1);
+ actual = numeric(original, "int32", AllowNonScalar=true);
+ testCase.verifyEqual(actual, expected);
+ end
end
end
\ No newline at end of file
diff --git a/matlab/test/arrow/internal/validate/index/tNumericOrString.m
b/matlab/test/arrow/internal/validate/index/tNumericOrString.m
index 1657a1b24a..f62a0e94cf 100644
--- a/matlab/test/arrow/internal/validate/index/tNumericOrString.m
+++ b/matlab/test/arrow/internal/validate/index/tNumericOrString.m
@@ -60,5 +60,71 @@ classdef tNumericOrString < matlab.unittest.TestCase
testCase.verifyEqual(numericOrString(["B" "A"], "int32"), ["B",
"A"]');
end
+
+ function AllowNonScalarTrue(testCase)
+ % Verify numericOrString() behaves as expected provided
+ % AllowNonScalar=true.
+
+ import arrow.internal.validate.index.numericOrString
+
+ % Provide a nonscalar double array
+ original = [1 2 3]';
+ expected = int32([1 2 3])';
+ actual = numericOrString(original, "int32", AllowNonScalar=true);
+ testCase.verifyEqual(actual, expected);
+
+ % Provide a scalar double array
+ original = 1;
+ expected = int32(1);
+ actual = numericOrString(original, "int32", AllowNonScalar=true);
+ testCase.verifyEqual(actual, expected);
+
+ % Provide a nonscalar string array
+ original = ["A", "B", "C"];
+ expected = ["A", "B", "C"]';
+ actual = numericOrString(original, "int32", AllowNonScalar=true);
+ testCase.verifyEqual(actual, expected);
+
+ % Provide a scalar string array
+ original = "A";
+ expected = "A";
+ actual = numericOrString(original, "int32", AllowNonScalar=true);
+ testCase.verifyEqual(actual, expected);
+ end
+
+ function AllowNonScalarFalse(testCase)
+ % Verify numericOrString() behaves as expected when provided
+ % AllowNonScalar=false.
+
+ import arrow.internal.validate.index.numericOrString
+
+ % Should throw an error when provided a nonscalar double array
+ original = [1 2 3]';
+ fcn = @() numericOrString(original, "int32", AllowNonScalar=false);
+ testCase.verifyError(fcn, "arrow:badsubscript:NonScalar");
+
+ % Should not throw an error when provided a scalar double array
+ original = 1;
+ expected = int32(1);
+ actual = numericOrString(original, "int32", AllowNonScalar=true);
+ testCase.verifyEqual(actual, expected);
+
+ % Should throw an error if provided a nonscalar string array
+ original = ["A", "B", "C"];
+ fcn = @() numericOrString(original, "int32", AllowNonScalar=false);
+ testCase.verifyError(fcn, "arrow:badsubscript:NonScalar");
+
+ % Should not throw an error if provided a scalar string array
+ original = "A";
+ expected = "A";
+ actual = numericOrString(original, "int32", AllowNonScalar=false);
+ testCase.verifyEqual(actual, expected);
+
+ % Should not throw an error if provided a character row vector
+ original = 'ABC';
+ expected = "ABC";
+ actual = numericOrString(original, "int32", AllowNonScalar=false);
+ testCase.verifyEqual(actual, expected);
+ end
end
end
\ No newline at end of file
diff --git a/matlab/test/arrow/internal/validate/index/tString.m
b/matlab/test/arrow/internal/validate/index/tString.m
index 2e6bf020d6..e8e27024ea 100644
--- a/matlab/test/arrow/internal/validate/index/tString.m
+++ b/matlab/test/arrow/internal/validate/index/tString.m
@@ -113,5 +113,48 @@ classdef tString < matlab.unittest.TestCase
actual = index.string(original);
testCase.verifyEqual(actual, expected);
end
+
+ function AllowNonScalarTrue(testCase)
+ % Verify string() behaves as expected provided
+ % AllowNonScalar=true.
+
+ import arrow.internal.validate.*
+
+ % Provide a nonscalar string array
+ original = ["A", "B", "C"];
+ expected = ["A", "B", "C"]';
+ actual = index.string(original, AllowNonScalar=true);
+ testCase.verifyEqual(actual, expected);
+
+ % Provide a scalar string array
+ original = "A";
+ expected = "A";
+ actual = index.string(original, AllowNonScalar=true);
+ testCase.verifyEqual(actual, expected);
+ end
+
+ function AllowNonScalarFalse(testCase)
+ % Verify string() behaves as expected when provided
+ % AllowNonScalar=false.
+
+ import arrow.internal.validate.*
+
+ % Should throw an error if provided a nonscalar string array
+ original = ["A", "B", "C"];
+ fcn = @() index.string(original, AllowNonScalar=false);
+ testCase.verifyError(fcn, "arrow:badsubscript:NonScalar");
+
+ % Should not throw an error if provided a scalar string array
+ original = "A";
+ expected = "A";
+ actual = index.string(original, AllowNonScalar=false);
+ testCase.verifyEqual(actual, expected);
+
+ % Should not throw an error if provided a character row vector
+ original = 'ABC';
+ expected = "ABC";
+ actual = index.string(original, AllowNonScalar=false);
+ testCase.verifyEqual(actual, expected);
+ end
end
end
\ No newline at end of file
diff --git a/matlab/test/arrow/tabular/tRecordBatch.m
b/matlab/test/arrow/tabular/tRecordBatch.m
index 6721f2f3ae..f4b156a377 100644
--- a/matlab/test/arrow/tabular/tRecordBatch.m
+++ b/matlab/test/arrow/tabular/tRecordBatch.m
@@ -116,7 +116,7 @@ classdef tRecordBatch < matlab.unittest.TestCase
TOriginal = table(1, 2, 3);
arrowRecordBatch = arrow.recordBatch(TOriginal);
fcn = @() arrowRecordBatch.column([1 2]);
- tc.verifyError(fcn, "MATLAB:expectedScalar");
+ tc.verifyError(fcn, "arrow:badsubscript:NonScalar");
end
function ErrorIfIndexIsNonPositive(tc)
@@ -380,10 +380,10 @@ classdef tRecordBatch < matlab.unittest.TestCase
);
name = ["A", "B", "C"];
- testCase.verifyError(@() recordBatch.column(name),
"MATLAB:expectedScalar");
+ testCase.verifyError(@() recordBatch.column(name),
"arrow:badsubscript:NonScalar");
name = ["A"; "B"; "C"];
- testCase.verifyError(@() recordBatch.column(name),
"MATLAB:expectedScalar");
+ testCase.verifyError(@() recordBatch.column(name),
"arrow:badsubscript:NonScalar");
end
end
diff --git a/matlab/test/arrow/tabular/tSchema.m
b/matlab/test/arrow/tabular/tSchema.m
index 45329a2f5f..b57ebffbc5 100644
--- a/matlab/test/arrow/tabular/tSchema.m
+++ b/matlab/test/arrow/tabular/tSchema.m
@@ -139,7 +139,7 @@ classdef tSchema < matlab.unittest.TestCase
]);
index = [];
- testCase.verifyError(@() schema.field(index),
"MATLAB:expectedScalar");
+ testCase.verifyError(@() schema.field(index),
"arrow:badsubscript:NonScalar");
index = 0;
testCase.verifyError(@() schema.field(index),
"arrow:badsubscript:NonPositive");
@@ -157,7 +157,7 @@ classdef tSchema < matlab.unittest.TestCase
testCase.verifyError(@() schema.field(index),
"arrow:badsubscript:UnsupportedIndexType");
index = [1; 1];
- testCase.verifyError(@() schema.field(index),
"MATLAB:expectedScalar");
+ testCase.verifyError(@() schema.field(index),
"arrow:badsubscript:NonScalar");
end
function GetFieldByIndex(testCase)
@@ -446,10 +446,10 @@ classdef tSchema < matlab.unittest.TestCase
]);
fieldName = [1, 2, 3];
- testCase.verifyError(@() schema.field(fieldName),
"MATLAB:expectedScalar");
+ testCase.verifyError(@() schema.field(fieldName),
"arrow:badsubscript:NonScalar");
fieldName = [1; 2; 3];
- testCase.verifyError(@() schema.field(fieldName),
"MATLAB:expectedScalar");
+ testCase.verifyError(@() schema.field(fieldName),
"arrow:badsubscript:NonScalar");
end
function ErrorIfFieldNameIsNonScalar(testCase)
@@ -462,10 +462,10 @@ classdef tSchema < matlab.unittest.TestCase
]);
fieldName = ["A", "B", "C"];
- testCase.verifyError(@() schema.field(fieldName),
"MATLAB:expectedScalar");
+ testCase.verifyError(@() schema.field(fieldName),
"arrow:badsubscript:NonScalar");
fieldName = ["A"; "B"; "C"];
- testCase.verifyError(@() schema.field(fieldName),
"MATLAB:expectedScalar");
+ testCase.verifyError(@() schema.field(fieldName),
"arrow:badsubscript:NonScalar");
end
end