JAX study notes[7]
文章目录
- jax.numpy
- jax.numpy.linspace
- jax.numpy.arange
- references
jax.numpy
jax.numpy.linspace
jax.numpy.linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0, *, device=None)
return a set of numbers separated with the fixed interval evenly.
import jax
import jax.numpy as jnp
print(jnp.linspace(11,55,5))
[11. 22. 33. 44. 55.]
import jax
import jax.numpy as jnp
print(jnp.linspace(11,55,5,retstep=True))
(Array([11., 22., 33., 44., 55.], dtype=float32), Array(11., dtype=float32))
the function can also use for generating multiple dimensions array.
import jax
import jax.numpy as jnp
print(jnp.linspace(jnp.array([1,11]),jnp.array([5,55]),5))
[[ 1. 11.][ 2. 22.][ 3. 33.][ 4. 44.][ 5. 55.]]
jax.numpy.arange
the function make also a sequence consisted of number and they are separated by equal interval.the function is similar as jax.numpy.linspace
.but the important difference is that jax.numpy.arange
runing depend on the step
which means interval and the jax.numpy.linspace
is applied with the argument num
which represents the number of numbers.
jax.numpy.arange(start, stop=None, step=None, dtype=None, *, device=None)
[11 22 33 44]
there are a difference which is easily ignored usually that the stop
at jax.numpy.arange
means the end but excusive itself.by the way, jax.numpy.arange
’s arguments must be scalars.
the following code will get a error when it runing.
print(jnp.arange(jnp.array([1,11]),jnp.array([5,55]),11))
references
https://docs.jax.dev/