Jax test
In [2]:
Copied!
from pylab import *
from pylab import *
In [3]:
Copied!
from jax import numpy as jnp
import jax
def simple_graph_with_log(x, aol):
x = x + 2
x = x ** 2
x = x + 3
y = x.mean()
aol.append(y)
return y, aol
inp = jnp.arange(3, dtype=jnp.float32)
aol = [2,3,4,5]
aol2 = [2,3,4,5,6]
from jax import numpy as jnp
import jax
def simple_graph_with_log(x, aol):
x = x + 2
x = x ** 2
x = x + 3
y = x.mean()
aol.append(y)
return y, aol
inp = jnp.arange(3, dtype=jnp.float32)
aol = [2,3,4,5]
aol2 = [2,3,4,5,6]
In [4]:
Copied!
with jax.log_compiles():
jax.make_jaxpr(simple_graph_with_log)(inp, aol)
with jax.log_compiles():
jax.make_jaxpr(simple_graph_with_log)(inp, aol)
WARNING:2025-05-14 08:18:39,104:jax._src.dispatch:184: Finished tracing + transforming add for pjit in 0.000792742 sec WARNING:jax._src.dispatch:Finished tracing + transforming add for pjit in 0.000792742 sec WARNING:2025-05-14 08:18:39,111:jax._src.dispatch:184: Finished tracing + transforming _reduce_sum for pjit in 0.001460314 sec WARNING:jax._src.dispatch:Finished tracing + transforming _reduce_sum for pjit in 0.001460314 sec WARNING:2025-05-14 08:18:39,115:jax._src.dispatch:184: Finished tracing + transforming _mean for pjit in 0.006220579 sec WARNING:jax._src.dispatch:Finished tracing + transforming _mean for pjit in 0.006220579 sec WARNING:2025-05-14 08:18:39,118:jax._src.dispatch:184: Finished tracing + transforming simple_graph_with_log for pjit in 0.015530109 sec WARNING:jax._src.dispatch:Finished tracing + transforming simple_graph_with_log for pjit in 0.015530109 sec
In [5]:
Copied!
jax.make_jaxpr(simple_graph_with_log)(inp, aol)
jax.make_jaxpr(simple_graph_with_log)(inp, aol)
Out[5]:
{ lambda ; a:f32[3] b:i32[] c:i32[] d:i32[] e:i32[]. let f:f32[3] = add a 2.0 g:f32[3] = integer_pow[y=2] f h:f32[3] = add g 3.0 i:f32[] = reduce_sum[axes=(0,)] h j:f32[] = div i 3.0 in (j, b, c, d, e, j) }
In [6]:
Copied!
jax.make_jaxpr(simple_graph_with_log)(inp, aol2)
jax.make_jaxpr(simple_graph_with_log)(inp, aol2)
Out[6]:
{ lambda ; a:f32[3] b:i32[] c:i32[] d:i32[] e:i32[] f:i32[]. let g:f32[3] = add a 2.0 h:f32[3] = integer_pow[y=2] g i:f32[3] = add h 3.0 j:f32[] = reduce_sum[axes=(0,)] i k:f32[] = div j 3.0 in (k, b, c, d, e, f, k) }
In [7]:
Copied!
def modify_third(aol):
aol[3] = aol[3]**2
return aol
def modify_third(aol):
aol[3] = aol[3]**2
return aol
In [8]:
Copied!
jax.make_jaxpr(modify_third)(aol2)
jax.make_jaxpr(modify_third)(aol2)
Out[8]:
{ lambda ; a:i32[] b:i32[] c:i32[] d:i32[] e:i32[]. let f:i32[] = integer_pow[y=2] d in (a, b, c, f, e) }
In [9]:
Copied!
aol2
aol2
Out[9]:
[2, 3, 4, 5, 6]
In [9]:
Copied!