Stroke classification

Stroke classification#

The task of percussion stroke classification has been, historically, the principal target on the timbre-side of computational analysis of Indian Art Music. As seen in the instrumentation presentation, the musical arrangement in Indian Art Music is quite well-defined, while there is an important scarcity of good-quality data (specially monphonic recordings) for many of the Carnatic and Hindustani specific instruments. These factors combined with the importance of the different stroke types in the main percussion instruments, have given research importance on mridangam [ABKM14, ABM13] and tabla stroke classification [RBR21].

We will go through examples of these tasks in this walkthrough.

## Installing (if not) and importing compiam to the project
import importlib.util
if importlib.util.find_spec('compiam') is None:
    ## Bear in mind this will only run in a jupyter notebook / Collab session
    %pip install compiam
import compiam

# Import extras and supress warnings to keep the tutorial clean
import os
import random
from pprint import pprint
import warnings
warnings.filterwarnings('ignore')

Mridangam stroke classification#

from compiam.timbre.stroke_classification import MridangamStrokeClassification
msc = MridangamStrokeClassification()  # Let's use msc for simplicity

Let’s start by loading the mridangam stroke dataset. Since MridangamStrokeClassificationis based on the Mridangam Stroke Dataset, compiam includes a specific function to load the dataset and integrate it to the pipeline.

msc.load_mridangam_dataset(data_home="../audio/mir_datasets/", download=True)

Note

This function does not return a dataloader. Instead, the dataloader lives within the tool class. We will see how this works in the following steps of this walkthrough. You may check out the MridangamStrokeClassification documentation to learn how we do take advantage of the dataloader in this tool class.

# Print list of available mirdangam strokes in the dataset
msc.list_strokes()
['bheem', 'cha', 'dheem', 'dhin', 'num', 'ta', 'tha', 'tham', 'thi', 'thom']

Let’s train and evaluate a very basic model to perform classification of mridangam strokes. We first use a util function in the mirdata Dataset class to separate the Mridangam Stroke Dataset in paticular splits. We will use get_random_track_splits, since this dataset does not have pre-determined splits, and we will create these randomly.

# Loading tracks for the mirdangam dataset
mridangam_tracks = msc.dataset.load_tracks()

# Getting list of id per split
# NOTE: We use (0.9, 0.1): two splits including 90% and 10% of the whole dataset
split_dict = msc.dataset.get_random_track_splits(
    splits=(0.9, 0.1),
    split_names=("train", "validation")
)

# Get track dictionaries given the created splits
train_split = {x: mridangam_tracks[x] for x in split_dict["train"]}
evaluation_split = {y: mridangam_tracks[y] for y in split_dict["validation"]}

# Let's get random track from the created evaluation split
random.choice(list(evaluation_split.items()))
('225104',
 Track(
   audio_path="../audio/mir_datasets/mridangam_stroke_1.5/B/225104__akshaylaya__thi-b-323.wav",
   stroke_name="thi",
   tonic="B",
   track_id="225104",
   audio: The track's audio
 
         Returns,
 ))

Our class will assume that the entire dataset is used for the training process. We need to update the dataset in the class with the training split.

msc.mridangam_tracks = train_split
msc.mridangam_ids = list(train_split.keys())

Let’s now train the model! We will train Support Vector Machine (SVM) model using scikit learn. The mridangam stroke classification tool in compiam uses the MusicExtraction in Essentia to compute low-level features from the stroke recordings and feed the model.

Note

You can also train a different model and compare the performance. We offer other options (see the documentation of the tool), but feel free to open a Pull Request in compiam to add more models to the available options.

svm_accuracy = msc.train_model()

The model has been trained! We have also got the testing accuracy returned in case we want to store it, re-train the model again using different settings, and compare.

Now we can predict the stroke on a particular list of instances. First, we need to get the list of paths for the mirdata dataset split we generated a few steps earlier.

# Get paths from created evaluation split
eval_paths = [evaluation_split[x].audio_path for x in list(evaluation_split.keys())]

# Compute prediction from list of paths
prediction = msc.predict(eval_paths)
# Visualise and evaluate some predictions from the model output
pprint(random.choice(list(prediction.items())))
pprint(random.choice(list(prediction.items())))
pprint(random.choice(list(prediction.items())))
('../audio/mir_datasets/mridangam_stroke_1.5/D#/228925__akshaylaya__dhin-dsh-126.wav',
 'dhin')
('../audio/mir_datasets/mridangam_stroke_1.5/C#/226559__akshaylaya__dheem-csh-017.wav',
 'dheem')
('../audio/mir_datasets/mridangam_stroke_1.5/D/227731__akshaylaya__cha-d-045.wav',
 'cha')

In the file paths of this validation files we can already see the actual stroke that is present in the recording, so we can evaluate how good our model classified the mridangam strokes. Otherwise, we can also get the actual tonic using the mirdata loader and a particular track ID.

msc.dataset.choice_track()
Track(
  audio_path="../audio/mir_datasets/mridangam_stroke_1.5/C/226097__akshaylaya__thi-c-026.wav",
  stroke_name="thi",
  tonic="C",
  track_id="226097",
  audio: The track's audio

        Returns,
)

We note that the ID has been directly taken from the file name of the stroke recordings. Let’s use that to compare, for a random prediction, the predicted and ground-truth stroke annotations.

# Selecting a random example from the predicted files
predicted_file, predicted_stroke = random.choice(list(prediction.items()))

# Getting the ID from filepath
identifier = os.path.basename(predicted_file).split("__")[0]

# Comparing target and estimation
if evaluation_split[identifier].stroke_name == predicted_stroke:
    print("Nice! Predicted stroke in {}\n coincides with ground-truth {}"\
        .format(
            os.path.basename(predicted_file),
            evaluation_split[identifier].stroke_name
        )
    )
else:
    print("Missed! Predicted stroke in {}\n does NOT coincide with ground-truth {}"\
        .format(
            os.path.basename(predicted_file),
            evaluation_split[identifier].stroke_name
        )
    )
Nice! Predicted stroke in 226722__akshaylaya__num-csh-039.wav
 coincides with ground-truth num