我正在try 训练INSITION算法来识别图像中的5种情绪,这5种情绪是愤怒、快乐、悲伤、恐惧和中性的,在Google CoLab中使用Python,尽管我对训练数据得到的一些结果感到相当困惑

If I use low batch_size like 1 or 2, I will get result below, which have bad accuracy, but all of label in the inception matrix is filled enter image description here

Whereas if I use batch_size like 16 or 32, I will get higher accuracy, but some label will be filled with 0 value enter image description here

那么,有没有人能帮我,哪个结果更"正确"?

哦,这是我目前使用的代码

import gc
import os
import numpy as np
from PIL import Image
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
from tensorflow.keras.applications.inception_v3 import InceptionV3
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense
from tensorflow.keras.models import Model
import tensorflow.keras.mixed_precision as mixed_precision
import tensorflow as tf

gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        tf.config.experimental.set_virtual_device_configuration(gpus[0], [
            tf.config.experimental.VirtualDeviceConfiguration(memory_limit=1024)])
        logical_gpus = tf.config.experimental.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        print(e)

# Enable mixed precision training
policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_global_policy(policy)

tf.keras.backend.clear_session()

# Load and preprocess the image data
def load_images(folder_path):
    images = []
    labels = []
    emotion_mapping = {'happy': 0, 'sad': 1, 'fear': 2, 'anger': 3, 'neutral': 4}

    for emotion_label in os.listdir(folder_path):
        emotion_dir = os.path.join(folder_path, emotion_label)
        if os.path.isdir(emotion_dir):
            label = emotion_mapping[emotion_label]
            for image_file in os.listdir(emotion_dir):
                image_path = os.path.join(emotion_dir, image_file)
                try:
                    img = Image.open(image_path).convert('RGB')
                    img = img.resize((128, 128))  # Resize the image to a fixed size for InceptionV3
                    img_array = np.array(img)
                    images.append(img_array)
                    labels.append(label)
                except Exception as e:
                    print(f"Error loading image: {image_path}\n{e}")

    return np.array(images), np.array(labels)

# Load and split the dataset into training and testing sets
folder_path = '/content/drive/MyDrive/QuanzengYou'
emotion_labels = ['Happy', 'Sad', 'Fear', 'Anger', 'Neutral']

images = []
labels = []
for label_idx, emotion_label in enumerate(emotion_labels):
    emotion_dir = os.path.join(folder_path, emotion_label)
    for image_file in os.listdir(emotion_dir):
        image_path = os.path.join(emotion_dir, image_file)
        try:
            img = Image.open(image_path).convert('RGB')
            img = img.resize((128, 128))  # Resize the image to a fixed size for InceptionV3
            img_array = np.array(img)
            images.append(img_array)
            labels.append(label_idx)
        except Exception as e:
            print(f"Error loading image: {image_path}\n{e}")

X = np.array(images)
y = np.array(labels)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# Preprocess the image data
X_train = X_train / 255.0
X_test = X_test / 255.0

# Load pre-trained InceptionV3 model without the top classification layer
base_model = InceptionV3(weights='imagenet', include_top=False, input_shape=(128, 128, 3))

# Add new classification layers on top of the base model
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(64, activation='relu')(x)
predictions = Dense(5, activation='softmax')(x)

# Create the final model
model = Model(inputs=base_model.input, outputs=predictions)

# Compile and train the model
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

# Augment training data for better generalization
datagen = ImageDataGenerator(rotation_range=20, width_shift_range=0.2, height_shift_range=0.2,
                             horizontal_flip=True, vertical_flip=True)
datagen.fit(X_train)

batch_size = 16
epochs = 10

# Train the model
for epoch in range(epochs):
    print(f"Epoch {epoch+1}/{epochs}")
    steps_per_epoch = len(X_train) // batch_size
    for step in range(steps_per_epoch):
        start = step * batch_size
        end = start + batch_size
        batch_images = X_train[start:end]
        batch_labels = y_train[start:end]
        with tf.device('/device:GPU:0'):
            loss = model.train_on_batch(batch_images, batch_labels)
        print(f"Step {step+1}/{steps_per_epoch}, Loss: {loss}")
    gc.collect()

eval_datagen = ImageDataGenerator()
eval_generator = eval_datagen.flow(X_test, y_test, batch_size=batch_size, shuffle=False)

# Evaluate the model using the generator
y_pred = model.predict(eval_generator)
y_pred_classes = np.argmax(y_pred, axis=1)

# Generate confusion matrix and save as an image
cm = confusion_matrix(y_test, y_pred_classes)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False, xticklabels=emotion_labels, yticklabels=emotion_labels)
plt.title('Confusion Matrix')
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.show()

# Print classification report
print("Classification Report:")
print(classification_report(y_test, y_pred_classes, zero_division=0))


推荐答案

多亏了用户Wakeme UpNow,似乎混淆矩阵中那些0值的问题确实是数据的不平衡,限制代码中获取的数据似乎已经解决了这个问题.谢谢!

Python相关问答推荐

如何使用bs 4从元素中提取文本

Python -Polars库中的滚动索引?

无法使用equals_html从网址获取全文

如何才能知道Python中2列表中的巧合.顺序很重要,但当1个失败时,其余的不应该失败或是0巧合

将jit与numpy linSpace函数一起使用时出错

Pandas 有条件轮班操作

如何在Python中并行化以下搜索?

基于字符串匹配条件合并两个帧

ODE集成中如何终止solve_ivp的无限运行

Python—从np.array中 Select 复杂的列子集

实现自定义QWidgets作为QTimeEdit的弹出窗口

cv2.matchTemplate函数匹配失败

需要帮助重新调整python fill_between与数据点

如何在Pyplot表中舍入值

如何排除prefecture_related中查询集为空的实例?

在二维NumPy数组中,如何 Select 内部数组的第一个和第二个元素?这可以通过索引来实现吗?

Pandas—堆栈多索引头,但不包括第一列

获取git修订版中每个文件的最后修改时间的最有效方法是什么?

504未连接IB API TWS错误—即使API连接显示已接受''

裁剪数字.nd数组引发-ValueError:无法将空图像写入JPEG