-
Notifications
You must be signed in to change notification settings - Fork 154
/
Copy pathimage_task_generator.py
87 lines (75 loc) · 3 KB
/
image_task_generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import bittensor as bt
import numpy as np
import hashlib
import random
from typing import Tuple, List
from webgenie.tasks.metric_types import (
ACCURACY_METRIC_NAME,
QUALITY_METRIC_NAME,
SEO_METRIC_NAME,
)
from webgenie.tasks.task_generator import TaskGenerator
from webgenie.constants import IMAGE_TASK_TIMEOUT, GROUND_TRUTH_HTML_LOAD_TIME
from webgenie.helpers.htmls import (
html_to_screenshot,
preprocess_html,
is_empty_html,
)
from webgenie.helpers.images import base64_to_image
from webgenie.protocol import WebgenieImageSynapse
from webgenie.tasks.solution import Solution
from webgenie.tasks.task import Task, ImageTask
from webgenie.rewards import (
QualityReward,
VisualReward,
LighthouseReward,
)
from webgenie.datasets import (
RandomWebsiteDataset,
SyntheticDataset,
HuggingfaceDataset,
)
class ImageTaskGenerator(TaskGenerator):
def __init__(self):
super().__init__()
self.datasets = [
#(RandomWebsiteDataset(), 1),
(SyntheticDataset(), 0.5),
(HuggingfaceDataset(dataset_name="SALT-NLP/Design2Code-hf", split="train", html_column="text"), 0.5),
]
self.metrics = {
ACCURACY_METRIC_NAME: VisualReward(),
SEO_METRIC_NAME: LighthouseReward(),
QUALITY_METRIC_NAME: QualityReward(),
}
async def generate_task(self) -> Tuple[Task, bt.Synapse]:
bt.logging.info("Generating Image task")
dataset, _ = random.choices(self.datasets, weights=[weight for _, weight in self.datasets])[0]
dataset_entry = await dataset.generate_context()
bt.logging.debug(f"Generated dataset entry: {dataset_entry.url}")
ground_truth_html = preprocess_html(dataset_entry.ground_truth_html)
bt.logging.info(f"Preprocessed ground truth html")
if not ground_truth_html :
raise ValueError("Invalid ground truth html")
if is_empty_html(ground_truth_html):
raise ValueError("Empty ground truth html")
base64_image = await html_to_screenshot(ground_truth_html, page_load_time=GROUND_TRUTH_HTML_LOAD_TIME)
# Check image dimensions ratio
image = base64_to_image(base64_image)
width, height = image.size
aspect_ratio = height / width
if aspect_ratio > 7: # If height is more than 7x the width
raise ValueError(f"Image aspect ratio too extreme: {aspect_ratio:.2f}. Height should not exceed 7x width.")
bt.logging.debug(f"Screenshot generated for {dataset_entry.url}")
image_task = ImageTask(
base64_image=base64_image,
ground_truth_html=ground_truth_html,
generator=self,
src=dataset_entry.src,
task_id=hashlib.sha256(dataset_entry.url.encode()).hexdigest(),
timeout=IMAGE_TASK_TIMEOUT,
)
return (
image_task,
WebgenieImageSynapse(base64_image=base64_image, task_id=image_task.task_id),
)