亚洲激情专区-91九色丨porny丨老师-久久久久久久女国产乱让韩-国产精品午夜小视频观看

溫馨提示×

溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊×
其他方式登錄
點擊 登錄注冊 即表示同意《億速云用戶服務條款》

pytorch 利用lstm做mnist手寫數字識別分類的實例

發布時間:2020-08-29 04:15:00 來源:腳本之家 閱讀:496 作者:xckkcxxck 欄目:開發技術

代碼如下,U我認為對于新手來說最重要的是學會rnn讀取數據的格式。

# -*- coding: utf-8 -*-
"""
Created on Tue Oct 9 08:53:25 2018
@author: www
"""
 
import sys
sys.path.append('..')
 
import torch
import datetime
from torch.autograd import Variable
from torch import nn
from torch.utils.data import DataLoader
 
from torchvision import transforms as tfs
from torchvision.datasets import MNIST
 
#定義數據
data_tf = tfs.Compose([
   tfs.ToTensor(),
   tfs.Normalize([0.5], [0.5])
])
train_set = MNIST('E:/data', train=True, transform=data_tf, download=True)
test_set = MNIST('E:/data', train=False, transform=data_tf, download=True)
 
train_data = DataLoader(train_set, 64, True, num_workers=4)
test_data = DataLoader(test_set, 128, False, num_workers=4)
 
#定義模型
class rnn_classify(nn.Module):
   def __init__(self, in_feature=28, hidden_feature=100, num_class=10, num_layers=2):
     super(rnn_classify, self).__init__()
     self.rnn = nn.LSTM(in_feature, hidden_feature, num_layers)#使用兩層lstm
     self.classifier = nn.Linear(hidden_feature, num_class)#將最后一個的rnn使用全連接的到最后的輸出結果
     
   def forward(self, x):
     #x的大小為(batch,1,28,28),所以我們需要將其轉化為rnn的輸入格式(28,batch,28)
     x = x.squeeze() #去掉(batch,1,28,28)中的1,變成(batch, 28,28)
     x = x.permute(2, 0, 1)#將最后一維放到第一維,變成(batch,28,28)
     out, _ = self.rnn(x) #使用默認的隱藏狀態,得到的out是(28, batch, hidden_feature)
     out = out[-1,:,:]#取序列中的最后一個,大小是(batch, hidden_feature)
     out = self.classifier(out) #得到分類結果
     return out
     
net = rnn_classify()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adadelta(net.parameters(), 1e-1)
 
#定義訓練過程
def get_acc(output, label):
  total = output.shape[0]
  _, pred_label = output.max(1)
  num_correct = (pred_label == label).sum().item()
  return num_correct / total
  
  
def train(net, train_data, valid_data, num_epochs, optimizer, criterion):
  if torch.cuda.is_available():
    net = net.cuda()
  prev_time = datetime.datetime.now()
  for epoch in range(num_epochs):
    train_loss = 0
    train_acc = 0
    net = net.train()
    for im, label in train_data:
      if torch.cuda.is_available():
        im = Variable(im.cuda()) # (bs, 3, h, w)
        label = Variable(label.cuda()) # (bs, h, w)
      else:
        im = Variable(im)
        label = Variable(label)
      # forward
      output = net(im)
      loss = criterion(output, label)
      # backward
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
 
      train_loss += loss.item()
      train_acc += get_acc(output, label)
 
    cur_time = datetime.datetime.now()
    h, remainder = divmod((cur_time - prev_time).seconds, 3600)
    m, s = divmod(remainder, 60)
    time_str = "Time %02d:%02d:%02d" % (h, m, s)
    if valid_data is not None:
      valid_loss = 0
      valid_acc = 0
      net = net.eval()
      for im, label in valid_data:
        if torch.cuda.is_available():
          im = Variable(im.cuda())
          label = Variable(label.cuda())
        else:
          im = Variable(im)
          label = Variable(label)
        output = net(im)
        loss = criterion(output, label)
        valid_loss += loss.item()
        valid_acc += get_acc(output, label)
      epoch_str = (
        "Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, "
        % (epoch, train_loss / len(train_data),
          train_acc / len(train_data), valid_loss / len(valid_data),
          valid_acc / len(valid_data)))
    else:
      epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " %
             (epoch, train_loss / len(train_data),
             train_acc / len(train_data)))
    prev_time = cur_time
    print(epoch_str + time_str)
    
train(net, train_data, test_data, 10, optimizer, criterion)    

以上這篇pytorch 利用lstm做mnist手寫數字識別分類的實例就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支持億速云。

向AI問一下細節

免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。

AI

昆明市| 当雄县| 余干县| 枣强县| 克什克腾旗| 武穴市| 峡江县| 偏关县| 巧家县| 辽中县| 周至县| 沛县| 玛沁县| 凌云县| 泽普县| 肃宁县| 钟山县| 砚山县| 新乐市| 柏乡县| 喀喇沁旗| 罗源县| 平阳县| 漯河市| 兰考县| 芦溪县| 桐乡市| 镇原县| 林周县| 武夷山市| 柘荣县| 南开区| 扎鲁特旗| 麦盖提县| 邵阳市| 新安县| 屏东市| 融水| 上林县| 武乡县| 南溪县|