use crate::{Error, Tensor};
use std::ops::{
Bound, Range, RangeBounds, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive,
};
impl Tensor {
fn index(&self, indexers: &[TensorIndexer]) -> Result<Self, Error> {
let mut x = self.clone();
let dims = self.shape().dims();
let mut current_dim = 0;
for (i, indexer) in indexers.iter().enumerate() {
x = match indexer {
TensorIndexer::Select(n) => x.narrow(current_dim, *n, 1)?.squeeze(current_dim)?,
TensorIndexer::Narrow(left_bound, right_bound) => {
let start = match left_bound {
Bound::Included(n) => *n,
Bound::Excluded(n) => *n + 1,
Bound::Unbounded => 0,
};
let stop = match right_bound {
Bound::Included(n) => *n + 1,
Bound::Excluded(n) => *n,
Bound::Unbounded => dims[i],
};
let out = x.narrow(current_dim, start, stop.saturating_sub(start))?;
current_dim += 1;
out
}
TensorIndexer::IndexSelect(indexes) => {
if indexes.rank() != 1 {
crate::bail!("multi-dimensional tensor indexing is not supported")
}
let out = x.index_select(&indexes.to_device(x.device())?, current_dim)?;
current_dim += 1;
out
}
TensorIndexer::Err(e) => crate::bail!("indexing error {e:?}"),
};
}
Ok(x)
}
}
#[derive(Debug)]
pub enum TensorIndexer {
Select(usize),
Narrow(Bound<usize>, Bound<usize>),
IndexSelect(Tensor),
Err(Error),
}
impl From<usize> for TensorIndexer {
fn from(index: usize) -> Self {
TensorIndexer::Select(index)
}
}
impl From<&[u32]> for TensorIndexer {
fn from(index: &[u32]) -> Self {
match Tensor::new(index, &crate::Device::Cpu) {
Ok(tensor) => TensorIndexer::IndexSelect(tensor),
Err(e) => TensorIndexer::Err(e),
}
}
}
impl From<Vec<u32>> for TensorIndexer {
fn from(index: Vec<u32>) -> Self {
let len = index.len();
match Tensor::from_vec(index, len, &crate::Device::Cpu) {
Ok(tensor) => TensorIndexer::IndexSelect(tensor),
Err(e) => TensorIndexer::Err(e),
}
}
}
impl From<&Tensor> for TensorIndexer {
fn from(tensor: &Tensor) -> Self {
TensorIndexer::IndexSelect(tensor.clone())
}
}
trait RB: RangeBounds<usize> {}
impl RB for Range<usize> {}
impl RB for RangeFrom<usize> {}
impl RB for RangeFull {}
impl RB for RangeInclusive<usize> {}
impl RB for RangeTo<usize> {}
impl RB for RangeToInclusive<usize> {}
impl<T: RB> From<T> for TensorIndexer {
fn from(range: T) -> Self {
use std::ops::Bound::*;
let start = match range.start_bound() {
Included(idx) => Included(*idx),
Excluded(idx) => Excluded(*idx),
Unbounded => Unbounded,
};
let end = match range.end_bound() {
Included(idx) => Included(*idx),
Excluded(idx) => Excluded(*idx),
Unbounded => Unbounded,
};
TensorIndexer::Narrow(start, end)
}
}
pub trait IndexOp<T> {
fn i(&self, index: T) -> Result<Tensor, Error>;
}
impl<T> IndexOp<T> for Tensor
where
T: Into<TensorIndexer>,
{
fn i(&self, index: T) -> Result<Tensor, Error> {
self.index(&[index.into()])
}
}
macro_rules! index_op_tuple {
($($t:ident),+) => {
#[allow(non_snake_case)]
impl<$($t),*> IndexOp<($($t,)*)> for Tensor
where
$($t: Into<TensorIndexer>,)*
{
fn i(&self, ($($t,)*): ($($t,)*)) -> Result<Tensor, Error> {
self.index(&[$($t.into(),)*])
}
}
};
}
index_op_tuple!(A);
index_op_tuple!(A, B);
index_op_tuple!(A, B, C);
index_op_tuple!(A, B, C, D);
index_op_tuple!(A, B, C, D, E);
index_op_tuple!(A, B, C, D, E, F);
index_op_tuple!(A, B, C, D, E, F, G);