/*
 * Decompiled with CFR 0.152.
 */
package org.broadinstitute.hellbender.tools.walkers.haplotypecaller;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.broadinstitute.hellbender.tools.walkers.haplotypecaller.KMerCounter;
import org.broadinstitute.hellbender.tools.walkers.haplotypecaller.Kmer;
import org.broadinstitute.hellbender.tools.walkers.haplotypecaller.ReadErrorCorrector;
import org.broadinstitute.hellbender.utils.BaseUtils;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.clipping.ReadClipper;
import org.broadinstitute.hellbender.utils.read.GATKRead;

public final class NearbyKmerErrorCorrector
implements ReadErrorCorrector {
    private static final Logger logger = LogManager.getLogger(NearbyKmerErrorCorrector.class);
    final KMerCounter countsByKMer;
    private final Map<Kmer, Kmer> kmerCorrectionMap = new HashMap<Kmer, Kmer>();
    private final Map<Kmer, Pair<int[], byte[]>> kmerDifferingBases = new HashMap<Kmer, Pair<int[], byte[]>>();
    private final int kmerLength;
    private final boolean debug;
    private final boolean trimLowQualityBases;
    private final byte minTailQuality;
    private final int maxMismatchesToCorrect;
    private final byte qualityOfCorrectedBases;
    private final int maxObservationsForKmerToBeCorrectable;
    private final int maxHomopolymerLengthInRegion;
    private final int minObservationsForKmerToBeSolid;
    private static final boolean doInplaceErrorCorrection = false;
    private static final int MAX_MISMATCHES_TO_CORRECT = 2;
    private static final byte QUALITY_OF_CORRECTED_BASES = 30;
    private static final int MAX_OBSERVATIONS_FOR_KMER_TO_BE_CORRECTABLE = 1;
    private static final boolean TRIM_LOW_QUAL_TAILS = false;
    private static final boolean DONT_CORRECT_IN_LONG_HOMOPOLYMERS = false;
    private static final int MAX_HOMOPOLYMER_THRESHOLD = 12;
    private final ReadErrorCorrectionStats readErrorCorrectionStats = new ReadErrorCorrectionStats();

    public NearbyKmerErrorCorrector(int kmerLength, int maxMismatchesToCorrect, int maxObservationsForKmerToBeCorrectable, byte qualityOfCorrectedBases, int minObservationsForKmerToBeSolid, boolean trimLowQualityBases, byte minTailQuality, boolean debug, byte[] fullReferenceWithPadding) {
        Utils.validateArg(kmerLength > 0, () -> "kmerLength must be > 0 but got " + kmerLength);
        Utils.validateArg(maxMismatchesToCorrect > 0, () -> "maxMismatchesToCorrect must be >= 1 but got " + maxMismatchesToCorrect);
        Utils.validateArg(qualityOfCorrectedBases >= 2 && qualityOfCorrectedBases <= 60, () -> "qualityOfCorrectedBases must be >= 2 and <= MAX_REASONABLE_Q_SCORE but got " + qualityOfCorrectedBases);
        this.countsByKMer = new KMerCounter(kmerLength);
        this.kmerLength = kmerLength;
        this.maxMismatchesToCorrect = maxMismatchesToCorrect;
        this.qualityOfCorrectedBases = qualityOfCorrectedBases;
        this.minObservationsForKmerToBeSolid = minObservationsForKmerToBeSolid;
        this.trimLowQualityBases = trimLowQualityBases;
        this.minTailQuality = minTailQuality;
        this.debug = debug;
        this.maxObservationsForKmerToBeCorrectable = maxObservationsForKmerToBeCorrectable;
        this.maxHomopolymerLengthInRegion = NearbyKmerErrorCorrector.computeMaxHLen(fullReferenceWithPadding);
    }

    public NearbyKmerErrorCorrector(int kmerLength, byte minTailQuality, int minObservationsForKmerToBeSolid, boolean debug, byte[] fullReferenceWithPadding) {
        this(kmerLength, 2, 1, 30, minObservationsForKmerToBeSolid, false, minTailQuality, debug, fullReferenceWithPadding);
    }

    protected void addReadKmers(GATKRead read) {
        Utils.nonNull(read);
        byte[] readBases = read.getBases();
        for (int offset = 0; offset <= readBases.length - this.kmerLength; ++offset) {
            this.countsByKMer.addKmer(new Kmer(readBases, offset, this.kmerLength), 1);
        }
    }

    @Override
    public final List<GATKRead> correctReads(Collection<GATKRead> reads) {
        ArrayList<GATKRead> correctedReads = new ArrayList<GATKRead>(reads.size());
        this.computeKmerCorrectionMap();
        for (GATKRead read : reads) {
            GATKRead correctedRead = this.correctRead(read);
            if (this.trimLowQualityBases) {
                correctedReads.add(ReadClipper.hardClipLowQualEnds(correctedRead, this.minTailQuality));
                continue;
            }
            correctedReads.add(correctedRead);
        }
        if (this.debug) {
            logger.info("Number of corrected bases:" + this.readErrorCorrectionStats.numBasesCorrected);
            logger.info("Number of corrected reads:" + this.readErrorCorrectionStats.numReadsCorrected);
            logger.info("Number of skipped reads:" + this.readErrorCorrectionStats.numReadsUncorrected);
            logger.info("Number of solid kmers:" + this.readErrorCorrectionStats.numSolidKmers);
            logger.info("Number of corrected kmers:" + this.readErrorCorrectionStats.numCorrectedKmers);
            logger.info("Number of uncorrectable kmers:" + this.readErrorCorrectionStats.numUncorrectableKmers);
        }
        return correctedReads;
    }

    private GATKRead correctRead(GATKRead inputRead) {
        Utils.nonNull(inputRead);
        boolean corrected = false;
        byte[] correctedBases = inputRead.getBases();
        byte[] correctedQuals = inputRead.getBaseQualities();
        CorrectionSet correctionSet = this.buildCorrectionMap(correctedBases);
        for (int offset = 0; offset < correctedBases.length; ++offset) {
            Byte b = correctionSet.getConsensusCorrection(offset);
            if (b != null && b != correctedBases[offset]) {
                correctedBases[offset] = b;
                correctedQuals[offset] = this.qualityOfCorrectedBases;
                corrected = true;
            }
            ++this.readErrorCorrectionStats.numBasesCorrected;
        }
        if (corrected) {
            ++this.readErrorCorrectionStats.numReadsCorrected;
            GATKRead correctedRead = inputRead.deepCopy();
            correctedRead.setBaseQualities(correctedQuals);
            correctedRead.setBases(correctedBases);
            correctedRead.setReadGroup(inputRead.getReadGroup());
            return correctedRead;
        }
        ++this.readErrorCorrectionStats.numReadsUncorrected;
        return inputRead;
    }

    private CorrectionSet buildCorrectionMap(byte[] correctedBases) {
        Utils.nonNull(correctedBases);
        CorrectionSet correctionSet = new CorrectionSet(correctedBases.length);
        for (int offset = 0; offset <= correctedBases.length - this.kmerLength; ++offset) {
            Kmer kmer = new Kmer(correctedBases, offset, this.kmerLength);
            Kmer newKmer = this.kmerCorrectionMap.get(kmer);
            if (newKmer == null || newKmer.equals(kmer)) continue;
            Pair<int[], byte[]> differingPositions = this.kmerDifferingBases.get(kmer);
            int[] differingIndeces = (int[])differingPositions.getLeft();
            byte[] differingBases = (byte[])differingPositions.getRight();
            for (int k = 0; k < differingIndeces.length; ++k) {
                correctionSet.add(offset + differingIndeces[k], differingBases[k]);
            }
        }
        return correctionSet;
    }

    public void addReadsToKmers(Collection<GATKRead> reads) {
        Utils.nonNull(reads);
        for (GATKRead read : reads) {
            this.addReadKmers(read);
        }
        if (this.debug) {
            for (KMerCounter.CountedKmer countedKmer : this.countsByKMer.getCountedKmers()) {
                logger.info(String.format("%s\t%d\n", countedKmer.kmer, countedKmer.count));
            }
        }
    }

    private void computeKmerCorrectionMap() {
        for (KMerCounter.CountedKmer storedKmer : this.countsByKMer.getCountedKmers()) {
            if (storedKmer.getCount() >= this.minObservationsForKmerToBeSolid) {
                this.kmerCorrectionMap.put(storedKmer.getKmer(), storedKmer.getKmer());
                this.kmerDifferingBases.put(storedKmer.getKmer(), (Pair<int[], byte[]>)Pair.of((Object)new int[0], (Object)new byte[0]));
                ++this.readErrorCorrectionStats.numSolidKmers;
                continue;
            }
            if (storedKmer.getCount() > this.maxObservationsForKmerToBeCorrectable) continue;
            Pair<Kmer, Pair<int[], byte[]>> nearestNeighbor = this.findNearestNeighbor(storedKmer.getKmer(), this.countsByKMer, this.maxMismatchesToCorrect);
            if (nearestNeighbor != null) {
                this.kmerCorrectionMap.put(storedKmer.getKmer(), (Kmer)nearestNeighbor.getLeft());
                this.kmerDifferingBases.put(storedKmer.getKmer(), (Pair<int[], byte[]>)nearestNeighbor.getRight());
                ++this.readErrorCorrectionStats.numCorrectedKmers;
                continue;
            }
            ++this.readErrorCorrectionStats.numUncorrectableKmers;
        }
    }

    private Pair<Kmer, Pair<int[], byte[]>> findNearestNeighbor(Kmer kmer, KMerCounter countsByKMer, int maxDistance) {
        Utils.nonNull(kmer, "KMER");
        Utils.nonNull(countsByKMer, "countsByKMer");
        Utils.validateArg(maxDistance >= 1, "countsByKMer");
        int minimumDistance = Integer.MAX_VALUE;
        Kmer closestKmer = null;
        int[] differingIndeces = new int[maxDistance + 1];
        byte[] differingBases = new byte[maxDistance + 1];
        int[] closestDifferingIndices = new int[maxDistance + 1];
        byte[] closestDifferingBases = new byte[maxDistance + 1];
        for (KMerCounter.CountedKmer candidateKmer : countsByKMer.getCountedKmers()) {
            int hammingDistance;
            if (candidateKmer.getKmer().equals(kmer) || (hammingDistance = kmer.getDifferingPositions(candidateKmer.getKmer(), maxDistance, differingIndeces, differingBases)) < 0 || hammingDistance >= minimumDistance) continue;
            minimumDistance = hammingDistance;
            closestKmer = candidateKmer.getKmer();
            System.arraycopy(differingBases, 0, closestDifferingBases, 0, differingBases.length);
            System.arraycopy(differingIndeces, 0, closestDifferingIndices, 0, differingIndeces.length);
        }
        return Pair.of(closestKmer, (Object)Pair.of((Object)closestDifferingIndices, (Object)closestDifferingBases));
    }

    private static int computeMaxHLen(byte[] fullReferenceWithPadding) {
        Utils.nonNull(fullReferenceWithPadding);
        int leftRun = 1;
        int maxRun = 1;
        for (int i = 1; i < fullReferenceWithPadding.length; ++i) {
            if (fullReferenceWithPadding[i] == fullReferenceWithPadding[i - 1]) {
                ++leftRun;
                continue;
            }
            leftRun = 1;
        }
        if (leftRun > maxRun) {
            maxRun = leftRun;
        }
        return maxRun;
    }

    protected static class CorrectionSet {
        private final int size;
        private ArrayList<List<Byte>> corrections;

        public CorrectionSet(int size) {
            this.size = size;
            this.corrections = new ArrayList(size);
            for (int k = 0; k < size; ++k) {
                this.corrections.add(k, new ArrayList());
            }
        }

        public void add(int offset, byte base) {
            if (offset >= this.size || offset < 0) {
                throw new IllegalStateException("Bad entry into CorrectionSet: offset > size");
            }
            if (!BaseUtils.isRegularBase(base)) {
                return;
            }
            List<Byte> storedBytes = this.corrections.get(offset);
            storedBytes.add(base);
        }

        public List<Byte> get(int offset) {
            Utils.validateArg(offset >= 0 && offset < this.size, "Illegal call of CorrectionSet.get(): offset must be < size");
            return this.corrections.get(offset);
        }

        public Byte getConsensusCorrection(int offset) {
            Utils.validateArg(offset >= 0 && offset < this.size, "Illegal call of CorrectionSet.getConsensusCorrection(): offset must be < size");
            List<Byte> storedBytes = this.corrections.get(offset);
            if (storedBytes.isEmpty()) {
                return null;
            }
            byte lastBase = storedBytes.remove(storedBytes.size() - 1);
            for (Byte b : storedBytes) {
                if (b == lastBase) continue;
                return null;
            }
            return lastBase;
        }
    }

    private static final class ReadErrorCorrectionStats {
        public int numReadsCorrected;
        public int numReadsUncorrected;
        public int numBasesCorrected;
        public int numSolidKmers;
        public int numUncorrectableKmers;
        public int numCorrectedKmers;

        private ReadErrorCorrectionStats() {
        }
    }
}

