PyJava Ray foreach/map_iter实现原理

上篇PyJava Python SDK导读中,我们大致介绍了下大概脉络。这里,我们重点介绍下 Ray中foreach/map_iter的实现。

Ray foreach

foreach的语义是,利用Ray集群,对Spark(Java侧)的每条记录做处理。使用方式如下:

注意:使用功能foreach, 必须连接一个Ray集群


ray_context = RayContext.connect(globals(),"xxxx:xxxx")
def echo(row):
    row1 = {}
    row1["ProductName"]="jackm"
    row1["SubProduct"] = row["SubProduct"]
    return row1


buffer = ray_context.foreach(echo)

该任务在在Ray中的task数量(并行度),等价于Spark(Java侧)数据分区的数量。

现在问题来了,在PyJava中,上述功能是如何实现的呢?

先看函数签名(python/pyjava/api/mlsql.py):

def foreach(self, func_for_row):
        return self.setup(func_for_row)

我们可以传递一个函数给foreach,然后foreach会将函数传递给setup. setup代码如下:

# func_for_row = foreach
# func_for_rows = map_iter
def setup(self, func_for_row, func_for_rows=None):
        if self.is_setup:
            raise ValueError("setup can be only invoke once")
        self.is_setup = True
        import ray

        # 如果是在测试条件下(也就是不依赖于Spark端数据,那么使用mock数据)
        if not self.is_in_mlsql:  
            # map_iter 逻辑          
            if func_for_rows is not None:
                func = ray.remote(func_for_rows)
                return ray.get(func.remote(self.mock_data))
            else:#foreach逻辑
                func = ray.remote(func_for_row)

                def iter_all(rows):
                    return [ray.get(func.remote(row)) for row in rows]

                iter_all_func = ray.remote(iter_all)
                return ray.get(iter_all_func.remote(self.mock_data))

        buffer = []
        # 迭代所有数据分片地址,并且根据数据分片地址构建对应的Ray actor
        for server_info in self.build_servers_in_ray():            
            server = ray.experimental.get_actor(server_info.server_id)
            buffer.append(ray.get(server.connect_info.remote()))
            # 启动每个Ray Actor,对外提供socket服务,异步
            server.serve.remote(func_for_row, func_for_rows)
        # 收集所有Ray Actor地址返回给Java/Spark侧
        items = [vars(server) for server in buffer]  
        self.python_context.build_result(items, 1024)
        return buffer

上面的代码我加了备注,应该能够方便大家看明白了。

这里核心在于self.build_servers_in_ray() 以及 server.serve.remote.

self.build_servers_in_ray()其实就是构建和分区数目对等的actor,具体代码如下:

def build_servers_in_ray(self):
    import ray
    from pyjava.api.serve import RayDataServer
    buffer = []
    for (server_id, java_server) in zip(self.server_ids_in_ray, self.servers):

        rds = RayDataServer.options(name=server_id, detached=True, max_concurrency=2).remote(server_id, java_server,
                                                                                              0,
                                                                                              java_server.timezone)
        self.rds_list.append(rds)
        res = ray.get(rds.connect_info.remote())
        if self.is_dev:
            print("build RayDataServer server_id:{} java_server: {} servers:{}".format(server_id,
                                                                                        str(vars(
                                                                                            java_server)),
                                                                                        str(vars(res))))
        buffer.append(res)
    return buffer

我们使用了Ray的detached actor,并且max_concurrency设置为2(socket一个线程,执行任务一个线程)。

server.serve.remote则是启动Actor里的socket服务:

def serve(self, data):
        from pyjava.api.mlsql import PythonContext
        if not self.is_bind:
            raise SocketNotBindException(
                "Please invoke server.bind() before invoke server.serve")
        conn, addr = self.socket.accept()
        sockfile = conn.makefile("rwb", int(
            os.environ.get("BUFFER_SIZE", 65536)))
        infile = sockfile  # os.fdopen(os.dup(conn.fileno()), "rb", 65536)
        out = sockfile  # os.fdopen(os.dup(conn.fileno()), "wb", 65536)
        try:
            write_int(SpecialLengths.START_ARROW_STREAM, out)            
            out_data = ([df[name] for name in df] for df in
                    PythonContext.build_chunk_result(data, 1024))                                    
            self.out_ser.dump_stream(out_data, out)

            write_int(SpecialLengths.END_OF_DATA_SECTION, out)
            write_int(SpecialLengths.END_OF_STREAM, out)
            out.flush()
            if self.is_dev:
                print("all data  in ray task have been consumed.")
            read_int(infile)

总体逻辑还是比较简单的。类似write_int(SpecialLengths.END_OF_DATA_SECTION, out)之类的,则是和Java侧的约定,什么时候算是数据传输结束了,什么时候算整个流结束了等等。

Ray map_iter

map_iter的使用示例如下:

ray_context = RayContext.connect(globals(),"xxxx:xxxx")
def echo(rows):
    count = 0
    for row in rows:
      row1 = {}
      row1["ProductName"]="jackm"
      row1["SubProduct"] = row["SubProduct"]
      count = count + 1
      if count%1000 == 0:
          print("=====> " + str(time.time()) + " ====>" + str(count))
      yield row1

ray_context.map_iter(echo)

和foreach的唯一区别是,此时函数获得是一个迭代器而不是一条记录。实现上也和foreach完全一致,在此就不在赘叙。

results matching ""

    No results matching ""