refactor: engine.py
This commit is contained in:
80
memory.py
Normal file
80
memory.py
Normal file
@@ -0,0 +1,80 @@
|
||||
from langchain_community.vectorstores import FAISS
|
||||
|
||||
from llm_runtime import embeddings
|
||||
from time_utils import describe_relative_time
|
||||
|
||||
|
||||
class MemoryEntry:
|
||||
def __init__(self, content, event_type, timestamp_str, location, entities):
|
||||
self.content = content
|
||||
self.event_type = event_type # 'dialogue', 'observation', 'reflection'
|
||||
self.timestamp = timestamp_str
|
||||
self.location = location
|
||||
self.entities = entities
|
||||
|
||||
def __repr__(self):
|
||||
return f"[{self.timestamp}] ({self.location}): {self.content}"
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"content": self.content,
|
||||
"event_type": self.event_type,
|
||||
"timestamp": self.timestamp,
|
||||
"location": self.location,
|
||||
"entities": list(self.entities),
|
||||
}
|
||||
|
||||
def to_vector_text(self):
|
||||
entities = ", ".join(self.entities) if self.entities else "Unknown"
|
||||
return (
|
||||
f"{self.content}\n"
|
||||
f"Time: {self.timestamp}\n"
|
||||
f"Location: {self.location}\n"
|
||||
f"Entities: {entities}\n"
|
||||
f"Type: {self.event_type}"
|
||||
)
|
||||
|
||||
def to_relative_string(self, reference_time):
|
||||
time_label = describe_relative_time(self.timestamp, reference_time)
|
||||
return f"[{time_label}] ({self.location}): {self.content}"
|
||||
|
||||
|
||||
class EntityMemory:
|
||||
def __init__(self):
|
||||
self.vector_store = None
|
||||
self.entries = []
|
||||
|
||||
def save(self, entry: MemoryEntry):
|
||||
self.entries.append(entry)
|
||||
entry_text = entry.to_vector_text()
|
||||
if self.vector_store is None:
|
||||
self.vector_store = FAISS.from_texts(
|
||||
[entry_text],
|
||||
embeddings,
|
||||
metadatas=[{"entry_index": len(self.entries) - 1}],
|
||||
)
|
||||
else:
|
||||
self.vector_store.add_texts(
|
||||
[entry_text],
|
||||
metadatas=[{"entry_index": len(self.entries) - 1}],
|
||||
)
|
||||
|
||||
def retrieve(self, query: str, k=2, reference_time=None):
|
||||
if self.vector_store is None:
|
||||
return "No long-term memories relevant."
|
||||
docs = self.vector_store.similarity_search(query, k=k)
|
||||
memories = []
|
||||
for doc in docs:
|
||||
entry_index = doc.metadata.get("entry_index")
|
||||
if entry_index is None:
|
||||
memories.append(doc.page_content)
|
||||
continue
|
||||
entry = self.entries[entry_index]
|
||||
if reference_time is None:
|
||||
memories.append(repr(entry))
|
||||
else:
|
||||
memories.append(entry.to_relative_string(reference_time))
|
||||
return "\n".join(memories)
|
||||
|
||||
def dump_entries(self):
|
||||
return list(self.entries)
|
||||
Reference in New Issue
Block a user