001package javax.visrec.spi;
002
003import javax.visrec.ml.ClassifierCreationException;
004import javax.visrec.ml.classification.*;
005import java.util.HashMap;
006import java.util.Map;
007import java.util.ServiceLoader;
008
009/**
010 * Service to provide the correct {@link Classifier} implementation.
011 *
012 * @author Kevin Berendsen
013 * @since 1.0
014 */
015public final class ClassifierFactoryService {
016
017    private Map<Class<?>, ImageClassifierFactory<?>> imageClassifierFactories;
018    private Map<Class<?>, BinaryClassifierFactory<?>> binaryClassifierFactories;
019
020    private static ClassifierFactoryService instance;
021    static ClassifierFactoryService getInstance() {
022        if (instance == null) {
023            instance = new ClassifierFactoryService();
024        }
025        return instance;
026    }
027
028    private ClassifierFactoryService() {
029        // Prevent instantiation
030    }
031
032    /**
033     * Creates a new {@link ImageClassifier} by providing the {@link NeuralNetImageClassifier.BuildingBlock} to tune
034     * the implementation's image classifier.
035     *
036     * @param block {@link NeuralNetImageClassifier.BuildingBlock} is provided to tune the building of the image classifier.
037     * @return {@link ImageClassifier}
038     * @throws ClassifierCreationException if the classifier can not be created due to any reason.
039     */
040    public <T> ImageClassifier<T> createNeuralNetImageClassifier(NeuralNetImageClassifier.BuildingBlock<T> block) throws ClassifierCreationException {
041        if (imageClassifierFactories == null) {
042            imageClassifierFactories = new HashMap<>();
043            for (ImageClassifierFactory<?> classifierCreator : ServiceLoader.load(ImageClassifierFactory.class)) {
044                imageClassifierFactories.put(classifierCreator.getImageClass(), classifierCreator);
045            }
046        }
047
048        ImageClassifierFactory<?> creator = imageClassifierFactories.get(block.getInputClass());
049        if (creator == null) {
050            throw new ClassifierCreationException("Unsupported image class");
051        }
052
053        @SuppressWarnings("unchecked")
054        ImageClassifierFactory<T> castedCreator = (ImageClassifierFactory<T>) creator;
055        return castedCreator.create(block);
056    }
057
058    /**
059     * Creates a new {@link BinaryClassifier} by providing the {@link NeuralNetBinaryClassifier.BuildingBlock} to tune
060     * the implementation's binary classifier.
061     *
062     * @param block {@link NeuralNetBinaryClassifier.BuildingBlock} is provided to tune the building of the binary classifier.
063     * @return {@link BinaryClassifier}
064     * @throws ClassifierCreationException if the classifier can not be created due to any reason.
065     */
066    public <T> BinaryClassifier<T> createNeuralNetBinaryClassifier(NeuralNetBinaryClassifier.BuildingBlock<T> block) throws ClassifierCreationException {
067        if (binaryClassifierFactories == null) {
068            binaryClassifierFactories = new HashMap<>();
069            for (BinaryClassifierFactory<?> classifierCreator : ServiceLoader.load(BinaryClassifierFactory.class)) {
070                binaryClassifierFactories.put(classifierCreator.getTargetClass(), classifierCreator);
071            }
072        }
073
074        BinaryClassifierFactory<?> creator = binaryClassifierFactories.get(block.getInputClass());
075        if (creator == null) {
076            throw new ClassifierCreationException("Unsupported target class");
077        }
078
079        @SuppressWarnings("unchecked")
080        BinaryClassifierFactory<T> castedCreator = (BinaryClassifierFactory<T>) creator;
081        return castedCreator.create(block);
082    }
083}