Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JS translation: Dictionary.GetValueOrDefault + fixed default parameter value resolution in generic methods #1761

Merged
merged 2 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,28 @@ public MethodRecognitionResult(int automaticTypeArgCount, int castCount, Express
public bool HasParamsAttribute { get; set; }
}

private Expression GetDefaultValue(ParameterInfo parameter)
{
if (parameter.HasDefaultValue)
{
var value = parameter.DefaultValue;
if (value is null && parameter.ParameterType.IsValueType)
{
// null with struct type means `default(T)`
value = ReflectionUtils.GetDefaultValue(parameter.ParameterType);
}
return Expression.Constant(value, parameter.ParameterType);
}
else if (parameter.IsDefined(ParamArrayAttributeType))
{
return Expression.NewArrayInit(parameter.ParameterType.GetElementType().NotNull());
}
else
{
throw new Exception($"Internal error: parameter {parameter.Name} of method {parameter.Member.Name} does not have a default value.");
}
}

private MethodRecognitionResult? TryCallMethod(MethodInfo method, Type[]? typeArguments, Expression[] positionalArguments, IDictionary<string, Expression>? namedArguments)
{
if (positionalArguments.Contains(null)) throw new ArgumentNullException("positionalArguments[]");
Expand Down Expand Up @@ -445,7 +467,7 @@ public MethodRecognitionResult(int automaticTypeArgCount, int castCount, Express
if (typeArgs[genericArgumentPosition] == null)
{
// try to resolve from arguments
var argType = GetGenericParameterType(genericArguments[genericArgumentPosition], parameterTypes, args.Select(s => s.Type).ToArray());
var argType = GetGenericParameterType(genericArguments[genericArgumentPosition], parameterTypes, args.Select(s => s?.Type).ToArray());
automaticTypeArgs++;
if (argType != null) typeArgs[genericArgumentPosition] = argType;
else return null;
Expand All @@ -466,11 +488,15 @@ public MethodRecognitionResult(int automaticTypeArgCount, int castCount, Express
}
else if (typeArguments != null) return null;

// cast arguments
// cast arguments and fill defaults
for (int i = 0; i < args.Length; i++)
{
if (args[i] == null)
{
args[i] = GetDefaultValue(parameters[i]);
}
Type elm;
if (args.Length == i + 1 && hasParamsArrayAttributes && !args[i].Type.IsArray)
if (args.Length == i + 1 && hasParamsArrayAttributes && !args[i]!.Type.IsArray)
{
elm = parameters[i].ParameterType.GetElementType().NotNull();
if (positionalArguments.Skip(i).Any(s => TypeConversion.ImplicitConversion(s, elm) is null))
Expand All @@ -482,7 +508,7 @@ public MethodRecognitionResult(int automaticTypeArgCount, int castCount, Express
{
elm = parameters[i].ParameterType;
}
var casted = TypeConversion.ImplicitConversion(args[i], elm);
var casted = TypeConversion.ImplicitConversion(args[i]!, elm);
if (casted == null)
{
return null;
Expand All @@ -492,7 +518,7 @@ public MethodRecognitionResult(int automaticTypeArgCount, int castCount, Express
castCount++;
args[i] = casted;
}
if (args.Length == i + 1 && hasParamsArrayAttributes && !args[i].Type.IsArray)
if (args.Length == i + 1 && hasParamsArrayAttributes && !args[i]!.Type.IsArray)
{
var converted = positionalArguments.Skip(i)
.Select(a => TypeConversion.EnsureImplicitConversion(a, elm))
Expand All @@ -505,13 +531,13 @@ public MethodRecognitionResult(int automaticTypeArgCount, int castCount, Express
automaticTypeArgCount: automaticTypeArgs,
castCount: castCount,
method: method,
arguments: args,
arguments: args!,
paramsArrayCount: positionalArguments.Length - args.Length,
hasParamsAttribute: hasParamsArrayAttributes,
isExtension: false
);
}
private static bool TryPrepareArguments(ParameterInfo[] parameters, Expression[] positionalArguments, IDictionary<string, Expression>? namedArguments, [MaybeNullWhen(false)] out Expression[] arguments, out int castCount)
private static bool TryPrepareArguments(ParameterInfo[] parameters, Expression[] positionalArguments, IDictionary<string, Expression>? namedArguments, [MaybeNullWhen(false)] out Expression?[] arguments, out int castCount)
{
castCount = 0;
arguments = null;
Expand All @@ -522,16 +548,17 @@ private static bool TryPrepareArguments(ParameterInfo[] parameters, Expression[]
if (!hasParamsArrayAttribute && parameters.Length < positionalArguments.Length)
return false;

arguments = new Expression[parameters.Length];
arguments = new Expression?[parameters.Length];
var copyItemsCount = !hasParamsArrayAttribute ? positionalArguments.Length : parameters.Length;

if (hasParamsArrayAttribute && parameters.Length > positionalArguments.Length)
{
var parameter = parameters.Last();
var elementType = parameter.ParameterType.GetElementType().NotNull();

// User specified no arguments for the `params` array, we need to create an empty array
arguments[arguments.Length - 1] = Expression.NewArrayInit(elementType);
// User specified no arguments for the `params` array => use default value
// created later by the GetDefaultValue, after we know the generic arguments
arguments[arguments.Length - 1] = null;

// Last argument was just generated => do not copy
addedArguments++;
Expand Down Expand Up @@ -561,7 +588,7 @@ private static bool TryPrepareArguments(ParameterInfo[] parameters, Expression[]
else if (parameters[i].HasDefaultValue)
{
castCount++;
arguments[i] = Expression.Constant(parameters[i].DefaultValue, parameters[i].ParameterType);
arguments[i] = null;
}
else if (parameters[i].IsDefined(ParamArrayAttributeType))
{
Expand All @@ -577,29 +604,30 @@ private static bool TryPrepareArguments(ParameterInfo[] parameters, Expression[]
return true;
}

private Type? GetGenericParameterType(Type genericArg, Type[] searchedGenericTypes, Type[] expressionTypes)
private Type? GetGenericParameterType(Type genericArg, Type[] searchedGenericTypes, Type?[] expressionTypes)
{
for (var i = 0; i < searchedGenericTypes.Length; i++)
{
if (expressionTypes.Length <= i) return null;
var expression = expressionTypes[i];
if (expression == null) continue;
var sgt = searchedGenericTypes[i];
if (sgt == genericArg)
{
return expressionTypes[i];
return expression;
}
if (sgt.IsArray)
{
var elementType = sgt.GetElementType();
var expressionElementType = expressionTypes[i].GetElementType();
var expressionElementType = expression.GetElementType();
if (elementType == genericArg)
return expressionElementType;
else
return GetGenericParameterType(genericArg, searchedGenericTypes[i].GetGenericArguments(), expressionTypes[i].GetGenericArguments());
return GetGenericParameterType(genericArg, searchedGenericTypes[i].GetGenericArguments(), expression.GetGenericArguments());
}
else if (sgt.IsGenericType)
{
Type[]? genericArguments = null;
var expression = expressionTypes[i];

if (expression.IsArray)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using DotVVM.Framework.Compilation.Binding;
using System.Collections.Immutable;
using DotVVM.Framework.Configuration;
using DotVVM.Framework.Utils;

namespace DotVVM.Framework.Compilation.ControlTree.Resolved
{
Expand Down Expand Up @@ -59,7 +60,7 @@ public DirectiveCompilationService(CompiledAssemblyCache compiledAssemblyCache,

public object? ResolvePropertyInitializer(DothtmlDirectiveNode directive, Type propertyType, BindingParserNode? initializer, ImmutableList<NamespaceImport> imports)
{
if (initializer == null) { return CreateDefaultValue(propertyType); }
if (initializer == null) { return ReflectionUtils.GetDefaultValue(propertyType); }

var registry = RegisterImports(TypeRegistry.DirectivesDefault(compiledAssemblyCache), imports);

Expand All @@ -75,25 +76,16 @@ public DirectiveCompilationService(CompiledAssemblyCache compiledAssemblyCache,
var lambda = Expression.Lambda<Func<object?>>(Expression.Block(Expression.Convert(TypeConversion.EnsureImplicitConversion(initializerExpression, propertyType), typeof(object))));
var lambdaDelegate = lambda.Compile(true);

return lambdaDelegate.Invoke() ?? CreateDefaultValue(propertyType);
return lambdaDelegate.Invoke() ?? ReflectionUtils.GetDefaultValue(propertyType);
}
catch (Exception ex)
{
directive.AddError("Could not initialize property value.");
directive.AddError(ex.Message);
return CreateDefaultValue(propertyType);
return ReflectionUtils.GetDefaultValue(propertyType);
}
}

private object? CreateDefaultValue(Type? type)
{
if (type != null && type.IsValueType)
{
return Activator.CreateInstance(type);
}
return null;
}

private Expression? CompileDirectiveExpression(DothtmlDirectiveNode directive, BindingParserNode expressionSyntax, ImmutableList<NamespaceImport> imports)
{
TypeRegistry registry;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,20 @@ private void AddDefaultDictionaryTranslations()
AddMethodTranslator(() => default(IReadOnlyDictionary<Generic.T, Generic.T>)!.ContainsKey(null!), containsKey);
AddMethodTranslator(() => default(IDictionary<Generic.T, Generic.T>)!.Remove(null!), new GenericMethodCompiler(args =>
new JsIdentifierExpression("dotvvm").Member("translations").Member("dictionary").Member("remove").Invoke(args[0].WithAnnotation(ShouldBeObservableAnnotation.Instance), args[1])));

var getValueOrDefault = new GenericMethodCompiler((JsExpression[] args, MethodInfo method) => {
var defaultValue =
args.Length > 3 ? args[3] :
new JsLiteral(ReflectionUtils.GetDefaultValue(method.GetGenericArguments().Last()));
return new JsIdentifierExpression("dotvvm").Member("translations").Member("dictionary").Member("getItem").Invoke(args[1], args[2], defaultValue);
});
#if DotNetCore
AddMethodTranslator(() => default(IReadOnlyDictionary<Generic.T, Generic.T>)!.GetValueOrDefault(null!), getValueOrDefault);
AddMethodTranslator(() => default(IReadOnlyDictionary<Generic.T, Generic.T>)!.GetValueOrDefault(null!, null), getValueOrDefault);
#endif
AddMethodTranslator(() => default(IImmutableDictionary<Generic.T, Generic.T>)!.GetValueOrDefault(null!), getValueOrDefault);
AddMethodTranslator(() => default(IImmutableDictionary<Generic.T, Generic.T>)!.GetValueOrDefault(null!, null), getValueOrDefault);
AddMethodTranslator(() => FunctionalExtensions.GetValueOrDefault(default(IReadOnlyDictionary<Generic.T, Generic.T>)!, null!, null!, false), getValueOrDefault);
}

private bool IsDictionary(Type type) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,14 @@ export function containsKey<Key, Value>(dictionary: Dictionary<Key, Value>, iden
return getKeyValueIndex(dictionary, identifier) !== null;
}

export function getItem<Key, Value>(dictionary: Dictionary<Key, Value>, identifier: Key): Value {
export function getItem<Key, Value>(dictionary: Dictionary<Key, Value>, identifier: Key, defaultValue?: Value): Value {
const index = getKeyValueIndex(dictionary, identifier);
if (index === null) {
throw Error("Provided key \"" + identifier + "\" is not present in the dictionary!");
if (defaultValue !== undefined) {
return defaultValue;
} else {
throw Error("Provided key \"" + identifier + "\" is not present in the dictionary!");
}
}

return ko.unwrap(ko.unwrap(dictionary[index]).Value);
Expand Down
27 changes: 19 additions & 8 deletions src/Framework/Framework/Utils/ReflectionUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
using DotVVM.Framework.Routing;
using DotVVM.Framework.ViewModel;
using System.Diagnostics;
using System.Runtime.CompilerServices;

namespace DotVVM.Framework.Utils
{
Expand Down Expand Up @@ -135,14 +136,7 @@ public static bool IsAssignableToGenericType(this Type givenType, Type genericTy
// handle null values
if (value == null)
{
if (type == typeof(bool))
return BoxingUtils.False;
else if (type == typeof(int))
return BoxingUtils.Zero;
else if (type.IsValueType)
return Activator.CreateInstance(type);
else
return null;
return GetDefaultValue(type);
}

if (type.IsInstanceOfType(value)) return value;
Expand Down Expand Up @@ -460,6 +454,23 @@ public static Type MakeNullableType(this Type type)
return type.IsValueType && Nullable.GetUnderlyingType(type) == null && type != typeof(void) ? typeof(Nullable<>).MakeGenericType(type) : type;
}

/// <summary> Returns the equivalent of default(T) in C#, null for reference and Nullable&lt;T>, zeroed object for structs. </summary>
public static object? GetDefaultValue(Type type)
{
if (!type.IsValueType)
return null;
if (type.IsNullable())
return null;

if (type == typeof(bool))
return BoxingUtils.False;
else if (type == typeof(int))
return BoxingUtils.Zero;
// see https://github.com/dotnet/runtime/issues/90697
// notably we can't use Activator.CreateInstance, because C# now allows default constructors in structs
return FormatterServices.GetUninitializedObject(type);
}

public static Type UnwrapTaskType(this Type type)
{
if (type.IsGenericType && typeof(Task<>).IsAssignableFrom(type.GetGenericTypeDefinition()))
Expand Down
43 changes: 43 additions & 0 deletions src/Tests/Binding/BindingCompilationTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1055,6 +1055,39 @@ public void BindingCompiler_DelegateFromMethodGroup()
Assert.AreEqual(42, result(42));
}

[DataTestMethod]
[DataRow("100", typeof(int))]
[DataRow("'aa'", null)]
[DataRow("NullableDateOnly", null)]
[DataRow("DateOnly", typeof(DateOnly))]
public void BindingCompiler_GenericMethod_DefaultArgument(string expression, Type resultType)
{
var result = ExecuteBinding($"_this.GenericDefault({expression})", new [] { new TestViewModel() });
if (resultType == null)
{
Assert.IsNull(result);
}
else
{
Assert.AreEqual(resultType, result.GetType(), message: $"_this.GenericDefault({expression}) returned {result} of type {result?.GetType().FullName ?? "null"}");
Assert.AreEqual(ReflectionUtils.GetDefaultValue(resultType), result);
}
}

[TestMethod]
public void BindingCompiler_GenericMethod_ParamsEmpty()
{
var result = ExecuteBinding("_this.GenericParams<int>()", new [] { new TestViewModel() });
Assert.AreEqual((0, 0), result);
}

[TestMethod]
public void BindingCompiler_GenericMethod_Params()
{
var result = ExecuteBinding("_this.GenericParams(10, 20, 30)", new [] { new TestViewModel() });
Assert.AreEqual((10, 3), result);
}

[TestMethod]
public void BindingCompiler_ComparisonOperators()
{
Expand Down Expand Up @@ -1356,6 +1389,16 @@ public async Task<string> GetStringPropAsync()
public string MethodWithOverloads(string i) => i;
public string MethodWithOverloads(DateTime i) => i.ToString();
public int MethodWithOverloads(int a, int b) => a + b;

public T GenericDefault<T>(T something, T somethingElse = default)
{
return somethingElse;
}

public (T, int) GenericParams<T>(params T[] something)
{
return (something.FirstOrDefault(), something.Length);
}
}


Expand Down
12 changes: 12 additions & 0 deletions src/Tests/Binding/JavascriptCompilationTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,18 @@ public void JsTranslator_ReadOnlyDictionaryIndexer_Get()
Assert.AreEqual("dotvvm.translations.dictionary.getItem(ReadOnlyDictionary(),1)", result);
}

[DataTestMethod]
[DataRow("Dictionary")]
[DataRow("ReadOnlyDictionary")]
public void JsTranslator_Dictionary_GetValueOrDefault(string property)
{
var imports = new NamespaceImport[] { new("System.Collections.Generic"), new("DotVVM.Framework.Utils") };
var result = CompileBinding($"{property}.GetValueOrDefault(1)", imports, typeof(TestViewModel5));
Assert.AreEqual($"dotvvm.translations.dictionary.getItem({property}(),1,0)", result);
var result2 = CompileBinding($"{property}.GetValueOrDefault(1, 1024)", imports, typeof(TestViewModel5));
Assert.AreEqual($"dotvvm.translations.dictionary.getItem({property}(),1,1024)", result2);
}

[TestMethod]
public void JsTranslator_DictionaryIndexer_Set()
{
Expand Down
Loading