Save model with torch.save – basic usage and code examples

torch.save() and torch.load() is two method that allow you to easily save and load tensors to disk as a file. The saved files are usually ended with .pt or .pth extension. This article is going to show you how to use torch.save and provide a few code examples from popular open source projects.

torch.save arguments

torch.save() takes the following arguments :

torch.save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True)
Code language: Python (python)
  • obj is the object to be saved (required)
  • f should be either a file-like object (has write and flush method) or a string representation of a path or os.PathLike object containing a file name ended with .pt or .pth (required)
  • pickle_module is the module used for saving metadata and objects, default to Python's pickle.
  • pickle_protocol – can be specified to override the default protocol
  • _use_new_zipfile_serialization - The 1.6 release of PyTorch switched torch.save to use a new zipfile-based file format. torch.load still retains the ability to load files in the old format. If for any reason you want torch.save to use the old format, pass the kwarg _use_new_zipfile_serialization=False.

torch.save examples

Simple saving with torch.save

>>> # Save to file >>> x = torch.tensor([0, 1, 2, 3, 4]) >>> torch.save(x, 'tensor.pt') >>> # Save to io.BytesIO buffer >>> buffer = io.BytesIO() >>> torch.save(x, buffer)
Code language: Python (python)

Saving multiple tensor at once with torch.save

torch.save and torch.load use Python’s pickle by default, so you can save multiple tensors as part of Python objects like tuples, lists, and dicts:

>>> d = {'a': torch.tensor([1., 2.]), 'b': torch.tensor([3., 4.])} >>> torch.save(d, 'tensor_dict.pt')
Code language: Python (python)

Save checkpoint with torch.save

The following code snippet is taken from DDPAE - proof-of-concept code for the paper Learning to Decompose and Disentangle Representations for Video Prediction by Jun-Ting Hsieh, Bingbin Liu, De-An Huang, Li Fei-Fei, Juan Carlos Niebles.

# https://github.com/jthsieh/DDPAE-video-prediction/blob/219e68301d24615410260c3d33c80ae74f6f2dc3/models/base_model.py def save(self, ckpt_path, epoch): ''' Save checkpoint. ''' for name, net in self.nets.items(): if isinstance(net, torch.nn.DataParallel): module = net.module else: module = net path = os.path.join(ckpt_path, 'net_{}_{}.pth'.format(name, epoch)) torch.save(module.state_dict(), path) for name, optimizer in self.optimizers.items(): path = os.path.join(ckpt_path, 'optimizer_{}_{}.pth'.format(name, epoch)) torch.save(optimizer.state_dict(), path)
Code language: Python (python)

torch.save usage in open source project

The following code snippet is taken from MMDetection - an open source object detection toolbox based on PyTorch. It is a part of the OpenMMLab project.

# https://github.com/open-mmlab/mmdetection/blob/bde7b4b7eea9dd6ee91a486c6996b2d68662366d/tools/model_converters/publish_model.py#L17 def process_checkpoint(in_file, out_file): checkpoint = torch.load(in_file, map_location='cpu') # remove optimizer for smaller file size if 'optimizer' in checkpoint: del checkpoint['optimizer'] # if it is necessary to remove some sensitive data in checkpoint['meta'], # add the code here. if torch.__version__ >= '1.6': torch.save(checkpoint, out_file, _use_new_zipfile_serialization=False) else: torch.save(checkpoint, out_file) sha = subprocess.check_output(['sha256sum', out_file]).decode() if out_file.endswith('.pth'): out_file_name = out_file[:-4] else: out_file_name = out_file final_file = out_file_name + f'-{sha[:8]}.pth' subprocess.Popen(['mv', out_file, final_file])
Code language: Python (python)

Common issues with torch.save

torch.save Permission denied

  • Maybe the save path is not correct. Try changing it to an absolute path, should be a different path to isolate the issue.

  • There can be a problem if you try saving the file with .txt extension. Try changing to .pth and .pt.

  • The file might be open. Double check whether the parent directory is used by another application on your system. Turning off synchronization program may solve the problem.

Click to rate this post!
[Total: 1 Average: 5]

Leave a Comment