001package javax.visrec; 002 003import javax.visrec.ml.ClassificationException; 004import javax.visrec.ml.classification.ImageClassifier; 005import javax.visrec.spi.ServiceProvider; 006import java.awt.image.BufferedImage; 007import java.io.File; 008import java.io.IOException; 009import java.io.InputStream; 010import java.util.Map; 011import java.util.Objects; 012import java.util.Optional; 013 014/** 015 * Skeleton abstract class to make it easier to implement image classifier. 016 * It provides implementation of Classifier interface for images, along with 017 * image factory for specific type of images. 018 * This class solves the problem of using various implementation of images and machine learning models in Java, 019 * and provides standard Classifier API for clients. 020 * 021 * By default the type of key in the Map the {@link ImageClassifier} is {@code String} 022 * 023 * @author Zoran Sevarac 024 * 025 * @param <MODEL_CLASS> class of machine learning model 026 */ 027public abstract class AbstractImageClassifier<IMAGE_CLASS, MODEL_CLASS> implements ImageClassifier<IMAGE_CLASS> { // could also implement binary classifier 028 029 private ImageFactory<IMAGE_CLASS> imageFactory; // image factory impl for the specified image class 030 private MODEL_CLASS model; // the model could be injected from machine learning container? 031 032 private float threshold; // this should ba a part of every classifier 033 034 protected AbstractImageClassifier(final Class<IMAGE_CLASS> imgCls, final MODEL_CLASS model) { 035 final Optional<ImageFactory<IMAGE_CLASS>> optionalImageFactory = ServiceProvider.current() 036 .getImageFactoryService() 037 .getByImageType(imgCls); 038 if (!optionalImageFactory.isPresent()) { 039 throw new IllegalArgumentException(String.format("Could not find ImageFactory by '%s'", BufferedImage.class.getName())); 040 } 041 imageFactory = optionalImageFactory.get(); 042 setModel(model); 043 } 044 045 public ImageFactory<IMAGE_CLASS> getImageFactory() { 046 return imageFactory; 047 } 048 049 @Override 050 public Map<String, Float> classify(File file) throws ClassificationException { 051 IMAGE_CLASS image; 052 try { 053 image = imageFactory.getImage(file); 054 } catch (IOException e) { 055 throw new ClassificationException("Couldn't transform input into a BufferedImage", e); 056 } 057 return classify(image); 058 } 059 060 @Override 061 public Map<String, Float> classify(InputStream inputStream) throws ClassificationException { 062 IMAGE_CLASS image; 063 try { 064 image = imageFactory.getImage(inputStream); 065 } catch (IOException e) { 066 throw new ClassificationException("Couldn't transform input into a BufferedImage", e); 067 } 068 return classify(image); 069 } 070 071 // todo: provide get top 1, 3, 5 results; sort and get 072 073 // do we need this now, when impl is loaded using service provider? 074 // Kevin and Zoran disussed: probably not needed now when we have service provider impl, and we dont want to allow user to mess with it 075// public void setImageFactory(ImageFactory<IMAGE_CLASS> imageFactory) { 076// this.imageFactory = imageFactory; 077// } 078 079 public MODEL_CLASS getModel() { 080 return model; 081 } 082 083 public void setModel(MODEL_CLASS model) { 084 this.model = Objects.requireNonNull(model, "Model cannot bu null!"); 085 } 086 087 public float getThreshold() { 088 return threshold; 089 } 090 091 public void setThreshold(float threshold) { 092 this.threshold = threshold; 093 } 094}