/*
 * Decompiled with CFR 0.152.
 */
package org.broadinstitute.hellbender.utils.recalibration.covariates;

import com.google.common.annotations.VisibleForTesting;
import htsjdk.samtools.SAMFileHeader;
import it.unimi.dsi.fastutil.ints.IntArrayList;
import it.unimi.dsi.fastutil.ints.IntList;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.broadinstitute.barclay.argparser.CommandLineException;
import org.broadinstitute.hellbender.exceptions.GATKException;
import org.broadinstitute.hellbender.utils.BaseUtils;
import org.broadinstitute.hellbender.utils.clipping.ClippingRepresentation;
import org.broadinstitute.hellbender.utils.clipping.ReadClipper;
import org.broadinstitute.hellbender.utils.read.GATKRead;
import org.broadinstitute.hellbender.utils.recalibration.RecalibrationArgumentCollection;
import org.broadinstitute.hellbender.utils.recalibration.covariates.Covariate;
import org.broadinstitute.hellbender.utils.recalibration.covariates.ReadCovariates;

public final class ContextCovariate
implements Covariate {
    private static final long serialVersionUID = 1L;
    private static final Logger logger = LogManager.getLogger(ContextCovariate.class);
    private final int mismatchesContextSize;
    private final int indelsContextSize;
    private final int mismatchesKeyMask;
    private final int indelsKeyMask;
    private static final int LENGTH_BITS = 4;
    private static final int LENGTH_MASK = 15;
    private static final int MAX_DNA_CONTEXT = 13;
    private final byte lowQualTail;

    public ContextCovariate(RecalibrationArgumentCollection RAC) {
        this.mismatchesContextSize = RAC.MISMATCHES_CONTEXT_SIZE;
        this.indelsContextSize = RAC.INDELS_CONTEXT_SIZE;
        logger.debug("\t\tContext sizes: base substitution model " + this.mismatchesContextSize + ", indel substitution model " + this.indelsContextSize);
        if (this.mismatchesContextSize > 13) {
            throw new CommandLineException.BadArgumentValue("mismatches_context_size", String.format("context size cannot be bigger than %d, but was %d", 13, this.mismatchesContextSize));
        }
        if (this.indelsContextSize > 13) {
            throw new CommandLineException.BadArgumentValue("indels_context_size", String.format("context size cannot be bigger than %d, but was %d", 13, this.indelsContextSize));
        }
        this.lowQualTail = RAC.LOW_QUAL_TAIL;
        if (this.mismatchesContextSize <= 0 || this.indelsContextSize <= 0) {
            throw new CommandLineException(String.format("Context size must be positive. Mismatches: %d Indels: %d", this.mismatchesContextSize, this.indelsContextSize));
        }
        this.mismatchesKeyMask = ContextCovariate.createMask(this.mismatchesContextSize);
        this.indelsKeyMask = ContextCovariate.createMask(this.indelsContextSize);
    }

    @Override
    public void recordValues(GATKRead read, SAMFileHeader header, ReadCovariates values, boolean recordIndelValues) {
        int originalReadLength = read.getLength();
        byte[] strandedClippedBases = ContextCovariate.getStrandedClippedBytes(read, this.lowQualTail);
        IntList mismatchKeys = ContextCovariate.contextWith(strandedClippedBases, this.mismatchesContextSize, this.mismatchesKeyMask);
        int readLengthAfterClipping = strandedClippedBases.length;
        if (readLengthAfterClipping != originalReadLength) {
            for (int i = 0; i < originalReadLength; ++i) {
                values.addCovariate(0, 0, 0, i);
            }
        }
        boolean negativeStrand = read.isReverseStrand();
        if (recordIndelValues) {
            IntList indelKeys = ContextCovariate.contextWith(strandedClippedBases, this.indelsContextSize, this.indelsKeyMask);
            for (int i = 0; i < readLengthAfterClipping; ++i) {
                int readOffset = ContextCovariate.getStrandedOffset(negativeStrand, i, readLengthAfterClipping);
                int indelKey = indelKeys.getInt(i);
                values.addCovariate(mismatchKeys.getInt(i), indelKey, indelKey, readOffset);
            }
        } else {
            for (int i = 0; i < readLengthAfterClipping; ++i) {
                int readOffset = ContextCovariate.getStrandedOffset(negativeStrand, i, readLengthAfterClipping);
                values.addCovariate(mismatchKeys.getInt(i), 0, 0, readOffset);
            }
        }
    }

    public static int getStrandedOffset(boolean isNegativeStrand, int offset, int readLength) {
        return isNegativeStrand ? readLength - offset - 1 : offset;
    }

    @VisibleForTesting
    static byte[] getStrandedClippedBytes(GATKRead read, byte lowQTail) {
        GATKRead clippedRead = ReadClipper.clipLowQualEnds(read, lowQTail, ClippingRepresentation.WRITE_NS);
        byte[] bases = clippedRead.getBases();
        if (read.isReverseStrand()) {
            return BaseUtils.simpleReverseComplement(bases);
        }
        return bases;
    }

    @Override
    public String formatKey(int key) {
        if (key == -1) {
            return null;
        }
        return ContextCovariate.contextFromKey(key);
    }

    @Override
    public int keyFromValue(Object value) {
        return ContextCovariate.keyFromContext((String)value);
    }

    private static int createMask(int contextSize) {
        int mask = 0;
        for (int i = 0; i < contextSize; ++i) {
            mask = mask << 2 | 3;
        }
        return mask << 4;
    }

    private static IntList contextWith(byte[] bases, int contextSize, int mask) {
        int baseIndex;
        int readLength = bases.length;
        IntArrayList keys = new IntArrayList(readLength);
        for (int i = 1; i < contextSize && i <= readLength; ++i) {
            keys.add(-1);
        }
        if (readLength < contextSize) {
            return keys;
        }
        int newBaseOffset = 2 * (contextSize - 1) + 4;
        int currentKey = ContextCovariate.keyFromContext(bases, 0, contextSize);
        keys.add(currentKey);
        int currentNPenalty = 0;
        if (currentKey == -1) {
            currentKey = 0;
            currentNPenalty = contextSize - 1;
            int offset = newBaseOffset;
            while ((baseIndex = BaseUtils.simpleBaseToBaseIndex(bases[currentNPenalty])) != -1) {
                currentKey |= baseIndex << offset;
                offset -= 2;
                --currentNPenalty;
            }
        }
        for (int currentIndex = contextSize; currentIndex < readLength; ++currentIndex) {
            baseIndex = BaseUtils.simpleBaseToBaseIndex(bases[currentIndex]);
            if (baseIndex == -1) {
                currentNPenalty = contextSize;
                currentKey = 0;
            } else {
                currentKey = currentKey >> 2 & mask;
                currentKey |= baseIndex << newBaseOffset;
                currentKey |= contextSize;
            }
            if (currentNPenalty == 0) {
                keys.add(currentKey);
                continue;
            }
            --currentNPenalty;
            keys.add(-1);
        }
        return keys;
    }

    public static int keyFromContext(String dna) {
        return ContextCovariate.keyFromContext(dna.getBytes(), 0, dna.length());
    }

    private static int keyFromContext(byte[] dna, int start, int end) {
        int key = end - start;
        int bitOffset = 4;
        for (int i = start; i < end; ++i) {
            int baseIndex = BaseUtils.simpleBaseToBaseIndex(dna[i]);
            if (baseIndex == -1) {
                return -1;
            }
            key |= baseIndex << bitOffset;
            bitOffset += 2;
        }
        return key;
    }

    public static String contextFromKey(int key) {
        if (key < 0) {
            throw new GATKException("dna conversion cannot handle negative numbers. Possible overflow?");
        }
        int length = key & 0xF;
        int mask = 48;
        int offset = 4;
        StringBuilder dna = new StringBuilder(length);
        for (int i = 0; i < length; ++i) {
            int baseIndex = (key & mask) >> offset;
            dna.append((char)BaseUtils.baseIndexToSimpleBase(baseIndex));
            mask <<= 2;
            offset += 2;
        }
        return dna.toString();
    }

    @Override
    public int maximumKeyValue() {
        int length;
        int key = length = Math.max(this.mismatchesContextSize, this.indelsContextSize);
        int bitOffset = 4;
        for (int i = 0; i < length; ++i) {
            key |= 3 << bitOffset;
            bitOffset += 2;
        }
        return key;
    }
}

