Kernel features: 3 Input features: 11 Groups: 3 Questions

Hello, I am reporting an error when converting onnx to hef. Could you please help me take a look. The relevant code is as follows:
Convert onnx code:

import argparse
import os
import sys
import torch
import torch.nn as nn

# add project root to import path
ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
if ROOT not in sys.path:
    sys.path.insert(0, ROOT)

from projects.mmdet3d_plugin.models.detection3d.detection3d_blocks import SparseBox3DEncoder


class LastDimBatchNorm(nn.BatchNorm1d):
    def forward(self, x):
        shape = x.shape
        x = super().forward(x.reshape(-1, shape[-1]))
        return x.view(shape)


class LayerNormViaLinear(nn.Module):
    def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
        super().__init__()
        C = int(normalized_shape[0]) if isinstance(normalized_shape, (list, tuple)) else int(normalized_shape)
        self.eps = eps
        self.elementwise_affine = elementwise_affine
        self.w_sum = nn.Parameter(torch.ones(C, 1), requires_grad=False)
        self.w_sumsq = nn.Parameter(torch.ones(C, 1), requires_grad=False)
        self.register_buffer("invC", torch.tensor(1.0 / C))
        if elementwise_affine:
            self.weight = nn.Parameter(torch.ones(C))
            self.bias = nn.Parameter(torch.zeros(C))
        else:
            self.register_parameter("weight", None)
            self.register_parameter("bias", None)

    def forward(self, x):
        mean = torch.matmul(x, self.w_sum) * self.invC
        var = torch.matmul(x * x, self.w_sumsq) * self.invC - mean * mean
        x_hat = (x - mean) * torch.rsqrt(var + self.eps)
        if self.elementwise_affine:
            x_hat = x_hat * self.weight + self.bias
        return x_hat


def collect_layernorm_stats(model: nn.Module, data_list):
    stats = {}
    handles = []

    def hook(mod, inp, out):
        x = inp[0]
        mean = x.mean(dim=(0, 1))
        var = x.var(dim=(0, 1), unbiased=False)
        if mod not in stats:
            stats[mod] = {"mean": mean.detach(), "var": var.detach(), "count": 1}
        else:
            stats[mod]["mean"] += mean.detach()
            stats[mod]["var"] += var.detach()
            stats[mod]["count"] += 1

    for m in model.modules():
        if isinstance(m, nn.LayerNorm):
            handles.append(m.register_forward_hook(hook))

    model.eval()
    with torch.no_grad():
        for data in data_list:
            model(data)

    for h in handles:
        h.remove()

    avg_stats = {}
    for mod, v in stats.items():
        c = v["count"]
        avg_stats[mod] = (v["mean"] / c, v["var"] / c)
    return avg_stats


def patch_norm(module: nn.Module, mode: str, stats: dict = None):
    for name, child in list(module.named_children()):
        if isinstance(child, nn.LayerNorm):
            if mode == "batchnorm":
                bn = LastDimBatchNorm(
                    child.normalized_shape[-1], affine=True, track_running_stats=True, momentum=0.0
                )
                if child.elementwise_affine:
                    with torch.no_grad():
                        bn.weight.copy_(child.weight)
                        bn.bias.copy_(child.bias)
                bn.eps = child.eps
                if stats is not None and child in stats:
                    mean, var = stats[child]
                    with torch.no_grad():
                        bn.running_mean.copy_(mean)
                        bn.running_var.copy_(var)
                else:
                    with torch.no_grad():
                        bn.running_mean.zero_()
                        bn.running_var.fill_(1.0)
                setattr(module, name, bn)
            elif mode == "linearln":
                ln = LayerNormViaLinear(
                    child.normalized_shape, eps=child.eps, elementwise_affine=child.elementwise_affine
                )
                if child.elementwise_affine:
                    with torch.no_grad():
                        ln.weight.copy_(child.weight)
                        ln.bias.copy_(child.bias)
                setattr(module, name, ln)
            elif mode == "none":
                setattr(module, name, nn.Identity())
        else:
            patch_norm(child, mode, stats)


def load_checkpoint_into_module(module: torch.nn.Module, ckpt_path: str):
    ckpt = torch.load(ckpt_path, map_location="cpu")
    state = ckpt["state_dict"] if isinstance(ckpt, dict) and "state_dict" in ckpt else ckpt
    prefix_candidates = ["head.anchor_encoder.", "module.head.anchor_encoder.", "anchor_encoder."]
    new_state = {}
    for k, v in state.items():
        for pre in prefix_candidates:
            if k.startswith(pre):
                new_state[k[len(pre) :]] = v
                break
    missing, unexpected = module.load_state_dict(new_state, strict=False)
    print(f"Loaded checkpoint: {ckpt_path}")
    if missing:
        print("Missing keys in encoder:", missing)
    if unexpected:
        print("Unexpected keys ignored:", unexpected)


class EncoderExportWrapper(torch.nn.Module):
    def __init__(self, encoder: torch.nn.Module):
        super().__init__()
        self.encoder = encoder

    def forward(self, anchor):
        # pos, size, yaw, vel = torch.split(anchor, [3, 3, 2, self.encoder.vel_dims], dim=-1)
        vel_start_idx = 8

        # 使用切片操作 [..., start:end]
        pos = anchor[..., 0:3]
        size = anchor[..., 3:6]
        yaw = anchor[..., 6:vel_start_idx]
        vel = anchor[..., vel_start_idx:]

        pos_feat = self.encoder.pos_fc(pos)
        size_feat = self.encoder.size_fc(size)
        yaw_feat = self.encoder.yaw_fc(yaw)
        
        # **** 修改区域 ****
        # 移除了 4D Reshape (pos_feat.reshape(...) 等)
        # **** 结束修改区域 ****

        if self.encoder.mode == "add":
            # 在 dim=-1 (特征维度) 上进行逐元素相加
            output = pos_feat + size_feat + yaw_feat
        else:
            # 模式为 "cat",在 dim=-1 (特征维度) 上进行拼接
            output = torch.cat([pos_feat, size_feat, yaw_feat], dim=-1)
            
        if self.encoder.vel_dims > 0:
            vel_feat = self.encoder.vel_fc(vel)
            # 移除了 vel_feat 的 4D Reshape
            
            output = output + vel_feat if self.encoder.mode == "add" else torch.cat([output, vel_feat], dim=-1)
            
        if self.encoder.output_fc is not None:
            output = self.encoder.output_fc(output)
            
        # 移除了最后的 reshape,直接返回 (B, N, C_total) 形状
        return output


class EncoderMultiInputWrapper(torch.nn.Module):
    def __init__(self, encoder: torch.nn.Module):
        super().__init__()
        self.encoder = encoder

    def forward(self, pos, size, yaw, vel=None):
        pos_feat = self.encoder.pos_fc(pos)
        size_feat = self.encoder.size_fc(size)
        yaw_feat = self.encoder.yaw_fc(yaw)
        
        # **** 修改区域 ****
        # 移除了 4D Reshape
        # **** 结束修改区域 ****
        
        if self.encoder.mode == "add":
            output = pos_feat + size_feat + yaw_feat
        else:
            output = torch.cat([pos_feat, size_feat, yaw_feat], dim=-1)
            
        if vel is not None and self.encoder.vel_dims > 0:
            vel_feat = self.encoder.vel_fc(vel)
            # 移除了 vel_feat 的 4D Reshape
            
            output = output + vel_feat if self.encoder.mode == "add" else torch.cat([output, vel_feat], dim=-1)
            
        if self.encoder.output_fc is not None:
            output = self.encoder.output_fc(output)
            
        # 移除了最后的 reshape
        return output


def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument("--ckpt", default=“”)
    p.add_argument("--out", default="")
    p.add_argument("--bs", type=int, default=6)
    p.add_argument("--num_anchor", type=int, default=900)
    p.add_argument("--state_dim", type=int, default=11)
    p.add_argument("--opset", type=int, default=13)
    p.add_argument("--no_decouple", action="store_true")
    p.add_argument("--norm", choices=["layernorm", "batchnorm", "none", "linearln"], default="layernorm")
    p.add_argument("--calib_batches", type=int, default=0)
    p.add_argument("--no_simplify", action="store_true")
    p.add_argument("--multi_input", action="store_true", help="Export 4 inputs (pos/size/yaw/vel)")
    return p.parse_args()


def main():
    args = parse_args()
    decouple_attn = not args.no_decouple
    os.makedirs(os.path.dirname(args.out), exist_ok=True)
    print(f"Configuring Encoder with decouple_attn={decouple_attn}...")

    if decouple_attn:
        embed_dims_arg, mode_arg, output_fc_arg, out_loops_arg = [128, 32, 32, 64], "cat", False, 4
    else:
        embed_dims_arg, mode_arg, output_fc_arg, out_loops_arg = 256, "add", True, 2

    encoder = SparseBox3DEncoder(
        vel_dims=3,
        embed_dims=embed_dims_arg,
        mode=mode_arg,
        output_fc=output_fc_arg,
        in_loops=1,
        out_loops=out_loops_arg,
    )

    if os.path.exists(args.ckpt):
        load_checkpoint_into_module(encoder, args.ckpt)
    else:
        print("Error: checkpoint not found at", args.ckpt)
        return

    if args.norm != "layernorm":
        stats = None
        if args.norm == "batchnorm" and args.calib_batches > 0:
            calib_data = [
                torch.randn(args.bs, args.num_anchor, args.state_dim, dtype=torch.float32)
                for _ in range(args.calib_batches)
            ]
            stats = collect_layernorm_stats(encoder, calib_data)
        print(f"Applying norm override: {args.norm}")
        patch_norm(encoder, args.norm, stats)

    wrapper = EncoderMultiInputWrapper(encoder) if args.multi_input else EncoderExportWrapper(encoder)
    wrapper = wrapper.eval()

    dummy_anchor = torch.randn(args.bs, args.num_anchor, args.state_dim, dtype=torch.float32)
    dummy_pos = dummy_anchor[..., :3]
    dummy_size = dummy_anchor[..., 3:6]
    dummy_yaw = dummy_anchor[..., 6:8]
    dummy_vel = dummy_anchor[..., 8 : 8 + 3] if encoder.vel_dims > 0 else None
    export_path = os.path.abspath(args.out)
    print(f"Exporting encoder to ONNX -> {export_path}")

    with torch.no_grad():
        if args.multi_input:
            inputs = (dummy_pos, dummy_size, dummy_yaw, dummy_vel)
            input_names = ["pos", "size", "yaw", "vel"]
        else:
            inputs = (dummy_anchor,)
            input_names = ["anchor"]

        torch.onnx.export(
            wrapper,
            inputs,
            export_path,
            opset_version=args.opset,
            do_constant_folding=True,
            input_names=input_names,
            output_names=["anchor_embed"],
            dynamic_axes=None,
            verbose=True,
        )

    print("ONNX export finished.")
    if not args.no_simplify:
        try:
            import onnx
            from onnxsim import simplify

            print("Trying to simplify ONNX model with onnx-simplifier...")
            model = onnx.load(export_path)
            model_simp, check = simplify(model)
            if check:
                onnx.save(model_simp, export_path)
                print("Simplified and overwritten ONNX model.")
            else:
                print("onnx-simplifier check failed; keeping original export.")
        except ImportError:
            print("onnx-simplifier module not found. Skipping simplification.")
        except Exception as e:
            print("onnx-simplifier failed:", e)
    else:
        print("Skipped onnx-simplifier as requested.")

    print("Done.")


if __name__ == "__main__":
    main()

onnx to hef:

import os
import sys
import logging
import numpy as np
import onnx
import onnx.checker
from hailo_sdk_client import ClientRunner

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
log = logging.getLogger("onnx2hef_encoder")

# 配置
ONNX_PATH = "1.onnx"
# 改回单输入 anchor,避免多输入校准 KeyError
USE_MULTI_INPUT = False
B, NUM_ANCHOR, STATE_DIM = 6, 900, 11

OUTDIR = "onnx_models/hailo_outputs_encoder"
NET_NAME = "box3d_encoder"
HW_ARCH = "hailo8"

HAR_FLOAT = os.path.join(OUTDIR, f"{NET_NAME}_hailo_model.har")
HAR_QUANT = os.path.join(OUTDIR, f"{NET_NAME}_hailo_quantized_model.har")
HEF_PATH = os.path.join(OUTDIR, f"{NET_NAME}.hef")

CALIB_SAMPLES = 32
RNG_SEED = 42

def verify_onnx(path: str, input_names):
    assert os.path.exists(path), f"ONNX 不存在: {path}"
    m = onnx.load(path)
    onnx.checker.check_model(m)
    inputs = [i.name for i in m.graph.input]
    log.info(f"ONNX 输入: {inputs}")
    for name in input_names:
        if name not in inputs:
            log.warning(f"期望输入名 '{name}',实际 {inputs}。")

def main():
    os.makedirs(OUTDIR, exist_ok=True)
    input_names = ["anchor"]
    net_input_shapes = {"anchor": (B, NUM_ANCHOR, STATE_DIM)}

    verify_onnx(ONNX_PATH, input_names)

    runner = ClientRunner(hw_arch=HW_ARCH)
    runner.translate_onnx_model(
        model=ONNX_PATH,
        net_name=NET_NAME,
        net_input_shapes=net_input_shapes,
        start_node_names=input_names,
    )
    runner.save_har(HAR_FLOAT)
    log.info(f"HAR(float) 已保存: {HAR_FLOAT}")

    script_lines = [
        "model_optimization_flavor(optimization_level=2, compression_level=0)",
        "resources_param(max_control_utilization=0.6, max_compute_utilization=0.6, max_memory_utilization=0.6)",
    ]
    try:
        runner.load_model_script("\n".join(script_lines))
        log.info("model script 已加载。")
    except Exception as e:
        log.warning(f"model script 加载失败,可忽略/按需调整: {e}")

    np.random.seed(RNG_SEED)
    if USE_MULTI_INPUT:
        # Hailo optimize 对多输入更稳:tuple of numpy arrays, 第一维是校准样本数,顺序与 start_node_names 对齐
        calib = {
            "pos": np.random.randn(CALIB_SAMPLES, B, NUM_ANCHOR, STATE_SPLIT["pos"]).astype(np.float32),
            "size": np.random.randn(CALIB_SAMPLES, B, NUM_ANCHOR, STATE_SPLIT["size"]).astype(np.float32),
            "yaw": np.random.randn(CALIB_SAMPLES, B, NUM_ANCHOR, STATE_SPLIT["yaw"]).astype(np.float32),
            "vel": np.random.randn(CALIB_SAMPLES, B, NUM_ANCHOR, STATE_SPLIT["vel"]).astype(np.float32),
        }
        log.info("校准数据: multi-input random anchors (dict of np_array)")
    else:
        calib = np.random.randn(CALIB_SAMPLES, B, NUM_ANCHOR, sum(STATE_SPLIT.values())).astype(np.float32)
        log.info(f"校准数据: shape={calib.shape}, dtype={calib.dtype}")

    # 明确传递校准数据类型
    runner.optimize(calib, data_type="np_array")
    log.info("Optimize 完成(生成量化权重)。")

    try:
        runner.save_har(HAR_QUANT)
        log.info(f"HAR(quantized) 已保存: {HAR_QUANT}")
    except Exception as e:
        log.info(f"保存量化 HAR 失败(可能 SDK 不支持写回量化信息): {e}")

    try:
        hef = runner.compile()
        try:
            hef.save(HEF_PATH)
        except Exception:
            with open(HEF_PATH, "wb") as f:
                f.write(hef)
        log.info(f"HEF 已导出: {HEF_PATH}")
    except AttributeError:
        runner.save_hef(HEF_PATH)
        log.info(f"HEF 已导出(save_hef): {HEF_PATH}")

    log.info("完成 ONNX -> HAR -> Optimize(量化) -> HEF 全流程。")

if __name__ == "__main__":
    main()

error:

[info] Simplified ONNX model for a parsing retry attempt (completion time: 00:00:00.51)
Traceback (most recent call last):
  File "/home/ubuntu/.conda/envs/sparse4d_new/lib/python3.8/site-packages/hailo_sdk_client/sdk_backend/parser/parser.py", line 179, in translate_onnx_model
    parsing_results = self._parse_onnx_model_to_hn(onnx_model, valid_net_name, start_node_names,
  File "/home/ubuntu/.conda/envs/sparse4d_new/lib/python3.8/site-packages/hailo_sdk_client/sdk_backend/parser/parser.py", line 237, in _parse_onnx_model_to_hn
    return self.parse_model_to_hn(onnx_model, None, net_name, start_node_names, end_node_names,
  File "/home/ubuntu/.conda/envs/sparse4d_new/lib/python3.8/site-packages/hailo_sdk_client/sdk_backend/parser/parser.py", line 263, in parse_model_to_hn
    fuser = HailoNNFuser(converter.convert_model(), net_name, converter.end_node_names)
  File "/home/ubuntu/.conda/envs/sparse4d_new/lib/python3.8/site-packages/hailo_sdk_client/model_translator/translator.py", line 76, in convert_model
    self._create_layers()
  File "/home/ubuntu/.conda/envs/sparse4d_new/lib/python3.8/site-packages/hailo_sdk_client/model_translator/edge_nn_translator.py", line 28, in _create_layers
    self._add_direct_layers()
  File "/home/ubuntu/.conda/envs/sparse4d_new/lib/python3.8/site-packages/hailo_sdk_client/model_translator/edge_nn_translator.py", line 121, in _add_direct_layers
    node.update_output_format(vertex)
  File "/home/ubuntu/.conda/envs/sparse4d_new/lib/python3.8/site-packages/hailo_sdk_client/model_translator/onnx_translator/onnx_graph.py", line 220, in update_output_format
    elif self.op == 'MatMul' and self.is_matmul_layer():
  File "/home/ubuntu/.conda/envs/sparse4d_new/lib/python3.8/site-packages/hailo_sdk_client/model_translator/onnx_translator/onnx_graph.py", line 2948, in is_matmul_layer
    kernel, _ = self.get_kernel()
  File "/home/ubuntu/.conda/envs/sparse4d_new/lib/python3.8/site-packages/hailo_sdk_client/model_translator/onnx_translator/onnx_graph.py", line 443, in get_kernel
    const_shape = pred._info.attribute[0].t.dims
IndexError: list index (0) out of range

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "projects/mmdet3d_plugin/tools/head/hailloe.py", line 105, in <module>
    main()
  File "projects/mmdet3d_plugin/tools/head/hailloe.py", line 47, in main
    runner.translate_onnx_model(
  File "/home/ubuntu/.conda/envs/sparse4d_new/lib/python3.8/site-packages/hailo_sdk_common/states/states.py", line 16, in wrapped_func
    return func(self, *args, **kwargs)
  File "/home/ubuntu/.conda/envs/sparse4d_new/lib/python3.8/site-packages/hailo_sdk_client/runner/client_runner.py", line 876, in translate_onnx_model
    parser.translate_onnx_model(model=model, net_name=net_name, start_node_names=start_node_names,
  File "/home/ubuntu/.conda/envs/sparse4d_new/lib/python3.8/site-packages/hailo_sdk_client/sdk_backend/parser/parser.py", line 211, in translate_onnx_model
    parsing_results = self._parse_onnx_model_to_hn(simplified_model, valid_net_name,
  File "/home/ubuntu/.conda/envs/sparse4d_new/lib/python3.8/site-packages/hailo_sdk_client/sdk_backend/parser/parser.py", line 237, in _parse_onnx_model_to_hn
    return self.parse_model_to_hn(onnx_model, None, net_name, start_node_names, end_node_names,
  File "/home/ubuntu/.conda/envs/sparse4d_new/lib/python3.8/site-packages/hailo_sdk_client/sdk_backend/parser/parser.py", line 263, in parse_model_to_hn
    fuser = HailoNNFuser(converter.convert_model(), net_name, converter.end_node_names)
  File "/home/ubuntu/.conda/envs/sparse4d_new/lib/python3.8/site-packages/hailo_sdk_client/model_translator/translator.py", line 85, in convert_model
    self._calculate_shapes(validate_shapes=False)
  File "/home/ubuntu/.conda/envs/sparse4d_new/lib/python3.8/site-packages/hailo_sdk_client/model_translator/onnx_translator/onnx_translator.py", line 110, in _calculate_shapes
    self._layers_graph.calculate_shapes(meta_edges_graph=self._meta_graph, validate_shapes=validate_shapes)
  File "/home/ubuntu/.conda/envs/sparse4d_new/lib/python3.8/site-packages/hailo_sdk_common/hailo_nn/hailo_nn.py", line 627, in calculate_shapes
    self.update_input_shapes_from_predecessors(layer)
  File "/home/ubuntu/.conda/envs/sparse4d_new/lib/python3.8/site-packages/hailo_sdk_common/hailo_nn/hailo_nn.py", line 668, in update_input_shapes_from_predecessors
    layer.input_shapes = input_shapes
  File "/home/ubuntu/.conda/envs/sparse4d_new/lib/python3.8/site-packages/hailo_sdk_common/hailo_nn/hn_layers/layer.py", line 505, in input_shapes
    self.set_input_shapes(input_shapes)
  File "/home/ubuntu/.conda/envs/sparse4d_new/lib/python3.8/site-packages/hailo_sdk_common/hailo_nn/hn_layers/conv2d.py", line 511, in set_input_shapes
    raise UnsupportedModelError(
hailo_sdk_common.hailo_nn.exceptions.UnsupportedModelError: Invalid kernel shape for base conv layer base_conv1 (translated from /pos_fc/pos_fc.0/Add).
Either the input shape doesn't match the kernel shape, or the calculated groups number doesn't match the expected ratio between kernel shape and input shape.
Kernel features: 3 Input features: 11 Groups: 3

Hey @chenyao,

The error you’re seeing is coming from Hailo’s parser during shape inference — not from your Python code. The important part of the message is:

Invalid kernel shape for base conv layer base_conv1 (translated from /pos_fc/pos_fc.0/Add). Either the input shape doesn't match the kernel shape, or the calculated groups number doesn't match the expected ratio between kernel shape and input shape. Kernel features: 3 Input features: 11 Groups: 3

Basically, Hailo thinks this layer (which came from your pos_fc/pos_fc.0/Add node) looks like a convolution, but it’s getting an input with 11 channels while the kernel expects 3 channels with 3 groups. That mismatch makes the parser reject the model.

Here’s what I’d check:

First, make sure the ONNX model itself is valid:

python -c "import onnx; onnx.checker.check_model(onnx.load('1.onnx'))"

Then open 1.onnx in Netron and inspect the /pos_fc/pos_fc.0/Add node. Look at:

  • What shape the input tensor actually has at that point
  • The shapes of the weights/biases
  • Whether a reshape/transpose right before this node is changing the layout in a way you didn’t expect

In your code you have:

pos  = anchor[..., 0:3]
size = anchor[..., 3:6]
yaw  = anchor[..., 6:8]
vel  = anchor[..., 8:]

So STATE_DIM should indeed be 11. But since the error shows a kernel expecting only 3 features, it looks like pos_fc might be getting the full 11-dim vector in the ONNX graph instead of just the 3D slice you intended. Netron will tell you immediately whether the slicing is being exported correctly.

Also double-check that your declared input shape matches what’s actually inside the ONNX:

net_input_shapes = {"anchor": (B, NUM_ANCHOR, STATE_DIM)}  # e.g., (6, 900, 11)

If the ONNX input got transposed to something like (B, STATE_DIM, NUM_ANCHOR), that could also trigger this kind of mismatch.

If everything seems correct but Hailo still complains, try passing the model through onnx-simplifier and recompile — it often cleans up graph patterns that Hailo struggles with.

Hope this gives you a good direction! Let me know what you find when you inspect the ONNX graph, You can post thye picture of the problem node from netron and we can help you better !