Model

Save/Load

Ref: https://pytorch.org/tutorials/beginner/saving_loading_models.html

  • torch.save(model.state_dict(), PATH)

    • only save state_dict, which is an OrderedDict containing trained weights for each layer

    • use model = TheModelClass(*args, **kwargs)load_state_dict(torch.load(PATH)) to load

  • torch.save(model, PATH)

    • save entire things about Model using pickle, since it will serialize all related things, it has a more restricted environment to load, including the defined model, the dictionary structure and so on.

Functions

A context manager to disable gradient synchronizations across DDP processes. Within this context, gradients will be accumulated on module variables, which will later be synchronized in the first forward-backward pass exiting the context.

Example in Megatron forward_backward_no_pipelining:

context_handler = dummy_handler
if isinstance(model, torchDDP):
    context_handler = model.no_sync

losses_reduced = []
input_tensor, output_tensor_grad = None, None
with context_handler():
    for i in range(get_num_microbatches() - 1):
        output_tensor = forward_step(forward_step_func, data_iterator, model,
                                      input_tensor, losses_reduced)
        if not forward_only:
            backward_step(optimizer, input_tensor, output_tensor,
                          output_tensor_grad)

# Run computation for last micro-batch out of context handler (want to
# synchronize gradients).
output_tensor = forward_step(forward_step_func, data_iterator, model,
                              input_tensor, losses_reduced)
if not forward_only:
    backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad)

ONNX

ONNX export

  • args: The support for kwargs is not sufficient, recommend to use positional arguments instead. args will be interpreted as (*args) and passed to the model, therefore it should have a wrapper at the outmost. For example, if the model has several inputs, it should be like:

  • input_names & output_names: will be readable in onnx model if setted

  • dynamic_axes: specify the dynamic axes of input(input_names required). Recommend to use dict inside dict, for example:

    In this case, onnx model will replace the axes with its dynamic name in each input parameter

  • opset_version: This specify the op set version used for generating onnx model. For example, dynamic slices requires opset>=10, therefore at least set opset=10 for normal execution.

Last updated

Was this helpful?