Save model with – basic usage and code examples and torch.load() is two method that allow you to easily save and load tensors to disk as a file. The saved files are usually ended with .pt or .pth extension. This article is going to show you how to use and provide a few code examples from popular open source projects. arguments takes the following arguments :, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True)
Code language: Python (python)
  • obj is the object to be saved (required)
  • f should be either a file-like object (has write and flush method) or a string representation of a path or os.PathLike object containing a file name ended with .pt or .pth (required)
  • pickle_module is the module used for saving metadata and objects, default to Python’s pickle.
  • pickle_protocol – can be specified to override the default protocol
  • _use_new_zipfile_serialization – The 1.6 release of PyTorch switched to use a new zipfile-based file format. torch.load still retains the ability to load files in the old format. If for any reason you want to use the old format, pass the kwarg _use_new_zipfile_serialization=False. examples

Simple saving with

>>> # Save to file >>> x = torch.tensor([0, 1, 2, 3, 4]) >>>, '') >>> # Save to io.BytesIO buffer >>> buffer = io.BytesIO() >>>, buffer)
Code language: Python (python)

Saving multiple tensor at once with and torch.load use Python’s pickle by default, so you can save multiple tensors as part of Python objects like tuples, lists, and dicts:

>>> d = {'a': torch.tensor([1., 2.]), 'b': torch.tensor([3., 4.])} >>>, '')
Code language: Python (python)

Save checkpoint with

The following code snippet is taken from DDPAE – proof-of-concept code for the paper Learning to Decompose and Disentangle Representations for Video Prediction by Jun-Ting Hsieh, Bingbin Liu, De-An Huang, Li Fei-Fei, Juan Carlos Niebles.

# def save(self, ckpt_path, epoch): ''' Save checkpoint. ''' for name, net in self.nets.items(): if isinstance(net, torch.nn.DataParallel): module = net.module else: module = net path = os.path.join(ckpt_path, 'net_{}_{}.pth'.format(name, epoch)), path) for name, optimizer in self.optimizers.items(): path = os.path.join(ckpt_path, 'optimizer_{}_{}.pth'.format(name, epoch)), path)
Code language: Python (python) usage in open source project

The following code snippet is taken from MMDetection – an open source object detection toolbox based on PyTorch. It is a part of the OpenMMLab project.

# def process_checkpoint(in_file, out_file): checkpoint = torch.load(in_file, map_location='cpu') # remove optimizer for smaller file size if 'optimizer' in checkpoint: del checkpoint['optimizer'] # if it is necessary to remove some sensitive data in checkpoint['meta'], # add the code here. if torch.__version__ >= '1.6':, out_file, _use_new_zipfile_serialization=False) else:, out_file) sha = subprocess.check_output(['sha256sum', out_file]).decode() if out_file.endswith('.pth'): out_file_name = out_file[:-4] else: out_file_name = out_file final_file = out_file_name + f'-{sha[:8]}.pth' subprocess.Popen(['mv', out_file, final_file])
Code language: Python (python)

Common issues with Permission denied

  • Maybe the save path is not correct. Try changing it to an absolute path, should be a different path to isolate the issue.

  • There can be a problem if you try saving the file with .txt extension. Try changing to .pth and .pt.

  • The file might be open. Double check whether the parent directory is used by another application on your system. Turning off synchronization program may solve the problem.

Leave a Comment