Skip to content

Commit

Permalink
fix nova model ids, update formatting for converse
Browse files Browse the repository at this point in the history
  • Loading branch information
brianandres2 committed Feb 20, 2025
1 parent 703777a commit 7372636
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 46 deletions.
6 changes: 3 additions & 3 deletions src/e84_geoai_common/llm/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@
NovaInvokeLLMRequest,
)

__all__ = [ # noqa: RUF022
__all__ = [
"CLAUDE_BEDROCK_MODEL_IDS",
"CONVERSE_BEDROCK_MODEL_IDS",
"NOVA_BEDROCK_MODEL_IDS",
"BedrockClaudeLLM",
"BedrockConverseLLM",
"BedrockNovaLLM",
"ClaudeInvokeLLMRequest",
"CONVERSE_BEDROCK_MODEL_IDS",
"BedrockConverseLLM",
"ConverseInvokeLLMRequest",
"NovaInvokeLLMRequest",
]
89 changes: 51 additions & 38 deletions src/e84_geoai_common/llm/models/converse.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@
}


#################################################################################
# Messages Object Components


class ConverseTextContent(BaseModel):
"""Converse text context model."""

Expand Down Expand Up @@ -94,6 +98,7 @@ class ConverseToolResultContent(BaseModel):

toolResult: ConverseToolResultInnerContent # noqa: N815


class ConverseImageSource(BaseModel):

model_config = ConfigDict(strict=True, extra="forbid")
Expand Down Expand Up @@ -230,43 +235,8 @@ class ConverseAssistantMessage(ConverseMessage):
content: list[ConverseTextContent | ConverseToolUseContent]


class ConverseUsageInfo(BaseModel):
"""Usage info from the Converse API."""

model_config = ConfigDict(strict=True, extra="forbid")

inputTokens: int # noqa: N815
outputTokens: int # noqa: N815
totalTokens: int # noqa: N815

class ConverseMessageResponse(BaseModel):

model_config = ConfigDict(strict=True, extra="forbid")

message: ConverseAssistantMessage

class ConverseMetrics(BaseModel):

model_config = ConfigDict(strict=True, extra="forbid")

latencyMs: int # noqa: N815

class ConverseResponse(BaseModel):
"""Converse response model."""

model_config = ConfigDict(strict=True, extra="forbid")

additionalModelResponseFields: dict[str, Any] | None = Field(default = None) # noqa: N815
metrics: ConverseMetrics
output: ConverseMessageResponse
performanceConfig: dict[str, Any] | None = Field(default = None) # noqa: N815
ResponseMetadata: dict[str, Any]
role: Literal["assistant"] = "assistant"
stopReason: Literal[ # noqa: N815
"end_turn", "max_tokens", "stop_sequence", "tool_use"
]
trace: dict[str, Any] | None = Field(default = None)
usage: ConverseUsageInfo
#################################################################################
# Other Request Objects


class ConverseToolSpec(BaseModel):
Expand Down Expand Up @@ -338,13 +308,15 @@ class ConverseToolChoice(BaseModel):
# not be supported in Bedrock
# disable_parallel_tool_use: bool | None = None # noqa: ERA001


class SystemContentBlock(BaseModel):
"""A system prompt block."""

model_config = ConfigDict(strict=True, extra="forbid")

text: str


class ConverseInferenceConfig(BaseModel):
"""Converse inference config model."""

Expand All @@ -355,6 +327,7 @@ class ConverseInferenceConfig(BaseModel):
temperature: float | None
topP: float | None # noqa: N815


class ConverseAdditionalModelRequestFields(BaseModel):
"""Converse additional fields for certain models."""

Expand All @@ -378,7 +351,7 @@ class ConverseInvokeLLMRequest(BaseModel):

modelId: str = Field( # noqa: N815
default=CONVERSE_BEDROCK_MODEL_IDS["Claude 3 Haiku"],
description="model used for the Converse api"
description="Model used for the Converse api"
)

messages: list[ConverseMessage] = Field(
Expand All @@ -391,6 +364,46 @@ class ConverseInvokeLLMRequest(BaseModel):
default=None, description="List of tools that the model may call."
)

#################################################################################
# Response objects

class ConverseUsageInfo(BaseModel):
"""Usage info from the Converse API."""

model_config = ConfigDict(strict=True, extra="forbid", frozen=True)

inputTokens: int # noqa: N815
outputTokens: int # noqa: N815
totalTokens: int # noqa: N815

class ConverseMessageResponse(BaseModel):

model_config = ConfigDict(strict=True, extra="forbid", frozen=True)

message: ConverseAssistantMessage

class ConverseMetrics(BaseModel):

model_config = ConfigDict(strict=True, extra="forbid", frozen=True)

latencyMs: int # noqa: N815

class ConverseResponse(BaseModel):
"""Converse response model."""

model_config = ConfigDict(strict=True, extra="forbid", frozen=True)

additionalModelResponseFields: dict[str, Any] | None = Field(default = None) # noqa: N815
metrics: ConverseMetrics
output: ConverseMessageResponse
performanceConfig: dict[str, Any] | None = Field(default = None) # noqa: N815
ResponseMetadata: dict[str, Any]
role: Literal["assistant"] = "assistant"
stopReason: Literal[ # noqa: N815
"end_turn", "max_tokens", "stop_sequence", "tool_use"
]
trace: dict[str, Any] | None = Field(default = None)
usage: ConverseUsageInfo

def _config_to_response_prefix(config: LLMInferenceConfig) -> str | None:
if config.json_mode:
Expand Down
10 changes: 5 additions & 5 deletions src/e84_geoai_common/llm/models/nova.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@

# https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html#model-ids-arns
NOVA_BEDROCK_MODEL_IDS = {
"Nova Canvas": "amazon.nova-canvas-v1:0",
"Nova Lite": "amazon.nova-lite-v1:0",
"Nova Micro": "amazon.nova-micro-v1:0",
"Nova Pro": "amazon.nova-pro-v1:0",
"Nova Reel": "amazon.nova-reel-v1:0",
"Nova Canvas": "us.amazon.nova-canvas-v1:0",
"Nova Lite": "us.amazon.nova-lite-v1:0",
"Nova Micro": "us.amazon.nova-micro-v1:0",
"Nova Pro": "us.amazon.nova-pro-v1:0",
"Nova Reel": "us.amazon.nova-reel-v1:0",
}


Expand Down

0 comments on commit 7372636

Please sign in to comment.