您好,登錄后才能下訂單哦!
tensorflow由于其基于靜態圖的模式,導致寫代碼的時候很難調試,除了用官方的調試工具外,最直接的方法就是把中間結果輸出出來查看,然而,直接用print函數只能輸出tensor變量的形狀,而不是數值,想要輸出tensor的具體數值需要用tf.Print函數。網上有很多關于這個函數使用方法的說明,這里簡要介紹:
Print( input_, data, message=None, first_n=None, summarize=None, name=None )
參數:
input_:通過這個操作的張量。 (流入的數據流)
data:計算 op 時要打印的張量列表。(用[ ]引起來的一串需要打印的東西,用逗號隔開)
message:一個字符串,錯誤消息的前綴。
first_n:只記錄 first_n 次數。負數日志,這是默認的。
summarize:只打印每個張量的固定數目的條目。如果沒有,則每個輸入張量最多打印3個元素。
name:操作的名稱(可選)
然而網上大部分資源都是介紹如何在主函數中先建立一個op,再開啟一個Session執行sess.run(op)的方法,但是如果想要輸出函數中的中間值而該值又未傳回主函數呢?這種情況下無法在函數中開啟一個新的Session,但是仍然可以用tf.Print建立op來實現。
import tensorflow as tf import os os.environ["CUDA_VISIBLE_DEVICES"] = "0" def test(): a=tf.constant(0) for i in range(10): a_print = tf.Print(a,['a_value: ',a]) a=a_print+1 return a if __name__=='__main__': with tf.Session() as sess: sess.run(test())
運行結果:
a_print可以理解為在圖中新增了一個節點,在后續代碼中當有別的變量使用了a_print時(如上例a=a_print+1),就會有數據從a_print節點上流過,就會輸出值,而究竟會輸出幾次值呢?這其實并不是看下文中a_print被使用了幾次,而是看數據流要從該節點上流經幾次,可以理解為a_print這個op被“定義”了幾次。
def test(): a=tf.constant(0) a_print = tf.Print(a,['a_value: ',a]) for i in range(10): a=a_print+1 return a if __name__=='__main__': with tf.Session() as sess: sess.run(test())
如果把test()函數改成這樣,則運行結果為:
輸出僅被執行了一次,因為a_print這個op只被定義了一次,雖然后面在循環里不斷被a使用,但是數據只從它身上經過了一次,所以只會print一次,并且a_print的值永遠為0,最終返回的a的值也為1。
再把代碼改成下例:
def test(): a=tf.constant(0) a_print = tf.Print(a,['a_value: ',a]) for i in range(10): a_print=a_print+1 return a if __name__=='__main__': with tf.Session() as sess: sess.run(test())
運行結果是什么也不會輸出,因為a_print這個op沒有和別的變量發生關系,它沒有被別的變量使用,在圖里為孤立的一個節點,沒有數據流過,就不會被執行。
而如果改成這樣
def test(): a=tf.constant(0) a_print = tf.Print(a,['a_value: ',a]) for i in range(10): a_print=a_print+1 return a_print if __name__=='__main__': with tf.Session() as sess: sess.run(test())
運行結果
返回的a_print值為10也是正確的,因為a_print在下文被返回,所以有數據流流經,會被執行,而因為a_print的定義只執行一次,所以只會輸出一次。
以上這篇tensorflow實現在函數中用tf.Print輸出中間值就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支持億速云。
免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。