keras回调函数的使用
回调函数
- 回调函数是一个对象(实现了特定方法的类实例),它在调用fit()时被传入模型,并在训练过程中的不同时间点被模型调用
- 可以访问关于模型状态与模型性能的所有可用数据
- 模型检查点(model checkpointing):在训练过程中的不同时间点保存模型的当前状态。
- 提前终止(early stopping):如果验证损失不再改善,则中断训练(当然,同时保存在训练过程中的最佳模型)。
- 在训练过程中动态调节某些参数值:比如调节优化器的学习率。
- 在训练过程中记录训练指标和验证指标,或者将模型学到的表示可视化(这些表示在不断更新):fit()进度条实际上就是一个回调函数。
fit()方法中使用callbacks参数
# 这里有两个callback函数:早停和模型检查点 callbacks_list=[ keras.callbacks.EarlyStopping( monitor="val_accuracy",#监控指标 patience=2 #两轮内不再改善中断训练 ), keras.callbacks.ModelCheckpoint( filepath="checkpoint_path", monitor="val_loss", save_best_only=True ) ] #模型获取 model=get_minist_model() model.compile(optimizer="rmsprop", loss="sparse_categorical_crossentropy", metrics=["accuracy"]) model.fit(train_images,train_labels, epochs=10,callbacks=callbacks_list, #该参数使用回调函数 validation_data=(val_images,val_labels)) test_metrics=model.evaluate(test_images,test_labels)#计算模型在新数据上的损失和指标 predictions=model.predict(test_images)#计算模型在新数据上的分类概率
模型的保存和加载
#也可以在训练完成后手动保存模型,只需调用model.save('my_checkpoint_path')。 #重新加载模型 model_new=keras.models.load_model("checkpoint_path.keras")
通过对Callback类子类化来创建自定义回调函数
on_epoch_begin(epoch, logs) ←----在每轮开始时被调用
on_epoch_end(epoch, logs) ←----在每轮结束时被调用
on_batch_begin(batch, logs) ←----在处理每个批量之前被调用
on_batch_end(batch, logs) ←----在处理每个批量之后被调用
on_train_begin(logs) ←----在训练开始时被调用
on_train_end(logs ←----在训练结束时被调用
from matplotlib import pyplot as plt # 实现记录每一轮中每个batch训练后的损失,并为每个epoch绘制一个图 class LossHistory(keras.callbacks.Callback): def on_train_begin(self, logs): self.per_batch_losses = [] def on_batch_end(self, batch, logs): self.per_batch_losses.append(logs.get("loss")) def on_epoch_end(self, epoch, logs): plt.clf() plt.plot(range(len(self.per_batch_losses)), self.per_batch_losses, label="Training loss for each batch") plt.xlabel(f"Batch (epoch {epoch})") plt.ylabel("Loss") plt.legend() plt.savefig(f"plot_at_epoch_{epoch}") self.per_batch_losses = [] #清空,方便下一轮的技术
model = get_mnist_model() model.compile(optimizer="rmsprop", loss="sparse_categorical_crossentropy", metrics=["accuracy"]) model.fit(train_images, train_labels, epochs=10, callbacks=[LossHistory()], validation_data=(val_images, val_labels))
【其他】模型的定义 和 数据加载
def get_minist_model(): inputs=keras.Input(shape=(28*28,)) features=layers.Dense(512,activation="relu")(inputs) features=layers.Dropout(0.5)(features) outputs=layers.Dense(10,activation="softmax")(features) model=keras.Model(inputs,outputs) return model #datset from tensorflow.keras.datasets import mnist (train_images,train_labels),(test_images,test_labels)=mnist.load_data() train_images=train_images.reshape((60000,28*28)).astype("float32")/255 test_images=test_images.reshape((10000,28*28)).astype("float32")/255 train_images,val_images=train_images[10000:],train_images[:10000] train_labels,val_labels=train_labels[10000:],train_labels[:10000]
到此这篇关于keras回调函数的使用的文章就介绍到这了,更多相关keras回调函数内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!
相关文章
Python中staticmethod和classmethod的作用与区别
今天小编就为大家分享一篇关于Python中staticmethod和classmethod的作用与区别,小编觉得内容挺不错的,现在分享给大家,具有很好的参考价值,需要的朋友一起跟随小编来看看吧2018-10-10使用django-guardian实现django-admin的行级权限控制的方法
这篇文章主要介绍了使用django-guardian实现django-admin的行级权限控制的方法,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧2018-10-10
最新评论