/*
 * Decompiled with CFR 0.152.
 */
package org.hipparchus.optim.nonlinear.vector.constrained;

import org.hipparchus.linear.ArrayRealVector;
import org.hipparchus.linear.DecompositionSolver;
import org.hipparchus.linear.EigenDecompositionSymmetric;
import org.hipparchus.linear.MatrixUtils;
import org.hipparchus.linear.RealMatrix;
import org.hipparchus.linear.RealVector;
import org.hipparchus.optim.nonlinear.vector.constrained.ADMMQPSolution;
import org.hipparchus.optim.nonlinear.vector.constrained.KarushKuhnTuckerSolver;
import org.hipparchus.util.FastMath;

public class ADMMQPKKT
implements KarushKuhnTuckerSolver<ADMMQPSolution> {
    private RealMatrix H;
    private RealVector q;
    private RealMatrix A;
    private double sigma;
    private RealMatrix R;
    private RealMatrix Rinv;
    private RealVector lb;
    private RealVector ub;
    private double alpha;
    private RealMatrix M;
    private DecompositionSolver dsX;

    ADMMQPKKT() {
    }

    @Override
    public ADMMQPSolution solve(RealVector b1, RealVector b2) {
        RealVector z = this.dsX.solve((RealVector)new ArrayRealVector((ArrayRealVector)b1, b2));
        return new ADMMQPSolution(z.getSubVector(0, b1.getDimension()), z.getSubVector(b1.getDimension(), b2.getDimension()));
    }

    public void updateSigmaRho(double newSigma, int me, double rho) {
        this.sigma = newSigma;
        this.H = this.H.add(MatrixUtils.createRealIdentityMatrix((int)this.H.getColumnDimension()).scalarMultiply(newSigma));
        this.createPenaltyMatrix(me, rho);
        this.M = MatrixUtils.createRealMatrix((int)(this.H.getRowDimension() + this.A.getRowDimension()), (int)(this.H.getRowDimension() + this.A.getRowDimension()));
        this.M.setSubMatrix(this.H.getData(), 0, 0);
        this.M.setSubMatrix(this.A.getData(), this.H.getRowDimension(), 0);
        this.M.setSubMatrix(this.A.transpose().getData(), 0, this.H.getRowDimension());
        this.M.setSubMatrix(this.Rinv.scalarMultiply(-1.0).getData(), this.H.getRowDimension(), this.H.getRowDimension());
        this.dsX = new EigenDecompositionSymmetric(this.M).getSolver();
    }

    public void initialize(RealMatrix newH, RealMatrix newA, RealVector newQ, int me, RealVector newLb, RealVector newUb, double rho, double newSigma, double newAlpha) {
        this.lb = newLb;
        this.ub = newUb;
        this.alpha = newAlpha;
        this.sigma = newSigma;
        this.H = newH.add(MatrixUtils.createRealIdentityMatrix((int)newH.getColumnDimension()).scalarMultiply(newSigma));
        this.A = newA.copy();
        this.q = newQ.copy();
        this.createPenaltyMatrix(me, rho);
        this.M = MatrixUtils.createRealMatrix((int)(newH.getRowDimension() + newA.getRowDimension()), (int)(newH.getRowDimension() + newA.getRowDimension()));
        this.M.setSubMatrix(newH.getData(), 0, 0);
        this.M.setSubMatrix(newA.getData(), newH.getRowDimension(), 0);
        this.M.setSubMatrix(newA.transpose().getData(), 0, newH.getRowDimension());
        this.M.setSubMatrix(this.Rinv.scalarMultiply(-1.0).getData(), newH.getRowDimension(), newH.getRowDimension());
        this.dsX = new EigenDecompositionSymmetric(this.M).getSolver();
    }

    private void createPenaltyMatrix(int me, double rho) {
        this.R = MatrixUtils.createRealIdentityMatrix((int)this.A.getRowDimension());
        for (int i = 0; i < this.R.getRowDimension(); ++i) {
            if (i < me) {
                this.R.setEntry(i, i, rho * 1000.0);
                continue;
            }
            this.R.setEntry(i, i, rho);
        }
        this.Rinv = MatrixUtils.inverse((RealMatrix)this.R);
    }

    @Override
    public ADMMQPSolution iterate(RealVector ... previousSol) {
        double onealfa = 1.0 - this.alpha;
        RealVector xold = previousSol[0].copy();
        RealVector yold = previousSol[1].copy();
        RealVector zold = previousSol[2].copy();
        RealVector b1 = previousSol[0].mapMultiply(this.sigma).subtract(this.q);
        RealVector b2 = previousSol[2].subtract(this.Rinv.operate(previousSol[1]));
        ADMMQPSolution sol = this.solve(b1, b2);
        RealVector xtilde = sol.getX();
        RealVector vtilde = sol.getV();
        RealVector ztilde = zold.add(this.Rinv.operate(vtilde.subtract(yold)));
        previousSol[0] = xtilde.mapMultiply(this.alpha).add(xold.mapMultiply(onealfa));
        RealVector zpartial = ztilde.mapMultiply(this.alpha).add(zold.mapMultiply(onealfa)).add(this.Rinv.operate(yold));
        for (int j = 0; j < previousSol[2].getDimension(); ++j) {
            previousSol[2].setEntry(j, FastMath.min((double)FastMath.max((double)zpartial.getEntry(j), (double)this.lb.getEntry(j)), (double)this.ub.getEntry(j)));
        }
        RealVector ytilde = ztilde.mapMultiply(this.alpha).add(zold.mapMultiply(onealfa).subtract(previousSol[2]));
        previousSol[1] = yold.add(this.R.operate(ytilde));
        return new ADMMQPSolution(previousSol[0], vtilde, previousSol[1], previousSol[2]);
    }
}

