|
1 | 1 | import {Box} from '@mui/material';
|
2 |
| -import {useViewerInfo} from '@wandb/weave/common/hooks/useViewerInfo'; |
3 | 2 | import {Select} from '@wandb/weave/components/Form/Select';
|
4 | 3 | import React from 'react';
|
5 | 4 |
|
6 | 5 | import {
|
7 | 6 | LLM_MAX_TOKENS,
|
8 | 7 | LLM_PROVIDER_LABELS,
|
| 8 | + LLM_PROVIDERS, |
9 | 9 | LLMMaxTokensKey,
|
10 | 10 | } from '../llmMaxTokens';
|
11 |
| -import {useConfiguredProviders} from '../useConfiguredProviders'; |
12 |
| -import {CustomOption, ProviderOption} from './LLMDropdownOptions'; |
13 | 11 |
|
14 | 12 | interface LLMDropdownProps {
|
15 | 13 | value: LLMMaxTokensKey;
|
16 | 14 | onChange: (value: LLMMaxTokensKey, maxTokens: number) => void;
|
17 |
| - entity: string; |
18 |
| - project: string; |
19 | 15 | }
|
20 | 16 |
|
21 |
| -export const LLMDropdown: React.FC<LLMDropdownProps> = ({ |
22 |
| - value, |
23 |
| - onChange, |
24 |
| - entity, |
25 |
| - project, |
26 |
| -}) => { |
27 |
| - const {result: configuredProviders, loading: configuredProvidersLoading} = |
28 |
| - useConfiguredProviders(entity); |
29 |
| - |
30 |
| - const {loading: loadingUserInfo, userInfo} = useViewerInfo(); |
31 |
| - const isAdmin = !loadingUserInfo && userInfo?.admin; |
32 |
| - |
33 |
| - const options: ProviderOption[] = []; |
34 |
| - const disabledOptions: ProviderOption[] = []; |
35 |
| - |
36 |
| - if (configuredProvidersLoading) { |
37 |
| - options.push({ |
38 |
| - label: 'Loading providers...', |
39 |
| - value: 'loading', |
40 |
| - llms: [], |
41 |
| - }); |
42 |
| - } else { |
43 |
| - Object.entries(configuredProviders).forEach(([provider, {status}]) => { |
44 |
| - const providerLLMs = Object.entries(LLM_MAX_TOKENS) |
45 |
| - .filter(([_, config]) => config.provider === provider) |
46 |
| - .map(([llmKey]) => ({ |
47 |
| - label: llmKey, |
48 |
| - value: llmKey as LLMMaxTokensKey, |
49 |
| - max_tokens: LLM_MAX_TOKENS[llmKey as LLMMaxTokensKey].max_tokens, |
50 |
| - })); |
51 |
| - |
52 |
| - const option = { |
53 |
| - label: |
54 |
| - LLM_PROVIDER_LABELS[provider as keyof typeof LLM_PROVIDER_LABELS], |
55 |
| - value: provider, |
56 |
| - llms: status ? providerLLMs : [], |
57 |
| - isDisabled: !status, |
58 |
| - }; |
59 |
| - |
60 |
| - if (!status) { |
61 |
| - disabledOptions.push(option); |
62 |
| - } else { |
63 |
| - options.push(option); |
64 |
| - } |
65 |
| - }); |
66 |
| - } |
67 |
| - |
68 |
| - // Combine enabled and disabled options |
69 |
| - const allOptions = [...options, ...disabledOptions]; |
| 17 | +export const LLMDropdown: React.FC<LLMDropdownProps> = ({value, onChange}) => { |
| 18 | + const options = LLM_PROVIDERS.map(provider => ({ |
| 19 | + // for each provider, get all the LLMs that are supported by that provider |
| 20 | + label: LLM_PROVIDER_LABELS[provider], |
| 21 | + // filtering to the LLMs that are supported by that provider |
| 22 | + options: Object.keys(LLM_MAX_TOKENS) |
| 23 | + .reduce< |
| 24 | + Array<{ |
| 25 | + provider_label: string; |
| 26 | + label: string; |
| 27 | + value: string; |
| 28 | + }> |
| 29 | + >((acc, llm) => { |
| 30 | + if (LLM_MAX_TOKENS[llm as LLMMaxTokensKey].provider === provider) { |
| 31 | + acc.push({ |
| 32 | + provider_label: LLM_PROVIDER_LABELS[provider], |
| 33 | + // add provider to the label if the LLM is not already prefixed with it |
| 34 | + label: llm.includes(provider) ? llm : provider + '/' + llm, |
| 35 | + value: llm, |
| 36 | + }); |
| 37 | + } |
| 38 | + return acc; |
| 39 | + }, []) |
| 40 | + .sort((a, b) => a.label.localeCompare(b.label)), |
| 41 | + })); |
70 | 42 |
|
71 | 43 | return (
|
72 |
| - <Box sx={{width: '300px'}}> |
| 44 | + <Box |
| 45 | + sx={{ |
| 46 | + width: 'max-content', |
| 47 | + maxWidth: '100%', |
| 48 | + '& .MuiOutlinedInput-root': { |
| 49 | + width: 'max-content', |
| 50 | + maxWidth: '300px', |
| 51 | + }, |
| 52 | + '& > div': { |
| 53 | + width: 'max-content', |
| 54 | + maxWidth: '300px', |
| 55 | + }, |
| 56 | + '& .MuiAutocomplete-popper, & [class*="-menu"]': { |
| 57 | + width: '300px !important', |
| 58 | + }, |
| 59 | + '& #react-select-2-listbox': { |
| 60 | + width: '300px', |
| 61 | + maxHeight: '500px', |
| 62 | + }, |
| 63 | + '& #react-select-2-listbox > div': { |
| 64 | + maxHeight: '500px', |
| 65 | + width: '300px', |
| 66 | + }, |
| 67 | + }}> |
73 | 68 | <Select
|
74 |
| - isDisabled={configuredProvidersLoading} |
75 |
| - placeholder={ |
76 |
| - configuredProvidersLoading ? 'Loading providers...' : 'Select a model' |
77 |
| - } |
78 |
| - value={allOptions.find( |
79 |
| - option => |
80 |
| - 'llms' in option && option.llms?.some(llm => llm.value === value) |
81 |
| - )} |
82 |
| - formatOptionLabel={(option: ProviderOption, meta) => { |
83 |
| - if (meta.context === 'value' && 'llms' in option) { |
84 |
| - const selectedLLM = option.llms.find(llm => llm.value === value); |
85 |
| - return selectedLLM?.label ?? option.label; |
86 |
| - } |
87 |
| - return option.label; |
88 |
| - }} |
| 69 | + value={options.flatMap(o => o.options).find(o => o.value === value)} |
89 | 70 | onChange={option => {
|
90 |
| - // When you click a provider, select the first LLM |
91 |
| - if (option && 'value' in option) { |
92 |
| - const selectedOption = option as ProviderOption; |
93 |
| - if (selectedOption.llms.length > 0) { |
94 |
| - const llm = selectedOption.llms[0]; |
95 |
| - onChange(llm.value, llm.max_tokens); |
96 |
| - } |
| 71 | + if (option) { |
| 72 | + const maxTokens = |
| 73 | + LLM_MAX_TOKENS[option.value as LLMMaxTokensKey]?.max_tokens || 0; |
| 74 | + onChange(option.value as LLMMaxTokensKey, maxTokens); |
97 | 75 | }
|
98 | 76 | }}
|
99 |
| - options={allOptions} |
100 |
| - maxMenuHeight={500} |
101 |
| - components={{ |
102 |
| - Option: props => ( |
103 |
| - <CustomOption |
104 |
| - {...props} |
105 |
| - onChange={onChange} |
106 |
| - entity={entity} |
107 |
| - project={project} |
108 |
| - isAdmin={isAdmin} |
109 |
| - /> |
110 |
| - ), |
111 |
| - }} |
| 77 | + options={options} |
112 | 78 | size="medium"
|
113 | 79 | isSearchable
|
114 | 80 | filterOption={(option, inputValue) => {
|
115 | 81 | const searchTerm = inputValue.toLowerCase();
|
116 |
| - const label = |
117 |
| - typeof option.data.label === 'string' ? option.data.label : ''; |
118 | 82 | return (
|
119 |
| - label.toLowerCase().includes(searchTerm) || |
120 |
| - option.data.llms.some(llm => |
121 |
| - llm.label.toLowerCase().includes(searchTerm) |
122 |
| - ) |
| 83 | + option.data.provider_label.toLowerCase().includes(searchTerm) || |
| 84 | + option.data.label.toLowerCase().includes(searchTerm) || |
| 85 | + option.data.value.toLowerCase().includes(searchTerm) |
123 | 86 | );
|
124 | 87 | }}
|
125 | 88 | />
|
|
0 commit comments