/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.math.solver;

import org.apache.mahout.math.CardinalityException;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorIterable;
import org.apache.mahout.math.function.Functions;
import org.apache.mahout.math.function.PlusMult;
import org.apache.mahout.math.solver.Preconditioner;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ConjugateGradientSolver {
    public static final double DEFAULT_MAX_ERROR = 1.0E-9;
    private static final Logger log = LoggerFactory.getLogger(ConjugateGradientSolver.class);
    private static final PlusMult plusMult = new PlusMult(1.0);
    private int iterations = 0;
    private double residualNormSquared = Double.NaN;

    public Vector solve(VectorIterable a, Vector b) {
        return this.solve(a, b, null, b.size(), 1.0E-9);
    }

    public Vector solve(VectorIterable a, Vector b, Preconditioner precond) {
        return this.solve(a, b, precond, b.size(), 1.0E-9);
    }

    public Vector solve(VectorIterable a, Vector b, Preconditioner preconditioner, int maxIterations, double maxError) {
        if (a.numRows() != a.numCols()) {
            throw new IllegalArgumentException("Matrix must be square, symmetric and positive definite.");
        }
        if (a.numCols() != b.size()) {
            throw new CardinalityException(a.numCols(), b.size());
        }
        if (maxIterations <= 0) {
            throw new IllegalArgumentException("Max iterations must be positive.");
        }
        if (maxError < 0.0) {
            throw new IllegalArgumentException("Max error must be non-negative.");
        }
        DenseVector x = new DenseVector(b.size());
        this.iterations = 0;
        Vector residual = b.minus(a.times(x));
        this.residualNormSquared = residual.dot(residual);
        log.info("Conjugate gradient initial residual norm = " + Math.sqrt(this.residualNormSquared));
        double previousConditionedNormSqr = 0.0;
        DenseVector updateDirection = null;
        while (Math.sqrt(this.residualNormSquared) > maxError && this.iterations < maxIterations) {
            double conditionedNormSqr;
            Vector conditionedResidual;
            if (preconditioner == null) {
                conditionedResidual = residual;
                conditionedNormSqr = this.residualNormSquared;
            } else {
                conditionedResidual = preconditioner.precondition(residual);
                conditionedNormSqr = residual.dot(conditionedResidual);
            }
            ++this.iterations;
            if (this.iterations == 1) {
                updateDirection = new DenseVector(conditionedResidual);
            } else {
                double beta = conditionedNormSqr / previousConditionedNormSqr;
                updateDirection.assign(Functions.MULT, beta);
                updateDirection.assign(conditionedResidual, Functions.PLUS);
            }
            Vector aTimesUpdate = a.times(updateDirection);
            double alpha = conditionedNormSqr / updateDirection.dot(aTimesUpdate);
            plusMult.setMultiplicator(alpha);
            x.assign(updateDirection, plusMult);
            plusMult.setMultiplicator(-alpha);
            residual.assign(aTimesUpdate, plusMult);
            previousConditionedNormSqr = conditionedNormSqr;
            this.residualNormSquared = residual.dot(residual);
            log.info("Conjugate gradient iteration " + this.iterations + " residual norm = " + Math.sqrt(this.residualNormSquared));
        }
        return x;
    }

    public int getIterations() {
        return this.iterations;
    }

    public double getResidualNorm() {
        return Math.sqrt(this.residualNormSquared);
    }
}

