/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.inference.action.filter;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.DocWriteRequest;
import org.elasticsearch.action.bulk.BulkItemRequest;
import org.elasticsearch.action.bulk.BulkShardRequest;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.support.ActionFilterChain;
import org.elasticsearch.action.support.MappedActionFilter;
import org.elasticsearch.action.support.RefCountingRunnable;
import org.elasticsearch.action.update.UpdateRequest;
import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
import org.elasticsearch.common.util.concurrent.AtomicArray;
import org.elasticsearch.common.xcontent.support.XContentMapValues;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkedInferenceServiceResults;
import org.elasticsearch.inference.ChunkingOptions;
import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceRegistry;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.UnparsedModel;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults;
import org.elasticsearch.xpack.inference.InferenceException;
import org.elasticsearch.xpack.inference.mapper.SemanticTextField;
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;

public class ShardBulkInferenceActionFilter
implements MappedActionFilter {
    protected static final int DEFAULT_BATCH_SIZE = 512;
    private final InferenceServiceRegistry inferenceServiceRegistry;
    private final ModelRegistry modelRegistry;
    private final int batchSize;

    public ShardBulkInferenceActionFilter(InferenceServiceRegistry inferenceServiceRegistry, ModelRegistry modelRegistry) {
        this(inferenceServiceRegistry, modelRegistry, 512);
    }

    public ShardBulkInferenceActionFilter(InferenceServiceRegistry inferenceServiceRegistry, ModelRegistry modelRegistry, int batchSize) {
        this.inferenceServiceRegistry = inferenceServiceRegistry;
        this.modelRegistry = modelRegistry;
        this.batchSize = batchSize;
    }

    public String actionName() {
        return "indices:data/write/bulk[s]";
    }

    public <Request extends ActionRequest, Response extends ActionResponse> void apply(Task task, String action, Request request, ActionListener<Response> listener, ActionFilterChain<Request, Response> chain) {
        BulkShardRequest bulkShardRequest;
        Map fieldInferenceMetadata;
        if ("indices:data/write/bulk[s]".equals(action) && (fieldInferenceMetadata = (bulkShardRequest = (BulkShardRequest)request).consumeInferenceFieldMap()) != null && !fieldInferenceMetadata.isEmpty()) {
            Runnable onInferenceCompletion = () -> chain.proceed(task, action, request, listener);
            this.processBulkShardRequest(fieldInferenceMetadata, bulkShardRequest, onInferenceCompletion);
            return;
        }
        chain.proceed(task, action, request, listener);
    }

    private void processBulkShardRequest(Map<String, InferenceFieldMetadata> fieldInferenceMap, BulkShardRequest bulkShardRequest, Runnable onCompletion) {
        new AsyncBulkShardInferenceAction(fieldInferenceMap, bulkShardRequest, onCompletion).run();
    }

    private static List<String> nodeStringValues(String field, Object valueObj) {
        if (valueObj instanceof Number || valueObj instanceof Boolean) {
            return List.of(valueObj.toString());
        }
        if (valueObj instanceof String) {
            String value = (String)valueObj;
            return List.of(value);
        }
        if (valueObj instanceof Collection) {
            Collection values = (Collection)valueObj;
            ArrayList<String> valuesString = new ArrayList<String>();
            for (Object v : values) {
                if (v instanceof Number || v instanceof Boolean) {
                    valuesString.add(v.toString());
                    continue;
                }
                if (v instanceof String) {
                    String value = (String)v;
                    valuesString.add(value);
                    continue;
                }
                throw new ElasticsearchStatusException("Invalid format for field [{}], expected [String] got [{}]", RestStatus.BAD_REQUEST, new Object[]{field, valueObj.getClass().getSimpleName()});
            }
            return valuesString;
        }
        throw new ElasticsearchStatusException("Invalid format for field [{}], expected [String] got [{}]", RestStatus.BAD_REQUEST, new Object[]{field, valueObj.getClass().getSimpleName()});
    }

    static IndexRequest getIndexRequestOrNull(DocWriteRequest<?> docWriteRequest) {
        if (docWriteRequest instanceof IndexRequest) {
            IndexRequest indexRequest = (IndexRequest)docWriteRequest;
            return indexRequest;
        }
        if (docWriteRequest instanceof UpdateRequest) {
            UpdateRequest updateRequest = (UpdateRequest)docWriteRequest;
            return updateRequest.doc();
        }
        return null;
    }

    private class AsyncBulkShardInferenceAction
    implements Runnable {
        private final Map<String, InferenceFieldMetadata> fieldInferenceMap;
        private final BulkShardRequest bulkShardRequest;
        private final Runnable onCompletion;
        private final AtomicArray<FieldInferenceResponseAccumulator> inferenceResults;

        private AsyncBulkShardInferenceAction(Map<String, InferenceFieldMetadata> fieldInferenceMap, BulkShardRequest bulkShardRequest, Runnable onCompletion) {
            this.fieldInferenceMap = fieldInferenceMap;
            this.bulkShardRequest = bulkShardRequest;
            this.inferenceResults = new AtomicArray(bulkShardRequest.items().length);
            this.onCompletion = onCompletion;
        }

        @Override
        public void run() {
            Map<String, List<FieldInferenceRequest>> inferenceRequests = this.createFieldInferenceRequests(this.bulkShardRequest);
            Runnable onInferenceCompletion = () -> {
                try {
                    for (FieldInferenceResponseAccumulator inferenceResponse : this.inferenceResults.asList()) {
                        BulkItemRequest request = this.bulkShardRequest.items()[inferenceResponse.id];
                        try {
                            this.applyInferenceResponses(request, inferenceResponse);
                        }
                        catch (Exception exc) {
                            request.abort(this.bulkShardRequest.index(), exc);
                        }
                    }
                }
                finally {
                    this.onCompletion.run();
                }
            };
            try (RefCountingRunnable releaseOnFinish = new RefCountingRunnable(onInferenceCompletion);){
                for (Map.Entry<String, List<FieldInferenceRequest>> entry : inferenceRequests.entrySet()) {
                    this.executeShardBulkInferenceAsync(entry.getKey(), null, entry.getValue(), releaseOnFinish.acquire());
                }
            }
        }

        private void executeShardBulkInferenceAsync(final String inferenceId, final @Nullable InferenceProvider inferenceProvider, final List<FieldInferenceRequest> requests, final Releasable onFinish) {
            if (inferenceProvider == null) {
                ActionListener<UnparsedModel> modelLoadingListener = new ActionListener<UnparsedModel>(){

                    public void onResponse(UnparsedModel unparsedModel) {
                        Optional service = ShardBulkInferenceActionFilter.this.inferenceServiceRegistry.getService(unparsedModel.service());
                        if (!service.isEmpty()) {
                            InferenceProvider provider = new InferenceProvider((InferenceService)service.get(), ((InferenceService)service.get()).parsePersistedConfigWithSecrets(inferenceId, unparsedModel.taskType(), unparsedModel.settings(), unparsedModel.secrets()));
                            AsyncBulkShardInferenceAction.this.executeShardBulkInferenceAsync(inferenceId, provider, requests, onFinish);
                        } else {
                            try (Releasable releasable = onFinish;){
                                for (FieldInferenceRequest request : requests) {
                                    ((FieldInferenceResponseAccumulator)AsyncBulkShardInferenceAction.this.inferenceResults.get((int)request.index)).failures.add((Exception)new ResourceNotFoundException("Inference service [{}] not found for field [{}]", new Object[]{unparsedModel.service(), request.field}));
                                }
                            }
                        }
                    }

                    public void onFailure(Exception exc) {
                        try (Releasable releasable = onFinish;){
                            for (FieldInferenceRequest request : requests) {
                                Object failure = ExceptionsHelper.unwrap((Throwable)exc, (Class[])new Class[]{ResourceNotFoundException.class}) instanceof ResourceNotFoundException ? new ResourceNotFoundException("Inference id [{}] not found for field [{}]", new Object[]{inferenceId, request.field}) : new InferenceException("Error loading inference for inference id [{}] on field [{}]", exc, inferenceId, request.field);
                                ((FieldInferenceResponseAccumulator)AsyncBulkShardInferenceAction.this.inferenceResults.get((int)request.index)).failures.add((Exception)failure);
                            }
                        }
                    }
                };
                ShardBulkInferenceActionFilter.this.modelRegistry.getModelWithSecrets(inferenceId, modelLoadingListener);
                return;
            }
            int currentBatchSize = Math.min(requests.size(), ShardBulkInferenceActionFilter.this.batchSize);
            List<FieldInferenceRequest> currentBatch = requests.subList(0, currentBatchSize);
            final List<FieldInferenceRequest> nextBatch = requests.subList(currentBatchSize, requests.size());
            List inputs = currentBatch.stream().map(FieldInferenceRequest::input).collect(Collectors.toList());
            ActionListener<List<ChunkedInferenceServiceResults>> completionListener = new ActionListener<List<ChunkedInferenceServiceResults>>(){

                /*
                 * WARNING - Removed try catching itself - possible behaviour change.
                 */
                public void onResponse(List<ChunkedInferenceServiceResults> results) {
                    try {
                        Iterator requestsIterator = requests.iterator();
                        for (ChunkedInferenceServiceResults result : results) {
                            FieldInferenceRequest request = (FieldInferenceRequest)requestsIterator.next();
                            FieldInferenceResponseAccumulator acc = (FieldInferenceResponseAccumulator)AsyncBulkShardInferenceAction.this.inferenceResults.get(request.index);
                            if (result instanceof ErrorChunkedInferenceResults) {
                                ErrorChunkedInferenceResults error = (ErrorChunkedInferenceResults)result;
                                acc.addFailure((Exception)((Object)new InferenceException("Exception when running inference id [{}] on field [{}]", error.getException(), inferenceProvider.model.getInferenceEntityId(), request.field)));
                                continue;
                            }
                            acc.addOrUpdateResponse(new FieldInferenceResponse(request.field(), request.input(), request.inputOrder(), request.isOriginalFieldInput(), inferenceProvider.model, result));
                        }
                    }
                    finally {
                        this.onFinish();
                    }
                }

                /*
                 * WARNING - Removed try catching itself - possible behaviour change.
                 */
                public void onFailure(Exception exc) {
                    try {
                        for (FieldInferenceRequest request : requests) {
                            AsyncBulkShardInferenceAction.this.addInferenceResponseFailure(request.index, (Exception)((Object)new InferenceException("Exception when running inference id [{}] on field [{}]", exc, inferenceProvider.model.getInferenceEntityId(), request.field)));
                        }
                    }
                    finally {
                        this.onFinish();
                    }
                }

                private void onFinish() {
                    if (nextBatch.isEmpty()) {
                        onFinish.close();
                    } else {
                        AsyncBulkShardInferenceAction.this.executeShardBulkInferenceAsync(inferenceId, inferenceProvider, nextBatch, onFinish);
                    }
                }
            };
            inferenceProvider.service().chunkedInfer(inferenceProvider.model(), null, inputs, Map.of(), InputType.INGEST, new ChunkingOptions(null, null), TimeValue.MAX_VALUE, (ActionListener)completionListener);
        }

        private FieldInferenceResponseAccumulator ensureResponseAccumulatorSlot(int id) {
            FieldInferenceResponseAccumulator acc = (FieldInferenceResponseAccumulator)this.inferenceResults.get(id);
            if (acc == null) {
                acc = new FieldInferenceResponseAccumulator(id, new HashMap<String, List<FieldInferenceResponse>>(), new ArrayList<Exception>());
                this.inferenceResults.set(id, (Object)acc);
            }
            return acc;
        }

        private void addInferenceResponseFailure(int id, Exception failure) {
            FieldInferenceResponseAccumulator acc = this.ensureResponseAccumulatorSlot(id);
            acc.addFailure(failure);
        }

        private void applyInferenceResponses(BulkItemRequest item, FieldInferenceResponseAccumulator response) {
            if (!response.failures().isEmpty()) {
                for (Exception failure : response.failures()) {
                    item.abort(item.index(), failure);
                }
                return;
            }
            IndexRequest indexRequest = ShardBulkInferenceActionFilter.getIndexRequestOrNull(item.request());
            Map newDocMap = indexRequest.sourceAsMap();
            for (Map.Entry<String, List<FieldInferenceResponse>> entry : response.responses.entrySet()) {
                String fieldName = entry.getKey();
                List<FieldInferenceResponse> responses = entry.getValue();
                Model model = responses.get(0).model();
                Collections.sort(responses, Comparator.comparingInt(FieldInferenceResponse::inputOrder));
                List<String> inputs = responses.stream().filter(r -> r.isOriginalFieldInput).map(r -> r.input).collect(Collectors.toList());
                List<ChunkedInferenceServiceResults> results = responses.stream().map(r -> r.chunkedResults).collect(Collectors.toList());
                SemanticTextField result = new SemanticTextField(fieldName, inputs, new SemanticTextField.InferenceResult(model.getInferenceEntityId(), new SemanticTextField.ModelSettings(model), SemanticTextField.toSemanticTextFieldChunks(results, indexRequest.getContentType())), indexRequest.getContentType());
                SemanticTextFieldMapper.insertValue(fieldName, newDocMap, result);
            }
            indexRequest.source(newDocMap, indexRequest.getContentType());
        }

        private Map<String, List<FieldInferenceRequest>> createFieldInferenceRequests(BulkShardRequest bulkShardRequest) {
            LinkedHashMap<String, List<FieldInferenceRequest>> fieldRequestsMap = new LinkedHashMap<String, List<FieldInferenceRequest>>();
            for (int itemIndex = 0; itemIndex < bulkShardRequest.items().length; ++itemIndex) {
                IndexRequest indexRequest;
                BulkItemRequest item = bulkShardRequest.items()[itemIndex];
                if (item.getPrimaryResponse() != null) continue;
                boolean isUpdateRequest = false;
                Object object = item.request();
                if (object instanceof IndexRequest) {
                    IndexRequest ir;
                    indexRequest = ir = (IndexRequest)object;
                } else {
                    object = item.request();
                    if (!(object instanceof UpdateRequest)) continue;
                    UpdateRequest updateRequest = (UpdateRequest)object;
                    isUpdateRequest = true;
                    if (updateRequest.script() != null) {
                        this.addInferenceResponseFailure(itemIndex, (Exception)((Object)new ElasticsearchStatusException("Cannot apply update with a script on indices that contain [{}] field(s)", RestStatus.BAD_REQUEST, new Object[]{"semantic_text"})));
                        continue;
                    }
                    indexRequest = updateRequest.doc();
                }
                Map docMap = indexRequest.sourceAsMap();
                block3: for (InferenceFieldMetadata entry : this.fieldInferenceMap.values()) {
                    String field = entry.getName();
                    String inferenceId = entry.getInferenceId();
                    Object originalFieldValue = XContentMapValues.extractValue((String)field, (Map)docMap);
                    if (originalFieldValue instanceof Map || originalFieldValue == null && entry.getSourceFields().length == 1) continue;
                    int order = 0;
                    for (String sourceField : entry.getSourceFields()) {
                        List<String> values;
                        boolean isOriginalFieldInput = sourceField.equals(field);
                        Object valueObj = XContentMapValues.extractValue((String)sourceField, (Map)docMap);
                        if (valueObj == null) {
                            if (!isUpdateRequest) continue;
                            this.addInferenceResponseFailure(itemIndex, (Exception)((Object)new ElasticsearchStatusException("Field [{}] must be specified on an update request to calculate inference for field [{}]", RestStatus.BAD_REQUEST, new Object[]{sourceField, field})));
                            continue block3;
                        }
                        this.ensureResponseAccumulatorSlot(itemIndex);
                        try {
                            values = ShardBulkInferenceActionFilter.nodeStringValues(field, valueObj);
                        }
                        catch (Exception exc) {
                            this.addInferenceResponseFailure(itemIndex, exc);
                            continue block3;
                        }
                        List fieldRequests = fieldRequestsMap.computeIfAbsent(inferenceId, k -> new ArrayList());
                        for (String v : values) {
                            fieldRequests.add(new FieldInferenceRequest(itemIndex, field, v, order++, isOriginalFieldInput));
                        }
                    }
                }
            }
            return fieldRequestsMap;
        }
    }

    private record FieldInferenceResponseAccumulator(int id, Map<String, List<FieldInferenceResponse>> responses, List<Exception> failures) {
        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        void addOrUpdateResponse(FieldInferenceResponse response) {
            FieldInferenceResponseAccumulator fieldInferenceResponseAccumulator = this;
            synchronized (fieldInferenceResponseAccumulator) {
                List list = this.responses.computeIfAbsent(response.field, k -> new ArrayList());
                list.add(response);
            }
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        void addFailure(Exception exc) {
            FieldInferenceResponseAccumulator fieldInferenceResponseAccumulator = this;
            synchronized (fieldInferenceResponseAccumulator) {
                this.failures.add(exc);
            }
        }
    }

    private record FieldInferenceResponse(String field, String input, int inputOrder, boolean isOriginalFieldInput, Model model, ChunkedInferenceServiceResults chunkedResults) {
    }

    private record FieldInferenceRequest(int index, String field, String input, int inputOrder, boolean isOriginalFieldInput) {
    }

    private record InferenceProvider(InferenceService service, Model model) {
    }
}

