Skip to content

Commit 4b68850

Browse files
authored
Merge pull request #374 from DataObjects-NET/6.0-typeas-transation-issue
Fixes certain translation issues connected to casts via "as"
2 parents 43e263d + 8e3b173 commit 4b68850

File tree

13 files changed

+1800
-259
lines changed

13 files changed

+1800
-259
lines changed

ChangeLog/6.0.13_dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
[main] Fixed certain cases of bad translation of casts via 'as' operator in LINQ queries

Orm/Xtensive.Orm.Tests/Issues/IssueJira0720_IncorrectTypeAsChainTranslation.cs

Lines changed: 1132 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 329 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,329 @@
1+
// Copyright (C) 2019 Xtensive LLC.
2+
// This code is distributed under MIT license terms.
3+
// See the License.txt file in the project root for more information.
4+
// Created by: Denis Kudelin
5+
// Created: 2018.01.16
6+
7+
using System;
8+
using System.Collections;
9+
using System.Collections.Generic;
10+
using System.Linq;
11+
using System.Linq.Expressions;
12+
using NUnit.Framework;
13+
using Xtensive.Core;
14+
using Xtensive.Orm.Configuration;
15+
using Xtensive.Orm.Tests.Linq.TypeAsTestModels;
16+
17+
namespace Xtensive.Orm.Tests.Linq
18+
{
19+
public class TypeAsTest : AutoBuildTest
20+
{
21+
#region Nested types
22+
23+
public class CustomExpressionReplacer : Xtensive.Linq.ExpressionVisitor
24+
{
25+
private readonly Func<Expression, Func<Expression, Expression>, Expression> visit;
26+
27+
public static Expression Visit(Expression expression, Func<Expression, Func<Expression, Expression>, Expression> visit) =>
28+
new CustomExpressionReplacer(visit).Visit(expression);
29+
30+
protected override Expression Visit(Expression exp) => visit(exp, base.Visit);
31+
32+
private CustomExpressionReplacer(Func<Expression, Func<Expression, Expression>, Expression> visit)
33+
{
34+
this.visit = visit;
35+
}
36+
}
37+
38+
private sealed class ComparisonComparer<T> : Comparer<T>
39+
{
40+
private readonly Comparison<T> comparison;
41+
42+
public static new Comparer<T> Create(Comparison<T> comparison) =>
43+
comparison == null ? throw new ArgumentNullException("comparison") : new ComparisonComparer<T>(comparison);
44+
45+
public override int Compare(T x, T y) => comparison(x, y);
46+
47+
private ComparisonComparer(Comparison<T> comparison)
48+
{
49+
this.comparison = comparison;
50+
}
51+
}
52+
#endregion
53+
54+
[Test]
55+
public void Test1()
56+
{
57+
Require.AllFeaturesSupported(Providers.ProviderFeatures.Apply);
58+
59+
using(var session = Domain.OpenSession())
60+
using (var tx = session.OpenTransaction()) {
61+
QueryExpressionTest(
62+
() => session.Query.All<TestEntity1>().SelectMany(
63+
x => x.EntitySet.SelectMany(
64+
y => y.EntitySet.Select(
65+
z => (x.Value1 as TestEntity3).EntitySet.Any()))));
66+
}
67+
}
68+
69+
[Test]
70+
public void Test2()
71+
{
72+
Require.AllFeaturesSupported(Providers.ProviderFeatures.Apply);
73+
74+
using (var session = Domain.OpenSession())
75+
using (var tx = session.OpenTransaction()) {
76+
QueryExpressionTest(
77+
() => session.Query.All<TestEntity1>().SelectMany(
78+
x => x.EntitySet.SelectMany(
79+
y => y.EntitySet.Select(
80+
z => (y.Value1 as TestEntity3).EntitySet.Any()))));
81+
}
82+
}
83+
84+
[Test]
85+
public void Test3()
86+
{
87+
Require.AllFeaturesSupported(Providers.ProviderFeatures.Apply);
88+
89+
using (var session = Domain.OpenSession())
90+
using (var tx = session.OpenTransaction()) {
91+
QueryExpressionTest(
92+
() =>
93+
session.Query.All<TestEntity1>().SelectMany(
94+
x => x.EntitySet.SelectMany(y => y.EntitySet.Select(z => (z.Value1 as TestEntity3).EntitySet.Any()))));
95+
}
96+
}
97+
98+
[Test]
99+
public void Test4()
100+
{
101+
using (var session = Domain.OpenSession())
102+
using (var tx = session.OpenTransaction()) {
103+
QueryExpressionTest(
104+
() =>
105+
session.Query.All<TestEntity2>()
106+
.Where(x => ((x.Value1 as TestEntity3).Value1 as TestEntity3).EntitySet.Any()).Select(x => x.Id2));
107+
}
108+
}
109+
110+
[Test]
111+
public void Test5()
112+
{
113+
using (var session = Domain.OpenSession())
114+
using (var tx = session.OpenTransaction()) {
115+
QueryExpressionTest(
116+
() => session.Query.All<TestEntity2>().Where(
117+
x => x.EntitySet.Any(y => ((x.Value1 as TestEntity3).Value1 as TestEntity3).EntitySet.Any())));
118+
}
119+
}
120+
121+
private void QueryExpressionTest<TResult>(Expression<Func<TResult>> queryExpression)
122+
{
123+
var result1 = RewriteQueryExpressionAndInvoke(queryExpression, false);
124+
var result2 = RewriteQueryExpressionAndInvoke(queryExpression, true);
125+
126+
if (!(result1 is IEnumerable) || !(result2 is IEnumerable)) {
127+
Assert.That(result1, Is.EqualTo(result2));
128+
return;
129+
}
130+
131+
var result1Array = ((IEnumerable) result1).Cast<object>().ToArray();
132+
var result2Array = ((IEnumerable) result2).Cast<object>().ToArray();
133+
134+
Assert.That(result1Array.SequenceEqual(result2Array));
135+
}
136+
137+
private TResult RewriteQueryExpressionAndInvoke<TResult>(Expression<Func<TResult>> expression, bool asEnumerable)
138+
{
139+
var orderByMethod = (asEnumerable ? typeof(Enumerable) : typeof(Queryable)).GetMethods()
140+
.Single(x => x.Name == "OrderBy" && x.GetParameters().Length == (asEnumerable ? 3 : 2));
141+
var toArrayMethod = typeof(Enumerable).GetMethod("ToArray");
142+
var asQueryableMethod = typeof(Queryable).GetMethods().Single(x => x.Name == "AsQueryable" && x.IsGenericMethod);
143+
var keyPropertyInfo = typeof(IEntity).GetProperty("Key");
144+
145+
expression = ((Expression<Func<TResult>>) CustomExpressionReplacer.Visit(
146+
expression,
147+
(e, visit) => {
148+
var result = (e = visit(e));
149+
150+
if (result != null && typeof(IQueryable<IEntity>).IsAssignableFrom(result.Type)) {
151+
var isOrderedQueryable = typeof(IOrderedQueryable).IsAssignableFrom(result.Type);
152+
var entityType = result.Type.GetGenericArguments().Single();
153+
154+
if (asEnumerable)
155+
result = Expression.Call(toArrayMethod.MakeGenericMethod(entityType), result);
156+
157+
if (!isOrderedQueryable) {
158+
var keyParameter = Expression.Parameter(entityType);
159+
var keyProperty = Expression.Property(keyParameter, keyPropertyInfo);
160+
var orderByMethodGeneric = orderByMethod.MakeGenericMethod(entityType, keyPropertyInfo.PropertyType);
161+
var parameters = orderByMethodGeneric.GetParameters();
162+
var keySelectorType = asEnumerable
163+
? parameters[1].ParameterType
164+
: parameters[1].ParameterType.GetGenericArguments().Single();
165+
var keySelector = (Expression) Expression.Lambda(keySelectorType, keyProperty, keyParameter);
166+
167+
if (asEnumerable) {
168+
keySelector = Expression.Constant(((LambdaExpression) keySelector).Compile());
169+
var comparer = Expression.Constant(
170+
ComparisonComparer<Key>.Create(
171+
(k1, k2) => Comparer.Default.Compare(k1.Value.GetValue(0), k2.Value.GetValue(0))));
172+
result = Expression.Call(orderByMethodGeneric, result, keySelector, comparer);
173+
result = Expression.Call(toArrayMethod.MakeGenericMethod(entityType), result);
174+
}
175+
else {
176+
result = Expression.Call(orderByMethodGeneric, result, keySelector);
177+
}
178+
}
179+
180+
if (asEnumerable) {
181+
result = Expression.Call(asQueryableMethod.MakeGenericMethod(entityType), result);
182+
}
183+
}
184+
185+
if (asEnumerable && (e is MemberExpression || e is MethodCallExpression)) {
186+
Expression obj;
187+
188+
var methodCall = e as MethodCallExpression;
189+
if (methodCall != null) {
190+
obj = methodCall.Object ?? methodCall.Arguments.FirstOrDefault();
191+
}
192+
else {
193+
obj = ((MemberExpression) e).Expression;
194+
}
195+
196+
if (obj != null && (typeof(IQueryable<IEntity>).IsAssignableFrom(obj.Type)
197+
|| typeof(IEntity).IsAssignableFrom(obj.Type)
198+
|| typeof(Structure).IsAssignableFrom(obj.Type))) {
199+
result = Expression.Condition(
200+
Expression.Equal(obj, Expression.Constant(null, obj.Type)),
201+
Expression.Default(result.Type),
202+
result);
203+
return result;
204+
}
205+
}
206+
207+
return result;
208+
}));
209+
210+
return expression.Compile()();
211+
}
212+
213+
protected override void PopulateData()
214+
{
215+
using (var session = Domain.OpenSession())
216+
using (var tx = session.OpenTransaction()) {
217+
var entity1a = new TestEntity1(1);
218+
219+
var entity2a = new TestEntity2(2);
220+
var entity2b = new TestEntity2(3);
221+
var entity2c = new TestEntity2(4);
222+
var entity2d = new TestEntity2(5);
223+
224+
var entity3a = new TestEntity3(6);
225+
var entity3b = new TestEntity3(7);
226+
var entity3c = new TestEntity3(8);
227+
228+
_ = entity1a.EntitySet.Add(entity2a);
229+
_ = entity1a.EntitySet.Add(entity2b);
230+
_ = entity1a.EntitySet.Add(entity2c);
231+
_ = entity1a.EntitySet.Add(entity2d);
232+
entity1a.Value1 = entity3a;
233+
234+
_ = entity2a.EntitySet.Add(entity3a);
235+
entity2a.Value1 = entity3a;
236+
_ = entity2b.EntitySet.Add(entity3b);
237+
entity2b.Value1 = entity3b;
238+
_ = entity2c.EntitySet.Add(entity3c);
239+
entity2d.Value1 = entity3c;
240+
241+
_ = entity3a.EntitySet.Add(entity1a);
242+
entity3a.Value1 = entity3c;
243+
_ = entity3b.EntitySet.Add(entity2a);
244+
entity3b.Value1 = entity3a;
245+
entity3c.Value1 = entity3b;
246+
247+
tx.Complete();
248+
}
249+
}
250+
251+
protected override DomainConfiguration BuildConfiguration()
252+
{
253+
var config = base.BuildConfiguration();
254+
config.Types.Register(typeof(ITestEntity).Assembly, typeof(ITestEntity).Namespace);
255+
return config;
256+
}
257+
}
258+
}
259+
260+
namespace Xtensive.Orm.Tests.Linq.TypeAsTestModels
261+
{
262+
public interface ITestEntity : IEntity
263+
{
264+
265+
}
266+
267+
[HierarchyRoot]
268+
public class TestEntity1 : Entity, ITestEntity
269+
{
270+
[Field]
271+
public EntitySet<TestEntity2> EntitySet { get; set; }
272+
273+
[Field]
274+
public ITestEntity Value1 { get; set; }
275+
276+
[Field, Key]
277+
public int Id { get; set; }
278+
279+
[Field]
280+
public int Id2 { get; set; }
281+
282+
public TestEntity1(int id2)
283+
{
284+
Id2 = id2;
285+
}
286+
}
287+
288+
[HierarchyRoot]
289+
public class TestEntity2 : Entity, ITestEntity
290+
{
291+
[Field]
292+
public EntitySet<TestEntity3> EntitySet { get; set; }
293+
294+
[Field, Key]
295+
public int Id { get; set; }
296+
297+
[Field]
298+
public ITestEntity Value1 { get; set; }
299+
300+
[Field]
301+
public int Id2 { get; set; }
302+
303+
public TestEntity2(int id2)
304+
{
305+
Id2 = id2;
306+
}
307+
}
308+
309+
[HierarchyRoot]
310+
public class TestEntity3 : Entity, ITestEntity
311+
{
312+
[Field, Key]
313+
public int Id { get; set; }
314+
315+
[Field]
316+
public EntitySet<ITestEntity> EntitySet { get; set; }
317+
318+
[Field]
319+
public ITestEntity Value1 { get; set; }
320+
321+
[Field]
322+
public int Id2 { get; set; }
323+
324+
public TestEntity3(int id2)
325+
{
326+
Id2 = id2;
327+
}
328+
}
329+
}

0 commit comments

Comments
 (0)