001package javax.visrec;
002
003import javax.visrec.ml.ClassificationException;
004import javax.visrec.ml.classification.ImageClassifier;
005import javax.visrec.spi.ServiceProvider;
006import java.awt.image.BufferedImage;
007import java.io.File;
008import java.io.IOException;
009import java.io.InputStream;
010import java.util.Map;
011import java.util.Objects;
012import java.util.Optional;
013
014/**
015 * Skeleton abstract class to make it easier to implement image classifier.
016 * It provides implementation of Classifier interface for images, along with
017 * image factory for specific type of images.
018 * This class solves the problem of using various implementation of images and machine learning models in Java,
019 * and provides standard Classifier API for clients.
020 *
021 * By default the type of key in the Map the {@link ImageClassifier} is {@code String}
022 *
023 * @author Zoran Sevarac
024 *
025 * @param <MODEL_CLASS> class of machine learning model
026 */
027public abstract class AbstractImageClassifier<IMAGE_CLASS, MODEL_CLASS> implements ImageClassifier<IMAGE_CLASS> { // could also implement binary classifier
028
029    private ImageFactory<IMAGE_CLASS> imageFactory; // image factory impl for the specified image class
030    private MODEL_CLASS model; // the model could be injected from machine learning container?
031
032    private float threshold; // this should ba a part of every classifier
033
034    protected AbstractImageClassifier(final Class<IMAGE_CLASS> imgCls, final MODEL_CLASS model) {
035        final Optional<ImageFactory<IMAGE_CLASS>> optionalImageFactory = ServiceProvider.current()
036                .getImageFactoryService()
037                .getByImageType(imgCls);
038        if (!optionalImageFactory.isPresent()) {
039            throw new IllegalArgumentException(String.format("Could not find ImageFactory by '%s'", BufferedImage.class.getName()));
040        }
041        imageFactory = optionalImageFactory.get();
042        setModel(model);
043    }
044
045    public ImageFactory<IMAGE_CLASS> getImageFactory() {
046        return imageFactory;
047    }
048
049    @Override
050    public Map<String, Float> classify(File file) throws ClassificationException {
051        IMAGE_CLASS image;
052        try {
053            image = imageFactory.getImage(file);
054        } catch (IOException e) {
055            throw new ClassificationException("Couldn't transform input into a BufferedImage", e);
056        }
057        return classify(image);
058    }
059
060    @Override
061    public Map<String, Float> classify(InputStream inputStream) throws ClassificationException {
062        IMAGE_CLASS image;
063        try {
064            image = imageFactory.getImage(inputStream);
065        } catch (IOException e) {
066            throw new ClassificationException("Couldn't transform input into a BufferedImage", e);
067        }
068        return classify(image);
069    }
070    
071    // todo: provide get top 1, 3, 5 results; sort and get
072
073    // do we need this now, when impl is loaded using service provider?
074    // Kevin and Zoran disussed: probably not needed now when we have service provider impl, and we dont want to allow user to mess with it
075//    public void setImageFactory(ImageFactory<IMAGE_CLASS> imageFactory) {
076//        this.imageFactory = imageFactory;
077//    }
078
079    public MODEL_CLASS getModel() {
080        return model;
081    }
082
083    public void setModel(MODEL_CLASS model) {
084        this.model = Objects.requireNonNull(model, "Model cannot bu null!");         
085    }
086
087    public float getThreshold() {
088        return threshold;
089    }
090
091    public void setThreshold(float threshold) {
092        this.threshold = threshold;
093    }
094}