我在YouTube上学习了一个教程,教程告诉我如何对2个数据集进行分类(咳嗽,不是咳嗽),但现在我需要添加一个额外的类,这是喷嚏,所以有3个类需要训练(咳嗽,喷嚏,其他),我不知道如何做到这一点.请帮助!

在代码中,模型在2个类(cough,not_cough)上进行训练,性能相当不错,但我无法让它在多个类(cough,sneze,other)上工作.

import os
from matplotlib import pyplot as plt
import tensorflow as tf 
import tensorflow_io as tfio
from tensorflow.keras.models import Sequential, load_model
from tensorflow.keras.layers import Conv2D, Dense, Flatten, MaxPool2D, Dropout,TimeDistributed, Reshape
from tensorflow.keras.optimizers.legacy import Adam
from keras import layers
from keras.utils import to_categorical

def load_wav_16k_mono(filename):
    # Load encoded wav file
    file_contents = tf.io.read_file(filename)
    # Decode wav (tensors by channels) 
    wav, sample_rate = tf.audio.decode_wav(file_contents, desired_channels=1)
    # Removes trailing axis
    wav = tf.squeeze(wav, axis=-1)
    sample_rate = tf.cast(sample_rate, dtype=tf.int64)
    # Goes from 44100Hz to 16000hz - amplitude of the audio signal
    wav = tfio.audio.resample(wav, rate_in=sample_rate, rate_out=16000)
    return wav

def preprocess(file_path, label): 
    wav = load_wav_16k_mono(file_path)
    wav = wav[:8000]
    zero_padding = tf.zeros([8000] - tf.shape(wav), dtype=tf.float32)
    wav = tf.concat([zero_padding, wav],0)
    
    spectrogram = tf.signal.stft(wav, frame_length=100, frame_step=20)
    spectrogram = tf.abs(spectrogram)
    spectrogram = tf.expand_dims(spectrogram, axis=2)
    return spectrogram, label


def get_CNN(input_shape):
    model = Sequential()
    model.add(Conv2D(16, (3,3), activation='relu', input_shape=input_shape))
    model.add(Conv2D(16, (3,3), activation='relu'))
    model.add(MaxPool2D((2,2)))
    model.add(Flatten())
    model.add(Dense(128, activation='relu'))
    model.add(Dense(1, activation='softmax'))
    
    model.compile('Adam', loss='BinaryCrossentropy', metrics=[tf.keras.metrics.Recall(),tf.keras.metrics.Precision(),'accuracy'])
    model.summary() # drop in some max pool layers to reduce params
    return model
    

def main():
    POS_COUGH = "./data/cough"
    NEG_COUGH = "./data/not_cough"
  
    #POS_SPEECH = "./data/speech"

    pos_cough = tf.data.Dataset.list_files(POS_COUGH+'\*.wav')
    neg_cough = tf.data.Dataset.list_files(NEG_COUGH+'\*.wav')
    
    #pos_speech = tf.data.Dataset.list_files(POS_SPEECH +'\*.wav')

    cough_labels = tf.data.Dataset.from_tensor_slices(tf.ones(len(pos_cough)))  
    
    not_cough_labels = tf.data.Dataset.from_tensor_slices(tf.ones(len(neg_cough))) 
    
    # Add labels and Combine Positive and Negative Samples
    cough = tf.data.Dataset.zip((pos_cough, cough_labels))
    
    not_cough = tf.data.Dataset.zip((neg_cough, not_cough_labels))
   
    negatives = not_cough
    positives = cough
    # join both sameples 
    data = positives.concatenate(negatives)

    ### 2. Create a Tensorflow Data Pipeline
    data = data.map(preprocess)
    data = data.cache()
    data = data.shuffle(buffer_size=1000)
    data = data.batch(16)
    data = data.prefetch(8)
    
    ## 3. Split data into train and test data
    train = data.take(int(len(data) * 0.7))
    test = data.skip(int(len(data) * 0.7)).take(int(len(data) - len(data) * 0.7))   #test.as_numpy_iterator().next()

    input_shape_spectrogram = (396, 65,1)
    model = get_CNN(input_shape_spectrogram)
    hist = model.fit(train, epochs=2, validation_data=test)

推荐答案

首先,你需要在你的数据集中有3个类,这意味着你需要区分喷嚏样本,因为你已经做了咳嗽/不咳嗽.然后,您需要将输出转换为一个热编码向量,其中,除了对应于类索引的元素之外,所有元素都为零.例如,如果考虑不咳嗽= 0、咳嗽= 1和喷嚏= 2,则有喷嚏的样本必须为[0,0,1],有咳嗽的样本必须为[0,1,0],没有咳嗽的样本必须为[1,0,0] 最后,输出层应该有3个神经元.

model.add(Dense(3, activation='softmax'))

Python相关问答推荐

如何在BeautifulSoup中链接Find()方法并处理无?

点到面的Y距离

当使用keras.utils.Image_dataset_from_directory仅加载测试数据集时,结果不同

带条件计算最小值

可变参数数量的重载类型(args或kwargs)

通过Selenium从页面获取所有H2元素

加速Python循环

numpy卷积与有效

优化器的运行顺序影响PyTorch中的预测

改进大型数据集的框架性能

Polars asof在下一个可用日期加入

如何防止Pandas将索引标为周期?

如何使用OpenGL使球体遵循Python中的八样路径?

剪切间隔以添加特定日期

如何找出Pandas 图中的连续空值(NaN)?

如何从pandas DataFrame中获取. groupby()和. agg()之后的子列?

如何在Airflow执行日期中保留日期并将时间转换为00:00

Scipy差分进化:如何传递矩阵作为参数进行优化?

Pandas:计数器的滚动和,复位

使用美汤对维基百科表格进行网络刮擦未返回任何内容