这篇文章主要给大家分享pytorch函数的内容,本文给大家介绍两个函数,分别是squeeze函数、cat函数。那么这两个函数有什么用呢?用法是什么?下面我们一起来学习一下。
1 squeeze(): 去除size为1的维度,包括行和列。
至于维度大于等于2时,squeeze()不起作用。
行、例:
>>> torch.rand(4, 1, 3)
(0 ,.,.) =
0.5391 0.8523 0.9260
(1 ,.,.) =
0.2507 0.9512 0.6578
(2 ,.,.) =
0.7302 0.3531 0.9442
(3 ,.,.) =
0.2689 0.4367 0.6610
[torch.FloatTensor of size 4x1x3]
>>> torch.rand(4, 1, 3).squeeze()
0.0801 0.4600 0.1799
0.0236 0.7137 0.6128
0.0242 0.3847 0.4546
0.9004 0.5018 0.4021
[torch.FloatTensor of size 4x3]
列、例:
>>> torch.rand(4, 3, 1)
(0 ,.,.) =
0.7013
0.9818
0.9723
(1 ,.,.) =
0.9902
0.8354
0.3864
(2 ,.,.) =
0.4620
0.0844
0.5707
(3 ,.,.) =
0.5722
0.2494
0.5815
[torch.FloatTensor of size 4x3x1]
>>> torch.rand(4, 3, 1).squeeze()
0.8784 0.6203 0.8213
0.7238 0.5447 0.8253
0.1719 0.7830 0.1046
0.0233 0.9771 0.2278
[torch.FloatTensor of size 4x3]
不变、例:
>>> torch.rand(4, 3, 2)
(0 ,.,.) =
0.6618 0.1678
0.3476 0.0329
0.1865 0.4349
(1 ,.,.) =
0.7588 0.8972
0.3339 0.8376
0.6289 0.9456
(2 ,.,.) =
0.1392 0.0320
0.0033 0.0187
0.8229 0.0005
(3 ,.,.) =
0.2327 0.6264
0.4810 0.6642
0.8625 0.6334
[torch.FloatTensor of size 4x3x2]
>>> torch.rand(4, 3, 2).squeeze()
(0 ,.,.) =
0.0593 0.8910
0.9779 0.1530
0.9210 0.2248
(1 ,.,.) =
0.7938 0.9362
0.1064 0.6630
0.9321 0.0453
(2 ,.,.) =
0.0189 0.9187
0.4458 0.9925
0.9928 0.7895
(3 ,.,.) =
0.5116 0.7253
0.0132 0.6673
0.9410 0.8159
[torch.FloatTensor of size 4x3x2]
2 cat函数
>>> t1=torch.FloatTensor(torch.randn(2,3))
>>> t1
-1.9405 1.2009 0.0018
0.9463 0.4409 -1.9017
[torch.FloatTensor of size 2x3]
>>> t2=torch.FloatTensor(torch.randn(2,2))
>>> t2
0.0942 0.1581
1.1621 1.2617
[torch.FloatTensor of size 2x2]
>>> torch.cat((t1, t2), 1)
-1.9405 1.2009 0.0018 0.0942 0.1581
0.9463 0.4409 -1.9017 1.1621 1.2617
[torch.FloatTensor of size 2x5]
补充:pytorch中 max()、view()、 squeeze()、 unsqueeze()
查了好多博客都似懂非懂,后来写了几个小例子,瞬间一目了然。
一、torch.max()
import torch
a=torch.randn(3)
print("a:\n",a)
print('max(a):',torch.max(a))
大型站长资讯类网站!