/*
 * Decompiled with CFR 0.152.
 */
package com.alibaba.ttl.threadpool.agent.internal.transformlet.impl;

import com.alibaba.ttl.internal.javassist.CannotCompileException;
import com.alibaba.ttl.internal.javassist.CtClass;
import com.alibaba.ttl.internal.javassist.CtConstructor;
import com.alibaba.ttl.internal.javassist.CtField;
import com.alibaba.ttl.internal.javassist.CtMethod;
import com.alibaba.ttl.internal.javassist.NotFoundException;
import com.alibaba.ttl.spi.TtlEnhanced;
import com.alibaba.ttl.threadpool.agent.internal.logging.Logger;
import com.alibaba.ttl.threadpool.agent.internal.transformlet.ClassInfo;
import com.alibaba.ttl.threadpool.agent.internal.transformlet.JavassistTransformlet;
import com.alibaba.ttl.threadpool.agent.internal.transformlet.impl.Utils;
import edu.umd.cs.findbugs.annotations.NonNull;
import java.io.IOException;

public class TtlForkJoinTransformlet
implements JavassistTransformlet {
    private static final Logger logger = Logger.getLogger(TtlForkJoinTransformlet.class);
    private static final String FORK_JOIN_TASK_CLASS_NAME = "java.util.concurrent.ForkJoinTask";
    private static final String FORK_JOIN_POOL_CLASS_NAME = "java.util.concurrent.ForkJoinPool";
    private static final String FORK_JOIN_WORKER_THREAD_FACTORY_CLASS_NAME = "java.util.concurrent.ForkJoinPool$ForkJoinWorkerThreadFactory";
    private final boolean disableInheritableForThreadPool;

    public TtlForkJoinTransformlet(boolean disableInheritableForThreadPool) {
        this.disableInheritableForThreadPool = disableInheritableForThreadPool;
    }

    @Override
    public void doTransform(@NonNull ClassInfo classInfo) throws IOException, NotFoundException, CannotCompileException {
        if (FORK_JOIN_TASK_CLASS_NAME.equals(classInfo.getClassName())) {
            this.updateForkJoinTaskClass(classInfo.getCtClass());
            classInfo.setModified();
        } else if (this.disableInheritableForThreadPool && FORK_JOIN_POOL_CLASS_NAME.equals(classInfo.getClassName())) {
            this.updateConstructorDisableInheritable(classInfo.getCtClass());
            classInfo.setModified();
        }
    }

    private void updateForkJoinTaskClass(@NonNull CtClass clazz) throws CannotCompileException, NotFoundException {
        String className = clazz.getName();
        String capturedFieldName = "captured$field$added$by$ttl";
        CtField capturedField = CtField.make("private final Object captured$field$added$by$ttl;", clazz);
        clazz.addField(capturedField, "com.alibaba.ttl.threadpool.agent.internal.transformlet.impl.Utils.doCaptureWhenNotTtlEnhanced(this);");
        logger.info("add new field captured$field$added$by$ttl to class " + className);
        CtMethod doExecMethod = clazz.getDeclaredMethod("doExec", new CtClass[0]);
        String doExec_renamed_method_name = Utils.renamedMethodNameByTtl(doExecMethod);
        String beforeCode = "if (this instanceof " + TtlEnhanced.class.getName() + ") {\n    return " + doExec_renamed_method_name + "($$);\n}\nObject backup = com.alibaba.ttl.TransmittableThreadLocal.Transmitter.replay(" + "captured$field$added$by$ttl" + ");";
        String finallyCode = "com.alibaba.ttl.TransmittableThreadLocal.Transmitter.restore(backup);";
        Utils.doTryFinallyForMethod(doExecMethod, doExec_renamed_method_name, beforeCode, "com.alibaba.ttl.TransmittableThreadLocal.Transmitter.restore(backup);");
    }

    private void updateConstructorDisableInheritable(@NonNull CtClass clazz) throws NotFoundException, CannotCompileException {
        for (CtConstructor constructor : clazz.getDeclaredConstructors()) {
            CtClass[] parameterTypes = constructor.getParameterTypes();
            StringBuilder insertCode = new StringBuilder();
            for (int i = 0; i < parameterTypes.length; ++i) {
                String paramTypeName = parameterTypes[i].getName();
                if (!FORK_JOIN_WORKER_THREAD_FACTORY_CLASS_NAME.equals(paramTypeName)) continue;
                String code = String.format("$%d = com.alibaba.ttl.threadpool.TtlForkJoinPoolHelper.getDisableInheritableForkJoinWorkerThreadFactory($%<d);", i + 1);
                logger.info("insert code before method " + Utils.signatureOfMethod(constructor) + " of class " + constructor.getDeclaringClass().getName() + ": " + code);
                insertCode.append(code);
            }
            if (insertCode.length() <= 0) continue;
            constructor.insertBefore(insertCode.toString());
        }
    }
}

