Java 8 Lambdas explained #1: Map/Reduce with Fork/Join and the Beauty of a One-Line Lambda

Lambdas are cool. Lambdas are sexy. Lambdas are expressive and let you write less code. Are these all a cliché? No! They are so true…

Recently we run a session locally explaining Lambdas with side-by-side comparisons of some coding patterns expressed in Java 7 vs. Java 8. We used the best we could obtain of each version of the language and APIs, including lambdas, stream processing, the new Date and Time API… It was fun seeing people reactions on how dramatically code can be reduced from Java 7 to Java 8 style.

In this post I want to show one example of the above: how a simple Map/Reduce pattern can be radically simplified if you switch the Fork/Join API from Java 7 to parallel stream processing in Java 8.

Setting up the project

Let’s start with setting up the project. A very simple Maven project configuration that explicitly sets Java 8 as source and target for compilation and adds a dependency to JavaTuples library:

<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
  <modelVersion>4.0.0</modelVersion>
  <groupId>deors.demos</groupId>
  <artifactId>deors.demos.java8</artifactId>
  <version>0.0.1-SNAPSHOT</version>
  <packaging>jar</packaging>
  <name>deors.demos.java8</name>
  <properties>
    <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
  </properties>
  <build>
    <plugins>
      <plugin>
        <groupId>org.apache.maven.plugins</groupId>
        <artifactId>maven-compiler-plugin</artifactId>
        <version>3.1</version>
        <configuration>
          <verbose>true</verbose>
          <compilerVersion>1.8</compilerVersion>
          <source>1.8</source>
          <target>1.8</target>
        </configuration>
      </plugin>
    </plugins>
  </build>
  <dependencies>
    <dependency>
      <groupId>org.javatuples</groupId>
      <artifactId>javatuples</artifactId>
      <version>1.2</version>
      <scope>compile</scope>
    </dependency>
  </dependencies>
</project>

The Problem

For this example, this is the problem we want to solve: given a list of integer tuples (pairs of integers), calculate the sum of the product of each pair. As we want the solution to scale, it should be done through parallel computing so the JVM running the code can take advantage of multi-core environments to split the problem in simpler pieces and make the processing faster.

In Java 7 fashion, a clever programmer could choose the Fork/Join API and create a task that, given the list of tuples, split it in halves when its size exceeds some defined threshold, process each half in parallel and recursively, aggregating the results from each piece until the final result is obtained. Not complex but a bit verbose even with the Fork/Join API. This is the code (don’t worry to write it down – this code can be obtained from GitHub here https://github.com/deors/deors.demos.java8):

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveTask;
import org.javatuples.Pair;

public class SumProductCalculationJava7 extends RecursiveTask<Integer> {
  private static final long serialVersionUID = 6939566748704874245L;
  private int threshold = 10;
  List<Pair<Integer, Integer>> pairList;

  public SumProductCalculationJava7(List<Pair<Integer, Integer>> pairList) {
    super();
    this.pairList = pairList;
  }

  @Override
  protected Integer compute() {
    System.out.printf("fragment %s size %d\n", this, pairList.size());
    if (pairList.size() <= threshold) {
       return computeDirect();
    }
    int split = pairList.size() / 2;

    List<Pair<Integer, Integer>> forkedList1 = pairList.subList(0, split);
    SumProductCalculationJava7 forkedTask1 = new SumProductCalculationJava7(forkedList1);
    forkedTask1.fork();

    List<Pair<Integer, Integer>> forkedList2 = pairList.subList(split, pairList.size());
    SumProductCalculationJava7 forkedTask2 = new SumProductCalculationJava7(forkedList2);
    forkedTask2.fork();

    return forkedTask1.join() + forkedTask2.join();
  }

  private Integer computeDirect() {
    Integer sumproduct = 0;
    for (Pair<Integer, Integer> pair : pairList) {
      sumproduct += pair.getValue0() * pair.getValue1();
    }
    System.out.printf("fragment %s total %s\n", this, sumproduct);    return sumproduct;
  }
}

In short, the compute method from RecursiveTask contract checks the size of the list. If it is lower than the threshold it calculates the sum-product using a private method with a for each loop. If not, it splits the list in two pieces and recursively processes the two fragments. The result of processing the two fragments is, as expected, the sum of both fragment results.

It is verbose, yes, but it does a lot of work in the back stage: use the best concurrency patterns, running the threads, waiting for each fragment to finish before aggregating results bottom-up and returning the final glorious result. Try to do the same with threads in Java 2 style and look how your hair will turn grey during the process!

However, Java 8 still can beat this by a large margin.

Testing the Problem

Now let’s run this task with a main method or JUnit test:

  public static void main(String[] args) {
    List<Pair<Integer, Integer>> thePairList = new ArrayList<>();
    thePairList.add(new Pair<Integer, Integer>(10, 1));
    thePairList.add(new Pair<Integer, Integer>(12, 2));
    thePairList.add(new Pair<Integer, Integer>(14, 3));
    thePairList.add(new Pair<Integer, Integer>(16, 4));
    thePairList.add(new Pair<Integer, Integer>(18, 5));
    thePairList.add(new Pair<Integer, Integer>(20, 6));
    thePairList.add(new Pair<Integer, Integer>(22, 7));
    thePairList.add(new Pair<Integer, Integer>(24, 8));
    thePairList.add(new Pair<Integer, Integer>(26, 9));
    thePairList.add(new Pair<Integer, Integer>(28, 10));
    thePairList.add(new Pair<Integer, Integer>(30, 11));
    thePairList.add(new Pair<Integer, Integer>(32, 12));
    thePairList.add(new Pair<Integer, Integer>(34, 13));
    thePairList.add(new Pair<Integer, Integer>(36, 14));

    SumProductCalculationJava7 theTask = new SumProductCalculationJava7(thePairList);
    ForkJoinPool thePool = new ForkJoinPool();
    Integer result = thePool.invoke(theTask);
    System.out.printf("the final result is %s\n", result);
  }

Well done! As expected, the final result is 2870!

Try now with different lists of tuples or just copy and paste to have a very long list in a few seconds. Change the threshold, measure execution times and fine-tune the program until it is as good as it can be.

You still cannot be sure whether it will be as optimal when executed in a different JVM, though, with different CPU and RAM available resources performance may change and it is likely that the best threshold is different depending on the JVM, machine resources, OS, workload, etc., but you can still deliver this with confidence and fine tune it for production later.

Same Problem. Different Approach

Now let’s do the same calculation above with Java 8 parallel streams and lambda expressions. Don’t expect too much – it’s really as simple as it looks. In fact it is too simple that I just coded everything in the main method:

import java.util.ArrayList;
import java.util.List;
import org.javatuples.Pair;

public class SumProductCalculationJava8 {
  public static void main(String[] args) {
    List<Pair<Integer, Integer>> thePairList = new ArrayList<>();
    thePairList.add(new Pair<Integer, Integer>(10, 1));
    thePairList.add(new Pair<Integer, Integer>(12, 2));
    thePairList.add(new Pair<Integer, Integer>(14, 3));
    thePairList.add(new Pair<Integer, Integer>(16, 4));
    thePairList.add(new Pair<Integer, Integer>(18, 5));
    thePairList.add(new Pair<Integer, Integer>(20, 6));
    thePairList.add(new Pair<Integer, Integer>(22, 7));
    thePairList.add(new Pair<Integer, Integer>(24, 8));
    thePairList.add(new Pair<Integer, Integer>(26, 9));
    thePairList.add(new Pair<Integer, Integer>(28, 10));
    thePairList.add(new Pair<Integer, Integer>(30, 11));
    thePairList.add(new Pair<Integer, Integer>(32, 12));
    thePairList.add(new Pair<Integer, Integer>(34, 13));
    thePairList.add(new Pair<Integer, Integer>(36, 14));

    Integer result = thePairList.parallelStream().
      mapToInt(p -> p.getValue0() * p.getValue1()).sum();
    System.out.printf("the final result is %s\n", result);
  }
}

Just one line of code to solve the same problem. Same result, same magic in the back stage, same powerful performance, but simple, expressive and productive!

Moreover, the JVM does all the optimizations for you. Depending on the variables mentioned above, like CPU, RAM or workload, the JVM can apply different approaches at runtime and make sure that your code runs as quick as possible in all conditions.

This is just one example on how Java 8 lets you write better code, so don’t wait for others to ask you to do it. Try to adopt Java 8 today!

Author: deors

senior technology architect in accenture, with a passion for technology related stuff, celtic music and the best sci-fi, among other thousand things!

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s