PyTorch confusables 대잔치 ep.1 [pytorch dim]

2024. 5. 9. 15:03딥러닝 이론

What are dims in pytorch?
- "dim" stands for dimensions
- same as numpy "axis"
- for a 2D torch.Tensor, dim==0 goes along the rows, dim==1 goes along the cols
Code:
a = torch.randn(5, 5)
b = torch.argmax(a) # torch.argmax(input, dim=None, keepdim=False)

Output:
tensor([[-8.4741e-01,  3.6448e-01,  1.4155e+00,  2.6417e-01, -3.4608e-01],
        [-8.9281e-01, -3.1258e-01,  1.4784e+00,  2.9023e-01, -1.0445e+00],
        [ 5.4736e-01, -6.0465e-01, -8.5010e-04,  1.0265e+00, -1.8865e-01],
        [-1.3834e+00, -1.0257e+00, -5.9035e-02,  5.3723e-01,  5.0983e-01],
        [-1.1933e+00,  2.1935e-01,  4.3595e-01, -1.4600e-01,  3.2134e-01]])
tensor(7)

Not providing dim automatically flattens the Tensor.

Thus 1.4784e+00 is the largest. You can confirm this.

 

What about with dim?

# b = torch.argmax(a, dim=0)
tensor([[-1.3997, -0.8501, -1.1572,  0.3971,  0.0856],
        [-0.1836,  0.5280,  0.6923,  0.1533, -0.1619],
        [-0.5082, -0.9608, -0.6744, -0.7565, -0.7908],
        [ 1.3769,  0.1149,  0.2402, -1.5244,  0.6827],
        [-0.1595,  0.2942, -0.8637, -0.9577, -0.8835]])

tensor([3, 1, 1, 0, 3])

# b = torch.argmax(a, dim=1)
tensor([[ 0.1855,  0.0787,  0.7245, -0.1628, -0.0741],
        [ 0.1648,  1.1548,  0.5182,  1.4389,  2.9808],
        [-0.3670,  1.2158, -0.5439,  2.8245,  0.6805],
        [ 0.1014, -0.2867,  0.1345, -0.0341, -0.7118],
        [ 0.1943, -0.8088,  1.1831, -0.8923,  1.0722]])

tensor([2, 4, 3, 2, 2])

 

Remember:

For dim=0 argmax selection is done vertically, down the columns.
For dim=1 selection is done horizontally, along the rows.

 

dim=0

tensor([[-1.3997, -0.8501, -1.1572,  0.3971,  0.0856],
        [-0.1836,  0.5280,  0.6923,  0.1533, -0.1619],
        [-0.5082, -0.9608, -0.6744, -0.7565, -0.7908],
        [ 1.3769,  0.1149,  0.2402, -1.5244,  0.6827],
        [-0.1595,  0.2942, -0.8637, -0.9577, -0.8835]])

 

dim=1

tensor([[ 0.1855,  0.0787,  0.7245-0.1628-0.0741],
        [ 0.1648,  1.1548,  0.5182,  1.4389,  2.9808],
        [-0.3670,  1.2158, -0.5439,  2.8245,  0.6805],
        [ 0.1014, -0.2867,  0.1345, -0.0341, -0.7118],
        [ 0.1943, -0.8088,  1.1831, -0.8923,  1.0722]])