IBM Developer Advocacy

Getting started with GraphFrames in Apache Spark

David Taieb

Introduction to Spark and graphs

GraphX is one of the 4 foundational components of Spark — along with SparkSQL, Spark Streaming and MLlib — that provides general purpose Graph APIs including graph-parallel computation:


GraphX APIs are great but present a few limitations. First they only work with Scala, so if you want to use GraphX with Python in a Jupyter Notebook, then you are out of luck. The second limitation is that they only work at the RDD (Resilient Distributed Dataset) level, which means that they can’t benefit from the performance improvement provided by DataFrames and the Catalyst query optimizer. GraphFrames is an open source Spark Package that was created with goal of addressing these two issues:

  • Provides a set of Python APIs
  • Works with DataFrames

In this post, we’ll show how to get started with GraphFrames from a Python Notebook. We’ll start by creating a graph composed of airports as the vertices and flight routes as the edges, using the data from the flight predict application. I’ll then show interesting ways of visualizing the data and apply various graph algorithms to extract insights from the data.

Installing GraphFrames

As previously mentioned, GraphFrames will be part of the Spark 2.0 distribution, but it’s currently available as a preview Spark package compatible with Spark 1.6 and higher. There are multiple ways to install the package depending on how you are running Spark:

  • Spark-submit or Spark-shell: simply add --packages graphframes:graphframes:0.1.0-spark1.6 as a command-line argument
  • Local Jupyter Notebook: assuming that you have access to the configuration files, all you need is to add --packages graphframes:graphframes:0.1.0-spark1.6 to the kernel.json located in ~/.ipython/kernels/<yourkernel>/kernel.json.
            "display_name": "pySpark (Spark 1.6.0) with graphFrames",
            "language": "python",
            "argv": [
            "env": {
                "SPARK_HOME": "/Users/dtaieb/cdsdev/spark-1.6.0",
                "PYTHONPATH": "/Users/dtaieb/cdsdev/spark-1.6.0/python/:/Users/dtaieb/cdsdev/spark-1.6.0/python/lib/",
                "PYTHONSTARTUP": "/Users/dtaieb/cdsdev/spark-1.6.0/python/pyspark/",
                "PYSPARK_SUBMIT_ARGS": "--packages graphframes:graphframes:0.1.0-spark1.6 --master local[10] pyspark-shell",
  • IPython Notebook (hosted on IBM Bluemix Apache Spark™ service): When the notebook is hosted and you don’t have access to the configuration files, I wished there were a magic command that would add a Spark Package to the session. Unfortunately there is no such thing today, so I made one :boom:. I created a helper Python library called pixiedust that implements a workaround.

Note: The following steps currently only work on an python Notebook hosted on IBM Bluemix

Open your python Notebook and run the following code:

  1. Cell1: install the pixiedust library.
     !pip install --user pixiedust

    Or if you want to upgrade the version already installed:

     !pip install --user --upgrade --no-deps pixiedust
  2. Cell2: import the pixiedust packageManager module and install graphframes.
        from pixiedust.packageManager import PackageManager

    If all goes well, you should see a message printed in red in the output asking you to restart the kernel. You can do so using the menu: Kernel/Restart.

  3. Once the kernel has restarted, run Cell2 again. Even though the Graphframes jar file is now part of the classpath, you still need to run the command to add the GraphFrames python APIs to the SparkContext.
  4. Cell3: verify that GraphFrames is correctly installed.
        #import the display module
        from pixiedust.display import *
        #import the Graphs example
        from graphframes.examples import Graphs
        #create the friends example graph
        #use the pixiedust display

Results of the code above should look like this:

Graph sample run

Note: I’ll be using the pixiedust display() API call in this post without diving into the details of how it’s built, which I’ll cover in a future post.

Create a graph with airports as nodes and flight routes as edges

At a high level, GraphFrames is to GraphX what DataFrames is to RDDs. It is built on top of Spark SQL and provides a set of APIs that elegantly combine Graph Analytics and Graph Queries:

GraphFrames Architecture

Diving into technical details, you need two DataFrames to build a Graph: one DataFrame for vertices and a second DataFrame for edges. With graphFrames successfully installed, we are now ready to load the data from the flight predict application.

As a reminder, the data lives in two Cloudant databases:

  • flight-metadata: contains the airports info
  • flightpredict_training_set: contains the flight routes augmented with weather info

The first step is to configure the Cloudant-spark connector and load the 2 datasets:

    #Configure connector
    import training
    import run
    training.sqlContext = sqlContext

    #load the 2 datasets
    airports = training.loadDataSet("flight-metadata", "airports")
    print("airports count: " + str(airports.count()))
    flights = training.loadDataSet("pycon_flightpredict_training_set","training")
    print("flights count: " + str(flights.count()))


Successfully cached dataframe
Successfully registered SQL table airports
airports count: 17535
Successfully cached dataframe
Successfully registered SQL table training
flights count: 33336

In this step, we build the vertices and edges DataFrames for our graph. The vertices (airports) must all have at least one edge (flights). They also must have a column named “id” that uniquely identifies the vertex. To meet these two requirements, the cell below performs a join between airports and flights, and renames the column “fs” (airport code) to “id”.

from pyspark.sql import functions as f
from pyspark.sql.types import *
rdd = flights.flatMap(lambda s: [s.arrivalAirportFsCode, s.departureAirportFsCode]).distinct()
    .map(lambda row:[row])
vertices = airports.join(
      sqlContext.createDataFrame(rdd, StructType([StructField("fs",StringType())])), "fs"

The edges dataframe is almost ready, but we need to make sure that it has the columns “src” and “dst” that respectively reference the “id” of the source and destination airport. We also drop a few unneeded columns:


We can now build the graph and display it:

from graphframes import GraphFrame
g = GraphFrame(vertices, edges)

When you initially run this cell, you’ll see a table. But because pixiedust introspects the dataset, it knows it contains latitude and longitude coordinates that can be displayed on a map. Click the map pin icon to see the graph of airports and flights overlaid on a map of the United States:

Create Graph and display it

Note: The visualization above is coming from a sample pixiedust plugin that visualizes all the flights for selected airports. It also provides menus to display the vertices and edges as tables. To generate this image, I reused Mike Bostock’s d3-based map. Thanks, Mike!

Let’s do some graph computing!

Compute the degree for each vertex in the graph

The degree of a vertex is the number of edges incident to the vertex. In a directed graph, in-degree is the number of edges where vertex is the destination and out-degree is the number of edges where the vertex is the source. GraphFrames has properties for degrees, outDegrees and inDegrees. They return a DataFrame containing the id of the vertex and the number of edges. We then sort them in descending order:

from pyspark.sql.functions import *
degrees = g.degrees.sort(desc("degree"))
display( degrees )


Compute Graph Degrees

Compute a list of shortest paths for each vertex to a specified list of landmarks

For this example we use the shortestPaths api that returns a DataFrame containing the properties for each vertex plus an extra column called distances that contains the number of hops to each landmark. In the following code, we use BOS and LAX as the landmarks:

r = g.shortestPaths(landmarks=["BOS", "LAX"]).select("id", "distances")


Compute shortest paths

Compute the pageRank for each vertex in the graph

PageRank is a famous algorithm used by Google Search to rank vertices in a graph by order of importance. To compute pageRank, we’ll use the pageRank() API call that returns a new graph in which the vertices have a new pagerank column representing the pagerank score for the vertex, and the edges have a new weight column representing the edge weight that contributed to the pageRank score. We’ll then display the vertex ids and associated pageranks sorted in descending order:

from pyspark.sql.functions import *
ranks = g.pageRank(resetProbability=0.20, maxIter=5)


Compute page Rank

Search routes between two airports with specific criteria

In this section, we want to find all the routes between Boston and San Francisco operated by United Airlines with at most two hops. To perform this search, we use the bfs() (breadth-first search) API call that returns a DataFrame containing the shortest path between matching vertices. For clarity, we will only keep the edge when displaying the results:

paths = g.bfs(fromExpr="id='BOS'",toExpr="id = 'SFO'",edgeFilter="carrierFsCode='UA'", maxPathLength = 2).drop("from").drop("to")


Compute BFS

Find all airports that do not have direct flights between each other

In this section, we’ll use a very powerful graphFrames search feature that uses a pattern called motif to find nodes. We’ll use it to apply the pattern "(a)-[]->(b);(b)-[]->(c);!(a)-[]->(c)", which searches for all nodes a, b and c that have a path to (a,b) and a path to (b,c) but not a path to (a,c). Also, because the search is computationally expensive, we reduce the number of edges by grouping the flights that have the same src and dst.

from pyspark.sql import functions as F
h = GraphFrame(g.vertices,"src","dst").groupBy("src","dst").agg(F.count("src").alias("count")))
query = h.find("(a)-[]-&gt;(b);(b)-[]-&gt;(c);!(a)-[]-&gt;(c)").drop("b")


Motif findings

Compute the strongly connected components for this graph

Strongly Connected Components are components for which each vertex is reachable from every other vertex. To compute them, we’ll use the stronglyConnectedComponents() API call that returns a DataFrame containing all the vertices, with the addition of a component column that contains the id value of each connected vertex. We then group all the rows by components and aggregate the sum of all the member vertices. This gives us a good idea of the components distribution in the graph.

from pyspark.sql.functions import *
components = g.stronglyConnectedComponents(maxIter=10).select("id","component")


Compute Strongly Connected components

Detect communities in the graph using Label Propagation algorithm

Label propagation is a popular algorithm for finding communities within a graph. It has the advantage of being computationally inexpensive and thus works well with large graphs. To compute the communities, we’ll use the labelPropagation() API call that returns a DataFrame containing all the vertices, with the addition of a label column that contains the id value of each connected vertex. Similar to the strongly connected components computation, we’ll then group all the rows by label and aggregate the sum of all the member vertices.

from pyspark.sql.functions import *
communities = g.labelPropagation(maxIter=5).select("id", "label")


Compute communities

Use AggregateMessages to compute the average flight delays by originating airport

AggregateMessages is a powerful building block for building graph algorithms. It works by sending messages between vertices and compounding all the results. In this cell, for each flight we send the corresponding flight delay to the source, then compute the average for each airport. Unfortunately, this api is not currently available in Python, so we use PixieDust Scala bridge to call out the Scala API.

Note: Notice that PixieDust is automatically rebinding the python GraphFrame variable g into a Scala GraphFrame with the same name.

import org.graphframes.lib.AggregateMessages
import org.apache.spark.sql.functions.{avg,desc,floor}

// For each airport, average the delays of the departing flights
val msgToSrc = AggregateMessages.edge("deltaDeparture")
val __agg = g.aggregateMessages
  .sendToSrc(msgToSrc)  // send each flight delay to source
  .agg(floor(avg(AggregateMessages.msg)).as("averageDelays"))  // average up all delays


| id|averageDelays|
|FLL|           27|
|EWR|           27|
|MDW|           26|
|LGA|           26|
|MIA|           26|
|JFK|           24|
|BWI|           23|
|ORD|           23|
|MSP|           23|
|PHX|           22|


In this post, we have learned several things:

  • How to use GraphFrames (and any other Spark packages) within an IPython notebook, including for the IBM Analytics for Apache Spark service on Bluemix.
  • We’ve introduced the pixiedust module that, among other things, provides a simple API to create compelling in-context interactive visualizations.
  • We’ve shown how to create a graph from data stored in the Cloudant JSON database service.
  • Finally, we’ve explored a few of the graph computation APIs provided by GraphFrames. Of course there is much more to explore, but hopefully this post gave you ideas you can reuse.

All the exercises and code are conveniently available in a completed Jupyter Notebook. Feel free to import it into your own Spark environment or on the IBM Apache Spark service — and use it as a starting point in your own project.

blog comments powered by Disqus