|
| 1 | +package ai_provider_local |
| 2 | + |
| 3 | +import ( |
| 4 | + "context" |
| 5 | + "fmt" |
| 6 | + "sync" |
| 7 | + |
| 8 | + "github.com/ollama/ollama/progress" |
| 9 | + |
| 10 | + "github.com/eolinker/eosc" |
| 11 | + "github.com/ollama/ollama/api" |
| 12 | +) |
| 13 | + |
| 14 | +var ( |
| 15 | + taskExecutor = NewAsyncExecutor(100) |
| 16 | +) |
| 17 | + |
| 18 | +// Pipeline 结构体,表示每个用户的管道 |
| 19 | +type Pipeline struct { |
| 20 | + id string |
| 21 | + channel chan PullMessage |
| 22 | + ctx context.Context |
| 23 | + cancel context.CancelFunc |
| 24 | +} |
| 25 | + |
| 26 | +func (p *Pipeline) Message() <-chan PullMessage { |
| 27 | + return p.channel |
| 28 | +} |
| 29 | + |
| 30 | +// AsyncExecutor 结构体,管理不同模型的管道和任务队列 |
| 31 | +type AsyncExecutor struct { |
| 32 | + ctx context.Context |
| 33 | + cancel context.CancelFunc |
| 34 | + mu sync.Mutex |
| 35 | + pipelines map[string]*modelPipeline // 以模型为 key,存管道列表 |
| 36 | + msgQueue chan messageTask // 消息队列 |
| 37 | +} |
| 38 | + |
| 39 | +type modelPipeline struct { |
| 40 | + ctx context.Context |
| 41 | + cancel context.CancelFunc |
| 42 | + pipelines eosc.Untyped[string, *Pipeline] |
| 43 | + pullFn PullCallback |
| 44 | + maxSize int |
| 45 | +} |
| 46 | + |
| 47 | +func (m *modelPipeline) List() []*Pipeline { |
| 48 | + return m.pipelines.List() |
| 49 | +} |
| 50 | + |
| 51 | +func (m *modelPipeline) Get(id string) (*Pipeline, bool) { |
| 52 | + return m.pipelines.Get(id) |
| 53 | +} |
| 54 | + |
| 55 | +func (m *modelPipeline) Set(id string, p *Pipeline) error { |
| 56 | + _, ok := m.pipelines.Get(id) |
| 57 | + if !ok { |
| 58 | + if m.pipelines.Count() > m.maxSize { |
| 59 | + return fmt.Errorf("pipeline size exceed %d", m.maxSize) |
| 60 | + } |
| 61 | + } |
| 62 | + m.pipelines.Set(id, p) |
| 63 | + return nil |
| 64 | +} |
| 65 | + |
| 66 | +func (m *modelPipeline) AddPipeline(id string) (*Pipeline, error) { |
| 67 | + ctx, cancel := context.WithCancel(m.ctx) |
| 68 | + pipeline := &Pipeline{ |
| 69 | + ctx: ctx, |
| 70 | + cancel: cancel, |
| 71 | + id: id, |
| 72 | + channel: make(chan PullMessage, 10), // 带缓冲,防止阻塞 |
| 73 | + } |
| 74 | + err := m.Set(id, pipeline) |
| 75 | + if err != nil { |
| 76 | + return nil, err |
| 77 | + } |
| 78 | + return pipeline, nil |
| 79 | +} |
| 80 | + |
| 81 | +func (m *modelPipeline) Close() { |
| 82 | + m.cancel() |
| 83 | + ids := m.pipelines.Keys() |
| 84 | + for _, id := range ids { |
| 85 | + m.ClosePipeline(id) |
| 86 | + } |
| 87 | + return |
| 88 | +} |
| 89 | + |
| 90 | +func (m *modelPipeline) ClosePipeline(id string) { |
| 91 | + // 关闭管道 |
| 92 | + p, has := m.pipelines.Del(id) |
| 93 | + if !has { |
| 94 | + return |
| 95 | + } |
| 96 | + p.cancel() |
| 97 | + close(p.channel) |
| 98 | +} |
| 99 | + |
| 100 | +func newModelPipeline(ctx context.Context, maxSize int) *modelPipeline { |
| 101 | + ctx, cancel := context.WithCancel(ctx) |
| 102 | + return &modelPipeline{ |
| 103 | + pipelines: eosc.BuildUntyped[string, *Pipeline](), |
| 104 | + ctx: ctx, |
| 105 | + cancel: cancel, |
| 106 | + maxSize: maxSize, |
| 107 | + } |
| 108 | +} |
| 109 | + |
| 110 | +// messageTask 结构体,包含模型名和消息内容 |
| 111 | +type messageTask struct { |
| 112 | + message PullMessage |
| 113 | +} |
| 114 | + |
| 115 | +type PullMessage struct { |
| 116 | + Model string |
| 117 | + Status string |
| 118 | + Digest string |
| 119 | + Total int64 |
| 120 | + Completed int64 |
| 121 | + Msg string |
| 122 | +} |
| 123 | + |
| 124 | +// NewAsyncExecutor 创建一个新的异步任务执行器 |
| 125 | +func NewAsyncExecutor(queueSize int) *AsyncExecutor { |
| 126 | + ctx, cancel := context.WithCancel(context.Background()) |
| 127 | + executor := &AsyncExecutor{ |
| 128 | + ctx: ctx, |
| 129 | + cancel: cancel, |
| 130 | + pipelines: make(map[string]*modelPipeline), // 以模型为 key,存管道列表 |
| 131 | + msgQueue: make(chan messageTask, queueSize), |
| 132 | + } |
| 133 | + executor.StartMessageDistributor() |
| 134 | + |
| 135 | + return executor |
| 136 | +} |
| 137 | + |
| 138 | +func (e *AsyncExecutor) GetModelPipeline(model string) (*modelPipeline, bool) { |
| 139 | + e.mu.Lock() |
| 140 | + defer e.mu.Unlock() |
| 141 | + |
| 142 | + mp, ok := e.pipelines[model] |
| 143 | + return mp, ok |
| 144 | +} |
| 145 | + |
| 146 | +func (e *AsyncExecutor) SetModelPipeline(model string, mp *modelPipeline) { |
| 147 | + e.mu.Lock() |
| 148 | + defer e.mu.Unlock() |
| 149 | + e.pipelines[model] = mp |
| 150 | +} |
| 151 | + |
| 152 | +// ClosePipeline 关闭管道并移除 |
| 153 | +func (e *AsyncExecutor) ClosePipeline(model string, id string) { |
| 154 | + e.mu.Lock() |
| 155 | + defer e.mu.Unlock() |
| 156 | + mp, ok := e.pipelines[model] |
| 157 | + if !ok { |
| 158 | + return |
| 159 | + } |
| 160 | + mp.ClosePipeline(id) |
| 161 | +} |
| 162 | + |
| 163 | +// CloseModelPipeline 关闭当前模型所有管道 |
| 164 | +func (e *AsyncExecutor) CloseModelPipeline(model string) { |
| 165 | + e.mu.Lock() |
| 166 | + defer e.mu.Unlock() |
| 167 | + mp, ok := e.pipelines[model] |
| 168 | + if !ok { |
| 169 | + return |
| 170 | + } |
| 171 | + mp.Close() |
| 172 | + delete(e.pipelines, model) |
| 173 | +} |
| 174 | + |
| 175 | +// StartMessageDistributor 启动消息分发器 |
| 176 | +func (e *AsyncExecutor) StartMessageDistributor() { |
| 177 | + go func() { |
| 178 | + for task := range e.msgQueue { |
| 179 | + msg := task.message |
| 180 | + e.DistributeToModelPipelines(msg.Model, msg) |
| 181 | + if msg.Status == "error" || msg.Status == "success" { |
| 182 | + mp, has := e.GetModelPipeline(msg.Model) |
| 183 | + if has && mp.pullFn != nil { |
| 184 | + mp.pullFn(msg) |
| 185 | + } |
| 186 | + e.CloseModelPipeline(msg.Model) |
| 187 | + continue |
| 188 | + } |
| 189 | + } |
| 190 | + }() |
| 191 | +} |
| 192 | + |
| 193 | +// DistributeToModelPipelines 仅将消息分发给指定模型的管道 |
| 194 | +func (e *AsyncExecutor) DistributeToModelPipelines(model string, msg PullMessage) { |
| 195 | + e.mu.Lock() |
| 196 | + defer e.mu.Unlock() |
| 197 | + pipelines, ok := e.pipelines[model] |
| 198 | + if !ok { |
| 199 | + return |
| 200 | + } |
| 201 | + for _, pipeline := range pipelines.List() { |
| 202 | + select { |
| 203 | + case pipeline.channel <- msg: |
| 204 | + default: |
| 205 | + // 如果管道已满,跳过 |
| 206 | + } |
| 207 | + } |
| 208 | +} |
| 209 | + |
| 210 | +type PullCallback func(msg PullMessage) error |
| 211 | + |
| 212 | +func PullModel(model string, id string, fn PullCallback) (*Pipeline, error) { |
| 213 | + mp, has := taskExecutor.GetModelPipeline(model) |
| 214 | + if !has { |
| 215 | + mp = newModelPipeline(taskExecutor.ctx, 100) |
| 216 | + mp.pullFn = fn |
| 217 | + taskExecutor.SetModelPipeline(model, mp) |
| 218 | + } |
| 219 | + p, err := mp.AddPipeline(id) |
| 220 | + if err != nil { |
| 221 | + return nil, err |
| 222 | + } |
| 223 | + if !has { |
| 224 | + var status string |
| 225 | + bars := make(map[string]*progress.Bar) |
| 226 | + fn := func(resp api.ProgressResponse) error { |
| 227 | + if resp.Digest != "" { |
| 228 | + bar, ok := bars[resp.Digest] |
| 229 | + if !ok { |
| 230 | + bar = progress.NewBar(fmt.Sprintf("pulling %s...", resp.Digest[7:19]), resp.Total, resp.Completed) |
| 231 | + bars[resp.Digest] = bar |
| 232 | + } |
| 233 | + bar.Set(resp.Completed) |
| 234 | + |
| 235 | + taskExecutor.msgQueue <- messageTask{ |
| 236 | + message: PullMessage{ |
| 237 | + Model: model, |
| 238 | + Digest: resp.Digest, |
| 239 | + Total: resp.Total, |
| 240 | + Completed: resp.Completed, |
| 241 | + Msg: bar.String(), |
| 242 | + Status: resp.Status, |
| 243 | + }, |
| 244 | + } |
| 245 | + } else if status != resp.Status { |
| 246 | + taskExecutor.msgQueue <- messageTask{ |
| 247 | + message: PullMessage{ |
| 248 | + Model: model, |
| 249 | + Digest: resp.Digest, |
| 250 | + Total: resp.Total, |
| 251 | + Completed: resp.Completed, |
| 252 | + Msg: status, |
| 253 | + Status: resp.Status, |
| 254 | + }, |
| 255 | + } |
| 256 | + } |
| 257 | + |
| 258 | + return nil |
| 259 | + } |
| 260 | + go func() { |
| 261 | + err = client.Pull(mp.ctx, &api.PullRequest{Model: model}, fn) |
| 262 | + if err != nil { |
| 263 | + taskExecutor.msgQueue <- messageTask{ |
| 264 | + message: PullMessage{ |
| 265 | + Model: model, |
| 266 | + Status: "error", |
| 267 | + Digest: "", |
| 268 | + Total: 0, |
| 269 | + Completed: 0, |
| 270 | + Msg: err.Error(), |
| 271 | + }, |
| 272 | + } |
| 273 | + } |
| 274 | + }() |
| 275 | + |
| 276 | + } |
| 277 | + |
| 278 | + return p, nil |
| 279 | +} |
| 280 | + |
| 281 | +func StopPull(model string) { |
| 282 | + taskExecutor.CloseModelPipeline(model) |
| 283 | +} |
| 284 | + |
| 285 | +func CancelPipeline(model string, id string) { |
| 286 | + taskExecutor.ClosePipeline(model, id) |
| 287 | +} |
| 288 | + |
| 289 | +func RemoveModel(model string) error { |
| 290 | + taskExecutor.CloseModelPipeline(model) |
| 291 | + err := client.Delete(context.Background(), &api.DeleteRequest{Model: model}) |
| 292 | + if err != nil { |
| 293 | + if err.Error() == fmt.Sprintf("model '%s' not found", model) { |
| 294 | + return nil |
| 295 | + } |
| 296 | + } |
| 297 | + return err |
| 298 | +} |
| 299 | + |
| 300 | +func ModelsInstalled() ([]Model, error) { |
| 301 | + result, err := client.List(context.Background()) |
| 302 | + if err != nil { |
| 303 | + return nil, err |
| 304 | + } |
| 305 | + models := make([]Model, 0, len(result.Models)) |
| 306 | + for _, m := range result.Models { |
| 307 | + models = append(models, Model{ |
| 308 | + Name: m.Name, |
| 309 | + Model: m.Model, |
| 310 | + ModifiedAt: m.ModifiedAt, |
| 311 | + Size: m.Size, |
| 312 | + Digest: m.Digest, |
| 313 | + Details: ModelDetails{ |
| 314 | + ParentModel: m.Details.ParentModel, |
| 315 | + Format: m.Details.Format, |
| 316 | + Family: m.Details.Family, |
| 317 | + Families: m.Details.Families, |
| 318 | + ParameterSize: m.Details.ParameterSize, |
| 319 | + QuantizationLevel: m.Details.QuantizationLevel, |
| 320 | + }, |
| 321 | + }) |
| 322 | + } |
| 323 | + return models, nil |
| 324 | +} |
0 commit comments