I think i found a error in class ONNXGraphNode(NNGraphNode)
in which the function get_spatial_unflatten_reshape_info spatial_reshape_sizes
is defined.
Old function:
def get_spatial_unflatten_reshape_info(self):
if self.op != "Reshape":
consumed_vertices = [look_for_node(self._graph, self, [FwdChainNode(op="Reshape")])]
else:
transpose = look_for_node(self._graph, self, [FwdChainNode(op="Transpose")])
consumed_vertices = [self, transpose] if transpose else [self]
if len(consumed_vertices) < 1:
raise UnexpectedNodeError(f"Failed to find reshape node in format conversion layer near {self.name}.")
print(f"Failed near {self.name}.")
output_shape = consumed_vertices[-1].get_output_shapes()[0]
print(output_shape)
spatial_reshape_sizes = [output_shape[1], output_shape[2]]
return consumed_vertices, [output_shape], spatial_reshape_sizes
Changed function:
def get_spatial_unflatten_reshape_info(self):
if self.op == "Reshape":
consumed_vertices = [look_for_node(self._graph, self, [FwdChainNode(op="Reshape")])]
else:
transpose = look_for_node(self._graph, self, [FwdChainNode(op="Transpose")])
consumed_vertices = [self, transpose] if transpose else [self]
if len(consumed_vertices) < 1:
raise UnexpectedNodeError(f"Failed to find reshape node in format conversion layer near {self.name}.")
print(f"Failed near {self.name}.")
output_shape = consumed_vertices[-1].get_output_shapes()[0]
print(output_shape)
spatial_reshape_sizes = [output_shape[1], output_shape[2]]
return consumed_vertices, [output_shape], spatial_reshape_sizes
i changed the if self.op != "Reshape":
to if self.op == "Reshape":
.
Now it compiles. Maybe i am using a old version of onnx.
Please tell me if i am completely wrong about this.