/*
 * Decompiled with CFR 0.152.
 */
package org.apache.tez.test;

import java.io.BufferedWriter;
import java.io.IOException;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collections;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Random;
import java.util.concurrent.TimeUnit;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FSDataOutputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.hdfs.MiniDFSCluster;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat;
import org.apache.hadoop.yarn.api.records.ApplicationAttemptId;
import org.apache.hadoop.yarn.api.records.ApplicationId;
import org.apache.hadoop.yarn.client.api.YarnClient;
import org.apache.tez.client.TezClient;
import org.apache.tez.client.TezClientUtils;
import org.apache.tez.common.Preconditions;
import org.apache.tez.common.TezCommonUtils;
import org.apache.tez.common.TezUtils;
import org.apache.tez.common.counters.DAGCounter;
import org.apache.tez.common.counters.TezCounters;
import org.apache.tez.dag.api.DAG;
import org.apache.tez.dag.api.DataSinkDescriptor;
import org.apache.tez.dag.api.DataSourceDescriptor;
import org.apache.tez.dag.api.Edge;
import org.apache.tez.dag.api.EdgeProperty;
import org.apache.tez.dag.api.ProcessorDescriptor;
import org.apache.tez.dag.api.TezConfiguration;
import org.apache.tez.dag.api.UserPayload;
import org.apache.tez.dag.api.Vertex;
import org.apache.tez.dag.api.client.DAGClient;
import org.apache.tez.dag.api.client.DAGStatus;
import org.apache.tez.dag.api.client.StatusGetOpts;
import org.apache.tez.dag.api.oldrecords.TaskAttemptState;
import org.apache.tez.dag.app.RecoveryParser;
import org.apache.tez.dag.history.HistoryEvent;
import org.apache.tez.dag.history.HistoryEventType;
import org.apache.tez.dag.history.events.TaskAttemptFinishedEvent;
import org.apache.tez.mapreduce.input.MRInput;
import org.apache.tez.mapreduce.output.MROutput;
import org.apache.tez.mapreduce.processor.SimpleMRProcessor;
import org.apache.tez.runtime.api.LogicalInput;
import org.apache.tez.runtime.api.LogicalOutput;
import org.apache.tez.runtime.api.ProcessorContext;
import org.apache.tez.runtime.library.api.KeyValueReader;
import org.apache.tez.runtime.library.api.KeyValueWriter;
import org.apache.tez.runtime.library.api.KeyValuesReader;
import org.apache.tez.runtime.library.conf.OrderedPartitionedKVEdgeConfig;
import org.apache.tez.runtime.library.conf.UnorderedKVEdgeConfig;
import org.apache.tez.runtime.library.partitioner.HashPartitioner;
import org.apache.tez.test.MiniTezCluster;
import org.junit.After;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class TestAMRecoveryAggregationBroadcast {
    private static final Logger LOG = LoggerFactory.getLogger(TestAMRecoveryAggregationBroadcast.class);
    private static final String INPUT1 = "Input";
    private static final String INPUT2 = "Input";
    private static final String OUTPUT = "Output";
    private static final String TABLE_SCAN = "TableScan";
    private static final String AGGREGATION = "Aggregation";
    private static final String MAP_JOIN = "MapJoin";
    private static final String TEST_ROOT_DIR = "target/" + TestAMRecoveryAggregationBroadcast.class.getName() + "-tmpDir";
    private static final Path INPUT_FILE = new Path(TEST_ROOT_DIR, "input.csv");
    private static final Path OUT_PATH = new Path(TEST_ROOT_DIR, "out-groups");
    private static final String EXPECTED_OUTPUT = "1-5\n1-5\n1-5\n1-5\n1-5\n2-4\n2-4\n2-4\n2-4\n3-3\n3-3\n3-3\n4-2\n4-2\n5-1\n";
    private static final String TABLE_SCAN_SLEEP = "tez.test.table.scan.sleep";
    private static final String AGGREGATION_SLEEP = "tez.test.aggregation.sleep";
    private static final String MAP_JOIN_SLEEP = "tez.test.map.join.sleep";
    private static Configuration dfsConf;
    private static MiniDFSCluster dfsCluster;
    private static MiniTezCluster tezCluster;
    private static FileSystem remoteFs;
    private TezConfiguration tezConf;
    private TezClient tezSession;

    @BeforeClass
    public static void setupAll() {
        try {
            dfsConf = new Configuration();
            dfsConf.set("hdfs.minidfs.basedir", TEST_ROOT_DIR);
            dfsCluster = new MiniDFSCluster.Builder(dfsConf).numDataNodes(3).format(true).racks(null).build();
            remoteFs = dfsCluster.getFileSystem();
            TestAMRecoveryAggregationBroadcast.createSampleFile();
        }
        catch (IOException io) {
            throw new RuntimeException("problem starting mini dfs cluster", io);
        }
        if (tezCluster == null) {
            tezCluster = new MiniTezCluster(TestAMRecoveryAggregationBroadcast.class.getName(), 1, 1, 1);
            Configuration conf = new Configuration(dfsConf);
            conf.set("fs.defaultFS", remoteFs.getUri().toString());
            conf.setInt("yarn.nodemanager.delete.debug-delay-sec", 20000);
            conf.setLong("tez.am.sleep.time.before.exit.millis", 500L);
            tezCluster.init(conf);
            tezCluster.start();
        }
    }

    private static void createSampleFile() throws IOException {
        FSDataOutputStream out = remoteFs.create(INPUT_FILE);
        BufferedWriter writer = new BufferedWriter(new OutputStreamWriter((OutputStream)out));
        for (int i = 1; i <= 5; ++i) {
            for (int j = 0; j <= 5 - i; ++j) {
                writer.write(String.valueOf(i));
                writer.newLine();
            }
        }
        writer.close();
    }

    @AfterClass
    public static void tearDownAll() {
        if (tezCluster != null) {
            tezCluster.stop();
            tezCluster = null;
        }
        if (dfsCluster != null) {
            dfsCluster.shutdown(true);
            dfsCluster = null;
        }
    }

    @Before
    public void setup() throws Exception {
        Path remoteStagingDir = remoteFs.makeQualified(new Path(TEST_ROOT_DIR, String.valueOf(new Random().nextInt(100000))));
        TezClientUtils.ensureStagingDirExists((Configuration)dfsConf, (Path)remoteStagingDir);
        this.tezConf = new TezConfiguration(tezCluster.getConfig());
        this.tezConf.setInt("tez.dag.recovery.max.unflushed.events", 0);
        this.tezConf.set("tez.am.log.level", "INFO");
        this.tezConf.set("tez.staging-dir", remoteStagingDir.toString());
        this.tezConf.setInt("tez.am.resource.memory.mb", 500);
        this.tezConf.set("tez.am.launch.cmd-opts", " -Xmx256m");
        this.tezConf.setBoolean("tez.am.staging.scratch-data.auto-delete", false);
        this.tezConf.setBoolean("tez.test.recovery.drain_event", true);
        this.tezSession = TezClient.create((String)"TestAMRecoveryAggregationBroadcast", (TezConfiguration)this.tezConf);
        this.tezSession.start();
    }

    @After
    public void teardown() throws InterruptedException {
        if (this.tezSession != null) {
            try {
                LOG.info("Stopping Tez Session");
                this.tezSession.stop();
            }
            catch (Exception e) {
                LOG.error("Failed to stop Tez session", (Throwable)e);
            }
        }
        this.tezSession = null;
    }

    @Test(timeout=120000L)
    public void testSucceed() throws Exception {
        DAG dag = this.createDAG("Succeed");
        TezCounters counters = this.runDAGAndVerify(dag, false);
        Assert.assertEquals((long)3L, (long)counters.findCounter((Enum)DAGCounter.NUM_SUCCEEDED_TASKS).getValue());
        List<HistoryEvent> historyEvents1 = this.readRecoveryLog(1);
        Assert.assertEquals((long)1L, (long)this.findTaskAttemptFinishedEvent(historyEvents1, 0, 0).size());
        Assert.assertEquals((long)1L, (long)this.findTaskAttemptFinishedEvent(historyEvents1, 1, 0).size());
        Assert.assertEquals((long)1L, (long)this.findTaskAttemptFinishedEvent(historyEvents1, 2, 0).size());
        Assert.assertEquals(Collections.emptyList(), this.readRecoveryLog(2));
    }

    @Test(timeout=120000L)
    public void testTableScanTemporalFailure() throws Exception {
        this.tezConf.setBoolean(TABLE_SCAN_SLEEP, true);
        DAG dag = this.createDAG("TableScanTemporalFailure");
        TezCounters counters = this.runDAGAndVerify(dag, true);
        Assert.assertEquals((long)3L, (long)counters.findCounter((Enum)DAGCounter.NUM_SUCCEEDED_TASKS).getValue());
        List<HistoryEvent> historyEvents1 = this.readRecoveryLog(1);
        Assert.assertEquals((long)0L, (long)this.findTaskAttemptFinishedEvent(historyEvents1, 0, 0).size());
        Assert.assertEquals((long)0L, (long)this.findTaskAttemptFinishedEvent(historyEvents1, 1, 0).size());
        Assert.assertEquals((long)0L, (long)this.findTaskAttemptFinishedEvent(historyEvents1, 2, 0).size());
        List<HistoryEvent> historyEvents2 = this.readRecoveryLog(2);
        Assert.assertEquals((long)1L, (long)this.findTaskAttemptFinishedEvent(historyEvents2, 0, 0).size());
        Assert.assertEquals((long)1L, (long)this.findTaskAttemptFinishedEvent(historyEvents2, 1, 0).size());
        Assert.assertEquals((long)1L, (long)this.findTaskAttemptFinishedEvent(historyEvents2, 2, 0).size());
        Assert.assertEquals(Collections.emptyList(), this.readRecoveryLog(3));
    }

    @Test(timeout=120000L)
    public void testAggregationTemporalFailure() throws Exception {
        this.tezConf.setBoolean(AGGREGATION_SLEEP, true);
        DAG dag = this.createDAG("AggregationTemporalFailure");
        TezCounters counters = this.runDAGAndVerify(dag, true);
        Assert.assertEquals((long)3L, (long)counters.findCounter((Enum)DAGCounter.NUM_SUCCEEDED_TASKS).getValue());
        List<HistoryEvent> historyEvents1 = this.readRecoveryLog(1);
        Assert.assertEquals((long)1L, (long)this.findTaskAttemptFinishedEvent(historyEvents1, 0, 0).size());
        Assert.assertEquals((long)0L, (long)this.findTaskAttemptFinishedEvent(historyEvents1, 1, 0).size());
        Assert.assertEquals((long)0L, (long)this.findTaskAttemptFinishedEvent(historyEvents1, 2, 0).size());
        List<HistoryEvent> historyEvents2 = this.readRecoveryLog(2);
        Assert.assertEquals((long)0L, (long)this.findTaskAttemptFinishedEvent(historyEvents2, 0, 0).size());
        Assert.assertEquals((long)1L, (long)this.findTaskAttemptFinishedEvent(historyEvents2, 1, 0).size());
        Assert.assertEquals((long)1L, (long)this.findTaskAttemptFinishedEvent(historyEvents2, 2, 0).size());
        Assert.assertEquals(Collections.emptyList(), this.readRecoveryLog(3));
    }

    @Test(timeout=120000L)
    public void testMapJoinTemporalFailure() throws Exception {
        this.tezConf.setBoolean(MAP_JOIN_SLEEP, true);
        DAG dag = this.createDAG("MapJoinTemporalFailure");
        TezCounters counters = this.runDAGAndVerify(dag, true);
        Assert.assertEquals((long)3L, (long)counters.findCounter((Enum)DAGCounter.NUM_SUCCEEDED_TASKS).getValue());
        List<HistoryEvent> historyEvents1 = this.readRecoveryLog(1);
        Assert.assertEquals((long)1L, (long)this.findTaskAttemptFinishedEvent(historyEvents1, 0, 0).size());
        Assert.assertEquals((long)1L, (long)this.findTaskAttemptFinishedEvent(historyEvents1, 1, 0).size());
        Assert.assertEquals((long)0L, (long)this.findTaskAttemptFinishedEvent(historyEvents1, 2, 0).size());
        List<HistoryEvent> historyEvents2 = this.readRecoveryLog(2);
        Assert.assertEquals((long)0L, (long)this.findTaskAttemptFinishedEvent(historyEvents2, 0, 0).size());
        Assert.assertEquals((long)0L, (long)this.findTaskAttemptFinishedEvent(historyEvents2, 1, 0).size());
        Assert.assertEquals((long)1L, (long)this.findTaskAttemptFinishedEvent(historyEvents2, 2, 0).size());
        Assert.assertEquals(Collections.emptyList(), this.readRecoveryLog(3));
    }

    private DAG createDAG(String dagName) throws Exception {
        UserPayload payload = TezUtils.createUserPayloadFromConf((Configuration)this.tezConf);
        DataSourceDescriptor dataSource = MRInput.createConfigBuilder((Configuration)new Configuration((Configuration)this.tezConf), TextInputFormat.class, (String)INPUT_FILE.toString()).build();
        Vertex tableScanVertex = Vertex.create((String)TABLE_SCAN, (ProcessorDescriptor)((ProcessorDescriptor)ProcessorDescriptor.create((String)TableScanProcessor.class.getName()).setUserPayload(payload))).addDataSource("Input", dataSource);
        Vertex aggregationVertex = Vertex.create((String)AGGREGATION, (ProcessorDescriptor)((ProcessorDescriptor)ProcessorDescriptor.create((String)AggregationProcessor.class.getName()).setUserPayload(payload)), (int)1);
        DataSinkDescriptor dataSink = MROutput.createConfigBuilder((Configuration)new Configuration((Configuration)this.tezConf), TextOutputFormat.class, (String)OUT_PATH.toString()).build();
        Vertex mapJoinVertex = Vertex.create((String)MAP_JOIN, (ProcessorDescriptor)((ProcessorDescriptor)ProcessorDescriptor.create((String)MapJoinProcessor.class.getName()).setUserPayload(payload))).addDataSource("Input", dataSource).addDataSink(OUTPUT, dataSink);
        EdgeProperty orderedEdge = OrderedPartitionedKVEdgeConfig.newBuilder((String)Text.class.getName(), (String)IntWritable.class.getName(), (String)HashPartitioner.class.getName()).setFromConfiguration((Configuration)this.tezConf).build().createDefaultEdgeProperty();
        EdgeProperty broadcastEdge = UnorderedKVEdgeConfig.newBuilder((String)Text.class.getName(), (String)IntWritable.class.getName()).setFromConfiguration((Configuration)this.tezConf).build().createDefaultBroadcastEdgeProperty();
        DAG dag = DAG.create((String)("TestAMRecoveryAggregationBroadcast_" + dagName));
        dag.addVertex(tableScanVertex).addVertex(aggregationVertex).addVertex(mapJoinVertex).addEdge(Edge.create((Vertex)tableScanVertex, (Vertex)aggregationVertex, (EdgeProperty)orderedEdge)).addEdge(Edge.create((Vertex)aggregationVertex, (Vertex)mapJoinVertex, (EdgeProperty)broadcastEdge));
        return dag;
    }

    TezCounters runDAGAndVerify(DAG dag, boolean killAM) throws Exception {
        this.tezSession.waitTillReady();
        DAGClient dagClient = this.tezSession.submitDAG(dag);
        if (killAM) {
            TimeUnit.SECONDS.sleep(10L);
            YarnClient yarnClient = YarnClient.createYarnClient();
            yarnClient.init((Configuration)this.tezConf);
            yarnClient.start();
            ApplicationAttemptId id = ApplicationAttemptId.newInstance((ApplicationId)this.tezSession.getAppMasterApplicationId(), (int)1);
            yarnClient.failApplicationAttempt(id);
            yarnClient.close();
        }
        DAGStatus dagStatus = dagClient.waitForCompletionWithStatusUpdates(EnumSet.of(StatusGetOpts.GET_COUNTERS));
        LOG.info("Diagnosis: " + dagStatus.getDiagnostics());
        Assert.assertEquals((Object)DAGStatus.State.SUCCEEDED, (Object)dagStatus.getState());
        FSDataInputStream in = remoteFs.open(new Path(OUT_PATH, "part-v002-o000-r-00000"));
        ByteBuffer buf = ByteBuffer.allocate(100);
        in.read(buf);
        buf.flip();
        Assert.assertEquals((Object)EXPECTED_OUTPUT, (Object)StandardCharsets.UTF_8.decode(buf).toString());
        return dagStatus.getDAGCounters();
    }

    private List<HistoryEvent> readRecoveryLog(int attemptNum) throws IOException {
        ApplicationId appId = this.tezSession.getAppMasterApplicationId();
        Path tezSystemStagingDir = TezCommonUtils.getTezSystemStagingPath((Configuration)this.tezConf, (String)appId.toString());
        Path recoveryDataDir = TezCommonUtils.getRecoveryPath((Path)tezSystemStagingDir, (Configuration)this.tezConf);
        FileSystem fs = tezSystemStagingDir.getFileSystem((Configuration)this.tezConf);
        ArrayList<HistoryEvent> historyEvents = new ArrayList<HistoryEvent>();
        Path currentAttemptRecoveryDataDir = TezCommonUtils.getAttemptRecoveryPath((Path)recoveryDataDir, (int)attemptNum);
        Path recoveryFilePath = new Path(currentAttemptRecoveryDataDir, appId.toString().replace("application", "dag") + "_1" + ".recovery");
        if (fs.exists(recoveryFilePath)) {
            LOG.info("Read recovery file:" + recoveryFilePath);
            historyEvents.addAll(RecoveryParser.parseDAGRecoveryFile((FSDataInputStream)fs.open(recoveryFilePath)));
        }
        this.printHistoryEvents(historyEvents, attemptNum);
        return historyEvents;
    }

    private void printHistoryEvents(List<HistoryEvent> historyEvents, int attemptId) {
        LOG.info("RecoveryLogs from attempt:" + attemptId);
        for (HistoryEvent historyEvent : historyEvents) {
            LOG.info("Parsed event from recovery stream, eventType=" + historyEvent.getEventType() + ", event=" + historyEvent);
        }
        LOG.info("");
    }

    private List<TaskAttemptFinishedEvent> findTaskAttemptFinishedEvent(List<HistoryEvent> historyEvents, int vertexId, int taskId) {
        ArrayList<TaskAttemptFinishedEvent> resultEvents = new ArrayList<TaskAttemptFinishedEvent>();
        for (HistoryEvent historyEvent : historyEvents) {
            TaskAttemptFinishedEvent taFinishedEvent;
            if (historyEvent.getEventType() != HistoryEventType.TASK_ATTEMPT_FINISHED || (taFinishedEvent = (TaskAttemptFinishedEvent)historyEvent).getState() == TaskAttemptState.KILLED || taFinishedEvent.getVertexID().getId() != vertexId || taFinishedEvent.getTaskID().getId() != taskId) continue;
            resultEvents.add(taFinishedEvent);
        }
        return resultEvents;
    }

    public static class MapJoinProcessor
    extends SimpleMRProcessor {
        private final boolean sleep;

        public MapJoinProcessor(ProcessorContext context) {
            super(context);
            try {
                Configuration conf = TezUtils.createConfFromUserPayload((UserPayload)this.getContext().getUserPayload());
                this.sleep = conf.getBoolean(TestAMRecoveryAggregationBroadcast.MAP_JOIN_SLEEP, false);
            }
            catch (IOException e) {
                throw new RuntimeException(e);
            }
        }

        public void run() throws Exception {
            if (this.getContext().getDAGAttemptNumber() == 1 && this.sleep) {
                TimeUnit.SECONDS.sleep(60L);
            }
            Preconditions.checkArgument((this.getInputs().size() == 2 ? 1 : 0) != 0);
            Preconditions.checkArgument((this.getOutputs().size() == 1 ? 1 : 0) != 0);
            KeyValueReader broadcastKvReader = (KeyValueReader)((LogicalInput)this.getInputs().get(TestAMRecoveryAggregationBroadcast.AGGREGATION)).getReader();
            HashMap<String, Integer> countMap = new HashMap<String, Integer>();
            while (broadcastKvReader.next()) {
                String key = broadcastKvReader.getCurrentKey().toString();
                int value = ((IntWritable)broadcastKvReader.getCurrentValue()).get();
                countMap.put(key, value);
            }
            KeyValueReader kvReader = (KeyValueReader)((LogicalInput)this.getInputs().get("Input")).getReader();
            KeyValueWriter kvWriter = (KeyValueWriter)((LogicalOutput)this.getOutputs().get(TestAMRecoveryAggregationBroadcast.OUTPUT)).getWriter();
            while (kvReader.next()) {
                String line = kvReader.getCurrentValue().toString();
                int count = countMap.getOrDefault(line, 0);
                kvWriter.write((Object)NullWritable.get(), (Object)String.format("%s-%d", line, count));
            }
        }
    }

    public static class AggregationProcessor
    extends SimpleMRProcessor {
        private final boolean sleep;

        public AggregationProcessor(ProcessorContext context) {
            super(context);
            try {
                Configuration conf = TezUtils.createConfFromUserPayload((UserPayload)this.getContext().getUserPayload());
                this.sleep = conf.getBoolean(TestAMRecoveryAggregationBroadcast.AGGREGATION_SLEEP, false);
            }
            catch (IOException e) {
                throw new RuntimeException(e);
            }
        }

        public void run() throws Exception {
            if (this.getContext().getDAGAttemptNumber() == 1 && this.sleep) {
                TimeUnit.SECONDS.sleep(60L);
            }
            Preconditions.checkArgument((this.getInputs().size() == 1 ? 1 : 0) != 0);
            Preconditions.checkArgument((this.getOutputs().size() == 1 ? 1 : 0) != 0);
            KeyValuesReader kvReader = (KeyValuesReader)((LogicalInput)this.getInputs().get(TestAMRecoveryAggregationBroadcast.TABLE_SCAN)).getReader();
            KeyValueWriter kvWriter = (KeyValueWriter)((LogicalOutput)this.getOutputs().get(TestAMRecoveryAggregationBroadcast.MAP_JOIN)).getWriter();
            while (kvReader.next()) {
                Text word = (Text)kvReader.getCurrentKey();
                int sum = 0;
                for (Object value : kvReader.getCurrentValues()) {
                    sum += ((IntWritable)value).get();
                }
                kvWriter.write((Object)word, (Object)new IntWritable(sum));
            }
        }
    }

    public static class TableScanProcessor
    extends SimpleMRProcessor {
        private static final IntWritable one = new IntWritable(1);
        private final boolean sleep;

        public TableScanProcessor(ProcessorContext context) {
            super(context);
            try {
                Configuration conf = TezUtils.createConfFromUserPayload((UserPayload)this.getContext().getUserPayload());
                this.sleep = conf.getBoolean(TestAMRecoveryAggregationBroadcast.TABLE_SCAN_SLEEP, false);
            }
            catch (IOException e) {
                throw new RuntimeException(e);
            }
        }

        public void run() throws Exception {
            if (this.getContext().getDAGAttemptNumber() == 1 && this.sleep) {
                TimeUnit.SECONDS.sleep(60L);
            }
            Preconditions.checkArgument((this.getInputs().size() == 1 ? 1 : 0) != 0);
            Preconditions.checkArgument((this.getOutputs().size() == 1 ? 1 : 0) != 0);
            KeyValueReader kvReader = (KeyValueReader)((LogicalInput)this.getInputs().get("Input")).getReader();
            KeyValueWriter kvWriter = (KeyValueWriter)((LogicalOutput)this.getOutputs().get(TestAMRecoveryAggregationBroadcast.AGGREGATION)).getWriter();
            while (kvReader.next()) {
                Text line = (Text)kvReader.getCurrentValue();
                kvWriter.write((Object)line, (Object)one);
            }
        }
    }
}

