Model
Save/Load
Ref: https://pytorch.org/tutorials/beginner/saving_loading_models.html
torch.save(model.state_dict(), PATH)
only save state_dict, which is an OrderedDict containing trained weights for each layer
use
model = TheModelClass(*args, **kwargs)load_state_dict(torch.load(PATH))
to load
torch.save(model, PATH)
save entire things about Model using
pickle
, since it will serialize all related things, it has a more restricted environment to load, including the defined model, the dictionary structure and so on.
Functions
A context manager to disable gradient synchronizations across DDP processes. Within this context, gradients will be accumulated on module variables, which will later be synchronized in the first forward-backward pass exiting the context.
Example in Megatron forward_backward_no_pipelining:
ONNX
ONNX export
args
: The support for kwargs is not sufficient, recommend to use positional arguments instead.args
will be interpreted as (*args) and passed to the model, therefore it should have a wrapper at the outmost. For example, if the model has several inputs, it should be like:input_names
&output_names
: will be readable in onnx model if setteddynamic_axes
: specify the dynamic axes of input(input_names
required). Recommend to use dict inside dict, for example:In this case, onnx model will replace the axes with its dynamic name in each input parameter
opset_version
: This specify the op set version used for generating onnx model. For example, dynamic slices requires opset>=10, therefore at least set opset=10 for normal execution.
Last updated