在PyTorch中,view()
函數用于調整張量的形狀。它的使用方式如下:
output = input.view(*shape)
這里的input
是輸入張量,shape
是一個元組,用于指定調整后的形狀。具體來說:
shape
中的每個元素可以是一個具體的維度大小,或者-1表示根據其他維度的大小自動計算。下面是一些示例:
import torch
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
# 將x的形狀調整為(3, 2)
output = x.view(3, 2)
print(output)
# 輸出:
# tensor([[1, 2],
# [3, 4],
# [5, 6]])
# 將x的形狀調整為(6, -1),其中-1表示自動計算
output = x.view(6, -1)
print(output)
# 輸出:
# tensor([[1],
# [2],
# [3],
# [4],
# [5],
# [6]])
# 將x的形狀調整為(1, 6)
output = x.view(1, 6)
print(output)
# 輸出:
# tensor([[1, 2, 3, 4, 5, 6]])
需要注意的是,調整后的形狀必須和原始張量的元素總數保持一致,否則會拋出錯誤。