/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.task;

import com.google.common.collect.ImmutableList;
import java.time.Instant;
import java.util.List;
import java.util.UUID;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionListenerResponseHandler;
import org.opensearch.action.support.ThreadedActionListener;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.breaker.MLCircuitBreakerService;
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.MLTaskType;
import org.opensearch.ml.common.dataset.MLInputDataType;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.input.Input;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.MLPredictionOutput;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.training.MLTrainingTaskRequest;
import org.opensearch.ml.engine.MLEngine;
import org.opensearch.ml.engine.indices.MLInputDatasetHandler;
import org.opensearch.ml.stats.ActionName;
import org.opensearch.ml.stats.MLActionLevelStat;
import org.opensearch.ml.stats.MLNodeLevelStat;
import org.opensearch.ml.stats.MLStats;
import org.opensearch.ml.task.MLTaskDispatcher;
import org.opensearch.ml.task.MLTaskManager;
import org.opensearch.ml.task.MLTaskRunner;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportResponseHandler;

public class MLTrainAndPredictTaskRunner
extends MLTaskRunner<MLTrainingTaskRequest, MLTaskResponse> {
    @Generated
    private static final Logger log = LogManager.getLogger(MLTrainAndPredictTaskRunner.class);
    private final ThreadPool threadPool;
    private final ClusterService clusterService;
    private final Client client;
    private final MLInputDatasetHandler mlInputDatasetHandler;
    protected final DiscoveryNodeHelper nodeFilter;
    private final MLEngine mlEngine;

    public MLTrainAndPredictTaskRunner(ThreadPool threadPool, ClusterService clusterService, Client client, MLTaskManager mlTaskManager, MLStats mlStats, MLInputDatasetHandler mlInputDatasetHandler, MLTaskDispatcher mlTaskDispatcher, MLCircuitBreakerService mlCircuitBreakerService, DiscoveryNodeHelper nodeFilter, MLEngine mlEngine) {
        super(mlTaskManager, mlStats, nodeFilter, mlTaskDispatcher, mlCircuitBreakerService, clusterService);
        this.threadPool = threadPool;
        this.clusterService = clusterService;
        this.client = client;
        this.mlInputDatasetHandler = mlInputDatasetHandler;
        this.nodeFilter = nodeFilter;
        this.mlEngine = mlEngine;
    }

    @Override
    protected String getTransportActionName() {
        return "cluster:admin/opensearch/ml/trainAndPredict";
    }

    @Override
    protected TransportResponseHandler<MLTaskResponse> getResponseHandler(ActionListener<MLTaskResponse> listener) {
        return new ActionListenerResponseHandler(listener, MLTaskResponse::new);
    }

    @Override
    protected void executeTask(MLTrainingTaskRequest request, ActionListener<MLTaskResponse> listener) {
        MLInputDataType inputDataType = request.getMlInput().getInputDataset().getInputDataType();
        Instant now = Instant.now();
        MLTask mlTask = MLTask.builder().taskId(UUID.randomUUID().toString()).taskType(MLTaskType.TRAINING_AND_PREDICTION).inputType(inputDataType).functionName(request.getMlInput().getFunctionName()).state(MLTaskState.CREATED).workerNodes((List)ImmutableList.of((Object)this.clusterService.localNode().getId())).createTime(now).lastUpdateTime(now).async(false).build();
        MLInput mlInput = request.getMlInput();
        MLInputDataset inputDataset = mlInput.getInputDataset();
        if (inputDataset.getInputDataType().equals((Object)MLInputDataType.SEARCH_QUERY)) {
            ActionListener dataFrameActionListener = ActionListener.wrap(dataSet -> {
                MLInput newInput = mlInput.toBuilder().inputDataset(dataSet).build();
                this.trainAndPredict(mlTask, newInput, listener);
            }, e -> {
                log.error("Failed to generate DataFrame from search query", (Throwable)e);
                this.handlePredictFailure(mlTask, listener, (Exception)e, false);
            });
            this.mlInputDatasetHandler.parseSearchQueryInput(inputDataset, (ActionListener)new ThreadedActionListener(log, this.threadPool, "opensearch_ml_train", dataFrameActionListener, false));
        } else {
            this.threadPool.executor("opensearch_ml_train").execute(() -> this.trainAndPredict(mlTask, mlInput, listener));
        }
    }

    private void trainAndPredict(MLTask mlTask, MLInput mlInput, ActionListener<MLTaskResponse> listener) {
        ActionListener<MLTaskResponse> internalListener = this.wrappedCleanupListener(listener, mlTask.getTaskId());
        this.mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment();
        this.mlStats.getStat(MLNodeLevelStat.ML_REQUEST_COUNT).increment();
        this.mlStats.createCounterStatIfAbsent(mlTask.getFunctionName(), ActionName.TRAIN_PREDICT, MLActionLevelStat.ML_ACTION_REQUEST_COUNT).increment();
        mlTask.setState(MLTaskState.RUNNING);
        this.mlTaskManager.add(mlTask);
        try {
            this.mlTaskManager.updateTaskStateAsRunning(mlTask.getTaskId(), mlTask.getTenantId(), mlTask.isAsync());
            MLOutput output = this.mlEngine.trainAndPredict((Input)mlInput);
            this.handleAsyncMLTaskComplete(mlTask);
            if (output instanceof MLPredictionOutput) {
                ((MLPredictionOutput)output).setStatus(MLTaskState.COMPLETED.name());
            }
            MLTaskResponse response = MLTaskResponse.builder().output(output).build();
            log.debug("Train and predict task done for algorithm: {}, task id: {}", (Object)mlTask.getFunctionName(), (Object)mlTask.getTaskId());
            internalListener.onResponse((Object)response);
        }
        catch (Exception e) {
            log.error("Failed to train and predict " + String.valueOf(mlInput.getAlgorithm()), (Throwable)e);
            this.handlePredictFailure(mlTask, listener, e, true);
            return;
        }
    }

    private void handlePredictFailure(MLTask mlTask, ActionListener<MLTaskResponse> listener, Exception e, boolean trackFailure) {
        if (trackFailure) {
            this.mlStats.createCounterStatIfAbsent(mlTask.getFunctionName(), ActionName.TRAIN_PREDICT, MLActionLevelStat.ML_ACTION_FAILURE_COUNT).increment();
            this.mlStats.getStat(MLNodeLevelStat.ML_FAILURE_COUNT).increment();
        }
        this.handleAsyncMLTaskFailure(mlTask, e);
        listener.onFailure(e);
    }
}

