From 3d004b8bc48fcd16d434c444cf49c0d1d0361a7e Mon Sep 17 00:00:00 2001 From: Robby <45851384+h0rv@users.noreply.github.com> Date: Fri, 31 May 2024 18:31:57 -0400 Subject: [PATCH] Add Cohere Support (#24) * x * Add Cohere Support --------- Co-authored-by: Robby --- README.md | 259 +++++++++++++++++- examples/document_segmentation/main.go | 168 +++++++----- .../self-attention-from-scratch.txt | 222 +++++++++++++++ examples/function_calling/main.go | 2 +- examples/images/anthropic/main.go | 2 +- examples/images/openai/main.go | 2 +- examples/ollama/main.go | 2 +- examples/streaming/cohere/main.go | 75 +++++ examples/streaming/{ => openai}/main.go | 4 +- go.mod | 2 +- go.sum | 2 + pkg/instructor/chat.go | 2 +- pkg/instructor/chat_stream.go | 125 +++++---- pkg/instructor/cohere.go | 48 ---- pkg/instructor/cohere_chat.go | 100 +++++++ pkg/instructor/cohere_chat_stream.go | 91 ++++++ pkg/instructor/cohere_stream.go | 11 - pkg/instructor/cohere_struct.go | 41 +++ pkg/instructor/openai_chat.go | 4 +- pkg/instructor/openai_chat_stream.go | 6 +- pkg/instructor/role_enum.go | 17 -- pkg/instructor/utils.go | 57 +++- 22 files changed, 1022 insertions(+), 220 deletions(-) create mode 100644 examples/document_segmentation/self-attention-from-scratch.txt create mode 100644 examples/streaming/cohere/main.go rename examples/streaming/{ => openai}/main.go (98%) delete mode 100644 pkg/instructor/cohere.go create mode 100644 pkg/instructor/cohere_chat.go create mode 100644 pkg/instructor/cohere_chat_stream.go delete mode 100644 pkg/instructor/cohere_stream.go create mode 100644 pkg/instructor/cohere_struct.go delete mode 100644 pkg/instructor/role_enum.go diff --git a/README.md b/README.md index 0791f83..8eabda3 100644 --- a/README.md +++ b/README.md @@ -597,7 +597,7 @@ func urlToBase64(url string) (string, error) { ```bash export OPENAI_API_KEY= -go run examples/streaming/main.go +go run examples/streaming/openai/main.go ``` @@ -735,9 +735,262 @@ Product list: -## Providers +
+Document Segmentation with Cohere + +
+Running + +```bash +export COHERE_API_KEY= +go run examples/document_segmentation/main.go +``` + +
+ +```go +package main + +import ( + "context" + "encoding/json" + "fmt" + "os" + "strings" + + cohere "github.com/cohere-ai/cohere-go/v2" + cohereclient "github.com/cohere-ai/cohere-go/v2/client" + "github.com/instructor-ai/instructor-go/pkg/instructor" +) + +type Section struct { + Title string `json:"title" jsonschema:"description=main topic of this section of the document"` + StartIndex int `json:"start_index" jsonschema:"description=line number where the section begins"` + EndIndex int `json:"end_index" jsonschema:"description=line number where the section ends"` +} + +type StructuredDocument struct { + Sections []Section `json:"sections" jsonschema:"description=a list of sections of the document"` +} + +type Segment struct { + Title string `json:"title"` + Content string `json:"content"` + Start int `json:"start"` + End int `json:"end"` +} + +func (s Segment) String() string { + return fmt.Sprintf("Title: %s\nContent:\n%s\nStart: %d\nEnd: %d\n", + s.Title, s.Content, s.Start, s.End) +} + +func (sd *StructuredDocument) PrettyPrint() string { + s, err := json.MarshalIndent(sd, "", " ") + if err != nil { + panic(err) + } + return string(s) +} + +func main() { + ctx := context.Background() + + client := instructor.FromCohere( + cohereclient.NewClient(cohereclient.WithToken(os.Getenv("COHERE_API_KEY"))), + instructor.WithMode(instructor.ModeToolCall), + instructor.WithMaxRetries(3), + ) + + /* + * Document is downloaded from a tutorial on Transformers from Sebastian Raschka: https://sebastianraschka.com/blog/2023/self-attention-from-scratch.html + * Downloaded and scraped via `trafilatura`: https://github.com/adbar/trafilatura + */ + doc, err := os.ReadFile("examples/document_segmentation/self-attention-from-scratch.txt") + if err != nil { + panic(err) + } + + getStructuredDocument := func(docWithLines string) *StructuredDocument { + var structuredDoc StructuredDocument + _, err := client.Chat(ctx, &cohere.ChatRequest{ + Model: toPtr("command-r-plus"), + Preamble: toPtr(` +You are a world class educator working on organizing your lecture notes. +Read the document below and extract a StructuredDocument object from it where each section of the document is centered around a single concept/topic that can be taught in one lesson. +Each line of the document is marked with its line number in square brackets (e.g. [1], [2], [3], etc). Use the line numbers to indicate section start and end. +`), + Message: docWithLines, + }, + &structuredDoc, + ) + if err != nil { + panic(err) + } + return &structuredDoc + } + + documentWithLineNumbers, line2text := docWithLines(string(doc)) + structuredDoc := getStructuredDocument(documentWithLineNumbers) + segments := getSectionsText(structuredDoc, line2text) + + println(segments[0].String()) + /* + Title: Introduction to Self-Attention + Content: + Understanding and Coding the Self-Attention Mechanism of Large Language Models From Scratch + In this article, we are going to understand how self-attention works from scratch. This means we will code it ourselves one step at a time. + Since its introduction via the original transformer paper (Attention Is All You Need), self-attention has become a cornerstone of many state-of-the-art deep learning models, particularly in the field of Natural Language Processing (NLP). Since self-attention is now everywhere, it’s important to understand how it works. + Self-Attention + The concept of “attention” in deep learning has its roots in the effort to improve Recurrent Neural Networks (RNNs) for handling longer sequences or sentences. For instance, consider translating a sentence from one language to another. Translating a sentence word-by-word does not work effectively. + To overcome this issue, attention mechanisms were introduced to give access to all sequence elements at each time step. The key is to be selective and determine which words are most important in a specific context. In 2017, the transformer architecture introduced a standalone self-attention mechanism, eliminating the need for RNNs altogether. + (For brevity, and to keep the article focused on the technical self-attention details, and I am skipping parts of the motivation, but my Machine Learning with PyTorch and Scikit-Learn book has some additional details in Chapter 16 if you are interested.) + We can think of self-attention as a mechanism that enhances the information content of an input embedding by including information about the input’s context. In other words, the self-attention mechanism enables the model to weigh the importance of different elements in an input sequence and dynamically adjust their influence on the output. This is especially important for language processing tasks, where the meaning of a word can change based on its context within a sentence or document. + Note that there are many variants of self-attention. A particular focus has been on making self-attention more efficient. However, most papers still implement the original scaled-dot product attention mechanism discussed in this paper since it usually results in superior accuracy and because self-attention is rarely a computational bottleneck for most companies training large-scale transformers. + Start: 0 + End: 9 + */ +} + +/* + * Preprocessing utilties + */ + +func toPtr[T any](val T) *T { + return &val +} -Most model API providers do not provide an official Go client, so here are the ones we chose for the following providers: +func docWithLines(document string) (string, map[int]string) { + documentLines := strings.Split(document, "\n") + documentWithLineNumbers := "" + line2text := make(map[int]string) + for i, line := range documentLines { + documentWithLineNumbers += fmt.Sprintf("[%d] %s\n", i, line) + line2text[i] = line + } + return documentWithLineNumbers, line2text +} + +func getSectionsText(structuredDoc *StructuredDocument, line2text map[int]string) []Segment { + var segments []Segment + for _, s := range structuredDoc.Sections { + var contents []string + for lineID := s.StartIndex; lineID < s.EndIndex; lineID++ { + if line, exists := line2text[lineID]; exists { + contents = append(contents, line) + } + } + segment := Segment{ + Title: s.Title, + Content: strings.Join(contents, "\n"), + Start: s.StartIndex, + End: s.EndIndex, + } + segments = append(segments, segment) + } + return segments +} +``` + +
+ + + +
+Streaming with Cohere + +
+Running + +```bash +export COHERE_API_KEY= +go run examples/streaming/cohere/main.go +``` + +
+ +```go +package main + +import ( + "context" + "fmt" + "os" + + cohere "github.com/cohere-ai/cohere-go/v2" + cohereclient "github.com/cohere-ai/cohere-go/v2/client" + "github.com/instructor-ai/instructor-go/pkg/instructor" +) + +type HistoricalFact struct { + Decade string `json:"decade" jsonschema:"title=Decade of the Fact,description=Decade when the fact occurred"` + Topic string `json:"topic" jsonschema:"title=Topic of the Fact,description=General category or topic of the fact"` + Description string `json:"description" jsonschema:"title=Description of the Fact,description=Description or details of the fact"` +} + +func (hf HistoricalFact) String() string { + return fmt.Sprintf(` +Decade: %s +Topic: %s +Description: %s`, hf.Decade, hf.Topic, hf.Description) +} + +func main() { + ctx := context.Background() + + client := instructor.FromCohere( + cohereclient.NewClient(cohereclient.WithToken(os.Getenv("COHERE_API_KEY"))), + instructor.WithMode(instructor.ModeJSON), + instructor.WithMaxRetries(3), + ) + + hfStream, err := client.ChatStream(ctx, &cohere.ChatStreamRequest{ + Model: toPtr("command-r-plus"), + Message: "Tell me about the history of artificial intelligence up to year 2000", + MaxTokens: toPtr(2500), + }, + *new(HistoricalFact), + ) + if err != nil { + panic(err) + } + + for instance := range hfStream { + hf := instance.(*HistoricalFact) + println(hf.String()) + } + /* + Decade: 1950s + Topic: Birth of AI + Description: The term 'Artificial Intelligence' is coined by John McCarthy at the Dartmouth Conference in 1956, considered the birth of AI as a field. Early research focuses on areas like problem solving, search algorithms, and logic. + + Decade: 1960s + Topic: Expert Systems and LISP + Description: The language LISP is developed, which becomes widely used in AI applications. Research also leads to the development of expert systems, which emulate human decision-making abilities in specific domains. + + Decade: 1970s + Topic: AI Winter + Description: AI experiences its first 'winter', a period of reduced funding and interest due to unmet expectations. Despite this, research continues in areas like knowledge representation and natural language processing. + + Decade: 1980s + Topic: Machine Learning and Neural Networks + Description: The field of machine learning emerges, with a focus on developing algorithms that can learn from data. Neural networks, inspired by the structure of biological brains, gain traction during this decade. + + Decade: 1990s + Topic: AI in Practice + Description: AI starts to find practical applications in various industries. Speech recognition, image processing, and expert systems are used in fields like healthcare, finance, and manufacturing. + */ +} + +func toPtr[T any](val T) *T { + return &val +} +``` + +
+ +## Providers - [OpenAI](https://github.com/sashabaranov/go-openai) - [Anthropic](https://github.com/liushuangls/go-anthropic) +- [Cohere](github.com/cohere-ai/cohere-go) diff --git a/examples/document_segmentation/main.go b/examples/document_segmentation/main.go index 3e32390..6db5020 100644 --- a/examples/document_segmentation/main.go +++ b/examples/document_segmentation/main.go @@ -2,23 +2,113 @@ package main import ( "context" + "encoding/json" "fmt" + "os" "strings" cohere "github.com/cohere-ai/cohere-go/v2" cohereclient "github.com/cohere-ai/cohere-go/v2/client" + "github.com/instructor-ai/instructor-go/pkg/instructor" ) type Section struct { - Title string `json:"title" jsonschema:"description=main topic of this section of the document"` - StartIndex int `json:"start_index" jsonschema:"description=line number where the section begins"` - EndIndex int `json:"end_index" jsonschema:"description=line number where the section ends"` + Title string `json:"title" jsonschema:"description=main topic of this section of the document"` + StartIndex int `json:"start_index" jsonschema:"description=line number where the section begins"` + EndIndex int `json:"end_index" jsonschema:"description=line number where the section ends"` } type StructuredDocument struct { Sections []Section `json:"sections" jsonschema:"description=a list of sections of the document"` } +type Segment struct { + Title string `json:"title"` + Content string `json:"content"` + Start int `json:"start"` + End int `json:"end"` +} + +func (s Segment) String() string { + return fmt.Sprintf("Title: %s\nContent:\n%s\nStart: %d\nEnd: %d\n", + s.Title, s.Content, s.Start, s.End) +} + +func (sd *StructuredDocument) PrettyPrint() string { + s, err := json.MarshalIndent(sd, "", " ") + if err != nil { + panic(err) + } + return string(s) +} + +func main() { + ctx := context.Background() + + client := instructor.FromCohere( + cohereclient.NewClient(cohereclient.WithToken(os.Getenv("COHERE_API_KEY"))), + instructor.WithMode(instructor.ModeToolCall), + instructor.WithMaxRetries(3), + ) + + /* + * Document is downloaded from a tutorial on Transformers from Sebastian Raschka: https://sebastianraschka.com/blog/2023/self-attention-from-scratch.html + * Downloaded and scraped via `trafilatura`: https://github.com/adbar/trafilatura + */ + doc, err := os.ReadFile("examples/document_segmentation/self-attention-from-scratch.txt") + if err != nil { + panic(err) + } + + getStructuredDocument := func(docWithLines string) *StructuredDocument { + var structuredDoc StructuredDocument + _, err := client.Chat(ctx, &cohere.ChatRequest{ + Model: toPtr("command-r-plus"), + Preamble: toPtr(` +You are a world class educator working on organizing your lecture notes. +Read the document below and extract a StructuredDocument object from it where each section of the document is centered around a single concept/topic that can be taught in one lesson. +Each line of the document is marked with its line number in square brackets (e.g. [1], [2], [3], etc). Use the line numbers to indicate section start and end. +`), + Message: docWithLines, + }, + &structuredDoc, + ) + if err != nil { + panic(err) + } + return &structuredDoc + } + + documentWithLineNumbers, line2text := docWithLines(string(doc)) + structuredDoc := getStructuredDocument(documentWithLineNumbers) + segments := getSectionsText(structuredDoc, line2text) + + println(segments[0].String()) + /* + Title: Introduction to Self-Attention + Content: + Understanding and Coding the Self-Attention Mechanism of Large Language Models From Scratch + In this article, we are going to understand how self-attention works from scratch. This means we will code it ourselves one step at a time. + Since its introduction via the original transformer paper (Attention Is All You Need), self-attention has become a cornerstone of many state-of-the-art deep learning models, particularly in the field of Natural Language Processing (NLP). Since self-attention is now everywhere, it’s important to understand how it works. + Self-Attention + The concept of “attention” in deep learning has its roots in the effort to improve Recurrent Neural Networks (RNNs) for handling longer sequences or sentences. For instance, consider translating a sentence from one language to another. Translating a sentence word-by-word does not work effectively. + To overcome this issue, attention mechanisms were introduced to give access to all sequence elements at each time step. The key is to be selective and determine which words are most important in a specific context. In 2017, the transformer architecture introduced a standalone self-attention mechanism, eliminating the need for RNNs altogether. + (For brevity, and to keep the article focused on the technical self-attention details, and I am skipping parts of the motivation, but my Machine Learning with PyTorch and Scikit-Learn book has some additional details in Chapter 16 if you are interested.) + We can think of self-attention as a mechanism that enhances the information content of an input embedding by including information about the input’s context. In other words, the self-attention mechanism enables the model to weigh the importance of different elements in an input sequence and dynamically adjust their influence on the output. This is especially important for language processing tasks, where the meaning of a word can change based on its context within a sentence or document. + Note that there are many variants of self-attention. A particular focus has been on making self-attention more efficient. However, most papers still implement the original scaled-dot product attention mechanism discussed in this paper since it usually results in superior accuracy and because self-attention is rarely a computational bottleneck for most companies training large-scale transformers. + Start: 0 + End: 9 + */ +} + +/* + * Preprocessing utilties + */ + +func toPtr[T any](val T) *T { + return &val +} + func docWithLines(document string) (string, map[int]string) { documentLines := strings.Split(document, "\n") documentWithLineNumbers := "" @@ -30,28 +120,8 @@ func docWithLines(document string) (string, map[int]string) { return documentWithLineNumbers, line2text } -// Mocking the call to the instructor and cohere client. Replace this with actual implementation. -func getStructuredDocument(documentWithLineNumbers string) StructuredDocument { - // Mock response - return StructuredDocument{ - Sections: []Section{ - { - Title: "Introduction", - StartIndex: 0, - EndIndex: 10, - }, - { - Title: "Background", - StartIndex: 10, - EndIndex: 20, - }, - // Add more sections as needed - }, - } -} - -func getSectionsText(structuredDoc StructuredDocument, line2text map[int]string) []map[string]interface{} { - var segments []map[string]interface{} +func getSectionsText(structuredDoc *StructuredDocument, line2text map[int]string) []Segment { + var segments []Segment for _, s := range structuredDoc.Sections { var contents []string for lineID := s.StartIndex; lineID < s.EndIndex; lineID++ { @@ -59,51 +129,13 @@ func getSectionsText(structuredDoc StructuredDocument, line2text map[int]string) contents = append(contents, line) } } - segment := map[string]interface{}{ - "title": s.Title, - "content": strings.Join(contents, "\n"), - "start": s.StartIndex, - "end": s.EndIndex, + segment := Segment{ + Title: s.Title, + Content: strings.Join(contents, "\n"), + Start: s.StartIndex, + End: s.EndIndex, } segments = append(segments, segment) } return segments } - -func main() { - ctx := context.Background() - - client := cohereclient.NewClient(cohereclient.WithToken("")) - - document := ` -Introduction to Multi-Head Attention -In the very first figure, at the top of this article, we saw that transformers use a module called multi-head attention. How does that relate to the self-attention mechanism (scaled-dot product attention) we walked through above? -In the scaled dot-product attention, the input sequence was transformed using three matrices representing the query, key, and value. These three matrices can be considered as a single attention head in the context of multi-head attention. The figure below summarizes this single attention head we covered previously: -As its name implies, multi-head attention involves multiple such heads, each consisting of query, key, and value matrices. This concept is similar to the use of multiple kernels in convolutional neural networks. -To illustrate this in code, suppose we have 3 attention heads, so we now extend the \(d' \times d\) dimensional weight matrices so \(3 \times d' \times d\): -In: -h = 3 -multihead_W_query = torch.nn.Parameter(torch.rand(h, d_q, d)) -multihead_W_key = torch.nn.Parameter(torch.rand(h, d_k, d)) -multihead_W_value = torch.nn.Parameter(torch.rand(h, d_v, d)) -Consequently, each query element is now \(3 \times d_q\) dimensional, where \(d_q=24\) (here, let’s keep the focus on the 3rd element corresponding to index position 2): -In: -multihead_query_2 = multihead_W_query.matmul(x_2) -print(multihead_query_2.shape) -Out: -torch.Size([3, 24]) -` - - response, err := client.Chat(ctx, &cohere.ChatRequest{ - Message: "How is the weather today?", - }, - ) - _, _ = response, err - - documentWithLineNumbers, line2text := docWithLines(document) - structuredDoc := getStructuredDocument(documentWithLineNumbers) - segments := getSectionsText(structuredDoc, line2text) - - fmt.Println(segments[1]["title"]) - fmt.Println(segments[1]["content"]) -} diff --git a/examples/document_segmentation/self-attention-from-scratch.txt b/examples/document_segmentation/self-attention-from-scratch.txt new file mode 100644 index 0000000..d84d898 --- /dev/null +++ b/examples/document_segmentation/self-attention-from-scratch.txt @@ -0,0 +1,222 @@ +Understanding and Coding the Self-Attention Mechanism of Large Language Models From Scratch +In this article, we are going to understand how self-attention works from scratch. This means we will code it ourselves one step at a time. +Since its introduction via the original transformer paper (Attention Is All You Need), self-attention has become a cornerstone of many state-of-the-art deep learning models, particularly in the field of Natural Language Processing (NLP). Since self-attention is now everywhere, it’s important to understand how it works. +Self-Attention +The concept of “attention” in deep learning has its roots in the effort to improve Recurrent Neural Networks (RNNs) for handling longer sequences or sentences. For instance, consider translating a sentence from one language to another. Translating a sentence word-by-word does not work effectively. +To overcome this issue, attention mechanisms were introduced to give access to all sequence elements at each time step. The key is to be selective and determine which words are most important in a specific context. In 2017, the transformer architecture introduced a standalone self-attention mechanism, eliminating the need for RNNs altogether. +(For brevity, and to keep the article focused on the technical self-attention details, and I am skipping parts of the motivation, but my Machine Learning with PyTorch and Scikit-Learn book has some additional details in Chapter 16 if you are interested.) +We can think of self-attention as a mechanism that enhances the information content of an input embedding by including information about the input’s context. In other words, the self-attention mechanism enables the model to weigh the importance of different elements in an input sequence and dynamically adjust their influence on the output. This is especially important for language processing tasks, where the meaning of a word can change based on its context within a sentence or document. +Note that there are many variants of self-attention. A particular focus has been on making self-attention more efficient. However, most papers still implement the original scaled-dot product attention mechanism discussed in this paper since it usually results in superior accuracy and because self-attention is rarely a computational bottleneck for most companies training large-scale transformers. +In this article, we focus on the original scaled-dot product attention mechanism (referred to as self-attention), which remains the most popular and most widely used attention mechanism in practice. However, if you are interested in other types of attention mechanisms, check out the 2020 Efficient Transformers: A Survey and the 2023 A Survey on Efficient Training of Transformers review and the recent FlashAttention paper. +Embedding an Input Sentence +Before we begin, let’s consider an input sentence “Life is short, eat dessert first” that we want to put through the self-attention mechanism. Similar to other types of modeling approaches for processing text (e.g., using recurrent neural networks or convolutional neural networks), we create a sentence embedding first. +For simplicity, here our dictionary dc +is restricted to the words that occur in the input sentence. In a real-world application, we would consider all words in the training dataset (typical vocabulary sizes range between 30k to 50k). +In: +sentence = 'Life is short, eat dessert first' +dc = {s:i for i,s in enumerate(sorted(sentence.replace(',', '').split()))} +print(dc) +Out: +{'Life': 0, 'dessert': 1, 'eat': 2, 'first': 3, 'is': 4, 'short': 5} +Next, we use this dictionary to assign an integer index to each word: +In: +import torch +sentence_int = torch.tensor([dc[s] for s in sentence.replace(',', '').split()]) +print(sentence_int) +Out: +tensor([0, 4, 5, 2, 1, 3]) +Now, using the integer-vector representation of the input sentence, we can use an embedding layer to encode the inputs into a real-vector embedding. Here, we will use a 16-dimensional embedding such that each input word is represented by a 16-dimensional vector. Since the sentence consists of 6 words, this will result in a \(6 \times 16\)-dimensional embedding: +In: +torch.manual_seed(123) +embed = torch.nn.Embedding(6, 16) +embedded_sentence = embed(sentence_int).detach() +print(embedded_sentence) +print(embedded_sentence.shape) +Out: +tensor([[ 0.3374, -0.1778, -0.3035, -0.5880, 0.3486, 0.6603, -0.2196, -0.3792, +0.7671, -1.1925, 0.6984, -1.4097, 0.1794, 1.8951, 0.4954, 0.2692], +[ 0.5146, 0.9938, -0.2587, -1.0826, -0.0444, 1.6236, -2.3229, 1.0878, +0.6716, 0.6933, -0.9487, -0.0765, -0.1526, 0.1167, 0.4403, -1.4465], +[ 0.2553, -0.5496, 1.0042, 0.8272, -0.3948, 0.4892, -0.2168, -1.7472, +-1.6025, -1.0764, 0.9031, -0.7218, -0.5951, -0.7112, 0.6230, -1.3729], +[-1.3250, 0.1784, -2.1338, 1.0524, -0.3885, -0.9343, -0.4991, -1.0867, +0.8805, 1.5542, 0.6266, -0.1755, 0.0983, -0.0935, 0.2662, -0.5850], +[-0.0770, -1.0205, -0.1690, 0.9178, 1.5810, 1.3010, 1.2753, -0.2010, +0.4965, -1.5723, 0.9666, -1.1481, -1.1589, 0.3255, -0.6315, -2.8400], +[ 0.8768, 1.6221, -1.4779, 1.1331, -1.2203, 1.3139, 1.0533, 0.1388, +2.2473, -0.8036, -0.2808, 0.7697, -0.6596, -0.7979, 0.1838, 0.2293]]) +torch.Size([6, 16]) +Defining the Weight Matrices +Now, let’s discuss the widely utilized self-attention mechanism known as the scaled dot-product attention, which is integrated into the transformer architecture. +Self-attention utilizes three weight matrices, referred to as \(\mathbf{W}_q\), \(\mathbf{W}_k\), and \(\mathbf{W}_v\), which are adjusted as model parameters during training. These matrices serve to project the inputs into query, key, and value components of the sequence, respectively. +The respective query, key and value sequences are obtained via matrix multiplication between the weight matrices \(\mathbf{W}\) and the embedded inputs \(\mathbf{x}\): +- Query sequence: \(\mathbf{q}^{(i)}=\mathbf{W}_q \mathbf{x}^{(i)}\) for \(i \in[1, T]\) +- Key sequence: \(\mathbf{k}^{(i)}=\mathbf{W}_k \mathbf{x}^{(i)}\) for \(i \in[1, T]\) +- Value sequence: \(\mathbf{v}^{(i)}=\mathbf{W}_v \mathbf{x}^{(i)}\) for \(i \in[1, T]\) +The index \(i\) refers to the token index position in the input sequence, which has length \(T\). +Here, both \(\mathbf{q}^{(i)}\) and \(\mathbf{k}^{(i)}\) are vectors of dimension \(d_k\). The projection matrices \(\mathbf{W}_{q}\) and \(\mathbf{W}_{k}\) have a shape of \(d_k \times d\), while \(\mathbf{W}_{v}\) has the shape \(d_v \times d\). +(It’s important to note that \(d\) represents the size of each word vector, \(\mathbf{x}\).) +Since we are computing the dot-product between the query and key vectors, these two vectors have to contain the same number of elements (\(d_q = d_k\)). However, the number of elements in the value vector \(\mathbf{v}^{(i)}\), which determines the size of the resulting context vector, is arbitrary. +So, for the following code walkthrough, we will set \(d_q = d_k = 24\) and use \(d_v = 28\), initializing the projection matrices as follows: +In: +torch.manual_seed(123) +d = embedded_sentence.shape[1] +d_q, d_k, d_v = 24, 24, 28 +W_query = torch.nn.Parameter(torch.rand(d_q, d)) +W_key = torch.nn.Parameter(torch.rand(d_k, d)) +W_value = torch.nn.Parameter(torch.rand(d_v, d)) +Computing the Unnormalized Attention Weights +Now, let’s suppose we are interested in computing the attention-vector for the second input element – the second input element acts as the query here: +In code, this looks like as follows: +In: +x_2 = embedded_sentence[1] +query_2 = W_query.matmul(x_2) +key_2 = W_key.matmul(x_2) +value_2 = W_value.matmul(x_2) +print(query_2.shape) +print(key_2.shape) +print(value_2.shape) +torch.Size([24]) +torch.Size([24]) +torch.Size([28]) +We can then generalize this to compute th remaining key, and value elements for all inputs as well, since we will need them in the next step when we compute the unnormalized attention weights \(\omega\): +In: +keys = W_key.matmul(embedded_sentence.T).T +values = W_value.matmul(embedded_sentence.T).T +print("keys.shape:", keys.shape) +print("values.shape:", values.shape) +Out: +keys.shape: torch.Size([6, 24]) +values.shape: torch.Size([6, 28]) +Now that we have all the required keys and values, we can proceed to the next step and compute the unnormalized attention weights \(\omega\) , which are illustrated in the figure below: +As illustrated in the figure above, we compute \(\omega_{i, j}\) as the dot product between the query and key sequences, \(\omega_{i j}=\mathbf{q}^{(i)^{\top}} \mathbf{k}^{(j)}\). +For example, we can compute the unnormalized attention weight for the query and 5th input element (corresponding to index position 4) as follows: +In: +omega_24 = query_2.dot(keys[4]) +print(omega_24) +Out: +tensor(11.1466) +Since we will need those to compute the attention scores later, let’s compute the \(\omega\) values for all input tokens as illustrated in the previous figure: +In: +omega_2 = query_2.matmul(keys.T) +print(omega_2) +Out: +tensor([ 8.5808, -7.6597, 3.2558, 1.0395, 11.1466, -0.4800]) +Computing the Attention Scores +The subsequent step in self-attention is to normalize the unnormalized attention weights, \(\omega\), to obtain the normalized attention weights, \(\alpha\), by applying the softmax function. Additionally, \(1/\sqrt{d_k}\) is used to scale \(\omega\) before normalizing it through the softmax function, as shown below: +The scaling by \(d_k\) ensures that the Euclidean length of the weight vectors will be approximately in the same magnitude. This helps prevent the attention weights from becoming too small or too large, which could lead to numerical instability or affect the model’s ability to converge during training. +In code, we can implement the computation of the attention weights as follows: +In: +import torch.nn.functional as F +attention_weights_2 = F.softmax(omega_2 / d_k**0.5, dim=0) +print(attention_weights_2) +Out: +tensor([0.2912, 0.0106, 0.0982, 0.0625, 0.4917, 0.0458]) +Finally, the last step is to compute the context vector \(\mathbf{z}^{(2)}\), which is an attention-weighted version of our original query input \(\mathbf{x}^{(2)}\), including all the other input elements as its context via the attention weights: +In code, this looks like as follows: +In: +context_vector_2 = attention_weights_2.matmul(values) +print(context_vector_2.shape) +print(context_vector_2) +Out: +torch.Size([28]) +tensor(torch.Size([28]) +tensor([-1.5993, 0.0156, 1.2670, 0.0032, -0.6460, -1.1407, -0.4908, -1.4632, +0.4747, 1.1926, 0.4506, -0.7110, 0.0602, 0.7125, -0.1628, -2.0184, +0.3838, -2.1188, -0.8136, -1.5694, 0.7934, -0.2911, -1.3640, -0.2366, +-0.9564, -0.5265, 0.0624, 1.7084]) +Note that this output vector has more dimensions (\(d_v=28\)) than the original input vector (\(d=16\)) since we specified \(d_v > d\) earlier; however, the embedding size choice is arbitrary. +Multi-Head Attention +In the very first figure, at the top of this article, we saw that transformers use a module called multi-head attention. How does that relate to the self-attention mechanism (scaled-dot product attention) we walked through above? +In the scaled dot-product attention, the input sequence was transformed using three matrices representing the query, key, and value. These three matrices can be considered as a single attention head in the context of multi-head attention. The figure below summarizes this single attention head we covered previously: +As its name implies, multi-head attention involves multiple such heads, each consisting of query, key, and value matrices. This concept is similar to the use of multiple kernels in convolutional neural networks. +To illustrate this in code, suppose we have 3 attention heads, so we now extend the \(d' \times d\) dimensional weight matrices so \(3 \times d' \times d\): +In: +h = 3 +multihead_W_query = torch.nn.Parameter(torch.rand(h, d_q, d)) +multihead_W_key = torch.nn.Parameter(torch.rand(h, d_k, d)) +multihead_W_value = torch.nn.Parameter(torch.rand(h, d_v, d)) +Consequently, each query element is now \(3 \times d_q\) dimensional, where \(d_q=24\) (here, let’s keep the focus on the 3rd element corresponding to index position 2): +In: +multihead_query_2 = multihead_W_query.matmul(x_2) +print(multihead_query_2.shape) +Out: +torch.Size([3, 24]) +We can then obtain the keys and values in a similar fashion: +In: +multihead_key_2 = multihead_W_key.matmul(x_2) +multihead_value_2 = multihead_W_value.matmul(x_2) +Now, these key and value elements are specific to the query element. But, similar to earlier, we will also need the value and keys for the other sequence elements in order to compute the attention scores for the query. We can do this is by expanding the input sequence embeddings to size 3, i.e., the number of attention heads: +In: +stacked_inputs = embedded_sentence.T.repeat(3, 1, 1) +print(stacked_inputs.shape) +Out: +torch.Size([3, 16, 6]) +Now, we can compute all the keys and values using via torch.bmm() +( batch matrix multiplication): +In: +multihead_keys = torch.bmm(multihead_W_key, stacked_inputs) +multihead_values = torch.bmm(multihead_W_value, stacked_inputs) +print("multihead_keys.shape:", multihead_keys.shape) +print("multihead_values.shape:", multihead_values.shape) +Out: +multihead_keys.shape: torch.Size([3, 24, 6]) +multihead_values.shape: torch.Size([3, 28, 6]) +We now have tensors that represent the three attention heads in their first dimension. The third and second dimensions refer to the number of words and the embedding size, respectively. To make the values and keys more intuitive to interpret, we will swap the second and third dimensions, resulting in tensors with the same dimensional structure as the original input sequence, embedded_sentence +: +In: +multihead_keys = multihead_keys.permute(0, 2, 1) +multihead_values = multihead_values.permute(0, 2, 1) +print("multihead_keys.shape:", multihead_keys.shape) +print("multihead_values.shape:", multihead_values.shape) +Out: +multihead_keys.shape: torch.Size([3, 6, 24]) +multihead_values.shape: torch.Size([3, 6, 28]) +Then, we follow the same steps as previously to compute the unscaled attention weights \(\omega\) and attention weights \(\alpha\), followed by the scaled-softmax computation to obtain an \(h \times d_v\) (here: \(3 \times d_v\)) dimensional context vector \(\mathbf{z}\) for the input element \(\mathbf{x}^{(2)}\). +Cross-Attention +In the code walkthrough above, we set \(d_q = d_k = 24\) and \(d_v=28\). Or in other words, we used the same dimensions for query and key sequences. While the value matrix \(\mathbf{W}_v\) is often chosen to have the same dimension as the query and key matrices (such as in PyTorch’s MultiHeadAttention class), we can select an arbitrary number size for the value dimensions. +Since the dimensions are sometimes a bit tricky to keep track of, let’s summarize everything we have covered so far in the figure below, which depicts the various tensor sizes for a single attention head. +Now, the illustration above corresponds to the self-attention mechanism used in transformers. One particular flavor of this attention mechanism we have yet to discuss is cross-attention. +What is cross-attention, and how does it differ from self-attention? +In self-attention, we work with the same input sequence. In cross-attention, we mix or combine two different input sequences. In the case of the original transformer architecture above, that’s the sequence returned by the encoder module on the left and the input sequence being processed by the decoder part on the right. +Note that in cross-attention, the two input sequences \(\mathbf{x}_1\) and \(\mathbf{x}_2\) can have different numbers of elements. However, their embedding dimensions must match. +The figure below illustrates the concept of cross-attention. If we set \(\mathbf{x}_1 = \mathbf{x}_2\), this is equivalent to self-attention. +(Note that the queries usually come from the decoder, and the keys and values usually come from the encoder.) +How does that work in code? Previously, when we implemented the self-attention mechanism at the beginning of this article, we used the following code to compute the query of the second input element along with all the keys and values as follows: +In: +torch.manual_seed(123) +d = embedded_sentence.shape[1] +print("embedded_sentence.shape:", embedded_sentence.shape:) +d_q, d_k, d_v = 24, 24, 28 +W_query = torch.rand(d_q, d) +W_key = torch.rand(d_k, d) +W_value = torch.rand(d_v, d) +x_2 = embedded_sentence[1] +query_2 = W_query.matmul(x_2) +print("query.shape", query_2.shape) +keys = W_key.matmul(embedded_sentence.T).T +values = W_value.matmul(embedded_sentence.T).T +print("keys.shape:", keys.shape) +print("values.shape:", values.shape) +Out: +embedded_sentence.shape: torch.Size([6, 16]) +queries.shape: torch.Size([24]) +keys.shape: torch.Size([6, 24]) +values.shape: torch.Size([6, 28]) +The only part that changes in cross attention is that we now have a second input sequence, for example, a second sentence with 8 instead of 6 input elements. Here, suppose this is a sentence with 8 tokens. +In: +embedded_sentence_2 = torch.rand(8, 16) # 2nd input sequence +keys = W_key.matmul(embedded_sentence_2.T).T +values = W_value.matmul(embedded_sentence_2.T).T +print("keys.shape:", keys.shape) +print("values.shape:", values.shape) +Out: +keys.shape: torch.Size([8, 24]) +values.shape: torch.Size([8, 28]) +Notice that compared to self-attention, the keys and values now have 8 instead of 6 rows. Everything else stays the same. +We talked a lot about language transformers above. In the original transformer architecture, cross-attention is useful when we go from an input sentence to an output sentence in the context of language translation. The input sentence represents one input sequence, and the translation represent the second input sequence (the two sentences can different numbers of words). +Another popular model where cross-attention is used is Stable Diffusion. Stable Diffusion uses cross-attention between the generated image in the U-Net model and the text prompts used for conditioning as described in High-Resolution Image Synthesis with Latent Diffusion Models – the original paper that describes the Stable Diffusion model that was later adopted by Stability AI to implement the popular Stable Diffusion model. +Conclusion +In this article, we saw how self-attention works using a step-by-step coding approach. We then extended this concept to multi-head attention, the widely used component of large-language transformers. After discussing self-attention and multi-head attention, we introduced yet another concept: cross-attention, which is a flavor of self-attention that we can apply between two different sequences. This is already a lot of information to take in. Let’s leave the training of a neural network using this multi-head attention block to a future article. +This blog is personal passion project that does not offer direct compensation. However, for those who wish to support me, please consider purchasing a copy of one of my books. If you find them insightful and beneficial, please feel free to recommend them to your friends and colleagues. +Your support means a great deal! Thank you! diff --git a/examples/function_calling/main.go b/examples/function_calling/main.go index a79e289..0ae4e98 100644 --- a/examples/function_calling/main.go +++ b/examples/function_calling/main.go @@ -42,7 +42,7 @@ func segment(ctx context.Context, data string) *Searches { Model: openai.GPT4o, Messages: []openai.ChatCompletionMessage{ { - Role: instructor.RoleUser, + Role: openai.ChatMessageRoleUser, Content: fmt.Sprintf("Consider the data below: '\n%s' and segment it into multiple search queries", data), }, }, diff --git a/examples/images/anthropic/main.go b/examples/images/anthropic/main.go index 0dab9fc..eaa6e6f 100644 --- a/examples/images/anthropic/main.go +++ b/examples/images/anthropic/main.go @@ -52,7 +52,7 @@ func main() { Model: "claude-3-haiku-20240307", Messages: []anthropic.Message{ { - Role: instructor.RoleUser, + Role: anthropic.RoleUser, Content: []anthropic.MessageContent{ anthropic.NewImageMessageContent(anthropic.MessageContentImageSource{ Type: "base64", diff --git a/examples/images/openai/main.go b/examples/images/openai/main.go index 6cb6900..6e59acc 100644 --- a/examples/images/openai/main.go +++ b/examples/images/openai/main.go @@ -43,7 +43,7 @@ func main() { Model: openai.GPT4o, Messages: []openai.ChatCompletionMessage{ { - Role: instructor.RoleUser, + Role: openai.ChatMessageRoleUser, MultiContent: []openai.ChatMessagePart{ { Type: openai.ChatMessagePartTypeText, diff --git a/examples/ollama/main.go b/examples/ollama/main.go index e3eb854..1433a45 100644 --- a/examples/ollama/main.go +++ b/examples/ollama/main.go @@ -45,7 +45,7 @@ func main() { Model: "llama3", Messages: []openai.ChatCompletionMessage{ { - Role: instructor.RoleUser, + Role: openai.ChatMessageRoleUser, Content: "Tell me about the Hal 9000", }, }, diff --git a/examples/streaming/cohere/main.go b/examples/streaming/cohere/main.go new file mode 100644 index 0000000..3047d0f --- /dev/null +++ b/examples/streaming/cohere/main.go @@ -0,0 +1,75 @@ +package main + +import ( + "context" + "fmt" + "os" + + cohere "github.com/cohere-ai/cohere-go/v2" + cohereclient "github.com/cohere-ai/cohere-go/v2/client" + "github.com/instructor-ai/instructor-go/pkg/instructor" +) + +type HistoricalFact struct { + Decade string `json:"decade" jsonschema:"title=Decade of the Fact,description=Decade when the fact occurred"` + Topic string `json:"topic" jsonschema:"title=Topic of the Fact,description=General category or topic of the fact"` + Description string `json:"description" jsonschema:"title=Description of the Fact,description=Description or details of the fact"` +} + +func (hf HistoricalFact) String() string { + return fmt.Sprintf(` +Decade: %s +Topic: %s +Description: %s`, hf.Decade, hf.Topic, hf.Description) +} + +func main() { + ctx := context.Background() + + client := instructor.FromCohere( + cohereclient.NewClient(cohereclient.WithToken(os.Getenv("COHERE_API_KEY"))), + instructor.WithMode(instructor.ModeJSON), + instructor.WithMaxRetries(3), + ) + + hfStream, err := client.ChatStream(ctx, &cohere.ChatStreamRequest{ + Model: toPtr("command-r-plus"), + Message: "Tell me about the history of artificial intelligence up to year 2000", + MaxTokens: toPtr(2500), + }, + *new(HistoricalFact), + ) + if err != nil { + panic(err) + } + + for instance := range hfStream { + hf := instance.(*HistoricalFact) + println(hf.String()) + } + /* + Decade: 1950s + Topic: Birth of AI + Description: The term 'Artificial Intelligence' is coined by John McCarthy at the Dartmouth Conference in 1956, considered the birth of AI as a field. Early research focuses on areas like problem solving, search algorithms, and logic. + + Decade: 1960s + Topic: Expert Systems and LISP + Description: The language LISP is developed, which becomes widely used in AI applications. Research also leads to the development of expert systems, which emulate human decision-making abilities in specific domains. + + Decade: 1970s + Topic: AI Winter + Description: AI experiences its first 'winter', a period of reduced funding and interest due to unmet expectations. Despite this, research continues in areas like knowledge representation and natural language processing. + + Decade: 1980s + Topic: Machine Learning and Neural Networks + Description: The field of machine learning emerges, with a focus on developing algorithms that can learn from data. Neural networks, inspired by the structure of biological brains, gain traction during this decade. + + Decade: 1990s + Topic: AI in Practice + Description: AI starts to find practical applications in various industries. Speech recognition, image processing, and expert systems are used in fields like healthcare, finance, and manufacturing. + */ +} + +func toPtr[T any](val T) *T { + return &val +} diff --git a/examples/streaming/main.go b/examples/streaming/openai/main.go similarity index 98% rename from examples/streaming/main.go rename to examples/streaming/openai/main.go index a8ce604..3516a08 100644 --- a/examples/streaming/main.go +++ b/examples/streaming/openai/main.go @@ -75,7 +75,7 @@ Preferred Shopping Times: Weekend Evenings Model: openai.GPT4o20240513, Messages: []openai.ChatCompletionMessage{ { - Role: instructor.RoleSystem, + Role: openai.ChatMessageRoleSystem, Content: fmt.Sprintf(` Generate the product recommendations from the product list based on the customer profile. Return in order of highest recommended first. @@ -83,7 +83,7 @@ Product list: %s`, productList), }, { - Role: instructor.RoleUser, + Role: openai.ChatMessageRoleUser, Content: fmt.Sprintf("User profile:\n%s", profileData), }, }, diff --git a/go.mod b/go.mod index e09410a..a11bb7a 100644 --- a/go.mod +++ b/go.mod @@ -12,7 +12,7 @@ require ( require ( github.com/bahlo/generic-list-go v0.2.0 // indirect github.com/buger/jsonparser v1.1.1 // indirect - github.com/google/uuid v1.4.0 // indirect + github.com/google/uuid v1.6.0 // indirect github.com/mailru/easyjson v0.7.7 // indirect github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index 0c13b63..ec9cab3 100644 --- a/go.sum +++ b/go.sum @@ -8,6 +8,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/google/uuid v1.4.0 h1:MtMxsa51/r9yyhkyLsVeVt0B+BGQZzpQiTQ4eHZ8bc4= github.com/google/uuid v1.4.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/invopop/jsonschema v0.12.0 h1:6ovsNSuvn9wEQVOyc72aycBMVQFKz7cPdMJn10CvzRI= github.com/invopop/jsonschema v0.12.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= diff --git a/pkg/instructor/chat.go b/pkg/instructor/chat.go index f15a924..6a4aeaa 100644 --- a/pkg/instructor/chat.go +++ b/pkg/instructor/chat.go @@ -26,7 +26,7 @@ func chatHandler(i Instructor, ctx context.Context, request interface{}, respons return nil, err } - text = extractJSON(text) + text = extractJSON(&text) err = json.Unmarshal([]byte(text), &response) if err != nil { diff --git a/pkg/instructor/chat_stream.go b/pkg/instructor/chat_stream.go index b39394e..80642db 100644 --- a/pkg/instructor/chat_stream.go +++ b/pkg/instructor/chat_stream.go @@ -7,13 +7,13 @@ import ( "strings" ) -func chatStreamHandler(i Instructor, ctx context.Context, request interface{}, response any) (<-chan any, error) { +type StreamWrapper[T any] struct { + Items []T `json:"items"` +} - type StreamWrapper[T any] struct { - Items []T `json:"items"` - } +const WRAPPER_END = `"items": [` - const WRAPPER_END = `"items": [` +func chatStreamHandler(i Instructor, ctx context.Context, request interface{}, response any) (<-chan interface{}, error) { responseType := reflect.TypeOf(response) @@ -36,11 +36,19 @@ func chatStreamHandler(i Instructor, ctx context.Context, request interface{}, r return nil, err } - parsedChan := make(chan any) // Buffered channel for parsed objects + parsedChan := parseStream(ctx, ch, responseType) + + return parsedChan, nil +} + +func parseStream(ctx context.Context, ch <-chan string, responseType reflect.Type) <-chan interface{} { + + parsedChan := make(chan any) go func() { defer close(parsedChan) - var buffer strings.Builder + + buffer := new(strings.Builder) inArray := false for { @@ -49,61 +57,72 @@ func chatStreamHandler(i Instructor, ctx context.Context, request interface{}, r return case text, ok := <-ch: if !ok { - // Steeam closed - - // Get last element out of stream wrapper - - data := buffer.String() - - if idx := strings.LastIndex(data, "]"); idx != -1 { - data = data[:idx] + data[idx+1:] - } - - // Process the remaining data in the buffer - decoder := json.NewDecoder(strings.NewReader(data)) - for decoder.More() { - instance := reflect.New(responseType).Interface() - err := decoder.Decode(instance) - if err != nil { - break - } - parsedChan <- instance - } + // Stream closed + processRemainingBuffer(buffer, parsedChan, responseType) return } + buffer.WriteString(text) - // eat all input until elements stream starts + // Eat all input until elements stream starts if !inArray { - idx := strings.Index(buffer.String(), WRAPPER_END) - if idx == -1 { - continue - } - - inArray = true - bufferStr := buffer.String() - trimmed := strings.TrimSpace(bufferStr[idx+len(WRAPPER_END):]) - buffer.Reset() - buffer.WriteString(trimmed) + inArray = startArray(buffer) } - data := buffer.String() - decoder := json.NewDecoder(strings.NewReader(data)) - - for decoder.More() { - instance := reflect.New(responseType).Interface() - err := decoder.Decode(instance) - if err != nil { - break - } - parsedChan <- instance - - buffer.Reset() - buffer.WriteString(data[len(data):]) - } + processBuffer(buffer, parsedChan, responseType) } } }() - return parsedChan, nil + return parsedChan +} + +func startArray(buffer *strings.Builder) bool { + + data := buffer.String() + + idx := strings.Index(data, WRAPPER_END) + if idx == -1 { + return false + } + + trimmed := strings.TrimSpace(data[idx+len(WRAPPER_END):]) + buffer.Reset() + buffer.WriteString(trimmed) + + return true +} + +func processBuffer(buffer *strings.Builder, parsedChan chan<- interface{}, responseType reflect.Type) { + + data := buffer.String() + + data, remaining := getFirstFullJSONElement(&data) + + decoder := json.NewDecoder(strings.NewReader(data)) + + for decoder.More() { + instance := reflect.New(responseType).Interface() + err := decoder.Decode(instance) + if err != nil { + break + } + parsedChan <- instance + + buffer.Reset() + buffer.WriteString(remaining) + } +} + +func processRemainingBuffer(buffer *strings.Builder, parsedChan chan<- interface{}, responseType reflect.Type) { + + data := buffer.String() + + data = extractJSON(&data) + + if idx := strings.LastIndex(data, "]"); idx != -1 { + data = data[:idx] + } + + processBuffer(buffer, parsedChan, responseType) } diff --git a/pkg/instructor/cohere.go b/pkg/instructor/cohere.go deleted file mode 100644 index 3cb3e1f..0000000 --- a/pkg/instructor/cohere.go +++ /dev/null @@ -1,48 +0,0 @@ -package instructor - -// import ( -// "context" -// "fmt" -// -// cohere "github.com/cohere-ai/cohere-go/v2" -// cohereclient "github.com/cohere-ai/cohere-go/v2/client" -// ) -// -// type CohereClient struct { -// Name string -// -// *cohereclient.Client -// } -// -// var _ Client = &CohereClient{} -// -// func NewCohereClient(client *cohereclient.Client) (*CohereClient, error) { -// o := &CohereClient{ -// Name: "Cohere", -// Client: client, -// } -// return o, nil -// } -// -// func (c *CohereClient) Chat(ctx context.Context, request interface{}, mode Mode, schema *Schema) (string, error) { -// -// req, ok := request.(cohere.ChatRequest) -// if !ok { -// return "", fmt.Errorf("invalid request type for %s client", c.Name) -// } -// -// switch mode { -// // case ModeToolCall: -// // return c.completionToolCall(ctx, &req, schema) -// case ModeJSON: -// return c.completionJSON(ctx, &req, schema) -// // case ModeJSONSchema: -// // return c.completionJSONSchema(ctx, &req, schema) -// default: -// return "", fmt.Errorf("mode '%s' is not supported for %s", mode, c.Name) -// } -// } -// -// func (c *CohereClient) completionJSON(ctx context.Context, request *cohere.ChatRequest, schema *Schema) (string, error) { -// panic("not implemented") -// } diff --git a/pkg/instructor/cohere_chat.go b/pkg/instructor/cohere_chat.go new file mode 100644 index 0000000..28373a2 --- /dev/null +++ b/pkg/instructor/cohere_chat.go @@ -0,0 +1,100 @@ +package instructor + +import ( + "context" + "fmt" + + cohere "github.com/cohere-ai/cohere-go/v2" + option "github.com/cohere-ai/cohere-go/v2/option" +) + +func (i *InstructorCohere) Chat( + ctx context.Context, + request *cohere.ChatRequest, + response any, + opts ...option.RequestOption, +) (*cohere.NonStreamedChatResponse, error) { + + resp, err := chatHandler(i, ctx, request, response) + if err != nil { + return nil, err + } + + return resp.(*cohere.NonStreamedChatResponse), nil +} + +func (i *InstructorCohere) chat(ctx context.Context, request interface{}, schema *Schema) (string, interface{}, error) { + + req, ok := request.(*cohere.ChatRequest) + if !ok { + return "", nil, fmt.Errorf("invalid request type for %s client", i.Provider()) + } + + switch i.Mode() { + case ModeToolCall: + return i.chatToolCall(ctx, req, schema) + case ModeJSON: + return i.chatJSON(ctx, req, schema) + default: + return "", nil, fmt.Errorf("mode '%s' is not supported for %s", i.Mode(), i.Provider()) + } +} + +func (i *InstructorCohere) chatToolCall(ctx context.Context, request *cohere.ChatRequest, schema *Schema) (string, *cohere.NonStreamedChatResponse, error) { + + request.Tools = []*cohere.Tool{createCohereTools(schema)} + + resp, err := i.Client.Chat(ctx, request) + if err != nil { + return "", nil, err + } + + _ = resp + + // TODO: implement + + panic("tool call not implemented Cohere") + +} + +func (i *InstructorCohere) chatJSON(ctx context.Context, request *cohere.ChatRequest, schema *Schema) (string, *cohere.NonStreamedChatResponse, error) { + + i.addOrConcatJSONSystemPrompt(request, schema) + + resp, err := i.Client.Chat(ctx, request) + if err != nil { + return "", nil, err + } + + return resp.Text, resp, nil +} + +func (i *InstructorCohere) addOrConcatJSONSystemPrompt(request *cohere.ChatRequest, schema *Schema) { + + schemaPrompt := fmt.Sprintf("```json!Please respond with JSON in the following JSON schema - make sure to return an instance of the JSON, not the schema itself: %s ", schema.String) + + if request.Preamble == nil { + request.Preamble = &schemaPrompt + } else { + request.Preamble = toPtr(*request.Preamble + "\n" + schemaPrompt) + } +} + +func createCohereTools(schema *Schema) *cohere.Tool { + + tool := &cohere.Tool{ + Name: "functions", + Description: schema.Schema.Description, + ParameterDefinitions: make(map[string]*cohere.ToolParameterDefinitionsValue), + } + + for _, function := range schema.Functions { + parameterDefinition := &cohere.ToolParameterDefinitionsValue{ + Description: toPtr(function.Description), + Type: function.Parameters.Type, + } + tool.ParameterDefinitions[function.Name] = parameterDefinition + } + + return tool +} diff --git a/pkg/instructor/cohere_chat_stream.go b/pkg/instructor/cohere_chat_stream.go new file mode 100644 index 0000000..9ed9f6b --- /dev/null +++ b/pkg/instructor/cohere_chat_stream.go @@ -0,0 +1,91 @@ +package instructor + +import ( + "context" + "errors" + "fmt" + "io" + + cohere "github.com/cohere-ai/cohere-go/v2" + option "github.com/cohere-ai/cohere-go/v2/option" +) + +func (i *InstructorCohere) ChatStream( + ctx context.Context, + request *cohere.ChatStreamRequest, + responseType any, + opts ...option.RequestOption, +) (<-chan any, error) { + + stream, err := chatStreamHandler(i, ctx, request, responseType) + if err != nil { + return nil, err + } + + return stream, err +} + +func (i *InstructorCohere) chatStream(ctx context.Context, request interface{}, schema *Schema) (<-chan string, error) { + + req, ok := request.(*cohere.ChatStreamRequest) + if !ok { + return nil, fmt.Errorf("invalid request type for %s client", i.Provider()) + } + + switch i.Mode() { + case ModeJSON: + return i.chatJSONStream(ctx, req, schema) + default: + return nil, fmt.Errorf("mode '%s' is not supported for %s", i.Mode(), i.Provider()) + } +} + +func (i *InstructorCohere) chatJSONStream(ctx context.Context, request *cohere.ChatStreamRequest, schema *Schema) (<-chan string, error) { + i.addOrConcatJSONSystemPromptStream(request, schema) + return i.createStream(ctx, request) +} + +func (i *InstructorCohere) addOrConcatJSONSystemPromptStream(request *cohere.ChatStreamRequest, schema *Schema) { + + schemaPrompt := fmt.Sprintf("```json!Please respond with JSON in the following JSON schema - make sure to return an instance of the JSON, not the schema itself: %s ", schema.String) + + if request.Preamble == nil { + request.Preamble = &schemaPrompt + } else { + request.Preamble = toPtr(*request.Preamble + "\n" + schemaPrompt) + } +} + +func (i *InstructorCohere) createStream(ctx context.Context, request *cohere.ChatStreamRequest) (<-chan string, error) { + stream, err := i.Client.ChatStream(ctx, request) + if err != nil { + return nil, err + } + + ch := make(chan string) + + go func() { + defer stream.Close() + defer close(ch) + for { + message, err := stream.Recv() + if errors.Is(err, io.EOF) { + return + } + if err != nil { + return + } + switch message.EventType { + case "stream-start": + continue + case "stream-end": + return + case "text-generation": + ch <- message.TextGeneration.Text + default: + panic(errors.New("cohere streaming event type not supported by instructor: " + message.EventType)) + } + } + }() + return ch, nil +} diff --git a/pkg/instructor/cohere_stream.go b/pkg/instructor/cohere_stream.go deleted file mode 100644 index d23305f..0000000 --- a/pkg/instructor/cohere_stream.go +++ /dev/null @@ -1,11 +0,0 @@ -package instructor - -// import ( -// "context" -// _ "github.com/cohere-ai/cohere-go/v2" -// _ "github.com/cohere-ai/cohere-go/v2/client" -// ) -// -// func (c *CohereClient) ChatStream(ctx context.Context, request interface{}, mode string, schema *Schema) (<-chan string, error) { -// panic("unimplemented") -// } diff --git a/pkg/instructor/cohere_struct.go b/pkg/instructor/cohere_struct.go new file mode 100644 index 0000000..e6cdb3d --- /dev/null +++ b/pkg/instructor/cohere_struct.go @@ -0,0 +1,41 @@ +package instructor + +import ( + cohere "github.com/cohere-ai/cohere-go/v2/client" +) + +type InstructorCohere struct { + *cohere.Client + + provider Provider + mode Mode + maxRetries int +} + +var _ Instructor = &InstructorCohere{} + +func FromCohere(client *cohere.Client, opts ...Options) *InstructorCohere { + + options := mergeOptions(opts...) + + i := &InstructorCohere{ + Client: client, + + provider: ProviderCohere, + mode: *options.Mode, + maxRetries: *options.MaxRetries, + } + return i +} + +func (i *InstructorCohere) Provider() string { + return i.provider +} + +func (i *InstructorCohere) Mode() string { + return i.mode +} + +func (i *InstructorCohere) MaxRetries() int { + return i.maxRetries +} diff --git a/pkg/instructor/openai_chat.go b/pkg/instructor/openai_chat.go index c8c1c69..7505de4 100644 --- a/pkg/instructor/openai_chat.go +++ b/pkg/instructor/openai_chat.go @@ -50,7 +50,7 @@ func (i *InstructorOpenAI) chat(ctx context.Context, request interface{}, schema func (i *InstructorOpenAI) chatToolCall(ctx context.Context, request *openai.ChatCompletionRequest, schema *Schema) (string, *openai.ChatCompletionResponse, error) { - request.Tools = createTools(schema) + request.Tools = createOpenAITools(schema) resp, err := i.Client.CreateChatCompletion(ctx, *request) if err != nil { @@ -138,7 +138,7 @@ Make sure to return an instance of the JSON, not the schema itself `, schema.String) msg := &openai.ChatCompletionMessage{ - Role: RoleSystem, + Role: openai.ChatMessageRoleSystem, Content: message, } diff --git a/pkg/instructor/openai_chat_stream.go b/pkg/instructor/openai_chat_stream.go index b995597..8a46c0e 100644 --- a/pkg/instructor/openai_chat_stream.go +++ b/pkg/instructor/openai_chat_stream.go @@ -47,7 +47,7 @@ func (i *InstructorOpenAI) chatStream(ctx context.Context, request interface{}, } func (i *InstructorOpenAI) chatToolCallStream(ctx context.Context, request *openai.ChatCompletionRequest, schema *Schema) (<-chan string, error) { - request.Tools = createTools(schema) + request.Tools = createOpenAITools(schema) return i.createStream(ctx, request) } @@ -63,7 +63,7 @@ func (i *InstructorOpenAI) chatJSONSchemaStream(ctx context.Context, request *op return i.createStream(ctx, request) } -func createTools(schema *Schema) []openai.Tool { +func createOpenAITools(schema *Schema) []openai.Tool { tools := make([]openai.Tool, 0, len(schema.Functions)) for _, function := range schema.Functions { f := openai.FunctionDefinition{ @@ -90,7 +90,7 @@ Make sure to return an array with the elements an instance of the JSON, not the `, schema.String) msg := &openai.ChatCompletionMessage{ - Role: RoleSystem, + Role: openai.ChatMessageRoleSystem, Content: message, } diff --git a/pkg/instructor/role_enum.go b/pkg/instructor/role_enum.go deleted file mode 100644 index 53cdb0a..0000000 --- a/pkg/instructor/role_enum.go +++ /dev/null @@ -1,17 +0,0 @@ -package instructor - -type Role = string - -const ( - RoleSystem = "system" - ChatMessageRoleSystem = RoleSystem - - RoleUser = "user" - ChatMessageRoleUser = RoleUser - - RoleAssistant = "assistant" - ChatMessageRoleAssistant = RoleAssistant - - RoleTool = "tool" - ChatMessageRoleTool = RoleTool -) diff --git a/pkg/instructor/utils.go b/pkg/instructor/utils.go index ad341d4..2eb0a11 100644 --- a/pkg/instructor/utils.go +++ b/pkg/instructor/utils.go @@ -12,14 +12,57 @@ func prepend[T any](to []T, from T) []T { return append([]T{from}, to...) } +func findMatchingBracket(json *string, start int) int { + stack := []int{} + openBracket := rune('{') + closeBracket := rune('}') + + for i := start; i < len(*json); i++ { + if rune((*json)[i]) == openBracket { + stack = append(stack, i) + } else if rune((*json)[i]) == closeBracket { + if len(stack) == 0 { + return -1 // Unbalanced brackets + } + stack = stack[:len(stack)-1] + if len(stack) == 0 { + return i // Found the matching bracket + } + } + } + + return -1 // Unbalanced brackets +} + +func getFirstFullJSONElement(json *string) (element string, remaining string) { + matchingBracketIdx := findMatchingBracket(json, 0) + + if matchingBracketIdx == -1 { + return "", *json + } + + element = (*json)[:matchingBracketIdx+1] + remaining = "" + + if matchingBracketIdx+1 < len(*json) { + remaining = (*json)[matchingBracketIdx+1:] + + if (*json)[matchingBracketIdx+1] == ',' { + remaining = (*json)[matchingBracketIdx+2:] + } + } + + return element, remaining +} + // Removes any prefixes before the JSON (like "Sure, here you go:") -func trimPrefixBeforeJSON(jsonStr string) string { - startObject := strings.IndexByte(jsonStr, '{') - startArray := strings.IndexByte(jsonStr, '[') +func trimPrefixBeforeJSON(json *string) string { + startObject := strings.IndexByte(*json, '{') + startArray := strings.IndexByte(*json, '[') var start int if startObject == -1 && startArray == -1 { - return jsonStr // No opening brace or bracket found, return the original string + return *json // No opening brace or bracket found, return the original string } else if startObject == -1 { start = startArray } else if startArray == -1 { @@ -28,7 +71,7 @@ func trimPrefixBeforeJSON(jsonStr string) string { start = min(startObject, startArray) } - return jsonStr[start:] + return (*json)[start:] } // Removes any postfixes after the JSON @@ -51,8 +94,8 @@ func trimPostfixAfterJSON(jsonStr string) string { } // Extracts the JSON by trimming prefixes and postfixes -func extractJSON(jsonStr string) string { - trimmedPrefix := trimPrefixBeforeJSON(jsonStr) +func extractJSON(json *string) string { + trimmedPrefix := trimPrefixBeforeJSON(json) trimmedJSON := trimPostfixAfterJSON(trimmedPrefix) return trimmedJSON }