Counting Turkeys with the Apache Spark Python API.

This is a notebook showing how to compute frequencies of Thanksgiving-related topics on Twitter with PySpark (the Spark Python API).

Note that as of version 2.2, PySpark is on PyPI. This makes setup very easy (pip install pyspark).

First, we'll import a bunch of packages we'll use below. Here is some version information:

  • bqplot: 0.10.1
  • geopandas: 0.3.0
  • matplotlib: 2.0.2
  • pandas: 0.20.3
  • pyspark: 2.2.0
  • Python: 3.6.2
In [ ]:
from functools import reduce
import gzip
import itertools
from io import BytesIO
import json
from operator import itemgetter
import os
import re
import sys

import bqplot
import geopandas as gpd
from IPython.display import HTML, display, Image
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import pandas as pd
import pyspark
from pyspark.sql import DataFrame, functions as F
from pyspark.sql.functions import udf, col
from pyspark.sql.types import ArrayType, StringType, FloatType, BooleanType
import us

print("bqplot:", bqplot.__version__)
print("geopandas:", gpd.__version__)
print("matplotlib:", mpl.__version__)
print("pandas:", pd.__version__)
print("pyspark:", pyspark.__version__)
print("Python:", sys.version)

Connecting to the local Spark cluster

Now, we set some Spark environment variables. Increasing driver-memory was needed when I ran experiments on the full dataset.

Instantiating the Spark context allows us to connect to the Spark cluster, which can just be on your local machine.

In [ ]:
os.environ['PYSPARK_SUBMIT_ARGS'] = '--driver-memory 8g pyspark-shell'
os.environ['SPARK_LOCAL_HOSTNAME'] = 'localhost'
sc = pyspark.SparkContext(appName="TurkeyCounts")
session = pyspark.sql.SparkSession(sc)

Reading Twitter data

I downloaded tweets using GNIP's Historical PowerTrack API. Note that tweets are in GNIP's activity streams format, which is slightly different from tweets from Twitter's API (e.g., GNIP uses instead of

The GNIP query I ran was the following (cf. syntax documentation), which includes various terms I had noticed as frequent and relevant from a prior look at a small sample of Thanksgiving tweets (I ended up not using "kale" because there wasn't much data):

(thanksgiving OR #thanksgiving OR turkey OR #turkey OR "black friday" OR blackfriday OR #blackfriday OR #happythanksgiving OR #cowboys OR #lions OR cowboys OR lions OR @lions OR @dallascowboys OR pie OR #pie OR kale OR #kale) -is:retweet profile_country:US

As a preprocessing step (not in this notebook), I removed the last line of job status info from the one-JSON-line-per-tweet files from GNIP.

The following will load up tweets into a Spark DataFrame, with one row per tweet. Two cells below, the number of tweets is printed.

In [ ]:
# Clean the data by removing the last line from GNIP, which is always.
data_dir = 'cleaned_data'
paths = [
    os.path.join(data_dir, x) for x in os.listdir(data_dir)
    if x.endswith('activities.json.gz')
In [ ]:
# GNIP provides files in 10-minute chunks. We could sample a few to run quick analyses 
# for debugging, etc.
# paths = np.random.RandomState(42).choice(paths, 3, replace=False).tolist()
input_df =

Defining Topics

Next, we'll define some topics. Each topic is a list of words or phrases.

Spark can build dataframes from multiple sources, and those can be easily accessed or joined to each other using Spark's APIs. Above, we load from JSON. Here, we load from Python tuples. Loading from CSV and various other formats is also possible.

In [ ]:
topics = {
    'thanksgiving': ['thanksgiving', '#thanksgiving', '#happythanksgiving'],
    'turkey': ['turkey', '#turkey'],
    'black friday': ['black friday', '#blackfriday'], 
    'pie': ['pie', '#pie'],
    'lions': ['lions', '@lions', '#lions'],
    'cowboys': ['cowboys', '#cowboys', '@dallascowboys'],

# This converts the dictionary above into a dataframe with two columns (topic, word)
# and a row for each pair. `itertools.chain.from_iterable` concatenates the lists
# of topic-word pairs for each topic into a single iterable.
df_topic_words = session.createDataFrame(
        ((name, w) for w in words) 
        for name, words in topics.items())), 
    schema=['topic', 'word']) 

Tokenization and preprocessing

Here, we'll make some Spark User-Defined Functions (UDFs) to do some particular processing in Python, easily extending what's possible with Spark's SQL-like syntax.

Specifically, we'll clean up IDs to look like Twitter IDs (this isn't actually necessary here but was causing troubles for me early on, and I decided to leave it as an example). We'll also extract which states users are from (GNIP provides normalized profile locations), filter out users who aren't using the popular official Twitter clients to avoid bots and spammers, and tokenize tweets to find out which tweets have words from the topics we defined above. The result is a cleaned up DataFrame with one row per tweet.

For details about the PySpark syntax here, see this page.

Side note: In the relatively near future, with Spark 2.3.x, it may be possible to avoid some of the overhead in converting between Python and Java representations by replacing these UDFs with vectorized UDFs.

In [ ]:
def remove_user_id_prefix(user_id):
    assert user_id.startswith('')
    return user_id[len(''):]

def extract_state(profile_locations):
    if not profile_locations:
        return ""
    return profile_locations[0]['address']['region']

    'Twitter for iPhone',
    'Twitter for Android',
    'Twitter for iPad',
    'Twitter Web Client'

def is_from_official_client(generator_name):
    return generator_name in _OFFICIAL_CLIENTS

TOK_PATTERN = re.compile(u"[#@]?[\w]+")

# A list of words and phrases to look for when preprocessing the texts.
TARGET_WORDS = set(itertools.chain.from_iterable(topics.values()))

def tokenize(text):
    """Tokenize a text str into a list of strings matched from TARGET_WORDS."""
    text = re.sub('(http\S+|&)', ' ', text)
    words = list(map(str.lower, TOK_PATTERN.findall(text)))
    toks = list({
        y for y in
        [' '.join(x) for x in zip(words, words[1:])] + words
        if y in TARGET_WORDS
    return toks

# Check that tokenization works.
assert (
    set(tokenize("black friday wibble thanksgiving wobble #thanksgiving wubble")) ==
    {'black friday', 'thanksgiving', '#thanksgiving'}
tokenize_udf = udf(returnType=ArrayType(StringType()))(tokenize)

df = (
    .withColumn("tokens", tokenize_udf(input_df['body']))
    .withColumn("state", extract_state(input_df['gnip.profileLocations']))
    .filter('state IS NOT NULL')
    .withColumn("user_id", remove_user_id_prefix(input_df['']))

Associating users with states

Here, we'll make a dataframe with the name of the U.S. state for each user. We'll arbitrarily pick a state if a user happen to have changed their profile within the tweets in our dataset.

In [ ]:
def pick_state(x):
        return (x[0], x[1][0].split('\t')[1])
    except IndexError:
        return (x[0], None)

df_states_non_uniq =, df.state,

df_user_states = session.createDataFrame(
    schema=['user_id', 'state']

Compute user counts per topic per state

Next, we'll use the explode function to turn the one-row-per-tweet df into a one-row-per-word dataframe df_tokens. That is then joined with the dataframe mapping words to topics df_topic_words and the mapping from users to U.S. states df_user_states. The result is then used to count distinct user IDs per state/topic pair.

The cache() function calls tell Spark to keep intermediate datasets around in memory or disk so they don't have to be recomputed if they are used later for multiple things.

In [ ]:
df_tokens =, df.user_id, F.explode(df.tokens).alias('word'))

df_topics_by_state = (
    .join(df_topic_words, on='word')
    .join(df_user_states, on='user_id')
    .select('user_id', 'topic', 'state')
    .groupby(['state', 'topic'])


Map projection fun

This loads up U.S. state geometries in GeoJSON format that are conveniently available in bqplot, which is a great Python library for making interactive maps via d3.

However, we'll use geopandas to make static maps below. To make a nice-looking U.S. map with geopandas, I had to fiddle with proj4 strings and such, which I won't even try to explain here.

In [ ]:
df_geom = gpd.read_file(bqplot.__path__[0] + '/map_data/USStatesMap.json', 
                        driver='GeoJSON') = "+init=epsg:4326"

proj4 = '+proj=aea +lat_1=29.5 +lat_2=45.5 +lat_0=23 +lon_0=-96 +x_0=0 +y_0=0 +ellps=GRS80 +towgs84=1,1,-1,0,0,0,0 +units=m +no_defs'
df_geom = df_geom[].to_crs(proj4)
df_geom[(df_geom['name'] == 'Alaska') | (df_geom['name'] == 'Hawaii')] = None
df_geom = (
    .rename(columns={'name': 'state'})
    .drop('id', axis=1)

Define a simple map function

This cell defines a function to be used to make a map given topic counts and state population statistics. It will be called below for each topic.

In [ ]:
def make_map(query_counts, state_pops, df_geom, max_val=0.1, cmap='OrRd', title=""):

    # Compute rates of users tweeting, normalized by general state-level
    # activity statistics.
    df_rate = pd.DataFrame([
        {'state': k, 'adj_user_activity': query_counts.get(k, 0.) / v}
        for k, v in state_pops.items()

    # Here's the actual map-making command!
    ax = df_geom.join(df_rate).plot(
        figsize=(15, 10), 

    return ax, df_rate

# A color scale from light-gray to a shade of blue that Civis uses.
cmap = mpl.colors.LinearSegmentedColormap.from_list(
        'civisblue', ((0.9, 0.9, 0.9), (1/255, 148/255, 211/255)), N=100)

Load up state population data.

Next, we'll load up some pre-computed counts of distinct users per state, so we can adjust the state-level user counts for state population sizes. These counts were calculated from a separate, non-topic specific sample of tweets with U.S. profile locations.

The profile locations from GNIP are automatically normalized to known locations (cities, counties, states). For example, it looked like some people from Virginia are mistakenly assigned West Virginia by GNIP. The process is imperfect, and we observed some mistakes. Using these general state-level counts instead of state population sizes (e.g., from the Census) will hopefully account for this.

In [ ]:
with open('distinct_user_counts_by_state_2016-11.json') as f:
    state_pops = json.load(f)
print(sorted(state_pops.items(), key=lambda x: -x[1])[:5], "...")

Make maps for each topic!

Next, we'll actually print out the maps and some other statistics for each topic.

In [ ]:
def plot_topic(topic, state_pops, df_geom, cmap):
    rate_label = 'Adjusted user activity'

    # Use `collect` to trigger execution of the Spark computations defined above,
    # and convert the resulting spark dataframe to a dictionary of state-level 
    # user counts.
    query_counts = {
        x['state']: x['n_users']
        for x in (df_topics_by_state
                  .where(col('topic') == topic)
    # Use the `map_map` function to make a map, show it, and save a copy to disk.
    title = '%s (%s)' % (rate_label, topic.title())
        buf = BytesIO()
        ax, df_rate = make_map(
            query_counts, state_pops, df_geom, cmap=cmap, title=title)
        img_bytes =
        with open('map_%s.png' % topic, 'wb') as f:
        map_img = Image(img_bytes)
    # Display the maps in the notebook, with total user counts.
    details_html = """
    <b>Total Users:</b>
        HTML("<hr />"),
        HTML("<h2>Topic: {title}</h2>".format(title=topic.title())),

for topic in topics:
    plot_topic(topic, state_pops, df_geom, cmap)

Cleaning up

The following cell closes the Spark context and cleans things up.

In [ ]: