/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sshd.server.kex;

import java.io.FileNotFoundException;
import java.io.IOException;
import java.math.BigInteger;
import java.net.URL;
import java.security.KeyPair;
import java.util.ArrayList;
import java.util.List;
import org.apache.sshd.common.FactoryManagerUtils;
import org.apache.sshd.common.NamedFactory;
import org.apache.sshd.common.SshException;
import org.apache.sshd.common.kex.DHFactory;
import org.apache.sshd.common.kex.DHG;
import org.apache.sshd.common.kex.DHGroupData;
import org.apache.sshd.common.kex.KexProposalOption;
import org.apache.sshd.common.kex.KeyExchange;
import org.apache.sshd.common.kex.KeyExchangeFactory;
import org.apache.sshd.common.random.Random;
import org.apache.sshd.common.session.AbstractSession;
import org.apache.sshd.common.signature.Signature;
import org.apache.sshd.common.util.GenericUtils;
import org.apache.sshd.common.util.SecurityUtils;
import org.apache.sshd.common.util.ValidateUtils;
import org.apache.sshd.common.util.buffer.Buffer;
import org.apache.sshd.common.util.buffer.BufferUtils;
import org.apache.sshd.common.util.buffer.ByteArrayBuffer;
import org.apache.sshd.server.ServerFactoryManager;
import org.apache.sshd.server.kex.AbstractDHServerKeyExchange;
import org.apache.sshd.server.kex.Moduli;

public class DHGEXServer
extends AbstractDHServerKeyExchange {
    protected final DHFactory factory;
    protected DHG dh;
    protected int min;
    protected int prf;
    protected int max;
    protected byte expected;
    protected boolean oldRequest;

    protected DHGEXServer(DHFactory factory) {
        this.factory = ValidateUtils.checkNotNull(factory, "No factory");
    }

    public static KeyExchangeFactory newFactory(final DHFactory factory) {
        return new KeyExchangeFactory(){

            @Override
            public KeyExchange create() {
                return new DHGEXServer(factory);
            }

            @Override
            public String getName() {
                return factory.getName();
            }

            public String toString() {
                return NamedFactory.class.getSimpleName() + "<" + KeyExchange.class.getSimpleName() + ">" + "[" + this.getName() + "]";
            }
        };
    }

    @Override
    public void init(AbstractSession s, byte[] v_s, byte[] v_c, byte[] i_s, byte[] i_c) throws Exception {
        super.init(s, v_s, v_c, i_s, i_c);
        this.expected = (byte)34;
    }

    @Override
    public boolean next(Buffer buffer) throws Exception {
        int cmd = buffer.getUByte();
        if (cmd == 30 && this.expected == 34) {
            this.log.debug("Received SSH_MSG_KEX_DH_GEX_REQUEST_OLD");
            this.oldRequest = true;
            this.min = 1024;
            this.prf = buffer.getInt();
            this.max = 8192;
            if (this.max < this.min || this.prf < this.min || this.max < this.prf) {
                throw new SshException(3, "Protocol error: bad parameters " + this.min + " !< " + this.prf + " !< " + this.max);
            }
            this.dh = this.chooseDH(this.min, this.prf, this.max);
            this.f = this.dh.getE();
            this.hash = this.dh.getHash();
            this.hash.init();
            this.log.debug("Send SSH_MSG_KEX_DH_GEX_GROUP");
            buffer = this.session.createBuffer((byte)31);
            buffer.putMPInt(this.dh.getP());
            buffer.putMPInt(this.dh.getG());
            this.session.writePacket(buffer);
            this.expected = (byte)32;
            return false;
        }
        if (cmd == 34 && this.expected == 34) {
            this.log.debug("Received SSH_MSG_KEX_DH_GEX_REQUEST");
            this.min = buffer.getInt();
            this.prf = buffer.getInt();
            this.max = buffer.getInt();
            if (this.prf < this.min || this.max < this.prf) {
                throw new SshException(3, "Protocol error: bad parameters " + this.min + " !< " + this.prf + " !< " + this.max);
            }
            this.dh = this.chooseDH(this.min, this.prf, this.max);
            this.f = this.dh.getE();
            this.hash = this.dh.getHash();
            this.hash.init();
            this.log.debug("Send SSH_MSG_KEX_DH_GEX_GROUP");
            buffer = this.session.createBuffer((byte)31);
            buffer.putMPInt(this.dh.getP());
            buffer.putMPInt(this.dh.getG());
            this.session.writePacket(buffer);
            this.expected = (byte)32;
            return false;
        }
        if (cmd != this.expected) {
            throw new SshException(3, "Protocol error: expected packet " + this.expected + ", got " + cmd);
        }
        if (cmd == 32) {
            this.log.debug("Received SSH_MSG_KEX_DH_GEX_INIT");
            this.e = buffer.getMPIntAsBytes();
            this.dh.setF(this.e);
            this.k = this.dh.getK();
            KeyPair kp = ValidateUtils.checkNotNull(this.session.getHostKey(), "No server key pair available");
            String algo = this.session.getNegotiatedKexParameter(KexProposalOption.SERVERKEYS);
            ServerFactoryManager manager = this.session.getFactoryManager();
            Signature sig = ValidateUtils.checkNotNull(NamedFactory.Utils.create(manager.getSignatureFactories(), algo), "Unknown negotiated server keys: %s", (Object)algo);
            sig.initSigner(kp.getPrivate());
            buffer = new ByteArrayBuffer();
            buffer.putRawPublicKey(kp.getPublic());
            byte[] k_s = buffer.getCompactData();
            buffer.clear();
            buffer.putBytes(this.v_c);
            buffer.putBytes(this.v_s);
            buffer.putBytes(this.i_c);
            buffer.putBytes(this.i_s);
            buffer.putBytes(k_s);
            if (this.oldRequest) {
                buffer.putInt(this.prf);
            } else {
                buffer.putInt(this.min);
                buffer.putInt(this.prf);
                buffer.putInt(this.max);
            }
            buffer.putMPInt(this.dh.getP());
            buffer.putMPInt(this.dh.getG());
            buffer.putMPInt(this.e);
            buffer.putMPInt(this.f);
            buffer.putMPInt(this.k);
            this.hash.update(buffer.array(), 0, buffer.available());
            this.h = this.hash.digest();
            buffer.clear();
            sig.update(this.h, 0, this.h.length);
            buffer.putString(algo);
            buffer.putBytes(sig.sign());
            byte[] sigH = buffer.getCompactData();
            if (this.log.isDebugEnabled()) {
                this.log.debug("K_S:  {}", (Object)BufferUtils.printHex(k_s));
                this.log.debug("f:    {}", (Object)BufferUtils.printHex(this.f));
                this.log.debug("sigH: {}", (Object)BufferUtils.printHex(sigH));
            }
            this.log.debug("Send SSH_MSG_KEX_DH_GEX_REPLY");
            buffer.clear();
            buffer.rpos(5);
            buffer.wpos(5);
            buffer.putByte((byte)33);
            buffer.putBytes(k_s);
            buffer.putBytes(this.f);
            buffer.putBytes(sigH);
            this.session.writePacket(buffer);
            return true;
        }
        return false;
    }

    private DHG chooseDH(int min, int prf, int max) throws Exception {
        List<Moduli.DhGroup> groups = this.loadModuliGroups();
        min = Math.max(min, 1024);
        prf = Math.max(prf, 1024);
        prf = Math.min(prf, SecurityUtils.isBouncyCastleRegistered() ? 8192 : 1024);
        max = Math.min(max, 8192);
        int bestSize = 0;
        ArrayList<Moduli.DhGroup> selected = new ArrayList<Moduli.DhGroup>();
        for (Moduli.DhGroup group : groups) {
            if (group.size < min || group.size > max) continue;
            if (group.size > prf && group.size < bestSize || group.size > bestSize && bestSize < prf) {
                bestSize = group.size;
                selected.clear();
            }
            if (group.size != bestSize) continue;
            selected.add(group);
        }
        if (selected.isEmpty()) {
            this.log.warn("No suitable primes found, defaulting to DHG1");
            return this.getDH(new BigInteger(DHGroupData.getP1()), new BigInteger(DHGroupData.getG()));
        }
        Random random = this.session.getFactoryManager().getRandomFactory().create();
        int which = random.random(selected.size());
        Moduli.DhGroup group = (Moduli.DhGroup)selected.get(which);
        return this.getDH(group.p, group.g);
    }

    protected List<Moduli.DhGroup> loadModuliGroups() throws IOException {
        URL moduli;
        List<Moduli.DhGroup> groups = null;
        String moduliStr = FactoryManagerUtils.getString(this.session, "moduli-url");
        if (!GenericUtils.isEmpty(moduliStr)) {
            try {
                moduli = new URL(moduliStr);
                groups = Moduli.parseModuli(moduli);
            }
            catch (IOException e) {
                this.log.warn("Error (" + e.getClass().getSimpleName() + ") loading external moduli from " + moduliStr + ": " + e.getMessage());
            }
        }
        if (groups == null) {
            moduliStr = "/org/apache/sshd/moduli";
            try {
                moduli = this.getClass().getResource(moduliStr);
                if (moduli == null) {
                    throw new FileNotFoundException("Missing internal moduli file");
                }
                moduliStr = moduli.toExternalForm();
                groups = Moduli.parseModuli(moduli);
            }
            catch (IOException e) {
                this.log.warn("Error (" + e.getClass().getSimpleName() + ") loading internal moduli from " + moduliStr + ": " + e.getMessage());
                throw e;
            }
        }
        this.log.debug("Loaded moduli groups from {}", (Object)moduliStr);
        return groups;
    }

    protected DHG getDH(BigInteger p, BigInteger g) throws Exception {
        return (DHG)this.factory.create(p, g);
    }
}

