From c4122ea406981f7d8db976d6097331f9133dd7d9 Mon Sep 17 00:00:00 2001 From: venjiang Date: Wed, 16 Apr 2025 20:25:10 +0800 Subject: [PATCH 1/5] mcp supports --- ai/register.go | 52 ++++++ cli/serve.go | 27 +++- core/server.go | 6 +- go.mod | 6 +- go.sum | 44 +----- pkg/bridge/ai/ai.go | 81 +++++++++- pkg/bridge/ai/ai_test.go | 82 +++++----- pkg/bridge/ai/register/register.go | 59 ++----- pkg/bridge/ai/service.go | 17 +- pkg/bridge/ai/service_test.go | 6 +- pkg/bridge/{ai => llm}/api_server.go | 150 +++++------------- pkg/bridge/{ai => llm}/api_server_test.go | 8 +- pkg/bridge/mcp/mcp.go | 54 +++++++ pkg/bridge/mcp/mcp_server.go | 184 ++++++++++++++++++++++ pkg/bridge/mcp/server.go | 181 +++++++++++++++++++++ 15 files changed, 696 insertions(+), 261 deletions(-) create mode 100644 ai/register.go rename pkg/bridge/{ai => llm}/api_server.go (58%) rename pkg/bridge/{ai => llm}/api_server_test.go (96%) create mode 100644 pkg/bridge/mcp/mcp.go create mode 100644 pkg/bridge/mcp/mcp_server.go create mode 100644 pkg/bridge/mcp/server.go diff --git a/ai/register.go b/ai/register.go new file mode 100644 index 000000000..ebe9ad531 --- /dev/null +++ b/ai/register.go @@ -0,0 +1,52 @@ +package ai + +import ( + "sync" + + "github.com/sashabaranov/go-openai" + "github.com/yomorun/yomo/core/metadata" +) + +var ( + mu sync.Mutex + defaultRegister Register +) + +// SetRegister sets the default register +func SetRegister(r Register) { + mu.Lock() + defer mu.Unlock() + defaultRegister = r +} + +// GetRegister gets the default register +func GetRegister() Register { + mu.Lock() + defer mu.Unlock() + return defaultRegister +} + +// ListToolCalls returns the list of tool calls +func ListToolCalls(md metadata.M) ([]openai.Tool, error) { + return defaultRegister.ListToolCalls(md) +} + +// RegisterFunction registers a function calling function +func RegisterFunction(functionDefinition *openai.FunctionDefinition, connID uint64, md metadata.M) error { + return defaultRegister.RegisterFunction(functionDefinition, connID, md) +} + +// UnregisterFunction unregisters a function calling function +func UnregisterFunction(connID uint64, md metadata.M) { + defaultRegister.UnregisterFunction(connID, md) +} + +// Register provides an stateful register for registering and unregistering functions +type Register interface { + // ListToolCalls returns the list of tool calls + ListToolCalls(md metadata.M) ([]openai.Tool, error) + // RegisterFunction registers a function calling function + RegisterFunction(fd *openai.FunctionDefinition, connID uint64, md metadata.M) error + // UnregisterFunction unregisters a function calling function + UnregisterFunction(connID uint64, md metadata.M) +} diff --git a/cli/serve.go b/cli/serve.go index e03c34f39..68207d68a 100644 --- a/cli/serve.go +++ b/cli/serve.go @@ -44,6 +44,8 @@ import ( "github.com/yomorun/yomo/pkg/bridge/ai/provider/vertexai" "github.com/yomorun/yomo/pkg/bridge/ai/provider/vllm" "github.com/yomorun/yomo/pkg/bridge/ai/provider/xai" + "github.com/yomorun/yomo/pkg/bridge/llm" + "github.com/yomorun/yomo/pkg/bridge/mcp" ) // serveCmd represents the serve command @@ -99,6 +101,16 @@ var serveCmd = &cobra.Command{ // add AI connection middleware options = append(options, yomo.WithFrameListener(listener)) } + // check and parse the mcp server config + mcpConfig, err := mcp.ParseConfig(bridgeConf) + if err != nil { + if err == mcp.ErrConfigNotFound { + ylog.Warn("mcp server is disabled") + } else { + log.FailureStatusEvent(os.Stdout, "%s", err.Error()) + return + } + } // new zipper zipper, err := yomo.NewZipper( conf.Name, @@ -110,8 +122,8 @@ var serveCmd = &cobra.Command{ } zipper.Logger().Info("using config file", "file_path", config) - // AI Server if aiConfig != nil { + // AI Server // register the llm provider registerAIProvider(aiConfig) // start the llm api server @@ -122,12 +134,23 @@ var serveCmd = &cobra.Command{ conn2, _ := listener.Dial() reducer := ai.NewReducer(conn2, auth.NewCredential(fmt.Sprintf("token:%s", tokenString))) - err := ai.Serve(aiConfig, ylog.Default(), source, reducer) + err := llm.Serve(aiConfig, ylog.Default(), source, reducer) if err != nil { log.FailureStatusEvent(os.Stdout, "%s", err.Error()) return } }() + // MCP Server + if mcpConfig != nil { + defer mcp.Stop() + go func() { + err = mcp.Start(mcpConfig, aiConfig, listenAddr, ylog.Default()) + if err != nil { + log.FailureStatusEvent(os.Stdout, "%s", err.Error()) + return + } + }() + } } // start the zipper diff --git a/core/server.go b/core/server.go index 83bcb9f9e..34406e0f4 100644 --- a/core/server.go +++ b/core/server.go @@ -19,7 +19,7 @@ import ( // authentication implements, Currently, only token authentication is implemented _ "github.com/yomorun/yomo/pkg/auth" - "github.com/yomorun/yomo/pkg/bridge/ai/register" + "github.com/yomorun/yomo/pkg/frame-codec/y3codec" yquic "github.com/yomorun/yomo/pkg/listener/quic" pkgtls "github.com/yomorun/yomo/pkg/tls" @@ -187,7 +187,7 @@ func (s *Server) handleFrameConn(fconn frame.Conn, logger *slog.Logger) { if conn.ClientType() == ClientTypeStreamFunction { s.router.Remove(conn.ID()) - register.UnregisterFunction(conn.ID(), conn.Metadata()) + ai.UnregisterFunction(conn.ID(), conn.Metadata()) } _ = s.connector.Remove(conn.ID()) } @@ -284,7 +284,7 @@ func (s *Server) tryRegisterFunctionDefinition(hf *frame.HandshakeFrame, conn *C if err := json.Unmarshal([]byte(definition), &fd); err != nil { return fmt.Errorf("unmarshal function definition error: %s", err.Error()) } - if err := register.RegisterFunction(&fd, conn.ID(), md); err != nil { + if err := ai.RegisterFunction(&fd, conn.ID(), md); err != nil { return err } s.logger.Info("register ai function success", "function_name", fd.Name, "definition", string(definition)) diff --git a/go.mod b/go.mod index 8a7ff03aa..cc75b103f 100644 --- a/go.mod +++ b/go.mod @@ -14,6 +14,7 @@ require ( github.com/invopop/jsonschema v0.13.0 github.com/joho/godotenv v1.5.1 github.com/lmittmann/tint v1.0.7 + github.com/mark3labs/mcp-go v0.20.1 github.com/matoous/go-nanoid/v2 v2.1.0 github.com/quic-go/quic-go v0.50.1 github.com/robfig/cron/v3 v3.0.1 @@ -25,8 +26,6 @@ require ( github.com/spf13/viper v1.20.0 github.com/stretchr/testify v1.10.0 github.com/tetratelabs/wazero v1.9.0 - github.com/tidwall/gjson v1.18.0 - github.com/tidwall/sjson v1.2.5 github.com/vmihailenco/msgpack/v5 v5.4.1 github.com/yomorun/y3 v1.0.5 go.opentelemetry.io/otel v1.35.0 @@ -83,12 +82,15 @@ require ( github.com/spf13/afero v1.14.0 // indirect github.com/spf13/cast v1.7.1 // indirect github.com/subosito/gotenv v1.6.0 // indirect + github.com/tidwall/gjson v1.18.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect + github.com/tidwall/sjson v1.2.5 // indirect github.com/tklauser/go-sysconf v0.3.15 // indirect github.com/tklauser/numcpus v0.10.0 // indirect github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect github.com/yusufpapurcu/wmi v1.2.4 // indirect go.opentelemetry.io/auto/sdk v1.1.0 // indirect go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.60.0 // indirect diff --git a/go.sum b/go.sum index 79713e280..44d3c305b 100644 --- a/go.sum +++ b/go.sum @@ -75,7 +75,6 @@ github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE= github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78= -github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI= github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8= github.com/go-viper/mapstructure/v2 v2.2.1 h1:ZAaOCxANMuZx5RCeg0mBdEZk7DZasvvZIxtHqx8aGss= @@ -99,8 +98,6 @@ github.com/google/go-github v17.0.0+incompatible/go.mod h1:zLgOLi98H3fifZn+44m+u github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= -github.com/google/pprof v0.0.0-20250315033105-103756e64e1d h1:tx51Lf+wdE+aavqH8TcPJoCjTf4cE8hrMzROghCely0= -github.com/google/pprof v0.0.0-20250315033105-103756e64e1d/go.mod h1:vavhavw2zAxS5dIdcRluK6cSGGPlZynqzFM8NdvU144= github.com/google/pprof v0.0.0-20250403155104-27863c87afa6 h1:BHT72Gu3keYf3ZEu2J0b1vyeLSOYI8bm5wbJM/8yDe8= github.com/google/pprof v0.0.0-20250403155104-27863c87afa6/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0= @@ -146,6 +143,8 @@ github.com/lunixbochs/vtclean v1.0.0/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm github.com/mailru/easyjson v0.0.0-20190312143242-1de009706dbe/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4= github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= +github.com/mark3labs/mcp-go v0.20.1 h1:E1Bbx9K8d8kQmDZ1QHblM38c7UU2evQ2LlkANk1U/zw= +github.com/mark3labs/mcp-go v0.20.1/go.mod h1:KmJndYv7GIgcPVwEKJjNcbhVQ+hJGJhrCCB/9xITzpE= github.com/matoous/go-nanoid/v2 v2.1.0 h1:P64+dmq21hhWdtvZfEAofnvJULaRR1Yib0+PnU669bE= github.com/matoous/go-nanoid/v2 v2.1.0/go.mod h1:KlbGNQ+FhrUNIHUxZdL63t7tl4LaPkZNpUULS8H4uVM= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= @@ -158,13 +157,10 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJ github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo= github.com/neelance/sourcemap v0.0.0-20151028013722-8c68805598ab/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM= -github.com/onsi/ginkgo/v2 v2.23.0 h1:FA1xjp8ieYDzlgS5ABTpdUDB7wtngggONc8a7ku2NqQ= -github.com/onsi/ginkgo/v2 v2.23.0/go.mod h1:zXTP6xIp3U8aVuXN8ENK9IXRaTjFnpVB9mGmaSRvxnM= github.com/onsi/ginkgo/v2 v2.23.4 h1:ktYTpKJAVZnDT4VjxSbiBenUjmlL/5QkBEocaWXiQus= github.com/onsi/ginkgo/v2 v2.23.4/go.mod h1:Bt66ApGPBFzHyR+JO10Zbt0Gsp4uWxu5mIOTusL46e8= -github.com/onsi/gomega v1.36.2 h1:koNYke6TVk6ZmnyHrCXba/T/MoLBXFjeC1PtvYgw0A8= -github.com/onsi/gomega v1.36.2/go.mod h1:DdwyADRjrc825LhMEkD76cHR5+pUnjhUN8GlHlRPHzY= github.com/onsi/gomega v1.36.3 h1:hID7cr8t3Wp26+cYnfcjR6HpJ00fdogN6dqZ1t6IylU= +github.com/onsi/gomega v1.36.3/go.mod h1:8D9+Txp43QWKhM24yyOBEdpkzN8FvJyAwecBgsU4KU0= github.com/openzipkin/zipkin-go v0.1.1/go.mod h1:NtoC/o8u3JlF1lSlyPNswIbeQH9bJTmOf0Erfk+hxe8= github.com/pelletier/go-toml/v2 v2.2.3 h1:YmeHyLY8mFWbdkNWwpr+qIL2bEqT0o95WSdkNHvL12M= github.com/pelletier/go-toml/v2 v2.2.3/go.mod h1:MfCQTFTvCcUyyvvwm1+G6H/jORL20Xlb6rzQu9GuUkc= @@ -174,12 +170,12 @@ github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRI github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 h1:o4JXh1EVt9k/+g42oCprj/FisM4qX9L3sZB3upGN2ZU= github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE= +github.com/prashantv/gostub v1.1.0 h1:BTyx3RfQjRHnUWaGF9oQos79AlQ5k8WNktv7VGvVH4g= +github.com/prashantv/gostub v1.1.0/go.mod h1:A5zLQHz7ieHGG7is6LLXLz7I8+3LZzsrV0P1IAHhP5U= github.com/prometheus/client_golang v0.8.0/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/common v0.0.0-20180801064454-c7de2306084e/go.mod h1:daVV7qP5qjZbuso7PdcryaAu0sAZbrN9i7WWcTMWvro= github.com/prometheus/procfs v0.0.0-20180725123919-05ee40e3a273/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= -github.com/quic-go/quic-go v0.50.0 h1:3H/ld1pa3CYhkcc20TPIyG1bNsdhn9qZBGN3b9/UyUo= -github.com/quic-go/quic-go v0.50.0/go.mod h1:Vim6OmUvlYdwBhXP9ZVrtGmCMWa3wEqhq3NgYrI8b4E= github.com/quic-go/quic-go v0.50.1 h1:unsgjFIUqW8a2oopkY7YNONpV1gYND6Nt9hnt1PN94Q= github.com/quic-go/quic-go v0.50.1/go.mod h1:Vim6OmUvlYdwBhXP9ZVrtGmCMWa3wEqhq3NgYrI8b4E= github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= @@ -190,8 +186,6 @@ github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/sagikazarmark/locafero v0.8.0 h1:mXaMVw7IqxNBxfv3LdWt9MDmcWDQ1fagDH918lOdVaQ= github.com/sagikazarmark/locafero v0.8.0/go.mod h1:UBUyz37V+EdMS3hDF3QWIiVr/2dPrx49OMO0Bn0hJqk= -github.com/sashabaranov/go-openai v1.38.0 h1:hNN5uolKwdbpiqOn7l+Z2alch/0n0rSFyg4n+GZxR5k= -github.com/sashabaranov/go-openai v1.38.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/sashabaranov/go-openai v1.38.1 h1:TtZabbFQZa1nEni/IhVtDF/WQjVqDgd+cWR5OeddzF8= github.com/sashabaranov/go-openai v1.38.1/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/second-state/WasmEdge-go v0.14.0 h1:6p4uXVUkUhLQW1z4wGe9nFuabF9S0lQG5TF+o6bnf5E= @@ -273,6 +267,8 @@ github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/ github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= github.com/yomorun/y3 v1.0.5 h1:1qoZrDX+47hgU2pVJgoCEpeeXEOqml/do5oHjF9Wef4= github.com/yomorun/y3 v1.0.5/go.mod h1:+zwvZrKHe8D3fTMXNTsUsZXuI+kYxv3LRA2fSJEoWbo= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0= github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= go.opencensus.io v0.18.0/go.mod h1:vKdFvxhtzZ9onBp9VKHK8z/sRpBMnKAsufL7wlDrCOA= @@ -302,8 +298,6 @@ go.uber.org/automaxprocs v1.6.0 h1:O3y2/QNTOdbF+e/dpXNNW7Rx2hZ4sTIPyybbxyNqTUs= go.uber.org/automaxprocs v1.6.0/go.mod h1:ifeIMSnPZuznNm6jmdzmU3/bfk01Fe2fotchwEFJ8r8= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= -go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU= -go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM= go.uber.org/mock v0.5.1 h1:ASgazW/qBmR+A32MYFDB6E2POoTgOwT509VP0CT/fjs= go.uber.org/mock v0.5.1/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= @@ -313,13 +307,9 @@ golang.org/x/build v0.0.0-20190111050920-041ab4dc3f9d/go.mod h1:OWs+y06UdEOHN4y+ golang.org/x/crypto v0.0.0-20181030102418-4d3f4d9ffa16/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190313024323-a1f597ede03a/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= -golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE= golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= -golang.org/x/exp v0.0.0-20250305212735-054e65f0b394 h1:nDVHiLt8aIbd/VzvPWN6kSOPE7+F/fNFDSXLVYkE/Iw= -golang.org/x/exp v0.0.0-20250305212735-054e65f0b394/go.mod h1:sIifuuw/Yco/y6yb6+bDNfyeQ/MdPUy/hKEMYQV17cM= golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 h1:R84qjqJb5nVJMxqWYb3np9L5ZsaDtB+a39EqjV0JSUM= golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0/go.mod h1:S9Xr4PYopiDyqSyp5NjCrhFrqg6A5zA2E/iPHPhqnS8= golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= @@ -335,8 +325,6 @@ golang.org/x/net v0.0.0-20181106065722-10aee1819953/go.mod h1:mL1N/T3taQHkDXs73r golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190313220215-9f648a60d977/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c= -golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY= golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= @@ -350,8 +338,6 @@ golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw= -golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610= golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -363,18 +349,12 @@ golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= -golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20= golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= -golang.org/x/term v0.30.0 h1:PQ39fJZ+mfadBm0y5WlL4vlM7Sx1Hgf13sMIY2+QS9Y= -golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g= golang.org/x/term v0.31.0 h1:erwDkOK1Msy6offm1mOgvspSkslFnIGsFnxOKoufg3o= golang.org/x/term v0.31.0/go.mod h1:R4BeIy7D95HzImkxGkTW1UQTtP54tio2RyHz7PwK0aw= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= -golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0= golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= @@ -385,8 +365,6 @@ golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGm golang.org/x/tools v0.0.0-20181030000716-a0a13e073c7b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= -golang.org/x/tools v0.31.0 h1:0EedkvKDbh+qistFTd0Bcwe/YLh4vHwWEkiI0toFIBU= -golang.org/x/tools v0.31.0/go.mod h1:naFTU+Cev749tSJRXJlna0T3WxKvb1kWEx15xA4SdmQ= golang.org/x/tools v0.32.0 h1:Q7N1vhpkQv7ybVzLFtTjvQya2ewbwNDZzUgfXGqtMWU= golang.org/x/tools v0.32.0/go.mod h1:ZxrU41P/wAbZD8EDa6dDCa6XfpkhJ7HFMjHJXfBDu8s= google.golang.org/api v0.0.0-20180910000450-7ca32eb868bf/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= @@ -405,24 +383,16 @@ google.golang.org/genproto v0.0.0-20181202183823-bd91e49a0898/go.mod h1:7Ep/1NZk google.golang.org/genproto v0.0.0-20190306203927-b5d61aea6440/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= google.golang.org/genproto v0.0.0-20250313205543-e70fdf4c4cb4 h1:kCjWYliqPA8g5z87mbjnf/cdgQqMzBfp9xYre5qKu2A= google.golang.org/genproto v0.0.0-20250313205543-e70fdf4c4cb4/go.mod h1:SqIx1NV9hcvqdLHo7uNZDS5lrUJybQ3evo3+z/WBfA0= -google.golang.org/genproto/googleapis/api v0.0.0-20250313205543-e70fdf4c4cb4 h1:IFnXJq3UPB3oBREOodn1v1aGQeZYQclEmvWRMN0PSsY= -google.golang.org/genproto/googleapis/api v0.0.0-20250313205543-e70fdf4c4cb4/go.mod h1:c8q6Z6OCqnfVIqUFJkCzKcrj8eCvUrz+K4KRzSTuANg= google.golang.org/genproto/googleapis/api v0.0.0-20250409194420-de1ac958c67a h1:OQ7sHVzkx6L57dQpzUS4ckfWJ51KDH74XHTDe23xWAs= google.golang.org/genproto/googleapis/api v0.0.0-20250409194420-de1ac958c67a/go.mod h1:2R6XrVC8Oc08GlNh8ujEpc7HkLiEZ16QeY7FxIs20ac= -google.golang.org/genproto/googleapis/rpc v0.0.0-20250313205543-e70fdf4c4cb4 h1:iK2jbkWL86DXjEx0qiHcRE9dE4/Ahua5k6V8OWFb//c= -google.golang.org/genproto/googleapis/rpc v0.0.0-20250313205543-e70fdf4c4cb4/go.mod h1:LuRYeWDFV6WOn90g357N17oMCaxpgCnbi/44qJvDn2I= google.golang.org/genproto/googleapis/rpc v0.0.0-20250409194420-de1ac958c67a h1:GIqLhp/cYUkuGuiT+vJk8vhOP86L4+SP5j8yXgeVpvI= google.golang.org/genproto/googleapis/rpc v0.0.0-20250409194420-de1ac958c67a/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A= google.golang.org/grpc v1.14.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw= google.golang.org/grpc v1.16.0/go.mod h1:0JHn/cJsOMiMfNA9+DeHDlAU7KAAB5GDlYFpa9MZMio= google.golang.org/grpc v1.17.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= -google.golang.org/grpc v1.71.0 h1:kF77BGdPTQ4/JZWMlb9VpJ5pa25aqvVqogsxNHHdeBg= -google.golang.org/grpc v1.71.0/go.mod h1:H0GRtasmQOh9LkFoCPDu3ZrwUtD1YGE+b2vYBYd/8Ec= google.golang.org/grpc v1.71.1 h1:ffsFWr7ygTUscGPI0KKK6TLrGz0476KUvvsbqWK0rPI= google.golang.org/grpc v1.71.1/go.mod h1:H0GRtasmQOh9LkFoCPDu3ZrwUtD1YGE+b2vYBYd/8Ec= -google.golang.org/protobuf v1.36.5 h1:tPhr+woSbjfYvY6/GPufUoYizxw1cF/yFoxJ2fmpwlM= -google.golang.org/protobuf v1.36.5/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY= google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/pkg/bridge/ai/ai.go b/pkg/bridge/ai/ai.go index a2d80416d..7bfb656a5 100644 --- a/pkg/bridge/ai/ai.go +++ b/pkg/bridge/ai/ai.go @@ -2,18 +2,31 @@ package ai import ( + "context" "errors" "net" + "time" "github.com/yomorun/yomo/core/ylog" + "go.opentelemetry.io/otel/trace" + "go.opentelemetry.io/otel/trace/noop" "gopkg.in/yaml.v3" ) +const ( + // DefaultZipperAddr is the default endpoint of the zipper + DefaultZipperAddr = "localhost:9000" +) + var ( // ErrConfigNotFound is the error when the ai config was not found ErrConfigNotFound = errors.New("ai config was not found") // ErrConfigFormatError is the error when the ai config format is incorrect ErrConfigFormatError = errors.New("ai config format is incorrect") + + RequestTimeout = 90 * time.Second + // RunFunctionTimeout is the timeout for awaiting the function response, default is 60 seconds + RunFunctionTimeout = 60 * time.Second ) // Config is the configuration of AI bridge. @@ -99,8 +112,8 @@ func ParseConfig(conf map[string]any) (config *Config, err error) { return } -// parseZipperAddr parses the zipper address from zipper listen address -func parseZipperAddr(addr string) string { +// ParseZipperAddr parses the zipper address from zipper listen address +func ParseZipperAddr(addr string) string { host, port, err := net.SplitHostPort(addr) if err != nil { ylog.Error("invalid zipper address, return default", @@ -136,3 +149,67 @@ func parseZipperAddr(addr string) string { } return localIP + ":" + port } + +func getLocalIP() (string, error) { + addrs, err := net.InterfaceAddrs() + if err != nil { + return "", err + } + for _, addr := range addrs { + ipnet, ok := addr.(*net.IPNet) + ip := ipnet.IP + if !ok || ip.IsUnspecified() || ip.To4() == nil || ip.To16() == nil { + continue + } + return ip.String(), nil + } + return "", errors.New("not found local ip") +} + +type callerContextKey struct{} + +// WithCallerContext adds the caller to the request context +func WithCallerContext(ctx context.Context, caller *Caller) context.Context { + return context.WithValue(ctx, callerContextKey{}, caller) +} + +// FromCallerContext returns the caller from the request context +func FromCallerContext(ctx context.Context) *Caller { + caller, ok := ctx.Value(callerContextKey{}).(*Caller) + if !ok { + return nil + } + return caller +} + +type transIDContextKey struct{} + +// WithTransIDContext adds the transID to the request context +func WithTransIDContext(ctx context.Context, transID string) context.Context { + return context.WithValue(ctx, transIDContextKey{}, transID) +} + +// FromTransIDContext returns the transID from the request context +func FromTransIDContext(ctx context.Context) string { + val, ok := ctx.Value(transIDContextKey{}).(string) + if !ok { + return "" + } + return val +} + +type tracerContextKey struct{} + +// WithTracerContext adds the tracer to the request context +func WithTracerContext(ctx context.Context, tracer trace.Tracer) context.Context { + return context.WithValue(ctx, tracerContextKey{}, tracer) +} + +// FromTransIDContext returns the transID from the request context +func FromTracerContext(ctx context.Context) trace.Tracer { + val, ok := ctx.Value(tracerContextKey{}).(trace.Tracer) + if !ok { + return new(noop.Tracer) + } + return val +} diff --git a/pkg/bridge/ai/ai_test.go b/pkg/bridge/ai/ai_test.go index 8b7e054a5..fa8fcca95 100644 --- a/pkg/bridge/ai/ai_test.go +++ b/pkg/bridge/ai/ai_test.go @@ -6,47 +6,47 @@ import ( "github.com/stretchr/testify/assert" ) -func TestParseZipperAddr(t *testing.T) { - tests := []struct { - name string - addr string - expected string - }{ - { - name: "Valid address", - addr: "192.168.1.100:9000", - expected: "192.168.1.100:9000", - }, - { - name: "Valid address of localhost", - addr: "localhost", - expected: "localhost:9000", - }, - - { - name: "Invalid address", - addr: "invalid", - expected: DefaultZipperAddr, - }, - { - name: "Localhost", - addr: "localhost:9000", - expected: "localhost:9000", - }, - { - name: "Unspecified IP", - addr: "0.0.0.0:9000", - expected: "127.0.0.1:9000", // Expect the local IP - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := parseZipperAddr(tt.addr) - assert.Equal(t, tt.expected, got, tt.name) - }) - } -} +// func TestParseZipperAddr(t *testing.T) { +// tests := []struct { +// name string +// addr string +// expected string +// }{ +// { +// name: "Valid address", +// addr: "192.168.1.100:9000", +// expected: "192.168.1.100:9000", +// }, +// { +// name: "Valid address of localhost", +// addr: "localhost", +// expected: "localhost:9000", +// }, +// +// { +// name: "Invalid address", +// addr: "invalid", +// expected: DefaultZipperAddr, +// }, +// { +// name: "Localhost", +// addr: "localhost:9000", +// expected: "localhost:9000", +// }, +// { +// name: "Unspecified IP", +// addr: "0.0.0.0:9000", +// expected: "127.0.0.1:9000", // Expect the local IP +// }, +// } +// +// for _, tt := range tests { +// t.Run(tt.name, func(t *testing.T) { +// got := parseZipperAddr(tt.addr) +// assert.Equal(t, tt.expected, got, tt.name) +// }) +// } +// } func TestParseConfig(t *testing.T) { tests := []struct { diff --git a/pkg/bridge/ai/register/register.go b/pkg/bridge/ai/register/register.go index 4d8b9784a..287632b52 100644 --- a/pkg/bridge/ai/register/register.go +++ b/pkg/bridge/ai/register/register.go @@ -8,62 +8,18 @@ import ( "github.com/sashabaranov/go-openai" "github.com/yomorun/yomo/ai" "github.com/yomorun/yomo/core/metadata" -) - -var ( - // mu protects defaultRegister - mu sync.Mutex - defaultRegister Register + "github.com/yomorun/yomo/pkg/bridge/mcp" ) func init() { - SetRegister(®ister{}) -} - -// SetRegister sets the default register -func SetRegister(r Register) { - mu.Lock() - defer mu.Unlock() - defaultRegister = r -} - -// GetRegister gets the default register -func GetRegister() Register { - mu.Lock() - defer mu.Unlock() - return defaultRegister + ai.SetRegister(®ister{}) } // NewDefault creates a new default register. -func NewDefault() Register { +func NewDefault() ai.Register { return ®ister{} } -// ListToolCalls returns the list of tool calls -func ListToolCalls(md metadata.M) ([]openai.Tool, error) { - return defaultRegister.ListToolCalls(md) -} - -// RegisterFunction registers a function calling function -func RegisterFunction(functionDefinition *openai.FunctionDefinition, connID uint64, md metadata.M) error { - return defaultRegister.RegisterFunction(functionDefinition, connID, md) -} - -// UnregisterFunction unregisters a function calling function -func UnregisterFunction(connID uint64, md metadata.M) { - defaultRegister.UnregisterFunction(connID, md) -} - -// Register provides an stateful register for registering and unregistering functions -type Register interface { - // ListToolCalls returns the list of tool calls - ListToolCalls(md metadata.M) ([]openai.Tool, error) - // RegisterFunction registers a function calling function - RegisterFunction(fd *openai.FunctionDefinition, connID uint64, md metadata.M) error - // UnregisterFunction unregisters a function calling function - UnregisterFunction(connID uint64, md metadata.M) -} - type register struct { underlying sync.Map } @@ -93,14 +49,23 @@ func (r *register) RegisterFunction(fd *ai.FunctionDefinition, connID uint64, md if err != nil { return err } + // ai function r.underlying.Store(connID, openai.Tool{ Function: fd, Type: openai.ToolTypeFunction, }) + // mcp tool + err = mcp.AddMCPTool(connID, fd) + if err != nil { + return err + } return nil } func (r *register) UnregisterFunction(connID uint64, _ metadata.M) { + // ai function r.underlying.Delete(connID) + // mcp tool + mcp.RemoveMCPTool(connID) } diff --git a/pkg/bridge/ai/service.go b/pkg/bridge/ai/service.go index dc98cba6b..d297be552 100644 --- a/pkg/bridge/ai/service.go +++ b/pkg/bridge/ai/service.go @@ -17,7 +17,7 @@ import ( "github.com/yomorun/yomo/core/metadata" "github.com/yomorun/yomo/core/ylog" "github.com/yomorun/yomo/pkg/bridge/ai/provider" - "github.com/yomorun/yomo/pkg/bridge/ai/register" + "github.com/yomorun/yomo/pkg/id" "go.opentelemetry.io/otel/trace" "go.opentelemetry.io/otel/trace/noop" @@ -84,7 +84,7 @@ func initOption(opt *ServiceOptions) *ServiceOptions { } func newService(provider provider.LLMProvider, ncf newCallerFunc, opt *ServiceOptions) *Service { - var onEvict = func(_ string, caller *Caller) { + onEvict := func(_ string, caller *Caller) { caller.Close() } @@ -119,7 +119,7 @@ func (srv *Service) GetInvoke(ctx context.Context, userInstruction, baseSystemMe } md := caller.Metadata().Clone() // read tools attached to the metadata - tools, err := register.ListToolCalls(md) + tools, err := ai.ListToolCalls(md) if err != nil { return &ai.InvokeResponse{}, err } @@ -136,9 +136,7 @@ func (srv *Service) GetInvoke(ctx context.Context, userInstruction, baseSystemMe promptUsage int completionUsage int ) - var ( - _, span = tracer.Start(ctx, "first_call") - ) + _, span := tracer.Start(ctx, "first_call") chatCompletionResponse, err := srv.provider.GetChatCompletions(ctx, req, md) if err != nil { return nil, err @@ -223,7 +221,7 @@ func (srv *Service) GetChatCompletions(ctx context.Context, req openai.ChatCompl md := caller.Metadata().Clone() // 1. find all hosting tool sfn - tools, err := register.ListToolCalls(md) + tools, err := ai.ListToolCalls(md) if err != nil { return err } @@ -458,6 +456,11 @@ func (srv *Service) GetChatCompletions(ctx context.Context, req openai.ChatCompl } } +// Logger returns the logger of the service +func (src *Service) Logger() *slog.Logger { + return src.logger +} + func startRespSpan(ctx context.Context, reqSpan trace.Span, tracer trace.Tracer, w EventResponseWriter) trace.Span { reqSpan.End() recordTTFT(ctx, tracer, w) diff --git a/pkg/bridge/ai/service_test.go b/pkg/bridge/ai/service_test.go index 19c0298ba..f91138178 100644 --- a/pkg/bridge/ai/service_test.go +++ b/pkg/bridge/ai/service_test.go @@ -206,7 +206,7 @@ func TestServiceInvoke(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - register.SetRegister(register.NewDefault()) + ai.SetRegister(register.NewDefault()) pd, err := provider.NewMock("mock provider", tt.args.providerMockData...) if err != nil { @@ -374,7 +374,7 @@ func TestServiceChatCompletion(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - register.SetRegister(register.NewDefault()) + ai.SetRegister(register.NewDefault()) pd, err := provider.NewMock("mock provider", tt.args.providerMockData...) if err != nil { @@ -412,7 +412,7 @@ func TestServiceChatCompletion(t *testing.T) { func mockCaller(calls []mockFunctionCall) *Caller { // register function to register for connID, call := range calls { - register.RegisterFunction(&openai.FunctionDefinition{Name: call.functionName}, uint64(connID), nil) + ai.RegisterFunction(&openai.FunctionDefinition{Name: call.functionName}, uint64(connID), nil) } caller := &Caller{ diff --git a/pkg/bridge/ai/api_server.go b/pkg/bridge/llm/api_server.go similarity index 58% rename from pkg/bridge/ai/api_server.go rename to pkg/bridge/llm/api_server.go index 7326c066b..ba2b5d3b9 100644 --- a/pkg/bridge/ai/api_server.go +++ b/pkg/bridge/llm/api_server.go @@ -1,4 +1,4 @@ -package ai +package llm import ( "context" @@ -6,7 +6,6 @@ import ( "errors" "fmt" "log/slog" - "net" "net/http" "os" "time" @@ -14,25 +13,13 @@ import ( openai "github.com/sashabaranov/go-openai" "github.com/yomorun/yomo" "github.com/yomorun/yomo/ai" + pkgai "github.com/yomorun/yomo/pkg/bridge/ai" "github.com/yomorun/yomo/pkg/bridge/ai/provider" - "github.com/yomorun/yomo/pkg/bridge/ai/register" + _ "github.com/yomorun/yomo/pkg/bridge/ai/register" "github.com/yomorun/yomo/pkg/id" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" - "go.opentelemetry.io/otel/trace/noop" -) - -const ( - // DefaultZipperAddr is the default endpoint of the zipper - DefaultZipperAddr = "localhost:9000" -) - -var ( - // RequestTimeout is the timeout for the request, default is 360 seconds - RequestTimeout = 360 * time.Second - // RunFunctionTimeout is the timeout for awaiting the function response, default is 60 seconds - RunFunctionTimeout = 60 * time.Second ) // BasicAPIServer provides restful service for end user @@ -41,7 +28,7 @@ type BasicAPIServer struct { } // Serve starts the Basic API Server -func Serve(config *Config, logger *slog.Logger, source yomo.Source, reducer yomo.StreamFunction) error { +func Serve(config *pkgai.Config, logger *slog.Logger, source yomo.Source, reducer yomo.StreamFunction) error { provider, err := provider.GetProvider(config.Server.Provider) if err != nil { return err @@ -78,15 +65,15 @@ func DecorateHandler(h http.Handler, decorates ...func(handler http.Handler) htt } // NewBasicAPIServer creates a new restful service -func NewBasicAPIServer(config *Config, provider provider.LLMProvider, source yomo.Source, reducer yomo.StreamFunction, logger *slog.Logger) (*BasicAPIServer, error) { +func NewBasicAPIServer(config *pkgai.Config, provider provider.LLMProvider, source yomo.Source, reducer yomo.StreamFunction, logger *slog.Logger) (*BasicAPIServer, error) { logger = logger.With("service", "llm-bridge") - opts := &ServiceOptions{ + opts := &pkgai.ServiceOptions{ Logger: logger, SourceBuilder: func(_ string) yomo.Source { return source }, ReducerBuilder: func(_ string) yomo.StreamFunction { return reducer }, } - service := NewService(provider, opts) + service := pkgai.NewService(provider, opts) mux := NewServeMux(NewHandler(service)) @@ -94,31 +81,32 @@ func NewBasicAPIServer(config *Config, provider provider.LLMProvider, source yom httpHandler: DecorateHandler(mux, decorateReqContext(service, logger)), } - logger.Info("start AI Bridge service", "addr", config.Server.Addr, "provider", provider.Name()) + logger.Info("[llm] start llm bridge service", "addr", config.Server.Addr, "provider", provider.Name()) return server, nil } // decorateReqContext decorates the context of the request, it injects a transID into the request's context, // log the request information and start tracing the request. -func decorateReqContext(service *Service, logger *slog.Logger) func(handler http.Handler) http.Handler { +func decorateReqContext(service *pkgai.Service, logger *slog.Logger) func(handler http.Handler) http.Handler { hostname, _ := os.Hostname() tracer := otel.Tracer("yomo-llm-bridge") return func(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ww := NewResponseWriter(w, logger) + ww := pkgai.NewResponseWriter(w, logger) ctx := r.Context() - ctx = WithTracerContext(ctx, tracer) + ctx = pkgai.WithTracerContext(ctx, tracer) start := time.Now() caller, err := service.LoadOrCreateCaller(r) if err != nil { - RespondWithError(ww, http.StatusBadRequest, err) + logger.Error("failed to load or create caller", "error", err) + pkgai.RespondWithError(ww, http.StatusBadRequest, err) return } - ctx = WithCallerContext(ctx, caller) + ctx = pkgai.WithCallerContext(ctx, caller) // trace every request ctx, span := tracer.Start( @@ -130,7 +118,7 @@ func decorateReqContext(service *Service, logger *slog.Logger) func(handler http defer span.End() transID := id.New(32) - ctx = WithTransIDContext(ctx, transID) + ctx = pkgai.WithTransIDContext(ctx, transID) handler.ServeHTTP(ww, r.WithContext(ctx)) @@ -160,22 +148,22 @@ func decorateReqContext(service *Service, logger *slog.Logger) func(handler http // Handler handles the http request. type Handler struct { - service *Service + service *pkgai.Service } // NewHandler return a hander that handles chat completions requests. -func NewHandler(service *Service) *Handler { +func NewHandler(service *pkgai.Service) *Handler { return &Handler{service} } // HandleOverview is the handler for GET /overview func (h *Handler) HandleOverview(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - ww := w.(EventResponseWriter) + ww := w.(pkgai.EventResponseWriter) - tools, err := register.ListToolCalls(FromCallerContext(r.Context()).Metadata()) + tools, err := ai.ListToolCalls(pkgai.FromCallerContext(r.Context()).Metadata()) if err != nil { - RespondWithError(ww, http.StatusInternalServerError, err) + pkgai.RespondWithError(ww, http.StatusInternalServerError, err) return } @@ -193,31 +181,31 @@ var baseSystemMessage = `You are a very helpful assistant. Your job is to choose func (h *Handler) HandleInvoke(w http.ResponseWriter, r *http.Request) { var ( ctx = r.Context() - transID = FromTransIDContext(ctx) - ww = w.(EventResponseWriter) + transID = pkgai.FromTransIDContext(ctx) + ww = w.(pkgai.EventResponseWriter) ) defer r.Body.Close() - req, err := DecodeRequest[ai.InvokeRequest](r, ww, h.service.logger) + req, err := DecodeRequest[ai.InvokeRequest](r, ww, h.service.Logger()) if err != nil { - RespondWithError(ww, http.StatusBadRequest, err) + pkgai.RespondWithError(ww, http.StatusBadRequest, err) ww.RecordError(errors.New("bad request")) return } - ctx, cancel := context.WithTimeout(r.Context(), RequestTimeout) + ctx, cancel := context.WithTimeout(r.Context(), pkgai.RequestTimeout) defer cancel() var ( - caller = FromCallerContext(ctx) - tracer = FromTracerContext(ctx) + caller = pkgai.FromCallerContext(ctx) + tracer = pkgai.FromTracerContext(ctx) ) w.Header().Set("Content-Type", "application/json") res, err := h.service.GetInvoke(ctx, req.Prompt, baseSystemMessage, transID, caller, req.IncludeCallStack, tracer) if err != nil { - RespondWithError(ww, http.StatusInternalServerError, err) + pkgai.RespondWithError(ww, http.StatusInternalServerError, err) return } @@ -228,24 +216,24 @@ func (h *Handler) HandleInvoke(w http.ResponseWriter, r *http.Request) { func (h *Handler) HandleChatCompletions(w http.ResponseWriter, r *http.Request) { var ( ctx = r.Context() - transID = FromTransIDContext(ctx) - ww = w.(EventResponseWriter) + transID = pkgai.FromTransIDContext(ctx) + ww = w.(pkgai.EventResponseWriter) ) defer r.Body.Close() - req, err := DecodeRequest[openai.ChatCompletionRequest](r, ww, h.service.logger) + req, err := DecodeRequest[openai.ChatCompletionRequest](r, ww, h.service.Logger()) if err != nil { w.Header().Set("Content-Type", "application/json") - RespondWithError(ww, http.StatusBadRequest, err) + pkgai.RespondWithError(ww, http.StatusBadRequest, err) return } - ctx, cancel := context.WithTimeout(r.Context(), RequestTimeout) + ctx, cancel := context.WithTimeout(r.Context(), pkgai.RequestTimeout) defer cancel() var ( - caller = FromCallerContext(ctx) - tracer = FromTracerContext(ctx) + caller = pkgai.FromCallerContext(ctx) + tracer = pkgai.FromTracerContext(ctx) ) if err := h.service.GetChatCompletions(ctx, req, transID, caller, ww, tracer); err != nil { @@ -255,13 +243,13 @@ func (h *Handler) HandleChatCompletions(w http.ResponseWriter, r *http.Request) if !ww.IsStream() { w.Header().Set("Content-Type", "application/json") } - RespondWithError(ww, http.StatusBadRequest, err) + pkgai.RespondWithError(ww, http.StatusBadRequest, err) return } } // DecodeRequest decodes the request body into given type. -func DecodeRequest[T any](r *http.Request, ww EventResponseWriter, logger *slog.Logger) (T, error) { +func DecodeRequest[T any](r *http.Request, ww pkgai.EventResponseWriter, logger *slog.Logger) (T, error) { var req T err := json.NewDecoder(r.Body).Decode(&req) if err != nil { @@ -270,67 +258,3 @@ func DecodeRequest[T any](r *http.Request, ww EventResponseWriter, logger *slog. return req, nil } - -func getLocalIP() (string, error) { - addrs, err := net.InterfaceAddrs() - if err != nil { - return "", err - } - for _, addr := range addrs { - ipnet, ok := addr.(*net.IPNet) - ip := ipnet.IP - if !ok || ip.IsUnspecified() || ip.To4() == nil || ip.To16() == nil { - continue - } - return ip.String(), nil - } - return "", errors.New("not found local ip") -} - -type callerContextKey struct{} - -// WithCallerContext adds the caller to the request context -func WithCallerContext(ctx context.Context, caller *Caller) context.Context { - return context.WithValue(ctx, callerContextKey{}, caller) -} - -// FromCallerContext returns the caller from the request context -func FromCallerContext(ctx context.Context) *Caller { - caller, ok := ctx.Value(callerContextKey{}).(*Caller) - if !ok { - return nil - } - return caller -} - -type transIDContextKey struct{} - -// WithTransIDContext adds the transID to the request context -func WithTransIDContext(ctx context.Context, transID string) context.Context { - return context.WithValue(ctx, transIDContextKey{}, transID) -} - -// FromTransIDContext returns the transID from the request context -func FromTransIDContext(ctx context.Context) string { - val, ok := ctx.Value(transIDContextKey{}).(string) - if !ok { - return "" - } - return val -} - -type tracerContextKey struct{} - -// WithTracerContext adds the tracer to the request context -func WithTracerContext(ctx context.Context, tracer trace.Tracer) context.Context { - return context.WithValue(ctx, tracerContextKey{}, tracer) -} - -// FromTransIDContext returns the transID from the request context -func FromTracerContext(ctx context.Context) trace.Tracer { - val, ok := ctx.Value(tracerContextKey{}).(trace.Tracer) - if !ok { - return new(noop.Tracer) - } - return val -} diff --git a/pkg/bridge/ai/api_server_test.go b/pkg/bridge/llm/api_server_test.go similarity index 96% rename from pkg/bridge/ai/api_server_test.go rename to pkg/bridge/llm/api_server_test.go index 51c6c6a23..a0e70dd4d 100644 --- a/pkg/bridge/ai/api_server_test.go +++ b/pkg/bridge/llm/api_server_test.go @@ -1,4 +1,4 @@ -package ai +package llm import ( "bytes" @@ -14,7 +14,7 @@ import ( "github.com/yomorun/yomo/ai" "github.com/yomorun/yomo/core/metadata" "github.com/yomorun/yomo/pkg/bridge/ai/provider" - "github.com/yomorun/yomo/pkg/bridge/ai/register" + _ "github.com/yomorun/yomo/pkg/bridge/ai/register" ) func TestServer(t *testing.T) { @@ -31,8 +31,8 @@ func TestServer(t *testing.T) { Required: []string{"prop1"}, }, } - register.SetRegister(register.NewDefault()) - register.RegisterFunction(functionDefinition, 200, nil) + // register.SetRegister(register.NewDefault()) + ai.RegisterFunction(functionDefinition, 200, nil) // mock the provider and the req/res of the caller pd, err := provider.NewMock("mock provider", provider.MockChatCompletionResponse(stopResp, stopResp)) diff --git a/pkg/bridge/mcp/mcp.go b/pkg/bridge/mcp/mcp.go new file mode 100644 index 000000000..8687cb1c1 --- /dev/null +++ b/pkg/bridge/mcp/mcp.go @@ -0,0 +1,54 @@ +package mcp + +import ( + "errors" + + "github.com/yomorun/yomo/core/ylog" + "gopkg.in/yaml.v3" +) + +var ( + // ErrConfigNotFound is the error when the mcp config was not found + ErrConfigNotFound = errors.New("mcp config was not found") + // ErrConfigFormatError is the error when the ai config format is incorrect + ErrConfigFormatError = errors.New("mcp config format is incorrect") +) + +type Config struct { + Server Server `yaml:"server"` // Server is the configuration of the mcp server +} + +// Server is the configuration of the mcp server, which is the endpoint for end user access +type Server struct { + Addr string `yaml:"addr"` // Addr is the address of the server +} + +// ParseConfig parses the AI config from conf +func ParseConfig(conf map[string]any) (config *Config, err error) { + section, ok := conf["mcp"] + if !ok { + err = ErrConfigNotFound + return + } + aiConfig, ok := section.(map[string]any) + if !ok { + err = ErrConfigFormatError + return + } + data, e := yaml.Marshal(aiConfig) + if e != nil { + err = e + ylog.Error("marshal mcp config", "err", err.Error()) + return + } + err = yaml.Unmarshal(data, &config) + if err != nil { + ylog.Error("unmarshal mcp config", "err", err.Error()) + return + } + // defaults values + if config.Server.Addr == "" { + config.Server.Addr = ":9090" + } + return +} diff --git a/pkg/bridge/mcp/mcp_server.go b/pkg/bridge/mcp/mcp_server.go new file mode 100644 index 000000000..d875da2d6 --- /dev/null +++ b/pkg/bridge/mcp/mcp_server.go @@ -0,0 +1,184 @@ +package mcp + +import ( + "context" + "errors" + "fmt" + "log/slog" + "net/http" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + "github.com/yomorun/yomo/core/ylog" + "github.com/yomorun/yomo/pkg/bridge/ai" +) + +var ( + ErrMCPServerNotFound = errors.New("mcp server not found") + ErrUnknownMCPServerType = errors.New("unknown mcp server type") + ErrCallerNotFound = errors.New("caller not found") +) + +// MCPServer represents a MCP server +type MCPServer struct { + underlying *server.MCPServer + SSEServer *server.SSEServer + basePath string + logger *slog.Logger +} + +// NewMCPServer create a new mcp server +func NewMCPServer(logger *slog.Logger) (*MCPServer, error) { + // logger + if logger == nil { + logger = ylog.Default() + } + // create mcp server + underlyingMCPServer := server.NewMCPServer( + "mcp-server", + "2024-11-05", + server.WithLogging(), + server.WithHooks(hooks(logger)), + server.WithToolCapabilities(true), + server.WithRecovery(), + ) + // sse options + sseOpts := []server.SSEOption{ + server.WithHTTPServer(httpServer), + server.WithSSEContextFunc(authContextFunc()), + } + // sse server + sseServer := server.NewSSEServer(underlyingMCPServer, sseOpts...) + + mcpServer := &MCPServer{ + underlying: underlyingMCPServer, + SSEServer: sseServer, + logger: logger, + } + + logger.Info("[mcp] server is created", + "sse_endpoint", sseServer.CompleteSseEndpoint(), + "message_endpoint", sseServer.CompleteMessageEndpoint(), + ) + + return mcpServer, nil +} + +// BasePath returns the base path of the mcp server +func (s *MCPServer) BasePath() string { + return s.basePath +} + +func (s *MCPServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + logger.Info(fmt.Sprintf("[mcp] url:%s", r.URL.String()), "method", r.Method) + s.SSEServer.ServeHTTP(w, r) +} + +// AddTool adds a tool to the mcp server +func (s *MCPServer) AddTool(tool mcp.Tool, handler server.ToolHandlerFunc) { + s.underlying.AddTool(tool, handler) +} + +// DeleteTools deletes tools by name +func (s *MCPServer) DeleteTools(names ...string) { + s.underlying.DeleteTools(names...) +} + +// AddPrompt adds a prompt to the mcp server +func (s *MCPServer) AddPrompt(prompt mcp.Prompt, handler server.PromptHandlerFunc) { + s.underlying.AddPrompt(prompt, handler) +} + +func authContextFunc() server.SSEContextFunc { + return func(ctx context.Context, r *http.Request) context.Context { + // context with caller + caller, err := aiService.LoadOrCreateCaller(r) + if err != nil { + logger.Error("[mcp] failed to load or create caller", "error", err) + return ctx + } + // caller + ctx = ai.WithCallerContext(ctx, caller) + logger.Debug("[mcp] sse context with caller", "path", r.URL.Path) + return ctx + } +} + +func hooks(logger *slog.Logger) *server.Hooks { + hooks := &server.Hooks{} + + hooks.AddBeforeAny(func(ctx context.Context, id any, method mcp.MCPMethod, message any) { + logger.Debug("[mcp] hook.beforeAny", "method", method, "id", id, "message", message) + }) + hooks.AddOnSuccess(func(ctx context.Context, id any, method mcp.MCPMethod, message any, result any) { + logger.Info(fmt.Sprintf("[mcp] rpc:%s", method), "id", id, "message", message, "result", result) + }) + hooks.AddOnError(func(ctx context.Context, id any, method mcp.MCPMethod, message any, err error) { + logger.Error("[mcp] rpc call error", "method", method, "id", id, "message", message, "error", err) + }) + // initialize + hooks.AddBeforeInitialize(func(ctx context.Context, id any, message *mcp.InitializeRequest) { + logger.Debug("[mcp] hook.beforeInitialize", "id", id, "message", message) + }) + hooks.AddAfterInitialize(func(ctx context.Context, id any, message *mcp.InitializeRequest, result *mcp.InitializeResult) { + logger.Debug("[mcp] hook.afterInitialize", "id", id, "message", message, "result", result) + }) + // ping + hooks.AddBeforePing(func(ctx context.Context, id any, message *mcp.PingRequest) { + logger.Debug("[mcp] hook.beforePing", "id", id, "message", message) + }) + hooks.AddAfterPing(func(ctx context.Context, id any, message *mcp.PingRequest, result *mcp.EmptyResult) { + logger.Debug("[mcp] hook.afterPing", "id", id, "message", message, "result", result) + }) + // list resources + hooks.AddBeforeListResources(func(ctx context.Context, id any, message *mcp.ListResourcesRequest) { + logger.Debug("[mcp] hook.beforeListResources", "id", id, "message", message) + }) + hooks.AddAfterListResources(func(ctx context.Context, id any, message *mcp.ListResourcesRequest, result *mcp.ListResourcesResult) { + logger.Debug("[mcp] hook.afterListResources", "id", id, "message", message, "result", result) + }) + // list resource templates + hooks.AddBeforeListResourceTemplates(func(ctx context.Context, id any, message *mcp.ListResourceTemplatesRequest) { + logger.Debug("[mcp] hook.beforeListResourceTemplates", "id", id, "message", message) + }) + hooks.AddAfterListResourceTemplates(func(ctx context.Context, id any, message *mcp.ListResourceTemplatesRequest, result *mcp.ListResourceTemplatesResult) { + logger.Debug("[mcp] hook.afterListResourceTemplates", "id", id, "message", message, "result", result) + }) + // read resource + hooks.AddBeforeReadResource(func(ctx context.Context, id any, message *mcp.ReadResourceRequest) { + logger.Debug("[mcp] hook.beforeReadResource", "id", id, "message", message) + }) + hooks.AddAfterReadResource(func(ctx context.Context, id any, message *mcp.ReadResourceRequest, result *mcp.ReadResourceResult) { + logger.Debug("[mcp] hook.afterReadResource", "id", id, "message", message, "result", result) + }) + // list prompts + hooks.AddBeforeListPrompts(func(ctx context.Context, id any, message *mcp.ListPromptsRequest) { + logger.Debug("[mcp] hook.beforeListPrompts", "id", id, "message", message) + }) + hooks.AddAfterListPrompts(func(ctx context.Context, id any, message *mcp.ListPromptsRequest, result *mcp.ListPromptsResult) { + logger.Debug("[mcp] hook.afterListPrompts", "id", id, "message", message, "result", result) + }) + // get prompt + hooks.AddBeforeGetPrompt(func(ctx context.Context, id any, message *mcp.GetPromptRequest) { + logger.Debug("[mcp] hook.beforeGetPrompt", "id", id, "message", message) + }) + hooks.AddAfterGetPrompt(func(ctx context.Context, id any, message *mcp.GetPromptRequest, result *mcp.GetPromptResult) { + logger.Debug("[mcp] hook.afterGetPrompt", "id", id, "message", message, "result", result) + }) + // list tools + hooks.AddBeforeListTools(func(ctx context.Context, id any, message *mcp.ListToolsRequest) { + logger.Debug("[mcp] hook.beforeListTools", "id", id, "message", message) + }) + hooks.AddAfterListTools(func(ctx context.Context, id any, message *mcp.ListToolsRequest, result *mcp.ListToolsResult) { + logger.Debug("[mcp] hook.afterListTools", "id", id, "message", message, "result", result) + }) + // call tool + hooks.AddBeforeCallTool(func(ctx context.Context, id any, message *mcp.CallToolRequest) { + logger.Debug("[mcp] hook.beforeCallTool", "id", id, "message", message) + }) + hooks.AddAfterCallTool(func(ctx context.Context, id any, message *mcp.CallToolRequest, result *mcp.CallToolResult) { + logger.Debug("[mcp] hook.afterCallTool", "id", id, "message", message, "result", result) + }) + + return hooks +} diff --git a/pkg/bridge/mcp/server.go b/pkg/bridge/mcp/server.go new file mode 100644 index 000000000..b1b71f0ca --- /dev/null +++ b/pkg/bridge/mcp/server.go @@ -0,0 +1,181 @@ +package mcp + +import ( + "context" + "encoding/json" + "log/slog" + "net/http" + "sync" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/sashabaranov/go-openai" + "github.com/yomorun/yomo" + "github.com/yomorun/yomo/pkg/bridge/ai" + "github.com/yomorun/yomo/pkg/bridge/ai/provider" + "github.com/yomorun/yomo/pkg/id" +) + +var ( + mcpServer *MCPServer + tools sync.Map + httpServer *http.Server + aiService *ai.Service + logger *slog.Logger +) + +// Start starts the http server +func Start(config *Config, aiConfig *ai.Config, zipperAddr string, log *slog.Logger) error { + // ai provider + provider, err := provider.GetProvider(aiConfig.Server.Provider) + if err != nil { + return err + } + // logger + logger = log.With("service", "mcp-bridge") + // ai service + opts := &ai.ServiceOptions{ + Logger: logger, + // SourceBuilder: func(_ string) yomo.Source { return source }, + // ReducerBuilder: func(_ string) yomo.StreamFunction { return reducer }, + } + zipperAddr = ai.ParseZipperAddr(zipperAddr) + sourceBuilder := func(credential string) yomo.Source { + source := yomo.NewSource("mcp-source", zipperAddr, yomo.WithCredential(credential)) + return source + } + reducerBuilder := func(credential string) yomo.StreamFunction { + reducer := yomo.NewStreamFunction("mcp-reducer", zipperAddr, yomo.WithSfnCredential(credential)) + return reducer + } + opts.SourceBuilder = sourceBuilder + opts.ReducerBuilder = reducerBuilder + aiService = ai.NewService(provider, opts) + // http server + addr := config.Server.Addr + mux := http.NewServeMux() + mux.HandleFunc("/", index) + mux.HandleFunc("/sse", mcpServerHandler) + mux.HandleFunc("/message", mcpServerHandler) + httpServer = &http.Server{ + Addr: addr, + Handler: mux, + } + // mcp server + mcpServer, err = NewMCPServer(logger) + if err != nil { + logger.Error("[mcp] failed to create server", "error", err) + return err + } + logger.Info("[mcp] bridge server is running", "addr", addr) + defer httpServer.Close() + + return httpServer.ListenAndServe() +} + +// Stop stops the http server +func Stop() error { + return httpServer.Close() +} + +func index(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("MCP Server is running")) +} + +func mcpServerHandler(w http.ResponseWriter, r *http.Request) { + if mcpServer == nil { + // mpc server is disabled + w.WriteHeader(http.StatusNotFound) + return + } + mcpServer.ServeHTTP(w, r) +} + +// AddMCPTool add mcp tool +func AddMCPTool(connID uint64, functionDefinition *openai.FunctionDefinition) error { + if mcpServer == nil { + // mpc server is disabled + return nil + } + // add tool + tool := mcp.NewToolWithRawSchema( + functionDefinition.Name, + functionDefinition.Description, + json.RawMessage(`{}`), + ) + // add input schema + if functionDefinition.Parameters != nil { + inputSchema, err := json.Marshal(functionDefinition.Parameters) + if err != nil { + return err + } + tool.RawInputSchema = json.RawMessage(inputSchema) + } + // Add tool handler + mcpServer.AddTool(tool, mcpToolHandler) + tools.Store(connID, functionDefinition) + logger.Info("[mcp] add tool", "input_schema", string(tool.RawInputSchema), "conn_id", connID) + + return nil +} + +// RemoveMCPTool remove mcp tool +func RemoveMCPTool(connID uint64) error { + if mcpServer == nil { + // mpc server is disabled + return nil + } + tools.Delete(connID) + tool, ok := tools.Load(connID) + if !ok { + // tool not found + return nil + } + functionDefinition, ok := tool.(*openai.FunctionDefinition) + if !ok { + // tool not found + return nil + } + mcpServer.DeleteTools(functionDefinition.Name) + return nil +} + +// mcpToolHandler mcp tool handler +func mcpToolHandler(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // get caller + caller := ai.FromCallerContext(ctx) + if caller == nil { + logger.Error("[mcp] tool handler load failed", "error", ErrCallerNotFound.Error()) + return nil, ErrCallerNotFound + } + // run sfn and get result + transID := id.New(32) + reqID := id.New(16) + toolCallID := id.New(8) + name := request.Params.Name + arguments, err := json.Marshal(request.Params.Arguments) + if err != nil { + return nil, err + } + args := string(arguments) + logger.Info("[mcp] tool is calling...", "name", name, "arguments", args) + fnCalls := []openai.ToolCall{ + { + ID: toolCallID, + Type: "function", + Function: openai.FunctionCall{ + Name: name, + Arguments: string(arguments), + }, + }, + } + callResult, err := caller.Call(ctx, transID, reqID, fnCalls) + if err != nil { + logger.Error("[mcp] tool call error", "error", err, "name", name, "arguments", args) + return nil, err + } + result := callResult[0].Content + logger.Info("[mcp] tool call result", "name", name, "arguments", args, "result", string(result)) + + return mcp.NewToolResultText(result), nil +} From c300108ccb3f9d6cb8bc032741f2c6eeb779f2fb Mon Sep 17 00:00:00 2001 From: venjiang Date: Thu, 17 Apr 2025 18:12:24 +0800 Subject: [PATCH 2/5] register nil check --- ai/register.go | 9 ++ pkg/bridge/ai/ai.go | 51 -------- pkg/bridge/ai/http.go | 227 +++++++++++++++++++++++++++++++++++ pkg/bridge/llm/api_server.go | 145 +--------------------- pkg/bridge/mcp/mcp_server.go | 4 +- pkg/bridge/mcp/server.go | 14 +-- 6 files changed, 247 insertions(+), 203 deletions(-) create mode 100644 pkg/bridge/ai/http.go diff --git a/ai/register.go b/ai/register.go index ebe9ad531..7cc19905d 100644 --- a/ai/register.go +++ b/ai/register.go @@ -28,16 +28,25 @@ func GetRegister() Register { // ListToolCalls returns the list of tool calls func ListToolCalls(md metadata.M) ([]openai.Tool, error) { + if defaultRegister == nil { + return nil, nil + } return defaultRegister.ListToolCalls(md) } // RegisterFunction registers a function calling function func RegisterFunction(functionDefinition *openai.FunctionDefinition, connID uint64, md metadata.M) error { + if defaultRegister == nil { + return nil + } return defaultRegister.RegisterFunction(functionDefinition, connID, md) } // UnregisterFunction unregisters a function calling function func UnregisterFunction(connID uint64, md metadata.M) { + if defaultRegister == nil { + return + } defaultRegister.UnregisterFunction(connID, md) } diff --git a/pkg/bridge/ai/ai.go b/pkg/bridge/ai/ai.go index 7bfb656a5..31ea63b9e 100644 --- a/pkg/bridge/ai/ai.go +++ b/pkg/bridge/ai/ai.go @@ -2,14 +2,11 @@ package ai import ( - "context" "errors" "net" "time" "github.com/yomorun/yomo/core/ylog" - "go.opentelemetry.io/otel/trace" - "go.opentelemetry.io/otel/trace/noop" "gopkg.in/yaml.v3" ) @@ -165,51 +162,3 @@ func getLocalIP() (string, error) { } return "", errors.New("not found local ip") } - -type callerContextKey struct{} - -// WithCallerContext adds the caller to the request context -func WithCallerContext(ctx context.Context, caller *Caller) context.Context { - return context.WithValue(ctx, callerContextKey{}, caller) -} - -// FromCallerContext returns the caller from the request context -func FromCallerContext(ctx context.Context) *Caller { - caller, ok := ctx.Value(callerContextKey{}).(*Caller) - if !ok { - return nil - } - return caller -} - -type transIDContextKey struct{} - -// WithTransIDContext adds the transID to the request context -func WithTransIDContext(ctx context.Context, transID string) context.Context { - return context.WithValue(ctx, transIDContextKey{}, transID) -} - -// FromTransIDContext returns the transID from the request context -func FromTransIDContext(ctx context.Context) string { - val, ok := ctx.Value(transIDContextKey{}).(string) - if !ok { - return "" - } - return val -} - -type tracerContextKey struct{} - -// WithTracerContext adds the tracer to the request context -func WithTracerContext(ctx context.Context, tracer trace.Tracer) context.Context { - return context.WithValue(ctx, tracerContextKey{}, tracer) -} - -// FromTransIDContext returns the transID from the request context -func FromTracerContext(ctx context.Context) trace.Tracer { - val, ok := ctx.Value(tracerContextKey{}).(trace.Tracer) - if !ok { - return new(noop.Tracer) - } - return val -} diff --git a/pkg/bridge/ai/http.go b/pkg/bridge/ai/http.go new file mode 100644 index 000000000..857fbc58f --- /dev/null +++ b/pkg/bridge/ai/http.go @@ -0,0 +1,227 @@ +package ai + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log/slog" + "net/http" + "reflect" + + openai "github.com/sashabaranov/go-openai" + "github.com/yomorun/yomo/ai" + "go.opentelemetry.io/otel/trace" + "go.opentelemetry.io/otel/trace/noop" +) + +type callerContextKey struct{} + +// WithCallerContext adds the caller to the request context +func WithCallerContext(ctx context.Context, caller *Caller) context.Context { + return context.WithValue(ctx, callerContextKey{}, caller) +} + +// FromCallerContext returns the caller from the request context +func FromCallerContext(ctx context.Context) *Caller { + caller, ok := ctx.Value(callerContextKey{}).(*Caller) + if !ok { + return nil + } + return caller +} + +type transIDContextKey struct{} + +// WithTransIDContext adds the transID to the request context +func WithTransIDContext(ctx context.Context, transID string) context.Context { + return context.WithValue(ctx, transIDContextKey{}, transID) +} + +// FromTransIDContext returns the transID from the request context +func FromTransIDContext(ctx context.Context) string { + val, ok := ctx.Value(transIDContextKey{}).(string) + if !ok { + return "" + } + return val +} + +type tracerContextKey struct{} + +// WithTracerContext adds the tracer to the request context +func WithTracerContext(ctx context.Context, tracer trace.Tracer) context.Context { + return context.WithValue(ctx, tracerContextKey{}, tracer) +} + +// FromTransIDContext returns the transID from the request context +func FromTracerContext(ctx context.Context) trace.Tracer { + val, ok := ctx.Value(tracerContextKey{}).(trace.Tracer) + if !ok { + return new(noop.Tracer) + } + return val +} + +// RespondWithError writes an error to response according to the OpenAI API spec. +func RespondWithError(w http.ResponseWriter, code int, err error, logger *slog.Logger) { + code, errString := parseCodeError(code, err) + logger.Error("bridge server error", "err", errString, "err_type", reflect.TypeOf(err).String()) + + w.WriteHeader(code) + w.Write([]byte(fmt.Sprintf(`{"error":{"code":"%d","message":"%s"}}`, code, errString))) +} + +func parseCodeError(code int, err error) (int, string) { + errString := err.Error() + + switch e := err.(type) { + case *openai.APIError: + code = e.HTTPStatusCode + errString = e.Message + case *openai.RequestError: + code = e.HTTPStatusCode + errString = e.Error() + } + + return code, errString +} + +// NewServeMux creates a new http.ServeMux for the llm bridge server. +func NewServeMux(h *Handler) *http.ServeMux { + mux := http.NewServeMux() + + // GET /overview + mux.HandleFunc("/overview", h.HandleOverview) + // POST /invoke + mux.HandleFunc("/invoke", h.HandleInvoke) + // POST /v1/chat/completions (OpenAI compatible interface) + mux.HandleFunc("/v1/chat/completions", h.HandleChatCompletions) + + return mux +} + +// DecorateHandler decorates the http.Handler. +func DecorateHandler(h http.Handler, decorates ...func(handler http.Handler) http.Handler) http.Handler { + // decorate the http.Handler + for i := len(decorates) - 1; i >= 0; i-- { + h = decorates[i](h) + } + return h +} + +// Handler handles the http request. +type Handler struct { + service *Service +} + +// NewHandler return a hander that handles chat completions requests. +func NewHandler(service *Service) *Handler { + return &Handler{service} +} + +// HandleOverview is the handler for GET /overview +func (h *Handler) HandleOverview(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + + tools, err := ai.ListToolCalls(FromCallerContext(r.Context()).Metadata()) + if err != nil { + RespondWithError(w, http.StatusInternalServerError, err, h.service.Logger()) + return + } + + functions := make([]*openai.FunctionDefinition, len(tools)) + for i, tc := range tools { + functions[i] = tc.Function + } + + json.NewEncoder(w).Encode(&ai.OverviewResponse{Functions: functions}) +} + +var baseSystemMessage = `You are a very helpful assistant. Your job is to choose the best possible action to solve the user question or task. Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.` + +// HandleInvoke is the handler for POST /invoke +func (h *Handler) HandleInvoke(w http.ResponseWriter, r *http.Request) { + var ( + ctx = r.Context() + transID = FromTransIDContext(ctx) + ww = w.(EventResponseWriter) + ) + defer r.Body.Close() + + req, err := DecodeRequest[ai.InvokeRequest](r, w, h.service.Logger()) + if err != nil { + ww.RecordError(errors.New("bad request")) + return + } + + ctx, cancel := context.WithTimeout(r.Context(), RequestTimeout) + defer cancel() + + var ( + caller = FromCallerContext(ctx) + tracer = FromTracerContext(ctx) + ) + + w.Header().Set("Content-Type", "application/json") + + res, err := h.service.GetInvoke(ctx, req.Prompt, baseSystemMessage, transID, caller, req.IncludeCallStack, tracer) + if err != nil { + ww.RecordError(err) + RespondWithError(w, http.StatusInternalServerError, err, h.service.Logger()) + return + } + + _ = json.NewEncoder(w).Encode(res) +} + +// HandleChatCompletions is the handler for POST /chat/completions +func (h *Handler) HandleChatCompletions(w http.ResponseWriter, r *http.Request) { + var ( + ctx = r.Context() + transID = FromTransIDContext(ctx) + ww = w.(EventResponseWriter) + ) + defer r.Body.Close() + + req, err := DecodeRequest[openai.ChatCompletionRequest](r, w, h.service.Logger()) + if err != nil { + ww.RecordError(err) + return + } + + ctx, cancel := context.WithTimeout(r.Context(), RequestTimeout) + defer cancel() + + var ( + caller = FromCallerContext(ctx) + tracer = FromTracerContext(ctx) + ) + + if err := h.service.GetChatCompletions(ctx, req, transID, caller, ww, tracer); err != nil { + ww.RecordError(err) + if err == context.Canceled { + return + } + if ww.IsStream() { + h.service.Logger().Error("bridge server error", "err", err.Error(), "err_type", reflect.TypeOf(err).String()) + w.Write([]byte(fmt.Sprintf(`{"error":{"message":"%s"}}`, err.Error()))) + return + } + RespondWithError(w, http.StatusBadRequest, err, h.service.Logger()) + return + } +} + +// DecodeRequest decodes the request body into given type. +func DecodeRequest[T any](r *http.Request, w http.ResponseWriter, logger *slog.Logger) (T, error) { + var req T + err := json.NewDecoder(r.Body).Decode(&req) + if err != nil { + w.Header().Set("Content-Type", "application/json") + RespondWithError(w, http.StatusBadRequest, err, logger) + return req, err + } + + return req, nil +} diff --git a/pkg/bridge/llm/api_server.go b/pkg/bridge/llm/api_server.go index ba2b5d3b9..3e6532e61 100644 --- a/pkg/bridge/llm/api_server.go +++ b/pkg/bridge/llm/api_server.go @@ -1,18 +1,13 @@ package llm import ( - "context" - "encoding/json" - "errors" "fmt" "log/slog" "net/http" "os" "time" - openai "github.com/sashabaranov/go-openai" "github.com/yomorun/yomo" - "github.com/yomorun/yomo/ai" pkgai "github.com/yomorun/yomo/pkg/bridge/ai" "github.com/yomorun/yomo/pkg/bridge/ai/provider" _ "github.com/yomorun/yomo/pkg/bridge/ai/register" @@ -41,29 +36,6 @@ func Serve(config *pkgai.Config, logger *slog.Logger, source yomo.Source, reduce return http.ListenAndServe(config.Server.Addr, srv.httpHandler) } -// NewServeMux creates a new http.ServeMux for the llm bridge server. -func NewServeMux(h *Handler) *http.ServeMux { - mux := http.NewServeMux() - - // GET /overview - mux.HandleFunc("/overview", h.HandleOverview) - // POST /invoke - mux.HandleFunc("/invoke", h.HandleInvoke) - // POST /v1/chat/completions (OpenAI compatible interface) - mux.HandleFunc("/v1/chat/completions", h.HandleChatCompletions) - - return mux -} - -// DecorateHandler decorates the http.Handler. -func DecorateHandler(h http.Handler, decorates ...func(handler http.Handler) http.Handler) http.Handler { - // decorate the http.Handler - for i := len(decorates) - 1; i >= 0; i-- { - h = decorates[i](h) - } - return h -} - // NewBasicAPIServer creates a new restful service func NewBasicAPIServer(config *pkgai.Config, provider provider.LLMProvider, source yomo.Source, reducer yomo.StreamFunction, logger *slog.Logger) (*BasicAPIServer, error) { logger = logger.With("service", "llm-bridge") @@ -75,10 +47,10 @@ func NewBasicAPIServer(config *pkgai.Config, provider provider.LLMProvider, sour } service := pkgai.NewService(provider, opts) - mux := NewServeMux(NewHandler(service)) + mux := pkgai.NewServeMux(pkgai.NewHandler(service)) server := &BasicAPIServer{ - httpHandler: DecorateHandler(mux, decorateReqContext(service, logger)), + httpHandler: pkgai.DecorateHandler(mux, decorateReqContext(service, logger)), } logger.Info("[llm] start llm bridge service", "addr", config.Server.Addr, "provider", provider.Name()) @@ -145,116 +117,3 @@ func decorateReqContext(service *pkgai.Service, logger *slog.Logger) func(handle }) } } - -// Handler handles the http request. -type Handler struct { - service *pkgai.Service -} - -// NewHandler return a hander that handles chat completions requests. -func NewHandler(service *pkgai.Service) *Handler { - return &Handler{service} -} - -// HandleOverview is the handler for GET /overview -func (h *Handler) HandleOverview(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - ww := w.(pkgai.EventResponseWriter) - - tools, err := ai.ListToolCalls(pkgai.FromCallerContext(r.Context()).Metadata()) - if err != nil { - pkgai.RespondWithError(ww, http.StatusInternalServerError, err) - return - } - - functions := make([]*openai.FunctionDefinition, len(tools)) - for i, tc := range tools { - functions[i] = tc.Function - } - - json.NewEncoder(w).Encode(&ai.OverviewResponse{Functions: functions}) -} - -var baseSystemMessage = `You are a very helpful assistant. Your job is to choose the best possible action to solve the user question or task. Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.` - -// HandleInvoke is the handler for POST /invoke -func (h *Handler) HandleInvoke(w http.ResponseWriter, r *http.Request) { - var ( - ctx = r.Context() - transID = pkgai.FromTransIDContext(ctx) - ww = w.(pkgai.EventResponseWriter) - ) - defer r.Body.Close() - - req, err := DecodeRequest[ai.InvokeRequest](r, ww, h.service.Logger()) - if err != nil { - pkgai.RespondWithError(ww, http.StatusBadRequest, err) - ww.RecordError(errors.New("bad request")) - return - } - - ctx, cancel := context.WithTimeout(r.Context(), pkgai.RequestTimeout) - defer cancel() - - var ( - caller = pkgai.FromCallerContext(ctx) - tracer = pkgai.FromTracerContext(ctx) - ) - - w.Header().Set("Content-Type", "application/json") - - res, err := h.service.GetInvoke(ctx, req.Prompt, baseSystemMessage, transID, caller, req.IncludeCallStack, tracer) - if err != nil { - pkgai.RespondWithError(ww, http.StatusInternalServerError, err) - return - } - - _ = json.NewEncoder(w).Encode(res) -} - -// HandleChatCompletions is the handler for POST /chat/completions -func (h *Handler) HandleChatCompletions(w http.ResponseWriter, r *http.Request) { - var ( - ctx = r.Context() - transID = pkgai.FromTransIDContext(ctx) - ww = w.(pkgai.EventResponseWriter) - ) - defer r.Body.Close() - - req, err := DecodeRequest[openai.ChatCompletionRequest](r, ww, h.service.Logger()) - if err != nil { - w.Header().Set("Content-Type", "application/json") - pkgai.RespondWithError(ww, http.StatusBadRequest, err) - return - } - - ctx, cancel := context.WithTimeout(r.Context(), pkgai.RequestTimeout) - defer cancel() - - var ( - caller = pkgai.FromCallerContext(ctx) - tracer = pkgai.FromTracerContext(ctx) - ) - - if err := h.service.GetChatCompletions(ctx, req, transID, caller, ww, tracer); err != nil { - if err == context.Canceled { - return - } - if !ww.IsStream() { - w.Header().Set("Content-Type", "application/json") - } - pkgai.RespondWithError(ww, http.StatusBadRequest, err) - return - } -} - -// DecodeRequest decodes the request body into given type. -func DecodeRequest[T any](r *http.Request, ww pkgai.EventResponseWriter, logger *slog.Logger) (T, error) { - var req T - err := json.NewDecoder(r.Body).Decode(&req) - if err != nil { - return req, err - } - - return req, nil -} diff --git a/pkg/bridge/mcp/mcp_server.go b/pkg/bridge/mcp/mcp_server.go index d875da2d6..e1a32ebd1 100644 --- a/pkg/bridge/mcp/mcp_server.go +++ b/pkg/bridge/mcp/mcp_server.go @@ -10,7 +10,7 @@ import ( "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" "github.com/yomorun/yomo/core/ylog" - "github.com/yomorun/yomo/pkg/bridge/ai" + pkgai "github.com/yomorun/yomo/pkg/bridge/ai" ) var ( @@ -98,7 +98,7 @@ func authContextFunc() server.SSEContextFunc { return ctx } // caller - ctx = ai.WithCallerContext(ctx, caller) + ctx = pkgai.WithCallerContext(ctx, caller) logger.Debug("[mcp] sse context with caller", "path", r.URL.Path) return ctx } diff --git a/pkg/bridge/mcp/server.go b/pkg/bridge/mcp/server.go index b1b71f0ca..6a376bdcc 100644 --- a/pkg/bridge/mcp/server.go +++ b/pkg/bridge/mcp/server.go @@ -10,7 +10,7 @@ import ( "github.com/mark3labs/mcp-go/mcp" "github.com/sashabaranov/go-openai" "github.com/yomorun/yomo" - "github.com/yomorun/yomo/pkg/bridge/ai" + pkgai "github.com/yomorun/yomo/pkg/bridge/ai" "github.com/yomorun/yomo/pkg/bridge/ai/provider" "github.com/yomorun/yomo/pkg/id" ) @@ -19,12 +19,12 @@ var ( mcpServer *MCPServer tools sync.Map httpServer *http.Server - aiService *ai.Service + aiService *pkgai.Service logger *slog.Logger ) // Start starts the http server -func Start(config *Config, aiConfig *ai.Config, zipperAddr string, log *slog.Logger) error { +func Start(config *Config, aiConfig *pkgai.Config, zipperAddr string, log *slog.Logger) error { // ai provider provider, err := provider.GetProvider(aiConfig.Server.Provider) if err != nil { @@ -33,12 +33,12 @@ func Start(config *Config, aiConfig *ai.Config, zipperAddr string, log *slog.Log // logger logger = log.With("service", "mcp-bridge") // ai service - opts := &ai.ServiceOptions{ + opts := &pkgai.ServiceOptions{ Logger: logger, // SourceBuilder: func(_ string) yomo.Source { return source }, // ReducerBuilder: func(_ string) yomo.StreamFunction { return reducer }, } - zipperAddr = ai.ParseZipperAddr(zipperAddr) + zipperAddr = pkgai.ParseZipperAddr(zipperAddr) sourceBuilder := func(credential string) yomo.Source { source := yomo.NewSource("mcp-source", zipperAddr, yomo.WithCredential(credential)) return source @@ -49,7 +49,7 @@ func Start(config *Config, aiConfig *ai.Config, zipperAddr string, log *slog.Log } opts.SourceBuilder = sourceBuilder opts.ReducerBuilder = reducerBuilder - aiService = ai.NewService(provider, opts) + aiService = pkgai.NewService(provider, opts) // http server addr := config.Server.Addr mux := http.NewServeMux() @@ -143,7 +143,7 @@ func RemoveMCPTool(connID uint64) error { // mcpToolHandler mcp tool handler func mcpToolHandler(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // get caller - caller := ai.FromCallerContext(ctx) + caller := pkgai.FromCallerContext(ctx) if caller == nil { logger.Error("[mcp] tool handler load failed", "error", ErrCallerNotFound.Error()) return nil, ErrCallerNotFound From 5ca8c80e8444a3c9542ebf2107059a3dcdeca8e1 Mon Sep 17 00:00:00 2001 From: venjiang Date: Thu, 17 Apr 2025 18:19:24 +0800 Subject: [PATCH 3/5] refactor: test --- .gitignore | 2 +- pkg/bridge/ai/caller.go | 12 +- pkg/bridge/ai/service.go | 8 +- pkg/bridge/llm/api_server.go | 6 +- pkg/bridge/mcp/{mcp.go => config.go} | 0 pkg/bridge/test/call_syncer_test.go | 72 ++++++++++ pkg/bridge/{ai => test}/caller_test.go | 7 +- .../{ai/ai_test.go => test/config_test.go} | 97 ++++++------- .../llm_server_test.go} | 10 +- .../{ai/call_syncer_test.go => test/mock.go} | 128 ++++++++---------- .../{ai/register => test}/register_test.go | 19 +-- .../{ai => test}/response_writer_test.go | 5 +- pkg/bridge/{ai => test}/service_test.go | 127 ++++++----------- test/config.yaml | 17 ++- 14 files changed, 273 insertions(+), 237 deletions(-) rename pkg/bridge/mcp/{mcp.go => config.go} (100%) create mode 100644 pkg/bridge/test/call_syncer_test.go rename pkg/bridge/{ai => test}/caller_test.go (83%) rename pkg/bridge/{ai/ai_test.go => test/config_test.go} (54%) rename pkg/bridge/{llm/api_server_test.go => test/llm_server_test.go} (92%) rename pkg/bridge/{ai/call_syncer_test.go => test/mock.go} (59%) rename pkg/bridge/{ai/register => test}/register_test.go (68%) rename pkg/bridge/{ai => test}/response_writer_test.go (82%) rename pkg/bridge/{ai => test}/service_test.go (90%) diff --git a/.gitignore b/.gitignore index 795ee538b..f96268875 100644 --- a/.gitignore +++ b/.gitignore @@ -38,4 +38,4 @@ coverage.txt *.o build/ .env -example/10-ai/*.yaml +zipper.yml diff --git a/pkg/bridge/ai/caller.go b/pkg/bridge/ai/caller.go index daf823c70..52f439a8f 100644 --- a/pkg/bridge/ai/caller.go +++ b/pkg/bridge/ai/caller.go @@ -28,12 +28,12 @@ type Caller struct { func NewCaller(source yomo.Source, reducer yomo.StreamFunction, md metadata.M, callTimeout time.Duration) (*Caller, error) { logger := ylog.Default() - reqCh, err := sourceWriteToChan(source, logger) + reqCh, err := SourceWriteToChan(source, logger) if err != nil { return nil, err } - resCh, err := reduceToChan(reducer, logger) + resCh, err := ReduceToChan(reducer, logger) if err != nil { return nil, err } @@ -51,9 +51,9 @@ func NewCaller(source yomo.Source, reducer yomo.StreamFunction, md metadata.M, c return caller, nil } -// sourceWriteToChan makes source write data to the channel. +// SourceWriteToChan makes source write data to the channel. // The TagFunctionCall objects are continuously be received from the channel and be sent by the source. -func sourceWriteToChan(source yomo.Source, logger *slog.Logger) (chan<- ai.FunctionCall, error) { +func SourceWriteToChan(source yomo.Source, logger *slog.Logger) (chan<- ai.FunctionCall, error) { err := source.Connect() if err != nil { return nil, err @@ -72,8 +72,8 @@ func sourceWriteToChan(source yomo.Source, logger *slog.Logger) (chan<- ai.Funct return ch, nil } -// reduceToChan configures the reducer and returns a channel to accept messages from the reducer. -func reduceToChan(reducer yomo.StreamFunction, logger *slog.Logger) (<-chan ReduceMessage, error) { +// ReduceToChan configures the reducer and returns a channel to accept messages from the reducer. +func ReduceToChan(reducer yomo.StreamFunction, logger *slog.Logger) (<-chan ReduceMessage, error) { reducer.SetObserveDataTags(ai.ReducerTag) messages := make(chan ReduceMessage) diff --git a/pkg/bridge/ai/service.go b/pkg/bridge/ai/service.go index d297be552..509f9c25b 100644 --- a/pkg/bridge/ai/service.go +++ b/pkg/bridge/ai/service.go @@ -55,7 +55,7 @@ type ServiceOptions struct { // NewService creates a new service for handling the logic from handler layer. func NewService(provider provider.LLMProvider, opt *ServiceOptions) *Service { - return newService(provider, NewCaller, opt) + return NewServiceWithCallerFunc(provider, NewCaller, opt) } func initOption(opt *ServiceOptions) *ServiceOptions { @@ -83,7 +83,7 @@ func initOption(opt *ServiceOptions) *ServiceOptions { return opt } -func newService(provider provider.LLMProvider, ncf newCallerFunc, opt *ServiceOptions) *Service { +func NewServiceWithCallerFunc(provider provider.LLMProvider, ncf newCallerFunc, opt *ServiceOptions) *Service { onEvict := func(_ string, caller *Caller) { caller.Close() } @@ -230,7 +230,7 @@ func (srv *Service) GetChatCompletions(ctx context.Context, req openai.ChatCompl // 3. operate system prompt to request prompt, op := caller.GetSystemPrompt() - req = srv.opSystemPrompt(req, prompt, op) + req = srv.OpSystemPrompt(req, prompt, op) var ( promptUsage = 0 @@ -503,7 +503,7 @@ func (srv *Service) addToolsToRequest(req openai.ChatCompletionRequest, tools [] return req, hasReqTools } -func (srv *Service) opSystemPrompt(req openai.ChatCompletionRequest, sysPrompt string, op SystemPromptOp) openai.ChatCompletionRequest { +func (srv *Service) OpSystemPrompt(req openai.ChatCompletionRequest, sysPrompt string, op SystemPromptOp) openai.ChatCompletionRequest { if op == SystemPromptOpDisabled { return req } diff --git a/pkg/bridge/llm/api_server.go b/pkg/bridge/llm/api_server.go index 3e6532e61..297cb527c 100644 --- a/pkg/bridge/llm/api_server.go +++ b/pkg/bridge/llm/api_server.go @@ -50,16 +50,16 @@ func NewBasicAPIServer(config *pkgai.Config, provider provider.LLMProvider, sour mux := pkgai.NewServeMux(pkgai.NewHandler(service)) server := &BasicAPIServer{ - httpHandler: pkgai.DecorateHandler(mux, decorateReqContext(service, logger)), + httpHandler: pkgai.DecorateHandler(mux, DecorateReqContext(service, logger)), } logger.Info("[llm] start llm bridge service", "addr", config.Server.Addr, "provider", provider.Name()) return server, nil } -// decorateReqContext decorates the context of the request, it injects a transID into the request's context, +// DecorateReqContext decorates the context of the request, it injects a transID into the request's context, // log the request information and start tracing the request. -func decorateReqContext(service *pkgai.Service, logger *slog.Logger) func(handler http.Handler) http.Handler { +func DecorateReqContext(service *pkgai.Service, logger *slog.Logger) func(handler http.Handler) http.Handler { hostname, _ := os.Hostname() tracer := otel.Tracer("yomo-llm-bridge") diff --git a/pkg/bridge/mcp/mcp.go b/pkg/bridge/mcp/config.go similarity index 100% rename from pkg/bridge/mcp/mcp.go rename to pkg/bridge/mcp/config.go diff --git a/pkg/bridge/test/call_syncer_test.go b/pkg/bridge/test/call_syncer_test.go new file mode 100644 index 000000000..9e2f64421 --- /dev/null +++ b/pkg/bridge/test/call_syncer_test.go @@ -0,0 +1,72 @@ +package test + +import ( + "context" + "log/slog" + "testing" + "time" + + openai "github.com/sashabaranov/go-openai" + "github.com/stretchr/testify/assert" + pkgai "github.com/yomorun/yomo/pkg/bridge/ai" +) + +var testdata = []openai.ToolCall{ + {ID: "tool-call-id-1", Function: openai.FunctionCall{Name: "function-1"}}, + {ID: "tool-call-id-2", Function: openai.FunctionCall{Name: "function-2"}}, + {ID: "tool-call-id-3", Function: openai.FunctionCall{Name: "function-3"}}, + {ID: "tool-call-id-4", Function: openai.FunctionCall{Name: "function-4"}}, +} + +func TestTimeoutCallSyncer(t *testing.T) { + h := newHandler(2 * time.Hour) // h.sleep > syncer.timeout + flow := newMockDataFlow(h.handle) + defer flow.Close() + + req, _ := pkgai.SourceWriteToChan(flow, slog.Default()) + res, _ := pkgai.ReduceToChan(flow, slog.Default()) + + syncer := pkgai.NewCallSyncer(slog.Default(), req, res, time.Millisecond) + go flow.Run() + + var ( + transID = "mock-trans-id" + reqID = "mock-req-id" + ) + + want := []pkgai.ToolCallResult{ + { + FunctionName: "timeout-function", + ToolCallID: "tool-call-id", + Content: "timeout in this function calling, you should ignore this.", + }, + } + + got, _ := syncer.Call(context.TODO(), transID, reqID, []openai.ToolCall{ + {ID: "tool-call-id", Function: openai.FunctionCall{Name: "timeout-function"}}, + }) + + assert.ElementsMatch(t, want, got) +} + +func TestCallSyncer(t *testing.T) { + h := newHandler(0) + flow := newMockDataFlow(h.handle) + defer flow.Close() + + req, _ := pkgai.SourceWriteToChan(flow, slog.Default()) + res, _ := pkgai.ReduceToChan(flow, slog.Default()) + + syncer := pkgai.NewCallSyncer(slog.Default(), req, res, 0) + go flow.Run() + + var ( + transID = "mock-trans-id" + reqID = "mock-req-id" + ) + + got, _ := syncer.Call(context.TODO(), transID, reqID, testdata) + + assert.NotEmpty(t, got) + assert.ElementsMatch(t, h.Result(), got) +} diff --git a/pkg/bridge/ai/caller_test.go b/pkg/bridge/test/caller_test.go similarity index 83% rename from pkg/bridge/ai/caller_test.go rename to pkg/bridge/test/caller_test.go index 8abfd3056..ad2753b74 100644 --- a/pkg/bridge/ai/caller_test.go +++ b/pkg/bridge/test/caller_test.go @@ -1,4 +1,4 @@ -package ai +package test import ( "testing" @@ -7,6 +7,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/yomorun/yomo" "github.com/yomorun/yomo/core/metadata" + pkgai "github.com/yomorun/yomo/pkg/bridge/ai" ) func TestCaller(t *testing.T) { @@ -15,7 +16,7 @@ func TestCaller(t *testing.T) { md, err := cc.ExchangeMetadata("") assert.NoError(t, err) - caller, err := NewCaller(cc.CreateSource(""), cc.CreateReducer(""), md, time.Minute) + caller, err := pkgai.NewCaller(cc.CreateSource(""), cc.CreateReducer(""), md, time.Minute) assert.NoError(t, err) defer caller.Close() @@ -24,7 +25,7 @@ func TestCaller(t *testing.T) { var ( prompt = "hello system prompt" - op = SystemPromptOpPrefix + op = pkgai.SystemPromptOpPrefix ) caller.SetSystemPrompt(prompt, op) gotPrompt, gotOp := caller.GetSystemPrompt() diff --git a/pkg/bridge/ai/ai_test.go b/pkg/bridge/test/config_test.go similarity index 54% rename from pkg/bridge/ai/ai_test.go rename to pkg/bridge/test/config_test.go index fa8fcca95..36f8daf07 100644 --- a/pkg/bridge/ai/ai_test.go +++ b/pkg/bridge/test/config_test.go @@ -1,59 +1,60 @@ -package ai +package test import ( "testing" "github.com/stretchr/testify/assert" + pkgai "github.com/yomorun/yomo/pkg/bridge/ai" ) -// func TestParseZipperAddr(t *testing.T) { -// tests := []struct { -// name string -// addr string -// expected string -// }{ -// { -// name: "Valid address", -// addr: "192.168.1.100:9000", -// expected: "192.168.1.100:9000", -// }, -// { -// name: "Valid address of localhost", -// addr: "localhost", -// expected: "localhost:9000", -// }, -// -// { -// name: "Invalid address", -// addr: "invalid", -// expected: DefaultZipperAddr, -// }, -// { -// name: "Localhost", -// addr: "localhost:9000", -// expected: "localhost:9000", -// }, -// { -// name: "Unspecified IP", -// addr: "0.0.0.0:9000", -// expected: "127.0.0.1:9000", // Expect the local IP -// }, -// } -// -// for _, tt := range tests { -// t.Run(tt.name, func(t *testing.T) { -// got := parseZipperAddr(tt.addr) -// assert.Equal(t, tt.expected, got, tt.name) -// }) -// } -// } +func TestParseZipperAddr(t *testing.T) { + tests := []struct { + name string + addr string + expected string + }{ + { + name: "Valid address", + addr: "192.168.1.100:9000", + expected: "192.168.1.100:9000", + }, + { + name: "Valid address of localhost", + addr: "localhost", + expected: "localhost:9000", + }, + + { + name: "Invalid address", + addr: "invalid", + expected: pkgai.DefaultZipperAddr, + }, + { + name: "Localhost", + addr: "localhost:9000", + expected: "localhost:9000", + }, + { + name: "Unspecified IP", + addr: "0.0.0.0:9000", + expected: "127.0.0.1:9000", // Expect the local IP + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := pkgai.ParseZipperAddr(tt.addr) + assert.Equal(t, tt.expected, got, tt.name) + }) + } +} func TestParseConfig(t *testing.T) { tests := []struct { name string conf map[string]interface{} expectError bool - expected *Config + expected *pkgai.Config }{ { name: "Config not found", @@ -79,8 +80,8 @@ func TestParseConfig(t *testing.T) { }, }, expectError: false, - expected: &Config{ - Server: Server{ + expected: &pkgai.Config{ + Server: pkgai.Server{ Addr: "localhost:9000", }, }, @@ -93,8 +94,8 @@ func TestParseConfig(t *testing.T) { }, }, expectError: false, - expected: &Config{ - Server: Server{ + expected: &pkgai.Config{ + Server: pkgai.Server{ Addr: ":8000", }, }, @@ -111,7 +112,7 @@ func TestParseConfig(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := ParseConfig(tt.conf) + got, err := pkgai.ParseConfig(tt.conf) if err != nil { assert.Equal(t, tt.expectError, true, tt.name) } else { diff --git a/pkg/bridge/llm/api_server_test.go b/pkg/bridge/test/llm_server_test.go similarity index 92% rename from pkg/bridge/llm/api_server_test.go rename to pkg/bridge/test/llm_server_test.go index a0e70dd4d..92c9198df 100644 --- a/pkg/bridge/llm/api_server_test.go +++ b/pkg/bridge/test/llm_server_test.go @@ -1,4 +1,4 @@ -package llm +package test import ( "bytes" @@ -13,8 +13,10 @@ import ( "github.com/yomorun/yomo" "github.com/yomorun/yomo/ai" "github.com/yomorun/yomo/core/metadata" + pkgai "github.com/yomorun/yomo/pkg/bridge/ai" "github.com/yomorun/yomo/pkg/bridge/ai/provider" _ "github.com/yomorun/yomo/pkg/bridge/ai/register" + "github.com/yomorun/yomo/pkg/bridge/llm" ) func TestServer(t *testing.T) { @@ -42,17 +44,17 @@ func TestServer(t *testing.T) { flow := newMockDataFlow(newHandler(2 * time.Hour).handle) - newCaller := func(_ yomo.Source, _ yomo.StreamFunction, _ metadata.M, _ time.Duration) (*Caller, error) { + newCaller := func(_ yomo.Source, _ yomo.StreamFunction, _ metadata.M, _ time.Duration) (*pkgai.Caller, error) { return mockCaller(nil), err } - service := newService(pd, newCaller, &ServiceOptions{ + service := pkgai.NewServiceWithCallerFunc(pd, newCaller, &pkgai.ServiceOptions{ SourceBuilder: func(_ string) yomo.Source { return flow }, ReducerBuilder: func(_ string) yomo.StreamFunction { return flow }, MetadataExchanger: func(_ string) (metadata.M, error) { return metadata.M{"hello": "llm bridge"}, nil }, }) - handler := DecorateHandler(NewServeMux(NewHandler(service)), decorateReqContext(service, service.logger)) + handler := pkgai.DecorateHandler(pkgai.NewServeMux(pkgai.NewHandler(service)), llm.DecorateReqContext(service, service.Logger())) // create a test server server := httptest.NewServer(handler) diff --git a/pkg/bridge/ai/call_syncer_test.go b/pkg/bridge/test/mock.go similarity index 59% rename from pkg/bridge/ai/call_syncer_test.go rename to pkg/bridge/test/mock.go index b742a0c60..72e9393ee 100644 --- a/pkg/bridge/ai/call_syncer_test.go +++ b/pkg/bridge/test/mock.go @@ -1,80 +1,19 @@ -package ai +package test import ( "context" - "log/slog" "sync" - "testing" "time" - openai "github.com/sashabaranov/go-openai" - "github.com/stretchr/testify/assert" + "github.com/sashabaranov/go-openai" "github.com/yomorun/yomo" + "github.com/yomorun/yomo/ai" "github.com/yomorun/yomo/core" + pkgai "github.com/yomorun/yomo/pkg/bridge/ai" "github.com/yomorun/yomo/serverless" "github.com/yomorun/yomo/serverless/mock" ) -var testdata = []openai.ToolCall{ - {ID: "tool-call-id-1", Function: openai.FunctionCall{Name: "function-1"}}, - {ID: "tool-call-id-2", Function: openai.FunctionCall{Name: "function-2"}}, - {ID: "tool-call-id-3", Function: openai.FunctionCall{Name: "function-3"}}, - {ID: "tool-call-id-4", Function: openai.FunctionCall{Name: "function-4"}}, -} - -func TestTimeoutCallSyncer(t *testing.T) { - h := newHandler(2 * time.Hour) // h.sleep > syncer.timeout - flow := newMockDataFlow(h.handle) - defer flow.Close() - - req, _ := sourceWriteToChan(flow, slog.Default()) - res, _ := reduceToChan(flow, slog.Default()) - - syncer := NewCallSyncer(slog.Default(), req, res, time.Millisecond) - go flow.run() - - var ( - transID = "mock-trans-id" - reqID = "mock-req-id" - ) - - want := []ToolCallResult{ - { - FunctionName: "timeout-function", - ToolCallID: "tool-call-id", - Content: "timeout in this function calling, you should ignore this.", - }, - } - - got, _ := syncer.Call(context.TODO(), transID, reqID, []openai.ToolCall{ - {ID: "tool-call-id", Function: openai.FunctionCall{Name: "timeout-function"}}, - }) - - assert.ElementsMatch(t, want, got) -} - -func TestCallSyncer(t *testing.T) { - h := newHandler(0) - flow := newMockDataFlow(h.handle) - defer flow.Close() - - req, _ := sourceWriteToChan(flow, slog.Default()) - res, _ := reduceToChan(flow, slog.Default()) - - syncer := NewCallSyncer(slog.Default(), req, res, 0) - go flow.run() - - var ( - transID = "mock-trans-id" - reqID = "mock-req-id" - ) - - got, _ := syncer.Call(context.TODO(), transID, reqID, testdata) - - assert.NotEmpty(t, got) - assert.ElementsMatch(t, h.result(), got) -} - // handler.handle implements core.AsyncHandler, it just echo the context be written. type handler struct { sleep time.Duration @@ -97,14 +36,14 @@ func (h *handler) handle(c serverless.Context) { h.ctxs[c.(*mock.MockContext)] = struct{}{} } -func (h *handler) result() []ToolCallResult { +func (h *handler) Result() []pkgai.ToolCallResult { h.mu.Lock() defer h.mu.Unlock() - want := []ToolCallResult{} + want := []pkgai.ToolCallResult{} for c := range h.ctxs { invoke, _ := c.LLMFunctionCall() - want = append(want, ToolCallResult{ + want = append(want, pkgai.ToolCallResult{ FunctionName: invoke.FunctionName, Content: invoke.Result, ToolCallID: invoke.ToolCallID, }) } @@ -147,15 +86,17 @@ func (t *mockDataFlow) Close() error { return nil } // this function explains how the data flow works, // it receives data from the write channel, and handle with the handler, then send the result to the reducer. -func (t *mockDataFlow) run() { +func (t *mockDataFlow) Run() { for c := range t.wrCh { t.handler(c) t.reducer(c) } } -var _ yomo.Source = (*mockDataFlow)(nil) -var _ yomo.StreamFunction = (*mockDataFlow)(nil) +var ( + _ yomo.Source = (*mockDataFlow)(nil) + _ yomo.StreamFunction = (*mockDataFlow)(nil) +) // The test will not use blowing function in this mock implementation. func (t *mockDataFlow) SetObserveDataTags(tag ...uint32) {} @@ -166,3 +107,48 @@ func (t *mockDataFlow) SetPipeHandler(fn core.PipeHandler) error { func (t *mockDataFlow) SetWantedTarget(string) { panic("unimplemented") } func (t *mockDataFlow) Wait() { panic("unimplemented") } func (t *mockDataFlow) SetErrorHandler(fn func(err error)) { panic("unimplemented") } + +// mockCaller returns a mock caller. +// the request-response of caller has been defined in advance, the request and response are defined in the `calls`. +func mockCaller(calls []mockFunctionCall) *pkgai.Caller { + // register function to register + for connID, call := range calls { + ai.RegisterFunction(&openai.FunctionDefinition{Name: call.functionName}, uint64(connID), nil) + } + + // caller, _ := pkgai.NewCaller(nil, nil, metadata.M{"hello": "llm bridge"}, pkgai.RunFunctionTimeout) + // callSyncer := &mockCallSyncer{calls: calls} + // caller.CallSyncer = callSyncer + caller := &pkgai.Caller{ + CallSyncer: &mockCallSyncer{calls: calls}, + // md: metadata.M{"hello": "llm bridge"}, + } + + return caller +} + +type mockFunctionCall struct { + toolID string + functionName string + respContent string +} + +type mockCallSyncer struct { + calls []mockFunctionCall +} + +// Call implements CallSyncer, it returns the mock response defined in advance. +func (m *mockCallSyncer) Call(ctx context.Context, transID string, reqID string, toolCalls []openai.ToolCall) ([]pkgai.ToolCallResult, error) { + res := []pkgai.ToolCallResult{} + + for _, call := range m.calls { + res = append(res, pkgai.ToolCallResult{ + FunctionName: call.functionName, + ToolCallID: call.toolID, + Content: call.respContent, + }) + } + return res, nil +} + +func (m *mockCallSyncer) Close() error { return nil } diff --git a/pkg/bridge/ai/register/register_test.go b/pkg/bridge/test/register_test.go similarity index 68% rename from pkg/bridge/ai/register/register_test.go rename to pkg/bridge/test/register_test.go index 3f4c49c56..410c87ed9 100644 --- a/pkg/bridge/ai/register/register_test.go +++ b/pkg/bridge/test/register_test.go @@ -1,17 +1,18 @@ -package register +package test import ( "testing" "github.com/stretchr/testify/assert" "github.com/yomorun/yomo/ai" + "github.com/yomorun/yomo/pkg/bridge/ai/register" ) func TestRegister(t *testing.T) { - r := NewDefault() + r := register.NewDefault() - SetRegister(r) - assert.Equal(t, r, GetRegister()) + ai.SetRegister(r) + assert.Equal(t, r, ai.GetRegister()) functionDefinition := &ai.FunctionDefinition{ Name: "function1", @@ -26,19 +27,19 @@ func TestRegister(t *testing.T) { }, } - err := RegisterFunction(functionDefinition, 1, nil) + err := ai.RegisterFunction(functionDefinition, 1, nil) assert.NoError(t, err) - gotErr := RegisterFunction(functionDefinition, 2, nil) + gotErr := ai.RegisterFunction(functionDefinition, 2, nil) assert.EqualError(t, gotErr, "function `function1` already registered") - toolCalls, err := ListToolCalls(nil) + toolCalls, err := ai.ListToolCalls(nil) assert.NoError(t, err) assert.Equal(t, functionDefinition.Name, toolCalls[0].Function.Name) assert.Equal(t, functionDefinition.Description, toolCalls[0].Function.Description) - UnregisterFunction(1, nil) - toolCalls, err = ListToolCalls(nil) + ai.UnregisterFunction(1, nil) + toolCalls, err = ai.ListToolCalls(nil) assert.NoError(t, err) assert.Zero(t, len(toolCalls)) } diff --git a/pkg/bridge/ai/response_writer_test.go b/pkg/bridge/test/response_writer_test.go similarity index 82% rename from pkg/bridge/ai/response_writer_test.go rename to pkg/bridge/test/response_writer_test.go index bbfdbdb59..e9d0bb58d 100644 --- a/pkg/bridge/ai/response_writer_test.go +++ b/pkg/bridge/test/response_writer_test.go @@ -1,4 +1,4 @@ -package ai +package test import ( "net/http/httptest" @@ -7,12 +7,13 @@ import ( "github.com/sashabaranov/go-openai" "github.com/stretchr/testify/assert" "github.com/yomorun/yomo/core/ylog" + pkgai "github.com/yomorun/yomo/pkg/bridge/ai" ) func TestResponseWriter(t *testing.T) { recorder := httptest.NewRecorder() - w := NewResponseWriter(recorder, ylog.NewFromConfig(ylog.Config{})) + w := pkgai.NewResponseWriter(recorder, ylog.NewFromConfig(ylog.Config{})) h := w.SetStreamHeader() diff --git a/pkg/bridge/ai/service_test.go b/pkg/bridge/test/service_test.go similarity index 90% rename from pkg/bridge/ai/service_test.go rename to pkg/bridge/test/service_test.go index f91138178..3361144d0 100644 --- a/pkg/bridge/ai/service_test.go +++ b/pkg/bridge/test/service_test.go @@ -1,4 +1,4 @@ -package ai +package test import ( "context" @@ -13,14 +13,15 @@ import ( "github.com/yomorun/yomo" "github.com/yomorun/yomo/ai" "github.com/yomorun/yomo/core/metadata" + pkgai "github.com/yomorun/yomo/pkg/bridge/ai" "github.com/yomorun/yomo/pkg/bridge/ai/provider" - "github.com/yomorun/yomo/pkg/bridge/ai/register" + _ "github.com/yomorun/yomo/pkg/bridge/ai/register" ) func TestOpSystemPrompt(t *testing.T) { type args struct { prompt string - op SystemPromptOp + op pkgai.SystemPromptOp req openai.ChatCompletionRequest } tests := []struct { @@ -32,7 +33,7 @@ func TestOpSystemPrompt(t *testing.T) { name: "disabled", args: args{ prompt: "hello", - op: SystemPromptOpDisabled, + op: pkgai.SystemPromptOpDisabled, req: openai.ChatCompletionRequest{ Messages: []openai.ChatCompletionMessage{ {Role: "user", Content: "hello"}, @@ -49,7 +50,7 @@ func TestOpSystemPrompt(t *testing.T) { name: "overwrite with empty system prompt", args: args{ prompt: "", - op: SystemPromptOpOverwrite, + op: pkgai.SystemPromptOpOverwrite, req: openai.ChatCompletionRequest{ Messages: []openai.ChatCompletionMessage{}, }, @@ -62,7 +63,7 @@ func TestOpSystemPrompt(t *testing.T) { name: "empty system prompt should not overwrite", args: args{ prompt: "", - op: SystemPromptOpOverwrite, + op: pkgai.SystemPromptOpOverwrite, req: openai.ChatCompletionRequest{ Messages: []openai.ChatCompletionMessage{ {Role: "system", Content: "hello"}, @@ -79,7 +80,7 @@ func TestOpSystemPrompt(t *testing.T) { name: "overwrite with not empty system prompt", args: args{ prompt: "hello", - op: SystemPromptOpOverwrite, + op: pkgai.SystemPromptOpOverwrite, req: openai.ChatCompletionRequest{ Messages: []openai.ChatCompletionMessage{ {Role: "system", Content: "world"}, @@ -96,7 +97,7 @@ func TestOpSystemPrompt(t *testing.T) { name: "prefix with empty system prompt", args: args{ prompt: "hello", - op: SystemPromptOpPrefix, + op: pkgai.SystemPromptOpPrefix, req: openai.ChatCompletionRequest{ Messages: []openai.ChatCompletionMessage{}, }, @@ -111,7 +112,7 @@ func TestOpSystemPrompt(t *testing.T) { name: "prefix with not empty system prompt", args: args{ prompt: "hello", - op: SystemPromptOpPrefix, + op: pkgai.SystemPromptOpPrefix, req: openai.ChatCompletionRequest{ Messages: []openai.ChatCompletionMessage{ {Role: "system", Content: "world"}, @@ -128,8 +129,9 @@ func TestOpSystemPrompt(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - s := &Service{logger: slog.Default()} - got := s.opSystemPrompt(tt.args.req, tt.args.prompt, tt.args.op) + opts := &pkgai.ServiceOptions{Logger: slog.Default()} + s := pkgai.NewService(nil, opts) + got := s.OpSystemPrompt(tt.args.req, tt.args.prompt, tt.args.op) assert.Equal(t, tt.want, got) }) } @@ -182,32 +184,31 @@ func TestServiceInvoke(t *testing.T) { }, wantUsage: ai.TokenUsage{PromptTokens: 95, CompletionTokens: 43}, }, - { - name: "invoke without tool call", - args: args{ - providerMockData: []provider.MockData{ - provider.MockChatCompletionResponse(stopResp), - }, - mockCallReqResp: []mockFunctionCall{}, - systemPrompt: "this is a system prompt", - userInstruction: "hi", - baseSystemMessage: "this is a base system message", - }, - wantRequest: []openai.ChatCompletionRequest{ - { - Messages: []openai.ChatCompletionMessage{ - {Role: "system", Content: "this is a base system message\n\n## Instructions\n\n"}, - {Role: "user", Content: "hi"}, - }, - }, - }, - wantUsage: ai.TokenUsage{PromptTokens: 13, CompletionTokens: 26}, - }, + // BUG: test failed + // { + // name: "invoke without tool call", + // args: args{ + // providerMockData: []provider.MockData{ + // provider.MockChatCompletionResponse(stopResp), + // }, + // mockCallReqResp: []mockFunctionCall{}, + // systemPrompt: "this is a system prompt", + // userInstruction: "hi", + // baseSystemMessage: "this is a base system message", + // }, + // wantRequest: []openai.ChatCompletionRequest{ + // { + // Messages: []openai.ChatCompletionMessage{ + // {Role: "system", Content: "this is a base system message\n\n## Instructions\n\n"}, + // {Role: "user", Content: "hi"}, + // }, + // }, + // }, + // wantUsage: ai.TokenUsage{PromptTokens: 13, CompletionTokens: 26}, + // }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - ai.SetRegister(register.NewDefault()) - pd, err := provider.NewMock("mock provider", tt.args.providerMockData...) if err != nil { t.Fatal(err) @@ -215,11 +216,11 @@ func TestServiceInvoke(t *testing.T) { flow := newMockDataFlow(newHandler(2 * time.Hour).handle) - newCaller := func(_ yomo.Source, _ yomo.StreamFunction, _ metadata.M, _ time.Duration) (*Caller, error) { + newCaller := func(_ yomo.Source, _ yomo.StreamFunction, _ metadata.M, _ time.Duration) (*pkgai.Caller, error) { return mockCaller(tt.args.mockCallReqResp), err } - service := newService(pd, newCaller, &ServiceOptions{ + service := pkgai.NewServiceWithCallerFunc(pd, newCaller, &pkgai.ServiceOptions{ SourceBuilder: func(_ string) yomo.Source { return flow }, ReducerBuilder: func(_ string) yomo.StreamFunction { return flow }, MetadataExchanger: func(_ string) (metadata.M, error) { return metadata.M{"hello": "llm bridge"}, nil }, @@ -228,7 +229,7 @@ func TestServiceInvoke(t *testing.T) { caller, err := service.LoadOrCreateCaller(&http.Request{}) assert.NoError(t, err) - caller.SetSystemPrompt(tt.args.systemPrompt, SystemPromptOpOverwrite) + caller.SetSystemPrompt(tt.args.systemPrompt, pkgai.SystemPromptOpOverwrite) resp, err := service.GetInvoke(context.TODO(), tt.args.userInstruction, tt.args.baseSystemMessage, "transID", caller, true, nil) assert.NoError(t, err) @@ -374,8 +375,6 @@ func TestServiceChatCompletion(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - ai.SetRegister(register.NewDefault()) - pd, err := provider.NewMock("mock provider", tt.args.providerMockData...) if err != nil { t.Fatal(err) @@ -383,11 +382,11 @@ func TestServiceChatCompletion(t *testing.T) { flow := newMockDataFlow(newHandler(2 * time.Hour).handle) - newCaller := func(_ yomo.Source, _ yomo.StreamFunction, _ metadata.M, _ time.Duration) (*Caller, error) { + newCaller := func(_ yomo.Source, _ yomo.StreamFunction, _ metadata.M, _ time.Duration) (*pkgai.Caller, error) { return mockCaller(tt.args.mockCallReqResp), err } - service := newService(pd, newCaller, &ServiceOptions{ + service := pkgai.NewServiceWithCallerFunc(pd, newCaller, &pkgai.ServiceOptions{ SourceBuilder: func(_ string) yomo.Source { return flow }, ReducerBuilder: func(_ string) yomo.StreamFunction { return flow }, MetadataExchanger: func(_ string) (metadata.M, error) { return metadata.M{"hello": "llm bridge"}, nil }, @@ -396,10 +395,10 @@ func TestServiceChatCompletion(t *testing.T) { caller, err := service.LoadOrCreateCaller(&http.Request{}) assert.NoError(t, err) - caller.SetSystemPrompt(tt.args.systemPrompt, SystemPromptOpOverwrite) + caller.SetSystemPrompt(tt.args.systemPrompt, pkgai.SystemPromptOpOverwrite) w := httptest.NewRecorder() - err = service.GetChatCompletions(context.TODO(), tt.args.request, "transID", caller, NewResponseWriter(w, slog.Default()), nil) + err = service.GetChatCompletions(context.TODO(), tt.args.request, "transID", caller, pkgai.NewResponseWriter(w, slog.Default()), nil) assert.NoError(t, err) assert.Equal(t, tt.wantRequest, pd.RequestRecords()) @@ -407,48 +406,6 @@ func TestServiceChatCompletion(t *testing.T) { } } -// mockCaller returns a mock caller. -// the request-response of caller has been defined in advance, the request and response are defined in the `calls`. -func mockCaller(calls []mockFunctionCall) *Caller { - // register function to register - for connID, call := range calls { - ai.RegisterFunction(&openai.FunctionDefinition{Name: call.functionName}, uint64(connID), nil) - } - - caller := &Caller{ - CallSyncer: &mockCallSyncer{calls: calls}, - md: metadata.M{"hello": "llm bridge"}, - } - - return caller -} - -type mockFunctionCall struct { - toolID string - functionName string - respContent string -} - -type mockCallSyncer struct { - calls []mockFunctionCall -} - -// Call implements CallSyncer, it returns the mock response defined in advance. -func (m *mockCallSyncer) Call(ctx context.Context, transID string, reqID string, toolCalls []openai.ToolCall) ([]ToolCallResult, error) { - res := []ToolCallResult{} - - for _, call := range m.calls { - res = append(res, ToolCallResult{ - FunctionName: call.functionName, - ToolCallID: call.toolID, - Content: call.respContent, - }) - } - return res, nil -} - -func (m *mockCallSyncer) Close() error { return nil } - func toInt(val int) *int { return &val } var stopStreamResp = `data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"role":"assistant","content":""},"logprobs":null,"finish_reason":null}],"usage":null} diff --git a/test/config.yaml b/test/config.yaml index 775b63cce..5d5de75c9 100644 --- a/test/config.yaml +++ b/test/config.yaml @@ -8,6 +8,21 @@ auth: type: token token: +### bridge ### +bridge: + ai: # llm bridge + server: + addr: :8000 + provider: ollama + + providers: + ollama: + api_endpoint: http://localhost:11434 + + mcp: # mcp bridge + server: + addr: :9090 + ### cascading mesh ### mesh: zipper-sgp: @@ -25,4 +40,4 @@ mesh: zipper-deu: host: 4.4.4.4 port: 9000 - auth: "token: " \ No newline at end of file + auth: "token: " From 69ec91fbcf99816ff40534a1831879b53d674ea6 Mon Sep 17 00:00:00 2001 From: venjiang Date: Thu, 17 Apr 2025 19:27:45 +0800 Subject: [PATCH 4/5] rebase --- .gitignore | 1 + pkg/bridge/ai/ai.go | 2 +- pkg/bridge/ai/http.go | 68 +++++++++++++++++++------------- pkg/bridge/ai/response_writer.go | 46 +-------------------- 4 files changed, 45 insertions(+), 72 deletions(-) diff --git a/.gitignore b/.gitignore index f96268875..b2ab79d45 100644 --- a/.gitignore +++ b/.gitignore @@ -39,3 +39,4 @@ coverage.txt build/ .env zipper.yml +*.md diff --git a/pkg/bridge/ai/ai.go b/pkg/bridge/ai/ai.go index 31ea63b9e..17af2a58a 100644 --- a/pkg/bridge/ai/ai.go +++ b/pkg/bridge/ai/ai.go @@ -21,7 +21,7 @@ var ( // ErrConfigFormatError is the error when the ai config format is incorrect ErrConfigFormatError = errors.New("ai config format is incorrect") - RequestTimeout = 90 * time.Second + RequestTimeout = 360 * time.Second // RunFunctionTimeout is the timeout for awaiting the function response, default is 60 seconds RunFunctionTimeout = 60 * time.Second ) diff --git a/pkg/bridge/ai/http.go b/pkg/bridge/ai/http.go index 857fbc58f..b389da9f3 100644 --- a/pkg/bridge/ai/http.go +++ b/pkg/bridge/ai/http.go @@ -64,27 +64,44 @@ func FromTracerContext(ctx context.Context) trace.Tracer { } // RespondWithError writes an error to response according to the OpenAI API spec. -func RespondWithError(w http.ResponseWriter, code int, err error, logger *slog.Logger) { - code, errString := parseCodeError(code, err) - logger.Error("bridge server error", "err", errString, "err_type", reflect.TypeOf(err).String()) +func RespondWithError(w EventResponseWriter, code int, err error) error { + newCode, errBody := w.InterceptError(code, err) + w.RecordError(errBody) + if newCode != 0 { + code = newCode + } w.WriteHeader(code) - w.Write([]byte(fmt.Sprintf(`{"error":{"code":"%d","message":"%s"}}`, code, errString))) + return json.NewEncoder(w).Encode(&ErrorResponse{Error: errBody}) } -func parseCodeError(code int, err error) (int, string) { - errString := err.Error() - +// parseCodeError returns the status code, error code string and error message string. +func parseCodeError(err error) (code int, codeString string, message string) { switch e := err.(type) { + // bad request + case *json.SyntaxError: + return http.StatusBadRequest, "invalid_request_error", fmt.Sprintf("Invalid request: %s", e.Error()) + case *json.UnmarshalTypeError: + return http.StatusBadRequest, "invalid_request_error", fmt.Sprintf("Invalid type for `%s`: expected a %s, but got a %s", e.Field, e.Type.String(), e.Value) + case *openai.APIError: - code = e.HTTPStatusCode - errString = e.Message + // handle azure api error + if e.InnerError != nil { + return e.HTTPStatusCode, e.InnerError.Code, e.Message + } + // handle openai api error + eCode, ok := e.Code.(string) + if ok { + return e.HTTPStatusCode, eCode, e.Message + } + codeString = e.Type + return + case *openai.RequestError: - code = e.HTTPStatusCode - errString = e.Error() + return e.HTTPStatusCode, e.HTTPStatus, string(e.Body) } - return code, errString + return code, reflect.TypeOf(err).Name(), err.Error() } // NewServeMux creates a new http.ServeMux for the llm bridge server. @@ -123,10 +140,11 @@ func NewHandler(service *Service) *Handler { // HandleOverview is the handler for GET /overview func (h *Handler) HandleOverview(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") + ww := w.(EventResponseWriter) tools, err := ai.ListToolCalls(FromCallerContext(r.Context()).Metadata()) if err != nil { - RespondWithError(w, http.StatusInternalServerError, err, h.service.Logger()) + RespondWithError(ww, http.StatusInternalServerError, err) return } @@ -149,8 +167,9 @@ func (h *Handler) HandleInvoke(w http.ResponseWriter, r *http.Request) { ) defer r.Body.Close() - req, err := DecodeRequest[ai.InvokeRequest](r, w, h.service.Logger()) + req, err := DecodeRequest[ai.InvokeRequest](r, ww, h.service.Logger()) if err != nil { + RespondWithError(ww, http.StatusBadRequest, err) ww.RecordError(errors.New("bad request")) return } @@ -167,8 +186,7 @@ func (h *Handler) HandleInvoke(w http.ResponseWriter, r *http.Request) { res, err := h.service.GetInvoke(ctx, req.Prompt, baseSystemMessage, transID, caller, req.IncludeCallStack, tracer) if err != nil { - ww.RecordError(err) - RespondWithError(w, http.StatusInternalServerError, err, h.service.Logger()) + RespondWithError(ww, http.StatusInternalServerError, err) return } @@ -184,9 +202,10 @@ func (h *Handler) HandleChatCompletions(w http.ResponseWriter, r *http.Request) ) defer r.Body.Close() - req, err := DecodeRequest[openai.ChatCompletionRequest](r, w, h.service.Logger()) + req, err := DecodeRequest[openai.ChatCompletionRequest](r, ww, h.service.Logger()) if err != nil { - ww.RecordError(err) + w.Header().Set("Content-Type", "application/json") + RespondWithError(ww, http.StatusBadRequest, err) return } @@ -199,27 +218,22 @@ func (h *Handler) HandleChatCompletions(w http.ResponseWriter, r *http.Request) ) if err := h.service.GetChatCompletions(ctx, req, transID, caller, ww, tracer); err != nil { - ww.RecordError(err) if err == context.Canceled { return } - if ww.IsStream() { - h.service.Logger().Error("bridge server error", "err", err.Error(), "err_type", reflect.TypeOf(err).String()) - w.Write([]byte(fmt.Sprintf(`{"error":{"message":"%s"}}`, err.Error()))) - return + if !ww.IsStream() { + w.Header().Set("Content-Type", "application/json") } - RespondWithError(w, http.StatusBadRequest, err, h.service.Logger()) + RespondWithError(ww, http.StatusBadRequest, err) return } } // DecodeRequest decodes the request body into given type. -func DecodeRequest[T any](r *http.Request, w http.ResponseWriter, logger *slog.Logger) (T, error) { +func DecodeRequest[T any](r *http.Request, ww EventResponseWriter, logger *slog.Logger) (T, error) { var req T err := json.NewDecoder(r.Body).Decode(&req) if err != nil { - w.Header().Set("Content-Type", "application/json") - RespondWithError(w, http.StatusBadRequest, err, logger) return req, err } diff --git a/pkg/bridge/ai/response_writer.go b/pkg/bridge/ai/response_writer.go index 5f5285dc2..4a257a0ca 100644 --- a/pkg/bridge/ai/response_writer.go +++ b/pkg/bridge/ai/response_writer.go @@ -2,7 +2,6 @@ package ai import ( "encoding/json" - "fmt" "io" "log/slog" "net/http" @@ -78,7 +77,8 @@ func (w *responseWriter) InterceptError(code int, err error) (int, ErrorResponse if pcode == 0 { return code, ErrorResponseBody{ Code: http.StatusText(code), - Message: err.Error()} + Message: err.Error(), + } } return pcode, ErrorResponseBody{ Code: codeString, @@ -167,45 +167,3 @@ type ErrorResponseBody struct { func (e ErrorResponseBody) Error() string { return e.Message } - -// RespondWithError writes an error to response according to the OpenAI API spec. -func RespondWithError(w EventResponseWriter, code int, err error) error { - newCode, errBody := w.InterceptError(code, err) - w.RecordError(errBody) - - if newCode != 0 { - code = newCode - } - w.WriteHeader(code) - return json.NewEncoder(w).Encode(&ErrorResponse{Error: errBody}) - -} - -// parseCodeError returns the status code, error code string and error message string. -func parseCodeError(err error) (code int, codeString string, message string) { - switch e := err.(type) { - // bad request - case *json.SyntaxError: - return http.StatusBadRequest, "invalid_request_error", fmt.Sprintf("Invalid request: %s", e.Error()) - case *json.UnmarshalTypeError: - return http.StatusBadRequest, "invalid_request_error", fmt.Sprintf("Invalid type for `%s`: expected a %s, but got a %s", e.Field, e.Type.String(), e.Value) - - case *openai.APIError: - // handle azure api error - if e.InnerError != nil { - return e.HTTPStatusCode, e.InnerError.Code, e.Message - } - // handle openai api error - eCode, ok := e.Code.(string) - if ok { - return e.HTTPStatusCode, eCode, e.Message - } - codeString = e.Type - return - - case *openai.RequestError: - return e.HTTPStatusCode, e.HTTPStatus, string(e.Body) - } - - return code, reflect.TypeOf(err).Name(), err.Error() -} From 75e5e6b32b791357ee76f2d258c4a489ad4c1a5e Mon Sep 17 00:00:00 2001 From: woorui Date: Sat, 19 Apr 2025 23:48:17 +0800 Subject: [PATCH 5/5] fix: unittest --- pkg/bridge/test/llm_server_test.go | 4 +-- pkg/bridge/test/service_test.go | 46 +++++++++++++++--------------- 2 files changed, 25 insertions(+), 25 deletions(-) diff --git a/pkg/bridge/test/llm_server_test.go b/pkg/bridge/test/llm_server_test.go index 92c9198df..a1ce0f3a6 100644 --- a/pkg/bridge/test/llm_server_test.go +++ b/pkg/bridge/test/llm_server_test.go @@ -15,7 +15,7 @@ import ( "github.com/yomorun/yomo/core/metadata" pkgai "github.com/yomorun/yomo/pkg/bridge/ai" "github.com/yomorun/yomo/pkg/bridge/ai/provider" - _ "github.com/yomorun/yomo/pkg/bridge/ai/register" + "github.com/yomorun/yomo/pkg/bridge/ai/register" "github.com/yomorun/yomo/pkg/bridge/llm" ) @@ -33,7 +33,7 @@ func TestServer(t *testing.T) { Required: []string{"prop1"}, }, } - // register.SetRegister(register.NewDefault()) + ai.SetRegister(register.NewDefault()) ai.RegisterFunction(functionDefinition, 200, nil) // mock the provider and the req/res of the caller diff --git a/pkg/bridge/test/service_test.go b/pkg/bridge/test/service_test.go index 3361144d0..627c79447 100644 --- a/pkg/bridge/test/service_test.go +++ b/pkg/bridge/test/service_test.go @@ -15,7 +15,7 @@ import ( "github.com/yomorun/yomo/core/metadata" pkgai "github.com/yomorun/yomo/pkg/bridge/ai" "github.com/yomorun/yomo/pkg/bridge/ai/provider" - _ "github.com/yomorun/yomo/pkg/bridge/ai/register" + "github.com/yomorun/yomo/pkg/bridge/ai/register" ) func TestOpSystemPrompt(t *testing.T) { @@ -184,31 +184,31 @@ func TestServiceInvoke(t *testing.T) { }, wantUsage: ai.TokenUsage{PromptTokens: 95, CompletionTokens: 43}, }, - // BUG: test failed - // { - // name: "invoke without tool call", - // args: args{ - // providerMockData: []provider.MockData{ - // provider.MockChatCompletionResponse(stopResp), - // }, - // mockCallReqResp: []mockFunctionCall{}, - // systemPrompt: "this is a system prompt", - // userInstruction: "hi", - // baseSystemMessage: "this is a base system message", - // }, - // wantRequest: []openai.ChatCompletionRequest{ - // { - // Messages: []openai.ChatCompletionMessage{ - // {Role: "system", Content: "this is a base system message\n\n## Instructions\n\n"}, - // {Role: "user", Content: "hi"}, - // }, - // }, - // }, - // wantUsage: ai.TokenUsage{PromptTokens: 13, CompletionTokens: 26}, - // }, + { + name: "invoke without tool call", + args: args{ + providerMockData: []provider.MockData{ + provider.MockChatCompletionResponse(stopResp), + }, + mockCallReqResp: []mockFunctionCall{}, + systemPrompt: "this is a system prompt", + userInstruction: "hi", + baseSystemMessage: "this is a base system message", + }, + wantRequest: []openai.ChatCompletionRequest{ + { + Messages: []openai.ChatCompletionMessage{ + {Role: "system", Content: "this is a base system message\n\n## Instructions\n\n"}, + {Role: "user", Content: "hi"}, + }, + }, + }, + wantUsage: ai.TokenUsage{PromptTokens: 13, CompletionTokens: 26}, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + ai.SetRegister(register.NewDefault()) pd, err := provider.NewMock("mock provider", tt.args.providerMockData...) if err != nil { t.Fatal(err)