Function candle_nn::ops::softmax

source ·
pub fn softmax<D: Dim>(xs: &Tensor, dim: D) -> Result<Tensor>
Expand description

Applies the softmax function to the input tensor, rescaling the element so that elements on a slice of fixed index on dimension dim are between 0 and 1 and sum to 1.

use candle::{Tensor, Device, test_utils::to_vec2_round};
let a = Tensor::new(&[[0f32, 1., 0., 1.], [-2., 2., 3., -3.]], &Device::Cpu)?;
let a = candle_nn::ops::softmax(&a, 1)?;
assert_eq!(
    to_vec2_round(&a, 4)?,
    &[
        [0.1345, 0.3655, 0.1345, 0.3655],
        [0.0049, 0.2671, 0.7262, 0.0018]
    ]);