/*
 * Decompiled with CFR 0.152.
 */
package org.opengion.penguin.math.statistics;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;

public class HybsLogisticRegression {
    private int n_N;
    private int n_in;
    private int n_out;
    private Integer[] random_index;
    private double[][] vW;
    private double[] vb;

    public HybsLogisticRegression(double[][] dArray, int[][] nArray, double d, int n, double d2) {
        int n2;
        this.n_N = dArray.length;
        this.n_in = dArray[0].length;
        this.n_out = nArray[0].length;
        this.vW = new double[this.n_out][this.n_in];
        this.vb = new double[this.n_out];
        this.random_index = new Integer[this.n_N];
        for (n2 = 0; n2 < this.n_N; ++n2) {
            this.random_index[n2] = n2;
        }
        List<Integer> list = Arrays.asList(this.random_index);
        for (n2 = 0; n2 < n; ++n2) {
            Collections.shuffle(list);
            this.random_index = list.toArray(new Integer[list.size()]);
            int n3 = 0;
            while ((double)n3 < (double)this.n_N * d2) {
                int n4 = this.random_index[n3];
                this.train(dArray[n4], nArray[n4], d);
                ++n3;
            }
            d *= 0.95;
        }
    }

    public HybsLogisticRegression(double[][] dArray, double[] dArray2) {
        this.n_in = dArray[0].length;
        this.n_out = dArray2.length;
        this.vW = dArray;
        this.vb = dArray2;
    }

    private double[] train(double[] dArray, int[] nArray, double d) {
        int n;
        int n2;
        double[] dArray2 = new double[this.n_out];
        double[] dArray3 = new double[this.n_out];
        for (n2 = 0; n2 < this.n_out; ++n2) {
            dArray2[n2] = 0.0;
            for (n = 0; n < this.n_in; ++n) {
                int n3 = n2;
                dArray2[n3] = dArray2[n3] + this.vW[n2][n] * dArray[n];
            }
            int n4 = n2;
            dArray2[n4] = dArray2[n4] + this.vb[n2];
        }
        this.softmax(dArray2);
        for (n2 = 0; n2 < this.n_out; ++n2) {
            dArray3[n2] = (double)nArray[n2] - dArray2[n2];
            for (n = 0; n < this.n_in; ++n) {
                double[] dArray4 = this.vW[n2];
                int n5 = n;
                dArray4[n5] = dArray4[n5] + d * dArray3[n2] * dArray[n] / (double)this.n_N;
            }
            int n6 = n2;
            this.vb[n6] = this.vb[n6] + d * dArray3[n2] / (double)this.n_N;
        }
        return dArray3;
    }

    private void softmax(double[] dArray) {
        int n;
        double d = 0.0;
        for (n = 0; n < this.n_out; ++n) {
            dArray[n] = Math.exp(dArray[n]);
            d += dArray[n];
        }
        n = 0;
        while (n < this.n_out) {
            int n2 = n++;
            dArray[n2] = dArray[n2] / d;
        }
    }

    public double[][] getW() {
        return this.vW;
    }

    public double[] getB() {
        return this.vb;
    }

    public double[] predict(double[] dArray) {
        double[] dArray2 = new double[this.n_out];
        for (int i = 0; i < this.n_out; ++i) {
            dArray2[i] = 0.0;
            for (int j = 0; j < this.n_in; ++j) {
                int n = i;
                dArray2[n] = dArray2[n] + this.vW[i][j] * dArray[j];
            }
            int n = i;
            dArray2[n] = dArray2[n] + this.vb[i];
        }
        this.softmax(dArray2);
        return dArray2;
    }

    public static void main(String[] stringArray) {
        double[][] dArrayArray = new double[][]{{-2.0, 2.0}, {-2.1, 1.9}, {-1.8, 2.1}, {0.0, 0.0}, {0.2, -0.2}, {-0.1, 0.1}, {2.0, -2.0}, {2.2, -2.1}, {1.9, -2.0}};
        int[][] nArrayArray = new int[][]{{1, 0, 0}, {1, 0, 0}, {1, 0, 0}, {0, 1, 0}, {0, 1, 0}, {0, 1, 0}, {0, 0, 1}, {0, 0, 1}, {0, 0, 1}};
        double[][] dArrayArray2 = new double[][]{{-2.5, 2.0}, {0.1, -0.1}, {1.5, -2.5}};
        double[][] dArray = new double[dArrayArray2.length][nArrayArray[0].length];
        HybsLogisticRegression hybsLogisticRegression = new HybsLogisticRegression(dArrayArray, nArrayArray, 0.1, 500, 1.0);
        for (int i = 0; i < dArrayArray2.length; ++i) {
            dArray[i] = hybsLogisticRegression.predict(dArrayArray2[i]);
            System.out.print(Arrays.toString(dArray[i]));
        }
    }
}

