/*
 * Decompiled with CFR 0.152.
 */
package org.kramerlab.autoencoder;

import java.lang.reflect.Field;
import org.kramerlab.autoencoder.experiments.Metaparameters;
import org.kramerlab.autoencoder.experiments.Metaparameters$;
import org.kramerlab.autoencoder.math.matrix.Mat;
import org.kramerlab.autoencoder.math.optimization.CrossEntropyErrorFunctionFactory$;
import org.kramerlab.autoencoder.math.optimization.DifferentiableErrorFunctionFactory;
import org.kramerlab.autoencoder.math.optimization.SquareErrorFunctionFactory$;
import org.kramerlab.autoencoder.neuralnet.BiasedUnitLayer;
import org.kramerlab.autoencoder.neuralnet.FullBipartiteConnection;
import org.kramerlab.autoencoder.neuralnet.Layer;
import org.kramerlab.autoencoder.neuralnet.autoencoder.Autoencoder;
import org.kramerlab.autoencoder.neuralnet.rbm.BernoulliUnitLayer;
import org.kramerlab.autoencoder.neuralnet.rbm.GaussianUnitLayer;
import org.kramerlab.autoencoder.neuralnet.rbm.Rbm;
import org.kramerlab.autoencoder.neuralnet.rbm.RbmLayer;
import org.kramerlab.autoencoder.neuralnet.rbm.RbmStack;
import org.kramerlab.autoencoder.neuralnet.rbm.RbmTrainingStrategy;
import org.kramerlab.autoencoder.package$;
import org.kramerlab.autoencoder.visualization.TrainingObserver;
import scala.Array$;
import scala.Function0;
import scala.Function1;
import scala.Function2;
import scala.MatchError;
import scala.Predef$;
import scala.Serializable;
import scala.Tuple2;
import scala.collection.GenIterable;
import scala.collection.JavaConversions$;
import scala.collection.Seq;
import scala.collection.TraversableLike;
import scala.collection.TraversableOnce;
import scala.collection.immutable.IndexedSeq$;
import scala.collection.immutable.List;
import scala.collection.immutable.List$;
import scala.collection.immutable.Nil$;
import scala.collection.immutable.Stream;
import scala.collection.immutable.Stream$;
import scala.collection.parallel.ForkJoinTaskSupport;
import scala.concurrent.forkjoin.ForkJoinPool;
import scala.reflect.ClassTag$;
import scala.runtime.BoxesRunTime;
import scala.runtime.RichDouble$;
import scala.runtime.RichInt$;

public final class package$ {
    public static final package$ MODULE$;
    private final int Linear;
    private final int Sigmoid;
    private final List<TrainingObserver> NoObservers;
    private final Function0<RbmTrainingStrategy> NoPretraining;
    private final Function0<RbmTrainingStrategy> HintonsMiraculousStrategy;
    private final Function0<RbmTrainingStrategy> RandomRetryStrategy;
    private final Function0<RbmTrainingStrategy> TournamentStrategy;

    static {
        new package$();
    }

    public int Linear() {
        return this.Linear;
    }

    public int Sigmoid() {
        return this.Sigmoid;
    }

    public List<TrainingObserver> NoObservers() {
        return this.NoObservers;
    }

    public Function0<RbmTrainingStrategy> NoPretraining() {
        return this.NoPretraining;
    }

    public Function0<RbmTrainingStrategy> HintonsMiraculousStrategy() {
        return this.HintonsMiraculousStrategy;
    }

    public Function0<RbmTrainingStrategy> RandomRetryStrategy() {
        return this.RandomRetryStrategy;
    }

    public Function0<RbmTrainingStrategy> TournamentStrategy() {
        return this.TournamentStrategy;
    }

    public Autoencoder trainAutoencoder(Mat data, int compressionDimension, int numberOfHiddenLayers, boolean useL2Error, Function0<RbmTrainingStrategy> trainingStrategyFactory, List<TrainingObserver> observers) {
        Metaparameters params = new Metaparameters(data.width(), compressionDimension, numberOfHiddenLayers, false);
        DifferentiableErrorFunctionFactory<Mat> errorFunctionFactory = useL2Error ? SquareErrorFunctionFactory$.MODULE$ : CrossEntropyErrorFunctionFactory$.MODULE$;
        return this.trainAutoencoder(params.layerTypes(), params.layerDims(), data, (List<RbmTrainingStrategy>)((List)List$.MODULE$.fill(params.numHidLayers() + 1, trainingStrategyFactory)), errorFunctionFactory, 5000, observers);
    }

    private Autoencoder trainAutoencoder(int[] layerTypes2, int[] layerDims2, Mat data, List<RbmTrainingStrategy> rbmTrainingStrategies, DifferentiableErrorFunctionFactory<Mat> errorFunctionFactory, int maxEvals, List<TrainingObserver> trainingObservers) {
        Tuple2[] layerDescriptions = (Tuple2[])Predef$.MODULE$.intArrayOps(layerTypes2).zip((GenIterable)Predef$.MODULE$.wrapIntArray(layerDims2), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class)));
        List rbms = Predef$.MODULE$.refArrayOps((Object[])Predef$.MODULE$.refArrayOps((Object[])Predef$.MODULE$.refArrayOps((Object[])layerDescriptions).zip((GenIterable)Predef$.MODULE$.wrapRefArray((Object[])Predef$.MODULE$.refArrayOps((Object[])layerDescriptions).tail()), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class)))).withFilter((Function1)new Serializable(){
            public static final long serialVersionUID = 0L;

            /*
             * Enabled force condition propagation
             * Lifted jumps to return sites
             */
            public final boolean apply(Tuple2<Tuple2<Object, Object>, Tuple2<Object, Object>> check$ifrefutable$1) {
                Tuple2<Tuple2<Object, Object>, Tuple2<Object, Object>> tuple2 = check$ifrefutable$1;
                if (tuple2 == null) return false;
                Tuple2 tuple22 = (Tuple2)tuple2._1();
                Tuple2 tuple23 = (Tuple2)tuple2._2();
                if (tuple22 == null) return false;
                if (tuple23 == null) return false;
                return true;
            }
        }).map((Function1)new Serializable(){
            public static final long serialVersionUID = 0L;

            public final Rbm apply(Tuple2<Tuple2<Object, Object>, Tuple2<Object, Object>> x$1) {
                Tuple2<Tuple2<Object, Object>, Tuple2<Object, Object>> tuple2 = x$1;
                if (tuple2 != null) {
                    Tuple2 tuple22 = (Tuple2)tuple2._1();
                    Tuple2 tuple23 = (Tuple2)tuple2._2();
                    if (tuple22 != null) {
                        int visType = tuple22._1$mcI$sp();
                        int visDim = tuple22._2$mcI$sp();
                        if (tuple23 != null) {
                            int hidType = tuple23._1$mcI$sp();
                            int hidDim = tuple23._2$mcI$sp();
                            Rbm rbm = new Rbm((RbmLayer)((Object)package$.MODULE$.mkLayer(visType, visDim)), new FullBipartiteConnection(visDim, hidDim), (RbmLayer)((Object)package$.MODULE$.mkLayer(hidType, hidDim)));
                            return rbm;
                        }
                    }
                }
                throw new MatchError(tuple2);
            }
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Rbm.class)))).toList();
        RbmStack stack = new RbmStack((List<Rbm>)rbms);
        RbmStack trainedStack = stack.train(data, rbmTrainingStrategies, trainingObservers);
        Autoencoder autoencoder = trainedStack.unfold();
        Autoencoder fineTunedAutoencoder = (Autoencoder)autoencoder.optimize(data, data, errorFunctionFactory, 0.33, maxEvals, trainingObservers);
        return fineTunedAutoencoder;
    }

    private List<RbmTrainingStrategy> trainAutoencoder$default$4() {
        return Nil$.MODULE$;
    }

    private DifferentiableErrorFunctionFactory<Mat> trainAutoencoder$default$5() {
        return SquareErrorFunctionFactory$.MODULE$;
    }

    public BiasedUnitLayer mkLayer(int layerType, int layerDim) {
        int n;
        block4: {
            BiasedUnitLayer biasedUnitLayer;
            block3: {
                block2: {
                    n = layerType;
                    if (this.Linear() != n) break block2;
                    biasedUnitLayer = new GaussianUnitLayer(layerDim);
                    break block3;
                }
                if (this.Sigmoid() != n) break block4;
                biasedUnitLayer = new BernoulliUnitLayer(layerDim);
            }
            return biasedUnitLayer;
        }
        throw new MatchError((Object)BoxesRunTime.boxToInteger((int)n));
    }

    public Stream<Autoencoder> deepAutoencoderStream(int layerType, int maxDepth, Seq<Object> hiddenLayerDims, Mat data, boolean useL2Error, Function0<RbmTrainingStrategy> pretrainingStrategyFactory, boolean finetuneInnerLayers, List<TrainingObserver> trainingObservers) {
        DifferentiableErrorFunctionFactory<Mat> errorFunctionFactory = useL2Error ? SquareErrorFunctionFactory$.MODULE$ : CrossEntropyErrorFunctionFactory$.MODULE$;
        BiasedUnitLayer inputLayer = this.mkLayer(layerType, data.width());
        Autoencoder trivialAutoencoder = new Autoencoder((List<Layer>)List$.MODULE$.apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new BiasedUnitLayer[]{inputLayer})));
        return ((Stream)((TraversableLike)hiddenLayerDims.toStream().scanLeft((Object)trivialAutoencoder, (Function2)new Serializable(layerType, data, pretrainingStrategyFactory, finetuneInnerLayers, trainingObservers, errorFunctionFactory){
            public static final long serialVersionUID = 0L;
            private final int layerType$1;
            private final Mat data$1;
            private final Function0 pretrainingStrategyFactory$1;
            private final boolean finetuneInnerLayers$1;
            private final List trainingObservers$1;
            private final DifferentiableErrorFunctionFactory errorFunctionFactory$1;

            /*
             * WARNING - void declaration
             */
            public final Autoencoder apply(Autoencoder a, int d) {
                void var3_3;
                Autoencoder nextAutoencoder = a.unfoldCentralLayer(this.layerType$1, d, (RbmTrainingStrategy)this.pretrainingStrategyFactory$1.apply(), this.data$1, this.errorFunctionFactory$1, this.finetuneInnerLayers$1, (List<TrainingObserver>)this.trainingObservers$1);
                return var3_3;
            }
            {
                this.layerType$1 = layerType$1;
                this.data$1 = data$1;
                this.pretrainingStrategyFactory$1 = pretrainingStrategyFactory$1;
                this.finetuneInnerLayers$1 = finetuneInnerLayers$1;
                this.trainingObservers$1 = trainingObservers$1;
                this.errorFunctionFactory$1 = errorFunctionFactory$1;
            }
        }, Stream$.MODULE$.canBuildFrom())).tail()).take(maxDepth);
    }

    public Iterable<Autoencoder> deepAutoencoderStream_java(int layerType, int maxDepth, double compressionFactor, Mat data, boolean useL2Error, Function0<RbmTrainingStrategy> pretrainingStrategyFactory, boolean finetuneInnerLayers, List<TrainingObserver> trainingObservers) {
        return JavaConversions$.MODULE$.asJavaIterable(this.deepAutoencoderStream(layerType, maxDepth, (Seq<Object>)((Stream)scala.package$.MODULE$.Stream().iterate((Object)BoxesRunTime.boxToInteger((int)data.width()), (Function1)new Serializable(compressionFactor){
            public static final long serialVersionUID = 0L;
            private final double compressionFactor$1;

            public final int apply(int x$2) {
                return this.apply$mcII$sp(x$2);
            }

            public int apply$mcII$sp(int x$2) {
                return Metaparameters$.MODULE$.nextLayerDimension(this.compressionFactor$1, x$2);
            }
            {
                this.compressionFactor$1 = compressionFactor$1;
            }
        }).tail()).take(maxDepth), data, useL2Error, pretrainingStrategyFactory, finetuneInnerLayers, trainingObservers));
    }

    public Autoencoder trainAutoencoder_Stream(Mat data, int compressionDimension, int numberOfHiddenLayers, boolean useL2Error, Function0<RbmTrainingStrategy> trainingStrategyFactory, List<TrainingObserver> observers) {
        Metaparameters params = new Metaparameters(data.width(), compressionDimension, numberOfHiddenLayers, false);
        return (Autoencoder)this.deepAutoencoderStream(this.Sigmoid(), numberOfHiddenLayers, (Seq<Object>)Predef$.MODULE$.wrapIntArray(params.layerDims()), data, useL2Error, trainingStrategyFactory, true, observers).last();
    }

    public List<Object> layerDims(int numVis, int numHid, int n, double alpha) {
        return ((TraversableOnce)RichInt$.MODULE$.to$extension0(Predef$.MODULE$.intWrapper(0), n).map((Function1)new Serializable(numVis, numHid, n, alpha){
            public static final long serialVersionUID = 0L;
            private final int numVis$1;
            private final int numHid$1;
            private final int n$1;
            private final double alpha$1;

            public final int apply(int k) {
                return this.apply$mcII$sp(k);
            }

            public int apply$mcII$sp(int k) {
                return (int)RichDouble$.MODULE$.round$extension(Predef$.MODULE$.doubleWrapper((double)this.numHid$1 + (double)(this.numVis$1 - this.numHid$1) * scala.math.package$.MODULE$.pow(1.0 - scala.math.package$.MODULE$.pow((double)k / (double)this.n$1, this.alpha$1), 1.0 / this.alpha$1)));
            }
            {
                this.numVis$1 = numVis$1;
                this.numHid$1 = numHid$1;
                this.n$1 = n$1;
                this.alpha$1 = alpha$1;
            }
        }, IndexedSeq$.MODULE$.canBuildFrom())).toList();
    }

    public void setParallelismGlobally(int numThreads) {
        scala.collection.parallel.package$ parPkgObj = scala.collection.parallel.package$.MODULE$;
        Field defaultTaskSupportField = (Field)Predef$.MODULE$.refArrayOps((Object[])parPkgObj.getClass().getDeclaredFields()).find((Function1)new Serializable(){
            public static final long serialVersionUID = 0L;

            public final boolean apply(Field x$3) {
                String string = x$3.getName();
                String string2 = "defaultTaskSupport";
                return !(string != null ? !string.equals(string2) : string2 != null);
            }
        }).get();
        defaultTaskSupportField.setAccessible(true);
        defaultTaskSupportField.set(parPkgObj, new ForkJoinTaskSupport(new ForkJoinPool(numThreads)));
    }

    private package$() {
        MODULE$ = this;
        this.Linear = 0;
        this.Sigmoid = 1;
        this.NoObservers = Nil$.MODULE$;
        this.NoPretraining = new anonfun.1();
        this.HintonsMiraculousStrategy = new anonfun.2();
        this.RandomRetryStrategy = new anonfun.3();
        this.TournamentStrategy = new anonfun.4();
    }
}

