Type Alias candle_nn::var_builder::VarBuilder  
source · pub type VarBuilder<'a> = VarBuilderArgs<'a, Box<dyn SimpleBackend + 'a>>;Expand description
A simple VarBuilder, this is less generic than VarBuilderArgs but should cover most common
use cases.
Aliased Type§
struct VarBuilder<'a> { /* private fields */ }Implementations§
source§impl<'a> VarBuilder<'a>
 
impl<'a> VarBuilder<'a>
sourcepub fn from_backend(
    backend: Box<dyn SimpleBackend + 'a>,
    dtype: DType,
    device: Device
) -> Self
 
pub fn from_backend( backend: Box<dyn SimpleBackend + 'a>, dtype: DType, device: Device ) -> Self
Initializes a VarBuilder using a custom backend.
It is preferred to use one of the more specific constructors. This constructor is provided to allow downstream users to define their own backends.
sourcepub fn zeros(dtype: DType, dev: &Device) -> Self
 
pub fn zeros(dtype: DType, dev: &Device) -> Self
Initializes a VarBuilder that uses zeros for any tensor.
sourcepub fn from_tensors(
    ts: HashMap<String, Tensor>,
    dtype: DType,
    dev: &Device
) -> Self
 
pub fn from_tensors( ts: HashMap<String, Tensor>, dtype: DType, dev: &Device ) -> Self
Initializes a VarBuilder that retrieves tensors stored in a hashtable. An error is
returned if no tensor is available under the requested path or on shape mismatches.
sourcepub fn from_varmap(varmap: &VarMap, dtype: DType, dev: &Device) -> Self
 
pub fn from_varmap(varmap: &VarMap, dtype: DType, dev: &Device) -> Self
Initializes a VarBuilder using a VarMap. The requested tensors are created and
initialized on new paths, the same tensor is used if the same path is requested multiple
times. This is commonly used when initializing a model before training.
Note that it is possible to load the tensor values after model creation using the load
method on varmap, this can be used to start model training from an existing checkpoint.
sourcepub unsafe fn from_mmaped_safetensors<P: AsRef<Path>>(
    paths: &[P],
    dtype: DType,
    dev: &Device
) -> Result<Self>
 
pub unsafe fn from_mmaped_safetensors<P: AsRef<Path>>( paths: &[P], dtype: DType, dev: &Device ) -> Result<Self>
Initializes a VarBuilder that retrieves tensors stored in a collection of safetensors
files.
§Safety
The unsafe is inherited from [memmap2::MmapOptions].
sourcepub fn from_buffered_safetensors(
    data: Vec<u8>,
    dtype: DType,
    dev: &Device
) -> Result<Self>
 
pub fn from_buffered_safetensors( data: Vec<u8>, dtype: DType, dev: &Device ) -> Result<Self>
Initializes a VarBuilder from a binary builder in the safetensor format.
sourcepub fn from_npz<P: AsRef<Path>>(
    p: P,
    dtype: DType,
    dev: &Device
) -> Result<Self>
 
pub fn from_npz<P: AsRef<Path>>( p: P, dtype: DType, dev: &Device ) -> Result<Self>
Initializes a VarBuilder that retrieves tensors stored in a numpy npz file.
sourcepub fn from_pth<P: AsRef<Path>>(
    p: P,
    dtype: DType,
    dev: &Device
) -> Result<Self>
 
pub fn from_pth<P: AsRef<Path>>( p: P, dtype: DType, dev: &Device ) -> Result<Self>
Initializes a VarBuilder that retrieves tensors stored in a pytorch pth file.
sourcepub fn rename_f<F: Fn(&str) -> String + Sync + Send + 'static>(
    self,
    f: F
) -> Self
 
pub fn rename_f<F: Fn(&str) -> String + Sync + Send + 'static>( self, f: F ) -> Self
Gets a VarBuilder that applies some renaming function on tensor it gets queried for before passing the new names to the inner VarBuilder.
use candle::{Tensor, DType, Device};
let a = Tensor::arange(0f32, 6f32, &Device::Cpu)?.reshape((2, 3))?;
let tensors: std::collections::HashMap<_, _> = [
    ("foo".to_string(), a),
]
.into_iter()
.collect();
let vb = candle_nn::VarBuilder::from_tensors(tensors, DType::F32, &Device::Cpu);
assert!(vb.contains_tensor("foo"));
assert!(vb.get((2, 3), "foo").is_ok());
assert!(!vb.contains_tensor("bar"));
let vb = vb.rename_f(|f: &str| if f == "bar" { "foo".to_string() } else { f.to_string() });
assert!(vb.contains_tensor("bar"));
assert!(vb.contains_tensor("foo"));
assert!(vb.get((2, 3), "bar").is_ok());
assert!(vb.get((2, 3), "foo").is_ok());
assert!(!vb.contains_tensor("baz"));