/*
 * Decompiled with CFR 0.152.
 */
package marytts.machinelearning;

import java.io.IOException;
import marytts.machinelearning.GaussianComponent;
import marytts.machinelearning.KMeansClusteringTrainer;
import marytts.util.io.MaryRandomAccessFile;
import marytts.util.math.MathUtils;

public class GMM {
    public double[] weights;
    public GaussianComponent[] components;
    public String info;
    public int featureDimension;
    public int totalComponents;
    public boolean isDiagonalCovariance;

    public GMM() {
        this(0, 0);
    }

    public GMM(int featureDimensionIn, int totalMixturesIn) {
        this.init(featureDimensionIn, totalMixturesIn, true);
    }

    public GMM(int featureDimensionIn, int totalComponentsIn, boolean isDiagonalCovarIn) {
        this.init(featureDimensionIn, totalComponentsIn, isDiagonalCovarIn);
    }

    public GMM(KMeansClusteringTrainer kmeansClusterer) {
        this.init(kmeansClusterer.getFeatureDimension(), kmeansClusterer.getTotalClusters(), kmeansClusterer.isDiagonalCovariance());
        for (int i = 0; i < kmeansClusterer.getTotalClusters(); ++i) {
            this.components[i] = new GaussianComponent(kmeansClusterer.clusters[i]);
        }
    }

    public GMM(GMM existing) {
        this.featureDimension = existing.featureDimension;
        this.totalComponents = existing.totalComponents;
        this.isDiagonalCovariance = existing.isDiagonalCovariance;
        if (existing.totalComponents > 0 && existing.components != null) {
            this.components = new GaussianComponent[this.totalComponents];
            for (int i = 0; i < this.totalComponents; ++i) {
                this.components[i] = new GaussianComponent(existing.components[i]);
            }
        } else {
            this.components = null;
            this.totalComponents = 0;
        }
        if (existing.weights != null) {
            this.weights = new double[existing.weights.length];
            System.arraycopy(existing.weights, 0, this.weights, 0, existing.weights.length);
        } else {
            this.weights = null;
        }
        this.info = existing.info;
    }

    public GMM(String gmmFile) {
        try {
            this.read(gmmFile);
        }
        catch (IOException e) {
            e.printStackTrace();
        }
    }

    public void init(int featureDimensionIn, int totalMixturesIn, boolean isDiagonalCovarIn) {
        this.featureDimension = featureDimensionIn;
        this.totalComponents = totalMixturesIn;
        this.isDiagonalCovariance = isDiagonalCovarIn;
        if (this.totalComponents > 0) {
            this.components = new GaussianComponent[this.totalComponents];
            this.weights = new double[this.totalComponents];
            for (int i = 0; i < this.totalComponents; ++i) {
                this.components[i] = new GaussianComponent(featureDimensionIn, isDiagonalCovarIn);
                this.weights[i] = 1.0 / (double)this.totalComponents;
            }
        } else {
            this.components = null;
            this.weights = null;
            this.totalComponents = 0;
            if (this.featureDimension < 0) {
                this.featureDimension = 0;
            }
        }
        this.info = "";
    }

    public double probability(double[] x) {
        double score = 0.0;
        for (int i = 0; i < this.totalComponents; ++i) {
            score += this.weights[i] * this.components[i].probability(x);
        }
        return score;
    }

    public double[] componentProbabilities(double[] x) {
        int i;
        double[] probs = new double[this.totalComponents];
        double totalProb = 0.0;
        if (this.isDiagonalCovariance) {
            for (i = 0; i < this.totalComponents; ++i) {
                probs[i] = this.weights[i] * MathUtils.getGaussianPdfValue(x, this.components[i].meanVector, this.components[i].covMatrix[0], this.components[i].getConstantTerm());
                totalProb += probs[i];
            }
        } else {
            for (i = 0; i < this.totalComponents; ++i) {
                probs[i] = this.weights[i] * MathUtils.getGaussianPdfValue(x, this.components[i].meanVector, this.components[i].getDetCovMatrix(), this.components[i].getInvCovMatrix());
                totalProb += probs[i];
            }
        }
        i = 0;
        while (i < this.totalComponents) {
            int n = i++;
            probs[n] = probs[n] / totalProb;
        }
        return probs;
    }

    public void write(String gmmFile) throws IOException {
        MaryRandomAccessFile stream = new MaryRandomAccessFile(gmmFile, "rw");
        this.write(stream);
        stream.close();
    }

    public void write(MaryRandomAccessFile stream) throws IOException {
        stream.writeIntEndian(this.featureDimension);
        stream.writeIntEndian(this.totalComponents);
        stream.writeBooleanEndian(this.isDiagonalCovariance);
        if (this.info != null && this.info.length() > 0) {
            stream.writeIntEndian(this.info.length());
            stream.writeCharEndian(this.info.toCharArray());
        } else {
            stream.writeIntEndian(0);
        }
        stream.writeDoubleEndian(this.weights);
        for (int i = 0; i < this.totalComponents; ++i) {
            this.components[i].write(stream);
        }
    }

    public void read(String gmmFile) throws IOException {
        MaryRandomAccessFile stream = new MaryRandomAccessFile(gmmFile, "r");
        this.read(stream);
        stream.close();
    }

    public void read(MaryRandomAccessFile stream) throws IOException {
        this.featureDimension = stream.readIntEndian();
        this.totalComponents = stream.readIntEndian();
        this.isDiagonalCovariance = stream.readBooleanEndian();
        int tmpLen = stream.readIntEndian();
        if (tmpLen > 0) {
            this.info = String.copyValueOf(stream.readCharEndian(tmpLen));
        }
        this.weights = stream.readDoubleEndian(this.totalComponents);
        this.components = new GaussianComponent[this.totalComponents];
        for (int i = 0; i < this.totalComponents; ++i) {
            this.components[i] = new GaussianComponent();
            this.components[i].read(stream);
        }
    }
}

