001package javax.visrec.ml.classification;
002
003import javax.visrec.ml.ClassifierCreationException;
004import javax.visrec.spi.ServiceProvider;
005import java.io.File;
006import java.util.Map;
007
008public interface NeuralNetBinaryClassifier<T> extends BinaryClassifier<T> {
009
010    static NeuralNetBinaryClassifier.Builder<?> builder() {
011        return new NeuralNetBinaryClassifier.Builder<>();
012    }
013
014    class BuildingBlock<T> {
015        private Class<T> inputCls;
016        private int inputsNum;
017        private int[] hiddenLayers;
018        private float maxError;
019        private int maxEpochs;
020        private float learningRate;
021        private File trainingFile;
022
023        private BuildingBlock() {
024        }
025
026        public Class<T> getInputClass() {
027            return inputCls;
028        }
029
030        public int getInputsNum() {
031            return inputsNum;
032        }
033
034        public int[] getHiddenLayers() {
035            return hiddenLayers;
036        }
037
038        public float getMaxError() {
039            return maxError;
040        }
041
042        public int getMaxEpochs() {
043            return maxEpochs;
044        }
045
046        public float getLearningRate() {
047            return learningRate;
048        }
049
050        public File getTrainingFile() {
051            return trainingFile;
052        }
053
054        private static <R> BuildingBlock<R> copyWithNewTargetClass(BuildingBlock<?> block, Class<R> cls) {
055            BuildingBlock<R> newBlock = new BuildingBlock<>();
056            newBlock.inputCls = cls;
057            newBlock.inputsNum = block.inputsNum;
058            newBlock.hiddenLayers = block.hiddenLayers;
059            newBlock.maxError = block.maxError;
060            newBlock.maxEpochs = block.maxEpochs;
061            newBlock.learningRate = block.learningRate;
062            newBlock.trainingFile = block.trainingFile;
063            return newBlock;
064        }
065    }
066
067    class Builder<T> {
068
069        private NeuralNetBinaryClassifier.BuildingBlock<T> block;
070
071        private Builder() {
072            this(new NeuralNetBinaryClassifier.BuildingBlock<>());
073        }
074
075        private Builder(BuildingBlock<T> block) {
076            this.block = block;
077        }
078
079        public <R> Builder<R> inputClass(Class<R> cls) {
080            BuildingBlock<R> newBlock = BuildingBlock.copyWithNewTargetClass(block, cls);
081            return new Builder<>(newBlock);
082        }
083
084        public Builder<T> inputsNum(int inputsNum) {
085            block.inputsNum = inputsNum;
086            return this;
087        }
088
089        public Builder<T> hiddenLayers(int... hiddenLayers) {
090            block.hiddenLayers = hiddenLayers;
091            return this;
092        }
093
094        public Builder<T> maxError(float maxError) {
095            block.maxError = maxError;
096            return this;
097        }
098
099        public Builder<T> maxEpochs(int maxEpochs) {
100            block.maxEpochs = maxEpochs;
101            return this;
102        }
103
104        public Builder<T> learningRate(float learningRate) {
105            block.learningRate = learningRate;
106            return this;
107        }
108
109        public Builder<T> trainingFile(File trainingFile) {
110            block.trainingFile = trainingFile;
111            return this;
112        }
113
114        public NeuralNetBinaryClassifier.BuildingBlock<T> getBuildingBlock() {
115            return block;
116        }
117
118        public BinaryClassifier<T> build() throws ClassifierCreationException {
119            return ServiceProvider.current().getClassifierFactoryService().createNeuralNetBinaryClassifier(block);
120        }
121
122        public BinaryClassifier<T> build(Map<String, Object> configuration) throws ClassifierCreationException {
123            throw new IllegalStateException("not implemented yet");
124        }
125    }
126}