import torch
import numpy as np
Pytorch Basics
Pytorch Basics
torch.cuda.is_available()
True
Data Types
= torch.ones(1, dtype=torch.float)
float_tensor float_tensor.dtype
torch.float32
= torch.ones(1, dtype=torch.double)
double_tensor double_tensor.dtype
torch.float64
= torch.ones(1, dtype=torch.complex64)
complex_float_tensor complex_float_tensor.dtype
torch.complex64
= torch.ones(1, dtype=torch.complex128)
complex_double_tensor complex_double_tensor.dtype
torch.complex128
= torch.ones(1, dtype=torch.int)
int_tensor int_tensor.dtype
torch.int32
= torch.ones(1, dtype=torch.long)
long_tensor long_tensor.dtype
torch.int64
= torch.ones(1, dtype=torch.uint8)
uint_tensor uint_tensor.dtype
torch.uint8
= torch.ones(1, dtype=torch.double)
double_tensor double_tensor.dtype
torch.float64
= torch.ones(1, dtype=torch.bool)
bool_tensor bool_tensor.dtype
torch.bool
Creation Operations
torch.is_tensor
= torch.tensor([1, 2, 3])
x x, torch.is_tensor(x)
(tensor([1, 2, 3]), True)
torch.set_default_device
1.2, 3]).device torch.tensor([
device(type='cpu')
'cuda') # current device is 0
torch.set_default_device(1.2, 3]).device torch.tensor([
device(type='cuda', index=0)
'cpu')
torch.set_default_device(= torch.arange(1000000)
a a
tensor([ 0, 1, 2, ..., 999997, 999998, 999999])
+ 1 a
CPU times: user 20.7 ms, sys: 25.4 ms, total: 46.1 ms
Wall time: 7.37 ms
tensor([ 1, 2, 3, ..., 999998, 999999, 1000000])
'cuda')
torch.set_default_device(= torch.arange(1000000)
a a
tensor([ 0, 1, 2, ..., 999997, 999998, 999999], device='cuda:0')
+ 1 a
CPU times: user 6.29 ms, sys: 0 ns, total: 6.29 ms
Wall time: 912 µs
tensor([ 1, 2, 3, ..., 999998, 999999, 1000000],
device='cuda:0')
torch.get_default_dtype
# initial default for floating point is torch.float32 torch.get_default_dtype()
torch.float32
torch.set_default_dtype(torch.float64)# default is now changed to torch.float64 torch.get_default_dtype()
torch.float64
torch.set_printoptions
# Limit the precision of elements
=2)
torch.set_printoptions(precision1.12345]) torch.tensor([
tensor([1.12], device='cuda:0')
# Limit the number of elements shown
=5)
torch.set_printoptions(threshold10) torch.arange(
tensor([0, 1, 2, ..., 7, 8, 9], device='cuda:0')
# Restore defaults
='default')
torch.set_printoptions(profile1.12345]) torch.tensor([
tensor([1.1235], device='cuda:0')
'cpu')
torch.set_default_device(10) torch.arange(
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
torch.as_tensor
= np.array([1, 2, 3])
a = torch.as_tensor(a)
t t
tensor([1, 2, 3])
0] = -1
t[ a
array([-1, 2, 3])
= np.array([1, 2, 3])
a = torch.as_tensor(a, device=torch.device('cuda'))
t t
tensor([1, 2, 3], device='cuda:0')
0] = -1
t[ a
array([1, 2, 3])
t
tensor([-1, 2, 3], device='cuda:0')
torch.zeros
2,2)) torch.empty((
tensor([[1.3554e-20, 3.0851e-41],
[1.3552e-20, 3.0851e-41]])
2, 3) torch.zeros(
tensor([[0., 0., 0.],
[0., 0., 0.]])
5) torch.zeros(
tensor([0., 0., 0., 0., 0.])
2, 3) torch.ones(
tensor([[1., 1., 1.],
[1., 1., 1.]])
5) torch.ones(
tensor([1., 1., 1., 1., 1.])
torch.range
5), torch.arange(1, 4), torch.arange(1, 2.5, 0.5) torch.arange(
(tensor([0, 1, 2, 3, 4]), tensor([1, 2, 3]), tensor([1.0000, 1.5000, 2.0000]))
3, 10, steps=5),\
torch.linspace(-10, 10, steps=5),\
torch.linspace(=-10, end=10, steps=5),\
torch.linspace(start=-10, end=10, steps=1) torch.linspace(start
(tensor([ 3.0000, 4.7500, 6.5000, 8.2500, 10.0000]),
tensor([-10., -5., 0., 5., 10.]),
tensor([-10., -5., 0., 5., 10.]),
tensor([-10.]))
=-10, end=10, steps=5),\
torch.logspace(start=0.1, end=1.0, steps=5),\
torch.logspace(start=0.1, end=1.0, steps=1),\
torch.logspace(start=2, end=2, steps=1, base=2) torch.logspace(start
(tensor([1.0000e-10, 1.0000e-05, 1.0000e+00, 1.0000e+05, 1.0000e+10]),
tensor([ 1.2589, 2.1135, 3.5481, 5.9566, 10.0000]),
tensor([1.2589]),
tensor([4.]))
2,3), dtype=torch.int64) torch.empty((
tensor([[ 140634107035024, 140634107035024, 7454421801564381752],
[2322206376936961119, 7310597164893758754, 145]])
2, 3), 3.141592) torch.full((
tensor([[3.1416, 3.1416, 3.1416],
[3.1416, 3.1416, 3.1416]])
torch.quantize_per_tensor
torch.set_default_dtype(torch.float32)
-1.0, 0.0, 1.0, 2.0]), 0.1, 10, torch.quint8),\
torch.quantize_per_tensor(torch.tensor([-1.0, 0.0, 1.0, 2.0]), 0.1, 10, torch.quint8).int_repr(),\
torch.quantize_per_tensor(torch.tensor([-1.0, 0.0, 1.0, 2.0]), torch.tensor(0.1), torch.tensor(10), torch.quint8) torch.quantize_per_tensor(torch.tensor([
(tensor([-1., 0., 1., 2.], size=(4,), dtype=torch.quint8,
quantization_scheme=torch.per_tensor_affine, scale=0.1, zero_point=10),
tensor([ 0, 10, 20, 30], dtype=torch.uint8),
tensor([-1., 0., 1., 2.], size=(4,), dtype=torch.quint8,
quantization_scheme=torch.per_tensor_affine, scale=0.10000000149011612,
zero_point=10))
torch.complex
= torch.tensor([1, 2], dtype=torch.float32)
real = torch.tensor([3, 4], dtype=torch.float32)
imag = torch.complex(real, imag)
z z.dtype
torch.complex64
real, imag, z
(tensor([1., 2.]), tensor([3., 4.]), tensor([1.+3.j, 2.+4.j]))
torch.polar
import numpy as np
abs = torch.tensor([1, 2], dtype=torch.float64)
= torch.tensor([np.pi / 2, 5 * np.pi / 4], dtype=torch.float64)
angle = torch.polar(abs, angle)
z z
tensor([ 6.1232e-17+1.0000j, -1.4142e+00-1.4142j], dtype=torch.complex128)
Indexing, Slicing, Joining, Mutating Ops
torch.cat
= torch.randn(2, 3)
x x
tensor([[-0.1340, 0.5254, -0.3770],
[-2.0310, -0.8961, -0.6459]])
0) torch.cat((x, x, x),
tensor([[-0.1340, 0.5254, -0.3770],
[-2.0310, -0.8961, -0.6459],
[-0.1340, 0.5254, -0.3770],
[-2.0310, -0.8961, -0.6459],
[-0.1340, 0.5254, -0.3770],
[-2.0310, -0.8961, -0.6459]])
1) torch.cat((x, x),
tensor([[-0.1340, 0.5254, -0.3770, -0.1340, 0.5254, -0.3770],
[-2.0310, -0.8961, -0.6459, -2.0310, -0.8961, -0.6459]])
torch.conj
= torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j])
x x, x.is_conj()
(tensor([-1.+1.j, -2.+2.j, 3.-3.j]), False)
= torch.conj(x)
y y, y.is_conj()
(tensor([-1.-1.j, -2.-2.j, 3.+3.j]), True)
torch.permute
= torch.randn(2, 3, 5)
x x.size()
torch.Size([2, 3, 5])
2, 0, 1)).size() torch.permute(x, (
torch.Size([5, 2, 3])
torch.reshape
= torch.arange(4.)
a a
tensor([0., 1., 2., 3.])
2, 2)) torch.reshape(a, (
tensor([[0., 1.],
[2., 3.]])
= torch.tensor([[0, 1], [2, 3]])
b b
tensor([[0, 1],
[2, 3]])
-1,)) torch.reshape(b, (
tensor([0, 1, 2, 3])
torch.movedim
= torch.randn(2,3,5)
t t.size()
torch.Size([2, 3, 5])
1, 0).shape torch.movedim(t,
torch.Size([3, 2, 5])
1, 0) torch.movedim(t,
tensor([[[-1.0370, -0.2811, 0.2693, 0.5935, -0.1354],
[-2.7575, -2.4650, 0.8077, -0.2873, -1.2993]],
[[-0.6173, -0.0460, -0.6329, 1.0519, -0.1674],
[-0.8958, -0.2828, 1.2355, -1.1782, -1.3597]],
[[-2.0996, -0.6692, -0.7840, 0.1171, 0.0334],
[ 0.4422, 0.2438, 1.0947, -1.2390, -1.4378]]])
1, 2), (0, 1)).shape torch.movedim(t, (
torch.Size([3, 5, 2])
torch.split
= torch.arange(10).reshape(5, 2)
a a
tensor([[0, 1],
[2, 3],
[4, 5],
[6, 7],
[8, 9]])
2) torch.split(a,
(tensor([[0, 1],
[2, 3]]),
tensor([[4, 5],
[6, 7]]),
tensor([[8, 9]]))
1, 4]) torch.split(a, [
(tensor([[0, 1]]),
tensor([[2, 3],
[4, 5],
[6, 7],
[8, 9]]))
torch.t
= torch.randn(())
x x, torch.t(x)
(tensor(-0.6261), tensor(-0.6261))
= torch.randn(3)
x x, torch.t(x)
(tensor([-1.2207, -0.6549, -0.0028]), tensor([-1.2207, -0.6549, -0.0028]))
= torch.randn(2, 3)
x x, torch.t(x)
(tensor([[-2.0021, 1.3072, -0.9742],
[-1.8025, 0.5369, 0.2517]]),
tensor([[-2.0021, -1.8025],
[ 1.3072, 0.5369],
[-0.9742, 0.2517]]))
Rand
torch.rand
4) torch.rand(
tensor([0.7377, 0.8273, 0.2958, 0.8372])
2, 3) torch.rand(
tensor([[0.5028, 0.1841, 0.1133],
[0.4431, 0.0016, 0.8662]])
torch.randint
3, 5, (3,)) torch.randint(
tensor([4, 3, 3])
10, (2, 2)) torch.randint(
tensor([[5, 3],
[8, 1]])
3, 10, (2, 2)) torch.randint(
tensor([[8, 3],
[8, 3]])
torch.randn
4) torch.randn(
tensor([0.2036, 0.3526, 0.7444, 1.0029])
2, 3) torch.randn(
tensor([[-0.3317, 3.1649, 2.7242],
[ 0.2243, 1.2105, 0.9819]])
torch.normal
=torch.arange(1., 11.), std=torch.arange(1, 0, -0.1)) torch.normal(mean
tensor([ 0.3426, 0.6076, 3.0707, 3.1876, 5.1126, 6.5160, 6.9380, 8.2192,
8.9378, 10.0406])
=0.5, std=torch.arange(1., 6.)) torch.normal(mean
tensor([1.4364, 1.6189, 0.1503, 2.7895, 3.0156])
=torch.arange(1., 6.)) torch.normal(mean
tensor([0.8270, 0.5561, 2.5076, 2.7576, 2.9344])
2, 3, size=(1, 4)) torch.normal(
tensor([[-1.9720, 2.4059, 3.3461, 0.6155]])
Save and Load
# Save to file
import io
= torch.tensor([0, 1, 2, 3, 4])
x 'Data/tensor.pt')
torch.save(x, # Save to io.BytesIO buffer
buffer = io.BytesIO()
buffer) torch.save(x,
torch.load('tensors.pt', weights_only=True)
torch.load('tensors.pt', map_location=torch.device('cpu'), weights_only=True)
torch.load('tensors.pt', map_location=lambda storage, loc: storage, weights_only=True)
torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1), weights_only=True)
torch.load('tensors.pt', map_location={'cuda:1': 'cuda:0'}, weights_only=True)
with open('tensor.pt', 'rb') as f:
buffer = io.BytesIO(f.read())
torch.load(buffer, weights_only=False)
torch.load('module.pt', encoding='ascii', weights_only=False)
Parallelism
torch.get_num_threads()
6
12) torch.set_num_threads(
torch.get_num_interop_threads()
6
12) torch.set_num_interop_threads(
Locally disabling gradient computation
= torch.zeros(1, requires_grad=True)
x with torch.no_grad():
= x * 2
y y.requires_grad
False
= False
is_train with torch.set_grad_enabled(is_train):
= x * 2
y y.requires_grad
False
True) # this can also be used as a function
torch.set_grad_enabled(= x * 2
y y.requires_grad
True
False)
torch.set_grad_enabled(= x * 2
y y.requires_grad
False
no_grad
= torch.tensor([1.], requires_grad=True)
x with torch.no_grad():
= x * 2
y y.requires_grad
False
@torch.no_grad()
def doubler(x):
return x * 2
= doubler(x)
z z.requires_grad
False
@torch.no_grad()
def tripler(x):
return x * 3
= tripler(x)
z
z.requires_grad# factory function exception
with torch.no_grad():
= torch.nn.Parameter(torch.rand(10))
a a.requires_grad
True
Math operations
abs(torch.tensor([-1, -2, 3])) torch.
tensor([1, 2, 3])
= torch.randn(4)
a a
tensor([ 0.1512, -0.5116, 1.4073, -0.9758])
20) torch.add(a,
tensor([20.1512, 19.4884, 21.4073, 19.0242])
= torch.randn(4)
b = torch.randn(4, 1)
c =10) b,c, torch.add(b, c, alpha
(tensor([-0.5080, -0.2541, 0.7946, -0.7497]),
tensor([[ 0.3399],
[-0.8642],
[-1.4262],
[ 0.3894]]),
tensor([[ 2.8912, 3.1451, 4.1939, 2.6495],
[ -9.1502, -8.8963, -7.8475, -9.3918],
[-14.7701, -14.5162, -13.4674, -15.0117],
[ 3.3861, 3.6400, 4.6887, 3.1444]]))
torch.asin(a)
tensor([ 0.1518, -0.5370, nan, -1.3503])
-1, -2, 3], dtype=torch.int8)) torch.bitwise_not(torch.tensor([
tensor([ 0, 1, -4], dtype=torch.int8)
torch.ceil(a)
tensor([1., -0., 2., -0.])
min=-0.5, max=0.5) torch.clamp(a,
tensor([ 0.1512, -0.5000, 0.5000, -0.5000])
min = torch.linspace(-1, 1, steps=4)
min, torch.clamp(a, min=min)
(tensor([-1.0000, -0.3333, 0.3333, 1.0000]),
tensor([ 0.1512, -0.3333, 1.4073, 1.0000]))
Gradient
# Estimates the gradient of f(x)=x^2 at points [-2, -1, 2, 4]
= (torch.tensor([-2., -1., 1., 4.]),)
coordinates = torch.tensor([4., 1., 1., 16.], )
values = coordinates) torch.gradient(values, spacing
(tensor([-3., -2., 2., 5.]),)
# Estimates the gradient of the R^2 -> R function whose samples are
# described by the tensor t. Implicit coordinates are [0, 1] for the outermost
# dimension and [0, 1, 2, 3] for the innermost dimension, and function estimates
# partial derivative for both dimensions.
= torch.tensor([[1, 2, 4, 8], [10, 20, 40, 80]])
t torch.gradient(t)
(tensor([[ 9., 18., 36., 72.],
[ 9., 18., 36., 72.]]),
tensor([[ 1.0000, 1.5000, 3.0000, 4.0000],
[10.0000, 15.0000, 30.0000, 40.0000]]))
# A scalar value for spacing modifies the relationship between tensor indices
# and input coordinates by multiplying the indices to find the
# coordinates. For example, below the indices of the innermost
# 0, 1, 2, 3 translate to coordinates of [0, 2, 4, 6], and the indices of
# the outermost dimension 0, 1 translate to coordinates of [0, 2].
= 2.0) # dim = None (implicitly [0, 1])
torch.gradient(t, spacing # doubling the spacing between samples halves the estimated partial gradients.
(tensor([[ 4.5000, 9.0000, 18.0000, 36.0000],
[ 4.5000, 9.0000, 18.0000, 36.0000]]),
tensor([[ 0.5000, 0.7500, 1.5000, 2.0000],
[ 5.0000, 7.5000, 15.0000, 20.0000]]))
# Estimates only the partial derivative for dimension 1
= 1) # spacing = None (implicitly 1.) torch.gradient(t, dim
(tensor([[ 1.0000, 1.5000, 3.0000, 4.0000],
[10.0000, 15.0000, 30.0000, 40.0000]]),)
# When spacing is a list of scalars, the relationship between the tensor
# indices and input coordinates changes based on dimension.
# For example, below, the indices of the innermost dimension 0, 1, 2, 3 translate
# to coordinates of [0, 3, 6, 9], and the indices of the outermost dimension
# 0, 1 translate to coordinates of [0, 2].
= [3., 2.]) torch.gradient(t, spacing
(tensor([[ 3., 6., 12., 24.],
[ 3., 6., 12., 24.]]),
tensor([[ 0.5000, 0.7500, 1.5000, 2.0000],
[ 5.0000, 7.5000, 15.0000, 20.0000]]))
# The following example is a replication of the previous one with explicit
# coordinates.
= (torch.tensor([0, 2]), torch.tensor([0, 3, 6, 9]))
coords = coords) torch.gradient(t, spacing
(tensor([[ 4.5000, 9.0000, 18.0000, 36.0000],
[ 4.5000, 9.0000, 18.0000, 36.0000]]),
tensor([[ 0.3333, 0.5000, 1.0000, 1.3333],
[ 3.3333, 5.0000, 10.0000, 13.3333]]))
Reduction Ops
= torch.randn(4, 4)
a a
tensor([[ 0.5653, -1.6250, -1.7234, 1.6898],
[ 0.0964, 0.4920, -0.8990, 0.2050],
[-0.0388, -0.3062, -2.7269, 2.2214],
[-0.4256, -1.3614, 1.6437, 1.3073]])
torch.argmax(a), torch.argmin(a)
(tensor(11), tensor(10))
=1), torch.argmin(a, dim=1) torch.argmax(a, dim
(tensor([3, 1, 3, 2]), tensor([2, 2, 2, 1]))
=1, keepdim=True) torch.argmin(a, dim
tensor([[2],
[2],
[2],
[1]])
0) torch.amax(a,
tensor([0.5653, 0.4920, 1.6437, 2.2214])
0) torch.amin(a,
tensor([-0.4256, -1.6250, -2.7269, 0.2050])
Tests if all element in input evaluates to True.
all(a) torch.
tensor(True)
Tests if any element in input evaluates to True.
any(a) torch.
tensor(True)
= torch.randn(4)
x = torch.randn(4)
y x,y
(tensor([-0.8039, -1.4679, 0.4484, -0.5348]),
tensor([-0.5979, -1.1656, -1.2298, -0.2573]))
3.5) torch.dist(x, y,
tensor(1.6806)
3) torch.dist(x, y,
tensor(1.6850)
0) torch.dist(x, y,
tensor(4.)
2) torch.dist(x, y,
tensor(1.7399)
= torch.randn(1, 3)
a a
tensor([[-0.7396, -1.3435, 1.0013]])
torch.mean(a)
tensor(-0.3606)
torch.prod(a)
tensor(0.9949)
= torch.randn(4, 2)
a a
tensor([[-0.0451, -0.4858],
[-0.1007, 0.4423],
[-0.2149, 0.4494],
[ 0.7059, 0.0417]])
1) torch.prod(a,
tensor([ 0.0219, -0.0446, -0.0966, 0.0295])
=1, keepdim=True) torch.std(a, dim
tensor([[0.3116],
[0.3840],
[0.4697],
[0.4696]])
sum(a) torch.
tensor(0.7927)
sum(a, 0) torch.
tensor([0.3451, 0.4476])
=0, keepdim=True) torch.var(a, dim
tensor([[0.1756, 0.1951]])
=0, keepdim=True) torch.var_mean(a, dim
(tensor([[0.1756, 0.1951]]), tensor([[0.0863, 0.1119]]))
Comparison Ops
1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) torch.eq(torch.tensor([[
tensor([[ True, False],
[False, True]])
10000., 1e-07]), torch.tensor([10000.1, 1e-08])) torch.allclose(torch.tensor([
False
1, float('nan'), 2])) torch.isnan(torch.tensor([
tensor([False, True, False])
Broadcast
= torch.tensor([1, 2, 3])
x 3, 3)) torch.broadcast_to(x, (
tensor([[1, 2, 3],
[1, 2, 3],
[1, 2, 3]])
2,), (3, 1), (1, 1, 1)) torch.broadcast_shapes((
torch.Size([1, 3, 2])
= torch.tensor([[0, 2], [1, 1], [2, 0]]).T
x
x x, torch.cov(x)
(tensor([[0, 1, 2],
[2, 1, 0]]),
tensor([[ 1., -1.],
[-1., 1.]]))
Covolution
=0) torch.cov(x, correction
tensor([[ 0.6667, -0.6667],
[-0.6667, 0.6667]])
= torch.randint(1, 10, (3,))
fw = torch.rand(3)
aw =fw, aweights=aw) fw, aw, torch.cov(x, fweights
(tensor([2, 2, 9]),
tensor([0.1281, 0.5609, 0.1982]),
tensor([[ 0.4583, -0.4583],
[-0.4583, 0.4583]]))
Diagonal
= torch.randn(3)
a a, torch.diag(a)
(tensor([-0.0159, -0.2550, -0.3652]),
tensor([[-0.0159, 0.0000, 0.0000],
[ 0.0000, -0.2550, 0.0000],
[ 0.0000, 0.0000, -0.3652]]))
1) torch.diag(a,
tensor([[ 0.0000, -0.0159, 0.0000, 0.0000],
[ 0.0000, 0.0000, -0.2550, 0.0000],
[ 0.0000, 0.0000, 0.0000, -0.3652],
[ 0.0000, 0.0000, 0.0000, 0.0000]])
= torch.randn(3, 3)
a a
tensor([[ 0.8558, -0.8227, 1.5082],
[-0.6674, -0.0815, -0.5271],
[-0.1011, -0.3513, 0.1919]])
0) torch.diag(a,
tensor([ 0.8558, -0.0815, 0.1919])
1) torch.diag(a,
tensor([-0.8227, -0.5271])
Diff
= torch.tensor([1, 3, 2])
a a, torch.diff(a)
(tensor([1, 3, 2]), tensor([ 2, -1]))
torch.gradient(a)
(tensor([ 2.0000, 0.5000, -1.0000]),)
= torch.tensor([4, 5])
b =b) torch.diff(a, append
tensor([ 2, -1, 2, 1])
= torch.tensor([[1, 2, 3], [3, 4, 5]])
c c
tensor([[1, 2, 3],
[3, 4, 5]])
=0) torch.diff(c, dim
tensor([[2, 2, 2]])
=1) torch.diff(c, dim
tensor([[1, 1],
[1, 1]])
Einsum
# trace
'ii', torch.randn(4, 4)) torch.einsum(
tensor(2.8873)
# diagonal
'ii->i', torch.randn(4, 4)) torch.einsum(
tensor([ 0.1727, -1.2934, 0.1134, 0.6699])
# outer product
= torch.randn(5)
x = torch.randn(4)
y x,y
(tensor([ 0.8261, 2.2608, 0.5666, -2.3195, -1.1706]),
tensor([-0.1575, 1.3682, -1.6248, -0.4177]))
'i,j->ij', x, y) torch.einsum(
tensor([[-0.1301, 1.1302, -1.3423, -0.3451],
[-0.3561, 3.0932, -3.6735, -0.9444],
[-0.0892, 0.7751, -0.9206, -0.2367],
[ 0.3653, -3.1735, 3.7688, 0.9689],
[ 0.1844, -1.6016, 1.9021, 0.4890]])
# batch matrix multiplication
= torch.randn(3, 2, 5)
As = torch.randn(3, 5, 4)
Bs As, Bs
(tensor([[[ 1.7602, 1.3836, -0.8395, 1.2415, 0.2106],
[-1.0765, -1.6569, 1.0785, 2.6039, -0.2173]],
[[-0.3606, 0.7737, -0.3265, -0.3982, -0.1795],
[ 0.4787, 0.3987, 0.5030, 2.0617, -1.2417]],
[[ 1.2606, -1.7532, 1.2267, 0.2588, -0.6794],
[ 0.4158, 0.3457, -0.3235, -1.4921, -1.0168]]]),
tensor([[[-1.1582, -0.5702, -0.5820, 1.4487],
[ 1.6068, 0.6026, 1.4152, 0.7676],
[-1.3955, -0.1634, -0.9673, -1.1214],
[ 0.8568, -1.2561, -1.0304, 0.8901],
[ 0.1840, 0.1531, 1.4496, 1.6163]],
[[-0.0585, -0.2001, 0.1084, -0.9644],
[-0.6568, 1.1768, -1.4877, 0.0249],
[-0.1339, 0.2070, 1.4734, 0.0627],
[ 0.3322, 0.5430, -0.5031, 0.1328],
[ 0.0214, 1.1490, 0.4914, 0.2774]],
[[ 0.3791, 0.3227, -2.6659, 0.6282],
[ 1.6384, 0.2688, -2.1299, 1.0743],
[ 0.1086, -0.0778, -0.2566, 0.0419],
[ 0.0142, 1.9122, -1.7333, 1.9668],
[ 0.1300, 0.8281, -1.0565, -1.4234]]]))
'bij,bjk->bik', As, Bs) torch.einsum(
tensor([[[ 2.4586, -1.5598, 0.7718, 5.9990],
[-0.7296, -3.8648, -5.7596, -2.0744]],
[[-0.5795, 0.4926, -1.5591, 0.2439],
[ 0.3011, 0.1702, -1.4475, -0.4909]],
[[-2.3460, -0.2277, 0.3280, 0.4358],
[ 0.5354, -3.4429, 1.8989, -0.8684]]])
# with sublist format and ellipsis
0, 1], Bs, [..., 1, 2], [..., 0, 2]) torch.einsum(As, [...,
tensor([[[ 2.4586, -1.5598, 0.7718, 5.9990],
[-0.7296, -3.8648, -5.7596, -2.0744]],
[[-0.5795, 0.4926, -1.5591, 0.2439],
[ 0.3011, 0.1702, -1.4475, -0.4909]],
[[-2.3460, -0.2277, 0.3280, 0.4358],
[ 0.5354, -3.4429, 1.8989, -0.8684]]])
# batch permute
= torch.randn(2, 3, 4, 5)
A '...ij->...ji', A).shape torch.einsum(
torch.Size([2, 3, 5, 4])
# equivalent to torch.nn.functional.bilinear
= torch.randn(3, 5, 4)
A = torch.randn(2, 5)
l = torch.randn(2, 4)
r A,l,r
(tensor([[[-0.5127, -0.0817, -0.5872, -2.0090],
[-0.3169, -1.0569, -0.2818, 1.8631],
[-1.5130, -0.7615, -0.3052, 0.7982],
[-0.3297, 1.6522, 0.9849, -1.5223],
[-0.5275, 0.1215, -0.5165, -0.4254]],
[[ 0.7555, -0.9271, 2.2486, -0.5548],
[ 0.0759, -0.3391, -1.3095, -0.2525],
[-0.2529, -1.0799, 0.5418, 0.4821],
[ 0.8987, -0.0494, -0.5371, -1.5568],
[-0.2188, 0.9023, 0.8624, 0.7310]],
[[ 2.1191, 0.3084, -0.8052, -0.2008],
[-0.3211, -0.4985, 0.3982, -1.0806],
[ 0.6403, -0.8154, -2.6253, 2.4096],
[ 0.5290, -1.3181, -0.8800, 0.9082],
[-0.8891, -1.0462, 1.2305, -0.7983]]]),
tensor([[ 1.1455, -1.6298, 0.2697, 1.6974, -1.9843],
[ 0.5881, -0.7756, 0.5787, -0.1486, -1.0844]]),
tensor([[-0.4899, 0.0820, -0.1346, -0.0343],
[-0.6835, 0.4302, -0.0725, 0.6343]]))
'bn,anm,bm->ba', l, A, r) torch.einsum(
tensor([[ 0.2351, -1.6662, -2.1157],
[-0.7526, -1.4344, 0.7918]])
torch.flatten
= torch.tensor([[[1, 2],
t 3, 4]],
[5, 6],
[[7, 8]]])
[ torch.flatten(t)
tensor([1, 2, 3, 4, 5, 6, 7, 8])
=1) torch.flatten(t, start_dim
tensor([[1, 2, 3, 4],
[5, 6, 7, 8]])
torch.histogram
1., 2, 1]),
torch.histogram(torch.tensor([=4,
binsrange=(0., 3.),
=torch.tensor([1., 2., 4.])) weight
torch.return_types.histogram(
hist=tensor([0., 5., 2., 0.]),
bin_edges=tensor([0.0000, 0.7500, 1.5000, 2.2500, 3.0000]))
1., 2, 1]),
torch.histogram(torch.tensor([=4, range=(0., 3.),
bins=torch.tensor([1., 2., 4.]),
weight=True) density
torch.return_types.histogram(
hist=tensor([0.0000, 0.9524, 0.3810, 0.0000]),
bin_edges=tensor([0.0000, 0.7500, 1.5000, 2.2500, 3.0000]))
torch.meshgrid
= torch.tensor([1, 2, 3])
x = torch.tensor([4, 5, 6])
y x,y
(tensor([1, 2, 3]), tensor([4, 5, 6]))
= torch.meshgrid(x, y, indexing='ij')
grid_x, grid_y grid_x,grid_y
(tensor([[1, 1, 1],
[2, 2, 2],
[3, 3, 3]]),
tensor([[4, 5, 6],
[4, 5, 6],
[4, 5, 6]]))
tuple(torch.dstack([grid_x, grid_y]))),
torch.equal(torch.cat(
torch.cartesian_prod(x, y))
import matplotlib.pyplot as plt
= torch.linspace(-5, 5, steps=100)
xs = torch.linspace(-5, 5, steps=100)
ys = torch.meshgrid(xs, ys, indexing='xy')
x, y = torch.sin(torch.sqrt(x * x + y * y))
z = plt.axes(projection='3d')
ax
ax.plot_surface(x.numpy(), y.numpy(), z.numpy()) plt.show()
torch.matmul
# vector x vector
= torch.randn(3)
tensor1 = torch.randn(3)
tensor2 tensor1, tensor2
(tensor([0.6342, 0.6817, 0.5164]), tensor([-0.5737, 1.6013, 0.5605]))
torch.matmul(tensor1, tensor2)
tensor(1.0173)
# matrix x vector
= torch.randn(3, 4)
tensor1 = torch.randn(4)
tensor2 tensor1, tensor2
(tensor([[-0.7989, -0.2968, -0.5672, -0.3673],
[ 0.3948, 0.9695, -0.9593, 0.5856],
[-0.6744, 1.4336, -1.3985, -0.2974]]),
tensor([-2.0189, -1.3822, -0.2220, 0.0440]))
torch.matmul(tensor1, tensor2).size(), torch.matmul(tensor1, tensor2)
(torch.Size([3]), tensor([ 2.1329, -1.8983, -0.3226]))
# batched matrix x broadcasted vector
= torch.randn(10, 3, 4)
tensor1 = torch.randn(4)
tensor2 torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3])
# batched matrix x batched matrix
= torch.randn(10, 3, 4)
tensor1 = torch.randn(10, 4, 5)
tensor2 torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3, 5])
# batched matrix x broadcasted matrix
= torch.randn(10, 3, 4)
tensor1 = torch.randn(4, 5)
tensor2 torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3, 5])