Pytorch提交模型训练任务

最近更新时间:2021-04-09 18:18:02

查看PDF

本文档将向您介绍如何在Notebook通过SDK将训练的Pytorch模型发布到模型训练模块中。示例代码及注释如下:

示例:remote模式

from kai_sdk.train.pytorch.estimator import Pytorch
from kai_sdk.conf.kai_conf import KaiConf

def remote_train():
    kai_conf = KaiConf()
    pytorch_estimator = Pytorch(                              #框架为Pytorch
        framework_version='1.6.0-torchvision0.7-py3.6',       #输入Pytorch的框架版本
        entry_point='train.py',                               #入口执行文件
        source_dir='ks3://xc-train-ksc/yolo3-pytorch-master', #代码文件存放的目录 
        hyperparameters={                                     #训练任务的参数设置
            # "epochs": 1,
            # "batch-size": 200
        },
        envs={                                                #训练使用的环境变量信息
            # "ADDRESS": "beijing"
        },
        train_instance_type='remote',                         #训练的类型
        # base_job_name='base_job',                           #作业名称的前缀
        dependences=[                                         #训练代码依赖的pip包
            # 'numpy==1.15.1',
            # 'Pillow == 5.3.0',
            # 'scipy==1.1.0',
            # 'wget==3.2',
            # 'seaborn==0.9.0',
            # 'opencv-python',
            # 'easydict',
            # 'tqdm'
        ],
        # output_path='model',                                #模型输出目录
        kai_conf=kai_conf,
        register_model_config={                               #模型注册时,使用的配置信息
            'experiment_id': '439642fb-4cbd-4af0-8ee2-490e282b3f00',
            'model_name': 'remote-train-test-pytorch'
        },
        resource_dict={                                       #训练使用的资源
            'cpu': 4,
            'memory': 8,
            "gpu": 0,
            "type": "p40"
        },
        preprocess = "cp sources.list /etc/apt/sources.list && apt-get update && apt-get install sudo && sudo apt-get install -y libglib2.0-dev libsm6 libxext6 libxrender-dev libgl1-mesa-glx" #训练前置处理

    )
    pytorch_estimator.fit({'data_path': 'ks3://xc-train-ksc/VOC_test'})

if __name__ == '__main__':
    remote_train()

文档内容是否对您有帮助?

根本没帮助
文档较差
文档一般
文档不错
文档很好

在文档使用中是否遇到以下问题

内容不全,不深入
内容更新不及时
描述不清晰,比较混乱
系统或功能太复杂,缺乏足够的引导
内容冗长

更多建议

0/200

评价建议不能为空

提交成功!

非常感谢您的反馈,我们会继续努力做到更好!

问题反馈