kou commented on code in PR #37013: URL: https://github.com/apache/arrow/pull/37013#discussion_r1283680536
########## matlab/src/cpp/arrow/matlab/tabular/proxy/schema.cc: ########## @@ -0,0 +1,196 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/matlab/error/error.h" +#include "arrow/matlab/tabular/proxy/schema.h" +#include "arrow/matlab/type/proxy/field.h" + +#include "libmexclass/proxy/ProxyManager.h" +#include "libmexclass/error/Error.h" + +#include "arrow/util/utf8.h" + +#include <sstream> + +namespace arrow::matlab::tabular::proxy { + + namespace { + + libmexclass::error::Error makeUnknownFieldNameError(const std::string& name) { + using namespace libmexclass::error; + const std::string error_message_id = std::string{error::ARROW_TABULAR_SCHEMA_UNKNOWN_FIELD_NAME}; + std::stringstream error_message_stream; + error_message_stream << "Unknown field name: '"; + error_message_stream << name; + error_message_stream << "'."; + const std::string& error_message = error_message_stream.str(); + return Error{error_message_id, error_message}; Review Comment: Can we simplify this (and reduce needless string copies)? ```suggestion std::stringstream error_message_stream; error_message_stream << "Unknown field name: '"; error_message_stream << name; error_message_stream << "'."; return Error{error::ARROW_TABULAR_SCHEMA_UNKNOWN_FIELD_NAME, error_message_stream.str()}; ``` ########## matlab/src/cpp/arrow/matlab/tabular/proxy/schema.cc: ########## @@ -0,0 +1,196 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/matlab/error/error.h" +#include "arrow/matlab/tabular/proxy/schema.h" +#include "arrow/matlab/type/proxy/field.h" + +#include "libmexclass/proxy/ProxyManager.h" +#include "libmexclass/error/Error.h" + +#include "arrow/util/utf8.h" + +#include <sstream> + +namespace arrow::matlab::tabular::proxy { + + namespace { + + libmexclass::error::Error makeUnknownFieldNameError(const std::string& name) { + using namespace libmexclass::error; + const std::string error_message_id = std::string{error::ARROW_TABULAR_SCHEMA_UNKNOWN_FIELD_NAME}; + std::stringstream error_message_stream; + error_message_stream << "Unknown field name: '"; + error_message_stream << name; + error_message_stream << "'."; + const std::string& error_message = error_message_stream.str(); + return Error{error_message_id, error_message}; + } + + libmexclass::error::Error makeEmptySchemaError() { + using namespace libmexclass::error; + const std::string error_message_id = std::string{error::ARROW_TABULAR_SCHEMA_NUMERIC_FIELD_INDEX_WITH_EMPTY_SCHEMA}; + std::stringstream error_message_stream; + error_message_stream << "Numeric indexing using the field method is not supported for schemas with no fields."; + const std::string& error_message = error_message_stream.str(); + return Error{error_message_id, error_message}; Review Comment: Can we simplify this (and reduce needless memory allocations)? ```suggestion return Error{error::ARROW_TABULAR_SCHEMA_NUMERIC_FIELD_INDEX_WITH_EMPTY_SCHEMA, "Numeric indexing using the field method is not supported for schemas with no fields."}; ``` ########## matlab/test/arrow/tabular/tSchema.m: ########## @@ -0,0 +1,474 @@ +% Licensed to the Apache Software Foundation (ASF) under one or more +% contributor license agreements. See the NOTICE file distributed with +% this work for additional information regarding copyright ownership. +% The ASF licenses this file to you under the Apache License, Version +% 2.0 (the "License"); you may not use this file except in compliance +% with the License. You may obtain a copy of the License at +% +% http://www.apache.org/licenses/LICENSE-2.0 +% +% Unless required by applicable law or agreed to in writing, software +% distributed under the License is distributed on an "AS IS" BASIS, +% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +% implied. See the License for the specific language governing +% permissions and limitations under the License. + +classdef tSchema < matlab.unittest.TestCase +% Tests for the arrow.tabular.Schema class and the associated arrow.schema +% construction function. + + methods(Test) + + function ErrorIfUnsupportedInputType(testCase) + % Verify that an error is thrown by arrow.schema if an + % unsupported input argument is supplied. + testCase.verifyError(@() arrow.schema("test"), "MATLAB:validation:UnableToConvert"); + end + + function ErrorIfUnsupportedConstructorInputs(testCase) + % Verify that an error is thrown by the constructor of + % arrow.tabular.Schema if unsupported arguments are passed to + % the constructor. + testCase.verifyError(@() arrow.tabular.Schema("test"), "MATLAB:validation:UnableToConvert"); + end + + function ErrorIfTooFewInputs(testCase) + % Verify that an error is thrown by arrow.schema if too few + % input arguments are supplied. + testCase.verifyError(@() arrow.schema(), "MATLAB:minrhs"); + end + + function ErrorIfTooManyInputs(testCase) + % Verify that an error is thrown by arrow.schema if too many + % input arguments are supplied. + testCase.verifyError(@() arrow.schema("a", "b", "c"), "MATLAB:TooManyInputs"); + end + + function ClassType(testCase) + % Verify that the class type of the object returned by a call + % to arrow.schema is "arrow.tabular.Schema". + schema = arrow.schema(arrow.field("A", arrow.uint8)); + testCase.verifyInstanceOf(schema, "arrow.tabular.Schema"); + end + + function ConstructSchemaFromProxy(testCase) + % Verify that an arrow.tabular.Schema instance can be + % constructred directly from an existing + % arrow.tabular.proxy.Schema Proxy instance. + schema1 = arrow.schema(arrow.field("a", arrow.uint8)); + % Construct an instance of arrow.tabular.Schema directly from a + % Proxy of type "arrow.tabular.proxy.Schema". + schema2 = arrow.tabular.Schema(schema1.Proxy); + testCase.verifyEqual(schema1.FieldNames, schema2.FieldNames); + testCase.verifyEqual(schema1.NumFields, schema2.NumFields); + end + + function Fields(testCase) + % Verify that the Fields property returns an expected array of + % Field objects. + f1 = arrow.field("A", arrow.uint8); + f2 = arrow.field("B", arrow.uint16); + f3 = arrow.field("C", arrow.uint32); + expectedFields = [f1, f2, f3]; + schema = arrow.schema(expectedFields); + + actualFields = schema.Fields; + + testCase.verifyEqual(actualFields(1).Name, expectedFields(1).Name); + testCase.verifyEqual(actualFields(1).Type.ID, expectedFields(1).Type.ID); + testCase.verifyEqual(actualFields(2).Name, expectedFields(2).Name); + testCase.verifyEqual(actualFields(2).Type.ID, expectedFields(2).Type.ID); + testCase.verifyEqual(actualFields(3).Name, expectedFields(3).Name); + testCase.verifyEqual(actualFields(3).Type.ID, expectedFields(3).Type.ID); + end + + function FieldNames(testCase) + % Verify that the FieldNames property returns an expected + % string array of field names. + expectedFieldNames = ["A" , "B" , "C"]; + schema = arrow.schema([... + arrow.field(expectedFieldNames(1), arrow.uint8), ... + arrow.field(expectedFieldNames(2), arrow.uint16), ... + arrow.field(expectedFieldNames(3), arrow.uint32) ... + ]); + actualFieldNames = schema.FieldNames; + testCase.verifyEqual(actualFieldNames, expectedFieldNames); + end + + function FieldNamesNoSetter(testCase) + % Verify that an error is thrown when trying to set the value + % of the FieldNames property. + schema = arrow.schema(arrow.field("A", arrow.uint8)); + testCase.verifyError(@() setfield(schema, "FieldNames", "B"), "MATLAB:class:SetProhibited"); + end + + function NumFieldsNoSetter(testCase) + % Verify than an error is thrown when trying to set the value + % of the NumFields property. + schema = arrow.schema(arrow.field("A", arrow.uint8)); + testCase.verifyError(@() setfield(schema, "NumFields", 123), "MATLAB:class:SetProhibited"); + end + + function FieldsNoSetter(testCase) + % Verify that an error is thrown when trying to set the value + % of the Fields property. + schema = arrow.schema(arrow.field("A", arrow.uint8)); + testCase.verifyError(@() setfield(schema, "Fields", arrow.field("B", arrow.uint8)), "MATLAB:class:SetProhibited"); + end + + function NumFields(testCase) + % Verify that the NumFields property returns an execpted number + % of fields. + schema = arrow.schema([... + arrow.field("A", arrow.uint8), ... + arrow.field("B", arrow.uint16), ... + arrow.field("C", arrow.uint32) ... + ]); + expectedNumFields = int32(3); + actualNumFields = schema.NumFields; + testCase.verifyEqual(actualNumFields, expectedNumFields); + end + + function ErrorIfUnsupportedFieldIndex(testCase) + % Verify that an error is thrown if an invalid field index is + % supplied to the field method (e.g. -1.1, NaN, {1}, etc.). + schema = arrow.schema([... + arrow.field("A", arrow.uint8), ... + arrow.field("B", arrow.uint16), ... + arrow.field("C", arrow.uint32) ... + ]); + + index = []; + testCase.verifyError(@() schema.field(index), "arrow:tabular:schema:UnsupportedFieldIndexType"); + + index = 0; + testCase.verifyError(@() schema.field(index), "arrow:tabular:schema:UnsupportedFieldIndexType"); + + index = -1; + testCase.verifyError(@() schema.field(index), "arrow:tabular:schema:UnsupportedFieldIndexType"); + + index = -1.23; + testCase.verifyError(@() schema.field(index), "arrow:tabular:schema:UnsupportedFieldIndexType"); + + index = NaN; + testCase.verifyError(@() schema.field(index), "arrow:tabular:schema:UnsupportedFieldIndexType"); + + index = {1}; + testCase.verifyError(@() schema.field(index), "arrow:tabular:schema:UnsupportedFieldIndexType"); + + index = [1; 1]; + testCase.verifyError(@() schema.field(index), "arrow:tabular:schema:UnsupportedFieldIndexType"); + end + + function GetFieldByIndex(testCase) + % Verify that Fields can be accessed using a numeric index. + schema = arrow.schema([... + arrow.field("A", arrow.uint8), ... + arrow.field("B", arrow.uint16), ... + arrow.field("C", arrow.uint32) ... + ]); + + field = schema.field(1); + testCase.verifyEqual(field.Name, "A"); + testCase.verifyEqual(field.Type.ID, arrow.type.ID.UInt8); + + field = schema.field(2); + testCase.verifyEqual(field.Name, "B"); + testCase.verifyEqual(field.Type.ID, arrow.type.ID.UInt16); + + field = schema.field(3); + testCase.verifyEqual(field.Name, "C"); + testCase.verifyEqual(field.Type.ID, arrow.type.ID.UInt32); + end + + function GetFieldByName(testCase) + % Verify that Fields can be accessed using a field name. + % Verify that Fields can be accessed using a numeric index. Review Comment: ```suggestion % Verify that Fields can be accessed using a numeric index. ``` ########## matlab/src/cpp/arrow/matlab/tabular/proxy/schema.cc: ########## @@ -0,0 +1,196 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/matlab/error/error.h" +#include "arrow/matlab/tabular/proxy/schema.h" +#include "arrow/matlab/type/proxy/field.h" + +#include "libmexclass/proxy/ProxyManager.h" +#include "libmexclass/error/Error.h" + +#include "arrow/util/utf8.h" + +#include <sstream> + +namespace arrow::matlab::tabular::proxy { + + namespace { + + libmexclass::error::Error makeUnknownFieldNameError(const std::string& name) { + using namespace libmexclass::error; + const std::string error_message_id = std::string{error::ARROW_TABULAR_SCHEMA_UNKNOWN_FIELD_NAME}; + std::stringstream error_message_stream; + error_message_stream << "Unknown field name: '"; + error_message_stream << name; + error_message_stream << "'."; + const std::string& error_message = error_message_stream.str(); + return Error{error_message_id, error_message}; + } + + libmexclass::error::Error makeEmptySchemaError() { + using namespace libmexclass::error; + const std::string error_message_id = std::string{error::ARROW_TABULAR_SCHEMA_NUMERIC_FIELD_INDEX_WITH_EMPTY_SCHEMA}; + std::stringstream error_message_stream; + error_message_stream << "Numeric indexing using the field method is not supported for schemas with no fields."; + const std::string& error_message = error_message_stream.str(); + return Error{error_message_id, error_message}; + } + + } + + Schema::Schema(std::shared_ptr<arrow::Schema> schema) : schema{std::move(schema)} { + REGISTER_METHOD(Schema, getFieldByIndex); + REGISTER_METHOD(Schema, getFieldByName); + REGISTER_METHOD(Schema, getNumFields); + REGISTER_METHOD(Schema, getFieldNames); + REGISTER_METHOD(Schema, toString); + } + + libmexclass::proxy::MakeResult Schema::make(const libmexclass::proxy::FunctionArguments& constructor_arguments) { + namespace mda = ::matlab::data; + using SchemaProxy = arrow::matlab::tabular::proxy::Schema; + + mda::StructArray args = constructor_arguments[0]; + const mda::TypedArray<uint64_t> field_proxy_ids_mda = args[0]["FieldProxyIDs"]; + + std::vector<std::shared_ptr<arrow::Field>> fields; + for (const auto proxy_id : field_proxy_ids_mda) { + using namespace libmexclass::proxy; + auto proxy = std::static_pointer_cast<arrow::matlab::type::proxy::Field>(ProxyManager::getProxy(proxy_id)); + auto field = proxy->unwrap(); + fields.push_back(field); + } + auto schema = arrow::schema(fields); + return std::make_shared<SchemaProxy>(std::move(schema)); + } + + std::shared_ptr<arrow::Schema> Schema::unwrap() { + return schema; + } + + void Schema::getFieldByIndex(libmexclass::proxy::method::Context& context) { + namespace mda = ::matlab::data; + using namespace libmexclass::proxy; + using FieldProxy = arrow::matlab::type::proxy::Field; + mda::ArrayFactory factory; + + mda::StructArray args = context.inputs[0]; + const mda::TypedArray<int32_t> index_mda = args[0]["Index"]; + const auto matlab_index = int32_t(index_mda[0]); + // Note: MATLAB uses 1-based indexing, so subtract 1. + // arrow::Schema::field does not do any bounds checking. + const int32_t index = matlab_index - 1; + const auto num_fields = schema->num_fields(); + + if (num_fields == 0) { + const auto& error = makeEmptySchemaError(); + context.error = error; + return; + } + + if (matlab_index < 1 || matlab_index > num_fields) { + using namespace libmexclass::error; + const std::string& error_message_id = std::string{error::ARROW_TABULAR_SCHEMA_INVALID_NUMERIC_FIELD_INDEX}; + std::stringstream error_message_stream; + error_message_stream << "Invalid field index: "; + error_message_stream << matlab_index; + error_message_stream << ". Field index must be between 1 and the number of fields ("; + error_message_stream << num_fields; + error_message_stream << ")."; + const std::string& error_message = error_message_stream.str(); + context.error = Error{error_message_id, error_message}; + return; + } + + const auto& field = schema->field(index); + auto field_proxy = std::make_shared<FieldProxy>(field); + const auto field_proxy_id = ProxyManager::manageProxy(field_proxy); + const auto field_proxy_id_mda = factory.createScalar(field_proxy_id); + + context.outputs[0] = field_proxy_id_mda; + } + + void Schema::getFieldByName(libmexclass::proxy::method::Context& context) { + namespace mda = ::matlab::data; + using namespace libmexclass::proxy; + using FieldProxy = arrow::matlab::type::proxy::Field; + mda::ArrayFactory factory; + + mda::StructArray args = context.inputs[0]; + const mda::StringArray name_mda = args[0]["Name"]; + const auto name_utf16 = std::u16string(name_mda[0]); + MATLAB_ASSIGN_OR_ERROR_WITH_CONTEXT(const auto name, arrow::util::UTF16StringToUTF8(name_utf16), context, error::UNICODE_CONVERSION_ERROR_ID); + const std::vector<std::string> names = {name}; + MATLAB_ERROR_IF_NOT_OK_WITH_CONTEXT(schema->CanReferenceFieldsByNames(names), context, error::ARROW_TABULAR_SCHEMA_AMBIGUOUS_FIELD_NAME); + const auto field = schema->GetFieldByName(name); + if (!field) { + // Note: This line should never be reached because CanReferenceFieldsByNames + // should already handle validating whether the supplied field name is valid. + const auto& error = makeUnknownFieldNameError(name); + context.error = error; + return; + } Review Comment: Can we remove this? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
