Skip to content

Commit

Permalink
Merge pull request #53 from crestalnetwork/feat/twitter-limit
Browse files Browse the repository at this point in the history
Feat: twitter limit
  • Loading branch information
taiyangc authored Jan 17, 2025
2 parents 72e1f9c + 5b313e2 commit 07a1f6a
Show file tree
Hide file tree
Showing 13 changed files with 195 additions and 122 deletions.
86 changes: 41 additions & 45 deletions app/admin/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,52 +46,48 @@ def create_agent(agent: Agent, db: Session = Depends(get_db)) -> Agent:
# Get the latest agent from create_or_update
latest_agent = agent.create_or_update(db)

# Send Slack notification only for new agents
# Send Slack notification
total_agents = db.exec(select(func.count()).select_from(Agent)).one()
if (
total_agents == 1
or not db.exec(select(Agent).filter(Agent.id == agent.id)).first()
):
send_slack_message(
"New agent created ",
attachments=[
{
"color": "good",
"fields": [
{"title": "ENV", "short": True, "value": config.env},
{"title": "Total", "short": True, "value": total_agents},
{"title": "ID", "short": True, "value": latest_agent.id},
{"title": "Name", "short": True, "value": latest_agent.name},
{"title": "Model", "short": True, "value": latest_agent.model},
{
"title": "Autonomous",
"short": True,
"value": str(latest_agent.autonomous_enabled),
},
{
"title": "Twitter",
"short": True,
"value": str(latest_agent.twitter_enabled),
},
{
"title": "Telegram",
"short": True,
"value": str(latest_agent.telegram_enabled),
},
{
"title": "CDP Enabled",
"short": True,
"value": str(latest_agent.cdp_enabled),
},
{
"title": "CDP Network",
"short": True,
"value": latest_agent.cdp_network_id or "Default",
},
],
}
],
)
send_slack_message(
"New agent created ",
attachments=[
{
"color": "good",
"fields": [
{"title": "ENV", "short": True, "value": config.env},
{"title": "Total", "short": True, "value": total_agents},
{"title": "ID", "short": True, "value": latest_agent.id},
{"title": "Name", "short": True, "value": latest_agent.name},
{"title": "Model", "short": True, "value": latest_agent.model},
{
"title": "Autonomous",
"short": True,
"value": str(latest_agent.autonomous_enabled),
},
{
"title": "Twitter",
"short": True,
"value": str(latest_agent.twitter_enabled),
},
{
"title": "Telegram",
"short": True,
"value": str(latest_agent.telegram_enabled),
},
{
"title": "CDP Enabled",
"short": True,
"value": str(latest_agent.cdp_enabled),
},
{
"title": "CDP Network",
"short": True,
"value": latest_agent.cdp_network_id or "Default",
},
],
}
],
)

# Mask sensitive data in response
latest_agent.cdp_wallet_data = "forbidden"
Expand Down
32 changes: 22 additions & 10 deletions app/services/twitter/oauth2_refresh.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,38 +35,50 @@ def get_expiring_tokens(db: Session, minutes_threshold: int = 10) -> list[AgentD
).all()


def refresh_token(db: Session, agent: AgentData) -> bool:
def refresh_token(db: Session, agent: AgentData):
"""Refresh Twitter OAuth2 token for an agent.
Args:
db: Database session
agent: Agent data record containing refresh token
Returns:
bool: True if refresh successful, False otherwise
"""
try:
# Get new token using refresh token
token = oauth2_user_handler.refresh(agent.twitter_refresh_token)

token = {} if token is None else token

# Update token information
agent.twitter_access_token = token["access_token"]
if "access_token" in token:
agent.twitter_access_token = token["access_token"]
else:
agent.twitter_access_token = None
if "refresh_token" in token: # Some providers return new refresh tokens
agent.twitter_refresh_token = token["refresh_token"]
agent.twitter_access_token_expires_at = datetime.fromtimestamp(
token["expires_at"], tz=timezone.utc
)
else:
agent.twitter_refresh_token = None
if "expires_at" in token:
agent.twitter_access_token_expires_at = datetime.fromtimestamp(
token["expires_at"], tz=timezone.utc
)
else:
agent.twitter_access_token_expires_at = None

# Save changes
db.add(agent)
db.commit()
db.refresh(agent)

logger.info(f"Refreshed token for agent {agent.id}")
return True
except Exception as e:
logger.error(f"Failed to refresh token for agent {agent.id}: {str(e)}")
return False
# if error, reset token
agent.twitter_access_token = None
agent.twitter_refresh_token = None
agent.twitter_access_token_expires_at = None
db.add(agent)
db.commit()
db.refresh(agent)


def refresh_expiring_tokens():
Expand Down
35 changes: 0 additions & 35 deletions debug/create_agent.py

This file was deleted.

10 changes: 9 additions & 1 deletion docs/create_agent.sh
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ END_OF_APPEND
# If you enable autonomous mode, the agent will automatically run the autonomous_prompt every N minutes
AUTONOMOUS_ENABLED=false
AUTONOMOUS_MINUTES=60
AUTONOMOUS_PROMPT="Autonomous mode prompt"
read -r -d '' AUTONOMOUS_PROMPT_TEXT << 'END_OF_AUTONOMOUS_PROMPT'
Check twitter for new mentions, choose the best one and reply it. If there is no mention, just have a rest, don't post anything.
END_OF_AUTONOMOUS_PROMPT

# CDP settings (optional)
# Skill list: https://docs.cdp.coinbase.com/agentkit/docs/wallet-management
Expand Down Expand Up @@ -82,6 +84,9 @@ PROMPT="$(echo "$PROMPT_TEXT" | awk '{printf "%s\\n", $0}' | sed 's/"/\\"/g' | s
# Convert multiline text to escaped string
PROMPT_APPEND="$(echo "$PROMPT_APPEND_TEXT" | awk '{printf "%s\\n", $0}' | sed 's/"/\\"/g' | sed '$ s/\\n$//')"

# Autonomous mode prompt
AUTONOMOUS_PROMPT="$(echo "$AUTONOMOUS_PROMPT_TEXT" | awk '{printf "%s\\n", $0}' | sed 's/"/\\"/g' | sed '$ s/\\n$//')"

# Create JSON payload
JSON_DATA=$(cat << EOF
{
Expand All @@ -97,6 +102,9 @@ JSON_DATA=$(cat << EOF
"cdp_skills": $CDP_SKILLS,
"cdp_wallet_data": "$CDP_WALLET_DATA",
"cdp_network_id": "$CDP_NETWORK_ID",
"enso_enabled": $ENSO_ENABLED,
"enso_config": $ENSO_CONFIG,
"enso_skills": $ENSO_SKILLS,
"twitter_enabled": $TWITTER_ENTRYPOINT_ENABLED,
"twitter_entrypoint_enabled": $TWITTER_ENTRYPOINT_ENABLED,
"twitter_config": $TWITTER_CONFIG,
Expand Down
67 changes: 55 additions & 12 deletions skills/twitter/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from datetime import datetime
from datetime import datetime, timedelta, timezone
from typing import Type

from pydantic import BaseModel, Field
Expand All @@ -8,17 +8,6 @@
from abstracts.twitter import TwitterABC


class Tweet(BaseModel):
"""Model representing a Twitter tweet."""

id: str
text: str
author_id: str
created_at: datetime
referenced_tweets: list[dict] | None = None
attachments: dict | None = None


class TwitterBaseTool(IntentKitSkill):
"""Base class for Twitter tools."""

Expand All @@ -31,3 +20,57 @@ class TwitterBaseTool(IntentKitSkill):
description="The agent store for persisting data"
)
store: SkillStoreABC = Field(description="The skill store for persisting data")

def check_rate_limit(
self, max_requests: int = 1, interval: int = 15
) -> tuple[bool, str | None]:
"""Check if the rate limit has been exceeded.
Args:
max_requests: Maximum number of requests allowed within the rate limit window.
interval: Time interval in minutes for the rate limit window.
Returns:
tuple[bool, str | None]: (is_rate_limited, error_message)
"""
rate_limit = self.store.get_agent_skill_data(
self.agent_id, self.name, "rate_limit"
)

current_time = datetime.now(tz=timezone.utc)

if (
rate_limit
and rate_limit.get("reset_time")
and rate_limit["count"] is not None
and datetime.fromisoformat(rate_limit["reset_time"]) > current_time
):
if rate_limit["count"] >= max_requests:
return True, "Rate limit exceeded"
else:
rate_limit["count"] += 1
self.store.save_agent_skill_data(
self.agent_id, self.name, "rate_limit", rate_limit
)
return False, None

# If no rate limit exists or it has expired, create a new one
new_rate_limit = {
"count": 1,
"reset_time": (current_time + timedelta(minutes=interval)).isoformat(),
}
self.store.save_agent_skill_data(
self.agent_id, self.name, "rate_limit", new_rate_limit
)
return False, None


class Tweet(BaseModel):
"""Model representing a Twitter tweet."""

id: str
text: str
author_id: str
created_at: datetime
referenced_tweets: list[dict] | None = None
attachments: dict | None = None
7 changes: 7 additions & 0 deletions skills/twitter/follow_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ def _run(self, user_id: str) -> TwitterFollowUserOutput:
Exception: If there's an error accessing the Twitter API.
"""
try:
# Check rate limit
is_rate_limited, error = self.check_rate_limit(max_requests=1, interval=15)
if is_rate_limited:
return TwitterFollowUserOutput(
success=False, message=f"Error following user: {error}"
)

client = self.twitter.get_client()
if not client:
return TwitterFollowUserOutput(
Expand Down
9 changes: 9 additions & 0 deletions skills/twitter/get_mentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,14 @@ def _run(self) -> TwitterGetMentionsOutput:
Exception: If there's an error accessing the Twitter API.
"""
try:
# Check rate limit
is_rate_limited, error = self.check_rate_limit(max_requests=1)
if is_rate_limited:
return TwitterGetMentionsOutput(
mentions=[],
error=error,
)

# get since id from store
last = self.store.get_agent_skill_data(self.agent_id, self.name, "last")
last = last or {}
Expand Down Expand Up @@ -88,6 +96,7 @@ def _run(self) -> TwitterGetMentionsOutput:

result = []
if mentions.data:
# Process and return results
for tweet in mentions.data:
mention = Tweet(
id=str(tweet.id),
Expand Down
15 changes: 11 additions & 4 deletions skills/twitter/get_timeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,27 @@ class TwitterGetTimeline(TwitterBaseTool):
description: str = "Get tweets from the authenticated user's timeline"
args_schema: Type[BaseModel] = TwitterGetTimelineInput

def _run(self) -> TwitterGetTimelineOutput:
"""Run the tool to get timeline tweets.
def _run(self, max_results: int = 10) -> TwitterGetTimelineOutput:
"""Run the tool to get the user's timeline.
Args:
max_results (int, optional): Maximum number of tweets to retrieve. Defaults to 10.
Returns:
TwitterGetTimelineOutput: A structured output containing the timeline tweets data.
TwitterGetTimelineOutput: A structured output containing the timeline data.
Raises:
Exception: If there's an error accessing the Twitter API.
"""
try:
# Check rate limit
is_rate_limited, error = self.check_rate_limit(max_requests=1, interval=15)
if is_rate_limited:
return TwitterGetTimelineOutput(tweets=[], error=error)

# get since id from store
last = self.store.get_agent_skill_data(self.agent_id, self.name, "last")
last = last or {}
max_results = 10
since_id = last.get("since_id")
if since_id:
max_results = 100
Expand Down
7 changes: 7 additions & 0 deletions skills/twitter/like_tweet.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ def _run(self, tweet_id: str) -> TwitterLikeTweetOutput:
Exception: If there's an error accessing the Twitter API.
"""
try:
# Check rate limit
is_rate_limited, error = self.check_rate_limit(max_requests=1, interval=15)
if is_rate_limited:
return TwitterLikeTweetOutput(
success=False, message=f"Error liking tweet: {error}"
)

client = self.twitter.get_client()
if not client:
return TwitterLikeTweetOutput(
Expand Down
5 changes: 5 additions & 0 deletions skills/twitter/post_tweet.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ def _run(self, text: str) -> str:
Exception: If there's an error posting to the Twitter API.
"""
try:
# Check rate limit
is_rate_limited, error = self.check_rate_limit(max_requests=1, interval=15)
if is_rate_limited:
return f"Error posting tweet: {error}"

client = self.twitter.get_client()
if not client:
return "Failed to get Twitter client. Please check your authentication."
Expand Down
Loading

0 comments on commit 07a1f6a

Please sign in to comment.