在醫學影像 AI 模型的開發中,如何讓推論更快、更穩定且易於部署,一直是工程師需要面對的課題。本文將分享如何透過 Triton Inference Server 與 Python 程式碼整合,實現 心臟 CT 鈣化分析的推論流程,並解釋相關程式設計的重點。


為什麼要用 Triton?

Triton Inference Server(NVIDIA 開源專案)有以下優點:

  1. 支援多種框架:PyTorch、TensorFlow、ONNX、TensorRT。

  2. 統一 API:可以用 HTTP/gRPC 與伺服器溝通,不需要在 client 安裝所有框架。

  3. 高效能推論:支援 batch、動態 shape、自動最佳化。

  4. 方便部署:可用 Docker 容器化,輕鬆整合到生產環境。

在臨床研究或醫院部署場景中,Triton 可以讓模型統一管理與快速推論,避免每台機器都安裝龐大的深度學習框架。


Triton Client 實作

在 Python 端,我們建立了一個 get_triton_client 函式,專門負責和 Triton Server 溝通。

import tritonclient.http as httpclient
import torch
import numpy as np

TRITON_URL = "10.18.6.6:8000"
_triton_client = None

def get_triton_client(model_name, input_shape, input_data):
    global _triton_client
    if _triton_client is None:
        _triton_client = httpclient.InferenceServerClient(url=TRITON_URL)

    # 準備輸入
    inputs = [httpclient.InferInput("INPUT__0", input_shape, "FP32")]

    # 如果是 Tensor 就轉 numpy
    if isinstance(input_data, torch.Tensor):
        input_data = input_data.detach().cpu().numpy()
    elif not isinstance(input_data, np.ndarray):
        raise TypeError(f"input_data 必須是 numpy.ndarray 或 torch.Tensor, 但收到 {type(input_data)}")

    inputs[0].set_data_from_numpy(input_data.astype(np.float32))

    outputs = [httpclient.InferRequestedOutput("OUTPUT__0")]

    # 呼叫 Triton 取得結果
    print(f"##### 讀取 triton {model_name} 模型")
    response = _triton_client.infer(model_name, inputs=inputs, outputs=outputs)
    response = response.as_numpy("OUTPUT__0")

    return torch.from_numpy(response)

這段程式碼重點:

  • 自動初始化 client(只建立一次連線)。

  • 輸入格式統一:允許 torch.Tensornumpy.ndarray

  • 明確的輸入/輸出名稱:與 Triton 模型設定一致(INPUT__0OUTPUT__0)。


將 Triton 整合進 nnUNet Trainer

在實際專案中,我們常常需要結合 nnUNet 模型 來做醫學影像的定位或分割。這邊我們建立一個 Triton 版本的 Trainer,繼承自 DS_nnUNetTrainer

from nnunetv2_ds_trainer import nnUNetTrainer

class Triton_nnUNetTrainer(nnUNetTrainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def get_predict_xyz_network(self, image):
        image_np = image.numpy()
        # 呼叫 Triton 進行推論
        response = get_triton_client("nnunet3d", image_np.shape, image_np)
        return response

這樣就能在 nnUNet 的 pipeline 中直接切換 本地推論Triton 推論,而不需要改動整體架構。


實際應用:心臟鈣化檢測

我們的流程:

  1. 讀取 DICOM 影像,過濾出胸部 CT axial 切片。

  2. nnUNet 取得心臟 ROI(透過 Triton 或本地 nnUNet)。

  3. 資料正規化與 resize

  4. 送進分類模型

模型透過 Triton 部署,可以:

  • 減少 Python 端重複載入模型的時間。

  • 讓所有模型常駐記憶體,推論延遲降低。


小結

透過 Triton Inference Server,我們把原本需要在本地 Python 程式中管理的多個模型,統一交由伺服器端處理。這帶來幾個好處:

  • 加速推論:避免重複載入模型。

  • 架構清晰:Client 只負責傳送資料,Server 管理模型推論。

  • 方便部署:不必每台機器都安裝完整的深度學習框架。

這種設計特別適合 醫學影像分析,例如心臟鈣化檢測、腫瘤分割、器官分類等場景。

未來若要新增模型,只需要更新 Triton 的模型倉庫,而不用修改 client 程式,讓維護更加容易。

 

 


Triton 模型設定與導出流程

在 Triton 中,每個模型都需要:

  1. 模型權重.pt, .onnx, .plan 等)。

  2. 設定檔案config.pbtxt),描述輸入輸出格式。

目錄結構通常如下:

model_repository/
 └── c4/                      # 模型名稱
     └── 1/                   # 版本號 (必須是數字資料夾)
         └── model.pt         # Pytorch 導出的模型
     └── config.pbtxt         # 模型設定檔

Step 1. 撰寫 config.pbtxt

以「四分類模型」為例(輸入 3D patch 96×96×96,輸出四分類機率):

name: "c4"
platform: "pytorch_libtorch"

max_batch_size: 1

input [
  {
    name: "INPUT__0"
    data_type: TYPE_FP32
    dims: [1, 96, 96, 96]   # [channel, depth, height, width]
  }
]

output [
  {
    name: "OUTPUT__0"
    data_type: TYPE_FP32
    dims: [4]               # 四分類輸出
  }
]

這裡要特別注意:

  • INPUT__0OUTPUT__0 必須和 client 程式一致。

  • dims 不包含 batch 維度,因為 Triton 會自動處理 batch。


 

Step 2. 導出 PyTorch 模型

假設你有一個 MyModel 類別(用 ResNet18 做四分類):

import torch
from main_train_model import MyModel

# 建立模型
model = MyModel(model_name="resnet18.a1_in1k", num_classes=4)
model.load_state_dict(torch.load("pths/best_weights.pth"))
model.eval()

# 建立假輸入 (batch=1, channel=1, depth=96, height=96, width=96)
dummy_input = torch.randn(1, 1, 96, 96, 96)

# 導出 TorchScript (pt)
traced_model = torch.jit.trace(model, dummy_input)
torch.jit.save(traced_model, "model_repository/pts/model.pt")

這樣就能得到 Triton 可讀取的 TorchScript 模型 (.pt)


Step 3. 啟動 Triton

假設 model_repository/ 在本機:

docker run --gpus=all --rm -it \
   -p 8090:8000 -p 8091:8001 -p 8092:8002 \
   -v /path/to/model_repository:/models \
   nvcr.io/nvidia/tritonserver:23.08-py3 \
   tritonserver --model-repository=/models
  • HTTP port = 8090

  • gRPC port = 8091

  • Metrics port = 8092


Step 4. 驗證

Client 端呼叫:

output = get_triton_client("c4", now_data_nor.shape, now_data_nor)
print("Triton 輸出:", output)

這樣就完成了 從 PyTorch → TorchScript → Triton 部署 → Client 推論 的流程。





心得:

1- 你必須有可以讀取GPU的 docker 的環境,請注意 nvidia-smi 必須可以執行

2- 設定檔案是 [通道, X, Y, Z] 但是輸入時候請加上BATCH維度

3- 設定檔案的 max_batch_size: 1 至少為1,否則很容易錯亂

4- 掛載image時候注意模型位置,且模型必須叫做 model.pt

5- 模型輸出不能是 list ,可以是 tensor 或是 dict 包 tensor,所以如果有list 情況必須要調整模型輸出