1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
// Variables are wrappers around tensors that can be modified, they are typically used for holding
// weights and being modified by gradient descent.
// We do not expose a public way to create variables as this would break the invariant that the
// tensor within a variable is actually with `is_variable` set to `true`.
use crate::{DType, Device, Error, Result, Shape, Tensor};

/// A variable is a wrapper around a tensor, however variables can have their content modified
/// whereas tensors are immutable.
#[derive(Clone, Debug)]
pub struct Var(Tensor);

impl std::fmt::Display for Var {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        std::fmt::Display::fmt(&self.0, f)
    }
}

impl std::ops::Deref for Var {
    type Target = Tensor;

    fn deref(&self) -> &Self::Target {
        self.0.as_ref()
    }
}

impl Var {
    pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
        let inner = Tensor::zeros_impl(shape, dtype, device, true)?;
        Ok(Self(inner))
    }

    pub fn ones<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
        let inner = Tensor::ones_impl(shape, dtype, device, true)?;
        Ok(Self(inner))
    }

    // Convert a tensor to a variable, if the tensor is already a variable then it is returned as is.
    pub fn from_tensor(t: &Tensor) -> Result<Self> {
        if t.is_variable() {
            Ok(Self(t.clone()))
        } else {
            let inner = t.make_var()?;
            Ok(Self(inner))
        }
    }

    pub fn rand_f64<S: Into<Shape>>(
        lo: f64,
        up: f64,
        s: S,
        dtype: DType,
        device: &Device,
    ) -> Result<Self> {
        let inner = Tensor::rand_f64_impl(lo, up, s, dtype, device, true)?;
        Ok(Self(inner))
    }

    pub fn randn_f64<S: Into<Shape>>(
        mean: f64,
        std: f64,
        s: S,
        dtype: DType,
        device: &Device,
    ) -> Result<Self> {
        let inner = Tensor::randn_f64_impl(mean, std, s, dtype, device, true)?;
        Ok(Self(inner))
    }

    pub fn rand<S: Into<Shape>, T: crate::FloatDType>(
        lo: T,
        up: T,
        s: S,
        device: &Device,
    ) -> Result<Self> {
        let inner = Tensor::rand_impl(lo, up, s, device, true)?;
        Ok(Self(inner))
    }

    pub fn randn<S: Into<Shape>, T: crate::FloatDType>(
        mean: T,
        std: T,
        s: S,
        device: &Device,
    ) -> Result<Self> {
        let inner = Tensor::randn_impl(mean, std, s, device, true)?;
        Ok(Self(inner))
    }

    /// Creates a new tensor on the specified device using the content and shape of the input.
    /// This is similar to `new` but the resulting tensor is a variable.
    pub fn new<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> {
        let shape = array.shape()?;
        let inner = Tensor::new_impl(array, shape, device, true)?;
        Ok(Self(inner))
    }

    pub fn from_vec<S: Into<Shape>, D: crate::WithDType>(
        data: Vec<D>,
        shape: S,
        device: &Device,
    ) -> Result<Self> {
        let inner = Tensor::from_vec_impl(data, shape, device, true)?;
        Ok(Self(inner))
    }

    pub fn from_slice<S: Into<Shape>, D: crate::WithDType>(
        array: &[D],
        shape: S,
        device: &Device,
    ) -> Result<Self> {
        let inner = Tensor::new_impl(array, shape.into(), device, true)?;
        Ok(Self(inner))
    }

    pub fn as_detached_tensor(&self) -> Tensor {
        self.0.detach()
    }

    pub fn as_tensor(&self) -> &Tensor {
        &self.0
    }

    /// Consumes this `Var` and return the underlying tensor.
    pub fn into_inner(self) -> Tensor {
        self.0
    }

    /// Sets the content of the inner tensor, this does not require a mutable reference as inner
    /// mutability is used.
    pub fn set(&self, src: &Tensor) -> Result<()> {
        if self.same_storage(src) {
            let msg = "cannot set a variable to a tensor that is derived from its value";
            Err(Error::CannotSetVar { msg }.bt())?
        }
        let (mut dst, layout) = self.storage_mut_and_layout();
        if !layout.is_contiguous() {
            let msg = "cannot set a non-contiguous variable";
            Err(Error::CannotSetVar { msg }.bt())?
        }
        let (src, src_l) = src.storage_and_layout();
        if layout.shape() != src_l.shape() {
            Err(Error::ShapeMismatchBinaryOp {
                lhs: layout.shape().clone(),
                rhs: src_l.shape().clone(),
                op: "set",
            }
            .bt())?
        }
        src.copy_strided_src(&mut dst, layout.start_offset(), src_l)?;
        Ok(())
    }
}