/*
 * Decompiled with CFR 0.152.
 */
package org.bouncycastle.pqc.crypto.xmss;

import java.io.IOException;
import java.security.SecureRandom;
import java.text.ParseException;
import org.bouncycastle.pqc.crypto.xmss.BDS;
import org.bouncycastle.pqc.crypto.xmss.HashTreeAddress;
import org.bouncycastle.pqc.crypto.xmss.KeyedHashFunctions;
import org.bouncycastle.pqc.crypto.xmss.LTreeAddress;
import org.bouncycastle.pqc.crypto.xmss.OTSHashAddress;
import org.bouncycastle.pqc.crypto.xmss.WOTSPlus;
import org.bouncycastle.pqc.crypto.xmss.WOTSPlusPublicKeyParameters;
import org.bouncycastle.pqc.crypto.xmss.WOTSPlusSignature;
import org.bouncycastle.pqc.crypto.xmss.XMSSAddress;
import org.bouncycastle.pqc.crypto.xmss.XMSSNode;
import org.bouncycastle.pqc.crypto.xmss.XMSSParameters;
import org.bouncycastle.pqc.crypto.xmss.XMSSPrivateKeyParameters;
import org.bouncycastle.pqc.crypto.xmss.XMSSPublicKeyParameters;
import org.bouncycastle.pqc.crypto.xmss.XMSSReducedSignature;
import org.bouncycastle.pqc.crypto.xmss.XMSSSignature;
import org.bouncycastle.pqc.crypto.xmss.XMSSUtil;

public class XMSS {
    private XMSSParameters params;
    private WOTSPlus wotsPlus;
    private SecureRandom prng;
    private KeyedHashFunctions khf;
    private XMSSPrivateKeyParameters privateKey;
    private XMSSPublicKeyParameters publicKey;

    public XMSS(XMSSParameters params) {
        if (params == null) {
            throw new NullPointerException("params == null");
        }
        this.params = params;
        this.wotsPlus = params.getWOTSPlus();
        this.prng = params.getPRNG();
        this.khf = this.wotsPlus.getKhf();
        try {
            this.privateKey = new XMSSPrivateKeyParameters.Builder(params).withBDSState(new BDS(this)).build();
            this.publicKey = new XMSSPublicKeyParameters.Builder(params).build();
        }
        catch (ParseException e) {
            e.printStackTrace();
        }
        catch (ClassNotFoundException e) {
            e.printStackTrace();
        }
        catch (IOException e) {
            e.printStackTrace();
        }
    }

    public void generateKeys() {
        this.privateKey = this.generatePrivateKey();
        XMSSNode root = this.getBDSState().initialize((OTSHashAddress)new OTSHashAddress.Builder().build());
        try {
            this.privateKey = new XMSSPrivateKeyParameters.Builder(this.params).withIndex(this.privateKey.getIndex()).withSecretKeySeed(this.privateKey.getSecretKeySeed()).withSecretKeyPRF(this.privateKey.getSecretKeyPRF()).withPublicSeed(this.privateKey.getPublicSeed()).withRoot(root.getValue()).withBDSState(this.privateKey.getBDSState()).build();
            this.publicKey = new XMSSPublicKeyParameters.Builder(this.params).withRoot(root.getValue()).withPublicSeed(this.getPublicSeed()).build();
        }
        catch (ParseException ex) {
            ex.printStackTrace();
        }
        catch (ClassNotFoundException e) {
            e.printStackTrace();
        }
        catch (IOException e) {
            e.printStackTrace();
        }
    }

    private XMSSPrivateKeyParameters generatePrivateKey() {
        int n = this.params.getDigestSize();
        byte[] secretKeySeed = new byte[n];
        this.prng.nextBytes(secretKeySeed);
        byte[] secretKeyPRF = new byte[n];
        this.prng.nextBytes(secretKeyPRF);
        byte[] publicSeed = new byte[n];
        this.prng.nextBytes(publicSeed);
        XMSSPrivateKeyParameters privateKey = null;
        try {
            privateKey = new XMSSPrivateKeyParameters.Builder(this.params).withSecretKeySeed(secretKeySeed).withSecretKeyPRF(secretKeyPRF).withPublicSeed(publicSeed).withBDSState(this.privateKey.getBDSState()).build();
        }
        catch (ParseException e) {
            e.printStackTrace();
        }
        catch (ClassNotFoundException e) {
            e.printStackTrace();
        }
        catch (IOException e) {
            e.printStackTrace();
        }
        return privateKey;
    }

    public void importState(byte[] privateKey, byte[] publicKey) throws ParseException, ClassNotFoundException, IOException {
        if (privateKey == null) {
            throw new NullPointerException("privateKey == null");
        }
        if (publicKey == null) {
            throw new NullPointerException("publicKey == null");
        }
        XMSSPrivateKeyParameters tmpPrivateKey = new XMSSPrivateKeyParameters.Builder(this.params).withPrivateKey(privateKey, this).build();
        XMSSPublicKeyParameters tmpPublicKey = new XMSSPublicKeyParameters.Builder(this.params).withPublicKey(publicKey).build();
        if (!XMSSUtil.compareByteArray(tmpPrivateKey.getRoot(), tmpPublicKey.getRoot())) {
            throw new IllegalStateException("root of private key and public key do not match");
        }
        if (!XMSSUtil.compareByteArray(tmpPrivateKey.getPublicSeed(), tmpPublicKey.getPublicSeed())) {
            throw new IllegalStateException("public seed of private key and public key do not match");
        }
        this.privateKey = tmpPrivateKey;
        this.publicKey = tmpPublicKey;
        this.wotsPlus.importKeys(new byte[this.params.getDigestSize()], this.privateKey.getPublicSeed());
    }

    public byte[] sign(byte[] message) {
        if (message == null) {
            throw new NullPointerException("message == null");
        }
        if (this.getBDSState().getAuthenticationPath().isEmpty()) {
            throw new IllegalStateException("not initialized");
        }
        int index = this.privateKey.getIndex();
        if (!XMSSUtil.isIndexValid(this.getParams().getHeight(), index)) {
            throw new IllegalArgumentException("index out of bounds");
        }
        byte[] random = this.khf.PRF(this.privateKey.getSecretKeyPRF(), XMSSUtil.toBytesBigEndian(index, 32));
        byte[] concatenated = XMSSUtil.concat(random, this.privateKey.getRoot(), XMSSUtil.toBytesBigEndian(index, this.params.getDigestSize()));
        byte[] messageDigest = this.khf.HMsg(concatenated, message);
        OTSHashAddress otsHashAddress = (OTSHashAddress)new OTSHashAddress.Builder().withOTSAddress(index).build();
        WOTSPlusSignature wotsPlusSignature = this.wotsSign(messageDigest, otsHashAddress);
        XMSSSignature signature = null;
        try {
            signature = (XMSSSignature)new XMSSSignature.Builder(this.params).withIndex(index).withRandom(random).withWOTSPlusSignature(wotsPlusSignature).withAuthPath(this.getBDSState().getAuthenticationPath()).build();
        }
        catch (ParseException ex) {
            ex.printStackTrace();
        }
        int treeHeight = this.getParams().getHeight();
        if (index < (1 << treeHeight) - 1) {
            this.getBDSState().nextAuthenticationPath((OTSHashAddress)new OTSHashAddress.Builder().build());
        }
        this.setIndex(index + 1);
        return signature.toByteArray();
    }

    public boolean verifySignature(byte[] message, byte[] signature, byte[] publicKey) throws ParseException {
        if (message == null) {
            throw new NullPointerException("message == null");
        }
        if (signature == null) {
            throw new NullPointerException("signature == null");
        }
        if (publicKey == null) {
            throw new NullPointerException("publicKey == null");
        }
        XMSSSignature sig = new XMSSSignature.Builder(this.params).withSignature(signature).build();
        XMSSPublicKeyParameters pubKey = new XMSSPublicKeyParameters.Builder(this.params).withPublicKey(publicKey).build();
        int savedIndex = this.privateKey.getIndex();
        byte[] savedPublicSeed = this.privateKey.getPublicSeed();
        int index = sig.getIndex();
        this.setIndex(index);
        this.setPublicSeed(pubKey.getPublicSeed());
        this.wotsPlus.importKeys(new byte[this.params.getDigestSize()], this.getPublicSeed());
        byte[] concatenated = XMSSUtil.concat(sig.getRandom(), pubKey.getRoot(), XMSSUtil.toBytesBigEndian(index, this.params.getDigestSize()));
        byte[] messageDigest = this.khf.HMsg(concatenated, message);
        OTSHashAddress otsHashAddress = (OTSHashAddress)new OTSHashAddress.Builder().withOTSAddress(index).build();
        XMSSNode rootNodeFromSignature = this.getRootNodeFromSignature(messageDigest, sig, otsHashAddress);
        this.setIndex(savedIndex);
        this.setPublicSeed(savedPublicSeed);
        return XMSSUtil.compareByteArray(rootNodeFromSignature.getValue(), pubKey.getRoot());
    }

    public byte[] exportPrivateKey() {
        return this.privateKey.toByteArray();
    }

    public byte[] exportPublicKey() {
        return this.publicKey.toByteArray();
    }

    protected XMSSNode randomizeHash(XMSSNode left, XMSSNode right, XMSSAddress address) {
        int i;
        XMSSAddress tmpAddress;
        XMSSAddress tmpAddress2;
        XMSSAddress tmpAddress3;
        if (left == null) {
            throw new NullPointerException("left == null");
        }
        if (right == null) {
            throw new NullPointerException("right == null");
        }
        if (left.getHeight() != right.getHeight()) {
            throw new IllegalStateException("height of both nodes must be equal");
        }
        if (address == null) {
            throw new NullPointerException("address == null");
        }
        byte[] publicSeed = this.getPublicSeed();
        if (address instanceof LTreeAddress) {
            tmpAddress3 = (LTreeAddress)address;
            address = (LTreeAddress)((LTreeAddress.Builder)((LTreeAddress.Builder)((LTreeAddress.Builder)new LTreeAddress.Builder().withLayerAddress(tmpAddress3.getLayerAddress())).withTreeAddress(tmpAddress3.getTreeAddress())).withLTreeAddress(((LTreeAddress)tmpAddress3).getLTreeAddress()).withTreeHeight(((LTreeAddress)tmpAddress3).getTreeHeight()).withTreeIndex(((LTreeAddress)tmpAddress3).getTreeIndex()).withKeyAndMask(0)).build();
        } else if (address instanceof HashTreeAddress) {
            tmpAddress3 = (HashTreeAddress)address;
            address = (HashTreeAddress)((HashTreeAddress.Builder)((HashTreeAddress.Builder)((HashTreeAddress.Builder)new HashTreeAddress.Builder().withLayerAddress(tmpAddress3.getLayerAddress())).withTreeAddress(tmpAddress3.getTreeAddress())).withTreeHeight(((HashTreeAddress)tmpAddress3).getTreeHeight()).withTreeIndex(((HashTreeAddress)tmpAddress3).getTreeIndex()).withKeyAndMask(0)).build();
        }
        byte[] key = this.khf.PRF(publicSeed, address.toByteArray());
        if (address instanceof LTreeAddress) {
            tmpAddress2 = (LTreeAddress)address;
            address = (LTreeAddress)((LTreeAddress.Builder)((LTreeAddress.Builder)((LTreeAddress.Builder)new LTreeAddress.Builder().withLayerAddress(tmpAddress2.getLayerAddress())).withTreeAddress(tmpAddress2.getTreeAddress())).withLTreeAddress(((LTreeAddress)tmpAddress2).getLTreeAddress()).withTreeHeight(((LTreeAddress)tmpAddress2).getTreeHeight()).withTreeIndex(((LTreeAddress)tmpAddress2).getTreeIndex()).withKeyAndMask(1)).build();
        } else if (address instanceof HashTreeAddress) {
            tmpAddress2 = (HashTreeAddress)address;
            address = (HashTreeAddress)((HashTreeAddress.Builder)((HashTreeAddress.Builder)((HashTreeAddress.Builder)new HashTreeAddress.Builder().withLayerAddress(tmpAddress2.getLayerAddress())).withTreeAddress(tmpAddress2.getTreeAddress())).withTreeHeight(((HashTreeAddress)tmpAddress2).getTreeHeight()).withTreeIndex(((HashTreeAddress)tmpAddress2).getTreeIndex()).withKeyAndMask(1)).build();
        }
        byte[] bitmask0 = this.khf.PRF(publicSeed, address.toByteArray());
        if (address instanceof LTreeAddress) {
            tmpAddress = (LTreeAddress)address;
            address = (LTreeAddress)((LTreeAddress.Builder)((LTreeAddress.Builder)((LTreeAddress.Builder)new LTreeAddress.Builder().withLayerAddress(tmpAddress.getLayerAddress())).withTreeAddress(tmpAddress.getTreeAddress())).withLTreeAddress(((LTreeAddress)tmpAddress).getLTreeAddress()).withTreeHeight(((LTreeAddress)tmpAddress).getTreeHeight()).withTreeIndex(((LTreeAddress)tmpAddress).getTreeIndex()).withKeyAndMask(2)).build();
        } else if (address instanceof HashTreeAddress) {
            tmpAddress = (HashTreeAddress)address;
            address = (HashTreeAddress)((HashTreeAddress.Builder)((HashTreeAddress.Builder)((HashTreeAddress.Builder)new HashTreeAddress.Builder().withLayerAddress(tmpAddress.getLayerAddress())).withTreeAddress(tmpAddress.getTreeAddress())).withTreeHeight(((HashTreeAddress)tmpAddress).getTreeHeight()).withTreeIndex(((HashTreeAddress)tmpAddress).getTreeIndex()).withKeyAndMask(2)).build();
        }
        byte[] bitmask1 = this.khf.PRF(publicSeed, address.toByteArray());
        int n = this.params.getDigestSize();
        byte[] tmpMask = new byte[2 * n];
        for (i = 0; i < n; ++i) {
            tmpMask[i] = (byte)(left.getValue()[i] ^ bitmask0[i]);
        }
        for (i = 0; i < n; ++i) {
            tmpMask[i + n] = (byte)(right.getValue()[i] ^ bitmask1[i]);
        }
        byte[] out = this.khf.H(key, tmpMask);
        return new XMSSNode(left.getHeight(), out);
    }

    protected XMSSNode lTree(WOTSPlusPublicKeyParameters publicKey, LTreeAddress address) {
        int i;
        if (publicKey == null) {
            throw new NullPointerException("publicKey == null");
        }
        if (address == null) {
            throw new NullPointerException("address == null");
        }
        int len = this.wotsPlus.getParams().getLen();
        byte[][] publicKeyBytes = publicKey.toByteArray();
        XMSSNode[] publicKeyNodes = new XMSSNode[publicKeyBytes.length];
        for (i = 0; i < publicKeyBytes.length; ++i) {
            publicKeyNodes[i] = new XMSSNode(0, publicKeyBytes[i]);
        }
        address = (LTreeAddress)((LTreeAddress.Builder)((LTreeAddress.Builder)((LTreeAddress.Builder)new LTreeAddress.Builder().withLayerAddress(address.getLayerAddress())).withTreeAddress(address.getTreeAddress())).withLTreeAddress(address.getLTreeAddress()).withTreeHeight(0).withTreeIndex(address.getTreeIndex()).withKeyAndMask(address.getKeyAndMask())).build();
        while (len > 1) {
            for (i = 0; i < (int)Math.floor(len / 2); ++i) {
                address = (LTreeAddress)((LTreeAddress.Builder)((LTreeAddress.Builder)((LTreeAddress.Builder)new LTreeAddress.Builder().withLayerAddress(address.getLayerAddress())).withTreeAddress(address.getTreeAddress())).withLTreeAddress(address.getLTreeAddress()).withTreeHeight(address.getTreeHeight()).withTreeIndex(i).withKeyAndMask(address.getKeyAndMask())).build();
                publicKeyNodes[i] = this.randomizeHash(publicKeyNodes[2 * i], publicKeyNodes[2 * i + 1], address);
            }
            if (len % 2 == 1) {
                publicKeyNodes[(int)Math.floor((double)((double)(len / 2)))] = publicKeyNodes[len - 1];
            }
            len = (int)Math.ceil((double)len / 2.0);
            address = (LTreeAddress)((LTreeAddress.Builder)((LTreeAddress.Builder)((LTreeAddress.Builder)new LTreeAddress.Builder().withLayerAddress(address.getLayerAddress())).withTreeAddress(address.getTreeAddress())).withLTreeAddress(address.getLTreeAddress()).withTreeHeight(address.getTreeHeight() + 1).withTreeIndex(address.getTreeIndex()).withKeyAndMask(address.getKeyAndMask())).build();
        }
        return publicKeyNodes[0];
    }

    protected WOTSPlusSignature wotsSign(byte[] messageDigest, OTSHashAddress otsHashAddress) {
        if (messageDigest.length != this.params.getDigestSize()) {
            throw new IllegalArgumentException("size of messageDigest needs to be equal to size of digest");
        }
        if (otsHashAddress == null) {
            throw new NullPointerException("otsHashAddress == null");
        }
        this.wotsPlus.importKeys(this.getWOTSPlusSecretKey(otsHashAddress), this.getPublicSeed());
        return this.wotsPlus.sign(messageDigest, otsHashAddress);
    }

    protected XMSSNode getRootNodeFromSignature(byte[] messageDigest, XMSSReducedSignature signature, OTSHashAddress otsHashAddress) {
        if (messageDigest.length != this.params.getDigestSize()) {
            throw new IllegalArgumentException("size of messageDigest needs to be equal to size of digest");
        }
        if (signature == null) {
            throw new NullPointerException("signature == null");
        }
        if (otsHashAddress == null) {
            throw new NullPointerException("otsHashAddress == null");
        }
        LTreeAddress lTreeAddress = (LTreeAddress)((LTreeAddress.Builder)((LTreeAddress.Builder)new LTreeAddress.Builder().withLayerAddress(otsHashAddress.getLayerAddress())).withTreeAddress(otsHashAddress.getTreeAddress())).withLTreeAddress(otsHashAddress.getOTSAddress()).build();
        HashTreeAddress hashTreeAddress = (HashTreeAddress)((HashTreeAddress.Builder)((HashTreeAddress.Builder)new HashTreeAddress.Builder().withLayerAddress(otsHashAddress.getLayerAddress())).withTreeAddress(otsHashAddress.getTreeAddress())).withTreeIndex(otsHashAddress.getOTSAddress()).build();
        WOTSPlusPublicKeyParameters wotsPlusPK = this.wotsPlus.getPublicKeyFromSignature(messageDigest, signature.getWOTSPlusSignature(), otsHashAddress);
        XMSSNode[] node = new XMSSNode[2];
        node[0] = this.lTree(wotsPlusPK, lTreeAddress);
        for (int k = 0; k < this.params.getHeight(); ++k) {
            hashTreeAddress = (HashTreeAddress)((HashTreeAddress.Builder)((HashTreeAddress.Builder)((HashTreeAddress.Builder)new HashTreeAddress.Builder().withLayerAddress(hashTreeAddress.getLayerAddress())).withTreeAddress(hashTreeAddress.getTreeAddress())).withTreeHeight(k).withTreeIndex(hashTreeAddress.getTreeIndex()).withKeyAndMask(hashTreeAddress.getKeyAndMask())).build();
            if (Math.floor(this.privateKey.getIndex() / (1 << k)) % 2.0 == 0.0) {
                hashTreeAddress = (HashTreeAddress)((HashTreeAddress.Builder)((HashTreeAddress.Builder)((HashTreeAddress.Builder)new HashTreeAddress.Builder().withLayerAddress(hashTreeAddress.getLayerAddress())).withTreeAddress(hashTreeAddress.getTreeAddress())).withTreeHeight(hashTreeAddress.getTreeHeight()).withTreeIndex(hashTreeAddress.getTreeIndex() / 2).withKeyAndMask(hashTreeAddress.getKeyAndMask())).build();
                node[1] = this.randomizeHash(node[0], signature.getAuthPath().get(k), hashTreeAddress);
                node[1] = new XMSSNode(node[1].getHeight() + 1, node[1].getValue());
            } else {
                hashTreeAddress = (HashTreeAddress)((HashTreeAddress.Builder)((HashTreeAddress.Builder)((HashTreeAddress.Builder)new HashTreeAddress.Builder().withLayerAddress(hashTreeAddress.getLayerAddress())).withTreeAddress(hashTreeAddress.getTreeAddress())).withTreeHeight(hashTreeAddress.getTreeHeight()).withTreeIndex((hashTreeAddress.getTreeIndex() - 1) / 2).withKeyAndMask(hashTreeAddress.getKeyAndMask())).build();
                node[1] = this.randomizeHash(signature.getAuthPath().get(k), node[0], hashTreeAddress);
                node[1] = new XMSSNode(node[1].getHeight() + 1, node[1].getValue());
            }
            node[0] = node[1];
        }
        return node[0];
    }

    protected byte[] getWOTSPlusSecretKey(OTSHashAddress otsHashAddress) {
        otsHashAddress = (OTSHashAddress)((OTSHashAddress.Builder)((OTSHashAddress.Builder)new OTSHashAddress.Builder().withLayerAddress(otsHashAddress.getLayerAddress())).withTreeAddress(otsHashAddress.getTreeAddress())).withOTSAddress(otsHashAddress.getOTSAddress()).build();
        return this.khf.PRF(this.privateKey.getSecretKeySeed(), otsHashAddress.toByteArray());
    }

    public XMSSParameters getParams() {
        return this.params;
    }

    protected WOTSPlus getWOTSPlus() {
        return this.wotsPlus;
    }

    protected KeyedHashFunctions getKhf() {
        return this.khf;
    }

    public byte[] getRoot() {
        return this.privateKey.getRoot();
    }

    protected void setRoot(byte[] root) {
        try {
            this.privateKey = new XMSSPrivateKeyParameters.Builder(this.params).withIndex(this.privateKey.getIndex()).withSecretKeySeed(this.privateKey.getSecretKeySeed()).withSecretKeyPRF(this.privateKey.getSecretKeyPRF()).withPublicSeed(this.getPublicSeed()).withRoot(root).withBDSState(this.privateKey.getBDSState()).build();
            this.publicKey = new XMSSPublicKeyParameters.Builder(this.params).withRoot(root).withPublicSeed(this.getPublicSeed()).build();
        }
        catch (ParseException ex) {
            ex.printStackTrace();
        }
        catch (ClassNotFoundException e) {
            e.printStackTrace();
        }
        catch (IOException e) {
            e.printStackTrace();
        }
    }

    public int getIndex() {
        return this.privateKey.getIndex();
    }

    protected void setIndex(int index) {
        try {
            this.privateKey = new XMSSPrivateKeyParameters.Builder(this.params).withIndex(index).withSecretKeySeed(this.privateKey.getSecretKeySeed()).withSecretKeyPRF(this.privateKey.getSecretKeyPRF()).withPublicSeed(this.privateKey.getPublicSeed()).withRoot(this.privateKey.getRoot()).withBDSState(this.privateKey.getBDSState()).build();
        }
        catch (ParseException ex) {
            ex.printStackTrace();
        }
        catch (ClassNotFoundException e) {
            e.printStackTrace();
        }
        catch (IOException e) {
            e.printStackTrace();
        }
    }

    public byte[] getPublicSeed() {
        return this.privateKey.getPublicSeed();
    }

    protected void setPublicSeed(byte[] publicSeed) {
        try {
            this.privateKey = new XMSSPrivateKeyParameters.Builder(this.params).withIndex(this.privateKey.getIndex()).withSecretKeySeed(this.privateKey.getSecretKeySeed()).withSecretKeyPRF(this.privateKey.getSecretKeyPRF()).withPublicSeed(publicSeed).withRoot(this.getRoot()).withBDSState(this.privateKey.getBDSState()).build();
            this.publicKey = new XMSSPublicKeyParameters.Builder(this.params).withRoot(this.getRoot()).withPublicSeed(publicSeed).build();
        }
        catch (ParseException ex) {
            ex.printStackTrace();
        }
        catch (ClassNotFoundException e) {
            e.printStackTrace();
        }
        catch (IOException e) {
            e.printStackTrace();
        }
        this.wotsPlus.importKeys(new byte[this.params.getDigestSize()], publicSeed);
    }

    protected BDS getBDSState() {
        return this.privateKey.getBDSState();
    }
}

