"use strict";
/**
 * @license
 * Copyright (C) 2006-2020  Music Technology Group - Universitat Pompeu Fabra
 *
 * This file is part of Essentia
 *
 * Essentia is free software: you can redistribute it and/or modify it under
 * the terms of the GNU Affero General Public License as published by the Free
 * Software Foundation (FSF), either version 3 of the License, or (at your
 * option) any later version.
 *
 * This program is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
 * FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
 * details.
 *
 * You should have received a copy of the Affero GNU General Public License
 * version 3 along with this program.  If not, see http://www.gnu.org/licenses/
 */
var __extends = (this && this.__extends) || (function () {
    var extendStatics = function (d, b) {
        extendStatics = Object.setPrototypeOf ||
            ({ __proto__: [] } instanceof Array && function (d, b) { d.__proto__ = b; }) ||
            function (d, b) { for (var p in b) if (b.hasOwnProperty(p)) d[p] = b[p]; };
        return extendStatics(d, b);
    };
    return function (d, b) {
        extendStatics(d, b);
        function __() { this.constructor = d; }
        d.prototype = b === null ? Object.create(b) : (__.prototype = b.prototype, new __());
    };
})();
var __awaiter = (this && this.__awaiter) || function (thisArg, _arguments, P, generator) {
    function adopt(value) { return value instanceof P ? value : new P(function (resolve) { resolve(value); }); }
    return new (P || (P = Promise))(function (resolve, reject) {
        function fulfilled(value) { try { step(generator.next(value)); } catch (e) { reject(e); } }
        function rejected(value) { try { step(generator["throw"](value)); } catch (e) { reject(e); } }
        function step(result) { result.done ? resolve(result.value) : adopt(result.value).then(fulfilled, rejected); }
        step((generator = generator.apply(thisArg, _arguments || [])).next());
    });
};
var __generator = (this && this.__generator) || function (thisArg, body) {
    var _ = { label: 0, sent: function() { if (t[0] & 1) throw t[1]; return t[1]; }, trys: [], ops: [] }, f, y, t, g;
    return g = { next: verb(0), "throw": verb(1), "return": verb(2) }, typeof Symbol === "function" && (g[Symbol.iterator] = function() { return this; }), g;
    function verb(n) { return function (v) { return step([n, v]); }; }
    function step(op) {
        if (f) throw new TypeError("Generator is already executing.");
        while (_) try {
            if (f = 1, y && (t = op[0] & 2 ? y["return"] : op[0] ? y["throw"] || ((t = y["return"]) && t.call(y), 0) : y.next) && !(t = t.call(y, op[1])).done) return t;
            if (y = 0, t) op = [op[0] & 2, t.value];
            switch (op[0]) {
                case 0: case 1: t = op; break;
                case 4: _.label++; return { value: op[1], done: false };
                case 5: _.label++; y = op[1]; op = [0]; continue;
                case 7: op = _.ops.pop(); _.trys.pop(); continue;
                default:
                    if (!(t = _.trys, t = t.length > 0 && t[t.length - 1]) && (op[0] === 6 || op[0] === 2)) { _ = 0; continue; }
                    if (op[0] === 3 && (!t || (op[1] > t[0] && op[1] < t[3]))) { _.label = op[1]; break; }
                    if (op[0] === 6 && _.label < t[1]) { _.label = t[1]; t = op; break; }
                    if (t && _.label < t[2]) { _.label = t[2]; _.ops.push(op); break; }
                    if (t[2]) _.ops.pop();
                    _.trys.pop(); continue;
            }
            op = body.call(thisArg, _);
        } catch (e) { op = [6, e]; y = 0; } finally { f = t = 0; }
        if (op[0] & 5) throw op[1]; return { value: op[0] ? op[1] : void 0, done: true };
    }
};
exports.__esModule = true;
/**
 * Base class for loading a pre-trained Essentia-Tensorflow.js model for inference
 * using TensorFlow.js.
 * @class
 */
var EssentiaTensorflowJSModel = /** @class */ (function () {
    function EssentiaTensorflowJSModel(tfjs, modelPath, verbose) {
        if (verbose === void 0) { verbose = false; }
        this.model = null;
        this.audioSampleRate = 16000;
        this.tf = null;
        this.isReady = false;
        this.modelPath = "";
        this.IS_TRAIN = null;
        this.randomTensorInput = null;
        this.minimumInputFrameSize = null;
        this.tf = tfjs;
        this.IS_TRAIN = this.tf.tensor([0], [1], 'bool');
        this.modelPath = modelPath;
        this.isReady = !!this.model;
    }
    /**
     * Promise for loading & initialise an Essentia.js-TensorFlow.js model.
     * @async
     * @method
     * @memberof EssentiaTensorflowJSModel
     */
    EssentiaTensorflowJSModel.prototype.initialize = function () {
        return __awaiter(this, void 0, void 0, function () {
            var _a;
            return __generator(this, function (_b) {
                switch (_b.label) {
                    case 0:
                        _a = this;
                        return [4 /*yield*/, this.tf.loadGraphModel(this.modelPath)];
                    case 1:
                        _a.model = _b.sent();
                        this.isReady = true;
                        return [2 /*return*/];
                }
            });
        });
    };
    /**
     * Converts an input 1D or 2D array into a 3D tensor (tfjs) given it's shape and required
     * patchSize. If `padding=true`, this method will zero-pad the input feature.
     *
     * @method
     * @param {Float32Array|any[]} inputFeatureArray input feature array as either 1D or 2D array
     * @param {any[]} inputShape shape of the input feature array in 2D.
     * @param {number} patchSize required patchSize to dynamically make batches of feature
     * @param {boolean} [zeroPadding=false] whether to enable zero-padding if less frames found for a batch.
     * @returns {tf.Tensor3D} returns the computed frame-wise feature for the given audio signal.
     * @memberof EssentiaTensorflowJSModel
     */
    EssentiaTensorflowJSModel.prototype.arrayToTensorAsBatches = function (inputfeatureArray, inputShape, patchSize, zeroPadding) {
        if (zeroPadding === void 0) { zeroPadding = false; }
        // convert a flattened 1D typed array into 2D tensor with given shape 
        var featureTensor = this.tf.tensor(inputfeatureArray, inputShape, 'float32');
        // create a tensor of zeros for zero-padding the output tensor if necessary
        var zeroPadTensor;
        // variable to store the dynamic batch size computed from given input array and patchSize
        var batchSize;
        if (!zeroPadding) {
            this.assertMinimumFeatureInputSize({
                melSpectrum: inputfeatureArray,
                frameSize: inputShape[0],
                melBandsSize: inputShape[1],
                patchSize: patchSize
            });
            return featureTensor.as3D(1, patchSize, inputShape[1]);
            // return the feature with batch size 1 if number of frames = patchSize
        }
        else if (inputShape[0] === patchSize) {
            return featureTensor.as3D(1, patchSize, inputShape[1]);
            // Otherwise do zeropadding 
        }
        else if (inputShape[0] > patchSize) {
            if ((inputShape[0] % patchSize) != 0) {
                batchSize = Math.floor(inputShape[0] / patchSize) + 1;
                zeroPadTensor = this.tf.zeros([
                    Math.floor(batchSize * patchSize - inputfeatureArray.length),
                    inputShape[1]
                ], 'float32');
                featureTensor = featureTensor.concat(zeroPadTensor);
                zeroPadTensor.dispose();
                return featureTensor.as3D(batchSize, patchSize, inputShape[1]);
            }
            else {
                batchSize = Math.floor(inputShape[0] / patchSize);
                zeroPadTensor = this.tf.zeros([
                    Math.floor(batchSize * patchSize - inputfeatureArray.length),
                    inputShape[1]
                ], 'float32');
                featureTensor = featureTensor.concat(zeroPadTensor);
                zeroPadTensor.dispose();
                return featureTensor.as3D(batchSize, patchSize, inputShape[1]);
            }
        }
        else {
            // fixed batchSize=1 if the input array has lengh less than the given patchSize.
            batchSize = 1;
            zeroPadTensor = this.tf.zeros([patchSize - inputShape[0], inputShape[1]]);
            featureTensor = featureTensor.concat(zeroPadTensor);
            zeroPadTensor.dispose();
            return featureTensor.as3D(batchSize, patchSize, inputShape[1]);
        }
    };
    EssentiaTensorflowJSModel.prototype.dispose = function () {
        this.model.dispose();
    };
    EssentiaTensorflowJSModel.prototype.assertMinimumFeatureInputSize = function (inputFeature) {
        this.minimumInputFrameSize = inputFeature.patchSize; // at least 1 full patch
        if (inputFeature.melSpectrum.length != this.minimumInputFrameSize) {
            // let minimumAudioDuration = this.minimumInputFrameSize / this.audioSampleRate; // <-- cannot provide accurate duration without model input hopSize
            throw Error("When `padding=false` in `predict` method, the model expect audio feature for a minimum frame size of "
                + this.minimumInputFrameSize + ". Was given " + inputFeature.melSpectrum.length + " melband frames");
        }
    };
    EssentiaTensorflowJSModel.prototype.disambiguateExtraInputs = function () {
        if (!this.isReady)
            throw Error("No loaded tfjs model found! Make sure to call `initialize` method and resolve the promise before calling `predict` method.");
        var inputsCount = this.model.executor.inputs.length;
        if (inputsCount === 1) {
            return [];
        }
        else if (inputsCount === 2) {
            return [this.IS_TRAIN.clone()];
        }
        else if (inputsCount === 3) {
            // Overhead from the tensorflowjs-converter, creates random tensorinput without
            // connected to other layers for some vggish models trained on audioset. 
            // The tfjs model needs this unsignificant tensor object on the prediction call.
            // This will removed in future once this has been sorted on the conversation process.
            if (!this.randomTensorInput)
                this.randomTensorInput = this.tf.zeros([1, this.model.executor.inputs[0].shape[1]]);
            return [this.randomTensorInput.clone(), this.IS_TRAIN.clone()];
        }
        else {
            throw Error("Found unsupported number of input requirements for the model. Expects the following inputs -> " + this.model.executor.inputs);
        }
    };
    return EssentiaTensorflowJSModel;
}());
exports.EssentiaTensorflowJSModel = EssentiaTensorflowJSModel;
/**
 * Class with methods for computing inference of
 * Essentia-Tensorflow.js MusiCNN-based pre-trained models.
 * The `predict` method expect an input audio feature computed
 * using `EssentiaTFInputExtractor`.
 * @class
 * @example
 *
 * // FEATURE EXTRACTION
 * // Create `EssentiaTFInputExtractor` instance by passing
 * // essentia-wasm import `EssentiaWASM` global object and `extractorType=musicnn`.
 * const inputFeatureExtractor = new EssentiaTFInputExtractor(EssentiaWASM, "musicnn");
 * // Compute feature for a given audio signal
 * let inputMusiCNN = inputFeatureExtractor.computeFrameWise(audioSignal);
 * // INFERENCE
 * const modelURL = "./autotagging/msd/msd-musicnn-1/model.json"
 * // Where `tf` is the global import object from the `@tensorflow/tfjs*` package.
 * const musicnn = new TensorflowMusiCNN(tf, modelURL);
 * // Promise for loading the model
 * await musicnn.initialize();
 * // Compute predictions for a given input feature.
 * let predictions = await musicnn.predict(inputMusiCNN);
 * @extends {EssentiaTensorflowJSModel}
 */
var TensorflowMusiCNN = /** @class */ (function (_super) {
    __extends(TensorflowMusiCNN, _super);
    function TensorflowMusiCNN(tfjs, model_url, verbose) {
        if (verbose === void 0) { verbose = false; }
        var _this = _super.call(this, tfjs, model_url) || this;
        _this.minimumInputFrameSize = 3;
        return _this;
    }
    /**
     * Run inference on the given audio feature input and returns the activations
     * @param {InputMusiCNN} inputFeature audio feature required by the MusiCNN model.
     * @param {boolean} [zeroPadding=false] whether to do zero-padding to the input feature.
     * @returns {array} activations of the output layer of the model
     * @memberof TensorflowMusiCNN
     */
    TensorflowMusiCNN.prototype.predict = function (inputFeature, zeroPadding) {
        if (zeroPadding === void 0) { zeroPadding = false; }
        return __awaiter(this, void 0, void 0, function () {
            var featureTensor, modelInputs, results, resultsArray;
            return __generator(this, function (_a) {
                switch (_a.label) {
                    case 0:
                        featureTensor = this.arrayToTensorAsBatches(inputFeature.melSpectrum, [inputFeature.frameSize, inputFeature.melBandsSize], inputFeature.patchSize, zeroPadding);
                        modelInputs = this.disambiguateExtraInputs();
                        // add the input feature tensor to the model inputs
                        modelInputs.push(featureTensor);
                        results = this.model.execute(modelInputs);
                        // free tensors
                        featureTensor.dispose();
                        return [4 /*yield*/, results.array()];
                    case 1:
                        resultsArray = _a.sent();
                        results.dispose();
                        return [2 /*return*/, resultsArray];
                }
            });
        });
    };
    return TensorflowMusiCNN;
}(EssentiaTensorflowJSModel));
exports.TensorflowMusiCNN = TensorflowMusiCNN;
/**
* Class with methods for computing inference of
 * Essentia-Tensorflow.js VGGish-based pre-trained models.
 * The `predict` method expect an input audio feature computed
 * using `EssentiaTFInputExtractor`.
 * @class
 * @example
 * // FEATURE EXTRACTION
 * // Create `EssentiaTFInputExtractor` instance by passing
 * // essentia-wasm import `EssentiaWASM` global object and `extractorType=vggish`.
 * const inputFeatureExtractor = new EssentiaTFInputExtractor(EssentiaWASM, "vggish");
 * // Compute feature for a given audio signal array
 * let inputVGGish = inputFeatureExtractor.computeFrameWise(audioSignal);
 * // INFERENCE
 * const modelURL = "./classifiers/danceability/danceability-vggish-audioset-1/model.json"
 * // Where `tf` is the global import object from the `@tensorflow/tfjs*` package.
 * const vggish = new TensorflowVGGish(tf, modelURL);
 * // Promise for loading the model
 * await vggish.initialize();
 * // Compute predictions for a given input feature.
 * let predictions = await vggish.predict(inputVGGish);
 * @extends {EssentiaTensorflowJSModel}
 */
var TensorflowVGGish = /** @class */ (function (_super) {
    __extends(TensorflowVGGish, _super);
    function TensorflowVGGish(tfjs, model_url, verbose) {
        if (verbose === void 0) { verbose = false; }
        return _super.call(this, tfjs, model_url) || this;
    }
    /**
     * Run inference on the given audio feature input and returns the activations
     * @param {InputVGGish} inputFeature audio feature required by the VGGish model.
     * @param {boolean} [zeroPadding=false] whether to do zero-padding to the input feature.
     * @returns {array} activations of the output layer of the model
     * @memberof TensorflowVGGish
     */
    TensorflowVGGish.prototype.predict = function (inputFeature, zeroPadding) {
        if (zeroPadding === void 0) { zeroPadding = false; }
        return __awaiter(this, void 0, void 0, function () {
            var featureTensor, modelInputs, results, resultsArray;
            return __generator(this, function (_a) {
                switch (_a.label) {
                    case 0:
                        featureTensor = this.arrayToTensorAsBatches(inputFeature.melSpectrum, [inputFeature.frameSize, inputFeature.melBandsSize], inputFeature.patchSize, zeroPadding);
                        modelInputs = this.disambiguateExtraInputs();
                        // add the input feature tensor to the model inputs
                        modelInputs.push(featureTensor);
                        results = this.model.execute(modelInputs);
                        // free tensors
                        featureTensor.dispose();
                        return [4 /*yield*/, results.array()];
                    case 1:
                        resultsArray = _a.sent();
                        results.dispose();
                        return [2 /*return*/, resultsArray];
                }
            });
        });
    };
    return TensorflowVGGish;
}(EssentiaTensorflowJSModel));
exports.TensorflowVGGish = TensorflowVGGish;