fix(layout_optimizer): apply code review follow-ups

This commit is contained in:
yexiaozhou
2026-04-03 01:42:22 +08:00
parent 00bdf9b822
commit a7a6d77d7a
12 changed files with 336 additions and 68 deletions

View File

@@ -19,6 +19,8 @@
from __future__ import annotations
from collections import defaultdict
import itertools
import logging
import logging.handlers
import math
@@ -120,6 +122,7 @@ else:
# ---------- 设备目录缓存 ----------
_device_cache: list[dict] | None = None
_DEVICE_PARAM_KEYS = {"device_a", "device_b", "arm_id", "target_device_id", "device"}
# 消耗品/配件关键词(不独立放置于实验台)
@@ -199,6 +202,80 @@ def _build_device_list() -> list[dict]:
return _device_cache
def _catalog_id_from_internal(device_id: str) -> str:
"""内部实例 ID → catalog ID。"""
return device_id.split("#", 1)[0]
def _expand_constraints_for_duplicates(
constraints: list[Constraint], devices: list,
) -> list[Constraint]:
"""将引用 bare catalog ID 的约束扩展到所有重复实例。"""
catalog_instances: dict[str, list[str]] = defaultdict(list)
for dev in devices:
catalog_instances[_catalog_id_from_internal(dev.id)].append(dev.id)
expanded_constraints: list[Constraint] = []
for constraint in constraints:
fan_out_keys: list[str] = []
fan_out_values: list[list[str]] = []
for key in _DEVICE_PARAM_KEYS:
if key not in constraint.params:
continue
ref_id = constraint.params[key]
if "#" in ref_id:
continue
instances = catalog_instances.get(ref_id, [])
if len(instances) > 1:
fan_out_keys.append(key)
fan_out_values.append(instances)
logger.info(
"Fan-out: %s %s=%s -> %d instances",
constraint.rule_name, key, ref_id, len(instances),
)
if not fan_out_keys:
expanded_constraints.append(constraint)
continue
for combo in itertools.product(*fan_out_values):
new_params = dict(constraint.params)
for key, internal_id in zip(fan_out_keys, combo):
new_params[key] = internal_id
expanded_constraints.append(
Constraint(
type=constraint.type,
rule_name=constraint.rule_name,
params=new_params,
weight=constraint.weight,
)
)
return expanded_constraints
def _maybe_add_prefer_aligned_constraint(
constraints: list[Constraint], align_weight: float,
) -> list[Constraint]:
"""仅在用户未显式提供 prefer_aligned 时注入对齐约束。"""
if align_weight <= 0:
return constraints
if any(c.rule_name == "prefer_aligned" for c in constraints):
logger.info("Skipping auto-injected prefer_aligned because one already exists")
return constraints
constraints.append(
Constraint(
type="soft",
rule_name="prefer_aligned",
weight=align_weight,
)
)
return constraints
# ---------- 路由 ----------
@@ -322,6 +399,14 @@ async def interpret_schema():
},
"generates": "soft maximize_distance for each pair",
},
"keep_adjacent": {
"description": "Devices should stay adjacent, similar to close_together",
"params": {
"devices": {"type": "list[string]", "required": True, "description": "Device IDs (min 2)"},
"priority": {"type": "string", "required": False, "default": "medium", "enum": ["low", "medium", "high"]},
},
"generates": "soft minimize_distance for each pair",
},
"max_distance": {
"description": "Two devices must be within a maximum distance",
"params": {
@@ -410,6 +495,7 @@ class OptimizeRequest(BaseModel):
seed: int | None = None
snap_cardinal: bool = False
angle_granularity: int | None = None
arm_reach: dict[str, float] = {}
class PositionXYZ(BaseModel):
@@ -439,7 +525,7 @@ async def run_optimize(request: OptimizeRequest):
from fastapi import HTTPException
from .constraints import evaluate_default_hard_constraints, evaluate_constraints
from .mock_checkers import MockCollisionChecker
from .mock_checkers import MockCollisionChecker, MockReachabilityChecker
from .optimizer import optimize, snap_theta, snap_theta_safe
from .seeders import resolve_seeder_params, seed_layout
@@ -460,19 +546,12 @@ async def run_optimize(request: OptimizeRequest):
detail="angle_granularity must be one of: 4, 8, 12, 24",
)
# Build mapping: internal uuid-based id → (catalog_id, uuid)
# create_devices_from_list uses uuid as Device.id when available
id_to_catalog: dict[str, str] = {}
id_to_uuid: dict[str, str] = {}
for d in request.devices:
internal_id = d.uuid or d.id
id_to_catalog[internal_id] = d.id
id_to_uuid[internal_id] = d.uuid or d.id
# 转换输入
devices = create_devices_from_list(
[d.model_dump() for d in request.devices]
)
id_to_catalog = {dev.id: _catalog_id_from_internal(dev.id) for dev in devices}
id_to_uuid = {dev.id: (dev.uuid or dev.id) for dev in devices}
lab = parse_lab(request.lab.model_dump())
constraints = [
Constraint(
@@ -483,6 +562,7 @@ async def run_optimize(request: OptimizeRequest):
)
for c in request.constraints
]
constraints = _expand_constraints_for_duplicates(constraints, devices)
# 1. Resolve seeder
try:
@@ -509,23 +589,22 @@ async def run_optimize(request: OptimizeRequest):
weight=request.seeder_overrides.get("orientation_weight", DEFAULT_WEIGHT_ANGLE),
))
# prefer_aligned: penalize non-cardinal angles默认关闭用户可通过 align_cardinal intent 或 seeder_overrides 开启)
align_weight = request.seeder_overrides.get("align_weight", 0)
if align_weight > 0:
constraints.append(Constraint(
type="soft",
rule_name="prefer_aligned",
weight=align_weight,
))
constraints = _maybe_add_prefer_aligned_constraint(
constraints,
request.seeder_overrides.get("align_weight", 0),
)
# 4. Conditional Differential Evolution
de_ran = False
checker = MockCollisionChecker()
reachability_checker = MockReachabilityChecker(request.arm_reach or None)
if request.run_de:
result_placements = optimize(
devices=devices,
lab=lab,
constraints=constraints,
collision_checker=checker,
reachability_checker=reachability_checker,
seed_placements=seed_placements,
maxiter=request.maxiter,
seed=request.seed,
@@ -552,7 +631,7 @@ async def run_optimize(request: OptimizeRequest):
# 也检查用户硬约束binary 模式)
if constraints and not math.isinf(final_cost):
user_hard_cost = evaluate_constraints(
devices, result_placements, lab, constraints, checker,
devices, result_placements, lab, constraints, checker, reachability_checker,
graduated=False,
)
if math.isinf(user_hard_cost):