/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.sql.expression.function.CollectionUDF;

import java.util.ArrayList;
import java.util.List;
import org.apache.calcite.adapter.enumerable.NotNullImplementor;
import org.apache.calcite.adapter.enumerable.NullPolicy;
import org.apache.calcite.adapter.enumerable.RexToLixTranslator;
import org.apache.calcite.linq4j.function.Function1;
import org.apache.calcite.linq4j.function.Function2;
import org.apache.calcite.linq4j.tree.Expression;
import org.apache.calcite.linq4j.tree.Expressions;
import org.apache.calcite.linq4j.tree.Types;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexCallBinding;
import org.apache.calcite.rex.RexLambda;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.type.ArraySqlType;
import org.apache.calcite.sql.type.SqlReturnTypeInference;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.type.SqlTypeUtil;
import org.opensearch.sql.expression.function.CollectionUDF.LambdaUtils;
import org.opensearch.sql.expression.function.ImplementorUDF;
import org.opensearch.sql.expression.function.UDFOperandMetadata;

public class TransformFunctionImpl
extends ImplementorUDF {
    public TransformFunctionImpl() {
        super(new TransformImplementor(), NullPolicy.ANY);
    }

    @Override
    public SqlReturnTypeInference getReturnTypeInference() {
        return sqlOperatorBinding -> {
            RelDataTypeFactory typeFactory = sqlOperatorBinding.getTypeFactory();
            RexCallBinding rexCallBinding = (RexCallBinding)sqlOperatorBinding;
            List<RexNode> operands = rexCallBinding.operands();
            RelDataType lambdaReturnType = ((RexLambda)operands.get(1)).getExpression().getType();
            return SqlTypeUtil.createArrayType(typeFactory, typeFactory.createTypeWithNullability(lambdaReturnType, true), true);
        };
    }

    @Override
    public UDFOperandMetadata getOperandMetadata() {
        return null;
    }

    public static Object eval(Object ... args2) {
        List target = (List)args2[0];
        ArrayList<Object> results = new ArrayList<Object>();
        SqlTypeName returnType = (SqlTypeName)((Object)args2[args2.length - 1]);
        if (args2[1] instanceof Function1) {
            Function1 lambdaFunction = (Function1)args2[1];
            try {
                for (Object candidate : target) {
                    results.add(LambdaUtils.transferLambdaOutputToTargetType(lambdaFunction.apply(candidate), returnType));
                }
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
            return results;
        }
        if (args2[1] instanceof Function2) {
            Function2 lambdaFunction = (Function2)args2[1];
            try {
                for (int i = 0; i < target.size(); ++i) {
                    results.add(LambdaUtils.transferLambdaOutputToTargetType(lambdaFunction.apply(target.get(i), i), returnType));
                }
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
            return results;
        }
        throw new IllegalArgumentException("wrong lambda function input");
    }

    public static class TransformImplementor
    implements NotNullImplementor {
        @Override
        public Expression implement(RexToLixTranslator translator, RexCall call, List<Expression> translatedOperands) {
            ArraySqlType arrayType = (ArraySqlType)call.getType();
            ArrayList<Expression> withReturnTypeList = new ArrayList<Expression>(translatedOperands);
            withReturnTypeList.add(Expressions.constant((Object)arrayType.getComponentType().getSqlTypeName()));
            return Expressions.call(Types.lookupMethod(TransformFunctionImpl.class, "eval", Object[].class), withReturnTypeList);
        }
    }
}

