/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.conf.layers;

import java.util.Map;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
import org.deeplearning4j.nn.conf.layers.samediff.SDLayerParams;
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer;
import org.deeplearning4j.nn.weights.WeightInitUtil;
import org.nd4j.autodiff.samediff.SDIndex;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;

public class LearnedSelfAttentionLayer
extends SameDiffLayer {
    private long nIn;
    private long nOut;
    private int nHeads;
    private long headSize;
    private boolean projectInput;
    private int nQueries;
    private static final String WEIGHT_KEY_QUERY_PROJECTION = "Wq";
    private static final String WEIGHT_KEY_KEY_PROJECTION = "Wk";
    private static final String WEIGHT_KEY_VALUE_PROJECTION = "Wv";
    private static final String WEIGHT_KEY_OUT_PROJECTION = "Wo";
    private static final String WEIGHT_QUERIES = "Q";

    private LearnedSelfAttentionLayer() {
    }

    protected LearnedSelfAttentionLayer(Builder builder) {
        super(builder);
        this.nIn = builder.nIn;
        this.nOut = builder.nOut;
        this.nHeads = builder.nHeads;
        this.headSize = builder.headSize == 0 ? this.nOut / (long)this.nHeads : (long)builder.headSize;
        this.projectInput = builder.projectInput;
        this.nQueries = builder.nQueries;
    }

    @Override
    public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
        return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, this.getLayerName());
    }

    @Override
    public void setNIn(InputType inputType, boolean override) {
        if (inputType == null || inputType.getType() != InputType.Type.RNN) {
            throw new IllegalStateException("Invalid input for Learned Self Attention layer (layer name = \"" + this.getLayerName() + "\"): expect RNN input type with size > 0. Got: " + inputType);
        }
        if (this.nIn <= 0L || override) {
            InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent)inputType;
            this.nIn = r.getSize();
        }
    }

    @Override
    public InputType getOutputType(int layerIndex, InputType inputType) {
        if (inputType == null || inputType.getType() != InputType.Type.RNN) {
            throw new IllegalStateException("Invalid input for Learned Self Attention layer (layer index = " + layerIndex + ", layer name = \"" + this.getLayerName() + "\"): expect RNN input type with size > 0. Got: " + inputType);
        }
        if (this.projectInput) {
            return InputType.recurrent(this.nOut, this.nQueries);
        }
        return InputType.recurrent(this.nIn, this.nQueries);
    }

    @Override
    public void defineParameters(SDLayerParams params) {
        params.clear();
        params.addWeightParam(WEIGHT_QUERIES, 1L, this.nIn, this.nQueries);
        if (this.projectInput) {
            params.addWeightParam(WEIGHT_KEY_QUERY_PROJECTION, this.nHeads, this.headSize, this.nIn);
            params.addWeightParam(WEIGHT_KEY_KEY_PROJECTION, this.nHeads, this.headSize, this.nIn);
            params.addWeightParam(WEIGHT_KEY_VALUE_PROJECTION, this.nHeads, this.headSize, this.nIn);
            params.addWeightParam(WEIGHT_KEY_OUT_PROJECTION, (long)this.nHeads * this.headSize, this.nOut);
        }
    }

    @Override
    public void initializeParameters(Map<String, INDArray> params) {
        try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();){
            for (Map.Entry<String, INDArray> e : params.entrySet()) {
                if (e.getKey().equals(WEIGHT_KEY_OUT_PROJECTION)) {
                    WeightInitUtil.initWeights((double)this.nIn, (double)this.headSize, e.getValue().shape(), this.weightInit, null, 'c', e.getValue());
                    continue;
                }
                if (e.getKey().equals(WEIGHT_QUERIES)) {
                    WeightInitUtil.initWeights((double)this.nIn, (double)this.nQueries, e.getValue().shape(), this.weightInit, null, 'c', e.getValue());
                    continue;
                }
                WeightInitUtil.initWeights((double)((long)this.nHeads * this.headSize), (double)this.nOut, e.getValue().shape(), this.weightInit, null, 'c', e.getValue());
            }
        }
    }

    @Override
    public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput, Map<String, SDVariable> paramTable, SDVariable mask) {
        SDVariable baseQueries = paramTable.get(WEIGHT_QUERIES);
        SDVariable batchSize = layerInput.shape().get(new SDIndex[]{SDIndex.point((long)0L)});
        SDVariable tileAxis = sameDiff.scatterUpdate(sameDiff.onesLike(layerInput.shape()), sameDiff.constant(0), batchSize);
        SDVariable queries = sameDiff.tile(baseQueries, tileAxis);
        if (this.projectInput) {
            SDVariable Wq = paramTable.get(WEIGHT_KEY_QUERY_PROJECTION);
            SDVariable Wk = paramTable.get(WEIGHT_KEY_KEY_PROJECTION);
            SDVariable Wv = paramTable.get(WEIGHT_KEY_VALUE_PROJECTION);
            SDVariable Wo = paramTable.get(WEIGHT_KEY_OUT_PROJECTION);
            return sameDiff.nn.multiHeadDotProductAttention(this.getLayerName(), queries, layerInput, layerInput, Wq, Wk, Wv, Wo, mask, true);
        }
        return sameDiff.nn.dotProductAttention(this.getLayerName(), queries, layerInput, layerInput, mask, true);
    }

    @Override
    public Pair<INDArray, MaskState> feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState, int minibatchSize) {
        return null;
    }

    public long getNIn() {
        return this.nIn;
    }

    public long getNOut() {
        return this.nOut;
    }

    public int getNHeads() {
        return this.nHeads;
    }

    public long getHeadSize() {
        return this.headSize;
    }

    public boolean isProjectInput() {
        return this.projectInput;
    }

    public int getNQueries() {
        return this.nQueries;
    }

    public void setNIn(long nIn) {
        this.nIn = nIn;
    }

    public void setNOut(long nOut) {
        this.nOut = nOut;
    }

    public void setNHeads(int nHeads) {
        this.nHeads = nHeads;
    }

    public void setHeadSize(long headSize) {
        this.headSize = headSize;
    }

    public void setProjectInput(boolean projectInput) {
        this.projectInput = projectInput;
    }

    public void setNQueries(int nQueries) {
        this.nQueries = nQueries;
    }

    @Override
    public String toString() {
        return "LearnedSelfAttentionLayer(nIn=" + this.getNIn() + ", nOut=" + this.getNOut() + ", nHeads=" + this.getNHeads() + ", headSize=" + this.getHeadSize() + ", projectInput=" + this.isProjectInput() + ", nQueries=" + this.getNQueries() + ")";
    }

    @Override
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof LearnedSelfAttentionLayer)) {
            return false;
        }
        LearnedSelfAttentionLayer other = (LearnedSelfAttentionLayer)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (!super.equals(o)) {
            return false;
        }
        if (this.getNIn() != other.getNIn()) {
            return false;
        }
        if (this.getNOut() != other.getNOut()) {
            return false;
        }
        if (this.getNHeads() != other.getNHeads()) {
            return false;
        }
        if (this.getHeadSize() != other.getHeadSize()) {
            return false;
        }
        if (this.isProjectInput() != other.isProjectInput()) {
            return false;
        }
        return this.getNQueries() == other.getNQueries();
    }

    @Override
    protected boolean canEqual(Object other) {
        return other instanceof LearnedSelfAttentionLayer;
    }

    @Override
    public int hashCode() {
        int PRIME = 59;
        int result = super.hashCode();
        long $nIn = this.getNIn();
        result = result * 59 + (int)($nIn >>> 32 ^ $nIn);
        long $nOut = this.getNOut();
        result = result * 59 + (int)($nOut >>> 32 ^ $nOut);
        result = result * 59 + this.getNHeads();
        long $headSize = this.getHeadSize();
        result = result * 59 + (int)($headSize >>> 32 ^ $headSize);
        result = result * 59 + (this.isProjectInput() ? 79 : 97);
        result = result * 59 + this.getNQueries();
        return result;
    }

    public static class Builder
    extends SameDiffLayer.Builder<Builder> {
        private int nIn;
        private int nOut;
        private int nHeads;
        private int headSize;
        private boolean projectInput;
        private int nQueries;

        public Builder nIn(int nIn) {
            this.nIn = nIn;
            return this;
        }

        public Builder nOut(int nOut) {
            this.nOut = nOut;
            return this;
        }

        public Builder nHeads(int nHeads) {
            this.nHeads = nHeads;
            return this;
        }

        public Builder headSize(int headSize) {
            this.headSize = headSize;
            return this;
        }

        public Builder projectInput(boolean projectInput) {
            this.projectInput = projectInput;
            return this;
        }

        public Builder nQueries(int nQueries) {
            this.nQueries = nQueries;
            return this;
        }

        @Override
        public LearnedSelfAttentionLayer build() {
            Preconditions.checkArgument((this.projectInput || this.nHeads == 1 ? 1 : 0) != 0, (String)"projectInput must be true when nHeads != 1");
            Preconditions.checkArgument((this.projectInput || this.nIn == this.nOut ? 1 : 0) != 0, (String)"nIn must be equal to nOut when projectInput is false");
            Preconditions.checkArgument((!this.projectInput || this.nOut != 0 ? 1 : 0) != 0, (String)"nOut must be specified when projectInput is true");
            Preconditions.checkArgument((this.nOut % this.nHeads == 0 || this.headSize > 0 ? 1 : 0) != 0, (String)"nOut isn't divided by nHeads cleanly. Specify the headSize manually.");
            Preconditions.checkArgument((this.nQueries > 0 ? 1 : 0) != 0, (String)"You must set numQueries.");
            return new LearnedSelfAttentionLayer(this);
        }

        public int getNIn() {
            return this.nIn;
        }

        public int getNOut() {
            return this.nOut;
        }

        public int getNHeads() {
            return this.nHeads;
        }

        public int getHeadSize() {
            return this.headSize;
        }

        public boolean isProjectInput() {
            return this.projectInput;
        }

        public int getNQueries() {
            return this.nQueries;
        }

        public void setNIn(int nIn) {
            this.nIn = nIn;
        }

        public void setNOut(int nOut) {
            this.nOut = nOut;
        }

        public void setNHeads(int nHeads) {
            this.nHeads = nHeads;
        }

        public void setHeadSize(int headSize) {
            this.headSize = headSize;
        }

        public void setProjectInput(boolean projectInput) {
            this.projectInput = projectInput;
        }

        public void setNQueries(int nQueries) {
            this.nQueries = nQueries;
        }
    }
}

