PyTorch confusables 대잔치 ep.2 [pytorch squeeze/unsqueeze]
2024. 6. 22. 14:11ㆍ딥러닝 이론
next_q_values = torch.gather(next_qa_values, 1, next_action.unsqueeze(axis=-1)).squeeze(axis=1)
What is pytorch squeeze and unsqueeze?
- Squeeze "squeezes out" the 1 dimensions
- Unsqueeze is the reverse of squeeze
- You can set dim to specify the dimension to apply
a = torch.randn(2,1,2)
b = torch.squeeze(a)
Output:
tensor([[[-1.6897, -0.6981]],
[[-1.7473, 1.1294]]])
tensor([[-1.6897, -0.6981],
[-1.7473, 1.1294]])
The "1" in 2x1x2 Tensor is removed.
The same happens for a 2x2x1, 1x2x2 Tensor, etc.
Remember: set the "dim"!
pytorch will automatically remove all the "1" dimensions
'딥러닝 이론' 카테고리의 다른 글
PyTorch confusables 대잔치 ep.1 [pytorch dim] (0) | 2024.05.09 |
---|---|
Day 30: 9-4, 9-5, 9-6 Seq2seq의 개념 및 문제점 (0) | 2023.03.21 |
Day 29: Ch. 9-4 RNN유형과 seq2seq (0) | 2023.03.20 |
Day 28: 9-1, 9-2 RNN, RNN backpropagation 그리고 구조적 한계 (0) | 2023.03.19 |
Day 27: 8-8 Beautiful Insights for CNN (CNN 마무리) (0) | 2023.03.19 |