pytorch 自定义参数不更新方式
更新时间:2020年01月06日 09:56:54 作者:ShellCollector
今天小编就为大家分享一篇pytorch 自定义参数不更新方式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
nn.Module中定义参数:不需要加cuda,可以求导,反向传播
class BiFPN(nn.Module): def __init__(self, fpn_sizes): self.w1 = nn.Parameter(torch.rand(1)) print("no---------------------------------------------------",self.w1.data, self.w1.grad)
下面这个例子说明中间变量可能没有梯度,但是最终变量有梯度:
cy1 cd都有梯度
import torch xP=torch.Tensor([[ 3233.8557, 3239.0657, 3243.4355, 3234.4507, 3241.7087, 3243.7292, 3234.6826, 3237.6609, 3249.7937, 3244.8623, 3239.5349, 3241.4626, 3251.3457, 3247.4263, 3236.4924, 3251.5735, 3246.4731, 3242.4692, 3239.4958, 3247.7283, 3251.7134, 3249.0237, 3247.5637], [ 1619.9011, 1619.7140, 1620.4883, 1620.0642, 1620.2191, 1619.9796, 1617.6597, 1621.1522, 1621.0869, 1620.9725, 1620.7130, 1620.6071, 1620.7437, 1621.4825, 1620.5107, 1621.1519, 1620.8462, 1620.5944, 1619.8038, 1621.3364, 1620.7399, 1621.1178, 1618.7080], [ 1619.9330, 1619.8542, 1620.5176, 1620.1167, 1620.1577, 1620.0579, 1617.7155, 1621.1718, 1621.1338, 1620.9572, 1620.6288, 1620.6621, 1620.7074, 1621.5305, 1620.5656, 1621.2281, 1620.8346, 1620.6021, 1619.8228, 1621.3936, 1620.7616, 1621.1954, 1618.7983], [ 1922.6078, 1922.5680, 1923.1331, 1922.6604, 1922.9589, 1922.8818, 1920.4602, 1923.8107, 1924.0142, 1923.6907, 1923.4465, 1923.2820, 1923.5728, 1924.4071, 1922.8853, 1924.1107, 1923.5465, 1923.5121, 1922.4673, 1924.1871, 1923.6248, 1923.9086, 1921.9496], [ 1922.5948, 1922.5311, 1923.2850, 1922.6613, 1922.9734, 1922.9271, 1920.5950, 1923.8757, 1924.0422, 1923.7318, 1923.4889, 1923.3296, 1923.5752, 1924.4948, 1922.9866, 1924.1642, 1923.6427, 1923.6067, 1922.5214, 1924.2761, 1923.6636, 1923.9481, 1921.9005]]) yP=torch.Tensor([[ 2577.7729, 2590.9868, 2600.9712, 2579.0195, 2596.3684, 2602.2771, 2584.0305, 2584.7749, 2615.4897, 2603.3164, 2589.8406, 2595.3486, 2621.9116, 2608.2820, 2582.9534, 2619.2073, 2607.1233, 2597.7888, 2591.5735, 2608.9060, 2620.8992, 2613.3511, 2614.2195], [ 673.7830, 693.8904, 709.2661, 675.4254, 702.4049, 711.2085, 683.1571, 684.6160, 731.3878, 712.7546, 692.3011, 701.0069, 740.6815, 720.4229, 681.8199, 736.9869, 718.5508, 704.3666, 695.0511, 721.5912, 739.6672, 728.0584, 729.3143], [ 673.8367, 693.9529, 709.3196, 675.5266, 702.3820, 711.2159, 683.2151, 684.6421, 731.5291, 712.6366, 692.1913, 701.0057, 740.6229, 720.4082, 681.8656, 737.0168, 718.4943, 704.2719, 695.0775, 721.5616, 739.7233, 728.1235, 729.3387], [ 872.9419, 891.7061, 905.8004, 874.6565, 899.2053, 907.5082, 881.5528, 883.0028, 926.3083, 908.9742, 890.0403, 897.8606, 934.6913, 916.0902, 880.4689, 931.3562, 914.4233, 901.2154, 892.5759, 916.9590, 933.9291, 923.0745, 924.4461], [ 872.9661, 891.7683, 905.8128, 874.6301, 899.2887, 907.5155, 881.6916, 883.0234, 926.3242, 908.9561, 890.0731, 897.9221, 934.7324, 916.0806, 880.4300, 931.3933, 914.5662, 901.2715, 892.5501, 916.9894, 933.9813, 923.0823, 924.3654]]) shape=[4000, 6000] cx,cy1=torch.rand(1,requires_grad=True),torch.rand(1,requires_grad=True) cd=torch.rand(1,requires_grad=True) ox,oy=cx,cy1 print('cx:{},cy:{}'.format(id(cx),id(cy1))) print('ox:{},oy:{}'.format(id(ox),id(oy))) cx,cy=cx*shape[1],cy1*shape[0] print('cx:{},cy:{}'.format(id(cx),id(cy))) print('ox:{},oy:{}'.format(id(ox),id(oy))) distance=torch.sqrt(torch.pow((xP-cx),2)+torch.pow((yP-cy),2)) mean=torch.mean(distance,1) starsFC=cd*torch.pow((distance-mean[...,None]),2) loss=torch.sum(torch.mean(starsFC,1).squeeze(),0) loss.backward() print(loss) print(cx) print(cy1) print("cx",cx.grad) print("cy",cy1.grad) print("cd",cd.grad) print(ox.grad) print(oy.grad) print('cx:{},cy:{}'.format(id(cx),id(cy))) print('ox:{},oy:{}'.format(id(ox),id(oy)))
以上这篇pytorch 自定义参数不更新方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。
相关文章
Django1.9 加载通过ImageField上传的图片方法
今天小编就为大家分享一篇Django1.9 加载通过ImageField上传的图片方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧2018-05-05
最新评论