001/*
002 * Copyright (c) 2021, Oracle and/or its affiliates.
003 *
004 * Licensed under the 2-clause BSD license.
005 *
006 * Redistribution and use in source and binary forms, with or without
007 * modification, are permitted provided that the following conditions are met:
008 *
009 * 1. Redistributions of source code must retain the above copyright notice,
010 *    this list of conditions and the following disclaimer.
011 *
012 * 2. Redistributions in binary form must reproduce the above copyright notice,
013 *    this list of conditions and the following disclaimer in the documentation
014 *    and/or other materials provided with the distribution.
015 *
016 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
017 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
018 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
019 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
020 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
021 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
022 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
023 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
024 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
025 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
026 * POSSIBILITY OF SUCH DAMAGE.
027 */
028
029package com.oracle.labs.mlrg.olcut.config.json;
030
031import com.fasterxml.jackson.core.JsonParseException;
032import com.fasterxml.jackson.core.JsonProcessingException;
033import com.fasterxml.jackson.core.type.TypeReference;
034import com.fasterxml.jackson.databind.JsonMappingException;
035import com.fasterxml.jackson.databind.ObjectMapper;
036import com.fasterxml.jackson.databind.SerializationFeature;
037import com.oracle.labs.mlrg.olcut.provenance.io.MarshalledProvenance;
038import com.oracle.labs.mlrg.olcut.provenance.io.ObjectMarshalledProvenance;
039import com.oracle.labs.mlrg.olcut.provenance.io.ProvenanceSerialization;
040import com.oracle.labs.mlrg.olcut.provenance.io.ProvenanceSerializationException;
041
042import java.io.BufferedWriter;
043import java.io.FileWriter;
044import java.io.IOException;
045import java.io.PrintWriter;
046import java.nio.file.Path;
047import java.util.ArrayList;
048import java.util.List;
049
050/**
051 * Class for serializing and deserializing provenances to/from json.
052 */
053public final class JsonProvenanceSerialization implements ProvenanceSerialization {
054
055    private static final TypeReference<List<MarshalledProvenance>> typeRef = new TypeReference<List<MarshalledProvenance>>() {};
056
057    private final ObjectMapper mapper;
058
059    /**
060     * Construct a JsonProvenanceSerialization.
061     *
062     * @param indentOutput Indent the output.
063     */
064    public JsonProvenanceSerialization(boolean indentOutput) {
065        mapper = new ObjectMapper();
066        mapper.registerModule(new JsonProvenanceModule());
067        if (indentOutput) {
068            mapper.enable(SerializationFeature.INDENT_OUTPUT);
069        }
070    }
071
072    @Override
073    public String getFileExtension() {
074        return "json";
075    }
076
077    @Override
078    public List<ObjectMarshalledProvenance> deserializeFromFile(Path path) throws ProvenanceSerializationException, IOException {
079        try {
080            List<MarshalledProvenance> jsonProvenances = mapper.readValue(path.toFile(), typeRef);
081            return convertMarshalledProvenanceList(jsonProvenances);
082        } catch (JsonParseException | JsonMappingException e) {
083            throw new ProvenanceSerializationException("Failed to parse JSON",e);
084        }
085    }
086
087    @Override
088    public List<ObjectMarshalledProvenance> deserializeFromString(String input) throws ProvenanceSerializationException {
089        try {
090            List<MarshalledProvenance> jsonProvenances = mapper.readValue(input, typeRef);
091            return convertMarshalledProvenanceList(jsonProvenances);
092        } catch (JsonProcessingException e) {
093            throw new ProvenanceSerializationException("Failed to deserialize provenance", e);
094        }
095    }
096
097    /**
098     * Converts the list of {@link MarshalledProvenance}s to a list of {@link ObjectMarshalledProvenance}s.
099     * <p>
100     * This is because Jackson's deserialization doesn't give a sharp enough type, so we have to check.
101     * It will throw {@link IllegalArgumentException} in the event the provenance stream was malformed and
102     * so contained a top level {@link com.oracle.labs.mlrg.olcut.provenance.io.FlatMarshalledProvenance}.
103     * @param provenances The provenances to cast.
104     * @return A list of {@link ObjectMarshalledProvenance}s.
105     */
106    private static List<ObjectMarshalledProvenance> convertMarshalledProvenanceList(List<MarshalledProvenance> provenances) {
107        List<ObjectMarshalledProvenance> jps = new ArrayList<>();
108        for (MarshalledProvenance mp : provenances) {
109            if (mp instanceof ObjectMarshalledProvenance) {
110                jps.add((ObjectMarshalledProvenance) mp);
111            } else {
112                throw new IllegalArgumentException("Invalid provenance found, expected ObjectMarshalledProvenance, found " + mp);
113            }
114        }
115        return jps;
116    }
117
118    @Override
119    public String serializeToString(List<ObjectMarshalledProvenance> marshalledProvenances) {
120        try {
121            return mapper.writeValueAsString(marshalledProvenances);
122        } catch (JsonProcessingException e) {
123            throw new IllegalArgumentException("Failed to serialize provenance", e);
124        }
125    }
126
127    @Override
128    public void serializeToFile(List<ObjectMarshalledProvenance> marshalledProvenances, Path path) throws IOException {
129        try (PrintWriter writer = new PrintWriter(new BufferedWriter(new FileWriter(path.toFile())))) {
130            writer.println(serializeToString(marshalledProvenances));
131        }
132    }
133}