Forward filling in Spark

Recently I had the challenge to figure out the status of a certain person in our database on any possible date, while the only thing we store is when the status of a person changes. The query I need to answer is similar to the following SQL statement:

SELECT status FROM person_statuses WHERE person = "John Doe" AND time = current_date()

Step 1 - Create the data

To create the dummy data I will create an event generator. The generated events will be the input for the forward fill exercise. I will use the code I have written for the data-pipeline-project from this repo.

In [1]:
from datetime import datetime, timedelta
from faker import Faker
import json


DATE_END = datetime.now()
DATE_START = DATE_END - timedelta(days=31)
NUM_EVENTS = 10

class EventGenerator:
    """ Defines the EventGenerator """

    MIN_LIVES = 1
    MAX_LIVES = 99
    CHARACTERS = ["Mario", "Luigi", "Peach", "Toad"]

    def __init__(self, num_events, output_type, start_date, end_date, output_file=None):
        """ Initialize the EventGenerator """
        self.faker = Faker()
        self.num_events = num_events
        self.output_type = output_type
        self.output_file = output_file
        self.start_date = start_date
        self.end_date = end_date

    def _get_date_between(self, date_start, date_end):
        """ Get a date between start and end date """
        return self.faker.date_between_dates(date_start=date_start, date_end=date_end)

    def _generate_events(self):
        """ Generate the metric data """
        for _ in range(self.num_events):
            yield {
                "character": self.faker.random_element(self.CHARACTERS),
                "world": self.faker.random_int(min=1, max=8, step=1),
                "level": self.faker.random_int(min=1, max=4, step=1),
                "lives": self.faker.random_int(
                    min=self.MIN_LIVES, max=self.MAX_LIVES, step=1
                ),
                "time": str(self._get_date_between(self.start_date, self.end_date)),
            }

    def store_events(self):
        if self.output_type == "jl":
            with open(self.output_file, "w") as outputfile:
                for event in self._generate_events():
                    outputfile.write(f"{json.dumps(event)}\n")
        elif self.output_type == "list":
            return list(self._generate_events())

I only want 10 events to keep the dataframe we use in Spark small.

In [2]:
params = {
    "num_events": NUM_EVENTS,
    "output_type": "list",
    "start_date": DATE_START,
    "end_date": DATE_END,
}
# Create the event generator
generator = EventGenerator(**params)
# Create and store the events
events = generator.store_events()

Step 2 - Analyze the data

The events represent the persons (Nintendo characters) and their status (current world and level). From looking at the data it is obvious there are big gaps in time before the characters advance to the next level. How do we know where Mario was without storing the daily world/level status?

In [3]:
import pandas as pd
pd.DataFrame(events).sort_values(["character", "time"])
Out[3]:
character world level lives time
9 Luigi 4 3 9 2020-04-24
8 Luigi 7 3 42 2020-04-30
5 Luigi 5 2 53 2020-05-14
0 Luigi 6 4 99 2020-05-22
3 Luigi 2 1 41 2020-05-23
1 Mario 5 3 13 2020-05-02
2 Mario 4 4 80 2020-05-06
7 Peach 8 4 55 2020-04-29
4 Toad 6 3 40 2020-05-11
6 Toad 6 3 58 2020-05-14

Step 3 - Create Spark Dataframe

To create the Spark Dataframe a SparkSession is used. I don't use any external libraries right now, since we will use plain (Py)Spark.

In [4]:
from pyspark.sql import SparkSession, SQLContext
from pyspark.sql import functions as F
from pyspark.sql.window import Window

import os
os.environ['PYSPARK_PYTHON'] = "/data/jupyter/bin/python"

spark = (SparkSession
  .builder
  .appName("Spark Forward Fill")
  .getOrCreate())

sc = SQLContext(spark)
In [5]:
df = sc.createDataFrame(events)
/data/jupyter/lib/python3.6/site-packages/pyspark/sql/session.py:346: UserWarning: inferring schema from dict is deprecated,please use pyspark.sql.Row instead
  warnings.warn("inferring schema from dict is deprecated,"
In [6]:
df.show()
+---------+-----+-----+----------+-----+
|character|level|lives|      time|world|
+---------+-----+-----+----------+-----+
|    Luigi|    4|   99|2020-05-22|    6|
|    Mario|    3|   13|2020-05-02|    5|
|    Mario|    4|   80|2020-05-06|    4|
|    Luigi|    1|   41|2020-05-23|    2|
|     Toad|    3|   40|2020-05-11|    6|
|    Luigi|    2|   53|2020-05-14|    5|
|     Toad|    3|   58|2020-05-14|    6|
|    Peach|    4|   55|2020-04-29|    8|
|    Luigi|    3|   42|2020-04-30|    7|
|    Luigi|    3|    9|2020-04-24|    4|
+---------+-----+-----+----------+-----+

Step 4 - Correct the data

Ensure we have the correct datatypes:

In [7]:
df.dtypes
Out[7]:
[('character', 'string'),
 ('level', 'bigint'),
 ('lives', 'bigint'),
 ('time', 'string'),
 ('world', 'bigint')]
In [8]:
df = df.withColumn("time", F.to_date(F.to_timestamp("time")))
In [9]:
df.dtypes
Out[9]:
[('character', 'string'),
 ('level', 'bigint'),
 ('lives', 'bigint'),
 ('time', 'date'),
 ('world', 'bigint')]

As a status I combine the world and the level:

In [10]:
df = df.withColumn("status", F.concat(F.col('world'), F.lit('-'), F.col('level')))
df.show()
+---------+-----+-----+----------+-----+------+
|character|level|lives|      time|world|status|
+---------+-----+-----+----------+-----+------+
|    Luigi|    4|   99|2020-05-22|    6|   6-4|
|    Mario|    3|   13|2020-05-02|    5|   5-3|
|    Mario|    4|   80|2020-05-06|    4|   4-4|
|    Luigi|    1|   41|2020-05-23|    2|   2-1|
|     Toad|    3|   40|2020-05-11|    6|   6-3|
|    Luigi|    2|   53|2020-05-14|    5|   5-2|
|     Toad|    3|   58|2020-05-14|    6|   6-3|
|    Peach|    4|   55|2020-04-29|    8|   8-4|
|    Luigi|    3|   42|2020-04-30|    7|   7-3|
|    Luigi|    3|    9|2020-04-24|    4|   4-3|
+---------+-----+-----+----------+-----+------+

Step 5 - Forward fill the data

The data should be partitioned by character and ordered by the time. For this I will use a simple window function to go through the data.

In [11]:
w = Window().partitionBy("character").orderBy("time")

Apply the window to create a new column where we add the time of the next status update per character.

In [12]:
df = df.withColumn("next_status", F.lead("time").over(w))
df.show()
+---------+-----+-----+----------+-----+------+-----------+
|character|level|lives|      time|world|status|next_status|
+---------+-----+-----+----------+-----+------+-----------+
|    Peach|    4|   55|2020-04-29|    8|   8-4|       null|
|     Toad|    3|   40|2020-05-11|    6|   6-3| 2020-05-14|
|     Toad|    3|   58|2020-05-14|    6|   6-3|       null|
|    Mario|    3|   13|2020-05-02|    5|   5-3| 2020-05-06|
|    Mario|    4|   80|2020-05-06|    4|   4-4|       null|
|    Luigi|    3|    9|2020-04-24|    4|   4-3| 2020-04-30|
|    Luigi|    3|   42|2020-04-30|    7|   7-3| 2020-05-14|
|    Luigi|    2|   53|2020-05-14|    5|   5-2| 2020-05-22|
|    Luigi|    4|   99|2020-05-22|    6|   6-4| 2020-05-23|
|    Luigi|    1|   41|2020-05-23|    2|   2-1|       null|
+---------+-----+-----+----------+-----+------+-----------+

With the sequence I create the column containing the dates between the time of the status change and the time of the next_status. The sequence column will contain a list of dates that can be exploded to create a row for each date in the list.

In [13]:
df = df.withColumn("sequence", F.when(F.col("next_status").isNotNull(),
                               F.expr("sequence(to_date(time), date_sub(to_date(next_status),1), interval 1 day)"))\
                                .otherwise(F.array("time")))
df.show()
+---------+-----+-----+----------+-----+------+-----------+--------------------+
|character|level|lives|      time|world|status|next_status|            sequence|
+---------+-----+-----+----------+-----+------+-----------+--------------------+
|    Peach|    4|   55|2020-04-29|    8|   8-4|       null|        [2020-04-29]|
|     Toad|    3|   40|2020-05-11|    6|   6-3| 2020-05-14|[2020-05-11, 2020...|
|     Toad|    3|   58|2020-05-14|    6|   6-3|       null|        [2020-05-14]|
|    Mario|    3|   13|2020-05-02|    5|   5-3| 2020-05-06|[2020-05-02, 2020...|
|    Mario|    4|   80|2020-05-06|    4|   4-4|       null|        [2020-05-06]|
|    Luigi|    3|    9|2020-04-24|    4|   4-3| 2020-04-30|[2020-04-24, 2020...|
|    Luigi|    3|   42|2020-04-30|    7|   7-3| 2020-05-14|[2020-04-30, 2020...|
|    Luigi|    2|   53|2020-05-14|    5|   5-2| 2020-05-22|[2020-05-14, 2020...|
|    Luigi|    4|   99|2020-05-22|    6|   6-4| 2020-05-23|        [2020-05-22]|
|    Luigi|    1|   41|2020-05-23|    2|   2-1|       null|        [2020-05-23]|
+---------+-----+-----+----------+-----+------+-----------+--------------------+

Now explode is applied to the sequence column which gives the result shown below. As we can see the events are filled from the first status change until the latest status change.

In [14]:
df.select("character", "status", F.explode("sequence").alias("time")).show()
+---------+------+----------+
|character|status|      time|
+---------+------+----------+
|    Peach|   8-4|2020-04-29|
|     Toad|   6-3|2020-05-11|
|     Toad|   6-3|2020-05-12|
|     Toad|   6-3|2020-05-13|
|     Toad|   6-3|2020-05-14|
|    Mario|   5-3|2020-05-02|
|    Mario|   5-3|2020-05-03|
|    Mario|   5-3|2020-05-04|
|    Mario|   5-3|2020-05-05|
|    Mario|   4-4|2020-05-06|
|    Luigi|   4-3|2020-04-24|
|    Luigi|   4-3|2020-04-25|
|    Luigi|   4-3|2020-04-26|
|    Luigi|   4-3|2020-04-27|
|    Luigi|   4-3|2020-04-28|
|    Luigi|   4-3|2020-04-29|
|    Luigi|   7-3|2020-04-30|
|    Luigi|   7-3|2020-05-01|
|    Luigi|   7-3|2020-05-02|
|    Luigi|   7-3|2020-05-03|
+---------+------+----------+
only showing top 20 rows

One thing missing in the approach above is that we only know the status until the latest change for each character. What I need is to know what the status is today, so I need to modify the code and create an artifical end date when there is no next status, or the next status is in the past. That way I can fill dates after the last status up until today with that last status. For completeness I will repeat the code that I have created before.

In [15]:
# Create the dataframe
df = sc.createDataFrame(events)
# Modify the data type
df = df.withColumn("time",
                   F.to_date(F.to_timestamp("time")))
# Create the status column
df = df.withColumn("status",
                   F.concat(F.col('world'),
                            F.lit('-'),
                            F.col('level')))
# Apply the window
df = df.withColumn("next_status",
                   F.lead("time").over(w))
# Fill in the empty `next_status` column with today's date
df = df.withColumn("next_status",
                   F.when(F.col("next_status").isNull(),
                          F.expr("current_date()"))\
                           .otherwise(F.col("next_status")))
# Apply the sequence
df = df.withColumn("sequence",
                   F.when(F.col("next_status").isNotNull(),
                          F.expr("sequence(to_date(time), date_sub(to_date(next_status), 1), interval 1 day)"))\
                           .otherwise(F.array("time")))
# Select the columns and explore the sequence
df.select("character", "status", F.explode("sequence").alias("time")).show()
+---------+------+----------+
|character|status|      time|
+---------+------+----------+
|    Peach|   8-4|2020-04-29|
|    Peach|   8-4|2020-04-30|
|    Peach|   8-4|2020-05-01|
|    Peach|   8-4|2020-05-02|
|    Peach|   8-4|2020-05-03|
|    Peach|   8-4|2020-05-04|
|    Peach|   8-4|2020-05-05|
|    Peach|   8-4|2020-05-06|
|    Peach|   8-4|2020-05-07|
|    Peach|   8-4|2020-05-08|
|    Peach|   8-4|2020-05-09|
|    Peach|   8-4|2020-05-10|
|    Peach|   8-4|2020-05-11|
|    Peach|   8-4|2020-05-12|
|    Peach|   8-4|2020-05-13|
|    Peach|   8-4|2020-05-14|
|    Peach|   8-4|2020-05-15|
|    Peach|   8-4|2020-05-16|
|    Peach|   8-4|2020-05-17|
|    Peach|   8-4|2020-05-18|
+---------+------+----------+
only showing top 20 rows

Step 6 - Avoid forward filling!

From this small test it is clear it is not wise to use the sequence to create an event for every date for every character. This will easily become massive and should not be used in reality. The main idea behind trying to do the forward filling is to compare the complexity with Python pandas. In Pandas you can forward fill by simply using ffill. I used the following code to retrieve the same result as in the Spark script above.

In [16]:
pdf = pd.DataFrame(events)
# Convert datatypes
pdf['time'] = pd.to_datetime(pdf['time'])
# Add status column
pdf['status'] = pdf.apply(lambda x: f"{x['world']}-{x['level']}", axis=1)
# Create full time range as frame
timeframe = pd.date_range(start=min(pdf['time']),
                          end=datetime.now().date()).to_frame().reset_index(drop=True).rename(columns={0: 'time'})
# Merge timeframe into original frame
pdf = pdf.merge(timeframe,
                left_on='time',
                right_on='time',
                how='right')
# 1. Pivot to get dates on rows and characters as columns
# 2. Forward fill values per character
# 3. Fill remaining NaNs with False
pdf = pdf.pivot(index='time',
                columns='character',
                values='status')

pdf = pdf.fillna(method='ffill')
# Drop NaN column and reset the index
pdf = pdf.loc[:, pdf.columns.notnull()].reset_index()
# Melt the columns back
pdf = pd.melt(pdf,
              id_vars='time',
              value_name='status')
print(f"Original length: {len(events)}, new length: {len(pdf)}")
pdf.head(10)
Original length: 10, new length: 128
Out[16]:
time character status
0 2020-04-24 Luigi 4-3
1 2020-04-25 Luigi 4-3
2 2020-04-26 Luigi 4-3
3 2020-04-27 Luigi 4-3
4 2020-04-28 Luigi 4-3
5 2020-04-29 Luigi 4-3
6 2020-04-30 Luigi 7-3
7 2020-05-01 Luigi 7-3
8 2020-05-02 Luigi 7-3
9 2020-05-03 Luigi 7-3

Step 7 - Be smarter

Instead of using the forward fill approach to create all this data, it is a better idea to only add one more column to the table where the end of the status is recorded. For example, if Mario was in level 1-1 seven days ago and today he finally made it to level 1-2, there is no need to create six events containing level 1-1 and one event for level 1-2. Instead it would be better that one event contains 1-1 with start date seven days ago and end date yesterday, plus an event for level 1-2 with start date today and future date somewhere in the future.

Coming back to the initial SQL query I can rewrite this easily to a query with BETWEEN:

SELECT status FROM person_statuses WHERE person = "John Doe" AND time = current_date()

is rewritten to

SELECT status FROM person_statuses WHERE person = "John Doe" AND current_date() BETWEEN time AND endtime

I use a similar script, but will extract one day from the endtime column. There is no explode needed with this approach.

In [17]:
# Create the dataframe
df = sc.createDataFrame(events)
# Modify the data type
df = df.withColumn("time",
                   F.to_date(F.to_timestamp("time")))
# Create the status column
df = df.withColumn("status",
                   F.concat(F.col('world'),
                            F.lit('-'),
                            F.col('level')))
# Apply the window
df = df.withColumn("endtime",
                   F.lead("time").over(w))
# Substract one day from the endtime
df = df.withColumn("endtime",
                   F.expr("date_sub(to_date(endtime), 1)"))
# Fill in the empty `endtime` column with today's date
df = df.withColumn("endtime",
                   F.when(F.col("endtime").isNull(),
                          F.expr("current_date()"))\
                           .otherwise(F.col("endtime")))
# Select the columns and explore the sequence
df.select("character", "status", "time", "endtime").show()
+---------+------+----------+----------+
|character|status|      time|   endtime|
+---------+------+----------+----------+
|    Peach|   8-4|2020-04-29|2020-05-25|
|     Toad|   6-3|2020-05-11|2020-05-13|
|     Toad|   6-3|2020-05-14|2020-05-25|
|    Mario|   5-3|2020-05-02|2020-05-05|
|    Mario|   4-4|2020-05-06|2020-05-25|
|    Luigi|   4-3|2020-04-24|2020-04-29|
|    Luigi|   7-3|2020-04-30|2020-05-13|
|    Luigi|   5-2|2020-05-14|2020-05-21|
|    Luigi|   6-4|2020-05-22|2020-05-22|
|    Luigi|   2-1|2020-05-23|2020-05-25|
+---------+------+----------+----------+

Store the dataframe as temporary view such that I can use SQL to query it with Spark:

In [18]:
df.createOrReplaceTempView("df")
In [19]:
sc.sql("SELECT status, time, endtime FROM df WHERE character = 'Peach'").show()
+------+----------+----------+
|status|      time|   endtime|
+------+----------+----------+
|   8-4|2020-04-29|2020-05-25|
+------+----------+----------+

In [20]:
sc.sql("SELECT status FROM df WHERE character = 'Peach' AND '2020-05-24' BETWEEN time AND endtime").show()
+------+
|status|
+------+
|   8-4|
+------+

Instead of using PySpark to create the dataframe the same can be achieved using SQL.

In [21]:
# Create the events view
eventsdf = sc.createDataFrame(events)
eventsdf.createOrReplaceTempView("events")
In [22]:
sc.sql("""
WITH statuses AS (
    SELECT
        character,
        CONCAT(world, '-', level) AS status,
        to_date(time) AS start,
        to_date(DATE_SUB(LEAD(time) OVER (PARTITION BY character ORDER BY time), 1)) AS end
    FROM
        events
)
SELECT
  character,
  status,
  start,
  IF(end IS NOT NULL, end, current_date()) AS end
FROM statuses
""").show()
+---------+------+----------+----------+
|character|status|     start|       end|
+---------+------+----------+----------+
|    Peach|   8-4|2020-04-29|2020-05-25|
|     Toad|   6-3|2020-05-11|2020-05-13|
|     Toad|   6-3|2020-05-14|2020-05-25|
|    Mario|   5-3|2020-05-02|2020-05-05|
|    Mario|   4-4|2020-05-06|2020-05-25|
|    Luigi|   4-3|2020-04-24|2020-04-29|
|    Luigi|   7-3|2020-04-30|2020-05-13|
|    Luigi|   5-2|2020-05-14|2020-05-21|
|    Luigi|   6-4|2020-05-22|2020-05-22|
|    Luigi|   2-1|2020-05-23|2020-05-25|
+---------+------+----------+----------+

And that's it. This was a good exercise to understand how easy it is to make things too complex..