/*
 * Decompiled with CFR 0.152.
 */
package marytts.htsengine;

import java.util.Arrays;
import marytts.htsengine.HMMData;
import marytts.htsengine.HTSParameterGeneration;
import marytts.util.MaryUtils;
import org.apache.log4j.Logger;

public class HTSPStream {
    public static final int WLEFT = 0;
    public static final int WRIGHT = 1;
    public static final int NUM = 3;
    private static final int WIDTH = 3;
    public final HMMData.FeatureType feaType;
    private final int vSize;
    private final int order;
    private int nT;
    private double[][] par;
    private double[][] mseq;
    private double[][] ivseq;
    private double[] g;
    private double[][] wuw;
    private double[] wum;
    static final int[] leftWidths = new int[]{0, -1, -1};
    static final int[] rightWidths = new int[]{0, 1, 1};
    static final double[] xcoefs = new double[]{0.0, 1.0, 0.0, -0.5, 0.0, 0.5, 1.0, -2.0, 1.0};
    private double mean;
    private double var;
    private final int maxGVIter;
    private static final double GVepsilon = 1.0E-4;
    private static final double minEucNorm = 0.01;
    private static final double stepInit = 0.1;
    private static final double stepDec = 0.5;
    private static final double stepInc = 1.2;
    private static final double w1 = 1.0;
    private static final double w2 = 1.0;
    private static final double lzero = -1.0E10;
    private double norm = 0.0;
    private double GVobj = 0.0;
    private double HMMobj = 0.0;
    private double[] gvmean;
    private double[] gvcovInv;
    private boolean[] gvSwitch;
    private int gvLength;
    private Logger logger = MaryUtils.getLogger("PStream");

    public int getDWLeftBoundary(int i) {
        return leftWidths[i];
    }

    public int getDWRightBoundary(int i) {
        return rightWidths[i];
    }

    public HTSPStream(int vector_size, int utt_length, HMMData.FeatureType fea_type, int maxIterationsGV) throws Exception {
        this.feaType = fea_type;
        this.vSize = vector_size;
        this.order = vector_size / 3;
        this.nT = utt_length;
        this.maxGVIter = maxIterationsGV;
        this.par = new double[this.nT][this.order];
        this.mseq = new double[this.nT][this.vSize];
        this.ivseq = new double[this.nT][this.vSize];
        this.g = new double[this.nT];
        this.wuw = new double[this.nT][3];
        this.wum = new double[this.nT];
        this.gvSwitch = new boolean[this.nT];
        for (int i = 0; i < this.nT; ++i) {
            this.gvSwitch[i] = true;
        }
        this.gvLength = this.nT;
    }

    public int getVsize() {
        return this.vSize;
    }

    public int getOrder() {
        return this.order;
    }

    public void setPar(int i, int j, double val) {
        this.par[i][j] = val;
    }

    public double getPar(int i, int j) {
        return this.par[i][j];
    }

    public double[] getParVec(int i) {
        return Arrays.copyOf(this.par[i], this.par[i].length);
    }

    public int getT() {
        return this.nT;
    }

    public void setMseq(int i, int j, double val) {
        this.mseq[i][j] = val;
    }

    public void setMseq(int i, double[] vec) {
        this.mseq[i] = vec;
    }

    public void setVseq(int i, double[] vec) {
        assert (vec.length == this.ivseq[i].length);
        for (int j = 0; j < this.ivseq[i].length; ++j) {
            this.ivseq[i][j] = HTSParameterGeneration.finv(vec[j]);
        }
    }

    public void setIvseq(int i, int j, double val) {
        this.ivseq[i][j] = val;
    }

    public void setGvMeanVar(double[] mean, double[] ivar) {
        this.gvmean = mean;
        this.gvcovInv = ivar;
    }

    public void setGvSwitch(int i, boolean bv) {
        if (!bv) {
            --this.gvLength;
        }
        this.gvSwitch[i] = bv;
    }

    public void fixDynFeatOnBoundaries() {
        for (int k = 1; k < this.vSize; ++k) {
            this.setIvseq(0, k, 0.0);
            this.setIvseq(this.nT - 1, k, 0.0);
        }
    }

    private void printWUW(int t) {
        for (int i = 0; i < 3; ++i) {
            System.out.print("WUW[" + t + "][" + i + "]=" + this.wuw[t][i] + "  ");
        }
        System.out.println("");
    }

    public void mlpg(HMMData htsData) {
        this.mlpg(htsData, htsData.getUseGV());
    }

    public void mlpg(HMMData htsData, boolean useGV) {
        if (htsData.getUseContextDependentGV()) {
            this.logger.info("Context-dependent global variance optimization: gvLength = " + this.gvLength);
        } else {
            this.logger.info("Global variance optimization");
        }
        for (int m = 0; m < this.order; ++m) {
            this.calcWUWandWUM(m);
            double[][] mywuw = new double[this.nT][];
            for (int x = 0; x < this.wuw.length; ++x) {
                mywuw[x] = Arrays.copyOf(this.wuw[x], this.wuw[x].length);
            }
            double[] mywum = Arrays.copyOf(this.wum, this.wum.length);
            HTSPStream.ldlFactorization(mywuw);
            this.forwardSubstitution(mywum, mywuw);
            this.backwardSubstitution(m, mywuw);
            if (!useGV || this.gvLength <= 0) continue;
            if (htsData.getGvMethodGradient()) {
                this.gvParmGenGradient(m, false);
                continue;
            }
            this.gvParmGenDerivative(m, false);
        }
    }

    private void calcWUWandWUM(int m) {
        Arrays.fill(this.wum, 0, this.nT, 0.0);
        for (int t = 0; t < this.nT; ++t) {
            Arrays.fill(this.wuw[t], 0.0);
            for (int i = 0; i < 3; ++i) {
                int dwWidth_iright = rightWidths[i];
                int iorder = i * this.order + m;
                for (int j = leftWidths[i]; j <= dwWidth_iright; ++j) {
                    double dwCoef_ij;
                    if (t + j < 0 || t + j >= this.nT || (dwCoef_ij = xcoefs[1 + i * 3 - j]) == 0.0) continue;
                    double WU = dwCoef_ij * this.ivseq[t + j][iorder];
                    int n = t;
                    this.wum[n] = this.wum[n] + WU * this.mseq[t + j][iorder];
                    for (int k = 0; k < 3 && t + k < this.nT; ++k) {
                        double dwCoef_ikj;
                        if (k - j > dwWidth_iright || (dwCoef_ikj = xcoefs[1 + i * 3 + k - j]) == 0.0) continue;
                        double[] dArray = this.wuw[t];
                        int n2 = k;
                        dArray[n2] = dArray[n2] + WU * dwCoef_ikj;
                    }
                }
            }
        }
    }

    private static void ldlFactorization(double[][] mywuw) {
        for (int t = 0; t < mywuw.length; ++t) {
            int i;
            for (i = 1; i < 3 && t - i >= 0; ++i) {
                double[] dArray = mywuw[t];
                dArray[0] = dArray[0] - mywuw[t - i][i] * mywuw[t - i][i] * mywuw[t - i][0];
            }
            for (i = 2; i <= 3; ++i) {
                int j = 1;
                while (i + j <= 3 && t - j >= 0) {
                    double[] dArray = mywuw[t];
                    int n = i - 1;
                    dArray[n] = dArray[n] - mywuw[t - j][j] * mywuw[t - j][i + j - 1] * mywuw[t - j][0];
                    ++j;
                }
                double[] dArray = mywuw[t];
                int n = i - 1;
                dArray[n] = dArray[n] / mywuw[t][0];
            }
        }
    }

    private void forwardSubstitution(double[] mywum, double[][] mywuw) {
        System.arraycopy(mywum, 0, this.g, 0, mywum.length);
        for (int t = 0; t < this.nT; ++t) {
            for (int i = 1; i < 3 && t - i >= 0; ++i) {
                int n = t;
                this.g[n] = this.g[n] - mywuw[t - i][i] * this.g[t - i];
            }
        }
    }

    private void backwardSubstitution(int m, double[][] mywuw) {
        for (int t = this.nT - 1; t >= 0; --t) {
            this.par[t][m] = this.g[t] / mywuw[t][0];
            for (int i = 1; i < 3 && t + i < this.nT; ++i) {
                double[] dArray = this.par[t];
                int n = m;
                dArray[n] = dArray[n] - mywuw[t][i] * this.par[t + i][m];
            }
        }
    }

    private void gvParmGenDerivative(int m, boolean debug) {
        int iter;
        int t;
        double step2 = 0.1;
        double prev = 1.0E10;
        double obj = 0.0;
        double[] diag = new double[this.nT];
        double[] par_ori = new double[this.nT];
        this.mean = 0.0;
        this.var = 0.0;
        boolean numDown = false;
        for (t = 0; t < this.nT; ++t) {
            this.g[t] = 0.0;
            par_ori[t] = this.par[t][m];
        }
        this.convGV(m);
        this.calcWUWandWUM(m);
        for (iter = 1; iter <= this.maxGVIter; ++iter) {
            obj = this.calcDerivative(m);
            if (obj > prev) {
                step2 *= 0.5;
            }
            if (obj < prev) {
                step2 *= 1.2;
            }
            for (t = 0; t < this.nT; ++t) {
                double[] dArray = this.par[t];
                int n = m;
                dArray[n] = dArray[n] + step2 * this.g[t];
            }
            prev = obj;
        }
        this.logger.info("Derivative GV optimization for feature: (" + m + ")  number of iterations=" + (iter - 1));
    }

    private void gvParmGenGradient(int m, boolean debug) {
        int iter;
        int t;
        double step2 = 0.1;
        double obj = 0.0;
        double prev = 0.0;
        double[] diag = new double[this.nT];
        double[] par_ori = new double[this.nT];
        this.mean = 0.0;
        this.var = 0.0;
        int numDown = 0;
        int totalNumIter = 0;
        int firstIter = 0;
        for (t = 0; t < this.nT; ++t) {
            this.g[t] = 0.0;
            par_ori[t] = this.par[t][m];
        }
        this.convGV(m);
        this.calcWUWandWUM(m);
        for (iter = 1; iter <= this.maxGVIter; ++iter) {
            obj = this.calcGradient(m);
            if (iter > 1) {
                if (obj > prev) {
                    step2 *= 1.2;
                    numDown = 0;
                }
                if (obj < prev) {
                    for (t = 0; t < this.nT; ++t) {
                        double[] dArray = this.par[t];
                        int n = m;
                        dArray[n] = dArray[n] - step2 * diag[t];
                    }
                    step2 *= 0.5;
                    for (t = 0; t < this.nT; ++t) {
                        double[] dArray = this.par[t];
                        int n = m;
                        dArray[n] = dArray[n] + step2 * diag[t];
                    }
                    --iter;
                    if (++numDown < 100) continue;
                    this.logger.info("  ***Convergence problems....optimization stopped. Number of iterations: " + iter);
                    break;
                }
            } else if (debug) {
                this.logger.info("  First iteration:  GVobj=" + obj + " (HMMobj=" + this.HMMobj + "  GVobj=" + this.GVobj + ")");
            }
            if (this.norm < 0.01 || iter > 1 && Math.abs(obj - prev) < 1.0E-4) {
                if (debug) {
                    this.logger.info("  Number of iterations: [   " + iter + "   ] GVobj=" + obj + " (HMMobj=" + this.HMMobj + "  GVobj=" + this.GVobj + ")");
                }
                ++totalNumIter;
                if (m == 0) {
                    firstIter = iter;
                }
                if (!debug) break;
                if (iter > 1) {
                    this.logger.info("  Converged (norm=" + this.norm + ", change=" + Math.abs(obj - prev) + ")");
                    break;
                }
                this.logger.info("  Converged (norm=" + this.norm + ")");
                break;
            }
            for (t = 0; t < this.nT; ++t) {
                double[] dArray = this.par[t];
                int n = m;
                dArray[n] = dArray[n] + step2 * this.g[t];
                diag[t] = this.g[t];
            }
            prev = obj;
        }
        if (iter > this.maxGVIter) {
            this.logger.info("   optimization stopped by reaching max number of iterations (no global variance applied)");
            for (t = 0; t < this.nT; ++t) {
                this.par[t][m] = par_ori[t];
            }
        }
        totalNumIter = iter;
        this.logger.info("Gradient GV optimization for feature: (" + m + ")  number of iterations=" + totalNumIter);
    }

    private double calcGradient(int m) {
        int t;
        double w = 1.0 / (double)(3 * this.nT);
        this.calcGV(m);
        this.GVobj = -0.5 * (this.var - this.gvmean[m]) * this.gvcovInv[m] * (this.var - this.gvmean[m]);
        double vd = this.gvcovInv[m] * (this.var - this.gvmean[m]);
        for (t = 0; t < this.nT; ++t) {
            this.g[t] = this.wuw[t][0] * this.par[t][m];
            for (int i = 2; i <= 3; ++i) {
                if (t + i - 1 < this.nT) {
                    int n = t;
                    this.g[n] = this.g[n] + this.wuw[t][i - 1] * this.par[t + i - 1][m];
                }
                if (t - i + 1 < 0) continue;
                int n = t;
                this.g[n] = this.g[n] + this.wuw[t - i + 1][i - 1] * this.par[t - i + 1][m];
            }
        }
        this.HMMobj = 0.0;
        this.norm = 0.0;
        for (t = 0; t < this.nT; ++t) {
            this.HMMobj += -0.5 * w * this.par[t][m] * (this.g[t] - 2.0 * this.wum[t]);
            double h = (double)(this.nT - 1) * vd + 2.0 * this.gvcovInv[m] * (this.par[t][m] - this.mean) * (this.par[t][m] - this.mean);
            h = -1.0 * w * this.wuw[t][0] - 2.0 / (double)(this.nT * this.nT) * h;
            h = -1.0 / h;
            if (this.gvSwitch[t]) {
                double aux = (this.par[t][m] - this.mean) * vd;
                this.g[t] = h * (1.0 * w * (-this.g[t] + this.wum[t]) + -2.0 / (double)this.nT * aux);
            } else {
                this.g[t] = h * (1.0 * w * (-this.g[t] + this.wum[t]));
            }
            this.norm += this.g[t] * this.g[t];
        }
        this.norm = Math.sqrt(this.norm);
        return this.HMMobj + this.GVobj;
    }

    private double calcDerivative(int m) {
        int t;
        double w = 1.0 / (double)(3 * this.nT);
        this.calcGV(m);
        this.GVobj = -0.5 * this.var * this.gvcovInv[m] * (this.var - 2.0 * this.gvmean[m]);
        double vd = -2.0 * this.gvcovInv[m] * (this.var - this.gvmean[m]) / (double)this.nT;
        for (t = 0; t < this.nT; ++t) {
            this.g[t] = this.wuw[t][0] * this.par[t][m];
            for (int i = 2; i <= 3; ++i) {
                if (t + i - 1 < this.nT) {
                    int n = t;
                    this.g[n] = this.g[n] + this.wuw[t][i - 1] * this.par[t + i - 1][m];
                }
                if (t - i + 1 < 0) continue;
                int n = t;
                this.g[n] = this.g[n] + this.wuw[t - i + 1][i - 1] * this.par[t - i + 1][m];
            }
        }
        this.HMMobj = 0.0;
        for (t = 0; t < this.nT; ++t) {
            this.HMMobj += 1.0 * w * this.par[t][m] * (this.wum[t] - 0.5 * this.g[t]);
            double h = -1.0 * w * this.wuw[t][0] - 2.0 / (double)(this.nT * this.nT) * ((double)(this.nT - 1) * this.gvcovInv[m] * (this.var - this.gvmean[m]) + 2.0 * this.gvcovInv[m] * (this.par[t][m] - this.mean) * (this.par[t][m] - this.mean));
            this.g[t] = this.gvSwitch[t] ? 1.0 / h * (1.0 * w * (-this.g[t] + this.wum[t]) + 1.0 * vd * (this.par[t][m] - this.mean)) : 1.0 / h * (1.0 * w * (-this.g[t] + this.wum[t]));
        }
        return -(this.HMMobj + this.GVobj);
    }

    private void convGV(int m) {
        this.calcGV(m);
        double ratio = Math.sqrt(this.gvmean[m] / this.var);
        for (int t = 0; t < this.nT; ++t) {
            if (!this.gvSwitch[t]) continue;
            this.par[t][m] = ratio * (this.par[t][m] - this.mean) + this.mean;
        }
    }

    private void calcGV(int m) {
        int t;
        this.mean = 0.0;
        this.var = 0.0;
        for (t = 0; t < this.nT; ++t) {
            if (!this.gvSwitch[t]) continue;
            this.mean += this.par[t][m];
        }
        this.mean /= (double)this.gvLength;
        for (t = 0; t < this.nT; ++t) {
            if (!this.gvSwitch[t]) continue;
            this.var += (this.par[t][m] - this.mean) * (this.par[t][m] - this.mean);
        }
        this.var /= (double)this.gvLength;
    }
}

