-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathpy_bind.cpp
37 lines (32 loc) · 899 Bytes
/
py_bind.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
#include <torch/extension.h>
#include <torch/torch.h>
torch::Tensor ggml_dequantize(
torch::Tensor X,
int8_t type,
int64_t m,
int64_t n
);
torch::Tensor ggml_mul_mat_vec(
torch::Tensor W, // quant weight
torch::Tensor X, // input
int8_t type,
int64_t m
);
torch::Tensor ggml_mul_mat_vec_a8(
torch::Tensor W, // quant weight
torch::Tensor X, // input
int8_t type,
int64_t row
);
torch::Tensor ggml_mul_mat_a8(
torch::Tensor W, // quant weight
torch::Tensor X, // input
int8_t type,
int64_t row
);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("ggml_dequantize", &ggml_dequantize, "ggml_dequantize");
m.def("ggml_mul_mat_vec", &ggml_mul_mat_vec, "ggml_mul_mat_vec");
m.def("ggml_mul_mat_vec_a8", &ggml_mul_mat_vec_a8, "ggml_mul_mat_vec_a8");
m.def("ggml_mul_mat_a8", &ggml_mul_mat_a8, "ggml_mul_mat_a8");
}