Skip to content

Commit d2703e1

Browse files
authored
Merge pull request #33 from DataObjects-NET/master-bulk-with-in
Makes IEnumberable<T>.Contains work with Bulk operations
2 parents eb5e356 + 95eaa78 commit d2703e1

File tree

2 files changed

+139
-21
lines changed

2 files changed

+139
-21
lines changed
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
// Copyright (C) 2020 Xtensive LLC.
2+
// This code is distributed under MIT license terms.
3+
// See the License.txt file in the project root for more information.
4+
5+
using System;
6+
using System.Linq;
7+
using NUnit.Framework;
8+
using Xtensive.Orm.BulkOperations.ContainsTestModel;
9+
using Xtensive.Orm.Configuration;
10+
11+
namespace Xtensive.Orm.BulkOperations.ContainsTestModel
12+
{
13+
[HierarchyRoot]
14+
[KeyGenerator(KeyGeneratorKind.None)]
15+
public class TagType : Entity
16+
{
17+
[Field, Key]
18+
public long Id { get; private set; }
19+
20+
[Field]
21+
public int ProjectedValueAdjustment { get; set; }
22+
23+
public TagType(Session session, long id)
24+
:base(session, id)
25+
{
26+
}
27+
}
28+
}
29+
30+
namespace Xtensive.Orm.BulkOperations.Tests
31+
{
32+
public class ContainsTest : AutoBuildTest
33+
{
34+
private long[] tagIds;
35+
36+
protected override DomainConfiguration BuildConfiguration()
37+
{
38+
var configuration = base.BuildConfiguration();
39+
configuration.Types.Register(typeof(TagType).Assembly, typeof(TagType).Namespace);
40+
return configuration;
41+
}
42+
43+
protected override void PopulateData()
44+
{
45+
tagIds = Enumerable.Range(0, 100).Select(i => (long) i).ToArray();
46+
using (var session = Domain.OpenSession())
47+
using (var transaction = session.OpenTransaction()) {
48+
foreach (var id in tagIds.Concat(Enumerable.Repeat(1000, 1).Select(i => (long) i)))
49+
new TagType(session, id) { ProjectedValueAdjustment = -1 };
50+
transaction.Complete();
51+
}
52+
}
53+
54+
[Test]
55+
public void Test1()
56+
{
57+
using (var session = Domain.OpenSession())
58+
using (var tx = session.OpenTransaction()) {
59+
var updatedRows = session.Query.All<TagType>()
60+
.Where(t => t.Id.In(tagIds))
61+
.Set(t => t.ProjectedValueAdjustment, 2)
62+
.Update();
63+
Assert.That(updatedRows, Is.EqualTo(100));
64+
Assert.That(session.Query.All<TagType>().Count(t => t.ProjectedValueAdjustment == 2 && t.Id <= 200), Is.EqualTo(100));
65+
Assert.That(session.Query.All<TagType>().Count(t => t.ProjectedValueAdjustment == -1 && t.Id > 700), Is.EqualTo(1));
66+
}
67+
}
68+
69+
[Test]
70+
public void Test2()
71+
{
72+
using (var session = Domain.OpenSession())
73+
using (var tx = session.OpenTransaction()) {
74+
var updatedRows = session.Query.All<TagType>()
75+
.Where(t => t.Id.In(IncludeAlgorithm.ComplexCondition, tagIds))
76+
.Set(t => t.ProjectedValueAdjustment, 2)
77+
.Update();
78+
Assert.That(updatedRows, Is.EqualTo(100));
79+
Assert.That(session.Query.All<TagType>().Count(t => t.ProjectedValueAdjustment == 2 && t.Id <= 200), Is.EqualTo(100));
80+
Assert.That(session.Query.All<TagType>().Count(t => t.ProjectedValueAdjustment == -1 && t.Id > 700), Is.EqualTo(1));
81+
}
82+
}
83+
84+
[Test]
85+
public void Test3()
86+
{
87+
using (var session = Domain.OpenSession())
88+
using (var tx = session.OpenTransaction()) {
89+
Assert.Throws<NotSupportedException>(() => session.Query.All<TagType>()
90+
.Where(t => t.Id.In(IncludeAlgorithm.TemporaryTable, tagIds))
91+
.Set(t => t.ProjectedValueAdjustment, 2)
92+
.Update());
93+
}
94+
}
95+
96+
[Test]
97+
public void Test4()
98+
{
99+
using (var session = Domain.OpenSession())
100+
using (var tx = session.OpenTransaction()) {
101+
var updatedRows = session.Query.All<TagType>()
102+
.Where(t => tagIds.Contains(t.Id))
103+
.Set(t => t.ProjectedValueAdjustment, 2)
104+
.Update();
105+
Assert.That(updatedRows, Is.EqualTo(100));
106+
Assert.That(session.Query.All<TagType>().Count(t => t.ProjectedValueAdjustment == 2 && t.Id <= 200), Is.EqualTo(100));
107+
Assert.That(session.Query.All<TagType>().Count(t => t.ProjectedValueAdjustment == -1 && t.Id > 700), Is.EqualTo(1));
108+
}
109+
}
110+
}
111+
}

Extensions/Xtensive.Orm.BulkOperations/Internals/QueryOperation.cs

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using System;
1+
using System;
22
using System.Collections.Generic;
33
using System.Linq;
44
using System.Linq.Expressions;
@@ -14,7 +14,7 @@ namespace Xtensive.Orm.BulkOperations
1414
internal abstract class QueryOperation<T> : Operation<T>
1515
where T : class, IEntity
1616
{
17-
private static MethodInfo inMethod = GetInMethod();
17+
private readonly static MethodInfo inMethod = GetInMethod();
1818
protected IQueryable<T> query;
1919

2020
protected QueryOperation(QueryProvider queryProvider)
@@ -24,10 +24,9 @@ protected QueryOperation(QueryProvider queryProvider)
2424

2525
private static MethodInfo GetInMethod()
2626
{
27-
foreach (var method in typeof (QueryableExtensions).GetMethods().Where(a=>a.Name=="In"))
28-
{
27+
foreach (var method in typeof (QueryableExtensions).GetMethods().Where(a => string.Equals(a.Name, "In", StringComparison.Ordinal))) {
2928
var parameters = method.GetParameters();
30-
if (parameters.Length == 3 && parameters[2].ParameterType.Name == "IEnumerable`1")
29+
if (parameters.Length == 3 && string.Equals(parameters[2].ParameterType.Name, "IEnumerable`1", StringComparison.Ordinal))
3130
return method;
3231
}
3332
return null;
@@ -37,30 +36,38 @@ protected override int ExecuteInternal()
3736
{
3837
Expression e = query.Expression.Visit((MethodCallExpression ex) =>
3938
{
40-
if (ex.Method.DeclaringType == typeof (QueryableExtensions) && ex.Method.Name == "In" &&
41-
ex.Arguments.Count > 1)
42-
{
43-
if (ex.Arguments[1].Type == typeof (IncludeAlgorithm))
44-
{
45-
var v = (IncludeAlgorithm) ex.Arguments[1].Invoke();
46-
if (v == IncludeAlgorithm.TemporaryTable)
47-
{
39+
var methodInfo = ex.Method;
40+
//rewrite localCollection.Contains(entity.SomeField) -> entity.SomeField.In(localCollection)
41+
if (methodInfo.DeclaringType == typeof(Enumerable) &&
42+
string.Equals(methodInfo.Name, "Contains", StringComparison.Ordinal) &&
43+
ex.Arguments.Count == 2) {
44+
var localCollection = ex.Arguments[0];//IEnumerable<T>
45+
var valueToCheck = ex.Arguments[1];
46+
var genericInMethod = inMethod.MakeGenericMethod(new[] { valueToCheck.Type });
47+
ex = Expression.Call(genericInMethod, valueToCheck, Expression.Constant(IncludeAlgorithm.ComplexCondition), localCollection);
48+
methodInfo = ex.Method;
49+
}
50+
51+
if (methodInfo.DeclaringType == typeof(QueryableExtensions) &&
52+
string.Equals(methodInfo.Name, "In", StringComparison.Ordinal) &&
53+
ex.Arguments.Count > 1) {
54+
if (ex.Arguments[1].Type == typeof(IncludeAlgorithm)) {
55+
var algorithm = (IncludeAlgorithm) ex.Arguments[1].Invoke();
56+
if (algorithm == IncludeAlgorithm.TemporaryTable) {
4857
throw new NotSupportedException("IncludeAlgorithm.TemporaryTable is not supported");
4958
}
50-
if (v == IncludeAlgorithm.Auto)
51-
{
59+
if (algorithm == IncludeAlgorithm.Auto) {
5260
List<Expression> arguments = ex.Arguments.ToList();
5361
arguments[1] = Expression.Constant(IncludeAlgorithm.ComplexCondition);
54-
ex = Expression.Call(ex.Method, arguments);
62+
ex = Expression.Call(methodInfo, arguments);
5563
}
5664
}
57-
else
58-
{
65+
else {
5966
List<Expression> arguments = ex.Arguments.ToList();
6067
arguments.Insert(1, Expression.Constant(IncludeAlgorithm.ComplexCondition));
61-
List<Type> types = ex.Method.GetParameters().Select(a => a.ParameterType).ToList();
62-
types.Insert(1, typeof (IncludeAlgorithm));
63-
ex = Expression.Call(inMethod.MakeGenericMethod(ex.Method.GetGenericArguments()),
68+
List<Type> types = methodInfo.GetParameters().Select(a => a.ParameterType).ToList();
69+
types.Insert(1, typeof(IncludeAlgorithm));
70+
ex = Expression.Call(inMethod.MakeGenericMethod(methodInfo.GetGenericArguments()),
6471
arguments.ToArray());
6572
}
6673
}

0 commit comments

Comments
 (0)