在醫學影像 AI 模型的開發中,如何讓推論更快、更穩定且易於部署,一直是工程師需要面對的課題。本文將分享如何透過 Triton Inference Server 與 Python 程式碼整合,實現 心臟 CT 鈣化分析的推論流程,並解釋相關程式設計的重點。
為什麼要用 Triton?
Triton Inference Server(NVIDIA 開源專案)有以下優點:
-
支援多種框架:PyTorch、TensorFlow、ONNX、TensorRT。
-
統一 API:可以用 HTTP/gRPC 與伺服器溝通,不需要在 client 安裝所有框架。
-
高效能推論:支援 batch、動態 shape、自動最佳化。
-
方便部署:可用 Docker 容器化,輕鬆整合到生產環境。
在臨床研究或醫院部署場景中,Triton 可以讓模型統一管理與快速推論,避免每台機器都安裝龐大的深度學習框架。
Triton Client 實作
在 Python 端,我們建立了一個 get_triton_client
函式,專門負責和 Triton Server 溝通。
import tritonclient.http as httpclient import torch import numpy as np TRITON_URL = "10.18.6.6:8000" _triton_client = None def get_triton_client(model_name, input_shape, input_data): global _triton_client if _triton_client is None: _triton_client = httpclient.InferenceServerClient(url=TRITON_URL) # 準備輸入 inputs = [httpclient.InferInput("INPUT__0", input_shape, "FP32")] # 如果是 Tensor 就轉 numpy if isinstance(input_data, torch.Tensor): input_data = input_data.detach().cpu().numpy() elif not isinstance(input_data, np.ndarray): raise TypeError(f"input_data 必須是 numpy.ndarray 或 torch.Tensor, 但收到 {type(input_data)}") inputs[0].set_data_from_numpy(input_data.astype(np.float32)) outputs = [httpclient.InferRequestedOutput("OUTPUT__0")] # 呼叫 Triton 取得結果 print(f"##### 讀取 triton {model_name} 模型") response = _triton_client.infer(model_name, inputs=inputs, outputs=outputs) response = response.as_numpy("OUTPUT__0") return torch.from_numpy(response)
這段程式碼重點:
-
自動初始化 client(只建立一次連線)。
-
輸入格式統一:允許
torch.Tensor
或numpy.ndarray
。 -
明確的輸入/輸出名稱:與 Triton 模型設定一致(
INPUT__0
、OUTPUT__0
)。
將 Triton 整合進 nnUNet Trainer
在實際專案中,我們常常需要結合 nnUNet 模型 來做醫學影像的定位或分割。這邊我們建立一個 Triton 版本的 Trainer,繼承自 DS_nnUNetTrainer
:
from nnunetv2_ds_trainer import nnUNetTrainer class Triton_nnUNetTrainer(nnUNetTrainer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def get_predict_xyz_network(self, image): image_np = image.numpy() # 呼叫 Triton 進行推論 response = get_triton_client("nnunet3d", image_np.shape, image_np) return response
這樣就能在 nnUNet 的 pipeline 中直接切換 本地推論或 Triton 推論,而不需要改動整體架構。
實際應用:心臟鈣化檢測
我們的流程:
-
讀取 DICOM 影像,過濾出胸部 CT axial 切片。
-
nnUNet 取得心臟 ROI(透過 Triton 或本地 nnUNet)。
-
資料正規化與 resize。
-
送進分類模型
模型透過 Triton 部署,可以:
-
減少 Python 端重複載入模型的時間。
-
讓所有模型常駐記憶體,推論延遲降低。
小結
透過 Triton Inference Server,我們把原本需要在本地 Python 程式中管理的多個模型,統一交由伺服器端處理。這帶來幾個好處:
-
加速推論:避免重複載入模型。
-
架構清晰:Client 只負責傳送資料,Server 管理模型推論。
-
方便部署:不必每台機器都安裝完整的深度學習框架。
這種設計特別適合 醫學影像分析,例如心臟鈣化檢測、腫瘤分割、器官分類等場景。
未來若要新增模型,只需要更新 Triton 的模型倉庫,而不用修改 client 程式,讓維護更加容易。
Triton 模型設定與導出流程
在 Triton 中,每個模型都需要:
-
模型權重(
.pt
,.onnx
,.plan
等)。 -
設定檔案(
config.pbtxt
),描述輸入輸出格式。
目錄結構通常如下:
model_repository/ └── c4/ # 模型名稱 └── 1/ # 版本號 (必須是數字資料夾) └── model.pt # Pytorch 導出的模型 └── config.pbtxt # 模型設定檔
Step 1. 撰寫 config.pbtxt
以「四分類模型」為例(輸入 3D patch 96×96×96,輸出四分類機率):
name: "c4" platform: "pytorch_libtorch" max_batch_size: 1 input [ { name: "INPUT__0" data_type: TYPE_FP32 dims: [1, 96, 96, 96] # [channel, depth, height, width] } ] output [ { name: "OUTPUT__0" data_type: TYPE_FP32 dims: [4] # 四分類輸出 } ]
這裡要特別注意:
-
INPUT__0
與OUTPUT__0
必須和 client 程式一致。 -
dims
不包含 batch 維度,因為 Triton 會自動處理 batch。
Step 2. 導出 PyTorch 模型
假設你有一個 MyModel
類別(用 ResNet18 做四分類):
import torch from main_train_model import MyModel # 建立模型 model = MyModel(model_name="resnet18.a1_in1k", num_classes=4) model.load_state_dict(torch.load("pths/best_weights.pth")) model.eval() # 建立假輸入 (batch=1, channel=1, depth=96, height=96, width=96) dummy_input = torch.randn(1, 1, 96, 96, 96) # 導出 TorchScript (pt) traced_model = torch.jit.trace(model, dummy_input) torch.jit.save(traced_model, "model_repository/pts/model.pt")
這樣就能得到 Triton 可讀取的 TorchScript 模型 (.pt
)。
Step 3. 啟動 Triton
假設 model_repository/
在本機:
docker run --gpus=all --rm -it \ -p 8090:8000 -p 8091:8001 -p 8092:8002 \ -v /path/to/model_repository:/models \ nvcr.io/nvidia/tritonserver:23.08-py3 \ tritonserver --model-repository=/models
-
HTTP port =
8090
-
gRPC port =
8091
-
Metrics port =
8092
Step 4. 驗證
Client 端呼叫:
output = get_triton_client("c4", now_data_nor.shape, now_data_nor) print("Triton 輸出:", output)
這樣就完成了 從 PyTorch → TorchScript → Triton 部署 → Client 推論 的流程。
心得:
1- 你必須有可以讀取GPU的 docker 的環境,請注意 nvidia-smi 必須可以執行
2- 設定檔案是 [通道, X, Y, Z] 但是輸入時候請加上BATCH維度
3- 設定檔案的 max_batch_size: 1 至少為1,否則很容易錯亂
4- 掛載image時候注意模型位置,且模型必須叫做 model.pt
5- 模型輸出不能是 list ,可以是 tensor 或是 dict 包 tensor,所以如果有list 情況必須要調整模型輸出
留言板
歡迎留下建議與分享!希望一起交流!感恩!