Skip to content

Commit

Permalink
Merge branch 'feat/enable-multiple-accounts' into feat/account-event-…
Browse files Browse the repository at this point in the history
…change-get-starknet
  • Loading branch information
khanti42 authored Jan 20, 2025
2 parents e945f78 + d26d8de commit 280bbd7
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 154 deletions.
139 changes: 26 additions & 113 deletions packages/starknet-snap/src/on-home-page.test.ts
Original file line number Diff line number Diff line change
@@ -1,79 +1,48 @@
import { ethers } from 'ethers';
import { constants } from 'starknet';

import { generateAccounts, type StarknetAccount } from './__tests__/helper';
import { HomePageController } from './on-home-page';
import type { Network, SnapState } from './types/snapState';
import { setupAccountController } from './rpcs/__tests__/helper';
import type { Network } from './types/snapState';
import {
BlockIdentifierEnum,
ETHER_MAINNET,
STARKNET_SEPOLIA_TESTNET_NETWORK,
STARKNET_MAINNET_NETWORK,
} from './utils/constants';
import * as snapHelper from './utils/snap';
import * as starknetUtils from './utils/starknetUtils';

jest.mock('./utils/snap');
jest.mock('./utils/logger');

describe('homepageController', () => {
const state: SnapState = {
accContracts: [],
erc20Tokens: [],
networks: [STARKNET_SEPOLIA_TESTNET_NETWORK],
transactions: [],
currentNetwork: STARKNET_SEPOLIA_TESTNET_NETWORK,
};

const mockAccount = async (chainId: constants.StarknetChainId) => {
return (await generateAccounts(chainId, 1))[0];
};

const mockState = async () => {
const getStateDataSpy = jest.spyOn(snapHelper, 'getStateData');
getStateDataSpy.mockResolvedValue(state);
return {
getStateDataSpy,
};
};
const currentNetwork = STARKNET_MAINNET_NETWORK;

class MockHomePageController extends HomePageController {
async getAddress(network: Network): Promise<string> {
return super.getAddress(network);
}

async getBalance(network: Network, address: string): Promise<string> {
return super.getBalance(network, address);
}
}

describe('execute', () => {
const prepareExecuteMock = (account: StarknetAccount, balance: string) => {
const getAddressSpy = jest.spyOn(
MockHomePageController.prototype,
'getAddress',
);
const setupExecuteTest = async (network: Network, balance = '1000') => {
const { account } = await setupAccountController({ network });

const getBalanceSpy = jest.spyOn(
MockHomePageController.prototype,
'getBalance',
);
getAddressSpy.mockResolvedValue(account.address);
getBalanceSpy.mockResolvedValue(balance);

return {
getAddressSpy,
account,
getBalanceSpy,
};
};

it('returns the correct homepage response', async () => {
const { currentNetwork } = state;
await mockState();
const account = await mockAccount(
currentNetwork?.chainId as unknown as constants.StarknetChainId,
);
const balance = '100';

const { getAddressSpy, getBalanceSpy } = prepareExecuteMock(
account,
const { getBalanceSpy, account } = await setupExecuteTest(
currentNetwork,
balance,
);

Expand Down Expand Up @@ -119,20 +88,15 @@ describe('homepageController', () => {
type: 'panel',
},
});
expect(getAddressSpy).toHaveBeenCalledWith(currentNetwork);
expect(getBalanceSpy).toHaveBeenCalledWith(
currentNetwork,
account.address,
);
});

it('throws `Failed to initialize Snap HomePage` error if an error was thrown', async () => {
await mockState();
const account = await mockAccount(constants.StarknetChainId.SN_SEPOLIA);
const balance = '100';

const { getAddressSpy } = prepareExecuteMock(account, balance);
getAddressSpy.mockReset().mockRejectedValue(new Error('error'));
const { getBalanceSpy } = await setupExecuteTest(currentNetwork);
getBalanceSpy.mockReset().mockRejectedValue(new Error('error'));

const homepageController = new MockHomePageController();
await expect(homepageController.execute()).rejects.toThrow(
Expand All @@ -141,84 +105,33 @@ describe('homepageController', () => {
});
});

describe('getAddress', () => {
const prepareGetAddressMock = async (account: StarknetAccount) => {
const getKeysFromAddressSpy = jest.spyOn(
starknetUtils,
'getKeysFromAddressIndex',
);

getKeysFromAddressSpy.mockResolvedValue({
privateKey: account.privateKey,
publicKey: account.publicKey,
addressIndex: account.addressIndex,
derivationPath: account.derivationPath as unknown as any,
});

const getCorrectContractAddressSpy = jest.spyOn(
starknetUtils,
'getCorrectContractAddress',
);
getCorrectContractAddressSpy.mockResolvedValue({
address: account.address,
signerPubKey: account.publicKey,
upgradeRequired: false,
deployRequired: false,
});
return {
getKeysFromAddressSpy,
getCorrectContractAddressSpy,
};
};

it('returns the correct homepage response', async () => {
const network = STARKNET_SEPOLIA_TESTNET_NETWORK;
await mockState();
const account = await mockAccount(constants.StarknetChainId.SN_SEPOLIA);
const { getKeysFromAddressSpy, getCorrectContractAddressSpy } =
await prepareGetAddressMock(account);

const homepageController = new MockHomePageController();
const result = await homepageController.getAddress(network);

expect(result).toStrictEqual(account.address);
expect(getKeysFromAddressSpy).toHaveBeenCalledWith(
// BIP44 Deriver has mocked as undefined, hence this argument should be undefined
undefined,
network.chainId,
state,
0,
);
expect(getCorrectContractAddressSpy).toHaveBeenCalledWith(
network,
account.publicKey,
);
});
});

describe('getBalance', () => {
const prepareGetBalanceMock = async (balance: number) => {
const setupGetBalanceTest = async (network: Network, balance: number) => {
const { account } = await setupAccountController({ network });

const getBalanceSpy = jest.spyOn(starknetUtils, 'getBalance');

getBalanceSpy.mockResolvedValue(balance.toString(16));

return {
account,
getBalanceSpy,
};
};

it('returns the balance on pending block', async () => {
const network = STARKNET_SEPOLIA_TESTNET_NETWORK;
const token = ETHER_MAINNET;
const expectedBalance = 100;
await mockState();
const { address } = await mockAccount(
constants.StarknetChainId.SN_SEPOLIA,
const { getBalanceSpy, account } = await setupGetBalanceTest(
currentNetwork,
expectedBalance,
);
const { getBalanceSpy } = await prepareGetBalanceMock(expectedBalance);

const homepageController = new MockHomePageController();
const result = await homepageController.getBalance(network, address);
const result = await homepageController.getBalance(
currentNetwork,
account.address,
);

expect(result).toStrictEqual(
ethers.utils.formatUnits(
Expand All @@ -227,9 +140,9 @@ describe('homepageController', () => {
),
);
expect(getBalanceSpy).toHaveBeenCalledWith(
address,
account.address,
token.address,
network,
currentNetwork,
BlockIdentifierEnum.Pending,
);
});
Expand Down
50 changes: 14 additions & 36 deletions packages/starknet-snap/src/on-home-page.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,11 @@ import {
import { ethers } from 'ethers';

import { NetworkStateManager } from './state/network-state-manager';
import type { Network, SnapState } from './types/snapState';
import {
getBip44Deriver,
getDappUrl,
getStateData,
logger,
toJson,
} from './utils';
import type { Network } from './types/snapState';
import { getDappUrl, logger, toJson } from './utils';
import { BlockIdentifierEnum, ETHER_MAINNET } from './utils/constants';
import {
getBalance,
getCorrectContractAddress,
getKeysFromAddressIndex,
} from './utils/starknetUtils';

import { createAccountService } from './utils/factory';
import { getBalance } from './utils/starknetUtils';
/**
* The onHomePage handler to execute the home page event operation.
*/
Expand All @@ -37,20 +27,24 @@ export class HomePageController {

/**
* Execute the on home page event operation.
* It derives an account address with index 0 and retrieves the spendable balance of ETH.
* It returns a snap panel component with the address, network, and balance.
* It returns the component that contains the address, network, and balance for the current account.
*
* @returns A promise that resolve to a OnHomePageResponse object.
* @returns A promise that resolve to a `OnHomePageResponse` object.
*/
async execute(): Promise<OnHomePageResponse> {
try {
const network = await this.networkStateMgr.getCurrentNetwork();

const address = await this.getAddress(network);
const accountService = createAccountService(network);

const balance = await this.getBalance(network, address);
const account = await accountService.getCurrentAccount();

return this.buildComponents(address, network, balance);
const balance = await this.getBalance(network, account.address);

// FIXME: The SNAP UI render method in buildComponents is deprecated,
// However, there is some tricky issue when using JSX components here,
// so we will keep using the deprecated method for now.
return this.buildComponents(account.address, network, balance);
} catch (error) {
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
logger.error('Failed to execute onHomePage', toJson(error));
Expand All @@ -59,22 +53,6 @@ export class HomePageController {
}
}

protected async getAddress(network: Network): Promise<string> {
const deriver = await getBip44Deriver();
const state = await getStateData<SnapState>();

const { publicKey } = await getKeysFromAddressIndex(
deriver,
network.chainId,
state,
0,
);

const { address } = await getCorrectContractAddress(network, publicKey);

return address;
}

protected async getBalance(
network: Network,
address: string,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,15 @@ export const AccountSwitchModalView = ({
starkName,
}: Props) => {
const networks = useAppSelector((state) => state.networks);
const { switchAccount, addNewAccount } = useStarkNetSnap();
const { switchAccount, initWalletData, addNewAccount } = useStarkNetSnap();
const chainId = networks?.items[networks.activeNetwork]?.chainId;

const changeAccount = async (currentAddress: string) => {
await switchAccount(chainId, currentAddress);
const account = await switchAccount(chainId, currentAddress);
await initWalletData({
account,
chainId,
});
};

return (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ export const PopperTooltipView = ({
});

return (
<>
<div style={{ zIndex: 1 }}>
<Wrapper
ref={setTriggerRef}
onClick={handleOnClick}
Expand All @@ -91,6 +91,6 @@ export const PopperTooltipView = ({
<ToolTipContent style={contentStyle}>{content}</ToolTipContent>
</PopperContainer>
)}
</>
</div>
);
};
2 changes: 1 addition & 1 deletion packages/wallet-ui/src/services/useStarkNetSnap.ts
Original file line number Diff line number Diff line change
Expand Up @@ -808,7 +808,7 @@ export const useStarkNetSnap = () => {
);
try {
const account = await invokeSnap<Account>({
method: 'starkNet_switchAccount',
method: 'starkNet_swtichAccount',
params: {
chainId,
address,
Expand Down

0 comments on commit 280bbd7

Please sign in to comment.