Skip to content

Commit

Permalink
Add vetorless search to record collection interface.
Browse files Browse the repository at this point in the history
  • Loading branch information
westey-m committed Oct 15, 2024
1 parent 524d9ab commit ba92d42
Show file tree
Hide file tree
Showing 26 changed files with 635 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ public void BuildFilterStringBuildsCorrectEqualityStringForEachFilterType(string
var filter = new VectorSearchFilter().EqualTo(fieldName, fieldValue!);

// Act.
var actual = AzureAISearchVectorStoreCollectionSearchMapping.BuildFilterString(filter, new Dictionary<string, string> { { fieldName, "storage_" + fieldName } });
var actual = AzureAISearchVectorStoreCollectionSearchMapping.BuildFilterString(filter.FilterClauses, new Dictionary<string, string> { { fieldName, "storage_" + fieldName } });

// Assert.
Assert.Equal(expected, actual);
Expand All @@ -34,7 +34,7 @@ public void BuildFilterStringBuildsCorrectTagContainsString()
var filter = new VectorSearchFilter().AnyTagEqualTo("Tags", "mytag");

// Act.
var actual = AzureAISearchVectorStoreCollectionSearchMapping.BuildFilterString(filter, new Dictionary<string, string> { { "Tags", "storage_tags" } });
var actual = AzureAISearchVectorStoreCollectionSearchMapping.BuildFilterString(filter.FilterClauses, new Dictionary<string, string> { { "Tags", "storage_tags" } });

// Assert.
Assert.Equal("storage_tags/any(t: t eq 'mytag')", actual);
Expand All @@ -47,7 +47,7 @@ public void BuildFilterStringCombinesFilterOptions()
var filter = new VectorSearchFilter().EqualTo("intField", 5).AnyTagEqualTo("Tags", "mytag");

// Act.
var actual = AzureAISearchVectorStoreCollectionSearchMapping.BuildFilterString(filter, new Dictionary<string, string> { { "Tags", "storage_tags" }, { "intField", "storage_intField" } });
var actual = AzureAISearchVectorStoreCollectionSearchMapping.BuildFilterString(filter.FilterClauses, new Dictionary<string, string> { { "Tags", "storage_tags" }, { "intField", "storage_intField" } });

// Assert.
Assert.Equal("storage_intField eq 5 and storage_tags/any(t: t eq 'mytag')", actual);
Expand All @@ -57,8 +57,8 @@ public void BuildFilterStringCombinesFilterOptions()
public void BuildFilterStringThrowsForUnknownPropertyName()
{
// Act and assert.
Assert.Throws<InvalidOperationException>(() => AzureAISearchVectorStoreCollectionSearchMapping.BuildFilterString(new VectorSearchFilter().EqualTo("unknown", "value"), new Dictionary<string, string>()));
Assert.Throws<InvalidOperationException>(() => AzureAISearchVectorStoreCollectionSearchMapping.BuildFilterString(new VectorSearchFilter().AnyTagEqualTo("unknown", "value"), new Dictionary<string, string>()));
Assert.Throws<InvalidOperationException>(() => AzureAISearchVectorStoreCollectionSearchMapping.BuildFilterString(new VectorSearchFilter().EqualTo("unknown", "value").FilterClauses, new Dictionary<string, string>()));
Assert.Throws<InvalidOperationException>(() => AzureAISearchVectorStoreCollectionSearchMapping.BuildFilterString(new VectorSearchFilter().AnyTagEqualTo("unknown", "value").FilterClauses, new Dictionary<string, string>()));
}

public static IEnumerable<object[]> DataTypeMappingOptions()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,40 @@ public async Task CanUpsertManyRecordsAsync<TKey>(bool useDefinition, TKey testK
Assert.Equal($"data {testKey1}", (collection[testKey1] as SinglePropsModel<TKey>)!.Data);
}

[Theory]
[InlineData(true, TestRecordKey1, TestRecordKey2)]
[InlineData(true, TestRecordIntKey1, TestRecordIntKey2)]
[InlineData(false, TestRecordKey1, TestRecordKey2)]
[InlineData(false, TestRecordIntKey1, TestRecordIntKey2)]
public async Task CanSearchWithoutVectorAsync<TKey>(bool useDefinition, TKey testKey1, TKey testKey2)
where TKey : notnull
{
// Arrange
var record1 = CreateModel(testKey1, withVectors: true, new float[] { 1, 1, 1, 1 });
var record2 = CreateModel(testKey2, withVectors: true, new float[] { -1, -1, -1, -1 });

var collection = new ConcurrentDictionary<object, object>();
collection.TryAdd(testKey1, record1);
collection.TryAdd(testKey2, record2);

this._collectionStore.TryAdd(TestCollectionName, collection);

var sut = this.CreateRecordCollection<TKey>(useDefinition);

// Act
var filter = new VectorlessSearchFilter().EqualTo("Data", $"data {testKey2}");
var actual = await sut.VectorlessSearchAsync(
new VectorlessSearchOptions { IncludeVectors = true, Filter = filter },
this._testCancellationToken);

// Assert
Assert.NotNull(actual);
Assert.Null(actual.TotalCount);
var actualResults = await actual.Results.ToListAsync();
Assert.Single(actualResults);
Assert.Contains(actualResults, x => x.Key!.Equals(testKey2));
}

[Theory]
[InlineData(true, TestRecordKey1, TestRecordKey2)]
[InlineData(true, TestRecordIntKey1, TestRecordIntKey2)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,17 @@ internal static class AzureAISearchVectorStoreCollectionSearchMapping
/// <summary>
/// Build an OData filter string from the provided <see cref="VectorSearchFilter"/>.
/// </summary>
/// <param name="basicVectorSearchFilter">The <see cref="VectorSearchFilter"/> to build an OData filter string from.</param>
/// <param name="filterClausees">The <see cref="FilterClause"/> objects to build an OData filter string from.</param>
/// <param name="storagePropertyNames">A mapping of data model property names to the names under which they are stored.</param>
/// <returns>The OData filter string.</returns>
/// <exception cref="InvalidOperationException">Thrown when a provided filter value is not supported.</exception>
public static string BuildFilterString(VectorSearchFilter? basicVectorSearchFilter, IReadOnlyDictionary<string, string> storagePropertyNames)
public static string BuildFilterString(IEnumerable<FilterClause>? filterClausees, IReadOnlyDictionary<string, string> storagePropertyNames)
{
var filterString = string.Empty;
if (basicVectorSearchFilter?.FilterClauses is not null)
if (filterClausees is not null)
{
// Map Equality clauses.
var filterStrings = basicVectorSearchFilter?.FilterClauses.OfType<EqualToFilterClause>().Select(x =>
var filterStrings = filterClausees.OfType<EqualToFilterClause>().Select(x =>
{
string storageFieldName = GetStoragePropertyName(storagePropertyNames, x.FieldName);
Expand All @@ -46,7 +46,7 @@ public static string BuildFilterString(VectorSearchFilter? basicVectorSearchFilt
});

// Map tag contains clauses.
var tagListContainsStrings = basicVectorSearchFilter?.FilterClauses
var tagListContainsStrings = filterClausees
.OfType<AnyTagEqualToFilterClause>()
.Select(x =>
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,36 @@ public async IAsyncEnumerable<string> UpsertBatchAsync(IEnumerable<TRecord> reco
foreach (var resultKey in resultKeys) { yield return resultKey; }
}

/// <inheritdoc />
public async Task<VectorlessSearchResults<TRecord>> VectorlessSearchAsync(VectorlessSearchOptions? options = null, CancellationToken cancellationToken = default)
{
// Resolve options.
var internalOptions = options ?? new VectorlessSearchOptions();

// Configure search settings.
var filterString = AzureAISearchVectorStoreCollectionSearchMapping.BuildFilterString(internalOptions.Filter?.FilterClauses, this._propertyReader.JsonPropertyNamesMap);

// Build search options.
var searchOptions = new SearchOptions
{
VectorSearch = new(),
Size = internalOptions.Top,
Skip = internalOptions.Skip,
Filter = filterString,
IncludeTotalCount = internalOptions.IncludeTotalCount,
};

// Filter out vector fields if requested.
if (!internalOptions.IncludeVectors)
{
searchOptions.Select.Add(this._propertyReader.KeyPropertyJsonName);
searchOptions.Select.AddRange(this._propertyReader.DataPropertyJsonNames);
}

var vectorSearchResults = await this.SearchAndMapToDataModelAsync(null, searchOptions, internalOptions.IncludeVectors, cancellationToken).ConfigureAwait(false);
return VectorStoreSearchResultMapping.ConvertToVectorlessSearchResults(vectorSearchResults, cancellationToken);
}

/// <inheritdoc />
public Task<VectorSearchResults<TRecord>> VectorizedSearchAsync<TVector>(TVector vector, VectorData.VectorSearchOptions? options = null, CancellationToken cancellationToken = default)
{
Expand All @@ -335,7 +365,7 @@ public Task<VectorSearchResults<TRecord>> VectorizedSearchAsync<TVector>(TVector
// Configure search settings.
var vectorQueries = new List<VectorQuery>();
vectorQueries.Add(new VectorizedQuery(floatVector) { KNearestNeighborsCount = internalOptions.Top, Fields = { vectorFieldName } });
var filterString = AzureAISearchVectorStoreCollectionSearchMapping.BuildFilterString(internalOptions.Filter, this._propertyReader.JsonPropertyNamesMap);
var filterString = AzureAISearchVectorStoreCollectionSearchMapping.BuildFilterString(internalOptions.Filter?.FilterClauses, this._propertyReader.JsonPropertyNamesMap);

// Build search options.
var searchOptions = new SearchOptions
Expand Down Expand Up @@ -375,7 +405,7 @@ public Task<VectorSearchResults<TRecord>> VectorizableTextSearchAsync(string sea
// Configure search settings.
var vectorQueries = new List<VectorQuery>();
vectorQueries.Add(new VectorizableTextQuery(searchText) { KNearestNeighborsCount = internalOptions.Top, Fields = { vectorFieldName } });
var filterString = AzureAISearchVectorStoreCollectionSearchMapping.BuildFilterString(internalOptions.Filter, this._propertyReader.JsonPropertyNamesMap);
var filterString = AzureAISearchVectorStoreCollectionSearchMapping.BuildFilterString(internalOptions.Filter?.FilterClauses, this._propertyReader.JsonPropertyNamesMap);

// Build search options.
var searchOptions = new SearchOptions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,18 @@ public async IAsyncEnumerable<string> UpsertBatchAsync(
}
}

/// <inheritdoc />
public async Task<VectorlessSearchResults<TRecord>> VectorlessSearchAsync(VectorlessSearchOptions? options = null, CancellationToken cancellationToken = default)
{
// TODO: Switch to non-vector search to improve performance.
var dimensions = this._propertyReader.VectorProperty?.Dimensions ?? throw new InvalidOperationException("The collection does not have any vector properties, so simulated vectorless search is not possible.");
var vectorSearchResults = await this.VectorizedSearchAsync(
new ReadOnlyMemory<float>(new float[dimensions]),
VectorSearchOptions.FromVectorlessSearchOptions(options),
cancellationToken).ConfigureAwait(false);
return VectorStoreSearchResultMapping.ConvertToVectorlessSearchResults(vectorSearchResults, cancellationToken);
}

/// <inheritdoc />
public async Task<VectorSearchResults<TRecord>> VectorizedSearchAsync<TVector>(
TVector vector,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,18 @@ async IAsyncEnumerable<AzureCosmosDBNoSQLCompositeKey> IVectorStoreRecordCollect
}
}

/// <inheritdoc />
public async Task<VectorlessSearchResults<TRecord>> VectorlessSearchAsync(VectorlessSearchOptions? options = null, CancellationToken cancellationToken = default)
{
// TODO: Switch to non-vector search to improve performance.
var dimensions = this._propertyReader.VectorProperty?.Dimensions ?? throw new InvalidOperationException("The collection does not have any vector properties, so simulated vectorless search is not possible.");
var vectorSearchResults = await this.VectorizedSearchAsync(
new ReadOnlyMemory<float>(new float[dimensions]),
VectorSearchOptions.FromVectorlessSearchOptions(options),
cancellationToken).ConfigureAwait(false);
return VectorStoreSearchResultMapping.ConvertToVectorlessSearchResults(vectorSearchResults, cancellationToken);
}

/// <inheritdoc />
public Task<VectorSearchResults<TRecord>> VectorizedSearchAsync<TVector>(
TVector vector,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,13 @@ public static float ConvertScore(float score, string? distanceFunction)
/// <summary>
/// Filter the provided records using the provided filter definition.
/// </summary>
/// <param name="filter">The filter definition to filter the <paramref name="records"/> with.</param>
/// <param name="filterClauses">The filter clauses to filter the <paramref name="records"/> with.</param>
/// <param name="records">The records to filter.</param>
/// <returns>The filtered records.</returns>
/// <exception cref="InvalidOperationException">Thrown when an unsupported filter clause is encountered.</exception>
public static IEnumerable<object> FilterRecords(VectorSearchFilter? filter, IEnumerable<object> records)
public static IEnumerable<object> FilterRecords(IEnumerable<FilterClause>? filterClauses, IEnumerable<object> records)
{
if (filter == null)
if (filterClauses == null)
{
return records;
}
Expand All @@ -109,7 +109,7 @@ public static IEnumerable<object> FilterRecords(VectorSearchFilter? filter, IEnu
// Run each filter clause against the record, and AND the results together.
// Break if any clause returns false, since we are doing an AND and no need
// to check any further clauses.
foreach (var clause in filter.FilterClauses)
foreach (var clause in filterClauses)
{
if (clause is EqualToFilterClause equalToFilter)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,26 @@ public async IAsyncEnumerable<TKey> UpsertBatchAsync(IEnumerable<TRecord> record
}
}

/// <inheritdoc />
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously - Need to satisfy the interface which returns IAsyncEnumerable
public async Task<VectorlessSearchResults<TRecord>> VectorlessSearchAsync(VectorlessSearchOptions? options = null, CancellationToken cancellationToken = default)
#pragma warning restore CS1998
{
var internalOptions = options ?? new VectorlessSearchOptions();

var filteredRecords = InMemoryVectorStoreCollectionSearchMapping.FilterRecords(internalOptions.Filter?.FilterClauses, this.GetCollectionDictionary().Values);

long? count = null;
if (internalOptions.IncludeTotalCount)
{
count = filteredRecords.Count();
}

var resultsPage = filteredRecords.Skip(internalOptions.Skip).Take(internalOptions.Top);

return new VectorlessSearchResults<TRecord>(resultsPage.Cast<TRecord>().ToAsyncEnumerable()) { TotalCount = count };
}

/// <inheritdoc />
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously - Need to satisfy the interface which returns IAsyncEnumerable
public async Task<VectorSearchResults<TRecord>> VectorizedSearchAsync<TVector>(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default)
Expand Down Expand Up @@ -235,7 +255,7 @@ public async Task<VectorSearchResults<TRecord>> VectorizedSearchAsync<TVector>(T
}

// Filter records using the provided filter before doing the vector comparison.
var filteredRecords = InMemoryVectorStoreCollectionSearchMapping.FilterRecords(internalOptions.Filter, this.GetCollectionDictionary().Values);
var filteredRecords = InMemoryVectorStoreCollectionSearchMapping.FilterRecords(internalOptions.Filter?.FilterClauses, this.GetCollectionDictionary().Values);

// Compare each vector in the filtered results with the provided vector.
var results = filteredRecords.Select<object, (object record, float score)?>((record) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,17 @@ await this.RunOperationAsync(
}
}

/// <inheritdoc />
public async Task<VectorlessSearchResults<TRecord>> VectorlessSearchAsync(VectorlessSearchOptions? options = null, CancellationToken cancellationToken = default)
{
var dimensions = this._propertyReader.VectorProperty?.Dimensions ?? throw new InvalidOperationException("The collection does not have any vector properties, so simulated vectorless search is not possible.");
var vectorSearchResults = await this.VectorizedSearchAsync(
new ReadOnlyMemory<float>(new float[dimensions]),
VectorSearchOptions.FromVectorlessSearchOptions(options),
cancellationToken).ConfigureAwait(false);
return VectorStoreSearchResultMapping.ConvertToVectorlessSearchResults(vectorSearchResults, cancellationToken);
}

/// <inheritdoc />
public Task<VectorSearchResults<TRecord>> VectorizedSearchAsync<TVector>(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,17 @@ private async IAsyncEnumerable<TRecord> GetBatchByPointIdAsync<TKey>(
}
}

/// <inheritdoc />
public async Task<VectorlessSearchResults<TRecord>> VectorlessSearchAsync(VectorlessSearchOptions? options = null, CancellationToken cancellationToken = default)
{
var dimensions = this._propertyReader.VectorProperty?.Dimensions ?? throw new InvalidOperationException("The collection does not have any vector properties, so simulated vectorless search is not possible.");
var vectorSearchResults = await this.VectorizedSearchAsync(
new ReadOnlyMemory<float>(new float[dimensions]),
VectorSearchOptions.FromVectorlessSearchOptions(options),
cancellationToken).ConfigureAwait(false);
return VectorStoreSearchResultMapping.ConvertToVectorlessSearchResults(vectorSearchResults, cancellationToken);
}

/// <inheritdoc />
public async Task<VectorSearchResults<TRecord>> VectorizedSearchAsync<TVector>(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,18 @@ public async IAsyncEnumerable<string> UpsertBatchAsync(IEnumerable<TRecord> reco
}
}

/// <inheritdoc />
public async Task<VectorlessSearchResults<TRecord>> VectorlessSearchAsync(VectorlessSearchOptions? options = null, CancellationToken cancellationToken = default)
{
// TODO: Switch to non-vector search to improve performance.
var dimensions = this._propertyReader.VectorProperty?.Dimensions ?? throw new InvalidOperationException("The collection does not have any vector properties, so simulated vectorless search is not possible.");
var vectorSearchResults = await this.VectorizedSearchAsync(
new ReadOnlyMemory<float>(new float[dimensions]),
VectorSearchOptions.FromVectorlessSearchOptions(options),
cancellationToken).ConfigureAwait(false);
return VectorStoreSearchResultMapping.ConvertToVectorlessSearchResults(vectorSearchResults, cancellationToken);
}

/// <inheritdoc />
public async Task<VectorSearchResults<TRecord>> VectorizedSearchAsync<TVector>(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default)
{
Expand Down
Loading

0 comments on commit ba92d42

Please sign in to comment.