decouple scenario, add structure to memories
This commit is contained in:
435
engine.py
Normal file
435
engine.py
Normal file
@@ -0,0 +1,435 @@
|
||||
from datetime import datetime, timedelta
|
||||
import logging
|
||||
import multiprocessing
|
||||
|
||||
from langchain_community.chat_models import ChatLlamaCpp
|
||||
from langchain_community.embeddings import HuggingFaceEmbeddings
|
||||
from langchain_community.vectorstores import FAISS
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
|
||||
DEFAULT_MODEL_PATH = "/home/sortedcord/.cache/huggingface/hub/models--ggml-org--gemma-4-E4B-it-GGUF/snapshots/6b352c53e1d2e4bb974d9f8cafcf85887c224219/gemma-4-e4b-it-Q4_K_M.gguf"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
|
||||
|
||||
llm = ChatLlamaCpp(
|
||||
temperature=0.2,
|
||||
model_path=DEFAULT_MODEL_PATH,
|
||||
n_ctx=4096,
|
||||
n_gpu_layers=8,
|
||||
max_tokens=512,
|
||||
n_threads=multiprocessing.cpu_count() - 1,
|
||||
repeat_penalty=1.5,
|
||||
)
|
||||
|
||||
|
||||
def _format_prompt(messages):
|
||||
formatted = []
|
||||
for message in messages:
|
||||
formatted.append(f"{message.__class__.__name__}:\n{message.content}")
|
||||
return "\n\n".join(formatted)
|
||||
|
||||
|
||||
def _normalize_llm_output(text: str) -> str:
|
||||
return text.replace("\r", "").replace("\n", "").strip()
|
||||
|
||||
|
||||
def _time_of_day_label(hour: int, *, for_today: bool) -> str:
|
||||
if 5 <= hour < 12:
|
||||
return "morning"
|
||||
if 12 <= hour < 17:
|
||||
return "afternoon"
|
||||
return "tonight" if for_today else "night"
|
||||
|
||||
|
||||
def describe_relative_time(
|
||||
timestamp_str: str,
|
||||
reference_time: datetime,
|
||||
*,
|
||||
prefer_day_part_for_today: bool = False,
|
||||
) -> str:
|
||||
try:
|
||||
timestamp = datetime.strptime(timestamp_str, "%Y-%m-%d %H:%M")
|
||||
except ValueError:
|
||||
return "a long time ago"
|
||||
|
||||
delta = reference_time - timestamp
|
||||
seconds = delta.total_seconds()
|
||||
if seconds < 0:
|
||||
return "just now"
|
||||
|
||||
if not prefer_day_part_for_today:
|
||||
if seconds < 120:
|
||||
return "just now"
|
||||
if seconds < 15 * 60:
|
||||
return "a few minutes ago"
|
||||
if seconds < 90 * 60:
|
||||
return "an hour ago"
|
||||
if seconds < 3 * 60 * 60:
|
||||
return "a couple hours ago"
|
||||
|
||||
day_diff = (reference_time.date() - timestamp.date()).days
|
||||
if day_diff == 0:
|
||||
return f"today {_time_of_day_label(timestamp.hour, for_today=True)}"
|
||||
if day_diff == 1:
|
||||
return f"yesterday {_time_of_day_label(timestamp.hour, for_today=False)}"
|
||||
if day_diff == 2:
|
||||
return "2 days ago"
|
||||
if day_diff == 3:
|
||||
return "3 days ago"
|
||||
if day_diff <= 6:
|
||||
return "a couple days ago"
|
||||
if day_diff <= 10:
|
||||
return "a week ago"
|
||||
if day_diff <= 20:
|
||||
return "a couple weeks ago"
|
||||
if day_diff <= 45:
|
||||
return "a month ago"
|
||||
if day_diff <= 75:
|
||||
return "a couple months ago"
|
||||
if day_diff <= 420:
|
||||
return "a year ago"
|
||||
return "a long time ago"
|
||||
|
||||
|
||||
class WorldClock:
|
||||
def __init__(self, start_year=1999, month=5, day=14, hour=18, minute=0):
|
||||
# We use a standard datetime object for easy math
|
||||
self.current_time = datetime(start_year, month, day, hour, minute)
|
||||
|
||||
def advance_time(self, minutes=0, hours=0, days=0):
|
||||
self.current_time += timedelta(minutes=minutes, hours=hours, days=days)
|
||||
|
||||
def get_time_str(self):
|
||||
# 1999-05-14 18:00
|
||||
return self.current_time.strftime("%Y-%m-%d %H:%M")
|
||||
|
||||
def get_vibe(self):
|
||||
"""Helper to tell the LLM the 'feel' of the time."""
|
||||
hour = self.current_time.hour
|
||||
if 5 <= hour < 12:
|
||||
return "Morning"
|
||||
if 12 <= hour < 17:
|
||||
return "Afternoon"
|
||||
if 17 <= hour < 21:
|
||||
return "Evening"
|
||||
return "Night"
|
||||
|
||||
@classmethod
|
||||
def from_time_str(cls, time_str: str | None):
|
||||
if not time_str:
|
||||
return cls()
|
||||
parsed = datetime.strptime(time_str, "%Y-%m-%d %H:%M")
|
||||
return cls(
|
||||
start_year=parsed.year,
|
||||
month=parsed.month,
|
||||
day=parsed.day,
|
||||
hour=parsed.hour,
|
||||
minute=parsed.minute,
|
||||
)
|
||||
|
||||
|
||||
class MemoryEntry:
|
||||
def __init__(self, content, event_type, timestamp_str, location, entities):
|
||||
self.content = content
|
||||
self.event_type = event_type # 'dialogue', 'observation', 'reflection'
|
||||
self.timestamp = timestamp_str
|
||||
self.location = location
|
||||
self.entities = entities
|
||||
|
||||
def __repr__(self):
|
||||
return f"[{self.timestamp}] ({self.location}): {self.content}"
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"content": self.content,
|
||||
"event_type": self.event_type,
|
||||
"timestamp": self.timestamp,
|
||||
"location": self.location,
|
||||
"entities": list(self.entities),
|
||||
}
|
||||
|
||||
def to_vector_text(self):
|
||||
entities = ", ".join(self.entities) if self.entities else "Unknown"
|
||||
return (
|
||||
f"{self.content}\n"
|
||||
f"Time: {self.timestamp}\n"
|
||||
f"Location: {self.location}\n"
|
||||
f"Entities: {entities}\n"
|
||||
f"Type: {self.event_type}"
|
||||
)
|
||||
|
||||
def to_relative_string(self, reference_time: datetime):
|
||||
time_label = describe_relative_time(self.timestamp, reference_time)
|
||||
return f"[{time_label}] ({self.location}): {self.content}"
|
||||
|
||||
|
||||
class EntityMemory:
|
||||
def __init__(self):
|
||||
self.vector_store = None
|
||||
self.entries = []
|
||||
|
||||
def save(self, entry: MemoryEntry):
|
||||
self.entries.append(entry)
|
||||
entry_text = entry.to_vector_text()
|
||||
if self.vector_store is None:
|
||||
self.vector_store = FAISS.from_texts(
|
||||
[entry_text],
|
||||
embeddings,
|
||||
metadatas=[{"entry_index": len(self.entries) - 1}],
|
||||
)
|
||||
else:
|
||||
self.vector_store.add_texts(
|
||||
[entry_text],
|
||||
metadatas=[{"entry_index": len(self.entries) - 1}],
|
||||
)
|
||||
|
||||
def retrieve(self, query: str, k=2, reference_time: datetime | None = None):
|
||||
if self.vector_store is None:
|
||||
return "No long-term memories relevant."
|
||||
docs = self.vector_store.similarity_search(query, k=k)
|
||||
memories = []
|
||||
for doc in docs:
|
||||
entry_index = doc.metadata.get("entry_index")
|
||||
if entry_index is None:
|
||||
memories.append(doc.page_content)
|
||||
continue
|
||||
entry = self.entries[entry_index]
|
||||
if reference_time is None:
|
||||
memories.append(repr(entry))
|
||||
else:
|
||||
memories.append(entry.to_relative_string(reference_time))
|
||||
return "\n".join(memories)
|
||||
|
||||
def dump_entries(self):
|
||||
return list(self.entries)
|
||||
|
||||
|
||||
class Entity:
|
||||
def __init__(
|
||||
self,
|
||||
name,
|
||||
traits,
|
||||
stats,
|
||||
voice_sample,
|
||||
current_mood="Neutral",
|
||||
entity_id=None,
|
||||
):
|
||||
self.name = name
|
||||
self.traits = traits
|
||||
self.stats = stats
|
||||
self.current_mood = current_mood
|
||||
self.memory = EntityMemory()
|
||||
# TIER 1: The Short-Term Buffer (Verbatim)
|
||||
self.chat_buffer = []
|
||||
self.voice_sample = voice_sample
|
||||
self.entity_id = entity_id
|
||||
|
||||
def perceive(self, entry: MemoryEntry):
|
||||
self.memory.save(entry)
|
||||
|
||||
def reflect_and_summarize(self, world_clock: WorldClock, location: str):
|
||||
"""Converts Tier 1 (Buffer) into Tier 2 (Long-term Subjective Memory)."""
|
||||
if not self.chat_buffer:
|
||||
return
|
||||
|
||||
dialogue_text = "\n".join(
|
||||
[f"{m['role_name']}: {m['content']}" for m in self.chat_buffer]
|
||||
)
|
||||
|
||||
# The Subjective Filter Prompt
|
||||
summary_prompt = [
|
||||
SystemMessage(
|
||||
content=f"""
|
||||
You are the private inner thoughts of {self.name}.
|
||||
Traits: {", ".join(self.traits)}.
|
||||
Mood: {self.current_mood}.
|
||||
Voice Reference: {self.voice_sample}
|
||||
|
||||
Think about what just happened.
|
||||
- No META-TALK, Do not use 'player', 'interaction', 'entity', or 'dialogue'
|
||||
- BE SUBJECTIVE. If you hated the talk or loved it, then express that.
|
||||
- USE YOUR VOICE. Match the style of your Voice Reference
|
||||
- Focus only on facts learned or feelings toward the person"""
|
||||
),
|
||||
HumanMessage(
|
||||
content=f"""
|
||||
What just happened? Context:\n{dialogue_text}"""
|
||||
),
|
||||
]
|
||||
|
||||
logger.info("LLM prompt (reflection):\n%s", _format_prompt(summary_prompt))
|
||||
summary = _normalize_llm_output(llm.invoke(summary_prompt).content)
|
||||
logger.info("SYSTEM: %s reflected on the talk: '%s'", self.name, summary)
|
||||
|
||||
chat_entities = sorted(
|
||||
{
|
||||
m["role_id"]
|
||||
for m in self.chat_buffer
|
||||
if m.get("role_id") and m.get("role_id") != self.entity_id
|
||||
}
|
||||
)
|
||||
reflection = MemoryEntry(
|
||||
content=f"Past Conversation Summary: {summary}",
|
||||
event_type="reflection",
|
||||
timestamp_str=world_clock.get_time_str(),
|
||||
location=location,
|
||||
entities=chat_entities,
|
||||
)
|
||||
self.perceive(reflection)
|
||||
self.chat_buffer = [] # Clear buffer after archiving
|
||||
|
||||
|
||||
class Player(Entity):
|
||||
pass
|
||||
|
||||
|
||||
def ask_entity(
|
||||
entity: Entity,
|
||||
player: Entity,
|
||||
player_query: str,
|
||||
world_clock: WorldClock,
|
||||
location: str,
|
||||
):
|
||||
facts = entity.memory.retrieve(
|
||||
player_query,
|
||||
reference_time=world_clock.current_time,
|
||||
)
|
||||
|
||||
recent_context = "\n".join(
|
||||
[f"{m['role_name']}: {m['content']}" for m in entity.chat_buffer[-5:]]
|
||||
)
|
||||
|
||||
world_time_label = describe_relative_time(
|
||||
world_clock.get_time_str(),
|
||||
world_clock.current_time,
|
||||
prefer_day_part_for_today=True,
|
||||
)
|
||||
|
||||
prompt = [
|
||||
SystemMessage(content=f"WORLD TIME: {world_time_label}"),
|
||||
SystemMessage(
|
||||
content=f"""
|
||||
### ROLE
|
||||
You are {entity.name}. Persona: {", ".join(entity.traits)}.
|
||||
Current Mood: {entity.current_mood}.
|
||||
Vibe Time: {world_clock.get_vibe()}.
|
||||
Location: {location}.
|
||||
|
||||
### WRITING STYLE RULES
|
||||
1. NO META-TALK. Never mention "memory," "records," "claims," or "narratives."
|
||||
2. ACT, DON'T EXPLAIN. If you don't know something, just say "Never heard of it" or "I wasn't there." Do not explain WHY you don't know.
|
||||
|
||||
### KNOWLEDGE
|
||||
MEMORIES: {facts}
|
||||
RECENT CHAT: {recent_context}
|
||||
"""
|
||||
),
|
||||
HumanMessage(content=f"{player.name} speaks to you: {player_query}"),
|
||||
]
|
||||
|
||||
logger.info("LLM prompt (dialogue):\n%s", _format_prompt(prompt))
|
||||
response = _normalize_llm_output(llm.invoke(prompt).content)
|
||||
|
||||
entity.chat_buffer.append(
|
||||
{
|
||||
"role_id": player.entity_id,
|
||||
"role_name": player.name,
|
||||
"content": player_query,
|
||||
}
|
||||
)
|
||||
entity.chat_buffer.append(
|
||||
{
|
||||
"role_id": entity.entity_id,
|
||||
"role_name": entity.name,
|
||||
"content": response,
|
||||
}
|
||||
)
|
||||
|
||||
player.chat_buffer.append(
|
||||
{
|
||||
"role_id": player.entity_id,
|
||||
"role_name": player.name,
|
||||
"content": player_query,
|
||||
}
|
||||
)
|
||||
player.chat_buffer.append(
|
||||
{
|
||||
"role_id": entity.entity_id,
|
||||
"role_name": entity.name,
|
||||
"content": response,
|
||||
}
|
||||
)
|
||||
|
||||
logger.info("[%s]: %s", entity.name.upper(), response)
|
||||
|
||||
|
||||
def _build_name_lookup(entities):
|
||||
name_lookup = {}
|
||||
for entity_key, entity in entities.items():
|
||||
name_lookup[entity_key.lower()] = entity_key
|
||||
name_lookup[entity.name.lower()] = entity_key
|
||||
return name_lookup
|
||||
|
||||
|
||||
def start_game(entities, player_id=None, world_time=None, location="Unknown"):
|
||||
player = None
|
||||
if player_id:
|
||||
player = entities.get(player_id)
|
||||
if player is None:
|
||||
raise ValueError(f"Player entity '{player_id}' not found in scenario.")
|
||||
else:
|
||||
player = Player(
|
||||
name="Player",
|
||||
traits=["Curious"],
|
||||
stats={},
|
||||
voice_sample="Voice: 'Direct and concise.'",
|
||||
entity_id="player",
|
||||
)
|
||||
|
||||
available_entities = {
|
||||
entity_id: entity
|
||||
for entity_id, entity in entities.items()
|
||||
if entity_id != player_id
|
||||
}
|
||||
|
||||
world_clock = WorldClock.from_time_str(world_time)
|
||||
current_entity = None
|
||||
name_lookup = _build_name_lookup(available_entities)
|
||||
entity_names = "/".join(
|
||||
[entity.name for entity in available_entities.values()] + ["Exit"]
|
||||
)
|
||||
logger.info("--- WORLD INITIALIZED ---")
|
||||
logger.info("World initialized with %s active entities.", len(available_entities))
|
||||
logger.info("Current location: %s", location)
|
||||
logger.info(
|
||||
"World time: %s (%s)", world_clock.get_time_str(), world_clock.get_vibe()
|
||||
)
|
||||
|
||||
while True:
|
||||
target_name = (
|
||||
input(f"\nWho do you want to talk to? ({entity_names}): ").lower().strip()
|
||||
)
|
||||
|
||||
if target_name in ["exit", "quit"]:
|
||||
if current_entity:
|
||||
current_entity.reflect_and_summarize(world_clock, location)
|
||||
break
|
||||
|
||||
target_key = name_lookup.get(target_name)
|
||||
if target_key is None:
|
||||
logger.warning("Target not found.")
|
||||
continue
|
||||
|
||||
new_entity = available_entities[target_key]
|
||||
if current_entity and current_entity != new_entity:
|
||||
logger.info(
|
||||
"You leave %s and approach %s.", current_entity.name, new_entity.name
|
||||
)
|
||||
current_entity.reflect_and_summarize(world_clock, location)
|
||||
|
||||
current_entity = new_entity
|
||||
|
||||
user_msg = input(f"You to {current_entity.name}: ")
|
||||
ask_entity(current_entity, player, user_msg, world_clock, location)
|
||||
Reference in New Issue
Block a user