001package io.ebean.enhance.entity;
002
003import io.ebean.enhance.asm.ClassVisitor;
004import io.ebean.enhance.asm.FieldVisitor;
005import io.ebean.enhance.asm.Label;
006import io.ebean.enhance.asm.MethodVisitor;
007import io.ebean.enhance.asm.Opcodes;
008import io.ebean.enhance.common.ClassMeta;
009import io.ebean.enhance.common.EnhanceConstants;
010import io.ebean.enhance.common.VisitUtil;
011
012/**
013 * Generate the equals hashCode method using the identity.
014 * <p>
015 * This will add a _ebean_getIdentity() equals() and hashCode() methods based on
016 * having a single ID property and no existing equals() or hashCode() methods.
017 * </p>
018 */
019public class MethodEquals implements Opcodes, EnhanceConstants {
020
021  private static final String _EBEAN_GET_IDENTITY = "_ebean_getIdentity";
022
023  /**
024  * Adds equals(), hashCode() and _ebean_getIdentity() methods.
025  * <p>
026  * If the class already has a equals() or hashCode() method defined then
027  * these methods are not added (its a noop).
028  * </p>
029  *
030  * @param idFieldIndex
031  *            the index of the id field
032  */
033  public static void addMethods(ClassVisitor cv, ClassMeta meta, int idFieldIndex, FieldMeta idFieldMeta) {
034
035    if (meta.hasEqualsOrHashCode()) {
036      // already has a equals or hashcode method...
037      // so we will not add our identity based one
038      if (meta.isLog(2)) {
039        meta.log("already has a equals() or hashCode() method. Not adding the identity based one.");
040      }
041    } else {
042      if (meta.isLog(2)) {
043        meta.log("adding equals() hashCode() and _ebean_getIdentity() with Id field "
044          + idFieldMeta.getName()+ " index:" + idFieldIndex+" primitive:"+idFieldMeta.isPrimitiveType());
045      }
046      if (idFieldMeta.isPrimitiveType()){
047        addGetIdentityPrimitive(cv, meta, idFieldMeta);
048      } else {
049        addGetIdentityObject(cv, meta, idFieldIndex);
050      }
051      addEquals(cv, meta);
052      addHashCode(cv, meta);
053    }
054
055  }
056
057  /**
058  * The identity field used for implementing equals via the
059  * _ebean_getIdentity() method.
060  */
061  public static void addIdentityField(ClassVisitor cv) {
062
063      int access = ACC_PROTECTED + ACC_TRANSIENT;
064    FieldVisitor f0 = cv.visitField(access, IDENTITY_FIELD, "Ljava/lang/Object;", null, null);
065    f0.visitEnd();
066  }
067
068  /**
069  * Generate the _ebean_getIdentity method for primitive id types.
070  * <p>
071  * For primitives we need to check for == 0 rather than null.
072  * <p>
073  * <p>
074  * This is used for implementing equals().
075  * <p>
076  *
077  * <pre>
078  * private Object _ebean_getIdentity() {
079  *     synchronized (this) {
080  *             if (_ebean_identity != null) {
081  *                     return _ebean_identity;
082  *             }
083  *
084  *             if (0 != getId()) {
085  *                     _ebean_identity = Integer.valueOf(getId());
086  *             } else {
087  *                     _ebean_identity = new Object();
088  *             }
089  *
090  *             return _ebean_identity;
091  *     }
092  * }
093  * </pre>
094  */
095  private static void addGetIdentityPrimitive(ClassVisitor cv, ClassMeta classMeta, FieldMeta idFieldMeta) {
096
097    String className = classMeta.getClassName();
098
099    MethodVisitor mv;
100
101    mv = cv.visitMethod(ACC_PRIVATE, _EBEAN_GET_IDENTITY, "()Ljava/lang/Object;", null, null);
102    mv.visitCode();
103    Label l0 = new Label();
104    Label l1 = new Label();
105    Label l2 = new Label();
106    mv.visitTryCatchBlock(l0, l1, l2, null);
107    Label l3 = new Label();
108    Label l4 = new Label();
109    mv.visitTryCatchBlock(l3, l4, l2, null);
110    Label l5 = new Label();
111    mv.visitTryCatchBlock(l2, l5, l2, null);
112    Label l6 = new Label();
113    mv.visitLabel(l6);
114    mv.visitLineNumber(1, l6);
115    mv.visitVarInsn(ALOAD, 0);
116    mv.visitInsn(DUP);
117    mv.visitVarInsn(ASTORE, 1);
118    mv.visitInsn(MONITORENTER);
119    mv.visitLabel(l0);
120    mv.visitLineNumber(1, l0);
121    mv.visitVarInsn(ALOAD, 0);
122    mv.visitFieldInsn(GETFIELD, className, IDENTITY_FIELD, "Ljava/lang/Object;");
123    mv.visitJumpInsn(IFNULL, l3);
124    Label l7 = new Label();
125    mv.visitLabel(l7);
126    mv.visitLineNumber(1, l7);
127    mv.visitVarInsn(ALOAD, 0);
128    mv.visitFieldInsn(GETFIELD, className, IDENTITY_FIELD, "Ljava/lang/Object;");
129    mv.visitVarInsn(ALOAD, 1);
130    mv.visitInsn(MONITOREXIT);
131    mv.visitLabel(l1);
132    mv.visitInsn(ARETURN);
133    mv.visitLabel(l3);
134    mv.visitLineNumber(1, l3);
135    mv.visitFrame(Opcodes.F_APPEND, 1, new Object[]{"java/lang/Object"}, 0, null);
136    mv.visitVarInsn(ALOAD, 0);
137    idFieldMeta.appendGetPrimitiveIdValue(mv, classMeta);
138    idFieldMeta.appendCompare(mv, classMeta);
139
140    Label l8 = new Label();
141    mv.visitJumpInsn(IFEQ, l8);
142    Label l9 = new Label();
143    mv.visitLabel(l9);
144    mv.visitLineNumber(1, l9);
145    mv.visitVarInsn(ALOAD, 0);
146    mv.visitVarInsn(ALOAD, 0);
147    idFieldMeta.appendGetPrimitiveIdValue(mv, classMeta);
148    idFieldMeta.appendValueOf(mv);
149    mv.visitFieldInsn(PUTFIELD, className, IDENTITY_FIELD, "Ljava/lang/Object;");
150    Label l10 = new Label();
151    mv.visitJumpInsn(GOTO, l10);
152    mv.visitLabel(l8);
153    mv.visitLineNumber(1, l8);
154    mv.visitFrame(Opcodes.F_SAME, 0, null, 0, null);
155    mv.visitVarInsn(ALOAD, 0);
156    mv.visitTypeInsn(NEW, "java/lang/Object");
157    mv.visitInsn(DUP);
158    mv.visitMethodInsn(INVOKESPECIAL, "java/lang/Object", INIT, NOARG_VOID, false);
159    mv.visitFieldInsn(PUTFIELD, className, IDENTITY_FIELD, "Ljava/lang/Object;");
160    mv.visitLabel(l10);
161    mv.visitLineNumber(1, l10);
162    mv.visitFrame(Opcodes.F_SAME, 0, null, 0, null);
163    mv.visitVarInsn(ALOAD, 0);
164    mv.visitFieldInsn(GETFIELD, className, IDENTITY_FIELD, "Ljava/lang/Object;");
165    mv.visitVarInsn(ALOAD, 1);
166    mv.visitInsn(MONITOREXIT);
167    mv.visitLabel(l4);
168    mv.visitInsn(ARETURN);
169    mv.visitLabel(l2);
170    mv.visitLineNumber(1, l2);
171    mv.visitFrame(Opcodes.F_SAME1, 0, null, 1, new Object[]{"java/lang/Throwable"});
172    mv.visitVarInsn(ASTORE, 2);
173    mv.visitVarInsn(ALOAD, 1);
174    mv.visitInsn(MONITOREXIT);
175    mv.visitLabel(l5);
176    mv.visitVarInsn(ALOAD, 2);
177    mv.visitInsn(ATHROW);
178    Label l11 = new Label();
179    mv.visitLabel(l11);
180    mv.visitLocalVariable("this", "L"+className+";", null, l6, l11, 0);
181    mv.visitMaxs(4, 3);
182    mv.visitEnd();
183  }
184
185  /**
186  * Generate the _ebean_getIdentity method for used with equals().
187  *
188  * <pre>
189  * private Object _ebean_getIdentity() {
190  *     synchronized (this) {
191  *             if (_ebean_identity != null) {
192  *                     return _ebean_identity;
193  *             }
194  *
195  *             Object id = getId();
196  *             if (id != null) {
197  *                     _ebean_identity = id;
198  *             } else {
199  *                     _ebean_identity = new Object();
200  *             }
201  *
202  *             return _ebean_identity;
203  *     }
204  * }
205  * </pre>
206  */
207  private static void addGetIdentityObject(ClassVisitor cv, ClassMeta classMeta, int idFieldIndex) {
208
209    String className = classMeta.getClassName();
210
211    MethodVisitor mv;
212
213    mv = cv.visitMethod(ACC_PRIVATE, _EBEAN_GET_IDENTITY, "()Ljava/lang/Object;", null, null);
214    mv.visitCode();
215    Label l0 = new Label();
216    Label l1 = new Label();
217    Label l2 = new Label();
218    mv.visitTryCatchBlock(l0, l1, l2, null);
219    Label l3 = new Label();
220    Label l4 = new Label();
221    mv.visitTryCatchBlock(l3, l4, l2, null);
222    Label l5 = new Label();
223    mv.visitTryCatchBlock(l2, l5, l2, null);
224    Label l6 = new Label();
225    mv.visitLabel(l6);
226    mv.visitLineNumber(1, l6);
227    mv.visitVarInsn(ALOAD, 0);
228    mv.visitInsn(DUP);
229    mv.visitVarInsn(ASTORE, 1);
230    mv.visitInsn(MONITORENTER);
231    mv.visitLabel(l0);
232    mv.visitLineNumber(1, l0);
233    mv.visitVarInsn(ALOAD, 0);
234    mv.visitFieldInsn(GETFIELD, className, IDENTITY_FIELD, "Ljava/lang/Object;");
235    mv.visitJumpInsn(IFNULL, l3);
236    Label l7 = new Label();
237    mv.visitLabel(l7);
238    mv.visitLineNumber(1, l7);
239    mv.visitVarInsn(ALOAD, 0);
240    mv.visitFieldInsn(GETFIELD, className, IDENTITY_FIELD, "Ljava/lang/Object;");
241    mv.visitVarInsn(ALOAD, 1);
242    mv.visitInsn(MONITOREXIT);
243    mv.visitLabel(l1);
244    mv.visitInsn(ARETURN);
245    mv.visitLabel(l3);
246    mv.visitLineNumber(1, l3);
247    mv.visitFrame(Opcodes.F_APPEND, 1, new Object[]{"java/lang/Object"}, 0, null);
248    mv.visitVarInsn(ALOAD, 0);
249    VisitUtil.visitIntInsn(mv, idFieldIndex);
250    mv.visitMethodInsn(INVOKESPECIAL, className, "_ebean_getField", "(I)Ljava/lang/Object;", false);
251    mv.visitVarInsn(ASTORE, 2);
252    Label l8 = new Label();
253    mv.visitLabel(l8);
254    mv.visitLineNumber(1, l8);
255    mv.visitVarInsn(ALOAD, 2);
256    Label l9 = new Label();
257    mv.visitJumpInsn(IFNULL, l9);
258    Label l10 = new Label();
259    mv.visitLabel(l10);
260    mv.visitLineNumber(1, l10);
261    mv.visitVarInsn(ALOAD, 0);
262    mv.visitVarInsn(ALOAD, 2);
263    mv.visitFieldInsn(PUTFIELD, className, IDENTITY_FIELD, "Ljava/lang/Object;");
264    Label l11 = new Label();
265    mv.visitJumpInsn(GOTO, l11);
266    mv.visitLabel(l9);
267    mv.visitLineNumber(1, l9);
268    mv.visitFrame(Opcodes.F_APPEND, 1, new Object[]{"java/lang/Object"}, 0, null);
269    mv.visitVarInsn(ALOAD, 0);
270    mv.visitTypeInsn(NEW, "java/lang/Object");
271    mv.visitInsn(DUP);
272    mv.visitMethodInsn(INVOKESPECIAL, "java/lang/Object", INIT, NOARG_VOID, false);
273    mv.visitFieldInsn(PUTFIELD, className, IDENTITY_FIELD, "Ljava/lang/Object;");
274    mv.visitLabel(l11);
275    mv.visitLineNumber(1, l11);
276    mv.visitFrame(Opcodes.F_SAME, 0, null, 0, null);
277    mv.visitVarInsn(ALOAD, 0);
278    mv.visitFieldInsn(GETFIELD, className, IDENTITY_FIELD, "Ljava/lang/Object;");
279    mv.visitVarInsn(ALOAD, 1);
280    mv.visitInsn(MONITOREXIT);
281    mv.visitLabel(l4);
282    mv.visitInsn(ARETURN);
283    mv.visitLabel(l2);
284    mv.visitLineNumber(1, l2);
285    mv.visitFrame(Opcodes.F_FULL, 2, new Object[]{className, "java/lang/Object"}, 1, new Object[]{"java/lang/Throwable"});
286    mv.visitVarInsn(ASTORE, 3);
287    mv.visitVarInsn(ALOAD, 1);
288    mv.visitInsn(MONITOREXIT);
289    mv.visitLabel(l5);
290    mv.visitVarInsn(ALOAD, 3);
291    mv.visitInsn(ATHROW);
292    Label l12 = new Label();
293    mv.visitLabel(l12);
294    mv.visitLocalVariable("this", "L"+className+";", null, l6, l12, 0);
295    mv.visitLocalVariable("tmpId", "Ljava/lang/Object;", null, l8, l2, 2);
296    mv.visitMaxs(3, 4);
297    mv.visitEnd();
298  }
299
300    /**
301    * Generate the equals method.
302    *
303    * <pre>
304    * public boolean equals(Object o) {
305    *     if (o == null) {
306    *         return false;
307    *     }
308    *     if (!this.getClass().equals(o.getClass())) {
309    *         return false;
310    *     }
311    *     if (o == this) {
312    *         return true;
313    *     }
314    *     return _ebean_getIdentity().equals(((FooEntity)o)._ebean_getIdentity());
315    * }
316    * </pre>
317    */
318  private static void addEquals(ClassVisitor cv, ClassMeta classMeta) {
319
320    MethodVisitor mv;
321
322    mv = cv.visitMethod(ACC_PUBLIC, "equals", "(Ljava/lang/Object;)Z", null, null);
323    mv.visitCode();
324    Label l0 = new Label();
325    mv.visitLabel(l0);
326    mv.visitLineNumber(1, l0);
327    mv.visitVarInsn(ALOAD, 1);
328    Label l1 = new Label();
329    mv.visitJumpInsn(IFNONNULL, l1);
330    Label l2 = new Label();
331    mv.visitLabel(l2);
332    mv.visitLineNumber(2, l2);
333    mv.visitInsn(ICONST_0);
334    mv.visitInsn(IRETURN);
335    mv.visitLabel(l1);
336    mv.visitLineNumber(3, l1);
337    mv.visitFrame(Opcodes.F_SAME, 0, null, 0, null);
338    mv.visitVarInsn(ALOAD, 0);
339    mv.visitMethodInsn(INVOKEVIRTUAL, "java/lang/Object", "getClass", "()Ljava/lang/Class;", false);
340    mv.visitVarInsn(ALOAD, 1);
341    mv.visitMethodInsn(INVOKEVIRTUAL, "java/lang/Object", "getClass", "()Ljava/lang/Class;", false);
342    mv.visitMethodInsn(INVOKEVIRTUAL, "java/lang/Object", "equals", "(Ljava/lang/Object;)Z", false);
343    Label l3 = new Label();
344    mv.visitJumpInsn(IFNE, l3);
345    Label l4 = new Label();
346    mv.visitLabel(l4);
347    mv.visitLineNumber(4, l4);
348    mv.visitInsn(ICONST_0);
349    mv.visitInsn(IRETURN);
350    mv.visitLabel(l3);
351    mv.visitLineNumber(5, l3);
352    mv.visitFrame(Opcodes.F_SAME, 0, null, 0, null);
353    mv.visitVarInsn(ALOAD, 1);
354    mv.visitVarInsn(ALOAD, 0);
355    Label l5 = new Label();
356    mv.visitJumpInsn(IF_ACMPNE, l5);
357    Label l6 = new Label();
358    mv.visitLabel(l6);
359    mv.visitLineNumber(6, l6);
360    mv.visitInsn(ICONST_1);
361    mv.visitInsn(IRETURN);
362    mv.visitLabel(l5);
363    mv.visitLineNumber(7, l5);
364    mv.visitFrame(Opcodes.F_SAME, 0, null, 0, null);
365    mv.visitVarInsn(ALOAD, 0);
366    mv.visitMethodInsn(INVOKEVIRTUAL, classMeta.getClassName(), _EBEAN_GET_IDENTITY, "()Ljava/lang/Object;", false);
367    mv.visitVarInsn(ALOAD, 1);
368    mv.visitTypeInsn(CHECKCAST, classMeta.getClassName());
369    mv.visitMethodInsn(INVOKEVIRTUAL, classMeta.getClassName(), _EBEAN_GET_IDENTITY, "()Ljava/lang/Object;", false);
370    mv.visitMethodInsn(INVOKEVIRTUAL, "java/lang/Object", "equals", "(Ljava/lang/Object;)Z", false);
371    mv.visitInsn(IRETURN);
372    Label l7 = new Label();
373    mv.visitLabel(l7);
374    mv.visitLocalVariable("this", "L"+classMeta.getClassName()+";", null, l0, l7, 0);
375    mv.visitLocalVariable("obj", "Ljava/lang/Object;", null, l0, l7, 1);
376    mv.visitMaxs(2, 2);
377    mv.visitEnd();
378  }
379
380  /**
381  * Generate a hashCode method used to go with MethodEquals.
382  *
383  * <pre><code>
384  * public int hashCode() {
385  *     return ebeanGetIdentity().hashCode();
386  * }
387  * </code></pre>
388  */
389  private static void addHashCode(ClassVisitor cv, ClassMeta meta) {
390
391    MethodVisitor mv;
392
393    mv = cv.visitMethod(ACC_PUBLIC, "hashCode", "()I", null, null);
394    mv.visitCode();
395    Label l0 = new Label();
396    mv.visitLabel(l0);
397    mv.visitLineNumber(1, l0);
398    mv.visitVarInsn(ALOAD, 0);
399    mv.visitMethodInsn(INVOKESPECIAL, meta.getClassName(), _EBEAN_GET_IDENTITY, "()Ljava/lang/Object;", false);
400    mv.visitMethodInsn(INVOKEVIRTUAL, "java/lang/Object", "hashCode", "()I", false);
401    mv.visitInsn(IRETURN);
402    Label l1 = new Label();
403    mv.visitLabel(l1);
404    mv.visitLocalVariable("this", "L" + meta.getClassName() + ";", null, l0, l1, 0);
405    mv.visitMaxs(1, 1);
406    mv.visitEnd();
407
408  }
409
410}