This is an automated email from the ASF dual-hosted git repository.
apitrou pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new d3c5b85 ARROW-3532: [Python] Emit warning when looking up for
duplicate struct or schema fields
d3c5b85 is described below
commit d3c5b85e6881543f12378fd63c2921aead36cc5f
Author: Antoine Pitrou <[email protected]>
AuthorDate: Wed Feb 20 20:31:47 2019 +0100
ARROW-3532: [Python] Emit warning when looking up for duplicate struct or
schema fields
Also provide C++ APIs to get all fields having a given name.
Author: Antoine Pitrou <[email protected]>
Closes #3713 from pitrou/ARROW-3532-duplicated-field-warning and squashes
the following commits:
d23b241a6 <Antoine Pitrou> ARROW-3532: Emit warning when looking up for
duplicate struct or schema fields
---
cpp/src/arrow/testing/gtest_util.h | 7 +++
cpp/src/arrow/type-test.cc | 58 +++++++++++++++++++++++-
cpp/src/arrow/type.cc | 85 ++++++++++++++++++++++++------------
cpp/src/arrow/type.h | 16 ++++++-
python/pyarrow/includes/libarrow.pxd | 3 ++
python/pyarrow/tests/test_schema.py | 52 ++++++++++++++++++++++
python/pyarrow/tests/test_types.py | 33 ++------------
python/pyarrow/types.pxi | 28 +++++++++---
8 files changed, 214 insertions(+), 68 deletions(-)
diff --git a/cpp/src/arrow/testing/gtest_util.h
b/cpp/src/arrow/testing/gtest_util.h
index 8fe56ad..ad5830a 100644
--- a/cpp/src/arrow/testing/gtest_util.h
+++ b/cpp/src/arrow/testing/gtest_util.h
@@ -314,4 +314,11 @@ inline void BitmapFromVector(const std::vector<T>&
is_valid,
ASSERT_OK(GetBitmapFromVector(is_valid, out));
}
+template <typename T>
+void AssertSortedEquals(std::vector<T> u, std::vector<T> v) {
+ std::sort(u.begin(), u.end());
+ std::sort(v.begin(), v.end());
+ ASSERT_EQ(u, v);
+}
+
} // namespace arrow
diff --git a/cpp/src/arrow/type-test.cc b/cpp/src/arrow/type-test.cc
index 5138177..49a0332 100644
--- a/cpp/src/arrow/type-test.cc
+++ b/cpp/src/arrow/type-test.cc
@@ -202,6 +202,41 @@ TEST_F(TestSchema, GetFieldIndex) {
ASSERT_EQ(-1, schema->GetFieldIndex("not-found"));
}
+TEST_F(TestSchema, GetFieldDuplicates) {
+ auto f0 = field("f0", int32());
+ auto f1 = field("f1", uint8(), false);
+ auto f2 = field("f2", utf8());
+ auto f3 = field("f1", list(int16()));
+
+ auto schema = ::arrow::schema({f0, f1, f2, f3});
+
+ ASSERT_EQ(0, schema->GetFieldIndex(f0->name()));
+ ASSERT_EQ(-1, schema->GetFieldIndex(f1->name())); // duplicate
+ ASSERT_EQ(2, schema->GetFieldIndex(f2->name()));
+ ASSERT_EQ(-1, schema->GetFieldIndex("not-found"));
+ ASSERT_EQ(std::vector<int>{0}, schema->GetAllFieldIndices(f0->name()));
+ AssertSortedEquals(std::vector<int>{1, 3},
schema->GetAllFieldIndices(f1->name()));
+
+ std::vector<std::shared_ptr<Field>> results;
+
+ results = schema->GetAllFieldsByName(f0->name());
+ ASSERT_EQ(results.size(), 1);
+ ASSERT_TRUE(results[0]->Equals(f0));
+
+ results = schema->GetAllFieldsByName(f1->name());
+ ASSERT_EQ(results.size(), 2);
+ if (results[0]->type()->id() == Type::UINT8) {
+ ASSERT_TRUE(results[0]->Equals(f1));
+ ASSERT_TRUE(results[1]->Equals(f3));
+ } else {
+ ASSERT_TRUE(results[0]->Equals(f3));
+ ASSERT_TRUE(results[1]->Equals(f1));
+ }
+
+ results = schema->GetAllFieldsByName("not-found");
+ ASSERT_EQ(results.size(), 0);
+}
+
TEST_F(TestSchema, TestMetadataConstruction) {
auto metadata0 = key_value_metadata({{"foo", "bar"}, {"bizz", "buzz"}});
auto metadata1 = key_value_metadata({{"foo", "baz"}});
@@ -495,7 +530,7 @@ TEST(TestStructType, GetFieldIndex) {
ASSERT_EQ(-1, struct_type.GetFieldIndex("not-found"));
}
-TEST(TestStructType, GetFieldIndexDuplicates) {
+TEST(TestStructType, GetFieldDuplicates) {
auto f0 = field("f0", int32());
auto f1 = field("f1", int64());
auto f2 = field("f1", utf8());
@@ -503,6 +538,27 @@ TEST(TestStructType, GetFieldIndexDuplicates) {
ASSERT_EQ(0, struct_type.GetFieldIndex("f0"));
ASSERT_EQ(-1, struct_type.GetFieldIndex("f1"));
+ ASSERT_EQ(std::vector<int>{0}, struct_type.GetAllFieldIndices(f0->name()));
+ AssertSortedEquals(std::vector<int>{1, 2},
struct_type.GetAllFieldIndices(f1->name()));
+
+ std::vector<std::shared_ptr<Field>> results;
+
+ results = struct_type.GetAllFieldsByName(f0->name());
+ ASSERT_EQ(results.size(), 1);
+ ASSERT_TRUE(results[0]->Equals(f0));
+
+ results = struct_type.GetAllFieldsByName(f1->name());
+ ASSERT_EQ(results.size(), 2);
+ if (results[0]->type()->id() == Type::INT64) {
+ ASSERT_TRUE(results[0]->Equals(f1));
+ ASSERT_TRUE(results[1]->Equals(f2));
+ } else {
+ ASSERT_TRUE(results[0]->Equals(f2));
+ ASSERT_TRUE(results[1]->Equals(f1));
+ }
+
+ results = struct_type.GetAllFieldsByName("not-found");
+ ASSERT_EQ(results.size(), 0);
}
TEST(TestDictionaryType, Equals) {
diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc
index 852ddb0..3bb997a 100644
--- a/cpp/src/arrow/type.cc
+++ b/cpp/src/arrow/type.cc
@@ -225,15 +225,31 @@ std::string UnionType::ToString() const {
namespace {
-std::unordered_map<std::string, int> CreateNameToIndexMap(
+std::unordered_multimap<std::string, int> CreateNameToIndexMap(
const std::vector<std::shared_ptr<Field>>& fields) {
- std::unordered_map<std::string, int> name_to_index;
+ std::unordered_multimap<std::string, int> name_to_index;
for (size_t i = 0; i < fields.size(); ++i) {
- name_to_index[fields[i]->name()] = static_cast<int>(i);
+ name_to_index.emplace(fields[i]->name(), static_cast<int>(i));
}
return name_to_index;
}
+int LookupNameIndex(const std::unordered_multimap<std::string, int>&
name_to_index,
+ const std::string& name) {
+ auto p = name_to_index.equal_range(name);
+ auto it = p.first;
+ if (it == p.second) {
+ // Not found
+ return -1;
+ }
+ auto index = it->second;
+ if (++it != p.second) {
+ // Duplicate field name
+ return -1;
+ }
+ return index;
+}
+
} // namespace
StructType::StructType(const std::vector<std::shared_ptr<Field>>& fields)
@@ -261,33 +277,30 @@ std::shared_ptr<Field> StructType::GetFieldByName(const
std::string& name) const
}
int StructType::GetFieldIndex(const std::string& name) const {
- if (name_to_index_.size() < children_.size()) {
- // There are duplicate field names. Refuse to guess
- int counts = 0;
- int last_observed_index = -1;
- for (size_t i = 0; i < children_.size(); ++i) {
- if (children_[i]->name() == name) {
- ++counts;
- last_observed_index = static_cast<int>(i);
- }
- }
+ return LookupNameIndex(name_to_index_, name);
+}
- if (counts == 1) {
- return last_observed_index;
- } else {
- // Duplicate or not found
- return -1;
- }
+std::vector<int> StructType::GetAllFieldIndices(const std::string& name) const
{
+ std::vector<int> result;
+ auto p = name_to_index_.equal_range(name);
+ for (auto it = p.first; it != p.second; ++it) {
+ result.push_back(it->second);
}
+ return result;
+}
- auto it = name_to_index_.find(name);
- if (it == name_to_index_.end()) {
- return -1;
- } else {
- return it->second;
+std::vector<std::shared_ptr<Field>> StructType::GetAllFieldsByName(
+ const std::string& name) const {
+ std::vector<std::shared_ptr<Field>> result;
+ auto p = name_to_index_.equal_range(name);
+ for (auto it = p.first; it != p.second; ++it) {
+ result.push_back(children_[it->second]);
}
+ return result;
}
+// Deprecated methods
+
std::shared_ptr<Field> StructType::GetChildByName(const std::string& name)
const {
return GetFieldByName(name);
}
@@ -386,12 +399,26 @@ std::shared_ptr<Field> Schema::GetFieldByName(const
std::string& name) const {
}
int Schema::GetFieldIndex(const std::string& name) const {
- auto it = name_to_index_.find(name);
- if (it == name_to_index_.end()) {
- return -1;
- } else {
- return it->second;
+ return LookupNameIndex(name_to_index_, name);
+}
+
+std::vector<int> Schema::GetAllFieldIndices(const std::string& name) const {
+ std::vector<int> result;
+ auto p = name_to_index_.equal_range(name);
+ for (auto it = p.first; it != p.second; ++it) {
+ result.push_back(it->second);
+ }
+ return result;
+}
+
+std::vector<std::shared_ptr<Field>> Schema::GetAllFieldsByName(
+ const std::string& name) const {
+ std::vector<std::shared_ptr<Field>> result;
+ auto p = name_to_index_.equal_range(name);
+ for (auto it = p.first; it != p.second; ++it) {
+ result.push_back(fields_[it->second]);
}
+ return result;
}
Status Schema::AddField(int i, const std::shared_ptr<Field>& field,
diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h
index c775b11..472ba03 100644
--- a/cpp/src/arrow/type.h
+++ b/cpp/src/arrow/type.h
@@ -518,10 +518,16 @@ class ARROW_EXPORT StructType : public NestedType {
/// Returns null if name not found
std::shared_ptr<Field> GetFieldByName(const std::string& name) const;
+ /// Return all fields having this name
+ std::vector<std::shared_ptr<Field>> GetAllFieldsByName(const std::string&
name) const;
+
/// Returns -1 if name not found or if there are multiple fields having the
/// same name
int GetFieldIndex(const std::string& name) const;
+ /// Return the indices of all fields having this name
+ std::vector<int> GetAllFieldIndices(const std::string& name) const;
+
ARROW_DEPRECATED("Use GetFieldByName")
std::shared_ptr<Field> GetChildByName(const std::string& name) const;
@@ -529,7 +535,7 @@ class ARROW_EXPORT StructType : public NestedType {
int GetChildIndex(const std::string& name) const;
private:
- std::unordered_map<std::string, int> name_to_index_;
+ std::unordered_multimap<std::string, int> name_to_index_;
};
/// \brief Base type class for (fixed-size) decimal data
@@ -826,9 +832,15 @@ class ARROW_EXPORT Schema {
/// Returns null if name not found
std::shared_ptr<Field> GetFieldByName(const std::string& name) const;
+ /// Return all fields having this name
+ std::vector<std::shared_ptr<Field>> GetAllFieldsByName(const std::string&
name) const;
+
/// Returns -1 if name not found
int GetFieldIndex(const std::string& name) const;
+ /// Return the indices of all fields having this name
+ std::vector<int> GetAllFieldIndices(const std::string& name) const;
+
const std::vector<std::shared_ptr<Field>>& fields() const { return fields_; }
std::vector<std::string> field_names() const;
@@ -866,7 +878,7 @@ class ARROW_EXPORT Schema {
private:
std::vector<std::shared_ptr<Field>> fields_;
- std::unordered_map<std::string, int> name_to_index_;
+ std::unordered_multimap<std::string, int> name_to_index_;
std::shared_ptr<const KeyValueMetadata> metadata_;
};
diff --git a/python/pyarrow/includes/libarrow.pxd
b/python/pyarrow/includes/libarrow.pxd
index 97bc892..64b907b 100644
--- a/python/pyarrow/includes/libarrow.pxd
+++ b/python/pyarrow/includes/libarrow.pxd
@@ -277,6 +277,7 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil:
CStructType(const vector[shared_ptr[CField]]& fields)
shared_ptr[CField] GetFieldByName(const c_string& name)
+ vector[shared_ptr[CField]] GetAllFieldsByName(const c_string& name)
int GetFieldIndex(const c_string& name)
cdef cppclass CUnionType" arrow::UnionType"(CDataType):
@@ -298,7 +299,9 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil:
shared_ptr[CField] field(int i)
shared_ptr[const CKeyValueMetadata] metadata()
shared_ptr[CField] GetFieldByName(const c_string& name)
+ vector[shared_ptr[CField]] GetAllFieldsByName(const c_string& name)
int64_t GetFieldIndex(const c_string& name)
+ vector[int64_t] GetAllFieldIndice(const c_string& name)
int num_fields()
c_string ToString()
diff --git a/python/pyarrow/tests/test_schema.py
b/python/pyarrow/tests/test_schema.py
index 8549d61..75c7f48 100644
--- a/python/pyarrow/tests/test_schema.py
+++ b/python/pyarrow/tests/test_schema.py
@@ -20,6 +20,7 @@ import pickle
import pytest
import numpy as np
+import pandas as pd
import pyarrow as pa
@@ -266,6 +267,30 @@ baz: list<item: int8>
pa.schema(fields)
+def test_schema_duplicate_fields():
+ fields = [
+ pa.field('foo', pa.int32()),
+ pa.field('bar', pa.string()),
+ pa.field('foo', pa.list_(pa.int8())),
+ ]
+ sch = pa.schema(fields)
+ assert sch.names == ['foo', 'bar', 'foo']
+ assert sch.types == [pa.int32(), pa.string(), pa.list_(pa.int8())]
+ assert len(sch) == 3
+ assert repr(sch) == """\
+foo: int32
+bar: string
+foo: list<item: int8>
+ child 0, item: int8"""
+
+ assert sch[0].name == 'foo'
+ assert sch[0].type == fields[0].type
+ assert sch.field_by_name('bar') == fields[1]
+ assert sch.field_by_name('xxx') is None
+ with pytest.warns(UserWarning):
+ assert sch.field_by_name('foo') is None
+
+
def test_field_flatten():
f0 = pa.field('foo', pa.int32()).add_metadata({b'foo': b'bar'})
assert f0.flatten() == [f0]
@@ -456,3 +481,30 @@ def test_type_schema_pickling():
schema = pa.schema(fields, metadata={b'foo': b'bar'})
roundtripped = pickle.loads(pickle.dumps(schema))
assert schema == roundtripped
+
+
+def test_empty_table():
+ schema = pa.schema([
+ pa.field('oneField', pa.int64())
+ ])
+ table = schema.empty_table()
+ assert isinstance(table, pa.Table)
+ assert table.num_rows == 0
+ assert table.schema == schema
+
+
[email protected]('data', [
+ list(range(10)),
+ pd.Categorical(list(range(10))),
+ ['foo', 'bar', None, 'baz', 'qux'],
+ np.array([
+ '2007-07-13T01:23:34.123456789',
+ '2006-01-13T12:34:56.432539784',
+ '2010-08-13T05:46:57.437699912'
+ ], dtype='datetime64[ns]')
+])
+def test_schema_from_pandas(data):
+ df = pd.DataFrame({'a': data})
+ schema = pa.Schema.from_pandas(df)
+ expected = pa.Table.from_pandas(df).schema
+ assert schema == expected
diff --git a/python/pyarrow/tests/test_types.py
b/python/pyarrow/tests/test_types.py
index 11c0cca..2401bf5 100644
--- a/python/pyarrow/tests/test_types.py
+++ b/python/pyarrow/tests/test_types.py
@@ -22,7 +22,6 @@ import pytest
import hypothesis as h
import hypothesis.strategies as st
-import pandas as pd
import numpy as np
import pyarrow as pa
import pyarrow.types as types
@@ -249,8 +248,9 @@ def test_struct_type():
assert ty['b'] == ty[2]
# Duplicate
- with pytest.raises(KeyError):
- ty['a']
+ with pytest.warns(UserWarning):
+ with pytest.raises(KeyError):
+ ty['a']
# Not found
with pytest.raises(KeyError):
@@ -519,16 +519,6 @@ def test_field_add_remove_metadata():
assert f5.equals(f6)
-def test_empty_table():
- schema = pa.schema([
- pa.field('oneField', pa.int64())
- ])
- table = schema.empty_table()
- assert isinstance(table, pa.Table)
- assert table.num_rows == 0
- assert table.schema == schema
-
-
def test_is_integer_value():
assert pa.types.is_integer_value(1)
assert pa.types.is_integer_value(np.int64(1))
@@ -550,23 +540,6 @@ def test_is_boolean_value():
assert pa.types.is_boolean_value(np.bool_(False))
[email protected]('data', [
- list(range(10)),
- pd.Categorical(list(range(10))),
- ['foo', 'bar', None, 'baz', 'qux'],
- np.array([
- '2007-07-13T01:23:34.123456789',
- '2006-01-13T12:34:56.432539784',
- '2010-08-13T05:46:57.437699912'
- ], dtype='datetime64[ns]')
-])
-def test_schema_from_pandas(data):
- df = pd.DataFrame({'a': data})
- schema = pa.Schema.from_pandas(df)
- expected = pa.Table.from_pandas(df).schema
- assert schema == expected
-
-
@h.given(
past.all_types |
past.all_fields |
diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi
index 72f62b3..0960d34 100644
--- a/python/pyarrow/types.pxi
+++ b/python/pyarrow/types.pxi
@@ -17,6 +17,7 @@
import collections
import re
+import warnings
# These are imprecise because the type (in pandas 0.x) depends on the presence
# of nulls
@@ -247,13 +248,17 @@ cdef class StructType(DataType):
"""
Return a child field by its name rather than its index.
"""
- cdef shared_ptr[CField] field
+ cdef vector[shared_ptr[CField]] fields
- field = self.struct_type.GetFieldByName(tobytes(name))
- if field == nullptr:
+ fields = self.struct_type.GetAllFieldsByName(tobytes(name))
+ if fields.size() == 0:
raise KeyError(name)
-
- return pyarrow_wrap_field(field)
+ elif fields.size() > 1:
+ warnings.warn("Struct field name corresponds to more "
+ "than one field", UserWarning)
+ raise KeyError(name)
+ else:
+ return pyarrow_wrap_field(fields[0])
def __len__(self):
"""
@@ -740,7 +745,18 @@ cdef class Schema:
-------
field: pyarrow.Field
"""
- return pyarrow_wrap_field(self.schema.GetFieldByName(tobytes(name)))
+ cdef:
+ vector[shared_ptr[CField]] results
+
+ results = self.schema.GetAllFieldsByName(tobytes(name))
+ if results.size() == 0:
+ return None
+ elif results.size() > 1:
+ warnings.warn("Schema field name corresponds to more "
+ "than one field", UserWarning)
+ return None
+ else:
+ return pyarrow_wrap_field(results[0])
def get_field_index(self, name):
return self.schema.GetFieldIndex(tobytes(name))