/*
 * Decompiled with CFR 0.152.
 */
package org.openeuler.sm4.mode;

import java.security.AlgorithmParameters;
import java.security.InvalidAlgorithmParameterException;
import java.security.InvalidKeyException;
import java.security.Key;
import java.security.SecureRandom;
import java.security.spec.AlgorithmParameterSpec;
import java.security.spec.InvalidParameterSpecException;
import java.util.Arrays;
import javax.crypto.BadPaddingException;
import javax.crypto.IllegalBlockSizeException;
import javax.crypto.NoSuchPaddingException;
import javax.crypto.ShortBufferException;
import javax.crypto.spec.IvParameterSpec;
import org.openeuler.BGMJCEProvider;
import org.openeuler.sm4.SM4Util;
import org.openeuler.sm4.StreamModeBaseCipher;

public class CCM
extends StreamModeBaseCipher {
    private final int M = 8;
    private int L;
    private final int defaultIvLen = 12;
    private byte[] B = new byte[16];
    private byte[] counter0 = new byte[16];
    private byte[] aad;
    private byte[] lenA;

    @Override
    public void engineInit(int opmode, Key key, SecureRandom random) throws InvalidKeyException {
        try {
            this.engineInit(opmode, key, (AlgorithmParameterSpec)null, random);
        }
        catch (InvalidAlgorithmParameterException e) {
            throw new InvalidKeyException(e.getMessage());
        }
    }

    @Override
    public void engineInit(int opmode, Key key, AlgorithmParameterSpec params, SecureRandom random) throws InvalidKeyException, InvalidAlgorithmParameterException {
        this.init(opmode, key);
        if (params == null) {
            if (this.opmode == 1) {
                if (random == null) {
                    random = BGMJCEProvider.getRandom();
                }
                this.iv = new byte[12];
                random.nextBytes(this.iv);
            } else if (this.opmode == 2) {
                throw new InvalidAlgorithmParameterException("need an IV");
            }
        } else {
            if (!(params instanceof IvParameterSpec)) {
                throw new InvalidAlgorithmParameterException();
            }
            IvParameterSpec param = (IvParameterSpec)params;
            if (param.getIV().length < 7 || param.getIV().length > 13) {
                throw new InvalidAlgorithmParameterException("nonce must have length from 7 to 13 octets");
            }
            this.iv = param.getIV();
        }
        this.L = 15 - this.iv.length;
        this.getCountero();
        SM4Util.copyArray(this.counter0, 0, this.counter0.length, this.counter, 0);
        this.incr();
        this.isInitialized = true;
    }

    @Override
    public void engineInit(int opmode, Key key, AlgorithmParameters params, SecureRandom random) throws InvalidKeyException, InvalidAlgorithmParameterException {
        IvParameterSpec spec = null;
        String paramType = null;
        if (params != null) {
            try {
                paramType = "IV";
                spec = params.getParameterSpec(IvParameterSpec.class);
            }
            catch (InvalidParameterSpecException e) {
                throw new InvalidAlgorithmParameterException("Wrong parameter type: " + paramType + " expected");
            }
        }
        this.engineInit(opmode, key, spec, random);
    }

    @Override
    public int engineGetOutputSize(int inputLen) {
        if (this.opmode == 1) {
            return inputLen + 8;
        }
        if (this.opmode == 2) {
            return inputLen - 8;
        }
        return 0;
    }

    @Override
    protected void engineUpdateAAD(byte[] src, int offset, int len) {
        if (!this.isInitialized) {
            throw new IllegalStateException("cipher uninitiallized");
        }
        this.aad = Arrays.copyOfRange(src, offset, len);
    }

    @Override
    public int engineUpdate(byte[] input, int inputOffset, int inputLen, byte[] output, int outputOffset) throws ShortBufferException {
        if (!this.isInitialized) {
            throw new IllegalStateException("cipher uninitialized");
        }
        this.inputUpdate = input;
        this.inputLenUpdate = inputLen;
        this.inputOffsetUpdate = inputOffset;
        this.len = 0;
        return 0;
    }

    @Override
    public byte[] engineUpdate(byte[] input, int inputOffset, int inputLen) {
        if (!this.isInitialized) {
            throw new IllegalStateException("cipher uninitialized");
        }
        this.inputUpdate = input;
        this.inputLenUpdate = inputLen;
        this.inputOffsetUpdate = inputOffset;
        this.len = 0;
        return null;
    }

    @Override
    public void engineSetPadding(String padding) throws NoSuchPaddingException {
        if (!padding.toUpperCase().equals("NOPADDING")) {
            throw new NoSuchPaddingException("only nopadding can be used in this mode");
        }
        super.engineSetPadding(padding);
    }

    @Override
    public byte[] engineDoFinal(byte[] input, int inputOffset, int inputLen) throws IllegalBlockSizeException, BadPaddingException {
        if (!this.isInitialized) {
            throw new IllegalStateException("cipher uninitialized");
        }
        byte[] res = null;
        int restLen = this.inputLenUpdate - this.len;
        if (this.opmode == 1) {
            res = new byte[restLen + inputLen + 8];
            if (restLen == 0) {
                this.encrypt(input, inputOffset, inputLen, res, 0);
            } else {
                byte[] allInput = new byte[restLen + inputLen];
                SM4Util.copyArray(this.inputUpdate, this.inputOffsetUpdate + this.len, restLen, allInput, 0);
                SM4Util.copyArray(input, inputOffset, inputLen, allInput, restLen);
                this.encrypt(allInput, 0, allInput.length, res, 0);
            }
        } else if (this.opmode == 2) {
            if (restLen + inputLen < 8) {
                throw new IllegalBlockSizeException();
            }
            if (restLen == 0) {
                res = this.decrypt(input, inputOffset, inputLen);
            } else {
                byte[] allInput = new byte[restLen + inputLen];
                SM4Util.copyArray(this.inputUpdate, this.inputOffsetUpdate + this.len, restLen, allInput, 0);
                SM4Util.copyArray(input, inputOffset, inputLen, allInput, restLen);
                res = this.decrypt(allInput, 0, allInput.length);
            }
        }
        this.reset();
        return res;
    }

    @Override
    public int engineDoFinal(byte[] input, int inputOffset, int inputLen, byte[] output, int outputOffset) throws ShortBufferException, IllegalBlockSizeException, BadPaddingException {
        if (!this.isInitialized) {
            throw new IllegalStateException("cipher uninitialized");
        }
        int restLen = this.inputLenUpdate - this.len;
        int need = 0;
        if (this.opmode == 1) {
            need = restLen + inputLen + 8;
            if (outputOffset + need > output.length) {
                throw new ShortBufferException();
            }
            if (restLen == 0) {
                this.encrypt(input, inputOffset, inputLen, output, outputOffset);
            } else {
                byte[] allInput = new byte[restLen + inputLen];
                SM4Util.copyArray(this.inputUpdate, this.inputOffsetUpdate + this.len, restLen, allInput, 0);
                SM4Util.copyArray(input, inputOffset, inputLen, allInput, restLen);
                this.encrypt(allInput, 0, allInput.length, output, outputOffset);
            }
        } else if (this.opmode == 2) {
            if (restLen + inputLen < 8) {
                throw new IllegalBlockSizeException();
            }
            need = inputLen - 8;
            if (outputOffset + need > output.length) {
                throw new ShortBufferException();
            }
            if (restLen == 0) {
                byte[] decrypt = this.decrypt(input, inputOffset, inputLen);
                SM4Util.copyArray(decrypt, 0, decrypt.length, output, outputOffset);
            } else {
                byte[] allInput = new byte[restLen + inputLen];
                SM4Util.copyArray(this.inputUpdate, this.inputOffsetUpdate + this.len, restLen, allInput, 0);
                SM4Util.copyArray(input, inputOffset, inputLen, allInput, restLen);
                byte[] decrypt = this.decrypt(allInput, 0, allInput.length);
                SM4Util.copyArray(decrypt, 0, decrypt.length, output, outputOffset);
            }
        }
        this.reset();
        return need;
    }

    private void getCountero() {
        this.counter0[0] = (byte)(this.L - 1);
        for (int i = 1; i <= 15 - this.L; ++i) {
            this.counter0[i] = this.iv[i - 1];
        }
    }

    private void incr() {
        int r;
        for (r = this.counter.length - 1; r >= 16 - this.L; --r) {
            try {
                this.counter[r] = this.increment(r);
                break;
            }
            catch (Exception e) {
                continue;
            }
        }
        if (r == 15 - this.L) {
            for (int i = 12; i < this.counter.length; ++i) {
                this.counter[i] = 0;
            }
        }
    }

    private byte increment(int index) throws Exception {
        int i;
        for (i = 0; i < 8 && (1 << i & this.counter[index]) != 0; ++i) {
        }
        if (i == 8) {
            throw new Exception();
        }
        this.counter[index] = (byte)(1 << i | this.counter[index]);
        int t = 0;
        for (int j = 7; j >= i; --j) {
            t |= 1 << j;
        }
        for (int k = index + 1; k < this.counter.length; ++k) {
            this.counter[k] = 0;
        }
        return (byte)(t & this.counter[index]);
    }

    private void encrypt(byte[] input, int inputOffset, int inputLen, byte[] output, int outputOffset) {
        byte[] plainText = null;
        if (inputLen != 0) {
            plainText = Arrays.copyOfRange(input, inputOffset, inputOffset + inputLen);
        }
        byte[] tag = this.getTag(plainText);
        if (plainText != null) {
            byte[] xor;
            int i = 0;
            while (i + 16 <= plainText.length) {
                byte[] encrypt = this.sm4.encrypt(this.rk, this.counter, 0);
                this.incr();
                xor = this.sm4.xor(encrypt, 0, 16, plainText, i, 16);
                SM4Util.copyArray(xor, 0, xor.length, output, outputOffset + i);
                i += 16;
            }
            if (plainText.length % 16 != 0) {
                byte[] encrrypt = this.sm4.encrypt(this.rk, this.counter, 0);
                this.incr();
                xor = this.sm4.xor(plainText, i, plainText.length % 16, encrrypt, 0, 16);
                SM4Util.copyArray(xor, 0, xor.length, output, outputOffset + i);
            }
        }
        SM4Util.copyArray(tag, 0, tag.length, output, outputOffset + inputLen);
    }

    private byte[] decrypt(byte[] input, int inputOffset, int inputLen) {
        byte[] fill = Arrays.copyOfRange(input, inputOffset, inputLen + inputOffset - 8);
        byte[] res = new byte[fill.length];
        byte[] _T = Arrays.copyOfRange(input, inputOffset + inputLen - 8, inputLen + inputOffset);
        int i = 0;
        while (i + 16 <= fill.length) {
            byte[] curBlock = Arrays.copyOfRange(fill, i, i + 16);
            byte[] encrypt = this.sm4.encrypt(this.rk, this.counter, 0);
            this.incr();
            byte[] xor = this.sm4.xor(encrypt, curBlock);
            SM4Util.copyArray(xor, 0, xor.length, res, i);
            i += 16;
        }
        if (fill.length % 16 != 0) {
            byte[] encrrypt = this.sm4.encrypt(this.rk, this.counter, 0);
            this.incr();
            byte[] xor = this.sm4.xor(Arrays.copyOfRange(fill, i, fill.length), encrrypt);
            SM4Util.copyArray(xor, 0, xor.length, res, res.length - xor.length);
        }
        this.checkMac(_T, this.getTag(res));
        return res;
    }

    @Override
    public void reset() {
        super.reset();
        this.aad = null;
        Arrays.fill(this.B, (byte)0);
        this.L = 15 - this.iv.length;
        this.getCountero();
        SM4Util.copyArray(this.counter0, 0, this.counter0.length, this.counter, 0);
        this.incr();
        this.lenA = null;
    }

    private byte[] getTag(byte[] plainText) {
        byte[] block;
        this.B[0] = 24;
        this.B[0] = this.aad == null || this.aad.length == 0 ? (byte)(this.B[0] & 0xBF) : (byte)(this.B[0] | 0x40);
        byte tem = (byte)(this.L - 1);
        tem = (byte)(tem & 7);
        this.B[0] = (byte)(this.B[0] | tem);
        SM4Util.copyArray(this.iv, 0, this.iv.length, this.B, 1);
        this.readInt(this.B, plainText == null ? 0 : plainText.length, 16 - this.L);
        this.B = this.sm4.encrypt(this.rk, this.sm4.xor(this.B, new byte[16]), 0);
        if (this.aad != null && this.aad.length != 0) {
            if (this.aad.length < 65280) {
                this.lenA = new byte[2];
                this.readInt(this.lenA, this.aad.length, 0);
            } else {
                this.lenA = new byte[6];
                this.lenA[0] = -1;
                this.lenA[1] = -2;
                this.sm4.intToBigEndian(this.lenA, this.aad.length, 2);
            }
        }
        if (this.lenA != null) {
            if (this.aad.length + this.lenA.length >= 16) {
                int needLen = 16 - this.lenA.length;
                block = new byte[16];
                SM4Util.copyArray(this.lenA, 0, this.lenA.length, block, 0);
                SM4Util.copyArray(Arrays.copyOfRange(this.aad, 0, needLen), 0, needLen, block, this.lenA.length);
                this.B = this.sm4.encrypt(this.rk, this.sm4.xor(this.B, block), 0);
                int i = needLen;
                while (i + 16 <= this.aad.length) {
                    this.B = this.sm4.encrypt(this.rk, this.sm4.xor(this.B, Arrays.copyOfRange(this.aad, i, i + 16)), 0);
                    i += 16;
                }
                if ((this.aad.length - needLen) % 16 != 0) {
                    block = new byte[16];
                    SM4Util.copyArray(this.aad, i, this.aad.length - i, block, 0);
                    this.B = this.sm4.encrypt(this.rk, this.sm4.xor(this.B, block), 0);
                }
            } else {
                byte[] block2 = new byte[16];
                SM4Util.copyArray(this.lenA, 0, this.lenA.length, block2, 0);
                SM4Util.copyArray(this.aad, 0, this.aad.length, block2, this.lenA.length);
                this.B = this.sm4.encrypt(this.rk, this.sm4.xor(this.B, block2), 0);
            }
        }
        if (plainText != null) {
            int i = 0;
            while (i + 16 <= plainText.length) {
                this.B = this.sm4.encrypt(this.rk, this.sm4.xor(this.B, Arrays.copyOfRange(plainText, i, i + 16)), 0);
                i += 16;
            }
            if (plainText.length % 16 != 0) {
                block = new byte[16];
                SM4Util.copyArray(plainText, i, plainText.length - i, block, 0);
                this.B = this.sm4.encrypt(this.rk, this.sm4.xor(this.B, block), 0);
            }
        }
        byte[] encrypt = this.sm4.encrypt(this.rk, this.counter0, 0);
        return Arrays.copyOfRange(this.sm4.xor(encrypt, this.B), 0, 8);
    }

    private void readInt(byte[] arr, int x, int start) {
        if (arr.length - start >= 4) {
            this.sm4.intToBigEndian(arr, x, arr.length - 4);
        } else if (arr.length - start == 3) {
            arr[start] = (byte)(x << 8 >>> 24);
            arr[++start] = (byte)(x << 16 >>> 24);
            arr[++start] = (byte)(x << 24 >>> 24);
        } else if (arr.length - start == 2) {
            arr[start] = (byte)(x << 16 >>> 24);
            arr[++start] = (byte)(x << 24 >>> 24);
        } else if (arr.length - start == 1) {
            arr[start] = (byte)(x << 24 >>> 24);
        }
    }

    private void checkMac(byte[] T, byte[] _T) {
        if (!Arrays.equals(T, _T)) {
            throw new RuntimeException("mac check failed in CCM mode.");
        }
    }
}

