Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions server/src/main/java/module-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -512,4 +512,5 @@
exports org.elasticsearch.index.mapper.blockloader;
exports org.elasticsearch.index.mapper.blockloader.docvalues;
exports org.elasticsearch.index.mapper.blockloader.docvalues.fn;
exports org.elasticsearch.search.diversification;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.search.diversification;

import org.elasticsearch.search.vectors.VectorData;

import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class DenseVectorFieldVectorSupplier implements FieldVectorSupplier {

private final String diversificationField;
private final DiversifyRetrieverBuilder.RankDocWithSearchHit[] searchHits;
private Map<Integer, List<VectorData>> fieldVectors = null;

public DenseVectorFieldVectorSupplier(String diversificationField, DiversifyRetrieverBuilder.RankDocWithSearchHit[] hits) {
this.diversificationField = diversificationField;
this.searchHits = hits;
}

@Override
public Map<Integer, List<VectorData>> getFieldVectors() {
if (fieldVectors != null) {
return fieldVectors;
}

fieldVectors = new HashMap<>();
for (DiversifyRetrieverBuilder.RankDocWithSearchHit searchHit : searchHits) {
var field = searchHit.hit().getFields().getOrDefault(diversificationField, null);
if (field != null) {
VectorData vector = extractFieldVectorData(field.getValues());
if (vector != null) {
fieldVectors.put(searchHit.rank, List.of(vector));
}
}
}

return fieldVectors;
}

public static boolean canFieldBeDenseVector(String fieldName, DiversifyRetrieverBuilder.RankDocWithSearchHit hit) {
var field = hit.hit().getFields().getOrDefault(fieldName, null);
if (field == null) {
return false;
}

VectorData vector = extractFieldVectorData(field.getValues());
return vector != null;
}

private static float[] unboxedFloatArray(Float[] array) {
float[] unboxedArray = new float[array.length];
int bIndex = 0;
for (Float b : array) {
unboxedArray[bIndex++] = b;
}
return unboxedArray;
}

private static byte[] unboxedByteArray(Byte[] array) {
byte[] unboxedArray = new byte[array.length];
int bIndex = 0;
for (Byte b : array) {
unboxedArray[bIndex++] = b;
}
return unboxedArray;
}

private static VectorData extractFieldVectorData(List<?> fieldValues) {
if (fieldValues == null || fieldValues.isEmpty()) {
return null;
}

if (fieldValues.getFirst() instanceof Float) {
Float[] asFloatArray = fieldValues.stream().map(x -> (Float) x).toArray(Float[]::new);
return new VectorData(unboxedFloatArray(asFloatArray));
}

if (fieldValues.getFirst() instanceof Byte) {
Byte[] asByteArray = fieldValues.stream().map(x -> (Byte) x).toArray(Byte[]::new);
return new VectorData(unboxedByteArray(asByteArray));
}

return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.diversification.mmr.MMRResultDiversificationContext;
import org.elasticsearch.search.fetch.StoredFieldsContext;
import org.elasticsearch.search.fetch.subphase.FetchSourceContext;
import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.search.retriever.CompoundRetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverBuilder;
Expand All @@ -36,11 +38,8 @@
import org.elasticsearch.xcontent.XContentParser;

import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.function.Supplier;

Expand Down Expand Up @@ -302,8 +301,16 @@ protected RetrieverBuilder doRewrite(QueryRewriteContext ctx) {

@Override
protected SearchSourceBuilder finalizeSourceBuilder(SearchSourceBuilder sourceBuilder) {
SearchSourceBuilder builder = sourceBuilder.from(0);
return super.finalizeSourceBuilder(builder).docValueField(diversificationField);
StoredFieldsContext sfCtx = StoredFieldsContext.fromList(List.of("_inference_fields", diversificationField));
FetchSourceContext fsCtx = FetchSourceContext.of(false, false, new String[] { "_inference_fields", diversificationField }, null);

SearchSourceBuilder builder = sourceBuilder.from(0)
.excludeVectors(false)
.storedFields(sfCtx)
.fetchSource(fsCtx)
.fetchField("_inference_fields")
.fetchField(diversificationField);
return super.finalizeSourceBuilder(builder);
}

@Override
Expand Down Expand Up @@ -344,26 +351,26 @@ protected RankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults, b
return new RankDoc[0];
}

ResultDiversificationContext diversificationContext = getResultDiversificationContext();

// gather and set the query vectors
// and create our intermediate results set
RankDoc[] results = new RankDoc[scoreDocs.length];
Map<Integer, VectorData> fieldVectors = new HashMap<>();
RankDocWithSearchHit[] results = new RankDocWithSearchHit[scoreDocs.length];
for (int i = 0; i < scoreDocs.length; i++) {
RankDocWithSearchHit asRankDoc = (RankDocWithSearchHit) scoreDocs[i];
results[i] = asRankDoc;

var field = asRankDoc.hit().getFields().getOrDefault(diversificationField, null);
if (field != null) {
var fieldValue = field.getValue();
if (fieldValue != null) {
extractFieldVectorData(asRankDoc.rank, fieldValue, fieldVectors);
}
}
results[i] = (RankDocWithSearchHit) scoreDocs[i];
}

ResultDiversificationContext diversificationContext = getResultDiversificationContext();

// temporary
int vectorCount = 0;
if (SemanticTextFieldVectorSupplier.isFieldSemanticTextVector(diversificationField, results[0])) {
FieldVectorSupplier fieldVectorSupplier = new SemanticTextFieldVectorSupplier(diversificationField, results);
vectorCount = diversificationContext.setFieldVectors(fieldVectorSupplier);
} else if (DenseVectorFieldVectorSupplier.canFieldBeDenseVector(diversificationField, results[0])) {
FieldVectorSupplier fieldVectorSupplier = new DenseVectorFieldVectorSupplier(diversificationField, results);
vectorCount = diversificationContext.setFieldVectors(fieldVectorSupplier);
}

if (fieldVectors.isEmpty()) {
if (vectorCount == 0) {
throw new ElasticsearchStatusException(
String.format(
Locale.ROOT,
Expand All @@ -374,8 +381,6 @@ protected RankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults, b
);
}

diversificationContext.setFieldVectors(fieldVectors);

try {
ResultDiversification<?> diversification = ResultDiversificationFactory.getDiversifier(
diversificationType,
Expand All @@ -397,72 +402,6 @@ private ResultDiversificationContext getResultDiversificationContext() {
throw new IllegalArgumentException("Unknown diversification type [" + diversificationType + "]");
}

private void extractFieldVectorData(int docId, Object fieldValue, Map<Integer, VectorData> fieldVectors) {
switch (fieldValue) {
case float[] floatArray -> {
fieldVectors.put(docId, new VectorData(floatArray));
return;
}
case byte[] byteArray -> {
fieldVectors.put(docId, new VectorData(byteArray));
return;
}
case Float[] boxedFloatArray -> {
fieldVectors.put(docId, new VectorData(unboxedFloatArray(boxedFloatArray)));
return;
}
case Byte[] boxedByteArray -> {
fieldVectors.put(docId, new VectorData(unboxedByteArray(boxedByteArray)));
return;
}
default -> {
}
}

// CCS search returns a generic Object[] array, so we must
// examine the individual element type here.
if (fieldValue instanceof Object[] objectArray) {
if (objectArray.length == 0) {
return;
}

if (objectArray[0] instanceof Byte) {
Byte[] asByteArray = Arrays.stream(objectArray).map(x -> (Byte) x).toArray(Byte[]::new);
fieldVectors.put(docId, new VectorData(unboxedByteArray(asByteArray)));
return;
}

if (objectArray[0] instanceof Float) {
Float[] asFloatArray = Arrays.stream(objectArray).map(x -> (Float) x).toArray(Float[]::new);
fieldVectors.put(docId, new VectorData(unboxedFloatArray(asFloatArray)));
return;
}
}

throw new ElasticsearchStatusException(
String.format(Locale.ROOT, "Failed to retrieve vectors for field [%s]. Is it a [dense_vector] field?", diversificationField),
RestStatus.BAD_REQUEST
);
}

private static float[] unboxedFloatArray(Float[] array) {
float[] unboxedArray = new float[array.length];
int bIndex = 0;
for (Float b : array) {
unboxedArray[bIndex++] = b;
}
return unboxedArray;
}

private static byte[] unboxedByteArray(Byte[] array) {
byte[] unboxedArray = new byte[array.length];
int bIndex = 0;
for (Byte b : array) {
unboxedArray[bIndex++] = b;
}
return unboxedArray;
}

@Override
public String getName() {
return NAME;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.search.diversification;

import org.elasticsearch.search.vectors.VectorData;

import java.util.List;
import java.util.Map;

public interface FieldVectorSupplier {
Map<Integer, List<VectorData>> getFieldVectors();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.search.diversification;

import org.elasticsearch.inference.ChunkedInference;

import java.util.List;

public interface InferenceChunkSupplier {
List<ChunkedInference.Chunk> getChunks(String key);
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.elasticsearch.core.Nullable;
import org.elasticsearch.search.vectors.VectorData;

import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Supplier;
Expand All @@ -20,7 +21,10 @@ public abstract class ResultDiversificationContext {
private final String field;
private final int size;
private final Supplier<VectorData> queryVector;
private Map<Integer, VectorData> fieldVectors = null;
private Map<Integer, List<VectorData>> fieldVectors = null;

private boolean retrievedQueryVector = false;
private VectorData realizedQueryVector = null;

protected ResultDiversificationContext(String field, int size, @Nullable Supplier<VectorData> queryVector) {
this.field = field;
Expand All @@ -36,24 +40,25 @@ public int getSize() {
return size;
}

/**
* Sets the field vectors for this context.
* Note that the key should be the `RankDoc` rank in the total result set
* @param fieldVectors the vectors to set
*/
public void setFieldVectors(Map<Integer, VectorData> fieldVectors) {
this.fieldVectors = fieldVectors;
public int setFieldVectors(FieldVectorSupplier fieldVectorSupplier) {
this.fieldVectors = fieldVectorSupplier.getFieldVectors();
return this.fieldVectors.size();
}

public VectorData getQueryVector() {
return queryVector == null ? null : queryVector.get();
if (retrievedQueryVector) {
return realizedQueryVector;
}
realizedQueryVector = queryVector == null ? null : queryVector.get();
retrievedQueryVector = true;
return realizedQueryVector;
}

public VectorData getFieldVector(int rank) {
public List<VectorData> getFieldVectorData(int rank) {
return fieldVectors.getOrDefault(rank, null);
}

public Set<Map.Entry<Integer, VectorData>> getFieldVectorsEntrySet() {
public Set<Map.Entry<Integer, List<VectorData>>> getFieldVectorsEntrySet() {
return fieldVectors.entrySet();
}
}
Loading