jorisvandenbossche commented on code in PR #14106: URL: https://github.com/apache/arrow/pull/14106#discussion_r976448330
########## cpp/src/arrow/compute/kernels/scalar_cast_extension.cc: ########## @@ -0,0 +1,65 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Implementation of casting to extension types +#include "arrow/compute/kernels/common.h" +#include "arrow/compute/kernels/scalar_cast_internal.h" +#include "arrow/scalar.h" + +namespace arrow { +namespace compute { +namespace internal { + +namespace { +Status CastToExtension(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { + const CastOptions& options = checked_cast<const CastState*>(ctx->state())->options; + auto out_ty = static_cast<const ExtensionType&>(*options.to_type.type).storage_type(); + + DCHECK(batch[0].is_array()); + std::shared_ptr<Array> array = batch[0].array.ToArray(); + + ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Array> result, + Cast(*array, out_ty, options, ctx->exec_context())); + + ExtensionArray extension(options.to_type.GetSharedPtr(), result); + out->value = std::move(extension.data()); + return Status::OK(); +} + +std::shared_ptr<CastFunction> GetCastToExtension(std::string name) { + auto func = std::make_shared<CastFunction>(std::move(name), Type::EXTENSION); + for (auto types : {IntTypes(), FloatingPointTypes(), StringTypes(), BinaryTypes()}) { + for (auto in_ty : types) { + DCHECK_OK( + func->AddKernel(in_ty->id(), {in_ty}, kOutputTargetType, CastToExtension)); + } + } + DCHECK_OK(func->AddKernel(Type::DICTIONARY, {InputType(Type::DICTIONARY)}, + kOutputTargetType, CastToExtension)); + return func; +} + +}; // namespace + +std::vector<std::shared_ptr<CastFunction>> GetExtensionCasts() { + auto func = GetCastToExtension("cast_extension"); Review Comment: Small thing I am wondering: is there a need for the separate `GetCastToExtension` function, or could the content of that function also be just inline here? ########## python/pyarrow/tests/test_extension_type.py: ########## @@ -517,10 +517,31 @@ def test_cast_kernel_on_extension_arrays(): assert isinstance(casted, pa.ChunkedArray) -def test_casting_to_extension_type_raises(): - arr = pa.array([1, 2, 3, 4], pa.int64()) - with pytest.raises(pa.ArrowNotImplementedError): - arr.cast(IntegerType()) +@pytest.mark.parametrize("data,ty", ( + ([1, 2], pa.int32), + ([1, 2], pa.int64), + (["1", "2"], pa.string), + ([b"1", b"2"], pa.binary), + ([1.0, 2.0], pa.float32), + ([1.0, 2.0], pa.float64) +)) +def test_casting_to_extension_type(data, ty): + arr = pa.array(data, ty()) + out = arr.cast(IntegerType()) + assert isinstance(out, pa.ExtensionArray) + assert out.type == IntegerType() + assert out.to_pylist() == [1, 2] Review Comment: I didn't check if this is covered by the C++ tests, but we should also add some failures cases (where casting to the storage type fails, either always or depending on an option, eg disallowing vs allowing float truncation when casting a float array to this IntegerType) ########## cpp/src/arrow/compute/kernels/scalar_cast_test.cc: ########## @@ -2765,6 +2772,68 @@ TEST(Cast, ExtensionTypeToIntDowncast) { } } +TEST(Cast, PrimitiveToExtension) { + { + auto primitive_array = ArrayFromJSON(uint8(), "[0, 1, 3]"); + auto extension_array = SmallintArrayFromJSON("[0, 1, 3]"); + CastOptions options; + options.to_type = smallint(); + CheckCast(primitive_array, extension_array, options); + } + { + CastOptions options; + options.to_type = smallint(); + CheckCastFails(ArrayFromJSON(utf8(), "[\"hello\"]"), options); + } +} + +TEST(Cast, ExtensionDictToExtension) { + auto extension_array = SmallintArrayFromJSON("[1, 2, 1]"); + auto indices_array = ArrayFromJSON(int32(), "[0, 1, 0]"); + + ASSERT_OK_AND_ASSIGN(auto dict_array, + DictionaryArray::FromArrays(indices_array, extension_array)); + + CastOptions options; + options.to_type = smallint(); + CheckCast(dict_array, extension_array, options); +} + +TEST(Cast, IntToExtensionTypeDowncast) { + CheckCast(ArrayFromJSON(uint8(), "[0, 100, 200, 1, 2]"), + SmallintArrayFromJSON("[0, 100, 200, 1, 2]")); + + // int32 to Smallint(int16), with overflow + { + CastOptions options; + options.to_type = smallint(); + CheckCastFails(ArrayFromJSON(int32(), "[0, null, 32768, 1, 3]"), options); + + options.allow_int_overflow = true; + CheckCast(ArrayFromJSON(int32(), "[0, null, 32768, 1, 3]"), + SmallintArrayFromJSON("[0, null, -32768, 1, 3]"), options); + } + + // int32 to Smallint(int16), with underflow + { + CastOptions options; + options.to_type = smallint(); + CheckCastFails(ArrayFromJSON(int32(), "[0, null, -32769, 1, 3]"), options); + + options.allow_int_overflow = true; + CheckCast(ArrayFromJSON(int32(), "[0, null, -32769, 1, 3]"), + SmallintArrayFromJSON("[0, null, 32767, 1, 3]"), options); + } + + // Cannot cast between extension types + { + CastOptions options; + options.to_type = smallint(); + auto tiny_array = TinyintArrayFromJSON("[0, 1, 3]"); + ASSERT_NOT_OK(Cast(tiny_array, smallint(), options)); + } Review Comment: This is just not yet implemented? Should we add ``AddCommonCasts(Type::EXTENSION, ..)`` to `GetExtensionCasts` ? (that would in addition also cover casting null to extension type) Or do we actually want to prohibit casting extension type -> extension type (assuming that the storage type -> storage type cast is implemented)? (can also for a follow-up) -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org