/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ad.ml;

import com.amazon.randomcutforest.RandomCutForest;
import com.amazon.randomcutforest.config.Precision;
import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest;
import com.amazon.randomcutforest.parkservices.state.ThresholdedRandomCutForestMapper;
import com.amazon.randomcutforest.parkservices.state.ThresholdedRandomCutForestState;
import com.amazon.randomcutforest.serialize.json.v1.V1JsonToV3StateConverter;
import com.amazon.randomcutforest.state.RandomCutForestMapper;
import com.amazon.randomcutforest.state.RandomCutForestState;
import com.google.gson.Gson;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;
import com.google.gson.reflect.TypeToken;
import io.protostuff.LinkedBuffer;
import io.protostuff.ProtostuffIOUtil;
import io.protostuff.Schema;
import java.io.IOException;
import java.lang.reflect.Type;
import java.security.AccessController;
import java.time.Clock;
import java.time.Instant;
import java.time.ZoneOffset;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Base64;
import java.util.Deque;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.apache.commons.pool2.impl.GenericObjectPool;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.Message;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.opensearch.action.get.GetRequest;
import org.opensearch.action.get.GetResponse;
import org.opensearch.action.support.IndicesOptions;
import org.opensearch.ad.indices.ADIndex;
import org.opensearch.ad.indices.ADIndexManagement;
import org.opensearch.ad.ml.ThresholdingModel;
import org.opensearch.core.action.ActionListener;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.index.query.MatchQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.reindex.DeleteByQueryRequest;
import org.opensearch.timeseries.common.exception.ResourceNotFoundException;
import org.opensearch.timeseries.ml.CheckpointDao;
import org.opensearch.timeseries.ml.ModelManager;
import org.opensearch.timeseries.ml.ModelState;
import org.opensearch.timeseries.ml.Sample;
import org.opensearch.timeseries.ml.SingleStreamModelIdMapper;
import org.opensearch.timeseries.model.Entity;
import org.opensearch.timeseries.util.ClientUtil;
import org.opensearch.transport.client.Client;

public class ADCheckpointDao
extends CheckpointDao<ThresholdedRandomCutForest, ADIndex, ADIndexManagement> {
    private static final Logger logger = LogManager.getLogger(ADCheckpointDao.class);
    public static final String ENTITY_RCF = "rcf";
    public static final String ENTITY_THRESHOLD = "th";
    public static final String ENTITY_TRCF = "trcf";
    public static final String FIELD_MODELV2 = "modelV2";
    public static final String DETECTOR_ID = "detectorId";
    private RandomCutForestMapper mapper;
    private V1JsonToV3StateConverter converter;
    private ThresholdedRandomCutForestMapper trcfMapper;
    private Schema<ThresholdedRandomCutForestState> trcfSchema;
    private final Class<? extends ThresholdingModel> thresholdingModelClass;
    private final ADIndexManagement indexUtil;
    private final JsonParser parser = new JsonParser();
    private double anomalyRate;
    private final Type doubleArrayType;

    public ADCheckpointDao(Client client, ClientUtil clientUtil, Gson gson, RandomCutForestMapper mapper, V1JsonToV3StateConverter converter, ThresholdedRandomCutForestMapper trcfMapper, Schema<ThresholdedRandomCutForestState> trcfSchema, Class<? extends ThresholdingModel> thresholdingModelClass, ADIndexManagement indexUtil, int maxCheckpointBytes, GenericObjectPool<LinkedBuffer> serializeRCFBufferPool, int serializeRCFBufferSize, double anomalyRate, Clock clock) {
        super(client, clientUtil, ".opendistro-anomaly-checkpoints", gson, maxCheckpointBytes, serializeRCFBufferPool, serializeRCFBufferSize, indexUtil, clock);
        this.mapper = mapper;
        this.converter = converter;
        this.trcfMapper = trcfMapper;
        this.trcfSchema = trcfSchema;
        this.thresholdingModelClass = thresholdingModelClass;
        this.indexUtil = indexUtil;
        this.anomalyRate = anomalyRate;
        this.doubleArrayType = new TypeToken<double[][]>(this){}.getType();
    }

    public void putTRCFCheckpoint(String modelId, ThresholdedRandomCutForest forest, ActionListener<Void> listener) {
        HashMap<String, Object> source = new HashMap<String, Object>();
        String modelCheckpoint = this.toCheckpoint(forest);
        if (modelCheckpoint != null) {
            source.put(FIELD_MODELV2, modelCheckpoint);
            source.put("timestamp", this.clock.instant().atZone(ZoneOffset.UTC));
            this.putModelCheckpoint(modelId, source, listener);
        } else {
            listener.onFailure((Exception)new RuntimeException("Fail to create checkpoint to save"));
        }
    }

    public void putThresholdCheckpoint(String modelId, ThresholdingModel threshold, ActionListener<Void> listener) {
        String modelCheckpoint = AccessController.doPrivileged(() -> this.gson.toJson((Object)threshold));
        HashMap<String, Object> source = new HashMap<String, Object>();
        source.put("model", modelCheckpoint);
        source.put("timestamp", this.clock.instant().atZone(ZoneOffset.UTC));
        this.putModelCheckpoint(modelId, source, listener);
    }

    @Override
    public Map<String, Object> toIndexSource(ModelState<ThresholdedRandomCutForest> modelState) throws IOException {
        Optional<Sample[]> samples;
        String modelId = modelState.getModelId();
        HashMap<String, Object> source = new HashMap<String, Object>();
        Optional<ThresholdedRandomCutForest> model = modelState.getModel();
        if (model.isPresent()) {
            ThresholdedRandomCutForest entityModel = model.get();
            Optional<String> serializedModel = this.toCheckpoint(entityModel, modelId);
            if (!serializedModel.isPresent() || serializedModel.get().length() > this.maxCheckpointBytes) {
                logger.warn((Message)new ParameterizedMessage("[{}]'s model is empty or too large: [{}] bytes", (Object)modelState.getModelId(), (Object)(serializedModel.isPresent() ? serializedModel.get().length() : 0)));
                return source;
            }
            source.put(FIELD_MODELV2, serializedModel.get());
        }
        if ((samples = this.toCheckpoint(modelState.getSamples())).isPresent()) {
            source.put("samples", samples.get());
        }
        if (!source.containsKey("samples") && !source.containsKey(FIELD_MODELV2)) {
            return source;
        }
        String detectorId = modelState.getConfigId();
        source.put(DETECTOR_ID, detectorId);
        source.put("timestamp", this.clock.instant().atZone(ZoneOffset.UTC));
        source.put("schema_version", this.indexUtil.getSchemaVersion(ADIndex.CHECKPOINT));
        Optional<Entity> entity = modelState.getEntity();
        if (entity.isPresent()) {
            source.put("entity", entity.get());
        }
        return source;
    }

    public Optional<String> toCheckpoint(ThresholdedRandomCutForest model, String modelId) {
        return AccessController.doPrivileged(() -> {
            if (model == null) {
                logger.warn("Empty model");
                return Optional.empty();
            }
            try {
                JsonObject json = new JsonObject();
                if (model != null) {
                    json.addProperty(ENTITY_TRCF, this.toCheckpoint(model));
                }
                return json.entrySet().isEmpty() ? Optional.empty() : Optional.ofNullable(this.gson.toJson((JsonElement)json));
            }
            catch (Exception ex) {
                logger.warn((Message)new ParameterizedMessage("fail to generate checkpoint for [{}]", (Object)modelId), (Throwable)ex);
                return Optional.empty();
            }
        });
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    String toCheckpoint(ThresholdedRandomCutForest trcf) {
        String checkpoint;
        block16: {
            checkpoint = null;
            Map.Entry<LinkedBuffer, Boolean> result = this.checkoutOrNewBuffer();
            LinkedBuffer buffer = result.getKey();
            boolean needCheckin = result.getValue();
            try {
                checkpoint = this.toCheckpoint(trcf, buffer);
            }
            catch (Exception e) {
                logger.error("Failed to serialize model", (Throwable)e);
                if (!needCheckin) break block16;
                try {
                    this.serializeRCFBufferPool.invalidateObject((Object)buffer);
                    needCheckin = false;
                }
                catch (Exception x) {
                    logger.warn("Failed to invalidate buffer", (Throwable)x);
                }
                try {
                    checkpoint = this.toCheckpoint(trcf, LinkedBuffer.allocate((int)this.serializeRCFBufferSize));
                }
                catch (Exception ex) {
                    logger.warn("Failed to generate checkpoint", (Throwable)ex);
                }
            }
            finally {
                if (needCheckin) {
                    try {
                        this.serializeRCFBufferPool.returnObject((Object)buffer);
                    }
                    catch (Exception e) {
                        logger.warn("Failed to return buffer to pool", (Throwable)e);
                    }
                }
            }
        }
        return checkpoint;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private String toCheckpoint(ThresholdedRandomCutForest trcf, LinkedBuffer buffer) {
        try {
            byte[] bytes = AccessController.doPrivileged(() -> {
                ThresholdedRandomCutForestState trcfState = this.trcfMapper.toState(trcf);
                return ProtostuffIOUtil.toByteArray((Object)trcfState, this.trcfSchema, (LinkedBuffer)buffer);
            });
            String string = Base64.getEncoder().encodeToString(bytes);
            return string;
        }
        finally {
            buffer.clear();
        }
    }

    @Override
    protected ModelState<ThresholdedRandomCutForest> fromEntityModelCheckpoint(Map<String, Object> checkpoint, String modelId, String configId) {
        try {
            return AccessController.doPrivileged(() -> {
                Object modelObj = checkpoint.get(FIELD_MODELV2);
                if (modelObj == null) {
                    modelObj = checkpoint.get("model");
                }
                if (modelObj == null) {
                    logger.warn((Message)new ParameterizedMessage("Empty model for [{}]", (Object)modelId));
                    return null;
                }
                String model = (String)modelObj;
                if (model.length() > this.maxCheckpointBytes) {
                    logger.warn((Message)new ParameterizedMessage("[{}]'s model too large: [{}] bytes", (Object)modelId, (Object)model.length()));
                    return null;
                }
                JsonObject json = this.parser.parse(model).getAsJsonObject();
                ThresholdedRandomCutForest trcf = null;
                if (json.has(ENTITY_TRCF)) {
                    trcf = this.toTrcf(json.getAsJsonPrimitive(ENTITY_TRCF).getAsString());
                } else {
                    Optional<ThresholdedRandomCutForest> convertedTRCF;
                    Optional<Object> rcf = Optional.empty();
                    Optional<ThresholdingModel> threshold = Optional.empty();
                    if (json.has(ENTITY_RCF)) {
                        String serializedRCF = json.getAsJsonPrimitive(ENTITY_RCF).getAsString();
                        rcf = this.deserializeRCFModel(serializedRCF, modelId);
                    }
                    if (json.has(ENTITY_THRESHOLD)) {
                        threshold = Optional.ofNullable((ThresholdingModel)this.gson.fromJson(json.getAsJsonPrimitive(ENTITY_THRESHOLD).getAsString(), this.thresholdingModelClass));
                    }
                    if (rcf.isPresent() && (convertedTRCF = this.convertToTRCF((RandomCutForest)rcf.get(), threshold)).isPresent()) {
                        trcf = convertedTRCF.get();
                    }
                }
                Deque<Sample> sampleQueue = this.processSampleQueue(json, checkpoint, modelId);
                String lastCheckpointTimeString = (String)checkpoint.get("timestamp");
                Instant timestamp = Instant.parse(lastCheckpointTimeString);
                Entity entity = null;
                Object serializedEntity = checkpoint.get("entity");
                if (serializedEntity != null) {
                    try {
                        entity = Entity.fromJsonArray(serializedEntity);
                    }
                    catch (Exception e) {
                        logger.error((Message)new ParameterizedMessage("fail to parse entity", serializedEntity), (Throwable)e);
                    }
                }
                ModelState<ThresholdedRandomCutForest> modelState = new ModelState<ThresholdedRandomCutForest>(trcf, modelId, configId, ModelManager.ModelType.TRCF.getName(), this.clock, 0.0f, Optional.ofNullable(entity), sampleQueue);
                modelState.setLastCheckpointTime(timestamp);
                return modelState;
            });
        }
        catch (Exception e) {
            logger.warn("Exception while deserializing checkpoint " + modelId, (Throwable)e);
            return null;
        }
    }

    private Deque<Sample> processSampleQueue(JsonObject json, Map<String, Object> checkpoint, String modelId) {
        ArrayDeque<Sample> sampleQueue = new ArrayDeque();
        if (json.has("sp")) {
            double[][] samplesArray = (double[][])this.gson.fromJson((JsonElement)json.getAsJsonArray("sp"), this.doubleArrayType);
            Arrays.stream(samplesArray).map(sampleArray -> new Sample((double[])sampleArray, Instant.ofEpochMilli(0L), Instant.ofEpochMilli(0L))).forEach(sampleQueue::add);
        } else {
            sampleQueue = this.loadSampleQueue(checkpoint, modelId);
        }
        return sampleQueue;
    }

    ThresholdedRandomCutForest toTrcf(String checkpoint) {
        ThresholdedRandomCutForest trcf = null;
        if (checkpoint != null && !checkpoint.isEmpty()) {
            try {
                byte[] bytes = Base64.getDecoder().decode(checkpoint);
                ThresholdedRandomCutForestState state = (ThresholdedRandomCutForestState)this.trcfSchema.newMessage();
                AccessController.doPrivileged(() -> {
                    ProtostuffIOUtil.mergeFrom((byte[])bytes, (Object)state, this.trcfSchema);
                    return null;
                });
                trcf = (ThresholdedRandomCutForest)this.trcfMapper.toModel((Object)state);
            }
            catch (RuntimeException e) {
                logger.info("checkpoint to restore: " + checkpoint);
                logger.error("Failed to deserialize TRCF model", (Throwable)e);
            }
        }
        return trcf;
    }

    private Optional<RandomCutForest> deserializeRCFModel(String checkpoint, String modelId) {
        if (checkpoint == null || checkpoint.isEmpty()) {
            return Optional.empty();
        }
        return Optional.ofNullable(AccessController.doPrivileged(() -> {
            try {
                RandomCutForestState state = this.converter.convert(checkpoint, Precision.FLOAT_32);
                return this.mapper.toModel(state);
            }
            catch (Exception e) {
                logger.error("Unexpected error when deserializing " + modelId, (Throwable)e);
                return null;
            }
        }));
    }

    private void deserializeTRCFModel(GetResponse response, String rcfModelId, ActionListener<Optional<ThresholdedRandomCutForest>> listener) {
        block6: {
            Object model = null;
            if (response.isExists()) {
                try {
                    model = response.getSource().get(FIELD_MODELV2);
                    if (model != null) {
                        listener.onResponse(Optional.ofNullable(this.toTrcf(model)));
                        break block6;
                    }
                    Object modelV1 = response.getSource().get("model");
                    Optional<RandomCutForest> forest = this.deserializeRCFModel((String)modelV1, rcfModelId);
                    if (!forest.isPresent()) {
                        logger.error("Unexpected error when deserializing [{}]", (Object)rcfModelId);
                        listener.onResponse(Optional.empty());
                        return;
                    }
                    String thresholdingModelId = SingleStreamModelIdMapper.getThresholdModelIdFromRCFModelId(rcfModelId);
                    this.getThresholdModel(thresholdingModelId, (ActionListener<Optional<ThresholdingModel>>)ActionListener.wrap(thresholdingModel -> listener.onResponse(this.convertToTRCF((RandomCutForest)forest.get(), (Optional<ThresholdingModel>)thresholdingModel)), arg_0 -> listener.onFailure(arg_0)));
                }
                catch (Exception e) {
                    logger.error((Message)new ParameterizedMessage("Unexpected error when deserializing [{}]", (Object)rcfModelId), (Throwable)e);
                    listener.onResponse(Optional.empty());
                }
            } else {
                listener.onResponse(Optional.empty());
            }
        }
    }

    @Override
    protected ModelState<ThresholdedRandomCutForest> fromSingleStreamModelCheckpoint(Map<String, Object> checkpoint, String modelId, String configId) {
        throw new UnsupportedOperationException("This method is not supported");
    }

    public void getTRCFModel(String modelId, ActionListener<Optional<ThresholdedRandomCutForest>> listener) {
        this.clientUtil.asyncRequest(new GetRequest(this.indexName, modelId), (arg_0, arg_1) -> ((Client)this.client).get(arg_0, arg_1), ActionListener.wrap(response -> this.deserializeTRCFModel((GetResponse)response, modelId, listener), exception -> {
            if (exception instanceof IndexNotFoundException) {
                listener.onResponse(Optional.empty());
            } else {
                listener.onFailure(exception);
            }
        }));
    }

    public void getThresholdModel(String modelId, ActionListener<Optional<ThresholdingModel>> listener) {
        this.clientUtil.asyncRequest(new GetRequest(this.indexName, modelId), (arg_0, arg_1) -> ((Client)this.client).get(arg_0, arg_1), ActionListener.wrap(response -> {
            Optional<Object> thresholdCheckpoint = this.processThresholdModelCheckpoint((GetResponse)response);
            if (!thresholdCheckpoint.isPresent()) {
                listener.onFailure((Exception)new ResourceNotFoundException("", "Fail to find model " + modelId));
                return;
            }
            Optional<ThresholdingModel> model = thresholdCheckpoint.map(checkpoint -> AccessController.doPrivileged(() -> (ThresholdingModel)this.gson.fromJson((String)checkpoint, this.thresholdingModelClass)));
            listener.onResponse(model);
        }, exception -> {
            if (exception instanceof IndexNotFoundException) {
                listener.onResponse(Optional.empty());
            } else {
                listener.onFailure(exception);
            }
        }));
    }

    private Optional<Object> processThresholdModelCheckpoint(GetResponse response) {
        return Optional.ofNullable(response).filter(GetResponse::isExists).map(GetResponse::getSource).map(source -> source.get("model"));
    }

    private Optional<ThresholdedRandomCutForest> convertToTRCF(RandomCutForest rcf, Optional<ThresholdingModel> kllThreshold) {
        if (rcf == null) {
            return Optional.empty();
        }
        List<Object> scores = new ArrayList();
        if (kllThreshold.isPresent()) {
            scores = kllThreshold.get().extractScores();
        }
        return Optional.of(new ThresholdedRandomCutForest(rcf, this.anomalyRate, scores, new double[rcf.getDimensions()]));
    }

    @Override
    protected DeleteByQueryRequest createDeleteCheckpointRequest(String detectorId) {
        return (DeleteByQueryRequest)((DeleteByQueryRequest)new DeleteByQueryRequest(new String[]{this.indexName}).setQuery((QueryBuilder)new MatchQueryBuilder(DETECTOR_ID, (Object)detectorId)).setIndicesOptions(IndicesOptions.LENIENT_EXPAND_OPEN).setAbortOnVersionConflict(false)).setRequestsPerSecond(500.0f);
    }
}

