/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.engine.memory;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchSecurityException;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionType;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.client.Client;
import org.opensearch.client.Requests;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.ExistsQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.index.reindex.DeleteByQueryAction;
import org.opensearch.index.reindex.DeleteByQueryRequest;
import org.opensearch.ml.common.conversation.ConversationalIndexConstants;
import org.opensearch.ml.common.conversation.Interaction;
import org.opensearch.ml.memory.action.conversation.CreateConversationAction;
import org.opensearch.ml.memory.action.conversation.CreateConversationRequest;
import org.opensearch.ml.memory.action.conversation.CreateConversationResponse;
import org.opensearch.ml.memory.action.conversation.CreateInteractionAction;
import org.opensearch.ml.memory.action.conversation.CreateInteractionRequest;
import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse;
import org.opensearch.ml.memory.action.conversation.GetTracesAction;
import org.opensearch.ml.memory.action.conversation.GetTracesRequest;
import org.opensearch.ml.memory.action.conversation.UpdateInteractionAction;
import org.opensearch.ml.memory.action.conversation.UpdateInteractionRequest;
import org.opensearch.ml.memory.index.ConversationMetaIndex;
import org.opensearch.search.SearchHit;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.sort.SortOrder;

public class MLMemoryManager {
    @Generated
    private static final Logger log = LogManager.getLogger(MLMemoryManager.class);
    private Client client;
    private ClusterService clusterService;
    private ConversationMetaIndex conversationMetaIndex;

    public void createConversation(String name, String applicationType, ActionListener<CreateConversationResponse> actionListener) {
        try {
            this.client.execute((ActionType)CreateConversationAction.INSTANCE, (ActionRequest)new CreateConversationRequest(name, applicationType), actionListener);
        }
        catch (Exception exception) {
            actionListener.onFailure(exception);
        }
    }

    public void createInteraction(String conversationId, String input, String promptTemplate, String response, String origin, Map<String, String> additionalInfo, String parentIntId, Integer traceNum, ActionListener<CreateInteractionResponse> actionListener) {
        Preconditions.checkNotNull((Object)conversationId);
        Preconditions.checkNotNull((Object)input);
        Preconditions.checkNotNull((Object)response);
        additionalInfo = additionalInfo == null ? new HashMap() : additionalInfo;
        try {
            this.client.execute((ActionType)CreateInteractionAction.INSTANCE, (ActionRequest)new CreateInteractionRequest(conversationId, input, promptTemplate, response, origin, additionalInfo, parentIntId, traceNum), actionListener);
        }
        catch (Exception exception) {
            actionListener.onFailure(exception);
        }
    }

    public void getFinalInteractions(String conversationId, int lastNInteraction, ActionListener<List<Interaction>> actionListener) {
        log.debug("Getting Interactions, conversationId {}, lastN {}", (Object)conversationId, (Object)lastNInteraction);
        try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().newStoredContext(true);){
            if (!this.clusterService.state().metadata().hasIndex(ConversationalIndexConstants.INTERACTIONS_INDEX_NAME)) {
                actionListener.onResponse(List.of());
                return;
            }
            ActionListener accessListener = ActionListener.wrap(access -> {
                if (!access.booleanValue()) {
                    String userStr = (String)this.client.threadPool().getThreadContext().getTransient("_opendistro_security_user_info");
                    String user = User.parse((String)userStr) == null ? "" : User.parse((String)userStr).getName();
                    throw new OpenSearchSecurityException("User [" + user + "] does not have access to conversation " + conversationId, new Object[0]);
                }
                this.innerGetFinalInteractions(conversationId, lastNInteraction, actionListener);
            }, e -> actionListener.onFailure(e));
            this.conversationMetaIndex.checkAccess(conversationId, accessListener);
        }
        catch (Exception e2) {
            log.error("Failed to get final interactions for conversation " + conversationId, (Throwable)e2);
            actionListener.onFailure(e2);
        }
    }

    @VisibleForTesting
    void innerGetFinalInteractions(String conversationId, int lastNInteraction, ActionListener<List<Interaction>> listener) {
        SearchRequest searchRequest = Requests.searchRequest((String[])new String[]{ConversationalIndexConstants.INTERACTIONS_INDEX_NAME});
        BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery();
        ExistsQueryBuilder existsQueryBuilder = QueryBuilders.existsQuery((String)"trace_number");
        boolQueryBuilder.mustNot((QueryBuilder)existsQueryBuilder);
        TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery((String)"memory_id", (String)conversationId);
        boolQueryBuilder.must((QueryBuilder)termQueryBuilder);
        SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
        searchSourceBuilder.query((QueryBuilder)boolQueryBuilder);
        searchRequest.source(searchSourceBuilder);
        searchRequest.source().size(lastNInteraction);
        searchRequest.source().sort("create_time", SortOrder.DESC);
        try (ThreadContext.StoredContext threadContext = this.client.threadPool().getThreadContext().stashContext();){
            ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore());
            ActionListener al = ActionListener.wrap(response -> {
                LinkedList<Interaction> result = new LinkedList<Interaction>();
                for (SearchHit hit : response.getHits()) {
                    result.add(0, Interaction.fromSearchHit((SearchHit)hit));
                }
                internalListener.onResponse(result);
            }, e -> internalListener.onFailure(e));
            this.client.search(searchRequest, al);
        }
        catch (Exception e2) {
            listener.onFailure(e2);
        }
    }

    public void getTraces(String parentInteractionId, ActionListener<List<Interaction>> actionListener) {
        Preconditions.checkNotNull((Object)parentInteractionId);
        log.debug("Getting traces for conversationId {}", (Object)parentInteractionId);
        ActionListener al = ActionListener.wrap(getTracesResponse -> actionListener.onResponse((Object)getTracesResponse.getTraces()), e -> actionListener.onFailure(e));
        try {
            this.client.execute((ActionType)GetTracesAction.INSTANCE, (ActionRequest)new GetTracesRequest(parentInteractionId), al);
        }
        catch (Exception exception) {
            actionListener.onFailure(exception);
        }
    }

    public void updateInteraction(String interactionId, Map<String, Object> updateContent, ActionListener<UpdateResponse> actionListener) {
        Preconditions.checkNotNull((Object)interactionId);
        Preconditions.checkNotNull(updateContent);
        try {
            this.client.execute((ActionType)UpdateInteractionAction.INSTANCE, (ActionRequest)new UpdateInteractionRequest(interactionId, updateContent), actionListener);
        }
        catch (Exception exception) {
            actionListener.onFailure(exception);
        }
    }

    public void deleteInteractionAndTrace(String interactionId, ActionListener<Boolean> listener) {
        DeleteByQueryRequest deleteByQueryRequest = new DeleteByQueryRequest(new String[]{ConversationalIndexConstants.INTERACTIONS_INDEX_NAME});
        deleteByQueryRequest.setQuery(this.buildDeleteInteractionQuery(interactionId));
        deleteByQueryRequest.setRefresh(true);
        this.innerDeleteInteractionAndTrace(deleteByQueryRequest, interactionId, listener);
    }

    @VisibleForTesting
    void innerDeleteInteractionAndTrace(DeleteByQueryRequest deleteByQueryRequest, String interactionId, ActionListener<Boolean> listener) {
        try (ThreadContext.StoredContext ignored = this.client.threadPool().getThreadContext().stashContext();){
            ActionListener al = ActionListener.wrap(bulkResponse -> {
                if (!(bulkResponse == null || bulkResponse.getBulkFailures().isEmpty() && bulkResponse.getSearchFailures().isEmpty())) {
                    log.info("Failed to delete the interaction with ID: {}", (Object)interactionId);
                    listener.onResponse((Object)false);
                    return;
                }
                log.info("Successfully delete the interaction with ID: {}", (Object)interactionId);
                listener.onResponse((Object)true);
            }, exception -> {
                log.error("Failed to delete interaction with ID {}. Details: {}", (Object)interactionId, exception);
                listener.onFailure(exception);
            });
            this.client.execute((ActionType)DeleteByQueryAction.INSTANCE, (ActionRequest)deleteByQueryRequest, al);
        }
        catch (Exception e) {
            log.error("Failed to delete interaction with ID {}. Details {}:", (Object)interactionId, (Object)e);
            listener.onFailure(e);
        }
    }

    @VisibleForTesting
    QueryBuilder buildDeleteInteractionQuery(String interactionId) {
        BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery();
        boolQueryBuilder.should((QueryBuilder)QueryBuilders.idsQuery().addIds(new String[]{interactionId}));
        BoolQueryBuilder traceBoolBuilder = QueryBuilders.boolQuery();
        ExistsQueryBuilder existsQueryBuilder = QueryBuilders.existsQuery((String)"trace_number");
        traceBoolBuilder.must((QueryBuilder)existsQueryBuilder);
        TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery((String)"parent_message_id", (String)interactionId);
        traceBoolBuilder.must((QueryBuilder)termQueryBuilder);
        boolQueryBuilder.should((QueryBuilder)traceBoolBuilder);
        return boolQueryBuilder;
    }

    @Generated
    public MLMemoryManager(Client client, ClusterService clusterService, ConversationMetaIndex conversationMetaIndex) {
        this.client = client;
        this.clusterService = clusterService;
        this.conversationMetaIndex = conversationMetaIndex;
    }
}

