亚洲激情专区-91九色丨porny丨老师-久久久久久久女国产乱让韩-国产精品午夜小视频观看

溫馨提示×

溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊×
其他方式登錄
點擊 登錄注冊 即表示同意《億速云用戶服務條款》

將自己的數據集制作成TFRecord格式教程

發布時間:2020-09-07 11:21:37 來源:腳本之家 閱讀:793 作者:v1_vivian 欄目:開發技術

在使用TensorFlow訓練神經網絡時,首先面臨的問題是:網絡的輸入

此篇文章,教大家將自己的數據集制作成TFRecord格式,feed進網絡,除了TFRecord格式,TensorFlow也支持其他格

式的數據,此處就不再介紹了。建議大家使用TFRecord格式,在后面可以通過api進行多線程的讀取文件隊列。

1. 原本的數據集

此時,我有兩類圖片,分別是xiansu100,xiansu60,每一類中有10張圖片。

將自己的數據集制作成TFRecord格式教程

2.制作成TFRecord格式

tfrecord會根據你選擇輸入文件的類,自動給每一類打上同樣的標簽。如在本例中,只有0,1 兩類,想知道文件夾名與label關系的,可以自己保存起來。

#生成整數型的屬性
def _int64_feature(value):
 return tf.train.Feature(int64_list = tf.train.Int64List(value = [value]))
 
#生成字符串類型的屬性
def _bytes_feature(value):
 return tf.train.Feature(bytes_list = tf.train.BytesList(value = [value]))
 
#制作TFRecord格式
def createTFRecord(filename,mapfile):
 class_map = {}
 data_dir = '/home/wc/DataSet/traffic/testTFRecord/'
 classes = {'xiansu60','xiansu100'}
 #輸出TFRecord文件的地址
 
 writer = tf.python_io.TFRecordWriter(filename)
 
 for index,name in enumerate(classes):
  class_path=data_dir+name+'/'
  class_map[index] = name
  for img_name in os.listdir(class_path):
   img_path = class_path + img_name #每個圖片的地址
   img = Image.open(img_path)
   img= img.resize((224,224))
   img_raw = img.tobytes()   #將圖片轉化成二進制格式
   example = tf.train.Example(features = tf.train.Features(feature = {
    'label':_int64_feature(index),
    'image_raw': _bytes_feature(img_raw)
   }))
   writer.write(example.SerializeToString())
 writer.close()
 
 txtfile = open(mapfile,'w+')
 for key in class_map.keys():
  txtfile.writelines(str(key)+":"+class_map[key]+"\n")
 txtfile.close()

此段代碼,運行完后會產生生成的.tfrecord文件。

3. 讀取TFRecord的數據,進行解析,此時使用了文件隊列以及多線程

#讀取train.tfrecord中的數據
def read_and_decode(filename): 
 #創建一個reader來讀取TFRecord文件中的樣例
 reader = tf.TFRecordReader()
 #創建一個隊列來維護輸入文件列表
 filename_queue = tf.train.string_input_producer([filename], shuffle=False,num_epochs = 1)
 #從文件中讀出一個樣例,也可以使用read_up_to一次讀取多個樣例
 _,serialized_example = reader.read(filename_queue)
#  print _,serialized_example
 
 #解析讀入的一個樣例,如果需要解析多個,可以用parse_example
 features = tf.parse_single_example(
 serialized_example,
 features = {'label':tf.FixedLenFeature([], tf.int64),
    'image_raw': tf.FixedLenFeature([], tf.string),})
 #將字符串解析成圖像對應的像素數組
 img = tf.decode_raw(features['image_raw'], tf.uint8)
 img = tf.reshape(img,[224, 224, 3]) #reshape為128*128*3通道圖片
 img = tf.image.per_image_standardization(img)
 labels = tf.cast(features['label'], tf.int32)
 return img, labels

4. 將圖片幾個一打包,形成batch

def createBatch(filename,batchsize):
 images,labels = read_and_decode(filename)
 
 min_after_dequeue = 10
 capacity = min_after_dequeue + 3 * batchsize
 
 image_batch, label_batch = tf.train.shuffle_batch([images, labels], 
              batch_size=batchsize, 
              capacity=capacity, 
              min_after_dequeue=min_after_dequeue
              )
 
 label_batch = tf.one_hot(label_batch,depth=2)
 return image_batch, label_batch

5.主函數

if __name__ =="__main__":
 #訓練圖片兩張為一個batch,進行訓練,測試圖片一起進行測試
 mapfile = "/home/wc/DataSet/traffic/testTFRecord/classmap.txt"
 train_filename = "/home/wc/DataSet/traffic/testTFRecord/train.tfrecords"
#  createTFRecord(train_filename,mapfile)
 test_filename = "/home/wc/DataSet/traffic/testTFRecord/test.tfrecords"
#  createTFRecord(test_filename,mapfile)
 image_batch, label_batch = createBatch(filename = train_filename,batchsize = 2)
 test_images,test_labels = createBatch(filename = test_filename,batchsize = 20)
 with tf.Session() as sess:
  initop = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())
  sess.run(initop)
  coord = tf.train.Coordinator()
  threads = tf.train.start_queue_runners(sess = sess, coord = coord)
 
  try:
   step = 0
   while 1:
    _image_batch,_label_batch = sess.run([image_batch,label_batch])
    step += 1
    print step
    print (_label_batch)
  except tf.errors.OutOfRangeError:
   print (" trainData done!")
   
  try:
   step = 0
   while 1:
    _test_images,_test_labels = sess.run([test_images,test_labels])
    step += 1
    print step
 #     print _image_batch.shape
    print (_test_labels)
  except tf.errors.OutOfRangeError:
   print (" TEST done!")
  coord.request_stop()
  coord.join(threads)

此時,生成的batch,就可以feed進網絡了。

以上這篇將自己的數據集制作成TFRecord格式教程就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支持億速云。

向AI問一下細節

免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。

AI

德江县| 凉城县| 建宁县| 鹿泉市| 太和县| 乌鲁木齐县| 前郭尔| 阜康市| 长岛县| 连山| 桐庐县| 乌鲁木齐县| 泽库县| 贞丰县| 佳木斯市| 丹寨县| 舒兰市| 滨州市| 合水县| 和政县| 新河县| 牡丹江市| 阆中市| 堆龙德庆县| 靖宇县| 通州市| 通州区| 新平| 萨迦县| 临沧市| 永定县| 武陟县| 武清区| 江达县| 东平县| 定南县| 扶余县| 灌云县| 那曲县| 淮北市| 拉萨市|