From 373aa64a81f032ed703b806386b0f407316c0ad8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ond=C5=99ej=20=C5=A0vanda?= <46406259+Papooch@users.noreply.github.com> Date: Fri, 16 Feb 2024 23:43:52 +0100 Subject: [PATCH] feat: add option to inject tx directly as Proxy --- packages/transactional/src/index.ts | 2 + .../src/lib/inject-transaction.decorator.ts | 23 ++ packages/transactional/src/lib/interfaces.ts | 14 +- .../src/lib/plugin-transactional.ts | 37 +++- packages/transactional/src/lib/symbols.ts | 8 +- .../transactional/src/lib/transaction-host.ts | 26 ++- .../test/inject-transaction.spec.ts | 203 ++++++++++++++++++ .../test/transaction-adapter-mock.ts | 4 +- 8 files changed, 292 insertions(+), 25 deletions(-) create mode 100644 packages/transactional/src/lib/inject-transaction.decorator.ts create mode 100644 packages/transactional/test/inject-transaction.spec.ts diff --git a/packages/transactional/src/index.ts b/packages/transactional/src/index.ts index f1fb0ebb..08d855bc 100644 --- a/packages/transactional/src/index.ts +++ b/packages/transactional/src/index.ts @@ -2,9 +2,11 @@ export * from './lib/transaction-host'; export * from './lib/transactional.decorator'; export * from './lib/plugin-transactional'; export * from './lib/propagation'; +export * from './lib/inject-transaction.decorator'; export { TransactionalAdapterOptions, TransactionalOptionsAdapterFactory, TransactionalAdapter, TransactionalPluginOptions, + Transaction, } from './lib/interfaces'; diff --git a/packages/transactional/src/lib/inject-transaction.decorator.ts b/packages/transactional/src/lib/inject-transaction.decorator.ts new file mode 100644 index 00000000..0b31a623 --- /dev/null +++ b/packages/transactional/src/lib/inject-transaction.decorator.ts @@ -0,0 +1,23 @@ +import { Inject } from '@nestjs/common'; +const TRANSACTION_TOKEN = Symbol('TRANSACTION_TOKEN'); + +/** + * Get injection token for the Transaction instance. + * If name is omitted, the default instance is used. + */ +export function getTransactionToken(connectionName?: string) { + return connectionName + ? Symbol.for(`${TRANSACTION_TOKEN.description}_${connectionName}`) + : TRANSACTION_TOKEN; +} + +/** + * Inject the Transaction instance directly (that is the `tx` property of the TransactionHost) + * + * Optionally, you can provide a connection name to inject a named instance. + * + * A shorthand for `Inject(getTransactionToken(connectionName))` + */ +export function InjectTransaction(connectionName?: string) { + return Inject(getTransactionToken(connectionName)); +} diff --git a/packages/transactional/src/lib/interfaces.ts b/packages/transactional/src/lib/interfaces.ts index 0b054e1f..fc3dcba6 100644 --- a/packages/transactional/src/lib/interfaces.ts +++ b/packages/transactional/src/lib/interfaces.ts @@ -7,9 +7,10 @@ export interface TransactionalAdapterOptions { getFallbackInstance: () => TTx; } -export interface TransactionalAdapterOptionsWithName +export interface MergedTransactionalAdapterOptions extends TransactionalAdapterOptions { - connectionName: string; + connectionName: string | undefined; + enableTransactionProxy: boolean; } export type TransactionalOptionsAdapterFactory = ( @@ -49,6 +50,12 @@ export interface TransactionalPluginOptions { * An optional name of the connection. Useful when there are multiple TransactionalPlugins registered in the app. */ connectionName?: string; + /** + * Whether to enable injecting the Transaction instance directly using `@InjectTransaction()` + * + * Default: `true` + */ + enableTransactionProxy?: boolean; } export type TTxFromAdapter = TAdapter extends TransactionalAdapter< @@ -63,3 +70,6 @@ export type TOptionsFromAdapter = TAdapter extends TransactionalAdapter ? TOptions : never; + +export type Transaction> = + TTxFromAdapter; diff --git a/packages/transactional/src/lib/plugin-transactional.ts b/packages/transactional/src/lib/plugin-transactional.ts index a7f1b52f..25a033a8 100644 --- a/packages/transactional/src/lib/plugin-transactional.ts +++ b/packages/transactional/src/lib/plugin-transactional.ts @@ -1,6 +1,10 @@ import { Provider } from '@nestjs/common'; -import { ClsPlugin } from 'nestjs-cls'; -import { TransactionalPluginOptions } from './interfaces'; +import { ClsModule, ClsPlugin } from 'nestjs-cls'; +import { getTransactionToken } from './inject-transaction.decorator'; +import { + MergedTransactionalAdapterOptions, + TransactionalPluginOptions, +} from './interfaces'; import { TRANSACTIONAL_ADAPTER_OPTIONS, TRANSACTION_CONNECTION, @@ -10,14 +14,14 @@ import { getTransactionHostToken, TransactionHost } from './transaction-host'; export class ClsPluginTransactional implements ClsPlugin { name: string; providers: Provider[]; - imports?: any[]; - exports?: any[]; + imports: any[] = []; + exports: any[] = []; constructor(options: TransactionalPluginOptions) { this.name = options.connectionName ? `cls-plugin-transactional-${options.connectionName}` : 'cls-plugin-transactional'; - this.imports = options.imports; + this.imports.push(...(options.imports ?? [])); const transactionHostToken = getTransactionHostToken( options.connectionName, ); @@ -29,12 +33,16 @@ export class ClsPluginTransactional implements ClsPlugin { { provide: TRANSACTIONAL_ADAPTER_OPTIONS, inject: [TRANSACTION_CONNECTION], - useFactory: (connection: any) => { + useFactory: ( + connection: any, + ): MergedTransactionalAdapterOptions => { const adapterOptions = options.adapter.optionsFactory(connection); return { ...adapterOptions, connectionName: options.connectionName, + enableTransactionProxy: + options.enableTransactionProxy ?? false, }; }, }, @@ -43,6 +51,21 @@ export class ClsPluginTransactional implements ClsPlugin { useClass: TransactionHost, }, ]; - this.exports = [transactionHostToken]; + this.exports.push(transactionHostToken); + + if (options.enableTransactionProxy) { + const transactionProxyToken = getTransactionToken( + options.connectionName, + ); + this.imports.push( + ClsModule.forFeatureAsync({ + provide: transactionProxyToken, + inject: [transactionHostToken], + useFactory: (txHost: TransactionHost) => txHost.tx, + type: 'function', + global: true, + }), + ); + } } } diff --git a/packages/transactional/src/lib/symbols.ts b/packages/transactional/src/lib/symbols.ts index 008a72e4..9fa30857 100644 --- a/packages/transactional/src/lib/symbols.ts +++ b/packages/transactional/src/lib/symbols.ts @@ -1,9 +1,9 @@ export const TRANSACTION_CONNECTION = Symbol('TRANSACTION_CONNECTION'); export const TRANSACTIONAL_ADAPTER_OPTIONS = Symbol('TRANSACTIONAL_OPTIONS'); -const TRANSACTIONAL_INSTANCE = Symbol('TRANSACTIONAL_CLIENT'); +const TRANSACTION_CLS_KEY = Symbol('TRANSACTION_CLS_KEY'); -export const getTransactionalInstanceSymbol = (connectionName?: string) => +export const getTransactionClsKey = (connectionName?: string) => connectionName - ? Symbol.for(`${TRANSACTIONAL_INSTANCE.toString()}_${connectionName}`) - : TRANSACTIONAL_INSTANCE; + ? Symbol.for(`${TRANSACTION_CLS_KEY.description}_${connectionName}`) + : TRANSACTION_CLS_KEY; diff --git a/packages/transactional/src/lib/transaction-host.ts b/packages/transactional/src/lib/transaction-host.ts index 3b9a80f3..69445e49 100644 --- a/packages/transactional/src/lib/transaction-host.ts +++ b/packages/transactional/src/lib/transaction-host.ts @@ -1,8 +1,9 @@ import { Inject, Injectable, Logger } from '@nestjs/common'; import { ClsServiceManager } from 'nestjs-cls'; +import { getTransactionToken } from './inject-transaction.decorator'; import { TOptionsFromAdapter, - TransactionalAdapterOptionsWithName, + MergedTransactionalAdapterOptions, TTxFromAdapter, } from './interfaces'; import { @@ -11,25 +12,22 @@ import { TransactionNotActiveError, TransactionPropagationError, } from './propagation'; -import { - getTransactionalInstanceSymbol, - TRANSACTIONAL_ADAPTER_OPTIONS, -} from './symbols'; +import { getTransactionClsKey, TRANSACTIONAL_ADAPTER_OPTIONS } from './symbols'; @Injectable() export class TransactionHost { private readonly cls = ClsServiceManager.getClsService(); private readonly logger = new Logger(TransactionHost.name); - private readonly transactionalInstanceSymbol: symbol; + private readonly transactionInstanceSymbol: symbol; constructor( @Inject(TRANSACTIONAL_ADAPTER_OPTIONS) - private readonly _options: TransactionalAdapterOptionsWithName< + private readonly _options: MergedTransactionalAdapterOptions< TTxFromAdapter, TOptionsFromAdapter >, ) { - this.transactionalInstanceSymbol = getTransactionalInstanceSymbol( + this.transactionInstanceSymbol = getTransactionClsKey( this._options.connectionName, ); } @@ -47,7 +45,7 @@ export class TransactionHost { return this._options.getFallbackInstance(); } return ( - this.cls.get(this.transactionalInstanceSymbol) ?? + this.cls.get(this.transactionInstanceSymbol) ?? this._options.getFallbackInstance() ); } @@ -216,11 +214,17 @@ export class TransactionHost { if (!this.cls.isActive()) { return false; } - return !!this.cls.get(this.transactionalInstanceSymbol); + return !!this.cls.get(this.transactionInstanceSymbol); } private setTxInstance(txInstance?: TTxFromAdapter) { - this.cls.set(this.transactionalInstanceSymbol, txInstance); + this.cls.set(this.transactionInstanceSymbol, txInstance); + if (this._options.enableTransactionProxy) { + this.cls.setProxy( + getTransactionToken(this._options.connectionName), + txInstance, + ); + } } } diff --git a/packages/transactional/test/inject-transaction.spec.ts b/packages/transactional/test/inject-transaction.spec.ts new file mode 100644 index 00000000..b5bd3d05 --- /dev/null +++ b/packages/transactional/test/inject-transaction.spec.ts @@ -0,0 +1,203 @@ +import { Injectable, Module } from '@nestjs/common'; +import { Test, TestingModule } from '@nestjs/testing'; +import { ClsModule, UseCls } from 'nestjs-cls'; +import { + ClsPluginTransactional, + InjectTransaction, + InjectTransactionHost, + Transaction, + Transactional, + TransactionHost, +} from '../src'; +import { + MockDbConnection, + TransactionAdapterMock, +} from './transaction-adapter-mock'; + +class CalledService { + constructor(private readonly tx: Transaction) {} + + async doWork(num: number) { + return this.tx.query(`SELECT ${num}`); + } +} + +@Injectable() +class CalledService1 extends CalledService { + constructor( + @InjectTransaction('named-connection') + txHost: Transaction, + ) { + super(txHost); + } +} + +@Injectable() +class CalledService2 extends CalledService { + constructor( + @InjectTransaction() + txHost: Transaction, + ) { + super(txHost); + } +} + +@Injectable() +class CallingService { + constructor( + private readonly calledService1: CalledService1, + private readonly calledService2: CalledService2, + @InjectTransactionHost('named-connection') + private readonly txHost1: TransactionHost, + @InjectTransactionHost() + private readonly txHost2: TransactionHost, + ) {} + + @UseCls() + async twoUnrelatedTransactionsWithDecorators() { + const [q1, q2] = await Promise.all([ + this.nestedStartTransaction1(1), + this.nestedStartTransaction2(2), + ]); + return { q1, q2 }; + } + + @Transactional('named-connection') + private async nestedStartTransaction1(num: number) { + return this.calledService1.doWork(num); + } + + // @UseCls() + @Transactional() + private async nestedStartTransaction2(num: number) { + return this.calledService2.doWork(num); + } + + // @UseCls() + async twoUnrelatedTransactionsWithStartTransaction() { + const [q1, q2] = await Promise.all([ + this.txHost1.withTransaction(() => this.calledService1.doWork(3)), + this.txHost2.withTransaction(() => this.calledService2.doWork(4)), + ]); + return { q1, q2 }; + } + + @UseCls() + @Transactional('named-connection') + async namedTransactionWithinAnotherNamedTransaction() { + const q1 = await this.calledService1.doWork(5); + const q2 = await this.calledService2.doWork(6); + const q3 = await this.nestedStartTransaction2(7); + return { q1, q2, q3 }; + } +} + +class MockDbConnection2 extends MockDbConnection {} +class MockDbConnection1 extends MockDbConnection {} + +@Module({ + providers: [MockDbConnection1], + exports: [MockDbConnection1], +}) +class DbConnectionModule1 {} + +@Module({ + providers: [MockDbConnection2], + exports: [MockDbConnection2], +}) +class DbConnectionModule2 {} + +@Module({ + imports: [ + ClsModule.forRoot({ + plugins: [ + new ClsPluginTransactional({ + connectionName: 'named-connection', + enableTransactionProxy: true, + imports: [DbConnectionModule1], + adapter: new TransactionAdapterMock({ + connectionToken: MockDbConnection1, + }), + }), + new ClsPluginTransactional({ + enableTransactionProxy: true, + imports: [DbConnectionModule2], + adapter: new TransactionAdapterMock({ + connectionToken: MockDbConnection2, + }), + }), + ], + }), + ], + providers: [CallingService, CalledService1, CalledService2], +}) +class AppModule {} + +describe('InjectTransaction with multiple named connections', () => { + let module: TestingModule; + let callingService: CallingService; + let mockDbConnection1: MockDbConnection; + let mockDbConnection2: MockDbConnection; + beforeEach(async () => { + module = await Test.createTestingModule({ + imports: [AppModule], + }).compile(); + await module.init(); + callingService = module.get(CallingService); + mockDbConnection1 = module.get(MockDbConnection1); + mockDbConnection2 = module.get(MockDbConnection2); + }); + + describe('when using the @Transactional decorator', () => { + it('should start two transactions independently with decorator', async () => { + const result = + await callingService.twoUnrelatedTransactionsWithDecorators(); + expect(result).toEqual({ + q1: { query: 'SELECT 1' }, + q2: { query: 'SELECT 2' }, + }); + const queries1 = mockDbConnection1.getClientsQueries(); + expect(queries1).toEqual([ + ['BEGIN TRANSACTION;', 'SELECT 1', 'COMMIT TRANSACTION;'], + ]); + const queries2 = mockDbConnection2.getClientsQueries(); + expect(queries2).toEqual([ + ['BEGIN TRANSACTION;', 'SELECT 2', 'COMMIT TRANSACTION;'], + ]); + }); + it('should start two transactions independently with startTransaction', async () => { + const result = + await callingService.twoUnrelatedTransactionsWithStartTransaction(); + expect(result).toEqual({ + q1: { query: 'SELECT 3' }, + q2: { query: 'SELECT 4' }, + }); + const queries1 = mockDbConnection1.getClientsQueries(); + expect(queries1).toEqual([ + ['BEGIN TRANSACTION;', 'SELECT 3', 'COMMIT TRANSACTION;'], + ]); + const queries2 = mockDbConnection2.getClientsQueries(); + expect(queries2).toEqual([ + ['BEGIN TRANSACTION;', 'SELECT 4', 'COMMIT TRANSACTION;'], + ]); + }); + it('ignore transactions for other named connection', async () => { + const result = + await callingService.namedTransactionWithinAnotherNamedTransaction(); + expect(result).toEqual({ + q1: { query: 'SELECT 5' }, + q2: { query: 'SELECT 6' }, + q3: { query: 'SELECT 7' }, + }); + const queries1 = mockDbConnection1.getClientsQueries(); + expect(queries1).toEqual([ + ['BEGIN TRANSACTION;', 'SELECT 5', 'COMMIT TRANSACTION;'], + ]); + const queries2 = mockDbConnection2.getClientsQueries(); + expect(queries2).toEqual([ + ['SELECT 6'], + ['BEGIN TRANSACTION;', 'SELECT 7', 'COMMIT TRANSACTION;'], + ]); + }); + }); +}); diff --git a/packages/transactional/test/transaction-adapter-mock.ts b/packages/transactional/test/transaction-adapter-mock.ts index 685ee8ab..3193fefb 100644 --- a/packages/transactional/test/transaction-adapter-mock.ts +++ b/packages/transactional/test/transaction-adapter-mock.ts @@ -22,7 +22,9 @@ export class MockDbConnection { } getClientsQueries() { - return this.clients.map((c) => c.operations); + return this.clients + .map((c) => c.operations) + .filter((o) => o.length > 0); } }