diff --git a/src/Framework/Framework/Compilation/Binding/MemberExpressionFactory.cs b/src/Framework/Framework/Compilation/Binding/MemberExpressionFactory.cs index da6ac12ce7..149f658ccc 100644 --- a/src/Framework/Framework/Compilation/Binding/MemberExpressionFactory.cs +++ b/src/Framework/Framework/Compilation/Binding/MemberExpressionFactory.cs @@ -415,6 +415,28 @@ public MethodRecognitionResult(int automaticTypeArgCount, int castCount, Express public bool HasParamsAttribute { get; set; } } + private Expression GetDefaultValue(ParameterInfo parameter) + { + if (parameter.HasDefaultValue) + { + var value = parameter.DefaultValue; + if (value is null && parameter.ParameterType.IsValueType) + { + // null with struct type means `default(T)` + value = ReflectionUtils.GetDefaultValue(parameter.ParameterType); + } + return Expression.Constant(value, parameter.ParameterType); + } + else if (parameter.IsDefined(ParamArrayAttributeType)) + { + return Expression.NewArrayInit(parameter.ParameterType.GetElementType().NotNull()); + } + else + { + throw new Exception($"Internal error: parameter {parameter.Name} of method {parameter.Member.Name} does not have a default value."); + } + } + private MethodRecognitionResult? TryCallMethod(MethodInfo method, Type[]? typeArguments, Expression[] positionalArguments, IDictionary? namedArguments) { if (positionalArguments.Contains(null)) throw new ArgumentNullException("positionalArguments[]"); @@ -445,7 +467,7 @@ public MethodRecognitionResult(int automaticTypeArgCount, int castCount, Express if (typeArgs[genericArgumentPosition] == null) { // try to resolve from arguments - var argType = GetGenericParameterType(genericArguments[genericArgumentPosition], parameterTypes, args.Select(s => s.Type).ToArray()); + var argType = GetGenericParameterType(genericArguments[genericArgumentPosition], parameterTypes, args.Select(s => s?.Type).ToArray()); automaticTypeArgs++; if (argType != null) typeArgs[genericArgumentPosition] = argType; else return null; @@ -466,11 +488,15 @@ public MethodRecognitionResult(int automaticTypeArgCount, int castCount, Express } else if (typeArguments != null) return null; - // cast arguments + // cast arguments and fill defaults for (int i = 0; i < args.Length; i++) { + if (args[i] == null) + { + args[i] = GetDefaultValue(parameters[i]); + } Type elm; - if (args.Length == i + 1 && hasParamsArrayAttributes && !args[i].Type.IsArray) + if (args.Length == i + 1 && hasParamsArrayAttributes && !args[i]!.Type.IsArray) { elm = parameters[i].ParameterType.GetElementType().NotNull(); if (positionalArguments.Skip(i).Any(s => TypeConversion.ImplicitConversion(s, elm) is null)) @@ -482,7 +508,7 @@ public MethodRecognitionResult(int automaticTypeArgCount, int castCount, Express { elm = parameters[i].ParameterType; } - var casted = TypeConversion.ImplicitConversion(args[i], elm); + var casted = TypeConversion.ImplicitConversion(args[i]!, elm); if (casted == null) { return null; @@ -492,7 +518,7 @@ public MethodRecognitionResult(int automaticTypeArgCount, int castCount, Express castCount++; args[i] = casted; } - if (args.Length == i + 1 && hasParamsArrayAttributes && !args[i].Type.IsArray) + if (args.Length == i + 1 && hasParamsArrayAttributes && !args[i]!.Type.IsArray) { var converted = positionalArguments.Skip(i) .Select(a => TypeConversion.EnsureImplicitConversion(a, elm)) @@ -505,13 +531,13 @@ public MethodRecognitionResult(int automaticTypeArgCount, int castCount, Express automaticTypeArgCount: automaticTypeArgs, castCount: castCount, method: method, - arguments: args, + arguments: args!, paramsArrayCount: positionalArguments.Length - args.Length, hasParamsAttribute: hasParamsArrayAttributes, isExtension: false ); } - private static bool TryPrepareArguments(ParameterInfo[] parameters, Expression[] positionalArguments, IDictionary? namedArguments, [MaybeNullWhen(false)] out Expression[] arguments, out int castCount) + private static bool TryPrepareArguments(ParameterInfo[] parameters, Expression[] positionalArguments, IDictionary? namedArguments, [MaybeNullWhen(false)] out Expression?[] arguments, out int castCount) { castCount = 0; arguments = null; @@ -522,7 +548,7 @@ private static bool TryPrepareArguments(ParameterInfo[] parameters, Expression[] if (!hasParamsArrayAttribute && parameters.Length < positionalArguments.Length) return false; - arguments = new Expression[parameters.Length]; + arguments = new Expression?[parameters.Length]; var copyItemsCount = !hasParamsArrayAttribute ? positionalArguments.Length : parameters.Length; if (hasParamsArrayAttribute && parameters.Length > positionalArguments.Length) @@ -530,8 +556,9 @@ private static bool TryPrepareArguments(ParameterInfo[] parameters, Expression[] var parameter = parameters.Last(); var elementType = parameter.ParameterType.GetElementType().NotNull(); - // User specified no arguments for the `params` array, we need to create an empty array - arguments[arguments.Length - 1] = Expression.NewArrayInit(elementType); + // User specified no arguments for the `params` array => use default value + // created later by the GetDefaultValue, after we know the generic arguments + arguments[arguments.Length - 1] = null; // Last argument was just generated => do not copy addedArguments++; @@ -561,7 +588,7 @@ private static bool TryPrepareArguments(ParameterInfo[] parameters, Expression[] else if (parameters[i].HasDefaultValue) { castCount++; - arguments[i] = Expression.Constant(parameters[i].DefaultValue, parameters[i].ParameterType); + arguments[i] = null; } else if (parameters[i].IsDefined(ParamArrayAttributeType)) { @@ -577,29 +604,30 @@ private static bool TryPrepareArguments(ParameterInfo[] parameters, Expression[] return true; } - private Type? GetGenericParameterType(Type genericArg, Type[] searchedGenericTypes, Type[] expressionTypes) + private Type? GetGenericParameterType(Type genericArg, Type[] searchedGenericTypes, Type?[] expressionTypes) { for (var i = 0; i < searchedGenericTypes.Length; i++) { if (expressionTypes.Length <= i) return null; + var expression = expressionTypes[i]; + if (expression == null) continue; var sgt = searchedGenericTypes[i]; if (sgt == genericArg) { - return expressionTypes[i]; + return expression; } if (sgt.IsArray) { var elementType = sgt.GetElementType(); - var expressionElementType = expressionTypes[i].GetElementType(); + var expressionElementType = expression.GetElementType(); if (elementType == genericArg) return expressionElementType; else - return GetGenericParameterType(genericArg, searchedGenericTypes[i].GetGenericArguments(), expressionTypes[i].GetGenericArguments()); + return GetGenericParameterType(genericArg, searchedGenericTypes[i].GetGenericArguments(), expression.GetGenericArguments()); } else if (sgt.IsGenericType) { Type[]? genericArguments = null; - var expression = expressionTypes[i]; if (expression.IsArray) { diff --git a/src/Framework/Framework/Compilation/ControlTree/Resolved/DirectiveCompilationService.cs b/src/Framework/Framework/Compilation/ControlTree/Resolved/DirectiveCompilationService.cs index b602c75505..3bc17d053c 100644 --- a/src/Framework/Framework/Compilation/ControlTree/Resolved/DirectiveCompilationService.cs +++ b/src/Framework/Framework/Compilation/ControlTree/Resolved/DirectiveCompilationService.cs @@ -6,6 +6,7 @@ using DotVVM.Framework.Compilation.Binding; using System.Collections.Immutable; using DotVVM.Framework.Configuration; +using DotVVM.Framework.Utils; namespace DotVVM.Framework.Compilation.ControlTree.Resolved { @@ -59,7 +60,7 @@ public DirectiveCompilationService(CompiledAssemblyCache compiledAssemblyCache, public object? ResolvePropertyInitializer(DothtmlDirectiveNode directive, Type propertyType, BindingParserNode? initializer, ImmutableList imports) { - if (initializer == null) { return CreateDefaultValue(propertyType); } + if (initializer == null) { return ReflectionUtils.GetDefaultValue(propertyType); } var registry = RegisterImports(TypeRegistry.DirectivesDefault(compiledAssemblyCache), imports); @@ -75,25 +76,16 @@ public DirectiveCompilationService(CompiledAssemblyCache compiledAssemblyCache, var lambda = Expression.Lambda>(Expression.Block(Expression.Convert(TypeConversion.EnsureImplicitConversion(initializerExpression, propertyType), typeof(object)))); var lambdaDelegate = lambda.Compile(true); - return lambdaDelegate.Invoke() ?? CreateDefaultValue(propertyType); + return lambdaDelegate.Invoke() ?? ReflectionUtils.GetDefaultValue(propertyType); } catch (Exception ex) { directive.AddError("Could not initialize property value."); directive.AddError(ex.Message); - return CreateDefaultValue(propertyType); + return ReflectionUtils.GetDefaultValue(propertyType); } } - private object? CreateDefaultValue(Type? type) - { - if (type != null && type.IsValueType) - { - return Activator.CreateInstance(type); - } - return null; - } - private Expression? CompileDirectiveExpression(DothtmlDirectiveNode directive, BindingParserNode expressionSyntax, ImmutableList imports) { TypeRegistry registry; diff --git a/src/Framework/Framework/Compilation/Javascript/JavascriptTranslatableMethodCollection.cs b/src/Framework/Framework/Compilation/Javascript/JavascriptTranslatableMethodCollection.cs index 96f9f1e2a2..d53d1636ce 100644 --- a/src/Framework/Framework/Compilation/Javascript/JavascriptTranslatableMethodCollection.cs +++ b/src/Framework/Framework/Compilation/Javascript/JavascriptTranslatableMethodCollection.cs @@ -681,6 +681,20 @@ private void AddDefaultDictionaryTranslations() AddMethodTranslator(() => default(IReadOnlyDictionary)!.ContainsKey(null!), containsKey); AddMethodTranslator(() => default(IDictionary)!.Remove(null!), new GenericMethodCompiler(args => new JsIdentifierExpression("dotvvm").Member("translations").Member("dictionary").Member("remove").Invoke(args[0].WithAnnotation(ShouldBeObservableAnnotation.Instance), args[1]))); + + var getValueOrDefault = new GenericMethodCompiler((JsExpression[] args, MethodInfo method) => { + var defaultValue = + args.Length > 3 ? args[3] : + new JsLiteral(ReflectionUtils.GetDefaultValue(method.GetGenericArguments().Last())); + return new JsIdentifierExpression("dotvvm").Member("translations").Member("dictionary").Member("getItem").Invoke(args[1], args[2], defaultValue); + }); +#if DotNetCore + AddMethodTranslator(() => default(IReadOnlyDictionary)!.GetValueOrDefault(null!), getValueOrDefault); + AddMethodTranslator(() => default(IReadOnlyDictionary)!.GetValueOrDefault(null!, null), getValueOrDefault); +#endif + AddMethodTranslator(() => default(IImmutableDictionary)!.GetValueOrDefault(null!), getValueOrDefault); + AddMethodTranslator(() => default(IImmutableDictionary)!.GetValueOrDefault(null!, null), getValueOrDefault); + AddMethodTranslator(() => FunctionalExtensions.GetValueOrDefault(default(IReadOnlyDictionary)!, null!, null!, false), getValueOrDefault); } private bool IsDictionary(Type type) => diff --git a/src/Framework/Framework/Resources/Scripts/translations/dictionaryHelper.ts b/src/Framework/Framework/Resources/Scripts/translations/dictionaryHelper.ts index bc06aafd58..d19e7671b7 100644 --- a/src/Framework/Framework/Resources/Scripts/translations/dictionaryHelper.ts +++ b/src/Framework/Framework/Resources/Scripts/translations/dictionaryHelper.ts @@ -8,10 +8,14 @@ export function containsKey(dictionary: Dictionary, iden return getKeyValueIndex(dictionary, identifier) !== null; } -export function getItem(dictionary: Dictionary, identifier: Key): Value { +export function getItem(dictionary: Dictionary, identifier: Key, defaultValue?: Value): Value { const index = getKeyValueIndex(dictionary, identifier); if (index === null) { - throw Error("Provided key \"" + identifier + "\" is not present in the dictionary!"); + if (defaultValue !== undefined) { + return defaultValue; + } else { + throw Error("Provided key \"" + identifier + "\" is not present in the dictionary!"); + } } return ko.unwrap(ko.unwrap(dictionary[index]).Value); diff --git a/src/Framework/Framework/Utils/ReflectionUtils.cs b/src/Framework/Framework/Utils/ReflectionUtils.cs index bf5e8ea0ef..847090e336 100644 --- a/src/Framework/Framework/Utils/ReflectionUtils.cs +++ b/src/Framework/Framework/Utils/ReflectionUtils.cs @@ -27,6 +27,7 @@ using DotVVM.Framework.Routing; using DotVVM.Framework.ViewModel; using System.Diagnostics; +using System.Runtime.CompilerServices; namespace DotVVM.Framework.Utils { @@ -135,14 +136,7 @@ public static bool IsAssignableToGenericType(this Type givenType, Type genericTy // handle null values if (value == null) { - if (type == typeof(bool)) - return BoxingUtils.False; - else if (type == typeof(int)) - return BoxingUtils.Zero; - else if (type.IsValueType) - return Activator.CreateInstance(type); - else - return null; + return GetDefaultValue(type); } if (type.IsInstanceOfType(value)) return value; @@ -460,6 +454,23 @@ public static Type MakeNullableType(this Type type) return type.IsValueType && Nullable.GetUnderlyingType(type) == null && type != typeof(void) ? typeof(Nullable<>).MakeGenericType(type) : type; } + /// Returns the equivalent of default(T) in C#, null for reference and Nullable<T>, zeroed object for structs. + public static object? GetDefaultValue(Type type) + { + if (!type.IsValueType) + return null; + if (type.IsNullable()) + return null; + + if (type == typeof(bool)) + return BoxingUtils.False; + else if (type == typeof(int)) + return BoxingUtils.Zero; + // see https://github.com/dotnet/runtime/issues/90697 + // notably we can't use Activator.CreateInstance, because C# now allows default constructors in structs + return FormatterServices.GetUninitializedObject(type); + } + public static Type UnwrapTaskType(this Type type) { if (type.IsGenericType && typeof(Task<>).IsAssignableFrom(type.GetGenericTypeDefinition())) diff --git a/src/Tests/Binding/BindingCompilationTests.cs b/src/Tests/Binding/BindingCompilationTests.cs index 8ad41f60e9..81ff02e6f4 100755 --- a/src/Tests/Binding/BindingCompilationTests.cs +++ b/src/Tests/Binding/BindingCompilationTests.cs @@ -1055,6 +1055,39 @@ public void BindingCompiler_DelegateFromMethodGroup() Assert.AreEqual(42, result(42)); } + [DataTestMethod] + [DataRow("100", typeof(int))] + [DataRow("'aa'", null)] + [DataRow("NullableDateOnly", null)] + [DataRow("DateOnly", typeof(DateOnly))] + public void BindingCompiler_GenericMethod_DefaultArgument(string expression, Type resultType) + { + var result = ExecuteBinding($"_this.GenericDefault({expression})", new [] { new TestViewModel() }); + if (resultType == null) + { + Assert.IsNull(result); + } + else + { + Assert.AreEqual(resultType, result.GetType(), message: $"_this.GenericDefault({expression}) returned {result} of type {result?.GetType().FullName ?? "null"}"); + Assert.AreEqual(ReflectionUtils.GetDefaultValue(resultType), result); + } + } + + [TestMethod] + public void BindingCompiler_GenericMethod_ParamsEmpty() + { + var result = ExecuteBinding("_this.GenericParams()", new [] { new TestViewModel() }); + Assert.AreEqual((0, 0), result); + } + + [TestMethod] + public void BindingCompiler_GenericMethod_Params() + { + var result = ExecuteBinding("_this.GenericParams(10, 20, 30)", new [] { new TestViewModel() }); + Assert.AreEqual((10, 3), result); + } + [TestMethod] public void BindingCompiler_ComparisonOperators() { @@ -1356,6 +1389,16 @@ public async Task GetStringPropAsync() public string MethodWithOverloads(string i) => i; public string MethodWithOverloads(DateTime i) => i.ToString(); public int MethodWithOverloads(int a, int b) => a + b; + + public T GenericDefault(T something, T somethingElse = default) + { + return somethingElse; + } + + public (T, int) GenericParams(params T[] something) + { + return (something.FirstOrDefault(), something.Length); + } } diff --git a/src/Tests/Binding/JavascriptCompilationTests.cs b/src/Tests/Binding/JavascriptCompilationTests.cs index 1c323b8337..561e1f7921 100644 --- a/src/Tests/Binding/JavascriptCompilationTests.cs +++ b/src/Tests/Binding/JavascriptCompilationTests.cs @@ -538,6 +538,18 @@ public void JsTranslator_ReadOnlyDictionaryIndexer_Get() Assert.AreEqual("dotvvm.translations.dictionary.getItem(ReadOnlyDictionary(),1)", result); } + [DataTestMethod] + [DataRow("Dictionary")] + [DataRow("ReadOnlyDictionary")] + public void JsTranslator_Dictionary_GetValueOrDefault(string property) + { + var imports = new NamespaceImport[] { new("System.Collections.Generic"), new("DotVVM.Framework.Utils") }; + var result = CompileBinding($"{property}.GetValueOrDefault(1)", imports, typeof(TestViewModel5)); + Assert.AreEqual($"dotvvm.translations.dictionary.getItem({property}(),1,0)", result); + var result2 = CompileBinding($"{property}.GetValueOrDefault(1, 1024)", imports, typeof(TestViewModel5)); + Assert.AreEqual($"dotvvm.translations.dictionary.getItem({property}(),1,1024)", result2); + } + [TestMethod] public void JsTranslator_DictionaryIndexer_Set() {