diff --git a/src/coreclr/nativeaot/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.NativeAot.cs b/src/coreclr/nativeaot/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.NativeAot.cs index 75a49c362ff2a0..87eb31d58022fe 100644 --- a/src/coreclr/nativeaot/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.NativeAot.cs +++ b/src/coreclr/nativeaot/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.NativeAot.cs @@ -40,37 +40,36 @@ public abstract partial class ComWrappers private static readonly Guid IID_IInspectable = new Guid(0xAF86E2E0, 0xB12D, 0x4c6a, 0x9C, 0x5A, 0xD7, 0xAA, 0x65, 0x10, 0x1E, 0x90); private static readonly Guid IID_IWeakReferenceSource = new Guid(0x00000038, 0, 0, 0xC0, 0, 0, 0, 0, 0, 0, 0x46); - private static readonly ConditionalWeakTable s_rcwTable = new ConditionalWeakTable(); + private static readonly ConditionalWeakTable s_nativeObjectWrapperTable = new ConditionalWeakTable(); private static readonly GCHandleSet s_referenceTrackerNativeObjectWrapperCache = new GCHandleSet(); - private readonly ConditionalWeakTable _ccwTable = new ConditionalWeakTable(); - private readonly Lock _lock = new Lock(useTrivialWaits: true); - private readonly Dictionary _rcwCache = new Dictionary(); + private readonly ConditionalWeakTable _managedObjectWrapperTable = new ConditionalWeakTable(); + private readonly RcwCache _rcwCache = new(); internal static bool TryGetComInstanceForIID(object obj, Guid iid, out IntPtr unknown, out long wrapperId) { if (obj == null - || !s_rcwTable.TryGetValue(obj, out NativeObjectWrapper? wrapper)) + || !s_nativeObjectWrapperTable.TryGetValue(obj, out NativeObjectWrapper? wrapper)) { unknown = IntPtr.Zero; wrapperId = 0; return false; } - wrapperId = wrapper._comWrappers.id; - return Marshal.QueryInterface(wrapper._externalComObject, iid, out unknown) == HResults.S_OK; + wrapperId = wrapper.ComWrappers.id; + return Marshal.QueryInterface(wrapper.ExternalComObject, iid, out unknown) == HResults.S_OK; } public static unsafe bool TryGetComInstance(object obj, out IntPtr unknown) { unknown = IntPtr.Zero; if (obj == null - || !s_rcwTable.TryGetValue(obj, out NativeObjectWrapper? wrapper)) + || !s_nativeObjectWrapperTable.TryGetValue(obj, out NativeObjectWrapper? wrapper)) { return false; } - return Marshal.QueryInterface(wrapper._externalComObject, IID_IUnknown, out unknown) == HResults.S_OK; + return Marshal.QueryInterface(wrapper.ExternalComObject, IID_IUnknown, out unknown) == HResults.S_OK; } public static unsafe bool TryGetObject(IntPtr unknown, [NotNullWhen(true)] out object? obj) @@ -484,9 +483,6 @@ public ManagedObjectWrapperReleaser(ManagedObjectWrapper* wrapper) // There are still outstanding references on the COM side. // This case should only be hit when an outstanding // tracker refcount exists from AddRefFromReferenceTracker. - // When implementing IReferenceTrackerHost, this should be - // reconsidered. - // https://github.com/dotnet/runtime/issues/85137 GC.ReRegisterForFinalize(this); } } @@ -494,12 +490,13 @@ public ManagedObjectWrapperReleaser(ManagedObjectWrapper* wrapper) internal unsafe class NativeObjectWrapper { - internal IntPtr _externalComObject; + private IntPtr _externalComObject; private IntPtr _inner; - internal ComWrappers _comWrappers; - internal readonly GCHandle _proxyHandle; - internal readonly GCHandle _proxyHandleTrackingResurrection; - internal readonly bool _aggregatedManagedObjectWrapper; + private ComWrappers _comWrappers; + private GCHandle _proxyHandle; + private GCHandle _proxyHandleTrackingResurrection; + private readonly bool _aggregatedManagedObjectWrapper; + private readonly bool _uniqueInstance; static NativeObjectWrapper() { @@ -522,18 +519,19 @@ public static NativeObjectWrapper Create(IntPtr externalComObject, IntPtr inner, } } - public NativeObjectWrapper(IntPtr externalComObject, IntPtr inner, ComWrappers comWrappers, object comProxy, CreateObjectFlags flags) + protected NativeObjectWrapper(IntPtr externalComObject, IntPtr inner, ComWrappers comWrappers, object comProxy, CreateObjectFlags flags) { _externalComObject = externalComObject; _inner = inner; _comWrappers = comWrappers; + _uniqueInstance = flags.HasFlag(CreateObjectFlags.UniqueInstance); _proxyHandle = GCHandle.Alloc(comProxy, GCHandleType.Weak); // We have a separate handle tracking resurrection as we want to make sure // we clean up the NativeObjectWrapper only after the RCW has been finalized // due to it can access the native object in the finalizer. At the same time, - // we want other callers which are using _proxyHandle such as the RCW cache to - // see the object as not alive once it is eligible for finalization. + // we want other callers which are using ProxyHandle such as the reference tracker runtime + // to see the object as not alive once it is eligible for finalization. _proxyHandleTrackingResurrection = GCHandle.Alloc(comProxy, GCHandleType.WeakTrackResurrection); // If this is an aggregation scenario and the identity object @@ -548,11 +546,17 @@ public NativeObjectWrapper(IntPtr externalComObject, IntPtr inner, ComWrappers c } } + internal IntPtr ExternalComObject => _externalComObject; + internal ComWrappers ComWrappers => _comWrappers; + internal GCHandle ProxyHandle => _proxyHandle; + internal bool IsUniqueInstance => _uniqueInstance; + internal bool IsAggregatedWithManagedObjectWrapper => _aggregatedManagedObjectWrapper; + public virtual void Release() { - if (_comWrappers != null) + if (!_uniqueInstance && _comWrappers is not null) { - _comWrappers.RemoveRCWFromCache(_externalComObject, _proxyHandle); + _comWrappers._rcwCache.Remove(_externalComObject, this); _comWrappers = null; } @@ -712,23 +716,23 @@ public unsafe IntPtr GetOrCreateComInterfaceForObject(object instance, CreateCom { ArgumentNullException.ThrowIfNull(instance); - ManagedObjectWrapperHolder? ccwValue; - if (_ccwTable.TryGetValue(instance, out ccwValue)) + ManagedObjectWrapperHolder? managedObjectWrapper; + if (_managedObjectWrapperTable.TryGetValue(instance, out managedObjectWrapper)) { - ccwValue.AddRef(); - return ccwValue.ComIp; + managedObjectWrapper.AddRef(); + return managedObjectWrapper.ComIp; } - ccwValue = _ccwTable.GetValue(instance, (c) => + managedObjectWrapper = _managedObjectWrapperTable.GetValue(instance, (c) => { - ManagedObjectWrapper* value = CreateCCW(c, flags); + ManagedObjectWrapper* value = CreateManagedObjectWrapper(c, flags); return new ManagedObjectWrapperHolder(value, c); }); - ccwValue.AddRef(); - return ccwValue.ComIp; + managedObjectWrapper.AddRef(); + return managedObjectWrapper.ComIp; } - private unsafe ManagedObjectWrapper* CreateCCW(object instance, CreateComInterfaceFlags flags) + private unsafe ManagedObjectWrapper* CreateManagedObjectWrapper(object instance, CreateComInterfaceFlags flags) { ComInterfaceEntry* userDefined = ComputeVtables(instance, flags, out int userDefinedCount); if ((userDefined == null && userDefinedCount != 0) || userDefinedCount < 0) @@ -799,7 +803,7 @@ public object GetOrCreateObjectForComInstance(IntPtr externalComObject, CreateOb if (!TryGetOrCreateObjectForComInstanceInternal(externalComObject, IntPtr.Zero, flags, null, out obj)) throw new ArgumentNullException(nameof(externalComObject)); - return obj!; + return obj; } /// @@ -841,7 +845,7 @@ public object GetOrRegisterObjectForComInstance(IntPtr externalComObject, Create if (!TryGetOrCreateObjectForComInstanceInternal(externalComObject, inner, flags, wrapper, out obj)) throw new ArgumentNullException(nameof(externalComObject)); - return obj!; + return obj; } private static unsafe ComInterfaceDispatch* TryGetComInterfaceDispatch(IntPtr comObject) @@ -917,7 +921,6 @@ private static void DetermineIdentityAndInner( } } -#pragma warning disable IDE0060 /// /// Get the currently registered managed object or creates a new managed object and registers it. /// @@ -932,7 +935,7 @@ private unsafe bool TryGetOrCreateObjectForComInstanceInternal( IntPtr innerMaybe, CreateObjectFlags flags, object? wrapperMaybe, - out object? retValue) + [NotNullWhen(true)] out object? retValue) { if (externalComObject == IntPtr.Zero) throw new ArgumentNullException(nameof(externalComObject)); @@ -949,137 +952,149 @@ private unsafe bool TryGetOrCreateObjectForComInstanceInternal( using ComHolder releaseIdentity = new ComHolder(identity); - if (!flags.HasFlag(CreateObjectFlags.UniqueInstance)) + // If the user has requested a unique instance, + // we will immediately create the object, register it, + // and return. + if (flags.HasFlag(CreateObjectFlags.UniqueInstance)) { - using (_lock.EnterScope()) - { - if (_rcwCache.TryGetValue(identity, out GCHandle handle)) - { - object? cachedWrapper = handle.Target; - if (cachedWrapper is not null) - { - retValue = cachedWrapper; - return true; - } - else - { - // The GCHandle has been clear out but the NativeObjectWrapper - // finalizer has not yet run to remove the entry from _rcwCache - _rcwCache.Remove(identity); - } - } + retValue = CreateAndRegisterObjectForComInstance(identity, inner, flags); + return retValue is not null; + } - if (wrapperMaybe is not null) - { - retValue = wrapperMaybe; - NativeObjectWrapper wrapper = NativeObjectWrapper.Create( - identity, - inner, - this, - retValue, - flags); - if (!s_rcwTable.TryAdd(retValue, wrapper)) - { - wrapper.Release(); - throw new NotSupportedException(); - } - _rcwCache.Add(identity, wrapper._proxyHandle); - AddWrapperToReferenceTrackerHandleCache(wrapper); - return true; - } - } - if (flags.HasFlag(CreateObjectFlags.Unwrap)) + // If we have a live cached wrapper currently, + // return that. + if (_rcwCache.FindProxyForComInstance(identity) is object liveCachedWrapper) + { + retValue = liveCachedWrapper; + return true; + } + + // If the user tried to provide a pre-created managed wrapper, try to register + // that object as the wrapper. + if (wrapperMaybe is not null) + { + retValue = RegisterObjectForComInstance(identity, inner, wrapperMaybe, flags); + return retValue is not null; + } + + // Check if the provided COM instance is actually a managed object wrapper from this + // ComWrappers instance, and use it if it is. + if (flags.HasFlag(CreateObjectFlags.Unwrap)) + { + ComInterfaceDispatch* comInterfaceDispatch = TryGetComInterfaceDispatch(identity); + if (comInterfaceDispatch != null) { - ComInterfaceDispatch* comInterfaceDispatch = TryGetComInterfaceDispatch(identity); - if (comInterfaceDispatch != null) + // If we found a managed object wrapper in this ComWrappers instance + // and it has the same identity pointer as the one we're creating a NativeObjectWrapper for, + // unwrap it. We don't AddRef the wrapper as we don't take a reference to it. + // + // A managed object can have multiple managed object wrappers, with a max of one per context. + // Let's say we have a managed object A and ComWrappers instances C1 and C2. Let B1 and B2 be the + // managed object wrappers for A created with C1 and C2 respectively. + // If we are asked to create an EOC for B1 with the unwrap flag on the C2 ComWrappers instance, + // we will create a new wrapper. In this scenario, we'll only unwrap B2. + object unwrapped = ComInterfaceDispatch.GetInstance(comInterfaceDispatch); + if (_managedObjectWrapperTable.TryGetValue(unwrapped, out ManagedObjectWrapperHolder? unwrappedWrapperInThisContext)) { - // If we found a managed object wrapper in this ComWrappers instance - // and it's has the same identity pointer as the one we're creating a NativeObjectWrapper for, - // unwrap it. We don't AddRef the wrapper as we don't take a reference to it. - // - // A managed object can have multiple managed object wrappers, with a max of one per context. - // Let's say we have a managed object A and ComWrappers instances C1 and C2. Let B1 and B2 be the - // managed object wrappers for A created with C1 and C2 respectively. - // If we are asked to create an EOC for B1 with the unwrap flag on the C2 ComWrappers instance, - // we will create a new wrapper. In this scenario, we'll only unwrap B2. - object unwrapped = ComInterfaceDispatch.GetInstance(comInterfaceDispatch); - if (_ccwTable.TryGetValue(unwrapped, out ManagedObjectWrapperHolder? unwrappedWrapperInThisContext)) + // The unwrapped object has a CCW in this context. Compare with identity + // so we can see if it's the CCW for the unwrapped object in this context. + if (unwrappedWrapperInThisContext.ComIp == identity) { - // The unwrapped object has a CCW in this context. Compare with identity - // so we can see if it's the CCW for the unwrapped object in this context. - if (unwrappedWrapperInThisContext.ComIp == identity) - { - retValue = unwrapped; - return true; - } + retValue = unwrapped; + return true; } } } } - retValue = CreateObject(identity, flags); - if (retValue == null) + // If the user didn't provide a wrapper and couldn't unwrap a managed object wrapper, + // create a new wrapper. + retValue = CreateAndRegisterObjectForComInstance(identity, inner, flags); + return retValue is not null; + } + + private object? CreateAndRegisterObjectForComInstance(IntPtr identity, IntPtr inner, CreateObjectFlags flags) + { + object? retValue = CreateObject(identity, flags); + if (retValue is null) { // If ComWrappers instance cannot create wrapper, we can do nothing here. - return false; + return null; } - if (flags.HasFlag(CreateObjectFlags.UniqueInstance)) + return RegisterObjectForComInstance(identity, inner, retValue, flags); + } + + private object RegisterObjectForComInstance(IntPtr identity, IntPtr inner, object comProxy, CreateObjectFlags flags) + { + NativeObjectWrapper nativeObjectWrapper = NativeObjectWrapper.Create( + identity, + inner, + this, + comProxy, + flags); + + object actualProxy = comProxy; + NativeObjectWrapper actualWrapper = nativeObjectWrapper; + if (!nativeObjectWrapper.IsUniqueInstance) { - NativeObjectWrapper wrapper = NativeObjectWrapper.Create( - identity, - inner, - null, // No need to cache NativeObjectWrapper for unique instances. They are not cached. - retValue, - flags); - if (!s_rcwTable.TryAdd(retValue, wrapper)) + // Add our entry to the cache here, using an already existing entry if someone else beat us to it. + (actualWrapper, actualProxy) = _rcwCache.GetOrAddProxyForComInstance(identity, nativeObjectWrapper, comProxy); + if (actualWrapper != nativeObjectWrapper) { - wrapper.Release(); - throw new NotSupportedException(); + // We raced with another thread to map identity to nativeObjectWrapper + // and lost the race. We will use the other thread's nativeObjectWrapper, so we can release ours. + nativeObjectWrapper.Release(); } - AddWrapperToReferenceTrackerHandleCache(wrapper); - return true; } - using (_lock.EnterScope()) + // At this point, actualProxy is the RCW object for the identity + // and actualWrapper is the NativeObjectWrapper that is in the RCW cache (if not unique) that associates the identity with actualProxy. + // Register the NativeObjectWrapper to handle lifetime tracking of the references to the COM object. + RegisterWrapperForObject(actualWrapper, actualProxy); + + return actualProxy; + } + + private void RegisterWrapperForObject(NativeObjectWrapper wrapper, object comProxy) + { + // When we call into RegisterWrapperForObject, there is only one valid non-"unique instance" wrapper for a given + // COM instance, which is already registered in the RCW cache. + // If we find a wrapper in the table that is a different NativeObjectWrapper instance + // then it must be for a different COM instance. + // It's possible that we could race here with another thread that is trying to register the same comProxy + // for the same COM instance, but in that case we'll be passed the same NativeObjectWrapper instance + // for both threads. In that case, it doesn't matter which thread adds the entry to the NativeObjectWrapper table + // as the entry is always the same pair. + Debug.Assert(wrapper.ProxyHandle.Target == comProxy); + Debug.Assert(wrapper.IsUniqueInstance || _rcwCache.FindProxyForComInstance(wrapper.ExternalComObject) == comProxy); + + if (s_nativeObjectWrapperTable.TryGetValue(comProxy, out NativeObjectWrapper? registeredWrapper) + && registeredWrapper != wrapper) { - object? cachedWrapper = null; - if (_rcwCache.TryGetValue(identity, out var existingHandle)) - { - cachedWrapper = existingHandle.Target; - if (cachedWrapper is null) - { - // The GCHandle has been clear out but the NativeObjectWrapper - // finalizer has not yet run to remove the entry from _rcwCache - _rcwCache.Remove(identity); - } - } + Debug.Assert(registeredWrapper.ExternalComObject != wrapper.ExternalComObject); + wrapper.Release(); + throw new NotSupportedException(); + } - if (cachedWrapper is not null) - { - retValue = cachedWrapper; - } - else - { - NativeObjectWrapper wrapper = NativeObjectWrapper.Create( - identity, - inner, - this, - retValue, - flags); - if (!s_rcwTable.TryAdd(retValue, wrapper)) - { - wrapper.Release(); - throw new NotSupportedException(); - } - _rcwCache.Add(identity, wrapper._proxyHandle); - AddWrapperToReferenceTrackerHandleCache(wrapper); - } + registeredWrapper = GetValueFromRcwTable(comProxy, wrapper); + if (registeredWrapper != wrapper) + { + Debug.Assert(registeredWrapper.ExternalComObject != wrapper.ExternalComObject); + wrapper.Release(); + throw new NotSupportedException(); } - return true; + // Always register our wrapper to the reference tracker handle cache here. + // We may not be the thread that registered the handle, but we need to ensure that the wrapper + // is registered before we return to user code. Otherwise the wrapper won't be walked by the + // TrackerObjectManager and we could end up missing a section of the object graph. + // This cache deduplicates, so it is okay that the wrapper will be registered multiple times. + AddWrapperToReferenceTrackerHandleCache(registeredWrapper); + + // Separate out into a local function to avoid the closure and delegate allocation unless we need it. + static NativeObjectWrapper GetValueFromRcwTable(object userObject, NativeObjectWrapper newWrapper) => s_nativeObjectWrapperTable.GetValue(userObject, _ => newWrapper); } -#pragma warning restore IDE0060 private static void AddWrapperToReferenceTrackerHandleCache(NativeObjectWrapper wrapper) { @@ -1089,16 +1104,94 @@ private static void AddWrapperToReferenceTrackerHandleCache(NativeObjectWrapper } } - private void RemoveRCWFromCache(IntPtr comPointer, GCHandle expectedValue) + private sealed class RcwCache { - using (_lock.EnterScope()) + private readonly Lock _lock = new Lock(useTrivialWaits: true); + private readonly Dictionary _cache = []; + + /// + /// Gets the current RCW proxy object for if it exists in the cache or inserts a new entry with . + /// + /// The com instance we want to get or record an RCW for. + /// The for . + /// The proxy object that is associated with . + /// The proxy object currently in the cache for or the proxy object owned by if no entry exists and the corresponding native wrapper. + public (NativeObjectWrapper actualWrapper, object actualProxy) GetOrAddProxyForComInstance(IntPtr comPointer, NativeObjectWrapper wrapper, object comProxy) { - // TryGetOrCreateObjectForComInstanceInternal may have put a new entry into the cache - // in the time between the GC cleared the contents of the GC handle but before the - // NativeObjectWrapper finalizer ran. - if (_rcwCache.TryGetValue(comPointer, out GCHandle cachedValue) && expectedValue.Equals(cachedValue)) + lock (_lock) { - _rcwCache.Remove(comPointer); + Debug.Assert(wrapper.ProxyHandle.Target == comProxy); + ref GCHandle rcwEntry = ref CollectionsMarshal.GetValueRefOrAddDefault(_cache, comPointer, out bool exists); + if (!exists) + { + // Someone else didn't beat us to adding the entry to the cache. + // Add our entry here. + rcwEntry = GCHandle.Alloc(wrapper, GCHandleType.Weak); + } + else if (rcwEntry.Target is not (NativeObjectWrapper cachedWrapper)) + { + Debug.Assert(rcwEntry.IsAllocated); + // The target was collected, so we need to update the cache entry. + rcwEntry.Target = wrapper; + } + else + { + object? existingProxy = cachedWrapper.ProxyHandle.Target; + // The target NativeObjectWrapper was not collected, but we need to make sure + // that the proxy object is still alive. + if (existingProxy is not null) + { + // The existing proxy object is still alive, we will use that. + return (cachedWrapper, existingProxy); + } + + // The proxy object was collected, so we need to update the cache entry. + rcwEntry.Target = wrapper; + } + + // We either added an entry to the cache or updated an existing entry that was dead. + // Return our target object. + return (wrapper, comProxy); + } + } + + public object? FindProxyForComInstance(IntPtr comPointer) + { + lock (_lock) + { + if (_cache.TryGetValue(comPointer, out GCHandle existingHandle)) + { + if (existingHandle.Target is NativeObjectWrapper { ProxyHandle.Target: object cachedProxy }) + { + // The target exists and is still alive. Return it. + return cachedProxy; + } + + // The target was collected, so we need to remove the entry from the cache. + _cache.Remove(comPointer); + existingHandle.Free(); + } + + return null; + } + } + + public void Remove(IntPtr comPointer, NativeObjectWrapper wrapper) + { + lock (_lock) + { + // TryGetOrCreateObjectForComInstanceInternal may have put a new entry into the cache + // in the time between the GC cleared the contents of the GC handle but before the + // NativeObjectWrapper finalizer ran. + // Only remove the entry if the target of the GC handle is the NativeObjectWrapper + // or is null (indicating that the corresponding NativeObjectWrapper has been scheduled for finalization). + if (_cache.TryGetValue(comPointer, out GCHandle cachedRef) + && (wrapper == cachedRef.Target + || cachedRef.Target is null)) + { + _cache.Remove(comPointer); + cachedRef.Free(); + } } } } @@ -1233,7 +1326,7 @@ internal static void ReleaseExternalObjectsFromCurrentThread() if (nativeObjectWrapper != null && nativeObjectWrapper._contextToken == contextToken) { - object? target = nativeObjectWrapper._proxyHandle.Target; + object? target = nativeObjectWrapper.ProxyHandle.Target; if (target != null) { objects.Add(target); @@ -1260,7 +1353,7 @@ internal static unsafe void WalkExternalTrackerObjects() if (nativeObjectWrapper != null && nativeObjectWrapper.TrackerObject != IntPtr.Zero) { - FindReferenceTargetsCallback.s_currentRootObjectHandle = nativeObjectWrapper._proxyHandle; + FindReferenceTargetsCallback.s_currentRootObjectHandle = nativeObjectWrapper.ProxyHandle; if (IReferenceTracker.FindTrackerTargets(nativeObjectWrapper.TrackerObject, TrackerObjectManager.s_findReferencesTargetCallback) != HResults.S_OK) { walkFailed = true; @@ -1287,7 +1380,7 @@ internal static void DetachNonPromotedObjects() ReferenceTrackerNativeObjectWrapper? nativeObjectWrapper = Unsafe.As(weakNativeObjectWrapperHandle.Target); if (nativeObjectWrapper != null && nativeObjectWrapper.TrackerObject != IntPtr.Zero && - !RuntimeImports.RhIsPromoted(nativeObjectWrapper._proxyHandle.Target)) + !RuntimeImports.RhIsPromoted(nativeObjectWrapper.ProxyHandle.Target)) { // Notify the wrapper it was not promoted and is being collected. TrackerObjectManager.BeforeWrapperFinalized(nativeObjectWrapper.TrackerObject); @@ -1608,7 +1701,7 @@ private static unsafe bool PossiblyComObject(object target) // If the RCW is an aggregated RCW, then the managed object cannot be recreated from the IUnknown // as the outer IUnknown wraps the managed object. In this case, don't create a weak reference backed // by a COM weak reference. - return s_rcwTable.TryGetValue(target, out NativeObjectWrapper? wrapper) && !wrapper._aggregatedManagedObjectWrapper; + return s_nativeObjectWrapperTable.TryGetValue(target, out NativeObjectWrapper? wrapper) && !wrapper.IsAggregatedWithManagedObjectWrapper; } private static unsafe IntPtr ObjectToComWeakRef(object target, out long wrapperId) diff --git a/src/tests/Interop/COM/ComWrappers/API/Program.cs b/src/tests/Interop/COM/ComWrappers/API/Program.cs index 7ef5d29c0ab635..3ed01e8604ad3e 100644 --- a/src/tests/Interop/COM/ComWrappers/API/Program.cs +++ b/src/tests/Interop/COM/ComWrappers/API/Program.cs @@ -10,6 +10,7 @@ namespace ComWrappersTests using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using System.Runtime.InteropServices.Marshalling; + using System.Threading; using ComWrappersTests.Common; using TestLibrary; @@ -962,6 +963,167 @@ public void ValidateAggregationWithReferenceTrackerObject() Assert.Equal(0, allocTracker.GetCount()); } + + [Fact] + public void ComWrappersNoLockAroundQueryInterface() + { + Console.WriteLine($"Running {nameof(ComWrappersNoLockAroundQueryInterface)}..."); + + var cw = new RecursiveSimpleComWrappers(); + + IntPtr comObject = cw.GetOrCreateComInterfaceForObject(new RecursiveCrossThreadQI(cw), CreateComInterfaceFlags.None); + try + { + _ = cw.GetOrCreateObjectForComInstance(comObject, CreateObjectFlags.TrackerObject); + } + finally + { + Marshal.Release(comObject); + } + } + + private class RecursiveCrossThreadQI(ComWrappers? wrappers) : ICustomQueryInterface + { + CustomQueryInterfaceResult ICustomQueryInterface.GetInterface(ref Guid iid, out IntPtr ppv) + { + ppv = IntPtr.Zero; + if (iid == ComWrappersHelper.IID_IReferenceTracker && wrappers is not null) + { + Console.WriteLine("Attempting to create a new COM object on a different thread."); + Thread thread = new Thread(() => + { + IntPtr comObject = wrappers.GetOrCreateComInterfaceForObject(new RecursiveCrossThreadQI(null), CreateComInterfaceFlags.None); + try + { + // Make sure that ComWrappers isn't locking in GetOrCreateObjectForComInstance + // around the QI call by calling it on a different thread from within a QI call to register a new managed wrapper + // for a COM object representing "this". + _ = wrappers.GetOrCreateObjectForComInstance(comObject, CreateObjectFlags.None); + } + finally + { + Marshal.Release(comObject); + } + }); + thread.Start(); + thread.Join(TimeSpan.FromSeconds(20)); // 20 seconds should be more than long enough for the thread to complete + } + + return CustomQueryInterfaceResult.Failed; + } + } + + private unsafe class RecursiveSimpleComWrappers : ComWrappers + { + protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) + { + count = 0; + return null; + } + + protected override object CreateObject(IntPtr externalComObject, CreateObjectFlags flags) + { + return new object(); + } + + protected override void ReleaseObjects(IEnumerable objects) + { + throw new NotImplementedException(); + } + } + + [Fact] + [PlatformSpecific(TestPlatforms.Windows)] // COM apartments are Windows-specific + public unsafe void CrossApartmentQueryInterface_NoDeadlock() + { + Console.WriteLine($"Running {nameof(CrossApartmentQueryInterface_NoDeadlock)}..."); + using ManualResetEvent hasAgileReference = new(false); + using ManualResetEvent testCompleted = new(false); + + IntPtr agileReference = IntPtr.Zero; + try + { + Thread staThread = new(() => + { + var cw = new RecursiveSimpleComWrappers(); + IntPtr comObject = cw.GetOrCreateComInterfaceForObject(new RecursiveQI(cw), CreateComInterfaceFlags.None); + try + { + Marshal.ThrowExceptionForHR(RoGetAgileReference(0, IUnknownVtbl.IID_IUnknown, comObject, out agileReference)); + } + finally + { + Marshal.Release(comObject); + } + hasAgileReference.Set(); + testCompleted.WaitOne(); + }); + staThread.SetApartmentState(ApartmentState.STA); + + Thread mtaThread = new(() => + { + hasAgileReference.WaitOne(); + IntPtr comObject; + int hr = ((delegate* unmanaged)(*(*(void***)agileReference + 3 /* IAgileReference.Resolve slot */)))(agileReference, IUnknownVtbl.IID_IUnknown, out comObject); + Marshal.ThrowExceptionForHR(hr); + try + { + var cw = new RecursiveSimpleComWrappers(); + // Make sure that ComWrappers isn't locking in GetOrCreateObjectForComInstance + // across the QI call + // by forcing marshalling across COM apartments. + _ = cw.GetOrCreateObjectForComInstance(comObject, CreateObjectFlags.TrackerObject); + } + finally + { + Marshal.Release(comObject); + } + testCompleted.Set(); + }); + mtaThread.SetApartmentState(ApartmentState.MTA); + + staThread.Start(); + mtaThread.Start(); + } + finally + { + if (agileReference != IntPtr.Zero) + { + Marshal.Release(agileReference); + } + } + + testCompleted.WaitOne(); + } + + [DllImport("ole32.dll")] + private static extern int RoGetAgileReference(int options, in Guid iid, IntPtr unknown, out IntPtr agileReference); + + private class RecursiveQI(ComWrappers? wrappers) : ICustomQueryInterface + { + CustomQueryInterfaceResult ICustomQueryInterface.GetInterface(ref Guid iid, out IntPtr ppv) + { + ppv = IntPtr.Zero; + if (wrappers is not null) + { + Console.WriteLine("Attempting to create a new COM object on the same thread."); + IntPtr comObject = wrappers.GetOrCreateComInterfaceForObject(new RecursiveQI(null), CreateComInterfaceFlags.None); + try + { + // Make sure that ComWrappers isn't locking in GetOrCreateObjectForComInstance + // around the QI call by calling it on a different thread from within a QI call to register a new managed wrapper + // for a COM object representing "this". + _ = wrappers.GetOrCreateObjectForComInstance(comObject, CreateObjectFlags.None); + } + finally + { + Marshal.Release(comObject); + } + } + + return CustomQueryInterfaceResult.Failed; + } + } } } diff --git a/src/tests/Interop/COM/ComWrappers/Common.cs b/src/tests/Interop/COM/ComWrappers/Common.cs index 96d21a349d69dc..dbcefa47ac175e 100644 --- a/src/tests/Interop/COM/ComWrappers/Common.cs +++ b/src/tests/Interop/COM/ComWrappers/Common.cs @@ -339,7 +339,7 @@ CustomQueryInterfaceResult ICustomQueryInterface.GetInterface(ref Guid iid, out class ComWrappersHelper { - private static Guid IID_IReferenceTracker = new Guid("11d3b13a-180e-4789-a8be-7712882893e6"); + public static readonly Guid IID_IReferenceTracker = new Guid("11d3b13a-180e-4789-a8be-7712882893e6"); [Flags] public enum ReleaseFlags