Github user mateiz commented on a diff in the pull request:

    https://github.com/apache/spark/pull/1598#discussion_r15678985
  
    --- Diff: python/pyspark/sql.py ---
    @@ -405,70 +359,432 @@ def _parse_datatype_string(datatype_string):
         >>> check_datatype(complex_arraytype)
         True
         >>> # Complex MapType.
    -    >>> complex_maptype = MapType(complex_structtype, complex_arraytype, 
False)
    +    >>> complex_maptype = MapType(complex_structtype,
    +    ...                           complex_arraytype, False)
         >>> check_datatype(complex_maptype)
         True
         """
    -    left_bracket_index = datatype_string.find("(")
    -    if left_bracket_index == -1:
    +    index = datatype_string.find("(")
    +    if index == -1:
             # It is a primitive type.
    -        left_bracket_index = len(datatype_string)
    -    type_or_field = datatype_string[:left_bracket_index]
    -    rest_part = 
datatype_string[left_bracket_index+1:len(datatype_string)-1].strip()
    -    if type_or_field == "StringType":
    -        return StringType()
    -    elif type_or_field == "BinaryType":
    -        return BinaryType()
    -    elif type_or_field == "BooleanType":
    -        return BooleanType()
    -    elif type_or_field == "TimestampType":
    -        return TimestampType()
    -    elif type_or_field == "DecimalType":
    -        return DecimalType()
    -    elif type_or_field == "DoubleType":
    -        return DoubleType()
    -    elif type_or_field == "FloatType":
    -        return FloatType()
    -    elif type_or_field == "ByteType":
    -        return ByteType()
    -    elif type_or_field == "IntegerType":
    -        return IntegerType()
    -    elif type_or_field == "LongType":
    -        return LongType()
    -    elif type_or_field == "ShortType":
    -        return ShortType()
    +        index = len(datatype_string)
    +    type_or_field = datatype_string[:index]
    +    rest_part = datatype_string[index + 1:len(datatype_string) - 1].strip()
    +
    +    if type_or_field in _all_primitive_types:
    +        return _all_primitive_types[type_or_field]()
    +
         elif type_or_field == "ArrayType":
             last_comma_index = rest_part.rfind(",")
             containsNull = True
    -        if rest_part[last_comma_index+1:].strip().lower() == "false":
    +        if rest_part[last_comma_index + 1:].strip().lower() == "false":
                 containsNull = False
    -        elementType = 
_parse_datatype_string(rest_part[:last_comma_index].strip())
    +        elementType = _parse_datatype_string(
    +            rest_part[:last_comma_index].strip())
             return ArrayType(elementType, containsNull)
    +
         elif type_or_field == "MapType":
             last_comma_index = rest_part.rfind(",")
             valueContainsNull = True
    -        if rest_part[last_comma_index+1:].strip().lower() == "false":
    +        if rest_part[last_comma_index + 1:].strip().lower() == "false":
                 valueContainsNull = False
    -        keyType, valueType = 
_parse_datatype_list(rest_part[:last_comma_index].strip())
    +        keyType, valueType = _parse_datatype_list(
    +            rest_part[:last_comma_index].strip())
             return MapType(keyType, valueType, valueContainsNull)
    +
         elif type_or_field == "StructField":
             first_comma_index = rest_part.find(",")
             name = rest_part[:first_comma_index].strip()
             last_comma_index = rest_part.rfind(",")
             nullable = True
    -        if rest_part[last_comma_index+1:].strip().lower() == "false":
    +        if rest_part[last_comma_index + 1:].strip().lower() == "false":
                 nullable = False
             dataType = _parse_datatype_string(
    -            rest_part[first_comma_index+1:last_comma_index].strip())
    +            rest_part[first_comma_index + 1:last_comma_index].strip())
             return StructField(name, dataType, nullable)
    +
         elif type_or_field == "StructType":
             # rest_part should be in the format like
             # List(StructField(field1,IntegerType,false)).
    -        field_list_string = rest_part[rest_part.find("(")+1:-1]
    +        field_list_string = rest_part[rest_part.find("(") + 1:-1]
             fields = _parse_datatype_list(field_list_string)
             return StructType(fields)
     
     
    +# Mapping Python types to Spark SQL DateType
    +_type_mappings = {
    +    bool: BooleanType,
    +    int: IntegerType,
    +    long: LongType,
    +    float: DoubleType,
    +    str: StringType,
    +    unicode: StringType,
    +    decimal.Decimal: DecimalType,
    +    datetime.datetime: TimestampType,
    +    datetime.date: TimestampType,
    +    datetime.time: TimestampType,
    +}
    +
    +
    +def _infer_type(obj):
    +    """Infer the DataType from obj"""
    +    if obj is None:
    +        raise ValueError("Can not infer type for None")
    +
    +    dataType = _type_mappings.get(type(obj))
    +    if dataType is not None:
    +        return dataType()
    +
    +    if isinstance(obj, dict):
    +        if not obj:
    +            raise ValueError("Can not infer type for empty dict")
    +        key, value = obj.iteritems().next()
    +        return MapType(_infer_type(key), _infer_type(value), True)
    +    elif isinstance(obj, (list, array.array)):
    +        if not obj:
    +            raise ValueError("Can not infer type for empty list/array")
    +        return ArrayType(_infer_type(obj[0]), True)
    +    else:
    +        try:
    +            return _infer_schema(obj)
    +        except ValueError:
    +            raise ValueError("not supported type: %s" % type(obj))
    +
    +
    +def _infer_schema(row):
    +    """Infer the schema from dict/namedtuple/object"""
    +    if isinstance(row, dict):
    +        items = sorted(row.items())
    +    elif isinstance(row, tuple):
    +        if hasattr(row, "_fields"): # namedtuple
    +            items = zip(row._fields, tuple(row))
    +        elif all(isinstance(x, tuple) and len(x) == 2
    +                 for x in row):
    +            items = row
    +    elif hasattr(row, "__dict__"): # object
    +        items = sorted(row.__dict__.items())
    +    else:
    +        raise ValueError("Can not infer schema for type: %s" % type(row))
    +
    +    fields = [StructField(k, _infer_type(v), True) for k, v in items]
    +    return StructType(fields)
    +
    +
    +def _create_converter(obj, dataType):
    +    """Create an converter to drop the names of fields in obj """
    +    if not _has_struct(dataType):
    +        return lambda x: x
    +
    +    elif isinstance(dataType, ArrayType):
    +        conv = _create_converter(obj[0], dataType.elementType)
    +        return lambda row: map(conv, row)
    +
    +    elif isinstance(dataType, MapType):
    +        value = obj.values()[0]
    +        conv = _create_converter(value, dataType.valueType)
    +        return lambda row: dict((k, conv(v)) for k, v in row.iteritems())
    +
    +    # dataType must be StructType
    +    names = [f.name for f in dataType.fields]
    +
    +    if isinstance(obj, dict):
    +        conv = lambda o: tuple(o.get(n) for n in names)
    +
    +    elif isinstance(obj, tuple):
    +        if hasattr(obj, "_fields"): # namedtuple
    +            conv = tuple
    +        elif all(isinstance(x, tuple) and len(x) == 2
    +                 for x in obj):
    +            conv = lambda o: tuple(v for k, v in o)
    +
    +    elif hasattr(obj, "__dict__"): # object
    +        conv = lambda o: [o.__dict__.get(n, None) for n in names]
    +
    +    nested = any(_has_struct(f.dataType) for f in dataType.fields)
    +    if not nested:
    +        return conv
    +
    +    row = conv(obj)
    +    convs = [_create_converter(v, f.dataType)
    +             for v, f in zip(row, dataType.fields)]
    +
    +    def nested_conv(row):
    +        return tuple(f(v) for f, v in zip(convs, conv(row)))
    +
    +    return nested_conv
    +
    +
    +def _dropSchema(rows, schema):
    +    """Drop all the names of fields, becoming tuples"""
    +    iterator = iter(rows)
    +    row = iterator.next()
    +    converter = _create_converter(row, schema)
    +    yield converter(row)
    +    for i in iterator:
    +        yield converter(i)
    +
    +
    +_BRAKETS = {'(': ')', '[': ']', '{': '}'}
    --- End diff --
    
    Small typo, should be BRACKETS with a C


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---

Reply via email to