19 Java 8 Stream并行计算原理

19.1 Java 8 Stream简介

从Java 8 开始,我们可以使用Stream接口以及lambda表达式进行“流式计算”。它可以让我们对集合的操作更加简洁、更加可读、更加高效。

Stream接口有非常多用于集合计算的方法,比如判空操作empty、过滤操作filter、求最max值、查找操作findFirst和findAny等等。

19.2 Stream单线程串行计算

Stream接口默认是使用串行的方式,也就是说在一个线程里执行。下面举一个例子:

public class StreamDemo {
public static void main(String[] args) {
Stream.of(1, 2, 3, 4, 5, 6, 7, 8, 9)
.reduce((a, b) -> {
System.out.println(String.format("%s: %d + %d = %d",
Thread.currentThread().getName(), a, b, a + b));
return a + b;
})
.ifPresent(System.out::println);
}
}

我们来理解一下这个方法。首先我们用整数1~9创建了一个Stream。这里的Stream.of(T... values)方法是Stream接口的一个静态方法,其底层调用的是Arrays.stream(T[] array)方法。

然后我们使用了reduce方法来计算这个集合的累加和。reduce方法这里做的是:从前两个元素开始,进行某种操作(我这里进行的是加法操作)后,返回一个结果,然后再拿这个结果跟第三个元素执行同样的操作,以此类推,直到最后的一个元素。

我们来打印一下当前这个reduce操作的线程以及它们被操作的元素和返回的结果以及最后所有reduce方法的结果,也就代表的是数字1到9的累加和。

main: 1 + 2 = 3 main: 3 + 3 = 6 main: 6 + 4 = 10 main: 10 + 5 = 15 main: 15 + 6 = 21 main: 21 + 7 = 28 main: 28 + 8 = 36 main: 36 + 9 = 45 45

可以看到,默认情况下,它是在一个单线程运行的,也就是main线程。然后每次reduce操作都是串行起来的,首先计算前两个数字的和,然后再往后依次计算。

19.3 Stream多线程并行计算

我们思考上面一个例子,是不是一定要在单线程里进行串行地计算呢?假如我的计算机是一个多核计算机,我们在理论上能否利用多核来进行并行计算,提高计算效率呢?

当然可以,比如我们在计算前两个元素1 + 2 = 3的时候,其实我们也可以同时在另一个核计算 3 + 4 = 7。然后等它们都计算完成之后,再计算 3 + 7 = 10的操作。

是不是很熟悉这样的操作手法?没错,它就是ForkJoin框架的思想。下面小小地修改一下上面的代码,增加一行代码,使Stream使用多线程来并行计算:

public class StreamParallelDemo {
public static void main(String[] args) {
Stream.of(1, 2, 3, 4, 5, 6, 7, 8, 9)
.parallel()
.reduce((a, b) -> {
System.out.println(String.format("%s: %d + %d = %d",
Thread.currentThread().getName(), a, b, a + b));
return a + b;
})
.ifPresent(System.out::println);
}
}

可以看到,与上一个案例的代码只有一点点区别,就是在reduce方法被调用之前,调用了parallel()方法。下面来看看这个方法的输出:

ForkJoinPool.commonPool-worker-1: 3 + 4 = 7 ForkJoinPool.commonPool-worker-4: 8 + 9 = 17 ForkJoinPool.commonPool-worker-2: 5 + 6 = 11 ForkJoinPool.commonPool-worker-3: 1 + 2 = 3 ForkJoinPool.commonPool-worker-4: 7 + 17 = 24 ForkJoinPool.commonPool-worker-4: 11 + 24 = 35 ForkJoinPool.commonPool-worker-3: 3 + 7 = 10 ForkJoinPool.commonPool-worker-3: 10 + 35 = 45 45

可以很明显地看到,它使用的线程是ForkJoinPool里面的commonPool里面的worker线程。并且它们是并行计算的,并不是串行计算的。但由于Fork/Join框架的作用,它最终能很好的协调计算结果,使得计算结果完全正确。

如果我们用Fork/Join代码去实现这样一个功能,那无疑是非常复杂的。但Java8提供了并行式的流式计算,大大简化了我们的代码量,使得我们只需要写很少很简单的代码就可以利用计算机底层的多核资源。

19.4 从源码看Stream并行计算原理

上面我们通过在控制台输出线程的名字,看到了Stream的并行计算底层其实是使用的Fork/Join框架。那它到底是在哪使用Fork/Join的呢?我们从源码上来解析一下上述案例。

Stream.of方法就不说了,它只是生成一个简单的Stream。先来看看parallel()方法的源码。这里由于我的数据是int类型的,所以它其实是使用的BaseStream接口的parallel()方法。而BaseStream接口的JDK唯一实现类是一个叫AbstractPipeline的类。下面我们来看看这个类的parallel()方法的代码:

public final S parallel() {
sourceStage.parallel = true;
return (S) this;
}

这个方法很简单,就是把一个标识sourceStage.parallel设置为true。然后返回实例本身。

接着我们再来看reduce这个方法的内部实现。

Stream.reduce()方法的具体实现是交给了ReferencePipeline这个抽象类,它是继承了AbstractPipeline这个类的:

// ReferencePipeline抽象类的reduce方法
@Override
public final Optional<P_OUT> reduce(BinaryOperator<P_OUT> accumulator) {
// 调用evaluate方法
return evaluate(ReduceOps.makeRef(accumulator));
}
final <R> R evaluate(TerminalOp<E_OUT, R> terminalOp) {
assert getOutputShape() == terminalOp.inputShape();
if (linkedOrConsumed)
throw new IllegalStateException(MSG_STREAM_LINKED);
linkedOrConsumed = true;
return isParallel() // 调用isParallel()判断是否使用并行模式
? terminalOp.evaluateParallel(this, sourceSpliterator(terminalOp.getOpFlags()))
: terminalOp.evaluateSequential(this, sourceSpliterator(terminalOp.getOpFlags()));
}
@Override
public final boolean isParallel() {
// 根据之前在parallel()方法设置的那个flag来判断。
return sourceStage.parallel;
}

从它的源码可以知道,reduce方法调用了evaluate方法,而evaluate方法会先去检查当前的flag,是否使用并行模式,如果是则会调用evaluateParallel方法执行并行计算,否则,会调用evaluateSequential方法执行串行计算。

这里我们再看看TerminalOp(注意这里是字母l O,而不是数字1 0)接口的evaluateParallel方法。TerminalOp接口的实现类有这样几个内部类:

  • java.util.stream.FindOps.FindOp

  • java.util.stream.ForEachOps.ForEachOp

  • java.util.stream.MatchOps.MatchOp

  • java.util.stream.ReduceOps.ReduceOp

可以看到,对应的是Stream的几种主要的计算操作。我们这里的示例代码使用的是reduce计算,那我们就看看ReduceOp类的这个方法的源码:

// java.util.stream.ReduceOps.ReduceOp.evaluateParallel
@Override
public <P_IN> R evaluateParallel(PipelineHelper<T> helper,
Spliterator<P_IN> spliterator) {
return new ReduceTask<>(this, helper, spliterator).invoke().get();
}

evaluateParallel方法创建了一个新的ReduceTask实例,并且调用了invoke()方法后再调用get()方法,然后返回这个结果。那这个ReduceTask是什么呢?它的invoke方法内部又是什么呢?

追溯源码我们可以发现,ReduceTask类是ReduceOps类的一个内部类,它继承了AbstractTask类,而AbstractTask类又继承了CountedCompleter类,而CountedCompleter类又继承了ForkJoinTask类!

它们的继承关系如下:

ReduceTask -> AbstractTask -> CountedCompleter -> ForkJoinTask

这里的ReduceTask的invoke方法,其实是调用的ForkJoinTask的invoke方法,中间三层继承并没有覆盖这个方法的实现。

所以这就从源码层面解释了Stream并行的底层原理是使用了Fork/Join框架。

19.5 Stream并行计算的性能提升

我们可以在本地测试一下如果在多核情况下,Stream并行计算会给我们的程序带来多大的效率上的提升。用以下示例代码来计算一千万个随机数的和:

public class StreamParallelDemo {
public static void main(String[] args) {
System.out.println(String.format("本计算机的核数:%d", Runtime.getRuntime().availableProcessors()));
// 产生100w个随机数(1 ~ 100),组成列表
Random random = new Random();
List<Integer> list = new ArrayList<>(1000_0000);
for (int i = 0; i < 1000_0000; i++) {
list.add(random.nextInt(100));
}
long prevTime = getCurrentTime();
list.stream().reduce((a, b) -> a + b).ifPresent(System.out::println);
System.out.println(String.format("单线程计算耗时:%d", getCurrentTime() - prevTime));
prevTime = getCurrentTime();
list.stream().parallel().reduce((a, b) -> a + b).ifPresent(System.out::println);
System.out.println(String.format("多线程计算耗时:%d", getCurrentTime() - prevTime));
}
private static long getCurrentTime() {
return System.currentTimeMillis();
}
}

输出:

本计算机的核数:8 495156156 单线程计算耗时:223 495156156 多线程计算耗时:95

所以在多核的情况下,使用Stream的并行计算确实比串行计算能带来很大效率上的提升,并且也能保证结果计算完全准确。

本文一直在强调的“多核”的情况。其实可以看到,我的本地电脑有8核,但并行计算耗时并不是单线程计算耗时除以8,因为线程的创建、销毁以及维护线程上下文的切换等等都有一定的开销。所以如果你的服务器并不是多核服务器,那也没必要用Stream的并行计算。因为在单核的情况下,往往Stream的串行计算比并行计算更快,因为它不需要线程切换的开销。

参考资料