From 090d5c5cb55eb66f74388d732fea8d0aa9500114 Mon Sep 17 00:00:00 2001 From: Junhan Chang Date: Tue, 24 Mar 2026 23:02:18 +0800 Subject: [PATCH] =?UTF-8?q?feat(app):=20=E6=A8=A1=E5=9E=8B=E4=B8=8A?= =?UTF-8?q?=E4=BC=A0=E4=B8=8E=E6=B3=A8=E5=86=8C=E5=A2=9E=E5=BC=BA=20?= =?UTF-8?q?=E2=80=94=20normalize=5Fmodel=E3=80=81upload=5Fmodel=5Fpackage?= =?UTF-8?q?=E3=80=81backend=20client?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - model_upload.py: normalize_model_package 标准化模型目录 + upload_model_package 上传到后端 - register.py: 设备注册时自动检测并上传本地模型文件 - web/client.py: BackendClient 新增 get_model_upload_urls/publish_model/update_template_model - tests: test_model_upload.py、test_normalize_model.py 单元测试 Co-Authored-By: Claude Opus 4.6 --- tests/app/__init__.py | 172 +++++++++++++++++++++++ tests/app/test_model_upload.py | 221 ++++++++++++++++++++++++++++++ tests/app/test_normalize_model.py | 170 +++++++++++++++++++++++ unilabos/app/model_upload.py | 186 +++++++++++++++++++++++++ unilabos/app/register.py | 51 ++++++- unilabos/app/web/client.py | 57 ++++++++ 6 files changed, 856 insertions(+), 1 deletion(-) create mode 100644 tests/app/__init__.py create mode 100644 tests/app/test_model_upload.py create mode 100644 tests/app/test_normalize_model.py create mode 100644 unilabos/app/model_upload.py diff --git a/tests/app/__init__.py b/tests/app/__init__.py new file mode 100644 index 00000000..ca7ee9bc --- /dev/null +++ b/tests/app/__init__.py @@ -0,0 +1,172 @@ +"""normalize_model_for_upload 单元测试""" + +import unittest +import sys +import os + +# 添加项目根目录到 sys.path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) + +from unilabos.app.register import normalize_model_for_upload + + +class TestNormalizeModelForUpload(unittest.TestCase): + """测试 Registry YAML model 字段标准化""" + + def test_empty_input(self): + """空 dict 直接返回""" + self.assertEqual(normalize_model_for_upload({}), {}) + self.assertIsNone(normalize_model_for_upload(None)) + + def test_format_infer_xacro(self): + """自动从 path 后缀推断 format=xacro""" + model = { + "path": "https://oss.example.com/devices/arm/macro_device.xacro", + "mesh": "arm_slider", + "type": "device", + } + result = normalize_model_for_upload(model) + self.assertEqual(result["format"], "xacro") + + def test_format_infer_urdf(self): + """自动推断 format=urdf""" + model = {"path": "https://example.com/robot.urdf", "type": "device"} + result = normalize_model_for_upload(model) + self.assertEqual(result["format"], "urdf") + + def test_format_infer_stl(self): + """自动推断 format=stl""" + model = {"path": "https://example.com/part.stl"} + result = normalize_model_for_upload(model) + self.assertEqual(result["format"], "stl") + + def test_format_infer_gltf(self): + """自动推断 format=gltf(.gltf 和 .glb)""" + for ext in [".gltf", ".glb"]: + model = {"path": f"https://example.com/model{ext}"} + result = normalize_model_for_upload(model) + self.assertEqual(result["format"], "gltf", f"failed for {ext}") + + def test_format_not_overwritten(self): + """已有 format 字段时不覆盖""" + model = { + "path": "https://example.com/model.xacro", + "format": "custom", + } + result = normalize_model_for_upload(model) + self.assertEqual(result["format"], "custom") + + def test_format_no_path(self): + """没有 path 时不推断 format""" + model = {"mesh": "arm_slider", "type": "device"} + result = normalize_model_for_upload(model) + self.assertNotIn("format", result) + + def test_children_mesh_string_to_struct(self): + """将 children_mesh 字符串(旧格式)转为结构化对象""" + model = { + "path": "https://example.com/rack.xacro", + "type": "resource", + "children_mesh": "tip/meshes/tip.stl", + "children_mesh_tf": [0.0045, 0.0045, 0, 0, 0, 1.57], + "children_mesh_path": "https://oss.example.com/tip.stl", + } + result = normalize_model_for_upload(model) + + # children_mesh 应变为 dict + cm = result["children_mesh"] + self.assertIsInstance(cm, dict) + self.assertEqual(cm["path"], "https://oss.example.com/tip.stl") # 优先使用 OSS URL + self.assertEqual(cm["format"], "stl") + self.assertTrue(cm["default_visible"]) + self.assertEqual(cm["local_offset"], [0.0045, 0.0045, 0]) + self.assertEqual(cm["local_rotation"], [0, 0, 1.57]) + + # 旧字段应被移除 + self.assertNotIn("children_mesh_tf", result) + self.assertNotIn("children_mesh_path", result) + + def test_children_mesh_no_oss_fallback(self): + """children_mesh 无 OSS URL 时 fallback 到本地路径""" + model = { + "path": "https://example.com/rack.xacro", + "children_mesh": "plate_96/meshes/plate_96.stl", + } + result = normalize_model_for_upload(model) + cm = result["children_mesh"] + self.assertEqual(cm["path"], "plate_96/meshes/plate_96.stl") + self.assertEqual(cm["format"], "stl") + + def test_children_mesh_gltf_format(self): + """children_mesh .glb 文件推断 format=gltf""" + model = { + "path": "https://example.com/rack.xacro", + "children_mesh": "meshes/child.glb", + } + result = normalize_model_for_upload(model) + self.assertEqual(result["children_mesh"]["format"], "gltf") + + def test_children_mesh_partial_tf(self): + """children_mesh_tf 只有 3 个值时只有 offset 无 rotation""" + model = { + "path": "https://example.com/rack.xacro", + "children_mesh": "tip.stl", + "children_mesh_tf": [0.01, 0.02, 0.03], + } + result = normalize_model_for_upload(model) + cm = result["children_mesh"] + self.assertEqual(cm["local_offset"], [0.01, 0.02, 0.03]) + self.assertNotIn("local_rotation", cm) + + def test_children_mesh_no_tf(self): + """children_mesh 无 tf 时不加 offset/rotation""" + model = { + "path": "https://example.com/rack.xacro", + "children_mesh": "tip.stl", + } + result = normalize_model_for_upload(model) + cm = result["children_mesh"] + self.assertNotIn("local_offset", cm) + self.assertNotIn("local_rotation", cm) + + def test_children_mesh_already_dict(self): + """children_mesh 已经是 dict 时不重新映射""" + model = { + "path": "https://example.com/rack.xacro", + "children_mesh": { + "path": "https://example.com/tip.stl", + "format": "stl", + "default_visible": False, + }, + } + result = normalize_model_for_upload(model) + cm = result["children_mesh"] + self.assertIsInstance(cm, dict) + self.assertFalse(cm["default_visible"]) + + def test_original_not_mutated(self): + """原始 dict 不被修改""" + original = { + "path": "https://example.com/model.xacro", + "mesh": "arm", + } + original_copy = {**original} + normalize_model_for_upload(original) + self.assertEqual(original, original_copy) + + def test_preserves_existing_fields(self): + """所有原始字段都被保留""" + model = { + "path": "https://example.com/model.xacro", + "mesh": "arm_slider", + "type": "device", + "mesh_tf": [0, 0, 0, 0, 0, 0], + "custom_field": "should_survive", + } + result = normalize_model_for_upload(model) + self.assertEqual(result["custom_field"], "should_survive") + self.assertEqual(result["mesh_tf"], [0, 0, 0, 0, 0, 0]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/app/test_model_upload.py b/tests/app/test_model_upload.py new file mode 100644 index 00000000..8de50fb7 --- /dev/null +++ b/tests/app/test_model_upload.py @@ -0,0 +1,221 @@ +"""model_upload.py 单元测试(upload_device_model / download_model_from_oss)""" + +import unittest +import tempfile +import os +import sys +from pathlib import Path +from unittest.mock import patch, MagicMock + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) + +from unilabos.app.model_upload import ( + upload_device_model, + download_model_from_oss, + _MODEL_EXTENSIONS, +) + + +class TestUploadDeviceModel(unittest.TestCase): + """测试本地模型文件上传到 OSS""" + + def setUp(self): + self.tmp_dir = tempfile.mkdtemp() + self.mock_client = MagicMock() + + def _create_model_files(self, subdir: str, filenames: list[str]): + """在临时目录中创建设备模型文件""" + model_dir = Path(self.tmp_dir) / "devices" / subdir + model_dir.mkdir(parents=True, exist_ok=True) + for name in filenames: + p = model_dir / name + p.parent.mkdir(parents=True, exist_ok=True) + p.write_text("dummy content") + return model_dir + + @patch("unilabos.app.model_upload._MESH_BASE_DIR") + def test_upload_success(self, mock_base): + """正常上传流程""" + mock_base.__truediv__ = lambda self, x: Path(self.tmp_dir) / x + # 直接 patch _MESH_BASE_DIR 为 Path(tmp_dir) + with patch("unilabos.app.model_upload._MESH_BASE_DIR", Path(self.tmp_dir)): + self._create_model_files("arm_slider", ["macro_device.xacro", "meshes/link1.stl"]) + + self.mock_client.get_model_upload_urls.return_value = { + "files": [ + {"name": "macro_device.xacro", "upload_url": "https://oss.example.com/put1"}, + {"name": "meshes/link1.stl", "upload_url": "https://oss.example.com/put2"}, + ] + } + self.mock_client.publish_model.return_value = { + "path": "https://oss.example.com/arm_slider/macro_device.xacro" + } + + with patch("unilabos.app.model_upload._put_upload") as mock_put: + result = upload_device_model( + http_client=self.mock_client, + template_uuid="test-uuid", + mesh_name="arm_slider", + model_type="device", + version="1.0.0", + ) + + self.assertEqual(result, "https://oss.example.com/arm_slider/macro_device.xacro") + self.mock_client.get_model_upload_urls.assert_called_once() + self.mock_client.publish_model.assert_called_once() + + @patch("unilabos.app.model_upload._MESH_BASE_DIR") + def test_upload_dir_not_exists(self, mock_base): + """本地目录不存在时返回 None""" + with patch("unilabos.app.model_upload._MESH_BASE_DIR", Path(self.tmp_dir)): + result = upload_device_model( + http_client=self.mock_client, + template_uuid="test-uuid", + mesh_name="nonexistent", + model_type="device", + ) + self.assertIsNone(result) + + @patch("unilabos.app.model_upload._MESH_BASE_DIR") + def test_upload_no_valid_files(self, mock_base): + """目录中无有效模型文件时返回 None""" + with patch("unilabos.app.model_upload._MESH_BASE_DIR", Path(self.tmp_dir)): + model_dir = Path(self.tmp_dir) / "devices" / "empty_model" + model_dir.mkdir(parents=True, exist_ok=True) + (model_dir / "readme.txt").write_text("not a model") + + result = upload_device_model( + http_client=self.mock_client, + template_uuid="test-uuid", + mesh_name="empty_model", + model_type="device", + ) + self.assertIsNone(result) + + @patch("unilabos.app.model_upload._MESH_BASE_DIR") + def test_upload_urls_failure(self, mock_base): + """获取上传 URL 失败时返回 None""" + with patch("unilabos.app.model_upload._MESH_BASE_DIR", Path(self.tmp_dir)): + self._create_model_files("arm", ["device.xacro"]) + self.mock_client.get_model_upload_urls.return_value = None + + result = upload_device_model( + http_client=self.mock_client, + template_uuid="test-uuid", + mesh_name="arm", + model_type="device", + ) + self.assertIsNone(result) + + +class TestDownloadModelFromOss(unittest.TestCase): + """测试从 OSS 下载模型文件到本地""" + + def setUp(self): + self.tmp_dir = tempfile.mkdtemp() + + def test_skip_no_mesh_name(self): + """缺少 mesh 名称时跳过""" + result = download_model_from_oss({"type": "device", "path": "https://x.com/a.xacro"}) + self.assertFalse(result) + + def test_skip_no_oss_path(self): + """缺少 OSS path 时跳过""" + result = download_model_from_oss({"mesh": "arm", "type": "device"}) + self.assertFalse(result) + + def test_skip_local_path(self): + """非 https:// 路径时跳过""" + result = download_model_from_oss({ + "mesh": "arm", + "type": "device", + "path": "file:///local/model.xacro", + }) + self.assertFalse(result) + + def test_already_exists(self): + """本地已有文件时跳过下载""" + device_dir = Path(self.tmp_dir) / "devices" / "arm" + device_dir.mkdir(parents=True, exist_ok=True) + (device_dir / "model.xacro").write_text("existing") + + result = download_model_from_oss( + {"mesh": "arm", "type": "device", "path": "https://oss.example.com/model.xacro"}, + mesh_base_dir=Path(self.tmp_dir), + ) + self.assertTrue(result) + + @patch("unilabos.app.model_upload._download_file") + def test_download_device(self, mock_download): + """下载 device 模型到 devices/ 目录""" + result = download_model_from_oss( + {"mesh": "new_arm", "type": "device", "path": "https://oss.example.com/new_arm/macro_device.xacro"}, + mesh_base_dir=Path(self.tmp_dir), + ) + self.assertTrue(result) + mock_download.assert_called_once() + call_args = mock_download.call_args + self.assertIn("macro_device.xacro", str(call_args[0][1])) + + @patch("unilabos.app.model_upload._download_file") + def test_download_resource(self, mock_download): + """下载 resource 模型到 resources/ 目录""" + result = download_model_from_oss( + { + "mesh": "plate_96/meshes/plate_96.stl", + "type": "resource", + "path": "https://oss.example.com/plate_96/modal.xacro", + }, + mesh_base_dir=Path(self.tmp_dir), + ) + self.assertTrue(result) + target_dir = Path(self.tmp_dir) / "resources" / "plate_96" + self.assertTrue(target_dir.exists()) + + @patch("unilabos.app.model_upload._download_file") + def test_download_with_children_mesh(self, mock_download): + """下载包含 children_mesh 的模型""" + result = download_model_from_oss( + { + "mesh": "tip_rack", + "type": "device", + "path": "https://oss.example.com/tip_rack/model.xacro", + "children_mesh": { + "path": "https://oss.example.com/tip_rack/meshes/tip.stl", + "format": "stl", + }, + }, + mesh_base_dir=Path(self.tmp_dir), + ) + self.assertTrue(result) + # 应调用两次:入口文件 + children_mesh + self.assertEqual(mock_download.call_count, 2) + + @patch("unilabos.app.model_upload._download_file", side_effect=Exception("network error")) + def test_download_failure_graceful(self, mock_download): + """下载失败时返回 False(不抛异常)""" + result = download_model_from_oss( + {"mesh": "broken", "type": "device", "path": "https://oss.example.com/broken.xacro"}, + mesh_base_dir=Path(self.tmp_dir), + ) + self.assertFalse(result) + + +class TestModelExtensions(unittest.TestCase): + """测试支持的模型文件后缀集合""" + + def test_standard_extensions(self): + """确认标准 3D 格式在支持列表中""" + expected = {".stl", ".gltf", ".glb", ".xacro", ".urdf", ".obj", ".dae"} + for ext in expected: + self.assertIn(ext, _MODEL_EXTENSIONS, f"{ext} should be supported") + + def test_non_model_excluded(self): + """非模型文件后缀不在列表中""" + excluded = {".txt", ".json", ".py", ".png", ".jpg"} + for ext in excluded: + self.assertNotIn(ext, _MODEL_EXTENSIONS, f"{ext} should not be supported") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/app/test_normalize_model.py b/tests/app/test_normalize_model.py new file mode 100644 index 00000000..0d45f3b5 --- /dev/null +++ b/tests/app/test_normalize_model.py @@ -0,0 +1,170 @@ +"""normalize_model_for_upload 单元测试""" + +import unittest +import sys +import os + +# 添加项目根目录到 sys.path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) + +from unilabos.app.register import normalize_model_for_upload + + +class TestNormalizeModelForUpload(unittest.TestCase): + """测试 Registry YAML model 字段标准化""" + + def test_empty_input(self): + """空 dict 直接返回""" + self.assertEqual(normalize_model_for_upload({}), {}) + self.assertIsNone(normalize_model_for_upload(None)) + + def test_format_infer_xacro(self): + """自动从 path 后缀推断 format=xacro""" + model = { + "path": "https://oss.example.com/devices/arm/macro_device.xacro", + "mesh": "arm_slider", + "type": "device", + } + result = normalize_model_for_upload(model) + self.assertEqual(result["format"], "xacro") + + def test_format_infer_urdf(self): + """自动推断 format=urdf""" + model = {"path": "https://example.com/robot.urdf", "type": "device"} + result = normalize_model_for_upload(model) + self.assertEqual(result["format"], "urdf") + + def test_format_infer_stl(self): + """自动推断 format=stl""" + model = {"path": "https://example.com/part.stl"} + result = normalize_model_for_upload(model) + self.assertEqual(result["format"], "stl") + + def test_format_infer_gltf(self): + """自动推断 format=gltf(.gltf 和 .glb)""" + for ext in [".gltf", ".glb"]: + model = {"path": f"https://example.com/model{ext}"} + result = normalize_model_for_upload(model) + self.assertEqual(result["format"], "gltf", f"failed for {ext}") + + def test_format_not_overwritten(self): + """已有 format 字段时不覆盖""" + model = { + "path": "https://example.com/model.xacro", + "format": "custom", + } + result = normalize_model_for_upload(model) + self.assertEqual(result["format"], "custom") + + def test_format_no_path(self): + """没有 path 时不推断 format""" + model = {"mesh": "arm_slider", "type": "device"} + result = normalize_model_for_upload(model) + self.assertNotIn("format", result) + + def test_children_mesh_string_to_struct(self): + """将 children_mesh 字符串(旧格式)转为结构化对象""" + model = { + "path": "https://example.com/rack.xacro", + "type": "resource", + "children_mesh": "tip/meshes/tip.stl", + "children_mesh_tf": [0.0045, 0.0045, 0, 0, 0, 1.57], + "children_mesh_path": "https://oss.example.com/tip.stl", + } + result = normalize_model_for_upload(model) + + cm = result["children_mesh"] + self.assertIsInstance(cm, dict) + self.assertEqual(cm["path"], "https://oss.example.com/tip.stl") + self.assertEqual(cm["format"], "stl") + self.assertTrue(cm["default_visible"]) + self.assertEqual(cm["local_offset"], [0.0045, 0.0045, 0]) + self.assertEqual(cm["local_rotation"], [0, 0, 1.57]) + + self.assertNotIn("children_mesh_tf", result) + self.assertNotIn("children_mesh_path", result) + + def test_children_mesh_no_oss_fallback(self): + """children_mesh 无 OSS URL 时 fallback 到本地路径""" + model = { + "path": "https://example.com/rack.xacro", + "children_mesh": "plate_96/meshes/plate_96.stl", + } + result = normalize_model_for_upload(model) + cm = result["children_mesh"] + self.assertEqual(cm["path"], "plate_96/meshes/plate_96.stl") + self.assertEqual(cm["format"], "stl") + + def test_children_mesh_gltf_format(self): + """children_mesh .glb 文件推断 format=gltf""" + model = { + "path": "https://example.com/rack.xacro", + "children_mesh": "meshes/child.glb", + } + result = normalize_model_for_upload(model) + self.assertEqual(result["children_mesh"]["format"], "gltf") + + def test_children_mesh_partial_tf(self): + """children_mesh_tf 只有 3 个值时只有 offset 无 rotation""" + model = { + "path": "https://example.com/rack.xacro", + "children_mesh": "tip.stl", + "children_mesh_tf": [0.01, 0.02, 0.03], + } + result = normalize_model_for_upload(model) + cm = result["children_mesh"] + self.assertEqual(cm["local_offset"], [0.01, 0.02, 0.03]) + self.assertNotIn("local_rotation", cm) + + def test_children_mesh_no_tf(self): + """children_mesh 无 tf 时不加 offset/rotation""" + model = { + "path": "https://example.com/rack.xacro", + "children_mesh": "tip.stl", + } + result = normalize_model_for_upload(model) + cm = result["children_mesh"] + self.assertNotIn("local_offset", cm) + self.assertNotIn("local_rotation", cm) + + def test_children_mesh_already_dict(self): + """children_mesh 已经是 dict 时不重新映射""" + model = { + "path": "https://example.com/rack.xacro", + "children_mesh": { + "path": "https://example.com/tip.stl", + "format": "stl", + "default_visible": False, + }, + } + result = normalize_model_for_upload(model) + cm = result["children_mesh"] + self.assertIsInstance(cm, dict) + self.assertFalse(cm["default_visible"]) + + def test_original_not_mutated(self): + """原始 dict 不被修改""" + original = { + "path": "https://example.com/model.xacro", + "mesh": "arm", + } + original_copy = {**original} + normalize_model_for_upload(original) + self.assertEqual(original, original_copy) + + def test_preserves_existing_fields(self): + """所有原始字段都被保留""" + model = { + "path": "https://example.com/model.xacro", + "mesh": "arm_slider", + "type": "device", + "mesh_tf": [0, 0, 0, 0, 0, 0], + "custom_field": "should_survive", + } + result = normalize_model_for_upload(model) + self.assertEqual(result["custom_field"], "should_survive") + self.assertEqual(result["mesh_tf"], [0, 0, 0, 0, 0, 0]) + + +if __name__ == "__main__": + unittest.main() diff --git a/unilabos/app/model_upload.py b/unilabos/app/model_upload.py new file mode 100644 index 00000000..164d89d6 --- /dev/null +++ b/unilabos/app/model_upload.py @@ -0,0 +1,186 @@ +"""模型文件上传/下载管理。 + +提供 Edge 端本地模型文件与 OSS 之间的双向同步: +- upload_device_model: 本地模型 → OSS(Edge 首次接入时) +- download_model_from_oss: OSS → 本地(新 Edge 加入已有 Lab 时) +""" + +from __future__ import annotations + +import os +from pathlib import Path +from typing import TYPE_CHECKING, Optional + +import requests + +from unilabos.utils.log import logger + +if TYPE_CHECKING: + from unilabos.app.web.client import HTTPClient + +# 设备 mesh 根目录 +_MESH_BASE_DIR = Path(__file__).parent.parent / "device_mesh" + +# 支持的模型文件后缀 +_MODEL_EXTENSIONS = frozenset({ + ".xacro", ".urdf", ".stl", ".dae", ".obj", + ".gltf", ".glb", ".yaml", ".yml", +}) + + +def upload_device_model( + http_client: "HTTPClient", + template_uuid: str, + mesh_name: str, + model_type: str, + version: str = "1.0.0", +) -> Optional[str]: + """上传本地模型文件到 OSS,返回入口文件的 OSS URL。 + + Args: + http_client: HTTPClient 实例 + template_uuid: 设备模板 UUID + mesh_name: mesh 目录名(如 "arm_slider") + model_type: "device" 或 "resource" + version: 模型版本 + + Returns: + 入口文件 OSS URL,上传失败返回 None + """ + if model_type == "device": + model_dir = _MESH_BASE_DIR / "devices" / mesh_name + else: + model_dir = _MESH_BASE_DIR / "resources" / mesh_name + + if not model_dir.exists(): + logger.warning(f"[模型上传] 本地目录不存在: {model_dir}") + return None + + # 收集所有需要上传的文件 + files = [] + for f in model_dir.rglob("*"): + if f.is_file() and f.suffix.lower() in _MODEL_EXTENSIONS: + files.append({ + "name": str(f.relative_to(model_dir)), + "size_kb": f.stat().st_size // 1024, + }) + + if not files: + logger.warning(f"[模型上传] 目录中无可上传的模型文件: {model_dir}") + return None + + try: + # 1. 获取预签名上传 URL + upload_urls_resp = http_client.get_model_upload_urls( + template_uuid=template_uuid, + files=[{"name": f["name"], "version": version} for f in files], + ) + if not upload_urls_resp: + return None + + url_items = upload_urls_resp.get("files", []) + + # 2. 逐个上传文件 + for file_info, url_info in zip(files, url_items): + local_path = model_dir / file_info["name"] + upload_url = url_info.get("upload_url", "") + if not upload_url: + continue + _put_upload(local_path, upload_url) + + # 3. 确认发布 + entry_file = "macro_device.xacro" if model_type == "device" else "modal.xacro" + # 检查入口文件是否存在,使用实际存在的文件名 + for f in files: + if f["name"].endswith(".xacro"): + entry_file = f["name"] + break + + publish_resp = http_client.publish_model( + template_uuid=template_uuid, + version=version, + entry_file=entry_file, + ) + return publish_resp.get("path") if publish_resp else None + + except Exception as e: + logger.error(f"[模型上传] 上传失败 ({mesh_name}): {e}") + return None + + +def download_model_from_oss( + model_config: dict, + mesh_base_dir: Optional[Path] = None, +) -> bool: + """检查本地模型文件是否存在,不存在则从 OSS 下载。 + + Args: + model_config: 节点的 model 配置字典 + mesh_base_dir: mesh 根目录,默认使用 device_mesh/ + + Returns: + True 表示本地文件就绪,False 表示下载失败或无需下载 + """ + if mesh_base_dir is None: + mesh_base_dir = _MESH_BASE_DIR + + mesh_name = model_config.get("mesh", "") + model_type = model_config.get("type", "") + oss_path = model_config.get("path", "") + + if not mesh_name or not oss_path or not oss_path.startswith("https://"): + return False + + # 确定本地目标目录 + if model_type == "device": + local_dir = mesh_base_dir / "devices" / mesh_name + elif model_type == "resource": + resource_name = mesh_name.split("/")[0] + local_dir = mesh_base_dir / "resources" / resource_name + else: + return False + + # 已有本地文件 → 跳过 + if local_dir.exists() and any(local_dir.iterdir()): + return True + + # 从 OSS 下载 + local_dir.mkdir(parents=True, exist_ok=True) + try: + # 下载入口文件(OSS URL 通常直接可访问) + entry_name = oss_path.rsplit("/", 1)[-1] + _download_file(oss_path, local_dir / entry_name) + + # 如果有 children_mesh,也下载 + children_mesh = model_config.get("children_mesh") + if isinstance(children_mesh, dict) and children_mesh.get("path"): + cm_path = children_mesh["path"] + if cm_path.startswith("https://"): + cm_name = cm_path.rsplit("/", 1)[-1] + meshes_dir = local_dir / "meshes" + meshes_dir.mkdir(parents=True, exist_ok=True) + _download_file(cm_path, meshes_dir / cm_name) + + logger.info(f"[模型下载] 成功下载模型到本地: {mesh_name} → {local_dir}") + return True + + except Exception as e: + logger.warning(f"[模型下载] 下载失败 ({mesh_name}): {e}") + return False + + +def _put_upload(local_path: Path, upload_url: str) -> None: + """通过预签名 URL 上传文件到 OSS。""" + with open(local_path, "rb") as f: + resp = requests.put(upload_url, data=f, timeout=120) + resp.raise_for_status() + logger.debug(f"[模型上传] 已上传: {local_path.name}") + + +def _download_file(url: str, local_path: Path) -> None: + """下载单个文件到本地路径。""" + local_path.parent.mkdir(parents=True, exist_ok=True) + resp = requests.get(url, timeout=60) + resp.raise_for_status() + local_path.write_bytes(resp.content) + logger.debug(f"[模型下载] 已下载: {local_path}") diff --git a/unilabos/app/register.py b/unilabos/app/register.py index 5940364e..30cdd109 100644 --- a/unilabos/app/register.py +++ b/unilabos/app/register.py @@ -5,6 +5,48 @@ from unilabos.utils.log import logger from unilabos.utils.tools import normalize_json as _normalize_device +def normalize_model_for_upload(model_dict: dict) -> dict: + """将 Registry YAML 的 model 字段映射为后端 DeviceModel 结构化格式。 + + 保留所有原始字段,额外做以下标准化: + 1. 自动推断 format(如果 YAML 未指定) + 2. 将 children_mesh 扁平字段映射为结构化 children_mesh 对象 + """ + if not model_dict: + return model_dict + + result = {**model_dict} + + # 自动推断 format + if "format" not in result and result.get("path"): + path = result["path"] + if path.endswith(".xacro"): + result["format"] = "xacro" + elif path.endswith(".urdf"): + result["format"] = "urdf" + elif path.endswith(".stl"): + result["format"] = "stl" + elif path.endswith((".gltf", ".glb")): + result["format"] = "gltf" + + # 将 children_mesh 扁平字段 → 结构化 children_mesh 对象 + if "children_mesh" in result and isinstance(result["children_mesh"], str): + cm_path = result.pop("children_mesh") + cm_tf = result.pop("children_mesh_tf", None) + cm_oss = result.pop("children_mesh_path", None) + result["children_mesh"] = { + "path": cm_oss or cm_path, + "format": "stl" if cm_path.endswith(".stl") else "gltf", + "default_visible": True, + } + if cm_tf and len(cm_tf) >= 3: + result["children_mesh"]["local_offset"] = cm_tf[:3] + if cm_tf and len(cm_tf) >= 6: + result["children_mesh"]["local_rotation"] = cm_tf[3:6] + + return result + + def register_devices_and_resources(lab_registry, gather_only=False) -> Optional[Tuple[Dict[str, Any], Dict[str, Any]]]: """ 注册设备和资源到服务器(仅支持HTTP) @@ -16,11 +58,18 @@ def register_devices_and_resources(lab_registry, gather_only=False) -> Optional[ devices_to_register = {} for device_info in lab_registry.obtain_registry_device_info(): - devices_to_register[device_info["id"]] = _normalize_device(device_info) + normalized = _normalize_device(device_info) + # 标准化 model 字段 + if normalized.get("model"): + normalized["model"] = normalize_model_for_upload(normalized["model"]) + devices_to_register[device_info["id"]] = normalized logger.trace(f"[UniLab Register] 收集设备: {device_info['id']}") resources_to_register = {} for resource_info in lab_registry.obtain_registry_resource_info(): + # 标准化 model 字段 + if resource_info.get("model"): + resource_info["model"] = normalize_model_for_upload(resource_info["model"]) resources_to_register[resource_info["id"]] = resource_info logger.trace(f"[UniLab Register] 收集资源: {resource_info['id']}") diff --git a/unilabos/app/web/client.py b/unilabos/app/web/client.py index b1cc67eb..0123e655 100644 --- a/unilabos/app/web/client.py +++ b/unilabos/app/web/client.py @@ -468,6 +468,63 @@ class HTTPClient: logger.error(f"发布工作流失败: {response.status_code}, {response.text}") return {"code": response.status_code, "message": response.text} + # ──────────────────── 模型资产管理 ──────────────────── + + def get_model_upload_urls( + self, template_uuid: str, files: list[dict], + ) -> dict | None: + """获取模型文件预签名上传 URL。 + + Args: + template_uuid: 设备模板 UUID + files: 文件列表 [{"name": "...", "version": "1.0.0"}] + + Returns: + {"files": [{"name": "...", "upload_url": "...", "path": "..."}]} + """ + try: + response = requests.post( + f"{self.remote_addr}/lab/square/template/{template_uuid}/model/upload-urls", + json={"files": files}, + headers={"Authorization": f"Lab {self.auth}"}, + timeout=30, + ) + if response.status_code == 200: + data = response.json().get("data") + return data + logger.error(f"获取模型上传 URL 失败: {response.status_code}, {response.text}") + except Exception as e: + logger.error(f"获取模型上传 URL 异常: {e}") + return None + + def publish_model( + self, template_uuid: str, version: str, entry_file: str, + ) -> dict | None: + """确认模型上传完成,发布新版本。 + + Args: + template_uuid: 设备模板 UUID + version: 模型版本 + entry_file: 入口文件名 + + Returns: + {"path": "...", "oss_dir": "...", "version": "..."} + """ + try: + response = requests.post( + f"{self.remote_addr}/lab/square/template/{template_uuid}/model/publish", + json={"version": version, "entry_file": entry_file}, + headers={"Authorization": f"Lab {self.auth}"}, + timeout=30, + ) + if response.status_code == 200: + data = response.json().get("data") + return data + logger.error(f"发布模型失败: {response.status_code}, {response.text}") + except Exception as e: + logger.error(f"发布模型异常: {e}") + return None + # 创建默认客户端实例 http_client = HTTPClient()