亚洲激情专区-91九色丨porny丨老师-久久久久久久女国产乱让韩-国产精品午夜小视频观看

溫馨提示×

溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊×
其他方式登錄
點擊 登錄注冊 即表示同意《億速云用戶服務條款》

怎么在Pytorch中利用WGAN生成動漫頭像

發布時間:2021-03-04 14:03:32 來源:億速云 閱讀:355 作者:Leah 欄目:開發技術

本篇文章為大家展示了怎么在Pytorch中利用WGAN生成動漫頭像,內容簡明扼要并且容易理解,絕對能使你眼前一亮,通過這篇文章的詳細介紹希望你能有所收獲。

WGAN與GAN的不同

  • 去除sigmoid

  • 使用具有動量的優化方法,比如使用RMSProp

  • 要對Discriminator的權重做修整限制以確保lipschitz連續約

WGAN實戰卷積生成動漫頭像 

import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import os
from anime_face_generator.dataset import ImageDataset
 
batch_size = 32
num_epoch = 100
z_dimension = 100
dir_path = './wgan_img'
 
# 創建文件夾
if not os.path.exists(dir_path):
  os.mkdir(dir_path)
 
 
def to_img(x):
  """因為我們在生成器里面用了tanh"""
  out = 0.5 * (x + 1)
  return out
 
 
dataset = ImageDataset()
dataloader = DataLoader(dataset, batch_size=32, shuffle=False)
 
 
class Generator(nn.Module):
  def __init__(self):
    super().__init__()
 
    self.gen = nn.Sequential(
      # 輸入是一個nz維度的噪聲,我們可以認為它是一個1*1*nz的feature map
      nn.ConvTranspose2d(100, 512, 4, 1, 0, bias=False),
      nn.BatchNorm2d(512),
      nn.ReLU(True),
      # 上一步的輸出形狀:(512) x 4 x 4
      nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
      nn.BatchNorm2d(256),
      nn.ReLU(True),
      # 上一步的輸出形狀: (256) x 8 x 8
      nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
      nn.BatchNorm2d(128),
      nn.ReLU(True),
      # 上一步的輸出形狀: (256) x 16 x 16
      nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
      nn.BatchNorm2d(64),
      nn.ReLU(True),
      # 上一步的輸出形狀:(256) x 32 x 32
      nn.ConvTranspose2d(64, 3, 5, 3, 1, bias=False),
      nn.Tanh() # 輸出范圍 -1~1 故而采用Tanh
      # nn.Sigmoid()
      # 輸出形狀:3 x 96 x 96
    )
 
  def forward(self, x):
    x = self.gen(x)
    return x
 
  def weight_init(m):
    # weight_initialization: important for wgan
    class_name = m.__class__.__name__
    if class_name.find('Conv') != -1:
      m.weight.data.normal_(0, 0.02)
    elif class_name.find('Norm') != -1:
      m.weight.data.normal_(1.0, 0.02)
 
 
class Discriminator(nn.Module):
  def __init__(self):
    super().__init__()
    self.dis = nn.Sequential(
      nn.Conv2d(3, 64, 5, 3, 1, bias=False),
      nn.LeakyReLU(0.2, inplace=True),
      # 輸出 (64) x 32 x 32
 
      nn.Conv2d(64, 128, 4, 2, 1, bias=False),
      nn.BatchNorm2d(128),
      nn.LeakyReLU(0.2, inplace=True),
      # 輸出 (128) x 16 x 16
 
      nn.Conv2d(128, 256, 4, 2, 1, bias=False),
      nn.BatchNorm2d(256),
      nn.LeakyReLU(0.2, inplace=True),
      # 輸出 (256) x 8 x 8
 
      nn.Conv2d(256, 512, 4, 2, 1, bias=False),
      nn.BatchNorm2d(512),
      nn.LeakyReLU(0.2, inplace=True),
      # 輸出 (512) x 4 x 4
 
      nn.Conv2d(512, 1, 4, 1, 0, bias=False),
      nn.Flatten(),
      # nn.Sigmoid() # 輸出一個數(概率)
    )
 
  def forward(self, x):
    x = self.dis(x)
    return x
 
  def weight_init(m):
    # weight_initialization: important for wgan
    class_name = m.__class__.__name__
    if class_name.find('Conv') != -1:
      m.weight.data.normal_(0, 0.02)
    elif class_name.find('Norm') != -1:
      m.weight.data.normal_(1.0, 0.02)
 
 
def save(model, filename="model.pt", out_dir="out/"):
  if model is not None:
    if not os.path.exists(out_dir):
      os.mkdir(out_dir)
    torch.save({'model': model.state_dict()}, out_dir + filename)
  else:
    print("[ERROR]:Please build a model!!!")
 
 
import QuickModelBuilder as builder
 
if __name__ == '__main__':
  one = torch.FloatTensor([1]).cuda()
  mone = -1 * one
 
  is_print = True
  # 創建對象
  D = Discriminator()
  G = Generator()
  D.weight_init()
  G.weight_init()
 
  if torch.cuda.is_available():
    D = D.cuda()
    G = G.cuda()
 
  lr = 2e-4
  d_optimizer = torch.optim.RMSprop(D.parameters(), lr=lr, )
  g_optimizer = torch.optim.RMSprop(G.parameters(), lr=lr, )
  d_scheduler = torch.optim.lr_scheduler.ExponentialLR(d_optimizer, gamma=0.99)
  g_scheduler = torch.optim.lr_scheduler.ExponentialLR(g_optimizer, gamma=0.99)
 
  fake_img = None
 
  # ##########################進入訓練##判別器的判斷過程#####################
  for epoch in range(num_epoch): # 進行多個epoch的訓練
    pbar = builder.MyTqdm(epoch=epoch, maxval=len(dataloader))
    for i, img in enumerate(dataloader):
      num_img = img.size(0)
      real_img = img.cuda() # 將tensor變成Variable放入計算圖中
      # 這里的優化器是D的優化器
      for param in D.parameters():
        param.requires_grad = True
      # ########判別器訓練train#####################
      # 分為兩部分:1、真的圖像判別為真;2、假的圖像判別為假
 
      # 計算真實圖片的損失
      d_optimizer.zero_grad() # 在反向傳播之前,先將梯度歸0
      real_out = D(real_img) # 將真實圖片放入判別器中
      d_loss_real = real_out.mean(0).view(1)
      d_loss_real.backward(one)
 
      # 計算生成圖片的損失
      z = torch.randn(num_img, z_dimension).cuda() # 隨機生成一些噪聲
      z = z.reshape(num_img, z_dimension, 1, 1)
      fake_img = G(z).detach() # 隨機噪聲放入生成網絡中,生成一張假的圖片。 # 避免梯度傳到G,因為G不用更新, detach分離
      fake_out = D(fake_img) # 判別器判斷假的圖片,
      d_loss_fake = fake_out.mean(0).view(1)
      d_loss_fake.backward(mone)
 
      d_loss = d_loss_fake - d_loss_real
      d_optimizer.step() # 更新參數
 
      # 每次更新判別器的參數之后把它們的絕對值截斷到不超過一個固定常數c=0.01
      for parm in D.parameters():
        parm.data.clamp_(-0.01, 0.01)
 
      # ==================訓練生成器============================
      # ###############################生成網絡的訓練###############################
      for param in D.parameters():
        param.requires_grad = False
 
      # 這里的優化器是G的優化器,所以不需要凍結D的梯度,因為不是D的優化器,不會更新D
      g_optimizer.zero_grad() # 梯度歸0
 
      z = torch.randn(num_img, z_dimension).cuda()
      z = z.reshape(num_img, z_dimension, 1, 1)
      fake_img = G(z) # 隨機噪聲輸入到生成器中,得到一副假的圖片
      output = D(fake_img) # 經過判別器得到的結果
      # g_loss = criterion(output, real_label) # 得到的假的圖片與真實的圖片的label的loss
      g_loss = torch.mean(output).view(1)
      # bp and optimize
      g_loss.backward(one) # 進行反向傳播
      g_optimizer.step() # .step()一般用在反向傳播后面,用于更新生成網絡的參數
 
      # 打印中間的損失
      pbar.set_right_info(d_loss=d_loss.data.item(),
                g_loss=g_loss.data.item(),
                real_scores=real_out.data.mean().item(),
                fake_scores=fake_out.data.mean().item(),
                )
      pbar.update()
      try:
        fake_images = to_img(fake_img.cpu())
        save_image(fake_images, dir_path + '/fake_images-{}.png'.format(epoch + 1))
      except:
        pass
      if is_print:
        is_print = False
        real_images = to_img(real_img.cpu())
        save_image(real_images, dir_path + '/real_images.png')
    pbar.finish()
    d_scheduler.step()
    g_scheduler.step()
    save(D, "wgan_D.pt")
    save(G, "wgan_G.pt")

上述內容就是怎么在Pytorch中利用WGAN生成動漫頭像,你們學到知識或技能了嗎?如果還想學到更多技能或者豐富自己的知識儲備,歡迎關注億速云行業資訊頻道。

向AI問一下細節

免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。

AI

邵东县| 即墨市| 井陉县| 铜山县| 孟村| 舞钢市| 广平县| 赤峰市| 青河县| 察哈| 长丰县| 综艺| 龙里县| 五河县| 德钦县| 辽宁省| 醴陵市| 宽城| 佛学| 荣成市| 澄城县| 二连浩特市| 田林县| 甘肃省| 莆田市| 昌平区| 新平| 玛纳斯县| 科尔| 湖州市| 鹤山市| 盘山县| 略阳县| 建阳市| 金昌市| 纳雍县| 绵竹市| 遂昌县| 浪卡子县| 新蔡县| 固原市|