Files
omnia-langchain/memory.py
2026-04-12 03:33:34 +05:30

81 lines
2.6 KiB
Python

from langchain_community.vectorstores import FAISS
from llm_runtime import embeddings
from time_utils import describe_relative_time
class MemoryEntry:
def __init__(self, content, event_type, timestamp_str, location, entities):
self.content = content
self.event_type = event_type # 'dialogue', 'observation', 'reflection'
self.timestamp = timestamp_str
self.location = location
self.entities = entities
def __repr__(self):
return f"[{self.timestamp}] ({self.location}): {self.content}"
def to_dict(self):
return {
"content": self.content,
"event_type": self.event_type,
"timestamp": self.timestamp,
"location": self.location,
"entities": list(self.entities),
}
def to_vector_text(self):
entities = ", ".join(self.entities) if self.entities else "Unknown"
return (
f"{self.content}\n"
f"Time: {self.timestamp}\n"
f"Location: {self.location}\n"
f"Entities: {entities}\n"
f"Type: {self.event_type}"
)
def to_relative_string(self, reference_time):
time_label = describe_relative_time(self.timestamp, reference_time)
return f"[{time_label}] ({self.location}): {self.content}"
class EntityMemory:
def __init__(self):
self.vector_store = None
self.entries = []
def save(self, entry: MemoryEntry):
self.entries.append(entry)
entry_text = entry.to_vector_text()
if self.vector_store is None:
self.vector_store = FAISS.from_texts(
[entry_text],
embeddings,
metadatas=[{"entry_index": len(self.entries) - 1}],
)
else:
self.vector_store.add_texts(
[entry_text],
metadatas=[{"entry_index": len(self.entries) - 1}],
)
def retrieve(self, query: str, k=2, reference_time=None):
if self.vector_store is None:
return "No long-term memories relevant."
docs = self.vector_store.similarity_search(query, k=k)
memories = []
for doc in docs:
entry_index = doc.metadata.get("entry_index")
if entry_index is None:
memories.append(doc.page_content)
continue
entry = self.entries[entry_index]
if reference_time is None:
memories.append(repr(entry))
else:
memories.append(entry.to_relative_string(reference_time))
return "\n".join(memories)
def dump_entries(self):
return list(self.entries)