/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.expression.predicate.operator.comparison;

import java.io.IOException;
import java.util.BitSet;
import java.util.Collections;
import java.util.List;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.logging.LoggerMessageFormat;
import org.elasticsearch.compute.operator.EvalOperator;
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Expressions;
import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
import org.elasticsearch.xpack.esql.core.expression.predicate.operator.comparison.Comparisons;
import org.elasticsearch.xpack.esql.core.tree.Node;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.core.util.CollectionUtils;
import org.elasticsearch.xpack.esql.core.util.StringUtils;
import org.elasticsearch.xpack.esql.evaluator.mapper.EvaluatorMapper;
import org.elasticsearch.xpack.esql.expression.EsqlTypeResolutions;
import org.elasticsearch.xpack.esql.expression.function.Example;
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
import org.elasticsearch.xpack.esql.expression.function.Param;
import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction;
import org.elasticsearch.xpack.esql.expression.function.scalar.math.Cast;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.InBooleanEvaluator;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.InBytesRefEvaluator;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.InDoubleEvaluator;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.InIntEvaluator;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.InLongEvaluator;
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
import org.elasticsearch.xpack.esql.type.EsqlDataTypeConverter;

public class In
extends EsqlScalarFunction {
    public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "In", In::new);
    private final Expression value;
    private final List<Expression> list;

    @FunctionInfo(returnType={"boolean"}, description="The `IN` operator allows testing whether a field or expression equals an element in a list of literals, fields or expressions.", examples={@Example(file="row", tag="in-with-expressions")})
    public In(Source source, @Param(name="field", type={"boolean", "cartesian_point", "cartesian_shape", "double", "geo_point", "geo_shape", "integer", "ip", "keyword", "long", "text", "version"}, description="An expression.") Expression value, @Param(name="inlist", type={"boolean", "cartesian_point", "cartesian_shape", "double", "geo_point", "geo_shape", "integer", "ip", "keyword", "long", "text", "version"}, description="A list of items.") List<Expression> list) {
        super(source, CollectionUtils.combine(list, (Object[])new Expression[]{value}));
        this.value = value;
        this.list = list;
    }

    public Expression value() {
        return this.value;
    }

    public List<Expression> list() {
        return this.list;
    }

    public DataType dataType() {
        return DataType.BOOLEAN;
    }

    private In(StreamInput in) throws IOException {
        this(Source.readFrom((StreamInput)((PlanStreamInput)in)), (Expression)in.readNamedWriteable(Expression.class), in.readNamedWriteableCollectionAsList(Expression.class));
    }

    public void writeTo(StreamOutput out) throws IOException {
        this.source().writeTo(out);
        out.writeNamedWriteable((NamedWriteable)this.value);
        out.writeNamedWriteableCollection(this.list);
    }

    public String getWriteableName() {
        return In.ENTRY.name;
    }

    protected NodeInfo<? extends Expression> info() {
        return NodeInfo.create((Node)this, In::new, (Object)this.value, this.list);
    }

    public Expression replaceChildren(List<Expression> newChildren) {
        return new In(this.source(), newChildren.get(newChildren.size() - 1), newChildren.subList(0, newChildren.size() - 1));
    }

    public boolean foldable() {
        return Expressions.isNull((Expression)this.value) || Expressions.foldable((List)this.children()) || Expressions.foldable(this.list) && this.list.stream().allMatch(Expressions::isNull);
    }

    @Override
    public Object fold() {
        if (Expressions.isNull((Expression)this.value) || this.list.stream().allMatch(Expressions::isNull)) {
            return null;
        }
        return super.fold();
    }

    protected boolean areCompatible(DataType left, DataType right) {
        if (left == DataType.UNSIGNED_LONG || right == DataType.UNSIGNED_LONG) {
            return left == right;
        }
        if (DataType.isSpatial((DataType)left) && DataType.isSpatial((DataType)right)) {
            return left == right;
        }
        return DataType.areCompatible((DataType)left, (DataType)right);
    }

    protected Expression.TypeResolution resolveType() {
        Expression.TypeResolution resolution = EsqlTypeResolutions.isExact(this.value, this.functionName(), TypeResolutions.ParamOrdinal.DEFAULT);
        if (resolution.unresolved()) {
            return resolution;
        }
        DataType dt = this.value.dataType();
        for (int i = 0; i < this.list.size(); ++i) {
            Expression listValue = this.list.get(i);
            if (this.areCompatible(dt, listValue.dataType())) continue;
            return new Expression.TypeResolution(LoggerMessageFormat.format(null, (String)"{} argument of [{}] must be [{}], found value [{}] type [{}]", (Object[])new Object[]{StringUtils.ordinal((int)(i + 1)), this.sourceText(), dt.typeName(), Expressions.name((Expression)listValue), listValue.dataType().typeName()}));
        }
        return Expression.TypeResolution.TYPE_RESOLVED;
    }

    protected Expression canonicalize() {
        List canonicalValues = Expressions.canonicalize(this.list);
        Collections.sort(canonicalValues, (l, r) -> Integer.compare(l.hashCode(), r.hashCode()));
        return new In(this.source(), this.value, canonicalValues);
    }

    @Override
    public EvalOperator.ExpressionEvaluator.Factory toEvaluator(EvaluatorMapper.ToEvaluator toEvaluator) {
        EvalOperator.ExpressionEvaluator.Factory[] factories;
        EvalOperator.ExpressionEvaluator.Factory lhs;
        DataType commonType = this.commonType();
        if (commonType.isNumeric()) {
            lhs = Cast.cast(this.source(), this.value.dataType(), commonType, toEvaluator.apply(this.value));
            factories = (EvalOperator.ExpressionEvaluator.Factory[])this.list.stream().map(e -> Cast.cast(this.source(), e.dataType(), commonType, toEvaluator.apply((Expression)e))).toArray(EvalOperator.ExpressionEvaluator.Factory[]::new);
        } else {
            lhs = toEvaluator.apply(this.value);
            factories = (EvalOperator.ExpressionEvaluator.Factory[])this.list.stream().map(toEvaluator::apply).toArray(EvalOperator.ExpressionEvaluator.Factory[]::new);
        }
        if (commonType == DataType.BOOLEAN) {
            return new InBooleanEvaluator.Factory(this.source(), lhs, factories);
        }
        if (commonType == DataType.DOUBLE) {
            return new InDoubleEvaluator.Factory(this.source(), lhs, factories);
        }
        if (commonType == DataType.INTEGER) {
            return new InIntEvaluator.Factory(this.source(), lhs, factories);
        }
        if (commonType == DataType.LONG || commonType == DataType.DATETIME || commonType == DataType.UNSIGNED_LONG) {
            return new InLongEvaluator.Factory(this.source(), lhs, factories);
        }
        if (commonType == DataType.KEYWORD || commonType == DataType.TEXT || commonType == DataType.IP || commonType == DataType.VERSION || commonType == DataType.UNSUPPORTED || DataType.isSpatial((DataType)commonType)) {
            return new InBytesRefEvaluator.Factory(this.source(), toEvaluator.apply(this.value), factories);
        }
        if (commonType == DataType.NULL) {
            return EvalOperator.CONSTANT_NULL_FACTORY;
        }
        throw EsqlIllegalArgumentException.illegalDataType(commonType);
    }

    private DataType commonType() {
        DataType commonType = this.value.dataType();
        for (Expression e : this.list) {
            if (e.dataType() == DataType.NULL && this.value.dataType() != DataType.NULL) continue;
            if (DataType.isSpatial((DataType)commonType)) {
                if (e.dataType() == commonType) continue;
                commonType = DataType.NULL;
                break;
            }
            commonType = EsqlDataTypeConverter.commonType(commonType, e.dataType());
        }
        return commonType;
    }

    static boolean process(BitSet nulls, BitSet mvs, int lhs, int[] rhs) {
        for (int i = 0; i < rhs.length; ++i) {
            Boolean compResult;
            if (nulls != null && nulls.get(i) || mvs != null && mvs.get(i) || (compResult = Comparisons.eq((Object)lhs, (Object)rhs[i])) != Boolean.TRUE) continue;
            return true;
        }
        return false;
    }

    static boolean process(BitSet nulls, BitSet mvs, long lhs, long[] rhs) {
        for (int i = 0; i < rhs.length; ++i) {
            Boolean compResult;
            if (nulls != null && nulls.get(i) || mvs != null && mvs.get(i) || (compResult = Comparisons.eq((Object)lhs, (Object)rhs[i])) != Boolean.TRUE) continue;
            return true;
        }
        return false;
    }

    static boolean process(BitSet nulls, BitSet mvs, double lhs, double[] rhs) {
        for (int i = 0; i < rhs.length; ++i) {
            Boolean compResult;
            if (nulls != null && nulls.get(i) || mvs != null && mvs.get(i) || (compResult = Comparisons.eq((Object)lhs, (Object)rhs[i])) != Boolean.TRUE) continue;
            return true;
        }
        return false;
    }

    static boolean process(BitSet nulls, BitSet mvs, BytesRef lhs, BytesRef[] rhs) {
        for (int i = 0; i < rhs.length; ++i) {
            Boolean compResult;
            if (nulls != null && nulls.get(i) || mvs != null && mvs.get(i) || (compResult = Comparisons.eq((Object)lhs, (Object)rhs[i])) != Boolean.TRUE) continue;
            return true;
        }
        return false;
    }
}

