@@ -795,8 +795,8 @@ class llama_model_params(ctypes.Structure):
795
795
# uint32_t n_batch; // logical maximum batch size that can be submitted to llama_decode
796
796
# uint32_t n_ubatch; // physical maximum batch size
797
797
# uint32_t n_seq_max; // max number of sequences (i.e. distinct states for recurrent models)
798
- # uint32_t n_threads; // number of threads to use for generation
799
- # uint32_t n_threads_batch; // number of threads to use for batch processing
798
+ # int32_t n_threads; // number of threads to use for generation
799
+ # int32_t n_threads_batch; // number of threads to use for batch processing
800
800
801
801
# enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
802
802
# enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id
@@ -901,8 +901,8 @@ class llama_context_params(ctypes.Structure):
901
901
("n_batch" , ctypes .c_uint32 ),
902
902
("n_ubatch" , ctypes .c_uint32 ),
903
903
("n_seq_max" , ctypes .c_uint32 ),
904
- ("n_threads" , ctypes .c_uint32 ),
905
- ("n_threads_batch" , ctypes .c_uint32 ),
904
+ ("n_threads" , ctypes .c_int32 ),
905
+ ("n_threads_batch" , ctypes .c_int32 ),
906
906
("rope_scaling_type" , ctypes .c_int ),
907
907
("pooling_type" , ctypes .c_int ),
908
908
("attention_type" , ctypes .c_int ),
@@ -1197,6 +1197,16 @@ def llama_numa_init(numa: int, /):
1197
1197
...
1198
1198
1199
1199
1200
+ # // Optional: an auto threadpool gets created in ggml if not passed explicitly
1201
+ # LLAMA_API void llama_attach_threadpool(
1202
+ # struct llama_context * ctx,
1203
+ # ggml_threadpool_t threadpool,
1204
+ # ggml_threadpool_t threadpool_batch);
1205
+
1206
+
1207
+ # LLAMA_API void llama_detach_threadpool(struct llama_context * ctx);
1208
+
1209
+
1200
1210
# // Call once at the end of the program - currently only used for MPI
1201
1211
# LLAMA_API void llama_backend_free(void);
1202
1212
@ctypes_function (
@@ -2478,20 +2488,20 @@ def llama_decode(ctx: llama_context_p, batch: llama_batch, /) -> int:
2478
2488
# // Set the number of threads used for decoding
2479
2489
# // n_threads is the number of threads used for generation (single token)
2480
2490
# // n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)
2481
- # LLAMA_API void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch);
2491
+ # LLAMA_API void llama_set_n_threads(struct llama_context * ctx, int32_t n_threads, int32_t n_threads_batch);
2482
2492
@ctypes_function (
2483
2493
"llama_set_n_threads" ,
2484
2494
[
2485
2495
llama_context_p_ctypes ,
2486
- ctypes .c_uint32 ,
2487
- ctypes .c_uint32 ,
2496
+ ctypes .c_int32 ,
2497
+ ctypes .c_int32 ,
2488
2498
],
2489
2499
None ,
2490
2500
)
2491
2501
def llama_set_n_threads (
2492
2502
ctx : llama_context_p ,
2493
- n_threads : Union [ctypes .c_uint32 , int ],
2494
- n_threads_batch : Union [ctypes .c_uint32 , int ],
2503
+ n_threads : Union [ctypes .c_int32 , int ],
2504
+ n_threads_batch : Union [ctypes .c_int32 , int ],
2495
2505
/ ,
2496
2506
):
2497
2507
"""Set the number of threads used for decoding
@@ -2502,16 +2512,16 @@ def llama_set_n_threads(
2502
2512
2503
2513
2504
2514
# // Get the number of threads used for generation of a single token.
2505
- # LLAMA_API uint32_t llama_n_threads(struct llama_context * ctx);
2506
- @ctypes_function ("llama_n_threads" , [llama_context_p_ctypes ], ctypes .c_uint32 )
2515
+ # LLAMA_API int32_t llama_n_threads(struct llama_context * ctx);
2516
+ @ctypes_function ("llama_n_threads" , [llama_context_p_ctypes ], ctypes .c_int32 )
2507
2517
def llama_n_threads (ctx : llama_context_p , / ) -> int :
2508
2518
"""Get the number of threads used for generation of a single token"""
2509
2519
...
2510
2520
2511
2521
2512
2522
# // Get the number of threads used for prompt and batch processing (multiple token).
2513
- # LLAMA_API uint32_t llama_n_threads_batch(struct llama_context * ctx);
2514
- @ctypes_function ("llama_n_threads_batch" , [llama_context_p_ctypes ], ctypes .c_uint32 )
2523
+ # LLAMA_API int32_t llama_n_threads_batch(struct llama_context * ctx);
2524
+ @ctypes_function ("llama_n_threads_batch" , [llama_context_p_ctypes ], ctypes .c_int32 )
2515
2525
def llama_n_threads_batch (ctx : llama_context_p , / ) -> int :
2516
2526
"""Get the number of threads used for prompt and batch processing (multiple token)"""
2517
2527
...
0 commit comments