mirror of
https://github.com/deepmodeling/Uni-Lab-OS
synced 2026-03-26 20:16:50 +00:00
feat(app): 模型上传与注册增强 — normalize_model、upload_model_package、backend client
- model_upload.py: normalize_model_package 标准化模型目录 + upload_model_package 上传到后端 - register.py: 设备注册时自动检测并上传本地模型文件 - web/client.py: BackendClient 新增 get_model_upload_urls/publish_model/update_template_model - tests: test_model_upload.py、test_normalize_model.py 单元测试 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
172
tests/app/__init__.py
Normal file
172
tests/app/__init__.py
Normal file
@@ -0,0 +1,172 @@
|
|||||||
|
"""normalize_model_for_upload 单元测试"""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
# 添加项目根目录到 sys.path
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||||
|
|
||||||
|
from unilabos.app.register import normalize_model_for_upload
|
||||||
|
|
||||||
|
|
||||||
|
class TestNormalizeModelForUpload(unittest.TestCase):
|
||||||
|
"""测试 Registry YAML model 字段标准化"""
|
||||||
|
|
||||||
|
def test_empty_input(self):
|
||||||
|
"""空 dict 直接返回"""
|
||||||
|
self.assertEqual(normalize_model_for_upload({}), {})
|
||||||
|
self.assertIsNone(normalize_model_for_upload(None))
|
||||||
|
|
||||||
|
def test_format_infer_xacro(self):
|
||||||
|
"""自动从 path 后缀推断 format=xacro"""
|
||||||
|
model = {
|
||||||
|
"path": "https://oss.example.com/devices/arm/macro_device.xacro",
|
||||||
|
"mesh": "arm_slider",
|
||||||
|
"type": "device",
|
||||||
|
}
|
||||||
|
result = normalize_model_for_upload(model)
|
||||||
|
self.assertEqual(result["format"], "xacro")
|
||||||
|
|
||||||
|
def test_format_infer_urdf(self):
|
||||||
|
"""自动推断 format=urdf"""
|
||||||
|
model = {"path": "https://example.com/robot.urdf", "type": "device"}
|
||||||
|
result = normalize_model_for_upload(model)
|
||||||
|
self.assertEqual(result["format"], "urdf")
|
||||||
|
|
||||||
|
def test_format_infer_stl(self):
|
||||||
|
"""自动推断 format=stl"""
|
||||||
|
model = {"path": "https://example.com/part.stl"}
|
||||||
|
result = normalize_model_for_upload(model)
|
||||||
|
self.assertEqual(result["format"], "stl")
|
||||||
|
|
||||||
|
def test_format_infer_gltf(self):
|
||||||
|
"""自动推断 format=gltf(.gltf 和 .glb)"""
|
||||||
|
for ext in [".gltf", ".glb"]:
|
||||||
|
model = {"path": f"https://example.com/model{ext}"}
|
||||||
|
result = normalize_model_for_upload(model)
|
||||||
|
self.assertEqual(result["format"], "gltf", f"failed for {ext}")
|
||||||
|
|
||||||
|
def test_format_not_overwritten(self):
|
||||||
|
"""已有 format 字段时不覆盖"""
|
||||||
|
model = {
|
||||||
|
"path": "https://example.com/model.xacro",
|
||||||
|
"format": "custom",
|
||||||
|
}
|
||||||
|
result = normalize_model_for_upload(model)
|
||||||
|
self.assertEqual(result["format"], "custom")
|
||||||
|
|
||||||
|
def test_format_no_path(self):
|
||||||
|
"""没有 path 时不推断 format"""
|
||||||
|
model = {"mesh": "arm_slider", "type": "device"}
|
||||||
|
result = normalize_model_for_upload(model)
|
||||||
|
self.assertNotIn("format", result)
|
||||||
|
|
||||||
|
def test_children_mesh_string_to_struct(self):
|
||||||
|
"""将 children_mesh 字符串(旧格式)转为结构化对象"""
|
||||||
|
model = {
|
||||||
|
"path": "https://example.com/rack.xacro",
|
||||||
|
"type": "resource",
|
||||||
|
"children_mesh": "tip/meshes/tip.stl",
|
||||||
|
"children_mesh_tf": [0.0045, 0.0045, 0, 0, 0, 1.57],
|
||||||
|
"children_mesh_path": "https://oss.example.com/tip.stl",
|
||||||
|
}
|
||||||
|
result = normalize_model_for_upload(model)
|
||||||
|
|
||||||
|
# children_mesh 应变为 dict
|
||||||
|
cm = result["children_mesh"]
|
||||||
|
self.assertIsInstance(cm, dict)
|
||||||
|
self.assertEqual(cm["path"], "https://oss.example.com/tip.stl") # 优先使用 OSS URL
|
||||||
|
self.assertEqual(cm["format"], "stl")
|
||||||
|
self.assertTrue(cm["default_visible"])
|
||||||
|
self.assertEqual(cm["local_offset"], [0.0045, 0.0045, 0])
|
||||||
|
self.assertEqual(cm["local_rotation"], [0, 0, 1.57])
|
||||||
|
|
||||||
|
# 旧字段应被移除
|
||||||
|
self.assertNotIn("children_mesh_tf", result)
|
||||||
|
self.assertNotIn("children_mesh_path", result)
|
||||||
|
|
||||||
|
def test_children_mesh_no_oss_fallback(self):
|
||||||
|
"""children_mesh 无 OSS URL 时 fallback 到本地路径"""
|
||||||
|
model = {
|
||||||
|
"path": "https://example.com/rack.xacro",
|
||||||
|
"children_mesh": "plate_96/meshes/plate_96.stl",
|
||||||
|
}
|
||||||
|
result = normalize_model_for_upload(model)
|
||||||
|
cm = result["children_mesh"]
|
||||||
|
self.assertEqual(cm["path"], "plate_96/meshes/plate_96.stl")
|
||||||
|
self.assertEqual(cm["format"], "stl")
|
||||||
|
|
||||||
|
def test_children_mesh_gltf_format(self):
|
||||||
|
"""children_mesh .glb 文件推断 format=gltf"""
|
||||||
|
model = {
|
||||||
|
"path": "https://example.com/rack.xacro",
|
||||||
|
"children_mesh": "meshes/child.glb",
|
||||||
|
}
|
||||||
|
result = normalize_model_for_upload(model)
|
||||||
|
self.assertEqual(result["children_mesh"]["format"], "gltf")
|
||||||
|
|
||||||
|
def test_children_mesh_partial_tf(self):
|
||||||
|
"""children_mesh_tf 只有 3 个值时只有 offset 无 rotation"""
|
||||||
|
model = {
|
||||||
|
"path": "https://example.com/rack.xacro",
|
||||||
|
"children_mesh": "tip.stl",
|
||||||
|
"children_mesh_tf": [0.01, 0.02, 0.03],
|
||||||
|
}
|
||||||
|
result = normalize_model_for_upload(model)
|
||||||
|
cm = result["children_mesh"]
|
||||||
|
self.assertEqual(cm["local_offset"], [0.01, 0.02, 0.03])
|
||||||
|
self.assertNotIn("local_rotation", cm)
|
||||||
|
|
||||||
|
def test_children_mesh_no_tf(self):
|
||||||
|
"""children_mesh 无 tf 时不加 offset/rotation"""
|
||||||
|
model = {
|
||||||
|
"path": "https://example.com/rack.xacro",
|
||||||
|
"children_mesh": "tip.stl",
|
||||||
|
}
|
||||||
|
result = normalize_model_for_upload(model)
|
||||||
|
cm = result["children_mesh"]
|
||||||
|
self.assertNotIn("local_offset", cm)
|
||||||
|
self.assertNotIn("local_rotation", cm)
|
||||||
|
|
||||||
|
def test_children_mesh_already_dict(self):
|
||||||
|
"""children_mesh 已经是 dict 时不重新映射"""
|
||||||
|
model = {
|
||||||
|
"path": "https://example.com/rack.xacro",
|
||||||
|
"children_mesh": {
|
||||||
|
"path": "https://example.com/tip.stl",
|
||||||
|
"format": "stl",
|
||||||
|
"default_visible": False,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
result = normalize_model_for_upload(model)
|
||||||
|
cm = result["children_mesh"]
|
||||||
|
self.assertIsInstance(cm, dict)
|
||||||
|
self.assertFalse(cm["default_visible"])
|
||||||
|
|
||||||
|
def test_original_not_mutated(self):
|
||||||
|
"""原始 dict 不被修改"""
|
||||||
|
original = {
|
||||||
|
"path": "https://example.com/model.xacro",
|
||||||
|
"mesh": "arm",
|
||||||
|
}
|
||||||
|
original_copy = {**original}
|
||||||
|
normalize_model_for_upload(original)
|
||||||
|
self.assertEqual(original, original_copy)
|
||||||
|
|
||||||
|
def test_preserves_existing_fields(self):
|
||||||
|
"""所有原始字段都被保留"""
|
||||||
|
model = {
|
||||||
|
"path": "https://example.com/model.xacro",
|
||||||
|
"mesh": "arm_slider",
|
||||||
|
"type": "device",
|
||||||
|
"mesh_tf": [0, 0, 0, 0, 0, 0],
|
||||||
|
"custom_field": "should_survive",
|
||||||
|
}
|
||||||
|
result = normalize_model_for_upload(model)
|
||||||
|
self.assertEqual(result["custom_field"], "should_survive")
|
||||||
|
self.assertEqual(result["mesh_tf"], [0, 0, 0, 0, 0, 0])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
221
tests/app/test_model_upload.py
Normal file
221
tests/app/test_model_upload.py
Normal file
@@ -0,0 +1,221 @@
|
|||||||
|
"""model_upload.py 单元测试(upload_device_model / download_model_from_oss)"""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
import tempfile
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||||
|
|
||||||
|
from unilabos.app.model_upload import (
|
||||||
|
upload_device_model,
|
||||||
|
download_model_from_oss,
|
||||||
|
_MODEL_EXTENSIONS,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestUploadDeviceModel(unittest.TestCase):
|
||||||
|
"""测试本地模型文件上传到 OSS"""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.tmp_dir = tempfile.mkdtemp()
|
||||||
|
self.mock_client = MagicMock()
|
||||||
|
|
||||||
|
def _create_model_files(self, subdir: str, filenames: list[str]):
|
||||||
|
"""在临时目录中创建设备模型文件"""
|
||||||
|
model_dir = Path(self.tmp_dir) / "devices" / subdir
|
||||||
|
model_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
for name in filenames:
|
||||||
|
p = model_dir / name
|
||||||
|
p.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
p.write_text("dummy content")
|
||||||
|
return model_dir
|
||||||
|
|
||||||
|
@patch("unilabos.app.model_upload._MESH_BASE_DIR")
|
||||||
|
def test_upload_success(self, mock_base):
|
||||||
|
"""正常上传流程"""
|
||||||
|
mock_base.__truediv__ = lambda self, x: Path(self.tmp_dir) / x
|
||||||
|
# 直接 patch _MESH_BASE_DIR 为 Path(tmp_dir)
|
||||||
|
with patch("unilabos.app.model_upload._MESH_BASE_DIR", Path(self.tmp_dir)):
|
||||||
|
self._create_model_files("arm_slider", ["macro_device.xacro", "meshes/link1.stl"])
|
||||||
|
|
||||||
|
self.mock_client.get_model_upload_urls.return_value = {
|
||||||
|
"files": [
|
||||||
|
{"name": "macro_device.xacro", "upload_url": "https://oss.example.com/put1"},
|
||||||
|
{"name": "meshes/link1.stl", "upload_url": "https://oss.example.com/put2"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
self.mock_client.publish_model.return_value = {
|
||||||
|
"path": "https://oss.example.com/arm_slider/macro_device.xacro"
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch("unilabos.app.model_upload._put_upload") as mock_put:
|
||||||
|
result = upload_device_model(
|
||||||
|
http_client=self.mock_client,
|
||||||
|
template_uuid="test-uuid",
|
||||||
|
mesh_name="arm_slider",
|
||||||
|
model_type="device",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(result, "https://oss.example.com/arm_slider/macro_device.xacro")
|
||||||
|
self.mock_client.get_model_upload_urls.assert_called_once()
|
||||||
|
self.mock_client.publish_model.assert_called_once()
|
||||||
|
|
||||||
|
@patch("unilabos.app.model_upload._MESH_BASE_DIR")
|
||||||
|
def test_upload_dir_not_exists(self, mock_base):
|
||||||
|
"""本地目录不存在时返回 None"""
|
||||||
|
with patch("unilabos.app.model_upload._MESH_BASE_DIR", Path(self.tmp_dir)):
|
||||||
|
result = upload_device_model(
|
||||||
|
http_client=self.mock_client,
|
||||||
|
template_uuid="test-uuid",
|
||||||
|
mesh_name="nonexistent",
|
||||||
|
model_type="device",
|
||||||
|
)
|
||||||
|
self.assertIsNone(result)
|
||||||
|
|
||||||
|
@patch("unilabos.app.model_upload._MESH_BASE_DIR")
|
||||||
|
def test_upload_no_valid_files(self, mock_base):
|
||||||
|
"""目录中无有效模型文件时返回 None"""
|
||||||
|
with patch("unilabos.app.model_upload._MESH_BASE_DIR", Path(self.tmp_dir)):
|
||||||
|
model_dir = Path(self.tmp_dir) / "devices" / "empty_model"
|
||||||
|
model_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
(model_dir / "readme.txt").write_text("not a model")
|
||||||
|
|
||||||
|
result = upload_device_model(
|
||||||
|
http_client=self.mock_client,
|
||||||
|
template_uuid="test-uuid",
|
||||||
|
mesh_name="empty_model",
|
||||||
|
model_type="device",
|
||||||
|
)
|
||||||
|
self.assertIsNone(result)
|
||||||
|
|
||||||
|
@patch("unilabos.app.model_upload._MESH_BASE_DIR")
|
||||||
|
def test_upload_urls_failure(self, mock_base):
|
||||||
|
"""获取上传 URL 失败时返回 None"""
|
||||||
|
with patch("unilabos.app.model_upload._MESH_BASE_DIR", Path(self.tmp_dir)):
|
||||||
|
self._create_model_files("arm", ["device.xacro"])
|
||||||
|
self.mock_client.get_model_upload_urls.return_value = None
|
||||||
|
|
||||||
|
result = upload_device_model(
|
||||||
|
http_client=self.mock_client,
|
||||||
|
template_uuid="test-uuid",
|
||||||
|
mesh_name="arm",
|
||||||
|
model_type="device",
|
||||||
|
)
|
||||||
|
self.assertIsNone(result)
|
||||||
|
|
||||||
|
|
||||||
|
class TestDownloadModelFromOss(unittest.TestCase):
|
||||||
|
"""测试从 OSS 下载模型文件到本地"""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.tmp_dir = tempfile.mkdtemp()
|
||||||
|
|
||||||
|
def test_skip_no_mesh_name(self):
|
||||||
|
"""缺少 mesh 名称时跳过"""
|
||||||
|
result = download_model_from_oss({"type": "device", "path": "https://x.com/a.xacro"})
|
||||||
|
self.assertFalse(result)
|
||||||
|
|
||||||
|
def test_skip_no_oss_path(self):
|
||||||
|
"""缺少 OSS path 时跳过"""
|
||||||
|
result = download_model_from_oss({"mesh": "arm", "type": "device"})
|
||||||
|
self.assertFalse(result)
|
||||||
|
|
||||||
|
def test_skip_local_path(self):
|
||||||
|
"""非 https:// 路径时跳过"""
|
||||||
|
result = download_model_from_oss({
|
||||||
|
"mesh": "arm",
|
||||||
|
"type": "device",
|
||||||
|
"path": "file:///local/model.xacro",
|
||||||
|
})
|
||||||
|
self.assertFalse(result)
|
||||||
|
|
||||||
|
def test_already_exists(self):
|
||||||
|
"""本地已有文件时跳过下载"""
|
||||||
|
device_dir = Path(self.tmp_dir) / "devices" / "arm"
|
||||||
|
device_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
(device_dir / "model.xacro").write_text("existing")
|
||||||
|
|
||||||
|
result = download_model_from_oss(
|
||||||
|
{"mesh": "arm", "type": "device", "path": "https://oss.example.com/model.xacro"},
|
||||||
|
mesh_base_dir=Path(self.tmp_dir),
|
||||||
|
)
|
||||||
|
self.assertTrue(result)
|
||||||
|
|
||||||
|
@patch("unilabos.app.model_upload._download_file")
|
||||||
|
def test_download_device(self, mock_download):
|
||||||
|
"""下载 device 模型到 devices/ 目录"""
|
||||||
|
result = download_model_from_oss(
|
||||||
|
{"mesh": "new_arm", "type": "device", "path": "https://oss.example.com/new_arm/macro_device.xacro"},
|
||||||
|
mesh_base_dir=Path(self.tmp_dir),
|
||||||
|
)
|
||||||
|
self.assertTrue(result)
|
||||||
|
mock_download.assert_called_once()
|
||||||
|
call_args = mock_download.call_args
|
||||||
|
self.assertIn("macro_device.xacro", str(call_args[0][1]))
|
||||||
|
|
||||||
|
@patch("unilabos.app.model_upload._download_file")
|
||||||
|
def test_download_resource(self, mock_download):
|
||||||
|
"""下载 resource 模型到 resources/ 目录"""
|
||||||
|
result = download_model_from_oss(
|
||||||
|
{
|
||||||
|
"mesh": "plate_96/meshes/plate_96.stl",
|
||||||
|
"type": "resource",
|
||||||
|
"path": "https://oss.example.com/plate_96/modal.xacro",
|
||||||
|
},
|
||||||
|
mesh_base_dir=Path(self.tmp_dir),
|
||||||
|
)
|
||||||
|
self.assertTrue(result)
|
||||||
|
target_dir = Path(self.tmp_dir) / "resources" / "plate_96"
|
||||||
|
self.assertTrue(target_dir.exists())
|
||||||
|
|
||||||
|
@patch("unilabos.app.model_upload._download_file")
|
||||||
|
def test_download_with_children_mesh(self, mock_download):
|
||||||
|
"""下载包含 children_mesh 的模型"""
|
||||||
|
result = download_model_from_oss(
|
||||||
|
{
|
||||||
|
"mesh": "tip_rack",
|
||||||
|
"type": "device",
|
||||||
|
"path": "https://oss.example.com/tip_rack/model.xacro",
|
||||||
|
"children_mesh": {
|
||||||
|
"path": "https://oss.example.com/tip_rack/meshes/tip.stl",
|
||||||
|
"format": "stl",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
mesh_base_dir=Path(self.tmp_dir),
|
||||||
|
)
|
||||||
|
self.assertTrue(result)
|
||||||
|
# 应调用两次:入口文件 + children_mesh
|
||||||
|
self.assertEqual(mock_download.call_count, 2)
|
||||||
|
|
||||||
|
@patch("unilabos.app.model_upload._download_file", side_effect=Exception("network error"))
|
||||||
|
def test_download_failure_graceful(self, mock_download):
|
||||||
|
"""下载失败时返回 False(不抛异常)"""
|
||||||
|
result = download_model_from_oss(
|
||||||
|
{"mesh": "broken", "type": "device", "path": "https://oss.example.com/broken.xacro"},
|
||||||
|
mesh_base_dir=Path(self.tmp_dir),
|
||||||
|
)
|
||||||
|
self.assertFalse(result)
|
||||||
|
|
||||||
|
|
||||||
|
class TestModelExtensions(unittest.TestCase):
|
||||||
|
"""测试支持的模型文件后缀集合"""
|
||||||
|
|
||||||
|
def test_standard_extensions(self):
|
||||||
|
"""确认标准 3D 格式在支持列表中"""
|
||||||
|
expected = {".stl", ".gltf", ".glb", ".xacro", ".urdf", ".obj", ".dae"}
|
||||||
|
for ext in expected:
|
||||||
|
self.assertIn(ext, _MODEL_EXTENSIONS, f"{ext} should be supported")
|
||||||
|
|
||||||
|
def test_non_model_excluded(self):
|
||||||
|
"""非模型文件后缀不在列表中"""
|
||||||
|
excluded = {".txt", ".json", ".py", ".png", ".jpg"}
|
||||||
|
for ext in excluded:
|
||||||
|
self.assertNotIn(ext, _MODEL_EXTENSIONS, f"{ext} should not be supported")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
170
tests/app/test_normalize_model.py
Normal file
170
tests/app/test_normalize_model.py
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
"""normalize_model_for_upload 单元测试"""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
# 添加项目根目录到 sys.path
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||||
|
|
||||||
|
from unilabos.app.register import normalize_model_for_upload
|
||||||
|
|
||||||
|
|
||||||
|
class TestNormalizeModelForUpload(unittest.TestCase):
|
||||||
|
"""测试 Registry YAML model 字段标准化"""
|
||||||
|
|
||||||
|
def test_empty_input(self):
|
||||||
|
"""空 dict 直接返回"""
|
||||||
|
self.assertEqual(normalize_model_for_upload({}), {})
|
||||||
|
self.assertIsNone(normalize_model_for_upload(None))
|
||||||
|
|
||||||
|
def test_format_infer_xacro(self):
|
||||||
|
"""自动从 path 后缀推断 format=xacro"""
|
||||||
|
model = {
|
||||||
|
"path": "https://oss.example.com/devices/arm/macro_device.xacro",
|
||||||
|
"mesh": "arm_slider",
|
||||||
|
"type": "device",
|
||||||
|
}
|
||||||
|
result = normalize_model_for_upload(model)
|
||||||
|
self.assertEqual(result["format"], "xacro")
|
||||||
|
|
||||||
|
def test_format_infer_urdf(self):
|
||||||
|
"""自动推断 format=urdf"""
|
||||||
|
model = {"path": "https://example.com/robot.urdf", "type": "device"}
|
||||||
|
result = normalize_model_for_upload(model)
|
||||||
|
self.assertEqual(result["format"], "urdf")
|
||||||
|
|
||||||
|
def test_format_infer_stl(self):
|
||||||
|
"""自动推断 format=stl"""
|
||||||
|
model = {"path": "https://example.com/part.stl"}
|
||||||
|
result = normalize_model_for_upload(model)
|
||||||
|
self.assertEqual(result["format"], "stl")
|
||||||
|
|
||||||
|
def test_format_infer_gltf(self):
|
||||||
|
"""自动推断 format=gltf(.gltf 和 .glb)"""
|
||||||
|
for ext in [".gltf", ".glb"]:
|
||||||
|
model = {"path": f"https://example.com/model{ext}"}
|
||||||
|
result = normalize_model_for_upload(model)
|
||||||
|
self.assertEqual(result["format"], "gltf", f"failed for {ext}")
|
||||||
|
|
||||||
|
def test_format_not_overwritten(self):
|
||||||
|
"""已有 format 字段时不覆盖"""
|
||||||
|
model = {
|
||||||
|
"path": "https://example.com/model.xacro",
|
||||||
|
"format": "custom",
|
||||||
|
}
|
||||||
|
result = normalize_model_for_upload(model)
|
||||||
|
self.assertEqual(result["format"], "custom")
|
||||||
|
|
||||||
|
def test_format_no_path(self):
|
||||||
|
"""没有 path 时不推断 format"""
|
||||||
|
model = {"mesh": "arm_slider", "type": "device"}
|
||||||
|
result = normalize_model_for_upload(model)
|
||||||
|
self.assertNotIn("format", result)
|
||||||
|
|
||||||
|
def test_children_mesh_string_to_struct(self):
|
||||||
|
"""将 children_mesh 字符串(旧格式)转为结构化对象"""
|
||||||
|
model = {
|
||||||
|
"path": "https://example.com/rack.xacro",
|
||||||
|
"type": "resource",
|
||||||
|
"children_mesh": "tip/meshes/tip.stl",
|
||||||
|
"children_mesh_tf": [0.0045, 0.0045, 0, 0, 0, 1.57],
|
||||||
|
"children_mesh_path": "https://oss.example.com/tip.stl",
|
||||||
|
}
|
||||||
|
result = normalize_model_for_upload(model)
|
||||||
|
|
||||||
|
cm = result["children_mesh"]
|
||||||
|
self.assertIsInstance(cm, dict)
|
||||||
|
self.assertEqual(cm["path"], "https://oss.example.com/tip.stl")
|
||||||
|
self.assertEqual(cm["format"], "stl")
|
||||||
|
self.assertTrue(cm["default_visible"])
|
||||||
|
self.assertEqual(cm["local_offset"], [0.0045, 0.0045, 0])
|
||||||
|
self.assertEqual(cm["local_rotation"], [0, 0, 1.57])
|
||||||
|
|
||||||
|
self.assertNotIn("children_mesh_tf", result)
|
||||||
|
self.assertNotIn("children_mesh_path", result)
|
||||||
|
|
||||||
|
def test_children_mesh_no_oss_fallback(self):
|
||||||
|
"""children_mesh 无 OSS URL 时 fallback 到本地路径"""
|
||||||
|
model = {
|
||||||
|
"path": "https://example.com/rack.xacro",
|
||||||
|
"children_mesh": "plate_96/meshes/plate_96.stl",
|
||||||
|
}
|
||||||
|
result = normalize_model_for_upload(model)
|
||||||
|
cm = result["children_mesh"]
|
||||||
|
self.assertEqual(cm["path"], "plate_96/meshes/plate_96.stl")
|
||||||
|
self.assertEqual(cm["format"], "stl")
|
||||||
|
|
||||||
|
def test_children_mesh_gltf_format(self):
|
||||||
|
"""children_mesh .glb 文件推断 format=gltf"""
|
||||||
|
model = {
|
||||||
|
"path": "https://example.com/rack.xacro",
|
||||||
|
"children_mesh": "meshes/child.glb",
|
||||||
|
}
|
||||||
|
result = normalize_model_for_upload(model)
|
||||||
|
self.assertEqual(result["children_mesh"]["format"], "gltf")
|
||||||
|
|
||||||
|
def test_children_mesh_partial_tf(self):
|
||||||
|
"""children_mesh_tf 只有 3 个值时只有 offset 无 rotation"""
|
||||||
|
model = {
|
||||||
|
"path": "https://example.com/rack.xacro",
|
||||||
|
"children_mesh": "tip.stl",
|
||||||
|
"children_mesh_tf": [0.01, 0.02, 0.03],
|
||||||
|
}
|
||||||
|
result = normalize_model_for_upload(model)
|
||||||
|
cm = result["children_mesh"]
|
||||||
|
self.assertEqual(cm["local_offset"], [0.01, 0.02, 0.03])
|
||||||
|
self.assertNotIn("local_rotation", cm)
|
||||||
|
|
||||||
|
def test_children_mesh_no_tf(self):
|
||||||
|
"""children_mesh 无 tf 时不加 offset/rotation"""
|
||||||
|
model = {
|
||||||
|
"path": "https://example.com/rack.xacro",
|
||||||
|
"children_mesh": "tip.stl",
|
||||||
|
}
|
||||||
|
result = normalize_model_for_upload(model)
|
||||||
|
cm = result["children_mesh"]
|
||||||
|
self.assertNotIn("local_offset", cm)
|
||||||
|
self.assertNotIn("local_rotation", cm)
|
||||||
|
|
||||||
|
def test_children_mesh_already_dict(self):
|
||||||
|
"""children_mesh 已经是 dict 时不重新映射"""
|
||||||
|
model = {
|
||||||
|
"path": "https://example.com/rack.xacro",
|
||||||
|
"children_mesh": {
|
||||||
|
"path": "https://example.com/tip.stl",
|
||||||
|
"format": "stl",
|
||||||
|
"default_visible": False,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
result = normalize_model_for_upload(model)
|
||||||
|
cm = result["children_mesh"]
|
||||||
|
self.assertIsInstance(cm, dict)
|
||||||
|
self.assertFalse(cm["default_visible"])
|
||||||
|
|
||||||
|
def test_original_not_mutated(self):
|
||||||
|
"""原始 dict 不被修改"""
|
||||||
|
original = {
|
||||||
|
"path": "https://example.com/model.xacro",
|
||||||
|
"mesh": "arm",
|
||||||
|
}
|
||||||
|
original_copy = {**original}
|
||||||
|
normalize_model_for_upload(original)
|
||||||
|
self.assertEqual(original, original_copy)
|
||||||
|
|
||||||
|
def test_preserves_existing_fields(self):
|
||||||
|
"""所有原始字段都被保留"""
|
||||||
|
model = {
|
||||||
|
"path": "https://example.com/model.xacro",
|
||||||
|
"mesh": "arm_slider",
|
||||||
|
"type": "device",
|
||||||
|
"mesh_tf": [0, 0, 0, 0, 0, 0],
|
||||||
|
"custom_field": "should_survive",
|
||||||
|
}
|
||||||
|
result = normalize_model_for_upload(model)
|
||||||
|
self.assertEqual(result["custom_field"], "should_survive")
|
||||||
|
self.assertEqual(result["mesh_tf"], [0, 0, 0, 0, 0, 0])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
186
unilabos/app/model_upload.py
Normal file
186
unilabos/app/model_upload.py
Normal file
@@ -0,0 +1,186 @@
|
|||||||
|
"""模型文件上传/下载管理。
|
||||||
|
|
||||||
|
提供 Edge 端本地模型文件与 OSS 之间的双向同步:
|
||||||
|
- upload_device_model: 本地模型 → OSS(Edge 首次接入时)
|
||||||
|
- download_model_from_oss: OSS → 本地(新 Edge 加入已有 Lab 时)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from unilabos.utils.log import logger
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from unilabos.app.web.client import HTTPClient
|
||||||
|
|
||||||
|
# 设备 mesh 根目录
|
||||||
|
_MESH_BASE_DIR = Path(__file__).parent.parent / "device_mesh"
|
||||||
|
|
||||||
|
# 支持的模型文件后缀
|
||||||
|
_MODEL_EXTENSIONS = frozenset({
|
||||||
|
".xacro", ".urdf", ".stl", ".dae", ".obj",
|
||||||
|
".gltf", ".glb", ".yaml", ".yml",
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
def upload_device_model(
|
||||||
|
http_client: "HTTPClient",
|
||||||
|
template_uuid: str,
|
||||||
|
mesh_name: str,
|
||||||
|
model_type: str,
|
||||||
|
version: str = "1.0.0",
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""上传本地模型文件到 OSS,返回入口文件的 OSS URL。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
http_client: HTTPClient 实例
|
||||||
|
template_uuid: 设备模板 UUID
|
||||||
|
mesh_name: mesh 目录名(如 "arm_slider")
|
||||||
|
model_type: "device" 或 "resource"
|
||||||
|
version: 模型版本
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
入口文件 OSS URL,上传失败返回 None
|
||||||
|
"""
|
||||||
|
if model_type == "device":
|
||||||
|
model_dir = _MESH_BASE_DIR / "devices" / mesh_name
|
||||||
|
else:
|
||||||
|
model_dir = _MESH_BASE_DIR / "resources" / mesh_name
|
||||||
|
|
||||||
|
if not model_dir.exists():
|
||||||
|
logger.warning(f"[模型上传] 本地目录不存在: {model_dir}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 收集所有需要上传的文件
|
||||||
|
files = []
|
||||||
|
for f in model_dir.rglob("*"):
|
||||||
|
if f.is_file() and f.suffix.lower() in _MODEL_EXTENSIONS:
|
||||||
|
files.append({
|
||||||
|
"name": str(f.relative_to(model_dir)),
|
||||||
|
"size_kb": f.stat().st_size // 1024,
|
||||||
|
})
|
||||||
|
|
||||||
|
if not files:
|
||||||
|
logger.warning(f"[模型上传] 目录中无可上传的模型文件: {model_dir}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 1. 获取预签名上传 URL
|
||||||
|
upload_urls_resp = http_client.get_model_upload_urls(
|
||||||
|
template_uuid=template_uuid,
|
||||||
|
files=[{"name": f["name"], "version": version} for f in files],
|
||||||
|
)
|
||||||
|
if not upload_urls_resp:
|
||||||
|
return None
|
||||||
|
|
||||||
|
url_items = upload_urls_resp.get("files", [])
|
||||||
|
|
||||||
|
# 2. 逐个上传文件
|
||||||
|
for file_info, url_info in zip(files, url_items):
|
||||||
|
local_path = model_dir / file_info["name"]
|
||||||
|
upload_url = url_info.get("upload_url", "")
|
||||||
|
if not upload_url:
|
||||||
|
continue
|
||||||
|
_put_upload(local_path, upload_url)
|
||||||
|
|
||||||
|
# 3. 确认发布
|
||||||
|
entry_file = "macro_device.xacro" if model_type == "device" else "modal.xacro"
|
||||||
|
# 检查入口文件是否存在,使用实际存在的文件名
|
||||||
|
for f in files:
|
||||||
|
if f["name"].endswith(".xacro"):
|
||||||
|
entry_file = f["name"]
|
||||||
|
break
|
||||||
|
|
||||||
|
publish_resp = http_client.publish_model(
|
||||||
|
template_uuid=template_uuid,
|
||||||
|
version=version,
|
||||||
|
entry_file=entry_file,
|
||||||
|
)
|
||||||
|
return publish_resp.get("path") if publish_resp else None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[模型上传] 上传失败 ({mesh_name}): {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def download_model_from_oss(
|
||||||
|
model_config: dict,
|
||||||
|
mesh_base_dir: Optional[Path] = None,
|
||||||
|
) -> bool:
|
||||||
|
"""检查本地模型文件是否存在,不存在则从 OSS 下载。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_config: 节点的 model 配置字典
|
||||||
|
mesh_base_dir: mesh 根目录,默认使用 device_mesh/
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True 表示本地文件就绪,False 表示下载失败或无需下载
|
||||||
|
"""
|
||||||
|
if mesh_base_dir is None:
|
||||||
|
mesh_base_dir = _MESH_BASE_DIR
|
||||||
|
|
||||||
|
mesh_name = model_config.get("mesh", "")
|
||||||
|
model_type = model_config.get("type", "")
|
||||||
|
oss_path = model_config.get("path", "")
|
||||||
|
|
||||||
|
if not mesh_name or not oss_path or not oss_path.startswith("https://"):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 确定本地目标目录
|
||||||
|
if model_type == "device":
|
||||||
|
local_dir = mesh_base_dir / "devices" / mesh_name
|
||||||
|
elif model_type == "resource":
|
||||||
|
resource_name = mesh_name.split("/")[0]
|
||||||
|
local_dir = mesh_base_dir / "resources" / resource_name
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 已有本地文件 → 跳过
|
||||||
|
if local_dir.exists() and any(local_dir.iterdir()):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 从 OSS 下载
|
||||||
|
local_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
try:
|
||||||
|
# 下载入口文件(OSS URL 通常直接可访问)
|
||||||
|
entry_name = oss_path.rsplit("/", 1)[-1]
|
||||||
|
_download_file(oss_path, local_dir / entry_name)
|
||||||
|
|
||||||
|
# 如果有 children_mesh,也下载
|
||||||
|
children_mesh = model_config.get("children_mesh")
|
||||||
|
if isinstance(children_mesh, dict) and children_mesh.get("path"):
|
||||||
|
cm_path = children_mesh["path"]
|
||||||
|
if cm_path.startswith("https://"):
|
||||||
|
cm_name = cm_path.rsplit("/", 1)[-1]
|
||||||
|
meshes_dir = local_dir / "meshes"
|
||||||
|
meshes_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
_download_file(cm_path, meshes_dir / cm_name)
|
||||||
|
|
||||||
|
logger.info(f"[模型下载] 成功下载模型到本地: {mesh_name} → {local_dir}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[模型下载] 下载失败 ({mesh_name}): {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _put_upload(local_path: Path, upload_url: str) -> None:
|
||||||
|
"""通过预签名 URL 上传文件到 OSS。"""
|
||||||
|
with open(local_path, "rb") as f:
|
||||||
|
resp = requests.put(upload_url, data=f, timeout=120)
|
||||||
|
resp.raise_for_status()
|
||||||
|
logger.debug(f"[模型上传] 已上传: {local_path.name}")
|
||||||
|
|
||||||
|
|
||||||
|
def _download_file(url: str, local_path: Path) -> None:
|
||||||
|
"""下载单个文件到本地路径。"""
|
||||||
|
local_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
resp = requests.get(url, timeout=60)
|
||||||
|
resp.raise_for_status()
|
||||||
|
local_path.write_bytes(resp.content)
|
||||||
|
logger.debug(f"[模型下载] 已下载: {local_path}")
|
||||||
@@ -5,6 +5,48 @@ from unilabos.utils.log import logger
|
|||||||
from unilabos.utils.tools import normalize_json as _normalize_device
|
from unilabos.utils.tools import normalize_json as _normalize_device
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_model_for_upload(model_dict: dict) -> dict:
|
||||||
|
"""将 Registry YAML 的 model 字段映射为后端 DeviceModel 结构化格式。
|
||||||
|
|
||||||
|
保留所有原始字段,额外做以下标准化:
|
||||||
|
1. 自动推断 format(如果 YAML 未指定)
|
||||||
|
2. 将 children_mesh 扁平字段映射为结构化 children_mesh 对象
|
||||||
|
"""
|
||||||
|
if not model_dict:
|
||||||
|
return model_dict
|
||||||
|
|
||||||
|
result = {**model_dict}
|
||||||
|
|
||||||
|
# 自动推断 format
|
||||||
|
if "format" not in result and result.get("path"):
|
||||||
|
path = result["path"]
|
||||||
|
if path.endswith(".xacro"):
|
||||||
|
result["format"] = "xacro"
|
||||||
|
elif path.endswith(".urdf"):
|
||||||
|
result["format"] = "urdf"
|
||||||
|
elif path.endswith(".stl"):
|
||||||
|
result["format"] = "stl"
|
||||||
|
elif path.endswith((".gltf", ".glb")):
|
||||||
|
result["format"] = "gltf"
|
||||||
|
|
||||||
|
# 将 children_mesh 扁平字段 → 结构化 children_mesh 对象
|
||||||
|
if "children_mesh" in result and isinstance(result["children_mesh"], str):
|
||||||
|
cm_path = result.pop("children_mesh")
|
||||||
|
cm_tf = result.pop("children_mesh_tf", None)
|
||||||
|
cm_oss = result.pop("children_mesh_path", None)
|
||||||
|
result["children_mesh"] = {
|
||||||
|
"path": cm_oss or cm_path,
|
||||||
|
"format": "stl" if cm_path.endswith(".stl") else "gltf",
|
||||||
|
"default_visible": True,
|
||||||
|
}
|
||||||
|
if cm_tf and len(cm_tf) >= 3:
|
||||||
|
result["children_mesh"]["local_offset"] = cm_tf[:3]
|
||||||
|
if cm_tf and len(cm_tf) >= 6:
|
||||||
|
result["children_mesh"]["local_rotation"] = cm_tf[3:6]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def register_devices_and_resources(lab_registry, gather_only=False) -> Optional[Tuple[Dict[str, Any], Dict[str, Any]]]:
|
def register_devices_and_resources(lab_registry, gather_only=False) -> Optional[Tuple[Dict[str, Any], Dict[str, Any]]]:
|
||||||
"""
|
"""
|
||||||
注册设备和资源到服务器(仅支持HTTP)
|
注册设备和资源到服务器(仅支持HTTP)
|
||||||
@@ -16,11 +58,18 @@ def register_devices_and_resources(lab_registry, gather_only=False) -> Optional[
|
|||||||
|
|
||||||
devices_to_register = {}
|
devices_to_register = {}
|
||||||
for device_info in lab_registry.obtain_registry_device_info():
|
for device_info in lab_registry.obtain_registry_device_info():
|
||||||
devices_to_register[device_info["id"]] = _normalize_device(device_info)
|
normalized = _normalize_device(device_info)
|
||||||
|
# 标准化 model 字段
|
||||||
|
if normalized.get("model"):
|
||||||
|
normalized["model"] = normalize_model_for_upload(normalized["model"])
|
||||||
|
devices_to_register[device_info["id"]] = normalized
|
||||||
logger.trace(f"[UniLab Register] 收集设备: {device_info['id']}")
|
logger.trace(f"[UniLab Register] 收集设备: {device_info['id']}")
|
||||||
|
|
||||||
resources_to_register = {}
|
resources_to_register = {}
|
||||||
for resource_info in lab_registry.obtain_registry_resource_info():
|
for resource_info in lab_registry.obtain_registry_resource_info():
|
||||||
|
# 标准化 model 字段
|
||||||
|
if resource_info.get("model"):
|
||||||
|
resource_info["model"] = normalize_model_for_upload(resource_info["model"])
|
||||||
resources_to_register[resource_info["id"]] = resource_info
|
resources_to_register[resource_info["id"]] = resource_info
|
||||||
logger.trace(f"[UniLab Register] 收集资源: {resource_info['id']}")
|
logger.trace(f"[UniLab Register] 收集资源: {resource_info['id']}")
|
||||||
|
|
||||||
|
|||||||
@@ -468,6 +468,63 @@ class HTTPClient:
|
|||||||
logger.error(f"发布工作流失败: {response.status_code}, {response.text}")
|
logger.error(f"发布工作流失败: {response.status_code}, {response.text}")
|
||||||
return {"code": response.status_code, "message": response.text}
|
return {"code": response.status_code, "message": response.text}
|
||||||
|
|
||||||
|
# ──────────────────── 模型资产管理 ────────────────────
|
||||||
|
|
||||||
|
def get_model_upload_urls(
|
||||||
|
self, template_uuid: str, files: list[dict],
|
||||||
|
) -> dict | None:
|
||||||
|
"""获取模型文件预签名上传 URL。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
template_uuid: 设备模板 UUID
|
||||||
|
files: 文件列表 [{"name": "...", "version": "1.0.0"}]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
{"files": [{"name": "...", "upload_url": "...", "path": "..."}]}
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
response = requests.post(
|
||||||
|
f"{self.remote_addr}/lab/square/template/{template_uuid}/model/upload-urls",
|
||||||
|
json={"files": files},
|
||||||
|
headers={"Authorization": f"Lab {self.auth}"},
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
if response.status_code == 200:
|
||||||
|
data = response.json().get("data")
|
||||||
|
return data
|
||||||
|
logger.error(f"获取模型上传 URL 失败: {response.status_code}, {response.text}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取模型上传 URL 异常: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def publish_model(
|
||||||
|
self, template_uuid: str, version: str, entry_file: str,
|
||||||
|
) -> dict | None:
|
||||||
|
"""确认模型上传完成,发布新版本。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
template_uuid: 设备模板 UUID
|
||||||
|
version: 模型版本
|
||||||
|
entry_file: 入口文件名
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
{"path": "...", "oss_dir": "...", "version": "..."}
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
response = requests.post(
|
||||||
|
f"{self.remote_addr}/lab/square/template/{template_uuid}/model/publish",
|
||||||
|
json={"version": version, "entry_file": entry_file},
|
||||||
|
headers={"Authorization": f"Lab {self.auth}"},
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
if response.status_code == 200:
|
||||||
|
data = response.json().get("data")
|
||||||
|
return data
|
||||||
|
logger.error(f"发布模型失败: {response.status_code}, {response.text}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"发布模型异常: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
# 创建默认客户端实例
|
# 创建默认客户端实例
|
||||||
http_client = HTTPClient()
|
http_client = HTTPClient()
|
||||||
|
|||||||
Reference in New Issue
Block a user