Source code for yolort.v5.helper

# Copyright (c) 2021, yolort team. All rights reserved.
import contextlib
import sys
from pathlib import Path

import torch

from .models import AutoShape
from .models.yolo import Model
from .utils import attempt_download, intersect_dicts, set_logging

__all__ = ["add_yolov5_context", "load_yolov5_model", "get_yolov5_size"]


[docs]@contextlib.contextmanager def add_yolov5_context(): """ Temporarily add yolov5 folder to `sys.path`. Adapted from https://github.com/fcakyon/yolov5-pip/blob/0d03de6/yolov5/utils/general.py#L739-L754 torch.hub handles it in the same way: https://github.com/pytorch/pytorch/blob/d3e36fa/torch/hub.py#L387-L416 """ path_ultralytics_yolov5 = str(Path(__file__).parent.resolve()) try: sys.path.insert(0, path_ultralytics_yolov5) yield finally: sys.path.remove(path_ultralytics_yolov5)
[docs]def get_yolov5_size(depth_multiple, width_multiple): if depth_multiple == 0.33 and width_multiple == 0.25: return "n" if depth_multiple == 0.33 and width_multiple == 0.5: return "s" if depth_multiple == 0.67 and width_multiple == 0.75: return "m" if depth_multiple == 1.0 and width_multiple == 1.0: return "l" if depth_multiple == 1.33 and width_multiple == 1.25: return "x" raise NotImplementedError( f"Currently does't support architecture with depth: {depth_multiple} " f"and {width_multiple}, fell free to create a ticket labeled enhancement to us" )
[docs]def load_yolov5_model(checkpoint_path: str, autoshape: bool = False, verbose: bool = True): """ Creates a specified YOLOv5 model Args: checkpoint_path (str): path of the YOLOv5 model, i.e. 'yolov5s.pt' autoshape (bool): apply YOLOv5 .autoshape() wrapper to model. Default: False. verbose (bool): print all information to screen. Default: True. Returns: YOLOv5 pytorch model """ set_logging(verbose=verbose) with add_yolov5_context(): ckpt = torch.load(attempt_download(checkpoint_path), map_location=torch.device("cpu")) if isinstance(ckpt, dict): model_ckpt = ckpt["model"] # load model model = Model(model_ckpt.yaml) # create model ckpt_state_dict = model_ckpt.float().state_dict() # checkpoint state_dict as FP32 ckpt_state_dict = intersect_dicts(ckpt_state_dict, model.state_dict(), exclude=["anchors"]) model.load_state_dict(ckpt_state_dict, strict=False) if autoshape: model = AutoShape(model) return model