/*
 * Decompiled with CFR 0.152.
 */
package be.ac.ulg.montefiore.run.jahmm.learn;

import be.ac.ulg.montefiore.run.jahmm.ForwardBackwardCalculator;
import be.ac.ulg.montefiore.run.jahmm.Hmm;
import be.ac.ulg.montefiore.run.jahmm.Observation;
import be.ac.ulg.montefiore.run.jahmm.Opdf;
import be.ac.ulg.montefiore.run.jahmm.learn.KMeansLearner;
import java.util.Arrays;
import java.util.EnumSet;
import java.util.Iterator;
import java.util.List;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class BaumWelchLearner {
    private int nbIterations = 9;

    public <O extends Observation> Hmm<O> iterate(Hmm<O> hmm, List<? extends List<? extends O>> list) {
        int n;
        int n2;
        int n3;
        Object object;
        int n4;
        Object object2;
        try {
            object2 = hmm.clone();
        }
        catch (CloneNotSupportedException cloneNotSupportedException) {
            throw new InternalError();
        }
        double[][][] dArrayArray = new double[list.size()][][];
        double[][] dArray = new double[hmm.nbStates()][hmm.nbStates()];
        double[] dArray2 = new double[hmm.nbStates()];
        Arrays.fill(dArray2, 0.0);
        for (n4 = 0; n4 < hmm.nbStates(); ++n4) {
            Arrays.fill(dArray[n4], 0.0);
        }
        n4 = 0;
        for (List<O> list2 : list) {
            object = this.generateForwardBackwardCalculator(list2, hmm);
            double[][][] dArray3 = this.estimateXi(list2, (ForwardBackwardCalculator)object, hmm);
            int n5 = n4++;
            double[][] dArray4 = this.estimateGamma(dArray3, (ForwardBackwardCalculator)object);
            dArrayArray[n5] = dArray4;
            double[][] dArray5 = dArray4;
            for (n3 = 0; n3 < hmm.nbStates(); ++n3) {
                for (n2 = 0; n2 < list2.size() - 1; ++n2) {
                    int n6 = n3;
                    dArray2[n6] = dArray2[n6] + dArray5[n2][n3];
                    for (int i = 0; i < hmm.nbStates(); ++i) {
                        double[] dArray6 = dArray[n3];
                        int n7 = i;
                        dArray6[n7] = dArray6[n7] + dArray3[n2][n3][i];
                    }
                }
            }
        }
        for (n = 0; n < hmm.nbStates(); ++n) {
            int n8;
            if (dArray2[n] == 0.0) {
                for (n8 = 0; n8 < hmm.nbStates(); ++n8) {
                    ((Hmm)object2).setAij(n, n8, hmm.getAij(n, n8));
                }
                continue;
            }
            for (n8 = 0; n8 < hmm.nbStates(); ++n8) {
                ((Hmm)object2).setAij(n, n8, dArray[n][n8] / dArray2[n]);
            }
        }
        for (n = 0; n < hmm.nbStates(); ++n) {
            ((Hmm)object2).setPi(n, 0.0);
        }
        for (n = 0; n < list.size(); ++n) {
            for (int i = 0; i < hmm.nbStates(); ++i) {
                ((Hmm)object2).setPi(i, ((Hmm)object2).getPi(i) + dArrayArray[n][0][i] / (double)list.size());
            }
        }
        for (n = 0; n < hmm.nbStates(); ++n) {
            List list3 = KMeansLearner.flat(list);
            object = new double[list3.size()];
            double d = 0.0;
            n3 = 0;
            n2 = 0;
            for (List<O> list4 : list) {
                int n9 = 0;
                while (n9 < list4.size()) {
                    double d2 = dArrayArray[n2][n9][n];
                    object[n3] = d2;
                    d += d2;
                    ++n9;
                    ++n3;
                }
                ++n2;
            }
            --n3;
            while (n3 >= 0) {
                Object object3 = object;
                int n10 = n3--;
                object3[n10] = object3[n10] / d;
            }
            Opdf opdf = ((Hmm)object2).getOpdf(n);
            opdf.fit(list3, (double[])object);
        }
        return object2;
    }

    protected <O extends Observation> ForwardBackwardCalculator generateForwardBackwardCalculator(List<? extends O> list, Hmm<O> hmm) {
        return new ForwardBackwardCalculator(list, hmm, EnumSet.allOf(ForwardBackwardCalculator.Computation.class));
    }

    public <O extends Observation> Hmm<O> learn(Hmm<O> hmm, List<? extends List<? extends O>> list) {
        Hmm<O> hmm2 = hmm;
        for (int i = 0; i < this.nbIterations; ++i) {
            hmm2 = this.iterate(hmm2, list);
        }
        return hmm2;
    }

    protected <O extends Observation> double[][][] estimateXi(List<? extends O> list, ForwardBackwardCalculator forwardBackwardCalculator, Hmm<O> hmm) {
        if (list.size() <= 1) {
            throw new IllegalArgumentException("Observation sequence too short");
        }
        double[][][] dArray = new double[list.size() - 1][hmm.nbStates()][hmm.nbStates()];
        double d = forwardBackwardCalculator.probability();
        Iterator<O> iterator = list.iterator();
        iterator.next();
        for (int i = 0; i < list.size() - 1; ++i) {
            Observation observation = (Observation)iterator.next();
            for (int j = 0; j < hmm.nbStates(); ++j) {
                for (int k = 0; k < hmm.nbStates(); ++k) {
                    dArray[i][j][k] = forwardBackwardCalculator.alphaElement(i, j) * hmm.getAij(j, k) * hmm.getOpdf(k).probability(observation) * forwardBackwardCalculator.betaElement(i + 1, k) / d;
                }
            }
        }
        return dArray;
    }

    protected double[][] estimateGamma(double[][][] dArray, ForwardBackwardCalculator forwardBackwardCalculator) {
        int n;
        int n2;
        double[][] dArray2 = new double[dArray.length + 1][dArray[0].length];
        for (n2 = 0; n2 < dArray.length + 1; ++n2) {
            Arrays.fill(dArray2[n2], 0.0);
        }
        for (n2 = 0; n2 < dArray.length; ++n2) {
            for (n = 0; n < dArray[0].length; ++n) {
                for (int i = 0; i < dArray[0].length; ++i) {
                    double[] dArray3 = dArray2[n2];
                    int n3 = n;
                    dArray3[n3] = dArray3[n3] + dArray[n2][n][i];
                }
            }
        }
        for (n2 = 0; n2 < dArray[0].length; ++n2) {
            for (n = 0; n < dArray[0].length; ++n) {
                double[] dArray4 = dArray2[dArray.length];
                int n4 = n2;
                dArray4[n4] = dArray4[n4] + dArray[dArray.length - 1][n][n2];
            }
        }
        return dArray2;
    }

    public int getNbIterations() {
        return this.nbIterations;
    }

    public void setNbIterations(int n) {
        if (n < 0) {
            throw new IllegalArgumentException("Positive number expected");
        }
        this.nbIterations = n;
    }
}

