Hey @chenyao ,
This is indeed a common challenge when deploying models to environments with limited operator support. Based on your tensor shapes and use case i recommend the following :
Analysis of Your Current Setup
From your debug output, I can see you’re working with:
- Feature maps:
(1, 6, 256, 64, 176)
→ flattened to (6, 256, 64, 176)
- Sampling points:
(6, 900, 13, 2)
in normalized coordinates [-1, 1]
- Output shape:
(6, 256, 900, 13)
Our primary recommendation is to replace grid_sample
with a custom bilinear interpolation implementation using only basic tensor operations. This approach maintains mathematical equivalence while ensuring broader deployment compatibility.
import torch
def _normalize_to_img_coords(grid, H, W, align_corners=False):
"""Convert normalized grid coordinates [-1,1] to image pixel coordinates"""
x = grid[..., 0]
y = grid[..., 1]
if align_corners:
x = 0.5 * (x + 1) * (W - 1)
y = 0.5 * (y + 1) * (H - 1)
else:
# PyTorch default behavior
x = ((x + 1) * W - 1) * 0.5
y = ((y + 1) * H - 1) * 0.5
return x, y
def bilinear_sample_2d(x, grid, align_corners=False):
"""
Bilinear sampling implementation equivalent to grid_sample
Args:
x: (N, C, H, W) feature tensor
grid: (N, Hout, Wout, 2) sampling coordinates in [-1, 1]
align_corners: coordinate alignment mode
Returns:
(N, C, Hout, Wout) sampled features
"""
N, C, H, W = x.shape
Hout, Wout = grid.shape[1], grid.shape[2]
# Convert to pixel coordinates
gx, gy = _normalize_to_img_coords(grid, H, W, align_corners)
# Find neighboring pixels
x0 = torch.floor(gx).long()
y0 = torch.floor(gy).long()
x1 = x0 + 1
y1 = y0 + 1
# Clamp coordinates for safe indexing
x0c = x0.clamp(0, W - 1)
x1c = x1.clamp(0, W - 1)
y0c = y0.clamp(0, H - 1)
y1c = y1.clamp(0, H - 1)
# Calculate interpolation weights
wx1 = (gx - x0.float()).clamp(0, 1)
wy1 = (gy - y0.float()).clamp(0, 1)
wx0 = 1 - wx1
wy0 = 1 - wy1
# Bilinear weights for four corners
w00 = (wx0 * wy0).unsqueeze(1)
w01 = (wx0 * wy1).unsqueeze(1)
w10 = (wx1 * wy0).unsqueeze(1)
w11 = (wx1 * wy1).unsqueeze(1)
# Flatten for efficient gathering
x_flat = x.view(N, C, H * W)
def gather_at(ix, iy):
idx = (iy * W + ix).view(N, 1, -1).expand(N, C, -1)
return torch.gather(x_flat, 2, idx).view(N, C, Hout, Wout)
# Sample at four corners
v00 = gather_at(x0c, y0c)
v01 = gather_at(x0c, y1c)
v10 = gather_at(x1c, y0c)
v11 = gather_at(x1c, y1c)
# Bilinear interpolation
out = w00 * v00 + w01 * v01 + w10 * v10 + w11 * v11
# Apply boundary mask (padding_mode='zeros' equivalent)
mask = (x0 >= 0) & (x1 < W) & (y0 >= 0) & (y1 < H)
out = out * mask.unsqueeze(1).to(out.dtype)
return out
Integration with Your Existing Code
Replace your current implementation as follows:
features = []
for fm in feature_maps:
print("Feature map shape:", fm.shape, "Points shape:", points_2d.shape)
# Replace grid_sample with custom bilinear sampling
sampled = bilinear_sample_2d(
fm.flatten(end_dim=1),
points_2d,
align_corners=False # Match PyTorch default
)
features.append(sampled)
print("Output shape:", sampled.shape)
Hope this helps !