/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.sql.planner.physical;

import java.time.Instant;
import java.time.LocalTime;
import java.time.temporal.ChronoUnit;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import lombok.Generated;
import org.apache.commons.lang3.tuple.Pair;
import org.opensearch.sql.ast.tree.Trendline;
import org.opensearch.sql.data.model.ExprTupleValue;
import org.opensearch.sql.data.model.ExprValue;
import org.opensearch.sql.data.model.ExprValueUtils;
import org.opensearch.sql.data.type.ExprCoreType;
import org.opensearch.sql.expression.DSL;
import org.opensearch.sql.expression.Expression;
import org.opensearch.sql.expression.LiteralExpression;
import org.opensearch.sql.planner.physical.PhysicalPlan;
import org.opensearch.sql.planner.physical.PhysicalPlanNodeVisitor;
import shaded.com.google.common.collect.EvictingQueue;
import shaded.com.google.common.collect.ImmutableMap;

public class TrendlineOperator
extends PhysicalPlan {
    private final PhysicalPlan input;
    private final List<Pair<Trendline.TrendlineComputation, ExprCoreType>> computations;
    private final List<TrendlineAccumulator> accumulators;
    private final Map<String, Integer> fieldToIndexMap;
    private final HashSet<String> aliases;

    public TrendlineOperator(PhysicalPlan input, List<Pair<Trendline.TrendlineComputation, ExprCoreType>> computations) {
        this.input = input;
        this.computations = computations;
        this.accumulators = computations.stream().map(TrendlineOperator::createAccumulator).toList();
        this.fieldToIndexMap = new HashMap<String, Integer>(computations.size());
        this.aliases = new HashSet(computations.size());
        for (int i = 0; i < computations.size(); ++i) {
            Trendline.TrendlineComputation computation = computations.get(i).getKey();
            this.fieldToIndexMap.put(computation.getDataField().getChild().get(0).toString(), i);
            this.aliases.add(computation.getAlias());
        }
    }

    @Override
    public <R, C> R accept(PhysicalPlanNodeVisitor<R, C> visitor2, C context) {
        return visitor2.visitTrendline(this, context);
    }

    @Override
    public List<PhysicalPlan> getChild() {
        return Collections.singletonList(this.input);
    }

    @Override
    public boolean hasNext() {
        return this.getChild().getFirst().hasNext();
    }

    @Override
    public ExprValue next() {
        ExprValue next = (ExprValue)this.input.next();
        Map<String, ExprValue> inputStruct = this.consumeInputTuple(next);
        ImmutableMap.Builder<String, ExprValue> mapBuilder = new ImmutableMap.Builder<String, ExprValue>();
        mapBuilder.putAll(inputStruct);
        for (int i = 0; i < this.accumulators.size(); ++i) {
            ExprValue calculateResult = this.accumulators.get(i).calculate();
            String field = this.computations.get(i).getKey().getAlias();
            if (calculateResult == null) continue;
            mapBuilder.put(field, calculateResult);
        }
        ExprTupleValue result2 = ExprTupleValue.fromExprValueMap(mapBuilder.buildKeepingLast());
        return result2;
    }

    private Map<String, ExprValue> consumeInputTuple(ExprValue inputValue) {
        Map<String, ExprValue> tupleValue = ExprValueUtils.getTupleValue(inputValue);
        for (String bindName : tupleValue.keySet()) {
            ExprValue fieldValue;
            Integer index = this.fieldToIndexMap.get(bindName);
            if (index == null || (fieldValue = tupleValue.get(bindName)).isNull()) continue;
            this.accumulators.get(index).accumulate(fieldValue);
        }
        tupleValue.keySet().removeAll(this.aliases);
        return tupleValue;
    }

    private static TrendlineAccumulator createAccumulator(Pair<Trendline.TrendlineComputation, ExprCoreType> computation) {
        return new SimpleMovingAverageAccumulator(computation.getKey(), computation.getValue());
    }

    @Generated
    public String toString() {
        return "TrendlineOperator(input=" + String.valueOf(this.getInput()) + ", computations=" + String.valueOf(this.getComputations()) + ", accumulators=" + String.valueOf(this.accumulators) + ", fieldToIndexMap=" + String.valueOf(this.fieldToIndexMap) + ", aliases=" + String.valueOf(this.aliases) + ")";
    }

    @Generated
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof TrendlineOperator)) {
            return false;
        }
        TrendlineOperator other = (TrendlineOperator)o;
        if (!other.canEqual(this)) {
            return false;
        }
        PhysicalPlan this$input = this.getInput();
        PhysicalPlan other$input = other.getInput();
        if (this$input == null ? other$input != null : !this$input.equals(other$input)) {
            return false;
        }
        List<Pair<Trendline.TrendlineComputation, ExprCoreType>> this$computations = this.getComputations();
        List<Pair<Trendline.TrendlineComputation, ExprCoreType>> other$computations = other.getComputations();
        return !(this$computations == null ? other$computations != null : !((Object)this$computations).equals(other$computations));
    }

    @Generated
    protected boolean canEqual(Object other) {
        return other instanceof TrendlineOperator;
    }

    @Generated
    public int hashCode() {
        int PRIME = 59;
        int result2 = 1;
        PhysicalPlan $input = this.getInput();
        result2 = result2 * 59 + ($input == null ? 43 : $input.hashCode());
        List<Pair<Trendline.TrendlineComputation, ExprCoreType>> $computations = this.getComputations();
        result2 = result2 * 59 + ($computations == null ? 43 : ((Object)$computations).hashCode());
        return result2;
    }

    @Generated
    public PhysicalPlan getInput() {
        return this.input;
    }

    @Generated
    public List<Pair<Trendline.TrendlineComputation, ExprCoreType>> getComputations() {
        return this.computations;
    }

    private static interface TrendlineAccumulator {
        public void accumulate(ExprValue var1);

        public ExprValue calculate();

        public static ArithmeticEvaluator getEvaluator(ExprCoreType type2) {
            switch (type2) {
                case DOUBLE: {
                    return NumericArithmeticEvaluator.INSTANCE;
                }
                case DATE: {
                    return DateArithmeticEvaluator.INSTANCE;
                }
                case TIME: {
                    return TimeArithmeticEvaluator.INSTANCE;
                }
                case TIMESTAMP: {
                    return TimestampArithmeticEvaluator.INSTANCE;
                }
            }
            throw new IllegalArgumentException(String.format("Invalid type %s used for moving average.", type2.typeName()));
        }
    }

    private static class SimpleMovingAverageAccumulator
    implements TrendlineAccumulator {
        private final LiteralExpression dataPointsNeeded;
        private final EvictingQueue<ExprValue> receivedValues;
        private final ArithmeticEvaluator evaluator;
        private Expression runningTotal = null;

        public SimpleMovingAverageAccumulator(Trendline.TrendlineComputation computation, ExprCoreType type2) {
            this.dataPointsNeeded = DSL.literal(computation.getNumberOfDataPoints().doubleValue());
            this.receivedValues = EvictingQueue.create(computation.getNumberOfDataPoints());
            this.evaluator = TrendlineAccumulator.getEvaluator(type2);
        }

        @Override
        public void accumulate(ExprValue value) {
            if (this.dataPointsNeeded.valueOf().integerValue() == 1) {
                this.runningTotal = this.evaluator.calculateFirstTotal(Collections.singletonList(value));
                this.receivedValues.add(value);
                return;
            }
            ExprValue valueToRemove = this.receivedValues.size() == this.dataPointsNeeded.valueOf().integerValue().intValue() ? (ExprValue)this.receivedValues.remove() : null;
            this.receivedValues.add(value);
            if (this.receivedValues.size() == this.dataPointsNeeded.valueOf().integerValue().intValue()) {
                if (this.runningTotal != null) {
                    this.runningTotal = this.evaluator.add(this.runningTotal, value, valueToRemove);
                } else {
                    List<ExprValue> data2 = this.receivedValues.stream().toList();
                    this.runningTotal = this.evaluator.calculateFirstTotal(data2);
                }
            }
        }

        @Override
        public ExprValue calculate() {
            if (this.receivedValues.size() < this.dataPointsNeeded.valueOf().integerValue()) {
                return null;
            }
            if (this.dataPointsNeeded.valueOf().integerValue() == 1) {
                return (ExprValue)this.receivedValues.peek();
            }
            return this.evaluator.evaluate(this.runningTotal, this.dataPointsNeeded);
        }
    }

    private static class TimestampArithmeticEvaluator
    implements ArithmeticEvaluator {
        private static final TimestampArithmeticEvaluator INSTANCE = new TimestampArithmeticEvaluator();

        private TimestampArithmeticEvaluator() {
        }

        @Override
        public Expression calculateFirstTotal(List<ExprValue> dataPoints) {
            Expression total = DSL.literal(0);
            for (ExprValue dataPoint : dataPoints) {
                total = DSL.add(total, DSL.literal(dataPoint.timestampValue().toEpochMilli()));
            }
            return DSL.literal(total.valueOf().longValue());
        }

        @Override
        public Expression add(Expression runningTotal, ExprValue incomingValue, ExprValue evictedValue) {
            return DSL.literal(DSL.add(runningTotal, DSL.subtract(DSL.literal(incomingValue.timestampValue().toEpochMilli()), DSL.literal(evictedValue.timestampValue().toEpochMilli()))).valueOf());
        }

        @Override
        public ExprValue evaluate(Expression runningTotal, LiteralExpression numberOfDataPoints) {
            return ExprValueUtils.timestampValue(Instant.ofEpochMilli(DSL.divide(runningTotal, numberOfDataPoints).valueOf().longValue()));
        }
    }

    private static class TimeArithmeticEvaluator
    implements ArithmeticEvaluator {
        private static final TimeArithmeticEvaluator INSTANCE = new TimeArithmeticEvaluator();

        private TimeArithmeticEvaluator() {
        }

        @Override
        public Expression calculateFirstTotal(List<ExprValue> dataPoints) {
            Expression total = DSL.literal(0);
            for (ExprValue dataPoint : dataPoints) {
                total = DSL.add(total, DSL.literal(ChronoUnit.MILLIS.between(LocalTime.MIN, dataPoint.timeValue())));
            }
            return DSL.literal(total.valueOf().longValue());
        }

        @Override
        public Expression add(Expression runningTotal, ExprValue incomingValue, ExprValue evictedValue) {
            return DSL.literal(DSL.add(runningTotal, DSL.subtract(DSL.literal(ChronoUnit.MILLIS.between(LocalTime.MIN, incomingValue.timeValue())), DSL.literal(ChronoUnit.MILLIS.between(LocalTime.MIN, evictedValue.timeValue())))).valueOf());
        }

        @Override
        public ExprValue evaluate(Expression runningTotal, LiteralExpression numberOfDataPoints) {
            return ExprValueUtils.timeValue(LocalTime.MIN.plus(DSL.divide(runningTotal, numberOfDataPoints).valueOf().longValue(), ChronoUnit.MILLIS));
        }
    }

    private static class DateArithmeticEvaluator
    implements ArithmeticEvaluator {
        private static final DateArithmeticEvaluator INSTANCE = new DateArithmeticEvaluator();

        private DateArithmeticEvaluator() {
        }

        @Override
        public Expression calculateFirstTotal(List<ExprValue> dataPoints) {
            return TimestampArithmeticEvaluator.INSTANCE.calculateFirstTotal(dataPoints);
        }

        @Override
        public Expression add(Expression runningTotal, ExprValue incomingValue, ExprValue evictedValue) {
            return TimestampArithmeticEvaluator.INSTANCE.add(runningTotal, incomingValue, evictedValue);
        }

        @Override
        public ExprValue evaluate(Expression runningTotal, LiteralExpression numberOfDataPoints) {
            ExprValue timestampResult = TimestampArithmeticEvaluator.INSTANCE.evaluate(runningTotal, numberOfDataPoints);
            return ExprValueUtils.dateValue(timestampResult.dateValue());
        }
    }

    private static class NumericArithmeticEvaluator
    implements ArithmeticEvaluator {
        private static final NumericArithmeticEvaluator INSTANCE = new NumericArithmeticEvaluator();

        private NumericArithmeticEvaluator() {
        }

        @Override
        public Expression calculateFirstTotal(List<ExprValue> dataPoints) {
            Expression total = DSL.literal(0.0);
            for (ExprValue dataPoint : dataPoints) {
                total = DSL.add(total, DSL.literal(dataPoint.doubleValue()));
            }
            return DSL.literal(total.valueOf().doubleValue());
        }

        @Override
        public Expression add(Expression runningTotal, ExprValue incomingValue, ExprValue evictedValue) {
            return DSL.literal(DSL.add(runningTotal, DSL.subtract(DSL.literal(incomingValue), DSL.literal(evictedValue))).valueOf().doubleValue());
        }

        @Override
        public ExprValue evaluate(Expression runningTotal, LiteralExpression numberOfDataPoints) {
            return DSL.divide(runningTotal, numberOfDataPoints).valueOf();
        }
    }

    private static interface ArithmeticEvaluator {
        public Expression calculateFirstTotal(List<ExprValue> var1);

        public Expression add(Expression var1, ExprValue var2, ExprValue var3);

        public ExprValue evaluate(Expression var1, LiteralExpression var2);
    }
}

