001package javax.visrec.ml.classification; 002 003import javax.visrec.ml.ClassifierCreationException; 004import javax.visrec.spi.ServiceProvider; 005import java.io.File; 006import java.util.Map; 007 008public interface NeuralNetBinaryClassifier<T> extends BinaryClassifier<T> { 009 010 static NeuralNetBinaryClassifier.Builder<?> builder() { 011 return new NeuralNetBinaryClassifier.Builder<>(); 012 } 013 014 class BuildingBlock<T> { 015 private Class<T> inputCls; 016 private int inputsNum; 017 private int[] hiddenLayers; 018 private float maxError; 019 private int maxEpochs; 020 private float learningRate; 021 private File trainingFile; 022 023 private BuildingBlock() { 024 } 025 026 public Class<T> getInputClass() { 027 return inputCls; 028 } 029 030 public int getInputsNum() { 031 return inputsNum; 032 } 033 034 public int[] getHiddenLayers() { 035 return hiddenLayers; 036 } 037 038 public float getMaxError() { 039 return maxError; 040 } 041 042 public int getMaxEpochs() { 043 return maxEpochs; 044 } 045 046 public float getLearningRate() { 047 return learningRate; 048 } 049 050 public File getTrainingFile() { 051 return trainingFile; 052 } 053 054 private static <R> BuildingBlock<R> copyWithNewTargetClass(BuildingBlock<?> block, Class<R> cls) { 055 BuildingBlock<R> newBlock = new BuildingBlock<>(); 056 newBlock.inputCls = cls; 057 newBlock.inputsNum = block.inputsNum; 058 newBlock.hiddenLayers = block.hiddenLayers; 059 newBlock.maxError = block.maxError; 060 newBlock.maxEpochs = block.maxEpochs; 061 newBlock.learningRate = block.learningRate; 062 newBlock.trainingFile = block.trainingFile; 063 return newBlock; 064 } 065 } 066 067 class Builder<T> { 068 069 private NeuralNetBinaryClassifier.BuildingBlock<T> block; 070 071 private Builder() { 072 this(new NeuralNetBinaryClassifier.BuildingBlock<>()); 073 } 074 075 private Builder(BuildingBlock<T> block) { 076 this.block = block; 077 } 078 079 public <R> Builder<R> inputClass(Class<R> cls) { 080 BuildingBlock<R> newBlock = BuildingBlock.copyWithNewTargetClass(block, cls); 081 return new Builder<>(newBlock); 082 } 083 084 public Builder<T> inputsNum(int inputsNum) { 085 block.inputsNum = inputsNum; 086 return this; 087 } 088 089 public Builder<T> hiddenLayers(int... hiddenLayers) { 090 block.hiddenLayers = hiddenLayers; 091 return this; 092 } 093 094 public Builder<T> maxError(float maxError) { 095 block.maxError = maxError; 096 return this; 097 } 098 099 public Builder<T> maxEpochs(int maxEpochs) { 100 block.maxEpochs = maxEpochs; 101 return this; 102 } 103 104 public Builder<T> learningRate(float learningRate) { 105 block.learningRate = learningRate; 106 return this; 107 } 108 109 public Builder<T> trainingFile(File trainingFile) { 110 block.trainingFile = trainingFile; 111 return this; 112 } 113 114 public NeuralNetBinaryClassifier.BuildingBlock<T> getBuildingBlock() { 115 return block; 116 } 117 118 public BinaryClassifier<T> build() throws ClassifierCreationException { 119 return ServiceProvider.current().getClassifierFactoryService().createNeuralNetBinaryClassifier(block); 120 } 121 122 public BinaryClassifier<T> build(Map<String, Object> configuration) throws ClassifierCreationException { 123 throw new IllegalStateException("not implemented yet"); 124 } 125 } 126}