Skip to main content

❤️ Lovely JAX

Project description

❤️ Lovely JAX

Read full docs here

Install

pip install lovely-jax

How to use

How often do you find yourself debugging PyTorch code? You dump a tensor to the cell output, and see this:

numbers
DeviceArray([[[-0.35405433, -0.33692956, -0.4054286 , ..., -0.55955136,
               -0.4739276 ,  2.2489083 ],
              [-0.4054286 , -0.42255333, -0.49105233, ..., -0.91917115,
               -0.8506721 ,  2.1632845 ],
              [-0.4739276 , -0.4739276 , -0.5424266 , ..., -1.0390445 ,
               -1.0390445 ,  2.1975338 ],
              ...,
              [-0.9020464 , -0.8335474 , -0.9362959 , ..., -1.4671633 ,
               -1.2959158 ,  2.2317834 ],
              [-0.8506721 , -0.78217316, -0.9362959 , ..., -1.6041614 ,
               -1.5014129 ,  2.1804092 ],
              [-0.8335474 , -0.81642264, -0.9705454 , ..., -1.6555357 ,
               -1.5527872 ,  2.11191   ]],

             [[-0.19747896, -0.19747896, -0.30252096, ..., -0.47759098,
               -0.37254897,  2.4110641 ],
              [-0.24999997, -0.23249297, -0.33753496, ..., -0.705182  ,
               -0.670168  ,  2.3585434 ],
              [-0.30252096, -0.28501397, -0.39005598, ..., -0.740196  ,
               -0.810224  ,  2.3760502 ],
              ...,
              [-0.42507   , -0.23249297, -0.37254897, ..., -1.0903361 ,
               -1.0203081 ,  2.4285715 ],
              [-0.39005598, -0.23249297, -0.42507   , ..., -1.230392  ,
               -1.230392  ,  2.4110641 ],
              [-0.40756297, -0.28501397, -0.47759098, ..., -1.2829131 ,
               -1.2829131 ,  2.3410363 ]],

             [[-0.67154676, -0.9852723 , -0.88069713, ..., -0.9678431 ,
               -0.68897593,  2.3959913 ],
              [-0.7238344 , -1.0724182 , -0.9678431 , ..., -1.2467101 ,
               -1.0201306 ,  2.3262744 ],
              [-0.82840955, -1.1247058 , -1.0201306 , ..., -1.2641394 ,
               -1.1595641 ,  2.3785625 ],
              ...,
              [-1.229281  , -1.4732897 , -1.3861438 , ..., -1.5081482 ,
               -1.2641394 ,  2.5179958 ],
              [-1.1944225 , -1.4558606 , -1.4210021 , ..., -1.6475817 ,
               -1.4732897 ,  2.4308496 ],
              [-1.229281  , -1.5255773 , -1.5081482 , ..., -1.68244   ,
               -1.5255773 ,  2.3611329 ]]], dtype=float32)

Was it really useful for you, as a human, to see all these numbers?

What is the shape? The size?
What are the statistics?
Are any of the values nan or inf?
Is it an image of a man holding a tench?

import lovely_jax as lj
lj.monkey_patch()

__repr__

numbers # torch.Tensor
DeviceArray[3, 196, 196] n=115248 x∈[-2.118, 2.640] μ=-0.388 σ=1.073

Better, huh?

numbers[1,:6,1] # Still shows values if there are not too many.
DeviceArray[6] x∈[-0.443, -0.197] μ=-0.311 σ=0.083 [-0.197, -0.232, -0.285, -0.373, -0.443, -0.338]
spicy = numbers.flatten()[:12].clone()

spicy = (spicy  .at[0].mul(10000)
                .at[1].divide(10000)
                .at[2].set(float('inf'))
                .at[3].set(float('-inf'))
                .at[4].set(float('nan'))
                .reshape((2,6)))
spicy # Spicy stuff
DeviceArray[2, 6] n=12 x∈[-3.541e+03, -3.369e-05] μ=-393.776 σ=1.113e+03 +inf! -inf! nan!
jnp.zeros((10, 10)) # A zero tensor - make it obvious
DeviceArray[10, 10] n=100 all_zeros
spicy.v # Verbose
DeviceArray[2, 6] n=12 x∈[-3.541e+03, -3.369e-05] μ=-393.776 σ=1.113e+03 +inf! -inf! nan!
DeviceArray([[-3.5405432e+03, -3.3692959e-05,            inf,
                        -inf,            nan, -4.0542859e-01],
             [-4.2255333e-01, -4.9105233e-01, -5.0817710e-01,
              -5.5955136e-01, -5.4242659e-01, -5.0817710e-01]],            dtype=float32)
spicy.p # The plain old way
DeviceArray([[-3.5405432e+03, -3.3692959e-05,            inf,
                        -inf,            nan, -4.0542859e-01],
             [-4.2255333e-01, -4.9105233e-01, -5.0817710e-01,
              -5.5955136e-01, -5.4242659e-01, -5.0817710e-01]],            dtype=float32)

Going .deeper

numbers.deeper
DeviceArray[3, 196, 196] n=115248 x∈[-2.118, 2.640] μ=-0.388 σ=1.073
  DeviceArray[196, 196] n=38416 x∈[-2.118, 2.249] μ=-0.324 σ=1.036
  DeviceArray[196, 196] n=38416 x∈[-1.966, 2.429] μ=-0.274 σ=0.973
  DeviceArray[196, 196] n=38416 x∈[-1.804, 2.640] μ=-0.567 σ=1.178
# You can go deeper if you need to
numbers[:,:3,:5].deeper(2)
DeviceArray[3, 3, 5] n=45 x∈[-1.316, -0.197] μ=-0.593 σ=0.302
  DeviceArray[3, 5] n=15 x∈[-0.765, -0.337] μ=-0.492 σ=0.119
    DeviceArray[5] x∈[-0.440, -0.337] μ=-0.385 σ=0.037 [-0.354, -0.337, -0.405, -0.440, -0.388]
    DeviceArray[5] x∈[-0.662, -0.405] μ=-0.512 σ=0.097 [-0.405, -0.423, -0.491, -0.577, -0.662]
    DeviceArray[5] x∈[-0.765, -0.474] μ=-0.580 σ=0.112 [-0.474, -0.474, -0.542, -0.645, -0.765]
  DeviceArray[3, 5] n=15 x∈[-0.513, -0.197] μ=-0.321 σ=0.096
    DeviceArray[5] x∈[-0.303, -0.197] μ=-0.243 σ=0.049 [-0.197, -0.197, -0.303, -0.303, -0.215]
    DeviceArray[5] x∈[-0.408, -0.232] μ=-0.327 σ=0.075 [-0.250, -0.232, -0.338, -0.408, -0.408]
    DeviceArray[5] x∈[-0.513, -0.285] μ=-0.394 σ=0.091 [-0.303, -0.285, -0.390, -0.478, -0.513]
  DeviceArray[3, 5] n=15 x∈[-1.316, -0.672] μ=-0.964 σ=0.170
    DeviceArray[5] x∈[-0.985, -0.672] μ=-0.846 σ=0.110 [-0.672, -0.985, -0.881, -0.776, -0.916]
    DeviceArray[5] x∈[-1.212, -0.724] μ=-0.989 σ=0.160 [-0.724, -1.072, -0.968, -0.968, -1.212]
    DeviceArray[5] x∈[-1.316, -0.828] μ=-1.058 σ=0.160 [-0.828, -1.125, -1.020, -1.003, -1.316]

Without .monkey_patch

lj.lovely(spicy)
DeviceArray[2, 6] n=12 x∈[-3.541e+03, -3.369e-05] μ=-393.776 σ=1.113e+03 +inf! -inf! nan!
lj.lovely(spicy, verbose=True)
DeviceArray[2, 6] n=12 x∈[-3.541e+03, -3.369e-05] μ=-393.776 σ=1.113e+03 +inf! -inf! nan!
DeviceArray([[-3.5405432e+03, -3.3692959e-05,            inf,
                        -inf,            nan, -4.0542859e-01],
             [-4.2255333e-01, -4.9105233e-01, -5.0817710e-01,
              -5.5955136e-01, -5.4242659e-01, -5.0817710e-01]],            dtype=float32)
lj.lovely(numbers, depth=1)
DeviceArray[3, 196, 196] n=115248 x∈[-2.118, 2.640] μ=-0.388 σ=1.073
  DeviceArray[196, 196] n=38416 x∈[-2.118, 2.249] μ=-0.324 σ=1.036
  DeviceArray[196, 196] n=38416 x∈[-1.966, 2.429] μ=-0.274 σ=0.973
  DeviceArray[196, 196] n=38416 x∈[-1.804, 2.640] μ=-0.567 σ=1.178

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

lovely-jax-0.0.1.tar.gz (149.7 kB view hashes)

Uploaded Source

Built Distribution

lovely_jax-0.0.1-py3-none-any.whl (8.5 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page