001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.util.infotheory.impl;
018
019import java.lang.reflect.Array;
020import java.util.ArrayList;
021import java.util.Collection;
022import java.util.Collections;
023import java.util.Iterator;
024import java.util.LinkedHashSet;
025import java.util.List;
026import java.util.ListIterator;
027import java.util.Set;
028
029/**
030 * An implementation of a List which wraps a set of lists.
031 * <p>
032 * Each access returns a {@link Row} drawn by taking an element from each list.
033 * <p>
034 * The rows only expose equals and hashcode, as the information theoretic calculations
035 * only care about equality.
036 * @param <T> The type stored in the lists.
037 */
038public final class RowList<T> implements List<Row<T>> {
039    private final Set<List<T>> set;
040    private final int size;
041
042    /**
043     * Constructs a RowList from a set of lists.
044     * @param set The feature lists.
045     */
046    public RowList(Set<List<T>> set) {
047        this.set = Collections.unmodifiableSet(new LinkedHashSet<>(set));
048        size = set.iterator().next().size();
049        for (Collection<T> element : this.set) {
050            if (size != element.size()) {
051                throw new IllegalArgumentException("Not all the collections in the set are the same length");
052            }
053        }
054    }
055    
056    @Override
057    public int size() {
058        return size;
059    }
060
061    @Override
062    public boolean isEmpty() {
063        return size == 0;
064    }
065
066    @Override
067    public boolean contains(Object o) {
068        if (o instanceof Row) {
069            Row<?> otherRow = (Row<?>) o;
070            boolean found = false;
071            for (Row<T> row : this) {
072                if (otherRow.equals(row)) {
073                    found = true;
074                    break;
075                }
076            }
077            return found;
078        } else {
079            return false;
080        }
081    }
082
083    @Override
084    public Iterator<Row<T>> iterator() {
085        return new RowListIterator<>(set);
086    }
087
088    @Override
089    public Object[] toArray() {
090        Object[] output = new Object[size];
091        int counter = 0;
092        for (Row<T> row : this) {
093            output[counter] = row;
094            counter++;
095        }
096        return output;
097    }
098
099    @Override
100    @SuppressWarnings("unchecked")
101    public <U> U[] toArray(U[] a) {
102        U[] output = a;
103        if (output.length < size) {
104            output = (U[]) Array.newInstance(a[0].getClass(), size);
105        }
106        int counter = 0;
107        for (Row<T> row : this) {
108            output[counter] = (U) row;
109            counter++;
110        }
111        if (output.length > size) {
112            //fill with nulls if bigger.
113            for (; counter < output.length; counter++) {
114                output[counter] = null;
115            }
116        }
117        return output;
118    }
119
120    @Override
121    public Row<T> get(int index) {
122        ArrayList<T> list = new ArrayList<>(set.size());
123        int counter = 0;
124        for (List<T> element : set) {
125            list.add(counter, element.get(index));
126            counter++;
127        }
128        return new Row<>(list);
129    }
130
131    @Override
132    public boolean containsAll(Collection<?> c) {
133        boolean found = true;
134        Iterator<?> itr = c.iterator();
135        while (itr.hasNext() && found) {
136            found = this.contains(itr.next());
137        }
138        return found;
139    }
140
141    @Override
142    public int indexOf(Object o) {
143        if (o instanceof Row) {
144            Row<?> otherRow = (Row<?>) o;
145            int counter = 0;
146            int found = -1;
147            Iterator<Row<T>> itr = this.iterator();
148            while (itr.hasNext() && found == -1) {
149                if (itr.next().equals(otherRow)) {
150                    found = counter;
151                }
152                counter++;
153            }
154            return found;
155        } else {
156            return -1;
157        }
158    }
159
160    @Override
161    public int lastIndexOf(Object o) {
162        if (o instanceof Row) {
163            Row<?> otherRow = (Row<?>) o;
164            int counter = 0;
165            int found = -1;
166            for (Row<T> tRow : this) {
167                if (tRow.equals(otherRow)) {
168                    found = counter;
169                }
170                counter++;
171            }
172            return found;
173        } else {
174            return -1;
175        }
176    }
177
178    @Override
179    public ListIterator<Row<T>> listIterator() {
180        return new RowListIterator<>(set);
181    }
182
183    @Override
184    public ListIterator<Row<T>> listIterator(int index) {
185        return new RowListIterator<>(set,index);
186    }
187
188    /**
189     * Unsupported. Throws UnsupportedOperationException.
190     * @param fromIndex n/a
191     * @param toIndex n/a
192     * @return n/a
193     */
194    @Override
195    public List<Row<T>> subList(int fromIndex, int toIndex) {
196        throw new UnsupportedOperationException("Views are not supported on a RowList.");
197    }
198
199    //*************************************************************************
200    // The remaining operations are unsupported as this list is immutable.
201    //*************************************************************************
202    /**
203     * Unsupported. Throws UnsupportedOperationException.
204     * @param e n/a
205     * @return n/a
206     */
207    @Override
208    public boolean add(Row<T> e) {
209        throw new UnsupportedOperationException("This list is immutable.");
210    }
211
212    /**
213     * Unsupported. Throws UnsupportedOperationException.
214     * @param o n/a
215     * @return n/a
216     */
217    @Override
218    public boolean remove(Object o) {
219        throw new UnsupportedOperationException("This list is immutable.");
220    }
221
222    /**
223     * Unsupported. Throws UnsupportedOperationException.
224     * @param c n/a
225     * @return n/a
226     */
227    @Override
228    public boolean addAll(Collection<? extends Row<T>> c) {
229        throw new UnsupportedOperationException("This list is immutable.");
230    }
231
232    /**
233     * Unsupported. Throws UnsupportedOperationException.
234     * @param index n/a
235     * @param c n/a
236     * @return n/a
237     */
238    @Override
239    public boolean addAll(int index, Collection<? extends Row<T>> c) {
240        throw new UnsupportedOperationException("This list is immutable.");
241    }
242
243    /**
244     * Unsupported. Throws UnsupportedOperationException.
245     * @param c n/a
246     * @return n/a
247     */
248    @Override
249    public boolean removeAll(Collection<?> c) {
250        throw new UnsupportedOperationException("This list is immutable.");
251    }
252
253    /**
254     * Unsupported. Throws UnsupportedOperationException.
255     * @param c n/a
256     * @return n/a
257     */
258    @Override
259    public boolean retainAll(Collection<?> c) {
260        throw new UnsupportedOperationException("This list is immutable.");
261    }
262
263    /**
264     * Unsupported. Throws UnsupportedOperationException.
265     */
266    @Override
267    public void clear() {
268        throw new UnsupportedOperationException("This list is immutable.");
269    }
270
271    /**
272     * Unsupported. Throws UnsupportedOperationException.
273     * @param index n/a
274     * @param element n/a
275     * @return n/a
276     */
277    @Override
278    public Row<T> set(int index, Row<T> element) {
279        throw new UnsupportedOperationException("This list is immutable.");
280    }
281
282    /**
283     * Unsupported. Throws UnsupportedOperationException.
284     * @param index n/a
285     * @param element n/a
286     */
287    @Override
288    public void add(int index, Row<T> element) {
289        throw new UnsupportedOperationException("This list is immutable.");
290    }
291
292    /**
293     * Unsupported. Throws UnsupportedOperationException.
294     * @param index n/a
295     * @return n/a
296     */
297    @Override
298    public Row<T> remove(int index) {
299        throw new UnsupportedOperationException("This list is immutable.");
300    }
301
302    /**
303     * The iterator over the rows.
304     * @param <T> The type of the row.
305     */
306    private static class RowListIterator<T> implements ListIterator<Row<T>> {
307        private int curIndex;
308        private final int size;
309        private final Set<List<T>> set;
310
311        public RowListIterator(Set<List<T>> set) {
312            this(set,0);
313        }
314
315        public RowListIterator(Set<List<T>> set, int curIndex) {
316            this.curIndex = curIndex;
317            this.set = set;
318            this.size = set.iterator().next().size();
319        }
320
321        @Override
322        public boolean hasNext() {
323            return curIndex < size;
324        }
325
326        @Override
327        public Row<T> next() {
328            ArrayList<T> list = new ArrayList<>(set.size());
329            int counter = 0;
330            for (List<T> element : set) {
331                list.add(counter, element.get(curIndex));
332                counter++;
333            }
334            curIndex++;
335            return new Row<>(list);
336        }
337
338        @Override
339        public boolean hasPrevious() {
340            return curIndex > 0;
341        }
342
343        @Override
344        public Row<T> previous() {
345            ArrayList<T> list = new ArrayList<>(set.size());
346            curIndex--;
347            int counter = 0;
348            for (List<T> element : set) {
349                list.add(counter, element.get(curIndex));
350                counter++;
351            }
352            return new Row<>(list);
353        }
354
355        @Override
356        public int nextIndex() {
357            return curIndex;
358        }
359
360        @Override
361        public int previousIndex() {
362            return curIndex - 1;
363        }
364
365        @Override
366        public void remove() {
367            throw new UnsupportedOperationException("The list backing this iterator is immutable.");
368        }
369
370        @Override
371        public void set(Row<T> e) {
372            throw new UnsupportedOperationException("The list backing this iterator is immutable.");
373        }
374
375        @Override
376        public void add(Row<T> e) {
377            throw new UnsupportedOperationException("The list backing this iterator is immutable.");
378        }
379    }
380}
381