From 0f6264503a6c39559733f1802aa2203f6cf428fd Mon Sep 17 00:00:00 2001 From: Xuwznln <18435084+Xuwznln@users.noreply.github.com> Date: Sat, 21 Mar 2026 19:24:14 +0800 Subject: [PATCH] new registry sys exp. support with add device --- .github/workflows/ci-check.yml | 2 +- .gitignore | 1 + docs/developer_guide/add_device.md | 109 +- docs/developer_guide/add_registry.md | 65 +- docs/developer_guide/networking_overview.md | 4 +- docs/user_guide/launch.md | 5 +- unilabos/app/main.py | 148 +- unilabos/app/register.py | 63 +- unilabos/app/web/api.py | 2 +- unilabos/app/web/client.py | 64 +- unilabos/app/web/server.py | 4 +- unilabos/app/ws_client.py | 68 +- unilabos/config/config.py | 1 + unilabos/device_comms/universal_driver.py | 1 - unilabos/devices/virtual/workbench.py | 351 ++- unilabos/registry/ast_registry_scanner.py | 1022 ++++++++ unilabos/registry/decorators.py | 614 +++++ unilabos/registry/registry.py | 2598 ++++++++++++++----- unilabos/registry/utils.py | 699 +++++ unilabos/resources/graphio.py | 20 +- unilabos/resources/resource_tracker.py | 58 +- unilabos/ros/msgs/message_converter.py | 42 +- unilabos/ros/nodes/base_device_node.py | 127 +- unilabos/ros/nodes/presets/camera.py | 7 + unilabos/ros/nodes/presets/host_node.py | 238 +- unilabos/ros/nodes/presets/workstation.py | 99 +- unilabos/utils/decorator.py | 200 +- unilabos/utils/environment_check.py | 3 +- unilabos/utils/import_manager.py | 14 +- unilabos/utils/log.py | 1 - unilabos/utils/requirements.txt | 3 +- 31 files changed, 5453 insertions(+), 1180 deletions(-) create mode 100644 unilabos/registry/ast_registry_scanner.py create mode 100644 unilabos/registry/decorators.py create mode 100644 unilabos/registry/utils.py diff --git a/.github/workflows/ci-check.yml b/.github/workflows/ci-check.yml index 57245d94..402edc26 100644 --- a/.github/workflows/ci-check.yml +++ b/.github/workflows/ci-check.yml @@ -49,7 +49,7 @@ jobs: uv pip uninstall enum34 || echo enum34 not installed, skipping uv pip install . - - name: Run check mode (complete_registry) + - name: Run check mode (AST registry validation) run: | call conda activate check-env echo Running check mode... diff --git a/.gitignore b/.gitignore index 838331e3..12b344d6 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ output/ unilabos_data/ pyrightconfig.json .cursorignore +device_package*/ ## Python # Byte-compiled / optimized / DLL files diff --git a/docs/developer_guide/add_device.md b/docs/developer_guide/add_device.md index dc95274f..15ba4e08 100644 --- a/docs/developer_guide/add_device.md +++ b/docs/developer_guide/add_device.md @@ -15,6 +15,9 @@ Python 类设备驱动在完成注册表后可以直接在 Uni-Lab 中使用, **示例:** ```python +from unilabos.registry.decorators import device, topic_config + +@device(id="mock_gripper", category=["gripper"], description="Mock Gripper") class MockGripper: def __init__(self): self._position: float = 0.0 @@ -23,19 +26,23 @@ class MockGripper: self._status = "Idle" @property + @topic_config() # 添加 @topic_config 才会定时广播 def position(self) -> float: return self._position @property + @topic_config() def velocity(self) -> float: return self._velocity @property + @topic_config() def torque(self) -> float: return self._torque - # 会被自动识别的设备属性,接入 Uni-Lab 时会定时对外广播 + # 使用 @topic_config 装饰的属性,接入 Uni-Lab 时会定时对外广播 @property + @topic_config(period=2.0) # 可自定义发布周期 def status(self) -> str: return self._status @@ -149,7 +156,7 @@ my_device: # 设备唯一标识符 系统会自动分析您的 Python 驱动类并生成: -- `status_types`:从 `@property` 装饰的方法自动识别状态属性 +- `status_types`:从 `@topic_config` 装饰的 `@property` 或方法自动识别状态属性 - `action_value_mappings`:从类方法自动生成动作映射 - `init_param_schema`:从 `__init__` 方法分析初始化参数 - `schema`:前端显示用的属性类型定义 @@ -179,7 +186,9 @@ Uni-Lab 设备驱动是一个 Python 类,需要遵循以下结构: ```python from typing import Dict, Any +from unilabos.registry.decorators import device, topic_config +@device(id="my_device", category=["general"], description="My Device") class MyDevice: """设备类文档字符串 @@ -198,8 +207,9 @@ class MyDevice: # 初始化硬件连接 @property + @topic_config() # 必须添加 @topic_config 才会广播 def status(self) -> str: - """设备状态(会自动广播)""" + """设备状态(通过 @topic_config 广播)""" return self._status def my_action(self, param: float) -> Dict[str, Any]: @@ -217,34 +227,61 @@ class MyDevice: ## 状态属性 vs 动作方法 -### 状态属性(@property) +### 状态属性(@property + @topic_config) -状态属性会被自动识别并定期广播: +状态属性需要同时使用 `@property` 和 `@topic_config` 装饰器才会被识别并定期广播: ```python +from unilabos.registry.decorators import topic_config + @property +@topic_config() # 必须添加,否则不会广播 def temperature(self) -> float: """当前温度""" return self._read_temperature() @property +@topic_config(period=2.0) # 可自定义发布周期(秒) def status(self) -> str: """设备状态: idle, running, error""" return self._status @property +@topic_config(name="ready") # 可自定义发布名称 def is_ready(self) -> bool: """设备是否就绪""" return self._status == "idle" ``` +也可以使用普通方法(非 @property)配合 `@topic_config`: + +```python +@topic_config(period=10.0) +def get_sensor_data(self) -> Dict[str, float]: + """获取传感器数据(get_ 前缀会自动去除,发布名为 sensor_data)""" + return {"temp": self._temp, "humidity": self._humidity} +``` + +**`@topic_config` 参数**: + +| 参数 | 类型 | 默认值 | 说明 | +|------|------|--------|------| +| `period` | float | 5.0 | 发布周期(秒) | +| `print_publish` | bool | 节点默认 | 是否打印发布日志 | +| `qos` | int | 10 | QoS 深度 | +| `name` | str | None | 自定义发布名称 | + +**发布名称优先级**:`@topic_config(name=...)` > `get_` 前缀去除 > 方法名 + **特点**: -- 使用`@property`装饰器 -- 只读,不能有参数 -- 自动添加到注册表的`status_types` +- 必须使用 `@topic_config` 装饰器 +- 支持 `@property` 和普通方法 +- 添加到注册表的 `status_types` - 定期发布到 ROS2 topic +> **⚠️ 重要:** 仅有 `@property` 装饰器而没有 `@topic_config` 的属性**不会**被广播。这是一个 Breaking Change。 + ### 动作方法 动作方法是设备可以执行的操作: @@ -497,6 +534,7 @@ class LiquidHandler: self._status = "idle" @property + @topic_config() def status(self) -> str: return self._status @@ -886,7 +924,52 @@ class MyDevice: ## 最佳实践 -### 1. 类型注解 +### 1. 使用 `@device` 装饰器标识设备类 + +```python +from unilabos.registry.decorators import device + +@device(id="my_device", category=["heating"], description="My Heating Device", icon="heater.webp") +class MyDevice: + ... +``` + +- `id`:设备唯一标识符,用于注册表匹配 +- `category`:分类列表,前端用于分组显示 +- `description`:设备描述 +- `icon`:图标文件名(可选) + +### 2. 使用 `@topic_config` 声明需要广播的状态 + +```python +from unilabos.registry.decorators import topic_config + +# ✓ @property + @topic_config → 会广播 +@property +@topic_config(period=2.0) +def temperature(self) -> float: + return self._temp + +# ✓ 普通方法 + @topic_config → 会广播(get_ 前缀自动去除) +@topic_config(period=10.0) +def get_sensor_data(self) -> Dict[str, float]: + return {"temp": self._temp} + +# ✓ 使用 name 参数自定义发布名称 +@property +@topic_config(name="ready") +def is_ready(self) -> bool: + return self._status == "idle" + +# ✗ 仅有 @property,没有 @topic_config → 不会广播 +@property +def internal_state(self) -> str: + return self._state +``` + +> **注意:** 与 `@property` 连用时,`@topic_config` 必须放在 `@property` 下面。 + +### 3. 类型注解 ```python from typing import Dict, Any, Optional, List @@ -901,7 +984,7 @@ def method( pass ``` -### 2. 文档字符串 +### 4. 文档字符串 ```python def method(self, param: float) -> Dict[str, Any]: @@ -923,7 +1006,7 @@ def method(self, param: float) -> Dict[str, Any]: pass ``` -### 3. 配置验证 +### 5. 配置验证 ```python def __init__(self, config: Dict[str, Any]): @@ -937,7 +1020,7 @@ def __init__(self, config: Dict[str, Any]): self.baudrate = config['baudrate'] ``` -### 4. 资源清理 +### 6. 资源清理 ```python def __del__(self): @@ -946,7 +1029,7 @@ def __del__(self): self.connection.close() ``` -### 5. 设计前端友好的返回值 +### 7. 设计前端友好的返回值 **记住:返回值会直接显示在 Web 界面** diff --git a/docs/developer_guide/add_registry.md b/docs/developer_guide/add_registry.md index 36caa943..38d3f893 100644 --- a/docs/developer_guide/add_registry.md +++ b/docs/developer_guide/add_registry.md @@ -422,18 +422,20 @@ placeholder_keys: ### status_types -系统会扫描你的 Python 类,从状态方法(property 或 get\_方法)自动生成这部分: +系统会扫描你的 Python 类,从带有 `@topic_config` 装饰器的 `@property` 或方法自动生成这部分: ```yaml status_types: - current_temperature: float # 从 get_current_temperature() 或 @property current_temperature - is_heating: bool # 从 get_is_heating() 或 @property is_heating - status: str # 从 get_status() 或 @property status + current_temperature: float # 从 @topic_config 装饰的 @property 或方法 + is_heating: bool + status: str ``` **注意事项**: -- 系统会查找所有 `get_` 开头的方法和 `@property` 装饰的属性 +- 仅有带 `@topic_config` 装饰器的 `@property` 或方法才会被识别为状态属性 +- 没有 `@topic_config` 的 `@property` 不会生成 status_types,也不会广播 +- `get_` 前缀的方法名会自动去除前缀(如 `get_temperature` → `temperature`) - 类型会自动转成相应的类型(如 `str`、`float`、`bool`) - 如果类型是 `Any`、`None` 或未知的,默认使用 `String` @@ -537,11 +539,13 @@ class AdvancedLiquidHandler: self._temperature = 25.0 @property + @topic_config() def status(self) -> str: """设备状态""" return self._status @property + @topic_config() def temperature(self) -> float: """当前温度""" return self._temperature @@ -809,21 +813,23 @@ my_temperature_controller: 你的设备类需要符合以下要求: ```python -from unilabos.common.device_base import DeviceBase +from unilabos.registry.decorators import device, topic_config -class MyDevice(DeviceBase): +@device(id="my_device", category=["temperature"], description="My Device") +class MyDevice: def __init__(self, config): """初始化,参数会自动分析到 init_param_schema.config""" - super().__init__(config) self.port = config.get('port', '/dev/ttyUSB0') - # 状态方法(会自动生成到 status_types) + # 状态方法(必须添加 @topic_config 才会生成到 status_types 并广播) @property + @topic_config() def status(self): """返回设备状态""" return "idle" @property + @topic_config() def temperature(self): """返回当前温度""" return 25.0 @@ -1039,7 +1045,34 @@ resource.type # "resource" ### 代码规范 -1. **始终使用类型注解** +1. **使用 `@device` 装饰器标识设备类** + +```python +from unilabos.registry.decorators import device + +@device(id="my_device", category=["heating"], description="My Device") +class MyDevice: + ... +``` + +2. **使用 `@topic_config` 声明广播属性** + +```python +from unilabos.registry.decorators import topic_config + +# ✓ 需要广播的状态属性 +@property +@topic_config(period=2.0) +def temperature(self) -> float: + return self._temp + +# ✗ 仅有 @property 不会广播 +@property +def internal_counter(self) -> int: + return self._counter +``` + +3. **始终使用类型注解** ```python # ✓ 好 @@ -1051,7 +1084,7 @@ def method(self, resource, device): pass ``` -2. **提供有意义的参数名** +4. **提供有意义的参数名** ```python # ✓ 好 - 清晰的参数名 @@ -1063,7 +1096,7 @@ def transfer(self, r1: ResourceSlot, r2: ResourceSlot): pass ``` -3. **使用 Optional 表示可选参数** +5. **使用 Optional 表示可选参数** ```python from typing import Optional @@ -1076,7 +1109,7 @@ def method( pass ``` -4. **添加详细的文档字符串** +6. **添加详细的文档字符串** ```python def method( @@ -1096,13 +1129,13 @@ def method( pass ``` -5. **方法命名规范** +7. **方法命名规范** - - 状态方法使用 `@property` 装饰器或 `get_` 前缀 + - 状态方法使用 `@property` + `@topic_config` 装饰器,或普通方法 + `@topic_config` - 动作方法使用动词开头 - 保持命名清晰、一致 -6. **完善的错误处理** +8. **完善的错误处理** - 实现完善的错误处理 - 添加日志记录 - 提供有意义的错误信息 diff --git a/docs/developer_guide/networking_overview.md b/docs/developer_guide/networking_overview.md index 40b308d3..19f16312 100644 --- a/docs/developer_guide/networking_overview.md +++ b/docs/developer_guide/networking_overview.md @@ -221,10 +221,10 @@ Laboratory A Laboratory B ```bash # 实验室A -unilab --ak your_ak --sk your_sk --upload_registry --use_remote_resource +unilab --ak your_ak --sk your_sk --upload_registry # 实验室B -unilab --ak your_ak --sk your_sk --upload_registry --use_remote_resource +unilab --ak your_ak --sk your_sk --upload_registry ``` --- diff --git a/docs/user_guide/launch.md b/docs/user_guide/launch.md index 402e39aa..34caa5b9 100644 --- a/docs/user_guide/launch.md +++ b/docs/user_guide/launch.md @@ -22,7 +22,6 @@ options: --is_slave Run the backend as slave node (without host privileges). --slave_no_host Skip waiting for host service in slave mode --upload_registry Upload registry information when starting unilab - --use_remote_resource Use remote resources when starting unilab --config CONFIG Configuration file path, supports .py format Python config files --port PORT Port for web service information page --disable_browser Disable opening information page on startup @@ -85,7 +84,7 @@ Uni-Lab 的启动过程分为以下几个阶段: 支持两种方式: - **本地文件**:使用 `-g` 指定图谱文件(支持 JSON 和 GraphML 格式) -- **远程资源**:使用 `--use_remote_resource` 从云端获取 +- **远程资源**:不指定本地文件即可 ### 7. 注册表构建 @@ -196,7 +195,7 @@ unilab --config path/to/your/config.py unilab --ak your_ak --sk your_sk -g path/to/graph.json --upload_registry # 使用远程资源启动 -unilab --ak your_ak --sk your_sk --use_remote_resource +unilab --ak your_ak --sk your_sk # 更新注册表 unilab --ak your_ak --sk your_sk --complete_registry diff --git a/unilabos/app/main.py b/unilabos/app/main.py index 93751262..6b507c6e 100644 --- a/unilabos/app/main.py +++ b/unilabos/app/main.py @@ -4,6 +4,7 @@ import os import platform import shutil import signal +import subprocess import sys import threading import time @@ -25,6 +26,84 @@ from unilabos.config.config import load_config, BasicConfig, HTTPConfig _restart_requested: bool = False _restart_reason: str = "" +RESTART_EXIT_CODE = 42 + + +def _build_child_argv(): + """Build sys.argv for child process, stripping supervisor-only arguments.""" + result = [] + skip_next = False + for arg in sys.argv: + if skip_next: + skip_next = False + continue + if arg in ("--restart_mode", "--restart-mode"): + continue + if arg in ("--auto_restart_count", "--auto-restart-count"): + skip_next = True + continue + if arg.startswith("--auto_restart_count=") or arg.startswith("--auto-restart-count="): + continue + result.append(arg) + return result + + +def _run_as_supervisor(max_restarts: int): + """ + Supervisor process that spawns and monitors child processes. + + Similar to Uvicorn's --reload: the supervisor itself does no heavy work, + it only launches the real process as a child and restarts it when the child + exits with RESTART_EXIT_CODE. + """ + child_argv = [sys.executable] + _build_child_argv() + restart_count = 0 + + print_status( + f"[Supervisor] Restart mode enabled (max restarts: {max_restarts}), " + f"child command: {' '.join(child_argv)}", + "info", + ) + + while True: + print_status( + f"[Supervisor] Launching process (restart {restart_count}/{max_restarts})...", + "info", + ) + + try: + process = subprocess.Popen(child_argv) + exit_code = process.wait() + except KeyboardInterrupt: + print_status("[Supervisor] Interrupted, terminating child process...", "info") + process.terminate() + try: + process.wait(timeout=10) + except subprocess.TimeoutExpired: + process.kill() + process.wait() + sys.exit(1) + + if exit_code == RESTART_EXIT_CODE: + restart_count += 1 + if restart_count > max_restarts: + print_status( + f"[Supervisor] Maximum restart count ({max_restarts}) reached, exiting", + "warning", + ) + sys.exit(1) + print_status( + f"[Supervisor] Child requested restart ({restart_count}/{max_restarts}), restarting in 2s...", + "info", + ) + time.sleep(2) + else: + if exit_code != 0: + print_status(f"[Supervisor] Child exited with code {exit_code}", "warning") + else: + print_status("[Supervisor] Child exited normally", "info") + sys.exit(exit_code) + def load_config_from_file(config_path): if config_path is None: @@ -66,6 +145,13 @@ def parse_args(): action="append", help="Path to the registry directory", ) + parser.add_argument( + "--devices", + type=str, + default=None, + action="append", + help="Path to Python code directory for AST-based device/resource scanning", + ) parser.add_argument( "--working_dir", type=str, @@ -155,12 +241,6 @@ def parse_args(): action="store_true", help="Skip environment dependency check on startup", ) - parser.add_argument( - "--complete_registry", - action="store_true", - default=False, - help="Complete registry information", - ) parser.add_argument( "--check_mode", action="store_true", @@ -178,6 +258,24 @@ def parse_args(): default=False, help="Test mode: all actions simulate execution and return mock results without running real hardware", ) + parser.add_argument( + "--extra_resource", + action="store_true", + default=False, + help="Load extra lab_ prefixed labware resources (529 auto-generated definitions from lab_resources.py)", + ) + parser.add_argument( + "--restart_mode", + action="store_true", + default=False, + help="Enable supervisor mode: automatically restart the process when triggered via WebSocket", + ) + parser.add_argument( + "--auto_restart_count", + type=int, + default=500, + help="Maximum number of automatic restarts in restart mode (default: 500)", + ) # workflow upload subcommand workflow_parser = subparsers.add_parser( "workflow_upload", @@ -228,6 +326,11 @@ def main(): args = parser.parse_args() args_dict = vars(args) + # Supervisor mode: spawn child processes and monitor for restart + if args_dict.get("restart_mode", False): + _run_as_supervisor(args_dict.get("auto_restart_count", 5)) + return + # 环境检查 - 检查并自动安装必需的包 (可选) skip_env_check = args_dict.get("skip_env_check", False) check_mode = args_dict.get("check_mode", False) @@ -358,6 +461,9 @@ def main(): BasicConfig.test_mode = args_dict.get("test_mode", False) if BasicConfig.test_mode: print_status("启用测试模式:所有动作将模拟执行,不调用真实硬件", "warning") + BasicConfig.extra_resource = args_dict.get("extra_resource", False) + if BasicConfig.extra_resource: + print_status("启用额外资源加载:将加载lab_开头的labware资源定义", "info") BasicConfig.communication_protocol = "websocket" machine_name = platform.node() machine_name = "".join([c if c.isalnum() or c == "_" else "_" for c in machine_name]) @@ -382,22 +488,30 @@ def main(): # 显示启动横幅 print_unilab_banner(args_dict) - # 注册表 - check_mode 时强制启用 complete_registry - complete_registry = args_dict.get("complete_registry", False) or check_mode - lab_registry = build_registry(args_dict["registry_path"], complete_registry, BasicConfig.upload_registry) + # Step 0: AST 分析优先 + YAML 注册表加载 + # check_mode 和 upload_registry 都会执行实际 import 验证 + devices_dirs = args_dict.get("devices", None) + lab_registry = build_registry( + registry_paths=args_dict["registry_path"], + devices_dirs=devices_dirs, + upload_registry=BasicConfig.upload_registry, + check_mode=check_mode, + ) - # Check mode: complete_registry 完成后直接退出,git diff 检测由 CI workflow 执行 + # Check mode: 注册表验证完成后直接退出 if check_mode: - print_status("Check mode: complete_registry 完成,退出", "info") + device_count = len(lab_registry.device_type_registry) + resource_count = len(lab_registry.resource_type_registry) + print_status(f"Check mode: 注册表验证完成 ({device_count} 设备, {resource_count} 资源),退出", "info") os._exit(0) + # Step 1: 上传全部注册表到服务端,同步保存到 unilabos_data if BasicConfig.upload_registry: - # 设备注册到服务端 - 需要 ak 和 sk if BasicConfig.ak and BasicConfig.sk: - print_status("开始注册设备到服务端...", "info") + # print_status("开始注册设备到服务端...", "info") try: register_devices_and_resources(lab_registry) - print_status("设备注册完成", "info") + # print_status("设备注册完成", "info") except Exception as e: print_status(f"设备注册失败: {e}", "error") else: @@ -482,7 +596,7 @@ def main(): continue # 如果从远端获取了物料信息,则与本地物料进行同步 - if request_startup_json and "nodes" in request_startup_json: + if file_path is not None and request_startup_json and "nodes" in request_startup_json: print_status("开始同步远端物料到本地...", "info") remote_tree_set = ResourceTreeSet.from_raw_dict_list(request_startup_json["nodes"]) resource_tree_set.merge_remote_resources(remote_tree_set) @@ -579,6 +693,10 @@ def main(): open_browser=not args_dict["disable_browser"], port=BasicConfig.port, ) + if restart_requested: + print_status("[Main] Restart requested, cleaning up...", "info") + cleanup_for_restart() + os._exit(RESTART_EXIT_CODE) if __name__ == "__main__": diff --git a/unilabos/app/register.py b/unilabos/app/register.py index 5918b43a..69355da9 100644 --- a/unilabos/app/register.py +++ b/unilabos/app/register.py @@ -1,60 +1,83 @@ import json import time -from typing import Optional, Tuple, Dict, Any +from typing import Any, Dict, Optional, Tuple from unilabos.utils.log import logger from unilabos.utils.type_check import TypeEncoder +try: + import orjson + + def _normalize_device(info: dict) -> dict: + """Serialize via orjson to strip non-JSON types (type objects etc.).""" + return orjson.loads(orjson.dumps(info, default=str)) +except ImportError: + def _normalize_device(info: dict) -> dict: + return json.loads(json.dumps(info, ensure_ascii=False, cls=TypeEncoder)) + def register_devices_and_resources(lab_registry, gather_only=False) -> Optional[Tuple[Dict[str, Any], Dict[str, Any]]]: """ 注册设备和资源到服务器(仅支持HTTP) """ - # 注册资源信息 - 使用HTTP方式 from unilabos.app.web.client import http_client logger.info("[UniLab Register] 开始注册设备和资源...") - # 注册设备信息 devices_to_register = {} for device_info in lab_registry.obtain_registry_device_info(): - devices_to_register[device_info["id"]] = json.loads( - json.dumps(device_info, ensure_ascii=False, cls=TypeEncoder) - ) - logger.debug(f"[UniLab Register] 收集设备: {device_info['id']}") + devices_to_register[device_info["id"]] = _normalize_device(device_info) + logger.trace(f"[UniLab Register] 收集设备: {device_info['id']}") resources_to_register = {} for resource_info in lab_registry.obtain_registry_resource_info(): resources_to_register[resource_info["id"]] = resource_info - logger.debug(f"[UniLab Register] 收集资源: {resource_info['id']}") + logger.trace(f"[UniLab Register] 收集资源: {resource_info['id']}") if gather_only: return devices_to_register, resources_to_register - # 注册设备 + if devices_to_register: try: start_time = time.time() - response = http_client.resource_registry({"resources": list(devices_to_register.values())}) + response = http_client.resource_registry( + {"resources": list(devices_to_register.values())}, + tag="device_registry", + ) cost_time = time.time() - start_time - if response.status_code in [200, 201]: - logger.info(f"[UniLab Register] 成功注册 {len(devices_to_register)} 个设备 {cost_time}s") + res_data = response.json() if response.status_code == 200 else {} + skipped = res_data.get("data", {}).get("skipped", False) + if skipped: + logger.info( + f"[UniLab Register] 设备注册跳过(内容未变化)" + f" {len(devices_to_register)} 个 {cost_time:.3f}s" + ) + elif response.status_code in [200, 201]: + logger.info(f"[UniLab Register] 成功注册 {len(devices_to_register)} 个设备 {cost_time:.3f}s") else: - logger.error(f"[UniLab Register] 设备注册失败: {response.status_code}, {response.text} {cost_time}s") + logger.error(f"[UniLab Register] 设备注册失败: {response.status_code}, {response.text} {cost_time:.3f}s") except Exception as e: logger.error(f"[UniLab Register] 设备注册异常: {e}") - # 注册资源 if resources_to_register: try: start_time = time.time() - response = http_client.resource_registry({"resources": list(resources_to_register.values())}) + response = http_client.resource_registry( + {"resources": list(resources_to_register.values())}, + tag="resource_registry", + ) cost_time = time.time() - start_time - if response.status_code in [200, 201]: - logger.info(f"[UniLab Register] 成功注册 {len(resources_to_register)} 个资源 {cost_time}s") + res_data = response.json() if response.status_code == 200 else {} + skipped = res_data.get("data", {}).get("skipped", False) + if skipped: + logger.info( + f"[UniLab Register] 资源注册跳过(内容未变化)" + f" {len(resources_to_register)} 个 {cost_time:.3f}s" + ) + elif response.status_code in [200, 201]: + logger.info(f"[UniLab Register] 成功注册 {len(resources_to_register)} 个资源 {cost_time:.3f}s") else: - logger.error(f"[UniLab Register] 资源注册失败: {response.status_code}, {response.text} {cost_time}s") + logger.error(f"[UniLab Register] 资源注册失败: {response.status_code}, {response.text} {cost_time:.3f}s") except Exception as e: logger.error(f"[UniLab Register] 资源注册异常: {e}") - - logger.info("[UniLab Register] 设备和资源注册完成.") diff --git a/unilabos/app/web/api.py b/unilabos/app/web/api.py index a67d09d2..99981f77 100644 --- a/unilabos/app/web/api.py +++ b/unilabos/app/web/api.py @@ -1052,7 +1052,7 @@ async def handle_file_import(websocket: WebSocket, request_data: dict): "result": {}, "schema": lab_registry._generate_unilab_json_command_schema(v["args"], k), "goal_default": {i["name"]: i["default"] for i in v["args"]}, - "handles": [], + "handles": {}, } # 不生成已配置action的动作 for k, v in enhanced_info["action_methods"].items() diff --git a/unilabos/app/web/client.py b/unilabos/app/web/client.py index 75b9e343..41e32514 100644 --- a/unilabos/app/web/client.py +++ b/unilabos/app/web/client.py @@ -8,6 +8,25 @@ import json import os from typing import List, Dict, Any, Optional +try: + import orjson as _json_fast + + def _fast_dumps(obj, **kwargs) -> bytes: + return _json_fast.dumps(obj, option=_json_fast.OPT_NON_STR_KEYS, default=str) + + def _fast_dumps_pretty(obj, **kwargs) -> bytes: + return _json_fast.dumps( + obj, option=_json_fast.OPT_NON_STR_KEYS | _json_fast.OPT_INDENT_2, default=str, + ) +except ImportError: + _json_fast = None # type: ignore[assignment] + + def _fast_dumps(obj, **kwargs) -> bytes: + return json.dumps(obj, ensure_ascii=False, default=str).encode("utf-8") + + def _fast_dumps_pretty(obj, **kwargs) -> bytes: + return json.dumps(obj, indent=2, ensure_ascii=False, default=str).encode("utf-8") + import requests from unilabos.resources.resource_tracker import ResourceTreeSet from unilabos.utils.log import info @@ -280,29 +299,54 @@ class HTTPClient: ) return response - def resource_registry(self, registry_data: Dict[str, Any] | List[Dict[str, Any]]) -> requests.Response: + def resource_registry( + self, registry_data: Dict[str, Any] | List[Dict[str, Any]], tag: str = "registry", + ) -> requests.Response: """ - 注册资源到服务器 + 注册资源到服务器,同步保存请求/响应到 unilabos_data Args: registry_data: 注册表数据,格式为 {resource_id: resource_info} / [{resource_info}] + tag: 保存文件的标签后缀 (如 "device_registry" / "resource_registry") Returns: Response: API响应对象 """ - compressed_body = gzip.compress( - json.dumps(registry_data, ensure_ascii=False, default=str).encode("utf-8") - ) + # 序列化一次,同时用于保存和发送 + json_bytes = _fast_dumps(registry_data) + + # 保存请求数据到 unilabos_data + req_path = os.path.join(BasicConfig.working_dir, f"req_{tag}_upload.json") + try: + os.makedirs(BasicConfig.working_dir, exist_ok=True) + with open(req_path, "wb") as f: + f.write(_fast_dumps_pretty(registry_data)) + logger.trace(f"注册表请求数据已保存: {req_path}") + except Exception as e: + logger.warning(f"保存注册表请求数据失败: {e}") + + compressed_body = gzip.compress(json_bytes) + headers = { + "Authorization": f"Lab {self.auth}", + "Content-Type": "application/json", + "Content-Encoding": "gzip", + } response = requests.post( f"{self.remote_addr}/lab/resource", data=compressed_body, - headers={ - "Authorization": f"Lab {self.auth}", - "Content-Type": "application/json", - "Content-Encoding": "gzip", - }, + headers=headers, timeout=30, ) + + # 保存响应数据到 unilabos_data + res_path = os.path.join(BasicConfig.working_dir, f"res_{tag}_upload.json") + try: + with open(res_path, "w", encoding="utf-8") as f: + f.write(f"{response.status_code}\n{response.text}") + logger.trace(f"注册表响应数据已保存: {res_path}") + except Exception as e: + logger.warning(f"保存注册表响应数据失败: {e}") + if response.status_code not in [200, 201]: logger.error(f"注册资源失败: {response.status_code}, {response.text}") if response.status_code == 200: diff --git a/unilabos/app/web/server.py b/unilabos/app/web/server.py index 8d090162..981edeca 100644 --- a/unilabos/app/web/server.py +++ b/unilabos/app/web/server.py @@ -86,7 +86,7 @@ def setup_server() -> FastAPI: # 设置页面路由 try: setup_web_pages(pages) - info("[Web] 已加载Web UI模块") + # info("[Web] 已加载Web UI模块") except ImportError as e: info(f"[Web] 未找到Web页面模块: {str(e)}") except Exception as e: @@ -138,7 +138,7 @@ def start_server(host: str = "0.0.0.0", port: int = 8002, open_browser: bool = T server_thread = threading.Thread(target=server.run, daemon=True, name="uvicorn_server") server_thread.start() - info("[Web] Server started, monitoring for restart requests...") + # info("[Web] Server started, monitoring for restart requests...") # 监控重启标志 import unilabos.app.main as main_module diff --git a/unilabos/app/ws_client.py b/unilabos/app/ws_client.py index faaa3075..cbbb58ef 100644 --- a/unilabos/app/ws_client.py +++ b/unilabos/app/ws_client.py @@ -26,6 +26,7 @@ from enum import Enum from typing_extensions import TypedDict from unilabos.app.model import JobAddReq +from unilabos.resources.resource_tracker import ResourceDictType from unilabos.ros.nodes.presets.host_node import HostNode from unilabos.utils.type_check import serialize_result_info from unilabos.app.communication import BaseCommunicationClient @@ -408,6 +409,7 @@ class MessageProcessor: # 线程控制 self.is_running = False self.thread = None + self._loop = None # asyncio event loop引用,用于外部关闭websocket self.reconnect_count = 0 logger.info(f"[MessageProcessor] Initialized for URL: {websocket_url}") @@ -434,22 +436,31 @@ class MessageProcessor: def stop(self) -> None: """停止消息处理线程""" self.is_running = False + # 主动关闭websocket以快速中断消息接收循环 + ws = self.websocket + loop = self._loop + if ws and loop and loop.is_running(): + try: + asyncio.run_coroutine_threadsafe(ws.close(), loop) + except Exception: + pass if self.thread and self.thread.is_alive(): self.thread.join(timeout=2) logger.info("[MessageProcessor] Stopped") def _run(self): """运行消息处理主循环""" - loop = asyncio.new_event_loop() + self._loop = asyncio.new_event_loop() try: - asyncio.set_event_loop(loop) - loop.run_until_complete(self._connection_handler()) + asyncio.set_event_loop(self._loop) + self._loop.run_until_complete(self._connection_handler()) except Exception as e: logger.error(f"[MessageProcessor] Thread error: {str(e)}") logger.error(traceback.format_exc()) finally: - if loop: - loop.close() + if self._loop: + self._loop.close() + self._loop = None async def _connection_handler(self): """处理WebSocket连接和重连逻辑""" @@ -648,6 +659,10 @@ class MessageProcessor: # elif message_type == "session_id": # self.session_id = message_data.get("session_id") # logger.info(f"[MessageProcessor] Session ID: {self.session_id}") + elif message_type == "add_device": + await self._handle_device_manage(message_data, "add") + elif message_type == "remove_device": + await self._handle_device_manage(message_data, "remove") elif message_type == "request_restart": await self._handle_request_restart(message_data) else: @@ -984,6 +999,37 @@ class MessageProcessor: ) thread.start() + async def _handle_device_manage(self, device_list: list[ResourceDictType], action: str): + """Handle add_device / remove_device from LabGo server.""" + if not device_list: + return + + for item in device_list: + target_node_id = item.get("target_node_id", "host_node") + + def _notify(target_id: str, act: str, cfg: ResourceDictType): + try: + host_node = HostNode.get_instance(timeout=5) + if not host_node: + logger.error(f"[DeviceManage] HostNode not available for {act}_device") + return + success = host_node.notify_device_manage(target_id, act, cfg) + if success: + logger.info(f"[DeviceManage] {act}_device completed on {target_id}") + else: + logger.warning(f"[DeviceManage] {act}_device failed on {target_id}") + except Exception as e: + logger.error(f"[DeviceManage] Error in {act}_device: {e}") + logger.error(traceback.format_exc()) + + thread = threading.Thread( + target=_notify, + args=(target_node_id, action, item), + daemon=True, + name=f"DeviceManage-{action}-{item.get('id', '')}", + ) + thread.start() + async def _handle_request_restart(self, data: Dict[str, Any]): """ 处理重启请求 @@ -995,10 +1041,9 @@ class MessageProcessor: logger.info(f"[MessageProcessor] Received restart request, reason: {reason}, delay: {delay}s") # 发送确认消息 - if self.websocket_client: - await self.websocket_client.send_message( - {"action": "restart_acknowledged", "data": {"reason": reason, "delay": delay}} - ) + self.send_message( + {"action": "restart_acknowledged", "data": {"reason": reason, "delay": delay}} + ) # 设置全局重启标志 import unilabos.app.main as main_module @@ -1100,6 +1145,7 @@ class QueueProcessor: def stop(self) -> None: """停止队列处理线程""" self.is_running = False + self.queue_update_event.set() # 立即唤醒等待中的线程 if self.thread and self.thread.is_alive(): self.thread.join(timeout=2) logger.info("[QueueProcessor] Stopped") @@ -1353,8 +1399,8 @@ class WebSocketClient(BaseCommunicationClient): message = {"action": "normal_exit", "data": {"session_id": session_id}} self.message_processor.send_message(message) logger.info(f"[WebSocketClient] Sent normal_exit message with session_id: {session_id}") - # 给一点时间让消息发送出去 - time.sleep(1) + # send_handler 每100ms检查一次队列,等300ms足以让消息发出 + time.sleep(0.3) except Exception as e: logger.warning(f"[WebSocketClient] Failed to send normal_exit message: {str(e)}") diff --git a/unilabos/config/config.py b/unilabos/config/config.py index d66b399d..b80d3b60 100644 --- a/unilabos/config/config.py +++ b/unilabos/config/config.py @@ -24,6 +24,7 @@ class BasicConfig: port = 8002 # 本地HTTP服务 check_mode = False # CI 检查模式,用于验证 registry 导入和文件一致性 test_mode = False # 测试模式,所有动作不实际执行,返回模拟结果 + extra_resource = False # 是否加载lab_开头的额外资源 # 'TRACE', 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL' log_level: Literal["TRACE", "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "DEBUG" diff --git a/unilabos/device_comms/universal_driver.py b/unilabos/device_comms/universal_driver.py index 281e0cd9..0ff41805 100644 --- a/unilabos/device_comms/universal_driver.py +++ b/unilabos/device_comms/universal_driver.py @@ -1,4 +1,3 @@ - from abc import abstractmethod from functools import wraps import inspect diff --git a/unilabos/devices/virtual/workbench.py b/unilabos/devices/virtual/workbench.py index f5fae47e..d67db398 100644 --- a/unilabos/devices/virtual/workbench.py +++ b/unilabos/devices/virtual/workbench.py @@ -1,15 +1,15 @@ """ Virtual Workbench Device - 模拟工作台设备 -包含: +包含: - 1个机械臂 (每次操作3s, 独占锁) - 3个加热台 (每次加热10s, 可并行) -工作流程: -1. A1-A5 物料同时启动,竞争机械臂 +工作流程: +1. A1-A5 物料同时启动, 竞争机械臂 2. 机械臂将物料移动到空闲加热台 -3. 加热完成后,机械臂将物料移动到C1-C5 +3. 加热完成后, 机械臂将物料移动到C1-C5 -注意:调用来自线程池,使用 threading.Lock 进行同步 +注意: 调用来自线程池, 使用 threading.Lock 进行同步 """ import logging @@ -21,9 +21,11 @@ from threading import Lock, RLock from typing_extensions import TypedDict +from unilabos.registry.decorators import ( + device, action, ActionInputHandle, ActionOutputHandle, DataSource, topic_config, not_action +) from unilabos.ros.nodes.base_device_node import BaseROS2DeviceNode -from unilabos.utils.decorator import not_action, always_free -from unilabos.resources.resource_tracker import SampleUUIDsType, LabSample, RETURN_UNILABOS_SAMPLES +from unilabos.resources.resource_tracker import SampleUUIDsType, LabSample # ============ TypedDict 返回类型定义 ============ @@ -57,6 +59,8 @@ class MoveToOutputResult(TypedDict): success: bool station_id: int material_id: str + output_position: str + message: str unilabos_samples: List[LabSample] @@ -81,9 +85,9 @@ class HeatingStationState(Enum): """加热台状态枚举""" IDLE = "idle" # 空闲 - OCCUPIED = "occupied" # 已放置物料,等待加热 + OCCUPIED = "occupied" # 已放置物料, 等待加热 HEATING = "heating" # 加热中 - COMPLETED = "completed" # 加热完成,等待取走 + COMPLETED = "completed" # 加热完成, 等待取走 class ArmState(Enum): @@ -105,19 +109,24 @@ class HeatingStation: heating_progress: float = 0.0 +@device( + id="virtual_workbench", + category=["virtual_device"], + description="Virtual Workbench with 1 robotic arm and 3 heating stations for concurrent material processing", +) class VirtualWorkbench: """ Virtual Workbench Device - 虚拟工作台设备 模拟一个包含1个机械臂和3个加热台的工作站 - - 机械臂操作耗时3秒,同一时间只能执行一个操作 - - 加热台加热耗时10秒,3个加热台可并行工作 + - 机械臂操作耗时3秒, 同一时间只能执行一个操作 + - 加热台加热耗时10秒, 3个加热台可并行工作 工作流: - 1. 物料A1-A5并发启动(线程池),竞争机械臂使用权 - 2. 获取机械臂后,查找空闲加热台 - 3. 机械臂将物料放入加热台,开始加热 - 4. 加热完成后,机械臂将物料移动到目标位置Cn + 1. 物料A1-A5并发启动(线程池), 竞争机械臂使用权 + 2. 获取机械臂后, 查找空闲加热台 + 3. 机械臂将物料放入加热台, 开始加热 + 4. 加热完成后, 机械臂将物料移动到目标位置Cn """ _ros_node: BaseROS2DeviceNode @@ -145,19 +154,19 @@ class VirtualWorkbench: self.HEATING_TIME = float(self.config.get("heating_time", self.HEATING_TIME)) self.NUM_HEATING_STATIONS = int(self.config.get("num_heating_stations", self.NUM_HEATING_STATIONS)) - # 机械臂状态和锁 (使用threading.Lock) + # 机械臂状态和锁 self._arm_lock = Lock() self._arm_state = ArmState.IDLE self._arm_current_task: Optional[str] = None - # 加热台状态 (station_id -> HeatingStation) - 立即初始化,不依赖initialize() + # 加热台状态 self._heating_stations: Dict[int, HeatingStation] = { i: HeatingStation(station_id=i) for i in range(1, self.NUM_HEATING_STATIONS + 1) } - self._stations_lock = RLock() # 可重入锁,保护加热台状态 + self._stations_lock = RLock() # 任务追踪 - self._active_tasks: Dict[str, Dict[str, Any]] = {} # material_id -> task_info + self._active_tasks: Dict[str, Dict[str, Any]] = {} self._tasks_lock = Lock() # 处理其他kwargs参数 @@ -183,7 +192,6 @@ class VirtualWorkbench: """初始化虚拟工作台""" self.logger.info(f"初始化虚拟工作台 {self.device_id}") - # 重置加热台状态 (已在__init__中创建,这里重置为初始状态) with self._stations_lock: for station in self._heating_stations.values(): station.state = HeatingStationState.IDLE @@ -191,7 +199,6 @@ class VirtualWorkbench: station.material_number = None station.heating_progress = 0.0 - # 初始化状态 self.data.update( { "status": "Ready", @@ -257,11 +264,7 @@ class VirtualWorkbench: self.data["message"] = message def _find_available_heating_station(self) -> Optional[int]: - """查找空闲的加热台 - - Returns: - 空闲加热台ID,如果没有则返回None - """ + """查找空闲的加热台""" with self._stations_lock: for station_id, station in self._heating_stations.items(): if station.state == HeatingStationState.IDLE: @@ -269,23 +272,12 @@ class VirtualWorkbench: return None def _acquire_arm(self, task_description: str) -> bool: - """获取机械臂使用权(阻塞直到获取) - - Args: - task_description: 任务描述,用于日志 - - Returns: - 是否成功获取 - """ + """获取机械臂使用权(阻塞直到获取)""" self.logger.info(f"[{task_description}] 等待获取机械臂...") - - # 阻塞等待获取锁 self._arm_lock.acquire() - self._arm_state = ArmState.BUSY self._arm_current_task = task_description self._update_data_status(f"机械臂执行: {task_description}") - self.logger.info(f"[{task_description}] 成功获取机械臂使用权") return True @@ -298,6 +290,22 @@ class VirtualWorkbench: self._update_data_status(f"机械臂已释放 (完成: {task})") self.logger.info(f"机械臂已释放 (完成: {task})") + @action( + auto_prefix=True, + description="批量准备物料 - 虚拟起始节点, 生成A1-A5物料, 输出5个handle供后续节点使用", + handles=[ + ActionOutputHandle(key="channel_1", data_type="workbench_material", + label="实验1", data_key="material_1", data_source=DataSource.EXECUTOR), + ActionOutputHandle(key="channel_2", data_type="workbench_material", + label="实验2", data_key="material_2", data_source=DataSource.EXECUTOR), + ActionOutputHandle(key="channel_3", data_type="workbench_material", + label="实验3", data_key="material_3", data_source=DataSource.EXECUTOR), + ActionOutputHandle(key="channel_4", data_type="workbench_material", + label="实验4", data_key="material_4", data_source=DataSource.EXECUTOR), + ActionOutputHandle(key="channel_5", data_type="workbench_material", + label="实验5", data_key="material_5", data_source=DataSource.EXECUTOR), + ], + ) def prepare_materials( self, sample_uuids: SampleUUIDsType, @@ -306,19 +314,14 @@ class VirtualWorkbench: """ 批量准备物料 - 虚拟起始节点 - 作为工作流的起始节点,生成指定数量的物料编号供后续节点使用。 - 输出5个handle (material_1 ~ material_5),分别对应实验1~5。 - - Args: - count: 待生成的物料数量,默认5 (生成 A1-A5) - - Returns: - PrepareMaterialsResult: 包含 material_1 ~ material_5 用于传递给 move_to_heating_station + 作为工作流的起始节点, 生成指定数量的物料编号供后续节点使用。 + 输出5个handle (material_1 ~ material_5), 分别对应实验1~5。 """ - # 生成物料列表 A1 - A{count} materials = [i for i in range(1, count + 1)] - self.logger.info(f"[准备物料] 生成 {count} 个物料: " f"A1-A{count} -> material_1~material_{count}") + self.logger.info( + f"[准备物料] 生成 {count} 个物料: A1-A{count} -> material_1~material_{count}" + ) return { "success": True, @@ -329,9 +332,28 @@ class VirtualWorkbench: "material_4": materials[3] if len(materials) > 3 else 0, "material_5": materials[4] if len(materials) > 4 else 0, "message": f"已准备 {count} 个物料: A1-A{count}", - "unilabos_samples": [LabSample(sample_uuid=sample_uuid, oss_path="", extra={"material_uuid": content} if isinstance(content, str) else content.serialize()) for sample_uuid, content in sample_uuids.items()] + "unilabos_samples": [ + LabSample( + sample_uuid=sample_uuid, + oss_path="", + extra={"material_uuid": content} if isinstance(content, str) else (content.serialize() if content else {}), + ) + for sample_uuid, content in sample_uuids.items() + ], } + @action( + auto_prefix=True, + description="将物料从An位置移动到空闲加热台, 返回分配的加热台ID", + handles=[ + ActionInputHandle(key="material_input", data_type="workbench_material", + label="物料编号", data_key="material_number", data_source=DataSource.HANDLE), + ActionOutputHandle(key="heating_station_output", data_type="workbench_station", + label="加热台ID", data_key="station_id", data_source=DataSource.EXECUTOR), + ActionOutputHandle(key="material_number_output", data_type="workbench_material", + label="物料编号", data_key="material_number", data_source=DataSource.EXECUTOR), + ], + ) def move_to_heating_station( self, sample_uuids: SampleUUIDsType, @@ -340,20 +362,12 @@ class VirtualWorkbench: """ 将物料从An位置移动到加热台 - 多线程并发调用时,会竞争机械臂使用权,并自动查找空闲加热台 - - Args: - material_number: 物料编号 (1-5) - - Returns: - MoveToHeatingStationResult: 包含 station_id, material_number 等用于传递给下一个节点 + 多线程并发调用时, 会竞争机械臂使用权, 并自动查找空闲加热台 """ - # 根据物料编号生成物料ID material_id = f"A{material_number}" task_desc = f"移动{material_id}到加热台" self.logger.info(f"[任务] {task_desc} - 开始执行") - # 记录任务 with self._tasks_lock: self._active_tasks[material_id] = { "status": "waiting_for_arm", @@ -361,33 +375,27 @@ class VirtualWorkbench: } try: - # 步骤1: 等待获取机械臂使用权(竞争) with self._tasks_lock: self._active_tasks[material_id]["status"] = "waiting_for_arm" self._acquire_arm(task_desc) - # 步骤2: 查找空闲加热台 with self._tasks_lock: self._active_tasks[material_id]["status"] = "finding_station" station_id = None - # 循环等待直到找到空闲加热台 while station_id is None: station_id = self._find_available_heating_station() if station_id is None: - self.logger.info(f"[{material_id}] 没有空闲加热台,等待中...") - # 释放机械臂,等待后重试 + self.logger.info(f"[{material_id}] 没有空闲加热台, 等待中...") self._release_arm() time.sleep(0.5) self._acquire_arm(task_desc) - # 步骤3: 占用加热台 - 立即标记为OCCUPIED,防止其他任务选择同一加热台 with self._stations_lock: self._heating_stations[station_id].state = HeatingStationState.OCCUPIED self._heating_stations[station_id].current_material = material_id self._heating_stations[station_id].material_number = material_number - # 步骤4: 模拟机械臂移动操作 (3秒) with self._tasks_lock: self._active_tasks[material_id]["status"] = "arm_moving" self._active_tasks[material_id]["assigned_station"] = station_id @@ -395,11 +403,11 @@ class VirtualWorkbench: time.sleep(self.ARM_OPERATION_TIME) - # 步骤5: 放入加热台完成 self._update_data_status(f"{material_id}已放入加热台{station_id}") - self.logger.info(f"[{material_id}] 已放入加热台{station_id} (用时{self.ARM_OPERATION_TIME}s)") + self.logger.info( + f"[{material_id}] 已放入加热台{station_id} (用时{self.ARM_OPERATION_TIME}s)" + ) - # 释放机械臂 self._release_arm() with self._tasks_lock: @@ -412,8 +420,16 @@ class VirtualWorkbench: "material_number": material_number, "message": f"{material_id}已成功移动到加热台{station_id}", "unilabos_samples": [ - LabSample(sample_uuid=sample_uuid, oss_path="", extra={"material_uuid": content} if isinstance(content, str) else content.serialize()) for - sample_uuid, content in sample_uuids.items()] + LabSample( + sample_uuid=sample_uuid, + oss_path="", + extra=( + {"material_uuid": content} + if isinstance(content, str) else (content.serialize() if content else {}) + ), + ) + for sample_uuid, content in sample_uuids.items() + ], } except Exception as e: @@ -427,11 +443,33 @@ class VirtualWorkbench: "material_number": material_number, "message": f"移动失败: {str(e)}", "unilabos_samples": [ - LabSample(sample_uuid=sample_uuid, oss_path="", extra={"material_uuid": content} if isinstance(content, str) else content.serialize()) for - sample_uuid, content in sample_uuids.items()] + LabSample( + sample_uuid=sample_uuid, + oss_path="", + extra=( + {"material_uuid": content} + if isinstance(content, str) else (content.serialize() if content else {}) + ), + ) + for sample_uuid, content in sample_uuids.items() + ], } - @always_free + @action( + auto_prefix=True, + always_free=True, + description="启动指定加热台的加热程序", + handles=[ + ActionInputHandle(key="station_id_input", data_type="workbench_station", + label="加热台ID", data_key="station_id", data_source=DataSource.HANDLE), + ActionInputHandle(key="material_number_input", data_type="workbench_material", + label="物料编号", data_key="material_number", data_source=DataSource.HANDLE), + ActionOutputHandle(key="heating_done_station", data_type="workbench_station", + label="加热完成-加热台ID", data_key="station_id", data_source=DataSource.EXECUTOR), + ActionOutputHandle(key="heating_done_material", data_type="workbench_material", + label="加热完成-物料编号", data_key="material_number", data_source=DataSource.EXECUTOR), + ], + ) def start_heating( self, sample_uuids: SampleUUIDsType, @@ -440,13 +478,6 @@ class VirtualWorkbench: ) -> StartHeatingResult: """ 启动指定加热台的加热程序 - - Args: - station_id: 加热台ID (1-3),从 move_to_heating_station 的 handle 传入 - material_number: 物料编号,从 move_to_heating_station 的 handle 传入 - - Returns: - StartHeatingResult: 包含 station_id, material_number 等用于传递给下一个节点 """ self.logger.info(f"[加热台{station_id}] 开始加热") @@ -458,8 +489,16 @@ class VirtualWorkbench: "material_number": material_number, "message": f"无效的加热台ID: {station_id}", "unilabos_samples": [ - LabSample(sample_uuid=sample_uuid, oss_path="", extra={"material_uuid": content} if isinstance(content, str) else content.serialize()) for - sample_uuid, content in sample_uuids.items()] + LabSample( + sample_uuid=sample_uuid, + oss_path="", + extra=( + {"material_uuid": content} + if isinstance(content, str) else (content.serialize() if content else {}) + ), + ) + for sample_uuid, content in sample_uuids.items() + ], } with self._stations_lock: @@ -473,8 +512,16 @@ class VirtualWorkbench: "material_number": material_number, "message": f"加热台{station_id}上没有物料", "unilabos_samples": [ - LabSample(sample_uuid=sample_uuid, oss_path="", extra={"material_uuid": content} if isinstance(content, str) else content.serialize()) for - sample_uuid, content in sample_uuids.items()] + LabSample( + sample_uuid=sample_uuid, + oss_path="", + extra=( + {"material_uuid": content} + if isinstance(content, str) else (content.serialize() if content else {}) + ), + ) + for sample_uuid, content in sample_uuids.items() + ], } if station.state == HeatingStationState.HEATING: @@ -485,13 +532,20 @@ class VirtualWorkbench: "material_number": material_number, "message": f"加热台{station_id}已经在加热中", "unilabos_samples": [ - LabSample(sample_uuid=sample_uuid, oss_path="", extra={"material_uuid": content} if isinstance(content, str) else content.serialize()) for - sample_uuid, content in sample_uuids.items()] + LabSample( + sample_uuid=sample_uuid, + oss_path="", + extra=( + {"material_uuid": content} + if isinstance(content, str) else (content.serialize() if content else {}) + ), + ) + for sample_uuid, content in sample_uuids.items() + ], } material_id = station.current_material - # 开始加热 station.state = HeatingStationState.HEATING station.heating_start_time = time.time() station.heating_progress = 0.0 @@ -502,7 +556,6 @@ class VirtualWorkbench: self._update_data_status(f"加热台{station_id}开始加热{material_id}") - # 打印当前所有正在加热的台位 with self._stations_lock: heating_list = [ f"加热台{sid}:{s.current_material}" @@ -511,7 +564,6 @@ class VirtualWorkbench: ] self.logger.info(f"[并行加热] 当前同时加热中: {', '.join(heating_list)}") - # 模拟加热过程 start_time = time.time() last_countdown_log = start_time while True: @@ -524,7 +576,6 @@ class VirtualWorkbench: self._update_data_status(f"加热台{station_id}加热中: {progress:.1f}%") - # 每5秒打印一次倒计时 if time.time() - last_countdown_log >= 5.0: self.logger.info(f"[加热台{station_id}] {material_id} 剩余 {remaining:.1f}s") last_countdown_log = time.time() @@ -534,7 +585,6 @@ class VirtualWorkbench: time.sleep(1.0) - # 加热完成 with self._stations_lock: self._heating_stations[station_id].state = HeatingStationState.COMPLETED self._heating_stations[station_id].heating_progress = 100.0 @@ -553,10 +603,28 @@ class VirtualWorkbench: "material_number": material_number, "message": f"加热台{station_id}加热完成", "unilabos_samples": [ - LabSample(sample_uuid=sample_uuid, oss_path="", extra={"material_uuid": content} if isinstance(content, str) else content.serialize()) for - sample_uuid, content in sample_uuids.items()] + LabSample( + sample_uuid=sample_uuid, + oss_path="", + extra=( + {"material_uuid": content} + if isinstance(content, str) else (content.serialize() if content else {}) + ), + ) + for sample_uuid, content in sample_uuids.items() + ], } + @action( + auto_prefix=True, + description="将物料从加热台移动到输出位置Cn", + handles=[ + ActionInputHandle(key="output_station_input", data_type="workbench_station", + label="加热台ID", data_key="station_id", data_source=DataSource.HANDLE), + ActionInputHandle(key="output_material_input", data_type="workbench_material", + label="物料编号", data_key="material_number", data_source=DataSource.HANDLE), + ], + ) def move_to_output( self, sample_uuids: SampleUUIDsType, @@ -565,15 +633,8 @@ class VirtualWorkbench: ) -> MoveToOutputResult: """ 将物料从加热台移动到输出位置Cn - - Args: - station_id: 加热台ID (1-3),从 start_heating 的 handle 传入 - material_number: 物料编号,从 start_heating 的 handle 传入,用于确定输出位置 Cn - - Returns: - MoveToOutputResult: 包含执行结果 """ - output_number = material_number # 物料编号决定输出位置 + output_number = material_number if station_id not in self._heating_stations: return { @@ -583,8 +644,16 @@ class VirtualWorkbench: "output_position": f"C{output_number}", "message": f"无效的加热台ID: {station_id}", "unilabos_samples": [ - LabSample(sample_uuid=sample_uuid, oss_path="", extra={"material_uuid": content} if isinstance(content, str) else content.serialize()) for - sample_uuid, content in sample_uuids.items()] + LabSample( + sample_uuid=sample_uuid, + oss_path="", + extra=( + {"material_uuid": content} + if isinstance(content, str) else (content.serialize() if content else {}) + ), + ) + for sample_uuid, content in sample_uuids.items() + ], } with self._stations_lock: @@ -599,8 +668,16 @@ class VirtualWorkbench: "output_position": f"C{output_number}", "message": f"加热台{station_id}上没有物料", "unilabos_samples": [ - LabSample(sample_uuid=sample_uuid, oss_path="", extra={"material_uuid": content} if isinstance(content, str) else content.serialize()) for - sample_uuid, content in sample_uuids.items()] + LabSample( + sample_uuid=sample_uuid, + oss_path="", + extra=( + {"material_uuid": content} + if isinstance(content, str) else (content.serialize() if content else {}) + ), + ) + for sample_uuid, content in sample_uuids.items() + ], } if station.state != HeatingStationState.COMPLETED: @@ -611,8 +688,16 @@ class VirtualWorkbench: "output_position": f"C{output_number}", "message": f"加热台{station_id}尚未完成加热 (当前状态: {station.state.value})", "unilabos_samples": [ - LabSample(sample_uuid=sample_uuid, oss_path="", extra={"material_uuid": content} if isinstance(content, str) else content.serialize()) for - sample_uuid, content in sample_uuids.items()] + LabSample( + sample_uuid=sample_uuid, + oss_path="", + extra=( + {"material_uuid": content} + if isinstance(content, str) else (content.serialize() if content else {}) + ), + ) + for sample_uuid, content in sample_uuids.items() + ], } output_position = f"C{output_number}" @@ -624,18 +709,17 @@ class VirtualWorkbench: if material_id in self._active_tasks: self._active_tasks[material_id]["status"] = "waiting_for_arm_output" - # 获取机械臂 self._acquire_arm(task_desc) with self._tasks_lock: if material_id in self._active_tasks: self._active_tasks[material_id]["status"] = "arm_moving_to_output" - # 模拟机械臂操作 (3秒) - self.logger.info(f"[{material_id}] 机械臂正在从加热台{station_id}取出并移动到{output_position}...") + self.logger.info( + f"[{material_id}] 机械臂正在从加热台{station_id}取出并移动到{output_position}..." + ) time.sleep(self.ARM_OPERATION_TIME) - # 清空加热台 with self._stations_lock: self._heating_stations[station_id].state = HeatingStationState.IDLE self._heating_stations[station_id].current_material = None @@ -643,17 +727,17 @@ class VirtualWorkbench: self._heating_stations[station_id].heating_progress = 0.0 self._heating_stations[station_id].heating_start_time = None - # 释放机械臂 self._release_arm() - # 任务完成 with self._tasks_lock: if material_id in self._active_tasks: self._active_tasks[material_id]["status"] = "completed" self._active_tasks[material_id]["end_time"] = time.time() self._update_data_status(f"{material_id}已移动到{output_position}") - self.logger.info(f"[{material_id}] 已成功移动到{output_position} (用时{self.ARM_OPERATION_TIME}s)") + self.logger.info( + f"[{material_id}] 已成功移动到{output_position} (用时{self.ARM_OPERATION_TIME}s)" + ) return { "success": True, @@ -662,8 +746,17 @@ class VirtualWorkbench: "output_position": output_position, "message": f"{material_id}已成功移动到{output_position}", "unilabos_samples": [ - LabSample(sample_uuid=sample_uuid, oss_path="", extra={"material_uuid": content} if isinstance(content, str) else content.serialize()) for - sample_uuid, content in sample_uuids.items()] + LabSample( + sample_uuid=sample_uuid, + oss_path="", + extra=( + {"material_uuid": content} + if isinstance(content, str) + else (content.serialize() if content is not None else {}) + ), + ) + for sample_uuid, content in sample_uuids.items() + ], } except Exception as e: @@ -677,83 +770,105 @@ class VirtualWorkbench: "output_position": output_position, "message": f"移动失败: {str(e)}", "unilabos_samples": [ - LabSample(sample_uuid=sample_uuid, oss_path="", extra={"material_uuid": content} if isinstance(content, str) else content.serialize()) for - sample_uuid, content in sample_uuids.items()] + LabSample( + sample_uuid=sample_uuid, + oss_path="", + extra=( + {"material_uuid": content} + if isinstance(content, str) else (content.serialize() if content else {}) + ), + ) + for sample_uuid, content in sample_uuids.items() + ], } # ============ 状态属性 ============ @property + @topic_config() def status(self) -> str: return self.data.get("status", "Unknown") @property + @topic_config() def arm_state(self) -> str: return self._arm_state.value @property + @topic_config() def arm_current_task(self) -> str: return self._arm_current_task or "" @property + @topic_config() def heating_station_1_state(self) -> str: with self._stations_lock: station = self._heating_stations.get(1) return station.state.value if station else "unknown" @property + @topic_config() def heating_station_1_material(self) -> str: with self._stations_lock: station = self._heating_stations.get(1) return station.current_material or "" if station else "" @property + @topic_config() def heating_station_1_progress(self) -> float: with self._stations_lock: station = self._heating_stations.get(1) return station.heating_progress if station else 0.0 @property + @topic_config() def heating_station_2_state(self) -> str: with self._stations_lock: station = self._heating_stations.get(2) return station.state.value if station else "unknown" @property + @topic_config() def heating_station_2_material(self) -> str: with self._stations_lock: station = self._heating_stations.get(2) return station.current_material or "" if station else "" @property + @topic_config() def heating_station_2_progress(self) -> float: with self._stations_lock: station = self._heating_stations.get(2) return station.heating_progress if station else 0.0 @property + @topic_config() def heating_station_3_state(self) -> str: with self._stations_lock: station = self._heating_stations.get(3) return station.state.value if station else "unknown" @property + @topic_config() def heating_station_3_material(self) -> str: with self._stations_lock: station = self._heating_stations.get(3) return station.current_material or "" if station else "" @property + @topic_config() def heating_station_3_progress(self) -> float: with self._stations_lock: station = self._heating_stations.get(3) return station.heating_progress if station else 0.0 @property + @topic_config() def active_tasks_count(self) -> int: with self._tasks_lock: return len(self._active_tasks) @property + @topic_config() def message(self) -> str: return self.data.get("message", "") diff --git a/unilabos/registry/ast_registry_scanner.py b/unilabos/registry/ast_registry_scanner.py new file mode 100644 index 00000000..2fa87873 --- /dev/null +++ b/unilabos/registry/ast_registry_scanner.py @@ -0,0 +1,1022 @@ +""" +AST-based Registry Scanner + +Statically parse Python files to extract @device, @action, @topic_config, @resource +decorator metadata without importing any modules. This is ~100x faster than importlib +since it only reads and parses text files. + +Includes a file-level cache: each file's MD5 hash, size and mtime are tracked so +unchanged files skip AST parsing entirely. The cache is persisted as JSON in the +working directory (``unilabos_data/ast_scan_cache.json``). + +Usage: + from unilabos.registry.ast_registry_scanner import scan_directory + + # Scan all device and resource files under a package directory + result = scan_directory("unilabos", python_path="/project") + # => {"devices": {device_id: {...}, ...}, "resources": {resource_id: {...}, ...}} +""" + +import ast +import hashlib +import json +import time +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union + + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +MAX_SCAN_DEPTH = 10 # 最大目录递归深度 +MAX_SCAN_FILES = 1000 # 最大扫描文件数量 +_CACHE_VERSION = 1 # 缓存格式版本号,格式变更时递增 + +# 合法的装饰器来源模块 +_REGISTRY_DECORATOR_MODULE = "unilabos.registry.decorators" + + +# --------------------------------------------------------------------------- +# File-level cache helpers +# --------------------------------------------------------------------------- + + +def _file_fingerprint(filepath: Path) -> Dict[str, Any]: + """Return size, mtime and MD5 hash for *filepath*.""" + stat = filepath.stat() + md5 = hashlib.md5(filepath.read_bytes()).hexdigest() + return {"size": stat.st_size, "mtime": stat.st_mtime, "md5": md5} + + +def load_scan_cache(cache_path: Optional[Path]) -> Dict[str, Any]: + """Load the AST scan cache from *cache_path*. Returns empty structure on any error.""" + if cache_path is None or not cache_path.is_file(): + return {"version": _CACHE_VERSION, "files": {}} + try: + raw = cache_path.read_text(encoding="utf-8") + data = json.loads(raw) + if data.get("version") != _CACHE_VERSION: + return {"version": _CACHE_VERSION, "files": {}} + return data + except Exception: + return {"version": _CACHE_VERSION, "files": {}} + + +def save_scan_cache(cache_path: Optional[Path], cache: Dict[str, Any]) -> None: + """Persist *cache* to *cache_path* (atomic-ish via temp file).""" + if cache_path is None: + return + try: + cache_path.parent.mkdir(parents=True, exist_ok=True) + tmp = cache_path.with_suffix(".tmp") + tmp.write_text(json.dumps(cache, ensure_ascii=False, indent=1), encoding="utf-8") + tmp.replace(cache_path) + except Exception: + pass + + +def _is_cache_hit(entry: Dict[str, Any], fp: Dict[str, Any]) -> bool: + """Check if a cache entry matches the current file fingerprint.""" + return ( + entry.get("md5") == fp["md5"] + and entry.get("size") == fp["size"] + ) + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +def _collect_py_files( + root_dir: Path, + max_depth: int = MAX_SCAN_DEPTH, + max_files: int = MAX_SCAN_FILES, + exclude_files: Optional[set] = None, +) -> List[Path]: + """ + 收集 root_dir 下的 .py 文件,限制最大递归深度和文件数量。 + + Args: + root_dir: 扫描根目录 + max_depth: 最大递归深度 (默认 10 层) + max_files: 最大文件数量 (默认 1000 个) + exclude_files: 要排除的文件名集合 (如 {"lab_resources.py"}) + + Returns: + 排序后的 .py 文件路径列表 + """ + result: List[Path] = [] + _exclude = exclude_files or set() + + def _walk(dir_path: Path, depth: int): + if depth > max_depth or len(result) >= max_files: + return + try: + entries = sorted(dir_path.iterdir()) + except (PermissionError, OSError): + return + for entry in entries: + if len(result) >= max_files: + return + if entry.is_file() and entry.suffix == ".py" and not entry.name.startswith("__"): + if entry.name not in _exclude: + result.append(entry) + elif entry.is_dir() and not entry.name.startswith(("__", ".")): + _walk(entry, depth + 1) + + _walk(root_dir, 0) + return result + + +def scan_directory( + root_dir: Union[str, Path], + python_path: Union[str, Path] = "", + max_depth: int = MAX_SCAN_DEPTH, + max_files: int = MAX_SCAN_FILES, + executor: ThreadPoolExecutor = None, + exclude_files: Optional[set] = None, + cache: Optional[Dict[str, Any]] = None, +) -> Dict[str, Any]: + """ + Recursively scan .py files under *root_dir* for @device and @resource + decorated classes/functions. + + Uses a thread pool to parse files in parallel for faster I/O. + When *cache* is provided, files whose fingerprint (MD5 + size) hasn't + changed since the last scan are served from cache without re-parsing. + + Returns: + {"devices": {device_id: meta, ...}, "resources": {resource_id: meta, ...}} + + Args: + root_dir: Directory to scan (e.g. "unilabos/devices"). + python_path: The directory that should be on sys.path, i.e. the parent + of the top-level package. Module paths are derived as + filepath relative to this directory. If empty, defaults to + root_dir's parent. + max_depth: Maximum directory recursion depth (default 10). + max_files: Maximum number of .py files to scan (default 1000). + executor: Shared ThreadPoolExecutor (required). The caller manages its + lifecycle. + exclude_files: 要排除的文件名集合 (如 {"lab_resources.py"}) + cache: Mutable cache dict (``load_scan_cache()`` result). Hits are read + from here; misses are written back so the caller can persist later. + """ + if executor is None: + raise ValueError("executor is required and must not be None") + + root_dir = Path(root_dir).resolve() + if not python_path: + python_path = root_dir.parent + else: + python_path = Path(python_path).resolve() + + # --- Collect files (depth/count limited) --- + py_files = _collect_py_files(root_dir, max_depth=max_depth, max_files=max_files, exclude_files=exclude_files) + + cache_files: Dict[str, Any] = cache.get("files", {}) if cache else {} + + # --- Parallel scan (with cache fast-path) --- + devices: Dict[str, dict] = {} + resources: Dict[str, dict] = {} + cache_hits = 0 + cache_misses = 0 + + def _parse_one_cached(py_file: Path) -> Tuple[List[dict], List[dict], bool]: + """Returns (devices, resources, was_cache_hit).""" + key = str(py_file) + try: + fp = _file_fingerprint(py_file) + except OSError: + return [], [], False + + cached_entry = cache_files.get(key) + if cached_entry and _is_cache_hit(cached_entry, fp): + return cached_entry.get("devices", []), cached_entry.get("resources", []), True + + try: + devs, ress = _parse_file(py_file, python_path) + except (SyntaxError, Exception): + devs, ress = [], [] + + cache_files[key] = { + "md5": fp["md5"], + "size": fp["size"], + "mtime": fp["mtime"], + "devices": devs, + "resources": ress, + } + return devs, ress, False + + def _collect_results(futures_dict: Dict): + nonlocal cache_hits, cache_misses + for future in as_completed(futures_dict): + devs, ress, hit = future.result() + if hit: + cache_hits += 1 + else: + cache_misses += 1 + for dev in devs: + device_id = dev.get("device_id") + if device_id: + if device_id in devices: + existing = devices[device_id].get("file_path", "?") + new_file = dev.get("file_path", "?") + raise ValueError( + f"@device id 重复: '{device_id}' 同时出现在 {existing} 和 {new_file}" + ) + devices[device_id] = dev + for res in ress: + resource_id = res.get("resource_id") + if resource_id: + if resource_id in resources: + existing = resources[resource_id].get("file_path", "?") + new_file = res.get("file_path", "?") + raise ValueError( + f"@resource id 重复: '{resource_id}' 同时出现在 {existing} 和 {new_file}" + ) + resources[resource_id] = res + + futures = {executor.submit(_parse_one_cached, f): f for f in py_files} + _collect_results(futures) + + if cache is not None: + cache["files"] = cache_files + + return { + "devices": devices, + "resources": resources, + "_cache_stats": {"hits": cache_hits, "misses": cache_misses, "total": len(py_files)}, + } + + + + +# --------------------------------------------------------------------------- +# File-level parsing +# --------------------------------------------------------------------------- + +# 已知继承自 rclpy.node.Node 的基类名 (用于 AST 静态检测) +_KNOWN_ROS2_BASE_CLASSES = {"Node", "BaseROS2DeviceNode"} +_KNOWN_ROS2_MODULES = {"rclpy", "rclpy.node"} + + +def _detect_class_type(cls_node: ast.ClassDef, import_map: Dict[str, str]) -> str: + """ + 检测类是否继承自 rclpy Node,返回 'ros2' 或 'python'。 + + 通过检查类的基类名称和 import_map 中的模块路径来判断: + 1. 基类名在已知 ROS2 基类集合中 + 2. 基类在 import_map 中解析到 rclpy 相关模块 + 3. 基类在 import_map 中解析到 BaseROS2DeviceNode + """ + for base in cls_node.bases: + base_name = "" + if isinstance(base, ast.Name): + base_name = base.id + elif isinstance(base, ast.Attribute): + base_name = base.attr + elif isinstance(base, ast.Subscript) and isinstance(base.value, ast.Name): + # Generic[T] 形式,如 BaseROS2DeviceNode[SomeType] + base_name = base.value.id + + if not base_name: + continue + + # 直接匹配已知 ROS2 基类名 + if base_name in _KNOWN_ROS2_BASE_CLASSES: + return "ros2" + + # 通过 import_map 检查模块路径 + module_path = import_map.get(base_name, "") + if any(mod in module_path for mod in _KNOWN_ROS2_MODULES): + return "ros2" + if "BaseROS2DeviceNode" in module_path: + return "ros2" + + return "python" + + +def _parse_file( + filepath: Path, + python_path: Path, +) -> Tuple[List[dict], List[dict]]: + """ + Parse a single .py file using ast and extract all @device-decorated classes + and @resource-decorated functions/classes. + + Returns: + (devices, resources) -- two lists of metadata dicts. + """ + source = filepath.read_text(encoding="utf-8", errors="replace") + tree = ast.parse(source, filename=str(filepath)) + + # Derive module path from file path + module_path = _filepath_to_module(filepath, python_path) + + # Build import map from the file (includes same-file class defs) + import_map = _collect_imports(tree, module_path) + + devices: List[dict] = [] + resources: List[dict] = [] + + for node in ast.iter_child_nodes(tree): + # --- @device on classes --- + if isinstance(node, ast.ClassDef): + device_decorator = _find_decorator(node, "device") + if device_decorator is not None and _is_registry_decorator("device", import_map): + device_args = _extract_decorator_args(device_decorator, import_map) + class_body = _extract_class_body(node, import_map) + + # Support ids + id_meta (multi-device) or id (single device) + device_ids: List[str] = [] + if device_args.get("ids") is not None: + device_ids = list(device_args["ids"]) + else: + did = device_args.get("id") or device_args.get("device_id") + device_ids = [did] if did else [f"{module_path}:{node.name}"] + + id_meta = device_args.get("id_meta") or {} + base_meta = { + "class_name": node.name, + "module": f"{module_path}:{node.name}", + "file_path": str(filepath).replace("\\", "/"), + "category": device_args.get("category", []), + "description": device_args.get("description", ""), + "display_name": device_args.get("display_name", ""), + "icon": device_args.get("icon", ""), + "version": device_args.get("version", "1.0.0"), + "device_type": _detect_class_type(node, import_map), + "handles": device_args.get("handles", []), + "model": device_args.get("model"), + "hardware_interface": device_args.get("hardware_interface"), + "actions": class_body.get("actions", {}), + "status_properties": class_body.get("status_properties", {}), + "init_params": class_body.get("init_params", []), + "auto_methods": class_body.get("auto_methods", {}), + "import_map": import_map, + } + for did in device_ids: + meta = dict(base_meta) + meta["device_id"] = did + overrides = id_meta.get(did, {}) + for key in ("handles", "description", "icon", "model", "hardware_interface"): + if key in overrides: + meta[key] = overrides[key] + devices.append(meta) + + # --- @resource on classes --- + resource_decorator = _find_decorator(node, "resource") + if resource_decorator is not None and _is_registry_decorator("resource", import_map): + res_meta = _extract_resource_meta( + resource_decorator, node.name, module_path, filepath, import_map, + is_function=False, + init_node=_find_init_in_class(node), + ) + resources.append(res_meta) + + # --- @resource on module-level functions --- + elif isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + resource_decorator = _find_method_decorator(node, "resource") + if resource_decorator is not None and _is_registry_decorator("resource", import_map): + res_meta = _extract_resource_meta( + resource_decorator, node.name, module_path, filepath, import_map, + is_function=True, + func_node=node, + ) + resources.append(res_meta) + + return devices, resources + + +def _find_init_in_class(cls_node: ast.ClassDef) -> Optional[ast.FunctionDef]: + """Find __init__ method in a class.""" + for item in cls_node.body: + if isinstance(item, ast.FunctionDef) and item.name == "__init__": + return item + return None + + +def _extract_resource_meta( + decorator_node: Union[ast.Call, ast.Name], + name: str, + module_path: str, + filepath: Path, + import_map: Dict[str, str], + is_function: bool = False, + func_node: Optional[Union[ast.FunctionDef, ast.AsyncFunctionDef]] = None, + init_node: Optional[ast.FunctionDef] = None, +) -> dict: + """ + Extract resource metadata from a @resource decorator on a function or class. + """ + res_args = _extract_decorator_args(decorator_node, import_map) + + resource_id = res_args.get("id") or res_args.get("resource_id") + if resource_id is None: + resource_id = f"{module_path}:{name}" + + # Extract init/function params + init_params: List[dict] = [] + if is_function and func_node is not None: + init_params = _extract_method_params(func_node, import_map) + elif not is_function and init_node is not None: + init_params = _extract_method_params(init_node, import_map) + + return { + "resource_id": resource_id, + "name": name, + "module": f"{module_path}:{name}", + "file_path": str(filepath).replace("\\", "/"), + "is_function": is_function, + "category": res_args.get("category", []), + "description": res_args.get("description", ""), + "icon": res_args.get("icon", ""), + "version": res_args.get("version", "1.0.0"), + "class_type": res_args.get("class_type", "pylabrobot"), + "handles": res_args.get("handles", []), + "model": res_args.get("model"), + "init_params": init_params, + } + + +# --------------------------------------------------------------------------- +# Import map collection +# --------------------------------------------------------------------------- + + +def _collect_imports(tree: ast.Module, module_path: str = "") -> Dict[str, str]: + """ + Walk all Import/ImportFrom nodes in the AST tree, build a mapping from + local name to fully-qualified import path. + + Also includes top-level class/function definitions from the same file, + so that same-file TypedDict / Enum / dataclass references can be resolved. + + Returns: + {"SendCmd": "unilabos_msgs.action:SendCmd", + "StrSingleInput": "unilabos_msgs.action:StrSingleInput", + "InputHandle": "unilabos.registry.decorators:InputHandle", + "SetLiquidReturn": "unilabos.devices.liquid_handling.liquid_handler_abstract:SetLiquidReturn", + ...} + """ + import_map: Dict[str, str] = {} + + for node in ast.walk(tree): + if isinstance(node, ast.ImportFrom): + module = node.module or "" + for alias in node.names: + local_name = alias.asname if alias.asname else alias.name + import_map[local_name] = f"{module}:{alias.name}" + elif isinstance(node, ast.Import): + for alias in node.names: + local_name = alias.asname if alias.asname else alias.name + import_map[local_name] = alias.name + + # 同文件顶层 class / function 定义 + if module_path: + for node in tree.body: + if isinstance(node, ast.ClassDef): + import_map.setdefault(node.name, f"{module_path}:{node.name}") + elif isinstance(node, ast.FunctionDef) or isinstance(node, ast.AsyncFunctionDef): + import_map.setdefault(node.name, f"{module_path}:{node.name}") + elif isinstance(node, ast.Assign): + # 顶层赋值 (如 MotorAxis = Enum(...)) + for target in node.targets: + if isinstance(target, ast.Name): + import_map.setdefault(target.id, f"{module_path}:{target.id}") + + return import_map + + + +# --------------------------------------------------------------------------- +# Decorator finding & argument extraction +# --------------------------------------------------------------------------- + + +def _find_decorator( + node: Union[ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef], + decorator_name: str, +) -> Optional[ast.Call]: + """ + Find a specific decorator call on a class or function definition. + + Handles both: + - @device(...) -> ast.Call with func=ast.Name(id="device") + - @module.device(...) -> ast.Call with func=ast.Attribute(attr="device") + """ + for dec in node.decorator_list: + if isinstance(dec, ast.Call): + if isinstance(dec.func, ast.Name) and dec.func.id == decorator_name: + return dec + if isinstance(dec.func, ast.Attribute) and dec.func.attr == decorator_name: + return dec + elif isinstance(dec, ast.Name) and dec.id == decorator_name: + # @device without parens (unlikely but handle it) + return None # Can't extract args from bare decorator + return None + + +def _find_method_decorator(func_node: ast.FunctionDef, decorator_name: str) -> Optional[Union[ast.Call, ast.Name]]: + """Find a decorator on a method.""" + for dec in func_node.decorator_list: + if isinstance(dec, ast.Call): + if isinstance(dec.func, ast.Name) and dec.func.id == decorator_name: + return dec + if isinstance(dec.func, ast.Attribute) and dec.func.attr == decorator_name: + return dec + elif isinstance(dec, ast.Name) and dec.id == decorator_name: + # @action without parens, or @topic_config without parens + return dec + return None + + +def _has_decorator(func_node: ast.FunctionDef, decorator_name: str) -> bool: + """Check if a method has a specific decorator (with or without call).""" + for dec in func_node.decorator_list: + if isinstance(dec, ast.Call): + if isinstance(dec.func, ast.Name) and dec.func.id == decorator_name: + return True + if isinstance(dec.func, ast.Attribute) and dec.func.attr == decorator_name: + return True + elif isinstance(dec, ast.Name) and dec.id == decorator_name: + return True + return False + + +def _is_registry_decorator(name: str, import_map: Dict[str, str]) -> bool: + """Check that *name* was imported from ``unilabos.registry.decorators``.""" + source = import_map.get(name, "") + return _REGISTRY_DECORATOR_MODULE in source + + +def _extract_decorator_args( + node: Union[ast.Call, ast.Name], + import_map: Dict[str, str], +) -> dict: + """ + Extract keyword arguments from a decorator call AST node. + + Resolves Name references (e.g. SendCmd, Side.NORTH) via import_map. + Handles literal values (strings, ints, bools, lists, dicts, None). + """ + if isinstance(node, ast.Name): + return {} # Bare decorator, no args + if not isinstance(node, ast.Call): + return {} + + result: dict = {} + + for kw in node.keywords: + if kw.arg is None: + continue # **kwargs, skip + result[kw.arg] = _ast_node_to_value(kw.value, import_map) + + return result + + +# --------------------------------------------------------------------------- +# AST node value conversion +# --------------------------------------------------------------------------- + + +def _ast_node_to_value(node: ast.expr, import_map: Dict[str, str]) -> Any: + """ + Convert an AST expression node to a Python value. + + Handles: + - Literals (str, int, float, bool, None) + - Lists, Tuples, Dicts, Sets + - Name references (e.g. SendCmd -> resolved via import_map) + - Attribute access (e.g. Side.NORTH -> resolved) + - Function/class calls (e.g. InputHandle(...) -> structured dict) + - Unary operators (e.g. -1) + """ + # --- Constant (str, int, float, bool, None) --- + if isinstance(node, ast.Constant): + return node.value + + # --- Name (e.g. SendCmd, True, False, None) --- + if isinstance(node, ast.Name): + return _resolve_name(node.id, import_map) + + # --- Attribute (e.g. Side.NORTH, DataSource.HANDLE) --- + if isinstance(node, ast.Attribute): + return _resolve_attribute(node, import_map) + + # --- List --- + if isinstance(node, ast.List): + return [_ast_node_to_value(elt, import_map) for elt in node.elts] + + # --- Tuple --- + if isinstance(node, ast.Tuple): + return [_ast_node_to_value(elt, import_map) for elt in node.elts] + + # --- Dict --- + if isinstance(node, ast.Dict): + result = {} + for k, v in zip(node.keys, node.values): + if k is None: + continue # **kwargs spread + key = _ast_node_to_value(k, import_map) + val = _ast_node_to_value(v, import_map) + result[key] = val + return result + + # --- Set --- + if isinstance(node, ast.Set): + return [_ast_node_to_value(elt, import_map) for elt in node.elts] + + # --- Call (e.g. InputHandle(...), OutputHandle(...)) --- + if isinstance(node, ast.Call): + return _ast_call_to_value(node, import_map) + + # --- UnaryOp (e.g. -1, -0.5) --- + if isinstance(node, ast.UnaryOp): + if isinstance(node.op, ast.USub): + operand = _ast_node_to_value(node.operand, import_map) + if isinstance(operand, (int, float)): + return -operand + elif isinstance(node.op, ast.Not): + operand = _ast_node_to_value(node.operand, import_map) + return not operand + + # --- BinOp (e.g. "a" + "b") --- + if isinstance(node, ast.BinOp): + if isinstance(node.op, ast.Add): + left = _ast_node_to_value(node.left, import_map) + right = _ast_node_to_value(node.right, import_map) + if isinstance(left, str) and isinstance(right, str): + return left + right + + # --- JoinedStr (f-string) --- + if isinstance(node, ast.JoinedStr): + return "" + + # Fallback: return the AST dump as a string marker + return f"" + + +def _resolve_name(name: str, import_map: Dict[str, str]) -> str: + """ + Resolve a bare Name reference via import_map. + + E.g. "SendCmd" -> "unilabos_msgs.action:SendCmd" + "True" -> True (handled by ast.Constant in Python 3.8+) + """ + if name in import_map: + return import_map[name] + # Fallback: return the name as-is + return name + + +def _resolve_attribute(node: ast.Attribute, import_map: Dict[str, str]) -> str: + """ + Resolve an attribute access like Side.NORTH or DataSource.HANDLE. + + Returns a string like "NORTH" for enum values, or + "module.path:Class.attr" for imported references. + """ + # Get the full dotted path + parts = [] + current = node + while isinstance(current, ast.Attribute): + parts.append(current.attr) + current = current.value + if isinstance(current, ast.Name): + parts.append(current.id) + + parts.reverse() + # parts = ["Side", "NORTH"] or ["DataSource", "HANDLE"] + + if len(parts) >= 2: + base = parts[0] + attr = ".".join(parts[1:]) + + # If the base is an imported name, resolve it + if base in import_map: + return f"{import_map[base]}.{attr}" + + # For known enum-like patterns, return just the value + # e.g. Side.NORTH -> "NORTH" + if base in ("Side", "DataSource"): + return parts[-1] + + return ".".join(parts) + + +def _ast_call_to_value(node: ast.Call, import_map: Dict[str, str]) -> dict: + """ + Convert a function/class call like InputHandle(key="in", ...) to a structured dict. + + Returns: + {"_call": "unilabos.registry.decorators:InputHandle", + "key": "in", "data_type": "fluid", ...} + """ + # Resolve the call target + if isinstance(node.func, ast.Name): + call_name = _resolve_name(node.func.id, import_map) + elif isinstance(node.func, ast.Attribute): + call_name = _resolve_attribute(node.func, import_map) + else: + call_name = "" + + result: dict = {"_call": call_name} + + # Positional args + for i, arg in enumerate(node.args): + result[f"_pos_{i}"] = _ast_node_to_value(arg, import_map) + + # Keyword args + for kw in node.keywords: + if kw.arg is None: + continue + result[kw.arg] = _ast_node_to_value(kw.value, import_map) + + return result + + +# --------------------------------------------------------------------------- +# Class body extraction +# --------------------------------------------------------------------------- + + +def _extract_class_body( + cls_node: ast.ClassDef, + import_map: Dict[str, str], +) -> dict: + """ + Walk the class body to extract: + - @action-decorated methods + - @property with @topic_config (status properties) + - get_* methods with @topic_config + - __init__ parameters + - Public methods without @action (auto-actions) + """ + result: dict = { + "actions": {}, # method_name -> action_info + "status_properties": {}, # prop_name -> status_info + "init_params": [], # [{"name": ..., "type": ..., "default": ...}, ...] + "auto_methods": {}, # method_name -> method_info (no @action decorator) + } + + for item in cls_node.body: + if not isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)): + continue + + method_name = item.name + + # --- __init__ --- + if method_name == "__init__": + result["init_params"] = _extract_method_params(item, import_map) + continue + + # --- Skip private/dunder --- + if method_name.startswith("_"): + continue + + # --- Check for @property or @topic_config → status property --- + is_property = _has_decorator(item, "property") + has_topic = ( + _has_decorator(item, "topic_config") + and _is_registry_decorator("topic_config", import_map) + ) + + if is_property or has_topic: + topic_args = {} + topic_dec = _find_method_decorator(item, "topic_config") + if topic_dec is not None: + topic_args = _extract_decorator_args(topic_dec, import_map) + + return_type = _get_annotation_str(item.returns, import_map) + # 非 @property 的 @topic_config 方法,用去掉 get_ 前缀的名称 + prop_name = method_name[4:] if method_name.startswith("get_") and not is_property else method_name + + result["status_properties"][prop_name] = { + "name": prop_name, + "return_type": return_type, + "is_property": is_property, + "topic_config": topic_args if topic_args else None, + } + continue + + # --- Check for @action --- + action_dec = _find_method_decorator(item, "action") + if action_dec is not None and _is_registry_decorator("action", import_map): + action_args = _extract_decorator_args(action_dec, import_map) + # 补全 @action 装饰器的默认值(与 decorators.py 中 action() 签名一致) + action_args.setdefault("action_type", None) + action_args.setdefault("goal", {}) + action_args.setdefault("feedback", {}) + action_args.setdefault("result", {}) + action_args.setdefault("handles", {}) + action_args.setdefault("goal_default", {}) + action_args.setdefault("placeholder_keys", {}) + action_args.setdefault("always_free", False) + action_args.setdefault("is_protocol", False) + action_args.setdefault("description", "") + action_args.setdefault("auto_prefix", False) + action_args.setdefault("parent", False) + method_params = _extract_method_params(item, import_map) + return_type = _get_annotation_str(item.returns, import_map) + is_async = isinstance(item, ast.AsyncFunctionDef) + method_doc = ast.get_docstring(item) + + result["actions"][method_name] = { + "action_args": action_args, + "params": method_params, + "return_type": return_type, + "is_async": is_async, + "docstring": method_doc, + } + continue + + # --- Check for @not_action --- + if _has_decorator(item, "not_action") and _is_registry_decorator("not_action", import_map): + continue + + # --- Public method without @action => auto-action --- + # Skip lifecycle / dunder methods that should never be auto-actions + if method_name in ("post_init", "__str__", "__repr__"): + continue + + # 'close' and 'cleanup' could be actions in some drivers -- keep them + + method_params = _extract_method_params(item, import_map) + return_type = _get_annotation_str(item.returns, import_map) + is_async = isinstance(item, ast.AsyncFunctionDef) + method_doc = ast.get_docstring(item) + + result["auto_methods"][method_name] = { + "params": method_params, + "return_type": return_type, + "is_async": is_async, + "docstring": method_doc, + } + + return result + + +# --------------------------------------------------------------------------- +# Method parameter extraction +# --------------------------------------------------------------------------- + + +_PARAM_SKIP_NAMES = frozenset({"sample_uuids"}) + + +def _extract_method_params( + func_node: Union[ast.FunctionDef, ast.AsyncFunctionDef], + import_map: Optional[Dict[str, str]] = None, +) -> List[dict]: + """ + Extract parameters from a class method definition. + + Automatically skips the first positional argument (self / cls) and any + domain-specific names listed in ``_PARAM_SKIP_NAMES``. + + Returns: + [{"name": "position", "type": "str", "default": None, "required": True}, ...] + """ + if import_map is None: + import_map = {} + params: List[dict] = [] + + args = func_node.args + + # Skip the first positional arg (self/cls) -- always present for class methods + # noinspection PyUnresolvedReferences + positional_args = args.args[1:] if args.args else [] + + # defaults align to the *end* of the args list; offset must account for the skipped arg + num_args = len(args.args) + num_defaults = len(args.defaults) + first_default_idx = num_args - num_defaults + + for i, arg in enumerate(positional_args, start=1): + name = arg.arg + if name in _PARAM_SKIP_NAMES: + continue + + param_info: dict = {"name": name} + + # Type annotation + if arg.annotation: + param_info["type"] = _get_annotation_str(arg.annotation, import_map) + else: + param_info["type"] = "" + + # Default value + default_idx = i - first_default_idx + if 0 <= default_idx < len(args.defaults): + default_val = _ast_node_to_value(args.defaults[default_idx], import_map) + param_info["default"] = default_val + param_info["required"] = False + else: + param_info["default"] = None + param_info["required"] = True + + params.append(param_info) + + # Keyword-only arguments (self/cls never appear here) + for i, arg in enumerate(args.kwonlyargs): + name = arg.arg + if name in _PARAM_SKIP_NAMES: + continue + + param_info: dict = {"name": name} + + if arg.annotation: + param_info["type"] = _get_annotation_str(arg.annotation, import_map) + else: + param_info["type"] = "" + + if i < len(args.kw_defaults) and args.kw_defaults[i] is not None: + param_info["default"] = _ast_node_to_value(args.kw_defaults[i], import_map) + param_info["required"] = False + else: + param_info["default"] = None + param_info["required"] = True + + params.append(param_info) + + return params + + +def _get_annotation_str(node: Optional[ast.expr], import_map: Dict[str, str]) -> str: + """Convert a type annotation AST node to a string representation. + + 保持类型字符串为合法 Python 表达式 (可被 ast.parse 解析)。 + 不在此处做 import_map 替换 — 由上层在需要时通过 import_map 解析。 + """ + if node is None: + return "" + + if isinstance(node, ast.Constant): + return str(node.value) + + if isinstance(node, ast.Name): + return node.id + + if isinstance(node, ast.Attribute): + parts = [] + current = node + while isinstance(current, ast.Attribute): + parts.append(current.attr) + current = current.value + if isinstance(current, ast.Name): + parts.append(current.id) + parts.reverse() + return ".".join(parts) + + # Handle subscript types like List[str], Dict[str, int], Optional[str] + if isinstance(node, ast.Subscript): + base = _get_annotation_str(node.value, import_map) + if isinstance(node.slice, ast.Tuple): + args = ", ".join(_get_annotation_str(elt, import_map) for elt in node.slice.elts) + else: + args = _get_annotation_str(node.slice, import_map) + return f"{base}[{args}]" + + # Handle Union types (X | Y in Python 3.10+) + if isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr): + left = _get_annotation_str(node.left, import_map) + right = _get_annotation_str(node.right, import_map) + return f"Union[{left}, {right}]" + + return ast.dump(node) + + +# --------------------------------------------------------------------------- +# Module path derivation +# --------------------------------------------------------------------------- + + +def _filepath_to_module(filepath: Path, python_path: Path) -> str: + """ + 通过 *python_path*(sys.path 中的根目录)推导 Python 模块路径。 + + 做法:取 filepath 相对于 python_path 的路径,将目录分隔符替换为 '.'。 + + E.g. filepath = "/project/unilabos/devices/pump/valve.py" + python_path = "/project" + => "unilabos.devices.pump.valve" + """ + try: + relative = filepath.relative_to(python_path) + except ValueError: + return str(filepath) + + parts = list(relative.parts) + # 去掉 .py 后缀 + if parts and parts[-1].endswith(".py"): + parts[-1] = parts[-1][:-3] + # 去掉 __init__ + if parts and parts[-1] == "__init__": + parts.pop() + + return ".".join(parts) diff --git a/unilabos/registry/decorators.py b/unilabos/registry/decorators.py new file mode 100644 index 00000000..e8c65ac8 --- /dev/null +++ b/unilabos/registry/decorators.py @@ -0,0 +1,614 @@ +""" +装饰器注册表系统 + +通过 @device, @action, @resource 装饰器替代 YAML 配置文件来定义设备/动作/资源注册表信息。 + +Usage: + from unilabos.registry.decorators import ( + device, action, resource, + InputHandle, OutputHandle, + ActionInputHandle, ActionOutputHandle, + HardwareInterface, Side, DataSource, + ) + + @device( + id="solenoid_valve.mock", + category=["pump_and_valve"], + description="模拟电磁阀设备", + handles=[ + InputHandle(key="in", data_type="fluid", label="in", side=Side.NORTH), + OutputHandle(key="out", data_type="fluid", label="out", side=Side.SOUTH), + ], + hardware_interface=HardwareInterface( + name="hardware_interface", + read="send_command", + write="send_command", + ), + ) + class SolenoidValveMock: + @action(action_type=EmptyIn) + def close(self): + ... + + @action( + handles=[ + ActionInputHandle(key="in", data_type="fluid", label="in"), + ActionOutputHandle(key="out", data_type="fluid", label="out"), + ], + ) + def set_valve_position(self, position): + ... + + # 无 @action 装饰器 => auto- 前缀动作 + def is_open(self): + ... +""" + +from enum import Enum +from functools import wraps +from typing import Any, Callable, Dict, List, Optional, TypeVar + +from pydantic import BaseModel, ConfigDict, Field + +F = TypeVar("F", bound=Callable[..., Any]) + +# --------------------------------------------------------------------------- +# 枚举 +# --------------------------------------------------------------------------- + + +class Side(str, Enum): + """UI 上 Handle 的显示位置""" + + NORTH = "NORTH" + SOUTH = "SOUTH" + EAST = "EAST" + WEST = "WEST" + + +class DataSource(str, Enum): + """Handle 的数据来源""" + + HANDLE = "handle" # 从上游 handle 获取数据 (用于 InputHandle) + EXECUTOR = "executor" # 从执行器输出数据 (用于 OutputHandle) + + +# --------------------------------------------------------------------------- +# Device / Resource Handle (设备/资源级别端口, 序列化时包含 io_type) +# --------------------------------------------------------------------------- + + +class _DeviceHandleBase(BaseModel): + """设备/资源端口基类 (内部使用)""" + + model_config = ConfigDict(populate_by_name=True) + + key: str = Field(serialization_alias="handler_key") + data_type: str + label: str + side: Optional[Side] = None + data_key: Optional[str] = None + data_source: Optional[str] = None + description: Optional[str] = None + + # 子类覆盖 + io_type: str = "" + + def to_registry_dict(self) -> Dict[str, Any]: + return self.model_dump(by_alias=True, exclude_none=True) + + +class InputHandle(_DeviceHandleBase): + """ + 输入端口 (io_type="target"), 用于 @device / @resource handles + + Example: + InputHandle(key="in", data_type="fluid", label="in", side=Side.NORTH) + """ + + io_type: str = "target" + + +class OutputHandle(_DeviceHandleBase): + """ + 输出端口 (io_type="source"), 用于 @device / @resource handles + + Example: + OutputHandle(key="out", data_type="fluid", label="out", side=Side.SOUTH) + """ + + io_type: str = "source" + + +# --------------------------------------------------------------------------- +# Action Handle (动作级别端口, 序列化时不含 io_type, 按类型自动分组) +# --------------------------------------------------------------------------- + + +class _ActionHandleBase(BaseModel): + """动作端口基类 (内部使用)""" + + model_config = ConfigDict(populate_by_name=True) + + key: str = Field(serialization_alias="handler_key") + data_type: str + label: str + side: Optional[Side] = None + data_key: Optional[str] = None + data_source: Optional[str] = None + description: Optional[str] = None + io_type: Optional[str] = None # source/sink (dataflow) or target/source (device-style) + + def to_registry_dict(self) -> Dict[str, Any]: + return self.model_dump(by_alias=True, exclude_none=True) + + +class ActionInputHandle(_ActionHandleBase): + """ + 动作输入端口, 用于 @action handles, 序列化后归入 "input" 组 + + Example: + ActionInputHandle( + key="material_input", data_type="workbench_material", + label="物料编号", data_key="material_number", data_source="handle", + ) + """ + + pass + + +class ActionOutputHandle(_ActionHandleBase): + """ + 动作输出端口, 用于 @action handles, 序列化后归入 "output" 组 + + Example: + ActionOutputHandle( + key="station_output", data_type="workbench_station", + label="加热台ID", data_key="station_id", data_source="executor", + ) + """ + + pass + + +# --------------------------------------------------------------------------- +# HardwareInterface +# --------------------------------------------------------------------------- + + +class HardwareInterface(BaseModel): + """ + 硬件通信接口定义 + + 描述设备与底层硬件通信的方式 (串口、Modbus 等)。 + + Example: + HardwareInterface(name="hardware_interface", read="send_command", write="send_command") + """ + + name: str + read: Optional[str] = None + write: Optional[str] = None + extra_info: Optional[List[str]] = None + + +# --------------------------------------------------------------------------- +# 全局注册表 -- 记录所有被装饰器标记的类/函数 +# --------------------------------------------------------------------------- +_registered_devices: Dict[str, type] = {} # device_id -> class +_registered_resources: Dict[str, Any] = {} # resource_id -> class or function + + +def _device_handles_to_list( + handles: Optional[List[_DeviceHandleBase]], +) -> List[Dict[str, Any]]: + """将设备/资源 Handle 列表序列化为字典列表 (含 io_type)""" + if handles is None: + return [] + return [h.to_registry_dict() for h in handles] + + +def _action_handles_to_dict( + handles: Optional[List[_ActionHandleBase]], +) -> Dict[str, Any]: + """ + 将动作 Handle 列表序列化为 {"input": [...], "output": [...]} 格式。 + + ActionInputHandle => "input", ActionOutputHandle => "output" + """ + if handles is None: + return {} + input_list = [h.to_registry_dict() for h in handles if isinstance(h, ActionInputHandle)] + output_list = [h.to_registry_dict() for h in handles if isinstance(h, ActionOutputHandle)] + result: Dict[str, Any] = {} + if input_list: + result["input"] = input_list + if output_list: + result["output"] = output_list + return result + + +# --------------------------------------------------------------------------- +# @device 类装饰器 +# --------------------------------------------------------------------------- + + +# noinspection PyShadowingBuiltins +def device( + id: Optional[str] = None, + ids: Optional[List[str]] = None, + id_meta: Optional[Dict[str, Dict[str, Any]]] = None, + category: Optional[List[str]] = None, + description: str = "", + display_name: str = "", + icon: str = "", + version: str = "1.0.0", + handles: Optional[List[_DeviceHandleBase]] = None, + model: Optional[Dict[str, Any]] = None, + device_type: str = "python", + hardware_interface: Optional[HardwareInterface] = None, +): + """ + 设备类装饰器 + + 将类标记为一个 UniLab-OS 设备,并附加注册表元数据。 + + 支持两种模式: + 1. 单设备: id="xxx", category=[...] + 2. 多设备: ids=["id1","id2"], id_meta={"id1":{handles:[...]}, "id2":{...}} + + Args: + id: 单设备时的注册表唯一标识 + ids: 多设备时的 id 列表,与 id_meta 配合使用 + id_meta: 每个 device_id 的覆盖元数据 (handles/description/icon/model) + category: 设备分类标签列表 (必填) + description: 设备描述 + display_name: 人类可读的设备显示名称,缺失时默认使用 id + icon: 图标路径 + version: 版本号 + handles: 设备端口列表 (单设备或 id_meta 未覆盖时使用) + model: 可选的 3D 模型配置 + device_type: 设备实现类型 ("python" / "ros2") + hardware_interface: 硬件通信接口 (HardwareInterface) + """ + # Resolve device ids + if ids is not None: + device_ids = list(ids) + if not device_ids: + raise ValueError("@device ids 不能为空") + id_meta = id_meta or {} + elif id is not None: + device_ids = [id] + id_meta = {} + else: + raise ValueError("@device 必须提供 id 或 ids") + + if category is None: + raise ValueError("@device category 必填") + + base_meta = { + "category": category, + "description": description, + "display_name": display_name, + "icon": icon, + "version": version, + "handles": _device_handles_to_list(handles), + "model": model, + "device_type": device_type, + "hardware_interface": (hardware_interface.model_dump(exclude_none=True) if hardware_interface else None), + } + + def decorator(cls): + cls._device_registry_meta = base_meta + cls._device_registry_id_meta = id_meta + cls._device_registry_ids = device_ids + + for did in device_ids: + if did in _registered_devices: + raise ValueError(f"@device id 重复: '{did}' 已被 {_registered_devices[did]} 注册") + _registered_devices[did] = cls + + return cls + + return decorator + + +# --------------------------------------------------------------------------- +# @action 方法装饰器 +# --------------------------------------------------------------------------- + +# 区分 "用户没传 action_type" 和 "用户传了 None" +_ACTION_TYPE_UNSET = object() + + +# noinspection PyShadowingNames +def action( + action_type: Any = _ACTION_TYPE_UNSET, + goal: Optional[Dict[str, str]] = None, + feedback: Optional[Dict[str, str]] = None, + result: Optional[Dict[str, str]] = None, + handles: Optional[List[_ActionHandleBase]] = None, + goal_default: Optional[Dict[str, Any]] = None, + placeholder_keys: Optional[Dict[str, str]] = None, + always_free: bool = False, + is_protocol: bool = False, + description: str = "", + auto_prefix: bool = False, + parent: bool = False, +): + """ + 动作方法装饰器 + + 标记方法为注册表动作。有三种用法: + 1. @action(action_type=EmptyIn, ...) -- 非 auto, 使用指定 ROS Action 类型 + 2. @action() -- 非 auto, UniLabJsonCommand (从方法签名生成 schema) + 3. 不加 @action -- auto- 前缀, UniLabJsonCommand + + Protocol 用法: + @action(action_type=Add, is_protocol=True) + def AddProtocol(self): ... + 标记该动作为高级协议 (protocol),运行时通过 ROS Action 路由到 + protocol generator 执行。action_type 指向 unilabos_msgs 的 Action 类型。 + + Args: + action_type: ROS Action 消息类型 (如 EmptyIn, SendCmd, HeatChill). + 不传/默认 = UniLabJsonCommand (非 auto). + goal: Goal 字段映射 (ROS字段名 -> 设备参数名). + protocol 模式下可留空,系统自动生成 identity 映射. + feedback: Feedback 字段映射 + result: Result 字段映射 + handles: 动作端口列表 (ActionInputHandle / ActionOutputHandle) + goal_default: Goal 字段默认值映射 (字段名 -> 默认值), 与自动生成的 goal_default 合并 + placeholder_keys: 参数占位符配置 + always_free: 是否为永久闲置动作 (不受排队限制) + is_protocol: 是否为工作站协议 (protocol)。True 时运行时走 protocol generator 路径。 + description: 动作描述 + auto_prefix: 若为 True,动作名使用 auto-{method_name} 形式(与无 @action 时一致) + parent: 若为 True,当方法参数为空 (*args, **kwargs) 时,通过 MRO 从父类获取真实方法参数 + """ + + def decorator(func: F) -> F: + @wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + # action_type 为哨兵值 => 用户没传, 视为 None (UniLabJsonCommand) + resolved_type = None if action_type is _ACTION_TYPE_UNSET else action_type + + meta = { + "action_type": resolved_type, + "goal": goal or {}, + "feedback": feedback or {}, + "result": result or {}, + "handles": _action_handles_to_dict(handles), + "goal_default": goal_default or {}, + "placeholder_keys": placeholder_keys or {}, + "always_free": always_free, + "is_protocol": is_protocol, + "description": description, + "auto_prefix": auto_prefix, + "parent": parent, + } + wrapper._action_registry_meta = meta # type: ignore[attr-defined] + + # 设置 _is_always_free 保持与旧 @always_free 装饰器兼容 + if always_free: + wrapper._is_always_free = True # type: ignore[attr-defined] + + return wrapper # type: ignore[return-value] + + return decorator + + +def get_action_meta(func) -> Optional[Dict[str, Any]]: + """获取方法上的 @action 装饰器元数据""" + return getattr(func, "_action_registry_meta", None) + + +def has_action_decorator(func) -> bool: + """检查函数是否带有 @action 装饰器""" + return hasattr(func, "_action_registry_meta") + + +# --------------------------------------------------------------------------- +# @resource 类/函数装饰器 +# --------------------------------------------------------------------------- + + +def resource( + id: str, + category: List[str], + description: str = "", + icon: str = "", + version: str = "1.0.0", + handles: Optional[List[_DeviceHandleBase]] = None, + model: Optional[Dict[str, Any]] = None, + class_type: str = "pylabrobot", +): + """ + 资源类/函数装饰器 + + 将类或工厂函数标记为一个 UniLab-OS 资源,附加注册表元数据。 + + Args: + id: 注册表唯一标识 (必填, 不可重复) + category: 资源分类标签列表 (必填) + description: 资源描述 + icon: 图标路径 + version: 版本号 + handles: 端口列表 (InputHandle / OutputHandle) + model: 可选的 3D 模型配置 + class_type: 资源实现类型 ("python" / "pylabrobot" / "unilabos") + """ + + def decorator(obj): + meta = { + "resource_id": id, + "category": category, + "description": description, + "icon": icon, + "version": version, + "handles": _device_handles_to_list(handles), + "model": model, + "class_type": class_type, + } + obj._resource_registry_meta = meta + + if id in _registered_resources: + raise ValueError(f"@resource id 重复: '{id}' 已被 {_registered_resources[id]} 注册") + _registered_resources[id] = obj + + return obj + + return decorator + + +def get_device_meta(cls, device_id: Optional[str] = None) -> Optional[Dict[str, Any]]: + """ + 获取类上的 @device 装饰器元数据。 + + 当 device_id 存在且类使用 ids+id_meta 时,返回合并后的 meta + (base_meta 与 id_meta[device_id] 深度合并)。 + """ + base = getattr(cls, "_device_registry_meta", None) + if base is None: + return None + id_meta = getattr(cls, "_device_registry_id_meta", None) or {} + if device_id is None or device_id not in id_meta: + result = dict(base) + ids = getattr(cls, "_device_registry_ids", None) + result["device_id"] = device_id if device_id is not None else (ids[0] if ids else None) + return result + + overrides = id_meta[device_id] + result = dict(base) + result["device_id"] = device_id + for key in ["handles", "description", "icon", "model"]: + if key in overrides: + val = overrides[key] + if key == "handles" and isinstance(val, list): + # handles 必须是 Handle 对象列表 + result[key] = [h.to_registry_dict() for h in val] + else: + result[key] = val + return result + + +def get_resource_meta(obj) -> Optional[Dict[str, Any]]: + """获取对象上的 @resource 装饰器元数据""" + return getattr(obj, "_resource_registry_meta", None) + + +def get_all_registered_devices() -> Dict[str, type]: + """获取所有已注册的设备类""" + return _registered_devices.copy() + + +def get_all_registered_resources() -> Dict[str, Any]: + """获取所有已注册的资源""" + return _registered_resources.copy() + + +def clear_registry(): + """清空全局注册表 (用于测试)""" + _registered_devices.clear() + _registered_resources.clear() + + +# --------------------------------------------------------------------------- +# topic_config / not_action / always_free 装饰器 +# --------------------------------------------------------------------------- + + +def topic_config( + period: Optional[float] = None, + print_publish: Optional[bool] = None, + qos: Optional[int] = None, + name: Optional[str] = None, +) -> Callable[[F], F]: + """ + Topic发布配置装饰器 + + 用于装饰 get_{attr_name} 方法或 @property,控制对应属性的ROS topic发布行为。 + + Args: + period: 发布周期(秒)。None 表示使用默认值 5.0 + print_publish: 是否打印发布日志。None 表示使用节点默认配置 + qos: QoS深度配置。None 表示使用默认值 10 + name: 自定义发布名称。None 表示使用方法名(去掉 get_ 前缀) + + Note: + 与 @property 连用时,@topic_config 必须放在 @property 下面, + 这样装饰器执行顺序为:先 topic_config 添加配置,再 property 包装。 + """ + + def decorator(func: F) -> F: + @wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + wrapper._topic_period = period # type: ignore[attr-defined] + wrapper._topic_print_publish = print_publish # type: ignore[attr-defined] + wrapper._topic_qos = qos # type: ignore[attr-defined] + wrapper._topic_name = name # type: ignore[attr-defined] + wrapper._has_topic_config = True # type: ignore[attr-defined] + + return wrapper # type: ignore[return-value] + + return decorator + + +def get_topic_config(func) -> dict: + """获取函数上的 topic 配置 (period, print_publish, qos, name)""" + if hasattr(func, "_has_topic_config") and getattr(func, "_has_topic_config", False): + return { + "period": getattr(func, "_topic_period", None), + "print_publish": getattr(func, "_topic_print_publish", None), + "qos": getattr(func, "_topic_qos", None), + "name": getattr(func, "_topic_name", None), + } + return {} + + +def always_free(func: F) -> F: + """ + 标记动作为永久闲置(不受busy队列限制)的装饰器 + + 被此装饰器标记的 action 方法,在执行时不会受到设备级别的排队限制, + 任何时候请求都可以立即执行。适用于查询类、状态读取类等轻量级操作。 + """ + + @wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + wrapper._is_always_free = True # type: ignore[attr-defined] + + return wrapper # type: ignore[return-value] + + +def is_always_free(func) -> bool: + """检查函数是否被标记为永久闲置""" + return getattr(func, "_is_always_free", False) + + +def not_action(func: F) -> F: + """ + 标记方法为非动作的装饰器 + + 用于装饰 driver 类中的方法,使其在注册表扫描时不被识别为动作。 + 适用于辅助方法、内部工具方法等不应暴露为设备动作的公共方法。 + """ + + @wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + wrapper._is_not_action = True # type: ignore[attr-defined] + + return wrapper # type: ignore[return-value] + + +def is_not_action(func) -> bool: + """检查函数是否被标记为非动作""" + return getattr(func, "_is_not_action", False) diff --git a/unilabos/registry/registry.py b/unilabos/registry/registry.py index 2a277664..355cd6d6 100644 --- a/unilabos/registry/registry.py +++ b/unilabos/registry/registry.py @@ -1,20 +1,57 @@ +""" +统一注册表系统 + +合并了原 Registry (YAML 加载) 和 DecoratorRegistry (装饰器/AST 扫描) 的功能, +提供单一入口来构建、验证和查询设备/资源注册表。 +""" + import copy +import importlib +import inspect import io import os import sys -import inspect -import importlib import threading import traceback from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path -from typing import Any, Dict, List, Union, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union import yaml +from unilabos_msgs.action import EmptyIn, ResourceCreateFromOuter, ResourceCreateFromOuterEasy from unilabos_msgs.msg import Resource from unilabos.config.config import BasicConfig +from unilabos.registry.decorators import ( + get_device_meta, + get_action_meta, + get_resource_meta, + has_action_decorator, + get_all_registered_devices, + get_all_registered_resources, + is_not_action, + is_always_free, + get_topic_config, +) +from unilabos.registry.utils import ( + ROSMsgNotFound, + parse_docstring, + get_json_schema_type, + parse_type_node, + type_node_to_schema, + resolve_type_object, + type_to_schema, + detect_slot_type, + detect_placeholder_keys, + normalize_ast_handles, + normalize_ast_action_handles, + wrap_action_schema, + preserve_field_descriptions, + resolve_method_params_via_import, + SIMPLE_TYPE_MAP, +) from unilabos.resources.graphio import resource_plr_to_ulab, tree_to_list +from unilabos.resources.resource_tracker import ResourceTreeSet from unilabos.ros.msgs.message_converter import ( msg_converter_manager, ros_action_to_json_schema, @@ -23,244 +60,1714 @@ from unilabos.ros.msgs.message_converter import ( ) from unilabos.utils import logger from unilabos.utils.decorator import singleton -from unilabos.utils.import_manager import get_enhanced_class_info, get_class +from unilabos.utils.cls_creator import import_class +from unilabos.utils.import_manager import get_enhanced_class_info from unilabos.utils.type_check import NoAliasDumper +from msgcenterpy.instances.json_schema_instance import JSONSchemaMessageInstance +from msgcenterpy.instances.ros2_instance import ROS2MessageInstance -DEFAULT_PATHS = [Path(__file__).absolute().parent] - - -class ROSMsgNotFound(Exception): - pass +_module_hash_cache: Dict[str, Optional[str]] = {} @singleton class Registry: + """ + 统一注册表。 + + 核心流程: + 1. AST 静态扫描 @device/@resource 装饰器 (快速, 无需 import) + 2. 加载 YAML 注册表 (兼容旧格式) + 3. 设置 host_node 内置设备 + 4. verify & resolve (实际 import 验证 + 类型解析) + """ + def __init__(self, registry_paths=None): import ctypes try: + # noinspection PyUnusedImports import unilabos_msgs except ImportError: logger.error("[UniLab Registry] unilabos_msgs模块未找到,请确保已根据官方文档安装unilabos_msgs包。") sys.exit(1) try: ctypes.CDLL(str(Path(unilabos_msgs.__file__).parent / "unilabos_msgs_s__rosidl_typesupport_c.pyd")) - except OSError as e: + except OSError: pass - self.registry_paths = DEFAULT_PATHS.copy() # 使用copy避免修改默认值 + self.registry_paths = [Path(__file__).absolute().parent] if registry_paths: self.registry_paths.extend(registry_paths) - self.ResourceCreateFromOuter = self._replace_type_with_class( - "ResourceCreateFromOuter", "host_node", f"动作 create_resource_detailed" - ) - self.ResourceCreateFromOuterEasy = self._replace_type_with_class( - "ResourceCreateFromOuterEasy", "host_node", f"动作 create_resource" - ) - self.EmptyIn = self._replace_type_with_class("EmptyIn", "host_node", f"") - self.StrSingleInput = self._replace_type_with_class("StrSingleInput", "host_node", f"") - self.device_type_registry = {} - self.device_module_to_registry = {} - self.resource_type_registry = {} - self._setup_called = False # 跟踪setup是否已调用 - self._registry_lock = threading.Lock() # 多线程加载时的锁 - # 其他状态变量 - # self.is_host_mode = False # 移至BasicConfig中 + logger.debug(f"[UniLab Registry] registry_paths: {self.registry_paths}") - def setup(self, complete_registry=False, upload_registry=False): - # 检查是否已调用过setup + self.device_type_registry: Dict[str, Any] = {} + self.resource_type_registry: Dict[str, Any] = {} + + self._setup_called = False + self._startup_executor: Optional[ThreadPoolExecutor] = None + + # ------------------------------------------------------------------ + # 统一入口 + # ------------------------------------------------------------------ + + def setup(self, devices_dirs=None, upload_registry=False): + """统一构建注册表入口。""" if self._setup_called: logger.critical("[UniLab Registry] setup方法已被调用过,不允许多次调用") return - from unilabos.app.web.utils.action_utils import get_yaml_from_goal_type - - # 获取 HostNode 类的增强信息,用于自动生成 action schema - host_node_enhanced_info = get_enhanced_class_info( - "unilabos.ros.nodes.presets.host_node:HostNode", use_dynamic=True + self._startup_executor = ThreadPoolExecutor( + max_workers=8, thread_name_prefix="RegistryStartup" ) - # 为 test_latency 生成 schema,保留原有 description - test_latency_method_info = host_node_enhanced_info.get("action_methods", {}).get("test_latency", {}) - test_latency_schema = self._generate_unilab_json_command_schema( - test_latency_method_info.get("args", []), - "test_latency", - test_latency_method_info.get("return_annotation"), - ) - test_latency_schema["description"] = "用于测试延迟的动作,返回延迟时间和时间差。" + # 1. AST 静态扫描 (快速, 无需 import) + self._run_ast_scan(devices_dirs, upload_registry=upload_registry) - test_resource_method_info = host_node_enhanced_info.get("action_methods", {}).get("test_resource", {}) - test_resource_schema = self._generate_unilab_json_command_schema( - test_resource_method_info.get("args", []), - "test_resource", - test_resource_method_info.get("return_annotation"), - ) - test_resource_schema["description"] = "用于测试物料、设备和样本。" + # 2. Host node 内置设备 + self._setup_host_node() - create_resource_method_info = host_node_enhanced_info.get("action_methods", {}).get("create_resource", {}) - create_resource_schema = self._generate_unilab_json_command_schema( - create_resource_method_info.get("args", []), - "create_resource", - create_resource_method_info.get("return_annotation"), - ) - create_resource_schema["description"] = "用于创建物料" - raw_create_resource_schema = ros_action_to_json_schema( - self.ResourceCreateFromOuterEasy, "用于创建或更新物料资源,每次传入一个物料信息。" - ) - raw_create_resource_schema["properties"]["result"] = create_resource_schema["properties"]["result"] - - self.device_type_registry.update( - { - "host_node": { - "description": "UniLabOS主机节点", - "class": { - "module": "unilabos.ros.nodes.presets.host_node", - "type": "python", - "status_types": {}, - "action_value_mappings": { - "create_resource_detailed": { - "type": self.ResourceCreateFromOuter, - "goal": { - "resources": "resources", - "device_ids": "device_ids", - "bind_parent_ids": "bind_parent_ids", - "bind_locations": "bind_locations", - "other_calling_params": "other_calling_params", - }, - "feedback": {}, - "result": {"success": "success"}, - "schema": ros_action_to_json_schema( - self.ResourceCreateFromOuter, "用于创建或更新物料资源,每次传入多个物料信息。" - ), - "goal_default": yaml.safe_load( - io.StringIO(get_yaml_from_goal_type(self.ResourceCreateFromOuter.Goal)) - ), - "handles": {}, - }, - "create_resource": { - "type": self.ResourceCreateFromOuterEasy, - "goal": { - "res_id": "res_id", - "class_name": "class_name", - "parent": "parent", - "device_id": "device_id", - "bind_locations": "bind_locations", - "liquid_input_slot": "liquid_input_slot[]", - "liquid_type": "liquid_type[]", - "liquid_volume": "liquid_volume[]", - "slot_on_deck": "slot_on_deck", - }, - "feedback": {}, - "result": {"success": "success"}, - "schema": raw_create_resource_schema, - "goal_default": yaml.safe_load( - io.StringIO(get_yaml_from_goal_type(self.ResourceCreateFromOuterEasy.Goal)) - ), - "handles": { - "output": [ - { - "handler_key": "labware", - "data_type": "resource", - "label": "Labware", - "data_source": "executor", - "data_key": "created_resource_tree.@flatten", - }, - { - "handler_key": "liquid_slots", - "data_type": "resource", - "label": "LiquidSlots", - "data_source": "executor", - "data_key": "liquid_input_resource_tree.@flatten", - }, - { - "handler_key": "materials", - "data_type": "resource", - "label": "AllMaterials", - "data_source": "executor", - "data_key": "[created_resource_tree,liquid_input_resource_tree].@flatten.@flatten", - }, - ] - }, - "placeholder_keys": { - "res_id": "unilabos_resources", # 将当前实验室的全部物料id作为下拉框可选择 - "device_id": "unilabos_devices", # 将当前实验室的全部设备id作为下拉框可选择 - "parent": "unilabos_nodes", # 将当前实验室的设备/物料作为下拉框可选择 - "class_name": "unilabos_class", # 当前实验室物料的class name - "slot_on_deck": "unilabos_resource_slot:parent", # 勾选的parent的config中的sites的name,展示name,参数对应slot(index) - }, - }, - "test_latency": { - "type": ( - "UniLabJsonCommandAsync" - if test_latency_method_info.get("is_async", False) - else "UniLabJsonCommand" - ), - "goal": {}, - "feedback": {}, - "result": {}, - "schema": test_latency_schema, - "goal_default": { - arg["name"]: arg["default"] for arg in test_latency_method_info.get("args", []) - }, - "handles": {}, - }, - "auto-test_resource": { - "type": "UniLabJsonCommand", - "goal": {}, - "feedback": {}, - "result": {}, - "schema": test_resource_schema, - "placeholder_keys": { - "device": "unilabos_devices", - "devices": "unilabos_devices", - "resource": "unilabos_resources", - "resources": "unilabos_resources", - }, - "goal_default": {}, - "handles": { - "input": [ - { - "handler_key": "input_resources", - "data_type": "resource", - "label": "InputResources", - "data_source": "handle", - "data_key": "resources", # 不为空 - }, - ] - }, - }, - }, - }, - "version": "1.0.0", - "category": [], - "config_info": [], - "icon": "icon_device.webp", - "registry_type": "device", - "handles": [], # virtue采用了不同的handle - "init_param_schema": {}, - "file_path": "/", - } - } - ) - # 为host_node添加内置的驱动命令动作 - self._add_builtin_actions(self.device_type_registry["host_node"], "host_node") - logger.trace(f"[UniLab Registry] ----------Setup----------") + # 3. YAML 注册表加载 (兼容旧格式) self.registry_paths = [Path(path).absolute() for path in self.registry_paths] for i, path in enumerate(self.registry_paths): sys_path = path.parent logger.trace(f"[UniLab Registry] Path {i+1}/{len(self.registry_paths)}: {sys_path}") sys.path.append(str(sys_path)) - self.load_device_types(path, complete_registry) + self.load_device_types(path) if BasicConfig.enable_resource_load: - self.load_resource_types(path, complete_registry, upload_registry) + self.load_resource_types(path, upload_registry) else: - logger.warning("跳过了资源注册表加载!") - logger.info("[UniLab Registry] 注册表设置完成") - # 标记setup已被调用 + logger.warning( + "[UniLab Registry] 资源加载已禁用 (enable_resource_load=False),跳过资源注册表加载" + ) + self._startup_executor.shutdown(wait=True) + self._startup_executor = None self._setup_called = True + logger.trace(f"[UniLab Registry] ----------Setup Complete----------") + + # ------------------------------------------------------------------ + # Host node 设置 + # ------------------------------------------------------------------ + + def _setup_host_node(self): + """设置 host_node 内置设备 — 基于 _run_ast_scan 已扫描的结果进行覆写。""" + # 从 AST 扫描结果中取出 host_node 的 action_value_mappings + ast_entry = self.device_type_registry.get("host_node", {}) + ast_actions = ast_entry.get("class", {}).get("action_value_mappings", {}) + + # 取出 AST 生成的 auto-method entries, 补充特定覆写 + test_latency_action = ast_actions.get("auto-test_latency", {}) + test_resource_action = ast_actions.get("auto-test_resource", {}) + test_resource_action["handles"] = { + "input": [ + { + "handler_key": "input_resources", + "data_type": "resource", + "label": "InputResources", + "data_source": "handle", + "data_key": "resources", + }, + ] + } + + create_resource_action = ast_actions.get("auto-create_resource", {}) + raw_create_resource_schema = ros_action_to_json_schema( + ResourceCreateFromOuterEasy, "用于创建或更新物料资源,每次传入一个物料信息。" + ) + raw_create_resource_schema["properties"]["result"] = create_resource_action["schema"]["properties"]["result"] + + # 覆写: 保留硬编码的 ROS2 action + AST 生成的 auto-method + self.device_type_registry["host_node"] = { + "class": { + "module": "unilabos.ros.nodes.presets.host_node:HostNode", + "status_types": {}, + "action_value_mappings": { + "create_resource_detailed": { + "type": ResourceCreateFromOuter, + "goal": { + "resources": "resources", + "device_ids": "device_ids", + "bind_parent_ids": "bind_parent_ids", + "bind_locations": "bind_locations", + "other_calling_params": "other_calling_params", + }, + "feedback": {}, + "result": {"success": "success"}, + "schema": ros_action_to_json_schema(ResourceCreateFromOuter), + "goal_default": ROS2MessageInstance(ResourceCreateFromOuter.Goal()).get_python_dict(), + "handles": {}, + }, + "create_resource": { + "type": ResourceCreateFromOuterEasy, + "goal": { + "res_id": "res_id", + "class_name": "class_name", + "parent": "parent", + "device_id": "device_id", + "bind_locations": "bind_locations", + "liquid_input_slot": "liquid_input_slot[]", + "liquid_type": "liquid_type[]", + "liquid_volume": "liquid_volume[]", + "slot_on_deck": "slot_on_deck", + }, + "feedback": {}, + "result": {"success": "success"}, + "schema": raw_create_resource_schema, + "goal_default": ROS2MessageInstance(ResourceCreateFromOuterEasy.Goal()).get_python_dict(), + "handles": { + "output": [ + { + "handler_key": "labware", + "data_type": "resource", + "label": "Labware", + "data_source": "executor", + "data_key": "created_resource_tree.@flatten", + }, + { + "handler_key": "liquid_slots", + "data_type": "resource", + "label": "LiquidSlots", + "data_source": "executor", + "data_key": "liquid_input_resource_tree.@flatten", + }, + { + "handler_key": "materials", + "data_type": "resource", + "label": "AllMaterials", + "data_source": "executor", + "data_key": "[created_resource_tree,liquid_input_resource_tree].@flatten.@flatten", + }, + ] + }, + "placeholder_keys": { + "res_id": "unilabos_resources", + "device_id": "unilabos_devices", + "parent": "unilabos_nodes", + "class_name": "unilabos_class", + }, + }, + "test_latency": test_latency_action, + "auto-test_resource": test_resource_action, + }, + "init_params": {}, + }, + "version": "1.0.0", + "category": [], + "config_info": [], + "icon": "icon_device.webp", + "registry_type": "device", + "description": "Host Node", + "handles": [], + "init_param_schema": {}, + "file_path": "/", + } + self._add_builtin_actions(self.device_type_registry["host_node"], "host_node") + + # ------------------------------------------------------------------ + # AST 静态扫描 + # ------------------------------------------------------------------ + + def _run_ast_scan(self, devices_dirs=None, upload_registry=False): + """ + 执行 AST 静态扫描,从 Python 代码中提取 @device / @resource 装饰器元数据。 + 无需 import 任何驱动模块,速度极快。 + + 启用文件级缓存:对每个 .py 文件记录 MD5/size/mtime,未变化的文件直接 + 复用上次的扫描结果,大幅减少重复启动时的耗时。 + + 扫描策略: + - 默认扫描 unilabos 包所在目录(即 unilabos 的父目录) + - 如果传入 devices_dirs,额外扫描这些目录(并将其父目录加入 sys.path) + """ + import time as _time + from unilabos.registry.ast_registry_scanner import ( + scan_directory, load_scan_cache, save_scan_cache, + ) + + scan_t0 = _time.perf_counter() + + # 确保 executor 存在 + own_executor = False + if self._startup_executor is None: + self._startup_executor = ThreadPoolExecutor( + max_workers=8, thread_name_prefix="RegistryStartup" + ) + own_executor = True + + # 加载缓存 + cache_path = None + if BasicConfig.working_dir: + cache_path = Path(BasicConfig.working_dir) / "ast_scan_cache.json" + cache = load_scan_cache(cache_path) + + # 默认:扫描 unilabos 包所在的父目录 + pkg_root = Path(__file__).resolve().parent.parent # .../unilabos + python_path = pkg_root.parent # .../Uni-Lab-OS + scan_root = pkg_root # 扫描 unilabos/ 整个包 + + # 额外的 --devices 目录:把它们的父目录加入 sys.path + extra_dirs: list[Path] = [] + if devices_dirs: + for d in devices_dirs: + d_path = Path(d).resolve() + if not d_path.is_dir(): + logger.warning(f"[UniLab Registry] --devices 路径不存在或不是目录: {d_path}") + continue + parent_dir = str(d_path.parent) + if parent_dir not in sys.path: + sys.path.insert(0, parent_dir) + logger.info(f"[UniLab Registry] 添加 Python 路径: {parent_dir}") + extra_dirs.append(d_path) + + # 主扫描 + exclude_files = {"lab_resources.py"} if not BasicConfig.extra_resource else None + scan_result = scan_directory( + scan_root, python_path=python_path, executor=self._startup_executor, + exclude_files=exclude_files, cache=cache, + ) + if exclude_files: + logger.info( + f"[UniLab Registry] 排除扫描文件: {exclude_files} " + f"(可通过 --extra_resource 启用加载)" + ) + + # 合并缓存统计 + total_stats = scan_result.pop("_cache_stats", {"hits": 0, "misses": 0, "total": 0}) + + # 额外目录逐个扫描并合并 + for d_path in extra_dirs: + extra_result = scan_directory( + d_path, python_path=str(d_path.parent), executor=self._startup_executor, + cache=cache, + ) + extra_stats = extra_result.pop("_cache_stats", {"hits": 0, "misses": 0, "total": 0}) + total_stats["hits"] += extra_stats["hits"] + total_stats["misses"] += extra_stats["misses"] + total_stats["total"] += extra_stats["total"] + + for did, dmeta in extra_result.get("devices", {}).items(): + if did in scan_result.get("devices", {}): + existing = scan_result["devices"][did].get("file_path", "?") + new_file = dmeta.get("file_path", "?") + raise ValueError( + f"@device id 重复: '{did}' 同时出现在 {existing} 和 {new_file}" + ) + scan_result.setdefault("devices", {})[did] = dmeta + for rid, rmeta in extra_result.get("resources", {}).items(): + if rid in scan_result.get("resources", {}): + existing = scan_result["resources"][rid].get("file_path", "?") + new_file = rmeta.get("file_path", "?") + raise ValueError( + f"@resource id 重复: '{rid}' 同时出现在 {existing} 和 {new_file}" + ) + scan_result.setdefault("resources", {})[rid] = rmeta + + # 持久化缓存 + cache["saved_at"] = _time.strftime("%Y-%m-%d %H:%M:%S") + save_scan_cache(cache_path, cache) + + # 缓存命中统计 + if total_stats["total"] > 0: + logger.info( + f"[UniLab Registry] AST 缓存统计: " + f"{total_stats['hits']}/{total_stats['total']} 命中, " + f"{total_stats['misses']} 重新解析" + ) + + ast_devices = scan_result.get("devices", {}) + ast_resources = scan_result.get("resources", {}) + + # build 结果缓存:当所有 AST 文件命中时跳过 _build_*_entry_from_ast + all_ast_hit = total_stats["misses"] == 0 and total_stats["total"] > 0 + build_cache = self._load_config_cache() if all_ast_hit else {} + cached_build = build_cache.get("_build_results") + + if all_ast_hit and cached_build: + cached_devices = cached_build.get("devices", {}) + cached_resources = cached_build.get("resources", {}) + if set(cached_devices) == set(ast_devices) and set(cached_resources) == set(ast_resources): + self.device_type_registry.update(cached_devices) + self.resource_type_registry.update(cached_resources) + logger.info( + f"[UniLab Registry] build 缓存命中: 跳过 {len(cached_devices)} 设备 + " + f"{len(cached_resources)} 资源的 entry 构建" + ) + else: + cached_build = None + + if not cached_build: + build_t0 = _time.perf_counter() + + for device_id, ast_meta in ast_devices.items(): + entry = self._build_device_entry_from_ast(device_id, ast_meta) + if entry: + self.device_type_registry[device_id] = entry + + for resource_id, ast_meta in ast_resources.items(): + entry = self._build_resource_entry_from_ast(resource_id, ast_meta) + if entry: + self.resource_type_registry[resource_id] = entry + + build_elapsed = _time.perf_counter() - build_t0 + logger.info(f"[UniLab Registry] entry 构建耗时: {build_elapsed:.2f}s") + + if not build_cache: + build_cache = self._load_config_cache() + build_cache["_build_results"] = { + "devices": {k: v for k, v in self.device_type_registry.items() if k in ast_devices}, + "resources": {k: v for k, v in self.resource_type_registry.items() if k in ast_resources}, + } + + # upload 模式下,利用线程池并行 import pylabrobot 资源并生成 config_info + if upload_registry: + if build_cache: + self._populate_resource_config_info(config_cache=build_cache) + self._save_config_cache(build_cache) + else: + self._populate_resource_config_info() + elif build_cache and not cached_build: + self._save_config_cache(build_cache) + + ast_device_count = len(ast_devices) + ast_resource_count = len(ast_resources) + scan_elapsed = _time.perf_counter() - scan_t0 + if ast_device_count > 0 or ast_resource_count > 0: + logger.info( + f"[UniLab Registry] AST 扫描完成: {ast_device_count} 设备, " + f"{ast_resource_count} 资源 (耗时 {scan_elapsed:.2f}s)" + ) + + if own_executor: + self._startup_executor.shutdown(wait=False) + self._startup_executor = None + + # ------------------------------------------------------------------ + # 类型辅助 (共享, 去重后的单一实现) + # ------------------------------------------------------------------ + + def _replace_type_with_class(self, type_name: str, device_id: str, field_name: str) -> Any: + """将类型名称替换为实际的 ROS 消息类对象""" + if not type_name or type_name == "": + return type_name + + # 泛型类型映射 + if "[" in type_name: + generic_mapping = { + "List[int]": "Int64MultiArray", + "list[int]": "Int64MultiArray", + "List[float]": "Float64MultiArray", + "list[float]": "Float64MultiArray", + "List[bool]": "Int8MultiArray", + "list[bool]": "Int8MultiArray", + } + mapped = generic_mapping.get(type_name) + if mapped: + cls = msg_converter_manager.search_class(mapped) + if cls: + return cls + logger.debug( + f"[Registry] 设备 {device_id} 的 {field_name} " + f"泛型类型 '{type_name}' 映射为 String" + ) + return String + + convert_manager = { + "str": "String", + "bool": "Bool", + "int": "Int64", + "float": "Float64", + } + type_name = convert_manager.get(type_name, type_name) + if ":" in type_name: + type_class = msg_converter_manager.get_class(type_name) + else: + type_class = msg_converter_manager.search_class(type_name) + if type_class: + return type_class + else: + # dataclass / TypedDict 等非 ROS2 类型,序列化为 JSON 字符串 + logger.trace( + f"[Registry] 类型 '{type_name}' 非 ROS2 消息类型 (设备 {device_id} {field_name}),映射为 String" + ) + return String + + # ---- 类型字符串 -> JSON Schema type ---- + # (常量和工具函数已移至 unilabos.registry.utils) + + def _generate_schema_from_info( + self, param_name: str, param_type: Union[str, Tuple[str]], param_default: Any, + import_map: Optional[Dict[str, str]] = None, + ) -> Dict[str, Any]: + """根据参数信息生成 JSON Schema。 + 支持复杂类型字符串如 'Optional[Dict[str, Any]]'、'List[int]' 等。 + 当提供 import_map 时,可解析 TypedDict 等自定义类型。""" + + prop_schema: Dict[str, Any] = {} + + if isinstance(param_type, str) and ("[" in param_type or "|" in param_type): + # 复杂泛型 — ast.parse 解析结构,递归生成 schema + node = parse_type_node(param_type) + if node is not None: + prop_schema = type_node_to_schema(node, import_map) + # slot 标记 fallback(正常不应走到这里,上层会拦截) + if "$slot" in prop_schema: + prop_schema = {"type": "object"} + else: + prop_schema["type"] = "string" + elif isinstance(param_type, str): + # 简单类型名,但可能是 import_map 中的自定义类型 + json_type = SIMPLE_TYPE_MAP.get(param_type.lower()) + if json_type: + prop_schema["type"] = json_type + elif import_map and param_type in import_map: + type_obj = resolve_type_object(import_map[param_type]) + if type_obj is not None: + prop_schema = type_to_schema(type_obj) + else: + # 无法 import 的自定义类型,默认当 object 处理(与 YAML runtime 路径一致) + prop_schema["type"] = "object" + else: + json_type = get_json_schema_type(param_type) + if json_type == "string" and param_type and param_type.lower() not in SIMPLE_TYPE_MAP: + # 不在已知简单类型中的未知类型名,当 object 处理 + prop_schema["type"] = "object" + else: + prop_schema["type"] = json_type + elif isinstance(param_type, tuple): + if len(param_type) == 2: + outer_type, inner_type = param_type + outer_json_type = get_json_schema_type(outer_type) + prop_schema["type"] = outer_json_type + # Any 值类型不加 additionalProperties/items (等同于无约束) + if isinstance(inner_type, str) and inner_type in ("Any", "None", "Unknown"): + pass + else: + inner_json_type = get_json_schema_type(inner_type) + if outer_json_type == "array": + prop_schema["items"] = {"type": inner_json_type} + elif outer_json_type == "object": + prop_schema["additionalProperties"] = {"type": inner_json_type} + else: + prop_schema["type"] = "string" + else: + prop_schema["type"] = get_json_schema_type(param_type) + + if param_default is not None: + prop_schema["default"] = param_default + + return prop_schema + + def _generate_unilab_json_command_schema( + self, method_args: list, docstring: Optional[str] = None + ) -> Dict[str, Any]: + """根据方法参数和 docstring 生成 UniLabJsonCommand schema""" + doc_info = parse_docstring(docstring) + param_descs = doc_info.get("params", {}) + + schema = { + "type": "object", + "properties": {}, + "required": [], + } + for arg_info in method_args: + param_name = arg_info.get("name", "") + param_type = arg_info.get("type", "") + param_default = arg_info.get("default") + param_required = arg_info.get("required", True) + + is_slot, is_list_slot = detect_slot_type(param_type) + if is_slot == "ResourceSlot": + if is_list_slot: + schema["properties"][param_name] = { + "items": ros_message_to_json_schema(Resource, param_name), + "type": "array", + } + else: + schema["properties"][param_name] = ros_message_to_json_schema( + Resource, param_name + ) + elif is_slot == "DeviceSlot": + schema["properties"][param_name] = {"type": "string", "description": "device reference"} + else: + schema["properties"][param_name] = self._generate_schema_from_info( + param_name, param_type, param_default + ) + + if param_name in param_descs: + schema["properties"][param_name]["description"] = param_descs[param_name] + + if param_required: + schema["required"].append(param_name) + + return schema + + def _generate_status_types_schema(self, status_methods: Dict[str, Any]) -> Dict[str, Any]: + """根据 status 方法信息生成 status_types schema""" + status_schema: Dict[str, Any] = { + "type": "object", + "properties": {}, + "required": [], + } + for status_name, status_info in status_methods.items(): + return_type = status_info.get("return_type", "str") + status_schema["properties"][status_name] = self._generate_schema_from_info( + status_name, return_type, None + ) + status_schema["required"].append(status_name) + return status_schema + + # ------------------------------------------------------------------ + # 方法签名分析 -- 委托给 ImportManager + # ------------------------------------------------------------------ + + @staticmethod + def _analyze_method_signature(method) -> Dict[str, Any]: + """分析方法签名,提取参数信息""" + from unilabos.utils.import_manager import default_manager + try: + return default_manager._analyze_method_signature(method) + except (ValueError, TypeError): + return {"args": [], "is_async": inspect.iscoroutinefunction(method)} + + @staticmethod + def _get_return_type_from_method(method) -> str: + """获取方法的返回类型字符串""" + from unilabos.utils.import_manager import default_manager + return default_manager._get_return_type_from_method(method) + + # ------------------------------------------------------------------ + # 动态类信息提取 (import-based) + # ------------------------------------------------------------------ + + def _extract_class_info(self, cls) -> Dict[str, Any]: + """ + 从类中提取 init 参数、状态方法和动作方法信息。 + """ + result = { + "class_name": cls.__name__, + "init_params": self._analyze_method_signature(cls.__init__)["args"], + "status_methods": {}, + "action_methods": {}, + "explicit_actions": {}, + "decorated_no_type_actions": {}, + } + + for name, method in cls.__dict__.items(): + if name.startswith("_"): + continue + + # property => status + if isinstance(method, property): + return_type = self._get_return_type_from_method(method.fget) if method.fget else "Any" + status_entry = { + "name": name, + "return_type": return_type, + } + if method.fget: + tc = get_topic_config(method.fget) + if tc: + status_entry["topic_config"] = tc + result["status_methods"][name] = status_entry + + if method.fset: + setter_info = self._analyze_method_signature(method.fset) + action_meta = get_action_meta(method.fset) + if action_meta and action_meta.get("action_type") is not None: + result["explicit_actions"][name] = { + "method_info": setter_info, + "action_meta": action_meta, + } + continue + + if not callable(method): + continue + + if is_not_action(method): + continue + + # @topic_config 装饰的非 property 方法视为状态方法,不作为 action + tc = get_topic_config(method) + if tc: + return_type = self._get_return_type_from_method(method) + prop_name = name[4:] if name.startswith("get_") else name + result["status_methods"][prop_name] = { + "name": prop_name, + "return_type": return_type, + "topic_config": tc, + } + continue + + method_info = self._analyze_method_signature(method) + action_meta = get_action_meta(method) + + if action_meta: + action_type = action_meta.get("action_type") + if action_type is not None: + result["explicit_actions"][name] = { + "method_info": method_info, + "action_meta": action_meta, + } + else: + result["decorated_no_type_actions"][name] = { + "method_info": method_info, + "action_meta": action_meta, + } + elif has_action_decorator(method): + result["explicit_actions"][name] = { + "method_info": method_info, + "action_meta": action_meta or {}, + } + else: + result["action_methods"][name] = method_info + + return result + + # ------------------------------------------------------------------ + # 设备注册表条目构建 (import-based) + # ------------------------------------------------------------------ + + def _build_device_entry(self, cls, device_meta: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """ + 根据类和装饰器元数据构建一个设备的完整注册表条目。 + """ + class_info = self._extract_class_info(cls) + module_str = f"{cls.__module__}:{cls.__name__}" + + # --- status_types --- + status_types_str = {} + status_types_ros = {} + status_str_type_mapping = {} + + for name, info in class_info["status_methods"].items(): + ret_type = info.get("return_type", "str") + if isinstance(ret_type, tuple) or ret_type in ["Any", "None", "Unknown"]: + ret_type = "String" + status_types_str[name] = ret_type + + target_type = self._replace_type_with_class(ret_type, device_meta.get("device_id", ""), f"状态 {name}") + if target_type in [dict, list]: + target_type = String + if target_type: + status_types_ros[name] = target_type + status_str_type_mapping[ret_type] = target_type + + status_types_str = dict(sorted(status_types_str.items())) + + # --- action_value_mappings --- + action_value_mappings_yaml = {} + action_value_mappings_runtime = {} + action_str_type_mapping = { + "UniLabJsonCommand": "UniLabJsonCommand", + "UniLabJsonCommandAsync": "UniLabJsonCommandAsync", + } + + # 1) auto- 动作 + for method_name, method_info in class_info["action_methods"].items(): + is_async = method_info.get("is_async", False) + type_str = "UniLabJsonCommandAsync" if is_async else "UniLabJsonCommand" + schema = self._generate_unilab_json_command_schema( + method_info["args"], + docstring=getattr(getattr(cls, method_name, None), "__doc__", None), + ) + goal_default = {a["name"]: a.get("default") for a in method_info["args"]} + + action_entry = { + "type": type_str, + "goal": {}, + "feedback": {}, + "result": {}, + "schema": schema, + "goal_default": goal_default, + "handles": {}, + } + action_value_mappings_yaml[f"auto-{method_name}"] = action_entry + action_value_mappings_runtime[f"auto-{method_name}"] = copy.deepcopy(action_entry) + + # 2) @action() 无 action_type + for method_name, info in class_info["decorated_no_type_actions"].items(): + method_info = info["method_info"] + action_meta = info["action_meta"] + is_async = method_info.get("is_async", False) + type_str = "UniLabJsonCommandAsync" if is_async else "UniLabJsonCommand" + schema = self._generate_unilab_json_command_schema( + method_info["args"], + docstring=getattr(getattr(cls, method_name, None), "__doc__", None), + ) + goal_default = {a["name"]: a.get("default") for a in method_info["args"]} + + action_name = action_meta.get("action_name", method_name) + action_entry = { + "type": type_str, + "goal": {}, + "feedback": {}, + "result": {}, + "schema": schema, + "goal_default": goal_default, + "handles": {}, + } + if is_always_free(getattr(cls, method_name, None)): + action_entry["always_free"] = True + action_value_mappings_yaml[action_name] = action_entry + action_value_mappings_runtime[action_name] = copy.deepcopy(action_entry) + + # 3) @action(action_type=X) + for method_name, info in class_info["explicit_actions"].items(): + method_info = info["method_info"] + action_meta = info["action_meta"] + action_type_raw = action_meta.get("action_type", "") + action_name = action_meta.get("action_name", method_name) + + action_type_obj = None + if isinstance(action_type_raw, type): + action_type_obj = action_type_raw + action_type_str = f"{action_type_raw.__module__}:{action_type_raw.__name__}" + elif isinstance(action_type_raw, str) and "." in action_type_raw and ":" not in action_type_raw: + parts = action_type_raw.rsplit(".", 1) + action_type_str = f"{parts[0]}:{parts[1]}" if len(parts) == 2 else action_type_raw + action_type_obj = resolve_type_object(action_type_str) + else: + action_type_str = str(action_type_raw) + if ":" in action_type_str: + action_type_obj = resolve_type_object(action_type_str) + + action_str_type_mapping[action_type_str] = action_type_str + + # goal: 优先方法参数 identity, 其次 MRO 父类参数 (需 parent=True), 最后 ROS2 Goal identity + method_args = method_info.get("args", []) + goal = {a["name"]: a["name"] for a in method_args} + if not goal and action_meta.get("parent"): + for base_cls in cls.__mro__: + if method_name not in base_cls.__dict__: + continue + base_method = base_cls.__dict__[method_name] + actual = getattr(base_method, "__wrapped__", base_method) + if isinstance(actual, (staticmethod, classmethod)): + actual = actual.__func__ + if not callable(actual): + continue + try: + sig = inspect.signature(actual, follow_wrapped=True) + params = [ + p.name for p in sig.parameters.values() + if p.name not in ("self", "cls") + and p.kind not in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD) + ] + if params: + goal = {p: p for p in params} + break + except (ValueError, TypeError): + continue + if not goal and action_type_obj is not None and hasattr(action_type_obj, "Goal"): + try: + goal = {k: k for k in action_type_obj.Goal.get_fields_and_field_types()} + except Exception: + pass + goal_mapping_override = action_meta.get("goal_mapping", {}) + if goal_mapping_override: + override_values = set(goal_mapping_override.values()) + goal = {k: v for k, v in goal.items() if not (k == v and v in override_values)} + goal.update(goal_mapping_override) + + # feedback / result: ROS2 identity + override + feedback = {} + if action_type_obj is not None and hasattr(action_type_obj, "Feedback"): + try: + feedback = {k: k for k in action_type_obj.Feedback.get_fields_and_field_types()} + except Exception: + pass + feedback.update(action_meta.get("feedback_mapping", {})) + + result_mapping = {} + if action_type_obj is not None and hasattr(action_type_obj, "Result"): + try: + result_mapping = {k: k for k in action_type_obj.Result.get_fields_and_field_types()} + except Exception: + pass + result_mapping.update(action_meta.get("result_mapping", {})) + + goal_default = {} + if action_type_obj is not None and hasattr(action_type_obj, "Goal"): + try: + goal_default = ROS2MessageInstance(action_type_obj.Goal()).get_python_dict() + except Exception: + pass + + action_entry = { + "type": action_type_str, + "goal": goal, + "feedback": feedback, + "result": result_mapping, + "schema": ros_action_to_json_schema(action_type_str), + "goal_default": goal_default, + "handles": {}, + } + if is_always_free(getattr(cls, method_name, None)): + action_entry["always_free"] = True + action_value_mappings_yaml[action_name] = action_entry + action_value_mappings_runtime[action_name] = copy.deepcopy(action_entry) + + action_value_mappings_yaml = dict(sorted(action_value_mappings_yaml.items())) + action_value_mappings_runtime = dict(sorted(action_value_mappings_runtime.items())) + + # --- init_param_schema --- + init_schema = self._generate_unilab_json_command_schema(class_info["init_params"]) + + # --- handles --- + handles_raw = device_meta.get("handles", []) + handles = [] + for h in handles_raw: + if isinstance(h, dict): + handles.append(h) + elif hasattr(h, "to_dict"): + handles.append(h.to_dict()) + + # --- 构建 YAML 版本 --- + yaml_entry: Dict[str, Any] = { + "category": device_meta.get("category", []), + "class": { + "module": module_str, + "status_types": status_types_str, + "action_value_mappings": action_value_mappings_yaml, + "init_params": {a["name"]: a.get("type", "") for a in class_info["init_params"]}, + }, + "description": device_meta.get("description", ""), + "handles": handles, + "icon": device_meta.get("icon", ""), + "init_param_schema": init_schema, + "version": device_meta.get("version", "1.0.0"), + } + + # --- 构建运行时版本 --- + runtime_entry: Dict[str, Any] = { + "category": device_meta.get("category", []), + "class": { + "module": module_str, + "status_types": status_types_ros, + "action_value_mappings": action_value_mappings_runtime, + "init_params": {a["name"]: a.get("type", "") for a in class_info["init_params"]}, + }, + "description": device_meta.get("description", ""), + "handles": handles, + "icon": device_meta.get("icon", ""), + "init_param_schema": init_schema, + "version": device_meta.get("version", "1.0.0"), + } + + return yaml_entry, runtime_entry + + def _build_resource_entry(self, obj, resource_meta: Dict[str, Any]) -> Dict[str, Any]: + """根据 @resource 元数据构建资源注册表条目""" + module_str = f"{obj.__module__}:{obj.__name__}" if hasattr(obj, "__name__") else "" + + entry = { + "category": resource_meta.get("category") or [], + "class": { + "module": module_str, + "type": resource_meta.get("class_type", "python"), + }, + "description": resource_meta.get("description", ""), + "handles": [], + "icon": resource_meta.get("icon", ""), + "init_param_schema": {}, + "version": resource_meta.get("version", "1.0.0"), + } + + if resource_meta.get("model"): + entry["model"] = resource_meta["model"] + + return entry + + # ------------------------------------------------------------------ + # 内置动作 + # ------------------------------------------------------------------ + + def _add_builtin_actions(self, device_config: Dict[str, Any], device_id: str): + """为设备添加内置的驱动命令动作""" + str_single_input = self._replace_type_with_class("StrSingleInput", device_id, "内置动作") + for additional_action in ["_execute_driver_command", "_execute_driver_command_async"]: + try: + goal_default = ROS2MessageInstance(str_single_input.Goal()).get_python_dict() + except Exception: + goal_default = {"string": ""} + + device_config["class"]["action_value_mappings"][additional_action] = { + "type": str_single_input, + "goal": {"string": "string"}, + "feedback": {}, + "result": {}, + "schema": ros_action_to_json_schema(str_single_input), + "goal_default": goal_default, + "handles": {}, + } + + # ------------------------------------------------------------------ + # AST-based 注册表条目构建 + # ------------------------------------------------------------------ + + def _build_device_entry_from_ast(self, device_id: str, ast_meta: dict) -> Dict[str, Any]: + """ + Build a device registry entry from AST-scanned metadata. + Uses only string types -- no module imports required (except for TypedDict resolution). + """ + module_str = ast_meta.get("module", "") + file_path = ast_meta.get("file_path", "") + imap = ast_meta.get("import_map") or {} + + # --- status_types (string version) --- + status_types_str: Dict[str, str] = {} + for name, info in ast_meta.get("status_properties", {}).items(): + ret_type = info.get("return_type", "str") + if not ret_type or ret_type in ("Any", "None", "Unknown", ""): + ret_type = "String" + # 归一化泛型容器类型: Dict[str, Any] → dict, List[int] → list 等 + elif "[" in ret_type: + base = ret_type.split("[", 1)[0].strip() + base_lower = base.lower() + if base_lower in ("dict", "mapping", "ordereddict"): + ret_type = "dict" + elif base_lower in ("list", "tuple", "set", "sequence", "iterable"): + ret_type = "list" + elif base_lower == "optional": + # Optional[X] → 取内部类型再归一化 + inner = ret_type.split("[", 1)[1].rsplit("]", 1)[0].strip() + inner_lower = inner.lower() + if inner_lower in ("dict", "mapping"): + ret_type = "dict" + elif inner_lower in ("list", "tuple", "set"): + ret_type = "list" + else: + ret_type = inner + status_types_str[name] = ret_type + status_types_str = dict(sorted(status_types_str.items())) + + # --- action_value_mappings --- + action_value_mappings: Dict[str, Any] = {} + + def _build_json_command_entry(method_name, method_info, action_args=None): + """构建 UniLabJsonCommand 类型的 action entry""" + is_async = method_info.get("is_async", False) + type_str = "UniLabJsonCommandAsync" if is_async else "UniLabJsonCommand" + params = method_info.get("params", []) + method_doc = method_info.get("docstring") + goal_schema = self._generate_schema_from_ast_params(params, method_name, method_doc, imap) + + if action_args is not None: + action_name = action_args.get("action_name", method_name) + if action_args.get("auto_prefix"): + action_name = f"auto-{action_name}" + else: + action_name = f"auto-{method_name}" + + # Source C: 从 schema 生成类型默认值 + goal_default = JSONSchemaMessageInstance.generate_default_from_schema(goal_schema) + # Source B: method param 显式 default 覆盖 Source C + for p in params: + if p.get("default") is not None: + goal_default[p["name"]] = p["default"] + # goal 为 identity mapping {param_name: param_name}, 默认值只放在 goal_default + goal = {p["name"]: p["name"] for p in params} + + # @action 中的显式 goal/goal_default 覆盖 + goal_override = dict((action_args or {}).get("goal", {})) + goal_default_override = dict((action_args or {}).get("goal_default", {})) + if goal_override: + override_values = set(goal_override.values()) + goal = {k: v for k, v in goal.items() if not (k == v and v in override_values)} + goal.update(goal_override) + goal_default.update(goal_default_override) + + # action handles: 从 @action(handles=[...]) 提取并转换为标准格式 + raw_handles = (action_args or {}).get("handles") + handles = normalize_ast_action_handles(raw_handles) if isinstance(raw_handles, list) else (raw_handles or {}) + + # placeholder_keys: 优先用装饰器显式配置,否则从参数类型检测 + pk = (action_args or {}).get("placeholder_keys") or detect_placeholder_keys(params) + + # 从方法返回值类型生成 result schema + result_schema = None + ret_type_str = method_info.get("return_type", "") + if ret_type_str and ret_type_str not in ("None", "Any", ""): + result_schema = self._generate_schema_from_info( + "result", ret_type_str, None, imap + ) + + entry = { + "type": type_str, + "goal": goal, + "feedback": (action_args or {}).get("feedback") or {}, + "result": (action_args or {}).get("result") or {}, + "schema": wrap_action_schema(goal_schema, action_name, result_schema=result_schema), + "goal_default": goal_default, + "handles": handles, + "placeholder_keys": pk, + } + if (action_args or {}).get("always_free") or method_info.get("always_free"): + entry["always_free"] = True + return action_name, entry + + # 1) auto- actions + for method_name, method_info in ast_meta.get("auto_methods", {}).items(): + action_name, action_entry = _build_json_command_entry(method_name, method_info) + action_value_mappings[action_name] = action_entry + + # 2) @action() without action_type + for method_name, method_info in ast_meta.get("actions", {}).items(): + action_args = method_info.get("action_args", {}) + if action_args.get("action_type"): + continue + action_name, action_entry = _build_json_command_entry(method_name, method_info, action_args) + action_value_mappings[action_name] = action_entry + + # 3) @action(action_type=X) + for method_name, method_info in ast_meta.get("actions", {}).items(): + action_args = method_info.get("action_args", {}) + action_type = action_args.get("action_type") + if not action_type: + continue + + action_name = action_args.get("action_name", method_name) + if action_args.get("auto_prefix"): + action_name = f"auto-{action_name}" + + raw_handles = action_args.get("handles") + handles = normalize_ast_action_handles(raw_handles) if isinstance(raw_handles, list) else (raw_handles or {}) + + method_params = method_info.get("params", []) + + # goal/feedback/result: 字段映射 + # parent=True 时直接通过 import class + MRO 获取; 否则从 AST 方法参数获取, 最后从 ROS2 Goal 获取 + # feedback/result 从 ROS2 获取; 默认 identity mapping {k: k}, 再用 @action 参数 update + goal_override = dict(action_args.get("goal", {})) + feedback_override = dict(action_args.get("feedback", {})) + result_override = dict(action_args.get("result", {})) + goal_default_override = dict(action_args.get("goal_default", {})) + + if action_args.get("parent"): + # @action(parent=True): 直接通过 import class + MRO 获取父类方法签名 + goal = resolve_method_params_via_import(module_str, method_name) + else: + # 从 AST 方法参数构建 goal identity mapping + real_params = [p for p in method_params if p["name"] not in ("self", "cls")] + goal = {p["name"]: p["name"] for p in real_params} + + feedback = {} + result = {} + schema = {} + goal_default = {} + + # 尝试 import ROS2 action type 获取 feedback/result/schema/goal_default, 以及 goal fallback + if ":" not in action_type: + action_type = imap.get(action_type, action_type) + action_type_obj = resolve_type_object(action_type) if ":" in action_type else None + if action_type_obj is None: + logger.warning( + f"[AST] device action '{action_name}': resolve_type_object('{action_type}') returned None" + ) + if action_type_obj is not None: + # 始终从 ROS2 Goal 获取字段作为基础, 再用方法参数覆盖 + try: + if hasattr(action_type_obj, "Goal"): + goal_fields = action_type_obj.Goal.get_fields_and_field_types() + ros2_goal = {k: k for k in goal_fields} + ros2_goal.update(goal) + goal = ros2_goal + except Exception as e: + logger.debug(f"[AST] device action '{action_name}': Goal enrichment from ROS2 failed: {e}") + try: + if hasattr(action_type_obj, "Feedback"): + fb_fields = action_type_obj.Feedback.get_fields_and_field_types() + feedback = {k: k for k in fb_fields} + except Exception as e: + logger.debug(f"[AST] device action '{action_name}': Feedback enrichment failed: {e}") + try: + if hasattr(action_type_obj, "Result"): + res_fields = action_type_obj.Result.get_fields_and_field_types() + result = {k: k for k in res_fields} + except Exception as e: + logger.debug(f"[AST] device action '{action_name}': Result enrichment failed: {e}") + try: + schema = ros_action_to_json_schema(action_type_obj) + except Exception: + pass + # 直接从 ROS2 Goal 实例获取默认值 (msgcenterpy) + try: + goal_default = ROS2MessageInstance(action_type_obj.Goal()).get_python_dict() + except Exception: + pass + + # 如果 ROS2 action type 未提供 result schema, 用方法返回值类型生成 fallback + if not schema.get("properties", {}).get("result"): + ret_type_str = method_info.get("return_type", "") + if ret_type_str and ret_type_str not in ("None", "Any", ""): + ret_schema = self._generate_schema_from_info( + "result", ret_type_str, None, imap + ) + if ret_schema: + schema.setdefault("properties", {})["result"] = ret_schema + + # @action 中的显式 goal/feedback/result/goal_default 覆盖默认值 + # 移除被 override 取代的 identity 条目 (如 {source: source} 被 {sources: source} 取代) + if goal_override: + override_values = set(goal_override.values()) + goal = {k: v for k, v in goal.items() if not (k == v and v in override_values)} + goal.update(goal_override) + feedback.update(feedback_override) + result.update(result_override) + goal_default.update(goal_default_override) + + action_entry = { + "type": action_type.split(":")[-1], + "goal": goal, + "feedback": feedback, + "result": result, + "schema": schema, + "goal_default": goal_default, + "handles": handles, + "placeholder_keys": action_args.get("placeholder_keys") or detect_placeholder_keys(method_params), + } + if action_args.get("always_free") or method_info.get("always_free"): + action_entry["always_free"] = True + action_value_mappings[action_name] = action_entry + + action_value_mappings = dict(sorted(action_value_mappings.items())) + + # --- init_param_schema = { config: , data: } --- + init_params = ast_meta.get("init_params", []) + config_schema = self._generate_schema_from_ast_params(init_params, "__init__", import_map=imap) + data_schema = self._generate_status_schema_from_ast( + ast_meta.get("status_properties", {}), imap + ) + init_schema: Dict[str, Any] = { + "config": config_schema, + "data": data_schema, + } + + # --- handles --- + handles_raw = ast_meta.get("handles", []) + handles = normalize_ast_handles(handles_raw) + + entry: Dict[str, Any] = { + "category": ast_meta.get("category", []), + "class": { + "module": module_str, + "status_types": status_types_str, + "action_value_mappings": action_value_mappings, + "type": ast_meta.get("device_type", "python"), + }, + "config_info": [], + "description": ast_meta.get("description", ""), + "handles": handles, + "icon": ast_meta.get("icon", ""), + "init_param_schema": init_schema, + "version": ast_meta.get("version", "1.0.0"), + "registry_type": "device", + "file_path": file_path, + } + model = ast_meta.get("model") + if model is not None: + entry["model"] = model + hardware_interface = ast_meta.get("hardware_interface") + if hardware_interface is not None: + # AST 解析 HardwareInterface(...) 得到 {"_call": "...", "name": ..., "read": ..., "write": ...} + # 归一化为 YAML 格式,去掉 _call + if isinstance(hardware_interface, dict) and "_call" in hardware_interface: + hardware_interface = {k: v for k, v in hardware_interface.items() if k != "_call"} + entry["class"]["hardware_interface"] = hardware_interface + return entry + + def _generate_schema_from_ast_params( + self, params: list, method_name: str, docstring: Optional[str] = None, + import_map: Optional[Dict[str, str]] = None, + ) -> Dict[str, Any]: + """Generate JSON Schema from AST-extracted parameter list.""" + doc_info = parse_docstring(docstring) + param_descs = doc_info.get("params", {}) + + schema: Dict[str, Any] = { + "type": "object", + "properties": {}, + "required": [], + } + for p in params: + pname = p.get("name", "") + ptype = p.get("type", "") + pdefault = p.get("default") + prequired = p.get("required", True) + + # --- 检测 ResourceSlot / DeviceSlot (兼容 runtime 和 AST 两种格式) --- + is_slot, is_list_slot = detect_slot_type(ptype) + if is_slot == "ResourceSlot": + if is_list_slot: + schema["properties"][pname] = { + "items": ros_message_to_json_schema(Resource, pname), + "type": "array", + } + else: + schema["properties"][pname] = ros_message_to_json_schema(Resource, pname) + elif is_slot == "DeviceSlot": + schema["properties"][pname] = {"type": "string", "description": "device reference"} + else: + schema["properties"][pname] = self._generate_schema_from_info( + pname, ptype, pdefault, import_map + ) + + if pname in param_descs: + schema["properties"][pname]["description"] = param_descs[pname] + + if prequired: + schema["required"].append(pname) + + return schema + + def _generate_status_schema_from_ast( + self, status_properties: Dict[str, Any], + import_map: Optional[Dict[str, str]] = None, + ) -> Dict[str, Any]: + """Generate status_types schema from AST-extracted status properties.""" + schema: Dict[str, Any] = { + "type": "object", + "properties": {}, + "required": [], + } + for name, info in status_properties.items(): + ret_type = info.get("return_type", "str") + schema["properties"][name] = self._generate_schema_from_info( + name, ret_type, None, import_map + ) + schema["required"].append(name) + return schema + + def _build_resource_entry_from_ast(self, resource_id: str, ast_meta: dict) -> Dict[str, Any]: + """Build a resource registry entry from AST-scanned metadata.""" + module_str = ast_meta.get("module", "") + file_path = ast_meta.get("file_path", "") + + handles_raw = ast_meta.get("handles", []) + handles = normalize_ast_handles(handles_raw) + + entry: Dict[str, Any] = { + "category": ast_meta.get("category", []), + "class": { + "module": module_str, + "type": ast_meta.get("class_type", "python"), + }, + "config_info": [], + "description": ast_meta.get("description", ""), + "handles": handles, + "icon": ast_meta.get("icon", ""), + "init_param_schema": {}, + "version": ast_meta.get("version", "1.0.0"), + "registry_type": "resource", + "file_path": file_path, + } + + if ast_meta.get("model"): + entry["model"] = ast_meta["model"] + + return entry + + # ------------------------------------------------------------------ + # config_info 缓存 (pickle 格式,比 JSON 快 ~10x,debug 模式下差异更大) + # ------------------------------------------------------------------ + + @staticmethod + def _get_config_cache_path() -> Optional[Path]: + if BasicConfig.working_dir: + return Path(BasicConfig.working_dir) / "resource_config_cache.pkl" + return None + + def _load_config_cache(self) -> dict: + import pickle + cache_path = self._get_config_cache_path() + if cache_path is None or not cache_path.is_file(): + return {} + try: + data = pickle.loads(cache_path.read_bytes()) + if not isinstance(data, dict) or data.get("_version") != 2: + return {} + return data + except Exception: + return {} + + def _save_config_cache(self, cache: dict) -> None: + import pickle + cache_path = self._get_config_cache_path() + if cache_path is None: + return + try: + cache["_version"] = 2 + cache_path.parent.mkdir(parents=True, exist_ok=True) + tmp = cache_path.with_suffix(".tmp") + tmp.write_bytes(pickle.dumps(cache, protocol=pickle.HIGHEST_PROTOCOL)) + tmp.replace(cache_path) + except Exception: + pass + + @staticmethod + def _module_source_hash(module_str: str) -> Optional[str]: + """Fast MD5 of the source file backing *module_str*. Results are + cached for the process lifetime so the same file is never read twice.""" + if module_str in _module_hash_cache: + return _module_hash_cache[module_str] + + import hashlib + import importlib.util + mod_part = module_str.split(":")[0] if ":" in module_str else module_str + result = None + try: + spec = importlib.util.find_spec(mod_part) + if spec and spec.origin and os.path.isfile(spec.origin): + result = hashlib.md5(open(spec.origin, "rb").read()).hexdigest() + except Exception: + pass + _module_hash_cache[module_str] = result + return result + + def _populate_resource_config_info(self, config_cache: Optional[dict] = None): + """ + 利用线程池并行 import pylabrobot 资源类,生成 config_info。 + 仅在 upload_registry=True 时调用。 + + 启用缓存:以 module_str 为 key,记录源文件 MD5。若源文件未变则 + 直接复用上次的 config_info,跳过 import + 实例化 + dump。 + + Args: + config_cache: 共享的缓存 dict。未提供时自行加载/保存; + 由 load_resource_types 传入时由调用方统一保存。 + """ + import time as _time + + executor = self._startup_executor + if executor is None: + return + + # 筛选需要 import 的 pylabrobot 资源(跳过已有 config_info 的缓存条目) + pylabrobot_entries = { + rid: entry + for rid, entry in self.resource_type_registry.items() + if entry.get("class", {}).get("type") == "pylabrobot" + and entry.get("class", {}).get("module") + and not entry.get("config_info") + } + if not pylabrobot_entries: + return + + t0 = _time.perf_counter() + own_cache = config_cache is None + if own_cache: + config_cache = self._load_config_cache() + cache_hits = 0 + cache_misses = 0 + + def _import_and_dump(resource_id: str, module_str: str): + """Import class, create instance, dump tree. Returns (rid, config_info).""" + try: + res_class = import_class(module_str) + if callable(res_class) and not isinstance(res_class, type): + res_instance = res_class(res_class.__name__) + tree_set = ResourceTreeSet.from_plr_resources([res_instance], known_newly_created=True, old_size=True) + dumped = tree_set.dump(old_position=True) + return resource_id, dumped[0] if dumped else [] + except Exception as e: + logger.warning(f"[UniLab Registry] 资源 {resource_id} config_info 生成失败: {e}") + return resource_id, [] + + # Separate into cache-hit vs cache-miss + need_generate: dict = {} # rid -> module_str + for rid, entry in pylabrobot_entries.items(): + module_str = entry["class"]["module"] + cached = config_cache.get(module_str) + if cached and isinstance(cached, dict) and "config_info" in cached: + src_hash = self._module_source_hash(module_str) + if src_hash is not None and cached.get("src_hash") == src_hash: + self.resource_type_registry[rid]["config_info"] = cached["config_info"] + cache_hits += 1 + continue + need_generate[rid] = module_str + + cache_misses = len(need_generate) + + if need_generate: + future_to_rid = { + executor.submit(_import_and_dump, rid, mod): rid + for rid, mod in need_generate.items() + } + for future in as_completed(future_to_rid): + try: + resource_id, config_info = future.result() + self.resource_type_registry[resource_id]["config_info"] = config_info + module_str = need_generate[resource_id] + src_hash = self._module_source_hash(module_str) + config_cache[module_str] = { + "src_hash": src_hash, + "config_info": config_info, + } + except Exception as e: + rid = future_to_rid[future] + logger.warning(f"[UniLab Registry] 资源 {rid} config_info 线程异常: {e}") + + if own_cache: + self._save_config_cache(config_cache) + + elapsed = _time.perf_counter() - t0 + total = cache_hits + cache_misses + logger.info( + f"[UniLab Registry] config_info 缓存统计: " + f"{cache_hits}/{total} 命中, {cache_misses} 重新生成 " + f"(耗时 {elapsed:.2f}s)" + ) + + # ------------------------------------------------------------------ + # Verify & Resolve (实际 import 验证) + # ------------------------------------------------------------------ + + def verify_and_resolve_registry(self): + """ + 对 AST 扫描得到的注册表执行实际 import 验证(使用共享线程池并行)。 + """ + errors = [] + import_success_count = 0 + resolved_count = 0 + total_items = len(self.device_type_registry) + len(self.resource_type_registry) + + lock = threading.Lock() + + def _verify_device(device_id: str, entry: dict): + nonlocal import_success_count, resolved_count + module_str = entry.get("class", {}).get("module", "") + if not module_str or ":" not in module_str: + with lock: + import_success_count += 1 + return None + + try: + cls = import_class(module_str) + with lock: + import_success_count += 1 + resolved_count += 1 + + # 尝试用动态信息增强注册表 + try: + self.resolve_types_for_device(device_id, cls) + except Exception as e: + logger.debug(f"[UniLab Registry/Verify] 设备 {device_id} 类型解析失败: {e}") + + return None + except Exception as e: + logger.warning( + f"[UniLab Registry/Verify] 设备 {device_id}: " + f"导入模块 {module_str} 失败: {e}" + ) + return f"device:{device_id}: {e}" + + def _verify_resource(resource_id: str, entry: dict): + nonlocal import_success_count + module_str = entry.get("class", {}).get("module", "") + if not module_str or ":" not in module_str: + with lock: + import_success_count += 1 + return None + + try: + import_class(module_str) + with lock: + import_success_count += 1 + return None + except Exception as e: + logger.warning( + f"[UniLab Registry/Verify] 资源 {resource_id}: " + f"导入模块 {module_str} 失败: {e}" + ) + return f"resource:{resource_id}: {e}" + + executor = self._startup_executor or ThreadPoolExecutor(max_workers=8) + try: + device_futures = {} + resource_futures = {} + + for device_id, entry in list(self.device_type_registry.items()): + fut = executor.submit(_verify_device, device_id, entry) + device_futures[fut] = device_id + + for resource_id, entry in list(self.resource_type_registry.items()): + fut = executor.submit(_verify_resource, resource_id, entry) + resource_futures[fut] = resource_id + + for future in as_completed(device_futures): + result = future.result() + if result: + errors.append(result) + + for future in as_completed(resource_futures): + result = future.result() + if result: + errors.append(result) + finally: + if self._startup_executor is None: + executor.shutdown(wait=True) + + if errors: + logger.warning( + f"[UniLab Registry/Verify] 验证完成: {import_success_count}/{total_items} 成功, " + f"{len(errors)} 个错误" + ) + else: + logger.info( + f"[UniLab Registry/Verify] 验证完成: {import_success_count}/{total_items} 全部通过, " + f"{resolved_count} 设备类型已解析" + ) + + return errors + + def resolve_types_for_device(self, device_id: str, cls=None): + """ + 将 AST 扫描得到的字符串类型引用替换为实际的 ROS 消息类对象。 + """ + entry = self.device_type_registry.get(device_id) + if not entry: + return + + class_info = entry.get("class", {}) + + # 解析 status_types + status_types = class_info.get("status_types", {}) + resolved_status = {} + for name, type_ref in status_types.items(): + if isinstance(type_ref, str): + resolved = self._replace_type_with_class(type_ref, device_id, f"状态 {name}") + if resolved: + resolved_status[name] = resolved + else: + resolved_status[name] = type_ref + else: + resolved_status[name] = type_ref + class_info["status_types"] = resolved_status + + # 解析 action_value_mappings + _KEEP_AS_STRING = {"UniLabJsonCommand", "UniLabJsonCommandAsync"} + action_mappings = class_info.get("action_value_mappings", {}) + for action_name, action_config in action_mappings.items(): + type_ref = action_config.get("type", "") + if isinstance(type_ref, str) and type_ref and type_ref not in _KEEP_AS_STRING: + resolved = self._replace_type_with_class(type_ref, device_id, f"动作 {action_name}") + if resolved: + action_config["type"] = resolved + if not action_config.get("schema"): + try: + action_config["schema"] = ros_action_to_json_schema(resolved) + except Exception: + pass + if not action_config.get("goal_default"): + try: + action_config["goal_default"] = ROS2MessageInstance(resolved.Goal()).get_python_dict() + except Exception: + pass + + # 如果提供了类,用动态信息增强 + if cls is not None: + try: + dynamic_info = self._extract_class_info(cls) + + for name, info in dynamic_info.get("status_methods", {}).items(): + if name not in resolved_status: + ret_type = info.get("return_type", "str") + resolved = self._replace_type_with_class(ret_type, device_id, f"状态 {name}") + if resolved: + class_info["status_types"][name] = resolved + + for action_name_key, action_config in action_mappings.items(): + type_obj = action_config.get("type") + if isinstance(type_obj, str) and type_obj in ( + "UniLabJsonCommand", "UniLabJsonCommandAsync" + ): + method_name = action_name_key + if method_name.startswith("auto-"): + method_name = method_name[5:] + + actual_method = getattr(cls, method_name, None) + if actual_method: + method_info = self._analyze_method_signature(actual_method) + schema = self._generate_unilab_json_command_schema( + method_info["args"], + docstring=getattr(actual_method, "__doc__", None), + ) + action_config["schema"] = schema + except Exception as e: + logger.debug(f"[Registry] 设备 {device_id} 动态增强失败: {e}") + + # 添加内置动作 + self._add_builtin_actions(entry, device_id) + + def resolve_all_types(self): + """将所有注册表条目中的字符串类型引用替换为实际的 ROS2 消息类对象。 + + 仅做 ROS2 消息类型查找,不 import 任何设备模块,速度快且无副作用。 + """ + for device_id in list(self.device_type_registry): + try: + self.resolve_types_for_device(device_id) + except Exception as e: + logger.debug(f"[Registry] 设备 {device_id} 类型解析失败: {e}") + + # ------------------------------------------------------------------ + # 模块加载 (import-based) + # ------------------------------------------------------------------ + + def load_modules(self, module_paths: List[str]): + """导入指定的 Python 模块,触发其中的装饰器执行。""" + for module_path in module_paths: + try: + importlib.import_module(module_path) + logger.debug(f"[Registry] 已导入模块: {module_path}") + except Exception as e: + logger.warning(f"[Registry] 导入模块 {module_path} 失败: {e}") + + def setup_from_imports(self, module_paths: Optional[List[str]] = None): + """ + 通过实际 import 构建注册表 (较慢路径)。 + """ + if module_paths: + self.load_modules(module_paths) + + for device_id, cls in get_all_registered_devices().items(): + device_meta = get_device_meta(cls, device_id) + if device_meta is None: + continue + + try: + yaml_entry, runtime_entry = self._build_device_entry(cls, device_meta) + runtime_entry["registry_type"] = "device" + runtime_entry["file_path"] = str(Path(inspect.getfile(cls)).absolute()).replace("\\", "/") + self._add_builtin_actions(runtime_entry, device_id) + self.device_type_registry[device_id] = runtime_entry + logger.debug(f"[Registry] 注册设备: {device_id}") + except Exception as e: + logger.warning(f"[Registry] 生成设备 {device_id} 注册表失败: {e}") + traceback.print_exc() + + for resource_id, obj in get_all_registered_resources().items(): + resource_meta = get_resource_meta(obj) + if resource_meta is None: + continue + + try: + entry = self._build_resource_entry(obj, resource_meta) + entry["registry_type"] = "resource" + if hasattr(obj, "__module__"): + try: + entry["file_path"] = str(Path(inspect.getfile(obj)).absolute()).replace("\\", "/") + except (TypeError, OSError): + entry["file_path"] = "" + self.resource_type_registry[resource_id] = entry + logger.debug(f"[Registry] 注册资源: {resource_id}") + except Exception as e: + logger.warning(f"[Registry] 生成资源 {resource_id} 注册表失败: {e}") + + # ------------------------------------------------------------------ + # YAML 注册表加载 (兼容旧格式) + # ------------------------------------------------------------------ def _load_single_resource_file( - self, file: Path, complete_registry: bool, upload_registry: bool + self, file: Path, complete_registry: bool ) -> Tuple[Dict[str, Any], Dict[str, Any], bool]: """ 加载单个资源文件 (线程安全) @@ -280,6 +1787,8 @@ class Registry: complete_data = {} for resource_id, resource_info in data.items(): + if not isinstance(resource_info, dict): + continue if "version" not in resource_info: resource_info["version"] = "1.0.0" if "category" not in resource_info: @@ -301,426 +1810,116 @@ class Registry: if "file_path" in resource_info: del resource_info["file_path"] complete_data[resource_id] = copy.deepcopy(dict(sorted(resource_info.items()))) - if upload_registry: - class_info = resource_info.get("class", {}) - if len(class_info) and "module" in class_info: - if class_info.get("type") == "pylabrobot": - res_class = get_class(class_info["module"]) - if callable(res_class) and not isinstance(res_class, type): - res_instance = res_class(res_class.__name__) - res_ulr = tree_to_list([resource_plr_to_ulab(res_instance)]) - resource_info["config_info"] = res_ulr resource_info["registry_type"] = "resource" resource_info["file_path"] = str(file.absolute()).replace("\\", "/") complete_data = dict(sorted(complete_data.items())) - complete_data = copy.deepcopy(complete_data) - - if complete_registry: - try: - with open(file, "w", encoding="utf-8") as f: - yaml.dump(complete_data, f, allow_unicode=True, default_flow_style=False, Dumper=NoAliasDumper) - except Exception as e: - logger.warning(f"[UniLab Registry] 写入资源文件失败: {file}, 错误: {e}") return data, complete_data, True - def load_resource_types(self, path: os.PathLike, complete_registry: bool, upload_registry: bool): + def load_resource_types(self, path: os.PathLike, upload_registry: bool, complete_registry: bool = False): abs_path = Path(path).absolute() - resource_path = abs_path / "resources" - files = list(resource_path.glob("*/*.yaml")) - logger.debug(f"[UniLab Registry] resources: {resource_path.exists()}, total: {len(files)}") + resources_path = abs_path / "resources" + files = list(resources_path.rglob("*.yaml")) + logger.trace( + f"[UniLab Registry] resources: {resources_path.exists()}, total: {len(files)}" + ) if not files: return - # 使用线程池并行加载 - max_workers = min(8, len(files)) - results = [] + import hashlib as _hl - with ThreadPoolExecutor(max_workers=max_workers) as executor: - future_to_file = { - executor.submit(self._load_single_resource_file, file, complete_registry, upload_registry): file - for file in files - } - for future in as_completed(future_to_file): - file = future_to_file[future] - try: - data, complete_data, is_valid = future.result() - if is_valid: - results.append((file, data)) - except Exception as e: - logger.warning(f"[UniLab Registry] 处理资源文件异常: {file}, 错误: {e}") + # --- YAML-level cache: per-file entries with config_info --- + config_cache = self._load_config_cache() if upload_registry else None + yaml_cache: dict = config_cache.get("_yaml_resources", {}) if config_cache else {} + yaml_cache_hits = 0 + yaml_cache_misses = 0 + uncached_files: list[Path] = [] + yaml_file_rids: dict[str, list[str]] = {} - # 线程安全地更新注册表 - current_resource_number = len(self.resource_type_registry) + 1 - with self._registry_lock: - for i, (file, data) in enumerate(results): - self.resource_type_registry.update(data) - logger.trace( - f"[UniLab Registry] Resource-{current_resource_number} File-{i+1}/{len(results)} " - + f"Add {list(data.keys())}" - ) - current_resource_number += 1 - - # 记录无效文件 - valid_files = {r[0] for r in results} for file in files: - if file not in valid_files: - logger.debug(f"[UniLab Registry] Res File Not Valid YAML File: {file.absolute()}") + file_key = str(file.absolute()).replace("\\", "/") + if upload_registry and yaml_cache: + try: + yaml_md5 = _hl.md5(file.read_bytes()).hexdigest() + except OSError: + uncached_files.append(file) + yaml_cache_misses += 1 + continue + cached = yaml_cache.get(file_key) + if cached and cached.get("yaml_md5") == yaml_md5: + module_hashes: dict = cached.get("module_hashes", {}) + all_ok = all( + self._module_source_hash(m) == h + for m, h in module_hashes.items() + ) if module_hashes else True + if all_ok and cached.get("entries"): + for rid, entry in cached["entries"].items(): + self.resource_type_registry[rid] = entry + yaml_cache_hits += 1 + continue + uncached_files.append(file) + yaml_cache_misses += 1 - def _extract_class_docstrings(self, module_string: str) -> Dict[str, str]: - """ - 从模块字符串中提取类和方法的docstring信息 - - Args: - module_string: 模块字符串,格式为 "module.path:ClassName" - - Returns: - 包含类和方法docstring信息的字典 - """ - docstrings = {"class_docstring": "", "methods": {}} - - if not module_string or ":" not in module_string: - return docstrings - - try: - module_path, class_name = module_string.split(":", 1) - - # 动态导入模块 - module = importlib.import_module(module_path) - - # 获取类 - if hasattr(module, class_name): - cls = getattr(module, class_name) - - # 获取类的docstring - class_doc = inspect.getdoc(cls) - if class_doc: - docstrings["class_docstring"] = class_doc.strip() - - # 获取所有方法的docstring - for method_name, method in inspect.getmembers(cls, predicate=inspect.isfunction): - method_doc = inspect.getdoc(method) - if method_doc: - docstrings["methods"][method_name] = method_doc.strip() - - # 也获取属性方法的docstring - for method_name, method in inspect.getmembers(cls, predicate=lambda x: isinstance(x, property)): - if hasattr(method, "fget") and method.fget: - method_doc = inspect.getdoc(method.fget) - if method_doc: - docstrings["methods"][method_name] = method_doc.strip() - - except Exception as e: - logger.warning(f"[UniLab Registry] 无法提取docstring信息,模块: {module_string}, 错误: {str(e)}") - - return docstrings - - def _replace_type_with_class(self, type_name: str, device_id: str, field_name: str) -> Any: - """ - 将类型名称替换为实际的类对象 - - Args: - type_name: 类型名称 - device_id: 设备ID,用于错误信息 - field_name: 字段名称,用于错误信息 - - Returns: - 找到的类对象或原始字符串 - - Raises: - SystemExit: 如果找不到类型则终止程序 - """ - # 如果类型名为空,跳过替换 - if not type_name or type_name == "": - logger.warning(f"[UniLab Registry] 设备 {device_id} 的 {field_name} 类型为空,跳过替换") - return type_name - convert_manager = { # 将python基本对象转为ros2基本对象 - "str": "String", - "bool": "Bool", - "int": "Int64", - "float": "Float64", - } - type_name = convert_manager.get(type_name, type_name) # 替换为ROS2类型 - if ":" in type_name: - type_class = msg_converter_manager.get_class(type_name) - else: - type_class = msg_converter_manager.search_class(type_name) - if type_class: - return type_class - else: - logger.error(f"[UniLab Registry] 无法找到类型 '{type_name}' 用于设备 {device_id} 的 {field_name}") - raise ROSMsgNotFound(f"类型 '{type_name}' 未找到,用于设备 {device_id} 的 {field_name}") - - def _get_json_schema_type(self, type_str: str) -> str: - """ - 根据类型字符串返回对应的JSON Schema类型 - - Args: - type_str: 类型字符串 - - Returns: - JSON Schema类型字符串 - """ - type_lower = type_str.lower() - type_mapping = { - ("str", "string"): "string", - ("int", "integer"): "integer", - ("float", "number"): "number", - ("bool", "boolean"): "boolean", - ("list", "array"): "array", - ("dict", "object"): "object", + # Process uncached YAML files with thread pool + executor = self._startup_executor + future_to_file = { + executor.submit(self._load_single_resource_file, file, complete_registry): file + for file in uncached_files } - # 遍历映射找到匹配的类型 - for type_variants, json_type in type_mapping.items(): - if type_lower in type_variants: - return json_type + for future in as_completed(future_to_file): + file = future_to_file[future] + try: + data, complete_data, is_valid = future.result() + if is_valid: + self.resource_type_registry.update(complete_data) + file_key = str(file.absolute()).replace("\\", "/") + yaml_file_rids[file_key] = list(complete_data.keys()) + except Exception as e: + logger.warning(f"[UniLab Registry] 加载资源文件失败: {file}, 错误: {e}") - # 特殊处理包含冒号的类型(如ROS消息类型) - if ":" in type_lower: - return "object" + # upload 模式下,统一利用线程池为 pylabrobot 资源生成 config_info + if upload_registry: + self._populate_resource_config_info(config_cache=config_cache) - # 默认返回字符串类型 - return "string" + # Update YAML cache for newly processed files (entries now have config_info) + if yaml_file_rids and config_cache is not None: + for file_key, rids in yaml_file_rids.items(): + entries = {} + module_hashes = {} + for rid in rids: + entry = self.resource_type_registry.get(rid) + if entry: + entries[rid] = copy.deepcopy(entry) + mod_str = entry.get("class", {}).get("module", "") + if mod_str and mod_str not in module_hashes: + src_h = self._module_source_hash(mod_str) + if src_h: + module_hashes[mod_str] = src_h + try: + yaml_md5 = _hl.md5(Path(file_key).read_bytes()).hexdigest() + except OSError: + continue + yaml_cache[file_key] = { + "yaml_md5": yaml_md5, + "module_hashes": module_hashes, + "entries": entries, + } + config_cache["_yaml_resources"] = yaml_cache + self._save_config_cache(config_cache) - def _generate_schema_from_info( - self, - param_name: str, - param_type: Union[str, Tuple[str]], - param_default: Any, - ) -> Dict[str, Any]: - """ - 根据参数信息生成JSON Schema - """ - prop_schema = {} - - # 处理嵌套类型(Tuple[str]) - if isinstance(param_type, tuple): - if len(param_type) == 2: - outer_type, inner_type = param_type - outer_json_type = self._get_json_schema_type(outer_type) - inner_json_type = self._get_json_schema_type(inner_type) - - prop_schema["type"] = outer_json_type - - # 根据外层类型设置内层类型信息 - if outer_json_type == "array": - prop_schema["items"] = {"type": inner_json_type} - elif outer_json_type == "object": - prop_schema["additionalProperties"] = {"type": inner_json_type} - else: - # 不是标准的嵌套类型,默认为字符串 - prop_schema["type"] = "string" - else: - # 处理非嵌套类型 - if param_type: - prop_schema["type"] = self._get_json_schema_type(param_type) - else: - # 如果没有类型信息,默认为字符串 - prop_schema["type"] = "string" - - # 设置默认值 - if param_default is not None: - prop_schema["default"] = param_default - - return prop_schema - - def _generate_status_types_schema(self, status_types: Dict[str, Any]) -> Dict[str, Any]: - """ - 根据状态类型生成JSON Schema - """ - status_schema = { - "type": "object", - "properties": {}, - "required": [], - } - for status_name, status_type in status_types.items(): - status_schema["properties"][status_name] = self._generate_schema_from_info( - status_name, status_type["return_type"], None + total_yaml = yaml_cache_hits + yaml_cache_misses + if upload_registry and total_yaml > 0: + logger.info( + f"[UniLab Registry] YAML 资源缓存: " + f"{yaml_cache_hits}/{total_yaml} 文件命中, " + f"{yaml_cache_misses} 重新加载" ) - status_schema["required"].append(status_name) - return status_schema - - def _generate_unilab_json_command_schema( - self, - method_args: List[Dict[str, Any]], - method_name: str, - return_annotation: Any = None, - previous_schema: Dict[str, Any] | None = None, - ) -> Dict[str, Any]: - """ - 根据UniLabJsonCommand方法信息生成JSON Schema,暂不支持嵌套类型 - - Args: - method_args: 方法信息字典,包含args等 - method_name: 方法名称 - return_annotation: 返回类型注解,用于生成result schema(仅支持TypedDict) - previous_schema: 之前的 schema,用于保留 goal/feedback/result 下一级字段的 description - - Returns: - JSON Schema格式的参数schema - """ - schema = { - "type": "object", - "properties": {}, - "required": [], - } - for arg_info in method_args: - param_name = arg_info.get("name", "") - param_type = arg_info.get("type", "") - param_default = arg_info.get("default") - param_required = arg_info.get("required", True) - if param_type == "unilabos.registry.placeholder_type:ResourceSlot": - schema["properties"][param_name] = ros_message_to_json_schema(Resource, param_name) - elif param_type == ("list", "unilabos.registry.placeholder_type:ResourceSlot"): - schema["properties"][param_name] = { - "items": ros_message_to_json_schema(Resource, param_name), - "type": "array", - } - else: - schema["properties"][param_name] = self._generate_schema_from_info( - param_name, param_type, param_default - ) - if param_required: - schema["required"].append(param_name) - - # 生成result schema(仅当return_annotation是TypedDict时) - result_schema = {} - if return_annotation is not None and self._is_typed_dict(return_annotation): - result_schema = self._generate_typed_dict_result_schema(return_annotation) - - final_schema = { - "title": f"{method_name}参数", - "description": f"", - "type": "object", - "properties": {"goal": schema, "feedback": {}, "result": result_schema}, - "required": ["goal"], - } - - # 保留之前 schema 中 goal/feedback/result 下一级字段的 description - if previous_schema: - self._preserve_field_descriptions(final_schema, previous_schema) - - return final_schema - - def _preserve_field_descriptions(self, new_schema: Dict[str, Any], previous_schema: Dict[str, Any]) -> None: - """ - 保留之前 schema 中 goal/feedback/result 下一级字段的 description 和 title - - Args: - new_schema: 新生成的 schema(会被修改) - previous_schema: 之前的 schema - """ - for section in ["goal", "feedback", "result"]: - new_section = new_schema.get("properties", {}).get(section, {}) - prev_section = previous_schema.get("properties", {}).get(section, {}) - - if not new_section or not prev_section: - continue - - new_props = new_section.get("properties", {}) - prev_props = prev_section.get("properties", {}) - - for field_name, field_schema in new_props.items(): - if field_name in prev_props: - prev_field = prev_props[field_name] - # 保留字段的 description - if "description" in prev_field and prev_field["description"]: - field_schema["description"] = prev_field["description"] - # 保留字段的 title(用户自定义的中文名) - if "title" in prev_field and prev_field["title"]: - field_schema["title"] = prev_field["title"] - - def _is_typed_dict(self, annotation: Any) -> bool: - """ - 检查类型注解是否是TypedDict - - Args: - annotation: 类型注解对象 - - Returns: - 是否为TypedDict - """ - if annotation is None or annotation == inspect.Parameter.empty: - return False - - # 使用 typing_extensions.is_typeddict 进行检查(Python < 3.12 兼容) - try: - from typing_extensions import is_typeddict - - return is_typeddict(annotation) - except ImportError: - # 回退方案:检查 TypedDict 特有的属性 - if isinstance(annotation, type): - return hasattr(annotation, "__required_keys__") and hasattr(annotation, "__optional_keys__") - return False - - def _generate_typed_dict_result_schema(self, return_annotation: Any) -> Dict[str, Any]: - """ - 根据TypedDict类型生成result的JSON Schema - - Args: - return_annotation: TypedDict类型注解 - - Returns: - JSON Schema格式的result schema - """ - if not self._is_typed_dict(return_annotation): - return {} - - try: - from msgcenterpy.instances.typed_dict_instance import TypedDictMessageInstance - - result_schema = TypedDictMessageInstance.get_json_schema_from_typed_dict(return_annotation) - return result_schema - except ImportError: - logger.warning("[UniLab Registry] msgcenterpy未安装,无法生成TypedDict的result schema") - return {} - except Exception as e: - logger.warning(f"[UniLab Registry] 生成TypedDict result schema失败: {e}") - return {} - - def _add_builtin_actions(self, device_config: Dict[str, Any], device_id: str): - """ - 为设备配置添加内置的执行驱动命令动作 - - Args: - device_config: 设备配置字典 - device_id: 设备ID - """ - from unilabos.app.web.utils.action_utils import get_yaml_from_goal_type - - if "class" not in device_config: - return - - if "action_value_mappings" not in device_config["class"]: - device_config["class"]["action_value_mappings"] = {} - - for additional_action in ["_execute_driver_command", "_execute_driver_command_async"]: - device_config["class"]["action_value_mappings"][additional_action] = { - "type": self._replace_type_with_class("StrSingleInput", device_id, f"动作 {additional_action}"), - "goal": {"string": "string"}, - "feedback": {}, - "result": {}, - "schema": ros_action_to_json_schema( - self._replace_type_with_class("StrSingleInput", device_id, f"动作 {additional_action}") - ), - "goal_default": yaml.safe_load( - io.StringIO( - get_yaml_from_goal_type( - self._replace_type_with_class( - "StrSingleInput", device_id, f"动作 {additional_action}" - ).Goal - ) - ) - ), - "handles": {}, - } def _load_single_device_file( - self, file: Path, complete_registry: bool, get_yaml_from_goal_type + self, file: Path, complete_registry: bool ) -> Tuple[Dict[str, Any], Dict[str, Any], bool, List[str]]: """ 加载单个设备文件 (线程安全) @@ -747,6 +1946,10 @@ class Registry: device_ids = [] for device_id, device_config in data.items(): + if not isinstance(device_config, dict): + continue + + # 补全默认字段 if "version" not in device_config: device_config["version"] = "1.0.0" if "category" not in device_config: @@ -763,6 +1966,7 @@ class Registry: device_config["handles"] = [] if "init_param_schema" not in device_config: device_config["init_param_schema"] = {} + if "class" in device_config: if "status_types" not in device_config["class"] or device_config["class"]["status_types"] is None: device_config["class"]["status_types"] = {} @@ -771,6 +1975,7 @@ class Registry: or device_config["class"]["action_value_mappings"] is None ): device_config["class"]["action_value_mappings"] = {} + enhanced_info = {} if complete_registry: device_config["class"]["status_types"].clear() @@ -780,6 +1985,8 @@ class Registry: device_config["class"]["status_types"].update( {k: v["return_type"] for k, v in enhanced_info["status_methods"].items()} ) + + # --- status_types: 字符串 → class 映射 --- for status_name, status_type in device_config["class"]["status_types"].items(): if isinstance(status_type, tuple) or status_type in ["Any", "None", "Unknown"]: status_type = "String" @@ -792,6 +1999,7 @@ class Registry: target_type = String status_str_type_mapping[status_type] = target_type device_config["class"]["status_types"] = dict(sorted(device_config["class"]["status_types"].items())) + if complete_registry: old_action_configs = {} for action_name, action_config in device_config["class"]["action_value_mappings"].items(): @@ -806,15 +2014,10 @@ class Registry: { f"auto-{k}": { "type": "UniLabJsonCommandAsync" if v["is_async"] else "UniLabJsonCommand", - "goal": {}, + "goal": {i["name"]: i["default"] for i in v["args"] if i["default"] is not None}, "feedback": {}, "result": {}, - "schema": self._generate_unilab_json_command_schema( - v["args"], - k, - v.get("return_annotation"), - old_action_configs.get(f"auto-{k}", {}).get("schema"), - ), + "schema": self._generate_unilab_json_command_schema(v["args"]), "goal_default": {i["name"]: i["default"] for i in v["args"]}, "handles": old_action_configs.get(f"auto-{k}", {}).get("handles", []), "placeholder_keys": { @@ -839,13 +2042,16 @@ class Registry: if k not in device_config["class"]["action_value_mappings"] } ) + # 保留旧 schema 中的 description for action_name, old_config in old_action_configs.items(): if action_name in device_config["class"]["action_value_mappings"]: old_schema = old_config.get("schema", {}) - if "description" in old_schema and old_schema["description"]: - device_config["class"]["action_value_mappings"][action_name]["schema"][ - "description" - ] = old_schema["description"] + new_schema = device_config["class"]["action_value_mappings"][action_name].get("schema", {}) + if old_schema: + preserve_field_descriptions(new_schema, old_schema) + if "description" in old_schema and old_schema["description"]: + new_schema["description"] = old_schema["description"] + device_config["init_param_schema"] = {} device_config["init_param_schema"]["config"] = self._generate_unilab_json_command_schema( enhanced_info["init_params"], "__init__" @@ -854,6 +2060,7 @@ class Registry: enhanced_info["status_methods"] ) + # --- action_value_mappings: 处理非 UniLabJsonCommand 类型 --- device_config.pop("schema", None) device_config["class"]["action_value_mappings"] = dict( sorted(device_config["class"]["action_value_mappings"].items()) @@ -878,37 +2085,50 @@ class Registry: continue action_str_type_mapping[action_type_str] = target_type if target_type is not None: - action_config["goal_default"] = yaml.safe_load( - io.StringIO(get_yaml_from_goal_type(target_type.Goal)) - ) + try: + action_config["goal_default"] = ROS2MessageInstance(target_type.Goal()).get_python_dict() + except Exception: + action_config["goal_default"] = {} action_config["schema"] = ros_action_to_json_schema(target_type) else: logger.warning( f"[UniLab Registry] 设备 {device_id} 的动作 {action_name} 类型为空,跳过替换" ) + + # deepcopy 保存可序列化的 complete_data(此时 type 字段仍为字符串) + device_config["file_path"] = str(file.absolute()).replace("\\", "/") + device_config["registry_type"] = "device" complete_data[device_id] = copy.deepcopy(dict(sorted(device_config.items()))) + + # 之后才把 type 字符串替换为 class 对象(仅用于运行时 data) for status_name, status_type in device_config["class"]["status_types"].items(): - device_config["class"]["status_types"][status_name] = status_str_type_mapping[status_type] + if status_type in status_str_type_mapping: + device_config["class"]["status_types"][status_name] = status_str_type_mapping[status_type] for action_name, action_config in device_config["class"]["action_value_mappings"].items(): - if action_config["type"] not in action_str_type_mapping: - continue - action_config["type"] = action_str_type_mapping[action_config["type"]] + if action_config.get("type") in action_str_type_mapping: + action_config["type"] = action_str_type_mapping[action_config["type"]] + self._add_builtin_actions(device_config, device_id) - device_config["file_path"] = str(file.absolute()).replace("\\", "/") - device_config["registry_type"] = "device" + device_ids.append(device_id) complete_data = dict(sorted(complete_data.items())) complete_data = copy.deepcopy(complete_data) - try: - with open(file, "w", encoding="utf-8") as f: - yaml.dump(complete_data, f, allow_unicode=True, default_flow_style=False, Dumper=NoAliasDumper) - except Exception as e: - logger.warning(f"[UniLab Registry] 写入设备文件失败: {file}, 错误: {e}") + if complete_registry: + # 仅在 complete_registry 模式下回写 YAML,排除运行时字段 + write_data = copy.deepcopy(complete_data) + for dev_id, dev_cfg in write_data.items(): + dev_cfg.pop("file_path", None) + dev_cfg.pop("registry_type", None) + try: + with open(file, "w", encoding="utf-8") as f: + yaml.dump(write_data, f, allow_unicode=True, default_flow_style=False, Dumper=NoAliasDumper) + except Exception as e: + logger.warning(f"[UniLab Registry] 写入设备文件失败: {file}, 错误: {e}") return data, complete_data, True, device_ids - def load_device_types(self, path: os.PathLike, complete_registry: bool): + def load_device_types(self, path: os.PathLike, complete_registry: bool = False): abs_path = Path(path).absolute() devices_path = abs_path / "devices" device_comms_path = abs_path / "device_comms" @@ -921,44 +2141,27 @@ class Registry: if not files: return - from unilabos.app.web.utils.action_utils import get_yaml_from_goal_type + executor = self._startup_executor + future_to_file = { + executor.submit( + self._load_single_device_file, file, complete_registry + ): file + for file in files + } - # 使用线程池并行加载 - max_workers = min(8, len(files)) - results = [] + for future in as_completed(future_to_file): + file = future_to_file[future] + try: + data, _complete_data, is_valid, device_ids = future.result() + if is_valid: + runtime_data = {did: data[did] for did in device_ids if did in data} + self.device_type_registry.update(runtime_data) + except Exception as e: + logger.warning(f"[UniLab Registry] 加载设备文件失败: {file}, 错误: {e}") - with ThreadPoolExecutor(max_workers=max_workers) as executor: - future_to_file = { - executor.submit(self._load_single_device_file, file, complete_registry, get_yaml_from_goal_type): file - for file in files - } - for future in as_completed(future_to_file): - file = future_to_file[future] - try: - data, complete_data, is_valid, device_ids = future.result() - if is_valid: - results.append((file, data, device_ids)) - except Exception as e: - traceback.print_exc() - logger.warning(f"[UniLab Registry] 处理设备文件异常: {file}, 错误: {e}") - - # 线程安全地更新注册表 - current_device_number = len(self.device_type_registry) + 1 - with self._registry_lock: - for file, data, device_ids in results: - self.device_type_registry.update(data) - for device_id in device_ids: - logger.trace( - f"[UniLab Registry] Device-{current_device_number} Add {device_id} " - + f"[{data[device_id].get('name', '未命名设备')}]" - ) - current_device_number += 1 - - # 记录无效文件 - valid_files = {r[0] for r in results} - for file in files: - if file not in valid_files: - logger.debug(f"[UniLab Registry] Device File Not Valid YAML File: {file.absolute()}") + # ------------------------------------------------------------------ + # 注册表信息输出 + # ------------------------------------------------------------------ def obtain_registry_device_info(self): devices = [] @@ -966,7 +2169,6 @@ class Registry: device_info_copy = copy.deepcopy(device_info) if "class" in device_info_copy and "action_value_mappings" in device_info_copy["class"]: action_mappings = device_info_copy["class"]["action_value_mappings"] - # 过滤掉内置的驱动命令动作 builtin_actions = ["_execute_driver_command", "_execute_driver_command_async"] filtered_action_mappings = { action_name: action_config @@ -976,6 +2178,9 @@ class Registry: device_info_copy["class"]["action_value_mappings"] = filtered_action_mappings for action_name, action_config in filtered_action_mappings.items(): + type_obj = action_config.get("type") + if hasattr(type_obj, "__name__"): + action_config["type"] = type_obj.__name__ if "schema" in action_config and action_config["schema"]: schema = action_config["schema"] # 确保schema结构存在 @@ -999,6 +2204,10 @@ class Registry: action_config["schema"]["properties"]["goal"]["_unilabos_placeholder_info"] = action_config[ "placeholder_keys" ] + status_types = device_info_copy["class"].get("status_types", {}) + for status_name, status_type in status_types.items(): + if hasattr(status_type, "__name__"): + status_types[status_name] = status_type.__name__ msg = {"id": device_id, **device_info_copy} devices.append(msg) @@ -1011,35 +2220,76 @@ class Registry: resources.append(msg) return resources + def get_yaml_output(self, device_id: str) -> str: + """将指定设备的注册表条目导出为 YAML 字符串。""" + entry = self.device_type_registry.get(device_id) + if not entry: + return "" + + entry = copy.deepcopy(entry) + + if "class" in entry: + status_types = entry["class"].get("status_types", {}) + for name, type_obj in status_types.items(): + if hasattr(type_obj, "__name__"): + status_types[name] = type_obj.__name__ + + for action_name, action_config in entry["class"].get("action_value_mappings", {}).items(): + type_obj = action_config.get("type") + if hasattr(type_obj, "__name__"): + action_config["type"] = type_obj.__name__ + + entry.pop("registry_type", None) + entry.pop("file_path", None) + + if "class" in entry and "action_value_mappings" in entry["class"]: + entry["class"]["action_value_mappings"] = { + k: v + for k, v in entry["class"]["action_value_mappings"].items() + if not k.startswith("_execute_driver_command") + } + + return yaml.dump( + {device_id: entry}, + allow_unicode=True, + default_flow_style=False, + Dumper=NoAliasDumper, + ) + + +# --------------------------------------------------------------------------- +# 全局单例实例 & 构建入口 +# --------------------------------------------------------------------------- -# 全局单例实例 lab_registry = Registry() -def build_registry(registry_paths=None, complete_registry=False, upload_registry=False): +def build_registry(registry_paths=None, devices_dirs=None, upload_registry=False, check_mode=False): """ 构建或获取Registry单例实例 - - Args: - registry_paths: 额外的注册表路径列表 - - Returns: - Registry实例 """ logger.info("[UniLab Registry] 构建注册表实例") - # 由于使用了单例,这里不需要重新创建实例 global lab_registry - # 如果有额外路径,添加到registry_paths if registry_paths: current_paths = lab_registry.registry_paths.copy() - # 检查是否有新路径需要添加 for path in registry_paths: if path not in current_paths: lab_registry.registry_paths.append(path) - # 初始化注册表 - lab_registry.setup(complete_registry, upload_registry) + lab_registry.setup(devices_dirs=devices_dirs, upload_registry=upload_registry) + + # 将 AST 扫描的字符串类型替换为实际 ROS2 消息类(仅查找 ROS2 类型,不 import 设备模块) + lab_registry.resolve_all_types() + + if check_mode: + lab_registry.verify_and_resolve_registry() + + # noinspection PyProtectedMember + if lab_registry._startup_executor is not None: + # noinspection PyProtectedMember + lab_registry._startup_executor.shutdown(wait=False) + lab_registry._startup_executor = None return lab_registry diff --git a/unilabos/registry/utils.py b/unilabos/registry/utils.py new file mode 100644 index 00000000..bc65450e --- /dev/null +++ b/unilabos/registry/utils.py @@ -0,0 +1,699 @@ +""" +注册表工具函数 + +从 registry.py 中提取的纯工具函数,包括: +- docstring 解析 +- 类型字符串 → JSON Schema 转换 +- AST 类型节点解析 +- TypedDict / Slot / Handle 等辅助检测 +""" + +import inspect +import logging +import re +import typing +from typing import Any, Dict, List, Optional, Tuple, Union + +from msgcenterpy.instances.typed_dict_instance import TypedDictMessageInstance + +from unilabos.utils.cls_creator import import_class + +_logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# 异常 +# --------------------------------------------------------------------------- + + +class ROSMsgNotFound(Exception): + pass + + +# --------------------------------------------------------------------------- +# Docstring 解析 (Google-style) +# --------------------------------------------------------------------------- + +_SECTION_RE = re.compile(r"^(\w[\w\s]*):\s*$") + + +def parse_docstring(docstring: Optional[str]) -> Dict[str, Any]: + """ + 解析 Google-style docstring,提取描述和参数说明。 + + Returns: + {"description": "短描述", "params": {"param1": "参数1描述", ...}} + """ + result: Dict[str, Any] = {"description": "", "params": {}} + if not docstring: + return result + + lines = docstring.strip().splitlines() + if not lines: + return result + + result["description"] = lines[0].strip() + + in_args = False + current_param: Optional[str] = None + current_desc_parts: list = [] + + for line in lines[1:]: + stripped = line.strip() + section_match = _SECTION_RE.match(stripped) + if section_match: + if current_param is not None: + result["params"][current_param] = "\n".join(current_desc_parts).strip() + current_param = None + current_desc_parts = [] + section_name = section_match.group(1).lower() + in_args = section_name in ("args", "arguments", "parameters", "params") + continue + + if not in_args: + continue + + if ":" in stripped and not stripped.startswith(" "): + if current_param is not None: + result["params"][current_param] = "\n".join(current_desc_parts).strip() + param_part, _, desc_part = stripped.partition(":") + param_name = param_part.strip().split("(")[0].strip() + current_param = param_name + current_desc_parts = [desc_part.strip()] + elif current_param is not None: + aline = line + if aline.startswith(" "): + aline = aline[4:] + elif aline.startswith("\t"): + aline = aline[1:] + current_desc_parts.append(aline.strip()) + + if current_param is not None: + result["params"][current_param] = "\n".join(current_desc_parts).strip() + + return result + + +# --------------------------------------------------------------------------- +# 类型常量 +# --------------------------------------------------------------------------- + +SIMPLE_TYPE_MAP = { + "str": "string", + "string": "string", + "int": "integer", + "integer": "integer", + "float": "number", + "number": "number", + "bool": "boolean", + "boolean": "boolean", + "list": "array", + "array": "array", + "dict": "object", + "object": "object", +} + +ARRAY_TYPES = {"list", "List", "tuple", "Tuple", "set", "Set", "Sequence", "Iterable"} +OBJECT_TYPES = {"dict", "Dict", "Mapping"} +WRAPPER_TYPES = {"Optional"} +SLOT_TYPES = {"ResourceSlot", "DeviceSlot"} + + +# --------------------------------------------------------------------------- +# 简单类型映射 +# --------------------------------------------------------------------------- + + +def get_json_schema_type(type_str: str) -> str: + """简单类型名 -> JSON Schema type""" + return SIMPLE_TYPE_MAP.get(type_str.lower(), "string") + + +# --------------------------------------------------------------------------- +# AST 类型解析 +# --------------------------------------------------------------------------- + + +def parse_type_node(type_str: str): + """将类型注解字符串解析为 AST 节点,失败返回 None。""" + import ast as _ast + + try: + return _ast.parse(type_str.strip(), mode="eval").body + except Exception: + return None + + +def _collect_bitor(node, out: list): + """递归收集 X | Y | Z 的所有分支。""" + import ast as _ast + + if isinstance(node, _ast.BinOp) and isinstance(node.op, _ast.BitOr): + _collect_bitor(node.left, out) + _collect_bitor(node.right, out) + else: + out.append(node) + + +def type_node_to_schema( + node, + import_map: Optional[Dict[str, str]] = None, +) -> Dict[str, Any]: + """将 AST 类型注解节点递归转换为 JSON Schema dict。 + + 当提供 import_map 时,对于未知类名会尝试通过 import_map 解析模块路径, + 然后 import 真实类型对象来生成 schema (支持 TypedDict 等)。 + + 映射规则: + - Optional[X] → X 的 schema (剥掉 Optional) + - Union[X, Y] → {"anyOf": [X_schema, Y_schema]} + - List[X] / Tuple[X] / Set[X] → {"type": "array", "items": X_schema} + - Dict[K, V] → {"type": "object", "additionalProperties": V_schema} + - Literal["a", "b"] → {"type": "string", "enum": ["a", "b"]} + - TypedDict (via import_map) → {"type": "object", "properties": {...}} + - 基本类型 str/int/... → {"type": "string"/"integer"/...} + """ + import ast as _ast + + # --- Name 节点: str / int / dict / ResourceSlot / 自定义类 --- + if isinstance(node, _ast.Name): + name = node.id + if name in SLOT_TYPES: + return {"$slot": name} + json_type = SIMPLE_TYPE_MAP.get(name.lower()) + if json_type: + return {"type": json_type} + # 尝试通过 import_map 解析并 import 真实类型 + if import_map and name in import_map: + type_obj = resolve_type_object(import_map[name]) + if type_obj is not None: + return type_to_schema(type_obj) + # 未知类名 → 无法转 schema 的自定义类型默认当 object + return {"type": "object"} + + if isinstance(node, _ast.Constant): + if isinstance(node.value, str): + return {"type": SIMPLE_TYPE_MAP.get(node.value.lower(), "string")} + return {"type": "string"} + + # --- Subscript 节点: List[X], Dict[K,V], Optional[X], Literal[...] 等 --- + if isinstance(node, _ast.Subscript): + base_name = node.value.id if isinstance(node.value, _ast.Name) else "" + + # Optional[X] → 剥掉 + if base_name in WRAPPER_TYPES: + return type_node_to_schema(node.slice, import_map) + + # Union[X, None] → 剥掉 None; Union[X, Y] → anyOf + if base_name == "Union": + elts = node.slice.elts if isinstance(node.slice, _ast.Tuple) else [node.slice] + non_none = [ + e + for e in elts + if not (isinstance(e, _ast.Constant) and e.value is None) + and not (isinstance(e, _ast.Name) and e.id == "None") + ] + if len(non_none) == 1: + return type_node_to_schema(non_none[0], import_map) + if len(non_none) > 1: + return {"anyOf": [type_node_to_schema(e, import_map) for e in non_none]} + return {"type": "string"} + + # Literal["a", "b", 1] → enum + if base_name == "Literal": + elts = node.slice.elts if isinstance(node.slice, _ast.Tuple) else [node.slice] + values = [] + for e in elts: + if isinstance(e, _ast.Constant): + values.append(e.value) + elif isinstance(e, _ast.Name): + values.append(e.id) + if values: + return {"type": "string", "enum": values} + return {"type": "string"} + + # List / Tuple / Set → array + if base_name in ARRAY_TYPES: + if isinstance(node.slice, _ast.Tuple) and node.slice.elts: + inner_node = node.slice.elts[0] + else: + inner_node = node.slice + return {"type": "array", "items": type_node_to_schema(inner_node, import_map)} + + # Dict → object + if base_name in OBJECT_TYPES: + schema: Dict[str, Any] = {"type": "object"} + if isinstance(node.slice, _ast.Tuple) and len(node.slice.elts) >= 2: + val_node = node.slice.elts[1] + # Dict[str, Any] → 不加 additionalProperties (Any 等同于无约束) + is_any = (isinstance(val_node, _ast.Name) and val_node.id == "Any") or ( + isinstance(val_node, _ast.Constant) and val_node.value is None + ) + if not is_any: + val_schema = type_node_to_schema(val_node, import_map) + schema["additionalProperties"] = val_schema + return schema + + # --- BinOp: X | Y (Python 3.10+) → 当 Union 处理 --- + if isinstance(node, _ast.BinOp) and isinstance(node.op, _ast.BitOr): + parts: list = [] + _collect_bitor(node, parts) + non_none = [ + p + for p in parts + if not (isinstance(p, _ast.Constant) and p.value is None) + and not (isinstance(p, _ast.Name) and p.id == "None") + ] + if len(non_none) == 1: + return type_node_to_schema(non_none[0], import_map) + if len(non_none) > 1: + return {"anyOf": [type_node_to_schema(p, import_map) for p in non_none]} + return {"type": "string"} + + return {"type": "string"} + + +# --------------------------------------------------------------------------- +# 真实类型对象解析 (import-based) +# --------------------------------------------------------------------------- + + +def resolve_type_object(type_ref: str) -> Optional[Any]: + """通过 'module.path:ClassName' 格式的引用 import 并返回真实类型对象。 + + 对于 typing 内置名 (str, int, List 等) 直接返回 None (由 AST 路径处理)。 + import 失败时静默返回 None。 + """ + if ":" not in type_ref: + return None + try: + return import_class(type_ref) + except Exception: + return None + + +def is_typed_dict_class(obj: Any) -> bool: + """检查对象是否是 TypedDict 类。""" + if obj is None: + return False + try: + from typing_extensions import is_typeddict + + return is_typeddict(obj) + except ImportError: + if isinstance(obj, type): + return hasattr(obj, "__required_keys__") and hasattr(obj, "__optional_keys__") + return False + + +def type_to_schema(tp: Any) -> Dict[str, Any]: + """将真实 typing 对象递归转换为 JSON Schema dict。 + + 支持: + - 基本类型: str, int, float, bool → {"type": "string"/"integer"/...} + - typing 泛型: List[X], Dict[K,V], Optional[X], Union[X,Y], Literal[...] + - TypedDict → {"type": "object", "properties": {...}, "required": [...]} + - 自定义类 (ResourceSlot 等) → {"$slot": "..."} 或 {"type": "string"} + """ + origin = getattr(tp, "__origin__", None) + args = getattr(tp, "__args__", None) + + # --- None / NoneType --- + if tp is type(None): + return {"type": "null"} + + # --- 基本类型 --- + if tp is str: + return {"type": "string"} + if tp is int: + return {"type": "integer"} + if tp is float: + return {"type": "number"} + if tp is bool: + return {"type": "boolean"} + + # --- TypedDict --- + if is_typed_dict_class(tp): + try: + return TypedDictMessageInstance.get_json_schema_from_typed_dict(tp) + except Exception: + return {"type": "object"} + + # --- Literal --- + if origin is typing.Literal: + values = list(args) if args else [] + return {"type": "string", "enum": values} + + # --- Optional / Union --- + if origin is typing.Union: + non_none = [a for a in (args or ()) if a is not type(None)] + if len(non_none) == 1: + return type_to_schema(non_none[0]) + if len(non_none) > 1: + return {"anyOf": [type_to_schema(a) for a in non_none]} + return {"type": "string"} + + # --- List / Sequence / Set / Tuple / Iterable --- + if origin in (list, tuple, set, frozenset) or ( + origin is not None + and getattr(origin, "__name__", "") in ("Sequence", "Iterable", "Iterator", "MutableSequence") + ): + if args: + return {"type": "array", "items": type_to_schema(args[0])} + return {"type": "array"} + + # --- Dict / Mapping --- + if origin in (dict,) or (origin is not None and getattr(origin, "__name__", "") in ("Mapping", "MutableMapping")): + schema: Dict[str, Any] = {"type": "object"} + if args and len(args) >= 2: + schema["additionalProperties"] = type_to_schema(args[1]) + return schema + + # --- Slot 类型 --- + if isinstance(tp, type): + name = tp.__name__ + if name in SLOT_TYPES: + return {"$slot": name} + + # --- 其他未知类型 fallback --- + if isinstance(tp, type): + return {"type": "object"} + return {"type": "string"} + + +# --------------------------------------------------------------------------- +# Slot / Placeholder 检测 +# --------------------------------------------------------------------------- + + +def detect_slot_type(ptype) -> Tuple[Optional[str], bool]: + """检测参数类型是否为 ResourceSlot / DeviceSlot。 + + 兼容多种格式: + - runtime: "unilabos.registry.placeholder_type:ResourceSlot" + - runtime tuple: ("list", "unilabos.registry.placeholder_type:ResourceSlot") + - AST 裸名: "ResourceSlot", "List[ResourceSlot]", "Optional[ResourceSlot]" + + Returns: (slot_name | None, is_list) + """ + ptype_str = str(ptype) + + # 快速路径: 字符串里根本没有 Slot + if "ResourceSlot" not in ptype_str and "DeviceSlot" not in ptype_str: + return (None, False) + + # runtime 格式: 完整模块路径 + if isinstance(ptype, str): + if ptype.endswith(":ResourceSlot") or ptype == "ResourceSlot": + return ("ResourceSlot", False) + if ptype.endswith(":DeviceSlot") or ptype == "DeviceSlot": + return ("DeviceSlot", False) + # AST 复杂格式: List[ResourceSlot], Optional[ResourceSlot] 等 + if "[" in ptype: + node = parse_type_node(ptype) + if node is not None: + schema = type_node_to_schema(node) + # 直接是 slot + if "$slot" in schema: + return (schema["$slot"], False) + # array 包裹 slot: {"type": "array", "items": {"$slot": "..."}} + items = schema.get("items", {}) + if isinstance(items, dict) and "$slot" in items: + return (items["$slot"], True) + return (None, False) + + # runtime tuple 格式 + if isinstance(ptype, tuple) and len(ptype) == 2: + inner_str = str(ptype[1]) + if "ResourceSlot" in inner_str: + return ("ResourceSlot", True) + if "DeviceSlot" in inner_str: + return ("DeviceSlot", True) + + return (None, False) + + +def detect_placeholder_keys(params: list) -> Dict[str, str]: + """Detect parameters that reference ResourceSlot or DeviceSlot.""" + result: Dict[str, str] = {} + for p in params: + ptype = p.get("type", "") + if "ResourceSlot" in str(ptype): + result[p["name"]] = "unilabos_resources" + elif "DeviceSlot" in str(ptype): + result[p["name"]] = "unilabos_devices" + return result + + +# --------------------------------------------------------------------------- +# Handle 规范化 +# --------------------------------------------------------------------------- + + +def normalize_ast_handles(handles_raw: Any) -> List[Dict[str, Any]]: + """Convert AST-parsed handle structures to the standard registry format.""" + if not handles_raw: + return [] + + # handle_type → io_type 映射 (AST 内部类名 → YAML 标准字段值) + _HANDLE_TYPE_TO_IO_TYPE = { + "input": "target", + "output": "source", + "action_input": "action_target", + "action_output": "action_source", + } + + result: List[Dict[str, Any]] = [] + for h in handles_raw: + if isinstance(h, dict): + call = h.get("_call", "") + if "InputHandle" in call: + handle_type = "input" + elif "OutputHandle" in call: + handle_type = "output" + elif "ActionInputHandle" in call: + handle_type = "action_input" + elif "ActionOutputHandle" in call: + handle_type = "action_output" + else: + handle_type = h.get("handle_type", "unknown") + + io_type = _HANDLE_TYPE_TO_IO_TYPE.get(handle_type, handle_type) + + entry: Dict[str, Any] = { + "handler_key": h.get("key", ""), + "data_type": h.get("data_type", ""), + "io_type": io_type, + } + side = h.get("side") + if side: + if isinstance(side, str) and "." in side: + val = side.rsplit(".", 1)[-1] + side = val.lower() if val in ("LEFT", "RIGHT", "TOP", "BOTTOM") else val + entry["side"] = side + label = h.get("label") + if label: + entry["label"] = label + data_key = h.get("data_key") + if data_key: + entry["data_key"] = data_key + data_source = h.get("data_source") + if data_source: + if isinstance(data_source, str) and "." in data_source: + val = data_source.rsplit(".", 1)[-1] + data_source = val.lower() if val in ("HANDLE", "EXECUTOR") else val + entry["data_source"] = data_source + description = h.get("description") + if description: + entry["description"] = description + + result.append(entry) + return result + + +def normalize_ast_action_handles(handles_raw: Any) -> Dict[str, Any]: + """Convert AST-parsed action handle list to {"input": [...], "output": [...]}. + + Mirrors the runtime behavior of decorators._action_handles_to_dict: + - ActionInputHandle => grouped under "input" + - ActionOutputHandle => grouped under "output" + Field mapping: key -> handler_key (matches Pydantic serialization_alias). + """ + if not handles_raw or not isinstance(handles_raw, list): + return {} + + input_list: List[Dict[str, Any]] = [] + output_list: List[Dict[str, Any]] = [] + + for h in handles_raw: + if not isinstance(h, dict): + continue + call = h.get("_call", "") + is_input = "ActionInputHandle" in call or "InputHandle" in call + is_output = "ActionOutputHandle" in call or "OutputHandle" in call + + entry: Dict[str, Any] = { + "handler_key": h.get("key", ""), + "data_type": h.get("data_type", ""), + "label": h.get("label", ""), + } + for opt_key in ("side", "data_key", "data_source", "description", "io_type"): + val = h.get(opt_key) + if val is not None: + # Only resolve enum-style refs (e.g. DataSource.HANDLE -> handle) for data_source/side + # data_key values like "wells.@flatten", "@this.0@@@plate" must be preserved as-is + if ( + isinstance(val, str) + and "." in val + and opt_key not in ("io_type", "data_key") + ): + val = val.rsplit(".", 1)[-1].lower() + entry[opt_key] = val + + # io_type: only add when explicitly set; do not default output to "sink" (YAML convention omits it) + if "io_type" not in entry and is_input: + entry["io_type"] = "source" + + if is_input: + input_list.append(entry) + elif is_output: + output_list.append(entry) + + result: Dict[str, Any] = {} + if input_list: + result["input"] = input_list + # Always include output (empty list when no outputs) to match YAML + result["output"] = output_list + return result + + +# --------------------------------------------------------------------------- +# Schema 辅助 +# --------------------------------------------------------------------------- + + +def wrap_action_schema( + goal_schema: Dict[str, Any], + action_name: str, + description: str = "", + result_schema: Optional[Dict[str, Any]] = None, + feedback_schema: Optional[Dict[str, Any]] = None, +) -> Dict[str, Any]: + """ + 将 goal 参数 schema 包装为标准的 action schema 格式: + { "properties": { "goal": ..., "feedback": ..., "result": ... }, ... } + """ + # 去掉 auto- 前缀用于 title/description,与 YAML 路径保持一致 + display_name = action_name.removeprefix("auto-") + return { + "title": f"{display_name}参数", + "description": description or f"{display_name}的参数schema", + "type": "object", + "properties": { + "goal": goal_schema, + "feedback": feedback_schema or {}, + "result": result_schema or {}, + }, + "required": ["goal"], + } + + +def preserve_field_descriptions(new_schema: Dict[str, Any], prev_schema: Dict[str, Any]): + """保留之前 schema 中的 field descriptions""" + if not prev_schema or not new_schema: + return + prev_props = prev_schema.get("properties", {}) + new_props = new_schema.get("properties", {}) + for field_name, prev_field in prev_props.items(): + if field_name in new_props and "title" in prev_field: + new_props[field_name].setdefault("title", prev_field["title"]) + + +# --------------------------------------------------------------------------- +# 深度对比 +# --------------------------------------------------------------------------- + + +def _short(val, limit=120): + """截断过长的值用于日志显示。""" + s = repr(val) + return s if len(s) <= limit else s[:limit] + "..." + + +def deep_diff(old, new, path="", max_depth=10) -> list: + """递归对比两个对象,返回所有差异的描述列表。""" + diffs = [] + if max_depth <= 0: + if old != new: + diffs.append(f"{path}: (达到最大深度) OLD≠NEW") + return diffs + + if type(old) != type(new): + diffs.append(f"{path}: 类型不同 OLD={type(old).__name__}({_short(old)}) NEW={type(new).__name__}({_short(new)})") + return diffs + + if isinstance(old, dict): + old_keys = set(old.keys()) + new_keys = set(new.keys()) + for k in sorted(new_keys - old_keys): + diffs.append(f"{path}.{k}: 新增字段 (AST有, YAML无) = {_short(new[k])}") + for k in sorted(old_keys - new_keys): + diffs.append(f"{path}.{k}: 缺失字段 (YAML有, AST无) = {_short(old[k])}") + for k in sorted(old_keys & new_keys): + diffs.extend(deep_diff(old[k], new[k], f"{path}.{k}", max_depth - 1)) + elif isinstance(old, (list, tuple)): + if len(old) != len(new): + diffs.append(f"{path}: 列表长度不同 OLD={len(old)} NEW={len(new)}") + for i in range(min(len(old), len(new))): + diffs.extend(deep_diff(old[i], new[i], f"{path}[{i}]", max_depth - 1)) + if len(new) > len(old): + for i in range(len(old), len(new)): + diffs.append(f"{path}[{i}]: 新增元素 = {_short(new[i])}") + elif len(old) > len(new): + for i in range(len(new), len(old)): + diffs.append(f"{path}[{i}]: 缺失元素 = {_short(old[i])}") + else: + if old != new: + diffs.append(f"{path}: OLD={_short(old)} NEW={_short(new)}") + return diffs + + +# --------------------------------------------------------------------------- +# MRO 方法参数解析 +# --------------------------------------------------------------------------- + + +def resolve_method_params_via_import(module_str: str, method_name: str) -> Dict[str, str]: + """当 AST 方法参数为空 (如 *args, **kwargs) 时, import class 并通过 MRO 获取真实方法参数. + + 返回 identity mapping {param_name: param_name}. + """ + if not module_str or ":" not in module_str: + return {} + try: + cls = import_class(module_str) + except Exception as e: + _logger.debug(f"[AST] resolve_method_params_via_import: import_class('{module_str}') failed: {e}") + return {} + + try: + for base_cls in cls.__mro__: + if method_name not in base_cls.__dict__: + continue + method = base_cls.__dict__[method_name] + actual = getattr(method, "__wrapped__", method) + if isinstance(actual, (staticmethod, classmethod)): + actual = actual.__func__ + if not callable(actual): + continue + sig = inspect.signature(actual, follow_wrapped=True) + params = [ + p.name for p in sig.parameters.values() + if p.name not in ("self", "cls") + and p.kind not in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD) + ] + if params: + return {p: p for p in params} + except Exception as e: + _logger.debug(f"[AST] resolve_method_params_via_import: MRO walk for '{method_name}' failed: {e}") + return {} diff --git a/unilabos/resources/graphio.py b/unilabos/resources/graphio.py index 38f96968..c8f1cc2c 100644 --- a/unilabos/resources/graphio.py +++ b/unilabos/resources/graphio.py @@ -76,7 +76,7 @@ def canonicalize_nodes_data( if sample_id: logger.error(f"{node}的sample_id参数已弃用,sample_id: {sample_id}") for k in list(node.keys()): - if k not in ["id", "uuid", "name", "description", "schema", "model", "icon", "parent_uuid", "parent", "type", "class", "position", "config", "data", "children", "pose", "extra"]: + if k not in ["id", "uuid", "name", "description", "schema", "model", "icon", "parent_uuid", "parent", "type", "class", "position", "config", "data", "children", "pose", "extra", "machine_name"]: v = node.pop(k) node["config"][k] = v if outer_host_node_id is not None: @@ -288,6 +288,15 @@ def read_node_link_json( physical_setup_graph = nx.node_link_graph(graph_data, edges="links", multigraph=False) handle_communications(physical_setup_graph) + # Stamp machine_name on device trees only (resources are cloud-managed) + local_machine = BasicConfig.machine_name or "本地" + for tree in resource_tree_set.trees: + if tree.root_node.res_content.type != "device": + continue + for node in tree.get_all_nodes(): + if not node.res_content.machine_name: + node.res_content.machine_name = local_machine + return physical_setup_graph, resource_tree_set, standardized_links @@ -372,6 +381,15 @@ def read_graphml(graphml_file: str) -> tuple[nx.Graph, ResourceTreeSet, List[Dic physical_setup_graph = nx.node_link_graph(graph_data, link="links", multigraph=False) handle_communications(physical_setup_graph) + # Stamp machine_name on device trees only (resources are cloud-managed) + local_machine = BasicConfig.machine_name or "本地" + for tree in resource_tree_set.trees: + if tree.root_node.res_content.type != "device": + continue + for node in tree.get_all_nodes(): + if not node.res_content.machine_name: + node.res_content.machine_name = local_machine + return physical_setup_graph, resource_tree_set, standardized_links diff --git a/unilabos/resources/resource_tracker.py b/unilabos/resources/resource_tracker.py index b34d10cc..3fb945b6 100644 --- a/unilabos/resources/resource_tracker.py +++ b/unilabos/resources/resource_tracker.py @@ -120,6 +120,7 @@ class ResourceDictType(TypedDict): config: Dict[str, Any] data: Dict[str, Any] extra: Dict[str, Any] + machine_name: str # 统一的资源字典模型,parent 自动序列化为 parent_uuid,children 不序列化 @@ -141,6 +142,7 @@ class ResourceDict(BaseModel): config: Dict[str, Any] = Field(description="Resource configuration") data: Dict[str, Any] = Field(description="Resource data, eg: container liquid data") extra: Dict[str, Any] = Field(description="Extra data, eg: slot index") + machine_name: str = Field(description="Machine this resource belongs to", default="") @field_serializer("parent_uuid") def _serialize_parent(self, parent_uuid: Optional["ResourceDict"]): @@ -196,22 +198,30 @@ class ResourceDictInstance(object): self.typ = "dict" @classmethod - def get_resource_instance_from_dict(cls, content: Dict[str, Any]) -> "ResourceDictInstance": + def get_resource_instance_from_dict(cls, content: ResourceDictType) -> "ResourceDictInstance": """从字典创建资源实例""" if "id" not in content: content["id"] = content["name"] if "uuid" not in content: content["uuid"] = str(uuid.uuid4()) if "description" in content and content["description"] is None: + # noinspection PyTypedDict del content["description"] if "model" in content and content["model"] is None: + # noinspection PyTypedDict del content["model"] + # noinspection PyTypedDict if "schema" in content and content["schema"] is None: + # noinspection PyTypedDict del content["schema"] + # noinspection PyTypedDict if "x" in content.get("position", {}): # 说明是老版本的position格式,转换成新的 + # noinspection PyTypedDict content["position"] = {"position": content["position"]} + # noinspection PyTypedDict if not content.get("class"): + # noinspection PyTypedDict content["class"] = "" if not content.get("config"): # todo: 后续从后端保证字段非空 content["config"] = {} @@ -222,16 +232,18 @@ class ResourceDictInstance(object): if "position" in content: pose = content.get("pose", {}) if "position" not in pose: + # noinspection PyTypedDict if "position" in content["position"]: + # noinspection PyTypedDict pose["position"] = content["position"]["position"] else: - pose["position"] = {"x": 0, "y": 0, "z": 0} + pose["position"] = ResourceDictPositionObjectType(x=0, y=0, z=0) if "size" not in pose: - pose["size"] = { - "width": content["config"].get("size_x", 0), - "height": content["config"].get("size_y", 0), - "depth": content["config"].get("size_z", 0), - } + pose["size"] = ResourceDictPositionSizeType( + width= content["config"].get("size_x", 0), + height= content["config"].get("size_y", 0), + depth= content["config"].get("size_z", 0), + ) content["pose"] = pose try: res_dict = ResourceDict.model_validate(content) @@ -399,7 +411,7 @@ class ResourceTreeSet(object): ) @classmethod - def from_plr_resources(cls, resources: List["PLRResource"], known_newly_created=False) -> "ResourceTreeSet": + def from_plr_resources(cls, resources: List["PLRResource"], known_newly_created=False, old_size=False) -> "ResourceTreeSet": """ 从plr资源创建ResourceTreeSet """ @@ -422,13 +434,20 @@ class ResourceTreeSet(object): "resource_group": "resource_group", "trash": "trash", "plate_adapter": "plate_adapter", + "consumable": "consumable", + "tool": "tool", + "condenser": "condenser", + "crucible": "crucible", + "reagent_bottle": "reagent_bottle", + "flask": "flask", + "beaker": "beaker", } if source in replace_info: return replace_info[source] elif source is None: return "" else: - print("转换pylabrobot的时候,出现未知类型", source) + logger.trace(f"转换pylabrobot的时候,出现未知类型 {source}") return source def build_uuid_mapping(res: "PLRResource", uuid_list: list, parent_uuid: Optional[str] = None): @@ -483,7 +502,7 @@ class ResourceTreeSet(object): k: v for k, v in d.items() if k - not in [ + not in ([ "name", "children", "parent_name", @@ -494,7 +513,15 @@ class ResourceTreeSet(object): "size_z", "cross_section_type", "bottom_type", - ] + ] if not old_size else [ + "name", + "children", + "parent_name", + "location", + "rotation", + "cross_section_type", + "bottom_type", + ]) }, "data": states[d["name"]], "extra": extra, @@ -793,7 +820,8 @@ class ResourceTreeSet(object): if remote_root_type == "device": # 情况1: 一级是 device if remote_root_id not in local_device_map: - logger.warning(f"Device '{remote_root_id}' 在本地不存在,跳过该 device 下的物料同步") + if remote_root_id != "host_node": + logger.warning(f"Device '{remote_root_id}' 在本地不存在,跳过该 device 下的物料同步") continue local_device = local_device_map[remote_root_id] @@ -883,7 +911,7 @@ class ResourceTreeSet(object): return self - def dump(self) -> List[List[Dict[str, Any]]]: + def dump(self, old_position=False) -> List[List[Dict[str, Any]]]: """ 将 ResourceTreeSet 序列化为嵌套列表格式 @@ -899,6 +927,10 @@ class ResourceTreeSet(object): # 获取树的所有节点并序列化 tree_nodes = [node.res_content.model_dump(by_alias=True) for node in tree.get_all_nodes()] result.append(tree_nodes) + if old_position: + for r in result: + for rr in r: + rr["position"] = rr["pose"]["position"] return result @classmethod diff --git a/unilabos/ros/msgs/message_converter.py b/unilabos/ros/msgs/message_converter.py index b526d5f5..1451ee5c 100644 --- a/unilabos/ros/msgs/message_converter.py +++ b/unilabos/ros/msgs/message_converter.py @@ -11,6 +11,7 @@ from io import StringIO from typing import Iterable, Any, Dict, Type, TypeVar, Union import yaml +from msgcenterpy.instances.ros2_instance import ROS2MessageInstance from pydantic import BaseModel from dataclasses import asdict, is_dataclass @@ -727,46 +728,9 @@ def ros_message_to_json_schema(msg_class: Any, field_name: str) -> Dict[str, Any Returns: 对应的 JSON Schema 定义 """ - schema = {"type": "object", "properties": {}, "required": []} - - # 优先使用字段名作为标题,否则使用类名 + schema = ROS2MessageInstance(msg_class()).get_json_schema() schema["title"] = field_name - - # 获取消息的字段和字段类型 - try: - for ind, slot_info in enumerate(msg_class._fields_and_field_types.items()): - slot_name, slot_type = slot_info - type_info = msg_class.SLOT_TYPES[ind] - field_schema = ros_field_type_to_json_schema(type_info, slot_name) - schema["properties"][slot_name] = field_schema - schema["required"].append(slot_name) - # if hasattr(msg_class, 'get_fields_and_field_types'): - # fields_and_types = msg_class.get_fields_and_field_types() - # - # for field_name, field_type in fields_and_types.items(): - # # 将 ROS 字段类型转换为 JSON Schema - # field_schema = ros_field_type_to_json_schema(field_type) - # - # schema['properties'][field_name] = field_schema - # schema['required'].append(field_name) - # elif hasattr(msg_class, '__slots__') and hasattr(msg_class, '_fields_and_field_types'): - # # 直接从实例属性获取 - # for field_name in msg_class.__slots__: - # # 移除前导下划线(如果有) - # clean_name = field_name[1:] if field_name.startswith('_') else field_name - # - # # 从 _fields_and_field_types 获取类型 - # if clean_name in msg_class._fields_and_field_types: - # field_type = msg_class._fields_and_field_types[clean_name] - # field_schema = ros_field_type_to_json_schema(field_type) - # - # schema['properties'][clean_name] = field_schema - # schema['required'].append(clean_name) - except Exception as e: - # 如果获取字段类型失败,添加错误信息 - schema["description"] = f"解析消息字段时出错: {str(e)}" - logger.error(f"解析 {msg_class.__name__} 消息字段失败: {str(e)}") - + schema.pop("description") return schema diff --git a/unilabos/ros/nodes/base_device_node.py b/unilabos/ros/nodes/base_device_node.py index 772e667b..28fe92a5 100644 --- a/unilabos/ros/nodes/base_device_node.py +++ b/unilabos/ros/nodes/base_device_node.py @@ -34,7 +34,8 @@ from unilabos_msgs.action import SendCmd from unilabos_msgs.srv._serial_command import SerialCommand_Request, SerialCommand_Response from unilabos.config.config import BasicConfig -from unilabos.utils.decorator import get_topic_config, get_all_subscriptions +from unilabos.registry.decorators import get_topic_config +from unilabos.utils.decorator import get_all_subscriptions from unilabos.resources.container import RegularContainer from unilabos.resources.graphio import ( @@ -57,6 +58,7 @@ from unilabos_msgs.msg import Resource # type: ignore from unilabos.resources.resource_tracker import ( DeviceNodeResourceTracker, + ResourceDictType, ResourceTreeSet, ResourceTreeInstance, ResourceDictInstance, @@ -194,9 +196,9 @@ class PropertyPublisher: self._value = None try: self.publisher_ = node.create_publisher(msg_type, f"{name}", qos) - except AttributeError as ex: + except Exception as e: self.node.lab_logger().error( - f"创建发布者 {name} 失败,可能由于注册表有误,类型: {msg_type},错误: {ex}\n{traceback.format_exc()}" + f"StatusError, DeviceId: {self.node.device_id} 创建发布者 {name} 失败,可能由于注册表有误,类型: {msg_type},错误: {e}" ) self.timer = node.create_timer(self.timer_period, self.publish_property) self.__loop = ROS2DeviceNode.get_asyncio_loop() @@ -596,6 +598,12 @@ class BaseROS2DeviceNode(Node, Generic[T]): self.s2c_resource_tree, # type: ignore callback_group=self.callback_group, ), + "s2c_device_manage": self.create_service( + SerialCommand, + f"/srv{self.namespace}/s2c_device_manage", + self.s2c_device_manage, # type: ignore + callback_group=self.callback_group, + ), } # 向全局在线设备注册表添加设备信息 @@ -1064,6 +1072,48 @@ class BaseROS2DeviceNode(Node, Generic[T]): return res + async def s2c_device_manage(self, req: SerialCommand_Request, res: SerialCommand_Response): + """Handle add/remove device requests from HostNode via SerialCommand.""" + try: + cmd = json.loads(req.command) + action = cmd.get("action", "") + data = cmd.get("data", {}) + device_id = data.get("device_id", "") + + if not device_id: + res.response = json.dumps({"success": False, "error": "device_id required"}) + return res + + if action == "add": + result = self.create_device(device_id, data) + elif action == "remove": + result = self.destroy_device(device_id) + else: + result = {"success": False, "error": f"Unknown action: {action}"} + + res.response = json.dumps(result, ensure_ascii=False) + + except NotImplementedError as e: + self.lab_logger().warning(f"[DeviceManage] {e}") + res.response = json.dumps({"success": False, "error": str(e)}) + except Exception as e: + self.lab_logger().error(f"[DeviceManage] Error: {e}") + res.response = json.dumps({"success": False, "error": str(e)}) + + return res + + def create_device(self, device_id: str, config: "ResourceDictType") -> dict: + """Create a sub-device dynamically. Override in HostNode / WorkstationNode.""" + raise NotImplementedError( + f"{self.__class__.__name__} does not support dynamic device creation" + ) + + def destroy_device(self, device_id: str) -> dict: + """Destroy a sub-device dynamically. Override in HostNode / WorkstationNode.""" + raise NotImplementedError( + f"{self.__class__.__name__} does not support dynamic device removal" + ) + async def transfer_resource_to_another( self, plr_resources: List["ResourcePLR"], @@ -1206,22 +1256,40 @@ class BaseROS2DeviceNode(Node, Generic[T]): return self._lab_logger def create_ros_publisher(self, attr_name, msg_type, initial_period=5.0): - """创建ROS发布者""" - # 检测装饰器配置(支持 get_{attr_name} 方法和 @property) + """创建ROS发布者,仅当方法/属性有 @topic_config 装饰器时才创建。""" + # 检测 @topic_config 装饰器配置 topic_config = {} + driver_class = type(self.driver_instance) - # 优先检测 get_{attr_name} 方法 - if hasattr(self.driver_instance, f"get_{attr_name}"): - getter_method = getattr(self.driver_instance, f"get_{attr_name}") - topic_config = get_topic_config(getter_method) + # 区分 @property 和普通方法两种情况 + is_prop = hasattr(driver_class, attr_name) and isinstance( + getattr(driver_class, attr_name), property + ) - # 如果没有配置,检测 @property 装饰的属性 + if is_prop: + # @property: 检测 fget 上的 @topic_config + class_attr = getattr(driver_class, attr_name) + if class_attr.fget is not None: + topic_config = get_topic_config(class_attr.fget) + else: + # 普通方法: 直接检测 attr_name 方法上的 @topic_config + if hasattr(self.driver_instance, attr_name): + method = getattr(self.driver_instance, attr_name) + if callable(method): + topic_config = get_topic_config(method) + + # 没有 @topic_config 装饰器则跳过发布 if not topic_config: - driver_class = type(self.driver_instance) - if hasattr(driver_class, attr_name): - class_attr = getattr(driver_class, attr_name) - if isinstance(class_attr, property) and class_attr.fget is not None: - topic_config = get_topic_config(class_attr.fget) + return + + # 发布名称优先级: @topic_config(name=...) > get_ 前缀去除 > attr_name + cfg_name = topic_config.get("name") + if cfg_name: + publish_name = cfg_name + elif attr_name.startswith("get_"): + publish_name = attr_name[4:] + else: + publish_name = attr_name # 使用装饰器配置或默认值 cfg_period = topic_config.get("period") @@ -1234,10 +1302,10 @@ class BaseROS2DeviceNode(Node, Generic[T]): # 获取属性值的方法 def get_device_attr(): try: - if hasattr(self.driver_instance, f"get_{attr_name}"): - return getattr(self.driver_instance, f"get_{attr_name}")() - else: + if is_prop: return getattr(self.driver_instance, attr_name) + else: + return getattr(self.driver_instance, attr_name)() except AttributeError as ex: if ex.args[0].startswith(f"AttributeError: '{self.driver_instance.__class__.__name__}' object"): self.lab_logger().error( @@ -1249,8 +1317,8 @@ class BaseROS2DeviceNode(Node, Generic[T]): ) self.lab_logger().error(traceback.format_exc()) - self._property_publishers[attr_name] = PropertyPublisher( - self, attr_name, get_device_attr, msg_type, period, print_publish, qos + self._property_publishers[publish_name] = PropertyPublisher( + self, publish_name, get_device_attr, msg_type, period, print_publish, qos ) def create_ros_action_server(self, action_name, action_value_mapping): @@ -1258,14 +1326,17 @@ class BaseROS2DeviceNode(Node, Generic[T]): action_type = action_value_mapping["type"] str_action_type = str(action_type)[8:-2] - self._action_servers[action_name] = ActionServer( - self, - action_type, - action_name, - execute_callback=self._create_execute_callback(action_name, action_value_mapping), - callback_group=self.callback_group, - ) - + try: + self._action_servers[action_name] = ActionServer( + self, + action_type, + action_name, + execute_callback=self._create_execute_callback(action_name, action_value_mapping), + callback_group=self.callback_group, + ) + except Exception as e: + self.lab_logger().error(f"创建ActionServer失败,Device: {self.device_id}, Action Name: {action_name}, Action Type: {action_type}, Error: {e}") + return self.lab_logger().trace(f"发布动作: {action_name}, 类型: {str_action_type}") def _setup_decorated_subscribers(self): diff --git a/unilabos/ros/nodes/presets/camera.py b/unilabos/ros/nodes/presets/camera.py index 2267f676..e94f001f 100644 --- a/unilabos/ros/nodes/presets/camera.py +++ b/unilabos/ros/nodes/presets/camera.py @@ -4,7 +4,14 @@ import cv2 from sensor_msgs.msg import Image from cv_bridge import CvBridge from unilabos.ros.nodes.base_device_node import BaseROS2DeviceNode, DeviceNodeResourceTracker +from unilabos.registry.decorators import device + +@device( + id="camera", + category=["camera"], + description="""VideoPublisher摄像头设备节点,用于实时视频采集和流媒体发布。该设备通过OpenCV连接本地摄像头(如USB摄像头、内置摄像头等),定时采集视频帧并将其转换为ROS2的sensor_msgs/Image消息格式发布到视频话题。主要用于实验室自动化系统中的视觉监控、图像分析、实时观察等应用场景。支持可配置的摄像头索引、发布频率等参数。""", +) class VideoPublisher(BaseROS2DeviceNode): def __init__(self, device_id='video_publisher', registry_name="", device_uuid='', camera_index=0, period: float = 0.1, resource_tracker: DeviceNodeResourceTracker = None): # 初始化BaseROS2DeviceNode,使用自身作为driver_instance diff --git a/unilabos/ros/nodes/presets/host_node.py b/unilabos/ros/nodes/presets/host_node.py index aa8b813f..eb139f1f 100644 --- a/unilabos/ros/nodes/presets/host_node.py +++ b/unilabos/ros/nodes/presets/host_node.py @@ -12,6 +12,7 @@ from geometry_msgs.msg import Point from rclpy.action import ActionClient, get_action_server_names_and_types_by_node from rclpy.service import Service from typing_extensions import TypedDict +from unilabos_msgs.action import EmptyIn, StrSingleInput, ResourceCreateFromOuterEasy, ResourceCreateFromOuter from unilabos_msgs.msg import Resource # type: ignore from unilabos_msgs.srv import ( ResourceAdd, @@ -23,6 +24,7 @@ from unilabos_msgs.srv import ( from unilabos_msgs.srv._serial_command import SerialCommand_Request, SerialCommand_Response from unique_identifier_msgs.msg import UUID +from unilabos.registry.decorators import device from unilabos.registry.placeholder_type import ResourceSlot, DeviceSlot from unilabos.registry.registry import lab_registry from unilabos.resources.container import RegularContainer @@ -30,6 +32,7 @@ from unilabos.resources.graphio import initialize_resource from unilabos.resources.registry import add_schema from unilabos.resources.resource_tracker import ( ResourceDict, + ResourceDictType, ResourceDictInstance, ResourceTreeSet, ResourceTreeInstance, @@ -86,6 +89,7 @@ class TestLatencyReturn(TypedDict): status: str +@device(id="host_node", category=[], description="Host Node", icon="icon_device.webp") class HostNode(BaseROS2DeviceNode): """ 主机节点类,负责管理设备、资源和控制器 @@ -274,44 +278,42 @@ class HostNode(BaseROS2DeviceNode): self._action_clients: Dict[str, ActionClient] = { # 为了方便了解实际的数据类型,host的默认写好 "/devices/host_node/create_resource": ActionClient( self, - lab_registry.ResourceCreateFromOuterEasy, + ResourceCreateFromOuterEasy, "/devices/host_node/create_resource", callback_group=self.callback_group, ), "/devices/host_node/create_resource_detailed": ActionClient( self, - lab_registry.ResourceCreateFromOuter, + ResourceCreateFromOuter, "/devices/host_node/create_resource_detailed", callback_group=self.callback_group, ), "/devices/host_node/test_latency": ActionClient( self, - lab_registry.EmptyIn, + EmptyIn, "/devices/host_node/test_latency", callback_group=self.callback_group, ), "/devices/host_node/test_resource": ActionClient( self, - lab_registry.EmptyIn, + EmptyIn, "/devices/host_node/test_resource", callback_group=self.callback_group, ), "/devices/host_node/_execute_driver_command": ActionClient( self, - lab_registry.StrSingleInput, + StrSingleInput, "/devices/host_node/_execute_driver_command", callback_group=self.callback_group, ), "/devices/host_node/_execute_driver_command_async": ActionClient( self, - lab_registry.StrSingleInput, + StrSingleInput, "/devices/host_node/_execute_driver_command_async", callback_group=self.callback_group, ), } # 用来存储多个ActionClient实例 - self._action_value_mappings: Dict[str, Dict] = ( - {} - ) # device_id -> action_value_mappings(本地+远程设备统一存储) + self._action_value_mappings: Dict[str, Dict] = {} # device_id -> action_value_mappings(本地+远程设备统一存储) self._slave_registry_configs: Dict[str, Dict] = {} # registry_name -> registry_config(含action_value_mappings) self._goals: Dict[str, Any] = {} # 用来存储多个目标的状态 self._online_devices: Set[str] = {f"{self.namespace}/{device_id}"} # 用于跟踪在线设备 @@ -329,10 +331,18 @@ class HostNode(BaseROS2DeviceNode): self._discover_devices() # 初始化所有本机设备节点,多一次过滤,防止重复初始化 + local_machine = BasicConfig.machine_name for device_config in devices_config.root_nodes: device_id = device_config.res_content.id if device_config.res_content.type != "device": continue + dev_machine = device_config.res_content.machine_name + if dev_machine and local_machine and dev_machine != local_machine: + self.lab_logger().info( + f"[Host Node] Device {device_id} belongs to machine '{dev_machine}', " + f"local is '{local_machine}', skipping initialization." + ) + continue if device_id not in self.devices_names: self.initialize_device(device_id, device_config) else: @@ -658,7 +668,12 @@ class HostNode(BaseROS2DeviceNode): action_id = f"/devices/{device_id}/{action_name}" if action_id not in self._action_clients: action_type = action_value_mapping["type"] - self._action_clients[action_id] = ActionClient(self, action_type, action_id) + try: + self._action_clients[action_id] = ActionClient(self, action_type, action_id) + except Exception as e: + self.lab_logger().error( + f"创建ActionClient失败,Device: {device_id}, Action Name: {action_name}, Action Type: {action_type}, Error: {e}") + continue self.lab_logger().trace( f"[Host Node] Created ActionClient (Local): {action_id}" ) # 子设备再创建用的是Discover发现的 @@ -1258,9 +1273,9 @@ class HostNode(BaseROS2DeviceNode): # 用 registry_name 索引已存储的 registry_config,获取 action_value_mappings if registry_name and registry_name in self._slave_registry_configs: - action_mappings = self._slave_registry_configs[registry_name].get( - "class", {} - ).get("action_value_mappings", {}) + action_mappings = ( + self._slave_registry_configs[registry_name].get("class", {}).get("action_value_mappings", {}) + ) if action_mappings: self._action_value_mappings[edge_device_id] = action_mappings self.lab_logger().info( @@ -1280,14 +1295,19 @@ class HostNode(BaseROS2DeviceNode): # 解析 devices_config,建立 device_id -> action_value_mappings 映射 if devices_config: + machine_name = info["machine_name"] + # Stamp machine_name on each device dict before parsing for device_tree in devices_config: for device_dict in device_tree: + device_dict["machine_name"] = machine_name device_id = device_dict.get("id", "") class_name = device_dict.get("class", "") if device_id and class_name and class_name in self._slave_registry_configs: - action_mappings = self._slave_registry_configs[class_name].get( - "class", {} - ).get("action_value_mappings", {}) + action_mappings = ( + self._slave_registry_configs[class_name] + .get("class", {}) + .get("action_value_mappings", {}) + ) if action_mappings: self._action_value_mappings[device_id] = action_mappings self.lab_logger().info( @@ -1295,6 +1315,18 @@ class HostNode(BaseROS2DeviceNode): f"for remote device {device_id} (class: {class_name})" ) + # Merge slave devices_config into self.devices_config tree + try: + slave_tree_set = ResourceTreeSet.load(devices_config) # slave一定是根节点的tree + for tree in slave_tree_set.trees: + self.devices_config.trees.append(tree) + self.lab_logger().info( + f"[Host Node] Merged {len(slave_tree_set.trees)} slave device trees " + f"(machine: {machine_name}) into devices_config" + ) + except Exception as e: + self.lab_logger().error(f"[Host Node] Failed to merge slave devices_config: {e}") + self.lab_logger().debug(f"[Host Node] Node info update: {info}") response.response = "OK" except Exception as e: @@ -1703,3 +1735,177 @@ class HostNode(BaseROS2DeviceNode): self.lab_logger().error(f"[Host Node-Resource] Error notifying resource tree update: {str(e)}") self.lab_logger().error(traceback.format_exc()) return False + + # ------------------------------------------------------------------ + # Device lifecycle (add / remove) — pure forwarder + # ------------------------------------------------------------------ + + def notify_device_manage(self, target_node_id: str, action: str, config: ResourceDictType) -> bool: + """Forward an add/remove device command to the target node via ROS2 SerialCommand. + + The HostNode does NOT interpret the command; it simply resolves the + target namespace and forwards the request to ``s2c_device_manage``. + + If *target_node_id* equals the HostNode's own device_id (i.e. the + command targets the host itself), we call our local ``create_device`` + / ``destroy_device`` directly instead of going through ROS2. + """ + try: + # If the target is the host itself, handle locally + device_id = config["id"] + if target_node_id == self.device_id: + if action == "add": + return self.create_device(device_id, config).get("success", False) + elif action == "remove": + return self.destroy_device(device_id).get("success", False) + + if target_node_id not in self.devices_names: + self.lab_logger().error( + f"[Host Node-DeviceMgr] Target {target_node_id} not found in devices_names" + ) + return False + + namespace = self.devices_names[target_node_id] + device_key = f"{namespace}/{target_node_id}" + if device_key not in self._online_devices: + self.lab_logger().error(f"[Host Node-DeviceMgr] Target {device_key} is offline") + return False + + srv_address = f"/srv{namespace}/s2c_device_manage" + self.lab_logger().info( + f"[Host Node-DeviceMgr] Forwarding {action}_device to {target_node_id} ({srv_address})" + ) + + sclient = self.create_client(SerialCommand, srv_address) + if not sclient.wait_for_service(timeout_sec=5.0): + self.lab_logger().error(f"[Host Node-DeviceMgr] Service {srv_address} not available") + return False + + request = SerialCommand.Request() + request.command = json.dumps({"action": action, "data": config}, ensure_ascii=False) + + future = sclient.call_async(request) + timeout = 30.0 + start_time = time.time() + while not future.done(): + if time.time() - start_time > timeout: + self.lab_logger().error( + f"[Host Node-DeviceMgr] Timeout waiting for {action}_device on {target_node_id}" + ) + return False + time.sleep(0.05) + + response = future.result() + self.lab_logger().info( + f"[Host Node-DeviceMgr] {action}_device on {target_node_id} completed" + ) + return True + + except Exception as e: + self.lab_logger().error(f"[Host Node-DeviceMgr] Error: {e}") + self.lab_logger().error(traceback.format_exc()) + return False + + def create_device(self, device_id: str, config: ResourceDictType) -> dict: + """Dynamically create a root-level device on the host.""" + if not device_id: + return {"success": False, "error": "device_id required"} + + if device_id in self.devices_names: + return {"success": False, "error": f"Device {device_id} already exists"} + + try: + config.setdefault("id", device_id) + config.setdefault("type", "device") + config.setdefault("machine_name", BasicConfig.machine_name or "本地") + res_dict = ResourceDictInstance.get_resource_instance_from_dict(config) + + self.initialize_device(device_id, res_dict) + + if device_id not in self.devices_names: + return {"success": False, "error": f"initialize_device failed for {device_id}"} + + # Add to config tree (devices_config) + tree = ResourceTreeInstance(res_dict) + self.devices_config.trees.append(tree) + + # Add to resource tracker so s2c_resource_tree can find it + try: + for plr_resource in ResourceTreeSet([tree]).to_plr_resources(): + self._resource_tracker.add_resource(plr_resource) + except Exception as ex: + self.lab_logger().warning(f"[Host Node-DeviceMgr] PLR resource registration skipped: {ex}") + + self.lab_logger().info(f"[Host Node-DeviceMgr] Device {device_id} created successfully") + return {"success": True, "device_id": device_id} + + except Exception as e: + self.lab_logger().error(f"[Host Node-DeviceMgr] Failed to create {device_id}: {e}") + self.lab_logger().error(traceback.format_exc()) + return {"success": False, "error": str(e)} + + def destroy_device(self, device_id: str) -> dict: + """Remove a root-level device from the host.""" + if not device_id: + return {"success": False, "error": "device_id required"} + + if device_id not in self.devices_names: + return {"success": False, "error": f"Device {device_id} not found"} + + if device_id == self.device_id: + return {"success": False, "error": "Cannot destroy host_node itself"} + + try: + namespace = self.devices_names[device_id] + device_key = f"{namespace}/{device_id}" + + # Remove action clients + action_prefix = f"/devices/{device_id}/" + to_remove = [k for k in self._action_clients if k.startswith(action_prefix)] + for k in to_remove: + try: + self._action_clients[k].destroy() + except Exception: + pass + del self._action_clients[k] + + # Remove from config tree (devices_config) + self.devices_config.trees = [ + t for t in self.devices_config.trees + if t.root_node.res_content.id != device_id + ] + + # Remove from resource tracker + try: + tracked = self._resource_tracker.uuid_to_resources.copy() + for uid, res in tracked.items(): + res_id = res.get("id") if isinstance(res, dict) else getattr(res, "name", None) + if res_id == device_id: + self._resource_tracker.remove_resource(res) + except Exception as ex: + self.lab_logger().warning(f"[Host Node-DeviceMgr] Resource tracker cleanup: {ex}") + + # Clean internal state + self._online_devices.discard(device_key) + self.devices_names.pop(device_id, None) + self.device_machine_names.pop(device_id, None) + self._action_value_mappings.pop(device_id, None) + + # Destroy the ROS2 node of the device + instance = self.devices_instances.pop(device_id, None) + if instance is not None: + try: + # noinspection PyProtectedMember + ros_node = getattr(instance, "_ros_node", None) + if ros_node is not None: + ros_node.destroy_node() + except Exception as e: + self.lab_logger().warning(f"[Host Node-DeviceMgr] Error destroying ROS node for {device_id}: {e}") + + self.lab_logger().info(f"[Host Node-DeviceMgr] Device {device_id} destroyed") + return {"success": True, "device_id": device_id} + + except Exception as e: + self.lab_logger().error(f"[Host Node-DeviceMgr] Failed to destroy {device_id}: {e}") + self.lab_logger().error(traceback.format_exc()) + return {"success": False, "error": str(e)} diff --git a/unilabos/ros/nodes/presets/workstation.py b/unilabos/ros/nodes/presets/workstation.py index 902e2967..7f9f2aed 100644 --- a/unilabos/ros/nodes/presets/workstation.py +++ b/unilabos/ros/nodes/presets/workstation.py @@ -20,7 +20,7 @@ from unilabos.ros.msgs.message_converter import ( convert_from_ros_msg_with_mapping, ) from unilabos.ros.nodes.base_device_node import BaseROS2DeviceNode, DeviceNodeResourceTracker, ROS2DeviceNode -from unilabos.resources.resource_tracker import ResourceTreeSet, ResourceDictInstance +from unilabos.resources.resource_tracker import ResourceDictType, ResourceTreeSet, ResourceDictInstance from unilabos.utils.type_check import get_result_info_str if TYPE_CHECKING: @@ -177,6 +177,103 @@ class ROS2WorkstationNode(BaseROS2DeviceNode): self.lab_logger().trace(f"为子设备 {device_id} 创建动作客户端: {action_name}") return d + def create_device(self, device_id: str, config: ResourceDictType) -> dict: + """Dynamically add a sub-device to this workstation.""" + if not device_id: + return {"success": False, "error": "device_id required"} + + if device_id in self.sub_devices: + return {"success": False, "error": f"Sub-device {device_id} already exists"} + + try: + from unilabos.config.config import BasicConfig + config.setdefault("id", device_id) + config.setdefault("type", "device") + config.setdefault("machine_name", BasicConfig.machine_name or "本地") + res_dict = ResourceDictInstance.get_resource_instance_from_dict(config) + + d = self.initialize_device(device_id, res_dict) + if d is None: + return {"success": False, "error": f"initialize_device returned None for {device_id}"} + + # Add to children config list + self.children.append(res_dict) + + # Add to resource tracker + try: + from unilabos.resources.resource_tracker import ResourceTreeInstance + tree = ResourceTreeInstance(res_dict) + for plr_resource in ResourceTreeSet([tree]).to_plr_resources(): + self.resource_tracker.add_resource(plr_resource) + except Exception as ex: + self.lab_logger().warning(f"[Workstation-DeviceMgr] PLR resource registration skipped: {ex}") + + self.lab_logger().info(f"[Workstation-DeviceMgr] Sub-device {device_id} created") + return {"success": True, "device_id": device_id} + + except Exception as e: + self.lab_logger().error(f"[Workstation-DeviceMgr] Failed to create {device_id}: {e}") + self.lab_logger().error(traceback.format_exc()) + return {"success": False, "error": str(e)} + + def destroy_device(self, device_id: str) -> dict: + """Dynamically remove a sub-device from this workstation.""" + if not device_id: + return {"success": False, "error": "device_id required"} + + if device_id not in self.sub_devices: + return {"success": False, "error": f"Sub-device {device_id} not found"} + + try: + # Remove from children config list + self.children = [ + c for c in self.children + if c.res_content.id != device_id + ] + + # Remove from resource tracker + try: + tracked = self.resource_tracker.uuid_to_resources.copy() + for uid, res in tracked.items(): + res_id = res.get("id") if isinstance(res, dict) else getattr(res, "name", None) + if res_id == device_id: + self.resource_tracker.remove_resource(res) + except Exception as ex: + self.lab_logger().warning(f"[Workstation-DeviceMgr] Resource tracker cleanup: {ex}") + + # Remove action clients for this sub-device + action_prefix = f"/devices/{device_id}/" + to_remove = [k for k in self._action_clients if k.startswith(action_prefix)] + for k in to_remove: + try: + self._action_clients[k].destroy() + except Exception: + pass + del self._action_clients[k] + + # Destroy the ROS2 node + instance = self.sub_devices.pop(device_id, None) + if instance is not None: + ros_node = getattr(instance, "ros_node_instance", None) + if ros_node is not None: + try: + ros_node.destroy_node() + except Exception as e: + self.lab_logger().warning( + f"[Workstation-DeviceMgr] Error destroying ROS node for {device_id}: {e}" + ) + + # Remove from communication map if present + self.communication_node_id_to_instance.pop(device_id, None) + + self.lab_logger().info(f"[Workstation-DeviceMgr] Sub-device {device_id} destroyed") + return {"success": True, "device_id": device_id} + + except Exception as e: + self.lab_logger().error(f"[Workstation-DeviceMgr] Failed to destroy {device_id}: {e}") + self.lab_logger().error(traceback.format_exc()) + return {"success": False, "error": str(e)} + def create_ros_action_server(self, action_name, action_value_mapping): """创建ROS动作服务器""" if action_name not in self.protocol_names: diff --git a/unilabos/utils/decorator.py b/unilabos/utils/decorator.py index 22a90736..15793b14 100644 --- a/unilabos/utils/decorator.py +++ b/unilabos/utils/decorator.py @@ -19,74 +19,6 @@ def singleton(cls): return get_instance -def topic_config( - period: Optional[float] = None, - print_publish: Optional[bool] = None, - qos: Optional[int] = None, -) -> Callable[[F], F]: - """ - Topic发布配置装饰器 - - 用于装饰 get_{attr_name} 方法或 @property,控制对应属性的ROS topic发布行为。 - - Args: - period: 发布周期(秒)。None 表示使用默认值 5.0 - print_publish: 是否打印发布日志。None 表示使用节点默认配置 - qos: QoS深度配置。None 表示使用默认值 10 - - Example: - class MyDriver: - # 方式1: 装饰 get_{attr_name} 方法 - @topic_config(period=1.0, print_publish=False, qos=5) - def get_temperature(self): - return self._temperature - - # 方式2: 与 @property 连用(topic_config 放在下面) - @property - @topic_config(period=0.1) - def position(self): - return self._position - - Note: - 与 @property 连用时,@topic_config 必须放在 @property 下面, - 这样装饰器执行顺序为:先 topic_config 添加配置,再 property 包装。 - """ - - def decorator(func: F) -> F: - @wraps(func) - def wrapper(*args, **kwargs): - return func(*args, **kwargs) - - # 在函数上附加配置属性 (type: ignore 用于动态属性) - wrapper._topic_period = period # type: ignore[attr-defined] - wrapper._topic_print_publish = print_publish # type: ignore[attr-defined] - wrapper._topic_qos = qos # type: ignore[attr-defined] - wrapper._has_topic_config = True # type: ignore[attr-defined] - - return wrapper # type: ignore[return-value] - - return decorator - - -def get_topic_config(func) -> dict: - """ - 获取函数上的topic配置 - - Args: - func: 被装饰的函数 - - Returns: - 包含 period, print_publish, qos 的配置字典 - """ - if hasattr(func, "_has_topic_config") and getattr(func, "_has_topic_config", False): - return { - "period": getattr(func, "_topic_period", None), - "print_publish": getattr(func, "_topic_print_publish", None), - "qos": getattr(func, "_topic_qos", None), - } - return {} - - def subscribe( topic: str, msg_type: Optional[type] = None, @@ -104,24 +36,6 @@ def subscribe( - {namespace}: 完整命名空间 (如 "/devices/pump_1") msg_type: ROS 消息类型。如果为 None,需要在回调函数的类型注解中指定 qos: QoS 深度配置,默认为 10 - - Example: - from std_msgs.msg import String, Float64 - - class MyDriver: - @subscribe(topic="/devices/{device_id}/set_speed", msg_type=Float64) - def on_speed_update(self, msg: Float64): - self._speed = msg.data - print(f"Speed updated to: {self._speed}") - - @subscribe(topic="{namespace}/command") - def on_command(self, msg: String): - # msg_type 可从类型注解推断 - self.execute_command(msg.data) - - Note: - - 回调方法的第一个参数是 self,第二个参数是收到的 ROS 消息 - - topic 中的占位符会在创建订阅时被实际值替换 """ def decorator(func: F) -> F: @@ -129,7 +43,6 @@ def subscribe( def wrapper(*args, **kwargs): return func(*args, **kwargs) - # 在函数上附加订阅配置 wrapper._subscribe_topic = topic # type: ignore[attr-defined] wrapper._subscribe_msg_type = msg_type # type: ignore[attr-defined] wrapper._subscribe_qos = qos # type: ignore[attr-defined] @@ -141,15 +54,7 @@ def subscribe( def get_subscribe_config(func) -> dict: - """ - 获取函数上的订阅配置 - - Args: - func: 被装饰的函数 - - Returns: - 包含 topic, msg_type, qos 的配置字典 - """ + """获取函数上的订阅配置 (topic, msg_type, qos)""" if hasattr(func, "_has_subscribe") and getattr(func, "_has_subscribe", False): return { "topic": getattr(func, "_subscribe_topic", None), @@ -163,9 +68,6 @@ def get_all_subscriptions(instance) -> list: """ 扫描实例的所有方法,获取带有 @subscribe 装饰器的方法及其配置 - Args: - instance: 要扫描的实例 - Returns: 包含 (method_name, method, config) 元组的列表 """ @@ -184,92 +86,14 @@ def get_all_subscriptions(instance) -> list: return subscriptions -def always_free(func: F) -> F: - """ - 标记动作为永久闲置(不受busy队列限制)的装饰器 - - 被此装饰器标记的 action 方法,在执行时不会受到设备级别的排队限制, - 任何时候请求都可以立即执行。适用于查询类、状态读取类等轻量级操作。 - - Example: - class MyDriver: - @always_free - def query_status(self, param: str): - # 这个动作可以随时执行,不需要排队 - return self._status - - def transfer(self, volume: float): - # 这个动作会按正常排队逻辑执行 - pass - - Note: - - 可以与其他装饰器组合使用,@always_free 应放在最外层 - - 仅影响 WebSocket 调度层的 busy/free 判断,不影响 ROS2 层 - """ - - @wraps(func) - def wrapper(*args, **kwargs): - return func(*args, **kwargs) - - wrapper._is_always_free = True # type: ignore[attr-defined] - - return wrapper # type: ignore[return-value] - - -def is_always_free(func) -> bool: - """ - 检查函数是否被标记为永久闲置 - - Args: - func: 被检查的函数 - - Returns: - 如果函数被 @always_free 装饰则返回 True,否则返回 False - """ - return getattr(func, "_is_always_free", False) - - -def not_action(func: F) -> F: - """ - 标记方法为非动作的装饰器 - - 用于装饰 driver 类中的方法,使其在 complete_registry 时不被识别为动作。 - 适用于辅助方法、内部工具方法等不应暴露为设备动作的公共方法。 - - Example: - class MyDriver: - @not_action - def helper_method(self): - # 这个方法不会被注册为动作 - pass - - def actual_action(self, param: str): - # 这个方法会被注册为动作 - self.helper_method() - - Note: - - 可以与其他装饰器组合使用,@not_action 应放在最外层 - - 仅影响 complete_registry 的动作识别,不影响方法的正常调用 - """ - - @wraps(func) - def wrapper(*args, **kwargs): - return func(*args, **kwargs) - - # 在函数上附加标记 - wrapper._is_not_action = True # type: ignore[attr-defined] - - return wrapper # type: ignore[return-value] - - -def is_not_action(func) -> bool: - """ - 检查函数是否被标记为非动作 - - Args: - func: 被检查的函数 - - Returns: - 如果函数被 @not_action 装饰则返回 True,否则返回 False - """ - return getattr(func, "_is_not_action", False) +# --------------------------------------------------------------------------- +# 向后兼容重导出 -- 已迁移到 unilabos.registry.decorators +# --------------------------------------------------------------------------- +from unilabos.registry.decorators import ( # noqa: E402, F401 + topic_config, + get_topic_config, + always_free, + is_always_free, + not_action, + is_not_action, +) diff --git a/unilabos/utils/environment_check.py b/unilabos/utils/environment_check.py index 73c0b10b..fa43d977 100644 --- a/unilabos/utils/environment_check.py +++ b/unilabos/utils/environment_check.py @@ -22,6 +22,7 @@ class EnvironmentChecker: # "pymodbus.framer.FramerType": "pymodbus==3.9.2", "websockets": "websockets", "msgcenterpy": "msgcenterpy", + "orjson": "orjson", "opentrons_shared_data": "opentrons_shared_data", "typing_extensions": "typing_extensions", "crcmod": "crcmod-plus", @@ -32,7 +33,7 @@ class EnvironmentChecker: # 包版本要求(包名: 最低版本) self.version_requirements = { - "msgcenterpy": "0.1.5", # msgcenterpy 最低版本要求 + "msgcenterpy": "0.1.7", # msgcenterpy 最低版本要求 } self.missing_packages = [] diff --git a/unilabos/utils/import_manager.py b/unilabos/utils/import_manager.py index dabbe1a7..a14702f0 100644 --- a/unilabos/utils/import_manager.py +++ b/unilabos/utils/import_manager.py @@ -29,7 +29,7 @@ from ast import Constant from unilabos.resources.resource_tracker import PARAM_SAMPLE_UUIDS from unilabos.utils import logger -from unilabos.utils.decorator import is_not_action, is_always_free +from unilabos.registry.decorators import is_not_action, is_always_free class ImportManager: @@ -481,10 +481,16 @@ class ImportManager: return False def _is_always_free_method(self, node: ast.FunctionDef) -> bool: - """检查是否是@always_free装饰的方法""" + """检查是否是@always_free装饰的方法,或 @action(always_free=True) 装饰的方法""" for decorator in node.decorator_list: - if isinstance(decorator, ast.Name) and decorator.id == "always_free": - return True + # 检查 @action(always_free=True) + if isinstance(decorator, ast.Call): + func = decorator.func + if isinstance(func, ast.Name) and func.id == "action": + for keyword in decorator.keywords: + if keyword.arg == "always_free": + if isinstance(keyword.value, Constant) and keyword.value.value is True: + return True return False def _get_property_name_from_setter(self, node: ast.FunctionDef) -> str: diff --git a/unilabos/utils/log.py b/unilabos/utils/log.py index be5d8c31..da085f14 100644 --- a/unilabos/utils/log.py +++ b/unilabos/utils/log.py @@ -217,7 +217,6 @@ def configure_logger(loglevel=None, working_dir=None): return log_filepath - # 配置日志系统 configure_logger() diff --git a/unilabos/utils/requirements.txt b/unilabos/utils/requirements.txt index 65d724fc..2d849b86 100644 --- a/unilabos/utils/requirements.txt +++ b/unilabos/utils/requirements.txt @@ -1,7 +1,8 @@ networkx typing_extensions websockets -msgcenterpy>=0.1.5 +msgcenterpy>=0.1.7 +orjson>=3.11 opentrons_shared_data pint fastapi