/*
 * Decompiled with CFR 0.152.
 */
package org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.iotdb.common.rpc.thrift.TAggregationType;
import org.apache.iotdb.commons.udf.builtin.relational.TableBuiltinAggregationFunction;
import org.apache.iotdb.commons.udf.utils.UDFDataTypeTransformer;
import org.apache.iotdb.db.queryengine.execution.aggregation.VarianceAccumulator;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.AggregationMask;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.ApproxCountDistinctAccumulator;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.AvgAccumulator;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.CountAccumulator;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.CountAllAccumulator;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.CountIfAccumulator;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.ExtremeAccumulator;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.FirstAccumulator;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.FirstByAccumulator;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.FirstByDescAccumulator;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.FirstDescAccumulator;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.LastAccumulator;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.LastByAccumulator;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.LastByDescAccumulator;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.LastDescAccumulator;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.MaxAccumulator;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.MinAccumulator;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.SumAccumulator;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.TableAccumulator;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.TableMaxByAccumulator;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.TableMinByAccumulator;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.TableModeAccumulator;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.TableVarianceAccumulator;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.UserDefinedAggregateFunctionAccumulator;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedAccumulator;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedApproxCountDistinctAccumulator;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedAvgAccumulator;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedCountAccumulator;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedCountIfAccumulator;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedExtremeAccumulator;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedFirstAccumulator;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedFirstByAccumulator;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedLastAccumulator;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedLastByAccumulator;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedMaxAccumulator;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedMaxByAccumulator;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedMinAccumulator;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedMinByAccumulator;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedModeAccumulator;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedSumAccumulator;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedUserDefinedAggregateAccumulator;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedVarianceAccumulator;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.UpdateMemory;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.hash.MarkDistinctHash;
import org.apache.iotdb.db.queryengine.plan.relational.planner.ir.GlobalTimePredicateExtractVisitor;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Expression;
import org.apache.iotdb.db.queryengine.plan.relational.type.InternalTypeManager;
import org.apache.iotdb.db.queryengine.plan.udf.TableUDFUtils;
import org.apache.iotdb.udf.api.customizer.parameter.FunctionArguments;
import org.apache.iotdb.udf.api.relational.AggregateFunction;
import org.apache.tsfile.block.column.Column;
import org.apache.tsfile.block.column.ColumnBuilder;
import org.apache.tsfile.enums.TSDataType;
import org.apache.tsfile.file.metadata.statistics.Statistics;
import org.apache.tsfile.read.common.block.column.IntColumn;
import org.apache.tsfile.read.common.type.IntType;
import org.apache.tsfile.read.common.type.Type;
import org.apache.tsfile.read.common.type.TypeFactory;

public class AccumulatorFactory {
    public static TableAccumulator createAccumulator(String functionName, TAggregationType aggregationType, List<TSDataType> inputDataTypes, List<Expression> inputExpressions, Map<String, String> inputAttributes, boolean ascending, String timeColumnName, boolean distinct) {
        TableAccumulator result;
        if (aggregationType == TAggregationType.UDAF) {
            result = AccumulatorFactory.createUDAFAccumulator(functionName, inputDataTypes, inputAttributes);
        } else if ((TableBuiltinAggregationFunction.LAST_BY.getFunctionName().equals(functionName) || TableBuiltinAggregationFunction.FIRST_BY.getFunctionName().equals(functionName)) && inputExpressions.size() > 1) {
            boolean xIsTimeColumn = false;
            boolean yIsTimeColumn = false;
            if (GlobalTimePredicateExtractVisitor.isTimeColumn(inputExpressions.get(1), timeColumnName)) {
                yIsTimeColumn = true;
            } else if (GlobalTimePredicateExtractVisitor.isTimeColumn(inputExpressions.get(0), timeColumnName)) {
                xIsTimeColumn = true;
            }
            result = TableBuiltinAggregationFunction.LAST_BY.getFunctionName().equals(functionName) ? (ascending ? new LastByAccumulator(inputDataTypes.get(0), inputDataTypes.get(1), xIsTimeColumn, yIsTimeColumn) : new LastByDescAccumulator(inputDataTypes.get(0), inputDataTypes.get(1), xIsTimeColumn, yIsTimeColumn)) : (ascending ? new FirstByAccumulator(inputDataTypes.get(0), inputDataTypes.get(1), xIsTimeColumn, yIsTimeColumn) : new FirstByDescAccumulator(inputDataTypes.get(0), inputDataTypes.get(1), xIsTimeColumn, yIsTimeColumn));
        } else {
            if (TableBuiltinAggregationFunction.LAST.getFunctionName().equals(functionName)) {
                boolean isTimeColumn = GlobalTimePredicateExtractVisitor.isTimeColumn(inputExpressions.get(0), timeColumnName);
                return ascending ? new LastAccumulator(inputDataTypes.get(0), isTimeColumn) : new LastDescAccumulator(inputDataTypes.get(0), isTimeColumn);
            }
            result = AccumulatorFactory.createBuiltinAccumulator(aggregationType, inputDataTypes, inputExpressions, inputAttributes, ascending);
        }
        if (distinct) {
            result = new DistinctAccumulator(result, inputDataTypes.stream().map(InternalTypeManager::fromTSDataType).collect(Collectors.toList()));
        }
        return result;
    }

    public static GroupedAccumulator createGroupedAccumulator(String functionName, TAggregationType aggregationType, List<TSDataType> inputDataTypes, List<Expression> inputExpressions, Map<String, String> inputAttributes, boolean ascending, boolean distinct) {
        GroupedAccumulator result = aggregationType == TAggregationType.UDAF ? AccumulatorFactory.createGroupedUDAFAccumulator(functionName, inputDataTypes, inputAttributes) : AccumulatorFactory.createBuiltinGroupedAccumulator(aggregationType, inputDataTypes, inputExpressions, inputAttributes, ascending);
        if (distinct) {
            result = new DistinctGroupedAccumulator(result, inputDataTypes.stream().map(InternalTypeManager::fromTSDataType).collect(Collectors.toList()));
        }
        return result;
    }

    private static TableAccumulator createUDAFAccumulator(String functionName, List<TSDataType> inputDataTypes, Map<String, String> inputAttributes) {
        AggregateFunction aggregateFunction = TableUDFUtils.getAggregateFunction(functionName);
        FunctionArguments functionArguments = new FunctionArguments(UDFDataTypeTransformer.transformToUDFDataTypeList(inputDataTypes), inputAttributes);
        aggregateFunction.beforeStart(functionArguments);
        return new UserDefinedAggregateFunctionAccumulator(aggregateFunction.analyze(functionArguments), aggregateFunction, inputDataTypes.stream().map(TypeFactory::getType).collect(Collectors.toList()));
    }

    private static GroupedAccumulator createGroupedUDAFAccumulator(String functionName, List<TSDataType> inputDataTypes, Map<String, String> inputAttributes) {
        AggregateFunction aggregateFunction = TableUDFUtils.getAggregateFunction(functionName);
        FunctionArguments functionArguments = new FunctionArguments(UDFDataTypeTransformer.transformToUDFDataTypeList(inputDataTypes), inputAttributes);
        aggregateFunction.beforeStart(functionArguments);
        return new GroupedUserDefinedAggregateAccumulator(aggregateFunction, inputDataTypes.stream().map(TypeFactory::getType).collect(Collectors.toList()));
    }

    private static GroupedAccumulator createBuiltinGroupedAccumulator(TAggregationType aggregationType, List<TSDataType> inputDataTypes, List<Expression> inputExpressions, Map<String, String> inputAttributes, boolean ascending) {
        switch (aggregationType) {
            case COUNT: {
                return new GroupedCountAccumulator();
            }
            case COUNT_IF: {
                return new GroupedCountIfAccumulator();
            }
            case AVG: {
                return new GroupedAvgAccumulator(inputDataTypes.get(0));
            }
            case SUM: {
                return new GroupedSumAccumulator(inputDataTypes.get(0));
            }
            case LAST: {
                return new GroupedLastAccumulator(inputDataTypes.get(0));
            }
            case FIRST: {
                return new GroupedFirstAccumulator(inputDataTypes.get(0));
            }
            case MAX: {
                return new GroupedMaxAccumulator(inputDataTypes.get(0));
            }
            case MIN: {
                return new GroupedMinAccumulator(inputDataTypes.get(0));
            }
            case EXTREME: {
                return new GroupedExtremeAccumulator(inputDataTypes.get(0));
            }
            case LAST_BY: {
                return new GroupedLastByAccumulator(inputDataTypes.get(0), inputDataTypes.get(1));
            }
            case FIRST_BY: {
                return new GroupedFirstByAccumulator(inputDataTypes.get(0), inputDataTypes.get(1));
            }
            case MAX_BY: {
                return new GroupedMaxByAccumulator(inputDataTypes.get(0), inputDataTypes.get(1));
            }
            case MIN_BY: {
                return new GroupedMinByAccumulator(inputDataTypes.get(0), inputDataTypes.get(1));
            }
            case MODE: {
                return new GroupedModeAccumulator(inputDataTypes.get(0));
            }
            case STDDEV: 
            case STDDEV_SAMP: {
                return new GroupedVarianceAccumulator(inputDataTypes.get(0), VarianceAccumulator.VarianceType.STDDEV_SAMP);
            }
            case STDDEV_POP: {
                return new GroupedVarianceAccumulator(inputDataTypes.get(0), VarianceAccumulator.VarianceType.STDDEV_POP);
            }
            case VARIANCE: 
            case VAR_SAMP: {
                return new GroupedVarianceAccumulator(inputDataTypes.get(0), VarianceAccumulator.VarianceType.VAR_SAMP);
            }
            case VAR_POP: {
                return new GroupedVarianceAccumulator(inputDataTypes.get(0), VarianceAccumulator.VarianceType.VAR_POP);
            }
            case APPROX_COUNT_DISTINCT: {
                return new GroupedApproxCountDistinctAccumulator(inputDataTypes.get(0));
            }
        }
        throw new IllegalArgumentException("Invalid Aggregation function: " + aggregationType);
    }

    public static TableAccumulator createBuiltinAccumulator(TAggregationType aggregationType, List<TSDataType> inputDataTypes, List<Expression> inputExpressions, Map<String, String> inputAttributes, boolean ascending) {
        switch (aggregationType) {
            case COUNT: {
                return new CountAccumulator();
            }
            case COUNT_ALL: {
                return new CountAllAccumulator();
            }
            case COUNT_IF: {
                return new CountIfAccumulator();
            }
            case AVG: {
                return new AvgAccumulator(inputDataTypes.get(0));
            }
            case SUM: {
                return new SumAccumulator(inputDataTypes.get(0));
            }
            case LAST: {
                return ascending ? new LastAccumulator(inputDataTypes.get(0), false) : new LastDescAccumulator(inputDataTypes.get(0), false);
            }
            case FIRST: {
                return ascending ? new FirstAccumulator(inputDataTypes.get(0)) : new FirstDescAccumulator(inputDataTypes.get(0));
            }
            case MAX: {
                return new MaxAccumulator(inputDataTypes.get(0));
            }
            case MIN: {
                return new MinAccumulator(inputDataTypes.get(0));
            }
            case LAST_BY: {
                return ascending ? new LastByAccumulator(inputDataTypes.get(0), inputDataTypes.get(1), false, false) : new LastByDescAccumulator(inputDataTypes.get(0), inputDataTypes.get(1), false, false);
            }
            case FIRST_BY: {
                return ascending ? new FirstByAccumulator(inputDataTypes.get(0), inputDataTypes.get(1), false, false) : new FirstByDescAccumulator(inputDataTypes.get(0), inputDataTypes.get(1), false, false);
            }
            case MAX_BY: {
                return new TableMaxByAccumulator(inputDataTypes.get(0), inputDataTypes.get(1));
            }
            case MIN_BY: {
                return new TableMinByAccumulator(inputDataTypes.get(0), inputDataTypes.get(1));
            }
            case EXTREME: {
                return new ExtremeAccumulator(inputDataTypes.get(0));
            }
            case MODE: {
                return new TableModeAccumulator(inputDataTypes.get(0));
            }
            case STDDEV: 
            case STDDEV_SAMP: {
                return new TableVarianceAccumulator(inputDataTypes.get(0), VarianceAccumulator.VarianceType.STDDEV_SAMP);
            }
            case STDDEV_POP: {
                return new TableVarianceAccumulator(inputDataTypes.get(0), VarianceAccumulator.VarianceType.STDDEV_POP);
            }
            case VARIANCE: 
            case VAR_SAMP: {
                return new TableVarianceAccumulator(inputDataTypes.get(0), VarianceAccumulator.VarianceType.VAR_SAMP);
            }
            case VAR_POP: {
                return new TableVarianceAccumulator(inputDataTypes.get(0), VarianceAccumulator.VarianceType.VAR_POP);
            }
            case APPROX_COUNT_DISTINCT: {
                return new ApproxCountDistinctAccumulator(inputDataTypes.get(0));
            }
        }
        throw new IllegalArgumentException("Invalid Aggregation function: " + aggregationType);
    }

    public static boolean isMultiInputAggregation(TAggregationType aggregationType) {
        switch (aggregationType) {
            case MAX_BY: 
            case MIN_BY: {
                return true;
            }
        }
        return false;
    }

    public static TableAccumulator createBuiltinMultiInputAccumulator(TAggregationType aggregationType, List<TSDataType> inputDataTypes) {
        switch (aggregationType) {
            case MAX_BY: {
                Preconditions.checkState((inputDataTypes.size() == 2 ? 1 : 0) != 0, (Object)"Wrong inputDataTypes size.");
            }
            case MIN_BY: {
                Preconditions.checkState((inputDataTypes.size() == 2 ? 1 : 0) != 0, (Object)"Wrong inputDataTypes size.");
            }
        }
        throw new IllegalArgumentException("Invalid Aggregation function: " + aggregationType);
    }

    private static TableAccumulator createBuiltinSingleInputAccumulator(TAggregationType aggregationType, TSDataType tsDataType, List<Expression> inputExpressions, Map<String, String> inputAttributes, boolean ascending) {
        switch (aggregationType) {
            case COUNT: {
                return new CountAccumulator();
            }
            case AVG: {
                return new AvgAccumulator(tsDataType);
            }
        }
        throw new IllegalArgumentException("Invalid Aggregation function: " + aggregationType);
    }

    private static class DistinctAccumulator
    implements TableAccumulator {
        private final TableAccumulator accumulator;
        private MarkDistinctHash hash;
        private final List<Type> inputTypes;

        private DistinctAccumulator(TableAccumulator accumulator, List<Type> inputTypes) {
            this.accumulator = Objects.requireNonNull(accumulator, "accumulator is null");
            this.hash = new MarkDistinctHash(inputTypes, false, UpdateMemory.NOOP);
            this.inputTypes = inputTypes;
        }

        @Override
        public long getEstimatedSize() {
            return this.hash.getEstimatedSize() + this.accumulator.getEstimatedSize();
        }

        @Override
        public TableAccumulator copy() {
            throw new UnsupportedOperationException("Distinct aggregation function state can not be copied");
        }

        @Override
        public void addInput(Column[] arguments, AggregationMask mask) {
            Column[] filtered = mask.filterBlock(arguments);
            Column distinctMask = this.hash.markDistinctRows(filtered);
            mask.reset(filtered[0].getPositionCount());
            mask.applyMaskBlock(distinctMask);
            if (mask.isSelectNone()) {
                return;
            }
            this.accumulator.addInput(filtered, mask);
        }

        @Override
        public void addIntermediate(Column argument) {
            throw new UnsupportedOperationException();
        }

        @Override
        public void evaluateIntermediate(ColumnBuilder columnBuilder) {
            throw new UnsupportedOperationException();
        }

        @Override
        public void evaluateFinal(ColumnBuilder columnBuilder) {
            this.accumulator.evaluateFinal(columnBuilder);
        }

        @Override
        public boolean hasFinalResult() {
            return this.accumulator.hasFinalResult();
        }

        @Override
        public void addStatistics(Statistics[] statistics) {
            throw new UnsupportedOperationException("Distinct aggregation function can not be push down");
        }

        @Override
        public void reset() {
            this.accumulator.reset();
            this.hash = new MarkDistinctHash(this.inputTypes, false, UpdateMemory.NOOP);
        }
    }

    private static class DistinctGroupedAccumulator
    implements GroupedAccumulator {
        private final GroupedAccumulator accumulator;
        private MarkDistinctHash hash;
        private final List<Type> inputTypes;

        private DistinctGroupedAccumulator(GroupedAccumulator accumulator, List<Type> inputTypes) {
            this.accumulator = Objects.requireNonNull(accumulator, "accumulator is null");
            this.inputTypes = ImmutableList.builder().add((Object)IntType.INT32).addAll(inputTypes).build();
            this.hash = new MarkDistinctHash(this.inputTypes, false, UpdateMemory.NOOP);
        }

        @Override
        public long getEstimatedSize() {
            return this.hash.getEstimatedSize() + this.accumulator.getEstimatedSize();
        }

        @Override
        public void setGroupCount(long groupCount) {
            this.accumulator.setGroupCount(groupCount);
        }

        @Override
        public void addInput(int[] groupIds, Column[] arguments, AggregationMask mask) {
            groupIds = DistinctGroupedAccumulator.maskGroupIds(groupIds, mask);
            Column[] filtered = mask.filterBlock(arguments);
            Column distinctMask = this.hash.markDistinctRows((Column[])Stream.concat(Stream.of(new IntColumn(groupIds.length, Optional.empty(), groupIds)), Arrays.stream(filtered)).toArray(Column[]::new));
            mask.reset(filtered[0].getPositionCount());
            mask.applyMaskBlock(distinctMask);
            if (mask.isSelectNone()) {
                return;
            }
            this.accumulator.addInput(groupIds, filtered, mask);
        }

        private static int[] maskGroupIds(int[] groupIds, AggregationMask mask) {
            if (mask.isSelectAll() || mask.isSelectNone()) {
                return groupIds;
            }
            int[] newGroupIds = new int[mask.getSelectedPositionCount()];
            int[] selectedPositions = mask.getSelectedPositions();
            for (int i = 0; i < newGroupIds.length; ++i) {
                newGroupIds[i] = groupIds[selectedPositions[i]];
            }
            return newGroupIds;
        }

        @Override
        public void addIntermediate(int[] groupIds, Column argument) {
            throw new UnsupportedOperationException();
        }

        @Override
        public void evaluateIntermediate(int groupId, ColumnBuilder columnBuilder) {
            throw new UnsupportedOperationException();
        }

        @Override
        public void evaluateFinal(int groupId, ColumnBuilder columnBuilder) {
            this.accumulator.evaluateFinal(groupId, columnBuilder);
        }

        @Override
        public void prepareFinal() {
        }

        @Override
        public void reset() {
            this.accumulator.reset();
            this.hash = new MarkDistinctHash(this.inputTypes, false, UpdateMemory.NOOP);
        }
    }

    @FunctionalInterface
    public static interface KeepEvaluator {
        public boolean apply(long var1);
    }
}

