diff --git a/.github/workflows/_ci.yaml b/.github/workflows/_ci.yaml index 589a42d3b..18856b718 100644 --- a/.github/workflows/_ci.yaml +++ b/.github/workflows/_ci.yaml @@ -1,5 +1,5 @@ name: ~CI, single-arch -run-name: CI-${{ inputs.ARCHITECTURE }} +run-name: CI-${{ inputs.ARCHITECTURE }}-${{ inputs.TESTSUBSET }} on: workflow_call: inputs: @@ -16,26 +16,60 @@ on: description: Artifact name in current run w/ manifest/patches. Leaving empty uses manifest/patches in current branch default: '' required: false + TEST_SUBSET: + type: string + description: | + Subset of tests to run. Allowed values are one of: + - base + - jax + - levanter + - equinox + - triton + - upstream-t5x + - rosetta-t5x + - upstream-pax + - rosetta-pax + - maxtext + - grok + + Will run all downstream-connected nodes and leaves. + default: 'base' + required: false outputs: DOCKER_TAGS: description: JSON object containing tags of all docker images built value: ${{ jobs.collect-docker-tags.outputs.TAGS }} permissions: - contents: read # to fetch code - actions: write # to cancel previous workflows + contents: read # to fetch code + actions: write # to cancel previous workflows packages: write # to upload container jobs: + pre-flight: + runs-on: ubuntu-22.04 + steps: + - name: Validate input `TEST_SUBSET` + shell: bash + run: | + valid_inputs=("base" "core" "levanter" "equinox" "triton" "upstream-t5x" "rosetta-t5x" "upstream-pax" "rosetta-pax" "maxtext" "grok") + + if [[ " ${valid_inputs[*]} " != *" ${{ inputs.TEST_SUBSET }} "* ]]; then + echo "Invalid value for \`TEST_SUBSET\` provided. Expected one of: ($valid_inputs), Actual: ${{ inputs.TEST_SUBSET }}" + exit 1 + fi + # Always build-base: uses: ./.github/workflows/_build_base.yaml + needs: pre-flight with: ARCHITECTURE: ${{ inputs.ARCHITECTURE }} BUILD_DATE: ${{ inputs.BUILD_DATE }} MANIFEST_ARTIFACT_NAME: ${{ inputs.MANIFEST_ARTIFACT_NAME }} secrets: inherit + # Always build-jax: needs: build-base uses: ./.github/workflows/_build.yaml @@ -50,9 +84,10 @@ jobs: RUNNER_SIZE: large secrets: inherit + # base, jax, triton build-triton: needs: build-jax - if: inputs.ARCHITECTURE == 'amd64' # Triton does not seem to support arm64 + if: contains(fromJSON('["base", "jax", "triton"]'), inputs.TEST_SUBSET) && inputs.ARCHITECTURE == 'amd64' # Triton does not seem to support arm64 uses: ./.github/workflows/_build.yaml with: ARCHITECTURE: ${{ inputs.ARCHITECTURE }} @@ -64,9 +99,11 @@ jobs: DOCKERFILE: .github/container/Dockerfile.triton secrets: inherit + # base, jax, equinox build-equinox: needs: build-jax uses: ./.github/workflows/_build.yaml + if: contains(fromJSON('["base", "jax", "equinox"]'), inputs.TEST_SUBSET) with: ARCHITECTURE: ${{ inputs.ARCHITECTURE }} ARTIFACT_NAME: artifact-equinox-build @@ -77,9 +114,10 @@ jobs: DOCKERFILE: .github/container/Dockerfile.equinox secrets: inherit + # base, jax, maxtext build-maxtext: needs: build-jax - if: inputs.ARCHITECTURE == 'amd64' # Triton does not seem to support arm64 + if: contains(fromJSON('["base", "jax", "maxtext"]'), inputs.TEST_SUBSET) && inputs.ARCHITECTURE == 'amd64' # Triton does not seem to support arm64 uses: ./.github/workflows/_build.yaml with: ARCHITECTURE: ${{ inputs.ARCHITECTURE }} @@ -91,35 +129,41 @@ jobs: DOCKERFILE: .github/container/Dockerfile.maxtext.amd64 secrets: inherit + # base, jax, levanter build-levanter: needs: [build-jax] uses: ./.github/workflows/_build.yaml + if: contains(fromJSON('["base", "jax", "levanter"]'), inputs.TEST_SUBSET) with: ARCHITECTURE: ${{ inputs.ARCHITECTURE }} - ARTIFACT_NAME: "artifact-levanter-build" - BADGE_FILENAME: "badge-levanter-build" + ARTIFACT_NAME: 'artifact-levanter-build' + BADGE_FILENAME: 'badge-levanter-build' BUILD_DATE: ${{ inputs.BUILD_DATE }} BASE_IMAGE: ${{ needs.build-jax.outputs.DOCKER_TAG_MEALKIT }} CONTAINER_NAME: levanter DOCKERFILE: .github/container/Dockerfile.levanter secrets: inherit + # base, jax, upstream-t5x build-upstream-t5x: needs: build-jax uses: ./.github/workflows/_build.yaml + if: contains(fromJSON('["base", "jax", "upstream-t5x", "rosetta-t5x"]'), inputs.TEST_SUBSET) with: ARCHITECTURE: ${{ inputs.ARCHITECTURE }} - ARTIFACT_NAME: "artifact-t5x-build" - BADGE_FILENAME: "badge-t5x-build" + ARTIFACT_NAME: 'artifact-t5x-build' + BADGE_FILENAME: 'badge-t5x-build' BUILD_DATE: ${{ inputs.BUILD_DATE }} BASE_IMAGE: ${{ needs.build-jax.outputs.DOCKER_TAG_MEALKIT }} CONTAINER_NAME: upstream-t5x DOCKERFILE: .github/container/Dockerfile.t5x.${{ inputs.ARCHITECTURE }} secrets: inherit + # base, jax, upstream-pax build-upstream-pax: needs: build-jax uses: ./.github/workflows/_build.yaml + if: contains(fromJSON('["base", "jax", "upstream-pax", "rosetta-pax"]'), inputs.TEST_SUBSET) with: ARCHITECTURE: ${{ inputs.ARCHITECTURE }} ARTIFACT_NAME: artifact-pax-build @@ -130,9 +174,11 @@ jobs: DOCKERFILE: .github/container/Dockerfile.pax.${{ inputs.ARCHITECTURE }} secrets: inherit + # base, jax, upstream-t5x, rosetta-t5x build-rosetta-t5x: needs: build-upstream-t5x uses: ./.github/workflows/_build_rosetta.yaml + if: contains(fromJSON('["base", "jax", "upstream-t5x", "rosetta-t5x"]'), inputs.TEST_SUBSET) with: ARCHITECTURE: ${{ inputs.ARCHITECTURE }} BUILD_DATE: ${{ inputs.BUILD_DATE }} @@ -140,9 +186,11 @@ jobs: BASE_LIBRARY: t5x secrets: inherit + # base, jax, upstream-pax, rosetta-pax build-rosetta-pax: needs: build-upstream-pax uses: ./.github/workflows/_build_rosetta.yaml + if: contains(fromJSON('["base", "jax", "upstream-pax", "rosetta-pax"]'), inputs.TEST_SUBSET) with: ARCHITECTURE: ${{ inputs.ARCHITECTURE }} BUILD_DATE: ${{ inputs.BUILD_DATE }} @@ -150,22 +198,24 @@ jobs: BASE_LIBRARY: pax secrets: inherit + # base, jax, grok build-grok: needs: [build-jax] uses: ./.github/workflows/_build.yaml + if: contains(fromJSON('["base", "jax", "grok"]'), inputs.TEST_SUBSET) with: ARCHITECTURE: ${{ inputs.ARCHITECTURE }} - ARTIFACT_NAME: "artifact-grok-build" - BADGE_FILENAME: "badge-grok-build" + ARTIFACT_NAME: 'artifact-grok-build' + BADGE_FILENAME: 'badge-grok-build' BUILD_DATE: ${{ inputs.BUILD_DATE }} BASE_IMAGE: ${{ needs.build-jax.outputs.DOCKER_TAG_MEALKIT }} CONTAINER_NAME: grok DOCKERFILE: .github/container/Dockerfile.grok secrets: inherit - + collect-docker-tags: runs-on: ubuntu-22.04 - if: "!cancelled()" + if: '!cancelled()' needs: - build-base - build-jax @@ -236,9 +286,10 @@ jobs: - name: Run integration test ${{ matrix.TEST_SCRIPT }} run: bash rosetta/tests/${{ matrix.TEST_SCRIPT }} + # base, jax test-jax: needs: build-jax - if: inputs.ARCHITECTURE == 'amd64' # arm64 runners n/a + if: contains(fromJSON('["base", "jax"]'), inputs.TEST_SUBSET) && inputs.ARCHITECTURE == 'amd64' # arm64 runners n/a uses: ./.github/workflows/_test_unit.yaml with: TEST_NAME: jax @@ -291,33 +342,37 @@ jobs: # test-equinox.log # secrets: inherit + # base, jax, upstream-pax test-te-multigpu: needs: build-upstream-pax - if: inputs.ARCHITECTURE == 'amd64' # arm64 runners n/a + if: contains(fromJSON('["base", "jax", "upstream-pax"]'), inputs.TEST_SUBSET) && inputs.ARCHITECTURE == 'amd64' # arm64 runners n/a uses: ./.github/workflows/_test_te.yaml with: TE_IMAGE: ${{ needs.build-upstream-pax.outputs.DOCKER_TAG_FINAL }} secrets: inherit + # base, jax, upstream-t5x test-upstream-t5x: needs: build-upstream-t5x - if: inputs.ARCHITECTURE == 'amd64' # arm64 runners n/a + if: contains(fromJSON('["base", "jax", "upstream-t5x"]'), inputs.TEST_SUBSET) && inputs.ARCHITECTURE == 'amd64' # arm64 runners n/a uses: ./.github/workflows/_test_upstream_t5x.yaml with: T5X_IMAGE: ${{ needs.build-upstream-t5x.outputs.DOCKER_TAG_FINAL }} secrets: inherit + # base, jax, upstream-t5x, rosetta-t5x test-rosetta-t5x: needs: build-rosetta-t5x - if: inputs.ARCHITECTURE == 'amd64' # no images for arm64 + if: contains(fromJSON('["base", "jax", "upstream-t5x", "rosetta-t5x"]'), inputs.TEST_SUBSET) && inputs.ARCHITECTURE == 'amd64' # no images for arm64 uses: ./.github/workflows/_test_t5x_rosetta.yaml with: T5X_IMAGE: ${{ needs.build-rosetta-t5x.outputs.DOCKER_TAG_FINAL }} secrets: inherit + # base, jax test-pallas: needs: build-jax - if: inputs.ARCHITECTURE == 'amd64' # triton doesn't support arm64(?) + if: contains(fromJSON('["base", "jax"]'), inputs.TEST_SUBSET) && inputs.ARCHITECTURE == 'amd64' # triton doesn't support arm64(?) uses: ./.github/workflows/_test_unit.yaml with: TEST_NAME: pallas @@ -341,9 +396,10 @@ jobs: test-pallas.log secrets: inherit + # base, jax, triton test-triton: needs: build-triton - if: inputs.ARCHITECTURE == 'amd64' # no images for arm64 + if: contains(fromJSON('["base", "jax", "triton"]'), inputs.TEST_SUBSET) && inputs.ARCHITECTURE == 'amd64' # no images for arm64 uses: ./.github/workflows/_test_unit.yaml with: TEST_NAME: triton @@ -367,9 +423,10 @@ jobs: test-triton.log secrets: inherit + # base, jax, levanter test-levanter: needs: build-levanter - if: inputs.ARCHITECTURE == 'amd64' # arm64 runners n/a + if: contains(fromJSON('["base", "jax", "levanter"]'), inputs.TEST_SUBSET) && inputs.ARCHITECTURE == 'amd64' # arm64 runners n/a uses: ./.github/workflows/_test_unit.yaml with: TEST_NAME: levanter @@ -394,9 +451,10 @@ jobs: test-levanter.log secrets: inherit + # base, jax, upstream-pax test-te: needs: build-upstream-pax - if: inputs.ARCHITECTURE == 'amd64' # arm64 runners n/a + if: contains(fromJSON('["base", "jax", "upstream-pax"]'), inputs.TEST_SUBSET) && inputs.ARCHITECTURE == 'amd64' # arm64 runners n/a uses: ./.github/workflows/_test_unit.yaml with: TEST_NAME: te @@ -422,25 +480,28 @@ jobs: pytest-report.jsonl secrets: inherit + # base, jax, upstream-pax test-upstream-pax: needs: build-upstream-pax - if: inputs.ARCHITECTURE == 'amd64' # no images for arm64 + if: contains(fromJSON('["base", "jax", "upstream-pax"]'), inputs.TEST_SUBSET) && inputs.ARCHITECTURE == 'amd64' # no images for arm64 uses: ./.github/workflows/_test_upstream_pax.yaml with: PAX_IMAGE: ${{ needs.build-upstream-pax.outputs.DOCKER_TAG_FINAL }} secrets: inherit + # base, jax, upstream-pax, rosetta-pax test-rosetta-pax: needs: build-rosetta-pax - if: inputs.ARCHITECTURE == 'amd64' # no images for arm64 + if: contains(fromJSON('["base", "jax", "upstream-pax", "rosetta-pax"]'), inputs.TEST_SUBSET) && inputs.ARCHITECTURE == 'amd64' # no images for arm64 uses: ./.github/workflows/_test_pax_rosetta.yaml with: PAX_IMAGE: ${{ needs.build-rosetta-pax.outputs.DOCKER_TAG_FINAL }} secrets: inherit + # base, jax, maxtext test-maxtext: needs: build-maxtext - if: inputs.ARCHITECTURE == 'amd64' # no images for arm64 + if: contains(fromJSON('["base", "jax", "maxtext"]'), inputs.TEST_SUBSET) && inputs.ARCHITECTURE == 'amd64' # no images for arm64 uses: ./.github/workflows/_test_maxtext.yaml with: MAXTEXT_IMAGE: ${{ needs.build-maxtext.outputs.DOCKER_TAG_FINAL }} diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 0098b83bf..02bc31ee3 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -2,7 +2,7 @@ name: CI on: schedule: - - cron: '30 9 * * *' # Pacific Time 01:30 AM in UTC + - cron: '30 9 * * *' # Pacific Time 01:30 AM in UTC pull_request: types: - opened @@ -25,25 +25,41 @@ on: required: false MERGE_BUMPED_MANIFEST: type: boolean - description: "(used if BUMP_MANIFEST=true) If true: attempt to PR/merge manifest branch" + description: '(used if BUMP_MANIFEST=true) If true: attempt to PR/merge manifest branch' default: false required: false + TEST_SUBSET: + type: choice + options: + - base + - core + - levanter + - equinox + - triton + - upstream-t5x + - rosetta-t5x + - upstream-pax + - rosetta-pax + - maxtext + - grok + description: Subset of tests to run. Will run all downstream-connected nodes and leaves. + default: 'base' + required: false concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} permissions: - contents: write # to fetch code and push branch - actions: write # to cancel previous workflows - packages: write # to upload container - pull-requests: write # to make pull request for manifest bump + contents: write # to fetch code and push branch + actions: write # to cancel previous workflows + packages: write # to upload container + pull-requests: write # to make pull request for manifest bump env: DEFAULT_MANIFEST_ARTIFACT_NAME: bumped-manifest jobs: - metadata: runs-on: ubuntu-22.04 outputs: @@ -53,6 +69,7 @@ jobs: MANIFEST_ARTIFACT_NAME: ${{ steps.manifest-branch.outputs.MANIFEST_ARTIFACT_NAME }} MANIFEST_BRANCH: ${{ steps.manifest-branch.outputs.MANIFEST_BRANCH }} MERGE_BUMPED_MANIFEST: ${{ steps.manifest-branch.outputs.MERGE_BUMBED_MANIFEST }} + TEST_SUBSET: ${{ steps.testset.outputs.TEST_SUBSET }} steps: - name: Cancel workflow run if the trigger is a draft PR id: cancel-if-draft @@ -103,6 +120,13 @@ jobs: exit 1 fi + - name: Set TESTSET + id: testset + shell: bash -x -e {0} + run: | + TEST_SUBSET=${{ inputs.TEST_SUBSET || 'base' }} + echo "TEST_SUBSET=$TEST_SUBSET" | tee -a $GITHUB_OUTPUT + bump-manifest: needs: metadata runs-on: ubuntu-22.04 @@ -115,7 +139,7 @@ jobs: shell: bash -x -e {0} run: | bash bump.sh --input-manifest manifest.yaml --output-manifest manifest.yaml.new --base-patch-dir ./patches-new - + - name: Maybe replace current manifest/patches with the new one and show diff working-directory: .github/container shell: bash -x -e {0} @@ -146,6 +170,7 @@ jobs: ARCHITECTURE: amd64 BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }} MANIFEST_ARTIFACT_NAME: ${{ needs.metadata.outputs.MANIFEST_ARTIFACT_NAME }} + TEST_SUBSET: ${{ needs.metadata.outputs.TEST_SUBSET }} secrets: inherit arm64: @@ -155,6 +180,7 @@ jobs: ARCHITECTURE: arm64 BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }} MANIFEST_ARTIFACT_NAME: ${{ needs.metadata.outputs.MANIFEST_ARTIFACT_NAME }} + TEST_SUBSET: ${{ needs.metadata.outputs.TEST_SUBSET }} secrets: inherit # Only merge if everything succeeds @@ -168,12 +194,11 @@ jobs: steps: - name: "Tests Succeeded: ${{ !contains(needs.*.result, 'failure') && !contains(needs.*.result, 'cancelled') }}" id: test_result - run: - echo "SUCCEEDED=${{ !contains(needs.*.result, 'failure') && !contains(needs.*.result, 'cancelled') }}" | tee -a $GITHUB_OUTPUT + run: echo "SUCCEEDED=${{ !contains(needs.*.result, 'failure') && !contains(needs.*.result, 'cancelled') }}" | tee -a $GITHUB_OUTPUT - name: Check out the repository under ${GITHUB_WORKSPACE} uses: actions/checkout@v4 - + - name: Delete checked-out manifest and patches run: | rm .github/container/manifest.yaml @@ -185,7 +210,7 @@ jobs: name: ${{ needs.metadata.outputs.MANIFEST_ARTIFACT_NAME }} path: .github/container/ - - name: "Create local manifest branch: ${{ needs.metadata.outputs.MANIFEST_BRANCH }}" + - name: 'Create local manifest branch: ${{ needs.metadata.outputs.MANIFEST_BRANCH }}' id: local_branch shell: bash -x -e {0} run: | @@ -213,7 +238,7 @@ jobs: git merge --ff-only ${{ needs.metadata.outputs.MANIFEST_BRANCH }} # Push the new change git push origin ${{ github.ref_name }} - + # We will create a Draft PR & remote branch if: # 1. The tests failed # 2. The merge failed @@ -244,12 +269,12 @@ jobs: draft: true env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - - - name: "Log created PR: #${{ fromJson(steps.create_pr.outputs.data).number }}" + + - name: 'Log created PR: #${{ fromJson(steps.create_pr.outputs.data).number }}' if: steps.create_pr.outcome == 'success' run: | echo "https://github.com/NVIDIA/JAX-Toolbox/pull/${{ fromJson(steps.create_pr.outputs.data).number }}" | tee -a $GITHUB_STEP_SUMMARY - + # Guard delete in simple check to protect other branches - name: Check that the branch matches znightly- prefix run: | @@ -271,7 +296,7 @@ jobs: make-publish-configs: runs-on: ubuntu-22.04 - if: ${{ !cancelled() }} + if: ${{ !cancelled() }} env: MEALKIT_IMAGE_REPO: ${{ needs.metadata.outputs.PUBLISH == 'true' && 'jax-mealkit' || 'mock-jax-mealkit' }} FINAL_IMAGE_REPO: ${{ needs.metadata.outputs.PUBLISH == 'true' && 'jax' || 'mock-jax' }} @@ -365,7 +390,7 @@ jobs: needs: - metadata - make-publish-configs - if: ${{ !cancelled() && needs.make-publish-configs.outputs.PUBLISH_CONFIGS.config != '{"config":[]}' }} + if: ${{ !cancelled() && needs.make-publish-configs.outputs.PUBLISH_CONFIGS.config != '{"config":[]}' }} strategy: fail-fast: false matrix: ${{ fromJson(needs.make-publish-configs.outputs.PUBLISH_CONFIGS) }} @@ -381,7 +406,7 @@ jobs: finalize: needs: [metadata, amd64, arm64, publish-containers] - if: "!cancelled()" + if: '!cancelled()' uses: ./.github/workflows/_finalize.yaml with: BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }}