/*
 * Decompiled with CFR 0.152.
 */
package marytts.fst;

import java.io.BufferedReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import marytts.fst.StringPair;
import marytts.util.MaryUtils;
import org.apache.log4j.Logger;

public class AlignerTrainer {
    private HashMap<StringPair, Integer> aligncost;
    private int defaultcost = 10;
    private int skipcost;
    private double logOf2 = Math.log(2.0);
    protected List<String> optInfo;
    protected List<String[]> inSplit;
    protected List<String[]> outSplit;
    protected Set<String> graphemeSet;
    protected Logger logger;
    private boolean inIsOut;

    public AlignerTrainer(boolean inIsOutAlphabet, boolean hasOptInfo) {
        this.skipcost = this.defaultcost;
        this.aligncost = new HashMap();
        this.inSplit = new ArrayList<String[]>();
        this.outSplit = new ArrayList<String[]>();
        this.graphemeSet = new HashSet<String>();
        this.graphemeSet.add("null");
        this.inIsOut = inIsOutAlphabet;
        if (hasOptInfo) {
            this.optInfo = new ArrayList<String>();
        }
        this.logger = MaryUtils.getLogger(this.getClass());
    }

    public AlignerTrainer() {
        this(false, false);
    }

    public void readLexicon(BufferedReader lexicon, String splitSym) throws IOException {
        String line;
        while ((line = lexicon.readLine()) != null) {
            String[] lineParts = line.trim().split(splitSym);
            this.splitAndAdd(lineParts[0], lineParts[1]);
            if (this.optInfo == null) continue;
            this.optInfo.add(lineParts.length > 2 ? lineParts[2] : null);
        }
    }

    public void splitAndAdd(String inStr, String outStr) {
        String[] outStrSplit;
        String[] inStrSplit = new String[inStr.length()];
        for (int i = 0; i < inStr.length(); ++i) {
            String c = inStr.substring(i, i + 1);
            this.graphemeSet.add(c);
            inStrSplit[i] = c;
        }
        if (outStr.contains(" ")) {
            outStrSplit = outStr.split(" ");
            int max = outStrSplit.length;
            for (int i = 1; i < max; ++i) {
                outStrSplit[i] = " " + outStrSplit[i];
            }
        } else {
            outStrSplit = new String[outStr.length()];
            for (int i = 0; i < outStr.length(); ++i) {
                outStrSplit[i] = outStr.substring(i, i + 1);
            }
        }
        this.inSplit.add(inStrSplit);
        this.outSplit.add(outStrSplit);
    }

    public void addAlreadySplit(List<String> inStr, List<String> outStr) {
        this.inSplit.add(inStr.toArray(new String[0]));
        this.outSplit.add(outStr.toArray(new String[0]));
    }

    public void addAlreadySplit(String[] inStr, String[] outStr) {
        this.inSplit.add(inStr);
        this.outSplit.add(outStr);
    }

    public void addAlreadySplit(List<String> inStr, List<String> outStr, String optionalInfo) {
        this.inSplit.add(inStr.toArray(new String[0]));
        this.outSplit.add(outStr.toArray(new String[0]));
        this.optInfo.add(optionalInfo);
    }

    public void addAlreadySplit(String[] inStr, String[] outStr, String optionalInfo) {
        this.inSplit.add(inStr);
        this.outSplit.add(outStr);
        this.optInfo.add(optionalInfo);
    }

    public void alignIteration() {
        HashMap<String, Integer> symMapCount = new HashMap<String, Integer>();
        HashMap<StringPair, Integer> sym2symCount = new HashMap<StringPair, Integer>();
        int symCount = 0;
        int symDels = 0;
        for (int i = 0; i < this.outSplit.size(); ++i) {
            String[] in = this.inSplit.get(i);
            String[] out = this.outSplit.get(i);
            int[] alignment = this.align(in, out);
            symCount += in.length;
            int pre = 0;
            for (int inNr = 0; inNr < in.length; ++inNr) {
                if (alignment[inNr] == pre) {
                    ++symDels;
                } else {
                    Integer c = (Integer)symMapCount.get(in[inNr]);
                    if (null == c) {
                        symMapCount.put(in[inNr], alignment[inNr] - pre);
                    } else {
                        symMapCount.put(in[inNr], c + alignment[inNr] - pre);
                    }
                    for (int outNr = pre; outNr < alignment[inNr]; ++outNr) {
                        StringPair key = new StringPair(in[inNr], out[outNr]);
                        Integer mapC = (Integer)sym2symCount.get(key);
                        if (null == mapC) {
                            sym2symCount.put(key, 1);
                            continue;
                        }
                        sym2symCount.put(key, 1 + mapC);
                    }
                }
                pre = alignment[inNr];
            }
        }
        double delFraction = (double)symDels / (double)symCount;
        this.skipcost = (int)(-this.log2(delFraction));
        this.aligncost.clear();
        for (StringPair mapping : sym2symCount.keySet()) {
            String firstSym = mapping.getString1();
            double fraction = (double)((Integer)sym2symCount.get(mapping)).intValue() / (double)((Integer)symMapCount.get(firstSym)).intValue();
            int cost = (int)(-this.log2(fraction));
            if (cost >= this.defaultcost) continue;
            this.aligncost.put(mapping, cost);
        }
    }

    public int lexiconSize() {
        return this.inSplit.size();
    }

    public StringPair[] getAlignment(int entryNr) {
        String[] in = this.inSplit.get(entryNr);
        String[] out = this.outSplit.get(entryNr);
        int[] align = this.align(in, out);
        StringPair[] listArray = new StringPair[in.length];
        int pre = 0;
        for (int pos = 0; pos < in.length; ++pos) {
            String inStr = in[pos];
            String oStr = "";
            for (int alPos = pre; alPos < align[pos]; ++alPos) {
                oStr = oStr + out[alPos];
            }
            pre = align[pos];
            listArray[pos] = new StringPair(inStr, oStr);
        }
        return listArray;
    }

    public String[] getAlignmentString(int entryNr) {
        String[] in = this.inSplit.get(entryNr);
        String[] out = this.outSplit.get(entryNr);
        int[] align = this.align(in, out);
        String[] stringArray = new String[in.length];
        int pre = 0;
        for (int pos = 0; pos < in.length; ++pos) {
            String inStr = in[pos];
            String oStr = "";
            for (int alPos = pre; alPos < align[pos]; ++alPos) {
                oStr = oStr + " " + out[alPos];
            }
            pre = align[pos];
            stringArray[pos] = inStr + oStr;
        }
        return stringArray;
    }

    public StringPair[] getInfoAlignment(int entryNr) {
        if (null == this.optInfo.get(entryNr)) {
            return this.getAlignment(entryNr);
        }
        String[] in = this.inSplit.get(entryNr);
        String[] out = this.outSplit.get(entryNr);
        int[] align = this.align(in, out);
        StringPair[] listArray = new StringPair[in.length + 1];
        int pre = 0;
        for (int pos = 0; pos < in.length; ++pos) {
            String inStr = in[pos];
            String oStr = "";
            for (int alPos = pre; alPos < align[pos]; ++alPos) {
                oStr = oStr + out[alPos];
            }
            pre = align[pos];
            listArray[pos] = new StringPair(inStr, oStr);
        }
        listArray[in.length] = new StringPair(this.optInfo.get(entryNr), "");
        return listArray;
    }

    public Set<String> getInputSyms() {
        if (this.graphemeSet == null || this.graphemeSet.isEmpty()) {
            return this.collectInputSyms();
        }
        return this.graphemeSet;
    }

    private Set<String> collectInputSyms() {
        this.graphemeSet = new HashSet<String>();
        this.graphemeSet.add("null");
        for (String[] is : this.inSplit) {
            for (String sym : is) {
                this.graphemeSet.add(sym);
            }
        }
        return this.graphemeSet;
    }

    private double log2(double d) {
        return Math.log(d) / this.logOf2;
    }

    private int symDist(StringPair key) {
        Integer cost = this.aligncost.get(key);
        if (null == cost) {
            if (this.inIsOut) {
                return key.getString1().equals(key.getString2()) ? 0 : this.defaultcost;
            }
            return this.defaultcost;
        }
        return cost;
    }

    public int[] align(String[] istr, String[] ostr) {
        StringPair key = new StringPair(null, null);
        int[] p_d = new int[ostr.length + 1];
        int[] d = new int[ostr.length + 1];
        boolean[] p_sk = new boolean[ostr.length + 1];
        boolean[] sk = new boolean[ostr.length + 1];
        int[][] p_al = new int[ostr.length + 1][istr.length];
        int[][] al = new int[ostr.length + 1][istr.length];
        p_d[0] = 0;
        p_sk[0] = true;
        for (int j = 1; j < ostr.length + 1; ++j) {
            p_al[j][0] = j;
            key.setString1(istr[0]);
            key.setString2(ostr[j - 1]);
            p_d[j] = p_d[j - 1] + this.symDist(key);
            p_sk[j] = false;
        }
        int skConst = this.skipcost;
        for (int i = 1; i < istr.length; ++i) {
            d[0] = p_d[0] + skConst;
            sk[0] = true;
            for (int j = 1; j < ostr.length + 1; ++j) {
                int sk_cost;
                key.setString1(istr[i]);
                key.setString2(ostr[j - 1]);
                int tr_cost = this.symDist(key);
                int n = sk_cost = p_sk[j] ? skConst : 0;
                if (sk_cost + p_d[j] < tr_cost + d[j - 1]) {
                    d[j] = sk_cost + p_d[j];
                    al[j] = p_al[j];
                    al[j][i] = j;
                    sk[j] = true;
                    continue;
                }
                d[j] = tr_cost + d[j - 1];
                System.arraycopy(al[j - 1], 0, al[j], 0, i);
                al[j][i] = j;
                sk[j] = false;
            }
            int[] _d = p_d;
            p_d = d;
            d = _d;
            boolean[] _sk = p_sk;
            p_sk = sk;
            sk = _sk;
            int[][] _al = p_al;
            p_al = al;
            al = _al;
        }
        return p_al[ostr.length];
    }
}

