MLSQL深度学习入门【二】-分布式模型训练

本文所有代码示例都基于MLSQL Engine最新版本 2.1.0-SNAPSHOT

本文将使用MLSQL Console的notebook演示深度学习的Hello world示例-mnist数据集。

环境要求、数据准备请参考前一篇: MLSQL深度学习入门【一】

系列文章:

  1. MLSQL 机器学习最简教程(无需Python!)
  2. MLSQL深度学习入门【一】
  3. MLSQL深度学习入门【二】-分布式模型训练
  4. MLSQL深度学习入门【三】-特征工程
  5. MLSQL深度学习入门【四】-Serving

我们会继续使用上一篇的数据集。

注意,在本文中,Ray环境是必选项

加载数据

我们可以在MLSQL Console的notebook里按如下方式加载数据:

load parquet.`/tmp/mnist` as mnist;

Python环境配置

接着需要指定Python Client在Driver端,并且选择相应的环境。因为我们从Python得到输出是模型目录,所以指定schema为file即可。对于非ETL,dataMode 都设置为model。

!python env "PYTHON_ENV=source /Users/allwefantasy/opt/anaconda3/bin/activate ray1.3.0";
!python conf "runIn=driver";
!python conf "schema=file";
!python conf "dataMode=model";

基础概念

现在,我们可以开始写python代码了,代码大概长这个样子:

第一步,我们还是需要在notebook的cell里提供下该cell的语言类型,数据表,以及是否需要缓存等信息。具体而言,就是下面的Annotation:

--%python
--%input=mnist
--%cache=true
--%output=mnist_model

接着,我们需要获取ray_context会话对象,方式如下:

ray_context = RayContext.connect(globals(),"127.0.0.1:10001")

该对象可以帮助我们在Python中获取数据和输出数据。在这里,我们指定了数据是mnist表,也就是前面通过load语法加载的表。我们可以通过下面的代码获得每个数据分区的引用:

data_servers = ray_context.data_servers()

data_servers 是一个数组。数组的长度就是分区的数目。

replica_num = len(data_servers)
print(f"total workers {replica_num}")

我们会在接下来的代码中,启动replica_num个TF Worker 去分布式训练我们的模型。

模型和分片数据

我们会使用上次的模型,他的样子大概是这样的:

def create_tf_model():    
    network = models.Sequential()
    network.add(layers.Dense(512,activation="relu",input_shape=(28*28,)))
    network.add(layers.Dense(10,activation="softmax"))
    network.compile(optimizer="sgd",loss="categorical_crossentropy",metrics=["accuracy"])
    return network

接着,我们可以通过下面的方法获取任意一个分区的数据:

def data_partition_creater(data_server):
    temp_data = [item for item in RayContext.collect_from([data_server])]
    train_images = np.array([np.array(item["image"]) for item in temp_data])
    train_labels = np_utils.to_categorical(np.array([item["label"] for item in temp_data])    )
    train_images = train_images.reshape((len(temp_data),28*28))
    return train_images,train_labels

在本文中,我们会使用Actor 来创建TF worker来完成分布式训练。Actor的数量取决于数据分区的数量。

创建 TF Worker

我们将会使用Ray的Actor来构建Worker,然后使用的Parameter Server的模式来进行训练。

Worker定义代码如下:

@ray.remote
class Network(object):
    def __init__(self,data_server):
        self.model = create_tf_model()
        # you can also save the data to local disk if the data is 
        # not fit in memory        
        self.train_images,self.train_labels = data_partition_creater(data_server)


    def train(self):
        history = self.model.fit(self.train_images,self.train_labels,batch_size=128)
        return history.history

    def get_weights(self):
        return self.model.get_weights()

    def set_weights(self, weights):
        # Note that for simplicity this does not handle the optimizer state.
        self.model.set_weights(weights)

    def get_final_model(self):
        model_path = os.path.join("/","tmp","minist_model")
        self.model.save(model_path)
        model_binary = [item for item in streaming_tar.build_rows_from_file(model_path)]
        return model_binary
    def shutdown(self):
        ray.actor.exit_actor()

启动TF Worker

现在,根据数据分片的数量启动Worker(注意,这些Worker都是分布在Ray集群上的独立的进程):

workers = [Network.remote(data_server) for data_server in data_servers]

开始训练

先开始第一个Epoch训练来获取每个Worker所包含模型的参数:

ray.get([worker.train.remote() for worker in workers])
_weights = ray.get([worker.get_weights.remote() for worker in workers])

定义一个更新参数的方法:

def epoch_train(weights):  
    sum_weights = reduce(lambda a,b: [(a1 + b1) for a1,b1 in zip(a,b)],weights)
    averaged_weights = [layer/replica_num for layer in sum_weights]
    [worker.set_weights.remote(averaged_weights) for worker in workers]
    ray.get([worker.train.remote() for worker in workers])
    return ray.get([worker.get_weights.remote() for worker in workers])

现在可以进行训练了:

for epoch in range(6):
   _weights = epoch_train(_weights)

你会看到如下日志:

返回模型

最后,调用随意选一个worker,保存下模型:

model_binary = ray.get(workers[0].get_final_model.remote())

关闭所有worker:

[worker.shutdown.remote() for worker in workers]

返回模型给系统:

ray_context.build_result(model_binary)

把模型保存到数据湖

save overwrite mnist_model as delta.`ai_model.mnist_model`;

输出如下:

完整Notebook如下

127.0.0.1_9002_ -2-

完整Python代码

--%python
--%input=mnist
--%output=mnist_model
--%cache=true

from functools import reduce
import os
import ray
import numpy as np
from tensorflow.keras import models,layers
from tensorflow.keras import utils as np_utils
from pyjava.api.mlsql import RayContext
from pyjava.storage import streaming_tar


ray_context = RayContext.connect(globals(),"127.0.0.1:10001")
data_servers = ray_context.data_servers()
replica_num = len(data_servers)
print(f"total workers {replica_num}")
def data_partition_creater(data_server):
    temp_data = [item for item in RayContext.collect_from([data_server])]
    train_images = np.array([np.array(item["image"]) for item in temp_data])
    train_labels = np_utils.to_categorical(np.array([item["label"] for item in temp_data])    )
    train_images = train_images.reshape((len(temp_data),28*28))
    return train_images,train_labels    

def create_tf_model():    
    network = models.Sequential()
    network.add(layers.Dense(512,activation="relu",input_shape=(28*28,)))
    network.add(layers.Dense(10,activation="softmax"))
    network.compile(optimizer="sgd",loss="categorical_crossentropy",metrics=["accuracy"])
    return network

@ray.remote
class Network(object):
    def __init__(self,data_server):
        self.model = create_tf_model()
        # you can also save the data to local disk if the data is 
        # not fit in memory        
        self.train_images,self.train_labels = data_partition_creater(data_server)


    def train(self):
        history = self.model.fit(self.train_images,self.train_labels,batch_size=128)
        return history.history

    def get_weights(self):
        return self.model.get_weights()

    def set_weights(self, weights):
        # Note that for simplicity this does not handle the optimizer state.
        self.model.set_weights(weights)

    def get_final_model(self):
        model_path = os.path.join("/","tmp","minist_model")
        self.model.save(model_path)
        model_binary = [item for item in streaming_tar.build_rows_from_file(model_path)]
        return model_binary
    def shutdown(self):
        ray.actor.exit_actor()

workers = [Network.remote(data_server) for data_server in data_servers]
ray.get([worker.train.remote() for worker in workers])
_weights = ray.get([worker.get_weights.remote() for worker in workers])

def epoch_train(weights):  
    sum_weights = reduce(lambda a,b: [(a1 + b1) for a1,b1 in zip(a,b)],weights)
    averaged_weights = [layer/replica_num for layer in sum_weights]
    ray.get([worker.set_weights.remote(averaged_weights) for worker in workers])
    ray.get([worker.train.remote() for worker in workers])
    return ray.get([worker.get_weights.remote() for worker in workers])

for epoch in range(6):
   _weights = epoch_train(_weights)

model_binary = ray.get(workers[0].get_final_model.remote())
[worker.shutdown.remote() for worker in workers]
ray_context.build_result(model_binary)

results matching ""

    No results matching ""