Files
omnia-langchain/scenario_loader.py

144 lines
4.5 KiB
Python

import json
import logging
from dataclasses import dataclass
from pathlib import Path
from entities import Entity, Player
from memory import MemoryEntry
from time_utils import WorldClock
from world_architect import WorldState
logger = logging.getLogger(__name__)
@dataclass
class Scenario:
metadata: dict
entities: dict
player_id: str | None = None
world_state: WorldState | None = None
def load_scenario(path: Path) -> Scenario:
logger.info("Loading scenario from %s", path)
payload = json.loads(path.read_text())
metadata = payload.get("scenario", {})
player_id = metadata.get("player_id")
world_time = metadata.get("world_time", "Unknown")
location = metadata.get("location", "Unknown")
raw_entities = payload.get("entities", [])
if not raw_entities:
raise ValueError(f"No entities found in scenario: {path}")
entities = {}
for raw in raw_entities:
entity_id = (raw.get("id") or raw["name"]).strip().lower()
if entity_id in entities:
raise ValueError(f"Duplicate Entity id '{entity_id}' in {path}")
entity_class = Player if player_id and entity_id == player_id else Entity
entity = entity_class(
name=raw["name"],
traits=list(raw["traits"]),
stats=dict(raw["stats"]),
voice_sample=raw["voice_sample"],
current_mood=raw.get("current_mood", "Neutral"),
entity_id=entity_id,
)
for memory in raw.get("memories", []):
if isinstance(memory, str):
entry = MemoryEntry(
content=memory,
event_type="observation",
timestamp_str=world_time,
location=location,
entities=[],
)
else:
entry = MemoryEntry(
content=memory["content"],
event_type=memory["event_type"],
timestamp_str=memory["timestamp"],
location=memory["location"],
entities=[
normalized
for entity_ref in memory.get("entities", [])
if (normalized := str(entity_ref).strip().lower())
and normalized != entity_id
],
)
entity.perceive(entry)
entities[entity_id] = entity
logger.info("Loaded %s entities from scenario.", len(entities))
# Initialize world state
world_state = WorldState()
world_state.world_clock = WorldClock.from_time_str(world_time)
# Populate world state with entity data
for entity_id, entity in entities.items():
world_state.entities[entity_id] = {
"name": entity.name,
"location": location,
"health": 100,
"status": "calm",
"mood": entity.current_mood,
}
# Add location to world state
world_state.locations[location.lower().replace(" ", "_")] = {
"name": location,
"description": f"The {location}",
"occupants": len(entities),
"visibility": "clear",
}
logger.info("Initialized WorldState for scenario")
return Scenario(
metadata=metadata,
entities=entities,
player_id=player_id,
world_state=world_state,
)
def dump_scenario(scenario: Scenario) -> dict:
entity_payloads = []
for entity_id in sorted(scenario.entities.keys()):
entity = scenario.entities[entity_id]
entity_payloads.append(
{
"id": entity_id,
"name": entity.name,
"traits": list(entity.traits),
"stats": dict(entity.stats),
"voice_sample": entity.voice_sample,
"current_mood": entity.current_mood,
"memories": [
entry.to_dict() for entry in entity.memory.dump_entries()
],
}
)
metadata = dict(scenario.metadata)
if scenario.player_id and metadata.get("player_id") != scenario.player_id:
metadata["player_id"] = scenario.player_id
return {
"scenario": metadata,
"entities": entity_payloads,
}
def dumps_scenario(scenario: Scenario) -> str:
return json.dumps(dump_scenario(scenario), indent=2)
def save_scenario(path: Path, scenario: Scenario) -> str:
dumped = dumps_scenario(scenario)
path.write_text(f"{dumped}\n")
logger.info("Saved scenario to %s", path)
return dumped