我正在try 开发一个小功能,它可以动态绘制TensorFlow模型训练过程中的损失或精度.我基本上绘制了每个时期的每个批处理结束时的精度历史(代码仍然需要一些修正,但目前它可以正常工作).

我有一个小问题,因为我在jupyter笔记本电脑单元中运行以下代码.我有想要的行为,有一个动态发展的情节.然而,在训练结束时,由于某种原因,最终的情节被重复了,我不知道为什么会这样.

from IPython.display import display, clear_output
import tensorflow as tf
from tensorflow.keras.models import Sequential
import numpy as np
import matplotlib.pyplot as plt


class CustomCallback(tf.keras.callbacks.Callback):
    def on_train_begin(self, logs=None):
        self.epoch = 0  # Initialize the epoch counter
        self.accuracies = []
        self.fig, self.ax = plt.subplots()
        self.line, = self.ax.plot([], [])
        self.ax.set_xlim(0, 30)
        self.ax.set_ylim(0, 1)
        self.displayed = False
        display(self.fig)
    
    def on_epoch_begin(self, epoch, logs=None):
        self.epoch = epoch  # Update the current epoch at the beginning of each epoch

    def on_train_batch_end(self, batch, logs=None):
        accuracy = logs['accuracy']
        self.accuracies.append(accuracy)
        self.line.set_data(range(1, len(self.accuracies) + 1), self.accuracies)
        self.ax.relim()
        self.ax.autoscale_view()
        clear_output(wait=True)
        display(self.fig)


custom_callback = CustomCallback()

model = Sequential()
model.add(tf.keras.layers.Dense(units=16, activation='relu'))
model.add(tf.keras.layers.Dropout(rate=0.35))
model.add(tf.keras.layers.Dense(units=1, activation='tanh'))

model.compile(optimizer=tf.keras.optimizers.Adam(), loss="binary_crossentropy", metrics=["accuracy"])

X = np.random.randn(10**2, 10**4)
y = np.random.randint(2, size=10**2)

abc = model.fit(X, y, epochs=7, batch_size=32, validation_split=0.025, verbose=False, callbacks=[custom_callback])

推荐答案

这是因为Jupyter笔记本已经内联显示了一个数字,所以拨打display()就是在复制它.例如,下面的代码在jupyter笔记本中显示了两次相同的线条图.

import matplotlib.pyplot as plt
fig, ax = plt.subplots()
ax.plot(range(3))
display(fig)

要关闭交互模式,请在matplotlib导入后立即调用plt.ioff().或者,您也可以在培训结束时通过在课程中加入以下方法来结束该图.

    def on_train_end(self, logs=None):
        plt.close(self.fig)

Python相关问答推荐

我对打乒乓球有问题

使用Curses for Python保存和恢复终端窗口内容

将嵌套列表的字典转换为数据框中的行

如何在Python中增量更新DF

"Discord机器人中缺少所需的位置参数ctx

opencv Python稳定的图标识别

Python中的负前瞻性regex遇到麻烦

Python 3.12中的通用[T]类方法隐式类型检索

如何计算两极打印机中 * 所有列 * 的出现次数?

如何使用Python将工作表从一个Excel工作簿复制粘贴到另一个工作簿?

从收件箱中的列中删除html格式

类型错误:输入类型不支持ufuncisnan-在执行Mann-Whitney U测试时[SOLVED]

追溯(最近最后一次调用):文件C:\Users\Diplom/PycharmProject\Yolo01\Roboflow-4.py,第4行,在模块导入roboflow中

将pandas Dataframe转换为3D numpy矩阵

如何获取numpy数组的特定索引值?

对象的`__call__`方法的setattr在Python中不起作用'

利用Selenium和Beautiful Soup实现Web抓取JavaScript表

下三角形掩码与seaborn clustermap bug

(Python/Pandas)基于列中非缺失值的子集DataFrame

在代码执行后关闭ChromeDriver窗口