Getting Spark Data from AWS S3 using Boto and Pyspark

We’ve had quite a bit of trouble getting efficient Spark operation when the data to be processed is coming from an AWS S3 bucket. Aside from pulling all the data to the Spark driver prior to the first map step (something that defeats the purpose of map-reduce!), we experienced terrible performance. In one scenario, Spark spun up 2360 tasks to read the records from one 1.1k log file. In another scenario, the Spark logs showed that reading every line of every file took a handful of repetitive operations–validate the file, open the file, seek to the next line, read the line, close the file, repeat. Processing 450 small log files took 42 minutes. Arrgh.

We also experienced memory problems. When processing the full set of logs we would see out-of-memory heap errors or complaints about exceeding Spark’s data frame size. Very frustrating.

One of my co-workers stumbled upon this site: http://tech.kinja.com/how-not-to-pull-from-s3-using-apache-spark-1704509219 This site describes a solution that overcomes the problems we were having. The solution is written in scala. We had to translate because we are working in Python. Here is a snippet of the scala version of their driver program, taken from https://gist.githubusercontent.com/pjrt/f1cad93b154ac8958e65/raw/7b0b764408f145f51477dc05ef1a99e8448bce6d/S3Puller.scala.

import com.amazonaws.services.s3._, model._
import com.amazonaws.auth.BasicAWSCredentials

val request = new ListObjectsRequest()
request.setBucketName(bucket)
request.setPrefix(prefix)
request.setMaxKeys(pageLength)
def s3 = new AmazonS3Client(new BasicAWSCredentials(key, secret))

val objs = s3.listObjects(request) // Note that this method returns 
                                   // truncated data if longer than
                                   // the "pageLength" above. You might
                                   // need to deal with that.
sc.parallelize(objs.getObjectSummaries.map(_.getKey).toList)
    .flatMap { key => Source.fromInputStream(s3.getObject(bucket, key).getObjectContent: InputStream).getLines }

The key to the solution is to follow a process like this:

  1. Go directly to S3 from the driver to get a list of the S3 keys for the files you care about.
  2. Parallelize the list of keys.
  3. Code the first map step to pull the data from the files.

This procedure minimizes the amount of data that gets pulled into the driver from S3–just the keys, not the data. Then, when map is executed in parallel on multiple Spark workers, each worker pulls over the S3 file data for only the files it has the keys for.

S3 access from Python was done using the Boto3 library for Python:

pip install boto3

Here’s a snippet of the python code that is similar to the scala code, above. It is processing log files that are composed of lines of json text:

import argparse
from pyspark import SparkContext, SparkConf
from boto.s3.connection import S3Connection

def main():
    # Use argparse to handle some argument parsing
    parser.add_argument("-a",
                        "--aws_access_key_id",
                        help="AWS_ACCESS_KEY_ID, omit to use env settings",
                        default=None)
    parser.add_argument("-s",
                        "--aws_secret_access_key",
                        help="AWS_SECRET_ACCESS_KEY, omit to use env settings",
                        default=None)
    parser.add_argument("-b",
                        "--bucket_name",
                        help="AWS bucket name",
                        default="spirent-orion")
    # Use Boto to connect to S3 and get a list of objects from a bucket
    conn = S3Connection(args.aws_access_key_id, args.aws_secret_access_key)
    bucket = conn.get_bucket(args.bucket_name)
    keys = bucket.list()
    # Get a Spark context and use it to parallelize the keys
    conf = SparkConf().setAppName("MyFileProcessingApp")
    sc = SparkContext(conf=conf)
    pkeys = sc.parallelize(keys)
    # Call the map step to handle reading in the file contents
    activation = pkeys.flatMap(map_func)
    # Additional map or reduce steps go here...

def map_func(key)
    # Use the key to read in the file contents, split on line endings
    for line in key.get_contents_as_string().splitlines():
        # parse one line of json
        j = json.loads(line)
        if "user_id" in j && "event" in j:
            if j['event'] == "event_we_care_about":
                yield j['user_id'], j['event']

With these changes in place, the 42-minute job now takes under 6 minutes. That is still too long, so on to the next tuning step!

Advertisements