在Keras中,回調函數是一種在訓練過程中自定義的操作,可以在每個訓練周期的不同階段執行。回調函數可以用于監控模型的性能、保存模型、調整學習率等。以下是如何在Keras中使用回調函數的步驟:
from keras.callbacks import EarlyStopping, ModelCheckpoint
callbacks = [EarlyStopping(monitor='val_loss', patience=5),
ModelCheckpoint(filepath='best_model.h5', monitor='val_loss', save_best_only=True)]
model.fit(x_train, y_train, validation_data=(x_val, y_val), callbacks=callbacks)
在上面的例子中,我們添加了兩個回調函數:一個是EarlyStopping,用于在驗證集上的損失不再減小時停止訓練;另一個是ModelCheckpoint,用于保存在驗證集上表現最好的模型。
from keras.callbacks import Callback
class CustomCallback(Callback):
def on_epoch_end(self, epoch, logs=None):
print('End of epoch:', epoch)
print('Training loss:', logs.get('loss'))
print('Validation loss:', logs.get('val_loss'))
callbacks = [CustomCallback()]
model.fit(x_train, y_train, validation_data=(x_val, y_val), callbacks=callbacks)
在上面的例子中,我們定義了一個自定義的回調函數CustomCallback,用于在每個訓練周期結束時輸出訓練損失和驗證損失。
通過以上步驟,您可以很容易地在Keras中使用回調函數來監控和控制模型的訓練過程。