torch.save() and torch.load() is two method that allow you to easily save and load tensors to disk as a file. The saved files are usually ended with .pt
or .pth
extension. This article is going to show you how to use torch.save and provide a few code examples from popular open source projects.
torch.save arguments
torch.save() takes the following arguments :
torch.save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True)
Code language: Python (python)
- obj is the object to be saved (required)
- f should be either a file-like object (has write and flush method) or a string representation of a path or os.PathLike object containing a file name ended with .pt or .pth (required)
- pickle_module is the module used for saving metadata and objects, default to Python’s pickle.
- pickle_protocol – can be specified to override the default protocol
- _use_new_zipfile_serialization – The 1.6 release of PyTorch switched
torch.save
to use a new zipfile-based file format.torch.load
still retains the ability to load files in the old format. If for any reason you wanttorch.save
to use the old format, pass the kwarg_use_new_zipfile_serialization=False
.
torch.save examples
Simple saving with torch.save
>>> # Save to file
>>> x = torch.tensor([0, 1, 2, 3, 4])
>>> torch.save(x, 'tensor.pt')
>>> # Save to io.BytesIO buffer
>>> buffer = io.BytesIO()
>>> torch.save(x, buffer)
Code language: Python (python)
Saving multiple tensor at once with torch.save
torch.save
and torch.load
use Python’s pickle by default, so you can save multiple tensors as part of Python objects like tuples, lists, and dicts:
>>> d = {'a': torch.tensor([1., 2.]), 'b': torch.tensor([3., 4.])}
>>> torch.save(d, 'tensor_dict.pt')
Code language: Python (python)
Save checkpoint with torch.save
The following code snippet is taken from DDPAE – proof-of-concept code for the paper Learning to Decompose and Disentangle Representations for Video Prediction by Jun-Ting Hsieh, Bingbin Liu, De-An Huang, Li Fei-Fei, Juan Carlos Niebles.
# https://github.com/jthsieh/DDPAE-video-prediction/blob/219e68301d24615410260c3d33c80ae74f6f2dc3/models/base_model.py
def save(self, ckpt_path, epoch):
'''
Save checkpoint.
'''
for name, net in self.nets.items():
if isinstance(net, torch.nn.DataParallel):
module = net.module
else:
module = net
path = os.path.join(ckpt_path, 'net_{}_{}.pth'.format(name, epoch))
torch.save(module.state_dict(), path)
for name, optimizer in self.optimizers.items():
path = os.path.join(ckpt_path, 'optimizer_{}_{}.pth'.format(name, epoch))
torch.save(optimizer.state_dict(), path)
Code language: Python (python)
torch.save usage in open source project
The following code snippet is taken from MMDetection – an open source object detection toolbox based on PyTorch. It is a part of the OpenMMLab project.
# https://github.com/open-mmlab/mmdetection/blob/bde7b4b7eea9dd6ee91a486c6996b2d68662366d/tools/model_converters/publish_model.py#L17
def process_checkpoint(in_file, out_file):
checkpoint = torch.load(in_file, map_location='cpu')
# remove optimizer for smaller file size
if 'optimizer' in checkpoint:
del checkpoint['optimizer']
# if it is necessary to remove some sensitive data in checkpoint['meta'],
# add the code here.
if torch.__version__ >= '1.6':
torch.save(checkpoint, out_file, _use_new_zipfile_serialization=False)
else:
torch.save(checkpoint, out_file)
sha = subprocess.check_output(['sha256sum', out_file]).decode()
if out_file.endswith('.pth'):
out_file_name = out_file[:-4]
else:
out_file_name = out_file
final_file = out_file_name + f'-{sha[:8]}.pth'
subprocess.Popen(['mv', out_file, final_file])
Code language: Python (python)
Common issues with torch.save
torch.save Permission denied
Maybe the save path is not correct. Try changing it to an absolute path, should be a different path to isolate the issue.
There can be a problem if you try saving the file with
.txt
extension. Try changing to .pth and .pt.The file might be open. Double check whether the parent directory is used by another application on your system. Turning off synchronization program may solve the problem.