use crate::slice::{InvalidSlice, SliceIterator, TensorIndexer};
use serde::{ser::SerializeMap, Deserialize, Deserializer, Serialize, Serializer};
use std::borrow::Cow;
use std::collections::HashMap;
use std::fs::File;
use std::io::{BufWriter, Write};
use std::path::Path;
const MAX_HEADER_SIZE: usize = 100_000_000;
#[derive(Debug)]
pub enum SafeTensorError {
InvalidHeader,
InvalidHeaderStart,
InvalidHeaderDeserialization,
HeaderTooLarge,
HeaderTooSmall,
InvalidHeaderLength,
TensorNotFound(String),
TensorInvalidInfo,
InvalidOffset(String),
IoError(std::io::Error),
JsonError(serde_json::Error),
InvalidTensorView(Dtype, Vec<usize>, usize),
MetadataIncompleteBuffer,
ValidationOverflow,
}
impl From<std::io::Error> for SafeTensorError {
fn from(error: std::io::Error) -> SafeTensorError {
SafeTensorError::IoError(error)
}
}
impl From<serde_json::Error> for SafeTensorError {
fn from(error: serde_json::Error) -> SafeTensorError {
SafeTensorError::JsonError(error)
}
}
impl std::fmt::Display for SafeTensorError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{self:?}")
}
}
impl std::error::Error for SafeTensorError {}
struct PreparedData {
n: u64,
header_bytes: Vec<u8>,
offset: usize,
}
pub trait View {
fn dtype(&self) -> Dtype;
fn shape(&self) -> &[usize];
fn data(&self) -> Cow<[u8]>;
fn data_len(&self) -> usize;
}
fn prepare<S: AsRef<str> + Ord + std::fmt::Display, V: View, I: IntoIterator<Item = (S, V)>>(
data: I,
data_info: &Option<HashMap<String, String>>,
) -> Result<(PreparedData, Vec<V>), SafeTensorError> {
let mut data: Vec<_> = data.into_iter().collect();
data.sort_by(|(lname, left), (rname, right)| {
right.dtype().cmp(&left.dtype()).then(lname.cmp(rname))
});
let mut tensors: Vec<V> = Vec::with_capacity(data.len());
let mut hmetadata = Vec::with_capacity(data.len());
let mut offset = 0;
let data: Vec<_> = data.into_iter().collect();
for (name, tensor) in data {
let n = tensor.data_len();
let tensor_info = TensorInfo {
dtype: tensor.dtype(),
shape: tensor.shape().to_vec(),
data_offsets: (offset, offset + n),
};
offset += n;
hmetadata.push((name.to_string(), tensor_info));
tensors.push(tensor);
}
let metadata: Metadata = Metadata::new(data_info.clone(), hmetadata)?;
let mut metadata_buf = serde_json::to_string(&metadata)?.into_bytes();
let extra = (8 - metadata_buf.len() % 8) % 8;
metadata_buf.extend(vec![b' '; extra]);
let n: u64 = metadata_buf.len() as u64;
Ok((
PreparedData {
n,
header_bytes: metadata_buf,
offset,
},
tensors,
))
}
pub fn serialize<
S: AsRef<str> + Ord + std::fmt::Display,
V: View,
I: IntoIterator<Item = (S, V)>,
>(
data: I,
data_info: &Option<HashMap<String, String>>,
) -> Result<Vec<u8>, SafeTensorError> {
let (
PreparedData {
n,
header_bytes,
offset,
},
tensors,
) = prepare(data, data_info)?;
let expected_size = 8 + header_bytes.len() + offset;
let mut buffer: Vec<u8> = Vec::with_capacity(expected_size);
buffer.extend(&n.to_le_bytes().to_vec());
buffer.extend(&header_bytes);
for tensor in tensors {
buffer.extend(tensor.data().as_ref());
}
Ok(buffer)
}
pub fn serialize_to_file<
S: AsRef<str> + Ord + std::fmt::Display,
V: View,
I: IntoIterator<Item = (S, V)>,
>(
data: I,
data_info: &Option<HashMap<String, String>>,
filename: &Path,
) -> Result<(), SafeTensorError> {
let (
PreparedData {
n, header_bytes, ..
},
tensors,
) = prepare(data, data_info)?;
let mut f = BufWriter::new(File::create(filename)?);
f.write_all(n.to_le_bytes().as_ref())?;
f.write_all(&header_bytes)?;
for tensor in tensors {
f.write_all(tensor.data().as_ref())?;
}
f.flush()?;
Ok(())
}
#[derive(Debug)]
pub struct SafeTensors<'data> {
metadata: Metadata,
data: &'data [u8],
}
impl<'data> SafeTensors<'data> {
pub fn read_metadata<'in_data>(
buffer: &'in_data [u8],
) -> Result<(usize, Metadata), SafeTensorError>
where
'in_data: 'data,
{
let buffer_len = buffer.len();
if buffer_len < 8 {
return Err(SafeTensorError::HeaderTooSmall);
}
let arr: [u8; 8] = [
buffer[0], buffer[1], buffer[2], buffer[3], buffer[4], buffer[5], buffer[6], buffer[7],
];
let n: usize = u64::from_le_bytes(arr)
.try_into()
.map_err(|_| SafeTensorError::HeaderTooLarge)?;
if n > MAX_HEADER_SIZE {
return Err(SafeTensorError::HeaderTooLarge);
}
let stop = n
.checked_add(8)
.ok_or(SafeTensorError::InvalidHeaderLength)?;
if stop > buffer_len {
return Err(SafeTensorError::InvalidHeaderLength);
}
let string =
std::str::from_utf8(&buffer[8..stop]).map_err(|_| SafeTensorError::InvalidHeader)?;
let metadata: Metadata = serde_json::from_str(string)
.map_err(|_| SafeTensorError::InvalidHeaderDeserialization)?;
let buffer_end = metadata.validate()?;
if buffer_end + 8 + n != buffer_len {
return Err(SafeTensorError::MetadataIncompleteBuffer);
}
Ok((n, metadata))
}
pub fn deserialize<'in_data>(buffer: &'in_data [u8]) -> Result<Self, SafeTensorError>
where
'in_data: 'data,
{
let (n, metadata) = SafeTensors::read_metadata(buffer)?;
let data = &buffer[n + 8..];
Ok(Self { metadata, data })
}
pub fn tensors(&self) -> Vec<(String, TensorView<'_>)> {
let mut tensors = Vec::with_capacity(self.metadata.index_map.len());
for (name, &index) in &self.metadata.index_map {
let info = &self.metadata.tensors[index];
let tensorview = TensorView {
dtype: info.dtype,
shape: info.shape.clone(),
data: &self.data[info.data_offsets.0..info.data_offsets.1],
};
tensors.push((name.to_string(), tensorview));
}
tensors
}
pub fn tensor(&self, tensor_name: &str) -> Result<TensorView<'_>, SafeTensorError> {
if let Some(index) = &self.metadata.index_map.get(tensor_name) {
if let Some(info) = &self.metadata.tensors.get(**index) {
Ok(TensorView {
dtype: info.dtype,
shape: info.shape.clone(),
data: &self.data[info.data_offsets.0..info.data_offsets.1],
})
} else {
Err(SafeTensorError::TensorNotFound(tensor_name.to_string()))
}
} else {
Err(SafeTensorError::TensorNotFound(tensor_name.to_string()))
}
}
pub fn names(&self) -> Vec<&'_ String> {
self.metadata.index_map.keys().collect()
}
#[inline]
pub fn len(&self) -> usize {
self.metadata.tensors.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.metadata.tensors.is_empty()
}
}
#[derive(Debug, Clone)]
pub struct Metadata {
metadata: Option<HashMap<String, String>>,
tensors: Vec<TensorInfo>,
index_map: HashMap<String, usize>,
}
#[derive(Serialize, Deserialize)]
struct HashMetadata {
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(rename = "__metadata__")]
metadata: Option<HashMap<String, String>>,
#[serde(flatten)]
tensors: HashMap<String, TensorInfo>,
}
impl<'de> Deserialize<'de> for Metadata {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let hashdata: HashMetadata = HashMetadata::deserialize(deserializer)?;
let (metadata, tensors) = (hashdata.metadata, hashdata.tensors);
let mut tensors: Vec<_> = tensors.into_iter().collect();
tensors.sort_by(|(_, left), (_, right)| left.data_offsets.cmp(&right.data_offsets));
Metadata::new(metadata, tensors).map_err(serde::de::Error::custom)
}
}
impl Serialize for Metadata {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut names = vec![""; self.index_map.len()];
for (name, index) in &self.index_map {
names[*index] = name;
}
let tensors: Vec<_> = names.iter().zip(self.tensors.iter()).collect();
let mut map = serializer.serialize_map(Some(tensors.len()))?;
if let Some(metadata) = &self.metadata {
map.serialize_entry("__metadata__", metadata)?;
}
for (name, info) in tensors {
map.serialize_entry(&name, &info)?;
}
map.end()
}
}
impl Metadata {
fn new(
metadata: Option<HashMap<String, String>>,
tensors: Vec<(String, TensorInfo)>,
) -> Result<Self, SafeTensorError> {
let mut index_map = HashMap::with_capacity(tensors.len());
let tensors: Vec<_> = tensors
.into_iter()
.enumerate()
.map(|(index, (k, tensor))| {
index_map.insert(k, index);
tensor
})
.collect();
let metadata = Self {
metadata,
tensors,
index_map,
};
Ok(metadata)
}
fn validate(&self) -> Result<usize, SafeTensorError> {
let mut start = 0;
for (i, info) in self.tensors.iter().enumerate() {
let (s, e) = info.data_offsets;
if s != start || e < s {
let tensor_name = self
.index_map
.iter()
.find_map(|(name, &index)| if index == i { Some(&name[..]) } else { None })
.unwrap_or("no_tensor");
return Err(SafeTensorError::InvalidOffset(tensor_name.to_string()));
}
start = e;
let nelements: usize = info
.shape
.iter()
.cloned()
.try_fold(1usize, usize::checked_mul)
.ok_or(SafeTensorError::ValidationOverflow)?;
let nbytes = nelements
.checked_mul(info.dtype.size())
.ok_or(SafeTensorError::ValidationOverflow)?;
if (e - s) != nbytes {
return Err(SafeTensorError::TensorInvalidInfo);
}
}
Ok(start)
}
pub fn info(&self, name: &str) -> Option<&TensorInfo> {
let index = self.index_map.get(name)?;
self.tensors.get(*index)
}
pub fn tensors(&self) -> HashMap<String, &TensorInfo> {
self.index_map
.iter()
.map(|(tensor_name, index)| (tensor_name.clone(), &self.tensors[*index]))
.collect()
}
pub fn metadata(&self) -> &Option<HashMap<String, String>> {
&self.metadata
}
}
#[derive(Debug, PartialEq, Eq)]
pub struct TensorView<'data> {
dtype: Dtype,
shape: Vec<usize>,
data: &'data [u8],
}
impl<'data> View for &TensorView<'data> {
fn dtype(&self) -> Dtype {
self.dtype
}
fn shape(&self) -> &[usize] {
&self.shape
}
fn data(&self) -> Cow<[u8]> {
self.data.into()
}
fn data_len(&self) -> usize {
self.data.len()
}
}
impl<'data> View for TensorView<'data> {
fn dtype(&self) -> Dtype {
self.dtype
}
fn shape(&self) -> &[usize] {
&self.shape
}
fn data(&self) -> Cow<[u8]> {
self.data.into()
}
fn data_len(&self) -> usize {
self.data.len()
}
}
impl<'data> TensorView<'data> {
pub fn new(
dtype: Dtype,
shape: Vec<usize>,
data: &'data [u8],
) -> Result<Self, SafeTensorError> {
let n = data.len();
let n_elements: usize = shape.iter().product();
if n != n_elements * dtype.size() {
Err(SafeTensorError::InvalidTensorView(dtype, shape, n))
} else {
Ok(Self { dtype, shape, data })
}
}
pub fn dtype(&self) -> Dtype {
self.dtype
}
pub fn shape(&'data self) -> &'data [usize] {
&self.shape
}
pub fn data(&self) -> &'data [u8] {
self.data
}
pub fn sliced_data(
&'data self,
slices: &[TensorIndexer],
) -> Result<SliceIterator<'data>, InvalidSlice> {
SliceIterator::new(self, slices)
}
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct TensorInfo {
pub dtype: Dtype,
pub shape: Vec<usize>,
pub data_offsets: (usize, usize),
}
#[derive(Debug, Deserialize, Serialize, Clone, Copy, PartialEq, Eq, Ord, PartialOrd)]
#[non_exhaustive]
pub enum Dtype {
BOOL,
U8,
I8,
#[allow(non_camel_case_types)]
F8_E5M2,
#[allow(non_camel_case_types)]
F8_E4M3,
I16,
U16,
F16,
BF16,
I32,
U32,
F32,
F64,
I64,
U64,
}
impl Dtype {
pub fn size(&self) -> usize {
match self {
Dtype::BOOL => 1,
Dtype::U8 => 1,
Dtype::I8 => 1,
Dtype::F8_E5M2 => 1,
Dtype::F8_E4M3 => 1,
Dtype::I16 => 2,
Dtype::U16 => 2,
Dtype::I32 => 4,
Dtype::U32 => 4,
Dtype::I64 => 8,
Dtype::U64 => 8,
Dtype::F16 => 2,
Dtype::BF16 => 2,
Dtype::F32 => 4,
Dtype::F64 => 8,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::slice::IndexOp;
use proptest::prelude::*;
const MAX_DIMENSION: usize = 8;
const MAX_SIZE: usize = 8;
const MAX_TENSORS: usize = 8;
fn arbitrary_dtype() -> impl Strategy<Value = Dtype> {
prop_oneof![
Just(Dtype::BOOL),
Just(Dtype::U8),
Just(Dtype::I8),
Just(Dtype::I16),
Just(Dtype::U16),
Just(Dtype::I32),
Just(Dtype::U32),
Just(Dtype::I64),
Just(Dtype::U64),
Just(Dtype::F16),
Just(Dtype::BF16),
Just(Dtype::F32),
Just(Dtype::F64),
]
}
fn arbitrary_shape() -> impl Strategy<Value = Vec<usize>> {
(1..MAX_DIMENSION).prop_flat_map(|length| prop::collection::vec(1..MAX_SIZE, length))
}
fn arbitrary_metadata() -> impl Strategy<Value = Metadata> {
(1..MAX_TENSORS)
.prop_flat_map(|size| {
(
prop::collection::vec(arbitrary_dtype(), size),
prop::collection::vec(arbitrary_shape(), size),
)
})
.prop_map(|(dtypes, shapes)| {
let mut start = 0;
let tensors: Vec<TensorInfo> = dtypes
.iter()
.zip(shapes)
.map(|(dtype, shape)| {
let length: usize = shape.iter().product();
let end = start + length * dtype.size();
let tensor = TensorInfo {
dtype: *dtype,
shape,
data_offsets: (start, end),
};
start = end;
tensor
})
.collect();
let index_map = (0..tensors.len())
.map(|index| (format!("t.{index}"), index))
.collect();
Metadata {
metadata: None,
tensors,
index_map,
}
})
}
fn data_size(metadata: &Metadata) -> usize {
metadata.tensors.last().unwrap().data_offsets.1
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(20))]
#[test]
fn test_indexing(metadata in arbitrary_metadata()) {
let data = vec![0u8; data_size(&metadata)];
let tensors = SafeTensors { metadata, data: &data };
for name in tensors.names() {
assert!(tensors.tensor(name).is_ok());
}
}
#[test]
fn test_roundtrip(metadata in arbitrary_metadata()) {
let data: Vec<u8> = (0..data_size(&metadata)).map(|x| x as u8).collect();
let before = SafeTensors { metadata, data: &data };
let tensors = before.tensors();
let bytes = serialize(tensors.iter().map(|(name, view)| (name.to_string(), view)), &None).unwrap();
let after = SafeTensors::deserialize(&bytes).unwrap();
assert_eq!(before.names().len(), after.names().len());
for name in before.names() {
let tensor_before = before.tensor(name).unwrap();
let tensor_after = after.tensor(name).unwrap();
assert_eq!(tensor_after.data().as_ptr() as usize % tensor_after.dtype().size(), 0);
assert_eq!(tensor_before, tensor_after);
}
}
}
#[test]
fn test_serialization() {
let data: Vec<u8> = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0]
.into_iter()
.flat_map(|f| f.to_le_bytes())
.collect();
let shape = vec![1, 2, 3];
let attn_0 = TensorView::new(Dtype::F32, shape, &data).unwrap();
let metadata: HashMap<String, TensorView> =
[("attn.0".to_string(), attn_0)].into_iter().collect();
let out = serialize(&metadata, &None).unwrap();
assert_eq!(
out,
[
64, 0, 0, 0, 0, 0, 0, 0, 123, 34, 97, 116, 116, 110, 46, 48, 34, 58, 123, 34, 100,
116, 121, 112, 101, 34, 58, 34, 70, 51, 50, 34, 44, 34, 115, 104, 97, 112, 101, 34,
58, 91, 49, 44, 50, 44, 51, 93, 44, 34, 100, 97, 116, 97, 95, 111, 102, 102, 115,
101, 116, 115, 34, 58, 91, 48, 44, 50, 52, 93, 125, 125, 0, 0, 0, 0, 0, 0, 128, 63,
0, 0, 0, 64, 0, 0, 64, 64, 0, 0, 128, 64, 0, 0, 160, 64
]
);
let _parsed = SafeTensors::deserialize(&out).unwrap();
}
#[test]
fn test_serialization_forced_alignement() {
let data: Vec<u8> = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0]
.into_iter()
.flat_map(|f| f.to_le_bytes())
.collect();
let shape = vec![1, 1, 2, 3];
let attn_0 = TensorView::new(Dtype::F32, shape, &data).unwrap();
let metadata: HashMap<String, TensorView> =
[("attn0".to_string(), attn_0)].into_iter().collect();
let out = serialize(&metadata, &None).unwrap();
assert_eq!(
out,
[
72, 0, 0, 0, 0, 0, 0, 0, 123, 34, 97, 116, 116, 110, 48, 34, 58, 123, 34, 100, 116,
121, 112, 101, 34, 58, 34, 70, 51, 50, 34, 44, 34, 115, 104, 97, 112, 101, 34, 58,
91, 49, 44, 49, 44, 50, 44, 51, 93, 44, 34, 100, 97, 116, 97, 95, 111, 102, 102,
115, 101, 116, 115, 34, 58, 91, 48, 44, 50, 52, 93, 125, 125, 32, 32, 32, 32, 32,
32, 32, 0, 0, 0, 0, 0, 0, 128, 63, 0, 0, 0, 64, 0, 0, 64, 64, 0, 0, 128, 64, 0, 0,
160, 64
],
);
let parsed = SafeTensors::deserialize(&out).unwrap();
let tensor = parsed.tensor("attn0").unwrap();
assert_eq!(tensor.data().as_ptr() as usize % tensor.dtype().size(), 0);
}
#[test]
fn test_slicing() {
let data: Vec<u8> = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0]
.into_iter()
.flat_map(|f| f.to_le_bytes())
.collect();
let attn_0 = TensorView {
dtype: Dtype::F32,
shape: vec![1, 2, 3],
data: &data,
};
let metadata: HashMap<String, TensorView> =
[("attn.0".to_string(), attn_0)].into_iter().collect();
let out = serialize(&metadata, &None).unwrap();
let parsed = SafeTensors::deserialize(&out).unwrap();
let out_buffer: Vec<u8> = parsed
.tensor("attn.0")
.unwrap()
.slice((.., ..1))
.unwrap()
.flat_map(|b| b.to_vec())
.collect();
assert_eq!(out_buffer, vec![0u8, 0, 0, 0, 0, 0, 128, 63, 0, 0, 0, 64]);
assert_eq!(
out_buffer,
vec![0.0f32, 1.0, 2.0]
.into_iter()
.flat_map(|f| f.to_le_bytes())
.collect::<Vec<_>>()
);
let out_buffer: Vec<u8> = parsed
.tensor("attn.0")
.unwrap()
.slice((.., .., ..1))
.unwrap()
.flat_map(|b| b.to_vec())
.collect();
assert_eq!(out_buffer, vec![0u8, 0, 0, 0, 0, 0, 64, 64]);
assert_eq!(
out_buffer,
vec![0.0f32, 3.0]
.into_iter()
.flat_map(|f| f.to_le_bytes())
.collect::<Vec<_>>()
);
}
#[test]
fn test_gpt2() {
gpt2_like(12, "gpt2");
}
#[test]
fn test_gpt2_tiny() {
gpt2_like(6, "gpt2_tiny");
}
fn gpt2_like(n_heads: usize, model_id: &str) {
let mut tensors_desc = vec![];
tensors_desc.push(("wte".to_string(), vec![50257, 768]));
tensors_desc.push(("wpe".to_string(), vec![1024, 768]));
for i in 0..n_heads {
tensors_desc.push((format!("h.{i}.ln_1.weight"), vec![768]));
tensors_desc.push((format!("h.{i}.ln_1.bias"), vec![768]));
tensors_desc.push((format!("h.{i}.attn.bias"), vec![1, 1, 1024, 1024]));
tensors_desc.push((format!("h.{i}.attn.c_attn.weight"), vec![768, 2304]));
tensors_desc.push((format!("h.{i}.attn.c_attn.bias"), vec![2304]));
tensors_desc.push((format!("h.{i}.attn.c_proj.weight"), vec![768, 768]));
tensors_desc.push((format!("h.{i}.attn.c_proj.bias"), vec![768]));
tensors_desc.push((format!("h.{i}.ln_2.weight"), vec![768]));
tensors_desc.push((format!("h.{i}.ln_2.bias"), vec![768]));
tensors_desc.push((format!("h.{i}.mlp.c_fc.weight"), vec![768, 3072]));
tensors_desc.push((format!("h.{i}.mlp.c_fc.bias"), vec![3072]));
tensors_desc.push((format!("h.{i}.mlp.c_proj.weight"), vec![3072, 768]));
tensors_desc.push((format!("h.{i}.mlp.c_proj.bias"), vec![768]));
}
tensors_desc.push(("ln_f.weight".to_string(), vec![768]));
tensors_desc.push(("ln_f.bias".to_string(), vec![768]));
let dtype = Dtype::F32;
let n: usize = tensors_desc
.iter()
.map(|(_, shape)| shape.iter().product::<usize>())
.sum::<usize>()
* dtype.size(); let all_data = vec![0; n];
let mut metadata = HashMap::with_capacity(tensors_desc.len());
let mut offset = 0;
for (name, shape) in tensors_desc {
let n: usize = shape.iter().product();
let buffer = &all_data[offset..offset + n * dtype.size()];
let tensor = TensorView::new(dtype, shape, buffer).unwrap();
metadata.insert(name, tensor);
offset += n;
}
let filename = format!("./out_{model_id}.safetensors");
let out = serialize(&metadata, &None).unwrap();
std::fs::write(&filename, out).unwrap();
let raw = std::fs::read(&filename).unwrap();
let _deserialized = SafeTensors::deserialize(&raw).unwrap();
std::fs::remove_file(&filename).unwrap();
serialize_to_file(&metadata, &None, Path::new(&filename)).unwrap();
let raw = std::fs::read(&filename).unwrap();
let _deserialized = SafeTensors::deserialize(&raw).unwrap();
std::fs::remove_file(&filename).unwrap();
}
#[test]
fn test_empty_shapes_allowed() {
let serialized = b"8\x00\x00\x00\x00\x00\x00\x00{\"test\":{\"dtype\":\"I32\",\"shape\":[],\"data_offsets\":[0,4]}}\x00\x00\x00\x00";
let loaded = SafeTensors::deserialize(serialized).unwrap();
assert_eq!(loaded.names(), vec!["test"]);
let tensor = loaded.tensor("test").unwrap();
assert!(tensor.shape().is_empty());
assert_eq!(tensor.dtype(), Dtype::I32);
assert_eq!(tensor.data(), b"\0\0\0\0");
}
#[test]
fn test_deserialization() {
let serialized = b"<\x00\x00\x00\x00\x00\x00\x00{\"test\":{\"dtype\":\"I32\",\"shape\":[2,2],\"data_offsets\":[0,16]}}\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00";
let loaded = SafeTensors::deserialize(serialized).unwrap();
assert_eq!(loaded.len(), 1);
assert_eq!(loaded.names(), vec!["test"]);
let tensor = loaded.tensor("test").unwrap();
assert_eq!(tensor.shape(), vec![2, 2]);
assert_eq!(tensor.dtype(), Dtype::I32);
assert_eq!(tensor.data(), b"\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0");
}
#[test]
fn test_json_attack() {
let mut tensors = HashMap::new();
let dtype = Dtype::F32;
let shape = vec![2, 2];
let data_offsets = (0, 16);
for i in 0..10 {
tensors.insert(
format!("weight_{i}"),
TensorInfo {
dtype,
shape: shape.clone(),
data_offsets,
},
);
}
let metadata = HashMetadata {
metadata: None,
tensors,
};
let serialized = serde_json::to_string(&metadata).unwrap();
let serialized = serialized.as_bytes();
let n = serialized.len();
let filename = "out.safetensors";
let mut f = BufWriter::new(File::create(filename).unwrap());
f.write_all(n.to_le_bytes().as_ref()).unwrap();
f.write_all(serialized).unwrap();
f.write_all(b"\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0").unwrap();
f.flush().unwrap();
let reloaded = std::fs::read(filename).unwrap();
match SafeTensors::deserialize(&reloaded) {
Err(SafeTensorError::InvalidOffset(_)) => {
}
Err(err) => panic!("Unexpected error {err:?}"),
Ok(_) => panic!("This should not be able to be deserialized"),
}
}
#[test]
fn test_metadata_incomplete_buffer() {
let serialized = b"<\x00\x00\x00\x00\x00\x00\x00{\"test\":{\"dtype\":\"I32\",\"shape\":[2,2],\"data_offsets\":[0,16]}}\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00extra_bogus_data_for_polyglot_file";
match SafeTensors::deserialize(serialized) {
Err(SafeTensorError::MetadataIncompleteBuffer) => {
}
_ => panic!("This should not be able to be deserialized"),
}
let serialized = b"<\x00\x00\x00\x00\x00\x00\x00{\"test\":{\"dtype\":\"I32\",\"shape\":[2,2],\"data_offsets\":[0,16]}}\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"; match SafeTensors::deserialize(serialized) {
Err(SafeTensorError::MetadataIncompleteBuffer) => {
}
_ => panic!("This should not be able to be deserialized"),
}
}
#[test]
fn test_header_too_large() {
let serialized = b"<\x00\x00\x00\x00\xff\xff\xff{\"test\":{\"dtype\":\"I32\",\"shape\":[2,2],\"data_offsets\":[0,16]}}\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00";
match SafeTensors::deserialize(serialized) {
Err(SafeTensorError::HeaderTooLarge) => {
}
_ => panic!("This should not be able to be deserialized"),
}
}
#[test]
fn test_header_too_small() {
let serialized = b"";
match SafeTensors::deserialize(serialized) {
Err(SafeTensorError::HeaderTooSmall) => {
}
_ => panic!("This should not be able to be deserialized"),
}
}
#[test]
fn test_invalid_header_length() {
let serialized = b"<\x00\x00\x00\x00\x00\x00\x00";
match SafeTensors::deserialize(serialized) {
Err(SafeTensorError::InvalidHeaderLength) => {
}
_ => panic!("This should not be able to be deserialized"),
}
}
#[test]
fn test_invalid_header_non_utf8() {
let serialized = b"\x01\x00\x00\x00\x00\x00\x00\x00\xff";
match SafeTensors::deserialize(serialized) {
Err(SafeTensorError::InvalidHeader) => {
}
_ => panic!("This should not be able to be deserialized"),
}
}
#[test]
fn test_invalid_header_not_json() {
let serialized = b"\x01\x00\x00\x00\x00\x00\x00\x00{";
match SafeTensors::deserialize(serialized) {
Err(SafeTensorError::InvalidHeaderDeserialization) => {
}
_ => panic!("This should not be able to be deserialized"),
}
}
#[test]
fn test_whitespace_padded_header() {
let serialized = b"\x06\x00\x00\x00\x00\x00\x00\x00{}\x0D\x20\x09\x0A";
let loaded = SafeTensors::deserialize(serialized).unwrap();
assert_eq!(loaded.len(), 0);
}
#[test]
fn test_zero_sized_tensor() {
let serialized = b"<\x00\x00\x00\x00\x00\x00\x00{\"test\":{\"dtype\":\"I32\",\"shape\":[2,0],\"data_offsets\":[0, 0]}}";
let loaded = SafeTensors::deserialize(serialized).unwrap();
assert_eq!(loaded.names(), vec!["test"]);
let tensor = loaded.tensor("test").unwrap();
assert_eq!(tensor.shape(), vec![2, 0]);
assert_eq!(tensor.dtype(), Dtype::I32);
assert_eq!(tensor.data(), b"");
}
#[test]
fn test_invalid_info() {
let serialized = b"<\x00\x00\x00\x00\x00\x00\x00{\"test\":{\"dtype\":\"I32\",\"shape\":[2,2],\"data_offsets\":[0, 4]}}";
match SafeTensors::deserialize(serialized) {
Err(SafeTensorError::TensorInvalidInfo) => {
}
_ => panic!("This should not be able to be deserialized"),
}
}
#[test]
fn test_validation_overflow() {
let serialized = b"O\x00\x00\x00\x00\x00\x00\x00{\"test\":{\"dtype\":\"I32\",\"shape\":[2,18446744073709551614],\"data_offsets\":[0,16]}}\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00";
match SafeTensors::deserialize(serialized) {
Err(SafeTensorError::ValidationOverflow) => {
}
_ => panic!("This should not be able to be deserialized"),
}
let serialized = b"N\x00\x00\x00\x00\x00\x00\x00{\"test\":{\"dtype\":\"I32\",\"shape\":[2,9223372036854775807],\"data_offsets\":[0,16]}}\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00";
match SafeTensors::deserialize(serialized) {
Err(SafeTensorError::ValidationOverflow) => {
}
_ => panic!("This should not be able to be deserialized"),
}
}
}