Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add drift tools to langchain #212

Merged
merged 1 commit into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions src/langchain/drift/create_user_account.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import { Tool } from "langchain/tools";
import { SolanaAgentKit } from "../../agent";

export class SolanaCreateDriftUserAccountTool extends Tool {
name = "create_drift_user_account";
description = `Create a new user account with a deposit on Drift protocol.

Inputs (JSON string):
- amount: number, amount of the token to deposit (required)
- symbol: string, symbol of the token to deposit (required)`;

constructor(private solanaKit: SolanaAgentKit) {
super();
}

protected async _call(input: string): Promise<string> {
try {
const parsedInput = JSON.parse(input);
const res = await this.solanaKit.createDriftUserAccount(
parsedInput.amount,
parsedInput.symbol,
);

return JSON.stringify({
status: "success",
message: `User account created with ${parsedInput.amount} ${parsedInput.symbol} successfully deposited`,
account: res.account,
signature: res.txSignature,
});
} catch (error: any) {
return JSON.stringify({
status: "error",
message: error.message,
code: error.code || "CREATE_DRIFT_USER_ACCOUNT_ERROR",
});
}
}
}
42 changes: 42 additions & 0 deletions src/langchain/drift/create_vault.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import { Tool } from "langchain/tools";
import { SolanaAgentKit } from "../../agent";

export class SolanaCreateDriftVaultTool extends Tool {
name = "create_drift_vault";
description = `Create a new drift vault delegating the agents address as the owner.

Inputs (JSON string):
- name: string, unique vault name (min 5 chars)
- marketName: string, market name in TOKEN-SPOT format
- redeemPeriod: number, days to wait before funds can be redeemed (min 1)
- maxTokens: number, maximum tokens vault can accommodate (min 100)
- minDepositAmount: number, minimum deposit amount
- managementFee: number, fee percentage for managing funds (max 20)
- profitShare: number, profit sharing percentage (max 90, default 5)
- hurdleRate: number, optional hurdle rate
- permissioned: boolean, whether vault has whitelist`;

constructor(private solanaKit: SolanaAgentKit) {
super();
}

protected async _call(input: string): Promise<string> {
try {
const parsedInput = JSON.parse(input);
const tx = await this.solanaKit.createDriftVault(parsedInput);

return JSON.stringify({
status: "success",
message: "Drift vault created successfully",
vaultName: parsedInput.name,
signature: tx,
});
} catch (error: any) {
return JSON.stringify({
status: "error",
message: error.message,
code: error.code || "CREATE_DRIFT_VAULT_ERROR",
});
}
}
}
37 changes: 37 additions & 0 deletions src/langchain/drift/deposit_into_vault.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import { Tool } from "langchain/tools";
import { SolanaAgentKit } from "../../agent";

export class SolanaDepositIntoDriftVaultTool extends Tool {
name = "deposit_into_drift_vault";
description = `Deposit funds into an existing drift vault.

Inputs (JSON string):
- vaultAddress: string, address of the vault (required)
- amount: number, amount to deposit (required)`;

constructor(private solanaKit: SolanaAgentKit) {
super();
}

protected async _call(input: string): Promise<string> {
try {
const parsedInput = JSON.parse(input);
const tx = await this.solanaKit.depositIntoDriftVault(
parsedInput.amount,
parsedInput.vaultAddress,
);

return JSON.stringify({
status: "success",
message: "Funds deposited successfully",
signature: tx,
});
} catch (error: any) {
return JSON.stringify({
status: "error",
message: error.message,
code: error.code || "DEPOSIT_INTO_VAULT_ERROR",
});
}
}
}
39 changes: 39 additions & 0 deletions src/langchain/drift/deposit_to_user_account.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import { Tool } from "langchain/tools";
import { SolanaAgentKit } from "../../agent";

export class SolanaDepositToDriftUserAccountTool extends Tool {
name = "deposit_to_drift_user_account";
description = `Deposit funds into your drift user account.

Inputs (JSON string):
- amount: number, amount to deposit (required)
- symbol: string, token symbol (required)
- repay: boolean, whether to repay borrowed funds (optional, default: false)`;

constructor(private solanaKit: SolanaAgentKit) {
super();
}

protected async _call(input: string): Promise<string> {
try {
const parsedInput = JSON.parse(input);
const tx = await this.solanaKit.depositToDriftUserAccount(
parsedInput.amount,
parsedInput.symbol,
parsedInput.repay,
);

return JSON.stringify({
status: "success",
message: "Funds deposited successfully",
signature: tx,
});
} catch (error: any) {
return JSON.stringify({
status: "error",
message: error.message,
code: error.code || "DEPOSIT_TO_DRIFT_ACCOUNT_ERROR",
});
}
}
}
32 changes: 32 additions & 0 deletions src/langchain/drift/derive_vault_address.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import { Tool } from "langchain/tools";
import { SolanaAgentKit } from "../../agent";

export class SolanaDeriveVaultAddressTool extends Tool {
name = "derive_drift_vault_address";
description = `Derive a drift vault address from the vault's name.

Inputs (JSON string):
- name: string, name of the vault to derive the address of (required)`;

constructor(private solanaKit: SolanaAgentKit) {
super();
}

protected async _call(input: string): Promise<string> {
try {
const address = await this.solanaKit.deriveDriftVaultAddress(input);

return JSON.stringify({
status: "success",
message: "Vault address derived successfully",
address,
});
} catch (error: any) {
return JSON.stringify({
status: "error",
message: error.message,
code: error.code || "DERIVE_VAULT_ADDRESS_ERROR",
});
}
}
}
38 changes: 38 additions & 0 deletions src/langchain/drift/does_user_have_drift_account.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import { Tool } from "langchain/tools";
import { SolanaAgentKit } from "../../agent";

export class SolanaCheckDriftAccountTool extends Tool {
name = "does_user_have_drift_account";
description = `Check if a user has a Drift account.

Inputs: No inputs required - checks the current user's account`;

constructor(private solanaKit: SolanaAgentKit) {
super();
}

protected async _call(_input: string): Promise<string> {
try {
const res = await this.solanaKit.doesUserHaveDriftAccount();

if (!res.hasAccount) {
return JSON.stringify({
status: "error",
message: "You do not have a Drift account",
});
}

return JSON.stringify({
status: "success",
message: "Nice! You have a Drift account",
account: res.account,
});
} catch (error: any) {
return JSON.stringify({
status: "error",
message: error.message,
code: error.code || "CHECK_DRIFT_ACCOUNT_ERROR",
});
}
}
}
29 changes: 29 additions & 0 deletions src/langchain/drift/drift_user_account_info.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import { Tool } from "langchain/tools";
import { SolanaAgentKit } from "../../agent";

export class SolanaDriftUserAccountInfoTool extends Tool {
name = "drift_user_account_info";
description = `Get information about your drift account.

Inputs: No inputs required - retrieves current user's account info`;

constructor(private solanaKit: SolanaAgentKit) {
super();
}

protected async _call(_input: string): Promise<string> {
try {
const accountInfo = await this.solanaKit.driftUserAccountInfo();
return JSON.stringify({
status: "success",
data: accountInfo,
});
} catch (error: any) {
return JSON.stringify({
status: "error",
message: error.message,
code: error.code || "DRIFT_ACCOUNT_INFO_ERROR",
});
}
}
}
15 changes: 15 additions & 0 deletions src/langchain/drift/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
export * from "./create_user_account";
export * from "./create_vault";
export * from "./deposit_into_vault";
export * from "./deposit_to_user_account";
export * from "./derive_vault_address";
export * from "./does_user_have_drift_account";
export * from "./drift_user_account_info";
export * from "./request_withdrawal";
export * from "./trade_delegated_vault";
export * from "./trade_perp_account";
export * from "./update_drift_vault_delegate";
export * from "./update_vault";
export * from "./vault_info";
export * from "./withdraw_from_account";
export * from "./withdraw_from_vault";
37 changes: 37 additions & 0 deletions src/langchain/drift/request_withdrawal.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import { Tool } from "langchain/tools";
import { SolanaAgentKit } from "../../agent";

export class SolanaRequestDriftWithdrawalTool extends Tool {
name = "request_withdrawal_from_drift_vault";
description = `Request a withdrawal from an existing drift vault.

Inputs (JSON string):
- vaultAddress: string, vault address (required)
- amount: number, amount of shares to withdraw (required)`;

constructor(private solanaKit: SolanaAgentKit) {
super();
}

protected async _call(input: string): Promise<string> {
try {
const parsedInput = JSON.parse(input);
const tx = await this.solanaKit.requestWithdrawalFromDriftVault(
parsedInput.amount,
parsedInput.vaultAddress,
);

return JSON.stringify({
status: "success",
message: "Withdrawal request successful",
signature: tx,
});
} catch (error: any) {
return JSON.stringify({
status: "error",
message: error.message,
code: error.code || "REQUEST_DRIFT_WITHDRAWAL_ERROR",
});
}
}
}
49 changes: 49 additions & 0 deletions src/langchain/drift/trade_delegated_vault.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import { Tool } from "langchain/tools";
import { SolanaAgentKit } from "../../agent";

export class SolanaTradeDelegatedDriftVaultTool extends Tool {
name = "trade_delegated_drift_vault";
description = `Carry out trades in a Drift vault.

Inputs (JSON string):
- vaultAddress: string, address of the Drift vault
- amount: number, amount to trade
- symbol: string, symbol of the token to trade
- action: "long" | "short", trade direction
- type: "market" | "limit", order type
- price: number, optional limit price`;

constructor(private solanaKit: SolanaAgentKit) {
super();
}

protected async _call(input: string): Promise<string> {
try {
const parsedInput = JSON.parse(input);
const tx = await this.solanaKit.tradeUsingDelegatedDriftVault(
parsedInput.vaultAddress,
parsedInput.amount,
parsedInput.symbol,
parsedInput.action,
parsedInput.type,
parsedInput.price,
);

return JSON.stringify({
status: "success",
message:
parsedInput.type === "limit"
? "Order placed successfully"
: "Trade successful",
transactionId: tx,
...parsedInput,
});
} catch (error: any) {
return JSON.stringify({
status: "error",
message: error.message,
code: error.code || "TRADE_DRIFT_VAULT_ERROR",
});
}
}
}
Loading
Loading