mirror of
https://github.com/deepmodeling/Uni-Lab-OS
synced 2026-03-24 09:17:39 +00:00
1038 lines
38 KiB
Python
1038 lines
38 KiB
Python
"""
|
||
AST-based Registry Scanner
|
||
|
||
Statically parse Python files to extract @device, @action, @topic_config, @resource
|
||
decorator metadata without importing any modules. This is ~100x faster than importlib
|
||
since it only reads and parses text files.
|
||
|
||
Includes a file-level cache: each file's MD5 hash, size and mtime are tracked so
|
||
unchanged files skip AST parsing entirely. The cache is persisted as JSON in the
|
||
working directory (``unilabos_data/ast_scan_cache.json``).
|
||
|
||
Usage:
|
||
from unilabos.registry.ast_registry_scanner import scan_directory
|
||
|
||
# Scan all device and resource files under a package directory
|
||
result = scan_directory("unilabos", python_path="/project")
|
||
# => {"devices": {device_id: {...}, ...}, "resources": {resource_id: {...}, ...}}
|
||
"""
|
||
|
||
import ast
|
||
import hashlib
|
||
import json
|
||
import time
|
||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||
from pathlib import Path
|
||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Constants
|
||
# ---------------------------------------------------------------------------
|
||
|
||
MAX_SCAN_DEPTH = 10 # 最大目录递归深度
|
||
MAX_SCAN_FILES = 1000 # 最大扫描文件数量
|
||
_CACHE_VERSION = 1 # 缓存格式版本号,格式变更时递增
|
||
|
||
# 合法的装饰器来源模块
|
||
_REGISTRY_DECORATOR_MODULE = "unilabos.registry.decorators"
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# File-level cache helpers
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def _file_fingerprint(filepath: Path) -> Dict[str, Any]:
|
||
"""Return size, mtime and MD5 hash for *filepath*."""
|
||
stat = filepath.stat()
|
||
md5 = hashlib.md5(filepath.read_bytes()).hexdigest()
|
||
return {"size": stat.st_size, "mtime": stat.st_mtime, "md5": md5}
|
||
|
||
|
||
def load_scan_cache(cache_path: Optional[Path]) -> Dict[str, Any]:
|
||
"""Load the AST scan cache from *cache_path*. Returns empty structure on any error."""
|
||
if cache_path is None or not cache_path.is_file():
|
||
return {"version": _CACHE_VERSION, "files": {}}
|
||
try:
|
||
raw = cache_path.read_text(encoding="utf-8")
|
||
data = json.loads(raw)
|
||
if data.get("version") != _CACHE_VERSION:
|
||
return {"version": _CACHE_VERSION, "files": {}}
|
||
return data
|
||
except Exception:
|
||
return {"version": _CACHE_VERSION, "files": {}}
|
||
|
||
|
||
def save_scan_cache(cache_path: Optional[Path], cache: Dict[str, Any]) -> None:
|
||
"""Persist *cache* to *cache_path* (atomic-ish via temp file)."""
|
||
if cache_path is None:
|
||
return
|
||
try:
|
||
cache_path.parent.mkdir(parents=True, exist_ok=True)
|
||
tmp = cache_path.with_suffix(".tmp")
|
||
tmp.write_text(json.dumps(cache, ensure_ascii=False, indent=1), encoding="utf-8")
|
||
tmp.replace(cache_path)
|
||
except Exception:
|
||
pass
|
||
|
||
|
||
def _is_cache_hit(entry: Dict[str, Any], fp: Dict[str, Any]) -> bool:
|
||
"""Check if a cache entry matches the current file fingerprint."""
|
||
return (
|
||
entry.get("md5") == fp["md5"]
|
||
and entry.get("size") == fp["size"]
|
||
)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Public API
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def _collect_py_files(
|
||
root_dir: Path,
|
||
max_depth: int = MAX_SCAN_DEPTH,
|
||
max_files: int = MAX_SCAN_FILES,
|
||
exclude_files: Optional[set] = None,
|
||
) -> List[Path]:
|
||
"""
|
||
收集 root_dir 下的 .py 文件,限制最大递归深度和文件数量。
|
||
|
||
Args:
|
||
root_dir: 扫描根目录
|
||
max_depth: 最大递归深度 (默认 10 层)
|
||
max_files: 最大文件数量 (默认 1000 个)
|
||
exclude_files: 要排除的文件名集合 (如 {"lab_resources.py"})
|
||
|
||
Returns:
|
||
排序后的 .py 文件路径列表
|
||
"""
|
||
result: List[Path] = []
|
||
_exclude = exclude_files or set()
|
||
|
||
def _walk(dir_path: Path, depth: int):
|
||
if depth > max_depth or len(result) >= max_files:
|
||
return
|
||
try:
|
||
entries = sorted(dir_path.iterdir())
|
||
except (PermissionError, OSError):
|
||
return
|
||
for entry in entries:
|
||
if len(result) >= max_files:
|
||
return
|
||
if entry.is_file() and entry.suffix == ".py" and not entry.name.startswith("__"):
|
||
if entry.name not in _exclude:
|
||
result.append(entry)
|
||
elif entry.is_dir() and not entry.name.startswith(("__", ".")):
|
||
_walk(entry, depth + 1)
|
||
|
||
_walk(root_dir, 0)
|
||
return result
|
||
|
||
|
||
def scan_directory(
|
||
root_dir: Union[str, Path],
|
||
python_path: Union[str, Path] = "",
|
||
max_depth: int = MAX_SCAN_DEPTH,
|
||
max_files: int = MAX_SCAN_FILES,
|
||
executor: ThreadPoolExecutor = None,
|
||
exclude_files: Optional[set] = None,
|
||
cache: Optional[Dict[str, Any]] = None,
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
Recursively scan .py files under *root_dir* for @device and @resource
|
||
decorated classes/functions.
|
||
|
||
Uses a thread pool to parse files in parallel for faster I/O.
|
||
When *cache* is provided, files whose fingerprint (MD5 + size) hasn't
|
||
changed since the last scan are served from cache without re-parsing.
|
||
|
||
Returns:
|
||
{"devices": {device_id: meta, ...}, "resources": {resource_id: meta, ...}}
|
||
|
||
Args:
|
||
root_dir: Directory to scan (e.g. "unilabos/devices").
|
||
python_path: The directory that should be on sys.path, i.e. the parent
|
||
of the top-level package. Module paths are derived as
|
||
filepath relative to this directory. If empty, defaults to
|
||
root_dir's parent.
|
||
max_depth: Maximum directory recursion depth (default 10).
|
||
max_files: Maximum number of .py files to scan (default 1000).
|
||
executor: Shared ThreadPoolExecutor (required). The caller manages its
|
||
lifecycle.
|
||
exclude_files: 要排除的文件名集合 (如 {"lab_resources.py"})
|
||
cache: Mutable cache dict (``load_scan_cache()`` result). Hits are read
|
||
from here; misses are written back so the caller can persist later.
|
||
"""
|
||
if executor is None:
|
||
raise ValueError("executor is required and must not be None")
|
||
|
||
root_dir = Path(root_dir).resolve()
|
||
if not python_path:
|
||
python_path = root_dir.parent
|
||
else:
|
||
python_path = Path(python_path).resolve()
|
||
|
||
# --- Collect files (depth/count limited) ---
|
||
py_files = _collect_py_files(root_dir, max_depth=max_depth, max_files=max_files, exclude_files=exclude_files)
|
||
|
||
cache_files: Dict[str, Any] = cache.get("files", {}) if cache else {}
|
||
|
||
# --- Parallel scan (with cache fast-path) ---
|
||
devices: Dict[str, dict] = {}
|
||
resources: Dict[str, dict] = {}
|
||
cache_hits = 0
|
||
cache_misses = 0
|
||
|
||
def _parse_one_cached(py_file: Path) -> Tuple[List[dict], List[dict], bool]:
|
||
"""Returns (devices, resources, was_cache_hit)."""
|
||
key = str(py_file)
|
||
try:
|
||
fp = _file_fingerprint(py_file)
|
||
except OSError:
|
||
return [], [], False
|
||
|
||
cached_entry = cache_files.get(key)
|
||
if cached_entry and _is_cache_hit(cached_entry, fp):
|
||
return cached_entry.get("devices", []), cached_entry.get("resources", []), True
|
||
|
||
try:
|
||
devs, ress = _parse_file(py_file, python_path)
|
||
except (SyntaxError, Exception):
|
||
devs, ress = [], []
|
||
|
||
cache_files[key] = {
|
||
"md5": fp["md5"],
|
||
"size": fp["size"],
|
||
"mtime": fp["mtime"],
|
||
"devices": devs,
|
||
"resources": ress,
|
||
}
|
||
return devs, ress, False
|
||
|
||
def _collect_results(futures_dict: Dict):
|
||
nonlocal cache_hits, cache_misses
|
||
for future in as_completed(futures_dict):
|
||
devs, ress, hit = future.result()
|
||
if hit:
|
||
cache_hits += 1
|
||
else:
|
||
cache_misses += 1
|
||
for dev in devs:
|
||
device_id = dev.get("device_id")
|
||
if device_id:
|
||
if device_id in devices:
|
||
existing = devices[device_id].get("file_path", "?")
|
||
new_file = dev.get("file_path", "?")
|
||
raise ValueError(
|
||
f"@device id 重复: '{device_id}' 同时出现在 {existing} 和 {new_file}"
|
||
)
|
||
devices[device_id] = dev
|
||
for res in ress:
|
||
resource_id = res.get("resource_id")
|
||
if resource_id:
|
||
if resource_id in resources:
|
||
existing = resources[resource_id].get("file_path", "?")
|
||
new_file = res.get("file_path", "?")
|
||
raise ValueError(
|
||
f"@resource id 重复: '{resource_id}' 同时出现在 {existing} 和 {new_file}"
|
||
)
|
||
resources[resource_id] = res
|
||
|
||
futures = {executor.submit(_parse_one_cached, f): f for f in py_files}
|
||
_collect_results(futures)
|
||
|
||
if cache is not None:
|
||
cache["files"] = cache_files
|
||
|
||
return {
|
||
"devices": devices,
|
||
"resources": resources,
|
||
"_cache_stats": {"hits": cache_hits, "misses": cache_misses, "total": len(py_files)},
|
||
}
|
||
|
||
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# File-level parsing
|
||
# ---------------------------------------------------------------------------
|
||
|
||
# 已知继承自 rclpy.node.Node 的基类名 (用于 AST 静态检测)
|
||
_KNOWN_ROS2_BASE_CLASSES = {"Node", "BaseROS2DeviceNode"}
|
||
_KNOWN_ROS2_MODULES = {"rclpy", "rclpy.node"}
|
||
|
||
|
||
def _detect_class_type(cls_node: ast.ClassDef, import_map: Dict[str, str]) -> str:
|
||
"""
|
||
检测类是否继承自 rclpy Node,返回 'ros2' 或 'python'。
|
||
|
||
通过检查类的基类名称和 import_map 中的模块路径来判断:
|
||
1. 基类名在已知 ROS2 基类集合中
|
||
2. 基类在 import_map 中解析到 rclpy 相关模块
|
||
3. 基类在 import_map 中解析到 BaseROS2DeviceNode
|
||
"""
|
||
for base in cls_node.bases:
|
||
base_name = ""
|
||
if isinstance(base, ast.Name):
|
||
base_name = base.id
|
||
elif isinstance(base, ast.Attribute):
|
||
base_name = base.attr
|
||
elif isinstance(base, ast.Subscript) and isinstance(base.value, ast.Name):
|
||
# Generic[T] 形式,如 BaseROS2DeviceNode[SomeType]
|
||
base_name = base.value.id
|
||
|
||
if not base_name:
|
||
continue
|
||
|
||
# 直接匹配已知 ROS2 基类名
|
||
if base_name in _KNOWN_ROS2_BASE_CLASSES:
|
||
return "ros2"
|
||
|
||
# 通过 import_map 检查模块路径
|
||
module_path = import_map.get(base_name, "")
|
||
if any(mod in module_path for mod in _KNOWN_ROS2_MODULES):
|
||
return "ros2"
|
||
if "BaseROS2DeviceNode" in module_path:
|
||
return "ros2"
|
||
|
||
return "python"
|
||
|
||
|
||
def _parse_file(
|
||
filepath: Path,
|
||
python_path: Path,
|
||
) -> Tuple[List[dict], List[dict]]:
|
||
"""
|
||
Parse a single .py file using ast and extract all @device-decorated classes
|
||
and @resource-decorated functions/classes.
|
||
|
||
Returns:
|
||
(devices, resources) -- two lists of metadata dicts.
|
||
"""
|
||
source = filepath.read_text(encoding="utf-8", errors="replace")
|
||
tree = ast.parse(source, filename=str(filepath))
|
||
|
||
# Derive module path from file path
|
||
module_path = _filepath_to_module(filepath, python_path)
|
||
|
||
# Build import map from the file (includes same-file class defs)
|
||
import_map = _collect_imports(tree, module_path)
|
||
|
||
devices: List[dict] = []
|
||
resources: List[dict] = []
|
||
|
||
for node in ast.iter_child_nodes(tree):
|
||
# --- @device on classes ---
|
||
if isinstance(node, ast.ClassDef):
|
||
device_decorator = _find_decorator(node, "device")
|
||
if device_decorator is not None and _is_registry_decorator("device", import_map):
|
||
device_args = _extract_decorator_args(device_decorator, import_map)
|
||
class_body = _extract_class_body(node, import_map)
|
||
|
||
# Support ids + id_meta (multi-device) or id (single device)
|
||
device_ids: List[str] = []
|
||
if device_args.get("ids") is not None:
|
||
device_ids = list(device_args["ids"])
|
||
else:
|
||
did = device_args.get("id") or device_args.get("device_id")
|
||
device_ids = [did] if did else [f"{module_path}:{node.name}"]
|
||
|
||
id_meta = device_args.get("id_meta") or {}
|
||
base_meta = {
|
||
"class_name": node.name,
|
||
"module": f"{module_path}:{node.name}",
|
||
"file_path": str(filepath).replace("\\", "/"),
|
||
"category": device_args.get("category", []),
|
||
"description": device_args.get("description", ""),
|
||
"display_name": device_args.get("display_name", ""),
|
||
"icon": device_args.get("icon", ""),
|
||
"version": device_args.get("version", "1.0.0"),
|
||
"device_type": _detect_class_type(node, import_map),
|
||
"handles": device_args.get("handles", []),
|
||
"model": device_args.get("model"),
|
||
"hardware_interface": device_args.get("hardware_interface"),
|
||
"actions": class_body.get("actions", {}),
|
||
"status_properties": class_body.get("status_properties", {}),
|
||
"init_params": class_body.get("init_params", []),
|
||
"auto_methods": class_body.get("auto_methods", {}),
|
||
"import_map": import_map,
|
||
}
|
||
for did in device_ids:
|
||
meta = dict(base_meta)
|
||
meta["device_id"] = did
|
||
overrides = id_meta.get(did, {})
|
||
for key in ("handles", "description", "icon", "model", "hardware_interface"):
|
||
if key in overrides:
|
||
meta[key] = overrides[key]
|
||
devices.append(meta)
|
||
|
||
# --- @resource on classes ---
|
||
resource_decorator = _find_decorator(node, "resource")
|
||
if resource_decorator is not None and _is_registry_decorator("resource", import_map):
|
||
res_meta = _extract_resource_meta(
|
||
resource_decorator, node.name, module_path, filepath, import_map,
|
||
is_function=False,
|
||
init_node=_find_init_in_class(node),
|
||
)
|
||
resources.append(res_meta)
|
||
|
||
# --- @resource on module-level functions ---
|
||
elif isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||
resource_decorator = _find_method_decorator(node, "resource")
|
||
if resource_decorator is not None and _is_registry_decorator("resource", import_map):
|
||
res_meta = _extract_resource_meta(
|
||
resource_decorator, node.name, module_path, filepath, import_map,
|
||
is_function=True,
|
||
func_node=node,
|
||
)
|
||
resources.append(res_meta)
|
||
|
||
return devices, resources
|
||
|
||
|
||
def _find_init_in_class(cls_node: ast.ClassDef) -> Optional[ast.FunctionDef]:
|
||
"""Find __init__ method in a class."""
|
||
for item in cls_node.body:
|
||
if isinstance(item, ast.FunctionDef) and item.name == "__init__":
|
||
return item
|
||
return None
|
||
|
||
|
||
def _extract_resource_meta(
|
||
decorator_node: Union[ast.Call, ast.Name],
|
||
name: str,
|
||
module_path: str,
|
||
filepath: Path,
|
||
import_map: Dict[str, str],
|
||
is_function: bool = False,
|
||
func_node: Optional[Union[ast.FunctionDef, ast.AsyncFunctionDef]] = None,
|
||
init_node: Optional[ast.FunctionDef] = None,
|
||
) -> dict:
|
||
"""
|
||
Extract resource metadata from a @resource decorator on a function or class.
|
||
"""
|
||
res_args = _extract_decorator_args(decorator_node, import_map)
|
||
|
||
resource_id = res_args.get("id") or res_args.get("resource_id")
|
||
if resource_id is None:
|
||
resource_id = f"{module_path}:{name}"
|
||
|
||
# Extract init/function params
|
||
init_params: List[dict] = []
|
||
if is_function and func_node is not None:
|
||
init_params = _extract_method_params(func_node, import_map)
|
||
elif not is_function and init_node is not None:
|
||
init_params = _extract_method_params(init_node, import_map)
|
||
|
||
return {
|
||
"resource_id": resource_id,
|
||
"name": name,
|
||
"module": f"{module_path}:{name}",
|
||
"file_path": str(filepath).replace("\\", "/"),
|
||
"is_function": is_function,
|
||
"category": res_args.get("category", []),
|
||
"description": res_args.get("description", ""),
|
||
"icon": res_args.get("icon", ""),
|
||
"version": res_args.get("version", "1.0.0"),
|
||
"class_type": res_args.get("class_type", "pylabrobot"),
|
||
"handles": res_args.get("handles", []),
|
||
"model": res_args.get("model"),
|
||
"init_params": init_params,
|
||
}
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Import map collection
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def _collect_imports(tree: ast.Module, module_path: str = "") -> Dict[str, str]:
|
||
"""
|
||
Walk all Import/ImportFrom nodes in the AST tree, build a mapping from
|
||
local name to fully-qualified import path.
|
||
|
||
Also includes top-level class/function definitions from the same file,
|
||
so that same-file TypedDict / Enum / dataclass references can be resolved.
|
||
|
||
Returns:
|
||
{"SendCmd": "unilabos_msgs.action:SendCmd",
|
||
"StrSingleInput": "unilabos_msgs.action:StrSingleInput",
|
||
"InputHandle": "unilabos.registry.decorators:InputHandle",
|
||
"SetLiquidReturn": "unilabos.devices.liquid_handling.liquid_handler_abstract:SetLiquidReturn",
|
||
...}
|
||
"""
|
||
import_map: Dict[str, str] = {}
|
||
|
||
for node in ast.walk(tree):
|
||
if isinstance(node, ast.ImportFrom):
|
||
module = node.module or ""
|
||
for alias in node.names:
|
||
local_name = alias.asname if alias.asname else alias.name
|
||
import_map[local_name] = f"{module}:{alias.name}"
|
||
elif isinstance(node, ast.Import):
|
||
for alias in node.names:
|
||
local_name = alias.asname if alias.asname else alias.name
|
||
import_map[local_name] = alias.name
|
||
|
||
# 同文件顶层 class / function 定义
|
||
if module_path:
|
||
for node in tree.body:
|
||
if isinstance(node, ast.ClassDef):
|
||
import_map.setdefault(node.name, f"{module_path}:{node.name}")
|
||
elif isinstance(node, ast.FunctionDef) or isinstance(node, ast.AsyncFunctionDef):
|
||
import_map.setdefault(node.name, f"{module_path}:{node.name}")
|
||
elif isinstance(node, ast.Assign):
|
||
# 顶层赋值 (如 MotorAxis = Enum(...))
|
||
for target in node.targets:
|
||
if isinstance(target, ast.Name):
|
||
import_map.setdefault(target.id, f"{module_path}:{target.id}")
|
||
|
||
return import_map
|
||
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Decorator finding & argument extraction
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def _find_decorator(
|
||
node: Union[ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef],
|
||
decorator_name: str,
|
||
) -> Optional[ast.Call]:
|
||
"""
|
||
Find a specific decorator call on a class or function definition.
|
||
|
||
Handles both:
|
||
- @device(...) -> ast.Call with func=ast.Name(id="device")
|
||
- @module.device(...) -> ast.Call with func=ast.Attribute(attr="device")
|
||
"""
|
||
for dec in node.decorator_list:
|
||
if isinstance(dec, ast.Call):
|
||
if isinstance(dec.func, ast.Name) and dec.func.id == decorator_name:
|
||
return dec
|
||
if isinstance(dec.func, ast.Attribute) and dec.func.attr == decorator_name:
|
||
return dec
|
||
elif isinstance(dec, ast.Name) and dec.id == decorator_name:
|
||
# @device without parens (unlikely but handle it)
|
||
return None # Can't extract args from bare decorator
|
||
return None
|
||
|
||
|
||
def _find_method_decorator(func_node: ast.FunctionDef, decorator_name: str) -> Optional[Union[ast.Call, ast.Name]]:
|
||
"""Find a decorator on a method."""
|
||
for dec in func_node.decorator_list:
|
||
if isinstance(dec, ast.Call):
|
||
if isinstance(dec.func, ast.Name) and dec.func.id == decorator_name:
|
||
return dec
|
||
if isinstance(dec.func, ast.Attribute) and dec.func.attr == decorator_name:
|
||
return dec
|
||
elif isinstance(dec, ast.Name) and dec.id == decorator_name:
|
||
# @action without parens, or @topic_config without parens
|
||
return dec
|
||
return None
|
||
|
||
|
||
def _has_decorator(func_node: ast.FunctionDef, decorator_name: str) -> bool:
|
||
"""Check if a method has a specific decorator (with or without call)."""
|
||
for dec in func_node.decorator_list:
|
||
if isinstance(dec, ast.Call):
|
||
if isinstance(dec.func, ast.Name) and dec.func.id == decorator_name:
|
||
return True
|
||
if isinstance(dec.func, ast.Attribute) and dec.func.attr == decorator_name:
|
||
return True
|
||
elif isinstance(dec, ast.Name) and dec.id == decorator_name:
|
||
return True
|
||
return False
|
||
|
||
|
||
def _is_registry_decorator(name: str, import_map: Dict[str, str]) -> bool:
|
||
"""Check that *name* was imported from ``unilabos.registry.decorators``."""
|
||
source = import_map.get(name, "")
|
||
return _REGISTRY_DECORATOR_MODULE in source
|
||
|
||
|
||
def _extract_decorator_args(
|
||
node: Union[ast.Call, ast.Name],
|
||
import_map: Dict[str, str],
|
||
) -> dict:
|
||
"""
|
||
Extract keyword arguments from a decorator call AST node.
|
||
|
||
Resolves Name references (e.g. SendCmd, Side.NORTH) via import_map.
|
||
Handles literal values (strings, ints, bools, lists, dicts, None).
|
||
"""
|
||
if isinstance(node, ast.Name):
|
||
return {} # Bare decorator, no args
|
||
if not isinstance(node, ast.Call):
|
||
return {}
|
||
|
||
result: dict = {}
|
||
|
||
for kw in node.keywords:
|
||
if kw.arg is None:
|
||
continue # **kwargs, skip
|
||
result[kw.arg] = _ast_node_to_value(kw.value, import_map)
|
||
|
||
return result
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# AST node value conversion
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def _ast_node_to_value(node: ast.expr, import_map: Dict[str, str]) -> Any:
|
||
"""
|
||
Convert an AST expression node to a Python value.
|
||
|
||
Handles:
|
||
- Literals (str, int, float, bool, None)
|
||
- Lists, Tuples, Dicts, Sets
|
||
- Name references (e.g. SendCmd -> resolved via import_map)
|
||
- Attribute access (e.g. Side.NORTH -> resolved)
|
||
- Function/class calls (e.g. InputHandle(...) -> structured dict)
|
||
- Unary operators (e.g. -1)
|
||
"""
|
||
# --- Constant (str, int, float, bool, None) ---
|
||
if isinstance(node, ast.Constant):
|
||
return node.value
|
||
|
||
# --- Name (e.g. SendCmd, True, False, None) ---
|
||
if isinstance(node, ast.Name):
|
||
return _resolve_name(node.id, import_map)
|
||
|
||
# --- Attribute (e.g. Side.NORTH, DataSource.HANDLE) ---
|
||
if isinstance(node, ast.Attribute):
|
||
return _resolve_attribute(node, import_map)
|
||
|
||
# --- List ---
|
||
if isinstance(node, ast.List):
|
||
return [_ast_node_to_value(elt, import_map) for elt in node.elts]
|
||
|
||
# --- Tuple ---
|
||
if isinstance(node, ast.Tuple):
|
||
return [_ast_node_to_value(elt, import_map) for elt in node.elts]
|
||
|
||
# --- Dict ---
|
||
if isinstance(node, ast.Dict):
|
||
result = {}
|
||
for k, v in zip(node.keys, node.values):
|
||
if k is None:
|
||
continue # **kwargs spread
|
||
key = _ast_node_to_value(k, import_map)
|
||
val = _ast_node_to_value(v, import_map)
|
||
result[key] = val
|
||
return result
|
||
|
||
# --- Set ---
|
||
if isinstance(node, ast.Set):
|
||
return [_ast_node_to_value(elt, import_map) for elt in node.elts]
|
||
|
||
# --- Call (e.g. InputHandle(...), OutputHandle(...)) ---
|
||
if isinstance(node, ast.Call):
|
||
return _ast_call_to_value(node, import_map)
|
||
|
||
# --- UnaryOp (e.g. -1, -0.5) ---
|
||
if isinstance(node, ast.UnaryOp):
|
||
if isinstance(node.op, ast.USub):
|
||
operand = _ast_node_to_value(node.operand, import_map)
|
||
if isinstance(operand, (int, float)):
|
||
return -operand
|
||
elif isinstance(node.op, ast.Not):
|
||
operand = _ast_node_to_value(node.operand, import_map)
|
||
return not operand
|
||
|
||
# --- BinOp (e.g. "a" + "b") ---
|
||
if isinstance(node, ast.BinOp):
|
||
if isinstance(node.op, ast.Add):
|
||
left = _ast_node_to_value(node.left, import_map)
|
||
right = _ast_node_to_value(node.right, import_map)
|
||
if isinstance(left, str) and isinstance(right, str):
|
||
return left + right
|
||
|
||
# --- JoinedStr (f-string) ---
|
||
if isinstance(node, ast.JoinedStr):
|
||
return "<f-string>"
|
||
|
||
# Fallback: return the AST dump as a string marker
|
||
return f"<ast:{type(node).__name__}>"
|
||
|
||
|
||
def _resolve_name(name: str, import_map: Dict[str, str]) -> str:
|
||
"""
|
||
Resolve a bare Name reference via import_map.
|
||
|
||
E.g. "SendCmd" -> "unilabos_msgs.action:SendCmd"
|
||
"True" -> True (handled by ast.Constant in Python 3.8+)
|
||
"""
|
||
if name in import_map:
|
||
return import_map[name]
|
||
# Fallback: return the name as-is
|
||
return name
|
||
|
||
|
||
def _resolve_attribute(node: ast.Attribute, import_map: Dict[str, str]) -> str:
|
||
"""
|
||
Resolve an attribute access like Side.NORTH or DataSource.HANDLE.
|
||
|
||
Returns a string like "NORTH" for enum values, or
|
||
"module.path:Class.attr" for imported references.
|
||
"""
|
||
# Get the full dotted path
|
||
parts = []
|
||
current = node
|
||
while isinstance(current, ast.Attribute):
|
||
parts.append(current.attr)
|
||
current = current.value
|
||
if isinstance(current, ast.Name):
|
||
parts.append(current.id)
|
||
|
||
parts.reverse()
|
||
# parts = ["Side", "NORTH"] or ["DataSource", "HANDLE"]
|
||
|
||
if len(parts) >= 2:
|
||
base = parts[0]
|
||
attr = ".".join(parts[1:])
|
||
|
||
# If the base is an imported name, resolve it
|
||
if base in import_map:
|
||
return f"{import_map[base]}.{attr}"
|
||
|
||
# For known enum-like patterns, return just the value
|
||
# e.g. Side.NORTH -> "NORTH"
|
||
if base in ("Side", "DataSource"):
|
||
return parts[-1]
|
||
|
||
return ".".join(parts)
|
||
|
||
|
||
def _ast_call_to_value(node: ast.Call, import_map: Dict[str, str]) -> dict:
|
||
"""
|
||
Convert a function/class call like InputHandle(key="in", ...) to a structured dict.
|
||
|
||
Returns:
|
||
{"_call": "unilabos.registry.decorators:InputHandle",
|
||
"key": "in", "data_type": "fluid", ...}
|
||
"""
|
||
# Resolve the call target
|
||
if isinstance(node.func, ast.Name):
|
||
call_name = _resolve_name(node.func.id, import_map)
|
||
elif isinstance(node.func, ast.Attribute):
|
||
call_name = _resolve_attribute(node.func, import_map)
|
||
else:
|
||
call_name = "<unknown>"
|
||
|
||
result: dict = {"_call": call_name}
|
||
|
||
# Positional args
|
||
for i, arg in enumerate(node.args):
|
||
result[f"_pos_{i}"] = _ast_node_to_value(arg, import_map)
|
||
|
||
# Keyword args
|
||
for kw in node.keywords:
|
||
if kw.arg is None:
|
||
continue
|
||
result[kw.arg] = _ast_node_to_value(kw.value, import_map)
|
||
|
||
return result
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Class body extraction
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def _extract_class_body(
|
||
cls_node: ast.ClassDef,
|
||
import_map: Dict[str, str],
|
||
) -> dict:
|
||
"""
|
||
Walk the class body to extract:
|
||
- @action-decorated methods
|
||
- @property with @topic_config (status properties)
|
||
- get_* methods with @topic_config
|
||
- __init__ parameters
|
||
- Public methods without @action (auto-actions)
|
||
"""
|
||
result: dict = {
|
||
"actions": {}, # method_name -> action_info
|
||
"status_properties": {}, # prop_name -> status_info
|
||
"init_params": [], # [{"name": ..., "type": ..., "default": ...}, ...]
|
||
"auto_methods": {}, # method_name -> method_info (no @action decorator)
|
||
}
|
||
|
||
for item in cls_node.body:
|
||
if not isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||
continue
|
||
|
||
method_name = item.name
|
||
|
||
# --- __init__ ---
|
||
if method_name == "__init__":
|
||
result["init_params"] = _extract_method_params(item, import_map)
|
||
continue
|
||
|
||
# --- Skip private/dunder ---
|
||
if method_name.startswith("_"):
|
||
continue
|
||
|
||
# --- Check for @property or @topic_config → status property ---
|
||
is_property = _has_decorator(item, "property")
|
||
has_topic = (
|
||
_has_decorator(item, "topic_config")
|
||
and _is_registry_decorator("topic_config", import_map)
|
||
)
|
||
|
||
if is_property or has_topic:
|
||
topic_args = {}
|
||
topic_dec = _find_method_decorator(item, "topic_config")
|
||
if topic_dec is not None:
|
||
topic_args = _extract_decorator_args(topic_dec, import_map)
|
||
|
||
return_type = _get_annotation_str(item.returns, import_map)
|
||
# 非 @property 的 @topic_config 方法,用去掉 get_ 前缀的名称
|
||
prop_name = method_name[4:] if method_name.startswith("get_") and not is_property else method_name
|
||
|
||
result["status_properties"][prop_name] = {
|
||
"name": prop_name,
|
||
"return_type": return_type,
|
||
"is_property": is_property,
|
||
"topic_config": topic_args if topic_args else None,
|
||
}
|
||
continue
|
||
|
||
# --- Check for @action ---
|
||
action_dec = _find_method_decorator(item, "action")
|
||
if action_dec is not None and _is_registry_decorator("action", import_map):
|
||
action_args = _extract_decorator_args(action_dec, import_map)
|
||
# 补全 @action 装饰器的默认值(与 decorators.py 中 action() 签名一致)
|
||
action_args.setdefault("action_type", None)
|
||
action_args.setdefault("goal", {})
|
||
action_args.setdefault("feedback", {})
|
||
action_args.setdefault("result", {})
|
||
action_args.setdefault("handles", {})
|
||
action_args.setdefault("goal_default", {})
|
||
action_args.setdefault("placeholder_keys", {})
|
||
action_args.setdefault("always_free", False)
|
||
action_args.setdefault("is_protocol", False)
|
||
action_args.setdefault("description", "")
|
||
action_args.setdefault("auto_prefix", False)
|
||
action_args.setdefault("parent", False)
|
||
method_params = _extract_method_params(item, import_map)
|
||
return_type = _get_annotation_str(item.returns, import_map)
|
||
is_async = isinstance(item, ast.AsyncFunctionDef)
|
||
method_doc = ast.get_docstring(item)
|
||
|
||
result["actions"][method_name] = {
|
||
"action_args": action_args,
|
||
"params": method_params,
|
||
"return_type": return_type,
|
||
"is_async": is_async,
|
||
"docstring": method_doc,
|
||
}
|
||
continue
|
||
|
||
# --- Check for @not_action ---
|
||
if _has_decorator(item, "not_action") and _is_registry_decorator("not_action", import_map):
|
||
continue
|
||
|
||
# --- get_ 前缀且无额外参数(仅 self)→ status property ---
|
||
if method_name.startswith("get_"):
|
||
real_args = [a for a in item.args.args if a.arg != "self"]
|
||
if len(real_args) == 0:
|
||
prop_name = method_name[4:]
|
||
return_type = _get_annotation_str(item.returns, import_map)
|
||
if prop_name not in result["status_properties"]:
|
||
result["status_properties"][prop_name] = {
|
||
"name": prop_name,
|
||
"return_type": return_type,
|
||
"is_property": False,
|
||
"topic_config": None,
|
||
}
|
||
continue
|
||
|
||
# --- Public method without @action => auto-action ---
|
||
if method_name in ("post_init", "__str__", "__repr__"):
|
||
continue
|
||
|
||
method_params = _extract_method_params(item, import_map)
|
||
return_type = _get_annotation_str(item.returns, import_map)
|
||
is_async = isinstance(item, ast.AsyncFunctionDef)
|
||
method_doc = ast.get_docstring(item)
|
||
|
||
auto_entry: dict = {
|
||
"params": method_params,
|
||
"return_type": return_type,
|
||
"is_async": is_async,
|
||
"docstring": method_doc,
|
||
}
|
||
if _has_decorator(item, "always_free") and _is_registry_decorator("always_free", import_map):
|
||
auto_entry["always_free"] = True
|
||
result["auto_methods"][method_name] = auto_entry
|
||
|
||
return result
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Method parameter extraction
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
_PARAM_SKIP_NAMES = frozenset({"sample_uuids"})
|
||
|
||
|
||
def _extract_method_params(
|
||
func_node: Union[ast.FunctionDef, ast.AsyncFunctionDef],
|
||
import_map: Optional[Dict[str, str]] = None,
|
||
) -> List[dict]:
|
||
"""
|
||
Extract parameters from a class method definition.
|
||
|
||
Automatically skips the first positional argument (self / cls) and any
|
||
domain-specific names listed in ``_PARAM_SKIP_NAMES``.
|
||
|
||
Returns:
|
||
[{"name": "position", "type": "str", "default": None, "required": True}, ...]
|
||
"""
|
||
if import_map is None:
|
||
import_map = {}
|
||
params: List[dict] = []
|
||
|
||
args = func_node.args
|
||
|
||
# Skip the first positional arg (self/cls) -- always present for class methods
|
||
# noinspection PyUnresolvedReferences
|
||
positional_args = args.args[1:] if args.args else []
|
||
|
||
# defaults align to the *end* of the args list; offset must account for the skipped arg
|
||
num_args = len(args.args)
|
||
num_defaults = len(args.defaults)
|
||
first_default_idx = num_args - num_defaults
|
||
|
||
for i, arg in enumerate(positional_args, start=1):
|
||
name = arg.arg
|
||
if name in _PARAM_SKIP_NAMES:
|
||
continue
|
||
|
||
param_info: dict = {"name": name}
|
||
|
||
# Type annotation
|
||
if arg.annotation:
|
||
param_info["type"] = _get_annotation_str(arg.annotation, import_map)
|
||
else:
|
||
param_info["type"] = ""
|
||
|
||
# Default value
|
||
default_idx = i - first_default_idx
|
||
if 0 <= default_idx < len(args.defaults):
|
||
default_val = _ast_node_to_value(args.defaults[default_idx], import_map)
|
||
param_info["default"] = default_val
|
||
param_info["required"] = False
|
||
else:
|
||
param_info["default"] = None
|
||
param_info["required"] = True
|
||
|
||
params.append(param_info)
|
||
|
||
# Keyword-only arguments (self/cls never appear here)
|
||
for i, arg in enumerate(args.kwonlyargs):
|
||
name = arg.arg
|
||
if name in _PARAM_SKIP_NAMES:
|
||
continue
|
||
|
||
param_info: dict = {"name": name}
|
||
|
||
if arg.annotation:
|
||
param_info["type"] = _get_annotation_str(arg.annotation, import_map)
|
||
else:
|
||
param_info["type"] = ""
|
||
|
||
if i < len(args.kw_defaults) and args.kw_defaults[i] is not None:
|
||
param_info["default"] = _ast_node_to_value(args.kw_defaults[i], import_map)
|
||
param_info["required"] = False
|
||
else:
|
||
param_info["default"] = None
|
||
param_info["required"] = True
|
||
|
||
params.append(param_info)
|
||
|
||
return params
|
||
|
||
|
||
def _get_annotation_str(node: Optional[ast.expr], import_map: Dict[str, str]) -> str:
|
||
"""Convert a type annotation AST node to a string representation.
|
||
|
||
保持类型字符串为合法 Python 表达式 (可被 ast.parse 解析)。
|
||
不在此处做 import_map 替换 — 由上层在需要时通过 import_map 解析。
|
||
"""
|
||
if node is None:
|
||
return ""
|
||
|
||
if isinstance(node, ast.Constant):
|
||
return str(node.value)
|
||
|
||
if isinstance(node, ast.Name):
|
||
return node.id
|
||
|
||
if isinstance(node, ast.Attribute):
|
||
parts = []
|
||
current = node
|
||
while isinstance(current, ast.Attribute):
|
||
parts.append(current.attr)
|
||
current = current.value
|
||
if isinstance(current, ast.Name):
|
||
parts.append(current.id)
|
||
parts.reverse()
|
||
return ".".join(parts)
|
||
|
||
# Handle subscript types like List[str], Dict[str, int], Optional[str]
|
||
if isinstance(node, ast.Subscript):
|
||
base = _get_annotation_str(node.value, import_map)
|
||
if isinstance(node.slice, ast.Tuple):
|
||
args = ", ".join(_get_annotation_str(elt, import_map) for elt in node.slice.elts)
|
||
else:
|
||
args = _get_annotation_str(node.slice, import_map)
|
||
return f"{base}[{args}]"
|
||
|
||
# Handle Union types (X | Y in Python 3.10+)
|
||
if isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr):
|
||
left = _get_annotation_str(node.left, import_map)
|
||
right = _get_annotation_str(node.right, import_map)
|
||
return f"Union[{left}, {right}]"
|
||
|
||
return ast.dump(node)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Module path derivation
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def _filepath_to_module(filepath: Path, python_path: Path) -> str:
|
||
"""
|
||
通过 *python_path*(sys.path 中的根目录)推导 Python 模块路径。
|
||
|
||
做法:取 filepath 相对于 python_path 的路径,将目录分隔符替换为 '.'。
|
||
|
||
E.g. filepath = "/project/unilabos/devices/pump/valve.py"
|
||
python_path = "/project"
|
||
=> "unilabos.devices.pump.valve"
|
||
"""
|
||
try:
|
||
relative = filepath.relative_to(python_path)
|
||
except ValueError:
|
||
return str(filepath)
|
||
|
||
parts = list(relative.parts)
|
||
# 去掉 .py 后缀
|
||
if parts and parts[-1].endswith(".py"):
|
||
parts[-1] = parts[-1][:-3]
|
||
# 去掉 __init__
|
||
if parts and parts[-1] == "__init__":
|
||
parts.pop()
|
||
|
||
return ".".join(parts)
|