Skip to content

Commit 3e50895

Browse files
authored
Merge pull request #200 from APIParkLab/feature/1.5-local-model
Feature/1.5 local model
2 parents 901bef1 + 620bd4c commit 3e50895

File tree

274 files changed

+78349
-1587
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

274 files changed

+78349
-1587
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
/build/
55
/apipark
66
.gitlab-ci.yml
7+
/.vscode/

ai-provider/local/entity.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
package ai_provider_local
2+
3+
import "time"
4+
5+
type Model struct {
6+
Name string `json:"name"`
7+
Model string `json:"model"`
8+
ModifiedAt time.Time `json:"modified_at"`
9+
Size int64 `json:"size"`
10+
Digest string `json:"digest"`
11+
Details ModelDetails `json:"details,omitempty"`
12+
}
13+
14+
// ModelDetails provides details about a model.
15+
type ModelDetails struct {
16+
ParentModel string `json:"parent_model"`
17+
Format string `json:"format"`
18+
Family string `json:"family"`
19+
Families []string `json:"families"`
20+
ParameterSize string `json:"parameter_size"`
21+
QuantizationLevel string `json:"quantization_level"`
22+
}

ai-provider/local/executor.go

Lines changed: 324 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,324 @@
1+
package ai_provider_local
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"sync"
7+
8+
"github.com/ollama/ollama/progress"
9+
10+
"github.com/eolinker/eosc"
11+
"github.com/ollama/ollama/api"
12+
)
13+
14+
var (
15+
taskExecutor = NewAsyncExecutor(100)
16+
)
17+
18+
// Pipeline 结构体,表示每个用户的管道
19+
type Pipeline struct {
20+
id string
21+
channel chan PullMessage
22+
ctx context.Context
23+
cancel context.CancelFunc
24+
}
25+
26+
func (p *Pipeline) Message() <-chan PullMessage {
27+
return p.channel
28+
}
29+
30+
// AsyncExecutor 结构体,管理不同模型的管道和任务队列
31+
type AsyncExecutor struct {
32+
ctx context.Context
33+
cancel context.CancelFunc
34+
mu sync.Mutex
35+
pipelines map[string]*modelPipeline // 以模型为 key,存管道列表
36+
msgQueue chan messageTask // 消息队列
37+
}
38+
39+
type modelPipeline struct {
40+
ctx context.Context
41+
cancel context.CancelFunc
42+
pipelines eosc.Untyped[string, *Pipeline]
43+
pullFn PullCallback
44+
maxSize int
45+
}
46+
47+
func (m *modelPipeline) List() []*Pipeline {
48+
return m.pipelines.List()
49+
}
50+
51+
func (m *modelPipeline) Get(id string) (*Pipeline, bool) {
52+
return m.pipelines.Get(id)
53+
}
54+
55+
func (m *modelPipeline) Set(id string, p *Pipeline) error {
56+
_, ok := m.pipelines.Get(id)
57+
if !ok {
58+
if m.pipelines.Count() > m.maxSize {
59+
return fmt.Errorf("pipeline size exceed %d", m.maxSize)
60+
}
61+
}
62+
m.pipelines.Set(id, p)
63+
return nil
64+
}
65+
66+
func (m *modelPipeline) AddPipeline(id string) (*Pipeline, error) {
67+
ctx, cancel := context.WithCancel(m.ctx)
68+
pipeline := &Pipeline{
69+
ctx: ctx,
70+
cancel: cancel,
71+
id: id,
72+
channel: make(chan PullMessage, 10), // 带缓冲,防止阻塞
73+
}
74+
err := m.Set(id, pipeline)
75+
if err != nil {
76+
return nil, err
77+
}
78+
return pipeline, nil
79+
}
80+
81+
func (m *modelPipeline) Close() {
82+
m.cancel()
83+
ids := m.pipelines.Keys()
84+
for _, id := range ids {
85+
m.ClosePipeline(id)
86+
}
87+
return
88+
}
89+
90+
func (m *modelPipeline) ClosePipeline(id string) {
91+
// 关闭管道
92+
p, has := m.pipelines.Del(id)
93+
if !has {
94+
return
95+
}
96+
p.cancel()
97+
close(p.channel)
98+
}
99+
100+
func newModelPipeline(ctx context.Context, maxSize int) *modelPipeline {
101+
ctx, cancel := context.WithCancel(ctx)
102+
return &modelPipeline{
103+
pipelines: eosc.BuildUntyped[string, *Pipeline](),
104+
ctx: ctx,
105+
cancel: cancel,
106+
maxSize: maxSize,
107+
}
108+
}
109+
110+
// messageTask 结构体,包含模型名和消息内容
111+
type messageTask struct {
112+
message PullMessage
113+
}
114+
115+
type PullMessage struct {
116+
Model string
117+
Status string
118+
Digest string
119+
Total int64
120+
Completed int64
121+
Msg string
122+
}
123+
124+
// NewAsyncExecutor 创建一个新的异步任务执行器
125+
func NewAsyncExecutor(queueSize int) *AsyncExecutor {
126+
ctx, cancel := context.WithCancel(context.Background())
127+
executor := &AsyncExecutor{
128+
ctx: ctx,
129+
cancel: cancel,
130+
pipelines: make(map[string]*modelPipeline), // 以模型为 key,存管道列表
131+
msgQueue: make(chan messageTask, queueSize),
132+
}
133+
executor.StartMessageDistributor()
134+
135+
return executor
136+
}
137+
138+
func (e *AsyncExecutor) GetModelPipeline(model string) (*modelPipeline, bool) {
139+
e.mu.Lock()
140+
defer e.mu.Unlock()
141+
142+
mp, ok := e.pipelines[model]
143+
return mp, ok
144+
}
145+
146+
func (e *AsyncExecutor) SetModelPipeline(model string, mp *modelPipeline) {
147+
e.mu.Lock()
148+
defer e.mu.Unlock()
149+
e.pipelines[model] = mp
150+
}
151+
152+
// ClosePipeline 关闭管道并移除
153+
func (e *AsyncExecutor) ClosePipeline(model string, id string) {
154+
e.mu.Lock()
155+
defer e.mu.Unlock()
156+
mp, ok := e.pipelines[model]
157+
if !ok {
158+
return
159+
}
160+
mp.ClosePipeline(id)
161+
}
162+
163+
// CloseModelPipeline 关闭当前模型所有管道
164+
func (e *AsyncExecutor) CloseModelPipeline(model string) {
165+
e.mu.Lock()
166+
defer e.mu.Unlock()
167+
mp, ok := e.pipelines[model]
168+
if !ok {
169+
return
170+
}
171+
mp.Close()
172+
delete(e.pipelines, model)
173+
}
174+
175+
// StartMessageDistributor 启动消息分发器
176+
func (e *AsyncExecutor) StartMessageDistributor() {
177+
go func() {
178+
for task := range e.msgQueue {
179+
msg := task.message
180+
e.DistributeToModelPipelines(msg.Model, msg)
181+
if msg.Status == "error" || msg.Status == "success" {
182+
mp, has := e.GetModelPipeline(msg.Model)
183+
if has && mp.pullFn != nil {
184+
mp.pullFn(msg)
185+
}
186+
e.CloseModelPipeline(msg.Model)
187+
continue
188+
}
189+
}
190+
}()
191+
}
192+
193+
// DistributeToModelPipelines 仅将消息分发给指定模型的管道
194+
func (e *AsyncExecutor) DistributeToModelPipelines(model string, msg PullMessage) {
195+
e.mu.Lock()
196+
defer e.mu.Unlock()
197+
pipelines, ok := e.pipelines[model]
198+
if !ok {
199+
return
200+
}
201+
for _, pipeline := range pipelines.List() {
202+
select {
203+
case pipeline.channel <- msg:
204+
default:
205+
// 如果管道已满,跳过
206+
}
207+
}
208+
}
209+
210+
type PullCallback func(msg PullMessage) error
211+
212+
func PullModel(model string, id string, fn PullCallback) (*Pipeline, error) {
213+
mp, has := taskExecutor.GetModelPipeline(model)
214+
if !has {
215+
mp = newModelPipeline(taskExecutor.ctx, 100)
216+
mp.pullFn = fn
217+
taskExecutor.SetModelPipeline(model, mp)
218+
}
219+
p, err := mp.AddPipeline(id)
220+
if err != nil {
221+
return nil, err
222+
}
223+
if !has {
224+
var status string
225+
bars := make(map[string]*progress.Bar)
226+
fn := func(resp api.ProgressResponse) error {
227+
if resp.Digest != "" {
228+
bar, ok := bars[resp.Digest]
229+
if !ok {
230+
bar = progress.NewBar(fmt.Sprintf("pulling %s...", resp.Digest[7:19]), resp.Total, resp.Completed)
231+
bars[resp.Digest] = bar
232+
}
233+
bar.Set(resp.Completed)
234+
235+
taskExecutor.msgQueue <- messageTask{
236+
message: PullMessage{
237+
Model: model,
238+
Digest: resp.Digest,
239+
Total: resp.Total,
240+
Completed: resp.Completed,
241+
Msg: bar.String(),
242+
Status: resp.Status,
243+
},
244+
}
245+
} else if status != resp.Status {
246+
taskExecutor.msgQueue <- messageTask{
247+
message: PullMessage{
248+
Model: model,
249+
Digest: resp.Digest,
250+
Total: resp.Total,
251+
Completed: resp.Completed,
252+
Msg: status,
253+
Status: resp.Status,
254+
},
255+
}
256+
}
257+
258+
return nil
259+
}
260+
go func() {
261+
err = client.Pull(mp.ctx, &api.PullRequest{Model: model}, fn)
262+
if err != nil {
263+
taskExecutor.msgQueue <- messageTask{
264+
message: PullMessage{
265+
Model: model,
266+
Status: "error",
267+
Digest: "",
268+
Total: 0,
269+
Completed: 0,
270+
Msg: err.Error(),
271+
},
272+
}
273+
}
274+
}()
275+
276+
}
277+
278+
return p, nil
279+
}
280+
281+
func StopPull(model string) {
282+
taskExecutor.CloseModelPipeline(model)
283+
}
284+
285+
func CancelPipeline(model string, id string) {
286+
taskExecutor.ClosePipeline(model, id)
287+
}
288+
289+
func RemoveModel(model string) error {
290+
taskExecutor.CloseModelPipeline(model)
291+
err := client.Delete(context.Background(), &api.DeleteRequest{Model: model})
292+
if err != nil {
293+
if err.Error() == fmt.Sprintf("model '%s' not found", model) {
294+
return nil
295+
}
296+
}
297+
return err
298+
}
299+
300+
func ModelsInstalled() ([]Model, error) {
301+
result, err := client.List(context.Background())
302+
if err != nil {
303+
return nil, err
304+
}
305+
models := make([]Model, 0, len(result.Models))
306+
for _, m := range result.Models {
307+
models = append(models, Model{
308+
Name: m.Name,
309+
Model: m.Model,
310+
ModifiedAt: m.ModifiedAt,
311+
Size: m.Size,
312+
Digest: m.Digest,
313+
Details: ModelDetails{
314+
ParentModel: m.Details.ParentModel,
315+
Format: m.Details.Format,
316+
Family: m.Details.Family,
317+
Families: m.Details.Families,
318+
ParameterSize: m.Details.ParameterSize,
319+
QuantizationLevel: m.Details.QuantizationLevel,
320+
},
321+
})
322+
}
323+
return models, nil
324+
}

0 commit comments

Comments
 (0)