This is an automated email from the ASF dual-hosted git repository.

apitrou pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new d3c5b85  ARROW-3532: [Python] Emit warning when looking up for 
duplicate struct or schema fields
d3c5b85 is described below

commit d3c5b85e6881543f12378fd63c2921aead36cc5f
Author: Antoine Pitrou <[email protected]>
AuthorDate: Wed Feb 20 20:31:47 2019 +0100

    ARROW-3532: [Python] Emit warning when looking up for duplicate struct or 
schema fields
    
    Also provide C++ APIs to get all fields having a given name.
    
    Author: Antoine Pitrou <[email protected]>
    
    Closes #3713 from pitrou/ARROW-3532-duplicated-field-warning and squashes 
the following commits:
    
    d23b241a6 <Antoine Pitrou> ARROW-3532:  Emit warning when looking up for 
duplicate struct or schema fields
---
 cpp/src/arrow/testing/gtest_util.h   |  7 +++
 cpp/src/arrow/type-test.cc           | 58 +++++++++++++++++++++++-
 cpp/src/arrow/type.cc                | 85 ++++++++++++++++++++++++------------
 cpp/src/arrow/type.h                 | 16 ++++++-
 python/pyarrow/includes/libarrow.pxd |  3 ++
 python/pyarrow/tests/test_schema.py  | 52 ++++++++++++++++++++++
 python/pyarrow/tests/test_types.py   | 33 ++------------
 python/pyarrow/types.pxi             | 28 +++++++++---
 8 files changed, 214 insertions(+), 68 deletions(-)

diff --git a/cpp/src/arrow/testing/gtest_util.h 
b/cpp/src/arrow/testing/gtest_util.h
index 8fe56ad..ad5830a 100644
--- a/cpp/src/arrow/testing/gtest_util.h
+++ b/cpp/src/arrow/testing/gtest_util.h
@@ -314,4 +314,11 @@ inline void BitmapFromVector(const std::vector<T>& 
is_valid,
   ASSERT_OK(GetBitmapFromVector(is_valid, out));
 }
 
+template <typename T>
+void AssertSortedEquals(std::vector<T> u, std::vector<T> v) {
+  std::sort(u.begin(), u.end());
+  std::sort(v.begin(), v.end());
+  ASSERT_EQ(u, v);
+}
+
 }  // namespace arrow
diff --git a/cpp/src/arrow/type-test.cc b/cpp/src/arrow/type-test.cc
index 5138177..49a0332 100644
--- a/cpp/src/arrow/type-test.cc
+++ b/cpp/src/arrow/type-test.cc
@@ -202,6 +202,41 @@ TEST_F(TestSchema, GetFieldIndex) {
   ASSERT_EQ(-1, schema->GetFieldIndex("not-found"));
 }
 
+TEST_F(TestSchema, GetFieldDuplicates) {
+  auto f0 = field("f0", int32());
+  auto f1 = field("f1", uint8(), false);
+  auto f2 = field("f2", utf8());
+  auto f3 = field("f1", list(int16()));
+
+  auto schema = ::arrow::schema({f0, f1, f2, f3});
+
+  ASSERT_EQ(0, schema->GetFieldIndex(f0->name()));
+  ASSERT_EQ(-1, schema->GetFieldIndex(f1->name()));  // duplicate
+  ASSERT_EQ(2, schema->GetFieldIndex(f2->name()));
+  ASSERT_EQ(-1, schema->GetFieldIndex("not-found"));
+  ASSERT_EQ(std::vector<int>{0}, schema->GetAllFieldIndices(f0->name()));
+  AssertSortedEquals(std::vector<int>{1, 3}, 
schema->GetAllFieldIndices(f1->name()));
+
+  std::vector<std::shared_ptr<Field>> results;
+
+  results = schema->GetAllFieldsByName(f0->name());
+  ASSERT_EQ(results.size(), 1);
+  ASSERT_TRUE(results[0]->Equals(f0));
+
+  results = schema->GetAllFieldsByName(f1->name());
+  ASSERT_EQ(results.size(), 2);
+  if (results[0]->type()->id() == Type::UINT8) {
+    ASSERT_TRUE(results[0]->Equals(f1));
+    ASSERT_TRUE(results[1]->Equals(f3));
+  } else {
+    ASSERT_TRUE(results[0]->Equals(f3));
+    ASSERT_TRUE(results[1]->Equals(f1));
+  }
+
+  results = schema->GetAllFieldsByName("not-found");
+  ASSERT_EQ(results.size(), 0);
+}
+
 TEST_F(TestSchema, TestMetadataConstruction) {
   auto metadata0 = key_value_metadata({{"foo", "bar"}, {"bizz", "buzz"}});
   auto metadata1 = key_value_metadata({{"foo", "baz"}});
@@ -495,7 +530,7 @@ TEST(TestStructType, GetFieldIndex) {
   ASSERT_EQ(-1, struct_type.GetFieldIndex("not-found"));
 }
 
-TEST(TestStructType, GetFieldIndexDuplicates) {
+TEST(TestStructType, GetFieldDuplicates) {
   auto f0 = field("f0", int32());
   auto f1 = field("f1", int64());
   auto f2 = field("f1", utf8());
@@ -503,6 +538,27 @@ TEST(TestStructType, GetFieldIndexDuplicates) {
 
   ASSERT_EQ(0, struct_type.GetFieldIndex("f0"));
   ASSERT_EQ(-1, struct_type.GetFieldIndex("f1"));
+  ASSERT_EQ(std::vector<int>{0}, struct_type.GetAllFieldIndices(f0->name()));
+  AssertSortedEquals(std::vector<int>{1, 2}, 
struct_type.GetAllFieldIndices(f1->name()));
+
+  std::vector<std::shared_ptr<Field>> results;
+
+  results = struct_type.GetAllFieldsByName(f0->name());
+  ASSERT_EQ(results.size(), 1);
+  ASSERT_TRUE(results[0]->Equals(f0));
+
+  results = struct_type.GetAllFieldsByName(f1->name());
+  ASSERT_EQ(results.size(), 2);
+  if (results[0]->type()->id() == Type::INT64) {
+    ASSERT_TRUE(results[0]->Equals(f1));
+    ASSERT_TRUE(results[1]->Equals(f2));
+  } else {
+    ASSERT_TRUE(results[0]->Equals(f2));
+    ASSERT_TRUE(results[1]->Equals(f1));
+  }
+
+  results = struct_type.GetAllFieldsByName("not-found");
+  ASSERT_EQ(results.size(), 0);
 }
 
 TEST(TestDictionaryType, Equals) {
diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc
index 852ddb0..3bb997a 100644
--- a/cpp/src/arrow/type.cc
+++ b/cpp/src/arrow/type.cc
@@ -225,15 +225,31 @@ std::string UnionType::ToString() const {
 
 namespace {
 
-std::unordered_map<std::string, int> CreateNameToIndexMap(
+std::unordered_multimap<std::string, int> CreateNameToIndexMap(
     const std::vector<std::shared_ptr<Field>>& fields) {
-  std::unordered_map<std::string, int> name_to_index;
+  std::unordered_multimap<std::string, int> name_to_index;
   for (size_t i = 0; i < fields.size(); ++i) {
-    name_to_index[fields[i]->name()] = static_cast<int>(i);
+    name_to_index.emplace(fields[i]->name(), static_cast<int>(i));
   }
   return name_to_index;
 }
 
+int LookupNameIndex(const std::unordered_multimap<std::string, int>& 
name_to_index,
+                    const std::string& name) {
+  auto p = name_to_index.equal_range(name);
+  auto it = p.first;
+  if (it == p.second) {
+    // Not found
+    return -1;
+  }
+  auto index = it->second;
+  if (++it != p.second) {
+    // Duplicate field name
+    return -1;
+  }
+  return index;
+}
+
 }  // namespace
 
 StructType::StructType(const std::vector<std::shared_ptr<Field>>& fields)
@@ -261,33 +277,30 @@ std::shared_ptr<Field> StructType::GetFieldByName(const 
std::string& name) const
 }
 
 int StructType::GetFieldIndex(const std::string& name) const {
-  if (name_to_index_.size() < children_.size()) {
-    // There are duplicate field names. Refuse to guess
-    int counts = 0;
-    int last_observed_index = -1;
-    for (size_t i = 0; i < children_.size(); ++i) {
-      if (children_[i]->name() == name) {
-        ++counts;
-        last_observed_index = static_cast<int>(i);
-      }
-    }
+  return LookupNameIndex(name_to_index_, name);
+}
 
-    if (counts == 1) {
-      return last_observed_index;
-    } else {
-      // Duplicate or not found
-      return -1;
-    }
+std::vector<int> StructType::GetAllFieldIndices(const std::string& name) const 
{
+  std::vector<int> result;
+  auto p = name_to_index_.equal_range(name);
+  for (auto it = p.first; it != p.second; ++it) {
+    result.push_back(it->second);
   }
+  return result;
+}
 
-  auto it = name_to_index_.find(name);
-  if (it == name_to_index_.end()) {
-    return -1;
-  } else {
-    return it->second;
+std::vector<std::shared_ptr<Field>> StructType::GetAllFieldsByName(
+    const std::string& name) const {
+  std::vector<std::shared_ptr<Field>> result;
+  auto p = name_to_index_.equal_range(name);
+  for (auto it = p.first; it != p.second; ++it) {
+    result.push_back(children_[it->second]);
   }
+  return result;
 }
 
+// Deprecated methods
+
 std::shared_ptr<Field> StructType::GetChildByName(const std::string& name) 
const {
   return GetFieldByName(name);
 }
@@ -386,12 +399,26 @@ std::shared_ptr<Field> Schema::GetFieldByName(const 
std::string& name) const {
 }
 
 int Schema::GetFieldIndex(const std::string& name) const {
-  auto it = name_to_index_.find(name);
-  if (it == name_to_index_.end()) {
-    return -1;
-  } else {
-    return it->second;
+  return LookupNameIndex(name_to_index_, name);
+}
+
+std::vector<int> Schema::GetAllFieldIndices(const std::string& name) const {
+  std::vector<int> result;
+  auto p = name_to_index_.equal_range(name);
+  for (auto it = p.first; it != p.second; ++it) {
+    result.push_back(it->second);
+  }
+  return result;
+}
+
+std::vector<std::shared_ptr<Field>> Schema::GetAllFieldsByName(
+    const std::string& name) const {
+  std::vector<std::shared_ptr<Field>> result;
+  auto p = name_to_index_.equal_range(name);
+  for (auto it = p.first; it != p.second; ++it) {
+    result.push_back(fields_[it->second]);
   }
+  return result;
 }
 
 Status Schema::AddField(int i, const std::shared_ptr<Field>& field,
diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h
index c775b11..472ba03 100644
--- a/cpp/src/arrow/type.h
+++ b/cpp/src/arrow/type.h
@@ -518,10 +518,16 @@ class ARROW_EXPORT StructType : public NestedType {
   /// Returns null if name not found
   std::shared_ptr<Field> GetFieldByName(const std::string& name) const;
 
+  /// Return all fields having this name
+  std::vector<std::shared_ptr<Field>> GetAllFieldsByName(const std::string& 
name) const;
+
   /// Returns -1 if name not found or if there are multiple fields having the
   /// same name
   int GetFieldIndex(const std::string& name) const;
 
+  /// Return the indices of all fields having this name
+  std::vector<int> GetAllFieldIndices(const std::string& name) const;
+
   ARROW_DEPRECATED("Use GetFieldByName")
   std::shared_ptr<Field> GetChildByName(const std::string& name) const;
 
@@ -529,7 +535,7 @@ class ARROW_EXPORT StructType : public NestedType {
   int GetChildIndex(const std::string& name) const;
 
  private:
-  std::unordered_map<std::string, int> name_to_index_;
+  std::unordered_multimap<std::string, int> name_to_index_;
 };
 
 /// \brief Base type class for (fixed-size) decimal data
@@ -826,9 +832,15 @@ class ARROW_EXPORT Schema {
   /// Returns null if name not found
   std::shared_ptr<Field> GetFieldByName(const std::string& name) const;
 
+  /// Return all fields having this name
+  std::vector<std::shared_ptr<Field>> GetAllFieldsByName(const std::string& 
name) const;
+
   /// Returns -1 if name not found
   int GetFieldIndex(const std::string& name) const;
 
+  /// Return the indices of all fields having this name
+  std::vector<int> GetAllFieldIndices(const std::string& name) const;
+
   const std::vector<std::shared_ptr<Field>>& fields() const { return fields_; }
 
   std::vector<std::string> field_names() const;
@@ -866,7 +878,7 @@ class ARROW_EXPORT Schema {
  private:
   std::vector<std::shared_ptr<Field>> fields_;
 
-  std::unordered_map<std::string, int> name_to_index_;
+  std::unordered_multimap<std::string, int> name_to_index_;
 
   std::shared_ptr<const KeyValueMetadata> metadata_;
 };
diff --git a/python/pyarrow/includes/libarrow.pxd 
b/python/pyarrow/includes/libarrow.pxd
index 97bc892..64b907b 100644
--- a/python/pyarrow/includes/libarrow.pxd
+++ b/python/pyarrow/includes/libarrow.pxd
@@ -277,6 +277,7 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil:
         CStructType(const vector[shared_ptr[CField]]& fields)
 
         shared_ptr[CField] GetFieldByName(const c_string& name)
+        vector[shared_ptr[CField]] GetAllFieldsByName(const c_string& name)
         int GetFieldIndex(const c_string& name)
 
     cdef cppclass CUnionType" arrow::UnionType"(CDataType):
@@ -298,7 +299,9 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil:
         shared_ptr[CField] field(int i)
         shared_ptr[const CKeyValueMetadata] metadata()
         shared_ptr[CField] GetFieldByName(const c_string& name)
+        vector[shared_ptr[CField]] GetAllFieldsByName(const c_string& name)
         int64_t GetFieldIndex(const c_string& name)
+        vector[int64_t] GetAllFieldIndice(const c_string& name)
         int num_fields()
         c_string ToString()
 
diff --git a/python/pyarrow/tests/test_schema.py 
b/python/pyarrow/tests/test_schema.py
index 8549d61..75c7f48 100644
--- a/python/pyarrow/tests/test_schema.py
+++ b/python/pyarrow/tests/test_schema.py
@@ -20,6 +20,7 @@ import pickle
 
 import pytest
 import numpy as np
+import pandas as pd
 
 import pyarrow as pa
 
@@ -266,6 +267,30 @@ baz: list<item: int8>
         pa.schema(fields)
 
 
+def test_schema_duplicate_fields():
+    fields = [
+        pa.field('foo', pa.int32()),
+        pa.field('bar', pa.string()),
+        pa.field('foo', pa.list_(pa.int8())),
+    ]
+    sch = pa.schema(fields)
+    assert sch.names == ['foo', 'bar', 'foo']
+    assert sch.types == [pa.int32(), pa.string(), pa.list_(pa.int8())]
+    assert len(sch) == 3
+    assert repr(sch) == """\
+foo: int32
+bar: string
+foo: list<item: int8>
+  child 0, item: int8"""
+
+    assert sch[0].name == 'foo'
+    assert sch[0].type == fields[0].type
+    assert sch.field_by_name('bar') == fields[1]
+    assert sch.field_by_name('xxx') is None
+    with pytest.warns(UserWarning):
+        assert sch.field_by_name('foo') is None
+
+
 def test_field_flatten():
     f0 = pa.field('foo', pa.int32()).add_metadata({b'foo': b'bar'})
     assert f0.flatten() == [f0]
@@ -456,3 +481,30 @@ def test_type_schema_pickling():
     schema = pa.schema(fields, metadata={b'foo': b'bar'})
     roundtripped = pickle.loads(pickle.dumps(schema))
     assert schema == roundtripped
+
+
+def test_empty_table():
+    schema = pa.schema([
+        pa.field('oneField', pa.int64())
+    ])
+    table = schema.empty_table()
+    assert isinstance(table, pa.Table)
+    assert table.num_rows == 0
+    assert table.schema == schema
+
+
[email protected]('data', [
+    list(range(10)),
+    pd.Categorical(list(range(10))),
+    ['foo', 'bar', None, 'baz', 'qux'],
+    np.array([
+        '2007-07-13T01:23:34.123456789',
+        '2006-01-13T12:34:56.432539784',
+        '2010-08-13T05:46:57.437699912'
+    ], dtype='datetime64[ns]')
+])
+def test_schema_from_pandas(data):
+    df = pd.DataFrame({'a': data})
+    schema = pa.Schema.from_pandas(df)
+    expected = pa.Table.from_pandas(df).schema
+    assert schema == expected
diff --git a/python/pyarrow/tests/test_types.py 
b/python/pyarrow/tests/test_types.py
index 11c0cca..2401bf5 100644
--- a/python/pyarrow/tests/test_types.py
+++ b/python/pyarrow/tests/test_types.py
@@ -22,7 +22,6 @@ import pytest
 import hypothesis as h
 import hypothesis.strategies as st
 
-import pandas as pd
 import numpy as np
 import pyarrow as pa
 import pyarrow.types as types
@@ -249,8 +248,9 @@ def test_struct_type():
     assert ty['b'] == ty[2]
 
     # Duplicate
-    with pytest.raises(KeyError):
-        ty['a']
+    with pytest.warns(UserWarning):
+        with pytest.raises(KeyError):
+            ty['a']
 
     # Not found
     with pytest.raises(KeyError):
@@ -519,16 +519,6 @@ def test_field_add_remove_metadata():
     assert f5.equals(f6)
 
 
-def test_empty_table():
-    schema = pa.schema([
-        pa.field('oneField', pa.int64())
-    ])
-    table = schema.empty_table()
-    assert isinstance(table, pa.Table)
-    assert table.num_rows == 0
-    assert table.schema == schema
-
-
 def test_is_integer_value():
     assert pa.types.is_integer_value(1)
     assert pa.types.is_integer_value(np.int64(1))
@@ -550,23 +540,6 @@ def test_is_boolean_value():
     assert pa.types.is_boolean_value(np.bool_(False))
 
 
[email protected]('data', [
-    list(range(10)),
-    pd.Categorical(list(range(10))),
-    ['foo', 'bar', None, 'baz', 'qux'],
-    np.array([
-        '2007-07-13T01:23:34.123456789',
-        '2006-01-13T12:34:56.432539784',
-        '2010-08-13T05:46:57.437699912'
-    ], dtype='datetime64[ns]')
-])
-def test_schema_from_pandas(data):
-    df = pd.DataFrame({'a': data})
-    schema = pa.Schema.from_pandas(df)
-    expected = pa.Table.from_pandas(df).schema
-    assert schema == expected
-
-
 @h.given(
     past.all_types |
     past.all_fields |
diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi
index 72f62b3..0960d34 100644
--- a/python/pyarrow/types.pxi
+++ b/python/pyarrow/types.pxi
@@ -17,6 +17,7 @@
 
 import collections
 import re
+import warnings
 
 # These are imprecise because the type (in pandas 0.x) depends on the presence
 # of nulls
@@ -247,13 +248,17 @@ cdef class StructType(DataType):
         """
         Return a child field by its name rather than its index.
         """
-        cdef shared_ptr[CField] field
+        cdef vector[shared_ptr[CField]] fields
 
-        field = self.struct_type.GetFieldByName(tobytes(name))
-        if field == nullptr:
+        fields = self.struct_type.GetAllFieldsByName(tobytes(name))
+        if fields.size() == 0:
             raise KeyError(name)
-
-        return pyarrow_wrap_field(field)
+        elif fields.size() > 1:
+            warnings.warn("Struct field name corresponds to more "
+                          "than one field", UserWarning)
+            raise KeyError(name)
+        else:
+            return pyarrow_wrap_field(fields[0])
 
     def __len__(self):
         """
@@ -740,7 +745,18 @@ cdef class Schema:
         -------
         field: pyarrow.Field
         """
-        return pyarrow_wrap_field(self.schema.GetFieldByName(tobytes(name)))
+        cdef:
+            vector[shared_ptr[CField]] results
+
+        results = self.schema.GetAllFieldsByName(tobytes(name))
+        if results.size() == 0:
+            return None
+        elif results.size() > 1:
+            warnings.warn("Schema field name corresponds to more "
+                          "than one field", UserWarning)
+            return None
+        else:
+            return pyarrow_wrap_field(results[0])
 
     def get_field_index(self, name):
         return self.schema.GetFieldIndex(tobytes(name))

Reply via email to