/*
 * Decompiled with CFR 0.152.
 */
package com.antgroup.geaflow.state.action.hook;

import com.antgroup.geaflow.common.config.Configuration;
import com.antgroup.geaflow.common.config.keys.ExecutionConfigKeys;
import com.antgroup.geaflow.file.FileConfigKeys;
import com.antgroup.geaflow.file.IPersistentIO;
import com.antgroup.geaflow.file.PersistentIOBuilder;
import com.antgroup.geaflow.state.action.ActionRequest;
import com.antgroup.geaflow.state.action.ActionType;
import com.antgroup.geaflow.state.action.hook.ActionHook;
import com.antgroup.geaflow.state.context.StateContext;
import com.antgroup.geaflow.state.strategy.accessor.IAccessor;
import com.antgroup.geaflow.store.ILocalStore;
import com.antgroup.geaflow.utils.keygroup.KeyGroup;
import com.google.common.base.Preconditions;
import com.google.common.base.Splitter;
import com.google.common.math.IntMath;
import java.io.IOException;
import java.nio.file.Paths;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import org.apache.hadoop.fs.Path;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ScaleHook
implements ActionHook {
    private static final Logger LOGGER = LoggerFactory.getLogger(ScaleHook.class);
    private Map<Integer, IAccessor> accessorMap;
    private KeyGroup shardGroup;
    private int totalShardNum;
    private ScaleManager scaleManager;

    @Override
    public void init(StateContext context, Map<Integer, IAccessor> accessorMap) {
        this.shardGroup = context.getKeyGroup();
        this.totalShardNum = context.getTotalShardNum();
        this.accessorMap = accessorMap;
        this.scaleManager = new ScaleManager(context.getName(), context.getConfig(), this.totalShardNum);
    }

    @Override
    public void doStoreAction(ActionType actionType, ActionRequest request) {
        if (this.shardGroup.getStartKeyGroup() == 0 && actionType == ActionType.ARCHIVE) {
            this.scaleManager.tryStoreShardNum((Long)request.getRequest());
        }
        if (actionType == ActionType.RECOVER && this.scaleManager.needScale((Long)request.getRequest())) {
            for (Map.Entry<Integer, IAccessor> entry : this.accessorMap.entrySet()) {
                int recoverShardId = this.scaleManager.getRecoverShardId(entry.getKey());
                ((ILocalStore)entry.getValue().getStore()).initShardId(recoverShardId);
            }
        }
    }

    public static class ScaleManager {
        private static final String PARTITION_HEADER = "shardNum#";
        private static final String PARTITION_FILE_FORMAT = "shardNum#%d#%d";
        private static final int PARTITION_HEADER_LEN = "shardNum#".length();
        private final IPersistentIO persistIO;
        private boolean hasStored = false;
        private final String remotePath;
        private final int shardNum;
        private int lastShardNum = -1;

        public ScaleManager(String name, Configuration configuration, int shardNum) {
            this.persistIO = PersistentIOBuilder.build((Configuration)configuration);
            String jobName = configuration.getString(ExecutionConfigKeys.JOB_APP_NAME);
            String root = configuration.getString(FileConfigKeys.ROOT);
            this.remotePath = Paths.get(root, jobName, name).toString();
            this.shardNum = shardNum;
        }

        public boolean needScale(long version) {
            this.lastShardNum = this.getShardNumByVersion(version);
            if (this.lastShardNum == 0) {
                return false;
            }
            Preconditions.checkArgument((this.shardNum >= this.lastShardNum ? 1 : 0) != 0, (String)"partitionNum %s, lastPartitionNum %s", (int)this.shardNum, (int)this.lastShardNum);
            return this.shardNum > this.lastShardNum;
        }

        public int getRecoverShardId(int shardId) {
            Preconditions.checkArgument((this.shardNum % this.lastShardNum == 0 && IntMath.isPowerOfTwo((int)(this.shardNum / this.lastShardNum)) ? 1 : 0) != 0, (String)"shardNum %s, lastShardNum %s", (int)shardId, (int)this.lastShardNum);
            int recoverShard = shardId % this.lastShardNum;
            LOGGER.info("auto scale state shard {} recover {}", (Object)shardId, (Object)recoverShard);
            return recoverShard;
        }

        public void tryStoreShardNum(long version) {
            this.lastShardNum = this.getShardNumByVersion(version);
            if (this.shardNum > this.lastShardNum && !this.hasStored) {
                try {
                    this.persistIO.createNewFile(new Path(this.remotePath, String.format(PARTITION_FILE_FORMAT, this.shardNum, version)));
                    this.hasStored = true;
                }
                catch (IOException e) {
                    throw new RuntimeException(e);
                }
            }
        }

        private int getShardNumByVersion(List<String> fileNames, long version) {
            TreeMap<Long, Integer> treeMap = new TreeMap<Long, Integer>();
            for (String fileName : fileNames) {
                if (!fileName.startsWith(PARTITION_HEADER)) continue;
                List l = Splitter.on((char)'#').splitToList((CharSequence)fileName.substring(PARTITION_HEADER_LEN));
                Preconditions.checkArgument((l.size() == 2 ? 1 : 0) != 0);
                int shardNum = Integer.parseInt((String)l.get(0));
                long tmpVersion = Long.parseLong((String)l.get(1));
                if (tmpVersion > version) continue;
                treeMap.put(tmpVersion, shardNum);
            }
            Map.Entry entry = treeMap.floorEntry(version);
            if (entry != null) {
                return (Integer)entry.getValue();
            }
            return 0;
        }

        public int getShardNumByVersion(long version) {
            List list;
            if (this.lastShardNum >= 0) {
                return this.lastShardNum;
            }
            try {
                if (!this.persistIO.exists(new Path(this.remotePath))) {
                    return 0;
                }
                list = this.persistIO.listFile(new Path(this.remotePath));
            }
            catch (IOException e) {
                throw new RuntimeException(e);
            }
            return this.getShardNumByVersion(list, version);
        }
    }
}

