亚洲激情专区-91九色丨porny丨老师-久久久久久久女国产乱让韩-国产精品午夜小视频观看

溫馨提示×

溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊×
其他方式登錄
點擊 登錄注冊 即表示同意《億速云用戶服務條款》

pytorch 圖像中的數據預處理和批標準化實例

發布時間:2020-09-03 15:40:52 來源:腳本之家 閱讀:166 作者:xckkcxxck 欄目:開發技術

目前數據預處理最常見的方法就是中心化和標準化。

中心化相當于修正數據的中心位置,實現方法非常簡單,就是在每個特征維度上減去對應的均值,最后得到 0 均值的特征。

標準化也非常簡單,在數據變成 0 均值之后,為了使得不同的特征維度有著相同的規模,可以除以標準差近似為一個標準正態分布,也可以依據最大值和最小值將其轉化為 -1 ~ 1 之間

批標準化:BN

在數據預處理的時候,我們盡量輸入特征不相關且滿足一個標準的正態分布,這樣模型的表現一般也較好。但是對于很深的網路結構,網路的非線性層會使得輸出的結果變得相關,且不再滿足一個標準的 N(0, 1) 的分布,甚至輸出的中心已經發生了偏移,這對于模型的訓練,特別是深層的模型訓練非常的困難。

所以在 2015 年一篇論文提出了這個方法,批標準化,簡而言之,就是對于每一層網絡的輸出,對其做一個歸一化,使其服從標準的正態分布,這樣后一層網絡的輸入也是一個標準的正態分布,所以能夠比較好的進行訓練,加快收斂速度。

batch normalization 的實現非常簡單,接下來寫一下對應的python代碼:

import sys
sys.path.append('..')
 
import torch
 
def simple_batch_norm_1d(x, gamma, beta):
  eps = 1e-5
  x_mean = torch.mean(x, dim=0, keepdim=True) # 保留維度進行 broadcast
  x_var = torch.mean((x - x_mean) ** 2, dim=0, keepdim=True)
  x_hat = (x - x_mean) / torch.sqrt(x_var + eps)
  return gamma.view_as(x_mean) * x_hat + beta.view_as(x_mean)
   
x = torch.arange(15).view(5, 3)
gamma = torch.ones(x.shape[1])
beta = torch.zeros(x.shape[1])
print('before bn: ')
print(x)
y = simple_batch_norm_1d(x, gamma, beta)
print('after bn: ')
print(y)

測試的時候該使用批標準化嗎?

答案是肯定的,因為訓練的時候使用了,而測試的時候不使用肯定會導致結果出現偏差,但是測試的時候如果只有一個數據集,那么均值不就是這個值,方差為 0 嗎?這顯然是隨機的,所以測試的時候不能用測試的數據集去算均值和方差,而是用訓練的時候算出的移動平均均值和方差去代替

下面我們實現以下能夠區分訓練狀態和測試狀態的批標準化方法

def batch_norm_1d(x, gamma, beta, is_training, moving_mean, moving_var, moving_momentum=0.1):
  eps = 1e-5
  x_mean = torch.mean(x, dim=0, keepdim=True) # 保留維度進行 broadcast
  x_var = torch.mean((x - x_mean) ** 2, dim=0, keepdim=True)
  if is_training:
    x_hat = (x - x_mean) / torch.sqrt(x_var + eps)
    moving_mean[:] = moving_momentum * moving_mean + (1. - moving_momentum) * x_mean
    moving_var[:] = moving_momentum * moving_var + (1. - moving_momentum) * x_var
  else:
    x_hat = (x - moving_mean) / torch.sqrt(moving_var + eps)
  return gamma.view_as(x_mean) * x_hat + beta.view_as(x_mean)

下面我們在卷積網絡下試用一下批標準化看看效果

def data_tf(x):
  x = np.array(x, dtype='float32') / 255
  x = (x - 0.5) / 0.5 # 數據預處理,標準化
  x = torch.from_numpy(x)
  x = x.unsqueeze(0)
  return x
 
train_set = mnist.MNIST('./data', train=True, transform=data_tf, download=True) # 重新載入數據集,申明定義的數據變換
test_set = mnist.MNIST('./data', train=False, transform=data_tf, download=True)
train_data = DataLoader(train_set, batch_size=64, shuffle=True)
test_data = DataLoader(test_set, batch_size=128, shuffle=False)
# 使用批標準化
class conv_bn_net(nn.Module):
  def __init__(self):
    super(conv_bn_net, self).__init__()
    self.stage1 = nn.Sequential(
      nn.Conv2d(1, 6, 3, padding=1),
      nn.BatchNorm2d(6),
      nn.ReLU(True),
      nn.MaxPool2d(2, 2),
      nn.Conv2d(6, 16, 5),
      nn.BatchNorm2d(16),
      nn.ReLU(True),
      nn.MaxPool2d(2, 2)
    )
    
    self.classfy = nn.Linear(400, 10)
  def forward(self, x):
    x = self.stage1(x)
    x = x.view(x.shape[0], -1)
    x = self.classfy(x)
    return x
 
net = conv_bn_net()
optimizer = torch.optim.SGD(net.parameters(), 1e-1) # 使用隨機梯度下降,學習率 0.1
 
 
train(net, train_data, test_data, 5, optimizer, criterion)

以上這篇pytorch 圖像中的數據預處理和批標準化實例就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支持億速云。

向AI問一下細節

免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。

AI

毕节市| 夏邑县| 孟津县| 清新县| 崇州市| 临高县| 青铜峡市| 高要市| 曲周县| 兴仁县| 蓬莱市| 民丰县| 东海县| 揭西县| 辽阳县| 曲靖市| 女性| 扬中市| 黎城县| 乐平市| 锦屏县| 类乌齐县| 左云县| 台州市| 太湖县| 新源县| 东乡| 揭东县| 洛宁县| 建昌县| 三都| 年辖:市辖区| 应用必备| 红河县| 祁阳县| 琼海市| 洮南市| 安康市| 宁德市| 萨嘎县| 玉山县|