PyTorch confusables 대잔치 ep.2 [pytorch squeeze/unsqueeze]

2024. 6. 22. 14:11딥러닝 이론

next_q_values = torch.gather(next_qa_values, 1, next_action.unsqueeze(axis=-1)).squeeze(axis=1)

 

What is pytorch squeeze and unsqueeze?

- Squeeze "squeezes out" the 1 dimensions

- Unsqueeze is the reverse of squeeze

- You can set dim to specify the dimension to apply

a = torch.randn(2,1,2)
b = torch.squeeze(a)

Output:
tensor([[[-1.6897, -0.6981]],

        [[-1.7473,  1.1294]]])

tensor([[-1.6897, -0.6981],
        [-1.7473,  1.1294]])

The "1" in 2x1x2 Tensor is removed.

The same happens for a 2x2x1, 1x2x2 Tensor, etc.

 

Remember: set the "dim"!

pytorch will automatically remove all the "1" dimensions