From a38d078d514b619b3e80cb0bbcdbafa5ee57187e Mon Sep 17 00:00:00 2001 From: Aditya Gupta Date: Sat, 18 Apr 2026 18:04:38 +0530 Subject: [PATCH] feat: Allow changing of parent node and fix tree collapse --- src/omnia/tools/scenario_builder_tui.py | 193 +++++++++++++++++++++--- 1 file changed, 176 insertions(+), 17 deletions(-) diff --git a/src/omnia/tools/scenario_builder_tui.py b/src/omnia/tools/scenario_builder_tui.py index 22828c0..ae9ac5e 100644 --- a/src/omnia/tools/scenario_builder_tui.py +++ b/src/omnia/tools/scenario_builder_tui.py @@ -451,6 +451,7 @@ class SpatialGraphScreen(Screen[Optional[dict[str, Any]]]): self.scenario = scenario self.selected_node = None self.node_map: dict[int, Any] = {} + self.parent_option_map: dict[str, tuple[str, dict[str, Any]]] = {} def compose(self) -> ComposeResult: yield Header(show_clock=True) @@ -471,14 +472,16 @@ class SpatialGraphScreen(Screen[Optional[dict[str, Any]]]): yield Label("", id="node-type-value") yield Label("ID", id="node-id-label") yield Input(id="node_id") - yield Label("Name", id="node-name-label") - yield Input(id="node_name") - yield Label("Description", id="node-desc-label") - yield TextArea(id="node-description") - with Horizontal(id="node-actions"): - yield Button("Generate", id="node_generate") - yield Button("Update", id="node_update", variant="primary") - yield Button("Reset", id="node_reset") + yield Label("Name", id="node-name-label") + yield Input(id="node_name") + yield Label("Description", id="node-desc-label") + yield TextArea(id="node-description") + yield Label("Parent", id="node-parent-label") + yield Select([], prompt="Select parent", id="node_parent") + with Horizontal(id="node-actions"): + yield Button("Generate", id="node_generate") + yield Button("Update", id="node_update", variant="primary") + yield Button("Reset", id="node_reset") with Horizontal(id="spatial-actions"): yield Button("Save", id="spatial_save", variant="primary") yield Button("Cancel", id="spatial_cancel") @@ -513,7 +516,11 @@ class SpatialGraphScreen(Screen[Optional[dict[str, Any]]]): def _world(self) -> dict[str, Any]: return self.scenario["spatial_graph"]["world"] - def _build_tree(self, select_obj: Optional[dict[str, Any]] = None) -> None: + def _build_tree( + self, + select_obj: Optional[dict[str, Any]] = None, + expanded: Optional[set[int]] = None, + ) -> None: tree = self.query_one("#spatial-tree", Tree) world = self._world() tree.reset(f"🌍 {world.get('name', 'World')}") @@ -521,12 +528,15 @@ class SpatialGraphScreen(Screen[Optional[dict[str, Any]]]): tree.root.expand() self.selected_node = None self.node_map = {} + expanded = expanded or set() def add_region(region: dict[str, Any]) -> None: node = tree.root.add( self._label_for("region", region), data=SpatialNodeRef("region", region) ) self.node_map[id(region)] = node + if id(region) in expanded: + node.expand() for location in region.get("locations", []): add_location(node, location) @@ -536,6 +546,8 @@ class SpatialGraphScreen(Screen[Optional[dict[str, Any]]]): data=SpatialNodeRef("location", location), ) self.node_map[id(location)] = loc_node + if id(location) in expanded: + loc_node.expand() for nested in location.get("locations", []): add_location(loc_node, nested) for poi in location.get("pois", []): @@ -543,16 +555,38 @@ class SpatialGraphScreen(Screen[Optional[dict[str, Any]]]): self._label_for("poi", poi), data=SpatialNodeRef("poi", poi) ) self.node_map[id(poi)] = poi_node + if id(poi) in expanded: + poi_node.expand() for region in world.get("regions", []): add_region(region) if select_obj and id(select_obj) in self.node_map: self.selected_node = self.node_map[id(select_obj)] + self._expand_ancestors(self.selected_node) self._load_form_from_node(self.selected_node) else: self._clear_form() + def _capture_expanded_nodes(self) -> set[int]: + tree = self.query_one("#spatial-tree", Tree) + expanded: set[int] = set() + + def walk(node: Any) -> None: + if node.is_expanded and isinstance(node.data, SpatialNodeRef): + expanded.add(id(node.data.obj)) + for child in node.children: + walk(child) + + walk(tree.root) + return expanded + + def _expand_ancestors(self, node: Any) -> None: + current = node + while current is not None: + current.expand() + current = current.parent + def _label_for(self, node_type: str, obj: dict[str, Any]) -> str: name = str(obj.get("name") or "Unnamed").strip() node_id = str(obj.get("id") or "no-id").strip() @@ -564,6 +598,8 @@ class SpatialGraphScreen(Screen[Optional[dict[str, Any]]]): self.query_one("#node_id", Input).value = "" self.query_one("#node_name", Input).value = "" self.query_one("#node-description", TextArea).text = "" + self.query_one("#node_parent", Select).set_options([]) + self.query_one("#node_parent", Select).value = Select.NULL self.query_one("#node-details", Vertical).add_class("hidden") self.query_one("#node_generate", Button).disabled = True self.query_one("#node_update", Button).disabled = True @@ -594,6 +630,20 @@ class SpatialGraphScreen(Screen[Optional[dict[str, Any]]]): ) self.query_one("#spatial_promote", Button).disabled = not can_promote self.query_one("#spatial_demote", Button).disabled = not can_demote + parent_select = self.query_one("#node_parent", Select) + if is_location: + options, selected_value, option_map = self._build_parent_options(ref.obj) + self.parent_option_map = option_map + parent_select.set_options(options) + parent_select.value = ( + selected_value if selected_value is not None else Select.NULL + ) + parent_select.disabled = False + else: + self.parent_option_map = {} + parent_select.set_options([]) + parent_select.value = Select.NULL + parent_select.disabled = True def on_tree_node_selected(self, event: Tree.NodeSelected) -> None: self.selected_node = event.node @@ -657,6 +707,91 @@ class SpatialGraphScreen(Screen[Optional[dict[str, Any]]]): return None return parent_ref.obj.setdefault("pois", []) + def _collect_locations(self) -> list[dict[str, Any]]: + locations: list[dict[str, Any]] = [] + + def walk(location_list: list[dict[str, Any]]) -> None: + for location in location_list: + locations.append(location) + walk(location.get("locations", [])) + + for region in self._world().get("regions", []): + walk(region.get("locations", [])) + return locations + + def _location_descendants(self, location: dict[str, Any]) -> set[int]: + descendants: set[int] = set() + + def walk(children: list[dict[str, Any]]) -> None: + for child in children: + descendants.add(id(child)) + walk(child.get("locations", [])) + + walk(location.get("locations", [])) + return descendants + + def _find_location_parent(self, target: dict[str, Any]) -> Optional[tuple[str, dict[str, Any]]]: + world = self._world() + for region in world.get("regions", []): + if target in region.get("locations", []): + return ("region", region) + for location in region.get("locations", []): + found = self._find_location_parent_in_location(location, target) + if found: + return found + return None + + def _find_location_parent_in_location( + self, location: dict[str, Any], target: dict[str, Any] + ) -> Optional[tuple[str, dict[str, Any]]]: + if target in location.get("locations", []): + return ("location", location) + for nested in location.get("locations", []): + found = self._find_location_parent_in_location(nested, target) + if found: + return found + return None + + def _build_parent_options( + self, current: dict[str, Any] + ) -> tuple[list[tuple[str, str]], Optional[str], dict[str, tuple[str, dict[str, Any]]]]: + options: list[tuple[str, str]] = [] + option_map: dict[str, tuple[str, dict[str, Any]]] = {} + forbidden = {id(current)} | self._location_descendants(current) + selected_value: Optional[str] = None + + for region in self._world().get("regions", []): + region_id = str(region.get("id", "")).strip() + region_name = str(region.get("name", "")).strip() + value = f"region@{id(region)}" + label = f"πŸ—ΊοΈ {region_name} ({region_id})" if region_name else f"πŸ—ΊοΈ {region_id or 'region'}" + options.append((label, value)) + option_map[value] = ("region", region) + + for location in self._collect_locations(): + if id(location) in forbidden: + continue + location_id = str(location.get("id", "")).strip() + location_name = str(location.get("name", "")).strip() + value = f"location@{id(location)}" + label = ( + f"πŸ“ {location_name} ({location_id})" + if location_name + else f"πŸ“ {location_id or 'location'}" + ) + options.append((label, value)) + option_map[value] = ("location", location) + + parent_info = self._find_location_parent(current) + if parent_info: + parent_type, parent_obj = parent_info + for value, (entry_type, entry_obj) in option_map.items(): + if entry_type == parent_type and entry_obj is parent_obj: + selected_value = value + break + + return options, selected_value, option_map + def _update_selected_node(self) -> None: ref = self._selected_ref() if ref is None or ref.node_type == "world": @@ -678,6 +813,26 @@ class SpatialGraphScreen(Screen[Optional[dict[str, Any]]]): ref.obj.setdefault("pois", []) elif ref.node_type == "poi": ref.obj.setdefault("connections", []) + if ref.node_type == "location": + parent_select = self.query_one("#node_parent", Select) + if parent_select.value is not Select.NULL: + selected_value = str(parent_select.value) + parent_info = self.parent_option_map.get(selected_value) + if parent_info: + parent_type, parent_obj = parent_info + current_parent = self._find_location_parent(ref.obj) + if current_parent is None or current_parent[1] is not parent_obj: + if current_parent: + current_type, current_obj = current_parent + if current_type == "region": + current_obj.get("locations", []).remove(ref.obj) + elif current_type == "location": + current_obj.get("locations", []).remove(ref.obj) + parent_obj.setdefault("locations", []).append(ref.obj) + expanded = self._capture_expanded_nodes() + self._build_tree(select_obj=ref.obj, expanded=expanded) + self._set_status("Node updated") + return if self.selected_node: self.selected_node.label = self._label_for(ref.node_type, ref.obj) self._set_status("Node updated") @@ -791,10 +946,11 @@ Current name: {ref.obj.get("name", "")} def _add_node(self, node_type: str) -> None: world = self._world() + expanded = self._capture_expanded_nodes() if node_type == "region": region = {"id": "", "name": "", "description": "", "locations": []} world.setdefault("regions", []).append(region) - self._build_tree(select_obj=region) + self._build_tree(select_obj=region, expanded=expanded) self._set_status("Region added") return @@ -815,7 +971,7 @@ Current name: {ref.obj.get("name", "")} "pois": [], } parent_ref.obj.setdefault("locations", []).append(location) - self._build_tree(select_obj=location) + self._build_tree(select_obj=location, expanded=expanded) self._set_status("Location added") return @@ -837,7 +993,7 @@ Current name: {ref.obj.get("name", "")} return poi = {"id": "", "name": "", "description": "", "connections": []} parent_ref.obj.setdefault("pois", []).append(poi) - self._build_tree(select_obj=poi) + self._build_tree(select_obj=poi, expanded=expanded) self._set_status("POI added") def _delete_node(self) -> None: @@ -846,10 +1002,11 @@ Current name: {ref.obj.get("name", "")} self._set_status("Select a node to delete") return parent_node = self._selected_parent_node() + expanded = self._capture_expanded_nodes() if ref.node_type == "region": world = self._world() world["regions"] = [r for r in world.get("regions", []) if r is not ref.obj] - self._build_tree() + self._build_tree(expanded=expanded) self._set_status("Region deleted") return if ref.node_type == "location": @@ -860,7 +1017,7 @@ Current name: {ref.obj.get("name", "")} parent_locations[:] = [ loc for loc in parent_locations if loc is not ref.obj ] - self._build_tree() + self._build_tree(expanded=expanded) self._set_status("Location deleted") return if ref.node_type == "poi": @@ -869,7 +1026,7 @@ Current name: {ref.obj.get("name", "")} self._set_status("Unable to delete POI") return parent_pois[:] = [poi for poi in parent_pois if poi is not ref.obj] - self._build_tree() + self._build_tree(expanded=expanded) self._set_status("POI deleted") def _promote_location(self) -> None: @@ -894,7 +1051,8 @@ Current name: {ref.obj.get("name", "")} self._set_status("Unable to promote location") return target_list.append(ref.obj) - self._build_tree(select_obj=ref.obj) + expanded = self._capture_expanded_nodes() + self._build_tree(select_obj=ref.obj, expanded=expanded) self._set_status("Location promoted") def _demote_location(self) -> None: @@ -918,7 +1076,8 @@ Current name: {ref.obj.get("name", "")} new_parent = parent_locations[index + 1] new_parent.setdefault("locations", []).append(ref.obj) parent_locations.remove(ref.obj) - self._build_tree(select_obj=ref.obj) + expanded = self._capture_expanded_nodes() + self._build_tree(select_obj=ref.obj, expanded=expanded) self._set_status("Location demoted") def on_button_pressed(self, event: Button.Pressed) -> None: