Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DNS optimize support #19429

Merged
merged 9 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/lightning/data/processing/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,18 @@
from tqdm.auto import tqdm as _tqdm

from lightning import seed_everything
from lightning.data.processing.readers import BaseReader
from lightning.data.streaming import Cache
from lightning.data.streaming.cache import Dir
from lightning.data.streaming.client import S3Client
from lightning.data.streaming.constants import (
from lightning.data.constants import (
_BOTO3_AVAILABLE,
_DEFAULT_FAST_DEV_RUN_ITEMS,
_INDEX_FILENAME,
_IS_IN_STUDIO,
_LIGHTNING_CLOUD_LATEST,
_TORCH_GREATER_EQUAL_2_1_0,
)
from lightning.data.processing.readers import BaseReader
from lightning.data.streaming import Cache
from lightning.data.streaming.cache import Dir
from lightning.data.streaming.client import S3Client
from lightning.data.streaming.resolver import _resolve_dir
from lightning.data.utilities.broadcast import broadcast_object
from lightning.data.utilities.packing import _pack_greedily
Expand Down
47 changes: 47 additions & 0 deletions src/lightning/data/processing/dns.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from contextlib import contextmanager
from subprocess import Popen
from typing import Any

from lightning.data.constants import _IS_IN_STUDIO


@contextmanager
def optimize_dns_context(enable: bool) -> Any:
optimize_dns(enable)
try:
yield
optimize_dns(False) # always disable the optimize DNS
except Exception as e:
optimize_dns(False) # always disable the optimize DNS
raise e

def optimize_dns(enable: bool) -> None:
if not _IS_IN_STUDIO:
return

with open("/etc/resolv.conf") as f:
lines = f.readlines()

if (
(enable and any("127.0.0.53" in line for line in lines))
or (not enable and any("127.0.0.1" in line for line in lines))
): # noqa E501
Popen(f"sudo /home/zeus/miniconda3/envs/cloudspace/bin/python -c 'from lightning.data.processing.dns import _optimize_dns; _optimize_dns({enable})'", shell=True).wait() # noqa E501

def _optimize_dns(enable: bool) -> None:
with open("/etc/resolv.conf") as f:
lines = f.readlines()

write_lines = []
for line in lines:
if "nameserver 127" in line:
if enable:
write_lines.append('nameserver 127.0.0.1\n')
else:
write_lines.append('nameserver 127.0.0.53\n')
else:
write_lines.append(line)

with open("/etc/resolv.conf", "w") as f:
for line in write_lines:
f.write(line)
25 changes: 15 additions & 10 deletions src/lightning/data/processing/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@

import torch

from lightning.data.constants import _IS_IN_STUDIO, _TORCH_GREATER_EQUAL_2_1_0
from lightning.data.processing.data_processor import DataChunkRecipe, DataProcessor, DataTransformRecipe
from lightning.data.processing.dns import optimize_dns_context
from lightning.data.processing.readers import BaseReader
from lightning.data.streaming.constants import _IS_IN_STUDIO, _TORCH_GREATER_EQUAL_2_1_0
from lightning.data.streaming.resolver import (
Dir,
_assert_dir_has_index_file,
Expand Down Expand Up @@ -218,7 +219,8 @@ def map(
weights=weights,
reader=reader,
)
return data_processor.run(LambdaDataTransformRecipe(fn, inputs))
with optimize_dns_context(True):
return data_processor.run(LambdaDataTransformRecipe(fn, inputs))
return _execute(
f"data-prep-map-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}",
num_nodes,
Expand Down Expand Up @@ -303,15 +305,18 @@ def optimize(
reorder_files=reorder_files,
reader=reader,
)
return data_processor.run(
LambdaDataChunkRecipe(
fn,
inputs,
chunk_size=chunk_size,
chunk_bytes=chunk_bytes,
compression=compression,

with optimize_dns_context(True):
data_processor.run(
LambdaDataChunkRecipe(
fn,
inputs,
chunk_size=chunk_size,
chunk_bytes=chunk_bytes,
compression=compression,
)
)
)
return None
return _execute(
f"data-prep-optimize-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}",
num_nodes,
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/data/streaming/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import os
from typing import Any, Dict, List, Optional, Tuple, Union

from lightning.data.streaming.constants import (
from lightning.data.constants import (
_INDEX_FILENAME,
_LIGHTNING_CLOUD_LATEST,
_TORCH_GREATER_EQUAL_2_1_0,
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/data/streaming/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from time import time
from typing import Any, Optional

from lightning.data.streaming.constants import _BOTO3_AVAILABLE
from lightning.data.constants import _BOTO3_AVAILABLE

if _BOTO3_AVAILABLE:
import boto3
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/data/streaming/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import os
from typing import Any, Dict, List, Optional, Tuple

from lightning.data.streaming.constants import _INDEX_FILENAME, _TORCH_GREATER_EQUAL_2_1_0
from lightning.data.constants import _INDEX_FILENAME, _TORCH_GREATER_EQUAL_2_1_0
from lightning.data.streaming.downloader import get_downloader_cls
from lightning.data.streaming.item_loader import BaseItemLoader, PyTreeLoader, TokensLoader
from lightning.data.streaming.sampler import ChunkedIndex
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/data/streaming/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@
)
from torch.utils.data.sampler import BatchSampler, Sampler

from lightning.data.constants import _DEFAULT_CHUNK_BYTES, _TORCH_GREATER_EQUAL_2_1_0, _VIZ_TRACKER_AVAILABLE
from lightning.data.streaming import Cache
from lightning.data.streaming.combined import (
__NUM_SAMPLES_YIELDED_KEY__,
__SAMPLES_KEY__,
CombinedStreamingDataset,
)
from lightning.data.streaming.constants import _DEFAULT_CHUNK_BYTES, _TORCH_GREATER_EQUAL_2_1_0, _VIZ_TRACKER_AVAILABLE
from lightning.data.streaming.dataset import StreamingDataset
from lightning.data.streaming.sampler import CacheBatchSampler
from lightning.data.utilities.env import _DistributedEnv
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/data/streaming/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
import numpy as np
from torch.utils.data import IterableDataset

from lightning.data.streaming import Cache
from lightning.data.streaming.constants import (
from lightning.data.constants import (
_DEFAULT_CACHE_DIR,
_INDEX_FILENAME,
)
from lightning.data.streaming import Cache
from lightning.data.streaming.item_loader import BaseItemLoader
from lightning.data.streaming.resolver import Dir, _resolve_dir
from lightning.data.streaming.sampler import ChunkedIndex
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/data/streaming/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@

from filelock import FileLock, Timeout

from lightning.data.constants import _INDEX_FILENAME
from lightning.data.streaming.client import S3Client
from lightning.data.streaming.constants import _INDEX_FILENAME


class Downloader(ABC):
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/data/streaming/item_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import numpy as np
import torch

from lightning.data.streaming.constants import (
from lightning.data.constants import (
_TORCH_DTYPES_MAPPING,
_TORCH_GREATER_EQUAL_2_1_0,
)
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/data/streaming/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from threading import Thread
from typing import Any, Dict, List, Optional, Tuple, Union

from lightning.data.constants import _TORCH_GREATER_EQUAL_2_1_0
from lightning.data.streaming.config import ChunksConfig
from lightning.data.streaming.constants import _TORCH_GREATER_EQUAL_2_1_0
from lightning.data.streaming.item_loader import BaseItemLoader, PyTreeLoader
from lightning.data.streaming.sampler import ChunkedIndex
from lightning.data.streaming.serializers import Serializer, _get_serializers
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/data/streaming/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import torch
from lightning_utilities.core.imports import RequirementCache

from lightning.data.streaming.constants import _NUMPY_DTYPES_MAPPING, _TORCH_DTYPES_MAPPING
from lightning.data.constants import _NUMPY_DTYPES_MAPPING, _TORCH_DTYPES_MAPPING

_PIL_AVAILABLE = RequirementCache("PIL")
_TORCH_VISION_AVAILABLE = RequirementCache("torchvision")
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/data/streaming/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
import numpy as np
import torch

from lightning.data.constants import _INDEX_FILENAME, _TORCH_GREATER_EQUAL_2_1_0
from lightning.data.streaming.compression import _COMPRESSORS, Compressor
from lightning.data.streaming.constants import _INDEX_FILENAME, _TORCH_GREATER_EQUAL_2_1_0
from lightning.data.streaming.serializers import Serializer, _get_serializers
from lightning.data.utilities.env import _DistributedEnv, _WorkerEnv
from lightning.data.utilities.format import _convert_bytes_to_int, _human_readable_bytes
Expand Down
33 changes: 33 additions & 0 deletions tests/tests_data/processing/test_dns.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from unittest.mock import MagicMock

from lightning.data.processing import dns as dns_module
from lightning.data.processing.dns import optimize_dns_context


def test_optimize_dns_context(monkeypatch):
popen_mock = MagicMock()

monkeypatch.setattr(dns_module, "_IS_IN_STUDIO", True)
monkeypatch.setattr(dns_module, "Popen", popen_mock)

class FakeFile:

def __init__(self, *args, **kwargs):
pass

def __enter__(self):
return self

def __exit__(self, *args, **kwargs):
return self

def readlines(self):
return ["127.0.0.53"]

monkeypatch.setitem(__builtins__, "open", MagicMock(return_value=FakeFile()))

with optimize_dns_context(True):
pass

cmd = popen_mock._mock_call_args_list[0].args[0]
assert cmd == "sudo /home/zeus/miniconda3/envs/cloudspace/bin/python -c 'from lightning.data.processing.dns import _optimize_dns; _optimize_dns(True)'" # noqa: E501
Loading