JAX compatible dataclass.
Project description
Write pytorch-like layers with rich visualizations in JAX.
Installation |Description |Quick Example |Filtering |StatefulComputation |Applications |More |Acknowledgements
๐ ๏ธ Installation
pip install pytreeclass
๐ Description
PyTreeClass offers a JAX compatible dataclass
like datastructure with the following functionalities
- ๐๏ธ Create PyTorch like NN classes
- ๐จ Visualize for pytrees decorated with
@pytc.treeclass
. - โ๏ธ Filtering by boolean masking similar to
jax.numpy.at
โฉ Quick Example
๐๏ธ Create simple MLP
For Autoencoder example from scratch see here
import jax
from jax import numpy as jnp
import pytreeclass as pytc
import matplotlib.pyplot as plt
@pytc.treeclass
class Linear :
# Any variable not wrapped with @pytc.treeclass
# should be declared as a dataclass field here
weight : jnp.ndarray
bias : jnp.ndarray
def __init__(self,key,in_dim,out_dim):
self.weight = jax.random.normal(key,shape=(in_dim, out_dim)) * jnp.sqrt(2/in_dim)
self.bias = jnp.ones((1,out_dim))
def __call__(self,x):
return x @ self.weight + self.bias
@pytc.treeclass
class StackedLinear:
def __init__(self,key,in_dim,out_dim,hidden_dim):
keys= jax.random.split(key,3)
# Declaring l1,l2,l3 as dataclass_fields is optional
# as l1,l2,l3 are Linear class that is wrapped with @pytc.treeclass
# To strictly include nodes defined in dataclass fields
# use `@pytc.treeclass(field_only=True)`
self.l1 = Linear(key=keys[0],in_dim=in_dim,out_dim=hidden_dim)
self.l2 = Linear(key=keys[1],in_dim=hidden_dim,out_dim=hidden_dim)
self.l3 = Linear(key=keys[2],in_dim=hidden_dim,out_dim=out_dim)
def __call__(self,x):
x = self.l1(x)
x = jax.nn.tanh(x)
x = self.l2(x)
x = jax.nn.tanh(x)
x = self.l3(x)
return x
>>> model = StackedLinear(in_dim=1,out_dim=1,hidden_dim=10,key=jax.random.PRNGKey(0))
>>> x = jnp.linspace(0,1,100)[:,None]
>>> y = x**3 + jax.random.uniform(jax.random.PRNGKey(0),(100,1))*0.01
๐จ Visualize
summary | tree_box | tree_diagram |
>>> print(model.summary())
โโโโโโโโฌโโโโโโโโฌโโโโโโโโฌโโโโโโโโโโโโโโโโโโ
โType โParam #โSize โConfig โ
โโโโโโโโผโโโโโโโโผโโโโโโโโผโโโโโโโโโโโโโโโโโโค
โLinearโ20 โ80.00B โweight=f32[1,10] โ
โ โ(0) โ(0.00B)โbias=f32[1,10] โ
โโโโโโโโผโโโโโโโโผโโโโโโโโผโโโโโโโโโโโโโโโโโโค
โLinearโ110 โ440.00Bโweight=f32[10,10]โ
โ โ(0) โ(0.00B)โbias=f32[1,10] โ
โโโโโโโโผโโโโโโโโผโโโโโโโโผโโโโโโโโโโโโโโโโโโค
โLinearโ11 โ44.00B โweight=f32[10,1] โ
โ โ(0) โ(0.00B)โbias=f32[1,1] โ
โโโโโโโโดโโโโโโโโดโโโโโโโโดโโโโโโโโโโโโโโโโโโ
Total # : 141(0)
Dynamic #: 141(0)
Static/Frozen #: 0(0)
------------------------------------------
Total size : 564.00B(0.00B)
Dynamic size: 564.00B(0.00B)
Static/Frozen size: 0.00B(0.00B)
==========================================
|
>>> print(model.tree_box(array=x))
# using jax.eval_shape (no-flops operation)
# ** note ** : the created modules
# in __init__ should be in the same order
# where they are called in __call__
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โStackedLinear(Parent) โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโค
โโโโโโโโโโโโโโโฌโโโโโโโโโฌโโโโโโโโโโโโโโโ
โโ โ Input โ f32[100,1] โโ
โโ Linear(l1) โโโโโโโโโโผโโโโโโโโโโโโโโคโ
โโ โ Output โ f32[100,10] โโ
โโโโโโโโโโโโโโโดโโโโโโโโโดโโโโโโโโโโโโโโโ
โโโโโโโโโโโโโโโฌโโโโโโโโโฌโโโโโโโโโโโโโโโ
โโ โ Input โ f32[100,10] โโ
โโ Linear(l2) โโโโโโโโโโผโโโโโโโโโโโโโโคโ
โโ โ Output โ f32[100,10] โโ
โโโโโโโโโโโโโโโดโโโโโโโโโดโโโโโโโโโโโโโโโ
โโโโโโโโโโโโโโโฌโโโโโโโโโฌโโโโโโโโโโโโโโโ
โโ โ Input โ f32[100,10] โโ
โโ Linear(l3) โโโโโโโโโโผโโโโโโโโโโโโโโคโ
โโ โ Output โ f32[100,1] โโ
โโโโโโโโโโโโโโโดโโโโโโโโโดโโโโโโโโโโโโโโโ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
>>> print(model.tree_diagram())
StackedLinear
โโโ l1=Linear
โ โโโ weight=f32[1,10]
โ โโโ bias=f32[1,10]
โโโ l2=Linear
โ โโโ weight=f32[10,10]
โ โโโ bias=f32[1,10]
โโโl3=Linear
โโโ weight=f32[10,1]
โโโ bias=f32[1,1]
|
mermaid.io (Native support in Github/Notion) |
# generate mermaid diagrams
# print(pytc.tree_viz.tree_mermaid(model)) # generate core syntax
>>> pytc.tree_viz.save_viz(model,filename="test_mermaid",method="tree_mermaid_md")
# use `method="tree_mermaid_html"` to save as html
flowchart LR
id15696277213149321320[StackedLinear]
id15696277213149321320 --> id159132120600507116(l1\nLinear)
id159132120600507116 --- id7500441386962467209["weight\nf32[1,10]"]
id159132120600507116 --- id10793958738030044218["bias\nf32[1,10]"]
id15696277213149321320 --> id10009280772564895168(l2\nLinear)
id10009280772564895168 --- id11951215191344350637["weight\nf32[10,10]"]
id10009280772564895168 --- id1196345851686744158["bias\nf32[1,10]"]
id15696277213149321320 --> id7572222925824649475(l3\nLinear)
id7572222925824649475 --- id4749243995442935477["weight\nf32[10,1]"]
id7572222925824649475 --- id8042761346510512486["bias\nf32[1,1]"]
โจ Generate shareable vizualization links โจ
>>> pytc.tree_viz.tree_mermaid(model,link=True)
'Open URL in browser: https://pytreeclass.herokuapp.com/temp/?id=*********'
|
โ๏ธ Model surgery
# freeze l1
>>> model.l1 = model.l1.freeze()
# set negative values in l2 to 0
>>> model.l2 = model.l2.at[model.l2<0].set(0)
# apply sin(x) to all values in l3
>>> model.l3 = model.l3.at[...].apply(jnp.sin)
# frozen nodes are marked with #
>>> print(model.tree_diagram())
StackedLinear
โโโ l1=Linear
โ โ#โ weight=f32[1,10]
โ โ#โ bias=f32[1,10]
โโโ l2=Linear
โ โโโ weight=f32[10,10]
โ โโโ bias=f32[1,10]
โโโ l3=Linear
โโโ weight=f32[10,1]
โโโ bias=f32[1,1]
โ๏ธ Filtering with .at[]
PyTreeClass
offers four means of filtering:
- Filter by value
- Filter by field name
- Filter by field type
- Filter by field metadata.
The following example demonstrates the usage the filtering. Suppose you have the following (Multilayer perceptron) MLP class
- Note in
StackedLinear
l1
andl2
has a description infield
metadata.
Model definition
import jax
from jax import numpy as jnp
import pytreeclass as pytc
import matplotlib.pyplot as plt
from dataclasses import field
@pytc.treeclass
class Linear :
weight : jnp.ndarray
bias : jnp.ndarray
def __init__(self,key,in_dim,out_dim):
self.weight = jax.random.normal(key,shape=(in_dim, out_dim)) * jnp.sqrt(2/in_dim)
self.bias = jnp.ones((1,out_dim))
def __call__(self,x):
return x @ self.weight + self.bias
@pytc.treeclass
class StackedLinear:
l1 : Linear = field(metadata={"description": "First layer"})
l2 : Linear = field(metadata={"description": "Second layer"})
def __init__(self,key,in_dim,out_dim,hidden_dim):
keys= jax.random.split(key,3)
self.l1 = Linear(key=keys[0],in_dim=in_dim,out_dim=hidden_dim)
self.l2 = Linear(key=keys[2],in_dim=hidden_dim,out_dim=out_dim)
def __call__(self,x):
x = self.l1(x)
x = jax.nn.tanh(x)
x = self.l2(x)
return x
model = StackedLinear(in_dim=1,out_dim=1,hidden_dim=5,key=jax.random.PRNGKey(0))
- Raw model values before any filtering.
>>> print(model)
StackedLinear(
l1=Linear(
weight=[[-1.6248673 -2.8383057 1.3969219 1.3169124 -0.40784812]],
bias=[[1. 1. 1. 1. 1.]]),
l2=Linear(
weight=
[[ 0.98507565]
[ 0.99815285]
[-1.0687716 ]
[-0.19255024]
[-1.2108876 ]],
bias=[[1.]]))
Filter by value
- Get all negative values
>>> print(model.at[model<0].get())
StackedLinear(
l1=Linear(
weight=[-1.6248673 -2.8383057 -0.40784812],
bias=[]),
l2=Linear(
weight=[-1.0687716 -0.19255024 -1.2108876 ],
bias=[]))
- Set negative values to 0
>>> print(model.at[model<0].set(0))
StackedLinear(
l1=Linear(
weight=[[0. 0. 1.3969219 1.3169124 0. ]],
bias=[[1. 1. 1. 1. 1.]]),
l2=Linear(
weight=
[[0.98507565]
[0.99815285]
[0. ]
[0. ]
[0. ]],
bias=[[1.]]))
- Apply f(x)=x^2 to negative values
>>> print(model.at[model<0].apply(lambda x:x**2))
StackedLinear(
l1=Linear(
weight=[[2.6401937 8.05598 1.3969219 1.3169124 0.16634008]],
bias=[[1. 1. 1. 1. 1.]]),
l2=Linear(
weight=
[[0.98507565]
[0.99815285]
[1.1422727 ]
[0.03707559]
[1.4662486 ]],
bias=[[1.]]))
- Sum all negative values
>>> print(model.at[model<0].reduce_sum())
-7.3432307
Filter by field name
- Get all fields named
l1
>>> print(model.at[model == "l1"].get())
StackedLinear(
l1=Linear(
weight=[-1.6248673 -2.8383057 1.3969219 1.3169124 -0.40784812],
bias=[1. 1. 1. 1. 1.]),
l2=Linear(weight=[],bias=[]))
Filter by field type
- Get all fields of
Linear
type
>>> print(model.at[model == Linear].get())
StackedLinear(
l1=Linear(
weight=[-1.6248673 -2.8383057 1.3969219 1.3169124 -0.40784812],
bias=[1. 1. 1. 1. 1.]),
l2=Linear(
weight=[ 0.98507565 0.99815285 -1.0687716 -0.19255024 -1.2108876 ],
bias=[1.]))
Filter by field metadata
- Get all fields of with
{"description": "First layer"}
in their metadata
>>> print(model.at[model == {"description": "First layer"}].get())
StackedLinear(
l1=Linear(
weight=[-1.6248673 -2.8383057 1.3969219 1.3169124 -0.40784812],
bias=[1. 1. 1. 1. 1.]),
l2=Linear(weight=[],bias=[]))
๐คฏ Application : Filtering PyTrees by boolean masking
- Manipulate certain modules attributes values.
- Set certain modules (e.g.
Dropout
) to eval mode
- Model definition
import jax
from jax import numpy as jnp
import jax.random as jr
import pytreeclass as pytc
@pytc.treeclass
class Linear :
weight : jnp.ndarray
bias : jnp.ndarray
def __init__(self,key,in_dim,out_dim):
self.weight = jax.random.normal(key,shape=(in_dim, out_dim)) * jnp.sqrt(2/in_dim)
self.bias = jnp.ones((1,out_dim))
def __call__(self,x):
return x @ self.weight + self.bias
@pytc.treeclass
class Dropout:
p: float
eval : bool | None
def __init__(self, p: float = 0.5, eval: bool | None = None):
"""p : probability of an element to be zeroed out"""
self.p = p
self.eval = eval
def __call__(self, x, *, key=jr.PRNGKey(0)):
return (
x if (self.eval is True)
else
jnp.where(jr.bernoulli(key, (1 - self.p), x.shape), x / (1 - self.p), 0)
)
@pytc.treeclass
class LinearWithDropout:
def __init__(self):
self.l1 = Linear(key=jr.PRNGKey(0), in_dim=1, out_dim=5)
self.d1 = Dropout(p = 1.) # zero out all elements
def __call__(self, x):
x = self.l1(x)
x = self.d1(x)
return x
Linear
module with full dropout
>>> model = LinearWithDropout()
>>> print(model(jnp.ones((1,1))))
[[0. 0. 0. 0. 0.]]
Disable Dropout
- using boolean masking with
.at[].set()
to disableDropout
>>> mask = (model == "eval")
>>> model_no_dropout = model.at[mask].set(True, is_leaf = lambda x:x is None)
>>> print(model_no_dropout(jnp.ones((1,1))))
[[ 1.2656513 -0.8149204 0.61661845 2.7664368 1.3457328 ]]
Set Linear
module bias to 0
- Combining attribute name mask and class type mask
>>> mask = (model == "bias") & (model == Linear)
>>> model_no_linear_bias = model.at[mask ].set(0)
>>> print(model_no_linear_bias)
LinearWithDropout(
l1=Linear(
weight=[[ 0.26565132 -1.8149204 -0.38338155 1.7664368 0.34573284]],
bias=[[0. 0. 0. 0. 0.]]),
d1=Dropout(p=1.0,eval=None))
๐ Stateful computations
JAX reference Under jax.jit jax requires states to be explicit, this means that for any class instance; variables needs to be separated from the class and be passed explictly. However when using @pytc.treeclass no need to separate the instance variables ; instead the whole instance is passed as a state.
Using the following pattern,Updating state can be achieved under jax.jit
@pytc.treeclass
class Counter:
calls : int = 0
def increment(self):
self.calls += 1
>>> c = Counter()
@jax.jit
def update(c):
c.increment()
return c
for i in range(10):
c = update(c)
>>> print(c.calls)
10
The following code snippets compares between the two concepts by comparing MLP's implementation.
Explicit state | Class instance as state |
import jax.numpy as jnp
import jax.random as jr
from jax.nn.initializers import he_normal
from jax.tree_util import tree_map
from jax import nn, value_and_grad,jit
import pytreeclass as pytc
def init_params(layers):
keys = jr.split(
jr.PRNGKey(0),len(layers)-1
)
params = list()
init_func = he_normal()
for key,n_in,n_out in zip(
keys,
layers[:-1],
layers[1:]
):
W = init_func(key,(n_in,n_out))
B = jr.uniform(key,shape=(n_out,))
params.append({'W':W,'B':B})
return params
def fwd(params,x):
*hidden,last = params
for layer in hidden :
x = nn.tanh(x@layer['W']+layer['B'])
return x@last['W'] + last['B']
@value_and_grad
def loss_func(params,x,y):
pred = fwd(params,x)
return jnp.mean((pred-y)**2)
@jit
def update(params,x,y):
# gradient w.r.t to params
value,grads= loss_func(params,x,y)
params = tree_map(
lambda x,y : x-1e-3*y, params,grads
)
return value,params
x = jnp.linspace(0,1,100).reshape(100,1)
y = x**2 -1
params = init_params([1] +[5]*4+[1] )
epochs = 10_000
for _ in range(1,epochs+1):
value , params = update(params,x,y)
# print loss and epoch info
if _ %(1_000) ==0:
print(f'Epoch={_}\tloss={value:.3e}')
|
import jax.numpy as jnp
import jax.random as jr
from jax.nn.initializers import he_normal
from jax.tree_util import tree_map
from jax import nn, value_and_grad,jit
import pytreeclass as pytc
@pytc.treeclass
class MLP:
Layers : list
def __init__(self,layers):
keys = jr.split(
jr.PRNGKey(0),len(layers)-1
)
self.Layers = list()
init_func = he_normal()
for key,n_in,n_out in zip(
keys,
layers[:-1],
layers[1:]
):
W = init_func(key,(n_in,n_out))
B = jr.uniform(key,shape=(n_out,))
self.Layers.append({'W':W,'B':B})
def __call__(self,x):
*hidden,last = self.Layers
for layer in hidden :
x = nn.tanh(x@layer['W']+layer['B'])
return x@last['W'] + last['B']
@value_and_grad
def loss_func(model,x,y):
pred = model(x)
return jnp.mean((pred-y)**2)
@jit
def update(model,x,y):
# gradient w.r.t to model
value , grads= loss_func(model,x,y)
model = tree_map(
lambda x,y : x-1e-3*y, model,grads
)
return value , model
x = jnp.linspace(0,1,100).reshape(100,1)
y = x**2 -1
model = MLP([1] +[5]*4+[1] )
epochs = 10_000
for _ in range(1,epochs+1):
value , model = update(model,x,y)
# print loss and epoch info
if _ %(1_000) ==0:
print(f'Epoch={_}\tloss={value:.3e}')
|
๐ Applications
๐ข More
More compact boilerplate
Standard definition of nodes in __init__
and calling in __call__
@pytc.treeclass
class StackedLinear:
def __init__(self,key,in_dim,out_dim,hidden_dim):
keys= jax.random.split(key,3)
self.l1 = Linear(key=keys[0],in_dim=in_dim,out_dim=hidden_dim)
self.l2 = Linear(key=keys[1],in_dim=hidden_dim,out_dim=hidden_dim)
self.l3 = Linear(key=keys[2],in_dim=hidden_dim,out_dim=out_dim)
def __call__(self,x):
x = self.l1(x)
x = jax.nn.tanh(x)
x = self.l2(x)
x = jax.nn.tanh(x)
x = self.l3(x)
return x
Using register_node
:
- More compact definition with node definition at runtime call
- The Linear layers are defined on the first call and retrieved on the subsequent calls
- This pattern is useful if module definition depends runtime data.
@pytc.treeclass
class StackedLinear:
def __init__(self,key):
self.keys = jax.random.split(key,3)
def __call__(self,x):
x = self.register_node(Linear(self.keys[0],x.shape[-1],10),name="l1")(x)
x = jax.nn.tanh(x)
x = self.register_node(Linear(self.keys[1],10,10),name="l2")(x)
x = jax.nn.tanh(x)
x = self.register_node(Linear(self.keys[2],10,x.shape[-1]),name="l3")(x)
return x
Simple AutoEncoder from scratch
While jax.lax
can be used to construct Convolution, Upsample, Maxpooling functions, in this example kernex is used for its clear syntax.
AE Construction
from typing import Sequence
import jax
import jax.numpy as jnp
import jax.random as jr
import pytreeclass as pytc # dataclass-like decorator for JAX
import kernex as kex # for stencil computations
@pytc.treeclass
class Conv2D:
weight: jnp.ndarray
bias: jnp.ndarray
# define these variabels here
# to be used in __call__
in_channels: int = pytc.static_field()
out_channels: int = pytc.static_field()
kernel_size: Sequence[int] = pytc.static_field()
padding: Sequence[str] = pytc.static_field()
strides: Sequence[int] = pytc.static_field()
def __init__(
self,
in_channels,
out_channels,
kernel_size,
strides=1,
padding=("same", "same"),
key=jax.random.PRNGKey(0),
kernel_initializer=jax.nn.initializers.kaiming_uniform(),
):
self.weight = kernel_initializer(key, (out_channels, in_channels, *kernel_size))
self.bias = jnp.zeros((out_channels, *((1,) * len(kernel_size))))
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.strides = strides
self.padding = ("valid",) + padding
def __call__(self, x):
@kex.kmap(
kernel_size=(self.in_channels, *self.kernel_size),
strides=self.strides,
padding=self.padding,
)
def _conv2d(x, w):
return jnp.sum(x * w)
@jax.vmap # vectorize on batch dimension
def fwd_image(image):
# filters shape is OIHW
# vectorize on filters output dimension
return jax.vmap(lambda w: _conv2d(image, w))(self.weight)[:, 0] + (
self.bias if self.bias is not None else 0
)
return fwd_image(x)
@pytc.treeclass
class Upsample2D:
scale_factor: int = pytc.static_field()
def __call__(self, x):
batch, channel, row, col = x.shape
@kex.kmap(
kernel_size=(channel, row, col),
strides=(1, 1, 1),
padding="valid",
relative=False,
)
def __upsample2D(x):
return x.repeat(self.scale_factor, axis=2).repeat(self.scale_factor, axis=1)
def _upsample2D(batch):
return jnp.squeeze(
jax.vmap(__upsample2D, in_axes=(0,))(batch), axis=tuple(range(1, 4))
)
return _upsample2D(x)
@pytc.treeclass
class MaxPool2D:
kernel_size: tuple[int, int] = pytc.static_field(default=(2, 2))
strides: int = pytc.static_field(default=2)
padding: str | int = pytc.static_field(default="valid")
def __call__(self, x):
@jax.vmap # apply on batch dimension
@jax.vmap # apply on channels dimension
@kex.kmap(
kernel_size=self.kernel_size, strides=self.strides, padding=self.padding
)
def _maxpool2d(x):
return jnp.max(x)
return _maxpool2d(x)
@pytc.treeclass
class AutoEncoder:
def __init__(self, in_channels, out_channels, key):
keys = jr.split(key, 5)
self.l1 = MaxPool2D()
self.l2 = Conv2D(in_channels, 16, (3, 3), key=keys[0])
self.l3 = MaxPool2D()
self.l4 = Conv2D(16, 32, (3, 3), key=keys[1])
self.l5 = Upsample2D(scale_factor=2)
self.l6 = Conv2D(32, 16, (3, 3), key=keys[2])
self.l7 = Upsample2D(scale_factor=2)
self.l8 = Conv2D(16, 1, (3, 3), key=keys[3])
self.l9 = Conv2D(1, out_channels, (1, 1), key=keys[4])
def __call__(self, x):
x = self.l1(x)
x = self.l2(x)
x = jax.nn.relu(x)
x = self.l3(x)
x = self.l4(x)
x = jax.nn.relu(x)
x = self.l5(x)
x = self.l6(x)
x = jax.nn.relu(x)
x = self.l7(x)
x = self.l8(x)
x = jax.nn.relu(x)
x = self.l9(x)
return x
ae = AutoEncoder(1, 1, jax.random.PRNGKey(0))
Model summary
โโโโโโโโโโโโฌโโโโโโโโฌโโโโโโโโฌโโโโโโโโโโโโโโโโโโโโโโ
โType โParam #โSize โConfig โ
โโโโโโโโโโโโผโโโโโโโโผโโโโโโโโผโโโโโโโโโโโโโโโโโโโโโโค
โMaxPool2D โ0 โ0.00B โ โ
โ โ(0) โ(0.00B)โ โ
โโโโโโโโโโโโผโโโโโโโโผโโโโโโโโผโโโโโโโโโโโโโโโโโโโโโโค
โConv2D โ160 โ640.00Bโweight=f32[16,1,3,3] โ
โ โ(0) โ(0.00B)โbias=f32[16,1,1] โ
โโโโโโโโโโโโผโโโโโโโโผโโโโโโโโผโโโโโโโโโโโโโโโโโโโโโโค
โMaxPool2D โ0 โ0.00B โ โ
โ โ(0) โ(0.00B)โ โ
โโโโโโโโโโโโผโโโโโโโโผโโโโโโโโผโโโโโโโโโโโโโโโโโโโโโโค
โConv2D โ4,640 โ18.12KBโweight=f32[32,16,3,3]โ
โ โ(0) โ(0.00B)โbias=f32[32,1,1] โ
โโโโโโโโโโโโผโโโโโโโโผโโโโโโโโผโโโโโโโโโโโโโโโโโโโโโโค
โUpsample2Dโ0 โ0.00B โ โ
โ โ(0) โ(0.00B)โ โ
โโโโโโโโโโโโผโโโโโโโโผโโโโโโโโผโโโโโโโโโโโโโโโโโโโโโโค
โConv2D โ4,624 โ18.06KBโweight=f32[16,32,3,3]โ
โ โ(0) โ(0.00B)โbias=f32[16,1,1] โ
โโโโโโโโโโโโผโโโโโโโโผโโโโโโโโผโโโโโโโโโโโโโโโโโโโโโโค
โUpsample2Dโ0 โ0.00B โ โ
โ โ(0) โ(0.00B)โ โ
โโโโโโโโโโโโผโโโโโโโโผโโโโโโโโผโโโโโโโโโโโโโโโโโโโโโโค
โConv2D โ145 โ580.00Bโweight=f32[1,16,3,3] โ
โ โ(0) โ(0.00B)โbias=f32[1,1,1] โ
โโโโโโโโโโโโผโโโโโโโโผโโโโโโโโผโโโโโโโโโโโโโโโโโโโโโโค
โConv2D โ2 โ8.00B โweight=f32[1,1,1,1] โ
โ โ(0) โ(0.00B)โbias=f32[1,1,1] โ
โโโโโโโโโโโโดโโโโโโโโดโโโโโโโโดโโโโโโโโโโโโโโโโโโโโโโ
Total # : 9,571(0)
Dynamic #: 9,571(0)
Static/Frozen #: 0(0)
--------------------------------------------------
Total size : 37.39KB(0.00B)
Dynamic size: 37.39KB(0.00B)
Static/Frozen size: 0.00B(0.00B)
==================================================
Model diagram
Note : static_field(untrainable) is marked with x
AutoEncoder
โโโ l1=MaxPool2D
โ โxโ kernel_size=(2, 2)
โ โxโ strides=2
โ โxโ padding='valid'
โโโ l2=Conv2D
โ โโโ weight=f32[16,1,3,3]
โ โโโ bias=f32[16,1,1]
โ โxโ in_channels=1
โ โxโ out_channels=16
โ โxโ kernel_size=(3, 3)
โ โxโ padding=('valid', 'same', 'same')
โ โxโ strides=1
โโโ l3=MaxPool2D
โ โxโ kernel_size=(2, 2)
โ โxโ strides=2
โ โxโ padding='valid'
โโโ l4=Conv2D
โ โโโ weight=f32[32,16,3,3]
โ โโโ bias=f32[32,1,1]
โ โxโ in_channels=16
โ โxโ out_channels=32
โ โxโ kernel_size=(3, 3)
โ โxโ padding=('valid', 'same', 'same')
โ โxโ strides=1
โโโ l5=Upsample2D
โ โxโ scale_factor=2
โโโ l6=Conv2D
โ โโโ weight=f32[16,32,3,3]
โ โโโ bias=f32[16,1,1]
โ โxโ in_channels=32
โ โxโ out_channels=16
โ โxโ kernel_size=(3, 3)
โ โxโ padding=('valid', 'same', 'same')
โ โxโ strides=1
โโโ l7=Upsample2D
โ โxโ scale_factor=2
โโโ l8=Conv2D
โ โโโ weight=f32[1,16,3,3]
โ โโโ bias=f32[1,1,1]
โ โxโ in_channels=16
โ โxโ out_channels=1
โ โxโ kernel_size=(3, 3)
โ โxโ padding=('valid', 'same', 'same')
โ โxโ strides=1
โโโ l9=Conv2D
โโโ weight=f32[1,1,1,1]
โโโ bias=f32[1,1,1]
โxโ in_channels=1
โxโ out_channels=1
โxโ kernel_size=(1, 1)
โxโ padding=('valid', 'same', 'same')
โxโ strides=1
Shape propagration
>>> x = jnp.ones([1, 1, 100, 100])
>>> print(ae.tree_box(array=x))
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โAutoEncoder(Parent) โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโค
โโโโโโโโโโโโโโโโโโฌโโโโโโโโโฌโโโโโโโโโโโโโโโโโโโ โ
โโ โ Input โ f32[1,1,100,100] โ โ
โโ MaxPool2D(l1) โโโโโโโโโโผโโโโโโโโโโโโโโโโโโโค โ
โโ โ Output โ f32[1,1,50,50] โ โ
โโโโโโโโโโโโโโโโโโดโโโโโโโโโดโโโโโโโโโโโโโโโโโโโ โ
โโโโโโโโโโโโโโโฌโโโโโโโโโฌโโโโโโโโโโโโโโโโโโ โ
โโ โ Input โ f32[1,1,50,50] โ โ
โโ Conv2D(l2) โโโโโโโโโโผโโโโโโโโโโโโโโโโโโค โ
โโ โ Output โ f32[1,16,50,50] โ โ
โโโโโโโโโโโโโโโดโโโโโโโโโดโโโโโโโโโโโโโโโโโโ โ
โโโโโโโโโโโโโโโโโโฌโโโโโโโโโฌโโโโโโโโโโโโโโโโโโ โ
โโ โ Input โ f32[1,16,50,50] โ โ
โโ MaxPool2D(l3) โโโโโโโโโโผโโโโโโโโโโโโโโโโโโค โ
โโ โ Output โ f32[1,16,25,25] โ โ
โโโโโโโโโโโโโโโโโโดโโโโโโโโโดโโโโโโโโโโโโโโโโโโ โ
โโโโโโโโโโโโโโโฌโโโโโโโโโฌโโโโโโโโโโโโโโโโโโ โ
โโ โ Input โ f32[1,16,25,25] โ โ
โโ Conv2D(l4) โโโโโโโโโโผโโโโโโโโโโโโโโโโโโค โ
โโ โ Output โ f32[1,32,25,25] โ โ
โโโโโโโโโโโโโโโดโโโโโโโโโดโโโโโโโโโโโโโโโโโโ โ
โโโโโโโโโโโโโโโโโโโฌโโโโโโโโโฌโโโโโโโโโโโโโโโโโโ โ
โโ โ Input โ f32[1,32,25,25] โ โ
โโ Upsample2D(l5) โโโโโโโโโโผโโโโโโโโโโโโโโโโโโค โ
โโ โ Output โ f32[1,32,50,50] โ โ
โโโโโโโโโโโโโโโโโโโดโโโโโโโโโดโโโโโโโโโโโโโโโโโโ โ
โโโโโโโโโโโโโโโฌโโโโโโโโโฌโโโโโโโโโโโโโโโโโโ โ
โโ โ Input โ f32[1,32,50,50] โ โ
โโ Conv2D(l6) โโโโโโโโโโผโโโโโโโโโโโโโโโโโโค โ
โโ โ Output โ f32[1,16,50,50] โ โ
โโโโโโโโโโโโโโโดโโโโโโโโโดโโโโโโโโโโโโโโโโโโ โ
โโโโโโโโโโโโโโโโโโโฌโโโโโโโโโฌโโโโโโโโโโโโโโโโโโโโโ
โโ โ Input โ f32[1,16,50,50] โโ
โโ Upsample2D(l7) โโโโโโโโโโผโโโโโโโโโโโโโโโโโโโโคโ
โโ โ Output โ f32[1,16,100,100] โโ
โโโโโโโโโโโโโโโโโโโดโโโโโโโโโดโโโโโโโโโโโโโโโโโโโโโ
โโโโโโโโโโโโโโโฌโโโโโโโโโฌโโโโโโโโโโโโโโโโโโโโ โ
โโ โ Input โ f32[1,16,100,100] โ โ
โโ Conv2D(l8) โโโโโโโโโโผโโโโโโโโโโโโโโโโโโโโค โ
โโ โ Output โ f32[1,1,100,100] โ โ
โโโโโโโโโโโโโโโดโโโโโโโโโดโโโโโโโโโโโโโโโโโโโโ โ
โโโโโโโโโโโโโโโฌโโโโโโโโโฌโโโโโโโโโโโโโโโโโโโ โ
โโ โ Input โ f32[1,1,100,100] โ โ
โโ Conv2D(l9) โโโโโโโโโโผโโโโโโโโโโโโโโโโโโโค โ
โโ โ Output โ f32[1,1,100,100] โ โ
โโโโโโโโโโโโโโโดโโโโโโโโโดโโโโโโโโโโโโโโโโโโโ โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
Mermaid diagram
flowchart LR
id15696277213149321320[AutoEncoder]
id15696277213149321320 --> id159132120600507116(l1\nMaxPool2D)
id159132120600507116 --x id7500441386962467209["kernel_size\n(2, 2)"]
id159132120600507116 --x id10793958738030044218["strides\n2"]
id159132120600507116 --x id16245750007010064142["padding\n'valid'"]
id15696277213149321320 --> id10009280772564895168(l2\nConv2D)
id10009280772564895168 --- id11951215191344350637["weight\nf32[16,1,3,3]"]
id10009280772564895168 --- id1196345851686744158["bias\nf32[16,1,1]"]
id10009280772564895168 --x id6648137120666764082["in_channels\n1"]
id10009280772564895168 --x id8609436656910886517["out_channels\n16"]
id10009280772564895168 --x id14061227925890906441["kernel_size\n(3, 3)"]
id10009280772564895168 --x id16022527462135028876["padding\n('valid', 'same', 'same')"]
id10009280772564895168 --x id869300739493054269["strides\n1"]
id15696277213149321320 --> id7572222925824649475(l3\nMaxPool2D)
id7572222925824649475 --x id4749243995442935477["kernel_size\n(2, 2)"]
id7572222925824649475 --x id8042761346510512486["strides\n2"]
id7572222925824649475 --x id17892909998474900538["padding\n'valid'"]
id15696277213149321320 --> id10865740276892226484(l4\nConv2D)
id10865740276892226484 --- id7858522665561710831["weight\nf32[32,16,3,3]"]
id10865740276892226484 --- id11152040016629287840["bias\nf32[32,1,1]"]
id10865740276892226484 --x id2555444594884124276["in_channels\n16"]
id10865740276892226484 --x id118386748143878583["out_channels\n32"]
id10865740276892226484 --x id9968535400108266635["kernel_size\n(3, 3)"]
id10865740276892226484 --x id7531477553368020942["padding\n('valid', 'same', 'same')"]
id10865740276892226484 --x id10824994904435597951["strides\n1"]
id15696277213149321320 --> id2269144855147062920(l5\nUpsample2D)
id2269144855147062920 --x id599357636669938791["scale_factor\n2"]
id15696277213149321320 --> id18278831082116368843(l6\nConv2D)
id18278831082116368843 --- id5107325274042179099["weight\nf32[16,32,3,3]"]
id18278831082116368843 --- id8400842625109756108["bias\nf32[16,1,1]"]
id18278831082116368843 --x id18250991277074144160["in_channels\n32"]
id18278831082116368843 --x id1765546739608714979["out_channels\n16"]
id18278831082116368843 --x id7217338008588734903["kernel_size\n(3, 3)"]
id18278831082116368843 --x id9178637544832857338["padding\n('valid', 'same', 'same')"]
id18278831082116368843 --x id12472154895900434347["strides\n1"]
id15696277213149321320 --> id9682235660371205279(l7\nUpsample2D)
id9682235660371205279 --x id13157878626227910245["scale_factor\n2"]
id15696277213149321320 --> id12975753011438782288(l8\nConv2D)
id12975753011438782288 --- id16267157296346685599["weight\nf32[1,16,3,3]"]
id12975753011438782288 --- id1113930573704710992["bias\nf32[1,1,1]"]
id12975753011438782288 --x id10964079225669099044["in_channels\n16"]
id12975753011438782288 --x id12925378761913221479["out_channels\n1"]
id12975753011438782288 --x id11331140742258384578["kernel_size\n(3, 3)"]
id12975753011438782288 --x id1891725493427812222["padding\n('valid', 'same', 'same')"]
id12975753011438782288 --x id5185242844495389231["strides\n1"]
id15696277213149321320 --> id10538695164698536595(l9\nConv2D)
id10538695164698536595 --- id9065186100445270439["weight\nf32[1,1,1,1]"]
id10538695164698536595 --- id12358703451512847448["bias\nf32[1,1,1]"]
id10538695164698536595 --x id3762108029767683884["in_channels\n1"]
id10538695164698536595 --x id1325050183027438191["out_channels\n1"]
id10538695164698536595 --x id11175198834991826243["kernel_size\n(1, 1)"]
id10538695164698536595 --x id8738140988251580550["padding\n('valid', 'same', 'same')"]
id10538695164698536595 --x id12031658339319157559["strides\n1"]
๐ Acknowledgements
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for pytreeclass-0.0.9-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 23f44f094f7004c74481d5527afff8ece7e4f2a7c4e02ac6369b2de5cbbc8c21 |
|
MD5 | 3f97b2002a00eeba2562190e892c4266 |
|
BLAKE2b-256 | 1531e5c5ff7182bdea948ca3b72bcdd852a6cfcb576d3ceddd287adb6463995a |