/*
 * Decompiled with CFR 0.152.
 */
package marytts.machinelearning;

import java.awt.Color;
import java.util.Arrays;
import javax.swing.JFrame;
import marytts.machinelearning.KMeansClusteringTrainerParams;
import marytts.machinelearning.PolynomialCluster;
import marytts.signalproc.display.FunctionGraph;
import marytts.util.math.MathUtils;
import marytts.util.math.Polynomial;

public class PolynomialKMeansClusteringTrainer {
    public static PolynomialCluster[] train(Polynomial[] polynomials, KMeansClusteringTrainerParams kmeansParams) {
        int t;
        int observations = polynomials.length;
        int polynomialOrder = polynomials[0].getOrder();
        Polynomial[] m_new = new Polynomial[kmeansParams.numClusters];
        boolean[][] b = new boolean[observations][kmeansParams.numClusters];
        boolean[][] b_old = new boolean[observations][kmeansParams.numClusters];
        Polynomial[] clusterMeans = new Polynomial[kmeansParams.numClusters];
        for (int k = 0; k < kmeansParams.numClusters; ++k) {
            clusterMeans[k] = new Polynomial(polynomialOrder);
        }
        for (int t2 = 1; t2 <= observations; ++t2) {
            Arrays.fill(b[t2 - 1], false);
        }
        Polynomial mAll = Polynomial.mean(polynomials);
        double[] dists = new double[observations];
        double[] tmp = new double[kmeansParams.numClusters + 1];
        for (int k = 1; k <= kmeansParams.numClusters; ++k) {
            for (int t3 = 1; t3 <= observations; ++t3) {
                if (k > 1) {
                    for (int i = 1; i <= k - 1; ++i) {
                        tmp[i - 1] = clusterMeans[i - 1].polynomialDistance(polynomials[t3 - 1]);
                    }
                    tmp[k - 1] = mAll.polynomialDistance(polynomials[t3 - 1]);
                    dists[t3 - 1] = MathUtils.mean(tmp, 0, k - 1);
                    continue;
                }
                dists[t3 - 1] = mAll.polynomialDistance(polynomials[t3 - 1]);
            }
            double maxD = Double.MIN_VALUE;
            int maxInd = -1;
            for (t = 1; t <= observations; ++t) {
                if (!(dists[t - 1] > maxD)) continue;
                maxD = dists[t - 1];
                maxInd = t;
            }
            clusterMeans[k - 1].copyCoeffs(polynomials[maxInd - 1]);
        }
        int[] tinyClusterInds = new int[kmeansParams.numClusters];
        int numTinyClusters = 0;
        int[] totalObservationsInClusters = new int[kmeansParams.numClusters];
        int[] clusterIndices = new int[observations];
        int iter = 0;
        boolean bCont = true;
        while (bCont) {
            for (t = 1; t <= observations; ++t) {
                double minDist = Double.MAX_VALUE;
                int ind = -1;
                for (int i = 1; i <= kmeansParams.numClusters; ++i) {
                    double tmpDist = clusterMeans[i - 1].polynomialDistance(polynomials[t - 1]);
                    b[t - 1][i - 1] = false;
                    if (!(tmpDist < minDist)) continue;
                    minDist = tmpDist;
                    ind = i;
                }
                b[t - 1][ind - 1] = true;
            }
            for (int i = 1; i <= kmeansParams.numClusters; ++i) {
                totalObservationsInClusters[i - 1] = 0;
                tinyClusterInds[i - 1] = 0;
            }
            int c = 1;
            for (int i = 1; i <= kmeansParams.numClusters; ++i) {
                m_new[i - 1] = new Polynomial(polynomialOrder);
                for (int t4 = 1; t4 <= observations; ++t4) {
                    if (!b[t4 - 1][i - 1]) continue;
                    for (int d = 0; d <= polynomialOrder; ++d) {
                        int n = d;
                        m_new[i - 1].coeffs[n] = m_new[i - 1].coeffs[n] + polynomials[t4 - 1].coeffs[d];
                    }
                    clusterIndices[t4 - 1] = i - 1;
                    int n = i - 1;
                    totalObservationsInClusters[n] = totalObservationsInClusters[n] + 1;
                }
                if (!((double)totalObservationsInClusters[i - 1] < (double)kmeansParams.minSamplesInOneCluster)) continue;
                tinyClusterInds[c - 1] = i;
                ++numTinyClusters;
                ++c;
            }
            c = 0;
            double[] tmps = new double[totalObservationsInClusters.length];
            for (int a = 0; a < tmps.length; ++a) {
                tmps[a] = totalObservationsInClusters[a];
            }
            int[] inds = MathUtils.quickSort(tmps, 0, kmeansParams.numClusters - 1);
            for (int i = 1; i <= kmeansParams.numClusters; ++i) {
                int d;
                if (totalObservationsInClusters[i - 1] >= kmeansParams.minSamplesInOneCluster) {
                    for (d = 0; d <= polynomialOrder; ++d) {
                        clusterMeans[i - 1].coeffs[d] = m_new[i - 1].coeffs[d] / (double)totalObservationsInClusters[i - 1];
                    }
                    continue;
                }
                for (d = 0; d <= polynomialOrder; ++d) {
                    double rnd = 2.0 * (Math.random() - 0.5) * clusterMeans[inds[kmeansParams.numClusters - c - 1]].coeffs[d] * 0.01;
                    clusterMeans[i - 1].coeffs[d] = clusterMeans[inds[kmeansParams.numClusters - c - 1]].coeffs[d] + rnd;
                }
                ++c;
            }
            int[] prev_totals = (int[])totalObservationsInClusters.clone();
            int totChanged = 0;
            if (++iter > 1) {
                if (iter >= kmeansParams.maxIterations) {
                    bCont = false;
                }
                block17: for (int t5 = 1; t5 <= observations; ++t5) {
                    for (int i = 1; i <= kmeansParams.numClusters; ++i) {
                        if (b_old[t5 - 1][i - 1] == b[t5 - 1][i - 1]) continue;
                        ++totChanged;
                        continue block17;
                    }
                }
                double changedPerc = (double)totChanged / (double)observations * 100.0;
                if (changedPerc < kmeansParams.minClusterChangePercent) {
                    bCont = false;
                }
            }
            for (int t6 = 1; t6 <= observations; ++t6) {
                System.arraycopy(b[t6 - 1], 0, b_old[t6 - 1], 0, b[t6 - 1].length);
            }
        }
        PolynomialCluster[] clusters = new PolynomialCluster[kmeansParams.numClusters];
        for (int i = 1; i <= kmeansParams.numClusters; ++i) {
            Polynomial[] members = new Polynomial[totalObservationsInClusters[i - 1]];
            int m = 0;
            for (int t7 = 1; t7 <= observations; ++t7) {
                if (!b[t7 - 1][i - 1]) continue;
                members[m] = polynomials[t7 - 1];
                ++m;
            }
            assert (m == members.length);
            clusters[i - 1] = new PolynomialCluster(clusterMeans[i - 1], members);
        }
        return clusters;
    }

    public static void main(String[] args) {
        int order = 3;
        int numPolynomials = 1000;
        int numClusters = 50;
        Polynomial[] ps = new Polynomial[numPolynomials];
        for (int i = 0; i < numPolynomials; ++i) {
            double[] coeffs = new double[order + 1];
            for (int c = 0; c < coeffs.length; ++c) {
                coeffs[c] = Math.random();
            }
            ps[i] = new Polynomial(coeffs);
        }
        KMeansClusteringTrainerParams params = new KMeansClusteringTrainerParams();
        params.numClusters = numClusters;
        PolynomialCluster[] clusters = PolynomialKMeansClusteringTrainer.train(ps, params);
        FunctionGraph clusterGraph = new FunctionGraph(0.0, 1.0, new double[1]);
        clusterGraph.setYMinMax(0.0, 5.0);
        clusterGraph.setPrimaryDataSeriesStyle(Color.BLUE, 2, 1);
        JFrame jf = clusterGraph.showInJFrame("", false, true);
        for (int i = 0; i < clusters.length; ++i) {
            double[] meanValues = clusters[i].getMeanPolynomial().generatePolynomialValues(100, 0.0, 1.0);
            clusterGraph.updateData(0.0, 1.0 / (double)meanValues.length, meanValues);
            Polynomial[] members = clusters[i].getClusterMembers();
            for (int m = 0; m < members.length; ++m) {
                double[] pred = members[m].generatePolynomialValues(meanValues.length, 0.0, 1.0);
                clusterGraph.addDataSeries(pred, Color.GRAY, 1, -1);
                jf.repaint();
            }
            jf.setTitle("Cluster " + (i + 1) + " of " + clusters.length + ": " + members.length + " members");
            jf.repaint();
            try {
                Thread.sleep(500L);
                continue;
            }
            catch (InterruptedException ie) {
                // empty catch block
            }
        }
        System.exit(0);
    }
}

