diff --git a/server/src/main/java/module-info.java b/server/src/main/java/module-info.java index d649651913fe2..eca2c9e1c8e0f 100644 --- a/server/src/main/java/module-info.java +++ b/server/src/main/java/module-info.java @@ -512,4 +512,5 @@ exports org.elasticsearch.index.mapper.blockloader; exports org.elasticsearch.index.mapper.blockloader.docvalues; exports org.elasticsearch.index.mapper.blockloader.docvalues.fn; + exports org.elasticsearch.search.diversification; } diff --git a/server/src/main/java/org/elasticsearch/search/diversification/DenseVectorFieldVectorSupplier.java b/server/src/main/java/org/elasticsearch/search/diversification/DenseVectorFieldVectorSupplier.java new file mode 100644 index 0000000000000..4ca45ed5ab69e --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/diversification/DenseVectorFieldVectorSupplier.java @@ -0,0 +1,94 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.diversification; + +import org.elasticsearch.search.vectors.VectorData; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class DenseVectorFieldVectorSupplier implements FieldVectorSupplier { + + private final String diversificationField; + private final DiversifyRetrieverBuilder.RankDocWithSearchHit[] searchHits; + private Map> fieldVectors = null; + + public DenseVectorFieldVectorSupplier(String diversificationField, DiversifyRetrieverBuilder.RankDocWithSearchHit[] hits) { + this.diversificationField = diversificationField; + this.searchHits = hits; + } + + @Override + public Map> getFieldVectors() { + if (fieldVectors != null) { + return fieldVectors; + } + + fieldVectors = new HashMap<>(); + for (DiversifyRetrieverBuilder.RankDocWithSearchHit searchHit : searchHits) { + var field = searchHit.hit().getFields().getOrDefault(diversificationField, null); + if (field != null) { + VectorData vector = extractFieldVectorData(field.getValues()); + if (vector != null) { + fieldVectors.put(searchHit.rank, List.of(vector)); + } + } + } + + return fieldVectors; + } + + public static boolean canFieldBeDenseVector(String fieldName, DiversifyRetrieverBuilder.RankDocWithSearchHit hit) { + var field = hit.hit().getFields().getOrDefault(fieldName, null); + if (field == null) { + return false; + } + + VectorData vector = extractFieldVectorData(field.getValues()); + return vector != null; + } + + private static float[] unboxedFloatArray(Float[] array) { + float[] unboxedArray = new float[array.length]; + int bIndex = 0; + for (Float b : array) { + unboxedArray[bIndex++] = b; + } + return unboxedArray; + } + + private static byte[] unboxedByteArray(Byte[] array) { + byte[] unboxedArray = new byte[array.length]; + int bIndex = 0; + for (Byte b : array) { + unboxedArray[bIndex++] = b; + } + return unboxedArray; + } + + private static VectorData extractFieldVectorData(List fieldValues) { + if (fieldValues == null || fieldValues.isEmpty()) { + return null; + } + + if (fieldValues.getFirst() instanceof Float) { + Float[] asFloatArray = fieldValues.stream().map(x -> (Float) x).toArray(Float[]::new); + return new VectorData(unboxedFloatArray(asFloatArray)); + } + + if (fieldValues.getFirst() instanceof Byte) { + Byte[] asByteArray = fieldValues.stream().map(x -> (Byte) x).toArray(Byte[]::new); + return new VectorData(unboxedByteArray(asByteArray)); + } + + return null; + } +} diff --git a/server/src/main/java/org/elasticsearch/search/diversification/DenseVectorSupplierField.java b/server/src/main/java/org/elasticsearch/search/diversification/DenseVectorSupplierField.java new file mode 100644 index 0000000000000..f156810c726af --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/diversification/DenseVectorSupplierField.java @@ -0,0 +1,18 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.diversification; + +import org.elasticsearch.search.vectors.VectorData; + +import java.util.List; + +public interface DenseVectorSupplierField { + List getVectorData(String key); +} diff --git a/server/src/main/java/org/elasticsearch/search/diversification/DiversifyRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/diversification/DiversifyRetrieverBuilder.java index 320c05ab559fc..6649a8bc13174 100644 --- a/server/src/main/java/org/elasticsearch/search/diversification/DiversifyRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/diversification/DiversifyRetrieverBuilder.java @@ -23,6 +23,8 @@ import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.diversification.mmr.MMRResultDiversificationContext; +import org.elasticsearch.search.fetch.StoredFieldsContext; +import org.elasticsearch.search.fetch.subphase.FetchSourceContext; import org.elasticsearch.search.rank.RankDoc; import org.elasticsearch.search.retriever.CompoundRetrieverBuilder; import org.elasticsearch.search.retriever.RetrieverBuilder; @@ -36,11 +38,8 @@ import org.elasticsearch.xcontent.XContentParser; import java.io.IOException; -import java.util.Arrays; -import java.util.HashMap; import java.util.List; import java.util.Locale; -import java.util.Map; import java.util.Objects; import java.util.function.Supplier; @@ -302,8 +301,16 @@ protected RetrieverBuilder doRewrite(QueryRewriteContext ctx) { @Override protected SearchSourceBuilder finalizeSourceBuilder(SearchSourceBuilder sourceBuilder) { - SearchSourceBuilder builder = sourceBuilder.from(0); - return super.finalizeSourceBuilder(builder).docValueField(diversificationField); + StoredFieldsContext sfCtx = StoredFieldsContext.fromList(List.of("_inference_fields", diversificationField)); + FetchSourceContext fsCtx = FetchSourceContext.of(false, false, new String[] { "_inference_fields", diversificationField }, null); + + SearchSourceBuilder builder = sourceBuilder.from(0) + .excludeVectors(false) + .storedFields(sfCtx) + .fetchSource(fsCtx) + .fetchField("_inference_fields") + .fetchField(diversificationField); + return super.finalizeSourceBuilder(builder); } @Override @@ -344,26 +351,26 @@ protected RankDoc[] combineInnerRetrieverResults(List rankResults, b return new RankDoc[0]; } - ResultDiversificationContext diversificationContext = getResultDiversificationContext(); - // gather and set the query vectors // and create our intermediate results set - RankDoc[] results = new RankDoc[scoreDocs.length]; - Map fieldVectors = new HashMap<>(); + RankDocWithSearchHit[] results = new RankDocWithSearchHit[scoreDocs.length]; for (int i = 0; i < scoreDocs.length; i++) { - RankDocWithSearchHit asRankDoc = (RankDocWithSearchHit) scoreDocs[i]; - results[i] = asRankDoc; - - var field = asRankDoc.hit().getFields().getOrDefault(diversificationField, null); - if (field != null) { - var fieldValue = field.getValue(); - if (fieldValue != null) { - extractFieldVectorData(asRankDoc.rank, fieldValue, fieldVectors); - } - } + results[i] = (RankDocWithSearchHit) scoreDocs[i]; + } + + ResultDiversificationContext diversificationContext = getResultDiversificationContext(); + + // temporary + int vectorCount = 0; + if (SemanticTextFieldVectorSupplier.isFieldSemanticTextVector(diversificationField, results[0])) { + FieldVectorSupplier fieldVectorSupplier = new SemanticTextFieldVectorSupplier(diversificationField, results); + vectorCount = diversificationContext.setFieldVectors(fieldVectorSupplier); + } else if (DenseVectorFieldVectorSupplier.canFieldBeDenseVector(diversificationField, results[0])) { + FieldVectorSupplier fieldVectorSupplier = new DenseVectorFieldVectorSupplier(diversificationField, results); + vectorCount = diversificationContext.setFieldVectors(fieldVectorSupplier); } - if (fieldVectors.isEmpty()) { + if (vectorCount == 0) { throw new ElasticsearchStatusException( String.format( Locale.ROOT, @@ -374,8 +381,6 @@ protected RankDoc[] combineInnerRetrieverResults(List rankResults, b ); } - diversificationContext.setFieldVectors(fieldVectors); - try { ResultDiversification diversification = ResultDiversificationFactory.getDiversifier( diversificationType, @@ -397,72 +402,6 @@ private ResultDiversificationContext getResultDiversificationContext() { throw new IllegalArgumentException("Unknown diversification type [" + diversificationType + "]"); } - private void extractFieldVectorData(int docId, Object fieldValue, Map fieldVectors) { - switch (fieldValue) { - case float[] floatArray -> { - fieldVectors.put(docId, new VectorData(floatArray)); - return; - } - case byte[] byteArray -> { - fieldVectors.put(docId, new VectorData(byteArray)); - return; - } - case Float[] boxedFloatArray -> { - fieldVectors.put(docId, new VectorData(unboxedFloatArray(boxedFloatArray))); - return; - } - case Byte[] boxedByteArray -> { - fieldVectors.put(docId, new VectorData(unboxedByteArray(boxedByteArray))); - return; - } - default -> { - } - } - - // CCS search returns a generic Object[] array, so we must - // examine the individual element type here. - if (fieldValue instanceof Object[] objectArray) { - if (objectArray.length == 0) { - return; - } - - if (objectArray[0] instanceof Byte) { - Byte[] asByteArray = Arrays.stream(objectArray).map(x -> (Byte) x).toArray(Byte[]::new); - fieldVectors.put(docId, new VectorData(unboxedByteArray(asByteArray))); - return; - } - - if (objectArray[0] instanceof Float) { - Float[] asFloatArray = Arrays.stream(objectArray).map(x -> (Float) x).toArray(Float[]::new); - fieldVectors.put(docId, new VectorData(unboxedFloatArray(asFloatArray))); - return; - } - } - - throw new ElasticsearchStatusException( - String.format(Locale.ROOT, "Failed to retrieve vectors for field [%s]. Is it a [dense_vector] field?", diversificationField), - RestStatus.BAD_REQUEST - ); - } - - private static float[] unboxedFloatArray(Float[] array) { - float[] unboxedArray = new float[array.length]; - int bIndex = 0; - for (Float b : array) { - unboxedArray[bIndex++] = b; - } - return unboxedArray; - } - - private static byte[] unboxedByteArray(Byte[] array) { - byte[] unboxedArray = new byte[array.length]; - int bIndex = 0; - for (Byte b : array) { - unboxedArray[bIndex++] = b; - } - return unboxedArray; - } - @Override public String getName() { return NAME; diff --git a/server/src/main/java/org/elasticsearch/search/diversification/FieldVectorSupplier.java b/server/src/main/java/org/elasticsearch/search/diversification/FieldVectorSupplier.java new file mode 100644 index 0000000000000..49c996e486802 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/diversification/FieldVectorSupplier.java @@ -0,0 +1,19 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.diversification; + +import org.elasticsearch.search.vectors.VectorData; + +import java.util.List; +import java.util.Map; + +public interface FieldVectorSupplier { + Map> getFieldVectors(); +} diff --git a/server/src/main/java/org/elasticsearch/search/diversification/ResultDiversificationContext.java b/server/src/main/java/org/elasticsearch/search/diversification/ResultDiversificationContext.java index 09c1fd92e6335..b6169059170c5 100644 --- a/server/src/main/java/org/elasticsearch/search/diversification/ResultDiversificationContext.java +++ b/server/src/main/java/org/elasticsearch/search/diversification/ResultDiversificationContext.java @@ -12,6 +12,7 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.search.vectors.VectorData; +import java.util.List; import java.util.Map; import java.util.Set; import java.util.function.Supplier; @@ -20,7 +21,10 @@ public abstract class ResultDiversificationContext { private final String field; private final int size; private final Supplier queryVector; - private Map fieldVectors = null; + private Map> fieldVectors = null; + + private boolean retrievedQueryVector = false; + private VectorData realizedQueryVector = null; protected ResultDiversificationContext(String field, int size, @Nullable Supplier queryVector) { this.field = field; @@ -36,24 +40,25 @@ public int getSize() { return size; } - /** - * Sets the field vectors for this context. - * Note that the key should be the `RankDoc` rank in the total result set - * @param fieldVectors the vectors to set - */ - public void setFieldVectors(Map fieldVectors) { - this.fieldVectors = fieldVectors; + public int setFieldVectors(FieldVectorSupplier fieldVectorSupplier) { + this.fieldVectors = fieldVectorSupplier.getFieldVectors(); + return this.fieldVectors.size(); } public VectorData getQueryVector() { - return queryVector == null ? null : queryVector.get(); + if (retrievedQueryVector) { + return realizedQueryVector; + } + realizedQueryVector = queryVector == null ? null : queryVector.get(); + retrievedQueryVector = true; + return realizedQueryVector; } - public VectorData getFieldVector(int rank) { + public List getFieldVectorData(int rank) { return fieldVectors.getOrDefault(rank, null); } - public Set> getFieldVectorsEntrySet() { + public Set>> getFieldVectorsEntrySet() { return fieldVectors.entrySet(); } } diff --git a/server/src/main/java/org/elasticsearch/search/diversification/SemanticTextFieldVectorSupplier.java b/server/src/main/java/org/elasticsearch/search/diversification/SemanticTextFieldVectorSupplier.java new file mode 100644 index 0000000000000..1c0b5bb32e7fd --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/diversification/SemanticTextFieldVectorSupplier.java @@ -0,0 +1,91 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.diversification; + +import org.elasticsearch.search.vectors.VectorData; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class SemanticTextFieldVectorSupplier implements FieldVectorSupplier { + + private final String diversificationField; + private final DiversifyRetrieverBuilder.RankDocWithSearchHit[] searchHits; + private Map> fieldVectors = null; + + public SemanticTextFieldVectorSupplier(String diversificationField, DiversifyRetrieverBuilder.RankDocWithSearchHit[] hits) { + this.diversificationField = diversificationField; + this.searchHits = hits; + } + + @Override + public Map> getFieldVectors() { + if (fieldVectors != null) { + return fieldVectors; + } + + fieldVectors = new HashMap<>(); + + for (DiversifyRetrieverBuilder.RankDocWithSearchHit hit : searchHits) { + var inferenceFieldValue = hit.hit().getFields().getOrDefault("_inference_fields", null); + if (inferenceFieldValue == null) { + continue; + } + + var fieldValues = inferenceFieldValue.getValues(); + if (fieldValues == null || fieldValues.isEmpty()) { + continue; + } + + if (fieldValues.getFirst() instanceof Map mappedValues) { + var fieldValue = mappedValues.getOrDefault(diversificationField, null); + if (fieldValue instanceof DenseVectorSupplierField vectorSupplier) { + List vectorData = vectorSupplier.getVectorData(diversificationField); + if (vectorData != null && vectorData.isEmpty() == false) { + fieldVectors.put(hit.rank, vectorData); + } + } + } + } + + return fieldVectors; + } + + private static float[] toFloatArray(byte[] values) { + float[] floatArray = new float[values.length]; + for (int i = 0; i < values.length; i++) { + floatArray[i] = ((Byte) values[i]).floatValue(); + } + return floatArray; + } + + public static boolean isFieldSemanticTextVector(String fieldName, DiversifyRetrieverBuilder.RankDocWithSearchHit hit) { + var inferenceFieldValue = hit.hit().getFields().getOrDefault("_inference_fields", null); + if (inferenceFieldValue == null) { + return false; + } + + var fieldValues = inferenceFieldValue.getValues(); + if (fieldValues == null || fieldValues.isEmpty()) { + return false; + } + + if (fieldValues.getFirst() instanceof Map mappedValues) { + var fieldValue = mappedValues.getOrDefault(fieldName, null); + if (fieldValue instanceof DenseVectorSupplierField vectorSupplier) { + List vectorData = vectorSupplier.getVectorData(fieldName); + return (vectorData != null && vectorData.isEmpty() == false); + } + } + + return false; + } +} diff --git a/server/src/main/java/org/elasticsearch/search/diversification/mmr/MMRResultDiversification.java b/server/src/main/java/org/elasticsearch/search/diversification/mmr/MMRResultDiversification.java index 861b7e9130a63..1c1d303ded819 100644 --- a/server/src/main/java/org/elasticsearch/search/diversification/mmr/MMRResultDiversification.java +++ b/server/src/main/java/org/elasticsearch/search/diversification/mmr/MMRResultDiversification.java @@ -46,8 +46,8 @@ public RankDoc[] diversify(RankDoc[] docs) throws IOException { List selectedDocRanks = new ArrayList<>(); // test the vector to see if we are using floats or bytes - VectorData firstVec = context.getFieldVector(docs[0].rank); - boolean useFloat = firstVec.isFloat(); + List firstVecData = context.getFieldVectorData(docs[0].rank); + boolean useFloat = firstVecData.getFirst().isFloat(); // cache the similarity scores for the query vector vs. searchHits Map querySimilarity = getQuerySimilarityForDocs(docs, useFloat, context); @@ -68,7 +68,8 @@ public RankDoc[] diversify(RankDoc[] docs) throws IOException { continue; } - var thisDocVector = context.getFieldVector(docRank); + // TODO - deal with multiple vectors to choose best + var thisDocVector = context.getFieldVectorData(docRank).getFirst(); if (thisDocVector == null) { continue; } @@ -144,7 +145,8 @@ private float getHighestScoreForSelectedVectors( highestScore = score; } } else { - VectorData comparisonVector = vec.getValue(); + // TODO - deal with multiple vectors to choose best + VectorData comparisonVector = vec.getValue().getFirst(); float score = useFloat ? getFloatVectorComparisonScore(similarityFunction, thisDocVector, comparisonVector) : getByteVectorComparisonScore(similarityFunction, thisDocVector, comparisonVector); @@ -166,7 +168,8 @@ protected Map getQuerySimilarityForDocs(RankDoc[] docs, boolean } for (RankDoc doc : docs) { - VectorData vectorData = context.getFieldVector(doc.rank); + // TODO - deal with multiple vectors to choose best + VectorData vectorData = context.getFieldVectorData(doc.rank).getFirst(); if (vectorData != null) { float querySimilarityScore = useFloat ? getFloatVectorComparisonScore(similarityFunction, vectorData, queryVector) diff --git a/server/src/test/java/org/elasticsearch/search/diversification/mmr/MMRResultDiversificationTests.java b/server/src/test/java/org/elasticsearch/search/diversification/mmr/MMRResultDiversificationTests.java index 7f8eb3509613d..9e902e9983a80 100644 --- a/server/src/test/java/org/elasticsearch/search/diversification/mmr/MMRResultDiversificationTests.java +++ b/server/src/test/java/org/elasticsearch/search/diversification/mmr/MMRResultDiversificationTests.java @@ -12,6 +12,7 @@ import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.mapper.MapperBuilderContext; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.search.diversification.FieldVectorSupplier; import org.elasticsearch.search.rank.RankDoc; import org.elasticsearch.search.vectors.VectorData; import org.elasticsearch.test.ESTestCase; @@ -74,19 +75,21 @@ private MMRResultDiversificationContext getRandomFloatContext(List expe Supplier queryVectorData = () -> new VectorData(new float[] { 0.5f, 0.2f, 0.4f, 0.4f }); var diversificationContext = new MMRResultDiversificationContext("dense_vector_field", 0.3f, 3, queryVectorData); diversificationContext.setFieldVectors( - Map.of( - 1, - new VectorData(new float[] { 0.4f, 0.2f, 0.4f, 0.4f }), - 2, - new VectorData(new float[] { 0.4f, 0.2f, 0.3f, 0.3f }), - 3, - new VectorData(new float[] { 0.4f, 0.1f, 0.3f, 0.3f }), - 4, - new VectorData(new float[] { 0.1f, 0.9f, 0.5f, 0.9f }), - 5, - new VectorData(new float[] { 0.1f, 0.9f, 0.5f, 0.9f }), - 6, - new VectorData(new float[] { 0.05f, 0.05f, 0.05f, 0.05f }) + new MockFieldVectorSuppler( + Map.of( + 1, + List.of(new VectorData(new float[] { 0.4f, 0.2f, 0.4f, 0.4f })), + 2, + List.of(new VectorData(new float[] { 0.4f, 0.2f, 0.3f, 0.3f })), + 3, + List.of(new VectorData(new float[] { 0.4f, 0.1f, 0.3f, 0.3f })), + 4, + List.of(new VectorData(new float[] { 0.1f, 0.9f, 0.5f, 0.9f })), + 5, + List.of(new VectorData(new float[] { 0.1f, 0.9f, 0.5f, 0.9f })), + 6, + List.of(new VectorData(new float[] { 0.05f, 0.05f, 0.05f, 0.05f })) + ) ) ); @@ -109,19 +112,21 @@ private MMRResultDiversificationContext getRandomByteContext(List expec Supplier queryVectorData = () -> new VectorData(new byte[] { 0x50, 0x20, 0x40, 0x40 }); var diversificationContext = new MMRResultDiversificationContext("dense_vector_field", 0.3f, 3, queryVectorData); diversificationContext.setFieldVectors( - Map.of( - 1, - new VectorData(new byte[] { 0x40, 0x20, 0x40, 0x40 }), - 2, - new VectorData(new byte[] { 0x40, 0x20, 0x30, 0x30 }), - 3, - new VectorData(new byte[] { 0x40, 0x10, 0x30, 0x30 }), - 4, - new VectorData(new byte[] { 0x10, (byte) 0x90, 0x50, (byte) 0x90 }), - 5, - new VectorData(new byte[] { 0x10, (byte) 0x90, 0x50, (byte) 0x90 }), - 6, - new VectorData(new byte[] { 0x50, 0x50, 0x50, 0x50 }) + new MockFieldVectorSuppler( + Map.of( + 1, + List.of(new VectorData(new byte[] { 0x40, 0x20, 0x40, 0x40 })), + 2, + List.of(new VectorData(new byte[] { 0x40, 0x20, 0x30, 0x30 })), + 3, + List.of(new VectorData(new byte[] { 0x40, 0x10, 0x30, 0x30 })), + 4, + List.of(new VectorData(new byte[] { 0x10, (byte) 0x90, 0x50, (byte) 0x90 })), + 5, + List.of(new VectorData(new byte[] { 0x10, (byte) 0x90, 0x50, (byte) 0x90 })), + 6, + List.of(new VectorData(new byte[] { 0x50, 0x50, 0x50, 0x50 })) + ) ) ); @@ -151,4 +156,17 @@ public void testMMRDiversificationIfNoSearchHits() throws IOException { assertSame(emptyDocs, resultDiversification.diversify(emptyDocs)); assertNull(resultDiversification.diversify(null)); } + + private class MockFieldVectorSuppler implements FieldVectorSupplier { + private final Map> vectors; + + MockFieldVectorSuppler(Map> vectors) { + this.vectors = vectors; + } + + @Override + public Map> getFieldVectors() { + return vectors; + } + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java index eaece2974ba64..f3b3ce472e09e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java @@ -14,9 +14,12 @@ import org.elasticsearch.common.xcontent.support.XContentMapValues; import org.elasticsearch.core.Nullable; import org.elasticsearch.index.IndexVersions; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.MinimalServiceSettings; +import org.elasticsearch.search.diversification.DenseVectorSupplierField; +import org.elasticsearch.search.vectors.VectorData; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.DeprecationHandler; import org.elasticsearch.xcontent.NamedXContentRegistry; @@ -58,7 +61,7 @@ public record SemanticTextField( @Nullable List originalValues, InferenceResult inference, XContentType contentType -) implements ToXContentObject { +) implements ToXContentObject, DenseVectorSupplierField { static final String TEXT_FIELD = "text"; static final String INFERENCE_FIELD = "inference"; @@ -82,6 +85,91 @@ public record InferenceResult( public record Chunk(@Nullable String text, int startOffset, int endOffset, BytesReference rawEmbeddings) {} + @Override + public List getVectorData(String key) { + if (this.inference == null) { + return Collections.emptyList(); + } + + if (this.inference.chunks() == null) { + return Collections.emptyList(); + } + + var elementType = this.inference().modelSettings().elementType(); + var dimensions = this.inference().modelSettings().dimensions(); + if (elementType == null || dimensions == null) { + return Collections.emptyList(); + } + + int embeddingLength = getEmbeddingLength(elementType, dimensions); + + List chunks = this.inference.chunks().getOrDefault(key, Collections.emptyList()); + List theseVectors = new ArrayList<>(); + for (Chunk chunk : chunks) { + BytesReference embeddingsBytes = chunk.rawEmbeddings(); + if (embeddingsBytes == null) { + continue; + } + double[] values = parseDenseVector(chunk.rawEmbeddings(), embeddingLength, this.contentType()); + if (values == null) { + continue; + } + + switch (elementType) { + case FLOAT, BFLOAT16 -> theseVectors.add(new VectorData(floatArrayOf(values))); + case BYTE, BIT -> theseVectors.add(new VectorData(byteArrayOf(values))); + } + } + return theseVectors; + } + + public static float[] floatArrayOf(double[] doublesArray) { + var floatArray = new float[doublesArray.length]; + for (int i = 0; i < doublesArray.length; i++) { + floatArray[i] = (float) doublesArray[i]; + } + return floatArray; + } + + private static byte[] byteArrayOf(double[] doublesArray) { + // It's fine to not check if the double values are out of range here because if any are, equality assertions on the expected vs. + // actual chunks will fail downstream + byte[] byteArray = new byte[doublesArray.length]; + for (int i = 0; i < doublesArray.length; i++) { + byteArray[i] = (byte) doublesArray[i]; + } + return byteArray; + } + + private static int getEmbeddingLength(DenseVectorFieldMapper.ElementType elementType, int dimensions) { + return switch (elementType) { + case FLOAT, BFLOAT16, BYTE -> dimensions; + case BIT -> { + assert dimensions % Byte.SIZE == 0; + yield dimensions / Byte.SIZE; + } + }; + } + + private static double[] parseDenseVector(BytesReference value, int numDims, XContentType contentType) { + try (XContentParser parser = XContentHelper.createParserNotCompressed(XContentParserConfiguration.EMPTY, value, contentType)) { + parser.nextToken(); + if (parser.currentToken() != XContentParser.Token.START_ARRAY) { + return null; + } + double[] values = new double[numDims]; + for (int i = 0; i < numDims; i++) { + if (parser.nextToken() == XContentParser.Token.END_ARRAY) { + return values; + } + values[i] = parser.doubleValue(); + } + return values; + } catch (Exception e) { + throw new RuntimeException(e); + } + } + public static String getOriginalTextFieldName(String fieldName) { return fieldName + "." + TEXT_FIELD; }