/*
 * Decompiled with CFR 0.152.
 */
package nom.bdezonia.zorbage.algorithm;

import nom.bdezonia.zorbage.sampling.IntegerIndex;
import nom.bdezonia.zorbage.sampling.SamplingCartesianIntegerGrid;
import nom.bdezonia.zorbage.sampling.SamplingIterator;
import nom.bdezonia.zorbage.type.algebra.Addition;
import nom.bdezonia.zorbage.type.algebra.Algebra;
import nom.bdezonia.zorbage.type.algebra.TensorMember;

public class TensorContract {
    private TensorContract() {
    }

    public static <M extends Algebra<M, NUMBER> & Addition<NUMBER>, NUMBER> void compute(M numberAlg, Integer aRank, Integer i, Integer j, TensorMember<NUMBER> a, TensorMember<NUMBER> b) {
        if (i == j) {
            throw new IllegalArgumentException("cannot contract along a single axis");
        }
        if (i < 0 || j < 0) {
            throw new IllegalArgumentException("negative contraction indices given");
        }
        if (i >= aRank || j >= aRank) {
            throw new IllegalArgumentException("contraction indices cannot be out of bounds of input tensor's rank");
        }
        if (a == b) {
            throw new IllegalArgumentException("src cannot equal dest: contraction is not an in place operation");
        }
        int newRank = aRank - 2;
        long[] newDims = new long[newRank];
        for (int k = 0; k < newDims.length; ++k) {
            newDims[k] = a.dimension(0);
        }
        b.alloc(newDims);
        if (newRank == 0) {
            Object sum = numberAlg.construct();
            Object tmp = numberAlg.construct();
            IntegerIndex pos = new IntegerIndex(2);
            int idx = 0;
            while ((long)idx < a.dimension(0)) {
                pos.set(i, idx);
                pos.set(j, idx);
                a.v(pos, tmp);
                ((Addition<NUMBER>)numberAlg).add().call(sum, tmp, sum);
                ++idx;
            }
            IntegerIndex origin = new IntegerIndex(0);
            b.setV(origin, sum);
            return;
        }
        IntegerIndex point1 = new IntegerIndex(newRank);
        IntegerIndex point2 = new IntegerIndex(newRank);
        for (int k = 0; k < newDims.length; ++k) {
            point2.set(k, newDims[k] - 1L);
        }
        SamplingCartesianIntegerGrid sampling = new SamplingCartesianIntegerGrid(point1, point2);
        SamplingIterator<IntegerIndex> iter = sampling.iterator();
        IntegerIndex contractedPos = new IntegerIndex(sampling.numDimensions());
        IntegerIndex origPos = new IntegerIndex(aRank);
        Object sum = numberAlg.construct();
        Object tmp = numberAlg.construct();
        while (iter.hasNext()) {
            iter.next(contractedPos);
            int p = 0;
            for (int r = 0; r < aRank; ++r) {
                if (r == i) {
                    origPos.set(i, 0L);
                    continue;
                }
                if (r == j) {
                    origPos.set(j, 0L);
                    continue;
                }
                origPos.set(r, contractedPos.get(p++));
            }
            numberAlg.zero().call(sum);
            int idx = 0;
            while ((long)idx < a.dimension(0)) {
                origPos.set(i, idx);
                origPos.set(j, idx);
                a.v(origPos, tmp);
                ((Addition<NUMBER>)numberAlg).add().call(sum, tmp, sum);
                ++idx;
            }
            b.setV(contractedPos, sum);
        }
    }
}

