Skip to content

Commit

Permalink
Added ArrayContains to CosmosLinqExtensions to allow partial matching
Browse files Browse the repository at this point in the history
The `Array_Contains` funtion CosmosDB Sql has a 3rd parameter which allows it to do a partial match on the given item. This is unable to be called with the built in Linq `array.Contains(item)` extension methods.

This adds this adds an explicit mapping to this function to allow it to be called in Linq like this:
`documents.Where(document => document.ObjectArray.ArrayContains(new { Name = "abc" }, true))`
  • Loading branch information
Ben Robinson committed Feb 3, 2025
1 parent 0958198 commit 8ec125f
Show file tree
Hide file tree
Showing 7 changed files with 200 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,13 @@ public ArrayContainsVisitor()
{
}

public bool UsePartialMatchParameter { get; set; }

protected override SqlScalarExpression VisitImplicit(MethodCallExpression methodCallExpression, TranslationContext context)
{
Expression searchList = null;
Expression searchExpression = null;
Expression partialMatchExpression = null;

// If non static Contains
if (methodCallExpression.Arguments.Count == 1)
Expand All @@ -59,6 +62,13 @@ protected override SqlScalarExpression VisitImplicit(MethodCallExpression method
searchList = methodCallExpression.Arguments[0];
searchExpression = methodCallExpression.Arguments[1];
}
// if CosmosLinqExtensions.ArrayContains extension method which includes partial match parameter
else if (this.UsePartialMatchParameter && methodCallExpression.Arguments.Count == 3)
{
searchList = methodCallExpression.Arguments[0];
searchExpression = methodCallExpression.Arguments[1];
partialMatchExpression = methodCallExpression.Arguments[2];
}

if (searchList == null || searchExpression == null)
{
Expand All @@ -72,7 +82,20 @@ protected override SqlScalarExpression VisitImplicit(MethodCallExpression method

SqlScalarExpression array = ExpressionToSql.VisitScalarExpression(searchList, context);
SqlScalarExpression expression = ExpressionToSql.VisitScalarExpression(searchExpression, context);
return SqlFunctionCallScalarExpression.CreateBuiltin("ARRAY_CONTAINS", array, expression);

SqlScalarExpression[] arrayContainsArgs;

if (partialMatchExpression is null)
{
arrayContainsArgs = new[] { array, expression };
}
else
{
SqlScalarExpression partialMatch = ExpressionToSql.VisitScalarExpression(partialMatchExpression, context);
arrayContainsArgs = new[] { array, expression, partialMatch };
}

return SqlFunctionCallScalarExpression.CreateBuiltin("ARRAY_CONTAINS", arrayContainsArgs);
}

private SqlScalarExpression VisitIN(Expression expression, ConstantExpression constantExpressionList, TranslationContext context)
Expand Down Expand Up @@ -177,6 +200,10 @@ static ArrayBuiltinFunctions()
{
"ToList",
new ArrayToArrayVisitor()
},
{
nameof(CosmosLinqExtensions.ArrayContains),
new ArrayContainsVisitor() { UsePartialMatchParameter = true }
}
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ public static SqlScalarExpression VisitBuiltinFunctionCall(MethodCallExpression
{
return OtherBuiltinSystemFunctions.Visit(methodCallExpression, context);
}

if (methodCallExpression.Method.Name == nameof(CosmosLinqExtensions.ArrayContains))
{
return ArrayBuiltinFunctions.Visit(methodCallExpression, context);
}

return TypeCheckFunctions.Visit(methodCallExpression, context);
}
Expand Down
36 changes: 33 additions & 3 deletions Microsoft.Azure.Cosmos/src/Linq/CosmosLinqExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

namespace Microsoft.Azure.Cosmos.Linq
{
using System;
using System;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
Expand Down Expand Up @@ -235,8 +236,37 @@ public static bool RegexMatch(this object obj, string regularExpression)
public static bool RegexMatch(this object obj, string regularExpression, string searchModifier)
{
throw new NotImplementedException(ClientResources.TypeCheckExtensionFunctionsNotImplemented);
}

}

/// <summary>
/// Returns a boolean indicating whether the array contains the specified value.
/// You can check for a partial or full match of an object by using a boolean expression within the function.
/// For more information, see https://learn.microsoft.com/en-gb/azure/cosmos-db/nosql/query/array-contains.
/// This method is to be used in LINQ expressions only and will be evaluated on server.
/// There's no implementation provided in the client library.
/// </summary>
/// <param name="obj"></param>
/// <param name="itemToMatch">The value to search within the array.</param>
/// <param name="partialMatch">Indicating whether the search should check for a partial match (true) or a full match (false).</param>
/// <returns>Returns true if the array array contains the specified value; otherwise, false.</returns>
/// <example>
/// <code>
/// <![CDATA[
/// var matched = documents.Where(document => document.Namess.ArrayContains(<itemToMatch>, <partialMatch>));
/// // To do a partial match on an array of objects, pass in an anonymous object set partialMatch to true
/// var matched = documents.Where(document => document.ObjectArray.ArrayContains(new { Name = <name> }, true));
/// ]]>
/// </code>
/// </example>
public static bool ArrayContains(this IEnumerable obj, object itemToMatch, bool partialMatch)
{
// The signature for this is not generic so the user can pass in anonymous type for the item to match
// e.g documents.Where(document => document.FooItems.ArrayContains(new { Name = "Bar" }, true)
// partialMatch could have a default values (bool partialMatch = false) but those are not valid in expressions
// (see error CS0854) and this method will only be used in expressions, so not point adding it
throw new NotImplementedException(ClientResources.TypeCheckExtensionFunctionsNotImplemented);
}

/// <summary>
/// This method generate query definition from LINQ query.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
<Results>
<Result>
<Input>
<Description><![CDATA[ArrayContains in Select clause with int value and match partial true]]></Description>
<Expression><![CDATA[query.Select(doc => doc.ArrayField.ArrayContains(Convert(1, Object), True))]]></Expression>
</Input>
<Output>
<SqlQuery><![CDATA[
SELECT VALUE ARRAY_CONTAINS(root["ArrayField"], 1, true)
FROM root]]></SqlQuery>
</Output>
</Result>
<Result>
<Input>
<Description><![CDATA[ArrayContains in Filter clause with int value and match partial true]]></Description>
<Expression><![CDATA[query.Where(doc => doc.ArrayField.ArrayContains(Convert(1, Object), True))]]></Expression>
</Input>
<Output>
<SqlQuery><![CDATA[
SELECT VALUE root
FROM root
WHERE ARRAY_CONTAINS(root["ArrayField"], 1, true)]]></SqlQuery>
</Output>
</Result>
<Result>
<Input>
<Description><![CDATA[ArrayContains in Select clause with object value and match partial true]]></Description>
<Expression><![CDATA[query.Select(doc => doc.ObjectArrayField.ArrayContains(new AnonymousType(Field = "abc"), True))]]></Expression>
</Input>
<Output>
<SqlQuery><![CDATA[
SELECT VALUE ARRAY_CONTAINS(root["ObjectArrayField"], {"Field": "abc"}, true)
FROM root]]></SqlQuery>
</Output>
</Result>
<Result>
<Input>
<Description><![CDATA[ArrayContains in Filter clause with object value and match partial true]]></Description>
<Expression><![CDATA[query.Where(doc => doc.ObjectArrayField.ArrayContains(new AnonymousType(Field = "abc"), True))]]></Expression>
</Input>
<Output>
<SqlQuery><![CDATA[
SELECT VALUE root
FROM root
WHERE ARRAY_CONTAINS(root["ObjectArrayField"], {"Field": "abc"}, true)]]></SqlQuery>
</Output>
</Result>
<Result>
<Input>
<Description><![CDATA[ArrayContains in Select clause with int value and match partial false]]></Description>
<Expression><![CDATA[query.Select(doc => doc.ArrayField.ArrayContains(Convert(1, Object), False))]]></Expression>
</Input>
<Output>
<SqlQuery><![CDATA[
SELECT VALUE ARRAY_CONTAINS(root["ArrayField"], 1, false)
FROM root]]></SqlQuery>
</Output>
</Result>
<Result>
<Input>
<Description><![CDATA[ArrayContains in Filter clause with int value and match partial false]]></Description>
<Expression><![CDATA[query.Where(doc => doc.ArrayField.ArrayContains(Convert(1, Object), False))]]></Expression>
</Input>
<Output>
<SqlQuery><![CDATA[
SELECT VALUE root
FROM root
WHERE ARRAY_CONTAINS(root["ArrayField"], 1, false)]]></SqlQuery>
</Output>
</Result>
<Result>
<Input>
<Description><![CDATA[ArrayContains in Select clause with object value and match partial false]]></Description>
<Expression><![CDATA[query.Select(doc => doc.ObjectArrayField.ArrayContains(new AnonymousType(Field = "abc"), False))]]></Expression>
</Input>
<Output>
<SqlQuery><![CDATA[
SELECT VALUE ARRAY_CONTAINS(root["ObjectArrayField"], {"Field": "abc"}, false)
FROM root]]></SqlQuery>
</Output>
</Result>
<Result>
<Input>
<Description><![CDATA[ArrayContains in Filter clause with object value and match partial false]]></Description>
<Expression><![CDATA[query.Where(doc => doc.ObjectArrayField.ArrayContains(new AnonymousType(Field = "abc"), False))]]></Expression>
</Input>
<Output>
<SqlQuery><![CDATA[
SELECT VALUE root
FROM root
WHERE ARRAY_CONTAINS(root["ObjectArrayField"], {"Field": "abc"}, false)]]></SqlQuery>
</Output>
</Result>
</Results>
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ internal class DataObject : LinqTestObject
#pragma warning disable CS0649 // Field is never assigned to, and will always have its default value false
public bool BooleanField;
public SimpleObject ObjectField = new SimpleObject();
public SimpleObject[] ObjectArrayField = new SimpleObject[0];
public Guid GuidField;
#pragma warning restore // Field is never assigned to, and will always have its default value false

Expand Down Expand Up @@ -343,6 +344,31 @@ public void TestRegexMatchFunction()

new LinqTestInput("RegexMatch with 2nd argument invalid string options", b => getQuery(b).Where(doc => doc.StringField.RegexMatch("abcd", "this should error out on the back end"))),
};
this.ExecuteTestSuite(inputs);
}

[TestMethod]
public void TestArrayContainsBuiltinFunction()
{
// Similar to the type checking function, Array_Contains are not supported client side.
// Therefore these methods are verified with baseline only.
List<DataObject> data = new List<DataObject>();
IOrderedQueryable<DataObject> query = testContainer.GetItemLinqQueryable<DataObject>(allowSynchronousQueryExecution: true);
Func<bool, IQueryable<DataObject>> getQuery = useQuery => useQuery ? query : data.AsQueryable();

List<LinqTestInput> inputs = new List<LinqTestInput>
{
new LinqTestInput("ArrayContains in Select clause with int value and match partial true", b => getQuery(b).Select(doc => doc.ArrayField.ArrayContains(1, true))),
new LinqTestInput("ArrayContains in Filter clause with int value and match partial true", b => getQuery(b).Where(doc => doc.ArrayField.ArrayContains(1, true))),
new LinqTestInput("ArrayContains in Select clause with object value and match partial true", b => getQuery(b).Select(doc => doc.ObjectArrayField.ArrayContains(new { Field = "abc" }, true))),
new LinqTestInput("ArrayContains in Filter clause with object value and match partial true", b => getQuery(b).Where(doc => doc.ObjectArrayField.ArrayContains(new { Field = "abc" }, true))),

new LinqTestInput("ArrayContains in Select clause with int value and match partial false", b => getQuery(b).Select(doc => doc.ArrayField.ArrayContains(1, false))),
new LinqTestInput("ArrayContains in Filter clause with int value and match partial false", b => getQuery(b).Where(doc => doc.ArrayField.ArrayContains(1, false))),
new LinqTestInput("ArrayContains in Select clause with object value and match partial false", b => getQuery(b).Select(doc => doc.ObjectArrayField.ArrayContains(new { Field = "abc" }, false))),
new LinqTestInput("ArrayContains in Filter clause with object value and match partial false", b => getQuery(b).Where(doc => doc.ObjectArrayField.ArrayContains(new { Field = "abc" }, false))),
};

this.ExecuteTestSuite(inputs);
}

Expand Down Expand Up @@ -449,10 +475,11 @@ public void TestDateTimeJsonConverter()
public void TestDateTimeJsonConverterTimezones()
{
const int Records = 10;
DateTime midDateTime = new (2016, 9, 13, 0, 0, 0);
DateTime midDateTime = new(2016, 9, 13, 0, 0, 0);
Func<Random, DataObject> createDataObj = (random) =>
{
DataObject obj = new() {
DataObject obj = new()
{
IsoDateOnly = LinqTestsCommon.RandomDateTime(random, midDateTime),
Id = Guid.NewGuid().ToString(),
Pk = "Test"
Expand Down Expand Up @@ -739,7 +766,7 @@ private Func<bool, IQueryable<DataObject>> CreateDataTestStringFunctions()
Pk = "Test",

// For ToString tests
ArrayField = new int[] {},
ArrayField = new int[] { },
Point = new Point(0, 0)
};
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@
<Content Include="BaselineTest\TestBaseline\IndexMetricsParserBaselineTest.IndexUtilizationHeaderLengthTest.xml">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</Content>
<Content Include="BaselineTest\TestBaseline\LinqTranslationBaselineTests.TestArrayContainsBuiltinFunction.xml">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</Content>
<Content Include="BaselineTest\TestBaseline\QueryAdvisorBaselineTest.QueryAdviceParse.xml">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</Content>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6680,6 +6680,13 @@
"Microsoft.Azure.Cosmos.Linq.CosmosLinqExtensions;System.Object;IsAbstract:True;IsSealed:True;IsInterface:False;IsEnum:False;IsClass:True;IsValueType:False;IsNested:False;IsGenericType:False;IsSerializable:False": {
"Subclasses": {},
"Members": {
"Boolean ArrayContains(System.Collections.IEnumerable, System.Object, Boolean)[System.Runtime.CompilerServices.ExtensionAttribute()]": {
"Type": "Method",
"Attributes": [
"ExtensionAttribute"
],
"MethodInfo": "Boolean ArrayContains(System.Collections.IEnumerable, System.Object, Boolean);IsAbstract:False;IsStatic:True;IsVirtual:False;IsGenericMethod:False;IsConstructor:False;IsFinal:False;"
},
"Boolean IsArray(System.Object)[System.Runtime.CompilerServices.ExtensionAttribute()]": {
"Type": "Method",
"Attributes": [
Expand Down

0 comments on commit 8ec125f

Please sign in to comment.