PyJava Ray foreach/map_iter实现原理
上篇PyJava Python SDK导读中,我们大致介绍了下大概脉络。这里,我们重点介绍下 Ray中foreach/map_iter的实现。
Ray foreach
foreach的语义是,利用Ray集群,对Spark(Java侧)的每条记录做处理。使用方式如下:
注意:使用功能foreach, 必须连接一个Ray集群
ray_context = RayContext.connect(globals(),"xxxx:xxxx")
def echo(row):
row1 = {}
row1["ProductName"]="jackm"
row1["SubProduct"] = row["SubProduct"]
return row1
buffer = ray_context.foreach(echo)
该任务在在Ray中的task数量(并行度),等价于Spark(Java侧)数据分区的数量。
现在问题来了,在PyJava中,上述功能是如何实现的呢?
先看函数签名(python/pyjava/api/mlsql.py):
def foreach(self, func_for_row):
return self.setup(func_for_row)
我们可以传递一个函数给foreach,然后foreach会将函数传递给setup. setup代码如下:
# func_for_row = foreach
# func_for_rows = map_iter
def setup(self, func_for_row, func_for_rows=None):
if self.is_setup:
raise ValueError("setup can be only invoke once")
self.is_setup = True
import ray
# 如果是在测试条件下(也就是不依赖于Spark端数据,那么使用mock数据)
if not self.is_in_mlsql:
# map_iter 逻辑
if func_for_rows is not None:
func = ray.remote(func_for_rows)
return ray.get(func.remote(self.mock_data))
else:#foreach逻辑
func = ray.remote(func_for_row)
def iter_all(rows):
return [ray.get(func.remote(row)) for row in rows]
iter_all_func = ray.remote(iter_all)
return ray.get(iter_all_func.remote(self.mock_data))
buffer = []
# 迭代所有数据分片地址,并且根据数据分片地址构建对应的Ray actor
for server_info in self.build_servers_in_ray():
server = ray.experimental.get_actor(server_info.server_id)
buffer.append(ray.get(server.connect_info.remote()))
# 启动每个Ray Actor,对外提供socket服务,异步
server.serve.remote(func_for_row, func_for_rows)
# 收集所有Ray Actor地址返回给Java/Spark侧
items = [vars(server) for server in buffer]
self.python_context.build_result(items, 1024)
return buffer
上面的代码我加了备注,应该能够方便大家看明白了。
这里核心在于self.build_servers_in_ray()
以及 server.serve.remote
.
self.build_servers_in_ray()
其实就是构建和分区数目对等的actor,具体代码如下:
def build_servers_in_ray(self):
import ray
from pyjava.api.serve import RayDataServer
buffer = []
for (server_id, java_server) in zip(self.server_ids_in_ray, self.servers):
rds = RayDataServer.options(name=server_id, detached=True, max_concurrency=2).remote(server_id, java_server,
0,
java_server.timezone)
self.rds_list.append(rds)
res = ray.get(rds.connect_info.remote())
if self.is_dev:
print("build RayDataServer server_id:{} java_server: {} servers:{}".format(server_id,
str(vars(
java_server)),
str(vars(res))))
buffer.append(res)
return buffer
我们使用了Ray的detached actor,并且max_concurrency设置为2(socket一个线程,执行任务一个线程)。
server.serve.remote
则是启动Actor里的socket服务:
def serve(self, data):
from pyjava.api.mlsql import PythonContext
if not self.is_bind:
raise SocketNotBindException(
"Please invoke server.bind() before invoke server.serve")
conn, addr = self.socket.accept()
sockfile = conn.makefile("rwb", int(
os.environ.get("BUFFER_SIZE", 65536)))
infile = sockfile # os.fdopen(os.dup(conn.fileno()), "rb", 65536)
out = sockfile # os.fdopen(os.dup(conn.fileno()), "wb", 65536)
try:
write_int(SpecialLengths.START_ARROW_STREAM, out)
out_data = ([df[name] for name in df] for df in
PythonContext.build_chunk_result(data, 1024))
self.out_ser.dump_stream(out_data, out)
write_int(SpecialLengths.END_OF_DATA_SECTION, out)
write_int(SpecialLengths.END_OF_STREAM, out)
out.flush()
if self.is_dev:
print("all data in ray task have been consumed.")
read_int(infile)
总体逻辑还是比较简单的。类似write_int(SpecialLengths.END_OF_DATA_SECTION, out)
之类的,则是和Java侧的约定,什么时候算是数据传输结束了,什么时候算整个流结束了等等。
Ray map_iter
map_iter的使用示例如下:
ray_context = RayContext.connect(globals(),"xxxx:xxxx")
def echo(rows):
count = 0
for row in rows:
row1 = {}
row1["ProductName"]="jackm"
row1["SubProduct"] = row["SubProduct"]
count = count + 1
if count%1000 == 0:
print("=====> " + str(time.time()) + " ====>" + str(count))
yield row1
ray_context.map_iter(echo)
和foreach的唯一区别是,此时函数获得是一个迭代器而不是一条记录。实现上也和foreach完全一致,在此就不在赘叙。