pyspark与py4j线程模型简析

事由

上周工作中遇到一个bug,现象是一个spark streaming的job会不定期地hang住,不退出也不继续运行。这个job经是用pyspark写的,以kafka为数据源,会在每个batch结束时将统计结果写入mysql。经过排查,我们在driver进程中发现有有若干线程都出于Sl状态(睡眠状态),进而使用gdb调试发现了一处死锁。

这是MySQLdb库旧版本中的一处bug,在此不再赘述,有兴趣的可以看这个issue。不过这倒是提起了我对另外一件事的兴趣,就是driver进程——严格的说应该是driver进程的python子进程——中的这些线程是从哪来的?当然,这些线程的存在很容易理解,我们开启了spark.streaming.concurrentJobs参数,有多个batch可以同时执行,每个线程对应一个batch。但翻遍pyspark的python代码,都没有找到有相关线程启动的地方,于是简单调研了一下pyspark到底是怎么工作的,做个记录。

本文概括

  1. Py4J的线程模型
  2. pyspark基本原理(driver端)
  3. CPython中的deque的线程安全

涉及软件版本

  • spark: 2.1.0
  • py4j: 0.10.4

Py4J

spark是由scala语言编写的,pyspark并没有像豆瓣开源的dpark用python复刻了spark,而只是提供了一层可以与原生JVM通信的python API,Py4J就是python与JVM之间的这座桥梁。这个库分为Java和Python两部分,基本原理是:

  1. Java部分,通过py4j.GatewayServer监听一个tcp socket(记做server_socket)
  2. Python部分,所有对JVM中对象的访问或者方法的调用,都是通过py4j.JavaGateway向上面这个socket完成的。
  3. 另外,Python部分在创建JavaGateway对象时,可以选择同时创建一个CallbackServer,它会在Python这册监听一个tcp socket(记做callback_socket),用来给Java回调Python代码提供一条渠道。
  4. Py4J提供了一套文本协议用来在tcp socket间传递命令。

pyspark driver工作流程

  1. 首先,一个spark job被提交后,如果被判定这是一个python的job,spark driver会找到相应的入口,即org.apache.spark.deploy.PythonRunnermain函数,这个函数中会启动GatewayServer
    // Launch a Py4J gateway server for the process to connect to; this will let it see our
    // Java system properties and such
    val gatewayServer = new py4j.GatewayServer(null, 0)
    val thread = new Thread(new Runnable() {
      override def run(): Unit = Utils.logUncaughtExceptions {
        gatewayServer.start()
      }
    })
    thread.setName("py4j-gateway-init")
    thread.setDaemon(true)
    thread.start()
  1. 然后,会创建一个Python子进程来运行我们提交上来的python入口文件,并把刚才GatewayServer监听的那个端口写入到子进程的环境变量中去(这样Python才知道要通过那个端口访问JVM)
    // Launch Python process
    val builder = new ProcessBuilder((Seq(pythonExec, formattedPythonFile) ++ otherArgs).asJava)
    val env = builder.environment()
    env.put("PYTHONPATH", pythonPath)
    // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u:
    env.put("PYTHONUNBUFFERED", "YES") // value is needed to be set to a non-empty string
    env.put("PYSPARK_GATEWAY_PORT", "" + gatewayServer.getListeningPort)
    // pass conf spark.pyspark.python to python process, the only way to pass info to
    // python process is through environment variable.
    sparkConf.get(PYSPARK_PYTHON).foreach(env.put("PYSPARK_PYTHON", _))
    builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize
  1. Python子进程这边,我们是通过pyspark提供的python API编写的这个程序,在创建SparkContext(python)时,会初始化_gateway变量(JavaGateway对象)和_jvm变量(JVMView对象)
    @classmethod
    def _ensure_initialized(cls, instance=None, gateway=None, conf=None):
        """
        Checks whether a SparkContext is initialized or not.
        Throws error if a SparkContext is already running.
        """
        with SparkContext._lock:
            if not SparkContext._gateway:
                SparkContext._gateway = gateway or launch_gateway(conf)
                SparkContext._jvm = SparkContext._gateway.jvm

            if instance:
                if (SparkContext._active_spark_context and
                        SparkContext._active_spark_context != instance):
                    currentMaster = SparkContext._active_spark_context.master
                    currentAppName = SparkContext._active_spark_context.appName
                    callsite = SparkContext._active_spark_context._callsite

                    # Raise error if there is already a running Spark context
                    raise ValueError(
                        "Cannot run multiple SparkContexts at once; "
                        "existing SparkContext(app=%s, master=%s)"
                        " created by %s at %s:%s "
                        % (currentAppName, currentMaster,
                            callsite.function, callsite.file, callsite.linenum))
                else:
                    SparkContext._active_spark_context = instance

其中launch_gateway函数可见pyspark/java_gateway.py

  1. 上面初始化的这个_jvm对象值得一说,在pyspark中很多对JVM的调用其实都是通过它来进行的,比如很多python种对应的spark对象都有一个_jsc变量,它是JVM中的SparkContext对象在Python中的影子,它是这么初始化的
    def _initialize_context(self, jconf):
        """
        Initialize SparkContext in function to allow subclass specific initialization
        """
        return self._jvm.JavaSparkContext(jconf)

这里_jvm为什么能直接调用JavaSparkContext这个JVM环境中的构造函数呢?我们看JVMView中的__getattr__方法:

    def __getattr__(self, name):
        if name == UserHelpAutoCompletion.KEY:
            return UserHelpAutoCompletion()

        answer = self._gateway_client.send_command(
            proto.REFLECTION_COMMAND_NAME +
            proto.REFL_GET_UNKNOWN_SUB_COMMAND_NAME + name + "\n" + self._id +
            "\n" + proto.END_COMMAND_PART)
        if answer == proto.SUCCESS_PACKAGE:
            return JavaPackage(name, self._gateway_client, jvm_id=self._id)
        elif answer.startswith(proto.SUCCESS_CLASS):
            return JavaClass(
                answer[proto.CLASS_FQN_START:], self._gateway_client)
        else:
            raise Py4JError("{0} does not exist in the JVM".format(name))

self._gateway_client.send_command其实就是向server_socket发送访问对象请求的命令了,最后根据响应值生成不同类型的影子对象,针对我们这里的JavaSparkContext,就是一个JavaClass对象。这个系列的类型还包括了JavaMemberJavaPackage等等,他们也通过__getattr__来实现Java对象属性访问以及方法的调用。

  1. 我们刚才介绍Py4j时说过Python端在创建JavaGateway时,可以选择同时创建一个CallbackClient,默认情况下,一个普通的pyspark job是不会启动回调服务的,因为用不着,所有的交互都是Python --> JVM这种模式的。那什么时候需要呢?streaming job就需要(具体流程我们稍后介绍),这就(终于!)引出了我们今天主要讨论的Py4J线程模型的问题。

Py4J线程模型

我们已经知道了Python与JVM双方向的通信分别是通过server_socketcallack_socket来完成的,这两个socket的处理模型都是多线程模型,即,每收到一个连接就启动一个线程来处理。我们只看Python --> JVM这条通路的情况,另外一边是一样的

Server端(Java)

    protected void processSocket(Socket socket) {
        try {
            this.lock.lock();
            if(!this.isShutdown) {
                socket.setSoTimeout(this.readTimeout);
                Py4JServerConnection gatewayConnection = this.createConnection(this.gateway, socket);
                this.connections.add(gatewayConnection);
                this.fireConnectionStarted(gatewayConnection);
            }
        } catch (Exception var6) {
            this.fireConnectionError(var6);
        } finally {
            this.lock.unlock();
        }
    }

继续看createConnection:

    protected Py4JServerConnection createConnection(Gateway gateway, Socket socket) throws IOException {
        GatewayConnection connection = new GatewayConnection(gateway, socket, this.customCommands, this.listeners);
        connection.startConnection();
        return connection;
    }

其中connection.startConnection其实就是创建了一个新线程,来负责处理这个连接。

Client端(Python)

我们来看GatewayClient中的send_command方法:

    def send_command(self, command, retry=True, binary=False):
        """Sends a command to the JVM. This method is not intended to be
           called directly by Py4J users. It is usually called by
           :class:`JavaMember` instances.

        :param command: the `string` command to send to the JVM. The command
         must follow the Py4J protocol.

        :param retry: if `True`, the GatewayClient tries to resend a message
         if it fails.

        :param binary: if `True`, we won't wait for a Py4J-protocol response
         from the other end; we'll just return the raw connection to the
         caller. The caller becomes the owner of the connection, and is
         responsible for closing the connection (or returning it this
         `GatewayClient` pool using `_give_back_connection`).

        :rtype: the `string` answer received from the JVM (The answer follows
         the Py4J protocol). The guarded `GatewayConnection` is also returned
         if `binary` is `True`.
        """
        connection = self._get_connection()
        try:
            response = connection.send_command(command)
            if binary:
                return response, self._create_connection_guard(connection)
            else:
                self._give_back_connection(connection)
        except Py4JNetworkError as pne:
            if connection:
                reset = False
                if isinstance(pne.cause, socket.timeout):
                    reset = True
                connection.close(reset)
            if self._should_retry(retry, connection, pne):
                logging.info("Exception while sending command.", exc_info=True)
                response = self.send_command(command, binary=binary)
            else:
                logging.exception(
                    "Exception while sending command.")
                response = proto.ERROR

        return response

这里这个self._get_connection是这么实现的

    def _get_connection(self):
        if not self.is_connected:
            raise Py4JNetworkError("Gateway is not connected.")
        try:
            connection = self.deque.pop()
        except IndexError:
            connection = self._create_connection()
        return connection

这里使用了一个deque(也就是Python标准库中的collections.deque)来维护一个连接池,如果有空闲的连接,就可以直接使用,如果没有,就新建一个连接。现在问题来了,如果deque不是线程安全的,那么这段代码在多线程环境就会有问题。那么deque是不是线程安全的呢?

deque的线程安全

当然是了,Py4J当然不会犯这样的低级错误,我们看标准库的文档:

Deques support thread-safe, memory efficient appends and pops from either side of the deque with approximately the same O(1) performance in either direction.

是线程安全的,不过措辞有点模糊,没有明确指出哪些方法是线程安全的,不过可以明确的是至少append的pop都是。之所以去查一下,是因为我也有点含糊,因为Python标准库还有另外一个Queue.Queue,在多线程编程中经常使用,肯定是线程安全的,于是很容易误以为deque不是线程安全的,所以我们才要一个新的Queue。这个问题,推荐阅读stackoverflow上Jonathan的这个答案——他的回答不是被采纳的最高票,不过我认为他的回答比高票更有说服力

  1. 高票答案一直强调说deque是线程安全的这个事实是个意外,是CPython中存在GIL造成的,其他Python解释器就不一定遵守。关于这一点我是不认同的,deque在CPython中的实现确实依赖的GIL才变成了线程安全的,但deque的双端append的pop是线程安全的这件事是白纸黑字写在Python文档中的,其他虚拟机的实现必须遵守,否则就不能称之为合格的Python实现。
  2. 那为什么还要有一个内部显式用了锁来做线程同步的Queue.Queue呢?Jonathan给出的回答是Queueputget可以是blocking的,而deque不行,这样一来,当你需要在多个线程中进行通信时(比如最简单的一个Producer – Consumer模式的实现),Queue往往是最佳选择。

关于deque是否是线程安全这个问题,我将调研的结果写在了这个知乎问题的答案下Python中的deque是线程安全的吗?,就不在赘述了,这篇文章已经太长了。

关于Py4J线程模型的问题,还可以参考官方文档中的解释

pyspark streaming与CallbackServer

刚才提到,如果是streaming的job,GatewayServer在初始化时会同时创建一个CallbackServer,提供JVM --> Python这条通路。

    @classmethod
    def _ensure_initialized(cls):
        SparkContext._ensure_initialized()
        gw = SparkContext._gateway

        java_import(gw.jvm, "org.apache.spark.streaming.*")
        java_import(gw.jvm, "org.apache.spark.streaming.api.java.*")
        java_import(gw.jvm, "org.apache.spark.streaming.api.python.*")

        # start callback server
        # getattr will fallback to JVM, so we cannot test by hasattr()
        if "_callback_server" not in gw.__dict__ or gw._callback_server is None:
            gw.callback_server_parameters.eager_load = True
            gw.callback_server_parameters.daemonize = True
            gw.callback_server_parameters.daemonize_connections = True
            gw.callback_server_parameters.port = 0
            gw.start_callback_server(gw.callback_server_parameters)
            cbport = gw._callback_server.server_socket.getsockname()[1]
            gw._callback_server.port = cbport
            # gateway with real port
            gw._python_proxy_port = gw._callback_server.port
            # get the GatewayServer object in JVM by ID
            jgws = JavaObject("GATEWAY_SERVER", gw._gateway_client)
            # update the port of CallbackClient with real port
            jgws.resetCallbackClient(jgws.getCallbackClient().getAddress(), gw._python_proxy_port)

        # register serializer for TransformFunction
        # it happens before creating SparkContext when loading from checkpointing
        cls._transformerSerializer = TransformFunctionSerializer(
            SparkContext._active_spark_context, CloudPickleSerializer(), gw)

为什么需要这样呢?一个streaming job通常需要调用foreachRDD,并提供一个函数,这个函数会在每个batch被回调:

    def foreachRDD(self, func):
        """
        Apply a function to each RDD in this DStream.
        """
        if func.__code__.co_argcount == 1:
            old_func = func
            func = lambda t, rdd: old_func(rdd)
        jfunc = TransformFunction(self._sc, func, self._jrdd_deserializer)
        api = self._ssc._jvm.PythonDStream
        api.callForeachRDD(self._jdstream, jfunc)

这里,Python函数func被封装成了一个TransformFunction对象,在scala端spark也定义了同样接口一个trait:

/**
 * Interface for Python callback function which is used to transform RDDs
 */
private[python] trait PythonTransformFunction {
  def call(time: Long, rdds: JList[_]): JavaRDD[Array[Byte]]

  /**
   * Get the failure, if any, in the last call to `call`.
   *
   * @return the failure message if there was a failure, or `null` if there was no failure.
   */
  def getLastFailure: String
}

这样是Py4J提供的机制,这样就可以让JVM通过这个影子接口回调Python中的对象了,下面就是scala中的callForeachRDD函数,它把PythonTransformFunction又封装了一层成为scala中的TransformFunction, 但不管如何封装,最后都会调用PythonTransformFunction接口中的call方法完成对Python的回调。

  /**
   * helper function for DStream.foreachRDD(),
   * cannot be `foreachRDD`, it will confusing py4j
   */
  def callForeachRDD(jdstream: JavaDStream[Array[Byte]], pfunc: PythonTransformFunction) {
    val func = new TransformFunction((pfunc))
    jdstream.dstream.foreachRDD((rdd, time) => func(Some(rdd), time))
  }

所以,终于要回答这个问题了,我们一开始看到的driver中的多个线程是怎么来的

  1. python调用foreachRDD提供一个TranformFunction给scala端
  2. scala端调用自己的foreachRDD进行正常的spark streaming作业
  3. 由于我们开启了spark.streaming.concurrentJobs,多个batch可以同时运行,这在scala端是通过线程池来进行的,每个batch都需要回调Python中的TranformFunction,而按照我们之前介绍的Py4J线程模型,多个并发的回调会发现没有可用的socket连接而生成新的,而在CallbackServer(Python)这端,每个新连接都会创建一个新线程来处理。这样就出现了driver的Python进程中出现多个线程的现象。

参考阅读

  1. MySQLdb1中的死锁issue
  2. queue-queue-vs-collections-deque
  3. Python中的deque是线程安全的吗?
  4. py4j线程模型官方文档
    原文作者:Garfieldog
    原文地址: https://www.jianshu.com/p/013fe44422c9
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞