-
Notifications
You must be signed in to change notification settings - Fork 576
/
Copy pathskill.py
275 lines (228 loc) · 9 KB
/
skill.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
from datetime import datetime, timezone
from typing import Annotated, Any, Dict, List, NotRequired, Optional, TypedDict
from pydantic import BaseModel, ConfigDict, Field
from sqlalchemy import Column, DateTime, String, delete, func, select
from sqlalchemy.dialects.postgresql import JSONB
from models.base import Base
from models.db import get_session
class SkillConfig(TypedDict):
"""Abstract base class for skill configuration."""
public_skills: List[str]
private_skills: NotRequired[List[str]]
__extra__: NotRequired[Dict[str, Any]]
class AgentSkillDataTable(Base):
"""Database table model for storing skill-specific data for agents."""
__tablename__ = "agent_skill_data"
agent_id = Column(String, primary_key=True)
skill = Column(String, primary_key=True)
key = Column(String, primary_key=True)
data = Column(JSONB, nullable=True)
created_at = Column(
DateTime(timezone=True),
nullable=False,
server_default=func.now(),
)
updated_at = Column(
DateTime(timezone=True),
nullable=False,
server_default=func.now(),
onupdate=lambda: datetime.now(timezone.utc),
)
class AgentSkillDataCreate(BaseModel):
"""Base model for creating agent skill data records."""
model_config = ConfigDict(from_attributes=True)
agent_id: Annotated[str, Field(description="ID of the agent this data belongs to")]
skill: Annotated[str, Field(description="Name of the skill this data is for")]
key: Annotated[str, Field(description="Key for this specific piece of data")]
data: Annotated[Dict[str, Any], Field(description="JSON data stored for this key")]
async def save(self) -> "AgentSkillData":
"""Save or update skill data.
Returns:
AgentSkillData: The saved agent skill data instance
"""
async with get_session() as db:
record = await db.scalar(
select(AgentSkillDataTable).where(
AgentSkillDataTable.agent_id == self.agent_id,
AgentSkillDataTable.skill == self.skill,
AgentSkillDataTable.key == self.key,
)
)
if record:
# Update existing record
record.data = self.data
else:
# Create new record
record = AgentSkillDataTable(**self.model_dump())
db.add(record)
await db.commit()
await db.refresh(record)
return AgentSkillData.model_validate(record)
class AgentSkillData(AgentSkillDataCreate):
"""Model for storing skill-specific data for agents.
This model uses a composite primary key of (agent_id, skill, key) to store
skill-specific data for agents in a flexible way.
"""
model_config = ConfigDict(
from_attributes=True,
json_encoders={datetime: lambda v: v.isoformat(timespec="milliseconds")},
)
created_at: Annotated[
datetime, Field(description="Timestamp when this data was created")
]
updated_at: Annotated[
datetime, Field(description="Timestamp when this data was updated")
]
@classmethod
async def get(cls, agent_id: str, skill: str, key: str) -> Optional[dict]:
"""Get skill data for an agent.
Args:
agent_id: ID of the agent
skill: Name of the skill
key: Data key
Returns:
Dictionary containing the skill data if found, None otherwise
"""
async with get_session() as db:
result = await db.scalar(
select(AgentSkillDataTable).where(
AgentSkillDataTable.agent_id == agent_id,
AgentSkillDataTable.skill == skill,
AgentSkillDataTable.key == key,
)
)
return result.data if result else None
@classmethod
async def clean_data(cls, agent_id: str):
"""Clean all skill data for an agent.
Args:
agent_id: ID of the agent
"""
async with get_session() as db:
await db.execute(
delete(AgentSkillDataTable).where(
AgentSkillDataTable.agent_id == agent_id
)
)
await db.commit()
class ThreadSkillDataTable(Base):
"""Database table model for storing skill-specific data for threads."""
__tablename__ = "thread_skill_data"
thread_id = Column(String, primary_key=True)
skill = Column(String, primary_key=True)
key = Column(String, primary_key=True)
agent_id = Column(String, nullable=False)
data = Column(JSONB, nullable=True)
created_at = Column(
DateTime(timezone=True),
nullable=False,
server_default=func.now(),
)
updated_at = Column(
DateTime(timezone=True),
nullable=False,
server_default=func.now(),
onupdate=lambda: datetime.now(timezone.utc),
)
class ThreadSkillDataCreate(BaseModel):
"""Base model for creating thread skill data records."""
model_config = ConfigDict(from_attributes=True)
thread_id: Annotated[
str, Field(description="ID of the thread this data belongs to")
]
skill: Annotated[str, Field(description="Name of the skill this data is for")]
key: Annotated[str, Field(description="Key for this specific piece of data")]
agent_id: Annotated[str, Field(description="ID of the agent that owns this thread")]
data: Annotated[Dict[str, Any], Field(description="JSON data stored for this key")]
async def save(self) -> "ThreadSkillData":
"""Save or update skill data.
Returns:
ThreadSkillData: The saved thread skill data instance
"""
async with get_session() as db:
record = await db.scalar(
select(ThreadSkillDataTable).where(
ThreadSkillDataTable.thread_id == self.thread_id,
ThreadSkillDataTable.skill == self.skill,
ThreadSkillDataTable.key == self.key,
)
)
if record:
# Update existing record
record.data = self.data
record.agent_id = self.agent_id
else:
# Create new record
record = ThreadSkillDataTable(**self.model_dump())
db.add(record)
await db.commit()
await db.refresh(record)
return ThreadSkillData.model_validate(record)
class ThreadSkillData(ThreadSkillDataCreate):
"""Model for storing skill-specific data for threads.
This model uses a composite primary key of (thread_id, skill, key) to store
skill-specific data for threads in a flexible way. It also includes agent_id
as a required field for tracking ownership.
"""
model_config = ConfigDict(
from_attributes=True,
json_encoders={datetime: lambda v: v.isoformat(timespec="milliseconds")},
)
created_at: Annotated[
datetime, Field(description="Timestamp when this data was created")
]
updated_at: Annotated[
datetime, Field(description="Timestamp when this data was updated")
]
@classmethod
async def get(cls, thread_id: str, skill: str, key: str) -> Optional[dict]:
"""Get skill data for a thread.
Args:
thread_id: ID of the thread
skill: Name of the skill
key: Data key
Returns:
Dictionary containing the skill data if found, None otherwise
"""
async with get_session() as db:
record = await db.scalar(
select(ThreadSkillDataTable).where(
ThreadSkillDataTable.thread_id == thread_id,
ThreadSkillDataTable.skill == skill,
ThreadSkillDataTable.key == key,
)
)
return record.data if record else None
@classmethod
async def clean_data(
cls,
agent_id: str,
thread_id: Annotated[
str,
Field(
default="",
description="Optional ID of the thread. If provided, only cleans data for that thread.",
),
],
):
"""Clean all skill data for a thread or agent.
Args:
agent_id: ID of the agent
thread_id: Optional ID of the thread. If provided, only cleans data for that thread.
If empty, cleans all data for the agent.
"""
async with get_session() as db:
if thread_id and thread_id != "":
await db.execute(
delete(ThreadSkillDataTable).where(
ThreadSkillDataTable.agent_id == agent_id,
ThreadSkillDataTable.thread_id == thread_id,
)
)
else:
await db.execute(
delete(ThreadSkillDataTable).where(
ThreadSkillDataTable.agent_id == agent_id
)
)
await db.commit()