diff --git a/packages/react-refresh/src/ReactFreshRuntime.js b/packages/react-refresh/src/ReactFreshRuntime.js index 43e8148d3de11..f7bce315ec033 100644 --- a/packages/react-refresh/src/ReactFreshRuntime.js +++ b/packages/react-refresh/src/ReactFreshRuntime.js @@ -146,6 +146,21 @@ function canPreserveStateBetween(prevType: any, nextType: any) { if (isReactClass(prevType) || isReactClass(nextType)) { return false; } + + if (typeof prevType !== typeof nextType) { + return false; + } else if ( + typeof prevType === 'object' && + prevType !== null && + nextType !== null + ) { + if ( + getProperty(prevType, '$$typeof') !== getProperty(nextType, '$$typeof') + ) { + return false; + } + } + if (haveEqualSignatures(prevType, nextType)) { return true; } @@ -183,6 +198,18 @@ function getProperty(object: any, property: string) { } } +function registerRefreshUpdate( + update: RefreshUpdate, + family: Family, + shouldPreserveState: boolean, +) { + if (shouldPreserveState) { + update.updatedFamilies.add(family); + } else { + update.staleFamilies.add(family); + } +} + export function performReactRefresh(): RefreshUpdate | null { if (!__DEV__) { throw new Error( @@ -200,6 +227,11 @@ export function performReactRefresh(): RefreshUpdate | null { try { const staleFamilies = new Set(); const updatedFamilies = new Set(); + // TODO: rename these fields to something more meaningful. + const update: RefreshUpdate = { + updatedFamilies, // Families that will re-render preserving state + staleFamilies, // Families that will be remounted + }; const updates = pendingUpdates; pendingUpdates = []; @@ -211,19 +243,33 @@ export function performReactRefresh(): RefreshUpdate | null { updatedFamiliesByType.set(nextType, family); family.current = nextType; - // Determine whether this should be a re-render or a re-mount. - if (canPreserveStateBetween(prevType, nextType)) { - updatedFamilies.add(family); - } else { - staleFamilies.add(family); + const shouldPreserveState = canPreserveStateBetween(prevType, nextType); + + if (typeof prevType === 'object' && prevType !== null) { + const nextFamily = { + current: + getProperty(nextType, '$$typeof') === REACT_FORWARD_REF_TYPE + ? nextType.render + : getProperty(nextType, '$$typeof') === REACT_MEMO_TYPE + ? nextType.type + : nextType, + }; + switch (getProperty(prevType, '$$typeof')) { + case REACT_FORWARD_REF_TYPE: { + updatedFamiliesByType.set(prevType.render, nextFamily); + registerRefreshUpdate(update, nextFamily, shouldPreserveState); + break; + } + case REACT_MEMO_TYPE: + updatedFamiliesByType.set(prevType.type, nextFamily); + registerRefreshUpdate(update, nextFamily, shouldPreserveState); + break; + } } - }); - // TODO: rename these fields to something more meaningful. - const update: RefreshUpdate = { - updatedFamilies, // Families that will re-render preserving state - staleFamilies, // Families that will be remounted - }; + // Determine whether this should be a re-render or a re-mount. + registerRefreshUpdate(update, family, shouldPreserveState); + }); helpersByRendererID.forEach(helpers => { // Even if there are no roots, set the handler on first update. diff --git a/packages/react-refresh/src/__tests__/ReactFresh-test.js b/packages/react-refresh/src/__tests__/ReactFresh-test.js index 6fb00a66a24b1..f5148f36040a6 100644 --- a/packages/react-refresh/src/__tests__/ReactFresh-test.js +++ b/packages/react-refresh/src/__tests__/ReactFresh-test.js @@ -699,6 +699,242 @@ describe('ReactFresh', () => { } }); + it('can remount when change function to memo', async () => { + if (__DEV__) { + await act(async () => { + await render(() => { + function Test() { + return

hi test

; + } + $RefreshReg$(Test, 'Test'); + return Test; + }); + }); + + // Check the initial render + const el = container.firstChild; + expect(el.textContent).toBe('hi test'); + + // Patch to change function to memo + await act(async () => { + await patch(() => { + function Test2() { + return

hi memo

; + } + const Test = React.memo(Test2); + $RefreshReg$(Test2, 'Test2'); + $RefreshReg$(Test, 'Test'); + return Test; + }); + }); + + // Check remount + expect(container.firstChild).not.toBe(el); + const nextEl = container.firstChild; + expect(nextEl.textContent).toBe('hi memo'); + + // Patch back to original function + await act(async () => { + await patch(() => { + function Test() { + return

hi test

; + } + $RefreshReg$(Test, 'Test'); + return Test; + }); + }); + + // Check final remount + expect(container.firstChild).not.toBe(nextEl); + const newEl = container.firstChild; + expect(newEl.textContent).toBe('hi test'); + } + }); + + it('can remount when change memo to forwardRef', async () => { + if (__DEV__) { + await act(async () => { + await render(() => { + function Test2() { + return

hi memo

; + } + const Test = React.memo(Test2); + $RefreshReg$(Test2, 'Test2'); + $RefreshReg$(Test, 'Test'); + return Test; + }); + }); + // Check the initial render + const el = container.firstChild; + expect(el.textContent).toBe('hi memo'); + + // Patch to change memo to forwardRef + await act(async () => { + await patch(() => { + function Test2() { + return

hi forwardRef

; + } + const Test = React.forwardRef(Test2); + $RefreshReg$(Test2, 'Test2'); + $RefreshReg$(Test, 'Test'); + return Test; + }); + }); + // Check remount + expect(container.firstChild).not.toBe(el); + const nextEl = container.firstChild; + expect(nextEl.textContent).toBe('hi forwardRef'); + + // Patch back to memo + await act(async () => { + await patch(() => { + function Test2() { + return

hi memo

; + } + const Test = React.memo(Test2); + $RefreshReg$(Test2, 'Test2'); + $RefreshReg$(Test, 'Test'); + return Test; + }); + }); + // Check final remount + expect(container.firstChild).not.toBe(nextEl); + const newEl = container.firstChild; + expect(newEl.textContent).toBe('hi memo'); + } + }); + + it('can remount when change function to forwardRef', async () => { + if (__DEV__) { + await act(async () => { + await render(() => { + function Test() { + return

hi test

; + } + $RefreshReg$(Test, 'Test'); + return Test; + }); + }); + + // Check the initial render + const el = container.firstChild; + expect(el.textContent).toBe('hi test'); + + // Patch to change function to forwardRef + await act(async () => { + await patch(() => { + function Test2() { + return

hi forwardRef

; + } + const Test = React.forwardRef(Test2); + $RefreshReg$(Test2, 'Test2'); + $RefreshReg$(Test, 'Test'); + return Test; + }); + }); + + // Check remount + expect(container.firstChild).not.toBe(el); + const nextEl = container.firstChild; + expect(nextEl.textContent).toBe('hi forwardRef'); + + // Patch back to a new function + await act(async () => { + await patch(() => { + function Test() { + return

hi test1

; + } + $RefreshReg$(Test, 'Test'); + return Test; + }); + }); + + // Check final remount + expect(container.firstChild).not.toBe(nextEl); + const newEl = container.firstChild; + expect(newEl.textContent).toBe('hi test1'); + } + }); + + it('resets state when switching between different component types', async () => { + if (__DEV__) { + await act(async () => { + await render(() => { + function Test() { + const [count, setCount] = React.useState(0); + return ( +
setCount(c => c + 1)}>count: {count}
+ ); + } + $RefreshReg$(Test, 'Test'); + return Test; + }); + }); + + expect(container.firstChild.textContent).toBe('count: 0'); + await act(async () => { + container.firstChild.click(); + }); + expect(container.firstChild.textContent).toBe('count: 1'); + + await act(async () => { + await patch(() => { + function Test2() { + const [count, setCount] = React.useState(0); + return ( +
setCount(c => c + 1)}>count: {count}
+ ); + } + const Test = React.memo(Test2); + $RefreshReg$(Test2, 'Test2'); + $RefreshReg$(Test, 'Test'); + return Test; + }); + }); + + expect(container.firstChild.textContent).toBe('count: 0'); + await act(async () => { + container.firstChild.click(); + }); + expect(container.firstChild.textContent).toBe('count: 1'); + + await act(async () => { + await patch(() => { + const Test = React.forwardRef((props, ref) => { + const [count, setCount] = React.useState(0); + const handleClick = () => setCount(c => c + 1); + + // Ensure ref is extensible + const divRef = React.useRef(null); + React.useEffect(() => { + if (ref) { + if (typeof ref === 'function') { + ref(divRef.current); + } else if (Object.isExtensible(ref)) { + ref.current = divRef.current; + } + } + }, [ref]); + + return ( +
+ count: {count} +
+ ); + }); + $RefreshReg$(Test, 'Test'); + return Test; + }); + }); + + expect(container.firstChild.textContent).toBe('count: 0'); + await act(async () => { + container.firstChild.click(); + }); + expect(container.firstChild.textContent).toBe('count: 1'); + } + }); + it('can update simple memo function in isolation', async () => { if (__DEV__) { await render(() => {