/*
 * Decompiled with CFR 0.152.
 */
package marytts.signalproc.adaptation.prosody;

import java.io.IOException;
import marytts.signalproc.adaptation.BaselineAdaptationSet;
import marytts.signalproc.adaptation.IndexMap;
import marytts.signalproc.adaptation.codebook.WeightedCodebookFeatureCollection;
import marytts.signalproc.adaptation.codebook.WeightedCodebookTrainerParams;
import marytts.signalproc.adaptation.prosody.PitchMappingFile;
import marytts.signalproc.adaptation.prosody.PitchMappingFileHeader;
import marytts.signalproc.adaptation.prosody.PitchStatistics;
import marytts.signalproc.analysis.PitchReaderWriter;
import marytts.util.math.MathUtils;
import marytts.util.signal.SignalProcUtils;

public class PitchTrainer {
    private WeightedCodebookTrainerParams params;

    public PitchTrainer(WeightedCodebookTrainerParams pa) {
        this.params = new WeightedCodebookTrainerParams(pa);
    }

    public void learnMapping(PitchMappingFile pitchMappingFile, WeightedCodebookFeatureCollection fcol, BaselineAdaptationSet sourceTrainingSet, BaselineAdaptationSet targetTrainingSet, int[] map) {
        PitchMappingFileHeader header = new PitchMappingFileHeader();
        pitchMappingFile.writePitchMappingHeader(header);
        this.getStatistics(pitchMappingFile, fcol, sourceTrainingSet, true, map, PitchStatistics.STATISTICS_IN_HERTZ);
        this.getStatistics(pitchMappingFile, fcol, sourceTrainingSet, true, map, PitchStatistics.STATISTICS_IN_LOGHERTZ);
        this.getStatistics(pitchMappingFile, fcol, targetTrainingSet, false, map, PitchStatistics.STATISTICS_IN_HERTZ);
        this.getStatistics(pitchMappingFile, fcol, targetTrainingSet, false, map, PitchStatistics.STATISTICS_IN_LOGHERTZ);
    }

    public void getStatistics(PitchMappingFile pitchMappingFile, WeightedCodebookFeatureCollection fcol, BaselineAdaptationSet trainingSet, boolean isSource, int[] map, int statisticsType) {
        double tempSum;
        PitchStatistics global = new PitchStatistics(statisticsType, isSource, true);
        PitchStatistics local = new PitchStatistics(statisticsType, isSource, false);
        PitchReaderWriter f0s = null;
        double[] voiceds = null;
        double[] contourInt = null;
        double[] line = null;
        int globalCount = 0;
        int tiltCount = 0;
        global.init();
        IndexMap imap = new IndexMap();
        int i = 0;
        while (i < fcol.indexMapFiles.length) {
            System.out.println("Pitch mapping for pair " + String.valueOf(i + 1) + " of " + String.valueOf(fcol.indexMapFiles.length) + ":");
            try {
                imap.readFromFile(fcol.indexMapFiles[i]);
            }
            catch (IOException e) {
                e.printStackTrace();
            }
            if (imap.files != null && trainingSet.items.length > i) {
                local.init();
                f0s = isSource ? new PitchReaderWriter(trainingSet.items[i].pitchFile) : new PitchReaderWriter(trainingSet.items[map[i]].pitchFile);
                voiceds = f0s.getVoiceds();
                local.range = SignalProcUtils.getF0Range(voiceds);
                if (statisticsType == PitchStatistics.STATISTICS_IN_LOGHERTZ) {
                    f0s.contour = SignalProcUtils.getLogF0s(f0s.contour);
                    if (voiceds != null) {
                        voiceds = SignalProcUtils.getLogF0s(voiceds);
                    }
                    local.range = Math.log(local.range);
                }
                if (voiceds != null) {
                    tempSum = MathUtils.sum(voiceds);
                    global.mean += tempSum;
                    globalCount += voiceds.length;
                    local.mean = tempSum / (double)voiceds.length;
                    local.standardDeviation = MathUtils.standardDeviation(voiceds, local.mean);
                    if (i == 0 || local.range > global.range) {
                        global.range = local.range;
                    }
                    contourInt = SignalProcUtils.interpolate_pitch_uv(f0s.contour);
                    line = SignalProcUtils.getContourLSFit(contourInt, false);
                    local.intercept = line[0];
                    local.slope = line[1];
                    global.intercept += local.intercept;
                    global.slope += local.slope;
                    ++tiltCount;
                }
                pitchMappingFile.writeF0StatisticsEntry(local);
            }
            ++i;
        }
        global.mean = globalCount > 0 ? (global.mean /= (double)globalCount) : 0.0;
        if (tiltCount > 0) {
            global.intercept /= (double)tiltCount;
            global.slope /= (double)tiltCount;
        }
        System.out.println("Computing global pitch standard deviations...");
        tempSum = 0.0;
        i = 0;
        while (i < fcol.indexMapFiles.length) {
            try {
                imap.readFromFile(fcol.indexMapFiles[i]);
            }
            catch (IOException e) {
                e.printStackTrace();
            }
            if (imap.files != null && trainingSet.items.length > i && (voiceds = (f0s = isSource ? new PitchReaderWriter(trainingSet.items[i].pitchFile) : new PitchReaderWriter(trainingSet.items[map[i]].pitchFile)).getVoiceds()) != null) {
                tempSum += MathUtils.sumSquared(voiceds, -1.0 * global.mean);
            }
            ++i;
        }
        global.standardDeviation = globalCount > 1 ? Math.sqrt(tempSum / (double)(globalCount - 1)) : 1.0;
        pitchMappingFile.writeF0StatisticsEntry(global);
    }
}

