大数据
流式处理
Spark
源码解析

Spark源码分析05 - 通信架构03:高层实现(1)RpcEnv和Dispatcher

简介:Spark针对各类场景,实现了不同的RpcHandler和StreamManager,在Spark Core模块的org.apache.spark.rpc包下,包含了RPC通信框架的高层实现,在本文中我们将以具体流程来对它们进行解析。

1. 通信组件架构

在前面两篇文章中,我们详细讨论了Spark通信架构里传输层的实现,回顾前面的讲解,我们知道在TransportChannelHandler中,将RPC请求消息委托给了RpcHandler处理,将流请求消息委托给了StreamManager处理;RpcHandler和StreamManager的具体实现类将根据具体的业务逻辑来定制。Spark针对各类场景,实现了不同的RpcHandler和StreamManager,在Spark Core模块的org.apache.spark.rpc包下,包含了RPC通信框架的高层实现,在本文中我们将以具体流程来对它们进行解析。

2. RpcEnv

熟悉Spark部署模式的读者一定知道,在Spark提供的Standalone部署模式中存在两种角色:Master和Worker。Spark中对这两种角色分别用org.apache.spark.deploy.master.Master类和org.apache.spark.deploy.worker.Worker类来表示;因此,在本文中,我们以Master角色的启动流程,来探讨通信架构的高层实现。

注:Worker和Client(org.apache.spark.deploy包下)的启动流程与Master非常类似。

Master有两种启动模式:Local-Cluster模式和Standalone模式。Local模式是以Java对象的方式启动的,即由SparkContext调用Master伴生对象的startRpcEnvAndEndpoint(...)方法,调用方法栈如下:

  • Master.startRpcEnvAndEndpoint(String, int, int, SparkConf)
  • ↖ LocalSparkCluster.start()
  • ↖ SparkContext.createTaskScheduler(SparkContext, String, String)

而Standalone模式则是使用start-master.sh脚本来启动,start-master.sh脚本最终会启动一个JVM进程执行org.apache.spark.deploy.master.Master类的main(...)方法;Master类的main(...)方法中其实还是调用了startRpcEnvAndEndpoint(...)方法:

注:关于start-master.sh脚本将在后面的文章中介绍。

org.apache.spark.deploy.master.Master#main
  • // 以JVM进程方式启动Master时的入口方法
  • def main(argStrings: Array[String]) {
  • Utils.initDaemon(log)
  • // 创建SparkConf
  • val conf = new SparkConf
  • /**
  • * 解析传入的参数。
  • * 命令行参数指定的值会覆盖系统环境变量指定的值。
  • * 属性指定的值会覆盖系统环境变量或命令行参数的值。
  • */
  • val args = new MasterArguments(argStrings, conf)
  • // 调用startRpcEnvAndEndpoint()方法启动
  • val (rpcEnv, _, _) = startRpcEnvAndEndpoint(args.host, args.port, args.webUiPort, conf)
  • // 该操作最后调用了Dispatcher中线程池的awaitTermination()方法
  • rpcEnv.awaitTermination()
  • }

可见,main(...)方法会从运行JVM进程时传入的参数解析得到Master运行的主机地址、端口号和Web UI端口号,然后调用startRpcEnvAndEndpoint(...)方法,该方法的源码如下:

org.apache.spark.deploy.master.Master#startRpcEnvAndEndpoint
  • /**
  • * Start the Master and return a three tuple of:
  • * (1) The Master RpcEnv
  • * (2) The web UI bound port
  • * (3) The REST server bound port, if any
  • *
  • * 用于创建Master对象,并将Master对象注册到RpcEnv中完成对Master对象的启动。
  • */
  • def startRpcEnvAndEndpoint(
  • host: String,
  • port: Int,
  • webUiPort: Int,
  • conf: SparkConf): (RpcEnv, Int, Option[Int]) = {
  • // 创建SecurityManager
  • val securityMgr = new SecurityManager(conf)
  • // 创建RpcEnv
  • val rpcEnv = RpcEnv.create(SYSTEM_NAME, host, port, conf, securityMgr)
  • // 创建Master,将Master(Master继承了ThreadSafeRpcEndpoint)注册到RpcEnv中,获得Master的RpcEndpointRef。
  • val masterEndpoint = rpcEnv.setupEndpoint(ENDPOINT_NAME,
  • new Master(rpcEnv, rpcEnv.address, webUiPort, securityMgr, conf))
  • // 向Master发送BoundPortsRequest消息,并获得返回的BoundPortsResponse消息。
  • val portsResponse = masterEndpoint.askWithRetry[BoundPortsResponse](BoundPortsRequest)
  • // 返回创建的RpcEnv、BoundPortsResponse消息携带的WebUIPort、REST服务的端口(restPort)等信息。
  • (rpcEnv, portsResponse.webUIPort, portsResponse.restPort)
  • }

startRpcEnvAndEndpoint(...)方法首先创建了SecurityManager安全管理器,然后调用RpcEnv的create(...)方法创建RpcEnv对象。

从RpcEnv的命名即可推测,它肯定和Spark的RPC环境息息相关;确实,RpcEnv是Spark内RPC环境的大管家,每个Spark节点都会存在一个对应的RpcEnv实例。RpcEnv是一个抽象类,它内部定义了很多规范方法,源码如下:

org.apache.spark.rpc.RpcEnv
  • /**
  • * An RPC environment. [[RpcEndpoint]]s need to register itself with a name to [[RpcEnv]] to
  • * receives messages. Then [[RpcEnv]] will process messages sent from [[RpcEndpointRef]] or remote
  • * nodes, and deliver them to corresponding [[RpcEndpoint]]s. For uncaught exceptions caught by
  • * [[RpcEnv]], [[RpcEnv]] will use [[RpcCallContext.sendFailure]] to send exceptions back to the
  • * sender, or logging them if no such sender or `NotSerializableException`.
  • *
  • * [[RpcEnv]] also provides some methods to retrieve [[RpcEndpointRef]]s given name or uri.
  • */
  • private[spark] abstract class RpcEnv(conf: SparkConf) {
  • /**
  • * 根据"spark.rpc.lookupTimeout"或"spark.network.timeout"配置构造RpcTimeout对象,默认超时时间为120秒
  • * 用于某些场景下将异步执行转换为同步执行
  • */
  • private[spark] val defaultLookupTimeout = RpcUtils.lookupRpcTimeout(conf)
  • /**
  • * Return RpcEndpointRef of the registered [[RpcEndpoint]]. Will be used to implement
  • * [[RpcEndpoint.self]]. Return `null` if the corresponding [[RpcEndpointRef]] does not exist.
  • *
  • * 根据RpcEndpoint查找对应的RpcEndpointRef
  • */
  • private[rpc] def endpointRef(endpoint: RpcEndpoint): RpcEndpointRef
  • /**
  • * Return the address that [[RpcEnv]] is listening to.
  • *
  • * 返回RpcEnv监听的地址
  • */
  • def address: RpcAddress
  • /**
  • * Register a [[RpcEndpoint]] with a name and return its [[RpcEndpointRef]]. [[RpcEnv]] does not
  • * guarantee thread-safety.
  • *
  • * 以特定名称注册一个RpcEndpoint,将返回对应的RpcEndpointRef对象
  • */
  • def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef
  • /**
  • * Retrieve the [[RpcEndpointRef]] represented by `uri` asynchronously.
  • *
  • * 异步方式,通过URI查找对应的RpcEndpointRef对象
  • */
  • def asyncSetupEndpointRefByURI(uri: String): Future[RpcEndpointRef]
  • /**
  • * Retrieve the [[RpcEndpointRef]] represented by `uri`. This is a blocking action.
  • *
  • * 同步方式,通过URI查找对应的RpcEndpointRef对象
  • */
  • def setupEndpointRefByURI(uri: String): RpcEndpointRef = {
  • /**
  • * asyncSetupEndpointRefByURI的功能是向远端NettyRpcEnv询问指定名称的RpcEndpoint的NettyRpcEndpointRef
  • * defaultLookupTimeout是根据参数构造的RpcTimeout对象,调用其awaitResult()方法进行超时操作
  • */
  • defaultLookupTimeout.awaitResult(asyncSetupEndpointRefByURI(uri))
  • }
  • /**
  • * Retrieve the [[RpcEndpointRef]] represented by `address` and `endpointName`.
  • * This is a blocking action.
  • * 使用RpcAddress和RpcEndpoint的名称,得到对应的RpcEndpointRef
  • */
  • def setupEndpointRef(address: RpcAddress, endpointName: String): RpcEndpointRef = {
  • // 封装RpcAddress和endpointName为RpcEndpointAddress对象
  • setupEndpointRefByURI(RpcEndpointAddress(address, endpointName).toString)
  • }
  • /**
  • * Stop [[RpcEndpoint]] specified by `endpoint`.
  • *
  • * 停止制定的RpcEndpoint
  • */
  • def stop(endpoint: RpcEndpointRef): Unit
  • /**
  • * Shutdown this [[RpcEnv]] asynchronously. If need to make sure [[RpcEnv]] exits successfully,
  • * call [[awaitTermination()]] straight after [[shutdown()]].
  • *
  • * 关闭当前RpcEnv
  • */
  • def shutdown(): Unit
  • /**
  • * Wait until [[RpcEnv]] exits.
  • *
  • * TODO do we need a timeout parameter?
  • *
  • * 等待知道RpcEnv退出
  • */
  • def awaitTermination(): Unit
  • /**
  • * [[RpcEndpointRef]] cannot be deserialized without [[RpcEnv]]. So when deserializing any object
  • * that contains [[RpcEndpointRef]]s, the deserialization codes should be wrapped by this method.
  • *
  • * 反序列化操作
  • */
  • def deserialize[T](deserializationAction: () => T): T
  • /**
  • * Return the instance of the file server used to serve files. This may be `null` if the
  • * RpcEnv is not operating in server mode.
  • *
  • * 获取文件服务器
  • */
  • def fileServer: RpcEnvFileServer
  • /**
  • * Open a channel to download a file from the given URI. If the URIs returned by the
  • * RpcEnvFileServer use the "spark" scheme, this method will be called by the Utils class to
  • * retrieve the files.
  • *
  • * 打开一个Channel用于从给定URI下载文件
  • *
  • * @param uri URI with location of the file.
  • */
  • def openChannel(uri: String): ReadableByteChannel
  • }

在RpcEnv的内部,几乎都是在操作RpcEndpoint和RpcEndpointRef对象,因此我们需要先了解这两个类。

2.1. RpcEndpoint

RpcEndpoint是一个特质,用于表示一个RPC端点,Spark中每个节点(如Master、Worker、Client等,它们都直接或间接实现了RpcEndpoint)都称之为一个RPC端点,RPC端点内部根据业务需求,设计了不同的消息处理方式;RpcEndpoint的定义如下:

org.apache.spark.rpc.RpcEndpoint
  • /**
  • * An end point for the RPC that defines what functions to trigger given a message.
  • *
  • * It is guaranteed that `onStart`, `receive` and `onStop` will be called in sequence.
  • *
  • * The life-cycle of an endpoint is:
  • *
  • * constructor -> onStart -> receive* -> onStop
  • *
  • * Note: `receive` can be called concurrently. If you want `receive` to be thread-safe, please use
  • * [[ThreadSafeRpcEndpoint]]
  • *
  • * If any error is thrown from one of [[RpcEndpoint]] methods except `onError`, `onError` will be
  • * invoked with the cause. If `onError` throws an error, [[RpcEnv]] will ignore it.
  • *
  • * 对Spark的RPC通信实体的统一抽象,所有运行于RPC框架之上的实体都应该继承RpcEndpoint
  • */
  • private[spark] trait RpcEndpoint {
  • /**
  • * The [[RpcEnv]] that this [[RpcEndpoint]] is registered to.
  • * 当前RpcEndpoint所属的RpcEnv
  • */
  • val rpcEnv: RpcEnv
  • ...
  • }

RpcEndpoint中定义了大量的方法接口,这里将其分为以下三类:

  1. 状态相关。

onStart()onConnected(...)onDisconnected(...)onError(...)onNetworkError(...)onStop()六个,它们都会在RpcEndpoint出现相关的状态时被调用,源码如下:

org.apache.spark.rpc.RpcEndpoint
  • /**
  • * Invoked before [[RpcEndpoint]] starts to handle any message.
  • *
  • * 在RpcEndpoint开始处理消息之前调用,可以在RpcEndpoint正式工作之前做一些准备工作。
  • */
  • def onStart(): Unit = {
  • // By default, do nothing.
  • }
  • /**
  • * Invoked when `remoteAddress` is connected to the current node.
  • *
  • * 当客户端与当前节点连接上之后调用,可以针对连接进行一些处理。
  • */
  • def onConnected(remoteAddress: RpcAddress): Unit = {
  • // By default, do nothing.
  • }
  • /**
  • * Invoked when `remoteAddress` is lost.
  • *
  • * 当客户端与当前节点的连接断开之后调用,可以针对断开连接进行一些处理。
  • */
  • def onDisconnected(remoteAddress: RpcAddress): Unit = {
  • // By default, do nothing.
  • }
  • /**
  • * Invoked when any exception is thrown during handling messages.
  • *
  • * 当处理消息发生异常时调用,可以对异常进行一些处理。
  • */
  • def onError(cause: Throwable): Unit = {
  • // By default, throw e and let RpcEnv handle it
  • throw cause
  • }
  • /**
  • * Invoked when some network error happens in the connection between the current node and
  • * `remoteAddress`.
  • *
  • * 当客户端与当前节点之间的连接发生网络错误时调用,可以针对连接发生的网络错误进行一些处理。
  • */
  • def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = {
  • // By default, do nothing.
  • }
  • /**
  • * Invoked when [[RpcEndpoint]] is stopping. `self` will be `null` in this method and you cannot
  • * use it to send or ask messages.
  • *
  • * 在停止RpcEndpoint时调用,可以在RpcEndpoint停止的时候做一些收尾工作。
  • */
  • def onStop(): Unit = {
  • // By default, do nothing.
  • }
  1. 处理消息相关。

receive(...)receiveAndReply(...),其中receive(...)用于接收消息,但不需要回复,而receiveAndReply(...)则处理需要回复的消息;源码如下:

org.apache.spark.rpc.RpcEndpoint
  • /**
  • * Process messages from [[RpcEndpointRef.send]] or [[RpcCallContext.reply)]]. If receiving a
  • * unmatched message, [[SparkException]] will be thrown and sent to `onError`.
  • *
  • * 接收消息并处理,但不需要给客户端回复
  • */
  • def receive: PartialFunction[Any, Unit] = {
  • case _ => throw new SparkException(self + " does not implement 'receive'")
  • }
  • /**
  • * Process messages from [[RpcEndpointRef.ask]]. If receiving a unmatched message,
  • * [[SparkException]] will be thrown and sent to `onError`.
  • *
  • * 接收消息并处理,需要给客户端回复。回复是通过RpcCallContext来实现的。
  • */
  • def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
  • case _ => context.sendFailure(new SparkException(self + " won't reply anything"))
  • }
  1. 辅助方法。

self()stop()self()用于获取与当前RpcEndpoint对应的RpcEndpointRef端点引用对象,而stop()方法则用于停止当前的RpcEndpoint端点;源码如下:

  • /**
  • * The [[RpcEndpointRef]] of this [[RpcEndpoint]]. `self` will become valid when `onStart` is
  • * called. And `self` will become `null` when `onStop` is called.
  • *
  • * Note: Because before `onStart`, [[RpcEndpoint]] has not yet been registered and there is not
  • * valid [[RpcEndpointRef]] for it. So don't call `self` before `onStart` is called.
  • *
  • * 获取RpcEndpoint相关联的RpcEndpointRef
  • */
  • final def self: RpcEndpointRef = {
  • require(rpcEnv != null, "rpcEnv has not been initialized")
  • // 实际调用了RpcEnv的endpointRef方法
  • rpcEnv.endpointRef(this)
  • }
  • /**
  • * A convenient method to stop [[RpcEndpoint]].
  • *
  • * 用于停止当前RpcEndpoint。
  • */
  • final def stop(): Unit = {
  • val _self = self
  • if (_self != null) {
  • // 实际调用了RpcEnv的stop方法
  • rpcEnv.stop(_self)
  • }
  • }

从RpcEndpoint的设计来看,RPC端点更像是充当着服务端的抽象角色,主要用于处理客户端的连接和断开、接收客户端消息等功能。

2.2. RpcEndpointRef

RpcEndpointRef则与RpcEndpoint恰好相反,从它的命名就可以看出来它更像是RPC端点的一个引用,使用RpcEndpointRef可以向与之存在“引用”关系的RpcEndpoint端点发送消息;它是一个抽象类,定义和重要字段如下:

org.apache.spark.rpc.RpcEndpointRef
  • /**
  • * A reference for a remote [[RpcEndpoint]]. [[RpcEndpointRef]] is thread-safe.
  • */
  • private[spark] abstract class RpcEndpointRef(conf: SparkConf)
  • extends Serializable with Logging {
  • // RPC最大重新连接次数。可以使用spark.rpc.numRetries属性进行配置,默认为3次。
  • private[this] val maxRetries = RpcUtils.numRetries(conf)
  • // RPC每次重新连接需要等待的毫秒数。可以使用spark.rpc.retry.wait属性进行配置,默认值为3秒。
  • private[this] val retryWaitMs = RpcUtils.retryWaitMs(conf)
  • /**
  • * RPC的ask操作的默认超时时间。
  • * 可以使用spark.rpc.askTimeout或者spark.network.timeout属性进行配置,默认值为120秒。
  • * spark.rpc.askTimeout属性的优先级更高。
  • */
  • private[this] val defaultAskTimeout = RpcUtils.askRpcTimeout(conf)
  • ...
  • }

它的三个字段是针对通信操作而设置的,都是从SparkConf所携带的配置信息中进行初始化的。

RpcEndpointRef也定义了大量的方法,这里将其分为以下两类:

  1. 处理消息相关。

send(...)ask(...)askWithRetry(...);其中send(...)方法用于发送无状态的消息,它不需要任何回复,ask(...)用于发送需要回复的消息,askWithRetry(...)发送的消息也需要回复,但它还附带了重试机制。ask(...)askWithRetry(...)都有重载版本,实现了超时等待;它们的源码如下:

org.apache.spark.rpc.RpcEndpointRef
  • /**
  • * Sends a one-way asynchronous message. Fire-and-forget semantics.
  • * 发送单向异步的消息。
  • * 所谓“单向”就是发送完后就会忘记此次发送,不会有任何状态要记录,也不会期望得到服务端的回复。
  • * send采用了at-most-once的投递规则。
  • */
  • def send(message: Any): Unit
  • /**
  • * Send a message to the corresponding [[RpcEndpoint.receiveAndReply)]] and return a [[Future]] to
  • * receive the reply within the specified timeout.
  • *
  • * This method only sends the message once and never retries.
  • *
  • * 发送消息并在指定超时时间内等待响应。
  • * 该方法只会发送一次,不会重试。
  • */
  • def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T]
  • /**
  • * Send a message to the corresponding [[RpcEndpoint.receiveAndReply)]] and return a [[Future]] to
  • * receive the reply within a default timeout.
  • *
  • * This method only sends the message once and never retries.
  • *
  • * 以默认的超时时间作为timeout参数,调用ask[T:ClassTag](message:Any,timeout:RpcTimeout)方法。
  • */
  • def ask[T: ClassTag](message: Any): Future[T] = ask(message, defaultAskTimeout)
  • /**
  • * Send a message to the corresponding [[RpcEndpoint]] and get its result within a default
  • * timeout, or throw a SparkException if this fails even after the default number of retries.
  • * The default `timeout` will be used in every trial of calling `sendWithReply`. Because this
  • * method retries, the message handling in the receiver side should be idempotent.
  • *
  • * Note: this is a blocking action which may cost a lot of time, so don't call it in a message
  • * loop of [[RpcEndpoint]].
  • *
  • * @param message the message to send
  • * @tparam T type of the reply message
  • * @return the reply message from the corresponding [[RpcEndpoint]]
  • */
  • def askWithRetry[T: ClassTag](message: Any): T = askWithRetry(message, defaultAskTimeout)
  • /**
  • * Send a message to the corresponding [[RpcEndpoint.receive]] and get its result within a
  • * specified timeout, throw a SparkException if this fails even after the specified number of
  • * retries. `timeout` will be used in every trial of calling `sendWithReply`. Because this method
  • * retries, the message handling in the receiver side should be idempotent.
  • *
  • * Note: this is a blocking action which may cost a lot of time, so don't call it in a message
  • * loop of [[RpcEndpoint]].
  • *
  • * 发送同步的请求,此类请求将会被RpcEndpoint接收,并在指定的超时时间内等待返回类型为T的处理结果。
  • * 当此方法抛出SparkException时,将会进行请求重试,直到超过了默认的重试次数为止。
  • * 由于此类方法会重试,因此要求服务端对消息的处理是幂等的。
  • * 此方法也采用了at-least-once的投递规则。
  • *
  • * @param message the message to send
  • * @param timeout the timeout duration
  • * @tparam T type of the reply message
  • * @return the reply message from the corresponding [[RpcEndpoint]]
  • */
  • def askWithRetry[T: ClassTag](message: Any, timeout: RpcTimeout): T = {
  • // TODO: Consider removing multiple attempts
  • // 尝试次数
  • var attempts = 0
  • // 用于记录异常
  • var lastException: Exception = null
  • while (attempts < maxRetries) { // 尝试次数小于最大可尝试次数
  • // 尝试次数自增
  • attempts += 1
  • try {
  • // 调用ask方法
  • val future = ask[T](message, timeout)
  • // 超时等待以获取结果
  • val result = timeout.awaitResult(future)
  • // 获取结果为空,抛出异常,会被捕获并记录在lastException中
  • if (result == null) {
  • throw new SparkException("RpcEndpoint returned null")
  • }
  • // 获取结果不为空,返回
  • return result
  • } catch {
  • // 除了中断异常会向外抛出,其他异常都会被记录到lastException中
  • case ie: InterruptedException => throw ie
  • case e: Exception =>
  • lastException = e
  • logWarning(s"Error sending message [message = $message] in $attempts attempts", e)
  • }
  • // 尝试次数小于最大可尝试次数,说明还可以尝试,需要等待一段时间后再进行
  • if (attempts < maxRetries) {
  • Thread.sleep(retryWaitMs)
  • }
  • }
  • throw new SparkException(
  • s"Error sending message [message = $message]", lastException)
  • }

从上面的源码可知,askWithRetry(...)内部重用了ask(...)方法,只是额外增加了重试功能。ask(...)方法有两个版本,带有超时机制的异步发送版本调用了另一个抽象版本;该抽象的ask(...)方法是需要子类具体实现的。

  1. 辅助方法。

address()name(),分别用于获取当前RpcEndpointRef的地址和名称,源码如下:

org.apache.spark.rpc.RpcEndpointRef
  • /**
  • * return the address for the [[RpcEndpointRef]]
  • * 返回当前RpcEndpointRef对应RpcEndpoint的RPC地址(RpcAddress)。
  • */
  • def address: RpcAddress
  • // 返回当前RpcEndpointRef对应RpcEndpoint的名称。
  • def name: String

2.3. NettyRpcEnvFactory

有了对RpcEndpoint和RpcEndpointRef的了解之后,让我们回到RpcEnv创建流程的讲解上。在RpcEnv的伴生对象中,定义了两个create(...)方法用于创建RpcEnv实例:

org.apache.spark.rpc.RpcEnv
  • /**
  • * A RpcEnv implementation must have a [[RpcEnvFactory]] implementation with an empty constructor
  • * so that it can be created via Reflection.
  • */
  • private[spark] object RpcEnv {
  • // 创建RpcEnv实例,调用了重载方法
  • def create(
  • name: String,
  • host: String,
  • port: Int,
  • conf: SparkConf,
  • securityManager: SecurityManager,
  • clientMode: Boolean = false): RpcEnv = {
  • create(name, host, host, port, conf, securityManager, clientMode)
  • }
  • // 创建RpcEnv实例
  • def create(
  • name: String,
  • bindAddress: String,
  • advertiseAddress: String,
  • port: Int,
  • conf: SparkConf,
  • securityManager: SecurityManager,
  • clientMode: Boolean): RpcEnv = {
  • // 构造配置
  • val config = RpcEnvConfig(conf, name, bindAddress, advertiseAddress, port, securityManager,
  • clientMode)
  • // 使用工厂创建
  • new NettyRpcEnvFactory().create(config)
  • }
  • }

从重载的create(...)方法中可以得知,它构造了一个RpcEnvConfig对象,然后创建了NettyRpcEnvFactory工厂类,使用该工厂类的create()方法创建RpcEnv并返回;注意在整个过程中,clientMode参数一直是默认值false。其中RpcEnvConfig类是一个样例类,只是用来保存各项属性,这里不多赘述。

NettyRpcEnvFactory是创建RpcEnv实例的工厂类,它只有create(...)一个方法:

org.apache.spark.rpc.netty.NettyRpcEnvFactory
  • private[rpc] class NettyRpcEnvFactory extends RpcEnvFactory with Logging {
  • // 创建NettyRpcEnv
  • def create(config: RpcEnvConfig): RpcEnv = {
  • // 创建序列化器
  • val sparkConf = config.conf
  • // Use JavaSerializerInstance in multiple threads is safe. However, if we plan to support
  • // KryoSerializer in future, we have to use ThreadLocal to store SerializerInstance
  • val javaSerializerInstance =
  • new JavaSerializer(sparkConf).newInstance().asInstanceOf[JavaSerializerInstance]
  • // 通过序列化器、监听地址、安全管理器等构造NettyRpcEnv
  • val nettyEnv =
  • new NettyRpcEnv(sparkConf, javaSerializerInstance, config.advertiseAddress,
  • config.securityManager)
  • if (!config.clientMode) { // 如果是Driver
  • // 定义启动RpcEnv的偏函数,该偏函数中会创建TransportServer
  • val startNettyRpcEnv: Int => (NettyRpcEnv, Int) = { actualPort =>
  • // 调用NettyRpcEnv的startServer(),这里会创建TransportServer
  • nettyEnv.startServer(config.bindAddress, actualPort)
  • // 返回NettyRpcEnv和端口
  • (nettyEnv, nettyEnv.address.port)
  • }
  • try {
  • // 在指定端口启动NettyRpcEnv
  • Utils.startServiceOnPort(config.port, startNettyRpcEnv, sparkConf, config.name)._1
  • } catch {
  • case NonFatal(e) =>
  • nettyEnv.shutdown()
  • throw e
  • }
  • }
  • // 返回NettyRpcEnv
  • nettyEnv
  • }
  • }

从上述方法可以得知,创建的RpcEnv其实是NettyRpcEnv实例,同时如果clientMode参数为false(Master启动时,该参数一直默认为false),会调用NettyRpcEnv的startServer(...)方法创建TransportServer对象,然后在返回的监听端口启动NettyRpcEnv服务,最终返回NettyRpcEnv实例。

3. NettyRpcEnv

在上面的介绍中我们已经知道最终创建的是NettyRpcEnv对象,NettyRpcEnv是Spark2.1.0版本中RpcEnv目前唯一的实现类,它的定义如下:

org.apache.spark.rpc.netty.NettyRpcEnv
  • private[netty] class NettyRpcEnv(
  • val conf: SparkConf,
  • javaSerializerInstance: JavaSerializerInstance,
  • host: String,
  • securityManager: SecurityManager) extends RpcEnv(conf) with Logging {
  • ...
  • }

3.1. NettyRpcEnv实现的方法

我们来分析一下NettyRpcEnv对RpcEnv几个方法的实现。

3.1.1. RpcEndpoint相关

用于操作RpcEndpoint的方法主要有以下几个:

  1. 为NettyRpcEnv设置RpcEndpoint的setupEndpoint(...)方法,内部使用了Dispatcher的registerRpcEndpoint(...)方法:
org.apache.spark.rpc.netty.NettyRpcEnv#setupEndpoint
  • // 设置RpcEndpoint
  • override def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = {
  • // 使用Dispatcher注册RpcEndpoint
  • dispatcher.registerRpcEndpoint(name, endpoint)
  • }
  1. 查找指定的RpcEndpoint所对应的RpcEndpointRef,即endpointRef(...)方法:
org.apache.spark.rpc.netty.NettyRpcEnv#endpointRef
  • // 根据RpcEndpoint获取对应的RpcEndpointRef
  • override def endpointRef(endpoint: RpcEndpoint): RpcEndpointRef = {
  • dispatcher.getRpcEndpointRef(endpoint)
  • }
  1. 根据指定URI查找并设置RpcEndpointRef,即asyncSetupEndpointRefByURI(...)方法,该方法需要向URI指定的地址发送消息请求,以确认是否存在对应的RpcEndpoint,如果存在就创建响应的RpcEndpointRef;关于该方法的实现原理我们将在后面讲解:
org.apache.spark.rpc.netty.NettyRpcEnv#asyncSetupEndpointRefByURI
  • def asyncSetupEndpointRefByURI(uri: String): Future[RpcEndpointRef] = {
  • // 得到RpcEndpointAddress对象
  • val addr = RpcEndpointAddress(uri)
  • // 构建NettyRpcEndpointRef
  • val endpointRef = new NettyRpcEndpointRef(conf, addr, this)
  • // 获取远端NettyRpcEnv的RpcEndpointVerifier
  • val verifier = new NettyRpcEndpointRef(
  • conf, RpcEndpointAddress(addr.rpcAddress, RpcEndpointVerifier.NAME), this)
  • // 使用ask方法进行询问,向远端NettyRpcEnv的RpcEndpointVerifier发送RpcEndpointVerifier.CheckExistence消息
  • verifier.ask[Boolean](RpcEndpointVerifier.CheckExistence(endpointRef.name)).flatMap { find =>
  • if (find) {
  • // 能够查询到
  • Future.successful(endpointRef)
  • } else {
  • // 没有查询到
  • Future.failed(new RpcEndpointNotFoundException(uri))
  • }
  • }(ThreadUtils.sameThread)
  • }

关于上述两个方法,都调用了Dispatcher的相关方法,后面会详细介绍。

3.1.2. 下载文件流相关

openChannel(...)方法可以根据URI创建Channel用于下载文件,源码如下:

  • // 根据URI打开一个ReadableByteChannel用于下载文件
  • override def openChannel(uri: String): ReadableByteChannel = {
  • // 解析URI为URI对象
  • val parsedUri = new URI(uri)
  • // 检查URI是否合法
  • require(parsedUri.getHost() != null, "Host name must be defined.")
  • require(parsedUri.getPort() > 0, "Port must be defined.")
  • require(parsedUri.getPath() != null && parsedUri.getPath().nonEmpty, "Path must be defined.")
  • // 创建单向管道,数据会从pipe的sink写入,然后可以从source读取
  • val pipe = Pipe.open()
  • // 将pipe的source包装为FileDownloadChannel
  • val source = new FileDownloadChannel(pipe.source())
  • try {
  • // 创建TransportClient客户端
  • val client = downloadClient(parsedUri.getHost(), parsedUri.getPort())
  • // 创建下载回调,此回调会把读到的数据写入到pipe的sink中
  • val callback = new FileDownloadCallback(pipe.sink(), source, client)
  • // 使用TransportClient的stream()方法发送流请求进行文件下载,响应的数据将交给callback处理
  • client.stream(parsedUri.getPath(), callback)
  • } catch {
  • case e: Exception =>
  • pipe.sink().close()
  • source.close()
  • throw e
  • }
  • // 返回FileDownloadChannel类型的source,它的读取方法会从pipe的source读取数据
  • source
  • }

openChannel(...)方法实现比较复杂,这里着重讲解一下。openChannel(...)方法主要是通过JDK NIO提供的Pipe功能实现数据读写的,Pipe提供了双线程单向读取的功能,数据会被一个线程写到Pipe的Sink中,然后另一个线程可以从其Source中读取数据。

在检查URI合法之后,openChannel(...)方法会创建一个Pipe对象,然后使用该Pipe的Source创建一个FileDownloadChannel对象作为新的source,FileDownloadChannel对象内部其实仅仅是对Pipe的Source进行了一层封装,在调用其read(...)方法读取数据时,其实是从Pipe的Source中读取的:

org.apache.spark.rpc.netty.NettyRpcEnv.FileDownloadChannel
  • private class FileDownloadChannel(source: ReadableByteChannel) extends ReadableByteChannel {
  • @volatile private var error: Throwable = _
  • // 出现错误
  • def setError(e: Throwable): Unit = {
  • error = e
  • // 关闭Source
  • source.close()
  • }
  • // 读取数据
  • override def read(dst: ByteBuffer): Int = {
  • // 从source中读取数据放入传入的dst中
  • Try(source.read(dst)) match {
  • // 读取成功,返回读取字节数
  • case Success(bytesRead) => bytesRead
  • // 读取出错,抛出异常
  • case Failure(readErr) =>
  • if (error != null) {
  • throw error
  • } else {
  • throw readErr
  • }
  • }
  • }
  • // 关闭source
  • override def close(): Unit = source.close()
  • // 判断source是否打开
  • override def isOpen(): Boolean = source.isOpen()
  • }

接下来,openChannel(...)方法会根据解析得到的URI,使用downloadClient(...)创建TransportClient客户端:

org.apache.spark.rpc.netty.NettyRpcEnv#downloadClient
  • private def downloadClient(host: String, port: Int): TransportClient = {
  • // 检查创建文件下载客户端的工厂是否为空
  • if (fileDownloadFactory == null) synchronized {
  • if (fileDownloadFactory == null) {
  • val module = "files"
  • val prefix = "spark.rpc.io."
  • // 克隆一份SparkConf
  • val clone = conf.clone()
  • // Copy any RPC configuration that is not overridden in the spark.files namespace.
  • // 检查file模块一些有关RPC的配置是否缺失了,如果是将设置到克隆的SparkConf中
  • conf.getAll.foreach { case (key, value) =>
  • if (key.startsWith(prefix)) {
  • val opt = key.substring(prefix.length())
  • clone.setIfMissing(s"spark.$module.io.$opt", value)
  • }
  • }
  • // 文件下载的IO线程数
  • val ioThreads = clone.getInt("spark.files.io.threads", 1)
  • // 创建TransportConf
  • val downloadConf = SparkTransportConf.fromSparkConf(clone, module, ioThreads)
  • // 创建TransportContext
  • val downloadContext = new TransportContext(downloadConf, new NoOpRpcHandler(), true)
  • // 创建用于文件下载客户端的TransportClientFactory
  • fileDownloadFactory = downloadContext.createClientFactory(createClientBootstraps())
  • }
  • }
  • // 创建TransportClient
  • fileDownloadFactory.createClient(host, port)
  • }

该方法的实现其实是比较简单的,就是使用传输层的TransportContext创建TransportClientFactory,然后使用该工厂对象创建TransportClient;不过这里有一个细节,由于此处创建的TransportClient只用于下载文件,它并不需要RpcHandler,因此在创建TransportContext时,传入的RpcHandler其实是没有任何操作的NoOpRpcHandler。

有了客户端之后,还会创建一个FileDownloadCallback回调对象,这个对象是给后面TransportClient的stream(...)方法用的;Pipe的Sink和前面创建的FileDownloadChannel对象source,以及TransportClient都会封装到FileDownloadCallback中,我们看一下它的实现:

org.apache.spark.rpc.netty.NettyRpcEnv.FileDownloadCallback
  • private class FileDownloadCallback(
  • sink: WritableByteChannel,
  • source: FileDownloadChannel,
  • client: TransportClient) extends StreamCallback {
  • // 接收到数据
  • override def onData(streamId: String, buf: ByteBuffer): Unit = {
  • // 将buf中的数据全部写入到WritableByteChannel中
  • while (buf.remaining() > 0) {
  • sink.write(buf)
  • }
  • }
  • // 数据读写完毕
  • override def onComplete(streamId: String): Unit = {
  • // 关闭WritableByteChannel
  • sink.close()
  • }
  • // 数据读写出错
  • override def onFailure(streamId: String, cause: Throwable): Unit = {
  • logDebug(s"Error downloading stream $streamId.", cause)
  • // 设置错误到FileDownloadChannel
  • source.setError(cause)
  • // 关闭WritableByteChannel
  • sink.close()
  • }
  • }

回顾一下流请求的处理流程:在TransportClient发送流请求StreamRequest之前,它会将流ID和回调函数存放到TransportResponseHandler的streamCallbacks队列中,然后将StreamRequest发出;在收到服务端返回的StreamResponse时,TransportResponseHandler的handle(...)方法会解析StreamResponse,如果存在需要接收的数据,就向处理器链中TransportFrameDecoder帧解码器设置StreamInterceptor拦截器,往后传输过来的流数据都会交给该拦截器处理。StreamInterceptor在接收到源源不断的数据后,每次都会将数据传给StreamCallback回调对象的onData(...)方法以便处理,当数据接收完成,就会告知StreamCallback回调对象的onComplete(...)方法。

上面的FileDownloadCallback对象,就充当了流请求处理步骤中StreamCallback回调对象,在接收到文件的流数据后,FileDownloadCallback的onData(...)方法会不断地将数据写入到Pipe的Sink中,这样一来,通过FileDownloadChannel的read(...)方法就可以从Pipe的Source中读取到下载的文件数据了。

3.1.3. 其他辅助方法

  1. 获取RpcEnv的RpcAddress,即address()方法,该方法在NettyRpcEnv中表现为lazy常量address,实现也非常简单:
org.apache.spark.rpc.netty.NettyRpcEnv#address
  • @Nullable
  • override lazy val address: RpcAddress = {
  • if (server != null) RpcAddress(host, server.getPort()) else null
  • }
  1. 反序列化方法,即deserialize(...),它在NettyRpcEnv中实现如下:
org.apache.spark.rpc.netty.NettyRpcEnv#deserialize
  • // 反序列化操作
  • override def deserialize[T](deserializationAction: () => T): T = {
  • NettyRpcEnv.currentEnv.withValue(this) {
  • deserializationAction()
  • }
  • }
  1. 获取文件服务器的fileServer()方法的实现就是返回创建的NettyStreamManager对象:
org.apache.spark.rpc.netty.NettyRpcEnv#fileServer
  • // 文件服务器
  • override def fileServer: RpcEnvFileServer = streamManager
  1. awaitTermination()方法用于保持当前NettyRpcEnv一直处于运行状态,不会退出:
org.apache.spark.rpc.netty.NettyRpcEnv#awaitTermination
  • override def awaitTermination(): Unit = {
  • dispatcher.awaitTermination()
  • }

它是通过调用Dispatcher的awaitTermination()方法实现的,底层其实是通过ThreadPoolExecutor线程池的awaitTermination(Long.MaxValue, TimeUnit.MILLISECONDS)实现的:

org.apache.spark.rpc.netty.Dispatcher#awaitTermination
  • def awaitTermination(): Unit = {
  • threadpool.awaitTermination(Long.MaxValue, TimeUnit.MILLISECONDS)
  • }
  1. stop(...)方法停止指定的RpcEndpointRef,内部也是通过Dispatcher的stop(...)方法实现的,底层会取消对指定RpcEndpointRef的注册,并停止其对应的消息处理业务(后面会详细介绍):
org.apache.spark.rpc.netty.NettyRpcEnv#stop
  • override def stop(endpointRef: RpcEndpointRef): Unit = {
  • require(endpointRef.isInstanceOf[NettyRpcEndpointRef])
  • // 停止Dispatcher,对对应的RpcEndpoint取消注册
  • dispatcher.stop(endpointRef)
  • }
  1. shutdown(...)用于关闭当前的NettyRpcEnv,内部调用了clean()方法,会将与之相关的所有组件全部关闭:
org.apache.spark.rpc.netty.NettyRpcEnv#cleanup
  • override def shutdown(): Unit = {
  • cleanup()
  • }
  • private def cleanup(): Unit = {
  • // CAS方式修改标识
  • if (!stopped.compareAndSet(false, true)) {
  • return
  • }
  • // 关闭OutBox
  • val iter = outboxes.values().iterator()
  • while (iter.hasNext()) {
  • val outbox = iter.next()
  • outboxes.remove(outbox.address)
  • outbox.stop()
  • }
  • // 关闭超时调度器
  • if (timeoutScheduler != null) {
  • timeoutScheduler.shutdownNow()
  • }
  • // 关闭Dispatcher
  • if (dispatcher != null) {
  • dispatcher.stop()
  • }
  • // 关闭TransportServer
  • if (server != null) {
  • server.close()
  • }
  • // 关闭TransportClientFactory
  • if (clientFactory != null) {
  • clientFactory.close()
  • }
  • // 关闭创建TransportClient的线程池
  • if (clientConnectionExecutor != null) {
  • clientConnectionExecutor.shutdownNow()
  • }
  • // 关闭创建文件下载器的工厂
  • if (fileDownloadFactory != null) {
  • fileDownloadFactory.close()
  • }
  • }

3.2. NettyRpcEnv中的组件

NettyRpcEnv是非常重要的RPC环境管理类,在它初始化的过程中,创建了大量与之相关的通信组件,列举如下:

org.apache.spark.rpc.netty.NettyRpcEnv
  • ...
  • // 创建TransportConf
  • private[netty] val transportConf = SparkTransportConf.fromSparkConf(
  • // 对SparkConf进行克隆,并设置spark.rpc.io.numConnectionsPerPeer为1,用于指定对等节点间的连接数
  • conf.clone.set("spark.rpc.io.numConnectionsPerPeer", "1"),
  • "rpc", // 模块名为rpc
  • conf.getInt("spark.rpc.io.threads", 0)) // 设置Netty传输线程数
  • // 消息调度器
  • private val dispatcher: Dispatcher = new Dispatcher(this)
  • // 创建流管理器
  • private val streamManager = new NettyStreamManager(this)
  • // 创建TransportContext
  • private val transportContext = new TransportContext(transportConf,
  • new NettyRpcHandler(dispatcher, this, streamManager))
  • // 创建TransportClientBootstrap
  • private def createClientBootstraps(): java.util.List[TransportClientBootstrap] = {
  • if (securityManager.isAuthenticationEnabled()) {
  • java.util.Arrays.asList(new SaslClientBootstrap(transportConf, "", securityManager,
  • securityManager.isSaslEncryptionEnabled()))
  • } else {
  • java.util.Collections.emptyList[TransportClientBootstrap]
  • }
  • }
  • // 创建TransportClientFactory工厂,用于常规的发送请求和接收响应
  • private val clientFactory = transportContext.createClientFactory(createClientBootstraps())
  • /**
  • * A separate client factory for file downloads. This avoids using the same RPC handler as
  • * the main RPC context, so that events caused by these clients are kept isolated from the
  • * main RPC traffic.
  • *
  • * It also allows for different configuration of certain properties, such as the number of
  • * connections per peer.
  • *
  • * 该TransportClientFactory用于文件下载
  • */
  • @volatile private var fileDownloadFactory: TransportClientFactory = _
  • // 用于处理请求超时的调度器,即单线程的ScheduledThreadPoolExecutor线程池
  • val timeoutScheduler = ThreadUtils.newDaemonSingleThreadScheduledExecutor("netty-rpc-env-timeout")
  • // Because TransportClientFactory.createClient is blocking, we need to run it in this thread pool
  • // to implement non-blocking send/ask.
  • // TODO: a non-blocking TransportClientFactory.createClient in future
  • /**
  • * 用于异步处理TransportClientFactory.createClient()方法调用的线程池。
  • * 线程池的大小默认为64,可以使用spark.rpc.connect.threads属性进行配置。
  • */
  • private[netty] val clientConnectionExecutor = ThreadUtils.newDaemonCachedThreadPool(
  • "netty-rpc-connection",
  • conf.getInt("spark.rpc.connect.threads", 64))
  • // TransportServer
  • @volatile private var server: TransportServer = _
  • // 标识NettyRpcEnv是否停止
  • private val stopped = new AtomicBoolean(false)
  • /**
  • * A map for [[RpcAddress]] and [[Outbox]]. When we are connecting to a remote [[RpcAddress]],
  • * we just put messages to its [[Outbox]] to implement a non-blocking `send` method.
  • * RpcAddress与Outbox的映射关系的缓存。
  • * 每次向远端发送请求时,此请求消息首先放入此远端地址对应的Outbox,然后使用线程异步发送。
  • */
  • private val outboxes = new ConcurrentHashMap[RpcAddress, Outbox]()
  • ...

其中有我们熟悉的TransportConf、NettyStreamManager、TransportContext、TransportClientBootstrap、SaslClientBootstrap、TransportClientFactory和TransportServer,对于这些组件的创建注释已经讲解得非常清楚,这里不再赘述。

也有陌生的组件,如Dispatcher、NettyRpcHandler和Outbox,后面将详细介绍。

4. Dispatcher

Dispatcher是Spark RPC环境里最重要的一个通信组件,我们通常叫它消息调度器,它通常用于缓存来自各个节点的请求消息,然后用线程池对其进行处理;从它的创建时机我们可以知道,每一个NettyRpcEnv环境都有一个Dispatcher组件,也就是说,每一个Spark节点都会有一个Dispatcher组件。它的定义和重要字段如下:

  • /**
  • * A message dispatcher, responsible for routing RPC messages to the appropriate endpoint(s).
  • */
  • private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging {
  • ...
  • // 端点实例名称与端点数据EndpointData之间映射关系的缓存,可以使用端点名称从中快速获取或删除EndpointData。
  • private val endpoints: ConcurrentMap[String, EndpointData] =
  • new ConcurrentHashMap[String, EndpointData]
  • // 端点实例RpcEndpoint与端点实例引用RpcEndpointRef之间映射关系的缓存,可以使用端点实例从中快速获取或删除端点实例引用。
  • private val endpointRefs: ConcurrentMap[RpcEndpoint, RpcEndpointRef] =
  • new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef]
  • // Track the receivers whose inboxes may contain messages.
  • // 存储端点数据EndpointData的阻塞队列。只有Inbox中有消息的EndpointData才会被放入此阻塞队列。
  • private val receivers = new LinkedBlockingQueue[EndpointData]
  • /**
  • * True if the dispatcher has been stopped. Once stopped, all messages posted will be bounced
  • * immediately.
  • *
  • * Dispatcher是否停止的状态
  • */
  • @GuardedBy("this")
  • private var stopped = false
  • ...
  • }

Dispatcher接收NettyRpcEnv传入的NettyRpcEnv对象,也即是NettyRpcEnv实例自己;在Dispatcher中,存在三个集合对象,endpoints保存了RpcEndpoint实例名与EndpointData实例的映射,endpointRefs保存了RpcEndpoint实例与RpcEndpointRef实例的映射,receivers是一个阻塞队列,保存了EndpointData实例。

4.1. EndpointData

EndpointData是Dispatcher的私有内部类,根据上面Dispatcher中字段可知,每个RpcEndpoint端点名称都有对应的EndpointData实例;它的定义如下:

org.apache.spark.rpc.netty.Dispatcher.EndpointData
  • // RPC端点数据,Inbox与RpcEndpoint、NettyRpcEndpointRef通过此EndpointData相关联。
  • private class EndpointData(
  • val name: String,
  • val endpoint: RpcEndpoint,
  • val ref: NettyRpcEndpointRef) {
  • val inbox = new Inbox(ref, endpoint)
  • }

EndpointData初始化时内部会创建一个Inbox对象,我们后面再讨论该类的实现。对于EndpointData来说,其实我们比较关注它的创建和销毁操作。

4.2. RpcEndpoint的维护

EndpointData的存在与RpcEndpoint及RpcEndpointRef息息相关。从前面对NettyRpcEnv的分析可知,它对RpcEnv指定的setupEndpoint(...)endpointRef(...)两个方法的具体实现都是通过Dispatcher来完成的,其中setupEndpoint(...)用于注册RpcEndpoint,内部使用的是Dispatcher的registerRpcEndpoint(...)方法,endpointRef(...)用于查找指定的RpcEndpoint所对应的RpcEndpointRef,内部使用的是getRpcEndpointRef(...)方法;读者可以自行回顾源码。

不过除了注册和查找,Dispatcher还维护了对RpcEndpoint的关闭、移除等操作,下面将一一介绍。

4.2.1. 注册RpcEndpoint

Dispatcher的registerRpcEndpoint(...)方法用于注册RpcEndpoint,它会根据指定的RpcEndpoint端点名称和RpcEndpoint对象,创建NettyRpcEndpointRef和EndpointData实例,维护endpointsendpointRefs两个集合,源码如下:

  • // 注册RpcEndpoint
  • def registerRpcEndpoint(name: String, endpoint: RpcEndpoint): NettyRpcEndpointRef = {
  • // 根据RpcEndpoint所在NettyRpcEnv的地址和名称构造RpcEndpointAddress对象
  • val addr = RpcEndpointAddress(nettyEnv.address, name)
  • // 构造NettyRpcEndpointRef对象
  • val endpointRef = new NettyRpcEndpointRef(nettyEnv.conf, addr, nettyEnv)
  • synchronized {
  • // 检查状态
  • if (stopped) {
  • throw new IllegalStateException("RpcEnv has been stopped")
  • }
  • /**
  • * 将RpcEndpoint、NettyRpcEndpointRef包装为EndpointData对象,
  • * 并放入endpoints字典中,如果返回值不为null说明已经存在了同名的
  • */
  • if (endpoints.putIfAbsent(name, new EndpointData(name, endpoint, endpointRef)) != null) {
  • throw new IllegalArgumentException(s"There is already an RpcEndpoint called $name")
  • }
  • // 放入成功,还需要将对应的RpcEndpoint和RpcEndpointRef存入endpointRefs字典
  • val data = endpoints.get(name)
  • endpointRefs.put(data.endpoint, data.ref)
  • /**
  • * 将EndpointData放入到receivers队尾,MessageLoop线程异步获取到此EndpointData,
  • * 并处理其Inbox中刚刚放入的OnStart消息,注意该OnStart消息是在Inbox初始化时放入的
  • * 最终调用RpcEndpoint的OnStart方法在RpcEndpoint开始处理消息之前做一些准备工作。
  • */
  • receivers.offer(data) // for the OnStart message
  • }
  • endpointRef
  • }

可见,对于注册的RpcEndpoint端点,会创建包装了其地址和指定名称的RpcEndpointAddress地址对象,然后将该地址对象与SparkConf、Dispatcher所属的NettyRpcEnv对象封装到一个RpcEndpointRef引用中。

然后,会以参数指定的名称为键,包装了指定名称、RpcEndpoint和对应RpcEndpointRef引用的EndpointData实例为值,尝试存放到endpoints字典中;同时RpcEndpoint和与之对应的RpcEndpointRef也会被存放到endpointRefs字典中。

registerRpcEndpoint(...)方法中,还有一个非常重要的操作,即它会将参数指定的名称所对应的EndpointData实例添加到receivers队列,以处理OnStart消息。读者朋友分析到这里可能觉得非常奇怪,不知道OnStart消息是什么,从哪里来;这需要在后面的Inbox类中详细介绍。

4.2.2. 获取RpcEndpointRef

registerRpcEndpoint(...)方法会为注册的RpcEndpoint创建一个对应的RpcEndpointRef引用对象,并存放到内部的集合结构中,有存必有取,getRpcEndpointRef(...)方法可以根据指定的RpcEndpoint获取对应的RpcEndpointRef,实现非常简单,就是从endpointRefs字典中获取:

org.apache.spark.rpc.netty.Dispatcher#getRpcEndpointRef
  • def getRpcEndpointRef(endpoint: RpcEndpoint): RpcEndpointRef = endpointRefs.get(endpoint)

4.2.3. 移除RpcEndpoint

removeRpcEndpointRef(...)方法用于根据指定的RpcEndpoint从endpointRefs字典中移除对应的RpcEndpointRef,源码非常简单:

org.apache.spark.rpc.netty.Dispatcher#removeRpcEndpointRef
  • def removeRpcEndpointRef(endpoint: RpcEndpoint): Unit = endpointRefs.remove(endpoint)

4.2.4. 关闭RpcEndpoint

Dispatcher的stop(...)方法用于根据指定的RpcEndpointRef停止RpcEndpoint:

org.apache.spark.rpc.netty.Dispatcher#stop
  • // 取消RpcEndpoint的注册
  • def stop(rpcEndpointRef: RpcEndpointRef): Unit = {
  • synchronized {
  • if (stopped) {
  • // This endpoint will be stopped by Dispatcher.stop() method.
  • return
  • }
  • // 调用unregisterRpcEndpoint()取消注册
  • unregisterRpcEndpoint(rpcEndpointRef.name)
  • }
  • }

可见,其内部其实调用了私有的unregisterRpcEndpoint(...)方法取消对应RpcEndpoint的注册:

org.apache.spark.rpc.netty.Dispatcher#unregisterRpcEndpoint
  • // Should be idempotent
  • private def unregisterRpcEndpoint(name: String): Unit = {
  • // 从endpoints字典移除,得到移除的EndpointData
  • val data = endpoints.remove(name)
  • if (data != null) {
  • // 停止EndpointData中的Inbox
  • data.inbox.stop()
  • /**
  • * 将EndpointData放入receivers,
  • * 注意,在Inbox的stop()方法中,会向自己的Message队列放入一个OnStop消息,
  • * Inbox在处理OnStop消息时,会调用Dispatcher的removeRpcEndpointRef()方法移除对应的RpcEndpoint,
  • * 并调用RpcEndpoint的onStop()方法告知该RpcEndpoint已暂停,
  • * 可以参考 {@link Inbox#process} 方法处理OnStop消息的分支
  • */
  • receivers.offer(data) // for the OnStop message
  • }
  • // Don't clean `endpointRefs` here because it's possible that some messages are being processed
  • // now and they can use `getRpcEndpointRef`. So `endpointRefs` will be cleaned in Inbox via
  • // `removeRpcEndpointRef`.
  • }

取消注册操作首先将对应的EndpointData从endpoints中移除了,然后调用EndpointData内部Inbox的stop()方法进行停止操作,该方法会向Inbox的消息队列中放入一个OnStop消息,Inbox对OnStop消息的处理会触发Dispatcher移除并关闭对应的RpcEndpoint。关于Inbox和OnStop消息会在下面的内容中讲解。

4.3. Inbox

Inbox在前面出现的次数很多,意为“收件箱”,它是端点进行接收消息的抽象。Inbox实例是随着EndpointData实例化而创建的,回顾源码:

org.apache.spark.rpc.netty.Dispatcher.EndpointData
  • // RPC端点数据,Inbox与RpcEndpoint、NettyRpcEndpointRef通过此EndpointData相关联。
  • private class EndpointData(
  • val name: String,
  • val endpoint: RpcEndpoint,
  • val ref: NettyRpcEndpointRef) {
  • val inbox = new Inbox(ref, endpoint)
  • }

从上面对registerRpcEndpoint(...)方法的分析可知,每个RPC端点都会对应一个EndpointData实例,因此每个RPC端点都有一个自己的Inbox实例。Inbox的定义和重要字段如下:

org.apache.spark.rpc.netty.Inbox
  • /**
  • * An inbox that stores messages for an [[RpcEndpoint]] and posts messages to it thread-safely.
  • *
  • * 端点内的盒子。
  • * 每个RpcEndpoint都有一个对应的盒子,这个盒子里有个存储InboxMessage消息的列表messages。
  • * 所有的消息将缓存在messages列表里面,并由RpcEndpoint异步处理这些消息。
  • */
  • private[netty] class Inbox(
  • val endpointRef: NettyRpcEndpointRef,
  • val endpoint: RpcEndpoint)
  • extends Logging {
  • inbox => // Give this an alias so we can use it more clearly in closures.
  • // 使用链表容器保存Box内的消息
  • @GuardedBy("this")
  • protected val messages = new java.util.LinkedList[InboxMessage]()
  • /**
  • * True if the inbox (and its associated endpoint) is stopped.
  • * 标识当前Box是否停止
  • **/
  • @GuardedBy("this")
  • private var stopped = false
  • /**
  • * Allow multiple threads to process messages at the same time.
  • * 是否允许线程并发访问
  • **/
  • @GuardedBy("this")
  • private var enableConcurrent = false
  • /**
  • * The number of threads processing messages for this inbox.
  • * 处理Box中消息的线程数量
  • **/
  • @GuardedBy("this")
  • private var numActiveThreads = 0
  • // OnStart should be the first message to process
  • // InBox启动时,会默认放入一条OnStart消息
  • inbox.synchronized {
  • messages.add(OnStart)
  • }
  • ...
  • }

其中,stoppedenableConcurrentnumActiveThreads其实都比较好理解,注释都说明得比较清楚了,这里我们关注一下messages字段,它是一个LinkedList链表,泛型为InboxMessage,这是由于,每个进入Inbox的消息都是InboxMessage类型的,它会被添加到messages链表的尾部。

通过上面对EndpointData的和Inbox结构的分析,我们能够得到Dispatcher内部的结构示意图,如下:

1.Dispatcher内部结构示意图.png

在上面Inbox类定义的最后,默认向messages链表尾部添加了一个OnStart消息,而这段代码会在Inbox启动的时候就被执行,因此当EndpointData被创建时,其对应的Inbox内就已经放入了一条OnStart消息,这也就解释了前面Dispatcher的registerRpcEndpoint(...)方法最后是为了处理Inbox消息了。

同时,在上面讲解Dispatcher关闭RpcEndpoint操作时,其私有方法unregisterRpcEndpoint(...)会调用RpcEndpoint对应的Inbox的stop()方法,该方法会向Inbox的消息队列添加一个OnStop消息用于完成RpcEndpoint的关闭,源码回溯如下:

org.apache.spark.rpc.netty.Dispatcher#unregisterRpcEndpoint
  • // Should be idempotent
  • private def unregisterRpcEndpoint(name: String): Unit = {
  • // 从endpoints字典移除,得到移除的EndpointData
  • val data = endpoints.remove(name)
  • if (data != null) {
  • // 停止EndpointData中的Inbox
  • data.inbox.stop()
  • /**
  • * 将EndpointData放入receivers,
  • * 注意,在Inbox的stop()方法中,会向自己的Message队列放入一个OnStop消息,
  • * Inbox在处理OnStop消息时,会调用Dispatcher的removeRpcEndpointRef()方法移除对应的RpcEndpoint,
  • * 并调用RpcEndpoint的onStop()方法告知该RpcEndpoint已暂停,
  • * 可以参考 {@link Inbox#process} 方法处理OnStop消息的分支
  • */
  • receivers.offer(data) // for the OnStop message
  • }
  • // Don't clean `endpointRefs` here because it's possible that some messages are being processed
  • // now and they can use `getRpcEndpointRef`. So `endpointRefs` will be cleaned in Inbox via
  • // `removeRpcEndpointRef`.
  • }

data.inbox.stop()调用的是以下方法:

org.apache.spark.rpc.netty.Inbox#stop
  • def stop(): Unit = inbox.synchronized {
  • // The following codes should be in `synchronized` so that we can make sure "OnStop" is the last
  • // message
  • if (!stopped) {
  • // We should disable concurrent here. Then when RpcEndpoint.onStop is called, it's the only
  • // thread that is processing messages. So `RpcEndpoint.onStop` can release its resources
  • // safely.
  • // 不允许并发操作
  • enableConcurrent = false
  • // 停止Inbox
  • stopped = true
  • // 添加OnStop消息
  • messages.add(OnStop)
  • // Note: The concurrent events in messages will be processed one by one.
  • }
  • }

我们了解了OnStart和OnStop方法的投递,下面就来关注投递到Inbox的消息是如何被处理的。

4.4. 消息的处理

知道了Inbox消息从何而来,那Inbox内的消息是怎样被处理的呢?其实在前面讲解Dispatcher,它的threadpool字段我们并没有提到,定义如下:

org.apache.spark.rpc.netty.Dispatcher#threadpool
  • /**
  • * Thread pool used for dispatching messages.
  • * 用于对消息进行调度的线程池。
  • * 此线程池运行的任务都是MessageLoop线程任务
  • **/
  • private val threadpool: ThreadPoolExecutor = {
  • // 调度线程数
  • val numThreads = nettyEnv.conf.getInt("spark.rpc.netty.dispatcher.numThreads",
  • math.max(2, Runtime.getRuntime.availableProcessors()))
  • // 创建固定线程数线程池,线程名前缀为dispatcher-event-loop
  • val pool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "dispatcher-event-loop")
  • // 调度MessageLoop线程对象
  • for (i <- 0 until numThreads) {
  • pool.execute(new MessageLoop)
  • }
  • pool
  • }

该字段在Dispatcher初始化的时候就被创建了,它会根据spark.rpc.netty.dispatcher.numThreads配置的线程数,启动一个固定线程数量的ThreadPoolExecutor线程池,然后往其中提交了与线程数量个数相同的MessageLoop任务,这个MessageLoop就是用于处理Inbox内消息的任务线程,它的源码如下:

org.apache.spark.rpc.netty.Dispatcher.MessageLoop
  • /** Message loop used for dispatching messages. */
  • private class MessageLoop extends Runnable {
  • override def run(): Unit = {
  • try {
  • while (true) { // 不断循环
  • try {
  • // 从receivers中取出EndpointData对象
  • val data = receivers.take()
  • if (data == PoisonPill) { // PoisonPill意思是"毒药",用于终止当前线程
  • // Put PoisonPill back so that other MessageLoops can see it.
  • // 取出的是"毒药",重新放入队列,以便终止其它线程
  • receivers.offer(PoisonPill)
  • // 直接返回,终止当前线程
  • return
  • }
  • // 处理消息
  • data.inbox.process(Dispatcher.this)
  • } catch {
  • case NonFatal(e) => logError(e.getMessage, e)
  • }
  • }
  • } catch {
  • case ie: InterruptedException => // exit
  • }
  • }
  • }

该任务的主体是一个死循环,它不断地从Dispatcher的receivers阻塞队列中取出EndpointData对象,然后使用该EndpointData对象内部Inbox实例的process(...)方法进行消息处理。

在前面Dispatcher的registerRpcEndpoint(...)方法最后,正是将新创建的EndpointData放入了receivers阻塞队列,因此其Inbox中新放入的OnStart消息也会被处理了。

我们来看看Inbox的process(...)方法的实现,该方法代码非常多,但总体来说比较简单,源码如下:

org.apache.spark.rpc.netty.Inbox#process
  • /**
  • * Process stored messages.
  • * 处理消息
  • */
  • def process(dispatcher: Dispatcher): Unit = {
  • var message: InboxMessage = null
  • inbox.synchronized { // 并发控制
  • // 如果不允许并发操作,但已激活线程数不为0,则说明已有线程在处理消息
  • if (!enableConcurrent && numActiveThreads != 0) {
  • // 直接返回
  • return
  • }
  • // 从自己的message链表头取出消息
  • message = messages.poll()
  • if (message != null) { // 消息不为空
  • // 激活线程数自增
  • numActiveThreads += 1
  • } else {
  • // 消息为空,直接返回
  • return
  • }
  • }
  • // 走到这里说明取到消息了
  • while (true) {
  • safelyCall(endpoint) { // 对下面操作中出现非致命的异常,都会传递给endpoint的onError()方法
  • // 根据消息类型进行匹配,分别处理
  • message match {
  • case RpcMessage(_sender, content, context) =>
  • try {
  • // 发送并要求回复
  • endpoint.receiveAndReply(context).applyOrElse[Any, Unit](content, { msg =>
  • throw new SparkException(s"Unsupported message $message from ${_sender}")
  • })
  • } catch {
  • case NonFatal(e) =>
  • context.sendFailure(e)
  • // Throw the exception -- this exception will be caught by the safelyCall function.
  • // The endpoint's onError function will be called.
  • throw e
  • }
  • case OneWayMessage(_sender, content) =>
  • // 发送不要求回复
  • endpoint.receive.applyOrElse[Any, Unit](content, { msg =>
  • throw new SparkException(s"Unsupported message $message from ${_sender}")
  • })
  • case OnStart =>
  • // 收到Inbox的OnStart消息,调用RpcEndpoint的onStart()方法告知该RpcEndpoint已启动
  • endpoint.onStart()
  • // endpoint不是要求线程安全的RpcEndpoint
  • if (!endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) {
  • inbox.synchronized {
  • if (!stopped) {
  • // 则允许Inbox的并发操作
  • enableConcurrent = true
  • }
  • }
  • }
  • case OnStop =>
  • // 激活线程数
  • val activeThreads = inbox.synchronized { inbox.numActiveThreads }
  • assert(activeThreads == 1,
  • s"There should be only a single active thread but found $activeThreads threads.")
  • // 从Dispatcher中移除对应的RpcEndpoint
  • dispatcher.removeRpcEndpointRef(endpoint)
  • // 调用RpcEndpoint的onStop()方法告知该RpcEndpoint已暂停
  • endpoint.onStop()
  • // OnStop消息应该是最后一条消息
  • assert(isEmpty, "OnStop should be the last message")
  • case RemoteProcessConnected(remoteAddress) =>
  • // 调用RpcEndpoint的onConnected()方法告知该RpcEndpoint收到远程连接
  • endpoint.onConnected(remoteAddress)
  • case RemoteProcessDisconnected(remoteAddress) =>
  • // 调用RpcEndpoint的onDisconnected()方法告知该RpcEndpoint断开远程连接
  • endpoint.onDisconnected(remoteAddress)
  • case RemoteProcessConnectionError(cause, remoteAddress) =>
  • // 调用RpcEndpoint的onNetworkError()方法告知该RpcEndpoint处理连接错误
  • endpoint.onNetworkError(cause, remoteAddress)
  • }
  • }
  • inbox.synchronized { // 加锁
  • // "enableConcurrent" will be set to false after `onStop` is called, so we should check it
  • // every time.
  • // 不允许并发操作且激活线程数为1时
  • if (!enableConcurrent && numActiveThreads != 1) {
  • // If we are not the only one worker, exit
  • // 需要需要减少激活线程数
  • numActiveThreads -= 1
  • return
  • }
  • // 再次尝试取出一条消息,如果取不到则将激活线程数再减1
  • message = messages.poll()
  • if (message == null) {
  • numActiveThreads -= 1
  • return
  • }
  • }
  • }
  • }

process(...)方法会从Inbox自己的message链表中取出消息,如果消息不为null则自增numActiveThreads计数器,然后针对消息的类型分别进行处理,可见,大部分的操作就交给了Inbox所对应的RpcEndpoint及Dispatcher进行处理。这里我们先了解一下OnStart和OnStop消息的处理,其他消息将在后面的章节中详细介绍。

4.4.1. OnStart消息的处理

处理OnStart消息的源码片段如下:

org.apache.spark.rpc.netty.Inbox#process
  • case OnStart =>
  • // 收到Inbox的OnStart消息,调用RpcEndpoint的onStart()方法告知该RpcEndpoint已启动
  • endpoint.onStart()
  • // endpoint不是要求线程安全的RpcEndpoint
  • if (!endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) {
  • inbox.synchronized {
  • if (!stopped) {
  • // 则允许Inbox的并发操作
  • enableConcurrent = true
  • }
  • }
  • }

前面分析过,当一个RpcEndpoint初次被注册时,它对应的EndpointData就会被创建,此时对应Inbox会自己向自己投递一个OnStart消息;从OnStart消息的处理流程可知,它会调用注册的RpcEndpoint的onStart()方法通知RpcEndpoint是时候启动了。

4.4.2. OnStop消息的处理

处理OnStop消息的源码片段如下:

org.apache.spark.rpc.netty.Inbox#process
  • case OnStop =>
  • // 激活线程数
  • val activeThreads = inbox.synchronized { inbox.numActiveThreads }
  • assert(activeThreads == 1,
  • s"There should be only a single active thread but found $activeThreads threads.")
  • // 从Dispatcher中移除对应的RpcEndpoint
  • dispatcher.removeRpcEndpointRef(endpoint)
  • // 调用RpcEndpoint的onStop()方法告知该RpcEndpoint已暂停
  • endpoint.onStop()
  • // OnStop消息应该是最后一条消息
  • assert(isEmpty, "OnStop should be the last message")

前面分析过,当NettyRpcEnv在调用自己的stop(...)方法停止指定的RpcEndpoint时,最终会触发向RpcEndpoint对应的Inbox中投递OnStop消息;对OnStop消息的处理中,会从Dispatcher里移除RpcEndpoint的注册,然后调用该RpcEndpoint的onStop()方法通知RpcEndpoint是时候停止了。

通过分析消息的处理流程,我们能够得到Inbox对消息进行处理的示意图:

2.Inbox消息处理示意图.png

4.5. 消息的投递

明白了OnStart消息的产生以及消息的处理,我们来考察一下Dispatcher中对消息投递实现的方法。Dispatcher提供了多个消息投递方法,下面分别介绍。

4.5.1. 私有投递方法

postMessage(...)是Dispatcher的私有方法,但它是所有消息投递的底层方法,Dispatcher对外开放的投递方法底层都调用了该方法;它的源码如下:

org.apache.spark.rpc.netty.Dispatcher#postMessage
  • /**
  • * Posts a message to a specific endpoint.
  • *
  • * 将消息投递给特定的RpcEndpoint
  • *
  • * @param endpointName name of the endpoint.
  • * @param message the message to post
  • * @param callbackIfStopped callback function if the endpoint is stopped.
  • */
  • private def postMessage(
  • endpointName: String,
  • message: InboxMessage,
  • callbackIfStopped: (Exception) => Unit): Unit = {
  • val error = synchronized { // 加锁
  • // 获取对应的EndpointData
  • val data = endpoints.get(endpointName)
  • if (stopped) { // 判断Dispatcher是否在运行
  • Some(new RpcEnvStoppedException())
  • } else if (data == null) { // 获取的EndpointData为空
  • Some(new SparkException(s"Could not find $endpointName."))
  • } else {
  • // 将消息添加到EndpointData中的Inbox中
  • data.inbox.post(message)
  • // 将EndpointData放入receivers队列等待处理
  • receivers.offer(data)
  • None
  • }
  • }
  • // We don't need to call `onStop` in the `synchronized` block
  • // 有错误,交给回调
  • error.foreach(callbackIfStopped)
  • }

可见,postMessage(...)方法会根据发送消息的客户端地址的RpcEndpoint名称去endpoints中查找对应的EndpointData对象,如果找到了,就讲投递进来的消息存放到EndpointData的Inbox中,然后将该EndpointData放入receivers队列等待MessageLoop任务处理。如果Dispatcher停止了,或者是获取不到对应的EndpointData都会抛出异常,异常会交给回调函数callbackIfStopped处理。

4.5.2. 投递OneWayMessage

postOneWayMessage(...)用于专门负责投递OneWayMessage消息,源码如下:

org.apache.spark.rpc.netty.Dispatcher#postOneWayMessage
  • /** Posts a one-way message.
  • * 投递不需要回复的RPC消息
  • **/
  • def postOneWayMessage(message: RequestMessage): Unit = {
  • // 使用postMessage()方法发送
  • postMessage(message.receiver.name, OneWayMessage(message.senderAddress, message.content),
  • (e) => throw e)
  • }

由于该消息不用回复,因此其内部直接调用postMessage(...)方法投递,传入的回调函数会直接抛出产生的异常。

4.5.3. 投递来自远端RpcEndpoint的消息

postRemoteMessage(...)方法用于投递来自远端RpcEndpoint的消息,它的源码如下:

org.apache.spark.rpc.netty.Dispatcher#postRemoteMessage
  • /**
  • * Posts a message sent by a remote endpoint.
  • * 投递来自远程RpcEndpoint发送的消息
  • **/
  • def postRemoteMessage(message: RequestMessage, callback: RpcResponseCallback): Unit = {
  • // 回调上下文
  • val rpcCallContext =
  • new RemoteNettyRpcCallContext(nettyEnv, callback, message.senderAddress)
  • // 构造RpcMessage,包装了回调上下文
  • val rpcMessage = RpcMessage(message.senderAddress, message.content, rpcCallContext)
  • // 使用postMessage()方法发送
  • postMessage(message.receiver.name, rpcMessage, (e) => callback.onFailure(e))
  • }

它会将NettyRpcEnv对象、对应的回调对象以及RequestMessage消息对象再次封装为RpcMessage对象,然后调用postMessage(...)方法进行处理,传入的回调函数会在抛出异常的时候调用callbackonFailure(...)方法。

4.5.4. 投递本地消息

postLocalMessage(...)方法用于投递来自本地RpcEndpoint发送的消息;RpcEndpoint端点内部的通信也是通过消息机制实现的,因此RpcEndpoint自己向自己投递消息的情况很常见;对于这种本地消息则单独由postLocalMessage(...)方法来处理,它的源码如下:

  • /**
  • * Posts a message sent by a local endpoint.
  • * 投递来自本地RpcEndpoint发送的消息
  • **/
  • def postLocalMessage(message: RequestMessage, p: Promise[Any]): Unit = {
  • // 回调上下文
  • val rpcCallContext =
  • new LocalNettyRpcCallContext(message.senderAddress, p)
  • // 构造RPCMessage,包装了回调上下文
  • val rpcMessage = RpcMessage(message.senderAddress, message.content, rpcCallContext)
  • // 使用postMessage()方法发送
  • postMessage(message.receiver.name, rpcMessage, (e) => p.tryFailure(e))
  • }

postLocalMessage(...)方法将传入的回调Promise和RequestMessage消息对象再次封装为RpcMessage对象,然后调用postMessage(...)方法进行处理,传入的回调函数(e) => p.tryFailure(e)会在抛出异常的时候调用类型为Promise的参数ptryFailure(...)方法。

与投递来自远端RpcEndpoint的消息相比,本地消息构造的NettyRpcCallContext不一样,postRemoteMessage(...)方法中构造的NettyRpcCallContext是RemoteNettyRpcCallContext,而postLocalMessage(...)方法中的则是LocalNettyRpcCallContext,下面我们来讨论一下它们的实现。

4.5.5. 投递消息的回调

RemoteNettyRpcCallContext和LocalNettyRpcCallContext是用于对回调进行封装的上下文,它们都继承自NettyRpcCallContext抽象类,而NettyRpcCallContext又实现了RpcCallContext特质:

org.apache.spark.rpc.RpcCallContext
  • /**
  • * A callback that [[RpcEndpoint]] can use to send back a message or failure. It's thread-safe
  • * and can be called in any thread.
  • */
  • private[spark] trait RpcCallContext {
  • /**
  • * Reply a message to the sender. If the sender is [[RpcEndpoint]], its [[RpcEndpoint.receive]]
  • * will be called.
  • *
  • * 用于向发送者回复信息。
  • */
  • def reply(response: Any): Unit
  • /**
  • * Report a failure to the sender.
  • *
  • * 用于向发送者发送失败信息。
  • */
  • def sendFailure(e: Throwable): Unit
  • /**
  • * The sender of this message.
  • *
  • * 用于获取发送者的地址。
  • */
  • def senderAddress: RpcAddress
  • }

NettyRpcCallContext抽象类重写了RpcCallContext定义的几个规范方法:

org.apache.spark.rpc.netty.NettyRpcCallContext
  • private[netty] abstract class NettyRpcCallContext(override val senderAddress: RpcAddress)
  • extends RpcCallContext with Logging {
  • // 用于发送。
  • protected def send(message: Any): Unit
  • // 用于向发送者回复信息。
  • override def reply(response: Any): Unit = {
  • send(response)
  • }
  • // 用于向发送者发送失败信息。
  • override def sendFailure(e: Throwable): Unit = {
  • send(RpcFailure(e))
  • }
  • }

可见,NettyRpcCallContext虽然实现了RpcCallContext定义的接口方法,但又定义了自己的send(...)方法要求子类实现;RpcCallContext规定的reply(...)sendFailure(...)方法都调用了send(...)方法。

RemoteNettyRpcCallContext用于处理远程回复的回调上下文,它的源码如下:

org.apache.spark.rpc.netty.RemoteNettyRpcCallContext
  • /**
  • * A [[RpcCallContext]] that will call [[RpcResponseCallback]] to send the reply back.
  • */
  • private[netty] class RemoteNettyRpcCallContext(
  • nettyEnv: NettyRpcEnv,
  • callback: RpcResponseCallback,
  • senderAddress: RpcAddress)
  • extends NettyRpcCallContext(senderAddress) {
  • // 向客户端发送消息
  • override protected def send(message: Any): Unit = {
  • // 序列化消息
  • val reply = nettyEnv.serialize(message)
  • // 使用回调的onSuccess()方法进行发送
  • callback.onSuccess(reply)
  • }
  • }

RemoteNettyRpcCallContext重写了send(...),它会使用构造RemoteNettyRpcCallContext时传入的RpcResponseCallback类型的callback回调函数的onSuccess(...)方法发送消息。

读者可能对callback这个回调函数的来源有疑惑,其实这个回调来自于传输层TransportRequestHandler;以RPC消息为例,我们来跟踪一下:

  1. TransportRequestHandler处理RpcRequest消息时,会调用RpcHandler的receive(...)方法,此时传给该方法的最后一个参数就是RpcResponseCallback回调:
org.apache.spark.network.server.TransportRequestHandler#processRpcRequest
  • // 处理需要回复的RPC请求
  • private void processRpcRequest(final RpcRequest req) {
  • ...
  • rpcHandler.receive(reverseClient, req.body().nioByteBuffer(), new RpcResponseCallback() {
  • @Override
  • public void onSuccess(ByteBuffer response) {
  • respond(new RpcResponse(req.requestId, new NioManagedBuffer(response)));
  • }
  • @Override
  • public void onFailure(Throwable e) {
  • respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e)));
  • }
  • });
  • ...
  • }
  1. 由于我们这里使用的是NettyRpcHandler,它在处理RpcRequest的时候,最终会把传入的RpcResponseCallback回调对象传给Dispatcher的postRemoteMessage(...)方法:
org.apache.spark.rpc.netty.NettyRpcHandler#receive
  • override def receive(
  • client: TransportClient,
  • message: ByteBuffer,
  • callback: RpcResponseCallback): Unit = {
  • // 转换消息数据为RequestMessage对象
  • val messageToDispatch = internalReceive(client, message)
  • // 将消息投递到Dispatcher中对应的Inbox中
  • dispatcher.postRemoteMessage(messageToDispatch, callback)
  • }
  1. Dispatcher的postRemoteMessage(...)方法会将该RpcResponseCallback回调对象封装到RemoteNettyRpcCallContext中:
  • def postRemoteMessage(message: RequestMessage, callback: RpcResponseCallback): Unit = {
  • // 回调上下文
  • val rpcCallContext =
  • new RemoteNettyRpcCallContext(nettyEnv, callback, message.senderAddress)
  • // 构造RpcMessage,包装了回调上下文
  • val rpcMessage = RpcMessage(message.senderAddress, message.content, rpcCallContext)
  • // 使用postMessage()方法发送
  • postMessage(message.receiver.name, rpcMessage, (e) => callback.onFailure(e))
  • }

可见,对于客户端发送的RpcRequest消息而言,RemoteNettyRpcCallContext的send(...)方法把消息交给RpcResponseCallback的onSuccess(...)方法后,会由TransportRequestHandler构造为RpcResponse消息,然后经由其respond()通过Channel返回给发送RpcRequest的客户端。

4.6. 终止消息处理

从上面的分析我们知道,MessageLoop任务使用了while死循环的方式一直在线程池中运行,关于它的终止使用了一种巧妙的方式。在Dispatcher中定义了一个特殊的EndpointData,称之为“毒药”:

org.apache.spark.rpc.netty.Dispatcher#PoisonPill
  • /** A poison endpoint that indicates MessageLoop should exit its message loop.
  • * 毒药消息
  • **/
  • private val PoisonPill = new EndpointData(null, null, null)

在MessageLoop任务中,如果从Dispatcher的receivers阻塞队列中取到的EndpointData是该“毒药”消息,会将其放回然后直接return终止无限while循环,回顾源码:

  • /** Message loop used for dispatching messages. */
  • private class MessageLoop extends Runnable {
  • override def run(): Unit = {
  • try {
  • while (true) { // 不断循环
  • try {
  • // 从receivers中取出EndpointData对象
  • val data = receivers.take()
  • if (data == PoisonPill) { // PoisonPill意思是"毒药",用于终止当前线程
  • // Put PoisonPill back so that other MessageLoops can see it.
  • // 取出的是"毒药",重新放入队列,以便终止其它线程
  • receivers.offer(PoisonPill)
  • // 直接返回,终止当前线程
  • return
  • }
  • // 处理消息
  • data.inbox.process(Dispatcher.this)
  • } catch {
  • case NonFatal(e) => logError(e.getMessage, e)
  • }
  • }
  • } catch {
  • case ie: InterruptedException => // exit
  • }
  • }
  • }

之所以将“毒药”消息放回receivers队列,是为了终止其他的MessageLoop任务。

当调用Dispatcher的stop()方法终止Dispatcher时,会向自己的receivers队列中投放“毒药”消息:

org.apache.spark.rpc.netty.Dispatcher#stop
  • // 停止Dispatcher
  • def stop(): Unit = {
  • synchronized {
  • // 如果已经停止则直接返回
  • if (stopped) {
  • return
  • }
  • stopped = true
  • }
  • // Stop all endpoints. This will queue all endpoints for processing by the message loops.
  • // 将endpoints中所有的EndpointData全部移除,该操作会停止对应的Inbox
  • endpoints.keySet().asScala.foreach(unregisterRpcEndpoint)
  • // Enqueue a message that tells the message loops to stop.
  • // 向receivers队尾放入"毒药"消息,它会控制关闭MessageLoop线程
  • receivers.offer(PoisonPill)
  • // 关掉线程池
  • threadpool.shutdown()
  • }