• 设为首页
  • 收藏本站
  • 积分充值
  • VIP赞助
  • 手机版
  • 微博
  • 微信
    微信公众号 添加方式:
    1:搜索微信号(888888
    2:扫描左侧二维码
  • 快捷导航
    福建二哥 门户 查看主题

    Pytorch之上/下采样函数torch.nn.functional.interpolate插值详解

    发布者: 土豆服务器 | 发布时间: 2025-6-17 08:08| 查看数: 33| 评论数: 0|帖子模式

    Pytorch上/下采样函数torch.nn.functional.interpolate插值
    1. torch.nn.functional.interpolate(input_tensor, size=None, scale_factor=8, mode='bilinear', align_corners=False)
    2. '''
    3. Down/up samples the input to either the given size or the given scale_factor
    4. The algorithm used for interpolation is determined by mode.
    5. Currently temporal, spatial and volumetric sampling are supported, i.e. expected inputs are 3-D, 4-D or 5-D in shape.
    6. The input dimensions are interpreted in the form: mini-batch x channels x [optional depth] x [optional height] x width.
    7. The modes available for resizing are: nearest, linear (3D-only), bilinear, bicubic (4D-only), trilinear (5D-only), area
    8. '''
    复制代码
    这个函数是用来
    1. 上采样
    复制代码
    1. 下采样
    复制代码
    tensor的
    1. 空间维度(h,w)
    复制代码
    1. input_tensor
    复制代码
    支持输入3D
    1. (b, c, w)或(batch,seq_len,dim)
    复制代码
    、4D
    1. (b, c, h, w)
    复制代码
    、5D
    1. (b, c, f, h, w)
    复制代码
    的 tensor shape。其中b表示batch_size,c表示channel,f表示frames,h表示height,w表示weight。
    1. size
    复制代码
    是目标tensor的
    1. (w)/(h,w)/(f,h,w)
    复制代码
    的形状;
    1. scale_factor
    复制代码
    是采样tensor的saptial shape
    1. (w)/(h,w)/(f,h,w)
    复制代码
    的缩放系数,
    1. size
    复制代码
    1. scale_factor
    复制代码
    两个参数只能定义一个,具体是上采样,还是下采样根据这两个参数判断。如果
    1. size
    复制代码
    或者
    1. scale_factor
    复制代码
    1. list序列
    复制代码
    ,则必须匹配输入的大小。

    • 如果输入3D,则它们的序列长度必须是1(只缩放最后1个维度w)。
    • 如果输入4D,则它们的序列长度必须是2(缩放最后2个维度h,w)。
    • 如果输入是5D,则它们的序列长度必须是3(缩放最后3个维度f,h,w)。
    插值算法
    1. mode
    复制代码
    可选:
    1. 最近邻(nearest, 默认)
    复制代码
    1. 线性(linear, 3D-only)
    复制代码
    1. 双线性(bilinear, 4D-only)
    复制代码
    1. 三线性(trilinear, 5D-only)
    复制代码
    等等。
    是否
    1. align_corners
    复制代码
    对齐角点:可选的bool值, 如果
    1. align_corners=True
    复制代码
    ,则对齐 input 和 output 的角点像素(corner pixels),保持在角点像素的值. 只会对
    1. mode=linear, bilinear, trilinear
    复制代码
    有作用. 默认是 False。一图看懂
    1. align_corners
    复制代码
    =
    1. True
    复制代码
    1. False
    复制代码
    的区别,从4×4上采样成8×8。
    一个是按四角的像素点中心对齐,另一个是按四角的像素角点对齐:
    1. import torch
    2. import torch.nn.functional as F
    3. b, c, f, h, w = 1, 3, 8, 64, 64
    复制代码
    1. upsample/downsample 3D tensor
    1. # interpolate 3D tensor
    2. x = torch.randn([b, c, w])
    3. ## downsample to (b, c, w/2)
    4. y0 = F.interpolate(x, scale_factor=0.5, mode='nearest')
    5. y1 = F.interpolate(x, size=[w//2], mode='nearest')
    6. y2 = F.interpolate(x, scale_factor=0.5, mode='linear')  # only 3D
    7. y3 = F.interpolate(x, size=[w//2], mode='linear')  # only 3D
    8. print(y0.shape, y1.shape, y2.shape, y3.shape)
    9. # torch.Size([1, 3, 32]) torch.Size([1, 3, 32]) torch.Size([1, 3, 32]) torch.Size([1, 3, 32])

    10. ## upsample to (b, c, w*2)
    11. y0 = F.interpolate(x, scale_factor=2, mode='nearest')
    12. y1 = F.interpolate(x, size=[w*2], mode='nearest')
    13. y2 = F.interpolate(x, scale_factor=2, mode='linear')  # only 3D
    14. y3 = F.interpolate(x, size=[w*2], mode='linear')  # only 3D
    15. print(y0.shape, y1.shape, y2.shape, y3.shape)
    16. # torch.Size([1, 3, 128]) torch.Size([1, 3, 128]) torch.Size([1, 3, 128]) torch.Size([1, 3, 128])
    复制代码
    2. upsample/downsample 4D tensor
    1. # interpolate 4D tensor
    2. x = torch.randn(b, c, h, w)
    3. ## downsample to (b, c, h/2, w/2)
    4. y0 = F.interpolate(x, scale_factor=0.5, mode='nearest')
    5. y1 = F.interpolate(x, size=[h//2, w//2], mode='nearest')
    6. y2 = F.interpolate(x, scale_factor=0.5, mode='bilinear')  # only 4D
    7. y3 = F.interpolate(x, size=[h//2, w//2], mode='bilinear')  # only 4D
    8. print(y0.shape, y1.shape, y2.shape, y3.shape)
    9. # torch.Size([1, 3, 32, 32]) torch.Size([1, 3, 32, 32]) torch.Size([1, 3, 32, 32]) torch.Size([1, 3, 32, 32])

    10. ## upsample to (b, c, h*2, w*2)
    11. y0 = F.interpolate(x, scale_factor=2, mode='nearest')
    12. y1 = F.interpolate(x, size=[h*2, w*2], mode='nearest')
    13. y2 = F.interpolate(x, scale_factor=2, mode='bilinear')  # only 4D
    14. y3 = F.interpolate(x, size=[h*2, w*2], mode='bilinear')  # only 4D
    15. print(y0.shape, y1.shape, y2.shape, y3.shape)
    16. # torch.Size([1, 3, 128, 128]) torch.Size([1, 3, 128, 128]) torch.Size([1, 3, 128, 128]) torch.Size([1, 3, 128, 128])
    复制代码
    3. upsample/downsample 5D tensor
    1. # interpolate 5D tensor
    2. x = torch.randn(b, c, f, h, w)
    3. ## downsample to (b, c, f/2, h/2, w/2)
    4. y0 = F.interpolate(x, scale_factor=0.5, mode='nearest')
    5. y1 = F.interpolate(x, size=[f//2, h//2, w//2], mode='nearest')
    6. y2 = F.interpolate(x, scale_factor=2, mode='trilinear')  # only 5D
    7. y3 = F.interpolate(x, size=[f//2, h//2, w//2], mode='trilinear')  # only 5D
    8. print(y0.shape, y1.shape, y2.shape, y3.shape)
    9. # torch.Size([1, 3, 4, 32, 32]) torch.Size([1, 3, 4, 32, 32]) torch.Size([1, 3, 16, 128, 128]) torch.Size([1, 3, 4, 32, 32])

    10. ## upsample to (b, c, f*2, h*2, w*2)
    11. y0 = F.interpolate(x, scale_factor=2, mode='nearest')
    12. y1 = F.interpolate(x, size=[f*2, h*2, w*2], mode='nearest')
    13. y2 = F.interpolate(x, scale_factor=2, mode='trilinear')  # only 5D
    14. y3 = F.interpolate(x, size=[f*2, h*2, w*2], mode='trilinear')  # only 5D
    15. print(y0.shape, y1.shape, y2.shape, y3.shape)
    16. # torch.Size([1, 3, 16, 128, 128]) torch.Size([1, 3, 16, 128, 128]) torch.Size([1, 3, 16, 128, 128]) torch.Size([1, 3, 16, 128, 128])
    复制代码
    总结

    以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

    来源:https://www.jb51.net/python/339694uzg.htm
    免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!

    本帖子中包含更多资源

    您需要 登录 才可以下载或查看,没有账号?立即注册

    ×

    最新评论

    QQ Archiver 手机版 小黑屋 福建二哥 ( 闽ICP备2022004717号|闽公网安备35052402000345号 )

    Powered by Discuz! X3.5 © 2001-2023

    快速回复 返回顶部 返回列表