package org.apache.flink.table.functions.hive;

import java.math.BigDecimal;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.catalog.hive.client.HiveShimLoader;
import org.apache.flink.table.functions.FunctionContext;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.inference.utils.CallContextMock;
import org.apache.flink.types.Row;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFCollectList;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFCollectSet;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFContextNGrams;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFCount;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMin;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFSum;
import org.assertj.core.api.Assertions;
import org.junit.Test;

/* loaded from: input_file:org/apache/flink/table/functions/hive/HiveGenericUDAFTest.class */
public class HiveGenericUDAFTest {
    @Test
    public void testUDAFMin() throws Exception {
        HiveGenericUDAF init = init(GenericUDAFMin.class, new Object[]{null}, new DataType[]{DataTypes.BIGINT()}, false, true);
        GenericUDAFEvaluator.AggregationBuffer createAccumulator = init.createAccumulator();
        init.accumulate(createAccumulator, new Object[]{2L});
        init.accumulate(createAccumulator, new Object[]{3L});
        init.accumulate(createAccumulator, new Object[]{1L});
        init.merge(createAccumulator, Collections.emptyList());
        Assertions.assertThat(init.getValue(createAccumulator)).isEqualTo(1L);
    }

    @Test
    public void testUDAFSum() throws Exception {
        HiveGenericUDAF init = init(GenericUDAFSum.class, new Object[]{null}, new DataType[]{DataTypes.DOUBLE()}, false, true);
        GenericUDAFEvaluator.AggregationBuffer createAccumulator = init.createAccumulator();
        init.accumulate(createAccumulator, new Object[]{Double.valueOf(0.5d)});
        init.accumulate(createAccumulator, new Object[]{Double.valueOf(0.3d)});
        init.accumulate(createAccumulator, new Object[]{Double.valueOf(5.3d)});
        init.merge(createAccumulator, Collections.emptyList());
        Assertions.assertThat(init.getValue(createAccumulator)).isEqualTo(Double.valueOf(6.1d));
        HiveGenericUDAF init2 = init(GenericUDAFSum.class, new Object[]{null}, new DataType[]{DataTypes.DECIMAL(5, 3)}, false, true);
        GenericUDAFEvaluator.AggregationBuffer createAccumulator2 = init2.createAccumulator();
        init2.accumulate(createAccumulator2, new Object[]{BigDecimal.valueOf(10.111d)});
        init2.accumulate(createAccumulator2, new Object[]{BigDecimal.valueOf(3.222d)});
        init2.accumulate(createAccumulator2, new Object[]{BigDecimal.valueOf(5.333d)});
        init2.merge(createAccumulator2, Collections.emptyList());
        Assertions.assertThat(init2.getValue(createAccumulator2)).isEqualTo(BigDecimal.valueOf(18.666d));
    }

    @Test
    public void testUDAFCount() throws Exception {
        HiveGenericUDAF init = init(GenericUDAFCount.class, new Object[]{null}, new DataType[]{DataTypes.DOUBLE()}, false, true);
        GenericUDAFEvaluator.AggregationBuffer createAccumulator = init.createAccumulator();
        init.accumulate(createAccumulator, new Object[]{Double.valueOf(0.5d)});
        init.accumulate(createAccumulator, new Object[]{Double.valueOf(0.3d)});
        init.accumulate(createAccumulator, new Object[]{Double.valueOf(5.3d)});
        init.merge(createAccumulator, Collections.emptyList());
        Assertions.assertThat(init.getValue(createAccumulator)).isEqualTo(3L);
    }

    @Test
    public void testUDAFResolver() throws Exception {
        HiveGenericUDAF init = init(GenericUDAFContextNGrams.class, new Object[]{null, null, null}, new DataType[]{DataTypes.ARRAY(DataTypes.STRING()), DataTypes.ARRAY(DataTypes.STRING()), DataTypes.INT()}, false, false);
        GenericUDAFEvaluator.AggregationBuffer createAccumulator = init.createAccumulator();
        init.accumulate(createAccumulator, new Object[]{new Object[]{"what", "i", "think"}, new Object[]{"what", "i", null}, 1});
        init.merge(createAccumulator, Collections.emptyList());
        Assertions.assertThat(Arrays.toString((Row[]) init.getValue(createAccumulator))).isEqualTo("[+I[[think], 1.0]]");
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Test
    public void testUDAFWithSingleArrayAsParameter() throws Exception {
        Object[] objArr = {null};
        DataType[] dataTypeArr = {DataTypes.ARRAY(DataTypes.INT().notNull())};
        HiveGenericUDAF init = init(GenericUDAFCollectList.class, objArr, dataTypeArr, false, false);
        GenericUDAFEvaluator.AggregationBuffer createAccumulator = init.createAccumulator();
        init.accumulate(createAccumulator, new Integer[]{1, 2});
        init.accumulate(createAccumulator, new Integer[]{2, 3});
        init.merge(createAccumulator, Collections.emptyList());
        org.junit.jupiter.api.Assertions.assertArrayEquals(new Integer[]{new Integer[]{1, 2}, new Integer[]{2, 3}}, (Integer[][]) init.getValue(createAccumulator));
        HiveGenericUDAF init2 = init(GenericUDAFCollectSet.class, objArr, dataTypeArr, false, false);
        GenericUDAFEvaluator.AggregationBuffer createAccumulator2 = init2.createAccumulator();
        init2.accumulate(createAccumulator2, new Integer[]{1, 2});
        init2.accumulate(createAccumulator2, new Integer[]{2, 3});
        init2.accumulate(createAccumulator2, new Integer[]{1, 2});
        init2.merge(createAccumulator2, Collections.emptySet());
        org.junit.jupiter.api.Assertions.assertArrayEquals(new Integer[]{new Integer[]{1, 2}, new Integer[]{2, 3}}, (Integer[][]) init2.getValue(createAccumulator2));
    }

    private static HiveGenericUDAF init(Class<?> cls, Object[] objArr, DataType[] dataTypeArr, boolean z, boolean z2) throws Exception {
        HiveFunctionWrapper hiveFunctionWrapper = new HiveFunctionWrapper(cls);
        CallContextMock callContextMock = new CallContextMock();
        callContextMock.argumentDataTypes = Arrays.asList(dataTypeArr);
        callContextMock.argumentValues = (List) Arrays.stream(objArr).map(Optional::ofNullable).collect(Collectors.toList());
        callContextMock.argumentLiterals = (List) Arrays.stream(objArr).map(Objects::nonNull).collect(Collectors.toList());
        HiveGenericUDAF hiveGenericUDAF = new HiveGenericUDAF(hiveFunctionWrapper, z, z2, HiveShimLoader.loadHiveShim(HiveShimLoader.getHiveVersion()));
        hiveGenericUDAF.setArguments(callContextMock);
        hiveGenericUDAF.inferReturnType();
        hiveGenericUDAF.open((FunctionContext) null);
        return hiveGenericUDAF;
    }
}
