Skip to content

Commit 6eed77e

Browse files
committed
feat(clients): support streaming response
1 parent 76382a3 commit 6eed77e

File tree

12 files changed

+166
-13
lines changed

12 files changed

+166
-13
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
### Added
66

7+
- Support streaming response.
78
- Support for Hugging Face.
89

910
## [2.5.0] - 2024-09-22

src/main/kotlin/com/github/blarc/ai/commits/intellij/plugin/settings/AppSettings2.kt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,9 @@ class AppSettings2 : PersistentStateComponent<AppSettings2> {
7676

7777
var appExclusions: Set<String> = setOf()
7878

79+
@Attribute
80+
var useStreamingResponse: Boolean = true
81+
7982
override fun getState() = this
8083

8184
override fun loadState(state: AppSettings2) {

src/main/kotlin/com/github/blarc/ai/commits/intellij/plugin/settings/AppSettingsConfigurable.kt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ class AppSettingsConfigurable(val project: Project, cs: CoroutineScope) : BoundC
4242
.bindSelected(project.service<ProjectSettings>()::isProjectSpecificLLMClient)
4343
contextHelp(message("settings.llmClient.projectSpecific.contextHelp"))
4444
.align(AlignX.LEFT)
45+
checkBox(message("settings.llmClient.streamingResponse"))
46+
.bindSelected(AppSettings2.instance::useStreamingResponse)
47+
contextHelp(message("settings.llmClient.streamingResponse.contextHelp"))
48+
.align(AlignX.LEFT)
4549
}
4650
row {
4751
llmClientToolbarDecorator = ToolbarDecorator.createDecorator(llmClientTable.table)

src/main/kotlin/com/github/blarc/ai/commits/intellij/plugin/settings/clients/LLMClientService.kt

Lines changed: 58 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,11 @@ import com.intellij.platform.ide.progress.withBackgroundProgress
2121
import com.intellij.ui.components.JBLabel
2222
import com.intellij.vcs.commit.AbstractCommitWorkflowHandler
2323
import com.intellij.vcs.commit.isAmendCommitMode
24+
import dev.langchain4j.data.message.AiMessage
2425
import dev.langchain4j.data.message.UserMessage
26+
import dev.langchain4j.model.StreamingResponseHandler
2527
import dev.langchain4j.model.chat.ChatLanguageModel
28+
import dev.langchain4j.model.chat.StreamingChatLanguageModel
2629
import git4idea.GitCommit
2730
import git4idea.history.GitHistoryUtils
2831
import git4idea.repo.GitRepositoryManager
@@ -35,6 +38,8 @@ abstract class LLMClientService<C : LLMClientConfiguration>(private val cs: Coro
3538

3639
abstract suspend fun buildChatModel(client: C): ChatLanguageModel
3740

41+
abstract suspend fun buildStreamingChatModel(client: C): StreamingChatLanguageModel?
42+
3843
fun generateCommitMessage(clientConfiguration: C, commitWorkflowHandler: AbstractCommitWorkflowHandler<*, *>, commitMessage: CommitMessage, project: Project) {
3944

4045
val commitContext = commitWorkflowHandler.workflow.commitContext
@@ -58,7 +63,7 @@ abstract class LLMClientService<C : LLMClientConfiguration>(private val cs: Coro
5863
val branch = commonBranch(includedChanges, project)
5964
val prompt = constructPrompt(project.service<ProjectSettings>().activePrompt.content, diff, branch, commitMessage.text, project)
6065

61-
sendRequest(clientConfiguration, prompt, onSuccess = {
66+
makeRequest(clientConfiguration, prompt, onSuccess = {
6267
withContext(Dispatchers.EDT) {
6368
commitMessage.setCommitMessage(it)
6469
}
@@ -72,6 +77,7 @@ abstract class LLMClientService<C : LLMClientConfiguration>(private val cs: Coro
7277
}
7378
}
7479

80+
7581
fun verifyConfiguration(client: C, label: JBLabel) {
7682
label.text = message("settings.verify.running")
7783
cs.launch(ModalityState.current().asContextElement()) {
@@ -89,20 +95,15 @@ abstract class LLMClientService<C : LLMClientConfiguration>(private val cs: Coro
8995
}
9096
}
9197

92-
private suspend fun sendRequest(client: C, text: String, onSuccess: suspend (r: String) -> Unit, onError: suspend (r: String) -> Unit) {
98+
private suspend fun makeRequest(client: C, text: String, onSuccess: suspend (r: String) -> Unit, onError: suspend (r: String) -> Unit) {
9399
try {
94-
val model = buildChatModel(client)
95-
val response = withContext(Dispatchers.IO) {
96-
model.generate(
97-
listOf(
98-
UserMessage.from(
99-
"user",
100-
text
101-
)
102-
)
103-
).content().text()
100+
if (AppSettings2.instance.useStreamingResponse) {
101+
buildStreamingChatModel(client)?.let { streamingChatModel ->
102+
sendStreamingRequest(streamingChatModel, text, onSuccess, onError)
103+
return
104+
}
104105
}
105-
onSuccess(response)
106+
sendRequest(client, text, onSuccess, onError)
106107
} catch (e: IllegalArgumentException) {
107108
onError(message("settings.verify.invalid", e.message ?: message("unknown-error")))
108109
} catch (e: Exception) {
@@ -112,6 +113,50 @@ abstract class LLMClientService<C : LLMClientConfiguration>(private val cs: Coro
112113
}
113114
}
114115

116+
private suspend fun sendStreamingRequest(streamingModel: StreamingChatLanguageModel, text: String, onSuccess: suspend (r: String) -> Unit, onError: suspend (r: String) -> Unit) {
117+
var response = ""
118+
withContext(Dispatchers.IO) {
119+
streamingModel.generate(
120+
listOf(
121+
UserMessage.from(
122+
"user",
123+
text
124+
)
125+
),
126+
object : StreamingResponseHandler<AiMessage> {
127+
override fun onNext(token: String?) {
128+
response += token
129+
cs.launch {
130+
onSuccess(response)
131+
}
132+
}
133+
134+
override fun onError(error: Throwable?) {
135+
response = error?.message.toString()
136+
cs.launch {
137+
onError(response)
138+
}
139+
}
140+
}
141+
)
142+
}
143+
}
144+
145+
private suspend fun sendRequest(client: C, text: String, onSuccess: suspend (r: String) -> Unit, onError: suspend (r: String) -> Unit) {
146+
val model = buildChatModel(client)
147+
val response = withContext(Dispatchers.IO) {
148+
model.generate(
149+
listOf(
150+
UserMessage.from(
151+
"user",
152+
text
153+
)
154+
)
155+
).content().text()
156+
}
157+
onSuccess(response)
158+
}
159+
115160
private suspend fun getLastCommitChanges(project: Project): List<Change> {
116161
return withContext(Dispatchers.IO) {
117162
GitRepositoryManager.getInstance(project).repositories.map { repo ->

src/main/kotlin/com/github/blarc/ai/commits/intellij/plugin/settings/clients/anthropic/AnthropicClientService.kt

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@ import com.intellij.openapi.components.Service
1010
import com.intellij.openapi.components.service
1111
import com.intellij.util.text.nullize
1212
import dev.langchain4j.model.anthropic.AnthropicChatModel
13+
import dev.langchain4j.model.anthropic.AnthropicStreamingChatModel
1314
import dev.langchain4j.model.chat.ChatLanguageModel
15+
import dev.langchain4j.model.chat.StreamingChatLanguageModel
1416
import kotlinx.coroutines.CoroutineScope
1517
import kotlinx.coroutines.Dispatchers
1618
import kotlinx.coroutines.launch
@@ -45,6 +47,22 @@ class AnthropicClientService(private val cs: CoroutineScope) : LLMClientService<
4547

4648
}
4749

50+
override suspend fun buildStreamingChatModel(client: AnthropicClientConfiguration): StreamingChatLanguageModel {
51+
val token = client.token.nullize(true) ?: retrieveToken(client.id)?.toString(true)
52+
val builder = AnthropicStreamingChatModel.builder()
53+
.modelName(client.modelId)
54+
.temperature(client.temperature.toDouble())
55+
.apiKey(token ?: "")
56+
.baseUrl(client.host)
57+
.timeout(Duration.ofSeconds(client.timeout.toLong()))
58+
59+
client.version?.takeIf { it.isNotBlank() }?.let {
60+
builder.version(it)
61+
}
62+
63+
return builder.build()
64+
}
65+
4866
fun saveToken(client: AnthropicClientConfiguration, token: String) {
4967
cs.launch(Dispatchers.Default) {
5068
try {

src/main/kotlin/com/github/blarc/ai/commits/intellij/plugin/settings/clients/azureOpenAi/AzureOpenAiClientService.kt

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@ import com.intellij.openapi.components.Service
1010
import com.intellij.openapi.components.service
1111
import com.intellij.util.text.nullize
1212
import dev.langchain4j.model.azure.AzureOpenAiChatModel
13+
import dev.langchain4j.model.azure.AzureOpenAiStreamingChatModel
1314
import dev.langchain4j.model.chat.ChatLanguageModel
15+
import dev.langchain4j.model.chat.StreamingChatLanguageModel
1416
import kotlinx.coroutines.CoroutineScope
1517
import kotlinx.coroutines.Dispatchers
1618
import kotlinx.coroutines.launch
@@ -36,6 +38,17 @@ class AzureOpenAiClientService(private val cs: CoroutineScope) : LLMClientServic
3638
.build()
3739
}
3840

41+
override suspend fun buildStreamingChatModel(client: AzureOpenAiClientConfiguration): StreamingChatLanguageModel {
42+
val token = client.token.nullize(true) ?: retrieveToken(client.id)?.toString(true)
43+
return AzureOpenAiStreamingChatModel.builder()
44+
.deploymentName(client.modelId)
45+
.temperature(client.temperature.toDouble())
46+
.timeout(Duration.ofSeconds(client.timeout.toLong()))
47+
.endpoint(client.host)
48+
.apiKey(token ?: "")
49+
.build()
50+
}
51+
3952
fun saveToken(client: AzureOpenAiClientConfiguration, token: String) {
4053
cs.launch(Dispatchers.Default) {
4154
try {

src/main/kotlin/com/github/blarc/ai/commits/intellij/plugin/settings/clients/gemini/GeminiClientService.kt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@ import com.github.blarc.ai.commits.intellij.plugin.settings.clients.LLMClientSer
44
import com.intellij.openapi.components.Service
55
import com.intellij.openapi.components.service
66
import dev.langchain4j.model.chat.ChatLanguageModel
7+
import dev.langchain4j.model.chat.StreamingChatLanguageModel
78
import dev.langchain4j.model.vertexai.VertexAiGeminiChatModel
9+
import dev.langchain4j.model.vertexai.VertexAiGeminiStreamingChatModel
810
import kotlinx.coroutines.CoroutineScope
911

1012
@Service(Service.Level.APP)
@@ -24,4 +26,12 @@ class GeminiClientService(private val cs: CoroutineScope): LLMClientService<Gemi
2426
.build()
2527
}
2628

29+
override suspend fun buildStreamingChatModel(client: GeminiClientConfiguration): StreamingChatLanguageModel {
30+
return VertexAiGeminiStreamingChatModel.builder()
31+
.project(client.projectId)
32+
.location(client.location)
33+
.modelName(client.modelId)
34+
.temperature(client.temperature.toFloat())
35+
.build()
36+
}
2737
}

src/main/kotlin/com/github/blarc/ai/commits/intellij/plugin/settings/clients/huggingface/HuggingFaceClientService.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ class HuggingFaceClientService(private val cs: CoroutineScope) : LLMClientServic
3737
.build()
3838
}
3939

40+
override suspend fun buildStreamingChatModel(client: HuggingFaceClientConfiguration) = null
41+
4042
fun saveToken(client: HuggingFaceClientConfiguration, token: String) {
4143
cs.launch(Dispatchers.Default) {
4244
try {

src/main/kotlin/com/github/blarc/ai/commits/intellij/plugin/settings/clients/ollama/OllamaClientService.kt

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@ import com.intellij.openapi.components.service
66
import com.intellij.openapi.ui.ComboBox
77
import com.intellij.openapi.ui.naturalSorted
88
import dev.langchain4j.model.chat.ChatLanguageModel
9+
import dev.langchain4j.model.chat.StreamingChatLanguageModel
910
import dev.langchain4j.model.ollama.OllamaChatModel
1011
import dev.langchain4j.model.ollama.OllamaModels
12+
import dev.langchain4j.model.ollama.OllamaStreamingChatModel
1113
import kotlinx.coroutines.CoroutineScope
1214
import kotlinx.coroutines.Dispatchers
1315
import kotlinx.coroutines.launch
@@ -53,4 +55,13 @@ class OllamaClientService(private val cs: CoroutineScope) : LLMClientService<Oll
5355
.baseUrl(client.host)
5456
.build()
5557
}
58+
59+
override suspend fun buildStreamingChatModel(client: OllamaClientConfiguration): StreamingChatLanguageModel {
60+
return OllamaStreamingChatModel.builder()
61+
.modelName(client.modelId)
62+
.temperature(client.temperature.toDouble())
63+
.timeout(Duration.ofSeconds(client.timeout.toLong()))
64+
.baseUrl(client.host)
65+
.build()
66+
}
5667
}

src/main/kotlin/com/github/blarc/ai/commits/intellij/plugin/settings/clients/openAi/OpenAiClientService.kt

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@ import com.intellij.openapi.components.Service
1010
import com.intellij.openapi.components.service
1111
import com.intellij.util.text.nullize
1212
import dev.langchain4j.model.chat.ChatLanguageModel
13+
import dev.langchain4j.model.chat.StreamingChatLanguageModel
1314
import dev.langchain4j.model.openai.OpenAiChatModel
15+
import dev.langchain4j.model.openai.OpenAiStreamingChatModel
1416
import kotlinx.coroutines.CoroutineScope
1517
import kotlinx.coroutines.Dispatchers
1618
import kotlinx.coroutines.launch
@@ -48,6 +50,27 @@ class OpenAiClientService(private val cs: CoroutineScope) : LLMClientService<Ope
4850
return builder.build()
4951
}
5052

53+
override suspend fun buildStreamingChatModel(client: OpenAiClientConfiguration): StreamingChatLanguageModel {
54+
val token = client.token.nullize(true) ?: retrieveToken(client.id)?.toString(true)
55+
val builder = OpenAiStreamingChatModel.builder()
56+
.apiKey(token ?: "")
57+
.modelName(client.modelId)
58+
.temperature(client.temperature.toDouble())
59+
.timeout(Duration.ofSeconds(client.timeout.toLong()))
60+
.baseUrl(client.host)
61+
62+
client.proxyUrl?.takeIf { it.isNotBlank() }?.let {
63+
val uri = URI(it)
64+
builder.proxy(Proxy(Proxy.Type.HTTP, InetSocketAddress(uri.host, uri.port)))
65+
}
66+
67+
client.organizationId?.takeIf { it.isNotBlank() }?.let {
68+
builder.organizationId(it)
69+
}
70+
71+
return builder.build()
72+
}
73+
5174
fun saveToken(client: OpenAiClientConfiguration, token: String) {
5275
cs.launch(Dispatchers.Default) {
5376
try {

src/main/kotlin/com/github/blarc/ai/commits/intellij/plugin/settings/clients/qianfan/QianfanClientService.kt

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@ import com.intellij.openapi.components.Service
1010
import com.intellij.openapi.components.service
1111
import com.intellij.util.text.nullize
1212
import dev.langchain4j.model.chat.ChatLanguageModel
13+
import dev.langchain4j.model.chat.StreamingChatLanguageModel
1314
import dev.langchain4j.model.qianfan.QianfanChatModel
1415
import dev.langchain4j.model.qianfan.QianfanChatModelNameEnum
16+
import dev.langchain4j.model.qianfan.QianfanStreamingChatModel
1517
import kotlinx.coroutines.CoroutineScope
1618
import kotlinx.coroutines.Dispatchers
1719
import kotlinx.coroutines.launch
@@ -42,6 +44,24 @@ class QianfanClientService(private val cs: CoroutineScope) : LLMClientService<Qi
4244
return builder.build()
4345
}
4446

47+
override suspend fun buildStreamingChatModel(client: QianfanClientConfiguration): StreamingChatLanguageModel {
48+
val apiKey = client.apiKey.nullize(true) ?: retrieveToken(client.id + "apiKey")?.toString(true)
49+
val secretKey = client.secretKey.nullize(true) ?: retrieveToken(client.id + "secretKey")?.toString(true)
50+
51+
val builder = QianfanStreamingChatModel.builder()
52+
.baseUrl(client.host)
53+
.apiKey(apiKey)
54+
.secretKey(secretKey)
55+
.modelName(client.modelId)
56+
.temperature(client.temperature.toDouble())
57+
// Fix https://github.com/langchain4j/langchain4j/pull/1426. Remove this 'if' statement when langchain4j releases a new version that resolves this issue.
58+
if (client.modelId == QianfanChatModelNameEnum.ERNIE_SPEED_128K.modelName) {
59+
builder.endpoint("ernie-speed-128k")
60+
}
61+
62+
return builder.build()
63+
}
64+
4565
fun saveApiKey(client: QianfanClientConfiguration, key: String) {
4666
cs.launch(Dispatchers.Default) {
4767
try {

src/main/resources/messages/AiCommitsBundle.properties

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,9 @@ settings.llmClient.timeout=Timeout
7676
settings.llmClient.temperature=Temperature
7777
settings.llmClient.temperature.comment=What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make \
7878
the output more random,while lower values like 0.2 will make it more focused and deterministic.
79+
settings.llmClient.streamingResponse=Streaming response
80+
settings.llmClient.streamingResponse.contextHelp=Some models do not support streaming response and will fall back to normal response.
81+
7982

8083
settings.openAI.token.example=sk-ABCdefgHIjKlxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
8184
settings.openAi.token.comment=You can get your token <a href="https://platform.openai.com/account/api-keys">here.</a>

0 commit comments

Comments
 (0)