From 2a37d5047226ac2ead536a67dd32e628997a58d3 Mon Sep 17 00:00:00 2001 From: Stanley Yuen <102275989+stanleyyconsensys@users.noreply.github.com> Date: Fri, 29 Nov 2024 15:52:15 +0800 Subject: [PATCH] chore: add chain rpc controller (#443) --- .../abstract/chain-rpc-controller.test.ts | 48 +++++++++++++++++++ .../src/rpcs/abstract/chain-rpc-controller.ts | 48 +++++++++++++++++++ .../src/state/__tests__/helper.ts | 2 +- 3 files changed, 97 insertions(+), 1 deletion(-) create mode 100644 packages/starknet-snap/src/rpcs/abstract/chain-rpc-controller.test.ts create mode 100644 packages/starknet-snap/src/rpcs/abstract/chain-rpc-controller.ts diff --git a/packages/starknet-snap/src/rpcs/abstract/chain-rpc-controller.test.ts b/packages/starknet-snap/src/rpcs/abstract/chain-rpc-controller.test.ts new file mode 100644 index 00000000..b4a4d505 --- /dev/null +++ b/packages/starknet-snap/src/rpcs/abstract/chain-rpc-controller.test.ts @@ -0,0 +1,48 @@ +import { string } from 'superstruct'; + +import { mockNetworkStateManager } from '../../state/__tests__/helper'; +import { STARKNET_SEPOLIA_TESTNET_NETWORK } from '../../utils/constants'; +import { InvalidNetworkError } from '../../utils/exceptions'; +import { BaseRequestStruct } from '../../utils/superstruct'; +import { ChainRpcController } from './chain-rpc-controller'; + +describe('ChainRpcController', () => { + type Request = { chainId: string }; + class MockRpc extends ChainRpcController { + protected requestStruct = BaseRequestStruct; + + protected responseStruct = string(); + + // Set it to public to be able to spy on it + async handleRequest(params: Request) { + return `tested with ${params.chainId}`; + } + } + + it('executes request', async () => { + const network = STARKNET_SEPOLIA_TESTNET_NETWORK; + const { getNetworkSpy } = mockNetworkStateManager(network); + const { chainId } = network; + + const rpc = new MockRpc(); + const result = await rpc.execute({ + chainId, + }); + + expect(getNetworkSpy).toHaveBeenCalledWith({ chainId }); + expect(result).toBe(`tested with ${chainId}`); + }); + + it('throws `InvalidNetworkError` error if the given chainId not found.', async () => { + const network = STARKNET_SEPOLIA_TESTNET_NETWORK; + mockNetworkStateManager(null); + const { chainId } = network; + + const rpc = new MockRpc(); + await expect( + rpc.execute({ + chainId, + }), + ).rejects.toThrow(InvalidNetworkError); + }); +}); diff --git a/packages/starknet-snap/src/rpcs/abstract/chain-rpc-controller.ts b/packages/starknet-snap/src/rpcs/abstract/chain-rpc-controller.ts new file mode 100644 index 00000000..1042f8f3 --- /dev/null +++ b/packages/starknet-snap/src/rpcs/abstract/chain-rpc-controller.ts @@ -0,0 +1,48 @@ +import type { Json } from '@metamask/snaps-sdk'; + +import { NetworkStateManager } from '../../state/network-state-manager'; +import type { Network } from '../../types/snapState'; +import { InvalidNetworkError } from '../../utils/exceptions'; +import { RpcController } from '../../utils/rpc'; + +/** + * A base class for all RPC controllers that require a chainId to be provided in the request parameters. + * + * @template Request - The expected structure of the request parameters that contains the chainId property. + * @template Response - The expected structure of the response. + * @augments RpcController - The base class for all RPC controllers. + * @class ChainRpcController + */ +export abstract class ChainRpcController< + Request extends { + chainId: string; + }, + Response extends Json, +> extends RpcController { + protected network: Network; + + protected networkStateMgr: NetworkStateManager; + + constructor() { + super(); + this.networkStateMgr = new NetworkStateManager(); + } + + protected async getNetwork(chainId: string): Promise { + const network = await this.networkStateMgr.getNetwork({ chainId }); + // if the network is not in the list of networks that we support, we throw an error + if (!network) { + throw new InvalidNetworkError() as unknown as Error; + } + + return network; + } + + protected async preExecute(params: Request): Promise { + await super.preExecute(params); + + const { chainId } = params; + + this.network = await this.getNetwork(chainId); + } +} diff --git a/packages/starknet-snap/src/state/__tests__/helper.ts b/packages/starknet-snap/src/state/__tests__/helper.ts index 6f31e09c..05044039 100644 --- a/packages/starknet-snap/src/state/__tests__/helper.ts +++ b/packages/starknet-snap/src/state/__tests__/helper.ts @@ -95,7 +95,7 @@ export const mockTransactionRequestStateManager = () => { }; }; -export const mockNetworkStateManager = (network: Network) => { +export const mockNetworkStateManager = (network: Network | null) => { const getNetworkSpy = jest.spyOn(NetworkStateManager.prototype, 'getNetwork'); getNetworkSpy.mockResolvedValue(network); return {