PyJava Python SDK导读

PyJava项目是一个混合项目,主要包含Java/Scala SDK和Python SDK。 该项目要求也比较高,需要你熟悉Java/Scala/Python等语言,同时需要熟悉Ray/Spark等项目。

PyJava项目不是双向操作的,主要是为了方便在Java代码侧调用Python代码片段,最典型的比如:

val session = spark
....
val abc = df.rdd.mapPartitions { iter =>
  val enconder = RowEncoder.apply(struct).resolveAndBind()
  val envs = new util.HashMap[String, String]()
  envs.put(str(PythonConf.PYTHON_ENV), "source activate streamingpro-spark-2.4.x")
  val batch = new ArrowPythonRunner(
    Seq(ChainedPythonFunctions(Seq(PythonFunction(
      """
        |import pandas as pd
        |import numpy as np
        |for item in data_manager.fetch_once():
        |    print(item)
        |df = pd.DataFrame({'AAA': [4, 5, 6, 7],'BBB': [10, 20, 30, 40],'CCC': [100, 50, -30, -50]})
        |data_manager.set_output([[df['AAA'],df['BBB']]])
      """.stripMargin, envs, "python", "3.6")))), struct,
    timezoneid, Map()
  )
  val newIter = iter.map { irow =>
    enconder.toRow(irow)
  }
  val commonTaskContext = new SparkContextImp(TaskContext.get(), batch)
  val columnarBatchIter = batch.compute(Iterator(newIter), TaskContext.getPartitionId(), commonTaskContext)
  columnarBatchIter.flatMap { batch =>
    batch.rowIterator.asScala.map(_.copy)
  }
}

val wow = SparkUtils.internalCreateDataFrame(session, abc, StructType(Seq(StructField("AAA", LongType), StructField("BBB", LongType))), false)
wow.show()

PyJava使用Arrow实现Java进程和Python进程的数据传递。这里面就涉及到了Arrow版本在两边协调的问题。 同时从这个实例也可以看出,传统如PySpark,Py4J主要是从Python侧调用Java实例,而PyJava则主要是在Java侧调用Python代码片段。不过无论如何,他们都解决了数据传递问题。

今天我们的重点是Python SDK部分。

入口daemon.py

python/pyjava/daemon.py 是整个Python SDK的入口。该类会在Java中被调用,其主要作用是Fork Python Worker进程。 Woker进程主要逻辑则在 python/pyjava/worker.py里。

主要逻辑worker.py

worker.py的核心逻辑是在main函数里, 参数infile/outfile其实是输入输出流。main通过infile获取数据,通过outfile输出自己产生的数据,这些数据都是来自或者输出到java进程中。当然我们为了测试方便,也可以完全使用python去调用这段代码。

main的主要流程包括:

  1. 导入ray
  2. 设置自己的内存限制
  3. 读取配置参数
  4. 读取python脚本 command = utf8_deserializer.loads(infile)
  5. 执行python脚本,然后并且荣通过ArrowStreamSerializer来获取该python脚本要处理的数据,执行使用exec函数
  6. 最后通过ArrowStreamPandasSerializer把自己处理完的数据写回去

所以流程其实还是比较简单的。

用户自定义代码如何和外部交互

我们知道,用户会传递一段Python脚本给worker,worker会利用exec函数执行。那么为了正确执行代码,用户会给两个东西给代码,

  1. 配置,是一个map,系统会根据配置决定如何如何运行这段代码,
  2. 数据,Python代码最终是为了处理数据的。

这里主要体现在worker.py的第125-127行代码。

125 data_manager = PythonContext(context_id, input_data, conf)
126 n_local = {"data_manager": data_manager, "context": data_manager}
127 exec(code, n_local, n_local)

我们创建了一个PythonContext对象,然后将该对象传递进去。用户在自己的代码片段中可以通过直接使用名字为data_manager/context对象,该对象也是用户和系统交的核心。 尽管如此,exec其实存在非常多的问题,我们希望未来可以改写这一块。

PythonContext

现在进入 python/pyjava/api/PythonContext.py,他核心提供了两类API

  1. 获取数据
  2. 构建输出数据

获取数据的,比如 fetch_once_as_dataframe,fetch_once_as_rows等等,数量比较多,主要是为了方便。 构建输出数据的,则有非常底层的set_output,也有build_result. build_result 接受行迭代器,他内部会把数据转化为多个dataframe,接着将dataframe拆解成列,得到的格式如下:

[

    [
       [1,2,3], #column1 data
       ["a","b","c"], #column2 data
    ], # block
        [
         [1,3], #column1 data
         ["b","c"], #column2 data
    ], # block
    .....
]

然后再赋值给output_data变量。至于output_data变量为啥需要这么奇怪的格式,原因是因为他最终是通过pyarrow将数据传输出去的,而pyarrow本质上是列式存储格式,需要上面的格式。

RayContext

PyJava整合了Ray,我们希望使用Ray做真正的数据处理。所以PyJava通过RayContext提供了Ray的整合。首先,RayConext持有PythonContext的引用。其次,如果使用RayContext,我们获取数据的方式和PythonContext并不一样。Ray是通过Spark(Java端的Socket地址)获取数据的。那么我们又是如何获取这些地址的?答案是通过PythonContext,此时它得到的数据其实是一个或者多个地址。每个地址分别是对应一个Spark(Java端的Socket地址)分区。

具体获取地址的方式如下:

for item in self.python_context.fetch_once_as_rows():
            self.server_ids_in_ray.append(str(uuid.uuid4()))
            self.servers.append(DataServer(
                item["host"], int(item["port"]), item["timezone"]))

通过这些地址,我们可以简单的使用如下方式获取某个地址的数据。


rows = RayContext.fetch_once_as_rows(self.java_server))

那如何把处理完的数据输出到Spark侧呢?我们通过Ray的Actor,启动serving服务,也就是在Ray中对外提供socket地址,那么Spark侧就可以通过这些socket地址拉取到数据了。对应的完整实现在python/pyjava/serve.py

这里我们说下,使用Ray做数据处理,其实是拉取模式的,spark侧提供socket连接,pyjava侧则通过ray提供python侧的socket链接,数据源则是spark侧。接着spark侧开始从pyjava拉取数据,然后pyjava又从spark侧拉取数据,也就是整个流程是spark(java)侧推动的。数据只是从spark流入到pyjava里,pyjava处理完成再回流到spark侧。

这整个逻辑其实就几行代码就可以表示出来:

if func_for_row is not None:
    data = (func_for_row(item)
            for item in RayContext.fetch_once_as_rows(self.java_server))
elif func_for_rows is not None:
    data = func_for_rows(
        RayContext.fetch_once_as_rows(self.java_server))
self.server.serve(data)

通过RayContext.fetch_once_as_rows(self.java_server)拉取数据,拉取的数据马上又通过server.serve对外提供服务。整个过程使用迭代器,所以不会造成Python侧的内存压力。

results matching ""

    No results matching ""