https://github.com/bigdatagenomics/    https://twitter.com/bigdatagenomics/

Special thanks to Neil Ferguson for this blog post on genomic analysis using ADAM, Spark and Deep Learning

Can we use deep learning to predict which population group you belong to, based solely on your genome?

Yes, we can – and in this post, we will show you exactly how to do this in a scalable way, using Apache Spark. We will explain how to apply deep learning using artifical neural networks to predict which population group an individual belongs to – based entirely on his or her genomic data.

This is a follow-up to an earlier post: Scalable Genomes Clustering With ADAM and Spark and attempts to replicate the results of that post. However, we will use a different machine learning technique. Where the original post used k-means clustering, we will use deep learning.

We will use ADAM and Apache Spark in combination with H2O, an open source predictive analytics platform, and Sparking Water, which integrates H2O with Spark.

Code

In this section, we’ll dive straight into the code. If you’d rather get something working before looking at the code you can skip to the “Building and Running” section.

The complete Scala code for this example can be found in the PopStrat.scala class on GitHub and we’ll refer to sections of the code here. Basic familiarity with Scala and Apache Spark is assumed.

Setting-up

The first thing we need to do is to read the names of the Genotype and Panel files that are passed into our program. The Genotype file contains data about a set of individuals (referred to here as “samples”) and their genetic variation. The Panel file lists the population group (or “region”) for each sample in the Genotype file; this is what we will try to predict.

1
2
val genotypeFile = args(0)
val panelFile = args(1)

Next, we set-up our Spark Context. Our program permits the Spark master to be specified as one of its arguments. This is useful when running from an IDE, but is omitted when running from the spark-submit script (see below).

1
2
3
4
val master = if (args.length > 2) Some(args(2)) else None
val conf = new SparkConf().setAppName("PopStrat")
master.foreach(conf.setMaster)
val sc = new SparkContext(conf)

Next, we declare a set called populations which contains all of the population groups that we’re interested in predicting. We then read the Panel file into a Map, filtering it based on the population groups in the populations set. The format of the panel file is described here. Luckily it’s very simple, containing the sample ID in the first column and the population group in the second.

1
2
3
4
5
6
7
8
val populations = Set("GBR", "ASW", "CHB")
def extract(file: String, filter: (String, String) => Boolean): Map[String,String] = {
  Source.fromFile(file).getLines().map(line => {
    val tokens = line.split("\t").toList
    tokens(0) -> tokens(1)
  }).toMap.filter(tuple => filter(tuple._1, tuple._2))
}
val panel: Map[String,String] = extract(panelFile, (sampleID: String, pop: String) => populations.contains(pop))

Preparing the Genomics Data

Next, we use ADAM to read our genotype data into a Spark RDD. Since we’ve imported ADAMContext._ at the top of our class, this is simply a matter of calling loadGenotypes on the Spark Context. Then, we filter the genotype data to contain only samples that are in the population groups which we’re interested in.

1
2
val allGenotypes: RDD[Genotype] = sc.loadGenotypes(genotypeFile)
val genotypes: RDD[Genotype] = allGenotypes.filter(genotype => {panel.contains(genotype.getSampleId)})

Next, we convert the ADAM Genotype objects into our own SampleVariant objects. These objects contain just the data we need for further processing: the sample ID (which uniquely identifies a particular sample), a variant ID (which uniquely identifies a particular genetic variant) and a count of alternate alleles, where the sample differs from the reference genome. These variations will help us to classify individuals according to their population group.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
case class SampleVariant(sampleId: String, variantId: Int, alternateCount: Int)
def variantId(genotype: Genotype): String = {
  val name = genotype.getVariant.getContig.getContigName
  val start = genotype.getVariant.getStart
  val end = genotype.getVariant.getEnd
  s"$name:$start:$end"
}
def alternateCount(genotype: Genotype): Int = {
  genotype.getAlleles.asScala.count(_ != GenotypeAllele.Ref)
}
def toVariant(genotype: Genotype): SampleVariant = {
  // Intern sample IDs as they will be repeated a lot
  new SampleVariant(genotype.getSampleId.intern(), variantId(genotype).hashCode(), alternateCount(genotype))
}
val variantsRDD: RDD[SampleVariant] = genotypes.map(toVariant)

Next, we count the total number of samples (individuals) in the data. We then group the data by variant ID and filter out those variants which do not appear in all of the samples. The aim of this is to simplify the processing of the data and, since we have a very large number of variants in the data (up to 30 million, depending on the exact data set), filtering out a small number will not make a significant difference to the results. In fact, in the next step we’ll reduce the number of variants even further.

1
2
3
4
5
6
val variantsBySampleId: RDD[(String, Iterable[SampleVariant])] = variantsRDD.groupBy(_.sampleId)
val sampleCount: Long = variantsBySampleId.count()
println("Found " + sampleCount + " samples")
val variantsByVariantId: RDD[(Int, Iterable[SampleVariant])] = variantsRDD.groupBy(_.variantId).filter {
  case (_, sampleVariants) => sampleVariants.size == sampleCount
}

When we train our machine learning model, each variant will be treated as a “feature)” that is used to train the model. Since it can be difficult to train machine learning models with very large numbers of features in the data (particularly if the number of samples is relatively small), we first need to try and reduce the number of variants in the data.

To do this, we first compute the frequency with which alternate alleles have occurred for each variant. We then filter the variants down to just those that appear within a certain frequency range. In this case, we’ve chosen a fairly arbitrary frequency of 11. This was chosen through experimentation as a value that leaves around 3,000 variants in the data set we are using.

There are more structured approaches to dimensionality reduction, which we perhaps could have employed, but this technique seems to work well enough for this example.

1
2
3
4
5
6
7
8
9
10
val variantFrequencies: collection.Map[Int, Int] = variantsByVariantId.map {
  case (variantId, sampleVariants) => (variantId, sampleVariants.count(_.alternateCount > 0))
}.collectAsMap()
val permittedRange = inclusive(11, 11)
val filteredVariantsBySampleId: RDD[(String, Iterable[SampleVariant])] = variantsBySampleId.map {
  case (sampleId, sampleVariants) =>
    val filteredSampleVariants = sampleVariants.filter(variant => permittedRange.contains(
      variantFrequencies.getOrElse(variant.variantId, -1)))
    (sampleId, filteredSampleVariants)
}

Creating the Training Data

To train our model, we need our data to be in tabular form where each row represents a single sample, and each column represents a specific variant. The table also contains a column for the population group or “Region”, which is what we are trying to predict.

Ultimately, in order for our data to be consumed by H2O we need it to end up in an H2O DataFrame object. Currently, the best way to do this in Spark seems to be to convert our data to an RDD of Spark SQL Row objects, and then this can automatically be converted to an H2O DataFrame.

To achieve this, we first need to group the data by sample ID, and then sort the variants for each sample in a consistent manner (by variant ID). We can then create a header row for our table, containing the Region column, the sample ID and all of the variants. We then create an RDD of type Row for each sample.

1
2
3
4
5
6
7
8
9
10
11
12
val sortedVariantsBySampleId: RDD[(String, Array[SampleVariant])] = filteredVariantsBySampleId.map {
  case (sampleId, variants) =>
    (sampleId, variants.toArray.sortBy(_.variantId))
}
val header = StructType(Array(StructField("Region", StringType)) ++
  sortedVariantsBySampleId.first()._2.map(variant => {StructField(variant.variantId.toString, IntegerType)}))
val rowRDD: RDD[Row] = sortedVariantsBySampleId.map {
  case (sampleId, sortedVariants) =>
    val region: Array[String] = Array(panel.getOrElse(sampleId, "Unknown"))
    val alternateCounts: Array[Int] = sortedVariants.map(_.alternateCount)
    Row.fromSeq(region ++ alternateCounts)
}

As mentioned above, once we have our RDD of Row objects we can then convert these automatically to an H2O DataFrame using Sparking Water (H2O’s Spark integration).

1
2
3
4
5
val sqlContext = new org.apache.spark.sql.SQLContext(sc)
val schemaRDD = sqlContext.applySchema(rowRDD, header)
val h2oContext = new H2OContext(sc).start()
import h2oContext._
val dataFrame = h2oContext.toDataFrame(schemaRDD)

Now that we have a DataFrame, we want to split it into the training data (which we’ll use to train our model), and a test set (which we’ll use to ensure that overfitting has not occurred).

We will also create a “validation” set, which performs a similar purpose to the test set – in that it will be used to validate the strength of our model as it is being built, while avoiding overfitting. However, when training a neural network, we typically keep the validation set distinct from the test set, to enable us to learn hyper-parameters for the model. See chapter 3 of Michael Nielsen’s “Neural Networks and Deep Learning” for more details on this.

H2O comes with a class called FrameSplitter, so splitting the data is simply a matter of calling creating one of those and letting it split the data set.

1
2
3
4
5
val frameSplitter = new FrameSplitter(dataFrame, Array(.5, .3), Array("training", "test", "validation").map(Key.make), null)
water.H2O.submitTask(frameSplitter)
val splits = frameSplitter.getResult
val training = splits(0)
val validation = splits(2)

Training the Model

Next, we need to set the parameters for our deep learning model. We specify the training and validation data sets, as well as the column in the data which contains the item we are trying to predict (in this case, the Region). We also set some hyper-parameters which affect the way the model learns. We won’t go into detail about these here, but you can read more in the H2O documentation. These parameters have been chosen through experimentation – however, H2O provides methods for automatically tuning hyper-parameters so it may be possible to achieve better results by employing one of these methods.

1
2
3
4
5
6
7
val deepLearningParameters = new DeepLearningParameters()
deepLearningParameters._train = training
deepLearningParameters._valid = validation
deepLearningParameters._response_column = "Region"
deepLearningParameters._epochs = 10
deepLearningParameters._activation = Activation.RectifierWithDropout
deepLearningParameters._hidden = Array[Int](100,100)

Finally, we’re ready to train our deep learning model! Now that we’ve set everything up this is easy: we simply create a H2O DeepLearning object and call trainModel on it.

1
2
val deepLearning = new DeepLearning(deepLearningParameters)
val deepLearningModel = deepLearning.trainModel.get

Having trained our model in the previous step, we now need to check how well it predicts the population groups in our data set. To do this we “score” our entire data set (including training, test, and validation data) against our model:

1
deepLearningModel.score(dataFrame)('predict)

This final step will print a confusion matrix which shows how well our model predicts our population groups. All being well, the confusion matrix should look something like this:

1
2
3
4
5
6
Confusion Matrix (vertical: actual; across: predicted):
ASW    CHB GBR  Error      Rate
ASW     60   1   0 0.0164 = 1 /  61
CHB      0 103   0 0.0000 = 0 / 103
GBR      0   1  90 0.0110 = 1 /  91
Totals  60 105  90 0.0078 = 2 / 255

This tells us that the model has correctly predicted 253 out of 255 population groups correctly (an accuracy of more than 99%). Nice!

Building and Running

Prerequisites

Before building and running the example, please ensure you have version 7 or later of the Java JDK installed.

Building

To build the example, first clone the GitHub repo at https://github.com/nfergu/popstrat.

Then download and install Maven. Then, at the command line, type:

1
mvn clean package

This will build a JAR (target/uber-popstrat-0.1-SNAPSHOT.jar), containing the PopStrat class, as well as all of its dependencies.

Running

First, download Spark version 1.2.0 and unpack it on your machine.

Next you’ll need to get some genomics data. Go to your nearest mirror of the 1000 genomes FTP site. From the release/20130502/ directory download the ALL.chr22.phase3_shapeit2_mvncall_integrated_v5a.20130502.genotypes.vcf.gz file and the integrated_call_samples_v3.20130502.ALL.panel file. The first file file is the genotype data for chromosome 22, and the second file is the panel file, which describes the population group for each sample in the genotype data.

Unzip the genotype data before continuing. This will require around 10GB of disk space.

To speed up execution and save disk space, you can convert the genotype VCF file to ADAM format (using the ADAM transform command) if you wish. However, this will take some time up-front. Both ADAM and VCF formats are supported.

Next, run the following command:

1
$ YOUR_SPARK_HOME/bin/spark-submit --class "com.neilferguson.PopStrat" --master local[6] --driver-memory 6G target/uber-popstrat-0.1-SNAPSHOT.jar <genotypesfile> <panelfile>

Replacing <genotypesfile> with the path to your genotype data file (ADAM or VCF), and <panelfile> with the panel file from 1000 genomes.

This runs the example using a local (in-process) Spark master with 6 cores and 6GB of RAM. You can run against a different Spark cluster by modifying the options in the above command line. See the Spark documentation for further details.

Using the above data, the example may take up to 2-3 hours to run, depending on hardware. When it is finished, you should see a confusion matrix which shows the predicted versus the actual populations. If all has gone well, this should show an accuracy of more than 99%. See the “Code” section above for more details on what exactly you should expect to see.

Conclusion

In this post, we have shown how to combine ADAM and Apache Spark with H2O’s deep learning capabilities to predict an individual’s population group based on his or her genomic data. Our results demonstrate that we can predict these very well, with more than 99% accuracy. Our choice of technologies makes for a relatively straightforward implementation, and we expect it to be very scalable.

Future work could involve validating the scalability of our solution on more hardware, trying to predict a wider range of population groups (currently we only predict 3 groups), and tuning the deep learning hyper-parameters to achieve even better accuracy.

Comments