/*
 * Decompiled with CFR 0.152.
 */
package com.tencent.tcvdbtext.encoder;

import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.tencent.tcvdbtext.encoder.BaseSparseEncoder;
import com.tencent.tcvdbtext.encoder.Bm25Parameter;
import com.tencent.tcvdbtext.tokenizer.BaseTokenizer;
import com.tencent.tcvdbtext.tokenizer.JiebaTokenizer;
import com.tencent.tcvdbtext.util.JsonUtils;
import java.io.BufferedReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.commons.lang3.tuple.Pair;

@JsonInclude(value=JsonInclude.Include.NON_NULL)
public class SparseVectorBm25Encoder
implements BaseSparseEncoder {
    @JsonIgnore
    private BaseTokenizer tokenizer;
    private Double b;
    private Double k1;
    private Map<String, Integer> tokenFreq;
    private Integer docCount;
    private Double averageDocLength;
    private Boolean enableStopWords;
    private Boolean lowerCase;

    public SparseVectorBm25Encoder() {
        this.tokenizer = new JiebaTokenizer();
        this.b = 0.75;
        this.k1 = 1.2;
    }

    public BaseTokenizer getTokenizer() {
        return this.tokenizer;
    }

    public void setTokenizer(BaseTokenizer tokenizer) {
        this.tokenizer = tokenizer;
    }

    public Double getB() {
        return this.b;
    }

    public Double getK1() {
        return this.k1;
    }

    public Map<String, Integer> getTokenFreq() {
        return this.tokenFreq;
    }

    public Integer getDocCount() {
        return this.docCount;
    }

    public Double getAverageDocLength() {
        return this.averageDocLength;
    }

    public void setB(Double b) {
        this.b = b;
    }

    public void setK1(Double k1) {
        this.k1 = k1;
    }

    public void setTokenFreq(Map<String, Integer> tokenFreq) {
        this.tokenFreq = tokenFreq;
    }

    public void setDocCount(Integer docCount) {
        this.docCount = docCount;
    }

    public void setAverageDocLength(Double averageDocLength) {
        this.averageDocLength = averageDocLength;
    }

    public Boolean getEnableStopWords() {
        return this.enableStopWords;
    }

    public Boolean getLowerCase() {
        return this.lowerCase;
    }

    public void setEnableStopWords(Boolean enableStopWords) {
        this.enableStopWords = enableStopWords;
        this.tokenizer.setEnableStopWords(enableStopWords);
    }

    public void setLowerCase(Boolean lowerCase) {
        this.lowerCase = lowerCase;
        this.tokenizer.setLowerCase(lowerCase);
    }

    public static SparseVectorBm25Encoder getBm25Encoder(String language) {
        String path;
        if (language.equals("zh")) {
            path = "data/bm25_zh_default.json";
        } else if (language.equals("en")) {
            path = "data/bm25_en_default.json";
        } else {
            throw new IllegalArgumentException("language must be zh or en");
        }
        SparseVectorBm25Encoder sparseVectorBm25Encoder = new SparseVectorBm25Encoder(new JiebaTokenizer(), 0.75, 1.2);
        sparseVectorBm25Encoder.setParams(path);
        return sparseVectorBm25Encoder;
    }

    public static SparseVectorBm25Encoder getDefaultBm25Encoder() {
        return SparseVectorBm25Encoder.getBm25Encoder("zh");
    }

    public SparseVectorBm25Encoder(BaseTokenizer tokenizer, Double b, Double k1) {
        this.tokenizer = tokenizer;
        this.b = b;
        this.k1 = k1;
    }

    private List<Pair<Long, Integer>> getTokenTF(String text) {
        List<Long> tokens = this.tokenizer.encode(text);
        LinkedHashMap<Long, Integer> tokenFreq = new LinkedHashMap<Long, Integer>();
        for (Long token : tokens) {
            if (tokenFreq.containsKey(token)) {
                tokenFreq.put(token, (Integer)tokenFreq.get(token) + 1);
                continue;
            }
            tokenFreq.put(token, 1);
        }
        ArrayList<Pair<Long, Integer>> result = new ArrayList<Pair<Long, Integer>>();
        tokenFreq.forEach((k, v) -> result.add(Pair.of((Object)k, (Object)v)));
        return result;
    }

    @Override
    public List<List<Pair<Long, Float>>> encodeTexts(List<String> texts) {
        if (texts == null || texts.isEmpty()) {
            throw new IllegalArgumentException("texts is empty");
        }
        if (this.tokenFreq == null || this.docCount == null || this.averageDocLength == null) {
            throw new IllegalArgumentException("BM25 must be fit before encoding documents");
        }
        ArrayList<List<Pair<Long, Float>>> sparseVectors = new ArrayList<List<Pair<Long, Float>>>();
        for (String text : texts) {
            List<Pair<Long, Integer>> tokensPairs = this.getTokenTF(text);
            Integer tfSum = tokensPairs.stream().map(Pair::getRight).reduce(0, Integer::sum);
            ArrayList<Pair> sparseVector = new ArrayList<Pair>();
            for (Pair<Long, Integer> token : tokensPairs) {
                Integer freq = (Integer)token.getValue();
                double score = ((double)freq.intValue() + 0.0) / (this.k1 * (1.0 - this.b + this.b * ((double)tfSum.intValue() / this.averageDocLength)) + (double)freq.intValue());
                sparseVector.add(Pair.of((Object)token.getKey(), (Object)Float.valueOf((float)score)));
            }
            sparseVectors.add(sparseVector);
        }
        return sparseVectors;
    }

    @Override
    public List<List<Pair<Long, Float>>> encodeQueries(List<String> texts) {
        if (this.tokenFreq == null || this.docCount == null || this.averageDocLength == null) {
            throw new IllegalArgumentException("BM25 must be fit before encoding documents");
        }
        ArrayList<List<Pair<Long, Float>>> sparseVectors = new ArrayList<List<Pair<Long, Float>>>();
        for (String text : texts) {
            List<Pair<Long, Integer>> tokensPairs = this.getTokenTF(text);
            List df = tokensPairs.stream().map(key -> this.tokenFreq.getOrDefault(((Long)key.getKey()).toString(), 1)).collect(Collectors.toList());
            List idfs = df.stream().map(idf -> Math.log((double)(this.docCount + 1) / ((double)idf.intValue() + 0.5))).collect(Collectors.toList());
            Double idfSum = idfs.stream().reduce(0.0, Double::sum);
            ArrayList<Pair> sparseVector = new ArrayList<Pair>();
            for (int i = 0; i < tokensPairs.size(); ++i) {
                sparseVector.add(Pair.of((Object)tokensPairs.get(i).getKey(), (Object)Float.valueOf((float)((Double)idfs.get(i) / idfSum))));
            }
            sparseVectors.add(sparseVector);
        }
        return sparseVectors;
    }

    @Override
    public void fitCorpus(List<String> texts) {
        if (texts == null || texts.isEmpty()) {
            throw new IllegalArgumentException("texts is empty");
        }
        HashMap<String, Integer> tokenFreq = new HashMap<String, Integer>();
        int docNum = 0;
        int sumDocLen = 0;
        ArrayList<Integer> docLengths = new ArrayList<Integer>();
        for (String text : texts) {
            List<Pair<Long, Integer>> tokens = this.getTokenTF(text);
            ++docNum;
            sumDocLen += tokens.stream().map(Pair::getRight).reduce(0, Integer::sum).intValue();
            for (Pair<Long, Integer> token : tokens) {
                if (tokenFreq.containsKey(((Long)token.getKey()).toString())) {
                    tokenFreq.put(((Long)token.getKey()).toString(), (Integer)tokenFreq.get(((Long)token.getKey()).toString()) + 1);
                    continue;
                }
                tokenFreq.put(((Long)token.getKey()).toString(), 1);
            }
            docLengths.add(tokens.size());
        }
        if (this.tokenFreq == null || this.docCount == null || this.averageDocLength == null) {
            this.tokenFreq = tokenFreq;
            this.docCount = docNum;
            this.averageDocLength = (double)docLengths.stream().reduce(0, Integer::sum).intValue() / (double)docLengths.size();
        } else {
            Iterator iterator = this;
            ((SparseVectorBm25Encoder)((Object)iterator)).docCount = ((SparseVectorBm25Encoder)((Object)iterator)).docCount + docNum;
            this.averageDocLength = (this.averageDocLength * (double)this.docCount.intValue() + (double)sumDocLen) / (double)(this.docCount + docNum);
            for (String token : tokenFreq.keySet()) {
                if (this.tokenFreq.containsKey(token)) {
                    this.tokenFreq.put(token, this.tokenFreq.get(token) + (Integer)tokenFreq.get(token));
                    continue;
                }
                this.tokenFreq.put(token, (Integer)tokenFreq.get(token));
            }
        }
    }

    @Override
    public void downloadParams(String paramsFilePath) {
        try {
            FileWriter writer = new FileWriter(paramsFilePath);
            writer.write(JsonUtils.toJsonString(this));
            writer.close();
        }
        catch (IOException e) {
            e.printStackTrace();
        }
    }

    @Override
    public void setParams(String paramsFile) {
        InputStream inputStream = this.getClass().getClassLoader().getResourceAsStream(paramsFile);
        try {
            String line;
            StringBuilder fileContent = new StringBuilder();
            BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream));
            while ((line = reader.readLine()) != null) {
                fileContent.append(line);
            }
            Bm25Parameter bm25Parameter = JsonUtils.parseObject(fileContent.toString(), Bm25Parameter.class);
            this.tokenFreq = bm25Parameter.getTokenFreq();
            this.docCount = bm25Parameter.getDocCount();
            this.averageDocLength = bm25Parameter.getAverageDocLength();
            this.b = bm25Parameter.getB();
            this.k1 = bm25Parameter.getK1();
            this.enableStopWords = bm25Parameter.getStopWords();
            this.setEnableStopWords(this.enableStopWords);
            this.setLowerCase(bm25Parameter.getLowerCase());
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public void setDict(String dictFile) {
        this.tokenizer.loadDict(dictFile);
    }

    public static class Builder {
        private BaseTokenizer tokenizer;
        private Double b;
        private Double k1;

        public Builder withTokenizer(BaseTokenizer tokenizer) {
            this.tokenizer = tokenizer;
            return this;
        }

        public Builder withB(Double b) {
            this.b = b;
            return this;
        }

        public Builder withK1(Double k1) {
            this.k1 = k1;
            return this;
        }

        public SparseVectorBm25Encoder build() {
            return new SparseVectorBm25Encoder(this.tokenizer, this.b, this.k1);
        }
    }
}

