🍪
cookielau
  • Introduction
  • Machine Learning
    • Distributed
      • Bookmarks
    • NLP
      • Transformers
    • MLC
      • Tensor Program Abstraction
      • End-to-End Module Execution
  • Framework
    • PyTorch
      • Bookmarks
      • Model
      • Shared
      • Miscellaneous
    • Tensorflow
      • Bookmarks
      • Model
      • Shared
      • Miscellaneous
    • CUDA
      • Bookmarks
    • DeepSpeed
    • Bagua
      • Model
      • Optimizer
    • Others
      • Bookmarks
  • About Me
    • 2022-04-28
  • Random Thoughts
  • Archives
    • CPP
      • Bookmarks
      • Container
      • Algorithm
      • FILE CONTROL
      • Virtual Table
      • Assembly
      • Key Words
      • Problems
      • Others
    • JAVA
      • String Container
      • Maps
    • PYTHON
      • Bookmarks
      • Python Tools
        • Batch Rename
        • Combine Excel
        • Excel Oprations
        • Read Write Excel
        • Rotate PDF
      • Library
        • Pandas Notes
        • Numpy Notes
        • Json Notes
      • Spider
        • Selenium Install
        • Selenium Locating
        • Selenium Errors
        • Selenium Basics
      • Django
        • Start Up
      • Others
    • LINUX
      • Installation
      • Cli Tools
      • WSL
      • Bugs
    • JUNIOR2
      • Economics
        • Chapter 0x01 经济管理概述
        • Chapter 0x02 微观市场机制分析
        • Chapter 0x03 生产决策与市场结构
        • Chapter 0x04 宏观经济市场分析
        • Chapter 0x05 管理的职能
        • Chapter 0x06 生产系统结构与战略
        • Chapter 0x0b 投资项目经济评价
        • Chapter 0x0f 投资项目经济评价
      • Computer Network
        • 概述
        • 分层模型
        • 物理层
        • 数据链路层
        • 网络层
        • 传输层
        • 应用层
        • HTTP(s)实验
        • [Practice]
      • Software Engineering
        • Introduction
        • Demand Analysis
        • Task Estimation
        • Presentation
      • Network Security
        • Chapter 0x01 概述
        • Chapter 0x02 密码学
        • Chapter 0x03 公钥体制
        • Chapter 0x04 消息认证
        • Chapter 0x05 密钥管理
        • Chapter 0x06 访问控制
        • Assignments
      • x86 Programming
        • Basic Knowledge
        • Program Design
        • System Interruption
        • Frequently used functions
    • MD&LaTex
      • Markdown
      • LaTex
    • NPM
      • NPM LINK
    • MyBlogs
      • 2020BUAA软工——“停下来,回头看”
      • 2020BUAA软工——“初窥构建之法”
      • 2020BUAA软工——“上手软件工程,PSP初体验!”
      • 2020BUAA软工——“深度评测官”
      • 2020BUAA软工——“并肩作战,平面交点Pro”
    • SC
      • PAC 2022
        • Lectures
      • OpenMP & MPI
        • MPI Overview
        • Message Passing Programming
        • OpenMP Overview
        • Work Sharing Directives
        • Annual Challenge
        • Future Topics in OpenMP
        • Tasks
        • OpenMP & MPI
    • Hardware
      • Nvidia GPU
        • Frequent Error
        • Memory Classification
        • CUDA_7_Streams_Simplify_Concurrency
        • Optimize_Data_Transfers_in_CUDA
        • Overlap_Data_Transfers_in_CUDA
        • Write_Flexible_Kernels_with_Grid-Stride_Loops
        • How_to_Access_Global_Memory_Efficiently
        • Using_Shared_Memory
      • Intel CPU
        • Construction
        • Optimization
        • Compilation
        • OpenMP
    • English
      • Vocab
      • Composition
    • Interview
      • Computer Network
Powered by GitBook
On this page
  • Save/Load
  • Functions
  • no_sync
  • ONNX
  • ONNX export

Was this helpful?

  1. Framework
  2. PyTorch

Model

PreviousBookmarksNextShared

Last updated 3 years ago

Was this helpful?

Save/Load

Ref: https://pytorch.org/tutorials/beginner/saving_loading_models.html

  • torch.save(model.state_dict(), PATH)

    • only save state_dict, which is an OrderedDict containing trained weights for each layer

    • use model = TheModelClass(*args, **kwargs)load_state_dict(torch.load(PATH)) to load

  • torch.save(model, PATH)

    • save entire things about Model using pickle, since it will serialize all related things, it has a more restricted environment to load, including the defined model, the dictionary structure and so on.

Functions

A context manager to disable gradient synchronizations across DDP processes. Within this context, gradients will be accumulated on module variables, which will later be synchronized in the first forward-backward pass exiting the context.

Example in Megatron :

context_handler = dummy_handler
if isinstance(model, torchDDP):
    context_handler = model.no_sync

losses_reduced = []
input_tensor, output_tensor_grad = None, None
with context_handler():
    for i in range(get_num_microbatches() - 1):
        output_tensor = forward_step(forward_step_func, data_iterator, model,
                                      input_tensor, losses_reduced)
        if not forward_only:
            backward_step(optimizer, input_tensor, output_tensor,
                          output_tensor_grad)

# Run computation for last micro-batch out of context handler (want to
# synchronize gradients).
output_tensor = forward_step(forward_step_func, data_iterator, model,
                              input_tensor, losses_reduced)
if not forward_only:
    backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad)

ONNX

ONNX export

torch.onnx.export(
    model,
    args,
    output_path,
    opset_version: int = None,
    input_names: List = None,
    output_names: List = None,
    dynamic_axes: dict = None,
    ...
)
  • args: The support for kwargs is not sufficient, recommend to use positional arguments instead. args will be interpreted as (*args) and passed to the model, therefore it should have a wrapper at the outmost. For example, if the model has several inputs, it should be like:

    torch.onnx.export(model, (input1, input2, input3), ...)
  • input_names & output_names: will be readable in onnx model if setted

  • dynamic_axes: specify the dynamic axes of input(input_names required). Recommend to use dict inside dict, for example:

    torch.onnx.export(model, (input1, input2, input3), 
      dynamic_axes={
          'input1': {
              0: 'batch',
              1: 'sequence_length'
          },
          'input2': {
              0: 'batch',
              2: 'hidden_state'
          },
          'input3': {
              1: 'sequence_length'
          }
      }
    )

    In this case, onnx model will replace the axes with its dynamic name in each input parameter

  • opset_version: This specify the op set version used for generating onnx model. For example, dynamic slices requires opset>=10, therefore at least set opset=10 for normal execution.

no_sync
forward_backward_no_pipelining