Fokko commented on code in PR #36846:
URL: https://github.com/apache/arrow/pull/36846#discussion_r1274828067
##########
cpp/src/arrow/type.cc:
##########
@@ -316,6 +298,431 @@ std::shared_ptr<Field> Field::WithNullable(const bool
nullable) const {
return std::make_shared<Field>(name_, type_, nullable, metadata_);
}
+Field::MergeOptions Field::MergeOptions::Permissive() {
+ MergeOptions options = Defaults();
+ options.promote_nullability = true;
+ options.promote_decimal = true;
+ options.promote_decimal_float = true;
+ options.promote_integer_decimal = true;
+ options.promote_integer_float = true;
+ options.promote_integer_sign = true;
+ options.promote_numeric_width = true;
+ options.promote_binary = true;
+ options.promote_date = true;
+ options.promote_duration = true;
+ options.promote_time = true;
+ options.promote_timestamp = true;
+ options.promote_dictionary = true;
+ options.promote_dictionary_ordered = false;
+ options.promote_large = true;
+ options.promote_nested = true;
+ return options;
+}
+
+std::string Field::MergeOptions::ToString() const {
+ std::stringstream ss;
+ ss << "MergeOptions{";
+ ss << "promote_nullability=" << (promote_nullability ? "true" : "false");
+ ss << ", promote_numeric_width=" << (promote_numeric_width ? "true" :
"false");
+ ss << ", promote_integer_float=" << (promote_integer_float ? "true" :
"false");
+ ss << ", promote_integer_decimal=" << (promote_integer_decimal ? "true" :
"false");
+ ss << ", promote_decimal_float=" << (promote_decimal_float ? "true" :
"false");
+ ss << ", promote_date=" << (promote_date ? "true" : "false");
+ ss << ", promote_time=" << (promote_time ? "true" : "false");
+ ss << ", promote_duration=" << (promote_duration ? "true" : "false");
+ ss << ", promote_timestamp=" << (promote_timestamp ? "true" : "false");
+ ss << ", promote_nested=" << (promote_nested ? "true" : "false");
+ ss << ", promote_dictionary=" << (promote_dictionary ? "true" : "false");
+ ss << ", promote_integer_sign=" << (promote_integer_sign ? "true" : "false");
+ ss << ", promote_large=" << (promote_large ? "true" : "false");
+ ss << ", promote_binary=" << (promote_binary ? "true" : "false");
+ ss << '}';
+ return ss.str();
+}
+
+namespace {
+// Utilities for Field::MergeWith
+
+std::shared_ptr<DataType> MakeSigned(const DataType& type) {
+ switch (type.id()) {
+ case Type::INT8:
+ case Type::UINT8:
+ return int8();
+ case Type::INT16:
+ case Type::UINT16:
+ return int16();
+ case Type::INT32:
+ case Type::UINT32:
+ return int32();
+ case Type::INT64:
+ case Type::UINT64:
+ return int64();
+ default:
+ DCHECK(false) << "unreachable";
+ }
+ return std::shared_ptr<DataType>(nullptr);
+}
+std::shared_ptr<DataType> MakeBinary(const DataType& type) {
+ switch (type.id()) {
+ case Type::BINARY:
+ case Type::STRING:
+ return binary();
+ case Type::LARGE_BINARY:
+ case Type::LARGE_STRING:
+ return large_binary();
+ default:
+ DCHECK(false) << "unreachable";
+ }
+ return std::shared_ptr<DataType>(nullptr);
+}
+TimeUnit::type CommonTimeUnit(TimeUnit::type left, TimeUnit::type right) {
+ if (left == TimeUnit::NANO || right == TimeUnit::NANO) {
+ return TimeUnit::NANO;
+ } else if (left == TimeUnit::MICRO || right == TimeUnit::MICRO) {
+ return TimeUnit::MICRO;
+ } else if (left == TimeUnit::MILLI || right == TimeUnit::MILLI) {
+ return TimeUnit::MILLI;
+ }
+ return TimeUnit::SECOND;
+}
+
+Result<std::shared_ptr<DataType>> MergeTypes(std::shared_ptr<DataType>
promoted_type,
+ std::shared_ptr<DataType>
other_type,
+ const Field::MergeOptions&
options);
+
+// Merge two dictionary types, or else give an error.
+Result<std::shared_ptr<DataType>> MergeDictionaryTypes(
+ const std::shared_ptr<DataType>& promoted_type,
+ const std::shared_ptr<DataType>& other_type, const Field::MergeOptions&
options) {
+ const auto& left = checked_cast<const DictionaryType&>(*promoted_type);
+ const auto& right = checked_cast<const DictionaryType&>(*other_type);
+ if (!options.promote_dictionary_ordered && left.ordered() !=
right.ordered()) {
+ return Status::Invalid(
+ "Cannot merge ordered and unordered dictionary unless "
+ "promote_dictionary_ordered=true");
+ }
+ Field::MergeOptions index_options = options;
+ index_options.promote_integer_sign = true;
+ index_options.promote_numeric_width = true;
+ ARROW_ASSIGN_OR_RAISE(auto indices,
+ MergeTypes(left.index_type(), right.index_type(),
index_options));
+ ARROW_ASSIGN_OR_RAISE(auto values,
+ MergeTypes(left.value_type(), right.value_type(),
options));
+ auto ordered = left.ordered() && right.ordered();
+ if (indices && values) {
+ return dictionary(indices, values, ordered);
+ } else if (values) {
+ return Status::Invalid("Could not merge index types");
+ }
+ return Status::Invalid("Could not merge value types");
+}
+
+// Merge temporal types based on options. Returns nullptr for non-temporal
types.
+Result<std::shared_ptr<DataType>> MaybeMergeTemporalTypes(
+ const std::shared_ptr<DataType>& promoted_type,
+ const std::shared_ptr<DataType>& other_type, const Field::MergeOptions&
options) {
+ if (options.promote_date) {
+ if (promoted_type->id() == Type::DATE32 && other_type->id() ==
Type::DATE64) {
+ return date64();
+ }
+ if (promoted_type->id() == Type::DATE64 && other_type->id() ==
Type::DATE32) {
+ return date64();
+ }
+ }
+
+ if (options.promote_duration && promoted_type->id() == Type::DURATION &&
+ other_type->id() == Type::DURATION) {
+ const auto& left = checked_cast<const DurationType&>(*promoted_type);
+ const auto& right = checked_cast<const DurationType&>(*other_type);
+ return duration(CommonTimeUnit(left.unit(), right.unit()));
+ }
+
+ if (options.promote_time && is_time(promoted_type->id()) &&
is_time(other_type->id())) {
+ const auto& left = checked_cast<const TimeType&>(*promoted_type);
+ const auto& right = checked_cast<const TimeType&>(*other_type);
+ const auto unit = CommonTimeUnit(left.unit(), right.unit());
+ if (unit == TimeUnit::MICRO || unit == TimeUnit::NANO) {
+ return time64(unit);
+ }
+ return time32(unit);
+ }
+
+ if (options.promote_timestamp && promoted_type->id() == Type::TIMESTAMP &&
+ other_type->id() == Type::TIMESTAMP) {
+ const auto& left = checked_cast<const TimestampType&>(*promoted_type);
+ const auto& right = checked_cast<const TimestampType&>(*other_type);
+ if (left.timezone().empty() ^ right.timezone().empty()) {
+ return Status::Invalid(
+ "Cannot merge timestamp with timezone and timestamp without
timezone");
+ }
+ if (left.timezone() != right.timezone()) {
+ return Status::Invalid("Cannot merge timestamps with differing
timezones");
+ }
+ return timestamp(CommonTimeUnit(left.unit(), right.unit()),
left.timezone());
+ }
+
+ return nullptr;
+}
+
+// Merge numeric types based on options. Returns nullptr for non-temporal
types.
+Result<std::shared_ptr<DataType>> MaybeMergeNumericTypes(
+ std::shared_ptr<DataType> promoted_type, std::shared_ptr<DataType>
other_type,
+ const Field::MergeOptions& options) {
+ bool promoted = false;
+ if (options.promote_decimal_float) {
+ if (is_decimal(promoted_type->id()) && is_floating(other_type->id())) {
+ promoted_type = other_type;
+ promoted = true;
+ } else if (is_floating(promoted_type->id()) &&
is_decimal(other_type->id())) {
+ other_type = promoted_type;
+ promoted = true;
+ }
+ }
+
+ if (options.promote_integer_decimal) {
+ if (is_integer(promoted_type->id()) && is_decimal(other_type->id())) {
+ promoted_type.swap(other_type);
+ }
+
+ if (is_decimal(promoted_type->id()) && is_integer(other_type->id())) {
+ ARROW_ASSIGN_OR_RAISE(const int32_t precision,
+ MaxDecimalDigitsForInteger(other_type->id()));
+ ARROW_ASSIGN_OR_RAISE(other_type,
+ DecimalType::Make(promoted_type->id(), precision,
0));
+ promoted = true;
+ }
+ }
+
+ if (options.promote_decimal && is_decimal(promoted_type->id()) &&
+ is_decimal(other_type->id())) {
+ const auto& left = checked_cast<const DecimalType&>(*promoted_type);
+ const auto& right = checked_cast<const DecimalType&>(*other_type);
+ if (!options.promote_numeric_width && left.bit_width() !=
right.bit_width()) {
+ return Status::Invalid(
+ "Cannot promote decimal128 to decimal256 without
promote_numeric_width=true");
+ }
+ const int32_t max_scale = std::max<int32_t>(left.scale(), right.scale());
+ const int32_t common_precision =
+ std::max<int32_t>(left.precision() + max_scale - left.scale(),
+ right.precision() + max_scale - right.scale());
+ if (left.id() == Type::DECIMAL256 || right.id() == Type::DECIMAL256 ||
+ (options.promote_numeric_width &&
+ common_precision > BasicDecimal128::kMaxPrecision)) {
+ return DecimalType::Make(Type::DECIMAL256, common_precision, max_scale);
+ }
+ return DecimalType::Make(Type::DECIMAL128, common_precision, max_scale);
+ }
+
+ if (options.promote_integer_sign) {
+ if (is_unsigned_integer(promoted_type->id()) &&
is_signed_integer(other_type->id())) {
+ promoted = bit_width(other_type->id()) >= bit_width(promoted_type->id());
+ promoted_type = MakeSigned(*promoted_type);
+ } else if (is_signed_integer(promoted_type->id()) &&
+ is_unsigned_integer(other_type->id())) {
+ promoted = bit_width(promoted_type->id()) >= bit_width(other_type->id());
+ other_type = MakeSigned(*other_type);
+ }
+ }
+
+ if (options.promote_integer_float &&
+ ((is_floating(promoted_type->id()) && is_integer(other_type->id())) ||
+ (is_integer(promoted_type->id()) && is_floating(other_type->id())))) {
+ const int max_width =
+ std::max<int>(bit_width(promoted_type->id()),
bit_width(other_type->id()));
+ if (max_width >= 64) {
+ promoted_type = float64();
+ } else if (max_width >= 32) {
+ promoted_type = float32();
+ } else {
+ promoted_type = float16();
+ }
+ promoted = true;
+ }
+
+ if (options.promote_numeric_width) {
+ const int max_width =
+ std::max<int>(bit_width(promoted_type->id()),
bit_width(other_type->id()));
+ if (is_floating(promoted_type->id()) && is_floating(other_type->id())) {
+ if (max_width >= 64) {
+ return float64();
+ } else if (max_width >= 32) {
+ return float32();
+ }
+ return float16();
+ } else if (is_signed_integer(promoted_type->id()) &&
+ is_signed_integer(other_type->id())) {
+ if (max_width >= 64) {
+ return int64();
+ } else if (max_width >= 32) {
+ return int32();
+ } else if (max_width >= 16) {
+ return int16();
+ }
+ return int8();
+ } else if (is_unsigned_integer(promoted_type->id()) &&
+ is_unsigned_integer(other_type->id())) {
+ if (max_width >= 64) {
+ return uint64();
+ } else if (max_width >= 32) {
+ return uint32();
+ } else if (max_width >= 16) {
+ return uint16();
+ }
+ return uint8();
+ }
+ }
+
+ return promoted ? promoted_type : nullptr;
+}
+
+Result<std::shared_ptr<DataType>> MergeTypes(std::shared_ptr<DataType>
promoted_type,
+ std::shared_ptr<DataType>
other_type,
+ const Field::MergeOptions&
options) {
+ if (promoted_type->Equals(*other_type)) return promoted_type;
+
+ bool promoted = false;
+ if (options.promote_nullability) {
+ if (promoted_type->id() == Type::NA) {
+ return other_type;
+ } else if (other_type->id() == Type::NA) {
+ return promoted_type;
+ }
+ } else if (promoted_type->id() == Type::NA || other_type->id() == Type::NA) {
+ return Status::Invalid("Cannot merge type with null unless
promote_nullability=true");
+ }
+
+ if (options.promote_dictionary && is_dictionary(promoted_type->id()) &&
+ is_dictionary(other_type->id())) {
+ return MergeDictionaryTypes(promoted_type, other_type, options);
+ }
+
+ ARROW_ASSIGN_OR_RAISE(auto maybe_promoted,
+ MaybeMergeTemporalTypes(promoted_type, other_type,
options));
+ if (maybe_promoted) return maybe_promoted;
+
+ ARROW_ASSIGN_OR_RAISE(maybe_promoted,
+ MaybeMergeNumericTypes(promoted_type, other_type,
options));
+ if (maybe_promoted) return maybe_promoted;
+
+ if (options.promote_large) {
+ if (promoted_type->id() == Type::FIXED_SIZE_BINARY &&
+ is_base_binary_like(other_type->id())) {
+ promoted_type = binary();
+ promoted = other_type->id() == Type::BINARY;
+ }
+ if (other_type->id() == Type::FIXED_SIZE_BINARY &&
+ is_base_binary_like(promoted_type->id())) {
+ other_type = binary();
+ promoted = promoted_type->id() == Type::BINARY;
+ }
+
+ if (promoted_type->id() == Type::FIXED_SIZE_LIST &&
+ is_var_length_list(other_type->id())) {
+ promoted_type =
+ list(checked_cast<const
BaseListType&>(*promoted_type).value_field());
+ promoted = other_type->Equals(*promoted_type);
+ }
+ if (other_type->id() == Type::FIXED_SIZE_LIST &&
+ is_var_length_list(promoted_type->id())) {
+ other_type = list(checked_cast<const
BaseListType&>(*other_type).value_field());
+ promoted = other_type->Equals(*promoted_type);
+ }
+ }
+
+ if (options.promote_binary) {
+ if (promoted_type->id() == Type::FIXED_SIZE_BINARY &&
+ other_type->id() == Type::FIXED_SIZE_BINARY) {
+ return binary();
+ }
+ if (is_string(promoted_type->id()) && is_binary(other_type->id())) {
+ promoted_type = MakeBinary(*promoted_type);
+ promoted =
+ offset_bit_width(promoted_type->id()) ==
offset_bit_width(other_type->id());
+ } else if (is_binary(promoted_type->id()) && is_string(other_type->id())) {
+ other_type = MakeBinary(*other_type);
+ promoted =
+ offset_bit_width(promoted_type->id()) ==
offset_bit_width(other_type->id());
+ }
+ }
+
+ if (options.promote_large) {
+ if ((promoted_type->id() == Type::STRING && other_type->id() ==
Type::LARGE_STRING) ||
+ (promoted_type->id() == Type::LARGE_STRING && other_type->id() ==
Type::STRING)) {
+ return large_utf8();
+ } else if ((promoted_type->id() == Type::BINARY &&
+ other_type->id() == Type::LARGE_BINARY) ||
+ (promoted_type->id() == Type::LARGE_BINARY &&
+ other_type->id() == Type::BINARY)) {
+ return large_binary();
+ }
+ if ((promoted_type->id() == Type::LIST && other_type->id() ==
Type::LARGE_LIST) ||
+ (promoted_type->id() == Type::LARGE_LIST && other_type->id() ==
Type::LIST)) {
+ promoted_type =
+ large_list(checked_cast<const
BaseListType&>(*promoted_type).value_field());
+ promoted = true;
Review Comment:
I took the original code from #12000. I would probably implement this using
a post-order visitor pattern to traverse the schema. That would simplify the
logic quite a bit.
I've updated the code to merge the values as well
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]