diff --git a/.github/workflows/e2e-tests.yml b/.github/workflows/e2e-tests.yml index 912bd117..c6676c27 100644 --- a/.github/workflows/e2e-tests.yml +++ b/.github/workflows/e2e-tests.yml @@ -2,13 +2,13 @@ name: E2E Tests # suppress warning raised by https://github.com/jupyter/jupyter_core/pull/292 env: - JUPYTER_PLATFORM_DIRS: '1' + JUPYTER_PLATFORM_DIRS: "1" on: push: branches: main pull_request: - branches: '*' + branches: "*" jobs: e2e-tests: @@ -41,17 +41,17 @@ jobs: ${{ github.workspace }}/pw-browsers key: ${{ runner.os }}-${{ hashFiles('packages/jupyter-ai/ui-tests/yarn.lock') }} - - name: Install browser + - name: Install Chromium working-directory: packages/jupyter-ai/ui-tests run: jlpm install-chromium - - name: Execute e2e tests + - name: Run E2E tests working-directory: packages/jupyter-ai/ui-tests run: jlpm test - - name: Upload Playwright Test report + - name: Upload Playwright test report if: always() - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v4 with: name: jupyter-ai-playwright-tests-linux path: | diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 518cd437..70c270cd 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -1,4 +1,4 @@ -name: Python Unit Tests +name: Python Tests # suppress warning raised by https://github.com/jupyter/jupyter_core/pull/292 env: @@ -12,19 +12,52 @@ on: jobs: unit-tests: - name: Linux + name: Unit tests (Python ${{ matrix.python-version }}, ${{ matrix.dependency-type }} dependencies) runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + include: + - dependency-type: minimum + python-version: "3.9" + - dependency-type: standard + python-version: "3.12" steps: - name: Checkout uses: actions/checkout@v4 - name: Base Setup uses: jupyterlab/maintainer-tools/.github/actions/base-setup@v1 + with: + python_version: ${{ matrix.python-version }} + dependency_type: ${{ matrix.dependency-type }} - name: Install extension dependencies and build the extension run: ./scripts/install.sh + - name: List installed versions + run: pip list + - name: Execute unit tests run: | set -eux pytest -vv -r ap --cov jupyter_ai + + typing-tests: + name: Typing test + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Base Setup + uses: jupyterlab/maintainer-tools/.github/actions/base-setup@v1 + + - name: Install extension dependencies and build the extension + run: ./scripts/install.sh + + - name: Run mypy + run: | + set -eux + mypy --version + mypy packages/jupyter-ai diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 08b9fe09..c4962313 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.6.0 + rev: v5.0.0 hooks: - id: end-of-file-fixer - id: check-case-conflict @@ -18,7 +18,7 @@ repos: - id: trailing-whitespace - repo: https://github.com/psf/black - rev: 24.4.2 + rev: 24.10.0 hooks: - id: black @@ -30,13 +30,13 @@ repos: files: \.py$ - repo: https://github.com/asottile/pyupgrade - rev: v3.16.0 + rev: v3.19.0 hooks: - id: pyupgrade args: [--py37-plus] - repo: https://github.com/pycqa/flake8 - rev: 7.1.0 + rev: 7.1.1 hooks: - id: flake8 additional_dependencies: @@ -48,7 +48,7 @@ repos: stages: [manual] - repo: https://github.com/sirosen/check-jsonschema - rev: 0.29.0 + rev: 0.30.0 hooks: - id: check-jsonschema name: "Check GitHub Workflows" diff --git a/CHANGELOG.md b/CHANGELOG.md index 55a13dc0..3a1951d8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,309 @@ +## 2.28.3 + +This release notably fixes a major bug with updated model fields not being used until after a server restart, and fixes a bug with Ollama in the chat. Thank you for your patience as we continue to improve Jupyter AI! 🤗 + +([Full Changelog](https://github.com/jupyterlab/jupyter-ai/compare/@jupyter-ai/core@2.28.2...37558cd9233c1971b8793138bd334bc648a69125)) + +### Enhancements made + +- Removes outdated OpenAI models and adds new ones [#1127](https://github.com/jupyterlab/jupyter-ai/pull/1127) ([@srdas](https://github.com/srdas)) + +### Bugs fixed + +- Fix install step in CI [#1139](https://github.com/jupyterlab/jupyter-ai/pull/1139) ([@dlqqq](https://github.com/dlqqq)) +- Update completion model fields immediately on save [#1137](https://github.com/jupyterlab/jupyter-ai/pull/1137) ([@dlqqq](https://github.com/dlqqq)) +- Fix JSON serialization error in Ollama models [#1129](https://github.com/jupyterlab/jupyter-ai/pull/1129) ([@JanusChoi](https://github.com/JanusChoi)) +- Update model fields immediately on save [#1125](https://github.com/jupyterlab/jupyter-ai/pull/1125) ([@dlqqq](https://github.com/dlqqq)) +- Downgrade spurious 'error' logs [#1119](https://github.com/jupyterlab/jupyter-ai/pull/1119) ([@ctcjab](https://github.com/ctcjab)) + +### Maintenance and upkeep improvements + +### Contributors to this release + +([GitHub contributors page for this release](https://github.com/jupyterlab/jupyter-ai/graphs/contributors?from=2024-11-18&to=2024-12-05&type=c)) + +[@ctcjab](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Actcjab+updated%3A2024-11-18..2024-12-05&type=Issues) | [@dlqqq](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Adlqqq+updated%3A2024-11-18..2024-12-05&type=Issues) | [@JanusChoi](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3AJanusChoi+updated%3A2024-11-18..2024-12-05&type=Issues) | [@krassowski](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Akrassowski+updated%3A2024-11-18..2024-12-05&type=Issues) | [@pre-commit-ci](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Apre-commit-ci+updated%3A2024-11-18..2024-12-05&type=Issues) | [@srdas](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Asrdas+updated%3A2024-11-18..2024-12-05&type=Issues) + + + +## 2.28.2 + +([Full Changelog](https://github.com/jupyterlab/jupyter-ai/compare/@jupyter-ai/core@2.28.1...cbfed31bc3ee19c02a12c28a618af99ab10206be)) + +### Bugs fixed + +- Bump LangChain minimum versions [#1109](https://github.com/jupyterlab/jupyter-ai/pull/1109) ([@dlqqq](https://github.com/dlqqq)) +- Catch error on non plaintext files in `@file` and reply gracefully in chat [#1106](https://github.com/jupyterlab/jupyter-ai/pull/1106) ([@srdas](https://github.com/srdas)) +- Fix rendering of code blocks in JupyterLab 4.3.0+ [#1104](https://github.com/jupyterlab/jupyter-ai/pull/1104) ([@dlqqq](https://github.com/dlqqq)) + +### Contributors to this release + +([GitHub contributors page for this release](https://github.com/jupyterlab/jupyter-ai/graphs/contributors?from=2024-11-11&to=2024-11-18&type=c)) + +[@dlqqq](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Adlqqq+updated%3A2024-11-11..2024-11-18&type=Issues) | [@srdas](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Asrdas+updated%3A2024-11-11..2024-11-18&type=Issues) + +## 2.28.1 + +([Full Changelog](https://github.com/jupyterlab/jupyter-ai/compare/@jupyter-ai/core@2.28.0...d5d604d64552ed89ba200d56da9d3044138feaf5)) + +### Bugs fixed + +- Update `faiss-cpu` version range [#1097](https://github.com/jupyterlab/jupyter-ai/pull/1097) ([@dlqqq](https://github.com/dlqqq)) + +### Contributors to this release + +([GitHub contributors page for this release](https://github.com/jupyterlab/jupyter-ai/graphs/contributors?from=2024-11-07&to=2024-11-11&type=c)) + +[@dlqqq](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Adlqqq+updated%3A2024-11-07..2024-11-11&type=Issues) + +## 2.28.0 + +([Full Changelog](https://github.com/jupyterlab/jupyter-ai/compare/@jupyter-ai/core@2.27.0...6eb6a625712b8a77240ddfffe70050ed9ab83576)) + +### Release summary + +This release notably includes the following changes: + +- Models from the `Anthropic` and `ChatAnthropic` providers are now merged in the config UI, so all Anthropic models are shown in the same place in the "Language model" dropdown. + +- Anthropic Claude v1 LLMs have been removed, as the models are retired and no longer available from the API. + +- The chat system prompt has been updated to encourage the LLM to express dollar quantities in LaTeX, i.e. the LLM should prefer returning `\(\$100\)` instead of `$100`. For the latest LLMs, this generally fixes a [rendering issue](https://github.com/jupyterlab/jupyter-ai/issues/1067) when multiple dollar quantities are given literally in the same sentence. + + - Note that the issue may still persist in older LLMs, which do not respect the system prompt as frequently. + +- `/export` has been fixed to include streamed replies, which were previously omitted. + +- Calling non-chat providers with history has been fixed to behave properly in magics. + +### Enhancements made + +- Remove retired models and add new `Haiku-3.5` model in Anthropic [#1092](https://github.com/jupyterlab/jupyter-ai/pull/1092) ([@srdas](https://github.com/srdas)) +- Reduced padding in cell around code icons in code toolbar [#1072](https://github.com/jupyterlab/jupyter-ai/pull/1072) ([@srdas](https://github.com/srdas)) +- Merge Anthropic language model providers [#1069](https://github.com/jupyterlab/jupyter-ai/pull/1069) ([@srdas](https://github.com/srdas)) +- Add examples of using Fields and EnvAuthStrategy to developer documentation [#1056](https://github.com/jupyterlab/jupyter-ai/pull/1056) ([@alanmeeson](https://github.com/alanmeeson)) + +### Bugs fixed + +- Continue to allow `$` symbols to delimit inline math in human messages [#1094](https://github.com/jupyterlab/jupyter-ai/pull/1094) ([@dlqqq](https://github.com/dlqqq)) +- Fix `/export` by including streamed agent messages [#1077](https://github.com/jupyterlab/jupyter-ai/pull/1077) ([@mcavdar](https://github.com/mcavdar)) +- Fix magic commands when using non-chat providers w/ history [#1075](https://github.com/jupyterlab/jupyter-ai/pull/1075) ([@alanmeeson](https://github.com/alanmeeson)) +- Allow `$` to literally denote quantities of USD in chat [#1068](https://github.com/jupyterlab/jupyter-ai/pull/1068) ([@dlqqq](https://github.com/dlqqq)) + +### Documentation improvements + +- Improve installation documentation and clarify provider dependencies [#1087](https://github.com/jupyterlab/jupyter-ai/pull/1087) ([@srdas](https://github.com/srdas)) +- Added Ollama to the providers table in user docs [#1064](https://github.com/jupyterlab/jupyter-ai/pull/1064) ([@srdas](https://github.com/srdas)) + +### Contributors to this release + +([GitHub contributors page for this release](https://github.com/jupyterlab/jupyter-ai/graphs/contributors?from=2024-10-29&to=2024-11-07&type=c)) + +[@alanmeeson](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Aalanmeeson+updated%3A2024-10-29..2024-11-07&type=Issues) | [@dlqqq](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Adlqqq+updated%3A2024-10-29..2024-11-07&type=Issues) | [@krassowski](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Akrassowski+updated%3A2024-10-29..2024-11-07&type=Issues) | [@mcavdar](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Amcavdar+updated%3A2024-10-29..2024-11-07&type=Issues) | [@srdas](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Asrdas+updated%3A2024-10-29..2024-11-07&type=Issues) + +## 2.27.0 + +([Full Changelog](https://github.com/jupyterlab/jupyter-ai/compare/@jupyter-ai/core@2.26.0...b621aa476afc3af9e9b3dd398108b7ed9aaf7bf6)) + +### Enhancements made + +- Added new Anthropic Sonnet3.5 v2 models [#1049](https://github.com/jupyterlab/jupyter-ai/pull/1049) ([@srdas](https://github.com/srdas)) +- Implement streaming for `/fix` [#1048](https://github.com/jupyterlab/jupyter-ai/pull/1048) ([@srdas](https://github.com/srdas)) + +### Documentation improvements + +- Added Developer documentation for streaming responses [#1051](https://github.com/jupyterlab/jupyter-ai/pull/1051) ([@srdas](https://github.com/srdas)) + +### Contributors to this release + +([GitHub contributors page for this release](https://github.com/jupyterlab/jupyter-ai/graphs/contributors?from=2024-10-21&to=2024-10-29&type=c)) + +[@dlqqq](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Adlqqq+updated%3A2024-10-21..2024-10-29&type=Issues) | [@pre-commit-ci](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Apre-commit-ci+updated%3A2024-10-21..2024-10-29&type=Issues) | [@srdas](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Asrdas+updated%3A2024-10-21..2024-10-29&type=Issues) + +## 2.26.0 + +This release notably includes the addition of a "Stop streaming" button, which takes over the "Send" button when a reply is streaming and the chat input is empty. While Jupyternaut is streaming a reply to a user, the user has the option to click the "Stop streaming" button to interrupt Jupyternaut and stop it from streaming further. Thank you @krassowski for contributing this feature! 🎉 + +([Full Changelog](https://github.com/jupyterlab/jupyter-ai/compare/@jupyter-ai/core@2.25.0...838dfa9cdcbd8fc0373e3c056c677b016531a68c)) + +### Enhancements made + +- Support Quarto Markdown in `/learn` [#1047](https://github.com/jupyterlab/jupyter-ai/pull/1047) ([@dlqqq](https://github.com/dlqqq)) +- Update requirements contributors doc [#1045](https://github.com/jupyterlab/jupyter-ai/pull/1045) ([@JasonWeill](https://github.com/JasonWeill)) +- Remove clear_message_ids from RootChatHandler [#1042](https://github.com/jupyterlab/jupyter-ai/pull/1042) ([@michaelchia](https://github.com/michaelchia)) +- Migrate streaming logic to `BaseChatHandler` [#1039](https://github.com/jupyterlab/jupyter-ai/pull/1039) ([@dlqqq](https://github.com/dlqqq)) +- Unify message clearing & broadcast logic [#1038](https://github.com/jupyterlab/jupyter-ai/pull/1038) ([@dlqqq](https://github.com/dlqqq)) +- Learn from JSON files [#1024](https://github.com/jupyterlab/jupyter-ai/pull/1024) ([@jlsajfj](https://github.com/jlsajfj)) +- Allow users to stop message streaming [#1022](https://github.com/jupyterlab/jupyter-ai/pull/1022) ([@krassowski](https://github.com/krassowski)) + +### Bugs fixed + +- Always use `username` from `IdentityProvider` [#1034](https://github.com/jupyterlab/jupyter-ai/pull/1034) ([@krassowski](https://github.com/krassowski)) + +### Maintenance and upkeep improvements + +- Support `jupyter-collaboration` v3 [#1035](https://github.com/jupyterlab/jupyter-ai/pull/1035) ([@krassowski](https://github.com/krassowski)) +- Test Python 3.9 and 3.12 on CI, test minimum dependencies [#1029](https://github.com/jupyterlab/jupyter-ai/pull/1029) ([@krassowski](https://github.com/krassowski)) + +### Documentation improvements + +- Update requirements contributors doc [#1045](https://github.com/jupyterlab/jupyter-ai/pull/1045) ([@JasonWeill](https://github.com/JasonWeill)) + +### Contributors to this release + +([GitHub contributors page for this release](https://github.com/jupyterlab/jupyter-ai/graphs/contributors?from=2024-10-07&to=2024-10-21&type=c)) + +[@dlqqq](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Adlqqq+updated%3A2024-10-07..2024-10-21&type=Issues) | [@JasonWeill](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3AJasonWeill+updated%3A2024-10-07..2024-10-21&type=Issues) | [@jlsajfj](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Ajlsajfj+updated%3A2024-10-07..2024-10-21&type=Issues) | [@krassowski](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Akrassowski+updated%3A2024-10-07..2024-10-21&type=Issues) | [@michaelchia](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Amichaelchia+updated%3A2024-10-07..2024-10-21&type=Issues) | [@pre-commit-ci](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Apre-commit-ci+updated%3A2024-10-07..2024-10-21&type=Issues) + +## 2.25.0 + +([Full Changelog](https://github.com/jupyterlab/jupyter-ai/compare/@jupyter-ai/core@2.24.1...097dbe48722e255173c6504e6de835c297c553ab)) + +### Enhancements made + +- Export context hooks from NPM package entry point [#1020](https://github.com/jupyterlab/jupyter-ai/pull/1020) ([@dlqqq](https://github.com/dlqqq)) +- Add support for optional telemetry plugin [#1018](https://github.com/jupyterlab/jupyter-ai/pull/1018) ([@dlqqq](https://github.com/dlqqq)) +- Add back history and reset subcommand in magics [#997](https://github.com/jupyterlab/jupyter-ai/pull/997) ([@akaihola](https://github.com/akaihola)) + +### Maintenance and upkeep improvements + +### Contributors to this release + +([GitHub contributors page for this release](https://github.com/jupyterlab/jupyter-ai/graphs/contributors?from=2024-10-04&to=2024-10-07&type=c)) + +[@akaihola](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Aakaihola+updated%3A2024-10-04..2024-10-07&type=Issues) | [@dlqqq](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Adlqqq+updated%3A2024-10-04..2024-10-07&type=Issues) | [@jtpio](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Ajtpio+updated%3A2024-10-04..2024-10-07&type=Issues) | [@pre-commit-ci](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Apre-commit-ci+updated%3A2024-10-04..2024-10-07&type=Issues) + +## 2.24.1 + +([Full Changelog](https://github.com/jupyterlab/jupyter-ai/compare/@jupyter-ai/core@2.24.0...f3692d94dfbb4837714888d0e69f6c7ca3ba547c)) + +### Enhancements made + +- Make path argument required on /learn [#1012](https://github.com/jupyterlab/jupyter-ai/pull/1012) ([@andrewfulton9](https://github.com/andrewfulton9)) + +### Bugs fixed + +- Export tokens from `lib/index.js` [#1019](https://github.com/jupyterlab/jupyter-ai/pull/1019) ([@dlqqq](https://github.com/dlqqq)) + +### Contributors to this release + +([GitHub contributors page for this release](https://github.com/jupyterlab/jupyter-ai/graphs/contributors?from=2024-09-26&to=2024-10-04&type=c)) + +[@andrewfulton9](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Aandrewfulton9+updated%3A2024-09-26..2024-10-04&type=Issues) | [@dlqqq](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Adlqqq+updated%3A2024-09-26..2024-10-04&type=Issues) | [@hockeymomonow](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Ahockeymomonow+updated%3A2024-09-26..2024-10-04&type=Issues) + +## 2.24.0 + +([Full Changelog](https://github.com/jupyterlab/jupyter-ai/compare/@jupyter-ai/core@2.23.0...e6ec9e9ba4336168ce7874c09d07157be8bbff5a)) + +This release notably introduces a new **context command** `@file:` to the chat UI, which includes the content of the target file with your prompt when sent. This allows you to ask questions like: + +- `What does @file:src/components/ActionButton.tsx do?` +- `Can you refactor @file:src/index.ts to use async/await syntax?` +- `How do I add an optional dependency to @file:pyproject.toml?` + +The context command feature also includes an autocomplete menu UI to help navigate your filesystem with fewer keystrokes. + +Thank you @michaelchia for developing this feature! + +### Enhancements made + +- Migrate to `ChatOllama` base class in Ollama provider [#1015](https://github.com/jupyterlab/jupyter-ai/pull/1015) ([@srdas](https://github.com/srdas)) +- Add `metadata` field to agent messages [#1013](https://github.com/jupyterlab/jupyter-ai/pull/1013) ([@dlqqq](https://github.com/dlqqq)) +- Add OpenRouter support [#996](https://github.com/jupyterlab/jupyter-ai/pull/996) ([@akaihola](https://github.com/akaihola)) +- Framework for adding context to LLM prompt [#993](https://github.com/jupyterlab/jupyter-ai/pull/993) ([@michaelchia](https://github.com/michaelchia)) +- Adds unix shell-style wildcard matching to `/learn` [#989](https://github.com/jupyterlab/jupyter-ai/pull/989) ([@andrewfulton9](https://github.com/andrewfulton9)) + +### Bugs fixed + +- Run mypy on CI, fix or ignore typing issues [#987](https://github.com/jupyterlab/jupyter-ai/pull/987) ([@krassowski](https://github.com/krassowski)) + +### Maintenance and upkeep improvements + +- Upgrade to `actions/upload-artifact@v4` in workflows [#992](https://github.com/jupyterlab/jupyter-ai/pull/992) ([@dlqqq](https://github.com/dlqqq)) + +### Contributors to this release + +([GitHub contributors page for this release](https://github.com/jupyterlab/jupyter-ai/graphs/contributors?from=2024-09-11&to=2024-09-26&type=c)) + +[@akaihola](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Aakaihola+updated%3A2024-09-11..2024-09-26&type=Issues) | [@andrewfulton9](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Aandrewfulton9+updated%3A2024-09-11..2024-09-26&type=Issues) | [@dlqqq](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Adlqqq+updated%3A2024-09-11..2024-09-26&type=Issues) | [@ellisonbg](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Aellisonbg+updated%3A2024-09-11..2024-09-26&type=Issues) | [@hockeymomonow](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Ahockeymomonow+updated%3A2024-09-11..2024-09-26&type=Issues) | [@krassowski](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Akrassowski+updated%3A2024-09-11..2024-09-26&type=Issues) | [@michaelchia](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Amichaelchia+updated%3A2024-09-11..2024-09-26&type=Issues) | [@srdas](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Asrdas+updated%3A2024-09-11..2024-09-26&type=Issues) + +## 2.23.0 + +([Full Changelog](https://github.com/jupyterlab/jupyter-ai/compare/@jupyter-ai/core@2.22.0...83cbd8ea240f1429766c417bada3bfb39afc4462)) + +### Enhancements made + +- Allow unlimited LLM memory through traitlets configuration [#986](https://github.com/jupyterlab/jupyter-ai/pull/986) ([@krassowski](https://github.com/krassowski)) +- Allow to disable automatic inline completions [#981](https://github.com/jupyterlab/jupyter-ai/pull/981) ([@krassowski](https://github.com/krassowski)) +- Add ability to delete messages + start new chat session [#951](https://github.com/jupyterlab/jupyter-ai/pull/951) ([@michaelchia](https://github.com/michaelchia)) + +### Bugs fixed + +- Fix `RunnableWithMessageHistory` import [#980](https://github.com/jupyterlab/jupyter-ai/pull/980) ([@krassowski](https://github.com/krassowski)) +- Fix sort messages [#975](https://github.com/jupyterlab/jupyter-ai/pull/975) ([@michaelchia](https://github.com/michaelchia)) + +### Contributors to this release + +([GitHub contributors page for this release](https://github.com/jupyterlab/jupyter-ai/graphs/contributors?from=2024-08-29&to=2024-09-11&type=c)) + +[@dlqqq](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Adlqqq+updated%3A2024-08-29..2024-09-11&type=Issues) | [@krassowski](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Akrassowski+updated%3A2024-08-29..2024-09-11&type=Issues) | [@michaelchia](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Amichaelchia+updated%3A2024-08-29..2024-09-11&type=Issues) | [@srdas](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Asrdas+updated%3A2024-08-29..2024-09-11&type=Issues) + +## 2.22.0 + +([Full Changelog](https://github.com/jupyterlab/jupyter-ai/compare/@jupyter-ai/core@2.21.0...79158abf7044605c5776e205dd171fe87fb64142)) + +### Enhancements made + +- Add 'Generative AI' submenu [#971](https://github.com/jupyterlab/jupyter-ai/pull/971) ([@dlqqq](https://github.com/dlqqq)) +- Add Gemini 1.5 to the list of chat options [#964](https://github.com/jupyterlab/jupyter-ai/pull/964) ([@trducng](https://github.com/trducng)) +- Allow configuring a default model for cell magics (and line error magic) [#962](https://github.com/jupyterlab/jupyter-ai/pull/962) ([@krassowski](https://github.com/krassowski)) +- Make chat memory size traitlet configurable + /clear to reset memory [#943](https://github.com/jupyterlab/jupyter-ai/pull/943) ([@michaelchia](https://github.com/michaelchia)) + +### Maintenance and upkeep improvements + +### Documentation improvements + +- Update documentation to cover installation of all dependencies [#961](https://github.com/jupyterlab/jupyter-ai/pull/961) ([@srdas](https://github.com/srdas)) + +### Contributors to this release + +([GitHub contributors page for this release](https://github.com/jupyterlab/jupyter-ai/graphs/contributors?from=2024-08-19&to=2024-08-29&type=c)) + +[@dlqqq](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Adlqqq+updated%3A2024-08-19..2024-08-29&type=Issues) | [@krassowski](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Akrassowski+updated%3A2024-08-19..2024-08-29&type=Issues) | [@michaelchia](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Amichaelchia+updated%3A2024-08-19..2024-08-29&type=Issues) | [@pre-commit-ci](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Apre-commit-ci+updated%3A2024-08-19..2024-08-29&type=Issues) | [@srdas](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Asrdas+updated%3A2024-08-19..2024-08-29&type=Issues) | [@trducng](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Atrducng+updated%3A2024-08-19..2024-08-29&type=Issues) + +## 2.21.0 + +([Full Changelog](https://github.com/jupyterlab/jupyter-ai/compare/@jupyter-ai/core@2.20.0...83e368b9d04904f9eb0ad4b1f0759bf3b7bbc93d)) + +### Enhancements made + +- Add optional configurable message footer [#942](https://github.com/jupyterlab/jupyter-ai/pull/942) ([@dlqqq](https://github.com/dlqqq)) +- Add support for Azure Open AI Embeddings to Jupyter AI [#940](https://github.com/jupyterlab/jupyter-ai/pull/940) ([@gsrikant7](https://github.com/gsrikant7)) +- Make help message template configurable [#938](https://github.com/jupyterlab/jupyter-ai/pull/938) ([@dlqqq](https://github.com/dlqqq)) +- Add latest Bedrock models (Titan, Llama 3.1 405b, Mistral Large 2, Jamba Instruct) [#923](https://github.com/jupyterlab/jupyter-ai/pull/923) ([@gabrielkoo](https://github.com/gabrielkoo)) +- Add support for custom/provisioned models in Bedrock [#922](https://github.com/jupyterlab/jupyter-ai/pull/922) ([@dlqqq](https://github.com/dlqqq)) +- Settings section improvement [#918](https://github.com/jupyterlab/jupyter-ai/pull/918) ([@andrewfulton9](https://github.com/andrewfulton9)) + +### Bugs fixed + +- Bind reject method to promise, improve typing [#949](https://github.com/jupyterlab/jupyter-ai/pull/949) ([@krassowski](https://github.com/krassowski)) +- Fix sending empty input with Enter [#946](https://github.com/jupyterlab/jupyter-ai/pull/946) ([@michaelchia](https://github.com/michaelchia)) +- Fix saving chat settings [#935](https://github.com/jupyterlab/jupyter-ai/pull/935) ([@dlqqq](https://github.com/dlqqq)) + +### Documentation improvements + +- Add documentation on how to use Amazon Bedrock [#936](https://github.com/jupyterlab/jupyter-ai/pull/936) ([@srdas](https://github.com/srdas)) +- Update copyright template [#925](https://github.com/jupyterlab/jupyter-ai/pull/925) ([@srdas](https://github.com/srdas)) + +### Contributors to this release + +([GitHub contributors page for this release](https://github.com/jupyterlab/jupyter-ai/graphs/contributors?from=2024-07-29&to=2024-08-19&type=c)) + +[@andrewfulton9](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Aandrewfulton9+updated%3A2024-07-29..2024-08-19&type=Issues) | [@dlqqq](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Adlqqq+updated%3A2024-07-29..2024-08-19&type=Issues) | [@gabrielkoo](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Agabrielkoo+updated%3A2024-07-29..2024-08-19&type=Issues) | [@gsrikant7](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Agsrikant7+updated%3A2024-07-29..2024-08-19&type=Issues) | [@krassowski](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Akrassowski+updated%3A2024-07-29..2024-08-19&type=Issues) | [@michaelchia](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Amichaelchia+updated%3A2024-07-29..2024-08-19&type=Issues) | [@srdas](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Asrdas+updated%3A2024-07-29..2024-08-19&type=Issues) + ## 2.20.0 ([Full Changelog](https://github.com/jupyterlab/jupyter-ai/compare/@jupyter-ai/core@2.19.1...79d66daefa2dc2f8a47b55712d7b812ee23acda4)) @@ -19,8 +322,6 @@ [@dlqqq](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Adlqqq+updated%3A2024-07-22..2024-07-29&type=Issues) | [@JasonWeill](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3AJasonWeill+updated%3A2024-07-22..2024-07-29&type=Issues) | [@srdas](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Asrdas+updated%3A2024-07-22..2024-07-29&type=Issues) - - ## 2.19.1 ([Full Changelog](https://github.com/jupyterlab/jupyter-ai/compare/@jupyter-ai/core@2.19.0...cfcb3c6df4e795d877e6d0967a22ecfb880b3be3)) @@ -48,8 +349,6 @@ [@andrewfulton9](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Aandrewfulton9+updated%3A2024-07-15..2024-07-22&type=Issues) | [@dlqqq](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Adlqqq+updated%3A2024-07-15..2024-07-22&type=Issues) | [@jtpio](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Ajtpio+updated%3A2024-07-15..2024-07-22&type=Issues) | [@krassowski](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Akrassowski+updated%3A2024-07-15..2024-07-22&type=Issues) | [@michaelchia](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Amichaelchia+updated%3A2024-07-15..2024-07-22&type=Issues) | [@pre-commit-ci](https://github.com/search?q=repo%3Ajupyterlab%2Fjupyter-ai+involves%3Apre-commit-ci+updated%3A2024-07-15..2024-07-22&type=Issues) - - ## 2.19.0 This is a significant release that implements LLM response streaming in Jupyter AI along with several other enhancements & fixes listed below. Special thanks to @krassowski for his generous contributions this release! diff --git a/README.md b/README.md index 5f2a2fa6..eb90510d 100644 --- a/README.md +++ b/README.md @@ -68,26 +68,80 @@ Below is a simplified overview of the installation and usage process. See [our official documentation](https://jupyter-ai.readthedocs.io/en/latest/users/index.html) for details on installing and using Jupyter AI. -### With pip +We offer 3 different ways to install Jupyter AI. You can read through each +section to pick the installation method that works best for you. + +1. Quick installation via `pip` (recommended) +2. Minimal installation via `pip` +3. Minimal installation via `conda` + +### Quick installation via `pip` (recommended) If you want to install both the `%%ai` magic and the JupyterLab extension, you can run: - $ pip install jupyter-ai + $ pip install jupyter-ai[all] + +Then, restart JupyterLab. This will install every optional dependency, which +provides access to all models currently supported by `jupyter-ai`. -If you are not using JupyterLab and you only want to install the Jupyter AI `%%ai` magic, you can run: +If you are not using JupyterLab and you only want to install the Jupyter AI +`%%ai` magic, you can run: - $ pip install jupyter-ai-magics + $ pip install jupyter-ai-magics[all] +`jupyter-ai` depends on `jupyter-ai-magics`, so installing `jupyter-ai` +automatically installs `jupyter-ai-magics`. -### With conda +### Minimal installation via `pip` + +Most model providers in Jupyter AI require a specific dependency to be installed +before they are available for use. These are called _provider dependencies_. +Provider dependencies are optional to Jupyter AI, meaning that Jupyter AI can be +installed with or without any provider dependencies installed. If a provider +requires a dependency that is not installed, its models are not listed in the +user interface which allows you to select a language model. + +To perform a minimal installation via `pip` without any provider dependencies, +omit the `[all]` optional dependency group from the package name: + +``` +pip install jupyter-ai +``` + +By selectively installing provider dependencies, you can control which models +are available in your Jupyter AI environment. + +For example, to install Jupyter AI with only added support for Anthropic models, run: + +``` +pip install jupyter-ai langchain-anthropic +``` + +For more information on model providers and which dependencies they require, see +[the model provider table](https://jupyter-ai.readthedocs.io/en/latest/users/index.html#model-providers). + +### Minimal installation via `conda` As an alternative to using `pip`, you can install `jupyter-ai` using [Conda](https://conda.io/projects/conda/en/latest/user-guide/install/index.html) -from the `conda-forge` channel, using one of the following two commands: +from the `conda-forge` channel: - $ conda install -c conda-forge jupyter-ai # or, $ conda install conda-forge::jupyter-ai +Most model providers in Jupyter AI require a specific _provider dependency_ to +be installed before they are available for use. Provider dependencies are +not installed when installing `jupyter-ai` from Conda Forge, and should be +installed separately as needed. + +For example, to install Jupyter AI with only added support for OpenAI models, run: + +``` +conda install conda-forge::jupyter-ai conda-forge::langchain-openai +``` + +For more information on model providers and which dependencies they require, see +[the model provider table](https://jupyter-ai.readthedocs.io/en/latest/users/index.html#model-providers). + ## The `%%ai` magic command The `%%ai` magic works anywhere the IPython kernel runs, including JupyterLab, Jupyter Notebook, Google Colab, and Visual Studio Code. diff --git a/docs/source/_static/bedrock-cross-region-inference.png b/docs/source/_static/bedrock-cross-region-inference.png new file mode 100644 index 00000000..15d017db Binary files /dev/null and b/docs/source/_static/bedrock-cross-region-inference.png differ diff --git a/docs/source/_static/chat-icon-left-tab-bar-custom.png b/docs/source/_static/chat-icon-left-tab-bar-custom.png new file mode 100644 index 00000000..b9ec3e71 Binary files /dev/null and b/docs/source/_static/chat-icon-left-tab-bar-custom.png differ diff --git a/docs/source/_static/chat-icon-left-tab-bar.png b/docs/source/_static/chat-icon-left-tab-bar.png index 4ad54b6c..07f21b52 100644 Binary files a/docs/source/_static/chat-icon-left-tab-bar.png and b/docs/source/_static/chat-icon-left-tab-bar.png differ diff --git a/docs/source/contributors/index.md b/docs/source/contributors/index.md index 52d1a545..c6fa7480 100644 --- a/docs/source/contributors/index.md +++ b/docs/source/contributors/index.md @@ -20,11 +20,11 @@ Issues and pull requests that violate the above principles may be declined. If y You can develop Jupyter AI on any system that can run a supported Python version up to and including 3.12, including recent Windows, macOS, and Linux versions. -Each Jupyter AI major version works with only one major version of JupyterLab. Jupyter AI 1.x supports JupyterLab 3.x, and Jupyter AI 2.x supports JupyterLab 4.x. +You should have the newest supported version of JupyterLab installed. -We highly recommend that you install [conda](https://conda.io/projects/conda/en/latest/user-guide/install/index.html) to start developing on Jupyter AI, especially if you are developing on macOS on an Apple Silicon-based Mac (M1, M1 Pro, M2, etc.). +We highly recommend that you install [conda](https://conda.io/projects/conda/en/latest/user-guide/install/index.html) to start contributing to Jupyter AI, especially if you are developing on macOS on an Apple Silicon-based Mac (M1, M1 Pro, M2, etc.). -You will need Node.js 18 to use Jupyter AI. Node.js 18.16.0 is known to work. +You will need [a supported version of node.js](https://github.com/nodejs/release#release-schedule) to use Jupyter AI. :::{warning} :name: node-18-15 diff --git a/docs/source/developers/index.md b/docs/source/developers/index.md index 644dc0a4..af4cfcea 100644 --- a/docs/source/developers/index.md +++ b/docs/source/developers/index.md @@ -121,6 +121,61 @@ your new provider's `id`: [LLM]: https://api.python.langchain.com/en/v0.0.339/llms/langchain.llms.base.LLM.html#langchain.llms.base.LLM [BaseChatModel]: https://api.python.langchain.com/en/v0.0.339/chat_models/langchain.chat_models.base.BaseChatModel.html +### API keys and fields for custom providers + +You can add handle authentication via API keys, and configuration with +custom parameters using an auth strategy and fields as shown in the example +below. + +```python +from typing import ClassVar, List +from jupyter_ai_magics import BaseProvider +from jupyter_ai_magics.providers import EnvAuthStrategy, Field, TextField, MultilineTextField +from langchain_community.llms import FakeListLLM + + +class MyProvider(BaseProvider, FakeListLLM): + id = "my_provider" + name = "My Provider" + model_id_key = "model" + models = [ + "model_a", + "model_b" + ] + + auth_strategy = EnvAuthStrategy( + name="MY_API_KEY", keyword_param="my_api_key_param" + ) + + fields: ClassVar[List[Field]] = [ + TextField(key="my_llm_parameter", label="The name for my_llm_parameter to show in the UI"), + MultilineTextField(key="custom_config", label="Custom Json Config", format="json"), + ] + + def __init__(self, **kwargs): + model = kwargs.get("model_id") + kwargs["responses"] = ( + ["This is a response from model 'a'"] + if model == "model_a" else + ["This is a response from model 'b'"] + ) + super().__init__(**kwargs) +``` + +The `auth_strategy` handles specifying API keys for providers and models. +The example shows the `EnvAuthStrategy` which takes the API key from the +environment variable with the name specified in `name` and be provided to the +model's `__init__` as a kwarg with the name specified in `keyword_param`. +This will also cause a field to be present in the configuration UI with the +`name` of the environment variable as the label. + +Further configuration can be handled adding `fields` into the settings +dialogue for your custom model by specifying a list of fields as shown in +the example. These will be passed into the `__init__` as kwargs, with the +key specified by the key in the field object. The label specified in the field +object determines the text shown in the configuration section of the user +interface. + ### Custom embeddings providers To provide a custom embeddings model an embeddings providers should be defined implementing the API of `jupyter-ai`'s `BaseEmbeddingsProvider` and of `langchain`'s [`Embeddings`][Embeddings] abstract class. @@ -392,6 +447,120 @@ custom = "custom_package:CustomChatHandler" Then, install your package so that Jupyter AI adds custom chat handlers to the existing chat handlers. +## Streaming output from custom slash commands + +Jupyter AI supports streaming output in the chat session. When a response is +streamed to the user, the user can watch the response being constructed in +real-time, which offers a visually pleasing user experience. Custom slash +commands can stream responses in chat by invoking the `stream_reply()` method, +provided by the `BaseChatHandler` class that custom slash commands inherit from. +Custom slash commands should always use `self.stream_reply()` to stream +responses, as it provides support for stopping the response stream from the UI. + +To use `stream_reply()`, your slash command must bind a LangChain +[`Runnable`](https://python.langchain.com/api_reference/core/runnables/langchain_core.runnables.base.Runnable.html) +to `self.llm_chain` in the `create_llm_chain()` method. Runnables can be created +by using LangChain Expression Language (LCEL). See below for an example +definition of `create_llm_chain()`, sourced from our implementation of `/fix` in +`fix.py`: + +```python +def create_llm_chain( + self, provider: Type[BaseProvider], provider_params: Dict[str, str] + ): + unified_parameters = { + "verbose": True, + **provider_params, + **(self.get_model_parameters(provider, provider_params)), + } + llm = provider(**unified_parameters) + self.llm = llm + prompt_template = FIX_PROMPT_TEMPLATE + self.prompt_template = prompt_template + + runnable = prompt_template | llm # type:ignore + self.llm_chain = runnable +``` + +Once your chat handler binds a Runnable to `self.llm_chain` in +`self.create_llm_chain()`, you can define `process_message()` to invoke +`self.stream_reply()`, which streams a reply back to the user using +`self.llm_chain.astream()`. +`self.stream_reply()` has two required arguments: + +- `input`: An input to your LangChain Runnable. This is usually a dictionary +whose keys are input variables specified in your prompt template, but may be +just a string if your Runnable does not use a prompt template. + +- `message`: The `HumanChatMessage` being replied to. + +An example of `process_message()` can also be sourced from our implementation of `/fix`: + +```python +async def process_message(self, message: HumanChatMessage): + if not (message.selection and message.selection.type == "cell-with-error"): + self.reply( + "`/fix` requires an active code cell with error output. Please click on a cell with error output and retry.", + message, + ) + return + + # hint type of selection + selection: CellWithErrorSelection = message.selection + + # parse additional instructions specified after `/fix` + extra_instructions = message.prompt[4:].strip() or "None." + + self.get_llm_chain() + assert self.llm_chain + + inputs = { + "extra_instructions": extra_instructions, + "cell_content": selection.source, + "traceback": selection.error.traceback, + "error_name": selection.error.name, + "error_value": selection.error.value, + } + await self.stream_reply(inputs, message, pending_msg="Analyzing error") +``` + +The last line of `process_message` above calls `stream_reply` in `base.py`. +Note that a custom pending message may also be passed. +The `stream_reply` function leverages the LCEL Runnable. +The function takes in the input, human message, and optional +pending message strings and configuration, as shown below: + +```python +async def stream_reply( + self, + input: Input, + human_msg: HumanChatMessage, + pending_msg="Generating response", + config: Optional[RunnableConfig] = None, + ): + """ + Streams a reply to a human message by invoking + `self.llm_chain.astream()`. A LangChain `Runnable` instance must be + bound to `self.llm_chain` before invoking this method. + + Arguments + --------- + - `input`: The input to your runnable. The type of `input` depends on + the runnable in `self.llm_chain`, but is usually a dictionary whose keys + refer to input variables in your prompt template. + + - `human_msg`: The `HumanChatMessage` being replied to. + + - `config` (optional): A `RunnableConfig` object that specifies + additional configuration when streaming from the runnable. + + - `pending_msg` (optional): Changes the default pending message from + "Generating response". + """ + assert self.llm_chain + assert isinstance(self.llm_chain, Runnable) +``` + ## Custom message footer You can provide a custom message footer that will be rendered under each message diff --git a/docs/source/users/bedrock.md b/docs/source/users/bedrock.md index 558bf93c..8bcce8a5 100644 --- a/docs/source/users/bedrock.md +++ b/docs/source/users/bedrock.md @@ -1,8 +1,10 @@ # Using Amazon Bedrock with Jupyter AI -[(Return to Chat Interface page for Bedrock)](index.md#amazon-bedrock-usage) +[(Return to the Chat Interface page)](index.md#amazon-bedrock-usage) -Bedrock supports many language model providers such as AI21 Labs, Amazon, Anthropic, Cohere, Meta, and Mistral AI. To use the base models from any supported provider make sure to enable them in Amazon Bedrock by using the AWS console. Go to Amazon Bedrock and select `Model Access` as shown here: +Bedrock supports many language model providers such as AI21 Labs, Amazon, Anthropic, Cohere, Meta, and Mistral AI. To use the base models from any supported provider make sure to enable them in Amazon Bedrock by using the AWS console. You should also select embedding models in Bedrock in addition to language completion models if you intend to use retrieval augmented generation (RAG) on your documents. + +Go to Amazon Bedrock and select `Model Access` as shown here: Screenshot of the Jupyter AI chat panel where the base language model and embedding model is selected. -Bedrock also allows custom models to be trained from scratch or fine-tuned from a base model. Jupyter AI enables a custom model to be called in the chat panel using its `arn` (Amazon Resource Name). As with custom models, you can also call a base model by its `model id` or its `arn`. An example of using a base model with its `model id` through the custom model interface is shown below: +If your provider requires an API key, please enter it in the box that will show for that provider. Make sure to click on `Save Changes` to ensure that the inputs have been saved. + +Bedrock also allows custom models to be trained from scratch or fine-tuned from a base model. Jupyter AI enables a custom model to be called in the chat panel using its `arn` (Amazon Resource Name). A fine-tuned model will have your 12-digit customer number in the ARN: + +Screenshot of the Jupyter AI chat panel where the custom model is selected using model arn. + +As with custom models, you can also call a base model by its `model id` or its `arn`. An example of using a base model with its `model id` through the custom model interface is shown below: Screenshot of the Jupyter AI chat panel where the base model is selected using its ARN. +## Fine-tuning in Bedrock + To train a custom model in Amazon Bedrock, select `Custom models` in the Bedrock console as shown below, and then you may customize a base model by fine-tuning it or continuing to pre-train it: Screenshot of Bedrock cross-region inference usage. + +## Summary + +1. Bedrock Base models: All available models will already be available in the drop down model list. The above interface also allows use of base model IDs or ARNs, though this is unnecessary as they are in the dropdown list. +2. Bedrock Custom models: If you have fine-tuned a Bedrock base model you may use the ARN for this custom model. Make sure to enter the correct provider information, such as `amazon`, `anthropic`, `cohere`, `meta`, `mistral` (always in lower case). +3. Provisioned Models: These are models that run on dedicated endpoints. Users can purchase Provisioned Throughput Model Units to get faster throughput. These may be base or custom models. Enter the ARN for these models in the Model ID field. +4. Cross-region Inference: Use the Inference profile ID for the cross-region model instead of the ARN. + +[(Return to the Chat Interface page)](index.md#amazon-bedrock-usage) diff --git a/docs/source/users/index.md b/docs/source/users/index.md index 0c920382..ff669f97 100644 --- a/docs/source/users/index.md +++ b/docs/source/users/index.md @@ -79,56 +79,96 @@ classes in their code. ## Installation -### Installation via `pip` +### Setup: creating a Jupyter AI environment (recommended) -To install the JupyterLab extension, you can run: +Before installing Jupyter AI, we highly recommend first creating a separate +Conda environment for Jupyter AI. This prevents the installation process from +clobbering Python packages in your existing Python environment. -``` -pip install jupyter-ai -``` +To do so, install +[conda](https://conda.io/projects/conda/en/latest/user-guide/install/index.html) +and create an environment that uses Python 3.12 and the latest version of +JupyterLab: -The latest major version of `jupyter-ai`, v2, only supports JupyterLab 4. If you -need support for JupyterLab 3, you should install `jupyter-ai` v1 instead: + $ conda create -n jupyter-ai python=3.12 jupyterlab + $ conda activate jupyter-ai -``` -pip install jupyter-ai~=1.0 -``` +You can now choose how to install Jupyter AI. -If you are not using JupyterLab and you only want to install the Jupyter AI `%%ai` magic, you can run: +We offer 3 different ways to install Jupyter AI. You can read through each +section to pick the installation method that works best for you. -``` -$ pip install jupyter-ai-magics -``` +1. Quick installation via `pip` (recommended) +2. Minimal installation via `pip` +3. Minimal installation via `conda` + +### Quick installation via `pip` (recommended) + +If you want to install both the `%%ai` magic and the JupyterLab extension, you can run: + + $ pip install jupyter-ai[all] + +Then, restart JupyterLab. This will install every optional dependency, which +provides access to all models currently supported by `jupyter-ai`. + +If you are not using JupyterLab and you only want to install the Jupyter AI +`%%ai` magic, you can run: + + $ pip install jupyter-ai-magics[all] `jupyter-ai` depends on `jupyter-ai-magics`, so installing `jupyter-ai` automatically installs `jupyter-ai-magics`. -### Installation via `pip` or `conda` in a Conda environment (recommended) +### Minimal installation via `pip` -We highly recommend installing both JupyterLab and Jupyter AI within an isolated -Conda environment to avoid clobbering Python packages in your existing Python -environment. +Most model providers in Jupyter AI require a specific dependency to be installed +before they are available for use. These are called _provider dependencies_. +Provider dependencies are optional to Jupyter AI, meaning that Jupyter AI can be +installed with or without any provider dependencies installed. If a provider +requires a dependency that is not installed, its models are not listed in the +user interface which allows you to select a language model. -First, install -[conda](https://conda.io/projects/conda/en/latest/user-guide/install/index.html) -and create an environment that uses Python 3.12: +To perform a minimal installation via `pip` without any provider dependencies, +omit the `[all]` optional dependency group from the package name: - $ conda create -n jupyter-ai python=3.12 - $ conda activate jupyter-ai +``` +pip install jupyter-ai +``` + +By selectively installing provider dependencies, you can control which models +are available in your Jupyter AI environment. -Then, use `conda` to install JupyterLab and Jupyter AI in this Conda environment. +For example, to install Jupyter AI with only added support for Anthropic models, run: + +``` +pip install jupyter-ai langchain-anthropic +``` + +For more information on model providers and which dependencies they require, see +[the model provider table](https://jupyter-ai.readthedocs.io/en/latest/users/index.html#model-providers). + +### Minimal installation via `conda` + +As an alternative to using `pip`, you can install `jupyter-ai` using +[Conda](https://conda.io/projects/conda/en/latest/user-guide/install/index.html) +from the `conda-forge` channel: - $ conda install -c conda-forge jupyter-ai # or, $ conda install conda-forge::jupyter-ai -When starting JupyterLab with Jupyter AI, make sure to activate the Conda -environment first: +Most model providers in Jupyter AI require a specific _provider dependency_ to +be installed before they are available for use. Provider dependencies are +not installed when installing `jupyter-ai` from Conda Forge, and should be +installed separately as needed. + +For example, to install Jupyter AI with only added support for OpenAI models, run: ``` -conda activate jupyter-ai -jupyter lab +conda install conda-forge::jupyter-ai conda-forge::langchain-openai ``` +For more information on model providers and which dependencies they require, see +[the model provider table](https://jupyter-ai.readthedocs.io/en/latest/users/index.html#model-providers). + ## Uninstallation If you installed Jupyter AI using `pip`, to remove the extension, run: @@ -168,6 +208,7 @@ Jupyter AI supports the following model providers: | Hugging Face Hub | `huggingface_hub` | `HUGGINGFACEHUB_API_TOKEN` | `huggingface_hub`, `ipywidgets`, `pillow` | | MistralAI | `mistralai` | `MISTRAL_API_KEY` | `langchain-mistralai` | | NVIDIA | `nvidia-chat` | `NVIDIA_API_KEY` | `langchain_nvidia_ai_endpoints` | +| Ollama | `ollama` | N/A | `langchain-ollama` | | OpenAI | `openai` | `OPENAI_API_KEY` | `langchain-openai` | | OpenAI (chat) | `openai-chat` | `OPENAI_API_KEY` | `langchain-openai` | | SageMaker endpoint | `sagemaker-endpoint` | N/A | `langchain-aws` | @@ -245,10 +286,31 @@ Before you can use the chat interface, you need to provide your API keys for the alt="Screen shot of the setup interface, showing model selections and key populated" class="screenshot" /> -Once you have set all the necessary keys, click the "back" (left arrow) button in the upper-left corner of the Jupyter AI side panel. The chat interface now appears, and you can ask a question using the message box at the bottom. +Once you have set all the necessary keys, click the "back" (left arrow) button in the upper-left corner of the Jupyter AI side panel. The chat interface now appears, with a help menu of available `/` (slash) commands, and you can ask a question using the message box at the bottom. Screen shot of the initial, blank, chat interface. + +You may customize the template of the chat interface from the default one. The steps are as follows: +1. Create a new `config.py` file in your current directory with the contents you want to see in the help message, by editing the template below: +``` +c.AiExtension.help_message_template = """ +Sup. I'm {persona_name}. This is a sassy custom help message. + +Here's the slash commands you can use. Use 'em or don't... I don't care. + +{slash_commands_list} +""".strip() +``` +2. Start JupyterLab with the following command: +``` +jupyter lab --config=config.py +``` +The new help message will be used instead of the default, as shown below + +Screen shot of the custom chat interface. To compose a message, type it in the text box at the bottom of the chat interface and press ENTER to send it. You can press SHIFT+ENTER to add a new line. (These are the default keybindings; you can change them in the chat settings pane.) Once you have sent a message, you should see a response from Cloudera Copilot, the Jupyter AI chatbot. @@ -272,29 +334,9 @@ The chat backend remembers the last two exchanges in your conversation and passe ### Amazon Bedrock Usage -Jupyter AI enables use of language models hosted on [Amazon Bedrock](https://aws.amazon.com/bedrock/) on AWS. First, ensure that you have authentication to use AWS using the `boto3` SDK with credentials stored in the `default` profile. Guidance on how to do this can be found in the [`boto3` documentation](https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html). - -For more detailed workflows, see [Using Amazon Bedrock with Jupter AI](bedrock.md). - -Bedrock supports many language model providers such as AI21 Labs, Amazon, Anthropic, Cohere, Meta, and Mistral AI. To use the base models from any supported provider make sure to enable them in Amazon Bedrock by using the AWS console. You should also select embedding models in Bedrock in addition to language completion models if you intend to use retrieval augmented generation (RAG) on your documents. - -You may now select a chosen Bedrock model from the drop-down menu box title `Completion model` in the chat interface. If RAG is going to be used then pick an embedding model that you chose from the Bedrock models as well. An example of these selections is shown below: - -Screenshot of the Jupyter AI chat panel where the base language model and embedding model is selected. - -If your provider requires an API key, please enter it in the box that will show for that provider. Make sure to click on `Save Changes` to ensure that the inputs have been saved. +Jupyter AI enables use of language models hosted on [Amazon Bedrock](https://aws.amazon.com/bedrock/) on AWS. Ensure that you have authentication to use AWS using the `boto3` SDK with credentials stored in the `default` profile. Guidance on how to do this can be found in the [`boto3` documentation](https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html). -Bedrock also allows custom models to be trained from scratch or fine-tuned from a base model. Jupyter AI enables a custom model to be called in the chat panel using its `arn` (Amazon Resource Name). The interface is shown below: - -Screenshot of the Jupyter AI chat panel where the custom model is selected using model arn. - -For detailed workflows, see [Using Amazon Bedrock with Jupter AI](bedrock.md). +For details on enabling model access in your AWS account, using cross-region inference, or invoking custom/provisioned models, please see our dedicated documentation page on [using Amazon Bedrock in Jupyter AI](bedrock.md). ### SageMaker endpoints usage @@ -472,6 +514,13 @@ To teach Jupyter AI about a folder full of documentation, for example, run `/lea alt='Screen shot of "/learn docs/" command and a response.' class="screenshot" /> +The `/learn` command also supports unix shell-style wildcard matching. This allows fine-grained file selection for learning. For example, to learn on only notebooks in all directories you can use `/learn **/*.ipynb` and all notebooks within your base (or preferred directory if set) will be indexed, while all other file extensions will be ignored. + +:::{warning} +:name: unix shell-style wildcard matching +Certain patterns may cause `/learn` to run more slowly. For instance `/learn **` may cause directories to be walked multiple times in search of files. +::: + You can then use `/ask` to ask a question specifically about the data that you taught Jupyter AI with `/learn`. :, defaults to None. + """, + config=True, + ) + + max_history = traitlets.Int( + default_value=2, + allow_none=False, + help="""Maximum number of exchanges (user/assistant) to include in the history + when invoking a chat model, defaults to 2. + """, + config=True, + ) + def __init__(self, shell): super().__init__(shell) - self.transcript_openai = [] + self.transcript = [] # suppress warning when using old Anthropic provider warnings.filterwarnings( @@ -258,7 +277,7 @@ def _is_langchain_chain(self, name): if not acceptable_name.match(name): return False - ipython = get_ipython() + ipython = self.shell return name in ipython.user_ns and isinstance(ipython.user_ns[name], LLMChain) # Is this an acceptable name for an alias? @@ -276,7 +295,7 @@ def _validate_name(self, register_name): def _safely_set_target(self, register_name, target): # If target is a string, treat this as an alias to another model. if self._is_langchain_chain(target): - ip = get_ipython() + ip = self.shell self.custom_model_registry[register_name] = ip.user_ns[target] else: # Ensure that the destination is properly formatted @@ -405,7 +424,7 @@ def handle_error(self, args: ErrorArgs): no_errors = "There have been no errors since the kernel started." # Find the most recent error. - ip = get_ipython() + ip = self.shell if "Err" not in ip.user_ns: return TextOrMarkdown(no_errors, no_errors) @@ -430,6 +449,12 @@ def handle_error(self, args: ErrorArgs): return self.run_ai_cell(cell_args, prompt) + def _append_exchange(self, prompt: str, output: str): + """Appends a conversational exchange between user and an OpenAI Chat + model to a transcript that will be included in future exchanges.""" + self.transcript.append(HumanMessage(prompt)) + self.transcript.append(AIMessage(output)) + def _decompose_model_id(self, model_id: str): """Breaks down a model ID into a two-tuple (provider_id, local_model_id). Returns (None, None) if indeterminate.""" # custom_model_registry maps keys to either a model name (a string) or an LLMChain. @@ -463,7 +488,7 @@ def display_output(self, output, display_format, md): text=output, replace=False, ) - ip = get_ipython() + ip = self.shell ip.payload_manager.write_payload(new_cell_payload) return HTML( "AI generated code inserted below ⬇️", metadata=md @@ -493,6 +518,9 @@ def handle_list(self, args: ListArgs): def handle_version(self, args: VersionArgs): return __version__ + def handle_reset(self, args: ResetArgs): + self.transcript = [] + def run_ai_cell(self, args: CellArgs, prompt: str): provider_id, local_model_id = self._decompose_model_id(args.model_id) @@ -568,9 +596,10 @@ def run_ai_cell(self, args: CellArgs, prompt: str): prompt = provider.get_prompt_template(args.format).format(prompt=prompt) # interpolate user namespace into prompt - ip = get_ipython() + ip = self.shell prompt = prompt.format_map(FormatDict(ip.user_ns)) + context = self.transcript[-2 * self.max_history :] if self.max_history else [] if provider.is_chat_provider: ut = UsageTracker() ut._SendCopilotEvent({ @@ -580,12 +609,29 @@ def run_ai_cell(self, args: CellArgs, prompt: str): "model_provider_id": provider_id, "prompt_word_count": len(orig_prompt.split(" ")), }) - result = provider.generate([[HumanMessage(content=prompt)]]) + result = provider.generate([[*context, HumanMessage(content=prompt)]]) else: # generate output from model via provider - result = provider.generate([prompt]) + if context: + inputs = "\n\n".join( + [ + ( + f"AI: {message.content}" + if message.type == "ai" + else f"{message.type.title()}: {message.content}" + ) + for message in context + [HumanMessage(content=prompt)] + ] + ) + else: + inputs = prompt + result = provider.generate([inputs]) output = result.generations[0][0].text + + # append exchange to transcript + self._append_exchange(prompt, output) + md = {"jupyter_ai": {"provider_id": provider_id, "model_id": local_model_id}} return self.display_output(output, args.format, md) @@ -593,12 +639,23 @@ def run_ai_cell(self, args: CellArgs, prompt: str): @line_cell_magic def ai(self, line, cell=None): raw_args = line.split(" ") + default_map = {"model_id": self.default_language_model} if cell: - args = cell_magic_parser(raw_args, prog_name="%%ai", standalone_mode=False) + args = cell_magic_parser( + raw_args, + prog_name="%%ai", + standalone_mode=False, + default_map={"cell_magic_parser": default_map}, + ) else: - args = line_magic_parser(raw_args, prog_name="%ai", standalone_mode=False) + args = line_magic_parser( + raw_args, + prog_name="%ai", + standalone_mode=False, + default_map={"error": default_map}, + ) - if args == 0: + if args == 0 and self.default_language_model is None: # this happens when `--help` is called on the root command, in which # case we want to exit early. return @@ -619,6 +676,8 @@ def ai(self, line, cell=None): return self.handle_update(args) if args.type == "version": return self.handle_version(args) + if args.type == "reset": + return self.handle_reset(args) except ValueError as e: print(e, file=sys.stderr) return @@ -636,7 +695,7 @@ def ai(self, line, cell=None): prompt = cell.strip() # interpolate user namespace into prompt - ip = get_ipython() + ip = self.shell prompt = prompt.format_map(FormatDict(ip.user_ns)) return self.run_ai_cell(args, prompt) diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/models/completion.py b/packages/jupyter-ai-magics/jupyter_ai_magics/models/completion.py index 147f6cee..f2ee0cd5 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/models/completion.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/models/completion.py @@ -43,6 +43,7 @@ class InlineCompletionItem(BaseModel): class CompletionError(BaseModel): type: str + title: str traceback: str diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py index 3786ab5e..07e26e87 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py @@ -95,6 +95,10 @@ class UpdateArgs(BaseModel): target: str +class ResetArgs(BaseModel): + type: Literal["reset"] = "reset" + + class LineMagicGroup(click.Group): """Helper class to print the help string for cell magics as well when `%ai --help` is called.""" @@ -121,7 +125,7 @@ def verify_json_value(ctx, param, value): @click.command() -@click.argument("model_id") +@click.argument("model_id", required=False) @click.option( "-f", "--format", @@ -156,7 +160,8 @@ def verify_json_value(ctx, param, value): callback=verify_json_value, default="{}", ) -def cell_magic_parser(**kwargs): +@click.pass_context +def cell_magic_parser(context: click.Context, **kwargs): """ Invokes a language model identified by MODEL_ID, with the prompt being contained in all lines after the first. Both local model IDs and global @@ -165,6 +170,8 @@ def cell_magic_parser(**kwargs): To view available language models, please run `%ai list`. """ + if not kwargs["model_id"] and context.default_map: + kwargs["model_id"] = context.default_map["cell_magic_parser"]["model_id"] return CellArgs(**kwargs) @@ -176,7 +183,7 @@ def line_magic_parser(): @line_magic_parser.command(name="error") -@click.argument("model_id") +@click.argument("model_id", required=False) @click.option( "-f", "--format", @@ -211,11 +218,14 @@ def line_magic_parser(): callback=verify_json_value, default="{}", ) -def error_subparser(**kwargs): +@click.pass_context +def error_subparser(context: click.Context, **kwargs): """ Explains the most recent error. Takes the same options (except -r) as the basic `%%ai` command. """ + if not kwargs["model_id"] and context.default_map: + kwargs["model_id"] = context.default_map["error_subparser"]["model_id"] return ErrorArgs(**kwargs) @@ -271,3 +281,12 @@ def register_subparser(**kwargs): def register_subparser(**kwargs): """Update an alias called NAME to refer to the model or chain named TARGET.""" return UpdateArgs(**kwargs) + + +@line_magic_parser.command( + name="reset", + short_help="Clear the conversation transcript.", +) +def register_subparser(**kwargs): + """Clear the conversation transcript.""" + return ResetArgs() diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/anthropic.py b/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/anthropic.py index 7876c285..8cfee824 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/anthropic.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/anthropic.py @@ -1,20 +1,22 @@ -from langchain_anthropic import AnthropicLLM, ChatAnthropic +from langchain_anthropic import ChatAnthropic from ..providers import BaseProvider, EnvAuthStrategy -class AnthropicProvider(BaseProvider, AnthropicLLM): - id = "anthropic" +class ChatAnthropicProvider( + BaseProvider, ChatAnthropic +): # https://docs.anthropic.com/en/docs/about-claude/models + id = "anthropic-chat" name = "Anthropic" models = [ - "claude-v1", - "claude-v1.0", - "claude-v1.2", - "claude-2", "claude-2.0", - "claude-instant-v1", - "claude-instant-v1.0", - "claude-instant-v1.2", + "claude-2.1", + "claude-3-opus-20240229", + "claude-3-sonnet-20240229", + "claude-3-haiku-20240307", + "claude-3-5-haiku-20241022", + "claude-3-5-sonnet-20240620", + "claude-3-5-sonnet-20241022", ] model_id_key = "model" pypi_package_deps = ["anthropic"] @@ -34,26 +36,3 @@ def is_api_key_exc(cls, e: Exception): if isinstance(e, anthropic.AuthenticationError): return e.status_code == 401 and "Invalid API Key" in str(e) return False - - -class ChatAnthropicProvider( - BaseProvider, ChatAnthropic -): # https://docs.anthropic.com/en/docs/about-claude/models - id = "anthropic-chat" - name = "ChatAnthropic" - models = [ - "claude-2.0", - "claude-2.1", - "claude-instant-1.2", - "claude-3-opus-20240229", - "claude-3-sonnet-20240229", - "claude-3-haiku-20240307", - "claude-3-5-sonnet-20240620", - ] - model_id_key = "model" - pypi_package_deps = ["anthropic"] - auth_strategy = EnvAuthStrategy(name="ANTHROPIC_API_KEY") - - @property - def allows_concurrency(self): - return False diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/aws.py b/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/aws.py index 4635a2e2..4e0a5c51 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/aws.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/aws.py @@ -69,6 +69,9 @@ class BedrockChatProvider(BaseProvider, ChatBedrock): "anthropic.claude-3-haiku-20240307-v1:0", "anthropic.claude-3-opus-20240229-v1:0", "anthropic.claude-3-5-sonnet-20240620-v1:0", + "anthropic.claude-3-5-haiku-20241022-v1:0", + "anthropic.claude-3-5-sonnet-20240620-v1:0", + "anthropic.claude-3-5-sonnet-20241022-v2:0", "meta.llama2-13b-chat-v1", "meta.llama2-70b-chat-v1", "meta.llama3-8b-instruct-v1:0", @@ -122,8 +125,9 @@ class BedrockCustomProvider(BaseProvider, ChatBedrock): ), ] help = ( - "Specify the ARN (Amazon Resource Name) of the custom/provisioned model as the model ID. For more information, see the [Amazon Bedrock model IDs documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html).\n\n" - "The model provider must also be specified below. This is the provider of your foundation model *in lowercase*, e.g. `amazon`, `anthropic`, `meta`, or `mistral`." + "- For Cross-Region Inference use the appropriate `Inference profile ID` (Model ID with a region prefix, e.g., `us.meta.llama3-2-11b-instruct-v1:0`). See the [inference profiles documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-support.html). \n" + "- For custom/provisioned models, specify the model ARN (Amazon Resource Name) as the model ID. For more information, see the [Amazon Bedrock model IDs documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html).\n\n" + "The model provider must also be specified below. This is the provider of your foundation model *in lowercase*, e.g., `amazon`, `anthropic`, `cohere`, `meta`, or `mistral`." ) registry = True diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/gemini.py b/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/gemini.py index c52e7ebd..5e72abac 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/gemini.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/gemini.py @@ -2,10 +2,13 @@ from langchain_google_genai import GoogleGenerativeAI +# See list of model ids here: https://ai.google.dev/gemini-api/docs/models/gemini class GeminiProvider(BaseProvider, GoogleGenerativeAI): id = "gemini" name = "Gemini" models = [ + "gemini-1.5-pro", + "gemini-1.5-flash", "gemini-1.0-pro", "gemini-1.0-pro-001", "gemini-1.0-pro-latest", diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/ollama.py b/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/ollama.py new file mode 100644 index 00000000..bf7d8474 --- /dev/null +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/ollama.py @@ -0,0 +1,36 @@ +from langchain_ollama import ChatOllama, OllamaEmbeddings + +from ..embedding_providers import BaseEmbeddingsProvider +from ..providers import BaseProvider, TextField + + +class OllamaProvider(BaseProvider, ChatOllama): + id = "ollama" + name = "Ollama" + model_id_key = "model" + help = ( + "See [https://www.ollama.com/library](https://www.ollama.com/library) for a list of models. " + "Pass a model's name; for example, `deepseek-coder-v2`." + ) + models = ["*"] + registry = True + fields = [ + TextField(key="base_url", label="Base API URL (optional)", format="text"), + ] + + +class OllamaEmbeddingsProvider(BaseEmbeddingsProvider, OllamaEmbeddings): + id = "ollama" + name = "Ollama" + # source: https://ollama.com/library + model_id_key = "model" + models = [ + "nomic-embed-text", + "mxbai-embed-large", + "all-minilm", + "snowflake-arctic-embed", + ] + registry = True + fields = [ + TextField(key="base_url", label="Base API URL (optional)", format="text"), + ] diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/openai.py b/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/openai.py index afba7c2b..34ca76a8 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/openai.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/openai.py @@ -1,4 +1,10 @@ -from langchain_openai import AzureChatOpenAI, ChatOpenAI, OpenAI, OpenAIEmbeddings +from langchain_openai import ( + AzureChatOpenAI, + AzureOpenAIEmbeddings, + ChatOpenAI, + OpenAI, + OpenAIEmbeddings, +) from ..embedding_providers import BaseEmbeddingsProvider from ..providers import BaseProvider, EnvAuthStrategy, TextField @@ -31,22 +37,17 @@ class ChatOpenAIProvider(BaseProvider, ChatOpenAI): name = "OpenAI" models = [ "gpt-3.5-turbo", - "gpt-3.5-turbo-0125", - "gpt-3.5-turbo-0301", # Deprecated as of 2024-06-13 - "gpt-3.5-turbo-0613", # Deprecated as of 2024-06-13 "gpt-3.5-turbo-1106", - "gpt-3.5-turbo-16k", - "gpt-3.5-turbo-16k-0613", # Deprecated as of 2024-06-13 "gpt-4", "gpt-4-turbo", "gpt-4-turbo-preview", "gpt-4-0613", - "gpt-4-32k", - "gpt-4-32k-0613", "gpt-4-0125-preview", "gpt-4-1106-preview", "gpt-4o", + "gpt-4o-2024-11-20", "gpt-4o-mini", + "chatgpt-4o-latest", ] model_id_key = "model_name" pypi_package_deps = ["langchain_openai"] @@ -106,3 +107,28 @@ class OpenAIEmbeddingsProvider(BaseEmbeddingsProvider, OpenAIEmbeddings): model_id_key = "model" pypi_package_deps = ["langchain_openai"] auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY") + registry = True + fields = [ + TextField( + key="openai_api_base", label="Base API URL (optional)", format="text" + ), + ] + + +class AzureOpenAIEmbeddingsProvider(BaseEmbeddingsProvider, AzureOpenAIEmbeddings): + id = "azure" + name = "Azure OpenAI" + models = [ + "text-embedding-ada-002", + "text-embedding-3-small", + "text-embedding-3-large", + ] + model_id_key = "azure_deployment" + pypi_package_deps = ["langchain_openai"] + auth_strategy = EnvAuthStrategy( + name="AZURE_OPENAI_API_KEY", keyword_param="openai_api_key" + ) + registry = True + fields = [ + TextField(key="azure_endpoint", label="Base API URL (optional)", format="text"), + ] diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/openrouter.py b/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/openrouter.py new file mode 100644 index 00000000..81c2d7ab --- /dev/null +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/openrouter.py @@ -0,0 +1,60 @@ +from typing import Dict + +from jupyter_ai_magics import BaseProvider +from jupyter_ai_magics.providers import EnvAuthStrategy, TextField +from langchain_core.pydantic_v1 import root_validator +from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env +from langchain_openai import ChatOpenAI + + +class ChatOpenRouter(ChatOpenAI): + @property + def lc_secrets(self) -> Dict[str, str]: + return {"openai_api_key": "OPENROUTER_API_KEY"} + + +class OpenRouterProvider(BaseProvider, ChatOpenRouter): + id = "openrouter" + name = "OpenRouter" + models = [ + "*" + ] # OpenRouter supports multiple models, so we use "*" to indicate it's a registry + model_id_key = "model_name" + pypi_package_deps = ["langchain_openai"] + auth_strategy = EnvAuthStrategy(name="OPENROUTER_API_KEY") + registry = True + + fields = [ + TextField( + key="openai_api_base", label="API Base URL (optional)", format="text" + ), + ] + + def __init__(self, **kwargs): + openrouter_api_key = kwargs.pop("openrouter_api_key", None) + openrouter_api_base = kwargs.pop( + "openai_api_base", "https://openrouter.ai/api/v1" + ) + + super().__init__( + openai_api_key=openrouter_api_key, + openai_api_base=openrouter_api_base, + **kwargs, + ) + + @root_validator(pre=False, skip_on_failure=True, allow_reuse=True) + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key and python package exists in environment.""" + values["openai_api_key"] = convert_to_secret_str( + get_from_dict_or_env(values, "openai_api_key", "OPENROUTER_API_KEY") + ) + return super().validate_environment(values) + + @classmethod + def is_api_key_exc(cls, e: Exception): + import openai + + if isinstance(e, openai.AuthenticationError): + error_details = e.json_body.get("error", {}) + return error_details.get("code") == "invalid_api_key" + return False diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py index 1fca2d68..2ab87a97 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py @@ -28,15 +28,6 @@ from langchain.schema import LLMResult from langchain.schema.output_parser import StrOutputParser from langchain.schema.runnable import Runnable -from langchain_community.chat_models import QianfanChatEndpoint -from langchain_community.llms import ( - AI21, - GPT4All, - HuggingFaceEndpoint, - Ollama, - SagemakerEndpoint, - Together, -) from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.language_models.llms import BaseLLM @@ -62,17 +53,35 @@ You are not a language model, but rather an application built on a foundation model from {provider_name} called {local_model_id}. You are talkative and you provide lots of specific details from the foundation model's context. You may use Markdown to format your response. -Code blocks must be formatted in Markdown. -Math should be rendered with inline TeX markup, surrounded by $. +If your response includes code, they must be enclosed in Markdown fenced code blocks (with triple backticks before and after). +If your response includes mathematical notation, they must be expressed in LaTeX markup and enclosed in LaTeX delimiters. +All dollar quantities (of USD) must be formatted in LaTeX, with the `$` symbol escaped by a single backslash `\\`. +- Example prompt: `If I have \\\\$100 and spend \\\\$20, how much money do I have left?` +- **Correct** response: `You have \\(\\$80\\) remaining.` +- **Incorrect** response: `You have $80 remaining.` If you do not know the answer to a question, answer truthfully by responding that you do not know. The following is a friendly conversation between you and a human. """.strip() -CHAT_DEFAULT_TEMPLATE = """Current conversation: -{history} -Human: {input} +CHAT_DEFAULT_TEMPLATE = """ +{% if context %} +Context: +{{context}} + +{% endif %} +Current conversation: +{{history}} +Human: {{input}} AI:""" +HUMAN_MESSAGE_TEMPLATE = """ +{% if context %} +Context: +{{context}} + +{% endif %} +{{input}} +""" COMPLETION_SYSTEM_PROMPT = """ You are an application built to provide helpful code completion suggestions. @@ -401,17 +410,21 @@ def get_chat_prompt_template(self) -> PromptTemplate: CHAT_SYSTEM_PROMPT ).format(provider_name=name, local_model_id=self.model_id), MessagesPlaceholder(variable_name="history"), - HumanMessagePromptTemplate.from_template("{input}"), + HumanMessagePromptTemplate.from_template( + HUMAN_MESSAGE_TEMPLATE, + template_format="jinja2", + ), ] ) else: return PromptTemplate( - input_variables=["history", "input"], + input_variables=["history", "input", "context"], template=CHAT_SYSTEM_PROMPT.format( provider_name=name, local_model_id=self.model_id ) + "\n\n" + CHAT_DEFAULT_TEMPLATE, + template_format="jinja2", ) def get_completion_prompt_template(self) -> PromptTemplate: @@ -524,6 +537,7 @@ def _create_completion_chain(self) -> Runnable: return prompt_template | self | StrOutputParser() +''' class AI21Provider(BaseProvider, AI21): id = "ai21" name = "AI21" @@ -690,20 +704,6 @@ async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]: return await self._call_in_executor(*args, **kwargs) -class OllamaProvider(BaseProvider, Ollama): - id = "ollama" - name = "Ollama" - model_id_key = "model" - help = ( - "See [https://www.ollama.com/library](https://www.ollama.com/library) for a list of models. " - "Pass a model's name; for example, `deepseek-coder-v2`." - ) - models = ["*"] - registry = True - fields = [ - TextField(key="base_url", label="Base API URL (optional)", format="text"), - ] - class TogetherAIProvider(BaseProvider, Together): id = "togetherai" name = "Together AI" @@ -750,3 +750,4 @@ class QianfanProvider(BaseProvider, QianfanChatEndpoint): model_id_key = "model_name" pypi_package_deps = ["qianfan"] auth_strategy = MultiEnvAuthStrategy(names=["QIANFAN_AK", "QIANFAN_SK"]) +''' diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/py.typed b/packages/jupyter-ai-magics/jupyter_ai_magics/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_magics.py b/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_magics.py index ec163534..2bd44024 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_magics.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_magics.py @@ -1,11 +1,122 @@ +import os +from unittest.mock import Mock, patch + +import pytest from IPython import InteractiveShell +from IPython.core.display import Markdown +from jupyter_ai_magics.magics import AiMagics +from langchain_core.messages import AIMessage, HumanMessage +from pytest import fixture from traitlets.config.loader import Config -def test_aliases_config(): +@fixture +def ip() -> InteractiveShell: ip = InteractiveShell() ip.config = Config() + return ip + + +def test_aliases_config(ip): ip.config.AiMagics.aliases = {"my_custom_alias": "my_provider:my_model"} ip.extension_manager.load_extension("jupyter_ai_magics") providers_list = ip.run_line_magic("ai", "list").text assert "my_custom_alias" in providers_list + + +def test_default_model_cell(ip): + ip.config.AiMagics.default_language_model = "my-favourite-llm" + ip.extension_manager.load_extension("jupyter_ai_magics") + with patch.object(AiMagics, "run_ai_cell", return_value=None) as mock_run: + ip.run_cell_magic("ai", "", cell="Write code for me please") + assert mock_run.called + cell_args = mock_run.call_args.args[0] + assert cell_args.model_id == "my-favourite-llm" + +def test_non_default_model_cell(ip): + ip.config.AiMagics.default_language_model = "my-favourite-llm" + ip.extension_manager.load_extension("jupyter_ai_magics") + with patch.object(AiMagics, "run_ai_cell", return_value=None) as mock_run: + ip.run_cell_magic("ai", "some-different-llm", cell="Write code for me please") + assert mock_run.called + cell_args = mock_run.call_args.args[0] + assert cell_args.model_id == "some-different-llm" + + +def test_default_model_error_line(ip): + ip.config.AiMagics.default_language_model = "my-favourite-llm" + ip.extension_manager.load_extension("jupyter_ai_magics") + with patch.object(AiMagics, "handle_error", return_value=None) as mock_run: + ip.run_cell_magic("ai", "error", cell=None) + assert mock_run.called + cell_args = mock_run.call_args.args[0] + assert cell_args.model_id == "my-favourite-llm" + + +PROMPT = HumanMessage( + content=("Write code for me please\n\nProduce output in markdown format only.") +) +RESPONSE = AIMessage(content="Leet code") +AI1 = AIMessage("ai1") +H1 = HumanMessage("h1") +AI2 = AIMessage("ai2") +H2 = HumanMessage("h2") +AI3 = AIMessage("ai3") + + +@pytest.mark.parametrize( + ["transcript", "max_history", "expected_context"], + [ + ([], 3, [PROMPT]), + ([AI1], 0, [PROMPT]), + ([AI1], 1, [AI1, PROMPT]), + ([H1, AI1], 0, [PROMPT]), + ([H1, AI1], 1, [H1, AI1, PROMPT]), + ([AI1, H1, AI2], 0, [PROMPT]), + ([AI1, H1, AI2], 1, [H1, AI2, PROMPT]), + ([AI1, H1, AI2], 2, [AI1, H1, AI2, PROMPT]), + ([H1, AI1, H2, AI2], 0, [PROMPT]), + ([H1, AI1, H2, AI2], 1, [H2, AI2, PROMPT]), + ([H1, AI1, H2, AI2], 2, [H1, AI1, H2, AI2, PROMPT]), + ([AI1, H1, AI2, H2, AI3], 0, [PROMPT]), + ([AI1, H1, AI2, H2, AI3], 1, [H2, AI3, PROMPT]), + ([AI1, H1, AI2, H2, AI3], 2, [H1, AI2, H2, AI3, PROMPT]), + ([AI1, H1, AI2, H2, AI3], 3, [AI1, H1, AI2, H2, AI3, PROMPT]), + ], +) +def test_max_history(ip, transcript, max_history, expected_context): + ip.extension_manager.load_extension("jupyter_ai_magics") + ai_magics = ip.magics_manager.registry["AiMagics"] + ai_magics.transcript = transcript.copy() + ai_magics.max_history = max_history + provider = ai_magics._get_provider("openrouter") + with ( + patch.object(provider, "generate") as generate, + patch.dict(os.environ, OPENROUTER_API_KEY="123"), + ): + generate.return_value.generations = [[Mock(text="Leet code")]] + result = ip.run_cell_magic( + "ai", + "openrouter:anthropic/claude-3.5-sonnet", + cell="Write code for me please", + ) + provider.generate.assert_called_once_with([expected_context]) + assert isinstance(result, Markdown) + assert result.data == "Leet code" + assert result.filename is None + assert result.metadata == { + "jupyter_ai": { + "model_id": "anthropic/claude-3.5-sonnet", + "provider_id": "openrouter", + } + } + assert result.url is None + assert ai_magics.transcript == [*transcript, PROMPT, RESPONSE] + + +def test_reset(ip): + ip.extension_manager.load_extension("jupyter_ai_magics") + ai_magics = ip.magics_manager.registry["AiMagics"] + ai_magics.transcript = [AI1, H1, AI2, H2, AI3] + result = ip.run_line_magic("ai", "reset") + assert ai_magics.transcript == [] diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py b/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py index 983bbf2d..5fb8f4fe 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py @@ -36,10 +36,9 @@ def get_lm_providers( ) continue except Exception as e: - log.error( - f"Unable to load model provider `{provider_ep.name}`. Printing full exception below." + log.warning( + f"Unable to load model provider `{provider_ep.name}`", exc_info=e ) - log.exception(e) continue if not is_provider_allowed(provider.id, restrictions): @@ -66,7 +65,7 @@ def get_em_providers( try: provider = model_provider_ep.load() except Exception as e: - log.error( + log.warning( f"Unable to load embeddings model provider class from entry point `{model_provider_ep.name}`: %s.", e, ) diff --git a/packages/jupyter-ai-magics/package.json b/packages/jupyter-ai-magics/package.json index 569f939f..9586ad99 100644 --- a/packages/jupyter-ai-magics/package.json +++ b/packages/jupyter-ai-magics/package.json @@ -1,11 +1,11 @@ { "name": "@jupyter-ai/magics", - "version": "2.20.0+cloudera", + "version": "2.28.3+cloudera", "description": "Jupyter AI magics Python package. Not published on NPM.", "private": true, - "homepage": "https://github.infra.cloudera.com/Sense/copilot", + "homepage": "https://github.com/cloudera/copilot", "bugs": { - "url": "https://github.infra.cloudera.com/Sense/copilot/issues" + "url": "https://github.com/cloudera/copilot/issues" }, "license": "BSD-3-Clause", "author": { @@ -13,7 +13,7 @@ }, "repository": { "type": "git", - "url": "https://github.infra.cloudera.com/Sense/copilot.git" + "url": "https://github.com/cloudera/copilot.git" }, "scripts": { "dev-install": "pip install -e \".[dev,all]\"", diff --git a/packages/jupyter-ai-magics/pyproject.toml b/packages/jupyter-ai-magics/pyproject.toml index 02eb8b9e..29937998 100644 --- a/packages/jupyter-ai-magics/pyproject.toml +++ b/packages/jupyter-ai-magics/pyproject.toml @@ -24,8 +24,8 @@ dynamic = ["version", "description", "authors", "urls", "keywords"] dependencies = [ "ipython", "importlib_metadata>=5.2.0", - "langchain>=0.1.0,<0.3.0", - "langchain_community>=0.1.0,<0.3.0", + "langchain>=0.2.17,<0.3.0", + "langchain_community>=0.2.19,<0.3.0", "typing_extensions>=4.5.0", "click~=8.0", "jsonpath-ng>=1.5.3,<2", @@ -37,53 +37,17 @@ dev = ["pre-commit>=3.3.3,<4"] test = ["coverage", "pytest", "pytest-asyncio", "pytest-cov"] all = [ - "ai21", - "gpt4all", - "huggingface_hub", - "ipywidgets", - "langchain_anthropic", "langchain_aws", - "langchain_cohere", - "langchain_google_genai", - "langchain_mistralai", - "langchain_nvidia_ai_endpoints", - "langchain_openai", - "pillow", "boto3", - "qianfan", - "together", ] [project.entry-points."jupyter_ai.model_providers"] -ai21 = "jupyter_ai_magics:AI21Provider" -anthropic = "jupyter_ai_magics.partner_providers.anthropic:AnthropicProvider" -anthropic-chat = "jupyter_ai_magics.partner_providers.anthropic:ChatAnthropicProvider" -cohere = "jupyter_ai_magics.partner_providers.cohere:CohereProvider" -gpt4all = "jupyter_ai_magics:GPT4AllProvider" -huggingface_hub = "jupyter_ai_magics:HfHubProvider" -ollama = "jupyter_ai_magics:OllamaProvider" -openai = "jupyter_ai_magics.partner_providers.openai:OpenAIProvider" -openai-chat = "jupyter_ai_magics.partner_providers.openai:ChatOpenAIProvider" -azure-chat-openai = "jupyter_ai_magics.partner_providers.openai:AzureChatOpenAIProvider" -sagemaker-endpoint = "jupyter_ai_magics.partner_providers.aws:SmEndpointProvider" amazon-bedrock = "jupyter_ai_magics.partner_providers.aws:BedrockProvider" amazon-bedrock-chat = "jupyter_ai_magics.partner_providers.aws:BedrockChatProvider" amazon-bedrock-custom = "jupyter_ai_magics.partner_providers.aws:BedrockCustomProvider" -qianfan = "jupyter_ai_magics:QianfanProvider" -nvidia-chat = "jupyter_ai_magics.partner_providers.nvidia:ChatNVIDIAProvider" -together-ai = "jupyter_ai_magics:TogetherAIProvider" -gemini = "jupyter_ai_magics.partner_providers.gemini:GeminiProvider" -mistralai = "jupyter_ai_magics.partner_providers.mistralai:MistralAIProvider" [project.entry-points."jupyter_ai.embeddings_model_providers"] bedrock = "jupyter_ai_magics.partner_providers.aws:BedrockEmbeddingsProvider" -cohere = "jupyter_ai_magics.partner_providers.cohere:CohereEmbeddingsProvider" -mistralai = "jupyter_ai_magics.partner_providers.mistralai:MistralAIEmbeddingsProvider" -gpt4all = "jupyter_ai_magics:GPT4AllEmbeddingsProvider" -huggingface_hub = "jupyter_ai_magics:HfHubEmbeddingsProvider" -ollama = "jupyter_ai_magics:OllamaEmbeddingsProvider" -openai = "jupyter_ai_magics.partner_providers.openai:OpenAIEmbeddingsProvider" -qianfan = "jupyter_ai_magics:QianfanEmbeddingsEndpointProvider" [tool.hatch.version] source = "nodejs" diff --git a/packages/jupyter-ai-test/jupyter_ai_test/test_llms.py b/packages/jupyter-ai-test/jupyter_ai_test/test_llms.py index c7c72666..17fa4265 100644 --- a/packages/jupyter-ai-test/jupyter_ai_test/test_llms.py +++ b/packages/jupyter-ai-test/jupyter_ai_test/test_llms.py @@ -48,10 +48,11 @@ def _stream( run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[GenerationChunk]: - time.sleep(5) + time.sleep(1) yield GenerationChunk( - text="Hello! This is a dummy response from a test LLM. I will now count from 1 to 100.\n\n" + text="Hello! This is a dummy response from a test LLM. I will now count from 1 to 5.\n\n", + generation_info={"test_metadata_field": "foobar"}, ) - for i in range(1, 101): - time.sleep(0.5) + for i in range(1, 6): + time.sleep(0.2) yield GenerationChunk(text=f"{i}, ") diff --git a/packages/jupyter-ai-test/jupyter_ai_test/test_providers.py b/packages/jupyter-ai-test/jupyter_ai_test/test_providers.py index f2803dee..1ec042af 100644 --- a/packages/jupyter-ai-test/jupyter_ai_test/test_providers.py +++ b/packages/jupyter-ai-test/jupyter_ai_test/test_providers.py @@ -75,3 +75,41 @@ class TestProviderWithStreaming(BaseProvider, TestLLMWithStreaming): fields: ClassVar[List[Field]] = [] """User inputs expected by this provider when initializing it. Each `Field` `f` should be passed in the constructor as a keyword argument, keyed by `f.key`.""" + + +class TestProviderAskLearnUnsupported(BaseProvider, TestLLMWithStreaming): + id: ClassVar[str] = "test-provider-ask-learn-unsupported" + """ID for this provider class.""" + + name: ClassVar[str] = "Test Provider (/learn and /ask unsupported)" + """User-facing name of this provider.""" + + models: ClassVar[List[str]] = ["test"] + """List of supported models by their IDs. For registry providers, this will + be just ["*"].""" + + help: ClassVar[str] = None + """Text to display in lieu of a model list for a registry provider that does + not provide a list of models.""" + + model_id_key: ClassVar[str] = "model_id" + """Kwarg expected by the upstream LangChain provider.""" + + model_id_label: ClassVar[str] = "Model ID" + """Human-readable label of the model ID.""" + + pypi_package_deps: ClassVar[List[str]] = [] + """List of PyPi package dependencies.""" + + auth_strategy: ClassVar[AuthStrategy] = None + """Authentication/authorization strategy. Declares what credentials are + required to use this model provider. Generally should not be `None`.""" + + registry: ClassVar[bool] = False + """Whether this provider is a registry provider.""" + + fields: ClassVar[List[Field]] = [] + """User inputs expected by this provider when initializing it. Each `Field` `f` + should be passed in the constructor as a keyword argument, keyed by `f.key`.""" + + unsupported_slash_commands = {"/learn", "/ask"} diff --git a/packages/jupyter-ai-test/package.json b/packages/jupyter-ai-test/package.json index df186d72..6c23773f 100644 --- a/packages/jupyter-ai-test/package.json +++ b/packages/jupyter-ai-test/package.json @@ -1,6 +1,6 @@ { "name": "@jupyter-ai/test", - "version": "2.20.0", + "version": "2.28.3", "description": "Jupyter AI test package. Not published on NPM or PyPI.", "private": true, "homepage": "https://github.com/jupyterlab/jupyter-ai", diff --git a/packages/jupyter-ai-test/pyproject.toml b/packages/jupyter-ai-test/pyproject.toml index c50c520d..eaecc09d 100644 --- a/packages/jupyter-ai-test/pyproject.toml +++ b/packages/jupyter-ai-test/pyproject.toml @@ -31,6 +31,7 @@ test = ["coverage", "pytest", "pytest-asyncio", "pytest-cov"] [project.entry-points."jupyter_ai.model_providers"] test-provider = "jupyter_ai_test.test_providers:TestProvider" test-provider-with-streaming = "jupyter_ai_test.test_providers:TestProviderWithStreaming" +test-provider-ask-learn-unsupported = "jupyter_ai_test.test_providers:TestProviderAskLearnUnsupported" [project.entry-points."jupyter_ai.chat_handlers"] test-slash-command = "jupyter_ai_test.test_slash_commands:TestSlashCommand" diff --git a/packages/jupyter-ai/jupyter_ai/callback_handlers/__init__.py b/packages/jupyter-ai/jupyter_ai/callback_handlers/__init__.py new file mode 100644 index 00000000..4567ecba --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/callback_handlers/__init__.py @@ -0,0 +1,6 @@ +""" +Provides classes which extend `langchain_core.callbacks:BaseCallbackHandler`. +Not to be confused with Jupyter AI chat handlers. +""" + +from .metadata import MetadataCallbackHandler diff --git a/packages/jupyter-ai/jupyter_ai/callback_handlers/metadata.py b/packages/jupyter-ai/jupyter_ai/callback_handlers/metadata.py new file mode 100644 index 00000000..c409a963 --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/callback_handlers/metadata.py @@ -0,0 +1,55 @@ +import inspect +import json + +from langchain_core.callbacks import BaseCallbackHandler +from langchain_core.outputs import LLMResult + + +def requires_no_arguments(func): + sig = inspect.signature(func) + for param in sig.parameters.values(): + if param.default is param.empty and param.kind in ( + param.POSITIONAL_ONLY, + param.POSITIONAL_OR_KEYWORD, + param.KEYWORD_ONLY, + ): + return False + return True + + +def convert_to_serializable(obj): + """Convert an object to a JSON serializable format""" + if hasattr(obj, "dict") and callable(obj.dict) and requires_no_arguments(obj.dict): + return obj.dict() + if hasattr(obj, "__dict__"): + return obj.__dict__ + return str(obj) + + +class MetadataCallbackHandler(BaseCallbackHandler): + """ + When passed as a callback handler, this stores the LLMResult's + `generation_info` dictionary in the `self.jai_metadata` instance attribute + after the provider fully processes an input. + + If used in a streaming chat handler: the `metadata` field of the final + `AgentStreamChunkMessage` should be set to `self.jai_metadata`. + + If used in a non-streaming chat handler: the `metadata` field of the + returned `AgentChatMessage` should be set to `self.jai_metadata`. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.jai_metadata = {} + + def on_llm_end(self, response: LLMResult, **kwargs) -> None: + if not (len(response.generations) and len(response.generations[0])): + return + + metadata = response.generations[0][0].generation_info or {} + + # Convert any non-serializable objects in metadata + self.jai_metadata = json.loads( + json.dumps(metadata, default=convert_to_serializable) + ) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py index 80c0f76f..b07933a8 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py @@ -72,7 +72,8 @@ async def process_message(self, message: HumanChatMessage): self.get_llm_chain() try: - with self.pending("Searching learned documents"): + with self.pending("Searching learned documents", message): + assert self.llm_chain ut = UsageTracker() ut._SendCopilotEvent({ "event_details": "/ask", @@ -83,7 +84,11 @@ async def process_message(self, message: HumanChatMessage): "model_provider_id": self.config_manager.lm_provider.id, "prompt_word_count": len(args.query), }) - result = await self.llm_chain.acall({"question": query}) + # TODO: migrate this class to use a LCEL `Runnable` instead of + # `Chain`, then remove the below ignore comment. + result = await self.llm_chain.acall( # type:ignore[attr-defined] + {"question": query} + ) response = result["answer"] self.reply(response, message) except AssertionError as e: diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py index 44efe026..0ba85922 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py @@ -1,10 +1,12 @@ import argparse +import asyncio import contextlib import os import time import traceback from typing import ( TYPE_CHECKING, + Any, Awaitable, ClassVar, Dict, @@ -13,25 +15,49 @@ Optional, Type, Union, + cast, ) +from typing import get_args as get_type_args from uuid import uuid4 from dask.distributed import Client as DaskClient +from jupyter_ai.callback_handlers import MetadataCallbackHandler from jupyter_ai.config_manager import ConfigManager, Logger +from jupyter_ai.history import WrappedBoundedChatHistory from jupyter_ai.models import ( AgentChatMessage, + AgentStreamChunkMessage, + AgentStreamMessage, ChatMessage, ClosePendingMessage, HumanChatMessage, + Message, PendingMessage, ) from jupyter_ai_magics import Persona from jupyter_ai_magics.providers import BaseProvider from jupyter_ai_magics.models.usage_tracking import UsageTracker from langchain.pydantic_v1 import BaseModel +from langchain_core.messages import AIMessageChunk +from langchain_core.runnables import Runnable +from langchain_core.runnables.config import RunnableConfig +from langchain_core.runnables.config import merge_configs as merge_runnable_configs +from langchain_core.runnables.utils import Input if TYPE_CHECKING: + from jupyter_ai.context_providers import BaseCommandContextProvider from jupyter_ai.handlers import RootChatHandler + from jupyter_ai.history import BoundedChatHistory + from langchain_core.chat_history import BaseChatMessageHistory + + +def get_preferred_dir(root_dir: str, preferred_dir: Optional[str]) -> Optional[str]: + if preferred_dir is not None and preferred_dir != "": + preferred_dir = os.path.expanduser(preferred_dir) + if not preferred_dir.startswith(root_dir): + preferred_dir = os.path.join(root_dir, preferred_dir) + return os.path.abspath(preferred_dir) + return None def get_preferred_dir(root_dir: str, preferred_dir: str) -> Optional[str]: @@ -45,7 +71,7 @@ def get_preferred_dir(root_dir: str, preferred_dir: str) -> Optional[str]: # Chat handler type, with specific attributes for each class HandlerRoutingType(BaseModel): - routing_method: ClassVar[Union[Literal["slash_command"]]] = ... + routing_method: ClassVar[Union[Literal["slash_command"]]] """The routing method that sends commands to this handler.""" @@ -81,17 +107,17 @@ class BaseChatHandler: multiple chat handler classes.""" # Class attributes - id: ClassVar[str] = ... + id: ClassVar[str] """ID for this chat handler; should be unique""" - name: ClassVar[str] = ... + name: ClassVar[str] """User-facing name of this handler""" - help: ClassVar[str] = ... + help: ClassVar[str] """What this chat handler does, which third-party models it contacts, the data it returns to the user, and so on, for display in the UI.""" - routing_type: HandlerRoutingType = ... + routing_type: ClassVar[HandlerRoutingType] uses_llm: ClassVar[bool] = True """Class attribute specifying whether this chat handler uses the LLM @@ -103,10 +129,28 @@ class BaseChatHandler: parse the arguments and display help when user queries with `-h` or `--help`""" - _requests_count = 0 + _requests_count: ClassVar[int] = 0 """Class attribute set to the number of requests that Cloudera Copilot is currently handling.""" + # Instance attributes + help_message_template: str + """Format string template that is used to build the help message. Specified + from traitlets configuration.""" + + chat_handlers: Dict[str, "BaseChatHandler"] + """Dictionary of chat handlers. Allows one chat handler to reference other + chat handlers, which is necessary for some use-cases like printing the help + message.""" + + context_providers: Dict[str, "BaseCommandContextProvider"] + """Dictionary of context providers. Allows chat handlers to reference + context providers, which can be used to provide context to the LLM.""" + + message_interrupted: Dict[str, asyncio.Event] + """Dictionary mapping an agent message identifier to an asyncio Event + which indicates if the message generation/streaming was interrupted.""" + def __init__( self, log: Logger, @@ -114,15 +158,21 @@ def __init__( root_chat_handlers: Dict[str, "RootChatHandler"], model_parameters: Dict[str, Dict], chat_history: List[ChatMessage], + llm_chat_memory: "BoundedChatHistory", root_dir: str, preferred_dir: Optional[str], dask_client_future: Awaitable[DaskClient], + help_message_template: str, + chat_handlers: Dict[str, "BaseChatHandler"], + context_providers: Dict[str, "BaseCommandContextProvider"], + message_interrupted: Dict[str, asyncio.Event], ): self.log = log self.config_manager = config_manager self._root_chat_handlers = root_chat_handlers self.model_parameters = model_parameters self._chat_history = chat_history + self.llm_chat_memory = llm_chat_memory self.parser = argparse.ArgumentParser( add_help=False, description=self.help, formatter_class=MarkdownHelpFormatter ) @@ -134,9 +184,14 @@ def __init__( self.root_dir = os.path.abspath(os.path.expanduser(root_dir)) self.preferred_dir = get_preferred_dir(self.root_dir, preferred_dir) self.dask_client_future = dask_client_future - self.llm = None - self.llm_params = None - self.llm_chain = None + self.help_message_template = help_message_template + self.chat_handlers = chat_handlers + self.context_providers = context_providers + self.message_interrupted = message_interrupted + + self.llm: Optional[BaseProvider] = None + self.llm_params: Optional[dict] = None + self.llm_chain: Optional[Runnable] = None async def on_message(self, message: HumanChatMessage): """ @@ -149,9 +204,8 @@ async def on_message(self, message: HumanChatMessage): # ensure the current slash command is supported if self.routing_type.routing_method == "slash_command": - slash_command = ( - "/" + self.routing_type.slash_id if self.routing_type.slash_id else "" - ) + routing_type = cast(SlashCommandRoutingType, self.routing_type) + slash_command = "/" + routing_type.slash_id if routing_type.slash_id else "" if slash_command in lm_provider_klass.unsupported_slash_commands: self.reply( "Sorry, the selected language model does not support this slash command." @@ -227,6 +281,26 @@ async def _default_handle_exc(self, e: Exception, message: HumanChatMessage): ) self.reply(response, message) + def broadcast_message(self, message: Message): + """ + Broadcasts a message to all WebSocket connections. If there are no + WebSocket connections and the message is a chat message, this method + directly appends to `self.chat_history`. + """ + broadcast = False + for websocket in self._root_chat_handlers.values(): + if not websocket: + continue + + websocket.broadcast_message(message) + broadcast = True + break + + if not broadcast: + if isinstance(message, get_type_args(ChatMessage)): + cast(ChatMessage, message) + self._chat_history.append(message) + def reply(self, response: str, human_msg: Optional[HumanChatMessage] = None): """ Sends an agent message, usually in response to a received @@ -240,18 +314,19 @@ def reply(self, response: str, human_msg: Optional[HumanChatMessage] = None): persona=self.persona, ) - for handler in self._root_chat_handlers.values(): - if not handler: - continue - - handler.broadcast_message(agent_msg) - break + self.broadcast_message(agent_msg) @property def persona(self): return self.config_manager.persona - def start_pending(self, text: str, ellipsis: bool = True) -> PendingMessage: + def start_pending( + self, + text: str, + human_msg: Optional[HumanChatMessage] = None, + *, + ellipsis: bool = True, + ) -> PendingMessage: """ Sends a pending message to the client. @@ -263,16 +338,12 @@ def start_pending(self, text: str, ellipsis: bool = True) -> PendingMessage: id=uuid4().hex, time=time.time(), body=text, + reply_to=human_msg.id if human_msg else "", persona=Persona(name=persona.name, avatar_route=persona.avatar_route), ellipsis=ellipsis, ) - for handler in self._root_chat_handlers.values(): - if not handler: - continue - - handler.broadcast_message(pending_msg) - break + self.broadcast_message(pending_msg) return pending_msg def close_pending(self, pending_msg: PendingMessage): @@ -286,22 +357,24 @@ def close_pending(self, pending_msg: PendingMessage): id=pending_msg.id, ) - for handler in self._root_chat_handlers.values(): - if not handler: - continue - - handler.broadcast_message(close_pending_msg) - break + self.broadcast_message(close_pending_msg) + pending_msg.closed = True pending_msg.closed = True @contextlib.contextmanager - def pending(self, text: str, ellipsis: bool = True): + def pending( + self, + text: str, + human_msg: Optional[HumanChatMessage] = None, + *, + ellipsis: bool = True, + ): """ Context manager that sends a pending message to the client, and closes it after the block is executed. """ - pending_msg = self.start_pending(text, ellipsis=ellipsis) + pending_msg = self.start_pending(text, human_msg=human_msg, ellipsis=ellipsis) try: yield pending_msg finally: @@ -359,6 +432,16 @@ def parse_args(self, message, silent=False): return None return args + def get_llm_chat_memory( + self, + last_human_msg: HumanChatMessage, + **kwargs, + ) -> "BaseChatMessageHistory": + return WrappedBoundedChatHistory( + history=self.llm_chat_memory, + last_human_msg=last_human_msg, + ) + @property def output_dir(self) -> str: # preferred dir is preferred, but if it is not specified, @@ -367,3 +450,176 @@ def output_dir(self) -> str: return self.preferred_dir else: return self.root_dir + + def send_help_message(self, human_msg: Optional[HumanChatMessage] = None) -> None: + """Sends a help message to all connected clients.""" + lm_provider = self.config_manager.lm_provider + unsupported_slash_commands = ( + lm_provider.unsupported_slash_commands if lm_provider else set() + ) + chat_handlers = self.chat_handlers + slash_commands = {k: v for k, v in chat_handlers.items() if k != "default"} + for key in unsupported_slash_commands: + del slash_commands[key] + + # markdown string that lists the slash commands + slash_commands_list = "\n".join( + [ + f"* `{command_name}` — {handler.help}" + for command_name, handler in slash_commands.items() + ] + ) + + context_commands_list = "\n".join( + [ + f"* `{cp.command_id}` — {cp.help}" + for cp in self.context_providers.values() + ] + ) + + help_message_body = self.help_message_template.format( + persona_name=self.persona.name, + slash_commands_list=slash_commands_list, + context_commands_list=context_commands_list, + ) + help_message = AgentChatMessage( + id=uuid4().hex, + time=time.time(), + body=help_message_body, + reply_to=human_msg.id if human_msg else "", + persona=self.persona, + ) + + self.broadcast_message(help_message) + + def _start_stream(self, human_msg: HumanChatMessage) -> str: + """ + Sends an `agent-stream` message to indicate the start of a response + stream. Returns the ID of the message, denoted as the `stream_id`. + """ + stream_id = uuid4().hex + stream_msg = AgentStreamMessage( + id=stream_id, + time=time.time(), + body="", + reply_to=human_msg.id, + persona=self.persona, + complete=False, + ) + + self.broadcast_message(stream_msg) + return stream_id + + def _send_stream_chunk( + self, + stream_id: str, + content: str, + complete: bool = False, + metadata: Optional[Dict[str, Any]] = None, + ) -> None: + """ + Sends an `agent-stream-chunk` message containing content that should be + appended to an existing `agent-stream` message with ID `stream_id`. + """ + if not metadata: + metadata = {} + + stream_chunk_msg = AgentStreamChunkMessage( + id=stream_id, content=content, stream_complete=complete, metadata=metadata + ) + self.broadcast_message(stream_chunk_msg) + + async def stream_reply( + self, + input: Input, + human_msg: HumanChatMessage, + pending_msg="Generating response", + config: Optional[RunnableConfig] = None, + ): + """ + Streams a reply to a human message by invoking + `self.llm_chain.astream()`. A LangChain `Runnable` instance must be + bound to `self.llm_chain` before invoking this method. + + Arguments + --------- + - `input`: The input to your runnable. The type of `input` depends on + the runnable in `self.llm_chain`, but is usually a dictionary whose keys + refer to input variables in your prompt template. + + - `human_msg`: The `HumanChatMessage` being replied to. + + - `config` (optional): A `RunnableConfig` object that specifies + additional configuration when streaming from the runnable. + + - `pending_msg` (optional): Changes the default pending message from + "Generating response". + """ + assert self.llm_chain + assert isinstance(self.llm_chain, Runnable) + + received_first_chunk = False + metadata_handler = MetadataCallbackHandler() + base_config: RunnableConfig = { + "configurable": {"last_human_msg": human_msg}, + "callbacks": [metadata_handler], + } + merged_config: RunnableConfig = merge_runnable_configs(base_config, config) + + # start with a pending message + with self.pending(pending_msg, human_msg) as pending_message: + # stream response in chunks. this works even if a provider does not + # implement streaming, as `astream()` defaults to yielding `_call()` + # when `_stream()` is not implemented on the LLM class. + chunk_generator = self.llm_chain.astream(input, config=merged_config) + stream_interrupted = False + async for chunk in chunk_generator: + if not received_first_chunk: + # when receiving the first chunk, close the pending message and + # start the stream. + self.close_pending(pending_message) + stream_id = self._start_stream(human_msg=human_msg) + received_first_chunk = True + self.message_interrupted[stream_id] = asyncio.Event() + + if self.message_interrupted[stream_id].is_set(): + try: + # notify the model provider that streaming was interrupted + # (this is essential to allow the model to stop generating) + # + # note: `mypy` flags this line, claiming that `athrow` is + # not defined on `AsyncIterator`. This is why an ignore + # comment is placed here. + await chunk_generator.athrow( # type:ignore[attr-defined] + GenerationInterrupted() + ) + except GenerationInterrupted: + # do not let the exception bubble up in case if + # the provider did not handle it + pass + stream_interrupted = True + break + + if isinstance(chunk, AIMessageChunk) and isinstance(chunk.content, str): + self._send_stream_chunk(stream_id, chunk.content) + elif isinstance(chunk, str): + self._send_stream_chunk(stream_id, chunk) + else: + self.log.error(f"Unrecognized type of chunk yielded: {type(chunk)}") + break + + # complete stream after all chunks have been streamed + stream_tombstone = ( + "\n\n(AI response stopped by user)" if stream_interrupted else "" + ) + self._send_stream_chunk( + stream_id, + stream_tombstone, + complete=True, + metadata=metadata_handler.jai_metadata, + ) + del self.message_interrupted[stream_id] + + +class GenerationInterrupted(asyncio.CancelledError): + """Exception raised when streaming is cancelled by the user""" diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py index 97cae4ab..d5b0ab6c 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py @@ -1,7 +1,4 @@ -from typing import List - -from jupyter_ai.chat_handlers.help import build_help_message -from jupyter_ai.models import ChatMessage, ClearMessage +from jupyter_ai.models import ClearRequest from .base import BaseChatHandler, SlashCommandRoutingType @@ -20,22 +17,10 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) async def process_message(self, _): + # Clear chat by triggering `RootChatHandler.on_clear_request()`. for handler in self._root_chat_handlers.values(): if not handler: continue - # Clear chat - handler.broadcast_message(ClearMessage()) - self._chat_history.clear() - - # Build /help message and reinstate it in chat - chat_handlers = handler.chat_handlers - persona = self.config_manager.persona - lm_provider = self.config_manager.lm_provider - unsupported_slash_commands = ( - lm_provider.unsupported_slash_commands if lm_provider else set() - ) - msg = build_help_message(chat_handlers, persona, unsupported_slash_commands) - self.reply(msg.body) - + handler.on_clear_request(ClearRequest()) break diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py index de1f0a7f..54272851 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py @@ -1,18 +1,13 @@ -import time +import asyncio from typing import Dict, Type -from uuid import uuid4 -from jupyter_ai.models import ( - AgentStreamChunkMessage, - AgentStreamMessage, - HumanChatMessage, -) from jupyter_ai_magics.models.usage_tracking import UsageTracker +from jupyter_ai.models import HumanChatMessage from jupyter_ai_magics.providers import BaseProvider -from langchain_core.messages import AIMessageChunk +from langchain_core.runnables import ConfigurableFieldSpec from langchain_core.runnables.history import RunnableWithMessageHistory -from ..history import BoundedChatHistory +from ..context_providers import ContextProviderException, find_commands from .base import BaseChatHandler, SlashCommandRoutingType @@ -27,6 +22,7 @@ class DefaultChatHandler(BaseChatHandler): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self.prompt_template = None def create_llm_chain( self, provider: Type[BaseProvider], provider_params: Dict[str, str] @@ -40,95 +36,65 @@ def create_llm_chain( prompt_template = llm.get_chat_prompt_template() self.llm = llm + self.prompt_template = prompt_template - runnable = prompt_template | llm + runnable = prompt_template | llm # type:ignore if not llm.manages_history: - history = BoundedChatHistory(k=2) runnable = RunnableWithMessageHistory( - runnable=runnable, - get_session_history=lambda *args: history, + runnable=runnable, # type:ignore[arg-type] + get_session_history=self.get_llm_chat_memory, input_messages_key="input", history_messages_key="history", + history_factory_config=[ + ConfigurableFieldSpec( + id="last_human_msg", + annotation=HumanChatMessage, + ), + ], ) - self.llm_chain = runnable - def _start_stream(self, human_msg: HumanChatMessage) -> str: - """ - Sends an `agent-stream` message to indicate the start of a response - stream. Returns the ID of the message, denoted as the `stream_id`. - """ - stream_id = uuid4().hex - stream_msg = AgentStreamMessage( - id=stream_id, - time=time.time(), - body="", - reply_to=human_msg.id, - persona=self.persona, - complete=False, - ) - - for handler in self._root_chat_handlers.values(): - if not handler: - continue - - handler.broadcast_message(stream_msg) - break - - return stream_id - - def _send_stream_chunk(self, stream_id: str, content: str, complete: bool = False): - """ - Sends an `agent-stream-chunk` message containing content that should be - appended to an existing `agent-stream` message with ID `stream_id`. - """ - stream_chunk_msg = AgentStreamChunkMessage( - id=stream_id, content=content, stream_complete=complete - ) - - for handler in self._root_chat_handlers.values(): - if not handler: - continue - - handler.broadcast_message(stream_chunk_msg) - break - async def process_message(self, message: HumanChatMessage): self.get_llm_chain() - received_first_chunk = False - - # start with a pending message - with self.pending("Generating response") as pending_message: - ut = UsageTracker() - ut._SendCopilotEvent({ - "event_type": "chat", - "include_selection": message.selection is not None, - "model_type": "language", - "model_name": self.llm.model_id, - "model_provider_id": self.config_manager.lm_provider.id, - "prompt_word_count": len(message.body.split(" ")) - }) - # stream response in chunks. this works even if a provider does not - # implement streaming, as `astream()` defaults to yielding `_call()` - # when `_stream()` is not implemented on the LLM class. - async for chunk in self.llm_chain.astream( - {"input": message.body}, - config={"configurable": {"session_id": "static_session"}}, - ): - if not received_first_chunk: - # when receiving the first chunk, close the pending message and - # start the stream. - self.close_pending(pending_message) - stream_id = self._start_stream(human_msg=message) - received_first_chunk = True - - if isinstance(chunk, AIMessageChunk): - self._send_stream_chunk(stream_id, chunk.content) - elif isinstance(chunk, str): - self._send_stream_chunk(stream_id, chunk) - else: - self.log.error(f"Unrecognized type of chunk yielded: {type(chunk)}") - break + assert self.llm_chain + + inputs = {"input": message.body} + ut = UsageTracker() + ut._SendCopilotEvent({ + "event_type": "chat", + "include_selection": message.selection is not None, + "model_type": "language", + "model_name": self.llm.model_id, + "model_provider_id": self.config_manager.lm_provider.id, + "prompt_word_count": len(message.body.split(" ")) + }) + + if "context" in self.prompt_template.input_variables: + # include context from context providers. + try: + context_prompt = await self.make_context_prompt(message) + except ContextProviderException as e: + self.reply(str(e), message) + return + inputs["context"] = context_prompt + inputs["input"] = self.replace_prompt(inputs["input"]) + + await self.stream_reply(inputs, message) + + async def make_context_prompt(self, human_msg: HumanChatMessage) -> str: + return "\n\n".join( + await asyncio.gather( + *[ + provider.make_context_prompt(human_msg) + for provider in self.context_providers.values() + if find_commands(provider, human_msg.prompt) + ] + ) + ) - # complete stream after all chunks have been streamed - self._send_stream_chunk(stream_id, "", complete=True) + def replace_prompt(self, prompt: str) -> str: + # modifies prompt by the context providers. + # some providers may modify or remove their '@' commands from the prompt. + for provider in self.context_providers.values(): + prompt = provider.replace_prompt(prompt) + return prompt diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/export.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/export.py index ed478f57..7323d81c 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/export.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/export.py @@ -3,7 +3,7 @@ from datetime import datetime from typing import List -from jupyter_ai.models import AgentChatMessage, HumanChatMessage +from jupyter_ai.models import AgentChatMessage, AgentStreamMessage, HumanChatMessage from .base import BaseChatHandler, SlashCommandRoutingType @@ -22,7 +22,7 @@ def __init__(self, *args, **kwargs): self.parser.add_argument("path", nargs=argparse.REMAINDER) def chat_message_to_markdown(self, message): - if isinstance(message, AgentChatMessage): + if isinstance(message, (AgentChatMessage, AgentStreamMessage)): agent = self.config_manager.persona.name return f"**{agent}**: {message.body}" elif isinstance(message, HumanChatMessage): diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py index 633110fb..30e941fa 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py @@ -3,7 +3,6 @@ from jupyter_ai.models import CellWithErrorSelection, HumanChatMessage from jupyter_ai_magics.models.usage_tracking import UsageTracker from jupyter_ai_magics.providers import BaseProvider -from langchain.chains import LLMChain from langchain.prompts import PromptTemplate from .base import BaseChatHandler, SlashCommandRoutingType @@ -65,6 +64,7 @@ class FixChatHandler(BaseChatHandler): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self.prompt_template = None def create_llm_chain( self, provider: Type[BaseProvider], provider_params: Dict[str, str] @@ -74,9 +74,11 @@ def create_llm_chain( **(self.get_model_parameters(provider, provider_params)), } llm = provider(**unified_parameters) - self.llm = llm - self.llm_chain = LLMChain(llm=llm, prompt=FIX_PROMPT_TEMPLATE, verbose=True) + prompt_template = FIX_PROMPT_TEMPLATE + + runnable = prompt_template | llm # type:ignore + self.llm_chain = runnable async def process_message(self, message: HumanChatMessage): if not (message.selection and message.selection.type == "cell-with-error"): @@ -93,23 +95,23 @@ async def process_message(self, message: HumanChatMessage): extra_instructions = message.prompt[4:].strip() or "None." self.get_llm_chain() - with self.pending("Analyzing error"): - response = await self.llm_chain.apredict( - extra_instructions=extra_instructions, - stop=["\nHuman:"], - cell_content=selection.source, - error_name=selection.error.name, - error_value=selection.error.value, - traceback="\n".join(selection.error.traceback), - ) - ut = UsageTracker() - ut._SendCopilotEvent({ - "event_details": "/fix", - "event_type": "slash", - "include_selection": message.selection is not None, - "model_type": "language", - "model_name": self.llm.model_id, - "model_provider_id": self.config_manager.lm_provider.id, - "prompt_word_count": len(extra_instructions) - }) - self.reply(response, message) + assert self.llm_chain + + inputs = { + "extra_instructions": extra_instructions, + "cell_content": selection.source, + "traceback": "\n".join(selection.error.traceback), + "error_name": selection.error.name, + "error_value": selection.error.value, + } + ut = UsageTracker() + ut._SendCopilotEvent({ + "event_details": "/fix", + "event_type": "slash", + "include_selection": message.selection is not None, + "model_type": "language", + "model_name": self.llm.model_id, + "model_provider_id": self.config_manager.lm_provider.id, + "prompt_word_count": len(extra_instructions) + }) + await self.stream_reply(inputs, message, pending_msg="Analyzing error") diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py index 654b12aa..95545009 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py @@ -227,7 +227,7 @@ class GenerateChatHandler(BaseChatHandler): def __init__(self, log_dir: Optional[str], *args, **kwargs): super().__init__(*args, **kwargs) self.log_dir = Path(log_dir) if log_dir else None - self.llm = None + self.llm: Optional[BaseProvider] = None def create_llm_chain( self, provider: Type[BaseProvider], provider_params: Dict[str, str] @@ -249,6 +249,7 @@ async def _generate_notebook(self, prompt: str): # Save the user input prompt, the description property is now LLM generated. outline["prompt"] = prompt + assert self.llm if self.llm.allows_concurrency: # fill the outline concurrently await afill_outline(outline, llm=self.llm, verbose=True) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/help.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/help.py index 0e82be13..cd855686 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/help.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/help.py @@ -1,53 +1,7 @@ -import time -from typing import Dict -from uuid import uuid4 - -from jupyter_ai.models import AgentChatMessage, HumanChatMessage -from jupyter_ai_magics import Persona +from jupyter_ai.models import HumanChatMessage from .base import BaseChatHandler, SlashCommandRoutingType -HELP_MESSAGE = """Hi there! I'm {persona_name}, your programming assistant. -You can ask me a question using the text box below. You can also use these commands: -{commands} - -Jupyter AI includes [magic commands](https://jupyter-ai.readthedocs.io/en/latest/users/index.html#the-ai-and-ai-magic-commands) that you can use in your notebooks. -For more information, see the [documentation](https://jupyter-ai.readthedocs.io). -""" - - -def _format_help_message( - chat_handlers: Dict[str, BaseChatHandler], - persona: Persona, - unsupported_slash_commands: set, -): - if unsupported_slash_commands: - keys = set(chat_handlers.keys()) - unsupported_slash_commands - chat_handlers = {key: chat_handlers[key] for key in keys} - - commands = "\n".join( - [ - f"* `{command_name}` — {handler.help}" - for command_name, handler in chat_handlers.items() - if command_name != "default" - ] - ) - return HELP_MESSAGE.format(commands=commands, persona_name=persona.name) - - -def build_help_message( - chat_handlers: Dict[str, BaseChatHandler], - persona: Persona, - unsupported_slash_commands: set, -): - return AgentChatMessage( - id=uuid4().hex, - time=time.time(), - body=_format_help_message(chat_handlers, persona, unsupported_slash_commands), - reply_to="", - persona=Persona(name=persona.name, avatar_route=persona.avatar_route), - ) - class HelpChatHandler(BaseChatHandler): id = "help" @@ -58,19 +12,8 @@ class HelpChatHandler(BaseChatHandler): uses_llm = False - def __init__(self, *args, chat_handlers: Dict[str, BaseChatHandler], **kwargs): + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self._chat_handlers = chat_handlers async def process_message(self, message: HumanChatMessage): - persona = self.config_manager.persona - lm_provider = self.config_manager.lm_provider - unsupported_slash_commands = ( - lm_provider.unsupported_slash_commands if lm_provider else set() - ) - self.reply( - _format_help_message( - self._chat_handlers, persona, unsupported_slash_commands - ), - message, - ) + self.send_help_message(message) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py index 4fe490d8..6dfa1ee4 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py @@ -1,6 +1,7 @@ import argparse import json import os +from glob import iglob from typing import Any, Coroutine, List, Optional, Tuple from dask.distributed import Client as DaskClient @@ -175,20 +176,34 @@ async def process_message(self, message: HumanChatMessage): return # Make sure the path exists. - if not len(args.path) == 1: - self.reply(f"{self.parser.format_usage()}", message) + if not (len(args.path) == 1 and args.path[0]): + no_path_arg_message = ( + "Please specify a directory or pattern you would like to " + 'learn on. "/learn" supports directories relative to ' + "the root (or preferred dir, if set) and Unix-style " + "wildcard matching.\n\n" + "Examples:\n" + "- Learn on the root directory recursively: `/learn .`\n" + "- Learn on files in the root directory: `/learn *`\n" + "- Learn all python files under the root directory recursively: `/learn **/*.py`" + ) + self.reply(f"{self.parser.format_usage()}\n\n {no_path_arg_message}") return short_path = args.path[0] load_path = os.path.join(self.output_dir, short_path) if not os.path.exists(load_path): - response = f"Sorry, that path doesn't exist: {load_path}" - self.reply(response, message) - return + try: + # check if globbing the load path will return anything + next(iglob(load_path)) + except StopIteration: + response = f"Sorry, that path doesn't exist: {load_path}" + self.reply(response, message) + return # delete and relearn index if embedding model was changed await self.delete_and_relearn() - with self.pending(f"Loading and splitting files for {load_path}"): + with self.pending(f"Loading and splitting files for {load_path}", message): try: _, em_provider_args = self.get_embedding_provider() ut = UsageTracker() @@ -205,11 +220,16 @@ async def process_message(self, message: HumanChatMessage): load_path, args.chunk_size, args.chunk_overlap, args.all_files ) except Exception as e: - response = f"""Learn documents in **{load_path}** failed. {str(e)}.""" + response = """Learn documents in **{}** failed. {}.""".format( + load_path.replace("*", r"\*"), + str(e), + ) else: self.save() - response = f"""🎉 I have learned documents at **{load_path}** and I am ready to answer questions about them. - You can ask questions about these docs by prefixing your message with **/ask**.""" + response = """🎉 I have learned documents at **%s** and I am ready to answer questions about them. + You can ask questions about these docs by prefixing your message with **/ask**.""" % ( + load_path.replace("*", r"\*") + ) self.reply(response, message) def _build_list_response(self): @@ -235,7 +255,9 @@ async def learn_dir( } splitter = ExtensionSplitter( splitters=splitters, - default_splitter=RecursiveCharacterTextSplitter(**splitter_kwargs), + default_splitter=RecursiveCharacterTextSplitter( + **splitter_kwargs # type:ignore[arg-type] + ), ) delayed = split(path, all_files, splitter=splitter) @@ -364,7 +386,7 @@ async def aget_relevant_documents( self, query: str ) -> Coroutine[Any, Any, List[Document]]: if not self.index: - return [] + return [] # type:ignore[return-value] await self.delete_and_relearn() docs = self.index.similarity_search(query) @@ -382,12 +404,14 @@ def get_embedding_model(self): class Retriever(BaseRetriever): - learn_chat_handler: LearnChatHandler = None + learn_chat_handler: LearnChatHandler = None # type:ignore[assignment] - def _get_relevant_documents(self, query: str) -> List[Document]: + def _get_relevant_documents( # type:ignore[override] + self, query: str + ) -> List[Document]: raise NotImplementedError() - async def _aget_relevant_documents( + async def _aget_relevant_documents( # type:ignore[override] self, query: str ) -> Coroutine[Any, Any, List[Document]]: docs = await self.learn_chat_handler.aget_relevant_documents(query) diff --git a/packages/jupyter-ai/jupyter_ai/config_manager.py b/packages/jupyter-ai/jupyter_ai/config_manager.py index 17d7ddde..a1ccd9e3 100644 --- a/packages/jupyter-ai/jupyter_ai/config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/config_manager.py @@ -3,7 +3,8 @@ import os import shutil import time -from typing import List, Optional, Union +from copy import deepcopy +from typing import List, Optional, Type, Union from deepmerge import always_merger as Merger from jsonschema import Draft202012Validator as Validator @@ -60,7 +61,7 @@ class BlockedModelError(Exception): pass -def _validate_provider_authn(config: GlobalConfig, provider: AnyProvider): +def _validate_provider_authn(config: GlobalConfig, provider: Type[AnyProvider]): # TODO: handle non-env auth strategies if not provider.auth_strategy or provider.auth_strategy.type != "env": return @@ -106,11 +107,11 @@ def __init__( log: Logger, lm_providers: LmProvidersDict, em_providers: EmProvidersDict, - allowed_providers: Optional[List[str]], - blocked_providers: Optional[List[str]], - allowed_models: Optional[List[str]], - blocked_models: Optional[List[str]], defaults: dict, + allowed_providers: Optional[List[str]] = None, + blocked_providers: Optional[List[str]] = None, + allowed_models: Optional[List[str]] = None, + blocked_models: Optional[List[str]] = None, *args, **kwargs, ): @@ -127,7 +128,13 @@ def __init__( self._allowed_models = allowed_models self._blocked_models = blocked_models self._defaults = defaults - """Provider defaults.""" + """ + Dictionary that maps config keys (e.g. `model_provider_id`, `fields`) to + user-specified overrides, set by traitlets configuration. + + Values in this dictionary should never be mutated as they may refer to + entries in the global `self.settings` dictionary. + """ self._last_read: Optional[int] = None """When the server last read the config file. If the file was not @@ -147,7 +154,7 @@ def _init_config_schema(self): os.makedirs(os.path.dirname(self.schema_path), exist_ok=True) shutil.copy(OUR_SCHEMA_PATH, self.schema_path) - def _init_validator(self) -> Validator: + def _init_validator(self) -> None: with open(OUR_SCHEMA_PATH, encoding="utf-8") as f: schema = json.loads(f.read()) Validator.check_schema(schema) @@ -218,19 +225,22 @@ def _create_default_config(self, default_config): self._write_config(GlobalConfig(**default_config)) def _init_defaults(self): - field_list = GlobalConfig.__fields__.keys() - properties = self.validator.schema.get("properties", {}) - field_dict = { - field: properties.get(field).get("default") for field in field_list + config_keys = GlobalConfig.__fields__.keys() + schema_properties = self.validator.schema.get("properties", {}) + default_config = { + field: schema_properties.get(field).get("default") for field in config_keys } if self._defaults is None: - return field_dict + return default_config - for field in field_list: - default_value = self._defaults.get(field) + for config_key in config_keys: + # we call `deepcopy()` here to avoid directly referring to the + # values in `self._defaults`, as they map to entries in the global + # `self.settings` dictionary and may be mutated otherwise. + default_value = deepcopy(self._defaults.get(config_key)) if default_value is not None: - field_dict[field] = default_value - return field_dict + default_config[config_key] = default_value + return default_config def _read_config(self) -> GlobalConfig: """Returns the user's current configuration as a GlobalConfig object. @@ -364,7 +374,7 @@ def delete_api_key(self, key_name: str): config_dict["api_keys"].pop(key_name, None) self._write_config(GlobalConfig(**config_dict)) - def update_config(self, config_update: UpdateConfigRequest): + def update_config(self, config_update: UpdateConfigRequest): # type:ignore last_write = os.stat(self.config_path).st_mtime_ns if config_update.last_read and config_update.last_read < last_write: raise WriteConflictError( @@ -432,20 +442,35 @@ def em_provider_params(self): @property def completions_lm_provider_params(self): return self._provider_params( - "completions_model_provider_id", self._lm_providers + "completions_model_provider_id", self._lm_providers, completions=True ) - def _provider_params(self, key, listing): - # get generic fields + def _provider_params(self, key, listing, completions: bool = False): + # read config config = self._read_config() - gid = getattr(config, key) - if not gid: + + # get model ID (without provider ID component) from model universal ID + # (with provider component). + model_uid = getattr(config, key) + if not model_uid: return None + model_id = model_uid.split(":", 1)[1] + + # get config fields (e.g. base API URL, etc.) + if completions: + fields = config.completions_fields.get(model_uid, {}) + else: + fields = config.fields.get(model_uid, {}) - lid = gid.split(":", 1)[1] + # exclude empty fields + # TODO: modify the config manager to never save empty fields in the + # first place. + for field_key in fields: + if isinstance(fields[field_key], str) and not len(fields[field_key]): + fields[field_key] = None # get authn fields - _, Provider = get_em_provider(gid, listing) + _, Provider = get_em_provider(model_uid, listing) authn_fields = {} if Provider.auth_strategy and Provider.auth_strategy.type == "env": keyword_param = ( @@ -456,7 +481,8 @@ def _provider_params(self, key, listing): authn_fields[keyword_param] = config.api_keys[key_name] return { - "model_id": lid, + "model_id": model_id, + **fields, **authn_fields, } diff --git a/packages/jupyter-ai/jupyter_ai/context_providers/__init__.py b/packages/jupyter-ai/jupyter_ai/context_providers/__init__.py new file mode 100644 index 00000000..7c521d84 --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/context_providers/__init__.py @@ -0,0 +1,7 @@ +from .base import ( + BaseCommandContextProvider, + ContextCommand, + ContextProviderException, + find_commands, +) +from .file import FileContextProvider diff --git a/packages/jupyter-ai/jupyter_ai/context_providers/_learned.py b/packages/jupyter-ai/jupyter_ai/context_providers/_learned.py new file mode 100644 index 00000000..5128487d --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/context_providers/_learned.py @@ -0,0 +1,53 @@ +# Currently unused as it is duplicating the functionality of the /ask command. +# TODO: Rename "learned" to something better. +from typing import List + +from jupyter_ai.chat_handlers.learn import Retriever +from jupyter_ai.models import HumanChatMessage + +from .base import BaseCommandContextProvider, ContextCommand +from .file import FileContextProvider + +FILE_CHUNK_TEMPLATE = """ +Snippet from file: {filepath} +``` +{content} +``` +""".strip() + + +class LearnedContextProvider(BaseCommandContextProvider): + id = "learned" + help = "Include content indexed from `/learn`" + remove_from_prompt = True + header = "Following are snippets from potentially relevant files:" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.retriever = Retriever(learn_chat_handler=self.chat_handlers["/learn"]) + + async def _make_context_prompt( + self, message: HumanChatMessage, commands: List[ContextCommand] + ) -> str: + if not self.retriever: + return "" + query = self._clean_prompt(message.body) + docs = await self.retriever.ainvoke(query) + excluded = self._get_repeated_files(message) + context = "\n\n".join( + [ + FILE_CHUNK_TEMPLATE.format( + filepath=d.metadata["path"], content=d.page_content + ) + for d in docs + if d.metadata["path"] not in excluded and d.page_content + ] + ) + return self.header + "\n" + context + + def _get_repeated_files(self, message: HumanChatMessage) -> List[str]: + # don't include files that are already provided by the file context provider + file_context_provider = self.context_providers.get("file") + if isinstance(file_context_provider, FileContextProvider): + return file_context_provider.get_filepaths(message) + return [] diff --git a/packages/jupyter-ai/jupyter_ai/context_providers/base.py b/packages/jupyter-ai/jupyter_ai/context_providers/base.py new file mode 100644 index 00000000..1b0953e8 --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/context_providers/base.py @@ -0,0 +1,243 @@ +import abc +import os +import re +from typing import TYPE_CHECKING, Awaitable, ClassVar, Dict, List, Optional + +from dask.distributed import Client as DaskClient +from jupyter_ai.chat_handlers.base import get_preferred_dir +from jupyter_ai.config_manager import ConfigManager, Logger +from jupyter_ai.models import ChatMessage, HumanChatMessage, ListOptionsEntry +from langchain.pydantic_v1 import BaseModel + +if TYPE_CHECKING: + from jupyter_ai.chat_handlers import BaseChatHandler + from jupyter_ai.history import BoundedChatHistory + + +class _BaseContextProvider(abc.ABC): + id: ClassVar[str] + """Unique identifier for the context provider command.""" + help: ClassVar[str] + """What this chat handler does, which third-party models it contacts, + the data it returns to the user, and so on, for display in the UI.""" + + def __init__( + self, + *, + log: Logger, + config_manager: ConfigManager, + model_parameters: Dict[str, Dict], + chat_history: List[ChatMessage], + llm_chat_memory: "BoundedChatHistory", + root_dir: str, + preferred_dir: Optional[str], + dask_client_future: Awaitable[DaskClient], + chat_handlers: Dict[str, "BaseChatHandler"], + context_providers: Dict[str, "BaseCommandContextProvider"], + ): + preferred_dir = preferred_dir or "" + self.log = log + self.config_manager = config_manager + self.model_parameters = model_parameters + self._chat_history = chat_history + self.llm_chat_memory = llm_chat_memory + self.root_dir = os.path.abspath(os.path.expanduser(root_dir)) + self.preferred_dir = get_preferred_dir(self.root_dir, preferred_dir) + self.dask_client_future = dask_client_future + self.chat_handlers = chat_handlers + self.context_providers = context_providers + + self.llm = None + + @abc.abstractmethod + async def make_context_prompt(self, message: HumanChatMessage) -> str: + """Returns a context prompt for all commands of the context provider + command. + """ + pass + + def replace_prompt(self, prompt: str) -> str: + """Modifies the prompt before sending it to the LLM.""" + return prompt + + def _clean_prompt(self, text: str) -> str: + # util for cleaning up the prompt before sending it to a retriever + for provider in self.context_providers.values(): + text = provider.replace_prompt(text) + return text + + @property + def base_dir(self) -> str: + # same as BaseChatHandler.output_dir + if self.preferred_dir and os.path.exists(self.preferred_dir): + return self.preferred_dir + else: + return self.root_dir + + def get_llm(self): + lm_provider = self.config_manager.lm_provider + lm_provider_params = self.config_manager.lm_provider_params + + curr_lm_id = ( + f'{self.llm.id}:{lm_provider_params["model_id"]}' if self.llm else None + ) + next_lm_id = ( + f'{lm_provider.id}:{lm_provider_params["model_id"]}' + if lm_provider + else None + ) + + if not lm_provider or not lm_provider_params: + return None + + if curr_lm_id != next_lm_id: + model_parameters = self.model_parameters.get( + f"{lm_provider.id}:{lm_provider_params['model_id']}", {} + ) + unified_parameters = { + "verbose": True, + **lm_provider_params, + **model_parameters, + } + llm = lm_provider(**unified_parameters) + self.llm = llm + return self.llm + + +class ContextCommand(BaseModel): + cmd: str + + @property + def id(self) -> str: + return self.cmd.partition(":")[0] + + @property + def arg(self) -> Optional[str]: + if ":" not in self.cmd: + return None + return self.cmd.partition(":")[2].strip("'\"").replace("\\ ", " ") + + def __str__(self) -> str: + return self.cmd + + def __hash__(self) -> int: + return hash(self.cmd) + + +class BaseCommandContextProvider(_BaseContextProvider): + id_prefix: ClassVar[str] = "@" + """Prefix symbol for command. Generally should not be overridden.""" + + # Configuration + requires_arg: ClassVar[bool] = False + """Whether command has an argument. E.g. '@file:'.""" + remove_from_prompt: ClassVar[bool] = False + """Whether the command should be removed from prompt when passing to LLM.""" + only_start: ClassVar[bool] = False + """Whether to command can only be inserted at the start of the prompt.""" + + @property + def command_id(self) -> str: + return self.id_prefix + self.id + + @property + def pattern(self) -> str: + # arg pattern allows for arguments between quotes or spaces with escape character ('\ ') + return ( + rf"(? str: + """Returns a context prompt for all commands of the context provider + command. + """ + commands = find_commands(self, message.prompt) + if not commands: + return "" + return await self._make_context_prompt(message, commands) + + @abc.abstractmethod + async def _make_context_prompt( + self, message: HumanChatMessage, commands: List[ContextCommand] + ) -> str: + """Returns a context prompt for the given commands.""" + pass + + def replace_prompt(self, prompt: str) -> str: + """Cleans up commands from the prompt before sending it to the LLM""" + + def replace(match): + if _is_command_call(match, prompt): + return self._replace_command(ContextCommand(cmd=match.group())) + return match.group() + + return re.sub(self.pattern, replace, prompt) + + def get_arg_options(self, arg_prefix: str) -> List[ListOptionsEntry]: + """Returns a list of autocomplete options for arguments to the command + based on the prefix. + Only triggered if ':' is present after the command id (e.g. '@file:'). + """ + if self.requires_arg: + # default implementation that should be modified if 'requires_arg' is True + return [self._make_arg_option(arg_prefix)] + return [] + + def _replace_command(self, command: ContextCommand) -> str: + if self.remove_from_prompt: + return "" + return command.cmd + + def _make_arg_option( + self, + arg: str, + *, + is_complete: bool = True, + description: Optional[str] = None, + ) -> ListOptionsEntry: + arg = arg.replace("\\ ", " ").replace(" ", "\\ ") # escape spaces + label = self.command_id + ":" + arg + (" " if is_complete else "") + return ListOptionsEntry( + id=self.command_id, + description=description or self.help, + label=label, + only_start=self.only_start, + ) + + +def find_commands( + context_provider: BaseCommandContextProvider, text: str +) -> List[ContextCommand]: + # finds commands of the context provider in the text + matches = list(re.finditer(context_provider.pattern, text)) + if context_provider.only_start: + matches = [match for match in matches if match.start() == 0] + results = [] + for match in matches: + if _is_command_call(match, text): + results.append(ContextCommand(cmd=match.group())) + return results + + +class ContextProviderException(Exception): + # Used to generate a response when a context provider fails + pass + + +def _is_command_call(match, text): + """Check if the match is a command call rather than a part of a code block. + This is done by checking if there is an even number of backticks before and + after the match. If there is an odd number of backticks, the match is likely + inside a code block. + """ + # potentially buggy if there is a stray backtick in text + # e.g. "help me count the backticks '`' ... ```\n...@cmd in code\n```". + # can be addressed by having selection in context rather than in prompt. + # more generally addressed by having a better command detection mechanism + # such as placing commands within special tags. + start, end = match.span() + before = text[:start] + after = text[end:] + return before.count("`") % 2 == 0 or after.count("`") % 2 == 0 diff --git a/packages/jupyter-ai/jupyter_ai/context_providers/file.py b/packages/jupyter-ai/jupyter_ai/context_providers/file.py new file mode 100644 index 00000000..45619122 --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/context_providers/file.py @@ -0,0 +1,169 @@ +import glob +import os +from typing import List + +import nbformat +from jupyter_ai.document_loaders.directory import SUPPORTED_EXTS +from jupyter_ai.models import HumanChatMessage, ListOptionsEntry + +from .base import ( + BaseCommandContextProvider, + ContextCommand, + ContextProviderException, + find_commands, +) + +FILE_CONTEXT_TEMPLATE = """ +File: {filepath} +``` +{content} +``` +""".strip() + + +class FileContextProvider(BaseCommandContextProvider): + id = "file" + help = "Include selected file's contents" + requires_arg = True + header = "Following are contents of files referenced:" + + def get_arg_options(self, arg_prefix: str) -> List[ListOptionsEntry]: + is_abs = not os.path.isabs(arg_prefix) + path_prefix = arg_prefix if is_abs else os.path.join(self.base_dir, arg_prefix) + path_prefix = path_prefix + return [ + self._make_arg_option( + arg=self._make_path(path, is_abs, is_dir), + description="Directory" if is_dir else "File", + is_complete=not is_dir, + ) + for path in glob.glob(path_prefix + "*") + if ( + (is_dir := os.path.isdir(path)) + or os.path.splitext(path)[1] in SUPPORTED_EXTS + ) + ] + + def _make_path(self, path: str, is_abs: bool, is_dir: bool) -> str: + if not is_abs: + path = os.path.relpath(path, self.base_dir) + if is_dir: + path += "/" + return path + + def get_file_type(self, filepath): + """ + Determine the file type of the given file path. + + Args: + filepath (str): The file path to analyze. + + Returns: + str: The file type as a string, e.g. '.txt', '.png', '.pdf', etc. + """ + file_extension = os.path.splitext(filepath)[1].lower() + + # Check if the file is a binary blob + try: + with open(filepath, "rb") as file: + file_header = file.read(4) + if ( + file_header == b"\x89PNG" + or file_header == b"\x89\x50\x4e\x47\x0d\x0a\x1a\x0a" + ): + return ".png" + elif file_header == b"\xff\xd8\xff\xe0": + return ".jpg" + elif file_header == b"GIF87a" or file_header == b"GIF89a": + return ".gif" + elif file_header == b"\x1f\x8b\x08": + return ".gz" + elif file_header == b"\x50\x4b\x03\x04": + return ".zip" + elif file_header == b"\x75\x73\x74\x61\x72": + return ".tar" + elif file_header == b"\x25\x50\x44\x46": + return ".pdf" + else: + return file_extension + except: + return file_extension + + async def _make_context_prompt( + self, message: HumanChatMessage, commands: List[ContextCommand] + ) -> str: + context = "\n\n".join( + [ + context + for i in set(commands) + if (context := self._make_command_context(i)) + ] + ) + if not context: + return "" + return self.header + "\n" + context + + def _make_command_context(self, command: ContextCommand) -> str: + filepath = command.arg or "" + if not os.path.isabs(filepath): + filepath = os.path.join(self.base_dir, filepath) + + if not os.path.exists(filepath): + raise ContextProviderException( + f"File not found while trying to read '{filepath}' " + f"triggered by `{command}`." + ) + if os.path.isdir(filepath): + raise ContextProviderException( + f"Cannot read directory '{filepath}' triggered by `{command}`. " + f"Only files are supported." + ) + if os.path.splitext(filepath)[1] not in SUPPORTED_EXTS: + raise ContextProviderException( + f"Cannot read unsupported file type '{filepath}' triggered by `{command}`. " + f"Supported file extensions are: {', '.join(SUPPORTED_EXTS)}." + ) + try: + with open(filepath) as f: + content = f.read() + except PermissionError: + raise ContextProviderException( + f"Permission denied while trying to read '{filepath}' " + f"triggered by `{command}`." + ) + except UnicodeDecodeError: + file_extension = self.get_file_type(filepath) + if file_extension: + raise ContextProviderException( + f"The `{file_extension}` file format is not supported for passing context to the LLM. " + f"The `@file` command only supports plaintext files." + ) + else: + raise ContextProviderException( + f"This file format is not supported for passing context to the LLM. " + f"The `@file` command only supports plaintext files." + ) + return FILE_CONTEXT_TEMPLATE.format( + filepath=filepath, + content=self._process_file(content, filepath), + ) + + def _process_file(self, content: str, filepath: str): + if filepath.endswith(".ipynb"): + nb = nbformat.reads(content, as_version=4) + return "\n\n".join([cell.source for cell in nb.cells]) + return content + + def _replace_command(self, command: ContextCommand) -> str: + # replaces commands of @file: with '' + filepath = command.arg or "" + return f"'{filepath}'" + + def get_filepaths(self, message: HumanChatMessage) -> List[str]: + filepaths = [] + for command in find_commands(self, message.prompt): + filepath = command.arg or "" + if not os.path.isabs(filepath): + filepath = os.path.join(self.base_dir, filepath) + filepaths.append(filepath) + return filepaths diff --git a/packages/jupyter-ai/jupyter_ai/document_loaders/directory.py b/packages/jupyter-ai/jupyter_ai/document_loaders/directory.py index 7b9b2832..e2840a90 100644 --- a/packages/jupyter-ai/jupyter_ai/document_loaders/directory.py +++ b/packages/jupyter-ai/jupyter_ai/document_loaders/directory.py @@ -3,6 +3,7 @@ import os import tarfile from datetime import datetime +from glob import iglob from pathlib import Path from typing import List @@ -30,7 +31,7 @@ def arxiv_to_text(id: str, output_dir: str) -> str: output path to the downloaded TeX file """ - import arxiv + import arxiv # type:ignore[import-not-found,import-untyped] outfile = f"{id}-{datetime.now():%Y-%m-%d-%H-%M}.tex" download_filename = "downloaded-paper.tar.gz" @@ -85,6 +86,7 @@ def path_to_doc(path): SUPPORTED_EXTS = { ".py", ".md", + ".qmd", ".R", ".Rmd", ".jl", @@ -97,7 +99,8 @@ def path_to_doc(path): ".txt", ".html", ".pdf", - ".tex", # added for raw latex files from arxiv + ".tex", + ".json", } @@ -109,6 +112,18 @@ def flatten(*chunk_lists): return list(itertools.chain(*chunk_lists)) +def walk_directory(directory, all_files): + filepaths = [] + for dir, subdirs, filenames in os.walk(directory): + # Filter out hidden filenames, hidden directories, and excluded directories, + # unless "all files" are requested + if not all_files: + subdirs[:] = [d for d in subdirs if not (d[0] == "." or d in EXCLUDE_DIRS)] + filenames = [f for f in filenames if not f[0] == "."] + filepaths += [Path(dir) / filename for filename in filenames] + return filepaths + + def collect_filepaths(path, all_files: bool): """Selects eligible files, i.e., 1. Files not in excluded directories, and @@ -119,17 +134,13 @@ def collect_filepaths(path, all_files: bool): # Check if the path points to a single file if os.path.isfile(path): filepaths = [Path(path)] + elif os.path.isdir(path): + filepaths = walk_directory(path, all_files) else: filepaths = [] - for dir, subdirs, filenames in os.walk(path): - # Filter out hidden filenames, hidden directories, and excluded directories, - # unless "all files" are requested - if not all_files: - subdirs[:] = [ - d for d in subdirs if not (d[0] == "." or d in EXCLUDE_DIRS) - ] - filenames = [f for f in filenames if not f[0] == "."] - filepaths.extend([Path(dir) / filename for filename in filenames]) + for glob_path in iglob(str(path), recursive=True): + if os.path.isfile(glob_path): + filepaths.append(Path(glob_path)) valid_exts = {j.lower() for j in SUPPORTED_EXTS} filepaths = [fp for fp in filepaths if fp.suffix.lower() in valid_exts] return filepaths diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index d1530e0c..a12ca46c 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -10,7 +10,7 @@ from jupyter_ai_magics.utils import get_em_providers, get_lm_providers from jupyter_server.extension.application import ExtensionApp from tornado.web import StaticFileHandler -from traitlets import Dict, List, Unicode +from traitlets import Dict, Integer, List, Unicode from .chat_handlers import ( AskChatHandler, @@ -22,11 +22,12 @@ HelpChatHandler, LearnChatHandler, ) -from .chat_handlers.help import build_help_message from .completions.handlers import DefaultInlineCompletionHandler from .config_manager import ConfigManager +from .context_providers import BaseCommandContextProvider, FileContextProvider from .handlers import ( ApiKeysHandler, + AutocompleteOptionsHandler, ChatHistoryHandler, EmbeddingsModelProviderHandler, GlobalConfigHandler, @@ -35,6 +36,7 @@ SlashCommandsInfoHandler, UsageTrackingHandler, ) +from .history import BoundedChatHistory CLOUDERA_COPILOT_AVATAR_ROUTE = ClouderaCopilotPersona.avatar_route CLOUDERA_COPILOT_AVATAR_PATH = str( @@ -42,14 +44,27 @@ ) +DEFAULT_HELP_MESSAGE_TEMPLATE = """Hi there! I'm {persona_name}, your programming assistant. +You can ask me a question using the text box below. You can also use these commands: +{slash_commands_list} + +You can use the following commands to add context to your questions: +{context_commands_list} + +Jupyter AI includes [magic commands](https://jupyter-ai.readthedocs.io/en/latest/users/index.html#the-ai-and-ai-magic-commands) that you can use in your notebooks. +For more information, see the [documentation](https://jupyter-ai.readthedocs.io). +""" + + class AiExtension(ExtensionApp): name = "jupyter_ai" - handlers = [ + handlers = [ # type:ignore[assignment] (r"api/ai/api_keys/(?P\w+)", ApiKeysHandler), (r"api/ai/config/?", GlobalConfigHandler), (r"api/ai/chats/?", RootChatHandler), (r"api/ai/chats/history?", ChatHistoryHandler), (r"api/ai/chats/slash_commands?", SlashCommandsInfoHandler), + (r"api/ai/chats/autocomplete_options?", AutocompleteOptionsHandler), (r"api/ai/providers?", ModelProviderHandler), (r"api/ai/providers/embeddings?", EmbeddingsModelProviderHandler), (r"api/ai/completion/inline/?", DefaultInlineCompletionHandler), @@ -160,6 +175,37 @@ class AiExtension(ExtensionApp): config=True, ) + help_message_template = Unicode( + default_value=DEFAULT_HELP_MESSAGE_TEMPLATE, + help=""" + A format string accepted by `str.format()`, which is used to generate a + dynamic help message. The format string should contain exactly two + named replacement fields: `persona_name` and `slash_commands_list`. + + - `persona_name`: String containing the name of the persona, which is + defined by the configured language model. Usually defaults to + 'Jupyternaut'. + + - `slash_commands_list`: A string containing a bulleted list of the + slash commands available to the configured language model. + """, + config=True, + ) + + default_max_chat_history = Integer( + default_value=2, + help=""" + Number of chat interactions to keep in the conversational memory object. + + An interaction is defined as an exchange between a human and AI, thus + comprising of one or two messages. + + Set to `None` to keep all interactions. + """, + allow_none=True, + config=True, + ) + def initialize_settings(self): start = time.time() @@ -224,6 +270,11 @@ def initialize_settings(self): # memory object used by the LM chain. self.settings["chat_history"] = [] + # conversational memory object used by LM chain + self.settings["llm_chat_memory"] = BoundedChatHistory( + k=self.default_max_chat_history + ) + # list of pending messages self.settings["pending_messages"] = [] @@ -240,24 +291,84 @@ def initialize_settings(self): # consumers a Future that resolves to the Dask client when awaited. self.settings["dask_client_future"] = loop.create_task(self._get_dask_client()) - eps = entry_points() + # Create empty context providers dict to be filled later. + # This is created early to use as kwargs for chat handlers. + self.settings["jai_context_providers"] = {} - common_handler_kargs = { - "log": self.log, - "config_manager": self.settings["jai_config_manager"], - "model_parameters": self.settings["model_parameters"], - } + # Create empty dictionary for events communicating that + # message generation/streaming got interrupted. + self.settings["jai_message_interrupted"] = {} # initialize chat handlers + self._init_chat_handlers() + + # initialize context providers + self._init_context_provders() + + # show help message at server start + self._show_help_message() + + latency_ms = round((time.time() - start) * 1000) + self.log.info(f"Initialized Jupyter AI server extension in {latency_ms} ms.") + + def _show_help_message(self): + """ + Method that ensures a dynamically-generated help message is included in + the chat history shown to users. + """ + # call `send_help_message()` on any instance of `BaseChatHandler`. The + # `default` chat handler should always exist, so we reference that + # object when calling `send_help_message()`. + default_chat_handler: DefaultChatHandler = self.settings["jai_chat_handlers"][ + "default" + ] + default_chat_handler.send_help_message() + + async def _get_dask_client(self): + return DaskClient(processes=False, asynchronous=True) + + async def stop_extension(self): + """ + Public method called by Jupyter Server when the server is stopping. + This calls the cleanup code defined in `self._stop_exception()` inside + an exception handler, as the server halts if this method raises an + exception. + """ + try: + await self._stop_extension() + except Exception as e: + self.log.error("Jupyter AI raised an exception while stopping:") + self.log.exception(e) + + async def _stop_extension(self): + """ + Private method that defines the cleanup code to run when the server is + stopping. + """ + if "dask_client_future" in self.settings: + dask_client: DaskClient = await self.settings["dask_client_future"] + self.log.info("Closing Dask client.") + await dask_client.close() + self.log.debug("Closed Dask client.") + + def _init_chat_handlers(self): + eps = entry_points() chat_handler_eps = eps.select(group="jupyter_ai.chat_handlers") + chat_handlers = {} chat_handler_kwargs = { - **common_handler_kargs, + "log": self.log, + "config_manager": self.settings["jai_config_manager"], + "model_parameters": self.settings["model_parameters"], "root_chat_handlers": self.settings["jai_root_chat_handlers"], "chat_history": self.settings["chat_history"], + "llm_chat_memory": self.settings["llm_chat_memory"], "root_dir": self.serverapp.root_dir, "dask_client_future": self.settings["dask_client_future"], - "model_parameters": self.settings["model_parameters"], "preferred_dir": self.serverapp.contents_manager.preferred_dir, + "help_message_template": self.help_message_template, + "chat_handlers": chat_handlers, + "context_providers": self.settings["jai_context_providers"], + "message_interrupted": self.settings["jai_message_interrupted"], } default_chat_handler = DefaultChatHandler(**chat_handler_kwargs) clear_chat_handler = ClearChatHandler(**chat_handler_kwargs) @@ -273,19 +384,13 @@ def initialize_settings(self): fix_chat_handler = FixChatHandler(**chat_handler_kwargs) - jai_chat_handlers = { - "default": default_chat_handler, - "/ask": ask_chat_handler, - "/clear": clear_chat_handler, - "/generate": generate_chat_handler, - "/learn": learn_chat_handler, - "/export": export_chat_handler, - "/fix": fix_chat_handler, - } - - help_chat_handler = HelpChatHandler( - **chat_handler_kwargs, chat_handlers=jai_chat_handlers - ) + chat_handlers["default"] = default_chat_handler + chat_handlers["/ask"] = ask_chat_handler + chat_handlers["/clear"] = clear_chat_handler + chat_handlers["/generate"] = generate_chat_handler + chat_handlers["/learn"] = learn_chat_handler + chat_handlers["/export"] = export_chat_handler + chat_handlers["/fix"] = fix_chat_handler slash_command_pattern = r"^[a-zA-Z0-9_]+$" for chat_handler_ep in chat_handler_eps: @@ -321,74 +426,75 @@ def initialize_settings(self): ) continue - if command_name in jai_chat_handlers: + if command_name in chat_handlers: self.log.error( f"Unable to register chat handler `{chat_handler.id}` because command `{command_name}` already has a handler" ) continue # The entry point is a class; we need to instantiate the class to send messages to it - jai_chat_handlers[command_name] = chat_handler(**chat_handler_kwargs) + chat_handlers[command_name] = chat_handler(**chat_handler_kwargs) self.log.info( f"Registered chat handler `{chat_handler.id}` with command `{command_name}`." ) # Make help always appear as the last command - jai_chat_handlers["/help"] = help_chat_handler + chat_handlers["/help"] = HelpChatHandler(**chat_handler_kwargs) # bind chat handlers to settings - self.settings["jai_chat_handlers"] = jai_chat_handlers - - # show help message at server start - self._show_help_message() - - latency_ms = round((time.time() - start) * 1000) - self.log.info(f"Initialized Jupyter AI server extension in {latency_ms} ms.") - - def _show_help_message(self): - """ - Method that ensures a dynamically-generated help message is included in - the chat history shown to users. - """ - chat_handlers = self.settings["jai_chat_handlers"] - config_manager: ConfigManager = self.settings["jai_config_manager"] - lm_provider = config_manager.lm_provider + self.settings["jai_chat_handlers"] = chat_handlers - if not lm_provider: - return + def _init_context_provders(self): + eps = entry_points() + context_providers_eps = eps.select(group="jupyter_ai.context_providers") + context_providers = self.settings["jai_context_providers"] + context_providers_kwargs = { + "log": self.log, + "config_manager": self.settings["jai_config_manager"], + "model_parameters": self.settings["model_parameters"], + "chat_history": self.settings["chat_history"], + "llm_chat_memory": self.settings["llm_chat_memory"], + "root_dir": self.serverapp.root_dir, + "dask_client_future": self.settings["dask_client_future"], + "preferred_dir": self.serverapp.contents_manager.preferred_dir, + "chat_handlers": self.settings["jai_chat_handlers"], + "context_providers": self.settings["jai_context_providers"], + } + context_providers_clses = [ + FileContextProvider, + ] + for context_provider_ep in context_providers_eps: + try: + context_provider = context_provider_ep.load() + except Exception as err: + self.log.error( + f"Unable to load context provider class from entry point `{context_provider_ep.name}`: " + + f"Unexpected {err=}, {type(err)=}" + ) + continue + context_providers_clses.append(context_provider) - persona = config_manager.persona - unsupported_slash_commands = ( - lm_provider.unsupported_slash_commands if lm_provider else set() - ) - help_message = build_help_message( - chat_handlers, persona, unsupported_slash_commands - ) - self.settings["chat_history"].append(help_message) + for context_provider in context_providers_clses: + if not issubclass(context_provider, BaseCommandContextProvider): + self.log.error( + f"Unable to register context provider `{context_provider.id}` because it does not inherit from `BaseCommandContextProvider`" + ) + continue - async def _get_dask_client(self): - return DaskClient(processes=False, asynchronous=True) + if context_provider.id in context_providers: + self.log.error( + f"Unable to register context provider `{context_provider.id}` because it already exists" + ) + continue - async def stop_extension(self): - """ - Public method called by Jupyter Server when the server is stopping. - This calls the cleanup code defined in `self._stop_exception()` inside - an exception handler, as the server halts if this method raises an - exception. - """ - try: - await self._stop_extension() - except Exception as e: - self.log.error("Jupyter AI raised an exception while stopping:") - self.log.exception(e) + if not re.match(r"^[a-zA-Z0-9_]+$", context_provider.id): + self.log.error( + f"Context provider `{context_provider.id}` is an invalid ID; " + + f"must contain only letters, numbers, and underscores" + ) + continue - async def _stop_extension(self): - """ - Private method that defines the cleanup code to run when the server is - stopping. - """ - if "dask_client_future" in self.settings: - dask_client: DaskClient = await self.settings["dask_client_future"] - self.log.info("Closing Dask client.") - await dask_client.close() - self.log.debug("Closed Dask client.") + context_providers[context_provider.id] = context_provider( + **context_providers_kwargs + ) + self.log.info(f"Registered context provider `{context_provider.id}`.") diff --git a/packages/jupyter-ai/jupyter_ai/handlers.py b/packages/jupyter-ai/jupyter_ai/handlers.py index 27797dab..3f594a47 100644 --- a/packages/jupyter-ai/jupyter_ai/handlers.py +++ b/packages/jupyter-ai/jupyter_ai/handlers.py @@ -3,13 +3,14 @@ import os import time import uuid -from asyncio import AbstractEventLoop +from asyncio import AbstractEventLoop, Event from dataclasses import asdict -from typing import TYPE_CHECKING, Dict, List, Optional +from typing import TYPE_CHECKING, Dict, List, Optional, Set, cast import tornado from jupyter_ai.chat_handlers import BaseChatHandler, SlashCommandRoutingType from jupyter_ai.config_manager import ConfigManager, KeyEmptyError, WriteConflictError +from jupyter_ai.context_providers import BaseCommandContextProvider, ContextCommand from jupyter_ai_magics.models.usage_tracking import UsageTracker from jupyter_server.base.handlers import APIHandler as BaseAPIHandler from jupyter_server.base.handlers import JupyterHandler @@ -26,15 +27,20 @@ ChatMessage, ChatRequest, ChatUser, + ClearMessage, + ClearRequest, ClosePendingMessage, ConnectionMessage, HumanChatMessage, + ListOptionsEntry, + ListOptionsResponse, ListProvidersEntry, ListProvidersResponse, ListSlashCommandsEntry, ListSlashCommandsResponse, Message, PendingMessage, + StopRequest, UpdateConfigRequest, ) @@ -42,12 +48,13 @@ from jupyter_ai_magics.embedding_providers import BaseEmbeddingsProvider from jupyter_ai_magics.providers import BaseProvider + from .context_providers import BaseCommandContextProvider + from .history import BoundedChatHistory + class ChatHistoryHandler(BaseAPIHandler): """Handler to return message history""" - _messages = [] - @property def chat_history(self) -> List[ChatMessage]: return self.settings["chat_history"] @@ -116,6 +123,14 @@ def chat_history(self) -> List[ChatMessage]: def chat_history(self, new_history): self.settings["chat_history"] = new_history + @property + def message_interrupted(self) -> Dict[str, Event]: + return self.settings["jai_message_interrupted"] + + @property + def llm_chat_memory(self) -> "BoundedChatHistory": + return self.settings["llm_chat_memory"] + @property def loop(self) -> AbstractEventLoop: return self.settings["jai_event_loop"] @@ -155,11 +170,17 @@ def get_chat_user(self) -> ChatUser: environment.""" # Get a dictionary of all loaded extensions. # (`serverapp` is a property on all `JupyterHandler` subclasses) + assert self.serverapp extensions = self.serverapp.extension_manager.extensions - collaborative = ( + collaborative_legacy = ( "jupyter_collaboration" in extensions and extensions["jupyter_collaboration"].enabled ) + collaborative_v3 = ( + "jupyter_server_ydoc" in extensions + and extensions["jupyter_server_ydoc"].enabled + ) + collaborative = collaborative_legacy or collaborative_v3 if collaborative: names = self.current_user.name.split(" ", maxsplit=2) @@ -181,7 +202,7 @@ def get_chat_user(self) -> ChatUser: login = getpass.getuser() initials = login[0].capitalize() return ChatUser( - username=login, + username=self.current_user.username, initials=initials, name=login, display_name=login, @@ -220,6 +241,15 @@ def broadcast_message(self, message: Message): Appends message to chat history. """ + # do not broadcast agent messages that are replying to cleared human message + if ( + isinstance(message, (AgentChatMessage, AgentStreamMessage)) + and message.reply_to + and message.reply_to + not in [m.id for m in self.chat_history if isinstance(m, HumanChatMessage)] + ): + return + self.log.debug("Broadcasting message: %s to all clients...", message) client_ids = self.root_chat_handlers.keys() @@ -246,6 +276,7 @@ def broadcast_message(self, message: Message): ): stream_message: AgentStreamMessage = history_message stream_message.body += chunk.content + stream_message.metadata = chunk.metadata stream_message.complete = chunk.stream_complete break elif isinstance(message, PendingMessage): @@ -260,11 +291,25 @@ async def on_message(self, message): try: message = json.loads(message) - chat_request = ChatRequest(**message) + if message.get("type") == "clear": + request = ClearRequest(**message) + elif message.get("type") == "stop": + request = StopRequest(**message) + else: + request = ChatRequest(**message) except ValidationError as e: self.log.error(e) return + if isinstance(request, ClearRequest): + self.on_clear_request(request) + return + + if isinstance(request, StopRequest): + self.on_stop_request() + return + + chat_request = request message_body = chat_request.prompt if chat_request.selection: message_body += f"\n\n```\n{chat_request.selection.source}\n```\n" @@ -288,6 +333,69 @@ async def on_message(self, message): # as a distinct concurrent task. self.loop.create_task(self._route(chat_message)) + def on_clear_request(self, request: ClearRequest): + target = request.target + + # if no target, clear all messages + if not target: + self.chat_history.clear() + self.pending_messages.clear() + self.llm_chat_memory.clear() + self.broadcast_message(ClearMessage()) + self.settings["jai_chat_handlers"]["default"].send_help_message() + return + + # otherwise, clear a single message + for msg in self.chat_history[::-1]: + # interrupt the single message + if msg.type == "agent-stream" and getattr(msg, "reply_to", None) == target: + try: + self.message_interrupted[msg.id].set() + except KeyError: + # do nothing if the message was already interrupted + # or stream got completed (thread-safe way!) + pass + break + + self.chat_history[:] = [ + msg + for msg in self.chat_history + if msg.id != target and getattr(msg, "reply_to", None) != target + ] + self.pending_messages[:] = [ + msg for msg in self.pending_messages if msg.reply_to != target + ] + self.llm_chat_memory.clear([target]) + self.broadcast_message(ClearMessage(targets=[target])) + + def on_stop_request(self): + # set of message IDs that were submitted by this user, determined by the + # username associated with this WebSocket connection. + current_user_messages: Set[str] = set() + for message in self.chat_history: + if ( + message.type == "human" + and message.client.username == self.current_user.username + ): + current_user_messages.add(message.id) + + # set of `AgentStreamMessage` IDs to stop + streams_to_stop: Set[str] = set() + for message in self.chat_history: + if ( + message.type == "agent-stream" + and message.reply_to in current_user_messages + ): + streams_to_stop.add(message.id) + + for stream_id in streams_to_stop: + try: + self.message_interrupted[stream_id].set() + except KeyError: + # do nothing if the message was already interrupted + # or stream got completed (thread-safe way!) + pass + async def _route(self, message): """Method that routes an incoming message to the appropriate handler.""" default = self.chat_handlers["default"] @@ -356,7 +464,7 @@ def filter_predicate(local_model_id: str): if self.blocked_models: return model_id not in self.blocked_models else: - return model_id in self.allowed_models + return model_id in cast(List, self.allowed_models) # filter out every model w/ model ID according to allow/blocklist for provider in providers: @@ -535,7 +643,7 @@ def post(self): class ApiKeysHandler(BaseAPIHandler): @property - def config_manager(self) -> ConfigManager: + def config_manager(self) -> ConfigManager: # type:ignore[override] return self.settings["jai_config_manager"] @web.authenticated @@ -550,7 +658,7 @@ class SlashCommandsInfoHandler(BaseAPIHandler): """List slash commands that are currently available to the user.""" @property - def config_manager(self) -> ConfigManager: + def config_manager(self) -> ConfigManager: # type:ignore[override] return self.settings["jai_config_manager"] @property @@ -593,3 +701,109 @@ def get(self): # sort slash commands by slash id and deliver the response response.slash_commands.sort(key=lambda sc: sc.slash_id) self.finish(response.json()) + + +class AutocompleteOptionsHandler(BaseAPIHandler): + """List context that are currently available to the user.""" + + @property + def config_manager(self) -> ConfigManager: # type:ignore[override] + return self.settings["jai_config_manager"] + + @property + def context_providers(self) -> Dict[str, "BaseCommandContextProvider"]: + return self.settings["jai_context_providers"] + + @property + def chat_handlers(self) -> Dict[str, "BaseChatHandler"]: + return self.settings["jai_chat_handlers"] + + @web.authenticated + def get(self): + response = ListOptionsResponse() + + # if no selected LLM, return an empty response + if not self.config_manager.lm_provider: + self.finish(response.json()) + return + + partial_cmd = self.get_query_argument("partialCommand", None) + if partial_cmd: + # if providing options for partial command argument + cmd = ContextCommand(cmd=partial_cmd) + context_provider = next( + ( + cp + for cp in self.context_providers.values() + if isinstance(cp, BaseCommandContextProvider) + and cp.command_id == cmd.id + ), + None, + ) + if ( + cmd.arg is not None + and context_provider + and isinstance(context_provider, BaseCommandContextProvider) + ): + response.options = context_provider.get_arg_options(cmd.arg) + else: + response.options = ( + self._get_slash_command_options() + self._get_context_provider_options() + ) + self.finish(response.json()) + + def _get_slash_command_options(self) -> List[ListOptionsEntry]: + options = [] + for id, chat_handler in self.chat_handlers.items(): + # filter out any chat handler that is not a slash command + if id == "default" or not isinstance( + chat_handler.routing_type, SlashCommandRoutingType + ): + continue + + routing_type = chat_handler.routing_type + + # filter out any chat handler that is unsupported by the current LLM + if ( + not routing_type.slash_id + or "/" + routing_type.slash_id + in self.config_manager.lm_provider.unsupported_slash_commands + ): + continue + + options.append( + self._make_autocomplete_option( + id="/" + routing_type.slash_id, + description=chat_handler.help, + only_start=True, + requires_arg=False, + ) + ) + options.sort(key=lambda opt: opt.id) + return options + + def _get_context_provider_options(self) -> List[ListOptionsEntry]: + options = [ + self._make_autocomplete_option( + id=context_provider.command_id, + description=context_provider.help, + only_start=context_provider.only_start, + requires_arg=context_provider.requires_arg, + ) + for context_provider in self.context_providers.values() + if isinstance(context_provider, BaseCommandContextProvider) + ] + options.sort(key=lambda opt: opt.id) + return options + + def _make_autocomplete_option( + self, + id: str, + description: str, + only_start: bool, + requires_arg: bool, + ): + label = id + (":" if requires_arg else " ") + return ListOptionsEntry( + id=id, description=description, label=label, only_start=only_start + ) diff --git a/packages/jupyter-ai/jupyter_ai/history.py b/packages/jupyter-ai/jupyter_ai/history.py index 02b77b91..0f1ba7dc 100644 --- a/packages/jupyter-ai/jupyter_ai/history.py +++ b/packages/jupyter-ai/jupyter_ai/history.py @@ -1,8 +1,13 @@ -from typing import List, Sequence +import time +from typing import List, Optional, Sequence, Set, Union from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.messages import BaseMessage -from langchain_core.pydantic_v1 import BaseModel, Field +from langchain_core.pydantic_v1 import BaseModel, PrivateAttr + +from .models import HumanChatMessage + +HUMAN_MSG_ID_KEY = "_jupyter_ai_human_msg_id" class BoundedChatHistory(BaseChatMessageHistory, BaseModel): @@ -11,30 +16,96 @@ class BoundedChatHistory(BaseChatMessageHistory, BaseModel): `k` exchanges between a user and an LLM. For example, when `k=2`, `BoundedChatHistory` will store up to 2 human - messages and 2 AI messages. + messages and 2 AI messages. If `k` is set to `None` all messages are kept. """ - messages: List[BaseMessage] = Field(default_factory=list) - size: int = 0 - k: int + k: Union[int, None] + clear_time: float = 0.0 + cleared_msgs: Set[str] = set() + _all_messages: List[BaseMessage] = PrivateAttr(default_factory=list) + + @property + def messages(self) -> List[BaseMessage]: # type:ignore[override] + if self.k is None: + return self._all_messages + return self._all_messages[-self.k * 2 :] async def aget_messages(self) -> List[BaseMessage]: return self.messages def add_message(self, message: BaseMessage) -> None: """Add a self-created message to the store""" - self.messages.append(message) - self.size += 1 - - if self.size > self.k * 2: - self.messages.pop(0) + if HUMAN_MSG_ID_KEY not in message.additional_kwargs: + # human message id must be added to allow for targeted clearing of messages. + # `WrappedBoundedChatHistory` should be used instead to add messages. + raise ValueError( + "Message must have a human message ID to be added to the store." + ) + self._all_messages.append(message) async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None: """Add messages to the store""" self.add_messages(messages) - def clear(self) -> None: - self.messages = [] + def clear(self, human_msg_ids: Optional[List[str]] = None) -> None: + """Clears conversation exchanges. If `human_msg_id` is provided, only + clears the respective human message and its reply. Otherwise, clears + all messages.""" + if human_msg_ids: + self._all_messages = [ + m + for m in self._all_messages + if m.additional_kwargs[HUMAN_MSG_ID_KEY] not in human_msg_ids + ] + self.cleared_msgs.update(human_msg_ids) + else: + self._all_messages = [] + self.cleared_msgs = set() + self.clear_time = time.time() async def aclear(self) -> None: self.clear() + + +class WrappedBoundedChatHistory(BaseChatMessageHistory, BaseModel): + """ + Wrapper around `BoundedChatHistory` that only appends an `AgentChatMessage` + if the `HumanChatMessage` it is replying to was not cleared. If a chat + handler is replying to a `HumanChatMessage`, it should pass this object via + the `last_human_msg` configuration parameter. + + For example, a chat handler that is streaming a reply to a + `HumanChatMessage` should be called via: + + ```py + async for chunk in self.llm_chain.astream( + {"input": message.body}, + config={"configurable": {"last_human_msg": message}}, + ): + ... + ``` + + Reference: https://python.langchain.com/v0.1/docs/expression_language/how_to/message_history/ + """ + + history: BoundedChatHistory + last_human_msg: HumanChatMessage + + @property + def messages(self) -> List[BaseMessage]: # type:ignore[override] + return self.history.messages + + def add_message(self, message: BaseMessage) -> None: + # prevent adding pending messages to the store if clear was triggered. + if ( + self.last_human_msg.time > self.history.clear_time + and self.last_human_msg.id not in self.history.cleared_msgs + ): + message.additional_kwargs[HUMAN_MSG_ID_KEY] = self.last_human_msg.id + self.history.add_message(message) + + async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None: + self.add_messages(messages) + + def clear(self) -> None: + self.history.clear() diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py index 65e401b1..9ddb4ddf 100644 --- a/packages/jupyter-ai/jupyter_ai/models.py +++ b/packages/jupyter-ai/jupyter_ai/models.py @@ -1,3 +1,4 @@ +import json from typing import Any, Dict, List, Literal, Optional, Union from jupyter_ai_magics import Persona @@ -39,6 +40,27 @@ class ChatRequest(BaseModel): selection: Optional[Selection] +class StopRequest(BaseModel): + """ + A request from a user to stop streaming all messages that are replying to + messages previously sent by that user. This request does not stop all + streaming responses for all users, but only the user that issued the + request. User identity is determined by the `username` from the + `IdentityProvider` instance available to each WebSocket handler. + """ + + type: Literal["stop"] + + +class ClearRequest(BaseModel): + type: Literal["clear"] = "clear" + target: Optional[str] + """ + Message ID of the HumanChatMessage to delete an exchange at. + If not provided, this requests the backend to clear all messages. + """ + + class ChatUser(BaseModel): # User ID assigned by IdentityProvider. username: str @@ -56,8 +78,7 @@ class ChatClient(ChatUser): id: str -class AgentChatMessage(BaseModel): - type: Literal["agent"] = "agent" +class BaseAgentMessage(BaseModel): id: str time: float body: str @@ -74,8 +95,20 @@ class AgentChatMessage(BaseModel): this defaults to a description of `ClouderaCopilotPersona`. """ + metadata: Dict[str, Any] = {} + """ + Message metadata set by a provider after fully processing an input. The + contents of this dictionary are provider-dependent, and can be any + dictionary with string keys. This field is not to be displayed directly to + the user, and is intended solely for developer purposes. + """ + -class AgentStreamMessage(AgentChatMessage): +class AgentChatMessage(BaseAgentMessage): + type: Literal["agent"] = "agent" + + +class AgentStreamMessage(BaseAgentMessage): type: Literal["agent-stream"] = "agent-stream" complete: bool # other attrs inherited from `AgentChatMessage` @@ -84,9 +117,26 @@ class AgentStreamMessage(AgentChatMessage): class AgentStreamChunkMessage(BaseModel): type: Literal["agent-stream-chunk"] = "agent-stream-chunk" id: str + """ID of the parent `AgentStreamMessage`.""" content: str + """The string to append to the `AgentStreamMessage` referenced by `id`.""" stream_complete: bool - """Indicates whether this chunk message completes the referenced stream.""" + """Indicates whether this chunk completes the stream referenced by `id`.""" + metadata: Dict[str, Any] = {} + """ + The metadata of the stream referenced by `id`. Metadata from the latest + chunk should override any metadata from previous chunks. See the docstring + on `BaseAgentMessage.metadata` for information. + """ + + @validator("metadata") + def validate_metadata(cls, v): + """Ensure metadata values are JSON serializable""" + try: + json.dumps(v) + return v + except TypeError as e: + raise ValueError(f"Metadata must be JSON serializable: {str(e)}") class HumanChatMessage(BaseModel): @@ -105,6 +155,11 @@ class HumanChatMessage(BaseModel): class ClearMessage(BaseModel): type: Literal["clear"] = "clear" + targets: Optional[List[str]] = None + """ + Message IDs of the HumanChatMessage to delete an exchange at. + If not provided, this instructs the frontend to clear all messages. + """ class PendingMessage(BaseModel): @@ -112,21 +167,20 @@ class PendingMessage(BaseModel): id: str time: float body: str + reply_to: str persona: Persona ellipsis: bool = True closed: bool = False class ClosePendingMessage(BaseModel): - type: Literal["pending"] = "close-pending" + type: Literal["close-pending"] = "close-pending" id: str # the type of messages being broadcast to clients ChatMessage = Union[ - AgentChatMessage, - HumanChatMessage, - AgentStreamMessage, + AgentChatMessage, HumanChatMessage, AgentStreamMessage, AgentStreamChunkMessage ] @@ -144,8 +198,7 @@ class ConnectionMessage(BaseModel): Message = Union[ - AgentChatMessage, - HumanChatMessage, + ChatMessage, ConnectionMessage, ClearMessage, PendingMessage, @@ -243,3 +296,21 @@ class ListSlashCommandsEntry(BaseModel): class ListSlashCommandsResponse(BaseModel): slash_commands: List[ListSlashCommandsEntry] = [] + + +class ListOptionsEntry(BaseModel): + id: str + """ID of the autocomplete option. + Includes the command prefix. E.g. "/clear", "@file".""" + label: str + """Text that will be inserted into the prompt when the option is selected. + Includes a space at the end if the option is complete. + Partial suggestions do not include the space and may trigger future suggestions.""" + description: str + """Text next to the option in the autocomplete list.""" + only_start: bool + """Whether to command can only be inserted at the start of the prompt.""" + + +class ListOptionsResponse(BaseModel): + options: List[ListOptionsEntry] = [] diff --git a/packages/jupyter-ai/jupyter_ai/tests/static/file9.ipynb b/packages/jupyter-ai/jupyter_ai/tests/static/file9.ipynb new file mode 100644 index 00000000..2d726593 --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/tests/static/file9.ipynb @@ -0,0 +1,51 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "85f55790-78a3-4fd2-bd0f-bf596e28a65c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "hello world\n" + ] + } + ], + "source": [ + "print(\"hello world\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "367c03ce-503f-4a2a-9221-c4fcd49b34c5", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py b/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py index 9aa16d2f..4a739f6e 100644 --- a/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py @@ -25,6 +25,25 @@ def schema_path(jp_data_dir): return str(jp_data_dir / "config_schema.json") +@pytest.fixture +def config_file_with_model_fields(jp_data_dir): + """ + Fixture that creates a `config.json` file with the chat model set to + `openai-chat:gpt-4o` and fields for that model. Returns path to the file. + """ + config_data = { + "model_provider_id:": "openai-chat:gpt-4o", + "embeddings_provider_id": None, + "api_keys": {"openai_api_key": "foobar"}, + "send_with_shift_enter": False, + "fields": {"openai-chat:gpt-4o": {"openai_api_base": "https://example.com"}}, + } + config_path = jp_data_dir / "config.json" + with open(config_path, "w") as file: + json.dump(config_data, file) + return str(config_path) + + @pytest.fixture def common_cm_kwargs(config_path, schema_path): """Kwargs that are commonly used when initializing the CM.""" @@ -175,6 +194,43 @@ def configure_to_openai(cm: ConfigManager): return LM_GID, EM_GID, LM_LID, EM_LID, API_PARAMS +def configure_with_fields(cm: ConfigManager, completions: bool = False): + """ + Default behavior: Configures the ConfigManager with fields and API keys. + Returns the expected result of `cm.lm_provider_params`. + + If `completions` is set to `True`, this configures the ConfigManager with + completion model fields, and returns the expected result of + `cm.completions_lm_provider_params`. + """ + if completions: + req = UpdateConfigRequest( + completions_model_provider_id="openai-chat:gpt-4o", + api_keys={"OPENAI_API_KEY": "foobar"}, + completions_fields={ + "openai-chat:gpt-4o": { + "openai_api_base": "https://example.com", + } + }, + ) + else: + req = UpdateConfigRequest( + model_provider_id="openai-chat:gpt-4o", + api_keys={"OPENAI_API_KEY": "foobar"}, + fields={ + "openai-chat:gpt-4o": { + "openai_api_base": "https://example.com", + } + }, + ) + cm.update_config(req) + return { + "model_id": "gpt-4o", + "openai_api_key": "foobar", + "openai_api_base": "https://example.com", + } + + def test_snapshot_default_config(cm: ConfigManager, snapshot): config_from_cm: DescribeConfigResponse = cm.get_config() assert config_from_cm == snapshot(exclude=lambda prop, path: prop == "last_read") @@ -402,3 +458,51 @@ def test_handle_bad_provider_ids(cm_with_bad_provider_ids): config_desc = cm_with_bad_provider_ids.get_config() assert config_desc.model_provider_id is None assert config_desc.embeddings_provider_id is None + + +def test_returns_chat_model_fields(cm): + """ + Asserts that `ConfigManager.lm_provider_params` returns model fields set by + the user. + """ + expected_model_args = configure_with_fields(cm) + assert cm.lm_provider_params == expected_model_args + + +def test_returns_completion_model_fields(cm): + expected_model_args = configure_with_fields(cm, completions=True) + assert cm.completions_lm_provider_params == expected_model_args + + +def test_config_manager_does_not_write_to_defaults( + config_file_with_model_fields, schema_path +): + """ + Asserts that `ConfigManager` does not write to the `defaults` argument when + the configured chat model differs from the one specified in `defaults`. + """ + from copy import deepcopy + + config_path = config_file_with_model_fields + log = logging.getLogger() + lm_providers = get_lm_providers() + em_providers = get_em_providers() + + defaults = { + "model_provider_id": None, + "embeddings_provider_id": None, + "api_keys": {}, + "fields": {}, + } + expected_defaults = deepcopy(defaults) + + cm = ConfigManager( + log=log, + lm_providers=lm_providers, + em_providers=em_providers, + config_path=config_path, + schema_path=schema_path, + defaults=defaults, + ) + + assert defaults == expected_defaults diff --git a/packages/jupyter-ai/jupyter_ai/tests/test_context_providers.py b/packages/jupyter-ai/jupyter_ai/tests/test_context_providers.py new file mode 100644 index 00000000..132dcf87 --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/tests/test_context_providers.py @@ -0,0 +1,80 @@ +import logging +from unittest import mock + +import pytest +from jupyter_ai.config_manager import ConfigManager +from jupyter_ai.context_providers import FileContextProvider, find_commands +from jupyter_ai.history import BoundedChatHistory +from jupyter_ai.models import ChatClient, HumanChatMessage, Persona + + +@pytest.fixture +def human_chat_message() -> HumanChatMessage: + chat_client = ChatClient( + id=0, username="test", initials="test", name="test", display_name="test" + ) + prompt = ( + "@file:test1.py @file @file:dir/test2.md test test\n" + "@file:/dir/test3.png\n" + "test@file:fail1.py\n" + "@file:dir\\ test\\ /test\\ 4.py\n" # spaces with escape + "@file:'test 5.py' @file:\"test6 .py\"\n" # quotes with spaces + "@file:'test7.py test\"\n" # do not allow for mixed quotes + "```\n@file:fail2.py\n```\n" # do not look within backticks + ) + return HumanChatMessage( + id="test", + time=0, + body=prompt, + prompt=prompt, + client=chat_client, + ) + + +@pytest.fixture +def file_context_provider() -> FileContextProvider: + config_manager = mock.create_autospec(ConfigManager) + config_manager.persona = Persona(name="test", avatar_route="test") + return FileContextProvider( + log=logging.getLogger(__name__), + config_manager=config_manager, + model_parameters={}, + chat_history=[], + llm_chat_memory=BoundedChatHistory(k=2), + root_dir="", + preferred_dir="", + dask_client_future=None, + chat_handlers={}, + context_providers={}, + ) + + +def test_find_instances(file_context_provider, human_chat_message): + expected = [ + "@file:test1.py", + "@file:dir/test2.md", + "@file:/dir/test3.png", + r"@file:dir\ test\ /test\ 4.py", + "@file:'test 5.py'", + '@file:"test6 .py"', + "@file:'test7.py", + ] + commands = [ + cmd.cmd + for cmd in find_commands(file_context_provider, human_chat_message.prompt) + ] + assert commands == expected + + +def test_replace_prompt(file_context_provider, human_chat_message): + expected = ( + "'test1.py' @file 'dir/test2.md' test test\n" + "'/dir/test3.png'\n" + "test@file:fail1.py\n" + "'dir test /test 4.py'\n" + "'test 5.py' 'test6 .py'\n" + "'test7.py' test\"\n" + "```\n@file:fail2.py\n```\n" # do not look within backticks + ) + prompt = file_context_provider.replace_prompt(human_chat_message.prompt) + assert prompt == expected diff --git a/packages/jupyter-ai/jupyter_ai/tests/test_directory.py b/packages/jupyter-ai/jupyter_ai/tests/test_directory.py index f9432b90..728815ec 100644 --- a/packages/jupyter-ai/jupyter_ai/tests/test_directory.py +++ b/packages/jupyter-ai/jupyter_ai/tests/test_directory.py @@ -17,6 +17,7 @@ def staging_dir(static_test_files_dir, jp_ai_staging_dir) -> Path: file6_path = static_test_files_dir / "file3.csv" file7_path = static_test_files_dir / "file3.xyz" file8_path = static_test_files_dir / "file4.pdf" + file9_path = static_test_files_dir / "file9.ipynb" job_staging_dir = jp_ai_staging_dir / "TestDir" job_staging_dir.mkdir() @@ -33,6 +34,7 @@ def staging_dir(static_test_files_dir, jp_ai_staging_dir) -> Path: shutil.copy2(file6_path, job_staging_hiddendir) shutil.copy2(file7_path, job_staging_subdir) shutil.copy2(file8_path, job_staging_hiddendir) + shutil.copy2(file9_path, job_staging_subdir) return job_staging_dir @@ -49,8 +51,24 @@ def test_collect_filepaths(staging_dir): # Call the function we want to test result = collect_filepaths(staging_dir_filepath, all_files) - assert len(result) == 3 # Test number of valid files + assert len(result) == 4 # Test number of valid files filenames = [fp.name for fp in result] assert "file0.html" in filenames # Check that valid file is included assert "file3.xyz" not in filenames # Check that invalid file is excluded + + # test unix wildcard pattern + pattern_path = os.path.join(staging_dir_filepath, "**/*.*py*") + results = collect_filepaths(pattern_path, all_files) + assert len(results) == 2 + condition = lambda p: p.suffix in [".py", ".ipynb"] + assert all(map(condition, results)) + + # test unix wildcard pattern returning only directories + pattern_path = f"{str(staging_dir_filepath)}*/" + results = collect_filepaths(pattern_path, all_files) + assert len(result) == 4 + filenames = [fp.name for fp in result] + + assert "file0.html" in filenames # Check that valid file is included + assert "file3.xyz" not in filenames # Check that invalid file is excluded diff --git a/packages/jupyter-ai/jupyter_ai/tests/test_extension.py b/packages/jupyter-ai/jupyter_ai/tests/test_extension.py index b46c148c..9ae52d8a 100644 --- a/packages/jupyter-ai/jupyter_ai/tests/test_extension.py +++ b/packages/jupyter-ai/jupyter_ai/tests/test_extension.py @@ -1,7 +1,12 @@ # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. +from unittest import mock + import pytest from jupyter_ai.extension import AiExtension +from jupyter_ai.history import HUMAN_MSG_ID_KEY +from jupyter_ai_magics import BaseProvider +from langchain_core.messages import BaseMessage pytest_plugins = ["pytest_jupyter.jupyter_server"] @@ -43,3 +48,46 @@ def test_blocks_providers(argv, jp_configurable_serverapp): ai._link_jupyter_server_extension(server) ai.initialize_settings() assert KNOWN_LM_A not in ai.settings["lm_providers"] + + +@pytest.fixture +def jp_server_config(jp_server_config): + # Disable the extension during server startup to avoid double initialization + return {"ServerApp": {"jpserver_extensions": {"jupyter_ai": False}}} + + +@pytest.fixture +def ai_extension(jp_serverapp): + ai = AiExtension() + # `BaseProvider.server_settings` can be only initialized once; however, the tests + # may run in parallel setting it with race condition; because we are not testing + # the `BaseProvider.server_settings` here, we can just mock the setter + settings_mock = mock.PropertyMock() + with mock.patch.object(BaseProvider.__class__, "server_settings", settings_mock): + yield ai + + +@pytest.mark.parametrize( + "max_history,messages_to_add,expected_size", + [ + # for max_history = 1 we expect to see up to 2 messages (1 human and 1 AI message) + (1, 4, 2), + # if there is less than `max_history` messages, all should be returned + (1, 1, 1), + # if no limit is set, all messages should be returned + (None, 9, 9), + ], +) +def test_max_chat_history(ai_extension, max_history, messages_to_add, expected_size): + ai = ai_extension + ai.default_max_chat_history = max_history + ai.initialize_settings() + for i in range(messages_to_add): + message = BaseMessage( + content=f"Test message {i}", + type="test", + additional_kwargs={HUMAN_MSG_ID_KEY: f"message-{i}"}, + ) + ai.settings["llm_chat_memory"].add_message(message) + + assert len(ai.settings["llm_chat_memory"].messages) == expected_size diff --git a/packages/jupyter-ai/jupyter_ai/tests/test_handlers.py b/packages/jupyter-ai/jupyter_ai/tests/test_handlers.py index d2e73ce6..81108bdb 100644 --- a/packages/jupyter-ai/jupyter_ai/tests/test_handlers.py +++ b/packages/jupyter-ai/jupyter_ai/tests/test_handlers.py @@ -7,10 +7,10 @@ import pytest from jupyter_ai.chat_handlers import DefaultChatHandler, learn from jupyter_ai.config_manager import ConfigManager +from jupyter_ai.extension import DEFAULT_HELP_MESSAGE_TEMPLATE from jupyter_ai.handlers import RootChatHandler +from jupyter_ai.history import BoundedChatHistory from jupyter_ai.models import ( - AgentStreamChunkMessage, - AgentStreamMessage, ChatClient, ClosePendingMessage, HumanChatMessage, @@ -70,9 +70,14 @@ def broadcast_message(message: Message) -> None: root_chat_handlers={"root": root_handler}, model_parameters={}, chat_history=[], + llm_chat_memory=BoundedChatHistory(k=2), root_dir="", preferred_dir="", dask_client_future=None, + help_message_template=DEFAULT_HELP_MESSAGE_TEMPLATE, + chat_handlers={}, + context_providers={}, + message_interrupted={}, ) diff --git a/packages/jupyter-ai/package.json b/packages/jupyter-ai/package.json index 7c94d4a5..d691d7f5 100644 --- a/packages/jupyter-ai/package.json +++ b/packages/jupyter-ai/package.json @@ -1,6 +1,6 @@ { "name": "@jupyter-ai/core", - "version": "2.20.0+cloudera", + "version": "2.28.3+cloudera", "description": "A generative AI extension for JupyterLab", "keywords": [ "jupyter", diff --git a/packages/jupyter-ai/pyproject.toml b/packages/jupyter-ai/pyproject.toml index e8deeb13..c9d1b5d5 100644 --- a/packages/jupyter-ai/pyproject.toml +++ b/packages/jupyter-ai/pyproject.toml @@ -29,10 +29,12 @@ dependencies = [ "importlib_metadata>=5.2.0", "jupyter_ai_magics>=2.13.0", "dask[distributed]", - "faiss-cpu<=1.8.0", # Not distributed by official repo + # faiss-cpu is not distributed by the official repo. + # v1.8.0.post0 should be excluded as it lacks macOS x86 wheels. + "faiss-cpu>=1.8.0,<2.0.0,!=1.8.0.post0", "typing_extensions>=4.5.0", "traitlets>=5.0", - "deepmerge>=1.0", + "deepmerge>=2.0,<3", ] dynamic = ["version", "description", "authors", "urls", "keywords"] @@ -50,6 +52,8 @@ test = [ "pytest-tornasync", "pytest-jupyter", "syrupy~=4.0.8", + "types-jsonschema", + "mypy", ] dev = ["jupyter_ai_magics[dev]"] diff --git a/packages/jupyter-ai/src/chat_handler.ts b/packages/jupyter-ai/src/chat_handler.ts index f1b131dc..e1b1e332 100644 --- a/packages/jupyter-ai/src/chat_handler.ts +++ b/packages/jupyter-ai/src/chat_handler.ts @@ -39,7 +39,7 @@ export class ChatHandler implements IDisposable { * Sends a message across the WebSocket. Promise resolves to the message ID * when the server sends the same message back, acknowledging receipt. */ - public sendMessage(message: AiService.ChatRequest): Promise { + public sendMessage(message: AiService.Request): Promise { return new Promise(resolve => { this._socket?.send(JSON.stringify(message)); this._sendResolverQueue.push(resolve); @@ -132,7 +132,20 @@ export class ChatHandler implements IDisposable { case 'connection': break; case 'clear': - this._messages = []; + if (newMessage.targets) { + const targets = newMessage.targets; + this._messages = this._messages.filter( + msg => + !targets.includes(msg.id) && + !('reply_to' in msg && targets.includes(msg.reply_to)) + ); + this._pendingMessages = this._pendingMessages.filter( + msg => !targets.includes(msg.reply_to) + ); + } else { + this._messages = []; + this._pendingMessages = []; + } break; case 'pending': this._pendingMessages = [...this._pendingMessages, newMessage]; @@ -157,6 +170,7 @@ export class ChatHandler implements IDisposable { } streamMessage.body += newMessage.content; + streamMessage.metadata = newMessage.metadata; if (newMessage.stream_complete) { streamMessage.complete = true; } diff --git a/packages/jupyter-ai/src/completions/handler.ts b/packages/jupyter-ai/src/completions/handler.ts index 6d9d4dbd..8e9cb9f0 100644 --- a/packages/jupyter-ai/src/completions/handler.ts +++ b/packages/jupyter-ai/src/completions/handler.ts @@ -110,7 +110,7 @@ export class CompletionWebsocketHandler implements IDisposable { (value: AiService.InlineCompletionReply) => void > = {}; - private _onClose(e: CloseEvent, reject: any) { + private _onClose(e: CloseEvent, reject: (reason: unknown) => void) { reject(new Error('Inline completion websocket disconnected')); console.error('Inline completion websocket disconnected'); // only attempt re-connect if there was an abnormal closure @@ -137,7 +137,7 @@ export class CompletionWebsocketHandler implements IDisposable { (token ? `?token=${encodeURIComponent(token)}` : ''); const socket = (this._socket = new WebSocket(url)); - socket.onclose = e => this._onClose(e, promise.reject); + socket.onclose = e => this._onClose(e, promise.reject.bind(promise)); socket.onerror = e => promise.reject(e); socket.onmessage = msg => msg.data && this._onMessage(JSON.parse(msg.data)); } diff --git a/packages/jupyter-ai/src/completions/provider.ts b/packages/jupyter-ai/src/completions/provider.ts index 786ced85..80c199f5 100644 --- a/packages/jupyter-ai/src/completions/provider.ts +++ b/packages/jupyter-ai/src/completions/provider.ts @@ -54,6 +54,17 @@ export class JaiInlineProvider request: CompletionHandler.IRequest, context: IInlineCompletionContext ): Promise> { + const allowedTriggerKind = this._settings.triggerKind; + const triggerKind = context.triggerKind; + if ( + allowedTriggerKind === 'manual' && + triggerKind !== InlineCompletionTriggerKind.Invoke + ) { + // Short-circuit if user requested to only invoke inline completions + // on manual trigger for jupyter-ai. Users may still get completions + // from other (e.g. less expensive or faster) providers. + return { items: [] }; + } const mime = request.mimeType ?? 'text/plain'; const language = this.options.languageRegistry.findByMIME(mime); if (!language) { @@ -142,6 +153,16 @@ export class JaiInlineProvider const knownLanguages = this.options.languageRegistry.getLanguages(); return { properties: { + triggerKind: { + title: 'Inline completions trigger', + type: 'string', + oneOf: [ + { const: 'any', title: 'Automatic (on typing or invocation)' }, + { const: 'manual', title: 'Only when invoked manually' } + ], + description: + 'When to trigger inline completions when using jupyter-ai.' + }, maxPrefix: { title: 'Maximum prefix length', minimum: 1, @@ -275,6 +296,7 @@ export namespace JaiInlineProvider { } export interface ISettings { + triggerKind: 'any' | 'manual'; maxPrefix: number; maxSuffix: number; debouncerDelay: number; @@ -284,6 +306,7 @@ export namespace JaiInlineProvider { } export const DEFAULT_SETTINGS: ISettings = { + triggerKind: 'any', maxPrefix: 10000, maxSuffix: 10000, // The debouncer delay handling is implemented upstream in JupyterLab; diff --git a/packages/jupyter-ai/src/components/chat-input.tsx b/packages/jupyter-ai/src/components/chat-input.tsx index e254b0ed..97f2ff48 100644 --- a/packages/jupyter-ai/src/components/chat-input.tsx +++ b/packages/jupyter-ai/src/components/chat-input.tsx @@ -36,12 +36,11 @@ type ChatInputProps = { * `'Jupyternaut'`, but can differ for custom providers. */ personaName: string; -}; - -type SlashCommandOption = { - id: string; - label: string; - description: string; + /** + * Whether the backend is streaming a reply to any message sent by the current + * user. + */ + streamingReplyHere: boolean; }; /** @@ -51,28 +50,29 @@ type SlashCommandOption = { * unclear whether custom icons should be defined within a Lumino plugin (in the * frontend) or served from a static server route (in the backend). */ -const DEFAULT_SLASH_COMMAND_ICONS: Record = { - ask: , - clear: , - export: , - fix: , - generate: , - help: , - learn: , +const DEFAULT_COMMAND_ICONS: Record = { + '/ask': , + '/clear': , + '/export': , + '/fix': , + '/generate': , + '/help': , + '/learn': , + '@file': , unknown: }; /** * Renders an option shown in the slash command autocomplete. */ -function renderSlashCommandOption( +function renderAutocompleteOption( optionProps: React.HTMLAttributes, - option: SlashCommandOption + option: AiService.AutocompleteOption ): JSX.Element { const icon = - option.id in DEFAULT_SLASH_COMMAND_ICONS - ? DEFAULT_SLASH_COMMAND_ICONS[option.id] - : DEFAULT_SLASH_COMMAND_ICONS.unknown; + option.id in DEFAULT_COMMAND_ICONS + ? DEFAULT_COMMAND_ICONS[option.id] + : DEFAULT_COMMAND_ICONS.unknown; return (
  • @@ -99,8 +99,14 @@ function renderSlashCommandOption( export function ChatInput(props: ChatInputProps): JSX.Element { const [input, setInput] = useState(''); - const [slashCommandOptions, setSlashCommandOptions] = useState< - SlashCommandOption[] + const [autocompleteOptions, setAutocompleteOptions] = useState< + AiService.AutocompleteOption[] + >([]); + const [autocompleteCommandOptions, setAutocompleteCommandOptions] = useState< + AiService.AutocompleteOption[] + >([]); + const [autocompleteArgOptions, setAutocompleteArgOptions] = useState< + AiService.AutocompleteOption[] >([]); const [currSlashCommand, setCurrSlashCommand] = useState(null); const activeCell = useActiveCellContext(); @@ -110,24 +116,46 @@ export function ChatInput(props: ChatInputProps): JSX.Element { * initial mount to populate the slash command autocomplete. */ useEffect(() => { - async function getSlashCommands() { - const slashCommands = (await AiService.listSlashCommands()) - .slash_commands; - setSlashCommandOptions( - slashCommands.map(slashCommand => ({ - id: slashCommand.slash_id, - label: '/' + slashCommand.slash_id + ' ', - description: slashCommand.description - })) - ); + async function getAutocompleteCommandOptions() { + const response = await AiService.listAutocompleteOptions(); + setAutocompleteCommandOptions(response.options); } - getSlashCommands(); + getAutocompleteCommandOptions(); }, []); - // whether any option is highlighted in the slash command autocomplete + useEffect(() => { + async function getAutocompleteArgOptions() { + let options: AiService.AutocompleteOption[] = []; + const lastWord = getLastWord(input); + if (lastWord.includes(':')) { + const id = lastWord.split(':', 1)[0]; + // get option that matches the command + const option = autocompleteCommandOptions.find( + option => option.id === id + ); + if (option) { + const response = await AiService.listAutocompleteArgOptions(lastWord); + options = response.options; + } + } + setAutocompleteArgOptions(options); + } + getAutocompleteArgOptions(); + }, [autocompleteCommandOptions, input]); + + // Combine the fixed options with the argument options + useEffect(() => { + if (autocompleteArgOptions.length > 0) { + setAutocompleteOptions(autocompleteArgOptions); + } else { + setAutocompleteOptions(autocompleteCommandOptions); + } + }, [autocompleteCommandOptions, autocompleteArgOptions]); + + // whether any option is highlighted in the autocomplete const [highlighted, setHighlighted] = useState(false); - // controls whether the slash command autocomplete is open + // controls whether the autocomplete is open const [open, setOpen] = useState(false); // store reference to the input element to enable focusing it easily @@ -153,7 +181,7 @@ export function ChatInput(props: ChatInputProps): JSX.Element { * chat input. Close the autocomplete when the user clears the chat input. */ useEffect(() => { - if (input === '/') { + if (filterAutocompleteOptions(autocompleteOptions, input).length > 0) { setOpen(true); return; } @@ -249,18 +277,52 @@ export function ChatInput(props: ChatInputProps): JSX.Element { const sendButtonProps: SendButtonProps = { onSend, + onStop: () => { + props.chatHandler.sendMessage({ + type: 'stop' + }); + }, + streamingReplyHere: props.streamingReplyHere, sendWithShiftEnter: props.sendWithShiftEnter, inputExists, activeCellHasError: activeCell.hasError, currSlashCommand }; + function filterAutocompleteOptions( + options: AiService.AutocompleteOption[], + inputValue: string + ): AiService.AutocompleteOption[] { + const lastWord = getLastWord(inputValue); + if (lastWord === '') { + return []; + } + const isStart = lastWord === inputValue; + return options.filter( + option => + option.label.startsWith(lastWord) && (!option.only_start || isStart) + ); + } + return ( { + return filterAutocompleteOptions(options, inputValue); + }} + onChange={(_, option) => { + const value = typeof option === 'string' ? option : option.label; + let matchLength = 0; + for (let i = 1; i <= value.length; i++) { + if (input.endsWith(value.slice(0, i))) { + matchLength = i; + } + } + setInput(input + value.slice(matchLength)); + }} onInputChange={(_, newValue: string) => { setInput(newValue); }} @@ -273,12 +335,16 @@ export function ChatInput(props: ChatInputProps): JSX.Element { setHighlighted(!!highlightedOption); } } - onClose={() => setOpen(false)} + onClose={(_, reason) => { + if (reason !== 'selectOption' || input.endsWith(' ')) { + setOpen(false); + } + }} // set this to an empty string to prevent the last selected slash // command from being shown in blue value="" open={open} - options={slashCommandOptions} + options={autocompleteOptions} // hide default extra right padding in the text field disableClearable // ensure the autocomplete popup always renders on top @@ -292,7 +358,7 @@ export function ChatInput(props: ChatInputProps): JSX.Element { } } }} - renderOption={renderSlashCommandOption} + renderOption={renderAutocompleteOption} ListboxProps={{ sx: { '& .MuiAutocomplete-option': { @@ -331,3 +397,7 @@ export function ChatInput(props: ChatInputProps): JSX.Element { ); } + +function getLastWord(input: string): string { + return input.split(/(? unknown; + onStop: () => unknown; sendWithShiftEnter: boolean; currSlashCommand: string | null; inputExists: boolean; activeCellHasError: boolean; + /** + * Whether the backend is streaming a reply to any message sent by the current + * user. + */ + streamingReplyHere: boolean; }; export function SendButton(props: SendButtonProps): JSX.Element { @@ -34,15 +41,27 @@ export function SendButton(props: SendButtonProps): JSX.Element { setMenuOpen(false); }, []); - const disabled = - props.currSlashCommand === '/fix' - ? !props.inputExists || !props.activeCellHasError - : !props.inputExists; + let action: 'send' | 'stop' | 'fix' = props.inputExists + ? 'send' + : props.streamingReplyHere + ? 'stop' + : 'send'; + if (props.currSlashCommand === '/fix') { + action = 'fix'; + } + + let disabled = false; + if (action === 'send' && !props.inputExists) { + disabled = true; + } + if (action === 'fix' && !props.activeCellHasError) { + disabled = true; + } const includeSelectionDisabled = !(activeCell.exists || textSelection); const includeSelectionTooltip = - props.currSlashCommand === '/fix' + action === 'fix' ? FIX_TOOLTIP : textSelection ? `${textSelection.text.split('\n').length} lines selected` @@ -55,8 +74,10 @@ export function SendButton(props: SendButtonProps): JSX.Element { : 'Send message (ENTER)'; const tooltip = - props.currSlashCommand === '/fix' && !props.activeCellHasError + action === 'fix' && !props.activeCellHasError ? FIX_TOOLTIP + : action === 'stop' + ? 'Stop streaming' : !props.inputExists ? 'Message must not be empty' : defaultTooltip; @@ -65,7 +86,7 @@ export function SendButton(props: SendButtonProps): JSX.Element { // if the current slash command is `/fix`, `props.onSend()` should always // include the code cell with error output, so the `selection` argument does // not need to be defined. - if (props.currSlashCommand === '/fix') { + if (action === 'fix') { props.onSend(); closeMenu(); return; @@ -85,7 +106,8 @@ export function SendButton(props: SendButtonProps): JSX.Element { if (activeCell.exists) { props.onSend({ type: 'cell', - source: activeCell.manager.getContent(false).source + // eslint-disable-next-line @typescript-eslint/no-non-null-assertion + source: activeCell.manager.getContent(false)!.source }); closeMenu(); return; @@ -95,7 +117,7 @@ export function SendButton(props: SendButtonProps): JSX.Element { return ( props.onSend()} + onClick={() => (action === 'stop' ? props.onStop() : props.onSend())} disabled={disabled} tooltip={tooltip} buttonProps={{ @@ -108,7 +130,7 @@ export function SendButton(props: SendButtonProps): JSX.Element { borderRadius: '2px 0px 0px 2px' }} > - + {action === 'stop' ? : } { diff --git a/packages/jupyter-ai/src/components/chat-messages.tsx b/packages/jupyter-ai/src/components/chat-messages.tsx index 86b6793d..5c4286f8 100644 --- a/packages/jupyter-ai/src/components/chat-messages.tsx +++ b/packages/jupyter-ai/src/components/chat-messages.tsx @@ -10,16 +10,20 @@ import { AiService } from '../handler'; import { RendermimeMarkdown } from './rendermime-markdown'; import { useCollaboratorsContext } from '../contexts/collaborators-context'; import { ChatMessageMenu } from './chat-messages/chat-message-menu'; +import { ChatMessageDelete } from './chat-messages/chat-message-delete'; +import { ChatHandler } from '../chat_handler'; import { IJaiMessageFooter } from '../tokens'; type ChatMessagesProps = { rmRegistry: IRenderMimeRegistry; messages: AiService.ChatMessage[]; + chatHandler: ChatHandler; messageFooter: IJaiMessageFooter | null; }; type ChatMessageHeaderProps = { message: AiService.ChatMessage; + chatHandler: ChatHandler; timestamp: string; sx?: SxProps; }; @@ -45,11 +49,11 @@ function sortMessages( */ const aOriginTimestamp = - a.type === 'agent' && a.reply_to in timestampsById + 'reply_to' in a && a.reply_to in timestampsById ? timestampsById[a.reply_to] : a.time; const bOriginTimestamp = - b.type === 'agent' && b.reply_to in timestampsById + 'reply_to' in b && b.reply_to in timestampsById ? timestampsById[b.reply_to] : b.time; @@ -113,6 +117,7 @@ export function ChatMessageHeader(props: ChatMessageHeaderProps): JSX.Element { const shouldShowMenu = props.message.type === 'agent' || (props.message.type === 'agent-stream' && props.message.complete); + const shouldShowDelete = props.message.type === 'human'; return ( )} + {shouldShowDelete && ( + + )} @@ -208,11 +220,13 @@ export function ChatMessages(props: ChatMessagesProps): JSX.Element { props.chatHandler.sendMessage(request)} + sx={props.sx} + tooltip="Delete this exchange" + > + + + ); +} + +export default ChatMessageDelete; diff --git a/packages/jupyter-ai/src/components/chat-settings.tsx b/packages/jupyter-ai/src/components/chat-settings.tsx index a1ad0a9b..c32eb46f 100644 --- a/packages/jupyter-ai/src/components/chat-settings.tsx +++ b/packages/jupyter-ai/src/components/chat-settings.tsx @@ -88,6 +88,9 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { const [apiKeys, setApiKeys] = useState>({}); const [sendWse, setSendWse] = useState(false); const [fields, setFields] = useState>({}); + const [embeddingModelFields, setEmbeddingModelFields] = useState< + Record + >({}); const [isCompleterEnabled, setIsCompleterEnabled] = useState( props.completionProvider && props.completionProvider.isEnabled() @@ -188,7 +191,15 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { const currFields: Record = server.config.fields?.[lmGlobalId] ?? {}; setFields(currFields); - }, [server, lmProvider]); + + if (!emGlobalId) { + return; + } + + const initEmbeddingModelFields: Record = + server.config.fields?.[emGlobalId] ?? {}; + setEmbeddingModelFields(initEmbeddingModelFields); + }, [server, lmGlobalId, emGlobalId]); const handleSave = async () => { // compress fields with JSON values @@ -222,6 +233,9 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { }), ...(clmGlobalId && { [clmGlobalId]: fields + }), + ...(emGlobalId && { + [emGlobalId]: embeddingModelFields }) } }), @@ -376,26 +390,35 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { {/* Embedding model section */}

    Embedding model

    {server.emProviders.providers.length > 0 ? ( - { + const emGid = e.target.value === 'null' ? null : e.target.value; + setEmGlobalId(emGid); + }} + MenuProps={{ sx: { maxHeight: '50%', minHeight: 400 } }} + > + None + {server.emProviders.providers.map(emp => + emp.models + .filter(em => em !== '*') // TODO: support registry providers + .map(em => ( + + {emp.name} :: {em} + + )) + )} + + {emGlobalId && ( + )} - + ) : (

    No embedding models available.

    )} diff --git a/packages/jupyter-ai/src/components/chat.tsx b/packages/jupyter-ai/src/components/chat.tsx index bd9123ed..4a1e0ae6 100644 --- a/packages/jupyter-ai/src/components/chat.tsx +++ b/packages/jupyter-ai/src/components/chat.tsx @@ -3,9 +3,11 @@ import { Box } from '@mui/system'; import { Button, IconButton, Stack } from '@mui/material'; import SettingsIcon from '@mui/icons-material/Settings'; import ArrowBackIcon from '@mui/icons-material/ArrowBack'; +import AddIcon from '@mui/icons-material/Add'; import type { Awareness } from 'y-protocols/awareness'; import type { IThemeManager } from '@jupyterlab/apputils'; import { IRenderMimeRegistry } from '@jupyterlab/rendermime'; +import type { User } from '@jupyterlab/services'; import { ISignal } from '@lumino/signaling'; import { JlThemeProvider } from './jl-theme-provider'; @@ -18,16 +20,25 @@ import { SelectionContextProvider } from '../contexts/selection-context'; import { SelectionWatcher } from '../selection-watcher'; import { ChatHandler } from '../chat_handler'; import { CollaboratorsContextProvider } from '../contexts/collaborators-context'; -import { IJaiCompletionProvider, IJaiMessageFooter } from '../tokens'; +import { + IJaiCompletionProvider, + IJaiMessageFooter, + IJaiTelemetryHandler +} from '../tokens'; import { ActiveCellContextProvider, ActiveCellManager } from '../contexts/active-cell-context'; +import { UserContextProvider, useUserContext } from '../contexts/user-context'; import { ScrollContainer } from './scroll-container'; +import { TooltippedIconButton } from './mui-extras/tooltipped-icon-button'; +import { TelemetryContextProvider } from '../contexts/telemetry-context'; type ChatBodyProps = { chatHandler: ChatHandler; - setChatView: (view: ChatView) => void; + openSettingsView: () => void; + showWelcomeMessage: boolean; + setShowWelcomeMessage: (show: boolean) => void; rmRegistry: IRenderMimeRegistry; focusInputSignal: ISignal; messageFooter: IJaiMessageFooter | null; @@ -51,7 +62,9 @@ function getPersonaName(messages: AiService.ChatMessage[]): string { function ChatBody({ chatHandler, focusInputSignal, - setChatView: chatViewHandler, + openSettingsView, + showWelcomeMessage, + setShowWelcomeMessage, rmRegistry: renderMimeRegistry, messageFooter }: ChatBodyProps): JSX.Element { @@ -64,8 +77,8 @@ function ChatBody({ const [personaName, setPersonaName] = useState( getPersonaName(messages) ); - const [showWelcomeMessage, setShowWelcomeMessage] = useState(false); const [sendWithShiftEnter, setSendWithShiftEnter] = useState(true); + const user = useUserContext(); /** * Effect: fetch config on initial render @@ -103,11 +116,6 @@ function ChatBody({ }; }, [chatHandler]); - const openSettingsView = () => { - setShowWelcomeMessage(false); - chatViewHandler(ChatView.Settings); - }; - if (showWelcomeMessage) { return ( m.type === 'human' && m.client.username === user?.identity.username + ) + .map(m => m.id) + ); + + // whether the backend is currently streaming a reply to any message sent by + // the current user. + const streamingReplyHere = messages.some( + m => + m.type === 'agent-stream' && + myHumanMessageIds.has(m.reply_to) && + !m.complete + ); + return ( <> - + ; messageFooter: IJaiMessageFooter | null; + telemetryHandler: IJaiTelemetryHandler | null; + userManager: User.IManager; }; enum ChatView { @@ -185,6 +215,12 @@ enum ChatView { export function Chat(props: ChatProps): JSX.Element { const [view, setView] = useState(props.chatView || ChatView.Chat); + const [showWelcomeMessage, setShowWelcomeMessage] = useState(false); + + const openSettingsView = () => { + setShowWelcomeMessage(false); + setView(ChatView.Settings); + }; return ( @@ -193,55 +229,78 @@ export function Chat(props: ChatProps): JSX.Element { - - {/* top bar */} - - {view !== ChatView.Chat ? ( - setView(ChatView.Chat)}> - - - ) : ( - - )} - {view === ChatView.Chat ? ( - setView(ChatView.Settings)}> - - - ) : ( - - )} - - {/* body */} - {view === ChatView.Chat && ( - - )} - {view === ChatView.Settings && ( - - )} - + + + =4.3.0. + // See: https://jupyterlab.readthedocs.io/en/latest/extension/extension_migration.html#css-styling + className="jp-ThemedContainer" + // root box should not include padding as it offsets the vertical + // scrollbar to the left + sx={{ + width: '100%', + height: '100%', + boxSizing: 'border-box', + background: 'var(--jp-layout-color0)', + display: 'flex', + flexDirection: 'column' + }} + > + {/* top bar */} + + {view !== ChatView.Chat ? ( + setView(ChatView.Chat)}> + + + ) : ( + + )} + {view === ChatView.Chat ? ( + + {!showWelcomeMessage && ( + + props.chatHandler.sendMessage({ type: 'clear' }) + } + tooltip="New chat" + > + + + )} + openSettingsView()}> + + + + ) : ( + + )} + + {/* body */} + {view === ChatView.Chat && ( + + )} + {view === ChatView.Settings && ( + + )} + + + diff --git a/packages/jupyter-ai/src/components/code-blocks/code-toolbar.tsx b/packages/jupyter-ai/src/components/code-blocks/code-toolbar.tsx index 7dd0f70e..315e5d4d 100644 --- a/packages/jupyter-ai/src/components/code-blocks/code-toolbar.tsx +++ b/packages/jupyter-ai/src/components/code-blocks/code-toolbar.tsx @@ -1,8 +1,10 @@ import React from 'react'; import { Box } from '@mui/material'; -import { addAboveIcon, addBelowIcon } from '@jupyterlab/ui-components'; - -import { CopyButton } from './copy-button'; +import { + addAboveIcon, + addBelowIcon, + copyIcon +} from '@jupyterlab/ui-components'; import { replaceCellIcon } from '../../icons'; import { @@ -11,20 +13,29 @@ import { } from '../../contexts/active-cell-context'; import { TooltippedIconButton } from '../mui-extras/tooltipped-icon-button'; import { useReplace } from '../../hooks/use-replace'; +import { useCopy } from '../../hooks/use-copy'; +import { AiService } from '../../handler'; +import { useTelemetry } from '../../contexts/telemetry-context'; +import { TelemetryEvent } from '../../tokens'; export type CodeToolbarProps = { /** * The content of the Markdown code block this component is attached to. */ - content: string; + code: string; + /** + * Parent message which contains the code referenced by `content`. + */ + parentMessage?: AiService.ChatMessage; }; export function CodeToolbar(props: CodeToolbarProps): JSX.Element { const activeCell = useActiveCellContext(); - const sharedToolbarButtonProps = { - content: props.content, + const sharedToolbarButtonProps: ToolbarButtonProps = { + code: props.code, activeCellManager: activeCell.manager, - activeCellExists: activeCell.exists + activeCellExists: activeCell.exists, + parentMessage: props.parentMessage }; return ( @@ -33,7 +44,7 @@ export function CodeToolbar(props: CodeToolbarProps): JSX.Element { display: 'flex', justifyContent: 'flex-end', alignItems: 'center', - padding: '6px 2px', + padding: '2px 2px', marginBottom: '1em', border: '1px solid var(--jp-cell-editor-border-color)', borderTop: 'none' @@ -41,19 +52,51 @@ export function CodeToolbar(props: CodeToolbarProps): JSX.Element { > - - + + ); } type ToolbarButtonProps = { - content: string; + code: string; activeCellExists: boolean; activeCellManager: ActiveCellManager; + parentMessage?: AiService.ChatMessage; + // TODO: parentMessage should always be defined, but this can be undefined + // when the code toolbar appears in Markdown help messages in the Settings + // UI. The Settings UI should use a different component to render Markdown, + // and should never render code toolbars within it. }; +function buildTelemetryEvent( + type: string, + props: ToolbarButtonProps +): TelemetryEvent { + const charCount = props.code.length; + // number of lines = number of newlines + 1 + const lineCount = (props.code.match(/\n/g) ?? []).length + 1; + + return { + type, + message: { + id: props.parentMessage?.id ?? '', + type: props.parentMessage?.type ?? 'human', + time: props.parentMessage?.time ?? 0, + metadata: + props.parentMessage && 'metadata' in props.parentMessage + ? props.parentMessage.metadata + : {} + }, + code: { + charCount, + lineCount + } + }; +} + function InsertAboveButton(props: ToolbarButtonProps) { + const telemetryHandler = useTelemetry(); const tooltip = props.activeCellExists ? 'Insert above active cell' : 'Insert above active cell (no active cell)'; @@ -61,7 +104,16 @@ function InsertAboveButton(props: ToolbarButtonProps) { return ( props.activeCellManager.insertAbove(props.content)} + onClick={() => { + props.activeCellManager.insertAbove(props.code); + + try { + telemetryHandler.onEvent(buildTelemetryEvent('insert-above', props)); + } catch (e) { + console.error(e); + return; + } + }} disabled={!props.activeCellExists} > @@ -70,6 +122,7 @@ function InsertAboveButton(props: ToolbarButtonProps) { } function InsertBelowButton(props: ToolbarButtonProps) { + const telemetryHandler = useTelemetry(); const tooltip = props.activeCellExists ? 'Insert below active cell' : 'Insert below active cell (no active cell)'; @@ -78,23 +131,67 @@ function InsertBelowButton(props: ToolbarButtonProps) { props.activeCellManager.insertBelow(props.content)} + onClick={() => { + props.activeCellManager.insertBelow(props.code); + + try { + telemetryHandler.onEvent(buildTelemetryEvent('insert-below', props)); + } catch (e) { + console.error(e); + return; + } + }} > ); } -function ReplaceButton(props: { value: string }) { +function ReplaceButton(props: ToolbarButtonProps) { + const telemetryHandler = useTelemetry(); const { replace, replaceDisabled, replaceLabel } = useReplace(); return ( replace(props.value)} + onClick={() => { + replace(props.code); + + try { + telemetryHandler.onEvent(buildTelemetryEvent('replace', props)); + } catch (e) { + console.error(e); + return; + } + }} > ); } + +export function CopyButton(props: ToolbarButtonProps): JSX.Element { + const telemetryHandler = useTelemetry(); + const { copy, copyLabel } = useCopy(); + + return ( + { + copy(props.code); + + try { + telemetryHandler.onEvent(buildTelemetryEvent('copy', props)); + } catch (e) { + console.error(e); + return; + } + }} + aria-label="Copy to clipboard" + > + + + ); +} diff --git a/packages/jupyter-ai/src/components/mui-extras/tooltipped-icon-button.tsx b/packages/jupyter-ai/src/components/mui-extras/tooltipped-icon-button.tsx index 31279b4d..e379ab39 100644 --- a/packages/jupyter-ai/src/components/mui-extras/tooltipped-icon-button.tsx +++ b/packages/jupyter-ai/src/components/mui-extras/tooltipped-icon-button.tsx @@ -80,6 +80,7 @@ export function TooltippedIconButton( onClick={props.onClick} disabled={props.disabled} sx={{ + ml: '8px', lineHeight: 0, ...(props.disabled && { opacity: 0.5 }), ...props.sx diff --git a/packages/jupyter-ai/src/components/pending-messages.tsx b/packages/jupyter-ai/src/components/pending-messages.tsx index e1101695..c258c295 100644 --- a/packages/jupyter-ai/src/components/pending-messages.tsx +++ b/packages/jupyter-ai/src/components/pending-messages.tsx @@ -3,9 +3,11 @@ import React, { useState, useEffect } from 'react'; import { Box, Typography } from '@mui/material'; import { AiService } from '../handler'; import { ChatMessageHeader } from './chat-messages'; +import { ChatHandler } from '../chat_handler'; type PendingMessagesProps = { messages: AiService.PendingMessage[]; + chatHandler: ChatHandler; }; type PendingMessageElementProps = { @@ -58,7 +60,8 @@ export function PendingMessages( time: lastMessage.time, body: '', reply_to: '', - persona: lastMessage.persona + persona: lastMessage.persona, + metadata: {} }); // timestamp format copied from ChatMessage @@ -85,6 +88,7 @@ export function PendingMessages( > { const renderContent = async () => { + // initialize mime model const mdStr = escapeLatexDelimiters(props.markdownStr); const model = props.rmRegistry.createModel({ data: { [MD_MIME_TYPE]: mdStr } }); + // step 1: render markdown await renderer.renderModel(model); - props.rmRegistry.latexTypesetter?.typeset(renderer.node); if (!renderer.node) { throw new Error( 'Rendermime was unable to render Markdown content within a chat message. Please report this upstream to Jupyter AI on GitHub.' ); } + // step 2: render LaTeX via MathJax + props.rmRegistry.latexTypesetter?.typeset(renderer.node); + // insert the rendering into renderingContainer if not yet inserted if (renderingContainer.current !== null && !renderingInserted.current) { renderingContainer.current.appendChild(renderer.node); @@ -87,7 +101,10 @@ function RendermimeMarkdownBase(props: RendermimeMarkdownProps): JSX.Element { ); newCodeToolbarDefns.push([ codeToolbarRoot, - { content: preBlock.textContent || '' } + { + code: preBlock.textContent || '', + parentMessage: props.parentMessage + } ]); }); @@ -95,7 +112,12 @@ function RendermimeMarkdownBase(props: RendermimeMarkdownProps): JSX.Element { }; renderContent(); - }, [props.markdownStr, props.complete, props.rmRegistry]); + }, [ + props.markdownStr, + props.complete, + props.rmRegistry, + props.parentMessage + ]); return (
    diff --git a/packages/jupyter-ai/src/contexts/index.ts b/packages/jupyter-ai/src/contexts/index.ts new file mode 100644 index 00000000..0cc0c017 --- /dev/null +++ b/packages/jupyter-ai/src/contexts/index.ts @@ -0,0 +1,4 @@ +export * from './active-cell-context'; +export * from './collaborators-context'; +export * from './selection-context'; +export * from './telemetry-context'; diff --git a/packages/jupyter-ai/src/contexts/telemetry-context.tsx b/packages/jupyter-ai/src/contexts/telemetry-context.tsx new file mode 100644 index 00000000..6a812e76 --- /dev/null +++ b/packages/jupyter-ai/src/contexts/telemetry-context.tsx @@ -0,0 +1,41 @@ +import React, { useContext, useState } from 'react'; +import { IJaiTelemetryHandler } from '../tokens'; + +const defaultTelemetryHandler: IJaiTelemetryHandler = { + onEvent: e => { + /* no-op */ + } +}; + +const TelemetryContext = React.createContext( + defaultTelemetryHandler +); + +/** + * Retrieves a reference to the current telemetry handler for Jupyter AI events + * returned by another plugin providing the `IJaiTelemetryHandler` token. If + * none exists, then the default telemetry handler is returned, which does + * nothing when `onEvent()` is called. + */ +export function useTelemetry(): IJaiTelemetryHandler { + return useContext(TelemetryContext); +} + +type TelemetryContextProviderProps = { + telemetryHandler: IJaiTelemetryHandler | null; + children: React.ReactNode; +}; + +export function TelemetryContextProvider( + props: TelemetryContextProviderProps +): JSX.Element { + const [telemetryHandler] = useState( + props.telemetryHandler ?? defaultTelemetryHandler + ); + + return ( + + {props.children} + + ); +} diff --git a/packages/jupyter-ai/src/contexts/user-context.tsx b/packages/jupyter-ai/src/contexts/user-context.tsx new file mode 100644 index 00000000..ff9fe8e3 --- /dev/null +++ b/packages/jupyter-ai/src/contexts/user-context.tsx @@ -0,0 +1,35 @@ +import React, { useContext, useEffect, useState } from 'react'; +import type { User } from '@jupyterlab/services'; +import { PartialJSONObject } from '@lumino/coreutils'; + +const UserContext = React.createContext(null); + +export function useUserContext(): User.IUser | null { + return useContext(UserContext); +} + +type UserContextProviderProps = { + userManager: User.IManager; + children: React.ReactNode; +}; + +export function UserContextProvider({ + userManager, + children +}: UserContextProviderProps): JSX.Element { + const [user, setUser] = useState(null); + + useEffect(() => { + userManager.ready.then(() => { + setUser({ + identity: userManager.identity!, + permissions: userManager.permissions as PartialJSONObject + }); + }); + userManager.userChanged.connect((sender, newUser) => { + setUser(newUser); + }); + }, []); + + return {children}; +} diff --git a/packages/jupyter-ai/src/handler.ts b/packages/jupyter-ai/src/handler.ts index af8e32ad..204d8b5d 100644 --- a/packages/jupyter-ai/src/handler.ts +++ b/packages/jupyter-ai/src/handler.ts @@ -83,6 +83,15 @@ export namespace AiService { selection?: Selection; }; + export type ClearRequest = { + type: 'clear'; + target?: string; + }; + + export type StopRequest = { + type: 'stop'; + }; + export type Collaborator = { username: string; initials: string; @@ -108,6 +117,7 @@ export namespace AiService { body: string; reply_to: string; persona: Persona; + metadata: Record; }; export type HumanChatMessage = { @@ -138,6 +148,7 @@ export namespace AiService { export type ClearMessage = { type: 'clear'; + targets?: string[]; }; export type PendingMessage = { @@ -145,6 +156,7 @@ export namespace AiService { id: string; time: number; body: string; + reply_to: string; persona: Persona; ellipsis: boolean; }; @@ -164,8 +176,11 @@ export namespace AiService { id: string; content: string; stream_complete: boolean; + metadata: Record; }; + export type Request = ChatRequest | ClearRequest | StopRequest; + export type ChatMessage = | AgentChatMessage | HumanChatMessage @@ -319,4 +334,30 @@ export namespace AiService { export async function listSlashCommands(): Promise { return requestAPI('chats/slash_commands'); } + + export type AutocompleteOption = { + id: string; + description: string; + label: string; + only_start: boolean; + }; + + export type ListAutocompleteOptionsResponse = { + options: AutocompleteOption[]; + }; + + export async function listAutocompleteOptions(): Promise { + return requestAPI( + 'chats/autocomplete_options' + ); + } + + export async function listAutocompleteArgOptions( + partialCommand: string + ): Promise { + return requestAPI( + 'chats/autocomplete_options?partialCommand=' + + encodeURIComponent(partialCommand) + ); + } } diff --git a/packages/jupyter-ai/src/index.ts b/packages/jupyter-ai/src/index.ts index e4209198..f24fbfa0 100644 --- a/packages/jupyter-ai/src/index.ts +++ b/packages/jupyter-ai/src/index.ts @@ -18,10 +18,16 @@ import { ChatHandler } from './chat_handler'; import { buildErrorWidget } from './widgets/chat-error'; import { completionPlugin } from './completions'; import { statusItemPlugin } from './status'; -import { IJaiCompletionProvider, IJaiMessageFooter } from './tokens'; +import { + IJaiCompletionProvider, + IJaiCore, + IJaiMessageFooter, + IJaiTelemetryHandler +} from './tokens'; import { IRenderMimeRegistry } from '@jupyterlab/rendermime'; import { ActiveCellManager } from './contexts/active-cell-context'; import { Signal } from '@lumino/signaling'; +import { menuPlugin } from './plugins/menu-plugin'; export type DocumentTracker = IWidgetTracker; @@ -35,17 +41,19 @@ export namespace CommandIDs { /** * Initialization data for the jupyter_ai extension. */ -const plugin: JupyterFrontEndPlugin = { +const plugin: JupyterFrontEndPlugin = { id: '@jupyter-ai/core:plugin', autoStart: true, + requires: [IRenderMimeRegistry], optional: [ IGlobalAwareness, ILayoutRestorer, IThemeManager, IJaiCompletionProvider, - IJaiMessageFooter + IJaiMessageFooter, + IJaiTelemetryHandler ], - requires: [IRenderMimeRegistry], + provides: IJaiCore, activate: async ( app: JupyterFrontEnd, rmRegistry: IRenderMimeRegistry, @@ -53,7 +61,8 @@ const plugin: JupyterFrontEndPlugin = { restorer: ILayoutRestorer | null, themeManager: IThemeManager | null, completionProvider: IJaiCompletionProvider | null, - messageFooter: IJaiMessageFooter | null + messageFooter: IJaiMessageFooter | null, + telemetryHandler: IJaiTelemetryHandler | null ) => { /** * Initialize selection watcher singleton @@ -91,7 +100,9 @@ const plugin: JupyterFrontEndPlugin = { openInlineCompleterSettings, activeCellManager, focusInputSignal, - messageFooter + messageFooter, + telemetryHandler, + app.serviceManager.user ); } catch (e) { chatWidget = buildErrorWidget(themeManager); @@ -114,7 +125,17 @@ const plugin: JupyterFrontEndPlugin = { }, label: 'Focus the jupyter-ai chat' }); + + return { + activeCellManager, + chatHandler, + chatWidget, + selectionWatcher + }; } }; -export default [plugin, statusItemPlugin, completionPlugin]; +export default [plugin, statusItemPlugin, completionPlugin, menuPlugin]; + +export * from './contexts'; +export * from './tokens'; diff --git a/packages/jupyter-ai/src/plugins/menu-plugin.ts b/packages/jupyter-ai/src/plugins/menu-plugin.ts new file mode 100644 index 00000000..8994a552 --- /dev/null +++ b/packages/jupyter-ai/src/plugins/menu-plugin.ts @@ -0,0 +1,158 @@ +import { + JupyterFrontEnd, + JupyterFrontEndPlugin +} from '@jupyterlab/application'; + +import { IJaiCore } from '../tokens'; +import { AiService } from '../handler'; +import { Menu } from '@lumino/widgets'; +import { CommandRegistry } from '@lumino/commands'; + +export namespace CommandIDs { + export const explain = 'jupyter-ai:explain'; + export const fix = 'jupyter-ai:fix'; + export const optimize = 'jupyter-ai:optimize'; + export const refactor = 'jupyter-ai:refactor'; +} + +/** + * Optional plugin that adds a "Generative AI" submenu to the context menu. + * These implement UI shortcuts that explain, fix, refactor, or optimize code in + * a notebook or file. + * + * **This plugin is experimental and may be removed in a future release.** + */ +export const menuPlugin: JupyterFrontEndPlugin = { + id: '@jupyter-ai/core:menu-plugin', + autoStart: true, + requires: [IJaiCore], + activate: (app: JupyterFrontEnd, jaiCore: IJaiCore) => { + const { activeCellManager, chatHandler, chatWidget, selectionWatcher } = + jaiCore; + + function activateChatSidebar() { + app.shell.activateById(chatWidget.id); + } + + function getSelection(): AiService.Selection | null { + const textSelection = selectionWatcher.selection; + const activeCell = activeCellManager.getContent(false); + const selection: AiService.Selection | null = textSelection + ? { type: 'text', source: textSelection.text } + : activeCell + ? { type: 'cell', source: activeCell.source } + : null; + + return selection; + } + + function buildLabelFactory(baseLabel: string): () => string { + return () => { + const textSelection = selectionWatcher.selection; + const activeCell = activeCellManager.getContent(false); + + return textSelection + ? `${baseLabel} (${textSelection.numLines} lines selected)` + : activeCell + ? `${baseLabel} (1 active cell)` + : baseLabel; + }; + } + + // register commands + const menuCommands = new CommandRegistry(); + menuCommands.addCommand(CommandIDs.explain, { + execute: () => { + const selection = getSelection(); + if (!selection) { + return; + } + + activateChatSidebar(); + chatHandler.sendMessage({ + prompt: 'Explain the code below.', + selection + }); + }, + label: buildLabelFactory('Explain code'), + isEnabled: () => !!getSelection() + }); + menuCommands.addCommand(CommandIDs.fix, { + execute: () => { + const activeCellWithError = activeCellManager.getContent(true); + if (!activeCellWithError) { + return; + } + + chatHandler.sendMessage({ + prompt: '/fix', + selection: { + type: 'cell-with-error', + error: activeCellWithError.error, + source: activeCellWithError.source + } + }); + }, + label: () => { + const activeCellWithError = activeCellManager.getContent(true); + return activeCellWithError + ? 'Fix code cell (1 error cell)' + : 'Fix code cell (no error cell)'; + }, + isEnabled: () => { + const activeCellWithError = activeCellManager.getContent(true); + return !!activeCellWithError; + } + }); + menuCommands.addCommand(CommandIDs.optimize, { + execute: () => { + const selection = getSelection(); + if (!selection) { + return; + } + + activateChatSidebar(); + chatHandler.sendMessage({ + prompt: 'Optimize the code below.', + selection + }); + }, + label: buildLabelFactory('Optimize code'), + isEnabled: () => !!getSelection() + }); + menuCommands.addCommand(CommandIDs.refactor, { + execute: () => { + const selection = getSelection(); + if (!selection) { + return; + } + + activateChatSidebar(); + chatHandler.sendMessage({ + prompt: 'Refactor the code below.', + selection + }); + }, + label: buildLabelFactory('Refactor code'), + isEnabled: () => !!getSelection() + }); + + // add commands as a context menu item containing a "Generative AI" submenu + const submenu = new Menu({ + commands: menuCommands + }); + submenu.id = 'jupyter-ai:submenu'; + submenu.title.label = 'Generative AI'; + submenu.addItem({ command: CommandIDs.explain }); + submenu.addItem({ command: CommandIDs.fix }); + submenu.addItem({ command: CommandIDs.optimize }); + submenu.addItem({ command: CommandIDs.refactor }); + + app.contextMenu.addItem({ + type: 'submenu', + selector: '.jp-Editor', + rank: 1, + submenu + }); + } +}; diff --git a/packages/jupyter-ai/src/selection-watcher.ts b/packages/jupyter-ai/src/selection-watcher.ts index 8dd7df58..9cbb67f3 100644 --- a/packages/jupyter-ai/src/selection-watcher.ts +++ b/packages/jupyter-ai/src/selection-watcher.ts @@ -76,6 +76,7 @@ function getTextSelection(widget: Widget | null): Selection | null { start, end, text, + numLines: text.split('\n').length, widgetId: widget.id, ...(cellId && { cellId @@ -88,6 +89,10 @@ export type Selection = CodeEditor.ITextSelection & { * The text within the selection as a string. */ text: string; + /** + * Number of lines contained by the text selection. + */ + numLines: number; /** * The ID of the document widget in which the selection was made. */ @@ -109,6 +114,10 @@ export class SelectionWatcher { setInterval(this._poll.bind(this), 200); } + get selection(): Selection | null { + return this._selection; + } + get selectionChanged(): Signal { return this._selectionChanged; } diff --git a/packages/jupyter-ai/src/tokens.ts b/packages/jupyter-ai/src/tokens.ts index efcada10..1b1c2eb1 100644 --- a/packages/jupyter-ai/src/tokens.ts +++ b/packages/jupyter-ai/src/tokens.ts @@ -1,8 +1,12 @@ import React from 'react'; import { Token } from '@lumino/coreutils'; import { ISignal } from '@lumino/signaling'; -import type { IRankedMenu } from '@jupyterlab/ui-components'; +import type { IRankedMenu, ReactWidget } from '@jupyterlab/ui-components'; + import { AiService } from './handler'; +import { ChatHandler } from './chat_handler'; +import { ActiveCellManager } from './contexts/active-cell-context'; +import { SelectionWatcher } from './selection-watcher'; export interface IJaiStatusItem { addItem(item: IRankedMenu.IItemOptions): void; @@ -46,3 +50,82 @@ export const IJaiMessageFooter = new Token( 'jupyter_ai:IJaiMessageFooter', 'Optional component that is used to render a footer on each Jupyter AI chat message, when provided.' ); + +export interface IJaiCore { + chatWidget: ReactWidget; + chatHandler: ChatHandler; + activeCellManager: ActiveCellManager; + selectionWatcher: SelectionWatcher; +} + +/** + * The Jupyter AI core provider token. Frontend plugins that want to extend the + * Jupyter AI frontend by adding features which send messages or observe the + * current text selection & active cell should require this plugin. + */ +export const IJaiCore = new Token( + 'jupyter_ai:core', + 'The core implementation of the frontend.' +); + +/** + * An object that describes an interaction event from the user. + * + * Jupyter AI natively emits 4 event types: "copy", "replace", "insert-above", + * or "insert-below". These are all emitted by the code toolbar rendered + * underneath code blocks in the chat sidebar. + */ +export type TelemetryEvent = { + /** + * Type of the interaction. + * + * Frontend extensions may add other event types in custom components. Custom + * events can be emitted via the `useTelemetry()` hook. + */ + type: 'copy' | 'replace' | 'insert-above' | 'insert-below' | string; + /** + * Anonymized details about the message that was interacted with. + */ + message: { + /** + * ID of the message assigned by Jupyter AI. + */ + id: string; + /** + * Type of the message. + */ + type: AiService.ChatMessage['type']; + /** + * UNIX timestamp of the message. + */ + time: number; + /** + * Metadata associated with the message, yielded by the underlying language + * model provider. + */ + metadata?: Record; + }; + /** + * Anonymized details about the code block that was interacted with, if any. + * This is left optional for custom events like message upvote/downvote that + * do not involve interaction with a specific code block. + */ + code?: { + charCount: number; + lineCount: number; + }; +}; + +export interface IJaiTelemetryHandler { + onEvent: (e: TelemetryEvent) => unknown; +} + +/** + * An optional plugin that handles telemetry events emitted via user + * interactions, when provided by a separate labextension. Not provided by + * default. + */ +export const IJaiTelemetryHandler = new Token( + 'jupyter_ai:telemetry', + 'An optional plugin that handles telemetry events emitted via interactions on agent messages, when provided by a separate labextension. Not provided by default.' +); diff --git a/packages/jupyter-ai/src/widgets/chat-sidebar.tsx b/packages/jupyter-ai/src/widgets/chat-sidebar.tsx index e7eee11b..732eedd3 100644 --- a/packages/jupyter-ai/src/widgets/chat-sidebar.tsx +++ b/packages/jupyter-ai/src/widgets/chat-sidebar.tsx @@ -2,13 +2,18 @@ import React from 'react'; import { ISignal } from '@lumino/signaling'; import { ReactWidget } from '@jupyterlab/apputils'; import type { IThemeManager } from '@jupyterlab/apputils'; +import type { User } from '@jupyterlab/services'; import type { Awareness } from 'y-protocols/awareness'; import { Chat } from '../components/chat'; import { chatIcon } from '../icons'; import { SelectionWatcher } from '../selection-watcher'; import { ChatHandler } from '../chat_handler'; -import { IJaiCompletionProvider, IJaiMessageFooter } from '../tokens'; +import { + IJaiCompletionProvider, + IJaiMessageFooter, + IJaiTelemetryHandler +} from '../tokens'; import { IRenderMimeRegistry } from '@jupyterlab/rendermime'; import type { ActiveCellManager } from '../contexts/active-cell-context'; @@ -22,7 +27,9 @@ export function buildChatSidebar( openInlineCompleterSettings: () => void, activeCellManager: ActiveCellManager, focusInputSignal: ISignal, - messageFooter: IJaiMessageFooter | null + messageFooter: IJaiMessageFooter | null, + telemetryHandler: IJaiTelemetryHandler | null, + userManager: User.IManager ): ReactWidget { const ChatWidget = ReactWidget.create( ); ChatWidget.id = 'jupyter-ai::chat'; diff --git a/packages/jupyter-ai/style/rendermime-markdown.css b/packages/jupyter-ai/style/rendermime-markdown.css index 7e93882a..2ced1722 100644 --- a/packages/jupyter-ai/style/rendermime-markdown.css +++ b/packages/jupyter-ai/style/rendermime-markdown.css @@ -1,8 +1,17 @@ -.jp-ai-rendermime-markdown .jp-RenderedHTMLCommon { +/* + * + * Selectors must be nested in `.jp-ThemedContainer` to have a higher + * specificity than selectors in rules provided by JupyterLab. + * + * See: https://jupyterlab.readthedocs.io/en/latest/extension/extension_migration.html#css-styling + * See also: https://github.com/jupyterlab/jupyter-ai/issues/1090 + */ + +.jp-ThemedContainer .jp-ai-rendermime-markdown .jp-RenderedHTMLCommon { padding-right: 0; } -.jp-ai-rendermime-markdown pre { +.jp-ThemedContainer .jp-ai-rendermime-markdown pre { background-color: var(--jp-cell-editor-background); overflow-x: auto; white-space: pre; @@ -11,12 +20,12 @@ border: var(--jp-border-width) solid var(--jp-cell-editor-border-color); } -.jp-ai-rendermime-markdown pre > code { +.jp-ThemedContainer .jp-ai-rendermime-markdown pre > code { background-color: inherit; overflow-x: inherit; white-space: inherit; } -.jp-ai-rendermime-markdown mjx-container { +.jp-ThemedContainer .jp-ai-rendermime-markdown mjx-container { font-size: 119%; } diff --git a/packages/jupyter-ai/ui-tests/tests/jupyter-ai.spec.ts-snapshots/chat-welcome-message-linux.png b/packages/jupyter-ai/ui-tests/tests/jupyter-ai.spec.ts-snapshots/chat-welcome-message-linux.png index c6f10885..f921708f 100644 Binary files a/packages/jupyter-ai/ui-tests/tests/jupyter-ai.spec.ts-snapshots/chat-welcome-message-linux.png and b/packages/jupyter-ai/ui-tests/tests/jupyter-ai.spec.ts-snapshots/chat-welcome-message-linux.png differ diff --git a/pyproject.toml b/pyproject.toml index cc45dd97..32217745 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,8 +5,11 @@ build-backend = "hatchling.build" [project] name = "jupyter_ai_monorepo" dynamic = ["version", "description", "authors", "urls", "keywords"] -requires-python = ">=3.8" -dependencies = [] +requires-python = ">=3.9" +dependencies = [ + "jupyter-ai-magics @ {root:uri}/packages/jupyter-ai-magics", + "jupyter-ai @ {root:uri}/packages/jupyter-ai" +] [project.optional-dependencies] build = [] @@ -22,6 +25,15 @@ text = "BSD 3-Clause License" source = "nodejs" path = "package.json" +[tool.hatch.build] +packages = [ + "packages/jupyter-ai-magics", + "packages/jupyter-ai" +] + +[tool.hatch.metadata] +allow-direct-references = true + [tool.check-manifest] ignore = [".*"] @@ -43,3 +55,8 @@ ignore_imports = ["jupyter_ai_magics.providers -> pydantic"] [tool.pytest.ini_options] addopts = "--ignore packages/jupyter-ai-module-cookiecutter" + +[tool.mypy] +exclude = [ + "tests" +] diff --git a/scripts/install.sh b/scripts/install.sh index ae5de0ff..7031bacb 100755 --- a/scripts/install.sh +++ b/scripts/install.sh @@ -1,7 +1,15 @@ #!/bin/bash +set -eux -# install core packages -pip install jupyterlab~=4.0 +# Install JupyterLab +# +# Excludes v4.3.2 as it pins `httpx` to a very narrow range, causing `pip +# install` to stall on package resolution. +# +# See: https://github.com/jupyterlab/jupyter-ai/issues/1138 +pip install jupyterlab~=4.0,!=4.3.2 + +# Install core packages cp playground/config.example.py playground/config.py jlpm install jlpm dev-install