Is it possible to convert models with Jax architecture?
I have tried to convert my custom model to HAR, and after running this:
hn, npz = runner.translate_onnx_model(onnx_path, onnx_model_name)
have this error: KeyError: jax2tf_inference_/pjit_inference_/VQVAE/encoder/res_block1/norm1/Sub:0
Are there troubles with my layers naming, or don’t you support this architecture at all?
Thank you!!!
Hi @ighgul,
We currently do not support the Jax framework in the parser. Did you try to export your model to ONNX or tflite? if you exported it to ONNX, did you run onnx-simplifier?
Regards,
Thank you for answer!
- this is ONNX model;
- I have an error on onnx-simplifier step. The last I see from logging was:
[info] Simplified ONNX model for a parsing retry attempt (completion time: 00:00:05.28)
@Omer do you have any ideas about this issue and how it can be solved?
as a detail: node_value = node_value[self._info.input[order.index("initializer")]].flatten()[0]
this is a line I have error in.
where it is located: hailo_sdk_client/model_translator/onnx_translator/onnx_graph.py
@Omer before this error occurs, it was this
IndexError: list index out of range
During handling of the above exception, another exception occurred:
maybe it can help to find solution
Hi @ighgul,
First, I suggest you try the onnx-simplifier on the ONNX before sending it to Hailo - onnxsim YOUR_MODEL.onnx YOUR_MODEL.sim.onnx --skip-fuse-bn
and only then run the Hailo Parser.
Secondly, I believe that your model is not supported currently by the Hailo SW - meaning that there are some layers that cannot be parsed. But the Parser should recommend start nodes\end nodes, and if that doesn’t happen it might be because of a bug.
If after you simplify the ONNX yourself you still get an error, please open a ticket with this request + the ONNX you used so we can take this to the R&D and check.
Regards,
@Omer, the out of range
error was solved using your recommendation, but the Keyerror
is still here.
I provide you my ONNX file and script to run it:
import onnxruntime as ort
import numpy as np
import os
import random
from PIL import Image
def save_output(orig_image, output_image, name):
# Combine original and output images vertically
#output_image = Image.fromarray(output_image)
combined_image = Image.new('RGB', (max(orig_image.width, output_image.width), orig_image.height + output_image.height))
combined_image.paste(orig_image, (0, 0))
combined_image.paste(output_image, (0, orig_image.height))
combined_image.save(os.path.join('outputs', name))
# Load the ONNX model
model_path = 'checkpoint_step_0300.onnx'
session = ort.InferenceSession(model_path, providers=['CPUExecutionProvider'])
input_name = session.get_inputs()[0].name
# Prepare the input data
image_size = (1280, 720)
batch_size = 1
for i in range(10):
print("Processing batch", i)
original_data_path = os.path.join('data', 'original')
images_names = [name for name in os.listdir(original_data_path) if not name.startswith('.')]
sample_batch = sorted(random.sample(images_names, batch_size))
images = [Image.open(os.path.join(original_data_path, name)).resize(image_size, Image.LANCZOS) for name in sample_batch]
buffer = np.stack([np.array(img) for img in images]).astype(np.uint8)
input_data = {"x": buffer}
# Run the model
output = session.run(None, input_data)
for i, out in enumerate(output):
print(i, out.shape, out.dtype)
ONNX file here
Hi @ighgul,
We don’t have experience with Jax, so it might be related to Jax or to the package that you’ve used to create the ONNX file from Jax. The model is filled with “artifacts” like Expand/Reshape sequences like this:
The parser don’t know how to eat this. Also, the fact that the tensor dimensions varies from 4D along the way is problematic.
@Nadav hey, do you support interpolation reshape (upgrade/downgrade)? And what type of reshape layers can be implemented in a model?
Since this may change between SW versions, the best answer is to refer to the supported op page on the DFC:
Dataflow Compiler v3.28.0 (hailo.ai)