/*
 * Decompiled with CFR 0.152.
 */
package ai.h2o.automl;

import ai.h2o.automl.ModelSelectionStrategy;
import ai.h2o.automl.leaderboard.Leaderboard;
import hex.Model;
import java.util.Arrays;
import java.util.function.Predicate;
import java.util.function.Supplier;
import org.apache.log4j.Logger;
import water.Key;
import water.util.ArrayUtils;

public final class ModelSelectionStrategies {
    private static final Logger LOG = Logger.getLogger(ModelSelectionStrategies.class);

    public static interface LeaderboardHolder {
        public Leaderboard get();

        default public void cleanup() {
        }
    }

    public static class KeepBestNFromSubgroup<M extends Model>
    extends LeaderboardBasedSelectionStrategy<M> {
        private final Predicate<Key<M>> _criterion;
        private final int _N;

        public KeepBestNFromSubgroup(int N2, Predicate<Key<M>> criterion, Supplier<LeaderboardHolder> leaderboardSupplier) {
            super(leaderboardSupplier);
            this._criterion = criterion;
            this._N = N2;
        }

        @Override
        public ModelSelectionStrategy.Selection<M> select(Key<M>[] originalModels, Key<M>[] newModels) {
            Key[] originalModelsSubgroup = (Key[])Arrays.stream(originalModels).filter(this._criterion).toArray(Key[]::new);
            Key[] newModelsSubGroup = (Key[])Arrays.stream(newModels).filter(this._criterion).toArray(Key[]::new);
            return new KeepBestN(this._N, this._leaderboardSupplier).select(originalModelsSubgroup, newModelsSubGroup);
        }
    }

    public static class KeepBestConstantSize<M extends Model>
    extends LeaderboardBasedSelectionStrategy<M> {
        public KeepBestConstantSize(Supplier<LeaderboardHolder> leaderboardSupplier) {
            super(leaderboardSupplier);
        }

        @Override
        public ModelSelectionStrategy.Selection<M> select(Key<M>[] originalModels, Key<M>[] newModels) {
            return new KeepBestN<M>(originalModels.length, this._leaderboardSupplier).select(originalModels, newModels);
        }
    }

    public static class KeepBestN<M extends Model>
    extends LeaderboardBasedSelectionStrategy<M> {
        private final int _N;

        public KeepBestN(int N2, Supplier<LeaderboardHolder> leaderboardSupplier) {
            super(leaderboardSupplier);
            this._N = N2;
        }

        @Override
        public ModelSelectionStrategy.Selection<M> select(Key<M>[] originalModels, Key<M>[] newModels) {
            LeaderboardHolder lbHolder = this.makeSelectionLeaderboard();
            Leaderboard tmpLeaderboard = lbHolder.get();
            tmpLeaderboard.addModels(originalModels);
            tmpLeaderboard.addModels(newModels);
            if (LOG.isDebugEnabled()) {
                LOG.debug((Object)tmpLeaderboard.toLogString());
            }
            Key<Model>[] sortedKeys = tmpLeaderboard.getModelKeys();
            Key[] bestN = ArrayUtils.subarray(sortedKeys, 0, Math.min(sortedKeys.length, this._N));
            Key[] toAdd = (Key[])Arrays.stream(bestN).filter(k2 -> !ArrayUtils.contains(originalModels, k2)).toArray(Key[]::new);
            Key[] toRemove = (Key[])Arrays.stream(originalModels).filter(k2 -> !ArrayUtils.contains(bestN, k2)).toArray(Key[]::new);
            ModelSelectionStrategy.Selection selection = new ModelSelectionStrategy.Selection(toAdd, toRemove);
            lbHolder.cleanup();
            return selection;
        }
    }

    public static abstract class LeaderboardBasedSelectionStrategy<M extends Model>
    implements ModelSelectionStrategy<M> {
        final Supplier<LeaderboardHolder> _leaderboardSupplier;

        public LeaderboardBasedSelectionStrategy(Supplier<LeaderboardHolder> leaderboardSupplier) {
            this._leaderboardSupplier = leaderboardSupplier;
        }

        LeaderboardHolder makeSelectionLeaderboard() {
            return this._leaderboardSupplier.get();
        }
    }
}

