PyTorch → TFLite Convert

cornpip
|2025. 11. 19. 21:16

PyTorch to TFLite Github

https://github.com/cornpip/pt_to_tflite

 

PyTorch to TFLite Docker Image

https://hub.docker.com/r/cornpip77/tf_213_converter

tensorflow:2.13.0-gpu 이미지 위에

`pip install torch torchvision onnx onnx-tf tensorflow-addons tensorflow-probability` 패키지를 설치한 이미지이다.

 

Exmaple

변환 예시로 mobilenet_v2를 tflite로 변환해 보자.

README에 MobileNet Test Sciprt를 따라가면 된다.

// torch model (mobilenet_v2-b0353104.pth)
258: 0.8303 (Samoyed)
259: 0.0699 (Pomeranian)
261: 0.0130 (keeshond)
231: 0.0108 (collie)
257: 0.0099 (Great Pyrenees)

// tflite model (mobilenet_v2.tflite)
258: 0.8692 (Samoyed)
259: 0.0498 (Pomeranian)
261: 0.0243 (keeshond)
257: 0.0163 (Great Pyrenees)
231: 0.0047 (collie)

// tflite no optimizer model (양자화X)
258: 0.8595 (Samoyed)
259: 0.0568 (Pomeranian)
261: 0.0252 (keeshond)
257: 0.0177 (Great Pyrenees)
231: 0.0060 (collie)

MobileNet Test Sciprt 따라 변환 후, 결과를 확인하면 추론 결과에 차이가 있다.

양자화를 하지 않아도 torch 모델과 추론 결과에 차이가 발생한다.

 

경험상, 학습한 모델마다 다른 유형을 보인다.

어떤 모델은 양자화 후 GT 정확도가 60~70 퍼까지 떨어지는 경우도 있고,

반대로 GT 정확도는 거의 감소하지 않는 모델도 있다. 물론 이 경우에도 결과 확률 값에는 차이가 있다.

 

변환한 TFLite를 다른 플랫폼에서 사용하기 전에

변환 한 tflite를 다른 플랫폼에서 사용하기 전에, 먼저 파이썬 환경에서 TFLite 추론을 검증해 보는 것을 권장한다.

이때 GT 정확도에 차이가 크다면 뭔가 틀렸다 생각하고 디버깅을 해봐야 한다.(확률 값의 차이는 GT 다음 확인)

  • build_model이 동일하지 않거나
  • 양자화에 영향을 크게 받는 모델이거나
  • 변환 과정에서 놓친 부분이 존재할 수 있다

파이썬 환경에서 먼저 확인하는 게 좋은 이유는,

전처리 파이프라인을 동일하게 구현해도 리사이즈 보간 방식, float↔int 캐스팅 정책, 원본 로딩 방식 등
플랫폼 간에 세부적인 차이가 있을 수밖에 없고 추론에 영향을 미친다.

그래서 TFLite 변환이 제대로 되었는지 먼저 파이썬에서 검증하지 않으면 나중에 디버깅이 더 힘들다.

 

변환 코드 살펴보기

main

변환 흐름은 다음과 같다.

PyTorch → ONNX → TensorFlow →TFLite

def main() -> None:
    args = parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"[device] Using {device}")

    # 1) 모델 준비 + 가중치 로드
    model = build_model(
        model_type=args.model,
        num_classes=args.num_classes,
        backbone=args.backbone,
        use_pretrained_backbone=args.use_pretrained_backbone,
    )
    state_dict = load_state_dict_flexible(args.pt_path, device=torch.device("cpu"))
    missing, unexpected = model.load_state_dict(state_dict, strict=False)
    print("[state_dict] missing:", missing)
    print("[state_dict] unexpected:", unexpected)
    model.eval().to(device)

    # 2) PyTorch → ONNX
    dummy_input = torch.randn(1, 3, args.input_height, args.input_width, device=device)
    onnx_model_path = f"./onnx/{args.result_name}.onnx"
    export_onnx(model, dummy_input, onnx_model_path)
    print(f"[onnx] saved to {onnx_model_path}")

    # 3) ONNX → TensorFlow SavedModel
    saved_model_dir = f"./saved_model/{args.result_name}"
    print("[tf] converting onnx → saved_model ...")
    onnx_to_tf_nhwc(onnx_model_path, saved_model_dir)
    print(f"[tf] saved_model ready at {saved_model_dir}")

    # 4) TensorFlow → TFLite
    tflite_model_path = f"./tflite/{args.result_name}.tflite"
    print("[tflite] converting saved_model → tflite ...")
    tf_to_tflite(saved_model_dir, tflite_model_path, optimize=True)
    print(f"[tflite] saved to {tflite_model_path}")
    
    
 def export_onnx(model: nn.Module, dummy_input: torch.Tensor, onnx_path: str) -> None:
    """PyTorch → ONNX"""
    torch.onnx.export(
        model,
        dummy_input,
        onnx_path,
        input_names=["input"],
        output_names=["output"],
        dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
        opset_version=11,
    )
    
  def tf_to_tflite(saved_model_dir: str, tflite_path: str, optimize: bool = True) -> None:
    """TensorFlow SavedModel → TFLite"""
    converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
    if optimize:
        converter.optimizations = [tf.lite.Optimize.DEFAULT]
    tflite_model = converter.convert()
    with open(tflite_path, "wb") as f:
        f.write(tflite_model)
  • 양자화하지 않는다면 `tf_to_tflite()`의 해당 부분을 제거할 수도 있다.(tf.lite.Optimize.DEFAULT)

build_model

학습한 모델과 동일한 레이어를 정확히 재현하는 것이 가장 중요하다.

import torch.nn as nn
from torchvision import models
from torchvision.models import EfficientNet_B4_Weights, ResNet50_Weights

def build_model(model_type: str, num_classes: int, backbone: str, use_pretrained_backbone: bool) -> nn.Module:
    if model_type == "resnet50":
        model = models.resnet50(weights=ResNet50_Weights.DEFAULT)
        model.fc = nn.Linear(model.fc.in_features, num_classes)
        return model

    if model_type == "efficientnet_b4":
        model = models.efficientnet_b4(weights=EfficientNet_B4_Weights.DEFAULT)
        model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
        return model

    if model_type == "custom_head":
        return CustomHeadModel(
            backbone_name=backbone,
            num_classes=num_classes,
            pretrained=use_pretrained_backbone,
        )

    raise ValueError(f"Unsupported model type: {model_type}")
    
class CustomHeadModel(nn.Module):
    """
    예시용 커스텀 모델.
    """
    
    def __init__(self, backbone_name: str = "efficientnet_b4", num_classes: int = 3, pretrained: bool = False):
        super().__init__()

        if "efficientnet" in backbone_name:
            base = getattr(models, backbone_name)(
                weights=models.EfficientNet_B4_Weights.DEFAULT if pretrained else None
            )
            in_features = base.classifier[1].in_features
            base.classifier = nn.Identity()
        elif "resnet" in backbone_name:
            base = getattr(models, backbone_name)(
                weights=models.ResNet50_Weights.DEFAULT if pretrained else None
            )
            in_features = base.fc.in_features
            base.fc = nn.Identity()
        else:
            raise ValueError(f"지원하지 않는 backbone: {backbone_name}")

        self.backbone = base
        self.bn = nn.BatchNorm1d(in_features)
        self.head = nn.Sequential(
            nn.Linear(in_features, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        feat = self.backbone(x)
        if feat.dim() == 4:
            feat = torch.flatten(feat, 1)
        feat = self.bn(feat)
        return self.head(feat)
  • 마지막 fc만 설정한 모델의 예시 (resnet50, efficientnet_b4)
  • fc 앞에 레이어를 조금 더 쌓은 모델의 예시 (custom_head)

onnx_to_tf_nhwc

TFLite 입력은 NHWC 이다. 그에 맞춰 인터페이스를 수정한다.

import tensorflow as tf
import onnx
from onnx_tf.backend import prepare

def onnx_to_tf_nhwc(onnx_model_path, saved_model_dir):
    # Load ONNX model
    onnx_model = onnx.load(onnx_model_path)
    tf_rep = prepare(onnx_model)
    tf_rep.export_graph(saved_model_dir)

    model = tf.saved_model.load(saved_model_dir)
    concrete_func = model.signatures["serving_default"]
    input_tensor = concrete_func.inputs[0]
    input_shape = input_tensor.shape.as_list()  # [1, C, H, W]
    nhwc_shape = [input_shape[0], input_shape[2], input_shape[3], input_shape[1]]

    @tf.function(input_signature=[tf.TensorSpec(shape=nhwc_shape, dtype=tf.float32)])
    def new_serving_fn(inputs):
        nchw_input = tf.transpose(inputs, [0, 3, 1, 2])
        outputs = concrete_func(nchw_input)
        return outputs

    tf.saved_model.save(model, saved_model_dir, signatures={'serving_default': new_serving_fn})
  • ONNX(NCHW) → TF SavedModel로 변환한 뒤, 입력만 NHWC로 받도록 래핑해 다시 저장한다.
  • 내부 연산은 NCHW 그대로 두고, 외부 인터페이스만 NHWC로 바꿔 TFLite 변환/추론 시 채널 순서가 어긋나지 않게 한다.