image_to_pixle_params_yoloSAM/ultralytics-main/ultralytics/utils/export.py

239 lines
9.8 KiB
Python

# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
import json
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
import torch
from ultralytics.utils import IS_JETSON, LOGGER
def export_onnx(
torch_model: torch.nn.Module,
im: torch.Tensor,
onnx_file: str,
opset: int = 14,
input_names: List[str] = ["images"],
output_names: List[str] = ["output0"],
dynamic: Union[bool, Dict] = False,
) -> None:
"""
Export a PyTorch model to ONNX format.
Args:
torch_model (torch.nn.Module): The PyTorch model to export.
im (torch.Tensor): Example input tensor for the model.
onnx_file (str): Path to save the exported ONNX file.
opset (int): ONNX opset version to use for export.
input_names (List[str]): List of input tensor names.
output_names (List[str]): List of output tensor names.
dynamic (bool | Dict, optional): Whether to enable dynamic axes.
Notes:
Setting `do_constant_folding=True` may cause issues with DNN inference for torch>=1.12.
"""
torch.onnx.export(
torch_model,
im,
onnx_file,
verbose=False,
opset_version=opset,
do_constant_folding=True, # WARNING: DNN inference with torch>=1.12 may require do_constant_folding=False
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic or None,
)
def export_engine(
onnx_file: str,
engine_file: Optional[str] = None,
workspace: Optional[int] = None,
half: bool = False,
int8: bool = False,
dynamic: bool = False,
shape: Tuple[int, int, int, int] = (1, 3, 640, 640),
dla: Optional[int] = None,
dataset=None,
metadata: Optional[Dict] = None,
verbose: bool = False,
prefix: str = "",
) -> None:
"""
Export a YOLO model to TensorRT engine format.
Args:
onnx_file (str): Path to the ONNX file to be converted.
engine_file (str, optional): Path to save the generated TensorRT engine file.
workspace (int, optional): Workspace size in GB for TensorRT.
half (bool, optional): Enable FP16 precision.
int8 (bool, optional): Enable INT8 precision.
dynamic (bool, optional): Enable dynamic input shapes.
shape (Tuple[int, int, int, int], optional): Input shape (batch, channels, height, width).
dla (int, optional): DLA core to use (Jetson devices only).
dataset (ultralytics.data.build.InfiniteDataLoader, optional): Dataset for INT8 calibration.
metadata (Dict, optional): Metadata to include in the engine file.
verbose (bool, optional): Enable verbose logging.
prefix (str, optional): Prefix for log messages.
Raises:
ValueError: If DLA is enabled on non-Jetson devices or required precision is not set.
RuntimeError: If the ONNX file cannot be parsed.
Notes:
TensorRT version compatibility is handled for workspace size and engine building.
INT8 calibration requires a dataset and generates a calibration cache.
Metadata is serialized and written to the engine file if provided.
"""
import tensorrt as trt # noqa
engine_file = engine_file or Path(onnx_file).with_suffix(".engine")
logger = trt.Logger(trt.Logger.INFO)
if verbose:
logger.min_severity = trt.Logger.Severity.VERBOSE
# Engine builder
builder = trt.Builder(logger)
config = builder.create_builder_config()
workspace = int((workspace or 0) * (1 << 30))
is_trt10 = int(trt.__version__.split(".", 1)[0]) >= 10 # is TensorRT >= 10
if is_trt10 and workspace > 0:
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace)
elif workspace > 0: # TensorRT versions 7, 8
config.max_workspace_size = workspace
flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
network = builder.create_network(flag)
half = builder.platform_has_fast_fp16 and half
int8 = builder.platform_has_fast_int8 and int8
# Optionally switch to DLA if enabled
if dla is not None:
if not IS_JETSON:
raise ValueError("DLA is only available on NVIDIA Jetson devices")
LOGGER.info(f"{prefix} enabling DLA on core {dla}...")
if not half and not int8:
raise ValueError(
"DLA requires either 'half=True' (FP16) or 'int8=True' (INT8) to be enabled. Please enable one of them and try again."
)
config.default_device_type = trt.DeviceType.DLA
config.DLA_core = int(dla)
config.set_flag(trt.BuilderFlag.GPU_FALLBACK)
# Read ONNX file
parser = trt.OnnxParser(network, logger)
if not parser.parse_from_file(onnx_file):
raise RuntimeError(f"failed to load ONNX file: {onnx_file}")
# Network inputs
inputs = [network.get_input(i) for i in range(network.num_inputs)]
outputs = [network.get_output(i) for i in range(network.num_outputs)]
for inp in inputs:
LOGGER.info(f'{prefix} input "{inp.name}" with shape{inp.shape} {inp.dtype}')
for out in outputs:
LOGGER.info(f'{prefix} output "{out.name}" with shape{out.shape} {out.dtype}')
if dynamic:
if shape[0] <= 1:
LOGGER.warning(f"{prefix} 'dynamic=True' model requires max batch size, i.e. 'batch=16'")
profile = builder.create_optimization_profile()
min_shape = (1, shape[1], 32, 32) # minimum input shape
max_shape = (*shape[:2], *(int(max(2, workspace or 2) * d) for d in shape[2:])) # max input shape
for inp in inputs:
profile.set_shape(inp.name, min=min_shape, opt=shape, max=max_shape)
config.add_optimization_profile(profile)
if int8:
config.set_calibration_profile(profile)
LOGGER.info(f"{prefix} building {'INT8' if int8 else 'FP' + ('16' if half else '32')} engine as {engine_file}")
if int8:
config.set_flag(trt.BuilderFlag.INT8)
config.profiling_verbosity = trt.ProfilingVerbosity.DETAILED
class EngineCalibrator(trt.IInt8Calibrator):
"""
Custom INT8 calibrator for TensorRT engine optimization.
This calibrator provides the necessary interface for TensorRT to perform INT8 quantization calibration
using a dataset. It handles batch generation, caching, and calibration algorithm selection.
Attributes:
dataset: Dataset for calibration.
data_iter: Iterator over the calibration dataset.
algo (trt.CalibrationAlgoType): Calibration algorithm type.
batch (int): Batch size for calibration.
cache (Path): Path to save the calibration cache.
Methods:
get_algorithm: Get the calibration algorithm to use.
get_batch_size: Get the batch size to use for calibration.
get_batch: Get the next batch to use for calibration.
read_calibration_cache: Use existing cache instead of calibrating again.
write_calibration_cache: Write calibration cache to disk.
"""
def __init__(
self,
dataset, # ultralytics.data.build.InfiniteDataLoader
cache: str = "",
) -> None:
"""Initialize the INT8 calibrator with dataset and cache path."""
trt.IInt8Calibrator.__init__(self)
self.dataset = dataset
self.data_iter = iter(dataset)
self.algo = (
trt.CalibrationAlgoType.ENTROPY_CALIBRATION_2 # DLA quantization needs ENTROPY_CALIBRATION_2
if dla is not None
else trt.CalibrationAlgoType.MINMAX_CALIBRATION
)
self.batch = dataset.batch_size
self.cache = Path(cache)
def get_algorithm(self) -> trt.CalibrationAlgoType:
"""Get the calibration algorithm to use."""
return self.algo
def get_batch_size(self) -> int:
"""Get the batch size to use for calibration."""
return self.batch or 1
def get_batch(self, names) -> Optional[List[int]]:
"""Get the next batch to use for calibration, as a list of device memory pointers."""
try:
im0s = next(self.data_iter)["img"] / 255.0
im0s = im0s.to("cuda") if im0s.device.type == "cpu" else im0s
return [int(im0s.data_ptr())]
except StopIteration:
# Return None to signal to TensorRT there is no calibration data remaining
return None
def read_calibration_cache(self) -> Optional[bytes]:
"""Use existing cache instead of calibrating again, otherwise, implicitly return None."""
if self.cache.exists() and self.cache.suffix == ".cache":
return self.cache.read_bytes()
def write_calibration_cache(self, cache: bytes) -> None:
"""Write calibration cache to disk."""
_ = self.cache.write_bytes(cache)
# Load dataset w/ builder (for batching) and calibrate
config.int8_calibrator = EngineCalibrator(
dataset=dataset,
cache=str(Path(onnx_file).with_suffix(".cache")),
)
elif half:
config.set_flag(trt.BuilderFlag.FP16)
# Write file
build = builder.build_serialized_network if is_trt10 else builder.build_engine
with build(network, config) as engine, open(engine_file, "wb") as t:
# Metadata
if metadata is not None:
meta = json.dumps(metadata)
t.write(len(meta).to_bytes(4, byteorder="little", signed=True))
t.write(meta.encode())
# Model
t.write(engine if is_trt10 else engine.serialize())