001package javax.visrec.ml.classification; 002 003import javax.visrec.ml.ClassifierCreationException; 004import javax.visrec.spi.ServiceProvider; 005import java.io.File; 006import java.lang.reflect.InvocationTargetException; 007import java.lang.reflect.Method; 008import java.util.Map; 009 010public interface NeuralNetImageClassifier<T> extends ImageClassifier<T> { 011 012 static <IMAGE_CLASS> NeuralNetImageClassifier.Builder<IMAGE_CLASS> builder() { 013 return new Builder<>(); 014 } 015 016 class BuildingBlock<T> { 017 018 private int imageWidth; 019 private int imageHeight; 020 private File networkArchitecture; 021 private File trainingFile; 022 private File labelsFile; 023 private float maxError; 024 private float learningRate; 025 private File modelFile; 026 private int maxEpochs; 027 private Class<T> inputCls; 028 029 private BuildingBlock() { 030 } 031 032 public File getNetworkArchitecture() { 033 return networkArchitecture; 034 } 035 036 public int getImageWidth() { 037 return imageWidth; 038 } 039 040 public int getImageHeight() { 041 return imageHeight; 042 } 043 044 public File getTrainingFile() { 045 return trainingFile; 046 } 047 048 public File getLabelsFile() { 049 return labelsFile; 050 } 051 052 public float getMaxError() { 053 return maxError; 054 } 055 056 public float getLearningRate() { 057 return learningRate; 058 } 059 060 public File getModelFile() { 061 return modelFile; 062 } 063 064 public int getMaxEpochs() { 065 return maxEpochs; 066 } 067 068 public Class<T> getInputClass() { return inputCls; } 069 070 private static <R> BuildingBlock<R> copyWithNewInputClass(BuildingBlock<?> block, Class<R> cls) { 071 BuildingBlock<R> newBlock = new BuildingBlock<>(); 072 newBlock.inputCls = cls; 073 newBlock.imageHeight = block.imageHeight; 074 newBlock.imageWidth = block.imageWidth; 075 newBlock.labelsFile = block.labelsFile; 076 newBlock.modelFile = block.modelFile; 077 newBlock.networkArchitecture = block.networkArchitecture; 078 newBlock.maxError = block.maxError; 079 newBlock.maxEpochs = block.maxEpochs; 080 newBlock.learningRate = block.learningRate; 081 newBlock.trainingFile = block.trainingFile; 082 return newBlock; 083 } 084 } 085 086 class Builder<T> { 087 088 private BuildingBlock<T> block; 089 090 private Builder() { 091 block = new BuildingBlock<>(); 092 } 093 094 private Builder(BuildingBlock<T> block) { 095 this.block = block; 096 } 097 098 public <R> Builder<R> inputClass(Class<R> cls) { 099 BuildingBlock<R> newBlock = BuildingBlock.copyWithNewInputClass(block, cls); 100 return new Builder<>(newBlock); 101 } 102 103 public Builder<T> imageWidth(int imageWidth) { 104 block.imageWidth = imageWidth; 105 return this; 106 } 107 108 public Builder<T> imageHeight(int imageHeight) { 109 block.imageHeight = imageHeight; 110 return this; 111 } 112 113 public Builder<T> trainingFile(File trainingFile) { 114 block.trainingFile = trainingFile; 115 return this; 116 } 117 118 public Builder<T> labelsFile(File labelsFile) { 119 block.labelsFile = labelsFile; 120 return this; 121 } 122 123 public Builder<T> maxError(float maxError) { 124 block.maxError = maxError; 125 return this; 126 } 127 128 public Builder<T> maxEpochs(int epochs) { 129 block.maxEpochs = epochs; 130 return this; 131 } 132 133 public Builder<T> learningRate(float learningRate) { 134 block.learningRate = learningRate; 135 return this; 136 } 137 138 public Builder<T> modelFile(File modelFile) { 139 block.modelFile = modelFile; 140 return this; 141 } 142 143 public Builder<T> networkArchitecture(File architecture) { 144 block.networkArchitecture = architecture; 145 return this; 146 } 147 148 public BuildingBlock getBuildingBlock() { 149 return block; 150 } 151 152 public ImageClassifier<T> build() throws ClassifierCreationException { 153 return ServiceProvider.current().getClassifierFactoryService().createNeuralNetImageClassifier(block); 154 } 155 156 public ImageClassifier<T> build(Map<String, Object> configuration) throws ClassifierCreationException { 157 Method[] methods = this.getClass().getDeclaredMethods(); 158 for (Method method : methods) { 159 if (!method.getName().equals("build") && method.getParameterCount() == 1 160 && configuration.containsKey(method.getName())) { 161 try { 162 Object value = configuration.get(method.getName()); 163 Class<?> expectedParameterType = method.getParameterTypes()[0]; 164 // Integer casting 165 if (expectedParameterType.equals(int.class) || expectedParameterType.equals(Integer.class)) { 166 if (value instanceof String) { 167 method.invoke(this, Integer.parseInt((String) value)); 168 continue; 169 } 170 } 171 172 // Float casting 173 if (expectedParameterType.equals(float.class) || expectedParameterType.equals(Float.class)) { 174 if (value instanceof String) { 175 method.invoke(this, Float.parseFloat((String) value)); 176 continue; 177 } 178 } 179 180 // File casting 181 if (expectedParameterType.equals(File.class)) { 182 if (value instanceof String) { 183 method.invoke(this, new File((String) value)); 184 continue; 185 } 186 } 187 188 // Others 189 method.invoke(this, value); 190 } catch (IllegalAccessException | InvocationTargetException | IllegalArgumentException e) { 191 throw new ClassifierCreationException("Couldn't invoke '" + method.getName() + "'", e); 192 } 193 } 194 } 195 return build(); 196 } 197 } 198}