您好,登錄后才能下訂單哦!
小編給大家分享一下pytorch怎么實現變分自動編碼器,相信大部分人都還不怎么了解,因此分享這篇文章給大家參考一下,希望大家閱讀完這篇文章后大有收獲,下面讓我們一起去了解一下吧!
# -*- coding: utf-8 -*-
"""
Created on Fri Oct 12 11:42:19 2018
@author: www
"""
import os
import torch
from torch.autograd import Variable
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms as tfs
from torchvision.utils import save_image
im_tfs = tfs.Compose([
tfs.ToTensor(),
tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # 標準化
])
train_set = MNIST('E:data', transform=im_tfs)
train_data = DataLoader(train_set, batch_size=128, shuffle=True)
class VAE(nn.Module):
def __init__(self):
super(VAE, self).__init__()
self.fc1 = nn.Linear(784, 400)
self.fc21 = nn.Linear(400, 20) # mean
self.fc22 = nn.Linear(400, 20) # var
self.fc3 = nn.Linear(20, 400)
self.fc4 = nn.Linear(400, 784)
def encode(self, x):
h2 = F.relu(self.fc1(x))
return self.fc21(h2), self.fc22(h2)
def reparametrize(self, mu, logvar):
std = logvar.mul(0.5).exp_()
eps = torch.FloatTensor(std.size()).normal_()
if torch.cuda.is_available():
eps = Variable(eps.cuda())
else:
eps = Variable(eps)
return eps.mul(std).add_(mu)
def decode(self, z):
h4 = F.relu(self.fc3(z))
return F.tanh(self.fc4(h4))
def forward(self, x):
mu, logvar = self.encode(x) # 編碼
z = self.reparametrize(mu, logvar) # 重新參數化成正態分布
return self.decode(z), mu, logvar # 解碼,同時輸出均值方差
net = VAE() # 實例化網絡
if torch.cuda.is_available():
net = net.cuda()
x, _ = train_set[0]
x = x.view(x.shape[0], -1)
if torch.cuda.is_available():
x = x.cuda()
x = Variable(x)
_, mu, var = net(x)
print(mu)
#可以看到,對于輸入,網絡可以輸出隱含變量的均值和方差,這里的均值方差還沒有訓練
#下面開始訓練
reconstruction_function = nn.MSELoss(size_average=False)
def loss_function(recon_x, x, mu, logvar):
"""
recon_x: generating images
x: origin images
mu: latent mean
logvar: latent log variance
"""
MSE = reconstruction_function(recon_x, x)
# loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
KLD = torch.sum(KLD_element).mul_(-0.5)
# KL divergence
return MSE + KLD
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
def to_img(x):
'''
定義一個函數將最后的結果轉換回圖片
'''
x = 0.5 * (x + 1.)
x = x.clamp(0, 1)
x = x.view(x.shape[0], 1, 28, 28)
return x
for e in range(100):
for im, _ in train_data:
im = im.view(im.shape[0], -1)
im = Variable(im)
if torch.cuda.is_available():
im = im.cuda()
recon_im, mu, logvar = net(im)
loss = loss_function(recon_im, im, mu, logvar) / im.shape[0] # 將 loss 平均
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (e + 1) % 20 == 0:
print('epoch: {}, Loss: {:.4f}'.format(e + 1, loss.item()))
save = to_img(recon_im.cpu().data)
if not os.path.exists('./vae_img'):
os.mkdir('./vae_img')
save_image(save, './vae_img/image_{}.png'.format(e + 1))
補充:PyTorch 深度學習快速入門——變分自動編碼器
變分編碼器是自動編碼器的升級版本,其結構跟自動編碼器是類似的,也由編碼器和解碼器構成。
回憶一下,自動編碼器有個問題,就是并不能任意生成圖片,因為我們沒有辦法自己去構造隱藏向量,需要通過一張圖片輸入編碼我們才知道得到的隱含向量是什么,這時我們就可以通過變分自動編碼器來解決這個問題。
其實原理特別簡單,只需要在編碼過程給它增加一些限制,迫使其生成的隱含向量能夠粗略的遵循一個標準正態分布,這就是其與一般的自動編碼器最大的不同。
這樣我們生成一張新圖片就很簡單了,我們只需要給它一個標準正態分布的隨機隱含向量,這樣通過解碼器就能夠生成我們想要的圖片,而不需要給它一張原始圖片先編碼。
一般來講,我們通過 encoder 得到的隱含向量并不是一個標準的正態分布,為了衡量兩種分布的相似程度,我們使用 KL divergence,利用其來表示隱含向量與標準正態分布之間差異的 loss,另外一個 loss 仍然使用生成圖片與原圖片的均方誤差來表示。
KL divergence 的公式如下
重參數 為了避免計算 KL divergence 中的積分,我們使用重參數的技巧,不是每次產生一個隱含向量,而是生成兩個向量,一個表示均值,一個表示標準差,這里我們默認編碼之后的隱含向量服從一個正態分布的之后,就可以用一個標準正態分布先乘上標準差再加上均值來合成這個正態分布,最后 loss 就是希望這個生成的正態分布能夠符合一個標準正態分布,也就是希望均值為 0,方差為 1
所以最后我們可以將我們的 loss 定義為下面的函數,由均方誤差和 KL divergence 求和得到一個總的 loss
def loss_function(recon_x, x, mu, logvar):
"""
recon_x: generating images
x: origin images
mu: latent mean
logvar: latent log variance
"""
MSE = reconstruction_function(recon_x, x)
# loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
KLD = torch.sum(KLD_element).mul_(-0.5)
# KL divergence
return MSE + KLD
用 mnist 數據集來簡單說明一下變分自動編碼器
import os
import torch
from torch.autograd import Variable
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms as tfs
from torchvision.utils import save_image
im_tfs = tfs.Compose([
tfs.ToTensor(),
tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # 標準化
])
train_set = MNIST('./mnist', transform=im_tfs)
train_data = DataLoader(train_set, batch_size=128, shuffle=True)
class VAE(nn.Module):
def __init__(self):
super(VAE, self).__init__()
self.fc1 = nn.Linear(784, 400)
self.fc21 = nn.Linear(400, 20) # mean
self.fc22 = nn.Linear(400, 20) # var
self.fc3 = nn.Linear(20, 400)
self.fc4 = nn.Linear(400, 784)
def encode(self, x):
h2 = F.relu(self.fc1(x))
return self.fc21(h2), self.fc22(h2)
def reparametrize(self, mu, logvar):
std = logvar.mul(0.5).exp_()
eps = torch.FloatTensor(std.size()).normal_()
if torch.cuda.is_available():
eps = Variable(eps.cuda())
else:
eps = Variable(eps)
return eps.mul(std).add_(mu)
def decode(self, z):
h4 = F.relu(self.fc3(z))
return F.tanh(self.fc4(h4))
def forward(self, x):
mu, logvar = self.encode(x) # 編碼
z = self.reparametrize(mu, logvar) # 重新參數化成正態分布
return self.decode(z), mu, logvar # 解碼,同時輸出均值方差
net = VAE() # 實例化網絡
if torch.cuda.is_available():
net = net.cuda()
x, _ = train_set[0]
x = x.view(x.shape[0], -1)
if torch.cuda.is_available():
x = x.cuda()
x = Variable(x)
_, mu, var = net(x)
print(mu)
Variable containing: Columns 0 to 9 -0.0307 -0.1439 -0.0435 0.3472 0.0368 -0.0339 0.0274 -0.5608 0.0280 0.2742 Columns 10 to 19 -0.6221 -0.0894 -0.0933 0.4241 0.1611 0.3267 0.5755 -0.0237 0.2714 -0.2806 [torch.cuda.FloatTensor of size 1x20 (GPU 0)]
可以看到,對于輸入,網絡可以輸出隱含變量的均值和方差,這里的均值方差還沒有訓練 下面開始訓練
reconstruction_function = nn.MSELoss(size_average=False)
def loss_function(recon_x, x, mu, logvar):
"""
recon_x: generating images
x: origin images
mu: latent mean
logvar: latent log variance
"""
MSE = reconstruction_function(recon_x, x)
# loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
KLD = torch.sum(KLD_element).mul_(-0.5)
# KL divergence
return MSE + KLD
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
def to_img(x):
'''
定義一個函數將最后的結果轉換回圖片
'''
x = 0.5 * (x + 1.)
x = x.clamp(0, 1)
x = x.view(x.shape[0], 1, 28, 28)
return x
for e in range(100):
for im, _ in train_data:
im = im.view(im.shape[0], -1)
im = Variable(im)
if torch.cuda.is_available():
im = im.cuda()
recon_im, mu, logvar = net(im)
loss = loss_function(recon_im, im, mu, logvar) / im.shape[0] # 將 loss 平均
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (e + 1) % 20 == 0:
print('epoch: {}, Loss: {:.4f}'.format(e + 1, loss.data[0]))
save = to_img(recon_im.cpu().data)
if not os.path.exists('./vae_img'):
os.mkdir('./vae_img')
save_image(save, './vae_img/image_{}.png'.format(e + 1))
epoch: 20, Loss: 61.5803 epoch: 40, Loss: 62.9573 epoch: 60, Loss: 63.4285 epoch: 80, Loss: 64.7138 epoch: 100, Loss: 63.3343
變分自動編碼器雖然比一般的自動編碼器效果要好,而且也限制了其輸出的編碼 (code) 的概率分布,但是它仍然是通過直接計算生成圖片和原始圖片的均方誤差來生成 loss,這個方式并不好,生成對抗網絡中,我們會講一講這種方式計算 loss 的局限性,然后會介紹一種新的訓練辦法,就是通過生成對抗的訓練方式來訓練網絡而不是直接比較兩張圖片的每個像素點的均方誤差。
以上是“pytorch怎么實現變分自動編碼器”這篇文章的所有內容,感謝各位的閱讀!相信大家都有了一定的了解,希望分享的內容對大家有所幫助,如果還想學習更多知識,歡迎關注億速云行業資訊頻道!
免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。