/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.inference.queries;

import java.util.Collection;
import java.util.List;
import java.util.Map;
import org.apache.lucene.search.join.ScoreMode;
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.query.TermsQueryBuilder;
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
import org.elasticsearch.search.vectors.QueryVectorBuilder;
import org.elasticsearch.xpack.core.ml.vectors.TextEmbeddingQueryVectorBuilder;
import org.elasticsearch.xpack.inference.mapper.SemanticTextField;
import org.elasticsearch.xpack.inference.queries.SemanticQueryRewriteInterceptor;

public class SemanticKnnVectorQueryRewriteInterceptor
extends SemanticQueryRewriteInterceptor {
    public static final NodeFeature SEMANTIC_KNN_VECTOR_QUERY_REWRITE_INTERCEPTION_SUPPORTED = new NodeFeature("search.semantic_knn_vector_query_rewrite_interception_supported");
    public static final NodeFeature SEMANTIC_KNN_FILTER_FIX = new NodeFeature("search.semantic_knn_filter_fix");

    @Override
    protected String getFieldName(QueryBuilder queryBuilder) {
        assert (queryBuilder instanceof KnnVectorQueryBuilder);
        KnnVectorQueryBuilder knnVectorQueryBuilder = (KnnVectorQueryBuilder)queryBuilder;
        return knnVectorQueryBuilder.getFieldName();
    }

    @Override
    protected String getQuery(QueryBuilder queryBuilder) {
        assert (queryBuilder instanceof KnnVectorQueryBuilder);
        KnnVectorQueryBuilder knnVectorQueryBuilder = (KnnVectorQueryBuilder)queryBuilder;
        TextEmbeddingQueryVectorBuilder queryVectorBuilder = this.getTextEmbeddingQueryBuilderFromQuery(knnVectorQueryBuilder);
        return queryVectorBuilder != null ? queryVectorBuilder.getModelText() : null;
    }

    @Override
    protected QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, SemanticQueryRewriteInterceptor.InferenceIndexInformationForField indexInformation) {
        assert (queryBuilder instanceof KnnVectorQueryBuilder);
        KnnVectorQueryBuilder knnVectorQueryBuilder = (KnnVectorQueryBuilder)queryBuilder;
        Map<String, List<String>> inferenceIdsIndices = indexInformation.getInferenceIdsIndices();
        if (inferenceIdsIndices.size() == 1) {
            Map.Entry<String, List<String>> inferenceIdIndex = inferenceIdsIndices.entrySet().iterator().next();
            String searchInferenceId = inferenceIdIndex.getKey();
            List<String> indices = inferenceIdIndex.getValue();
            return this.buildNestedQueryFromKnnVectorQuery(knnVectorQueryBuilder, indices, searchInferenceId);
        }
        return this.buildInferenceQueryWithMultipleInferenceIds(knnVectorQueryBuilder, inferenceIdsIndices);
    }

    private QueryBuilder buildInferenceQueryWithMultipleInferenceIds(KnnVectorQueryBuilder queryBuilder, Map<String, List<String>> inferenceIdsIndices) {
        BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
        for (String inferenceId : inferenceIdsIndices.keySet()) {
            boolQueryBuilder.should(this.createSubQueryForIndices((Collection<String>)inferenceIdsIndices.get(inferenceId), this.buildNestedQueryFromKnnVectorQuery(queryBuilder, inferenceIdsIndices.get(inferenceId), inferenceId)));
        }
        return boolQueryBuilder;
    }

    @Override
    protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery(QueryBuilder queryBuilder, SemanticQueryRewriteInterceptor.InferenceIndexInformationForField indexInformation) {
        assert (queryBuilder instanceof KnnVectorQueryBuilder);
        KnnVectorQueryBuilder knnVectorQueryBuilder = (KnnVectorQueryBuilder)queryBuilder;
        Map<String, List<String>> inferenceIdsIndices = indexInformation.getInferenceIdsIndices();
        BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
        boolQueryBuilder.should((QueryBuilder)this.addIndexFilterToKnnVectorQuery(indexInformation.nonInferenceIndices(), knnVectorQueryBuilder));
        for (String inferenceId : inferenceIdsIndices.keySet()) {
            boolQueryBuilder.should(this.createSubQueryForIndices((Collection<String>)inferenceIdsIndices.get(inferenceId), this.buildNestedQueryFromKnnVectorQuery(knnVectorQueryBuilder, inferenceIdsIndices.get(inferenceId), inferenceId)));
        }
        return boolQueryBuilder;
    }

    private QueryBuilder buildNestedQueryFromKnnVectorQuery(KnnVectorQueryBuilder knnVectorQueryBuilder, List<String> indices, String searchInferenceId) {
        KnnVectorQueryBuilder filteredKnnVectorQueryBuilder = this.addIndexFilterToKnnVectorQuery(indices, knnVectorQueryBuilder);
        TextEmbeddingQueryVectorBuilder queryVectorBuilder = this.getTextEmbeddingQueryBuilderFromQuery(filteredKnnVectorQueryBuilder);
        if (queryVectorBuilder != null && queryVectorBuilder.getModelId() == null && searchInferenceId != null) {
            queryVectorBuilder = new TextEmbeddingQueryVectorBuilder(searchInferenceId, queryVectorBuilder.getModelText());
        }
        return QueryBuilders.nestedQuery((String)SemanticTextField.getChunksFieldName(filteredKnnVectorQueryBuilder.getFieldName()), (QueryBuilder)this.buildNewKnnVectorQuery(SemanticTextField.getEmbeddingsFieldName(filteredKnnVectorQueryBuilder.getFieldName()), filteredKnnVectorQueryBuilder, (QueryVectorBuilder)queryVectorBuilder), (ScoreMode)ScoreMode.Max);
    }

    private KnnVectorQueryBuilder addIndexFilterToKnnVectorQuery(Collection<String> indices, KnnVectorQueryBuilder original) {
        KnnVectorQueryBuilder copy = original.queryVectorBuilder() != null ? new KnnVectorQueryBuilder(original.getFieldName(), original.queryVectorBuilder(), original.k(), original.numCands(), original.getVectorSimilarity()) : new KnnVectorQueryBuilder(original.getFieldName(), original.queryVector(), original.k(), original.numCands(), original.rescoreVectorBuilder(), original.getVectorSimilarity());
        copy.addFilterQueries(original.filterQueries());
        copy.addFilterQuery((QueryBuilder)new TermsQueryBuilder("_index", indices));
        return copy;
    }

    private TextEmbeddingQueryVectorBuilder getTextEmbeddingQueryBuilderFromQuery(KnnVectorQueryBuilder knnVectorQueryBuilder) {
        QueryVectorBuilder queryVectorBuilder = knnVectorQueryBuilder.queryVectorBuilder();
        if (queryVectorBuilder == null) {
            return null;
        }
        assert (queryVectorBuilder instanceof TextEmbeddingQueryVectorBuilder);
        return (TextEmbeddingQueryVectorBuilder)queryVectorBuilder;
    }

    private KnnVectorQueryBuilder buildNewKnnVectorQuery(String fieldName, KnnVectorQueryBuilder original, QueryVectorBuilder queryVectorBuilder) {
        KnnVectorQueryBuilder newQueryBuilder = original.queryVectorBuilder() != null ? new KnnVectorQueryBuilder(fieldName, queryVectorBuilder, original.k(), original.numCands(), original.getVectorSimilarity()) : new KnnVectorQueryBuilder(fieldName, original.queryVector(), original.k(), original.numCands(), original.rescoreVectorBuilder(), original.getVectorSimilarity());
        newQueryBuilder.addFilterQueries(original.filterQueries());
        return newQueryBuilder;
    }

    public String getQueryName() {
        return "knn";
    }
}

