torch.squeeze and torch.unsqueeze – usage and code examples

torch.squeeze and torch.unsqueeze are two of the popular yet hard to fully understand functions in PyTorch. From official documentation, the squeeze method "returns a tensor with all the dimensions of input of size 1 removed", while unsqueeze "returns a new tensor with a dimension of size one inserted at the specified position".

This short article is going to try to explain the two functions and show you a few examples of how you would use torch.squeeze and torch.unsqueeze.

torch.squeeze vs. torch.unsqueeze

Simply put, torch.unsqueeze "adds" a superficial 1 dimension to tensor (at the specified dimension), while torch.squeeze removes all superficial 1 dimensions from tensor.

Below is a visual representation of what squeeze/unsqueeze do for an 2D matrix:

torch.squeeze vs. torch.unsqueeze

Looking at the image, you can get a glimpse of how things work. If you apply squeeze into any of the 3D tensors above, you’ll get the same result.

>>> x = torch.rand(3,2,dtype=torch.float)
>>> x
tensor([[0.3703, 0.9588],
        [0.8064, 0.9716],
        [0.9585, 0.7860]])
>>> torch.equal(x.squeeze(0), x.squeeze(1))

But unsqueeze is a little bit harder to understand. One does not simply unsqueeze a tensor without knowing which dimension he wish to unsqueeze it across (as a row or column, for example). That’s why you need to specify a dim argument whenever using torch.unsqueeze.

torch.squeeze usage

torch.squeeze takes in an input tensor and returns a tensor with all the dimensions of input of size 1 removed.

torch.squeeze(input, dim=None, **, out=None*) → Tensor

Required arguments :

  • input (Tensor) – the input tensor.

Optional arguments :

  • dim (int, optional) – if given, the input will be squeezed only in this dimension
  • out (Tensor, optional) – the output tensor.

Basic usage :

>>> b = torch.randn(4, 1, 4)
tensor([[[ 1.2912, -1.9050,  1.4771,  1.5517]],

        [[-0.3359, -0.2381, -0.3590,  0.0406]],

        [[-0.2460, -0.2326,  0.4511,  0.7255]],

        [[-0.1456, -0.0857, -0.8443,  1.1423]]])
c = b.squeeze(1)
>>> c
tensor([[ 1.2912, -1.9050,  1.4771,  1.5517],
        [-0.3359, -0.2381, -0.3590,  0.0406],
        [-0.2460, -0.2326,  0.4511,  0.7255],
        [-0.1456, -0.0857, -0.8443,  1.1423]])
>>> b.size()
torch.Size([4, 1, 4])
>>> c.size()
torch.Size([4, 4])

torch.unsqueeze usage

torch.unsqueeze takes in an input tensor and a dimension index,returns a new tensor with a dimension of size one inserted at the specified position.

torch.unsqueeze(input, dim) → Tensor

Required arguments :

  • input (Tensor) – the input tensor.
  • dim (int) – the index at which to insert the singleton dimension

Basic usage :

a = torch.randn(4, 4, 4)
>>> torch.unsqueeze(a, 0).size()
torch.Size([1, 4, 4, 4])

>>> torch.unsqueeze(a, 1).size()
torch.Size([4, 1, 4, 4])

>>> torch.unsqueeze(a, 2).size()
torch.Size([4, 4, 1, 4])

>>> torch.unsqueeze(a, 3).size()
torch.Size([4, 4, 4, 1])


We hope that the short explanation above helps you understand the difference between torch.squeeze and torch.unsqueeze. If you want more code examples of those two functions in real open source projects, check out our selected torch.squeeze code examples and torch.unsqueeze code examples.

If you have any suggestions or spot an error in the article, feel free to leave a comment below to let us know.

Click to rate this post!
[Total: 13 Average: 4.8]

1 thought on “torch.squeeze and torch.unsqueeze – usage and code examples”

  1. thank you very much for the explenation!


Leave a Comment