"Si un trabajador quiere hacer bien su trabajo, primero debe afilar sus herramientas." - Confucio, "Las Analectas de Confucio. Lu Linggong"
Página delantera > Programación > Predicción musical de Tensorflow

Predicción musical de Tensorflow

Publicado el 2024-11-08
Navegar:764

Tensorflow music prediction

En este artículo, muestro cómo usar tensorflow para predecir un estilo de música.
En mi ejemplo, comparo la música techno y clásica.

Puedes encontrar el código en mi github:
https://github.com/victordalet/sound_to_partition


I - Conjunto de datos

Para el primer paso, necesitas crear un carpeta de conjunto de datos y dentro agregar una carpeta para el estilo de música, por ejemplo, agrego una carpeta techno y una carpeta clásica en la que pongo mi sonido wav.

II - Tren

Creo un archivo de tren, con los argumentos max_epochs para completar.

Modifica las clases en el constructor que corresponden a tu directorio en la carpeta del conjunto de datos.

En el método de carga y procesamiento, recupero el archivo wav de un directorio diferente y obtengo el espectrograma.

Para fines de capacitación, utilizo las convoluciones y el modelo de Keras.

import os
import sys
from typing import List

import librosa
import numpy as np
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Flatten, Dense
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from sklearn.model_selection import train_test_split
from tensorflow.keras.utils import to_categorical
from tensorflow.image import resize



class Train:

    def __init__(self):
        self.X_train = None
        self.X_test = None
        self.y_train = None
        self.y_test = None
        self.data_dir: str = 'dataset'
        self.classes: List[str] = ['techno','classic']
        self.max_epochs: int = int(sys.argv[1])

    @staticmethod
    def load_and_preprocess_data(data_dir, classes, target_shape=(128, 128)):
        data = []
        labels = []

        for i, class_name in enumerate(classes):
            class_dir = os.path.join(data_dir, class_name)
            for filename in os.listdir(class_dir):
                if filename.endswith('.wav'):
                    file_path = os.path.join(class_dir, filename)
                    audio_data, sample_rate = librosa.load(file_path, sr=None)
                    mel_spectrogram = librosa.feature.melspectrogram(y=audio_data, sr=sample_rate)
                    mel_spectrogram = resize(np.expand_dims(mel_spectrogram, axis=-1), target_shape)
                    data.append(mel_spectrogram)
                    labels.append(i)

        return np.array(data), np.array(labels)

    def create_model(self):
        data, labels = self.load_and_preprocess_data(self.data_dir, self.classes)
        labels = to_categorical(labels, num_classes=len(self.classes))  # Convert labels to one-hot encoding
        self.X_train, self.X_test, self.y_train, self.y_test = train_test_split(data, labels, test_size=0.2,
                                                                                random_state=42)

        input_shape = self.X_train[0].shape
        input_layer = Input(shape=input_shape)
        x = Conv2D(32, (3, 3), activation='relu')(input_layer)
        x = MaxPooling2D((2, 2))(x)
        x = Conv2D(64, (3, 3), activation='relu')(x)
        x = MaxPooling2D((2, 2))(x)
        x = Flatten()(x)
        x = Dense(64, activation='relu')(x)
        output_layer = Dense(len(self.classes), activation='softmax')(x)
        self.model = Model(input_layer, output_layer)

        self.model.compile(optimizer=Adam(learning_rate=0.001), loss='categorical_crossentropy', metrics=['accuracy'])

    def train_model(self):
        self.model.fit(self.X_train, self.y_train, epochs=self.max_epochs, batch_size=32,
                       validation_data=(self.X_test, self.y_test))
        test_accuracy = self.model.evaluate(self.X_test, self.y_test, verbose=0)
        print(test_accuracy[1])

    def save_model(self):
        self.model.save('weight.h5')


if __name__ == '__main__':
    train = Train()
    train.create_model()
    train.train_model()
    train.save_model()

III - Prueba

Para probar y usar el modelo, creé esta clase para recuperar el peso y predecir el estilo de la música.

No olvides agregar las clases correctas al constructor.

from typing import List

import librosa
import numpy as np
from tensorflow.keras.models import load_model
from tensorflow.image import resize
import tensorflow as tf



class Test:

    def __init__(self, audio_file_path: str):
        self.model = load_model('weight.h5')
        self.target_shape = (128, 128)
        self.classes: List[str] = ['techno','classic']
        self.audio_file_path: str = audio_file_path

    def test_audio(self, file_path, model):
        audio_data, sample_rate = librosa.load(file_path, sr=None)
        mel_spectrogram = librosa.feature.melspectrogram(y=audio_data, sr=sample_rate)
        mel_spectrogram = resize(np.expand_dims(mel_spectrogram, axis=-1), self.target_shape)
        mel_spectrogram = tf.reshape(mel_spectrogram, (1,)   self.target_shape   (1,))

        predictions = model.predict(mel_spectrogram)

        class_probabilities = predictions[0]

        predicted_class_index = np.argmax(class_probabilities)

        return class_probabilities, predicted_class_index

    def test(self):
        class_probabilities, predicted_class_index = self.test_audio(self.audio_file_path, self.model)

        for i, class_label in enumerate(self.classes):
            probability = class_probabilities[i]
            print(f'Class: {class_label}, Probability: {probability:.4f}')

        predicted_class = self.classes[predicted_class_index]
        accuracy = class_probabilities[predicted_class_index]
        print(f'The audio is classified as: {predicted_class}')
        print(f'Accuracy: {accuracy:.4f}')
Declaración de liberación Este artículo se reproduce en: https://dev.to/victicordalet/tensorflow-music-predicción-4i6f?1 Si hay alguna infracción, comuníquese con [email protected] para eliminarlo.
Último tutorial Más>

Descargo de responsabilidad: Todos los recursos proporcionados provienen en parte de Internet. Si existe alguna infracción de sus derechos de autor u otros derechos e intereses, explique los motivos detallados y proporcione pruebas de los derechos de autor o derechos e intereses y luego envíelos al correo electrónico: [email protected]. Lo manejaremos por usted lo antes posible.

Copyright© 2022 湘ICP备2022001581号-3