Source code for compiam.rhythm.meter.tcn_carnatic

import os
import sys
import numpy as np
from typing import Dict
from tqdm import tqdm
from compiam.exceptions import ModelNotTrainedError

from compiam.utils.download import download_remote_model
from compiam.utils import get_logger, WORKDIR
from compiam.io import write_csv

logger = get_logger(__name__)

[docs] class TCNTracker(object): """TCN beat tracker tuned to Carnatic Music.""" def __init__(self, post_processor="joint", model_version=42, model_path=None, download_link=None, download_checksum=None, gpu=-1): """TCN beat tracker init method. :param post_processor: Post-processing method to use. Choose from 'joint', or 'sequential'. :param model_version: Version of the pre-trained model to use. Choose from 42, 52, or 62. :param model_path: path to file to the model weights. :param download_link: link to the remote pre-trained model. :param download_checksum: checksum of the model file. """ ### IMPORTING OPTIONAL DEPENDENCIES try: global torch import torch except ImportError: raise ImportError( "Torch is required to use TCNTracker. " "Install compIAM with torch support: pip install 'compiam[torch]'" ) try: global madmom import madmom except ImportError: raise ImportError( "Madmom is required to use TCNTracker. " "Install compIAM with madmom support: pip install 'compiam[madmom]'" ) ### global MultiTracker, PreProcessor, joint_tracker, sequential_tracker from compiam.rhythm.meter.tcn_carnatic.model import MultiTracker from compiam.rhythm.meter.tcn_carnatic.pre import PreProcessor from compiam.rhythm.meter.tcn_carnatic.post import joint_tracker, sequential_tracker if post_processor not in ["beat", "joint", "sequential"]: raise ValueError(f"Invalid post_processor: {post_processor}. Choose from 'joint', or 'sequential'.") if model_version not in [42, 52, 62]: raise ValueError(f"Invalid model_version: {model_version}. Choose from 42, 52, or 62.") self.gpu = gpu self.device = None self.select_gpu(gpu) self.model_path = model_path self.model_version = f'multitracker_{model_version}.pth' self.download_link = download_link self.download_checksum = download_checksum self.trained = False self.model = self._build_model() if self.model_path is not None: self.load_model(self.model_path) self.pad_frames = 2 self.post_processor = joint_tracker if post_processor == "joint" else \ sequential_tracker def _build_model(self): """Build the TCN model.""" model = MultiTracker().to(self.device) model.eval() return model
[docs] def load_model(self, model_path): """Load pre-trained model weights.""" if not os.path.exists(os.path.join(model_path, self.model_version)): self.download_model(model_path) # Downloading model weights self.model.load_weights(os.path.join(model_path, self.model_version), self.device) self.model_path = model_path self.trained = True
[docs] def download_model(self, model_path=None, force_overwrite=True): """Download pre-trained model.""" download_path = ( #os.sep + os.path.join(*model_path.split(os.sep)[:-2]) model_path if model_path is not None else os.path.join(WORKDIR, "models", "rhythm", "tcn-carnatic") ) # Creating model folder to store the weights if not os.path.exists(download_path): os.makedirs(download_path) download_remote_model( self.download_link, self.download_checksum, download_path, force_overwrite=force_overwrite, )
[docs] def predict(self, input_data: str, sr: int = 44100, min_bpm=55, max_bpm=230, beats_per_bar=[3, 5, 7, 8]) -> Dict: """Run inference on input audio file. :param input_data: path to audio file or numpy array like audio signal. :param sr: sampling rate of the input audio signal (default: 44100). :param min_bpm: minimum BPM for beat tracking (default: 55). :param max_bpm: maximum BPM for beat tracking (default: 230). :param beats_per_bar: list of possible beats per bar for downbeat tracking (default: [3, 5, 7, 8]). :returns: a 2-D list with beats and beat positions. """ if self.trained is False: raise ModelNotTrainedError( """Model is not trained. Please load model before running inference! You can load the pre-trained instance with the load_model wrapper.""" ) features = self.preprocess_audio(input_data, sr) x = torch.from_numpy(features).to(self.device) output = self.model(x) beats_act = output["beats"].squeeze().detach().cpu().numpy() downbeats_act = output["downbeats"].squeeze().detach().cpu().numpy() pred = self.post_processor(beats_act, downbeats_act, min_bpm=min_bpm, max_bpm=max_bpm, beats_per_bar=beats_per_bar) return pred
[docs] def preprocess_audio(self, input_data: str, input_sr: int) -> np.ndarray: """Preprocess input audio file to extract features for inference. :param audio_path: Path to the input audio file. :param input_sr: Sampling rate of the input audio file. :returns: Preprocessed features as a numpy array. """ if isinstance(input_data, str): if not os.path.exists(input_data): raise FileNotFoundError("Target audio not found.") audio, sr = madmom.io.audio.load_audio_file(input_data) if audio.shape[0] == 2: audio = audio.mean(axis=0) signal = madmom.audio.Signal(audio, sr, num_channels=1) elif isinstance(input_data, np.ndarray): audio = input_data if audio.shape[0] == 2: audio = audio.mean(axis=0) signal = madmom.audio.Signal(audio, input_sr, num_channels=1) sr = input_sr else: raise ValueError("Input must be path to audio signal or an audio array") x = PreProcessor(sample_rate=sr)(signal) pad_start = np.repeat(x[:1], self.pad_frames, axis=0) pad_stop = np.repeat(x[-1:], self.pad_frames, axis=0) x_padded = np.concatenate((pad_start, x, pad_stop)) x_final = np.expand_dims(np.expand_dims(x_padded, axis=0), axis=0) return x_final
[docs] @staticmethod def save_pitch(data, output_path): """Calling the write_csv function in compiam.io to write the output beat track in a file :param data: the data to write :param output_path: the path where the data is going to be stored :returns: None """ return write_csv(data, output_path)
[docs] def select_gpu(self, gpu="-1"): """Select the GPU to use for inference. :param gpu: Id of the available GPU to use (-1 by default, to run on CPU), use string: '0', '1', etc. :returns: None """ if int(gpu) == -1: self.device = torch.device("cpu") else: if torch.cuda.is_available(): self.device = torch.device("cuda:" + str(gpu)) elif torch.backends.mps.is_available(): self.device = torch.device("mps:" + str(gpu)) else: self.device = torch.device("cpu") logger.warning("No GPU available. Running on CPU.") self.gpu = gpu