import os
import librosa
import numpy as np
from tqdm import tqdm
from compiam.exceptions import ModelNotTrainedError
# Shared functions with FTANetCarnatic for vocals
from compiam.utils.pitch import normalisation, resampling
from compiam.melody.pitch_extraction.ftanet_carnatic.pitch_processing import (
est,
std_normalize,
)
from compiam.melody.pitch_extraction.ftanet_carnatic.cfp import cfp_process
from compiam.io import write_csv
from compiam.utils import get_logger, WORKDIR
from compiam.utils.download import download_remote_model
logger = get_logger(__name__)
[docs]
class FTAResNetCarnatic(object):
"""FTA-ResNet melody extraction tuned to Carnatic Music."""
def __init__(
self,
model_path=None,
download_link=None,
download_checksum=None,
sample_rate=44100,
gpu="-1",
):
"""FTA-Net melody extraction init method.
:param model_path: path to file to the model weights.
:param download_link: link to the remote pre-trained model.
:param download_checksum: checksum of the model file.
:param sample_rate: Sample rate to which the audio is sampled for extraction.
"""
### IMPORTING OPTIONAL DEPENDENCIES
try:
global torch
import torch
global FTAnet
from compiam.melody.pitch_extraction.ftaresnet_carnatic.model import FTAnet
except:
raise ImportError(
"In order to use this tool you need to have torch installed. "
"Install compIAM with torch support: pip install 'compiam[torch]'"
)
###
## Setting up GPU if specified
self.gpu = gpu
self.device = None
self.select_gpu(gpu)
self.model = self._build_model()
self.sample_rate = sample_rate
self.trained = False
self.model_path = model_path
self.download_link = download_link
self.download_checksum = download_checksum
if self.model_path is not None:
self.load_model(self.model_path)
self.sample_rate = sample_rate
def _build_model(self):
"""Build the FTA-Net model."""
ftanet = FTAnet().to(self.device)
ftanet.eval()
return ftanet
[docs]
def load_model(self, model_path):
"""Load pre-trained model weights."""
if not os.path.exists(model_path):
self.download_model(model_path) # Downloading model weights
## Ensuring we can load the model for different torch versions
## -- (weights only might be deprecated)
try:
self.model.load_state_dict(torch.load(model_path, weights_only=True, map_location=self.device))
except:
self.model.load_state_dict(torch.load(model_path, map_location=self.device))
self.model_path = model_path
self.trained = True
[docs]
def download_model(self, model_path=None, force_overwrite=False):
"""Download pre-trained model."""
download_path = (
os.sep + os.path.join(*model_path.split(os.sep)[:-2])
if model_path is not None
else os.path.join(WORKDIR, "models", "melody", "ftaresnet_carnatic")
)
# Creating model folder to store the weights
if not os.path.exists(download_path):
os.makedirs(download_path)
download_remote_model(
self.download_link,
self.download_checksum,
download_path,
force_overwrite=force_overwrite,
)
[docs]
def predict(
self,
input_data,
input_sr=44100,
hop_size=441,
time_frame=128,
out_step=None,
gpu="-1",
):
"""Extract melody from input_data.
Implementation taken (and slightly adapted) from https://github.com/yushuai/FTANet-melodic.
:param input_data: path to audio file or numpy array like audio signal.
:param input_sr: sampling rate of the input array of data (if any). This variable is only
relevant if the input is an array of data instead of a filepath.
:param hop_size: hop size between frequency estimations.
:param batch_size: batches of seconds that are passed through the model
(defaulted to 5, increase if enough computational power, reduce if
needed).
:param out_step: particular time-step duration if needed at output
:param gpu: Id of the available GPU to use (-1 by default, to run on CPU), use string: '0', '1', etc.
:returns: a 2-D list with time-stamps and pitch values per timestamp.
"""
## Setting up GPU if any
if gpu != self.gpu:
self.select_gpu(gpu)
if self.trained is False:
raise ModelNotTrainedError(
"""Model is not trained. Please load model before running inference!
You can load the pre-trained instance with the load_model wrapper."""
)
# Loading and resampling audio
if isinstance(input_data, str):
if not os.path.exists(input_data):
raise FileNotFoundError("Target audio not found.")
audio, _ = librosa.load(input_data, sr=self.sample_rate)
elif isinstance(input_data, np.ndarray):
logger.warning(
f"Resampling... (input sampling rate is assumed {input_sr}Hz, \
make sure this is correct and change input_sr otherwise)"
)
audio = librosa.resample(
input_data, orig_sr=input_sr, target_sr=self.sample_rate
)
else:
raise ValueError("Input must be path to audio signal or an audio array")
audio_shape = audio.shape
if len(audio_shape) > 1:
audio_channels = min(audio_shape)
if audio_channels == 1:
audio = audio.flatten()
else:
audio = np.mean(audio, axis=np.argmin(audio_shape))
# Extracting pitch
audio_len = len(audio) // hop_size
with torch.no_grad():
prediction_pitch = torch.zeros(321, audio_len).to(self.device)
for i in tqdm(range(0, audio_len, time_frame)):
W, Cen_freq, _ = cfp_process(
y=audio[i * hop_size : (i + time_frame) * hop_size + 1],
sr=self.sample_rate,
hop=hop_size,
)
value = W.shape[-1]
W = np.concatenate(
(W, np.zeros((3, 320, 128 - W.shape[-1]))), axis=-1
) # Padding
W_norm = std_normalize(W)
w = np.stack((W_norm, W_norm))
prediction_pitch[:, i : i + value] = self.model(
torch.Tensor(w).to(self.device)
)[0][0][0][:, :value]
# Convert to Hz
frame_time = hop_size / self.sample_rate
y_hat = est(
prediction_pitch.to("cpu"),
Cen_freq,
torch.linspace(
0,
frame_time * ((audio_len // time_frame) * time_frame),
((audio_len // time_frame) * time_frame),
),
)
# Filter low frequency values
freqs = y_hat[:, 1]
freqs[freqs < 50] = 0
# Format output
TStamps = y_hat[:, 0]
output = np.array([TStamps, freqs]).transpose()
if out_step is not None:
new_len = int((len(audio) / self.sample_rate) // out_step)
output = resampling(output, new_len)
return output
return output
[docs]
@staticmethod
def normalise_pitch(pitch, tonic, bins_per_octave=120, max_value=4):
"""Normalise pitch given a tonic.
:param pitch: a 2-D list with time-stamps and pitch values per timestamp.
:param tonic: recording tonic to normalize the pitch to.
:param bins_per_octave: number of frequency bins per octave.
:param max_value: maximum value to clip the normalized pitch to.
:returns: a 2-D list with time-stamps and normalised to a given tonic
pitch values per timestamp.
"""
return normalisation(
pitch, tonic, bins_per_octave=bins_per_octave, max_value=max_value
)
[docs]
@staticmethod
def save_pitch(data, output_path):
"""Calling the write_csv function in compiam.io to write the output pitch curve in a file
:param data: the data to write
:param output_path: the path where the data is going to be stored
:returns: None
"""
return write_csv(data, output_path)
[docs]
def select_gpu(self, gpu="-1"):
"""Select the GPU to use for inference.
:param gpu: Id of the available GPU to use (-1 by default, to run on CPU), use string: '0', '1', etc.
:returns: None
"""
if int(gpu) == -1:
self.device = torch.device("cpu")
else:
if torch.cuda.is_available():
self.device = torch.device("cuda:" + str(gpu))
elif torch.backends.mps.is_available():
self.device = torch.device("mps:" + str(gpu))
else:
self.device = torch.device("cpu")
logger.warning("No GPU available. Running on CPU.")
self.gpu = gpu