Pytorch上/下采样函数torch.nn.functional.interpolate插值
- torch.nn.functional.interpolate(input_tensor, size=None, scale_factor=8, mode='bilinear', align_corners=False)
- '''
- Down/up samples the input to either the given size or the given scale_factor
- The algorithm used for interpolation is determined by mode.
- Currently temporal, spatial and volumetric sampling are supported, i.e. expected inputs are 3-D, 4-D or 5-D in shape.
- The input dimensions are interpreted in the form: mini-batch x channels x [optional depth] x [optional height] x width.
- The modes available for resizing are: nearest, linear (3D-only), bilinear, bicubic (4D-only), trilinear (5D-only), area
- '''
复制代码 这个函数是用来或tensor的:支持输入3D- (b, c, w)或(batch,seq_len,dim)
复制代码 、4D、5D的 tensor shape。其中b表示batch_size,c表示channel,f表示frames,h表示height,w表示weight。是目标tensor的的形状;是采样tensor的saptial shape的缩放系数,和两个参数只能定义一个,具体是上采样,还是下采样根据这两个参数判断。如果或者是,则必须匹配输入的大小。
- 如果输入3D,则它们的序列长度必须是1(只缩放最后1个维度w)。
- 如果输入4D,则它们的序列长度必须是2(缩放最后2个维度h,w)。
- 如果输入是5D,则它们的序列长度必须是3(缩放最后3个维度f,h,w)。
插值算法可选:、、、等等。
是否对齐角点:可选的bool值, 如果,则对齐 input 和 output 的角点像素(corner pixels),保持在角点像素的值. 只会对- mode=linear, bilinear, trilinear
复制代码 有作用. 默认是 False。一图看懂=与的区别,从4×4上采样成8×8。
一个是按四角的像素点中心对齐,另一个是按四角的像素角点对齐:
- import torch
- import torch.nn.functional as F
- b, c, f, h, w = 1, 3, 8, 64, 64
复制代码 1. upsample/downsample 3D tensor
- # interpolate 3D tensor
- x = torch.randn([b, c, w])
- ## downsample to (b, c, w/2)
- y0 = F.interpolate(x, scale_factor=0.5, mode='nearest')
- y1 = F.interpolate(x, size=[w//2], mode='nearest')
- y2 = F.interpolate(x, scale_factor=0.5, mode='linear') # only 3D
- y3 = F.interpolate(x, size=[w//2], mode='linear') # only 3D
- print(y0.shape, y1.shape, y2.shape, y3.shape)
- # torch.Size([1, 3, 32]) torch.Size([1, 3, 32]) torch.Size([1, 3, 32]) torch.Size([1, 3, 32])
- ## upsample to (b, c, w*2)
- y0 = F.interpolate(x, scale_factor=2, mode='nearest')
- y1 = F.interpolate(x, size=[w*2], mode='nearest')
- y2 = F.interpolate(x, scale_factor=2, mode='linear') # only 3D
- y3 = F.interpolate(x, size=[w*2], mode='linear') # only 3D
- print(y0.shape, y1.shape, y2.shape, y3.shape)
- # torch.Size([1, 3, 128]) torch.Size([1, 3, 128]) torch.Size([1, 3, 128]) torch.Size([1, 3, 128])
复制代码 2. upsample/downsample 4D tensor
- # interpolate 4D tensor
- x = torch.randn(b, c, h, w)
- ## downsample to (b, c, h/2, w/2)
- y0 = F.interpolate(x, scale_factor=0.5, mode='nearest')
- y1 = F.interpolate(x, size=[h//2, w//2], mode='nearest')
- y2 = F.interpolate(x, scale_factor=0.5, mode='bilinear') # only 4D
- y3 = F.interpolate(x, size=[h//2, w//2], mode='bilinear') # only 4D
- print(y0.shape, y1.shape, y2.shape, y3.shape)
- # torch.Size([1, 3, 32, 32]) torch.Size([1, 3, 32, 32]) torch.Size([1, 3, 32, 32]) torch.Size([1, 3, 32, 32])
- ## upsample to (b, c, h*2, w*2)
- y0 = F.interpolate(x, scale_factor=2, mode='nearest')
- y1 = F.interpolate(x, size=[h*2, w*2], mode='nearest')
- y2 = F.interpolate(x, scale_factor=2, mode='bilinear') # only 4D
- y3 = F.interpolate(x, size=[h*2, w*2], mode='bilinear') # only 4D
- print(y0.shape, y1.shape, y2.shape, y3.shape)
- # torch.Size([1, 3, 128, 128]) torch.Size([1, 3, 128, 128]) torch.Size([1, 3, 128, 128]) torch.Size([1, 3, 128, 128])
复制代码 3. upsample/downsample 5D tensor
- # interpolate 5D tensor
- x = torch.randn(b, c, f, h, w)
- ## downsample to (b, c, f/2, h/2, w/2)
- y0 = F.interpolate(x, scale_factor=0.5, mode='nearest')
- y1 = F.interpolate(x, size=[f//2, h//2, w//2], mode='nearest')
- y2 = F.interpolate(x, scale_factor=2, mode='trilinear') # only 5D
- y3 = F.interpolate(x, size=[f//2, h//2, w//2], mode='trilinear') # only 5D
- print(y0.shape, y1.shape, y2.shape, y3.shape)
- # torch.Size([1, 3, 4, 32, 32]) torch.Size([1, 3, 4, 32, 32]) torch.Size([1, 3, 16, 128, 128]) torch.Size([1, 3, 4, 32, 32])
- ## upsample to (b, c, f*2, h*2, w*2)
- y0 = F.interpolate(x, scale_factor=2, mode='nearest')
- y1 = F.interpolate(x, size=[f*2, h*2, w*2], mode='nearest')
- y2 = F.interpolate(x, scale_factor=2, mode='trilinear') # only 5D
- y3 = F.interpolate(x, size=[f*2, h*2, w*2], mode='trilinear') # only 5D
- print(y0.shape, y1.shape, y2.shape, y3.shape)
- # torch.Size([1, 3, 16, 128, 128]) torch.Size([1, 3, 16, 128, 128]) torch.Size([1, 3, 16, 128, 128]) torch.Size([1, 3, 16, 128, 128])
复制代码 总结
以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。
来源:https://www.jb51.net/python/339694uzg.htm
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作! |