package org.apache.storm.grouping;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import org.apache.storm.daemon.GrouperFactory;
import org.apache.storm.generated.GlobalStreamId;
import org.apache.storm.generated.Grouping;
import org.apache.storm.generated.NodeInfo;
import org.apache.storm.generated.NullStruct;
import org.apache.storm.grouping.LoadAwareShuffleGrouping;
import org.apache.storm.shade.com.google.common.collect.Lists;
import org.apache.storm.shade.com.google.common.collect.Sets;
import org.apache.storm.shade.com.google.common.util.concurrent.MoreExecutors;
import org.apache.storm.task.WorkerTopologyContext;
import org.apache.storm.tuple.Fields;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/storm/grouping/LoadAwareShuffleGroupingTest.class */
public class LoadAwareShuffleGroupingTest {
    public static final double ACCEPTABLE_MARGIN = 0.015d;
    private static final Logger LOG = LoggerFactory.getLogger(LoadAwareShuffleGroupingTest.class);

    private Map<String, Object> createConf() {
        HashMap hashMap = new HashMap();
        hashMap.put("storm.network.topography.plugin", "org.apache.storm.networktopography.DefaultRackDNSToSwitchMapping");
        hashMap.put("topology.localityaware.higher.bound", Double.valueOf(0.8d));
        hashMap.put("topology.localityaware.lower.bound", Double.valueOf(0.2d));
        return hashMap;
    }

    private WorkerTopologyContext mockContext(List<Integer> list) {
        WorkerTopologyContext workerTopologyContext = (WorkerTopologyContext) Mockito.mock(WorkerTopologyContext.class);
        Mockito.when(workerTopologyContext.getConf()).thenReturn(createConf());
        HashMap hashMap = new HashMap();
        NodeInfo nodeInfo = new NodeInfo("node-id", Sets.newHashSet(new Long[]{6700L}));
        list.forEach(num -> {
        });
        Mockito.when(workerTopologyContext.getTaskToNodePort()).thenReturn(new AtomicReference(hashMap));
        Mockito.when(workerTopologyContext.getAssignmentId()).thenReturn("node-id");
        Mockito.when(workerTopologyContext.getThisWorkerPort()).thenReturn(6700);
        Mockito.when(workerTopologyContext.getNodeToHost()).thenReturn(new AtomicReference(Collections.singletonMap("node-id", "hostname1")));
        return workerTopologyContext;
    }

    @Test
    public void testUnevenLoadOverTime() {
        LoadAwareShuffleGrouping loadAwareShuffleGrouping = new LoadAwareShuffleGrouping();
        loadAwareShuffleGrouping.prepare(mockContext(Arrays.asList(1, 2)), new GlobalStreamId("a", "default"), Arrays.asList(1, 2));
        double d = 100.0d;
        double d2 = 100.0d;
        HashMap hashMap = new HashMap();
        hashMap.put(1, Double.valueOf(1.0d));
        hashMap.put(2, Double.valueOf(0.0d));
        LoadMapping loadMapping = new LoadMapping();
        loadMapping.setLocal(hashMap);
        for (int i = 9; i >= 0; i--) {
            loadAwareShuffleGrouping.refreshLoad(loadMapping);
            d -= 10.0d;
            Map<Integer, Double> count = count(loadAwareShuffleGrouping.choices, loadAwareShuffleGrouping.rets);
            LOG.info("contByType = {}", count);
            Assertions.assertEquals(d / (d + 100.0d), count.getOrDefault(1, Double.valueOf(0.0d)).doubleValue() / loadAwareShuffleGrouping.getCapacity(), 0.01d, "i = " + i);
            Assertions.assertEquals(100.0d / (d + 100.0d), count.getOrDefault(2, Double.valueOf(0.0d)).doubleValue() / loadAwareShuffleGrouping.getCapacity(), 0.01d, "i = " + i);
        }
        hashMap.put(1, Double.valueOf(0.0d));
        hashMap.put(2, Double.valueOf(1.0d));
        loadMapping.setLocal(hashMap);
        while (d < 100.0d) {
            loadAwareShuffleGrouping.refreshLoad(loadMapping);
            d += 1.0d;
            d2 = Math.max(0.0d, d2 - 10.0d);
            Map<Integer, Double> count2 = count(loadAwareShuffleGrouping.choices, loadAwareShuffleGrouping.rets);
            LOG.info("contByType = {}", count2);
            Assertions.assertEquals(d / (d + d2), count2.getOrDefault(1, Double.valueOf(0.0d)).doubleValue() / loadAwareShuffleGrouping.getCapacity(), 0.01d);
            Assertions.assertEquals(d2 / (d + d2), count2.getOrDefault(2, Double.valueOf(0.0d)).doubleValue() / loadAwareShuffleGrouping.getCapacity(), 0.01d);
        }
    }

    private Map<Integer, Double> count(int[] iArr, List<Integer>[] listArr) {
        HashMap hashMap = new HashMap();
        for (int i : iArr) {
            int intValue = listArr[i].get(0).intValue();
            hashMap.put(Integer.valueOf(intValue), Double.valueOf(((Double) hashMap.getOrDefault(Integer.valueOf(intValue), Double.valueOf(0.0d))).doubleValue() + 1.0d));
        }
        return hashMap;
    }

    @Test
    public void testLoadAwareShuffleGroupingWithEvenLoadWithManyTargets() {
        testLoadAwareShuffleGroupingWithEvenLoad(1000);
    }

    @Test
    public void testLoadAwareShuffleGroupingWithEvenLoadWithLessTargets() {
        testLoadAwareShuffleGroupingWithEvenLoad(7);
    }

    private void testLoadAwareShuffleGroupingWithEvenLoad(int i) {
        LoadAwareShuffleGrouping loadAwareShuffleGrouping = new LoadAwareShuffleGrouping();
        List<Integer> availableTaskIds = getAvailableTaskIds(i);
        LoadMapping buildLocalTasksEvenLoadMapping = buildLocalTasksEvenLoadMapping(availableTaskIds);
        loadAwareShuffleGrouping.prepare(mockContext(availableTaskIds), (GlobalStreamId) null, availableTaskIds);
        int i2 = i * 5000;
        int i3 = (int) (i2 * ((1.0d / i) - 0.015d));
        int i4 = (int) (i2 * ((1.0d / i) + 0.015d));
        int[] runChooseTasksWithVerification = runChooseTasksWithVerification(loadAwareShuffleGrouping, i2, i, buildLocalTasksEvenLoadMapping);
        for (int i5 = 0; i5 < i; i5++) {
            Assertions.assertTrue(runChooseTasksWithVerification[i5] >= i3 && runChooseTasksWithVerification[i5] <= i4, "Distribution should be even for all nodes with small delta");
        }
    }

    @Test
    public void testLoadAwareShuffleGroupingWithEvenLoadMultiThreadedWithManyTargets() throws ExecutionException, InterruptedException {
        testLoadAwareShuffleGroupingWithEvenLoadMultiThreaded(1000);
    }

    @Test
    public void testLoadAwareShuffleGroupingWithEvenLoadMultiThreadedWithLessTargets() throws ExecutionException, InterruptedException {
        testLoadAwareShuffleGroupingWithEvenLoadMultiThreaded(7);
    }

    private void testLoadAwareShuffleGroupingWithEvenLoadMultiThreaded(int i) throws InterruptedException, ExecutionException {
        LoadAwareShuffleGrouping loadAwareShuffleGrouping = new LoadAwareShuffleGrouping();
        List<Integer> availableTaskIds = getAvailableTaskIds(i);
        LoadMapping buildLocalTasksEvenLoadMapping = buildLocalTasksEvenLoadMapping(availableTaskIds);
        loadAwareShuffleGrouping.prepare(mockContext(availableTaskIds), (GlobalStreamId) null, availableTaskIds);
        loadAwareShuffleGrouping.refreshLoad(buildLocalTasksEvenLoadMapping);
        int i2 = i * 5000;
        int i3 = i2 * 10;
        ArrayList newArrayList = Lists.newArrayList();
        for (int i4 = 0; i4 < 10; i4++) {
            newArrayList.add(() -> {
                int[] iArr = new int[availableTaskIds.size()];
                for (int i5 = 1; i5 <= i2; i5++) {
                    List chooseTasks = loadAwareShuffleGrouping.chooseTasks(100, Lists.newArrayList());
                    Assertions.assertNotNull(chooseTasks, "Not null taskId list returned");
                    Assertions.assertEquals(1, chooseTasks.size(), "Single task Id not returned");
                    int intValue = ((Integer) chooseTasks.get(0)).intValue();
                    Assertions.assertTrue(intValue >= 0 && intValue < availableTaskIds.size(), "TaskId should exist");
                    iArr[intValue] = iArr[intValue] + 1;
                }
                return iArr;
            });
        }
        List<Future> invokeAll = Executors.newFixedThreadPool(newArrayList.size()).invokeAll(newArrayList);
        int[] iArr = new int[i];
        for (Future future : invokeAll) {
            while (!future.isDone()) {
                Thread.sleep(1000L);
            }
            int[] iArr2 = (int[]) future.get();
            for (int i5 = 0; i5 < iArr2.length; i5++) {
                int i6 = i5;
                iArr[i6] = iArr[i6] + iArr2[i5];
            }
        }
        int i7 = (int) (i3 * ((1.0d / i) - 0.015d));
        int i8 = (int) (i3 * ((1.0d / i) + 0.015d));
        for (int i9 = 0; i9 < i; i9++) {
            Assertions.assertTrue(iArr[i9] >= i7 && iArr[i9] <= i8, "Distribution should be even for all nodes with small delta");
        }
    }

    @Test
    public void testShuffleLoadEven() {
        LoadAwareCustomStreamGrouping mkGrouper = GrouperFactory.mkGrouper(mockContext(Lists.newArrayList(new Integer[]{1, 2})), "comp", "stream", (Fields) null, Grouping.shuffle(new NullStruct()), Lists.newArrayList(new Integer[]{1, 2}), Collections.emptyMap());
        int i = (int) (100000 * 0.485d);
        int i2 = (int) (100000 * 0.515d);
        LoadMapping loadMapping = new LoadMapping();
        HashMap hashMap = new HashMap();
        hashMap.put(1, Double.valueOf(0.0d));
        hashMap.put(2, Double.valueOf(0.0d));
        loadMapping.setLocal(hashMap);
        mkGrouper.refreshLoad(loadMapping);
        ArrayList newArrayList = Lists.newArrayList(new Object[]{1, 2});
        int[] iArr = new int[3];
        for (int i3 = 0; i3 < 100000; i3++) {
            Iterator it = mkGrouper.chooseTasks(1, newArrayList).iterator();
            while (it.hasNext()) {
                int intValue = ((Integer) it.next()).intValue();
                iArr[intValue] = iArr[intValue] + 1;
            }
        }
        int i4 = iArr[1];
        int i5 = iArr[2];
        LOG.info("Frequency info: load1 = {}, load2 = {}", Integer.valueOf(i4), Integer.valueOf(i5));
        Assertions.assertTrue(i4 >= i);
        Assertions.assertTrue(i4 <= i2);
        Assertions.assertTrue(i5 >= i);
        Assertions.assertTrue(i5 <= i2);
    }

    @Disabled
    @Test
    public void testBenchmarkLoadAwareShuffleGroupingEvenLoad() {
        List<Integer> availableTaskIds = getAvailableTaskIds(10);
        runSimpleBenchmark(new LoadAwareShuffleGrouping(), availableTaskIds, buildLocalTasksEvenLoadMapping(availableTaskIds));
    }

    @Disabled
    @Test
    public void testBenchmarkLoadAwareShuffleGroupingUnevenLoad() {
        List<Integer> availableTaskIds = getAvailableTaskIds(10);
        runSimpleBenchmark(new LoadAwareShuffleGrouping(), availableTaskIds, buildLocalTasksUnevenLoadMapping(availableTaskIds));
    }

    @Disabled
    @Test
    public void testBenchmarkLoadAwareShuffleGroupingEvenLoadAndMultiThreaded() throws ExecutionException, InterruptedException {
        List<Integer> availableTaskIds = getAvailableTaskIds(10);
        runMultithreadedBenchmark(new LoadAwareShuffleGrouping(), availableTaskIds, buildLocalTasksEvenLoadMapping(availableTaskIds), 2);
    }

    @Disabled
    @Test
    public void testBenchmarkLoadAwareShuffleGroupingUnevenLoadAndMultiThreaded() throws ExecutionException, InterruptedException {
        List<Integer> availableTaskIds = getAvailableTaskIds(10);
        runMultithreadedBenchmark(new LoadAwareShuffleGrouping(), availableTaskIds, buildLocalTasksUnevenLoadMapping(availableTaskIds), 2);
    }

    private List<Integer> getAvailableTaskIds(int i) {
        ArrayList newArrayList = Lists.newArrayList();
        for (int i2 = 0; i2 < i; i2++) {
            newArrayList.add(Integer.valueOf(i2));
        }
        return newArrayList;
    }

    private LoadMapping buildLocalTasksEvenLoadMapping(List<Integer> list) {
        LoadMapping loadMapping = new LoadMapping();
        HashMap hashMap = new HashMap(list.size());
        Iterator<Integer> it = list.iterator();
        while (it.hasNext()) {
            hashMap.put(it.next(), Double.valueOf(0.1d));
        }
        loadMapping.setLocal(hashMap);
        return loadMapping;
    }

    private LoadMapping buildLocalTasksUnevenLoadMapping(List<Integer> list) {
        LoadMapping loadMapping = new LoadMapping();
        HashMap hashMap = new HashMap(list.size());
        for (int i = 0; i < list.size(); i++) {
            hashMap.put(list.get(i), Double.valueOf(0.1d * (i + 1)));
        }
        loadMapping.setLocal(hashMap);
        return loadMapping;
    }

    private int[] runChooseTasksWithVerification(LoadAwareShuffleGrouping loadAwareShuffleGrouping, int i, int i2, LoadMapping loadMapping) {
        int[] iArr = new int[i2];
        loadAwareShuffleGrouping.refreshLoad(loadMapping);
        for (int i3 = 1; i3 <= i; i3++) {
            List chooseTasks = loadAwareShuffleGrouping.chooseTasks(100, Lists.newArrayList());
            Assertions.assertNotNull(chooseTasks, "Not null taskId list returned");
            Assertions.assertEquals(1, chooseTasks.size(), "Single task Id not returned");
            int intValue = ((Integer) chooseTasks.get(0)).intValue();
            Assertions.assertTrue(intValue >= 0 && intValue < i2, "TaskId should exist");
            iArr[intValue] = iArr[intValue] + 1;
        }
        return iArr;
    }

    private void runSimpleBenchmark(LoadAwareCustomStreamGrouping loadAwareCustomStreamGrouping, List<Integer> list, LoadMapping loadMapping) {
        loadAwareCustomStreamGrouping.prepare(mockContext(list), (GlobalStreamId) null, list);
        ScheduledExecutorService exitingScheduledExecutorService = MoreExecutors.getExitingScheduledExecutorService(new ScheduledThreadPoolExecutor(1));
        exitingScheduledExecutorService.scheduleAtFixedRate(() -> {
            loadAwareCustomStreamGrouping.refreshLoad(loadMapping);
        }, 1L, 1L, TimeUnit.SECONDS);
        long currentTimeMillis = System.currentTimeMillis();
        int i = 0;
        while (true) {
            loadAwareCustomStreamGrouping.chooseTasks(100, Lists.newArrayList());
            i++;
            if (i % 100000 == 0 && System.currentTimeMillis() - currentTimeMillis >= 60000) {
                break;
            }
        }
        long currentTimeMillis2 = System.currentTimeMillis();
        for (int i2 = 1; i2 <= 2000000000; i2++) {
            loadAwareCustomStreamGrouping.chooseTasks(100, Lists.newArrayList());
        }
        LOG.info("Duration: {} ms", Long.valueOf(System.currentTimeMillis() - currentTimeMillis2));
        exitingScheduledExecutorService.shutdownNow();
    }

    private void runMultithreadedBenchmark(LoadAwareCustomStreamGrouping loadAwareCustomStreamGrouping, List<Integer> list, LoadMapping loadMapping, int i) throws InterruptedException, ExecutionException {
        loadAwareCustomStreamGrouping.prepare(mockContext(list), (GlobalStreamId) null, list);
        ScheduledExecutorService exitingScheduledExecutorService = MoreExecutors.getExitingScheduledExecutorService(new ScheduledThreadPoolExecutor(1));
        exitingScheduledExecutorService.scheduleAtFixedRate(() -> {
            loadAwareCustomStreamGrouping.refreshLoad(loadMapping);
        }, 1L, 1L, TimeUnit.SECONDS);
        long currentTimeMillis = System.currentTimeMillis();
        int i2 = 0;
        while (true) {
            loadAwareCustomStreamGrouping.chooseTasks(100, Lists.newArrayList());
            i2++;
            if (i2 % 100000 == 0 && System.currentTimeMillis() - currentTimeMillis >= 60000) {
                break;
            }
        }
        ArrayList newArrayList = Lists.newArrayList();
        for (int i3 = 0; i3 < i; i3++) {
            newArrayList.add(() -> {
                long currentTimeMillis2 = System.currentTimeMillis();
                for (int i4 = 1; i4 <= 2000000000; i4++) {
                    loadAwareCustomStreamGrouping.chooseTasks(100, Lists.newArrayList());
                }
                return Long.valueOf(System.currentTimeMillis() - currentTimeMillis2);
            });
        }
        Long l = 0L;
        for (Future future : Executors.newFixedThreadPool(newArrayList.size()).invokeAll(newArrayList)) {
            while (!future.isDone()) {
                Thread.sleep(100L);
            }
            Long l2 = (Long) future.get();
            if (l.longValue() < l2.longValue()) {
                l = l2;
            }
        }
        LOG.info("Max duration among threads is : {} ms", l);
        exitingScheduledExecutorService.shutdownNow();
    }

    @Test
    public void testLoadSwitching() {
        LoadAwareShuffleGrouping loadAwareShuffleGrouping = new LoadAwareShuffleGrouping();
        loadAwareShuffleGrouping.prepare(createLoadSwitchingContext(), new GlobalStreamId("a", "default"), Arrays.asList(1, 2, 3));
        Assertions.assertEquals(LoadAwareShuffleGrouping.LocalityScope.WORKER_LOCAL, loadAwareShuffleGrouping.getCurrentScope());
        LoadMapping createLoadMapping = createLoadMapping(1.0d, 1.0d, 1.0d);
        loadAwareShuffleGrouping.refreshLoad(createLoadMapping);
        Assertions.assertEquals(LoadAwareShuffleGrouping.LocalityScope.HOST_LOCAL, loadAwareShuffleGrouping.getCurrentScope());
        loadAwareShuffleGrouping.refreshLoad(createLoadMapping);
        Assertions.assertEquals(LoadAwareShuffleGrouping.LocalityScope.RACK_LOCAL, loadAwareShuffleGrouping.getCurrentScope());
        loadAwareShuffleGrouping.refreshLoad(createLoadMapping);
        Assertions.assertEquals(LoadAwareShuffleGrouping.LocalityScope.EVERYTHING, loadAwareShuffleGrouping.getCurrentScope());
        LoadMapping createLoadMapping2 = createLoadMapping(0.2d, 0.1d, 0.1d);
        loadAwareShuffleGrouping.refreshLoad(createLoadMapping2);
        Assertions.assertEquals(LoadAwareShuffleGrouping.LocalityScope.RACK_LOCAL, loadAwareShuffleGrouping.getCurrentScope());
        loadAwareShuffleGrouping.refreshLoad(createLoadMapping2);
        Assertions.assertEquals(LoadAwareShuffleGrouping.LocalityScope.HOST_LOCAL, loadAwareShuffleGrouping.getCurrentScope());
        loadAwareShuffleGrouping.refreshLoad(createLoadMapping2);
        Assertions.assertEquals(LoadAwareShuffleGrouping.LocalityScope.HOST_LOCAL, loadAwareShuffleGrouping.getCurrentScope());
        loadAwareShuffleGrouping.refreshLoad(createLoadMapping(0.1d, 0.1d, 0.1d));
        Assertions.assertEquals(LoadAwareShuffleGrouping.LocalityScope.WORKER_LOCAL, loadAwareShuffleGrouping.getCurrentScope());
    }

    private LoadMapping createLoadMapping(double d, double d2, double d3) {
        HashMap hashMap = new HashMap();
        hashMap.put(1, Double.valueOf(d));
        hashMap.put(2, Double.valueOf(d2));
        hashMap.put(3, Double.valueOf(d3));
        LoadMapping loadMapping = new LoadMapping();
        loadMapping.setLocal(hashMap);
        return loadMapping;
    }

    private WorkerTopologyContext createLoadSwitchingContext() {
        WorkerTopologyContext workerTopologyContext = (WorkerTopologyContext) Mockito.mock(WorkerTopologyContext.class);
        Mockito.when(workerTopologyContext.getConf()).thenReturn(createConf());
        HashMap hashMap = new HashMap();
        hashMap.put(1, new NodeInfo("node-id", Sets.newHashSet(new Long[]{6701L})));
        hashMap.put(2, new NodeInfo("node-id", Sets.newHashSet(new Long[]{6702L})));
        hashMap.put(3, new NodeInfo("node-id2", Sets.newHashSet(new Long[]{6703L})));
        Mockito.when(workerTopologyContext.getTaskToNodePort()).thenReturn(new AtomicReference(hashMap));
        Mockito.when(workerTopologyContext.getAssignmentId()).thenReturn("node-id");
        Mockito.when(workerTopologyContext.getThisWorkerPort()).thenReturn(6701);
        HashMap hashMap2 = new HashMap();
        hashMap2.put("node-id", "hostname1");
        hashMap2.put("node-id2", "hostname2");
        Mockito.when(workerTopologyContext.getNodeToHost()).thenReturn(new AtomicReference(hashMap2));
        return workerTopologyContext;
    }
}
