❤️ Lovely JAX
Project description
❤️ Lovely JAX
Read full docs here
Install
pip install lovely-jax
How to use
How often do you find yourself debugging PyTorch code? You dump a tensor to the cell output, and see this:
numbers
DeviceArray([[[-0.35405433, -0.33692956, -0.4054286 , ..., -0.55955136,
-0.4739276 , 2.2489083 ],
[-0.4054286 , -0.42255333, -0.49105233, ..., -0.91917115,
-0.8506721 , 2.1632845 ],
[-0.4739276 , -0.4739276 , -0.5424266 , ..., -1.0390445 ,
-1.0390445 , 2.1975338 ],
...,
[-0.9020464 , -0.8335474 , -0.9362959 , ..., -1.4671633 ,
-1.2959158 , 2.2317834 ],
[-0.8506721 , -0.78217316, -0.9362959 , ..., -1.6041614 ,
-1.5014129 , 2.1804092 ],
[-0.8335474 , -0.81642264, -0.9705454 , ..., -1.6555357 ,
-1.5527872 , 2.11191 ]],
[[-0.19747896, -0.19747896, -0.30252096, ..., -0.47759098,
-0.37254897, 2.4110641 ],
[-0.24999997, -0.23249297, -0.33753496, ..., -0.705182 ,
-0.670168 , 2.3585434 ],
[-0.30252096, -0.28501397, -0.39005598, ..., -0.740196 ,
-0.810224 , 2.3760502 ],
...,
[-0.42507 , -0.23249297, -0.37254897, ..., -1.0903361 ,
-1.0203081 , 2.4285715 ],
[-0.39005598, -0.23249297, -0.42507 , ..., -1.230392 ,
-1.230392 , 2.4110641 ],
[-0.40756297, -0.28501397, -0.47759098, ..., -1.2829131 ,
-1.2829131 , 2.3410363 ]],
[[-0.67154676, -0.9852723 , -0.88069713, ..., -0.9678431 ,
-0.68897593, 2.3959913 ],
[-0.7238344 , -1.0724182 , -0.9678431 , ..., -1.2467101 ,
-1.0201306 , 2.3262744 ],
[-0.82840955, -1.1247058 , -1.0201306 , ..., -1.2641394 ,
-1.1595641 , 2.3785625 ],
...,
[-1.229281 , -1.4732897 , -1.3861438 , ..., -1.5081482 ,
-1.2641394 , 2.5179958 ],
[-1.1944225 , -1.4558606 , -1.4210021 , ..., -1.6475817 ,
-1.4732897 , 2.4308496 ],
[-1.229281 , -1.5255773 , -1.5081482 , ..., -1.68244 ,
-1.5255773 , 2.3611329 ]]], dtype=float32)
Was it really useful for you, as a human, to see all these numbers?
What is the shape? The size?
What are the statistics?
Are any of the values nan
or inf
?
Is it an image of a man holding a tench?
import lovely_jax as lj
lj.monkey_patch()
__repr__
numbers # torch.Tensor
DeviceArray[3, 196, 196] n=115248 x∈[-2.118, 2.640] μ=-0.388 σ=1.073
Better, huh?
numbers[1,:6,1] # Still shows values if there are not too many.
DeviceArray[6] x∈[-0.443, -0.197] μ=-0.311 σ=0.083 [-0.197, -0.232, -0.285, -0.373, -0.443, -0.338]
spicy = numbers.flatten()[:12].clone()
spicy = (spicy .at[0].mul(10000)
.at[1].divide(10000)
.at[2].set(float('inf'))
.at[3].set(float('-inf'))
.at[4].set(float('nan'))
.reshape((2,6)))
spicy # Spicy stuff
DeviceArray[2, 6] n=12 x∈[-3.541e+03, -3.369e-05] μ=-393.776 σ=1.113e+03 +inf! -inf! nan!
jnp.zeros((10, 10)) # A zero tensor - make it obvious
DeviceArray[10, 10] n=100 all_zeros
spicy.v # Verbose
DeviceArray[2, 6] n=12 x∈[-3.541e+03, -3.369e-05] μ=-393.776 σ=1.113e+03 +inf! -inf! nan!
DeviceArray([[-3.5405432e+03, -3.3692959e-05, inf,
-inf, nan, -4.0542859e-01],
[-4.2255333e-01, -4.9105233e-01, -5.0817710e-01,
-5.5955136e-01, -5.4242659e-01, -5.0817710e-01]], dtype=float32)
spicy.p # The plain old way
DeviceArray([[-3.5405432e+03, -3.3692959e-05, inf,
-inf, nan, -4.0542859e-01],
[-4.2255333e-01, -4.9105233e-01, -5.0817710e-01,
-5.5955136e-01, -5.4242659e-01, -5.0817710e-01]], dtype=float32)
Going .deeper
numbers.deeper
DeviceArray[3, 196, 196] n=115248 x∈[-2.118, 2.640] μ=-0.388 σ=1.073
DeviceArray[196, 196] n=38416 x∈[-2.118, 2.249] μ=-0.324 σ=1.036
DeviceArray[196, 196] n=38416 x∈[-1.966, 2.429] μ=-0.274 σ=0.973
DeviceArray[196, 196] n=38416 x∈[-1.804, 2.640] μ=-0.567 σ=1.178
# You can go deeper if you need to
numbers[:,:3,:5].deeper(2)
DeviceArray[3, 3, 5] n=45 x∈[-1.316, -0.197] μ=-0.593 σ=0.302
DeviceArray[3, 5] n=15 x∈[-0.765, -0.337] μ=-0.492 σ=0.119
DeviceArray[5] x∈[-0.440, -0.337] μ=-0.385 σ=0.037 [-0.354, -0.337, -0.405, -0.440, -0.388]
DeviceArray[5] x∈[-0.662, -0.405] μ=-0.512 σ=0.097 [-0.405, -0.423, -0.491, -0.577, -0.662]
DeviceArray[5] x∈[-0.765, -0.474] μ=-0.580 σ=0.112 [-0.474, -0.474, -0.542, -0.645, -0.765]
DeviceArray[3, 5] n=15 x∈[-0.513, -0.197] μ=-0.321 σ=0.096
DeviceArray[5] x∈[-0.303, -0.197] μ=-0.243 σ=0.049 [-0.197, -0.197, -0.303, -0.303, -0.215]
DeviceArray[5] x∈[-0.408, -0.232] μ=-0.327 σ=0.075 [-0.250, -0.232, -0.338, -0.408, -0.408]
DeviceArray[5] x∈[-0.513, -0.285] μ=-0.394 σ=0.091 [-0.303, -0.285, -0.390, -0.478, -0.513]
DeviceArray[3, 5] n=15 x∈[-1.316, -0.672] μ=-0.964 σ=0.170
DeviceArray[5] x∈[-0.985, -0.672] μ=-0.846 σ=0.110 [-0.672, -0.985, -0.881, -0.776, -0.916]
DeviceArray[5] x∈[-1.212, -0.724] μ=-0.989 σ=0.160 [-0.724, -1.072, -0.968, -0.968, -1.212]
DeviceArray[5] x∈[-1.316, -0.828] μ=-1.058 σ=0.160 [-0.828, -1.125, -1.020, -1.003, -1.316]
Without .monkey_patch
lj.lovely(spicy)
DeviceArray[2, 6] n=12 x∈[-3.541e+03, -3.369e-05] μ=-393.776 σ=1.113e+03 +inf! -inf! nan!
lj.lovely(spicy, verbose=True)
DeviceArray[2, 6] n=12 x∈[-3.541e+03, -3.369e-05] μ=-393.776 σ=1.113e+03 +inf! -inf! nan!
DeviceArray([[-3.5405432e+03, -3.3692959e-05, inf,
-inf, nan, -4.0542859e-01],
[-4.2255333e-01, -4.9105233e-01, -5.0817710e-01,
-5.5955136e-01, -5.4242659e-01, -5.0817710e-01]], dtype=float32)
lj.lovely(numbers, depth=1)
DeviceArray[3, 196, 196] n=115248 x∈[-2.118, 2.640] μ=-0.388 σ=1.073
DeviceArray[196, 196] n=38416 x∈[-2.118, 2.249] μ=-0.324 σ=1.036
DeviceArray[196, 196] n=38416 x∈[-1.966, 2.429] μ=-0.274 σ=0.973
DeviceArray[196, 196] n=38416 x∈[-1.804, 2.640] μ=-0.567 σ=1.178
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
lovely-jax-0.0.1.tar.gz
(149.7 kB
view hashes)
Built Distribution
Close
Hashes for lovely_jax-0.0.1-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | dfffdfd87a167165d3aa7c0d9300fc864ae1f2c81a1bc02ebcbf57af698cb058 |
|
MD5 | c932ad8d0dcdef8064d592638a4f7ab9 |
|
BLAKE2b-256 | 439d66a3e177c41248bc68cdc009ed968d5ce047423b5b44d026d7aa98793465 |