Skip to content

Commit

Permalink
Merge branch 'main' into metal-bf16
Browse files Browse the repository at this point in the history
  • Loading branch information
kyle-mccarthy authored Jan 9, 2024
2 parents 9a48ae9 + 12b2a33 commit b500aa7
Show file tree
Hide file tree
Showing 58 changed files with 1,320 additions and 1,413 deletions.
8 changes: 8 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@ license = "MIT OR Apache-2.0"
accelerate-src = { version = "0.3.2" }
anyhow = { version = "1", features = ["backtrace"] }
byteorder = "1.4.3"
candle = { path = "./candle-core", package = "candle-core" }
candle-datasets = { path = "./candle-datasets" }
candle-flash-attn = { path = "./candle-flash-attn" }
candle-kernels = { path = "./candle-kernels" }
candle-metal-kernels = { path = "./candle-metal-kernels" }
candle-nn = { path = "./candle-nn" }
candle-onnx = { path = "./candle-onnx" }
candle-transformers = { path = "./candle-transformers" }
clap = { version = "4.2.4", features = ["derive"] }
criterion = { version = "0.5.1", default-features=false }
cudarc = { version = "0.9.14", features = ["f16"] }
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ And then head over to
- [`candle-ext`](https://github.com/mokeyish/candle-ext): An extension library to Candle that provides PyTorch functions not currently available in Candle.
- [`kalosm`](https://github.com/floneum/floneum/tree/master/interfaces/kalosm): A multi-modal meta-framework in Rust for interfacing with local pre-trained models with support for controlled generation, custom samplers, in-memory vector databases, audio transcription, and more.
- [`candle-sampling`](https://github.com/EricLBuehler/candle-sampling): Sampling techniques for Candle.
- [`gpt-from-scratch-rs`](https://github.com/jeroenvlek/gpt-from-scratch-rs): A port of Andrej Karpathy's _Let's build GPT_ tutorial on YouTube showcasing the Candle API on a toy problem.

If you have an addition to this list, please submit a pull request.

Expand Down
10 changes: 5 additions & 5 deletions candle-book/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ readme = "README.md"

[dependencies]
accelerate-src = { workspace = true, optional = true }
candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" }
candle-datasets = { path = "../candle-datasets", version = "0.3.3" }
candle-nn = { path = "../candle-nn", version = "0.3.3" }
candle-transformers = { path = "../candle-transformers", version = "0.3.3" }
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.3", optional = true }
candle = { workspace = true }
candle-datasets = { workspace = true }
candle-nn = { workspace = true }
candle-transformers = { workspace = true }
candle-flash-attn = { workspace = true, optional = true }
safetensors = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
Expand Down
7 changes: 3 additions & 4 deletions candle-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ readme = "README.md"
[dependencies]
accelerate-src = { workspace = true, optional = true }
byteorder = { workspace = true }
candle-kernels = { path = "../candle-kernels", version = "0.3.3", optional = true }
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.3", optional = true }
metal = { workspace = true, optional = true }
candle-kernels = { workspace = true, optional = true }
candle-metal-kernels = { workspace = true, optional = true }
metal = { workspace = true, optional = true}
cudarc = { workspace = true, optional = true }
gemm = { workspace = true }
half = { workspace = true }
Expand Down Expand Up @@ -48,4 +48,3 @@ metal = ["dep:metal", "dep:candle-metal-kernels"]
[[bench]]
name = "matmul"
harness = false
required-features = ["metal"]
14 changes: 11 additions & 3 deletions candle-core/src/pickle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,7 @@ impl PthTensors {
}

pub fn get(&self, name: &str) -> Result<Option<Tensor>> {
use std::io::Read;
let tensor_info = match self.tensor_infos.get(name) {
None => return Ok(None),
Some(tensor_info) => tensor_info,
Expand All @@ -712,14 +713,21 @@ impl PthTensors {
let mut zip = zip::ZipArchive::new(zip_reader)?;
let mut reader = zip.by_name(&tensor_info.path)?;

// Reading the data is a bit tricky as it can be strided, use an offset, etc.
// For now only support the basic case.
if tensor_info.layout.start_offset() != 0 || !tensor_info.layout.is_contiguous() {
// Reading the data is a bit tricky as it can be strided, for now only support the basic
// case.
if !tensor_info.layout.is_contiguous() {
crate::bail!(
"cannot retrieve non-contiguous tensors {:?}",
tensor_info.layout
)
}
let start_offset = tensor_info.layout.start_offset();
if start_offset > 0 {
std::io::copy(
&mut reader.by_ref().take(start_offset as u64),
&mut std::io::sink(),
)?;
}
let tensor = Tensor::from_reader(
tensor_info.layout.shape().clone(),
tensor_info.dtype,
Expand Down
Loading

0 comments on commit b500aa7

Please sign in to comment.