/*
 * Decompiled with CFR 0.152.
 */
package org.apache.commons.math3.distribution.fitting;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.commons.math3.distribution.MixtureMultivariateNormalDistribution;
import org.apache.commons.math3.distribution.MultivariateNormalDistribution;
import org.apache.commons.math3.exception.ConvergenceException;
import org.apache.commons.math3.exception.DimensionMismatchException;
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
import org.apache.commons.math3.exception.NumberIsTooLargeException;
import org.apache.commons.math3.exception.NumberIsTooSmallException;
import org.apache.commons.math3.exception.util.Localizable;
import org.apache.commons.math3.exception.util.LocalizedFormats;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.SingularMatrixException;
import org.apache.commons.math3.stat.correlation.Covariance;
import org.apache.commons.math3.util.FastMath;
import org.apache.commons.math3.util.MathArrays;
import org.apache.commons.math3.util.Pair;

public class MultivariateNormalMixtureExpectationMaximization {
    private static final int DEFAULT_MAX_ITERATIONS = 1000;
    private static final double DEFAULT_THRESHOLD = 1.0E-5;
    private final double[][] data;
    private MixtureMultivariateNormalDistribution fittedModel;
    private double logLikelihood = 0.0;

    public MultivariateNormalMixtureExpectationMaximization(double[][] data2) throws NotStrictlyPositiveException, DimensionMismatchException, NumberIsTooSmallException {
        if (data2.length < 1) {
            throw new NotStrictlyPositiveException(data2.length);
        }
        this.data = new double[data2.length][data2[0].length];
        for (int i2 = 0; i2 < data2.length; ++i2) {
            if (data2[i2].length != data2[0].length) {
                throw new DimensionMismatchException(data2[i2].length, data2[0].length);
            }
            if (data2[i2].length < 2) {
                throw new NumberIsTooSmallException((Localizable)LocalizedFormats.NUMBER_TOO_SMALL, (Number)data2[i2].length, 2, true);
            }
            this.data[i2] = MathArrays.copyOf(data2[i2], data2[i2].length);
        }
    }

    public void fit(MixtureMultivariateNormalDistribution initialMixture, int maxIterations, double threshold) throws SingularMatrixException, NotStrictlyPositiveException, DimensionMismatchException {
        if (maxIterations < 1) {
            throw new NotStrictlyPositiveException(maxIterations);
        }
        if (threshold < Double.MIN_VALUE) {
            throw new NotStrictlyPositiveException(threshold);
        }
        int n2 = this.data.length;
        int numCols = this.data[0].length;
        int k2 = initialMixture.getComponents().size();
        int numMeanColumns = ((MultivariateNormalDistribution)initialMixture.getComponents().get(0).getSecond()).getMeans().length;
        if (numMeanColumns != numCols) {
            throw new DimensionMismatchException(numMeanColumns, numCols);
        }
        int numIterations = 0;
        double previousLogLikelihood = 0.0;
        this.logLikelihood = Double.NEGATIVE_INFINITY;
        this.fittedModel = new MixtureMultivariateNormalDistribution(initialMixture.getComponents());
        while (numIterations++ <= maxIterations && FastMath.abs(previousLogLikelihood - this.logLikelihood) > threshold) {
            int j2;
            int j3;
            previousLogLikelihood = this.logLikelihood;
            double sumLogLikelihood = 0.0;
            List components = this.fittedModel.getComponents();
            double[] weights = new double[k2];
            MultivariateNormalDistribution[] mvns = new MultivariateNormalDistribution[k2];
            for (int j4 = 0; j4 < k2; ++j4) {
                weights[j4] = components.get(j4).getFirst();
                mvns[j4] = (MultivariateNormalDistribution)components.get(j4).getSecond();
            }
            double[][] gamma = new double[n2][k2];
            double[] gammaSums = new double[k2];
            double[][] gammaDataProdSums = new double[k2][numCols];
            for (int i2 = 0; i2 < n2; ++i2) {
                double rowDensity = this.fittedModel.density(this.data[i2]);
                sumLogLikelihood += FastMath.log(rowDensity);
                for (j3 = 0; j3 < k2; ++j3) {
                    gamma[i2][j3] = weights[j3] * mvns[j3].density(this.data[i2]) / rowDensity;
                    int n3 = j3;
                    gammaSums[n3] = gammaSums[n3] + gamma[i2][j3];
                    for (int col = 0; col < numCols; ++col) {
                        double[] dArray = gammaDataProdSums[j3];
                        int n4 = col;
                        dArray[n4] = dArray[n4] + gamma[i2][j3] * this.data[i2][col];
                    }
                }
            }
            this.logLikelihood = sumLogLikelihood / (double)n2;
            double[] newWeights = new double[k2];
            double[][] newMeans = new double[k2][numCols];
            for (int j5 = 0; j5 < k2; ++j5) {
                newWeights[j5] = gammaSums[j5] / (double)n2;
                for (int col = 0; col < numCols; ++col) {
                    newMeans[j5][col] = gammaDataProdSums[j5][col] / gammaSums[j5];
                }
            }
            RealMatrix[] newCovMats = new RealMatrix[k2];
            for (j3 = 0; j3 < k2; ++j3) {
                newCovMats[j3] = new Array2DRowRealMatrix(numCols, numCols);
            }
            for (int i3 = 0; i3 < n2; ++i3) {
                for (j2 = 0; j2 < k2; ++j2) {
                    Array2DRowRealMatrix vec = new Array2DRowRealMatrix(MathArrays.ebeSubtract(this.data[i3], newMeans[j2]));
                    RealMatrix dataCov = vec.multiply(vec.transpose()).scalarMultiply(gamma[i3][j2]);
                    newCovMats[j2] = newCovMats[j2].add(dataCov);
                }
            }
            double[][][] newCovMatArrays = new double[k2][numCols][numCols];
            for (j2 = 0; j2 < k2; ++j2) {
                newCovMats[j2] = newCovMats[j2].scalarMultiply(1.0 / gammaSums[j2]);
                newCovMatArrays[j2] = newCovMats[j2].getData();
            }
            this.fittedModel = new MixtureMultivariateNormalDistribution(newWeights, newMeans, newCovMatArrays);
        }
        if (FastMath.abs(previousLogLikelihood - this.logLikelihood) > threshold) {
            throw new ConvergenceException();
        }
    }

    public void fit(MixtureMultivariateNormalDistribution initialMixture) throws SingularMatrixException, NotStrictlyPositiveException {
        this.fit(initialMixture, 1000, 1.0E-5);
    }

    public static MixtureMultivariateNormalDistribution estimate(double[][] data2, int numComponents) throws NotStrictlyPositiveException, DimensionMismatchException {
        if (data2.length < 2) {
            throw new NotStrictlyPositiveException(data2.length);
        }
        if (numComponents < 2) {
            throw new NumberIsTooSmallException(numComponents, (Number)2, true);
        }
        if (numComponents > data2.length) {
            throw new NumberIsTooLargeException(numComponents, (Number)data2.length, true);
        }
        int numRows = data2.length;
        int numCols = data2[0].length;
        Object[] sortedData = new DataRow[numRows];
        for (int i2 = 0; i2 < numRows; ++i2) {
            sortedData[i2] = new DataRow(data2[i2]);
        }
        Arrays.sort(sortedData);
        double weight = 1.0 / (double)numComponents;
        ArrayList<Pair<Double, MultivariateNormalDistribution>> components = new ArrayList<Pair<Double, MultivariateNormalDistribution>>(numComponents);
        for (int binIndex = 0; binIndex < numComponents; ++binIndex) {
            int minIndex = binIndex * numRows / numComponents;
            int maxIndex = (binIndex + 1) * numRows / numComponents;
            int numBinRows = maxIndex - minIndex;
            double[][] binData = new double[numBinRows][numCols];
            double[] columnMeans = new double[numCols];
            int i3 = minIndex;
            int iBin = 0;
            while (i3 < maxIndex) {
                for (int j2 = 0; j2 < numCols; ++j2) {
                    double val = ((DataRow)sortedData[i3]).getRow()[j2];
                    int n2 = j2;
                    columnMeans[n2] = columnMeans[n2] + val;
                    binData[iBin][j2] = val;
                }
                ++i3;
                ++iBin;
            }
            MathArrays.scaleInPlace(1.0 / (double)numBinRows, columnMeans);
            double[][] covMat = new Covariance(binData).getCovarianceMatrix().getData();
            MultivariateNormalDistribution mvn = new MultivariateNormalDistribution(columnMeans, covMat);
            components.add(new Pair<Double, MultivariateNormalDistribution>(weight, mvn));
        }
        return new MixtureMultivariateNormalDistribution((List<Pair<Double, MultivariateNormalDistribution>>)components);
    }

    public double getLogLikelihood() {
        return this.logLikelihood;
    }

    public MixtureMultivariateNormalDistribution getFittedModel() {
        return new MixtureMultivariateNormalDistribution(this.fittedModel.getComponents());
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    private static class DataRow
    implements Comparable<DataRow> {
        private final double[] row;
        private Double mean;

        DataRow(double[] data2) {
            this.row = data2;
            this.mean = 0.0;
            for (int i2 = 0; i2 < data2.length; ++i2) {
                this.mean = this.mean + data2[i2];
            }
            this.mean = this.mean / (double)data2.length;
        }

        @Override
        public int compareTo(DataRow other) {
            return this.mean.compareTo(other.mean);
        }

        public boolean equals(Object other) {
            if (this == other) {
                return true;
            }
            if (other instanceof DataRow) {
                return MathArrays.equals(this.row, ((DataRow)other).row);
            }
            return false;
        }

        public int hashCode() {
            return Arrays.hashCode(this.row);
        }

        public double[] getRow() {
            return this.row;
        }
    }
}

