上下文对象跨线程传递方案

郭武

 最近在做个日志链路追踪的工具,遇到了上下文对象线程间传输的问题,在此分享下解决方案。 首先要说的是ThreadLocal类,其提供了在线程范围内提供了数据共享的能力。具体实现原理直接参考其源码就行,不难理解。看下面的例子:

 public static void main(String[] args) {
        ThreadLocal<String> local = new ThreadLocal<>();
        local.set("a");
        Thread t = new Thread(() -> {
            System.out.println(Thread.currentThread() + ":" + local.get());
            local.set("b");
            System.out.println(Thread.currentThread() + ":" + local.get());
        });
        t.start();
        try {
            t.join();
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        System.out.println(Thread.currentThread()+ ":" + local.get());
    }

输出结果如下: 由输出结果可知:
ThreadLocal中的数据只能在当前线程内获取的,子线程Thread-0不能获取到父线程main的ThreadLocal内的数据,其修改也只是在子线程内部生效。
如果子线程想获取到父线程所具有的值,怎么办呢?这时候要请出 InheritableThreadLocal这个类了,该类扩展了 ThreadLocal,为子线程提供从父线程那里继承的值:在创建子线程时,子线程会接收所有可继承的线程局部变量的初始值,以获得父线程所具有的值。示例如下:

 public static void main(String[] args) {
        InheritableThreadLocal<String> inheritableThreadLocal = new InheritableThreadLocal<>();
        inheritableThreadLocal.set("a");

        Thread t = new Thread(() -> {
            System.out.println(Thread.currentThread() + ":" + inheritableThreadLocal.get());
            inheritableThreadLocal.set("b");
            System.out.println(Thread.currentThread() + ":" + inheritableThreadLocal.get());
        });

        t.start();
        try {
            t.join();
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        System.out.println(Thread.currentThread()+ ":" + inheritableThreadLocal.get());
    }

输出如下: 由第一行输出日志可以看出,子线程已经获得了父线程那里继承到的值。 这样是不是就解决了多线程间对象的传递问题了呢?NO.

 对于使用线程池等会池化复用线程的组件的情况,线程由线程池创建好,并且线程是池化起来反复使用的;这时父子线程关系的ThreadLocal值传递已经没有意义,应用需要的实际上是把任务提交给线程池时的ThreadLocal值传递到任务执行时。 我们看下面的例子:

public static void main(String[] args) {  
        ExecutorService executorService = Executors.newFixedThreadPool(1);

        InheritableThreadLocal<String> inheritableThreadLocal = new InheritableThreadLocal<>();

        for (int i = 0; i < 10; i++) {
            inheritableThreadLocal.set("" + i);
            executorService.submit(() -> System.out.println(Thread.currentThread() + ":" + inheritableThreadLocal.get()));
        }
        try {
            Thread.sleep(5000);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
    }

输出结果如下: 并不是我们想象的0~9的输出。怎么解决呢?我们引入阿里开源的TransmittableThreadLocal,TransmittableThreadLocal类继承并加强InheritableThreadLocal类,完美解决了上述的问题。修改如下:

public static void main(String[] args) {  
        ExecutorService executorService = Executors.newFixedThreadPool(1);

        TransmittableThreadLocal<String> transmittableThreadLocal = new TransmittableThreadLocal<>();

        for (int i = 0; i < 10; i++) {
            transmittableThreadLocal.set("" + i);
            executorService.submit(TtlRunnable.get(() -> System.out.println(Thread.currentThread() + ":" + transmittableThreadLocal.get())));
        }

        try {
            Thread.sleep(5000);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
    }

输出: 总结:
 ThreadLocal设计之初就是为了绑定当前线程,如果希望当前线程的ThreadLocal能够被子线程使用,需要使用InheritableThreadLocal,但是大多数系统又会使用线程池技术,这时InheritableThreadLocal就不能解决了,需要引入TransmittableThreadLocal来处理。

参考资料: https://github.com/alibaba/transmittable-thread-local#-user-guide