Model
Save/Load
Ref: https://pytorch.org/tutorials/beginner/saving_loading_models.html
torch.save(model.state_dict(), PATH)
only save state_dict, which is an OrderedDict containing trained weights for each layer
use
model = TheModelClass(*args, **kwargs)load_state_dict(torch.load(PATH))
to load
torch.save(model, PATH)
save entire things about Model using
pickle
, since it will serialize all related things, it has a more restricted environment to load, including the defined model, the dictionary structure and so on.
Functions
A context manager to disable gradient synchronizations across DDP processes. Within this context, gradients will be accumulated on module variables, which will later be synchronized in the first forward-backward pass exiting the context.
Example in Megatron forward_backward_no_pipelining:
context_handler = dummy_handler
if isinstance(model, torchDDP):
context_handler = model.no_sync
losses_reduced = []
input_tensor, output_tensor_grad = None, None
with context_handler():
for i in range(get_num_microbatches() - 1):
output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced)
if not forward_only:
backward_step(optimizer, input_tensor, output_tensor,
output_tensor_grad)
# Run computation for last micro-batch out of context handler (want to
# synchronize gradients).
output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced)
if not forward_only:
backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad)
ONNX
ONNX export
torch.onnx.export(
model,
args,
output_path,
opset_version: int = None,
input_names: List = None,
output_names: List = None,
dynamic_axes: dict = None,
...
)
args
: The support for kwargs is not sufficient, recommend to use positional arguments instead.args
will be interpreted as (*args) and passed to the model, therefore it should have a wrapper at the outmost. For example, if the model has several inputs, it should be like:torch.onnx.export(model, (input1, input2, input3), ...)
input_names
&output_names
: will be readable in onnx model if setteddynamic_axes
: specify the dynamic axes of input(input_names
required). Recommend to use dict inside dict, for example:torch.onnx.export(model, (input1, input2, input3), dynamic_axes={ 'input1': { 0: 'batch', 1: 'sequence_length' }, 'input2': { 0: 'batch', 2: 'hidden_state' }, 'input3': { 1: 'sequence_length' } } )
In this case, onnx model will replace the axes with its dynamic name in each input parameter
opset_version
: This specify the op set version used for generating onnx model. For example, dynamic slices requires opset>=10, therefore at least set opset=10 for normal execution.
Last updated
Was this helpful?