001package javax.visrec.ml.classification;
002
003import javax.visrec.ml.ClassifierCreationException;
004import javax.visrec.spi.ServiceProvider;
005import java.io.File;
006import java.lang.reflect.InvocationTargetException;
007import java.lang.reflect.Method;
008import java.util.Map;
009
010public interface NeuralNetImageClassifier<T> extends ImageClassifier<T> {
011
012    static <IMAGE_CLASS> NeuralNetImageClassifier.Builder<IMAGE_CLASS> builder() {
013        return new Builder<>();
014    }
015
016    class BuildingBlock<T> {
017
018        private int imageWidth;
019        private int imageHeight;
020        private File networkArchitecture;
021        private File trainingFile;
022        private File labelsFile;
023        private float maxError;
024        private float learningRate;
025        private File modelFile;
026        private int maxEpochs;
027        private Class<T> inputCls;
028
029        private BuildingBlock() {
030        }
031
032        public File getNetworkArchitecture() {
033            return networkArchitecture;
034        }
035
036        public int getImageWidth() {
037            return imageWidth;
038        }
039
040        public int getImageHeight() {
041            return imageHeight;
042        }
043
044        public File getTrainingFile() {
045            return trainingFile;
046        }
047
048        public File getLabelsFile() {
049            return labelsFile;
050        }
051
052        public float getMaxError() {
053            return maxError;
054        }
055
056        public float getLearningRate() {
057            return learningRate;
058        }
059
060        public File getModelFile() {
061            return modelFile;
062        }
063
064        public int getMaxEpochs() {
065            return maxEpochs;
066        }
067
068        public Class<T> getInputClass() { return inputCls; }
069
070        private static <R> BuildingBlock<R> copyWithNewInputClass(BuildingBlock<?> block, Class<R> cls) {
071            BuildingBlock<R> newBlock = new BuildingBlock<>();
072            newBlock.inputCls = cls;
073            newBlock.imageHeight = block.imageHeight;
074            newBlock.imageWidth = block.imageWidth;
075            newBlock.labelsFile = block.labelsFile;
076            newBlock.modelFile = block.modelFile;
077            newBlock.networkArchitecture = block.networkArchitecture;
078            newBlock.maxError = block.maxError;
079            newBlock.maxEpochs = block.maxEpochs;
080            newBlock.learningRate = block.learningRate;
081            newBlock.trainingFile = block.trainingFile;
082            return newBlock;
083        }
084    }
085
086    class Builder<T> {
087
088        private BuildingBlock<T> block;
089
090        private Builder() {
091            block = new BuildingBlock<>();
092        }
093
094        private Builder(BuildingBlock<T> block) {
095            this.block = block;
096        }
097
098        public <R> Builder<R> inputClass(Class<R> cls) {
099            BuildingBlock<R> newBlock = BuildingBlock.copyWithNewInputClass(block, cls);
100            return new Builder<>(newBlock);
101        }
102
103        public Builder<T> imageWidth(int imageWidth) {
104            block.imageWidth = imageWidth;
105            return this;
106        }
107
108        public Builder<T> imageHeight(int imageHeight) {
109            block.imageHeight = imageHeight;
110            return this;
111        }
112
113        public Builder<T> trainingFile(File trainingFile) {
114            block.trainingFile = trainingFile;
115            return this;
116        }
117
118        public Builder<T> labelsFile(File labelsFile) {
119            block.labelsFile = labelsFile;
120            return this;
121        }
122
123        public Builder<T> maxError(float maxError) {
124            block.maxError = maxError;
125            return this;
126        }
127
128        public Builder<T> maxEpochs(int epochs) {
129            block.maxEpochs = epochs;
130            return this;
131        }
132
133        public Builder<T> learningRate(float learningRate) {
134            block.learningRate = learningRate;
135            return this;
136        }
137
138        public Builder<T> modelFile(File modelFile) {
139            block.modelFile = modelFile;
140            return this;
141        }
142
143        public Builder<T> networkArchitecture(File architecture) {
144            block.networkArchitecture = architecture;
145            return this;
146        }
147
148        public BuildingBlock getBuildingBlock() {
149            return block;
150        }
151
152        public ImageClassifier<T> build() throws ClassifierCreationException {
153            return ServiceProvider.current().getClassifierFactoryService().createNeuralNetImageClassifier(block);
154        }
155
156        public ImageClassifier<T> build(Map<String, Object> configuration) throws ClassifierCreationException {
157            Method[] methods = this.getClass().getDeclaredMethods();
158            for (Method method : methods) {
159                if (!method.getName().equals("build") && method.getParameterCount() == 1
160                        && configuration.containsKey(method.getName())) {
161                    try {
162                        Object value = configuration.get(method.getName());
163                        Class<?> expectedParameterType = method.getParameterTypes()[0];
164                        // Integer casting
165                        if (expectedParameterType.equals(int.class) || expectedParameterType.equals(Integer.class)) {
166                            if (value instanceof String) {
167                                method.invoke(this, Integer.parseInt((String) value));
168                                continue;
169                            }
170                        }
171
172                        // Float casting
173                        if (expectedParameterType.equals(float.class) || expectedParameterType.equals(Float.class)) {
174                            if (value instanceof String) {
175                                method.invoke(this, Float.parseFloat((String) value));
176                                continue;
177                            }
178                        }
179
180                        // File casting
181                        if (expectedParameterType.equals(File.class)) {
182                            if (value instanceof String) {
183                                method.invoke(this, new File((String) value));
184                                continue;
185                            }
186                        }
187
188                        // Others
189                        method.invoke(this, value);
190                    } catch (IllegalAccessException | InvocationTargetException | IllegalArgumentException e) {
191                        throw new ClassifierCreationException("Couldn't invoke '" + method.getName() + "'", e);
192                    }
193                }
194            }
195            return build();
196        }
197    }
198}