-->

2020-12-31

Python-friendly dtypes for pyspark dataframes

When using pyspark, most of the JVM core of Apache Spark is hidden to the python user. A notable exception is the DataFrame.dtypes attribute, which contains JVM format string representations of the data types of the DataFrame columns . While for the atomic data types the translation to python data types is trivial, for the composite data types the string representations can quickly become unwieldy (e.g. when using the elasticsearch-hadoop InputFormat).

from pyspark.sql import functions as f, types as t, Row

# Create DataFrame
ints = [1,2,3]
arrays = [[1,2,3], [4,5,6], [7,8,9]]
maps = [{'a':1, 'b':2, 'c':3}]
rows = [Row(x=1, y='1'), Row(x=2, y='2'), Row(x=3, y='3')]
composites = [Row(x=1, y=[1,2], z={'a':1, 'b':2}), Row(x=2, y=[3,4], z={'a':3, 'b':4})]
df = spark.createDataFrame(zip(ints, arrays, maps, rows, composites))

# Show standard dtypes
for x in df.dtypes:
    print(x)

# ('_1', 'bigint')
# ('_2', 'array<bigint>')
# ('_3', 'map<string,bigint>')
# ('_4', 'struct<x:bigint,y:string>')
# ('_5', 'struct<x:bigint,y:array<bigint>,z:map<string,bigint>>')

# Show python types in collected DataFrame
row = df.collect()[0]
for x in row:
    print(type(x))

# <class 'int'>
# <class 'list'>
# <class 'dict'>
# <class 'pyspark.sql.types.Row'>
# <class 'pyspark.sql.types.Row'>

# Show python types passed to a user defined function
def python_type(x):
    return str(type(x))

udf_python_type = f.udf(python_type, t.StringType())
row = df \
    .withColumn('_1', udf_python_type('_1')) \
    .withColumn('_2', udf_python_type('_2')) \
    .withColumn('_3', udf_python_type('_3')) \
    .withColumn('_4', udf_python_type('_4')) \
    .withColumn('_5', udf_python_type('_5')) \
    .collect()[0]
for x in row:
    print(x)

# <class 'int'>
# <class 'list'>
# <class 'dict'>
# <class 'pyspark.sql.types.Row'>
# <class 'pyspark.sql.types.Row'>

While the dtypes attribute shows the data types in terms of the JVM StructType, ArrayType and MapType classes, the python programmer gets to see the corresponding python types when collecting the DataFrame or passing a column to a user defined function.

To fill this gap in type representations, this blog presents a small utility that translates the content of the dtypes attribute to a data structure with string representations of the corresponding python types. The utility can be found as a gist on github, but is also listed below:

#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
#
import re
import string

import pyparsing


def pysql_dtypes(dtypes):
    """Represents the spark-sql dtypes in terms of python [], {} and Row()
    constructs.
    :param dtypes: [(string, string)] result from pyspark.sql.DataFrame.dtypes
    :return: [(string, string)]
    """

    def assemble(nested):
        cur = 0
        assembled = ''
        while cur < len(nested):
            parts = re.findall(r'[^:,]+', nested[cur])
            if not parts:
                parts = [nested[cur]]
            tail = parts[-1]
            if tail == 'array':
                assembled += nested[cur][:-5] + '['
                assembled += assemble(nested[cur+1])
                assembled += ']'
                cur += 2
            elif tail == 'map':
                assembled += nested[cur][:-3] + '{'
                assembled += assemble(nested[cur+1])
                assembled += '}'
                cur += 2
            elif tail == 'struct':
                assembled += nested[cur][:-6] + 'Row('
                assembled += assemble(nested[cur+1])
                assembled += ')'
                cur += 2
            else:
                assembled += nested[cur]
                cur += 1
        return assembled

    chars = ''.join([x for x in string.printable if x not in ['<', '>']])
    word = pyparsing.Word(chars)
    parens = pyparsing.nestedExpr('<', '>', content=word)
    dtype = word + pyparsing.Optional(parens)

    result = []
    for name, schema in dtypes:
        tree = dtype.parseString(schema).asList()
        pyschema = assemble(tree).replace(',', ', ').replace(',  ', ', ')
        result.append((name, pyschema))
    return result

The pysql_dtypes() function starts with building a simple grammar using the pyparsing package, to parse the dtypes as given by pyspark. Central in the grammar are the special characters '<' and '>' that are used to recognize nested types in the array<>, map<> and struct<> constructs. Note that these characteres cannot occur in JVM or python field names. The pyparsing.nestedExpr()method takes care of any multi-level nesting. Words are defined as arbitrary successions of printable characters with the exception of the angle brackets (because we use the output of DataFrame.dtypes as input, we assume that we will not encounter any weird characters). Finally, a word occurs either at the start of a JVM type representation (in case of so-called atomic types) or within angled brackets.

The assemble() function translates the parsed JVM type representation into corresponding python types and re-assembled them. This function is recursive because the nested expressions can have arbitrary depth. It splits the earlier defined 'words' into parts separated by a ',' or ':' character and then applies a simple recipe for the assembly of the parts for the various possible combinations of the recognized array, map and struct constructs. The gist on github also contains a test suite that provides many sample outputs of the pysql_dtypes() function. The code sample below takes the example used earlier.

from pysql_dtypes import pysql_dtypes

for x in pysql_dtypes(df.dtypes):
    print(x)

#('_1', 'bigint')
#('_2', '[bigint]')
#('_3', '{string, bigint}')
#('_4', 'Row(x:bigint, y:string)')
#('_5', 'Row(x:bigint, y:[bigint], z:{string, bigint})')

The pysql_dtypes()function will be suggested to Apache Spark, so if you would like the utility to become available as part of pyspark, be sure to star the gist on github and add yourself as watcher to the corresponding issue (requires Apache Jira account).