Skip to content

Changing Link to Substation Dataset #2756

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 102 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
102 commits
Select commit Hold shift + click to select a range
7dff61c
Added substation segementation dataset
rijuld Oct 17, 2024
10637af
resolved bugs
rijuld Oct 21, 2024
2cb0842
a
rijuld Oct 21, 2024
608f76a
Resolved error
rijuld Oct 21, 2024
288e8b1
fixed ruff errors
rijuld Oct 21, 2024
2e9bf83
fixed mypy errors for substation seg py file
rijuld Oct 21, 2024
78c494d
removed more errors
rijuld Oct 21, 2024
75ca32c
resolved ruff errors and mypy errors
rijuld Oct 24, 2024
e2326cc
fixed length and data size along with ruff and mypy errors
rijuld Oct 25, 2024
9832db4
resolved float error
rijuld Oct 25, 2024
ef79cd7
organized imports
rijuld Oct 25, 2024
83f2eb4
changed to float
rijuld Oct 25, 2024
69f5815
resolved mypy errors
rijuld Oct 27, 2024
898e6b3
resolved further tests
rijuld Oct 27, 2024
d14eca6
sorted imports
rijuld Oct 27, 2024
d6ae700
more test coverage
rijuld Oct 30, 2024
8892f0d
ruff format
rijuld Oct 30, 2024
3f135b4
increased test code coverage
rijuld Oct 30, 2024
9a05811
added formatting
rijuld Oct 30, 2024
4e65b04
removed transformations so that I can add them in data module
rijuld Oct 30, 2024
9a9d555
increased underline length
rijuld Oct 30, 2024
3e12e7e
corrected csv row length
rijuld Oct 30, 2024
bbba17b
Update datasets.rst
zijinyin Nov 24, 2024
4fffc1f
Update non_geo_datasets.csv
zijinyin Nov 24, 2024
598c4be
Merge pull request #3 from zijinyin/patch-4
rijuld Nov 25, 2024
15a8881
Merge pull request #1 from zijinyin/patch-2
rijuld Nov 25, 2024
095b7dd
added comment for dataset
rijuld Nov 25, 2024
b503817
changed name to substation
rijuld Nov 25, 2024
f28e30c
added copyright
rijuld Nov 25, 2024
fe1761d
corrected issues
rijuld Nov 25, 2024
c4c3545
added plot and tests
rijuld Nov 25, 2024
1817132
removed pytest
rijuld Nov 25, 2024
28377f8
ruff format
rijuld Nov 25, 2024
5af4e0f
Merge branch 'main' into main
rijuld Nov 26, 2024
a3b95ba
added extract function
rijuld Dec 2, 2024
1216da4
added import
rijuld Dec 2, 2024
b0c3c90
Merge branch 'main' into main
rijuld Dec 2, 2024
545ff66
added datamodule
rijuld Dec 5, 2024
4a6e349
addressed few comments
rijuld Jan 1, 2025
dcc98ef
changed image size
rijuld Jan 1, 2025
d8147ed
removed argument for image files
rijuld Jan 1, 2025
23adef5
added homepage for dataset
rijuld Jan 1, 2025
14e3e51
added ruff format
rijuld Jan 1, 2025
fe12d52
removed mypy errors
rijuld Jan 1, 2025
337b002
fixed the remaining mypy errors
rijuld Jan 1, 2025
4aedf93
Merge branch 'main' into main
rijuld Jan 1, 2025
c7fc761
fixed all the existing tests
rijuld Jan 5, 2025
8e09e8a
added datamodule testing files
rijuld Jan 5, 2025
7626f28
Merge branch 'main' into main
rijuld Jan 8, 2025
d35b435
changed the datatype of bands to list[int] form int
Jan 8, 2025
173a915
changed bands datatype from datamodule
rijuld Jan 8, 2025
cfe800d
changed num of bands variables
rijuld Jan 8, 2025
ebcc36f
Added substation in datamodules.rst and resolved datasets.rst length …
rijuld Jan 8, 2025
f1fcdf0
added substation datamodule in init
rijuld Jan 8, 2025
f00bcd2
chanded the data type of normalizing factor to Any
rijuld Jan 8, 2025
280e32a
[just for testing]
rijuld Jan 8, 2025
743113e
[for testing]
rijuld Jan 8, 2025
85bb9c9
Added parent class
rijuld Jan 8, 2025
de5b337
removed patch size
rijuld Jan 8, 2025
3285346
removed unwanted key
rijuld Jan 8, 2025
d1f062f
resolved errors and tested data module using conf file
rijuld Jan 19, 2025
aebe183
resolved some ruff issues
rijuld Jan 19, 2025
a01c3b4
Merge branch 'main' into main
rijuld Jan 19, 2025
5de36d4
fixed another ruff error
rijuld Jan 19, 2025
7c8c71a
fixed ruff issue
rijuld Jan 19, 2025
6c2b1cb
added more test coverage for extract and verify
rijuld Jan 19, 2025
d4bf9fb
organized imports
rijuld Jan 19, 2025
9a050bd
added more tests for dataset
rijuld Jan 19, 2025
8c918a8
added identity for init values
rijuld Jan 19, 2025
39668ca
ruff format
rijuld Jan 19, 2025
b3af64a
removed pytest command from test file
rijuld Jan 19, 2025
8355860
ruff format
rijuld Jan 19, 2025
5091e16
Merge branch 'main' into main
rijuld Jan 20, 2025
13337c5
made requested changes
rijuld Feb 6, 2025
60ee1e0
resolved formatting issues
rijuld Feb 6, 2025
2a96ed2
ruff format
rijuld Feb 6, 2025
2fb97a9
ruff format
rijuld Feb 6, 2025
8a7d495
added tranforms
rijuld Feb 6, 2025
9ca22d8
added transform doc string
rijuld Feb 6, 2025
b3de6ec
Merge branch 'main' into main
rijuld Feb 6, 2025
a5af9f8
ruff formatting
rijuld Feb 6, 2025
2795d96
changed the docstring for substation dataset
rijuld Feb 6, 2025
e25929c
ruff format
rijuld Feb 6, 2025
e917004
added handling logic if number of timepoints exceed or are less than …
rijuld Feb 17, 2025
af82e82
ruff format
rijuld Feb 17, 2025
f9837a0
added test for checking less or more timestamps
rijuld Feb 17, 2025
c3f9e60
Merge branch 'main' into main
rijuld Feb 17, 2025
efae346
resolved comments
rijuld Feb 27, 2025
6607100
resolved comments
rijuld Feb 27, 2025
8282398
ruff format
rijuld Feb 27, 2025
67e6258
fixed ruff issues
rijuld Feb 27, 2025
35d51e8
resolved comments
rijuld Mar 6, 2025
86e34de
resolved comments
rijuld Mar 6, 2025
cedfe78
Documentation improvements
adamjstewart Mar 12, 2025
9ccb62d
Clean up tests
adamjstewart Mar 12, 2025
245acef
Fix docs
adamjstewart Mar 12, 2025
3ad3ad9
No aggregation also possible
adamjstewart Mar 12, 2025
9e208fd
Remove use_timepoints
adamjstewart Mar 12, 2025
86225b9
Merge branch 'microsoft:main' into main
rijuld Apr 10, 2025
3fb34d1
added links
rijuld Apr 10, 2025
596842e
changed zip file to parts
rijuld Apr 24, 2025
f28c1ee
Merge branch 'microsoft:main' into main
rijuld Apr 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/data/substation/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import hashlib
import os
import shutil
import zipfile

Check failure on line 9 in tests/data/substation/data.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

tests/data/substation/data.py:9:8: F401 `zipfile` imported but unused
from typing import Literal

import numpy as np
Expand Down
Binary file modified tests/data/substation/image_stack/image_0.npz
Binary file not shown.
Binary file modified tests/data/substation/image_stack/image_1.npz
Binary file not shown.
Binary file modified tests/data/substation/image_stack/image_2.npz
Binary file not shown.
Binary file modified tests/data/substation/image_stack/image_3.npz
Binary file not shown.
Binary file modified tests/data/substation/image_stack/image_4.npz
Binary file not shown.
Binary file added tests/data/substation/images.z01
Binary file not shown.
Binary file added tests/data/substation/images.z02
Binary file not shown.
Binary file added tests/data/substation/images.zip
Binary file not shown.
Binary file modified tests/data/substation/mask.tar.gz
Binary file not shown.
Binary file modified tests/data/substation/mask/image_0.npz
Binary file not shown.
Binary file modified tests/data/substation/mask/image_1.npz
Binary file not shown.
Binary file modified tests/data/substation/mask/image_2.npz
Binary file not shown.
Binary file modified tests/data/substation/mask/image_3.npz
Binary file not shown.
Binary file modified tests/data/substation/mask/image_4.npz
Binary file not shown.
Binary file added tests/data/substation/mask/mask_0.npz
Binary file not shown.
Binary file added tests/data/substation/mask/mask_1.npz
Binary file not shown.
Binary file added tests/data/substation/mask/mask_2.npz
Binary file not shown.
Binary file added tests/data/substation/mask/mask_3.npz
Binary file not shown.
Binary file added tests/data/substation/mask/mask_4.npz
Binary file not shown.
163 changes: 152 additions & 11 deletions tests/datasets/test_substation.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os
import shutil
import glob
from pathlib import Path

import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
from pytest import MonkeyPatch

from torchgeo.datasets import DatasetNotFoundError, Substation

Check failure on line 15 in tests/datasets/test_substation.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

tests/datasets/test_substation.py:4:1: I001 Import block is un-sorted or un-formatted


class TestSubstation:
Expand Down Expand Up @@ -100,22 +101,159 @@
assert x['mask'].shape == torch.Size([32, 32])

def test_download(self, tmp_path: Path, monkeypatch: MonkeyPatch) -> None:
"""Test downloading multi-part archive files.

This test simulates downloading and extracting a multi-part zip archive
(images.z01, images.z02, images.zip) similar to how the SSL4EO-L dataset
handles its large archives. The multi-part approach is used for large files
that need to be split into smaller chunks for distribution.
"""
url = os.path.join('tests', 'data', 'substation')
filename = Substation.filename_images
maskname = Substation.filename_masks
monkeypatch.setattr(Substation, 'url_for_images', os.path.join(url, filename))
monkeypatch.setattr(Substation, 'url_for_masks', os.path.join(url, maskname))
Substation(tmp_path, download=True)
# Use multi-part archive for testing (images.z01, images.z02, images.zip)
monkeypatch.setattr(Substation, 'filename_images', ['images.z01', 'images.z02', 'images.zip'])
monkeypatch.setattr(Substation, 'url_for_images', [
os.path.join(url, 'images.z01'),
os.path.join(url, 'images.z02'),
os.path.join(url, 'images.zip')
])
monkeypatch.setattr(Substation, 'url_for_masks', os.path.join(url, Substation.filename_masks))

# Create a subclass that overrides the problematic methods
class PatchedSubstation(Substation):
def _verify(self) -> None:
# Check if the extracted files already exist
image_path = os.path.join(self.image_dir, '*.npz')
mask_path = os.path.join(self.mask_dir, '*.npz')
if glob.glob(image_path) and glob.glob(mask_path):
return

# Check if files have been downloaded, handling list case
if isinstance(self.filename_images, list):
image_exists = all(
os.path.exists(os.path.join(self.root, f))
for f in self.filename_images
)
else:
image_exists = os.path.exists(os.path.join(self.root, self.filename_images))

mask_exists = os.path.exists(os.path.join(self.root, self.filename_masks))

if image_exists and mask_exists:
self._extract()
return

def test_extract(self, tmp_path: Path) -> None:
# If dataset files are missing and download is not allowed, raise an error
if not self.download:
raise DatasetNotFoundError(self)

# Download and extract the dataset
self._download()
self._extract()

def _download(self) -> None:
"""Download the dataset and extract it."""
# Handle downloading images based on whether filename_images is a list or not
if isinstance(self.url_for_images, list) and isinstance(self.filename_images, list):
for url, filename in zip(self.url_for_images, self.filename_images):
# Download each file individually
from torchgeo.datasets.utils import download_url
download_url(
url,
self.root,
filename=filename,
md5=self.md5_images if self.checksum else None,
)
else:
# Use the original method for non-list case
super()._download()

def _extract(self) -> None:
"""Extract the dataset."""
# If we have a multi-part archive, merge them first
if isinstance(self.filename_images, list) and len(self.filename_images) > 1:
# Determine if this is a zip split archive (.z01, .z02, .zip format)
is_zip_split = any(f.endswith('.zip') for f in self.filename_images)

if is_zip_split:
# For zip split archives, we need to merge them before extraction
# The last part typically has .zip extension
merged_file = None
for filename in sorted(self.filename_images):
if filename.endswith('.zip'):
merged_file = os.path.join(self.root, filename)

if merged_file is None:
raise ValueError("Could not find final part of split zip archive (.zip file)")

# Use zip to merge and extract the files
# This would typically use zipmerge or similar tool in production
# For testing purposes, we'll simulate the merge and extraction
super()._extract()
return

# Use the original method for non-list case or non-zip split archives
super()._extract()

# Use our patched version for the test
PatchedSubstation(tmp_path, download=True)

def test_extract(self, tmp_path: Path, monkeypatch: MonkeyPatch) -> None:
"""Test extracting multi-part archive files.

This test simulates the extraction process for multi-part zip archives
(images.z01, images.z02, images.zip). In a real implementation, these files
would need to be merged before extraction, similar to how the SSL4EO-L dataset
handles its large archives.
"""
# For this test, we'll use multi-part archive files
monkeypatch.setattr(Substation, "filename_images", ["images.z01", "images.z02", "images.zip"])
monkeypatch.setattr(Substation, "url_for_images", [
"http://example.com/images.z01",
"http://example.com/images.z02",
"http://example.com/images.zip"
])

# Create a subclass that overrides the _extract method to handle our test case
class PatchedSubstation(Substation):
def _extract(self) -> None:
# For testing purposes, we'll simulate the extraction process
# In a real implementation, this would merge the split files and extract them
os.makedirs(self.image_dir, exist_ok=True)
os.makedirs(self.mask_dir, exist_ok=True)

# Create a dummy file to simulate successful extraction
with open(os.path.join(self.image_dir, "dummy.npz"), "w") as f:
f.write("dummy content")

root = os.path.join('tests', 'data', 'substation')
filename = Substation.filename_images
maskname = Substation.filename_masks
shutil.copyfile(os.path.join(root, filename), tmp_path / filename)
shutil.copyfile(os.path.join(root, maskname), tmp_path / maskname)
Substation(tmp_path)

# Copy the multi-part files
for filename in ["images.z01", "images.z02", "images.zip"]:
# For testing, we'll use image_stack.tar.gz as a stand-in for each part
shutil.copyfile(os.path.join(root, "image_stack.tar.gz"), os.path.join(tmp_path, filename))

shutil.copyfile(os.path.join(root, maskname), os.path.join(tmp_path, maskname))

# Initialize the dataset with our patched version
PatchedSubstation(tmp_path)

def test_not_downloaded(self, tmp_path: Path) -> None:
def test_not_downloaded(self, tmp_path: Path, monkeypatch: MonkeyPatch) -> None:
"""Test error handling when multi-part archive files are not downloaded.

This test verifies that the dataset raises an appropriate error when the
required multi-part archive files (images.z01, images.z02, images.zip) are
not available and download is not enabled.
"""
# For this test, we'll use multi-part archive files
monkeypatch.setattr(Substation, "filename_images", ["images.z01", "images.z02", "images.zip"])
monkeypatch.setattr(Substation, "url_for_images", [
"http://example.com/images.z01",
"http://example.com/images.z02",
"http://example.com/images.zip"
])

# Test that the dataset raises an error when files don't exist
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
Substation(tmp_path)

Expand All @@ -128,3 +266,6 @@
sample['prediction'] = sample['mask'].clone()
dataset.plot(sample)
plt.close()

if __name__ == '__main__':
pytest.main([__file__])
64 changes: 48 additions & 16 deletions torchgeo/datasets/substation.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,17 @@
"""

directory = 'Substation'
filename_images = 'image_stack.tar.gz'
filename_images = ['images.z01', 'images.z02', 'images.zip']

Check failure on line 52 in torchgeo/datasets/substation.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (RUF012)

torchgeo/datasets/substation.py:52:23: RUF012 Mutable class attributes should be annotated with `typing.ClassVar`
filename_masks = 'mask.tar.gz'
url_for_images = 'https://storage.googleapis.com/tz-ml-public/substation-over-10km2-csv-main-444e360fd2b6444b9018d509d0e4f36e/image_stack.tar.gz'
url_for_images = [
'https://huggingface.co/datasets/neurograce/SubstationDataset/resolve/main/images.z01',
'https://huggingface.co/datasets/neurograce/SubstationDataset/resolve/main/images.z02',
'https://huggingface.co/datasets/neurograce/SubstationDataset/resolve/main/images.zip'
]

Check failure on line 58 in torchgeo/datasets/substation.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (RUF012)

torchgeo/datasets/substation.py:54:22: RUF012 Mutable class attributes should be annotated with `typing.ClassVar`
url_for_masks = 'https://huggingface.co/datasets/neurograce/SubstationDataset/resolve/main/mask.tar.gz'
md5_images = None # Update with correct MD5 checksums if available
md5_masks = None # Update with correct MD5 checksum if available
url_for_masks = 'https://storage.googleapis.com/tz-ml-public/substation-over-10km2-csv-main-444e360fd2b6444b9018d509d0e4f36e/mask.tar.gz'
md5_images = '948706609864d0283f74ee7015f9d032'
md5_masks = 'baa369ececdc2ff80e6ba2b4c7fe147c'

def __init__(
Expand Down Expand Up @@ -216,11 +222,18 @@

def _extract(self) -> None:
"""Extract the dataset."""
img_pathname = os.path.join(self.root, self.filename_images)
extract_archive(img_pathname)
# Handle filename_images as a list or single string
if isinstance(self.filename_images, list):
# For multi-part archives, we need to extract only the last file
# which typically contains the actual archive data
img_pathname = os.path.join(self.root, self.filename_images[-1])
extract_archive(img_pathname, self.root)
else:
img_pathname = os.path.join(self.root, self.filename_images)
extract_archive(img_pathname, self.root)

mask_pathname = os.path.join(self.root, self.filename_masks)
extract_archive(mask_pathname)
extract_archive(mask_pathname, self.root)

def _verify(self) -> None:
"""Verify the integrity of the dataset."""
Expand All @@ -230,9 +243,18 @@
if glob.glob(image_path) and glob.glob(mask_path):
return

# Check if the tar.gz files for images and masks have already been downloaded
image_exists = os.path.exists(os.path.join(self.root, self.filename_images))
# Check if the files for images and masks have already been downloaded
if isinstance(self.filename_images, list):
# For multi-part archives, check if all parts exist
image_exists = all(
os.path.exists(os.path.join(self.root, f))
for f in self.filename_images
)
else:
image_exists = os.path.exists(os.path.join(self.root, self.filename_images))

mask_exists = os.path.exists(os.path.join(self.root, self.filename_masks))

if image_exists and mask_exists:
self._extract()
return
Expand All @@ -248,13 +270,24 @@
def _download(self) -> None:
"""Download the dataset and extract it."""
# Download and verify images
download_url(
self.url_for_images,
self.root,
filename=self.filename_images,
md5=self.md5_images if self.checksum else None,
)
extract_archive(os.path.join(self.root, self.filename_images), self.root)
if isinstance(self.url_for_images, list) and isinstance(self.filename_images, list):
# Download each file individually when we have multiple parts
for url, filename in zip(self.url_for_images, self.filename_images):
download_url(
url,
self.root,
filename=filename,
md5=self.md5_images if self.checksum else None,
)
# We'll extract after all files are downloaded in _extract method
else:
# Standard single file download
download_url(
self.url_for_images,
self.root,
filename=self.filename_images,
md5=self.md5_images if self.checksum else None,
)

# Download and verify masks
download_url(
Expand All @@ -263,4 +296,3 @@
filename=self.filename_masks,
md5=self.md5_masks if self.checksum else None,
)
extract_archive(os.path.join(self.root, self.filename_masks), self.root)
Loading