Skip to content

Commit

Permalink
Merge pull request #234 from crestalnetwork/fix/wallet-create
Browse files Browse the repository at this point in the history
fix: auto create cdp wallet when create agent
  • Loading branch information
hyacinthus authored Feb 13, 2025
2 parents 5bd783c + 0e77527 commit 7256f5d
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 21 deletions.
37 changes: 28 additions & 9 deletions app/admin/api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import json

from cdp import Wallet
from cdp.cdp import Cdp
from fastapi import APIRouter, Body, Depends, HTTPException
from pydantic import BaseModel, Field
from sqlalchemy.exc import SQLAlchemyError
Expand Down Expand Up @@ -48,11 +52,26 @@ async def create_agent(
agent.owner = subject

# Get the latest agent from create_or_update
latest_agent = await agent.create_or_update()

message = "Agent created"
if latest_agent.created_at != latest_agent.updated_at:
message = "Agent updated"
latest_agent, is_new = await agent.create_or_update()

if is_new:
message = "Agent Created"
# create the wallet
Cdp.configure(
api_key_name=config.cdp_api_key_name,
private_key=config.cdp_api_key_private_key,
)
wallet = Wallet.create(network_id=latest_agent.cdp_network_id)
wallet_data = wallet.export_data().to_dict()
wallet_data["default_address_id"] = wallet.default_address.address_id
agent_data = AgentData(
id=latest_agent.id, cdp_wallet_data=json.dumps(wallet_data)
)
await agent_data.save()
else:
message = "Agent Updated"
agent_data = await AgentData.get(latest_agent.id)
wallet_data = json.loads(agent_data.cdp_wallet_data) if agent_data else {}
# Send Slack notification
send_slack_message(
message,
Expand Down Expand Up @@ -104,20 +123,20 @@ async def create_agent(
"title": "Twitter Skills",
"value": str(latest_agent.twitter_skills),
},
{
"title": "CDP Wallet Address",
"value": wallet_data.get("default_address_id"),
},
],
}
],
)

# Mask sensitive data in response
latest_agent.cdp_wallet_data = "forbidden"
if latest_agent.skill_sets is not None:
for key in latest_agent.skill_sets:
latest_agent.skill_sets[key] = {}

# Get agent data
agent_data = await AgentData.get(latest_agent.id)

# Convert to AgentResponse
return AgentResponse.from_agent(latest_agent, agent_data)

Expand Down
3 changes: 0 additions & 3 deletions app/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,6 @@ async def initialize_agent(aid):
}
if agent_data and agent_data.cdp_wallet_data:
values["cdp_wallet_data"] = agent_data.cdp_wallet_data
elif agent.cdp_wallet_data:
# If there is a persisted agentic wallet, load it and pass to the CDP Agentkit Wrapper.
values["cdp_wallet_data"] = agent.cdp_wallet_data
agentkit = CdpAgentkitWrapper(**values)
# save the wallet after first create
if not agent_data or not agent_data.cdp_wallet_data:
Expand Down
12 changes: 3 additions & 9 deletions models/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,6 @@ class Agent(SQLModel, table=True):
cdp_network_id: Optional[str] = Field(
default="base-mainnet", description="Network identifier for CDP integration"
)
cdp_wallet_data: SkipJsonSchema[Optional[str]] = Field(
default=None, description="Deprecated: CDP wallet information"
)
# if twitter_enabled, the twitter_entrypoint will be enabled, twitter_config will be checked
twitter_entrypoint_enabled: Optional[bool] = Field(
default=False, description="Whether the agent can receive events from Twitter"
Expand Down Expand Up @@ -189,12 +186,9 @@ async def get(cls, agent_id: str) -> "Agent | None":
async with get_session() as db:
return (await db.exec(select(Agent).where(Agent.id == agent_id))).first()

async def create_or_update(self) -> "Agent":
async def create_or_update(self) -> ("Agent", bool):
"""Create the agent if not exists, otherwise update it.
Args:
db: Database session
Returns:
Agent: The created or updated agent
Expand Down Expand Up @@ -252,7 +246,7 @@ async def create_or_update(self) -> "Agent":
db.add(existing_agent)
await db.commit()
await db.refresh(existing_agent)
return existing_agent
return existing_agent, False
else:
# Check upstream_id for idempotent
async with get_session() as db:
Expand All @@ -273,7 +267,7 @@ async def create_or_update(self) -> "Agent":
db.add(self)
await db.commit()
await db.refresh(self)
return self
return self, True
except HTTPException:
await db.rollback()
raise
Expand Down

0 comments on commit 7256f5d

Please sign in to comment.