/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.search.vectors;

import java.io.IOException;
import java.util.Arrays;
import java.util.Objects;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.lucene.Lucene;
import org.elasticsearch.index.query.AbstractQueryBuilder;
import org.elasticsearch.index.query.MatchNoneQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.search.vectors.ExactKnnQueryBuilder;
import org.elasticsearch.search.vectors.KnnScoreDocQuery;
import org.elasticsearch.search.vectors.VectorData;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;

public class KnnScoreDocQueryBuilder
extends AbstractQueryBuilder<KnnScoreDocQueryBuilder> {
    public static final String NAME = "knn_score_doc";
    private final ScoreDoc[] scoreDocs;
    private final String fieldName;
    private final VectorData queryVector;
    private final Float vectorSimilarity;

    public KnnScoreDocQueryBuilder(ScoreDoc[] scoreDocs, String fieldName, VectorData queryVector, Float vectorSimilarity) {
        this.scoreDocs = scoreDocs;
        this.fieldName = fieldName;
        this.queryVector = queryVector;
        this.vectorSimilarity = vectorSimilarity;
    }

    public KnnScoreDocQueryBuilder(StreamInput in) throws IOException {
        super(in);
        this.scoreDocs = in.readArray(Lucene::readScoreDoc, ScoreDoc[]::new);
        if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_13_0)) {
            this.fieldName = in.readOptionalString();
            this.queryVector = in.readBoolean() ? (in.getTransportVersion().onOrAfter(TransportVersions.V_8_14_0) ? in.readOptionalWriteable(VectorData::new) : VectorData.fromFloats(in.readFloatArray())) : null;
        } else {
            this.fieldName = null;
            this.queryVector = null;
        }
        this.vectorSimilarity = in.getTransportVersion().onOrAfter(TransportVersions.FIX_VECTOR_SIMILARITY_INNER_HITS) || in.getTransportVersion().isPatchFrom(TransportVersions.V_8_15_0) ? in.readOptionalFloat() : null;
    }

    @Override
    public String getWriteableName() {
        return NAME;
    }

    public ScoreDoc[] scoreDocs() {
        return this.scoreDocs;
    }

    String fieldName() {
        return this.fieldName;
    }

    VectorData queryVector() {
        return this.queryVector;
    }

    Float vectorSimilarity() {
        return this.vectorSimilarity;
    }

    @Override
    protected void doWriteTo(StreamOutput out) throws IOException {
        out.writeArray(Lucene::writeScoreDoc, this.scoreDocs);
        if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_13_0)) {
            out.writeOptionalString(this.fieldName);
            if (this.queryVector != null) {
                out.writeBoolean(true);
                if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_14_0)) {
                    out.writeOptionalWriteable(this.queryVector);
                } else {
                    out.writeFloatArray(this.queryVector.asFloatVector());
                }
            } else {
                out.writeBoolean(false);
            }
        }
        if (out.getTransportVersion().onOrAfter(TransportVersions.FIX_VECTOR_SIMILARITY_INNER_HITS) || out.getTransportVersion().isPatchFrom(TransportVersions.V_8_15_0)) {
            out.writeOptionalFloat(this.vectorSimilarity);
        }
    }

    @Override
    protected void doXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
        builder.startObject(NAME);
        builder.startArray("values");
        for (ScoreDoc scoreDoc : this.scoreDocs) {
            builder.startObject().field("doc", scoreDoc.doc).field("score", scoreDoc.score).endObject();
        }
        builder.endArray();
        if (this.fieldName != null) {
            builder.field("field", this.fieldName);
        }
        if (this.queryVector != null) {
            builder.field("query", (ToXContent)this.queryVector);
        }
        if (this.vectorSimilarity != null) {
            builder.field("similarity", this.vectorSimilarity);
        }
        this.boostAndQueryNameToXContent(builder);
        builder.endObject();
    }

    @Override
    protected Query doToQuery(SearchExecutionContext context) throws IOException {
        int numDocs = this.scoreDocs.length;
        int[] docs = new int[numDocs];
        float[] scores = new float[numDocs];
        for (int i = 0; i < numDocs; ++i) {
            docs[i] = this.scoreDocs[i].doc;
            scores[i] = this.scoreDocs[i].score;
        }
        IndexReader reader = context.getIndexReader();
        int[] segmentStarts = KnnScoreDocQueryBuilder.findSegmentStarts(reader, docs);
        return new KnnScoreDocQuery(docs, scores, segmentStarts, reader.getContext().id());
    }

    @Override
    protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException {
        if (this.scoreDocs.length == 0) {
            return new MatchNoneQueryBuilder("The \"" + this.getName() + "\" query was rewritten to a \"match_none\" query.");
        }
        if (queryRewriteContext.convertToInnerHitsRewriteContext() != null && this.queryVector != null && this.fieldName != null) {
            return new ExactKnnQueryBuilder(this.queryVector, this.fieldName, this.vectorSimilarity);
        }
        return super.doRewrite(queryRewriteContext);
    }

    private static int[] findSegmentStarts(IndexReader reader, int[] docs) {
        int[] starts = new int[reader.leaves().size() + 1];
        starts[starts.length - 1] = docs.length;
        if (starts.length == 2) {
            return starts;
        }
        int resultIndex = 0;
        for (int i = 1; i < starts.length - 1; ++i) {
            int upper = ((LeafReaderContext)reader.leaves().get((int)i)).docBase;
            if ((resultIndex = Arrays.binarySearch(docs, resultIndex, docs.length, upper)) < 0) {
                resultIndex = -1 - resultIndex;
            }
            starts[i] = resultIndex;
        }
        return starts;
    }

    @Override
    protected boolean doEquals(KnnScoreDocQueryBuilder other) {
        if (this.scoreDocs.length != other.scoreDocs.length) {
            return false;
        }
        for (int i = 0; i < this.scoreDocs.length; ++i) {
            ScoreDoc scoreDoc = this.scoreDocs[i];
            ScoreDoc otherScoreDoc = other.scoreDocs[i];
            if (scoreDoc.doc == otherScoreDoc.doc && scoreDoc.score == otherScoreDoc.score && scoreDoc.shardIndex == otherScoreDoc.shardIndex) continue;
            return false;
        }
        return Objects.equals(this.fieldName, other.fieldName) && Objects.equals(this.queryVector, other.queryVector) && Objects.equals(this.vectorSimilarity, other.vectorSimilarity);
    }

    @Override
    protected int doHashCode() {
        int result = 1;
        for (ScoreDoc scoreDoc : this.scoreDocs) {
            int hashCode = Objects.hash(scoreDoc.doc, Float.valueOf(scoreDoc.score), scoreDoc.shardIndex);
            result = 31 * result + hashCode;
        }
        return Objects.hash(result, this.fieldName, this.vectorSimilarity, Objects.hashCode(this.queryVector));
    }

    @Override
    public TransportVersion getMinimalSupportedVersion() {
        return TransportVersions.V_8_4_0;
    }
}

