tensor.squeeze函數和tensor.unsqueeze函數的使用詳解

tensor.squeeze() 和 tensor.unsqueeze() 是 PyTorch 中用于改變 tensor 形狀的兩個函數,它們的作用如下:

  • tensor.squeeze(dim=None, *, out=None) : 壓縮 tensor 中尺寸為 1 的維度,并返回新的 tensor。可以指定要壓縮的維度(默認為所有尺寸為 1 的維度均壓縮)。
  • tensor.unsqueeze(dim, *, out=None) : 在指定的位置插入一個新維度,并返回新的 tensor。dim 參數表示新插入的維度在哪個位置(從 0 開始),可以是負數,表示倒數第幾個維度。
  • squeeze 是壓縮維度,unsqueeze是增加維度.

下面給出例子來說明它們的使用。文章源自四五設計網-http://www.133122.cn/49496.html

tensor.squeeze()

import torch
?
# 創建一個形狀為 (1, 3, 1, 2) 的 tensor
x = torch.randn(1, 3, 1, 2)
print(x.shape)? # torch.Size([1, 3, 1, 2])
?
# 壓縮尺寸為 1 的維度
y = x.squeeze()
print(y.shape)? # torch.Size([3, 2])
?
# 指定要壓縮的維度
y = x.squeeze(dim=0)
print(y.shape)? # torch.Size([3, 1, 2])

在上面的例子中,我們創建了一個形狀為 (1, 3, 1, 2) 的 tensor,然后使用 squeeze() 函數壓縮了尺寸為 1 的維度。在第二個 squeeze() 調用中,我們指定了要壓縮的維度為 0,也就是第一個維度,因此第一個維度的大小被壓縮為 1,變成了形狀為 (3, 1, 2) 的 tensor。文章源自四五設計網-http://www.133122.cn/49496.html

tensor.unsqueeze()

import torch
?
# 創建一個形狀為 (3, 2) 的 tensor
x = torch.randn(3, 2)
print(x.shape)? # torch.Size([3, 2])
?
# 在維度 0 上插入新維度
y = x.unsqueeze(dim=0)
print(y.shape)? # torch.Size([1, 3, 2])
?
# 在維度 1 上插入新維度
y = x.unsqueeze(dim=1)
print(y.shape)? # torch.Size([3, 1, 2])
?
# 在倒數第二個維度上插入新維度
y = x.unsqueeze(dim=-2)
print(y.shape)? # torch.Size([3, 1, 2])

在上面的例子中,我們創建了一個形狀為 (3, 2) 的 tensor,然后使用 unsqueeze() 函數在不同的位置插入了新維度。在第一個 unsqueeze() 調用中,我們在維度 0 上插入了新維度,因此新的 tensor 形狀為 (1, 3, 2)。在第二個和第三個 unsqueeze() 調用中,我們分別在維度 1 和倒數第二個維度上插入了新維度,分別得到了形狀為 (3, 1, 2) 和 (3, 2, 1) 的 tensor。文章源自四五設計網-http://www.133122.cn/49496.html

到此這篇關于tensor.squeeze函數和tensor.unsqueeze函數的使用詳解的文章就介紹到這了文章源自四五設計網-http://www.133122.cn/49496.html 文章源自四五設計網-http://www.133122.cn/49496.html

繼續閱讀
我的微信
微信掃一掃
weinxin
我的微信
惠生活福利社
微信掃一掃
weinxin
我的公眾號
 
  • 本文由 四五設計網小助手 發表于 2024年7月2日10:24:42
  • 轉載請務必保留本文鏈接:http://www.133122.cn/49496.html

發表評論

匿名網友
:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

拖動滑塊以完成驗證