Function candle_nn::encoding::one_hot

source ·
pub fn one_hot<D: WithDType>(
    indices: Tensor,
    depth: usize,
    on_value: D,
    off_value: D
) -> Result<Tensor>
Expand description

One-hot/cold encoding.

Given an input tensor of indices, this function returns a tensor of the same shape as the input tensor with an additional dimension of the given depth size. The values in the returned tensor are all set to the off_value except for the positions represented by the indices, which are set to the on_value.

This method returns a tensor with a rank that is one rank larger than the input tensor.

As an example, the following tensor will be encoded to a one-hot matrix:

[[0i64, 2], [1, -1]]

with a depth of 4 will be encoded to:

[[[1, 0, 0, 0], [0, 0, 1, 0]], [[0, 1, 0, 0], [0, 0, 0, 0]]]

When the input tensor index has a value of -1, the corresponding one-hot vector will be ignored, resulting in a vector of values set to the off_value.

This method supports one-cold encoding by setting on_value to 0 and off_value to 1. By default on_value is 1 and off_value is 0.

Other encoding values can be used by setting on_value and off_value to the desired values.

§Examples

§One-hot encoding

use candle::{Shape, Tensor, Device};
use candle_nn::encoding::one_hot;

let device = candle::Device::Cpu;

let indices = Tensor::new(vec![vec![0i64, 2], vec![1, -1]], &device).unwrap();
let depth = 4;
let one_hot = one_hot(indices, depth, 1f32, 0f32).unwrap();

let expected_matrix = [
    [[1.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0]],
    [[0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]],
];

assert_eq!(one_hot.shape(), &Shape::from((2, 2, depth)));

let matrix = one_hot.to_vec3::<f32>().unwrap();

assert_eq!(matrix, expected_matrix);

§One-cold Encoding

use candle::{Shape, Tensor, Device};
use candle_nn::encoding::one_hot;


let device = candle::Device::Cpu;
let depth = 4;
let indices = Tensor::new(vec![vec![0u8, 2], vec![1, 3]], &device).unwrap();
let one_cold = one_hot(indices, depth, 0u8, 1u8).unwrap();

let expected_matrix = [[[0, 1, 1, 1], [1, 1, 0, 1]], [[1, 0, 1, 1], [1, 1, 1, 0]]];

assert_eq!(one_cold.shape(), &Shape::from((2, 2, depth)));

let matrix = one_cold.to_vec3::<u8>().unwrap();

assert_eq!(matrix, expected_matrix);

§Bails

This method bails if:

  • One of the index value is less than -1.
  • One of the index value is greater than or equal to the depth value.
  • The input data type is not U8, U32, or I64.

§API Design

The api design for this method is loosely based on the TensorFlow One-Hot method.