跳至主要內容

Fork/Join使用详解

fangzhipeng约 1401 字大约 5 分钟

简介

Java Fork/Join 是 Java 7 引入的一个框架,用于实现并行计算。它基于 "分而治之" 的思想,使用递归的方式将一个大任务拆分成多个小任务,然后并行地执行这些小任务,最后将结果合并起来得到最终结果。

Fork/Join的运行流程大致如下所示:

image-20240121201736740
image-20240121201736740

核心模块

Java Fork/Join 框架的核心是 ForkJoinPool 类,它是一个特殊的线程池,内部使用工作窃取算法来实现任务的并行执行。

Fork/Join 框架中的主要组件包括:

  1. ForkJoinPool:是一个线程池,用于执行 Fork/Join 任务。它管理和调度任务的执行,并可根据需要创建新的工作线程。

  2. ForkJoinTask:是一个抽象类,表示 Fork/Join 框架中的任务。ForkJoinTask 分为两个子类:

    • RecursiveTask:用于返回结果的任务,继承它并实现 compute() 方法来执行任务并返回结果。

    • RecursiveAction:用于不返回结果的任务,继承它并实现 compute() 方法来执行任务。

ForkJoinPool继承关系

img
img

内部类介绍:

  • ForkJoinWorkerThreadFactory: 内部线程工厂接口,用于创建工作线程ForkJoinWorkerThread
  • DefaultForkJoinWorkerThreadFactory: ForkJoinWorkerThreadFactory 的默认实现类
  • InnocuousForkJoinWorkerThreadFactory: 实现了 ForkJoinWorkerThreadFactory,无许可线程工厂,当系统变量中有系统安全管理相关属性时,默认使用这个工厂创建工作线程。
  • EmptyTask: 内部占位类,用于替换队列中 join 的任务。
  • ManagedBlocker: 为 ForkJoinPool 中的任务提供扩展管理并行数的接口,一般用在可能会阻塞的任务(如在 Phaser 中用于等待 phase 到下一个generation)。
  • WorkQueue: ForkJoinPool 的核心数据结构,本质上是work-stealing 模式的双端任务队列,内部存放 ForkJoinTask 对象任务,使用 @Contented 注解修饰防止伪共享。
    • 工作线程在运行中产生新的任务(通常是因为调用了 fork())时,此时可以把 WorkQueue 的数据结构视为一个栈,新的任务会放入栈顶(top 位);工作线程在处理自己工作队列的任务时,按照 LIFO 的顺序。
    • 工作线程在处理自己的工作队列同时,会尝试窃取一个任务(可能是来自于刚刚提交到 pool 的任务,或是来自于其他工作线程的队列任务),此时可以把 WorkQueue 的数据结构视为一个 FIFO 的队列,窃取的任务位于其他线程的工作队列的队首(base位)。
  • 伪共享状态: 缓存系统中是以缓存行(cache line)为单位存储的。缓存行是2的整数幂个连续字节,一般为32-256个字节。最常见的缓存行大小是64个字节。当多线程修改互相独立的变量时,如果这些变量共享同一个缓存行,就会无意中影响彼此的性能,这就是伪共享。

ForkJoinTask继承关系

img
img

ForkJoinTask 实现了 Future 接口,说明它也是一个可取消的异步运算任务,实际上ForkJoinTask 是 Future 的轻量级实现,主要用在纯粹是计算的函数式任务或者操作完全独立的对象计算任务。fork 是主运行方法,用于异步执行;而 join 方法在任务结果计算完毕之后才会运行,用来合并或返回计算结果。 其内部类都比较简单,ExceptionNode 是用于存储任务执行期间的异常信息的单向链表;其余四个类是为 Runnable/Callable 任务提供的适配器类,用于把 Runnable/Callable 转化为 ForkJoinTask 类型的任务(因为 ForkJoinPool 只可以运行 ForkJoinTask 类型的任务)。

使用示例

以下是一个简单的使用 Fork/Join 框架的示例,假设我们要计算一个较大数组中所有元素的总和:

package io.github.forezp.concurrentlab.threadpool;

import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveTask;

public class ForkJoinExample {

   static class SumTask extends RecursiveTask<Long> {
        private static final int THRESHOLD = 1000; // 设置阈值,小于该阈值的任务将不再细分,直接计算结果
        private int[] array;
        private int start;
        private int end;

        public SumTask(int[] array, int start, int end) {
            this.array = array;
            this.start = start;
            this.end = end;
        }

        @Override
        protected Long compute() {
            if (end - start <= THRESHOLD) { // 如果任务足够小,直接计算结果
                long sum = 0;
                for (int i = start; i < end; i++) {
                    sum += array[i];
                }
                return sum;
            } else { // 否则细分为更小的子任务
                int mid = (start + end) >>> 1;
                SumTask left = new SumTask(array, start, mid);
                SumTask right = new SumTask(array, mid, end);
                left.fork(); // 异步执行左边的子任务
                long rightResult = right.compute(); // 同步执行右边的子任务
                long leftResult = left.join(); // 获取左边子任务的结果
                return leftResult + rightResult;
            }
        }
    }

    public static void main(String[] args) {
        int[] array = {1,2,3,4,5,6,7,8}; // 假设有一个很大的数组

        ForkJoinPool forkJoinPool = new ForkJoinPool();
        long result = forkJoinPool.invoke(new SumTask(array, 0, array.length)); // 同步执行任务,并获取结果
        System.out.println("Sum: " + result);
    }
}

在上面的示例中,我们定义了一个继承 RecursiveTaskSumTask 类来表示计算数组元素总和的任务。在 compute() 方法中,我们检查了任务的大小是否小于阈值,如果小于阈值,则直接计算结果;否则将任务拆分为更小的子任务,并使用 fork()join() 方法实现子任务的并行执行和结果的合并。

具体将任务拆成了四个子任务:

  • 1+2
  • 3+4
  • 5+6
  • 7+8
image-20240121211859054
image-20240121211859054

最终等待子任务执行完成,合并结果。

参考

https://pdai.tech/md/java/thread/java-thread-x-juc-executor-ForkJoinPool.htmlopen in new window

方志朋_官方公众号
方志朋_官方公众号