PyJava Python SDK导读
PyJava项目是一个混合项目,主要包含Java/Scala SDK和Python SDK。 该项目要求也比较高,需要你熟悉Java/Scala/Python等语言,同时需要熟悉Ray/Spark等项目。
PyJava项目不是双向操作的,主要是为了方便在Java代码侧调用Python代码片段,最典型的比如:
val session = spark
....
val abc = df.rdd.mapPartitions { iter =>
val enconder = RowEncoder.apply(struct).resolveAndBind()
val envs = new util.HashMap[String, String]()
envs.put(str(PythonConf.PYTHON_ENV), "source activate streamingpro-spark-2.4.x")
val batch = new ArrowPythonRunner(
Seq(ChainedPythonFunctions(Seq(PythonFunction(
"""
|import pandas as pd
|import numpy as np
|for item in data_manager.fetch_once():
| print(item)
|df = pd.DataFrame({'AAA': [4, 5, 6, 7],'BBB': [10, 20, 30, 40],'CCC': [100, 50, -30, -50]})
|data_manager.set_output([[df['AAA'],df['BBB']]])
""".stripMargin, envs, "python", "3.6")))), struct,
timezoneid, Map()
)
val newIter = iter.map { irow =>
enconder.toRow(irow)
}
val commonTaskContext = new SparkContextImp(TaskContext.get(), batch)
val columnarBatchIter = batch.compute(Iterator(newIter), TaskContext.getPartitionId(), commonTaskContext)
columnarBatchIter.flatMap { batch =>
batch.rowIterator.asScala.map(_.copy)
}
}
val wow = SparkUtils.internalCreateDataFrame(session, abc, StructType(Seq(StructField("AAA", LongType), StructField("BBB", LongType))), false)
wow.show()
PyJava使用Arrow实现Java进程和Python进程的数据传递。这里面就涉及到了Arrow版本在两边协调的问题。 同时从这个实例也可以看出,传统如PySpark,Py4J主要是从Python侧调用Java实例,而PyJava则主要是在Java侧调用Python代码片段。不过无论如何,他们都解决了数据传递问题。
今天我们的重点是Python SDK部分。
入口daemon.py
python/pyjava/daemon.py
是整个Python SDK的入口。该类会在Java中被调用,其主要作用是Fork Python Worker进程。 Woker进程主要逻辑则在 python/pyjava/worker.py
里。
主要逻辑worker.py
worker.py的核心逻辑是在main
函数里, 参数infile/outfile其实是输入输出流。main通过infile获取数据,通过outfile输出自己产生的数据,这些数据都是来自或者输出到java进程中。当然我们为了测试方便,也可以完全使用python去调用这段代码。
main的主要流程包括:
- 导入ray
- 设置自己的内存限制
- 读取配置参数
- 读取python脚本
command = utf8_deserializer.loads(infile)
- 执行python脚本,然后并且荣通过ArrowStreamSerializer来获取该python脚本要处理的数据,执行使用
exec
函数 - 最后通过ArrowStreamPandasSerializer把自己处理完的数据写回去
所以流程其实还是比较简单的。
用户自定义代码如何和外部交互
我们知道,用户会传递一段Python脚本给worker,worker会利用exec
函数执行。那么为了正确执行代码,用户会给两个东西给代码,
- 配置,是一个map,系统会根据配置决定如何如何运行这段代码,
- 数据,Python代码最终是为了处理数据的。
这里主要体现在worker.py
的第125-127行代码。
125 data_manager = PythonContext(context_id, input_data, conf)
126 n_local = {"data_manager": data_manager, "context": data_manager}
127 exec(code, n_local, n_local)
我们创建了一个PythonContext对象,然后将该对象传递进去。用户在自己的代码片段中可以通过直接使用名字为data_manager/context对象,该对象也是用户和系统交的核心。
尽管如此,exec
其实存在非常多的问题,我们希望未来可以改写这一块。
PythonContext
现在进入 python/pyjava/api/PythonContext.py
,他核心提供了两类API
- 获取数据
- 构建输出数据
获取数据的,比如 fetch_once_as_dataframe,fetch_once_as_rows等等,数量比较多,主要是为了方便。 构建输出数据的,则有非常底层的set_output,也有build_result. build_result 接受行迭代器,他内部会把数据转化为多个dataframe,接着将dataframe拆解成列,得到的格式如下:
[
[
[1,2,3], #column1 data
["a","b","c"], #column2 data
], # block
[
[1,3], #column1 data
["b","c"], #column2 data
], # block
.....
]
然后再赋值给output_data变量。至于output_data变量为啥需要这么奇怪的格式,原因是因为他最终是通过pyarrow将数据传输出去的,而pyarrow本质上是列式存储格式,需要上面的格式。
RayContext
PyJava整合了Ray,我们希望使用Ray做真正的数据处理。所以PyJava通过RayContext提供了Ray的整合。首先,RayConext持有PythonContext的引用。其次,如果使用RayContext,我们获取数据的方式和PythonContext并不一样。Ray是通过Spark(Java端的Socket地址)获取数据的。那么我们又是如何获取这些地址的?答案是通过PythonContext,此时它得到的数据其实是一个或者多个地址。每个地址分别是对应一个Spark(Java端的Socket地址)分区。
具体获取地址的方式如下:
for item in self.python_context.fetch_once_as_rows():
self.server_ids_in_ray.append(str(uuid.uuid4()))
self.servers.append(DataServer(
item["host"], int(item["port"]), item["timezone"]))
通过这些地址,我们可以简单的使用如下方式获取某个地址的数据。
rows = RayContext.fetch_once_as_rows(self.java_server))
那如何把处理完的数据输出到Spark侧呢?我们通过Ray的Actor,启动serving服务,也就是在Ray中对外提供socket地址,那么Spark侧就可以通过这些socket地址拉取到数据了。对应的完整实现在python/pyjava/serve.py
。
这里我们说下,使用Ray做数据处理,其实是拉取模式的,spark侧提供socket连接,pyjava侧则通过ray提供python侧的socket链接,数据源则是spark侧。接着spark侧开始从pyjava拉取数据,然后pyjava又从spark侧拉取数据,也就是整个流程是spark(java)侧推动的。数据只是从spark流入到pyjava里,pyjava处理完成再回流到spark侧。
这整个逻辑其实就几行代码就可以表示出来:
if func_for_row is not None:
data = (func_for_row(item)
for item in RayContext.fetch_once_as_rows(self.java_server))
elif func_for_rows is not None:
data = func_for_rows(
RayContext.fetch_once_as_rows(self.java_server))
self.server.serve(data)
通过RayContext.fetch_once_as_rows(self.java_server)
拉取数据,拉取的数据马上又通过server.serve
对外提供服务。整个过程使用迭代器,所以不会造成Python侧的内存压力。