Java AIO 源码解析

本篇为 Java 原生 AIO 编程的源码浅析。

Demo

@Slf4j
public class AioTimeServer {
    public static void main(String[] args) throws Exception {
        AsynchronousChannelGroup group =
                AsynchronousChannelGroup.withFixedThreadPool(4, r -> new Thread(r, "I/O Thread"));
        AsynchronousServerSocketChannel server = AsynchronousServerSocketChannel.open(group);
        server.bind(new InetSocketAddress(8081));
        log.debug("started.");
        server.accept(null, new CompletionHandler<AsynchronousSocketChannel, Void>() {
            @Override
            public void completed(AsynchronousSocketChannel channel, Void att) {
                server.accept(null, this);
                log.debug("connected");
                ByteBuffer recvBuf = ByteBuffer.wrap(new byte[4]);
                channel.read(recvBuf, null, new CompletionHandler<Integer, Object>() {
                    @Override
                    public void completed(Integer result, Object attachment) {
                        if (result == -1) {
                            try {
                                channel.close();
                            } catch (Throwable e) {
                                log.error("close failed.", e);
                            }
                            log.debug("channel closed.");
                            return;
                        }
                        if (result < 4) {
                            channel.read(recvBuf, null, this);
                            return;
                        }
                        byte[] content = recvBuf.array();
                        log.debug("request <== {}", new String(content, StandardCharsets.UTF_8));
                        // 异步业务处理
                        // pool.submit(new BizTask(Arrays.copyOf(content, content.length)));
                        recvBuf.clear();
                        channel.read(recvBuf, null, this);

                        String resp = String.valueOf(System.currentTimeMillis());
                        ByteBuffer sendBuf = ByteBuffer.wrap(resp.getBytes(StandardCharsets.UTF_8));
                        channel.write(sendBuf, null, new CompletionHandler<Integer, Object>() {
                            @Override
                            public void completed(Integer result, Object attachment) {
                                if (result == -1) {
                                    try {
                                        channel.close();
                                    } catch (Throwable e) {
                                        log.error("close failed.", e);
                                    }
                                    log.debug("channel closed.");
                                    return;
                                }
                                if (result < 13) {
                                    channel.write(sendBuf, null, this);
                                    return;
                                }
                                log.debug("response ==> {}", new String(sendBuf.array(), StandardCharsets.UTF_8));
                                sendBuf.clear();
                            }

                            @Override
                            public void failed(Throwable exc, Object attachment) {
                                log.error("write error.", exc);
                            }
                        });
                    }

                    @Override
                    public void failed(Throwable exc, Object attachment) {
                        log.error("read error.", exc);
                    }
                });
            }

            @Override
            public void failed(Throwable exc, Void att) {
                log.error("accept error.", exc);
            }
        });
    }
}

源码分析

创建 AsynchronousChannelGroup

image-20191207104847766

AsynchronousChannelGroup 是一组互相可以共享资源的异步 Channel,这个“资源”正是线程资源,也即一系列的线程池。AsynchronousChannelGroup 有三个静态工厂方法 withFixedThreadPool、withCachedThreadPool、withThreadPool,提供三种初始化内部线程池组的方式,逻辑均为先拿到 AsynchronousChannelProvider,再调用重载的两个 openAsynchronousChannelGroup 方法。

获取 AsynchronousChannelProvider

与 NIO 的 Selector 类似,AsynchronousChannelProvider.provider() 方法所获取到的 Provider 的加载也是依次通过 查找系统属性SPI 机制使用默认 的优先级进行的:

AsynchronousChannelProvider p;
// 相关的系统属性为:-Djava.nio.channels.spi.AsynchronousChannelProvider
p = loadProviderFromProperty();
if (p != null)
    return p;
p = loadProviderAsService();
if (p != null)
    return p;
return sun.nio.ch.DefaultAsynchronousChannelProvider.create();

不特别指定的话,会使用平台相关实现 sun.nio.ch.DefaultAsynchronousChannelProvider 来创建,Linux 下的逻辑为:

String osname = AccessController.doPrivileged(new GetPropertyAction("os.name"));
if (osname.equals("SunOS"))
    return createProvider("sun.nio.ch.SolarisAsynchronousChannelProvider");
if (osname.equals("Linux"))
    return createProvider("sun.nio.ch.LinuxAsynchronousChannelProvider");
if (osname.contains("OS X"))
    return createProvider("sun.nio.ch.BsdAsynchronousChannelProvider");
if (osname.equals("AIX"))
    return createProvider("sun.nio.ch.AixAsynchronousChannelProvider");
throw new InternalError("platform not recognized");

Windows 下则使用 WindowsAsynchronousChannelProvider。

初始化 AsynchronousChannelGroup

我们查看 Linux 系统的 JRE 实现 LinuxAsynchronousChannelProvider:

@Override
public AsynchronousChannelGroup openAsynchronousChannelGroup(int nThreads, ThreadFactory factory) throws IOException {
    return new EPollPort(this, ThreadPool.create(nThreads, factory)).start();
}

@Override
public AsynchronousChannelGroup openAsynchronousChannelGroup(ExecutorService executor, int initialSize) throws IOException {
    return new EPollPort(this, ThreadPool.wrap(executor, initialSize)).start();
}

ThreadPool 是一个线程池与 Channel 组的封装,内部持有一个 ExecutorService,create() 工厂方法使用 Executors.newFixedThreadPool() 创建内部 ExecutorService,wrap() 方法使用参数传入的 ExecutorService,而 createDefault() 则使用 Executors.newCachedThreadPool()

EPollPort 是通过 Linux epoll 机制实现的 AsynchronousChannelGroup,其继承关系如下:

EpollPort

构造 EpollPort 传入的 provider 自身和一个 ThreadPool 均直接持有,EpollPort 构造器中使用了 epoll 的几个系统调用创建 epoll 结构。

额外地,在构造器中还初始化了一个 BlockingQueue 用作事件的轮询源(一次轮询可能拿到多个 Event),并且直接添加了一个初始的特殊事件 NEED_TO_POLL。

“NEED_TO_POLL” 事件用于通知一个消费者(Handler 线程)所有事件均已派发完,重新轮询去吧。

EPollPort(AsynchronousChannelProvider provider, ThreadPool pool) throws IOException {
    // epoll 初始化、使用 ”Self-Pipe Trick”
    // ...
    this.queue = new ArrayBlockingQueue<Event>(MAX_EPOLL_EVENTS);
    this.queue.offer(NEED_TO_POLL);
}

启动事件处理任务

初始化 EpollPort 后 LinuxAsynchronousChannelProvider 调用了 start() 方法。

EPollPort start() {
    startThreads(new EventHandlerTask());
    return this;
}
protected final void startThreads(Runnable task) {
    if (!isFixedThreadPool()) {
        for (int i=0; i<internalThreadCount; i++) {
            startInternalThread(task);
            threadCount.incrementAndGet();
        }
    }
    if (pool.poolSize() > 0) {
        task = bindToGroup(task);
        try {
            for (int i=0; i<pool.poolSize(); i++) {
                pool.executor().execute(task);
                threadCount.incrementAndGet();
            }
        } catch (RejectedExecutionException  x) {
            // nothing we can do
        }
    }
}

这里 isFixedThreadPool() 的要求必须是通过 ThreadPool.create() 创建的 ThreadPool,而使用 ThreadPool.wrap() 或者 ThreadPool.createDefault() 创建的线程池,前者不确定其实现(只确定是 ExecutorService),后者核心线程池为空(Cached),所以使用使用这两个方法创建的 ThreadPool 时,将会开启 internalThreadCount 数量的内部 I/O 线程去执行 task,保证至少有一个线程在跑 task(polling)。

internalThreadCount 受系统属性 -Dsun.nio.ch.internalThreadPoolSize 支配,默认为 1。内部线程不绑定 ChannelGroup,不参与 CompletionHandler 的 I/O 事件回调,仅等待 I/O 事件并提交任务给线程池,伪代码如下,同图中示意(Invoker 类下文介绍):

if (notInternalThread()) {
    Invoker.invokeUnchecked();
} else {
    Invoker.invokeIndirectly();
}

fixed thread pool:

image-20191215212509720

cached 和自定义的 thread pools:

image-20191215212517601

// AsynchronousChannelGroupImpl
private Runnable bindToGroup(final Runnable task) {
    final AsynchronousChannelGroupImpl thisGroup = this;
    return new Runnable() {
        public void run() {
            Invoker.bindToGroup(thisGroup);
            task.run();
        }
    };
}

// sun.nio.ch.Invoker
static void bindToGroup(AsynchronousChannelGroupImpl group) {
    myGroupAndInvokeCount.set(new GroupAndInvokeCount(group));
}

static class GroupAndInvokeCount {
    private final AsynchronousChannelGroupImpl group;
    private int handlerInvokeCount;
    GroupAndInvokeCount(AsynchronousChannelGroupImpl group) {
        this.group = group;
    }
    AsynchronousChannelGroupImpl group() {
        return group;
    }
    int invokeCount() {
        return handlerInvokeCount;
    }
    void setInvokeCount(int value) {
        handlerInvokeCount = value;
    }
    void resetInvokeCount() {
        handlerInvokeCount = 0;
    }
    void incrementInvokeCount() {
        handlerInvokeCount++;
    }
}

private static final ThreadLocal<GroupAndInvokeCount> myGroupAndInvokeCount =
    new ThreadLocal<GroupAndInvokeCount>() {
        @Override protected GroupAndInvokeCount initialValue() {
            return null;
        }
    };

bindToGroup() 通过 ThreadLocal 机制将 EpollPort 绑定到运行 task 的线程,并且还将一个用于“已触发过的 CompletionHandler”计数的 handlerInvokeCount 计数器绑定到了将来这个线程。

说了这么多,其实就是用某种方式搞了个 ExecutorService 实例,然后将 task 交给这个 ExecutorService 的核心线程去 run,下面我们看看这个事件处理任务 EventHandlerTask。

private class EventHandlerTask implements Runnable {
    private Event poll() throws IOException {
        try {
            for (;;) {
                // epoll 拿到 n 个事件
                int n = epollWait(epfd, address, MAX_EPOLL_EVENTS);
                // 对 fdToChannel 映射需要一致读,故加对应读锁
                fdToChannelLock.readLock().lock();
                try {
                    while (n-- > 0) {
                        // 从 address 拿到第 n 个事件对应的 channel 事件地址
                        long eventAddress = getEvent(address, n);
                        // 通过事件地址拿到事件对应 fd(channel)
                        int fd = getDescriptor(eventAddress);
                        // 如果是 sp[0] 的事件,则意味着被 wakeup() 强制唤醒或者需要停止 epoll
                        if (fd == sp[0]) {
                            // 被唤醒,计数器 wakeupCount 每次唤醒都会 +1,如果当前的唤醒已经全部处理完,则“排干”管道,即把 sp[0] 待读的全部读完
                            if (wakeupCount.decrementAndGet() == 0) {
                                drain1(sp[0]);
                            }
                            // 多余入队,最后一个直接返回给当前线程处理
                            if (n > 0) {
                                queue.offer(EXECUTE_TASK_OR_SHUTDOWN);
                                continue;
                            }
                            return EXECUTE_TASK_OR_SHUTDOWN;
                        }
                        // 拿到 fd 对应的 channel
                        PollableChannel channel = fdToChannel.get(fd);
                        if (channel != null) {
                            // 根据事件地址拿到事件
                            int events = getEvents(eventAddress);
                            // 将 channel 与其上发生的事件包装成一个 Event 对象
                            Event ev = new Event(channel, events);
                            // 多余入队,最后一个直接返回给当前线程处理(省一次入队出队的过程,提高性能)
                            if (n > 0) {
                                queue.offer(ev);
                            } else {
                                return ev;
                            }
                        }
                    }
                } finally {
                    // 解读锁
                    fdToChannelLock.readLock().unlock();
                }
            }
        } finally {
            // 保证触发下一次 poll
            queue.offer(NEED_TO_POLL);
        }
    }

    public void run() {
        Invoker.GroupAndInvokeCount myGroupAndInvokeCount =
            Invoker.getGroupAndInvokeCount();
        // 根据是否 bindToGroup 过,来区分是否是核心线程
        final boolean isPooledThread = (myGroupAndInvokeCount != null);
        boolean replaceMe = false;
        Event ev;
        try {
            for (;;) {
                // 每次 poll 都重置计数
                if (isPooledThread)
                    myGroupAndInvokeCount.resetInvokeCount();

                try {
                    // “当前线程是否即将退出”的标识符
                    replaceMe = false;
                    // 从队列中取一个事件
                    ev = queue.take();
                    // 如果是特殊的 NEED_TO_POLL 则重新去 poll
                    if (ev == NEED_TO_POLL) {
                        try {
                            // poll 直接拿到当前线程需要处理的事件(直接返回而未入队的事件)
                            ev = poll();
                        } catch (IOException x) {
                            x.printStackTrace();
                            return;
                        }
                    }
                } catch (InterruptedException x) {
                    continue;
                }
                // 处理特殊事件--唤醒
                if (ev == EXECUTE_TASK_OR_SHUTDOWN) {
                    Runnable task = pollTask();
                    // 如果 taskQueue 没有任务,意味着事件的意思是停止任务
                    if (task == null) {
                        return;
                    }
                    // taskQueue 里还有任务,执行 task,这里提前设置 replaceMe 是因为 run() 可能抛异常出来
                    replaceMe = true;
                    task.run();
                    continue;
                }
                // 回调 CompletionHandler
                try {
                    ev.channel().onEvent(ev.events(), isPooledThread);
                } catch (Error x) {
                    replaceMe = true; throw x;
                } catch (RuntimeException x) {
                    replaceMe = true; throw x;
                }
            }
        } finally {
            // 退出循环,replaceMe 若为 true,则使用其它线程代替当前线程继续进行本 task 的执行,否则执行线程退出操作,如果是最后退出的线程,则执行 implClose() 清理资源
            int remaining = threadExit(this, replaceMe);
            if (remaining == 0 && isShutdown()) {
                implClose();
            }
        }
    }
}

这里有两个 Queue,分别装着包装好的 Event 和需要执行的 Runnable。

replaceMe 这个变量是为了防止执行时出现 Error 或 Exception,如果出现了,就进入 threadExit() 判断一下是否要替换掉当前线程,而如果是 shutdown 流程,则直接结束 task 的执行。

整个逻辑其实很清楚,就是在不断的 poll() 进行 epoll_wait 取事件,然后回调事件关注者(CompletionHandler),但是这个流程可能在多个 ThreadPool 核心线程中执行,所以需要一些加锁逻辑。

用伪代码概括一下 EventHandlerTask:

private class EventHandlerTask implements Runnable {
    private Event poll() {
        int result = epollWait();
        Event ev = wrap(result);
        queue.offer(ev);
    }

    public void run() {
        for (;;) {
            poll();
            Event ev = queue.take();
            ev.channel().onEvent(ev.events());
        }
    }
}

ev.channel().onEvent(ev.events(), isPooledThread) 需要我们关注 Event 的 channel,即 Port.PollableChannel 实现类,让我们回到 demo 中,截至目前我们已经拥有了一个 AsynchronousChannelGroup,具备了 epoll 的能力,接下来我们来看 PollableChannel 是如何进行事件监听和回调的。

AsynchronousServerSocketChannelImpl 建连

还是老套路,AsynchronousServerSocketChannel.open() 使用 provider 查找实现类,Linux 下使用的是 UnixAsynchronousServerSocketChannelImpl。

server.accept() 调用 UnixAsynchronousServerSocketChannelImpl#implAccept()(已精简):

// 先直接尝试建连
int n = accept(this.fd, newfd, isaa);

// 还未有可建连 remote,注册 server socket 到 epoll 并关注建连事件,然后组装 future 返回
if (n == IOStatus.UNAVAILABLE) {
    PendingFuture<AsynchronousSocketChannel,Object> result = ...;
    port.startPoll(fdVal, Net.POLLIN);
    return result;
}

// invokeIndirectly 建连
AsynchronousSocketChannel child = finishAccept(newfd, isaa[0], null);
Invoker.invokeIndirectly(this, handler, att, child, exc);

可以看到,这里有个优化,判断当前是否可以直接建连,如果成功,则通过 Invoker.invokeIndirectly() 提交 channel group 线程池触发 CompletionHandler 回调,否则才将 server socket 注册到 epoll 数据结构中去,通过上面 EventHandlerTask 的事件通知机制去回调 onEvent(),然后在 onEvent() 中回调 CompletionHandler。而 onEvent() 逻辑与“先尝试建连”时的逻辑一样,只不过使用的是 Invoker.invoke()

int n = accept(this.fd, newfd, isaa);
AsynchronousSocketChannel child = finishAccept(newfd, isaa[0]);
Invoker.invoke(this, acceptHandler, acceptAttachment, child);

AsynchronousSocketChannelImpl 读写

不废话,拿到 UnixAsynchronousSocketChannelImpl,读、写实现类似,以读为例。

channel.read() 实际调用 UnixAsynchronousSocketChannelImpl#implRead()(已精简):

// 先直接尝试读
int n = IOUtil.read(fd, dst, -1, nd);

// 还不可读写,注册 server socket 到 epoll 并关注建连事件,然后组装 future 返回
if (n == IOStatus.UNAVAILABLE) {
    PendingFuture<V,A> result = ...;
    // 为未完成的 I/O 操作注册事件
    updateEvents();
    return result;
}

// 一些判断,决定 invokeDirect 还是 invokeIndirectly...

if (invokeDirect) {
    Invoker.invokeDirect(myGroupAndInvokeCount, handler, attachment, (V)result, exc);
} else {
    Invoker.invokeIndirectly(this, handler, attachment, (V)result, exc);
}

与 AsynchronousServerSocketChannelImpl 类似,channel.read()channel.write() 都会先进行一番尝试,而不是直接抛给 epoll,但是它们还会再加一层判断,决定调 Invoker.invokeDirect() 还是 Invoker.invokeIndirectly()

onEvent() 的实现也类似,这里不赘述,参见 finishRead() 方法。

UnixAsynchronousSocketChannelImpl 还实现了读写超时的功能,Demo 中没有设置,但最好设置一下,否则万一所有线程都在 CompletionHandler 回调中死锁,那整个应用就完了。并且 CompletionHandler 最好不要做一些阻塞的操作,业务处理最好能异步进行(如 Demo 所示)。

Java AIO 中,epoll 的感兴趣的事件的模式设置为了 EPOLLONESHOT,一旦事件被 epoll_wait 轮询到,除非重新调用 epoll_ctl,否则这个 channel 将在 epoll 内部被禁用。

Invoker

sun.nio.ch.Invoker 包含一系列静态方法,用于触发 CompletionHandler 或任意的 Task。由于在回调方法中我们会继续 accept、read 或者 write,当一个线程压栈太多次,有可能出现栈溢出,为避免栈溢出,并且使任务尽量均匀,不至于某些线程饥饿,Invoker 中有个 maxHandlerInvokeCount 变量(系统属性 sun.nio.ch.maxCompletionHandlersOnStack,默认 16),每调用一次 invokeDirect() 方法,都会使用 incrementInvokeCount() 增加 ThreadLocal<GroupAndInvokeCount> 变量 myGroupAndInvokeCount 的 invokeCount 计数器的值。

invoke() 方法中,会判断当前线程是否为传入的 AsynchronousChannel 所绑定的线程,并且判断计数器的值是否小于 maxHandlerInvokeCount,仅二者均满足时才在当前线程触发回调,否则会转而使用 invokeIndirectly() 方法提交给线程池来做一次“再均衡”,让其它线程有机会执行。

GroupAndInvokeCount thisGroupAndInvokeCount = myGroupAndInvokeCount.get();
if (thisGroupAndInvokeCount != null) {
    if ((thisGroupAndInvokeCount.group() == ((Groupable)channel).group()) &&
        thisGroupAndInvokeCount.invokeCount() < maxHandlerInvokeCount)) {
        invokeDirect = true;
    }
}
if (invokeDirect) {
    invokeDirect(thisGroupAndInvokeCount, handler, attachment, result, exc);
} else {
    try {
        invokeIndirectly(channel, handler, attachment, result, exc);
    } catch (RejectedExecutionException ree) {
        // ...
    }
}

总结

我们知道,epoll 只是让我们具备一种管理海量连接的功能,不关心具体的 I/O,而我们发现,在实际读写时我们仍旧使用同步的 read()write(),即使可以设置为非阻塞。

所以总的来说,在 Linux 下,Java AIO 仍是个 伪异步 I/O,是使用 epoll 事件进行 Future 封装模拟出来的。

并且,与 NIO 相比,Java AIO实现更加复杂,但底层原理没变,所以还是老老实实用 Netty 的 NIO 吧。

Windows 下的实现是真正的异步 I/O 实现,使用 IOCP,但是 Windows 不作为服务器最常用的 OS,很遗憾。