According to Dataflow User Guide(p138) reshape is supported a tensor from (batch, height, 1, F) to (batch, height, W′, F’) where F = W’ * F’. But it seems there is something more beyond what is covered in the guide.
As I have posted I am still going through area-attention module and I guess there is some issue with reshape layer.
code
def forward(self, x):
"""Processes the input tensor 'x' through the area-attention."""
B, C, H, W = x.shape
N = H * W
qk = self.qk(x).flatten(2)
v = self.v(x)
pp = self.pe(v)
v = v.flatten(2)
if self.area > 1:
qk = qk.unsqueeze(2)
print(qk.shape)
qk = qk.reshape(1, C * 2, B * self.area, N // self.area)
print(qk.shape)
qk = qk.transpose(0, 2)
v = v.unsqueeze(2)
v = v.reshape(1, C, B * self.area, N // self.area)
v = v.transpose(0, 2)
B, _, _, N = qk.shape
the result of reshape format
torch.size([1, 512, 1, 16384])
torch.size([1, 512, 4, 4096])
But during parsing, the result becomes None which causes TypeError: object of type ‘NoneType’ has no len()
input [batch, channels, height, width]
/qk/conv/Conv [batch, channels, height, width]
/Reshape [1, 512, 16384] [batch, channels, height, width]
/Reshape [batch, channels, width]
/Unsqueeze [batch, channels, groups, width]
/Reshape_2 None
Are there any more stuff not documented in the guide? Any help is appreciated.