Skip to content

ref(replays): refactor query code #69334

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from
10 changes: 6 additions & 4 deletions src/sentry/replays/endpoints/organization_replay_details.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,15 @@ def get(self, request: Request, organization: Organization, replay_id: str) -> R

try:
filter_params = self.get_filter_params(
request, organization, project_ids=ALL_ACCESS_PROJECTS
request,
organization,
project_ids=ALL_ACCESS_PROJECTS,
date_filter_optional=False,
)
except NoProjects:
return Response(status=404)

if not filter_params["start"] or not filter_params["end"]:
return Response(status=404)
# "start" and "end" keys are expected to exist due to date_filter_optional=False.
# The fx returns defaults if filters aren't in the request

try:
replay_id = str(uuid.UUID(replay_id))
Expand Down
4 changes: 2 additions & 2 deletions src/sentry/replays/endpoints/organization_replay_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from sentry.exceptions import InvalidSearchQuery
from sentry.models.organization import Organization
from sentry.replays.post_process import ReplayDetailsResponse, process_raw_response
from sentry.replays.query import query_replays_collection_raw, replay_url_parser_config
from sentry.replays.query import query_replays_collection_paginated, replay_url_parser_config
from sentry.replays.usecases.errors import handled_snuba_exceptions
from sentry.replays.validators import ReplayValidator
from sentry.utils.cursors import Cursor, CursorResult
Expand Down Expand Up @@ -83,7 +83,7 @@ def data_fn(offset: int, limit: int):
# to do this for completeness sake.
return Response({"detail": "Missing start or end period."}, status=400)

return query_replays_collection_raw(
return query_replays_collection_paginated(
project_ids=filter_params["project_id"],
start=start,
end=end,
Expand Down
8 changes: 7 additions & 1 deletion src/sentry/replays/endpoints/project_replay_viewed_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from sentry.api.api_owners import ApiOwner
from sentry.api.api_publish_status import ApiPublishStatus
from sentry.api.base import region_silo_endpoint
from sentry.api.bases.organization import NoProjects
from sentry.api.bases.project import ProjectEndpoint, ProjectEventPermission
from sentry.apidocs.constants import RESPONSE_BAD_REQUEST, RESPONSE_FORBIDDEN, RESPONSE_NOT_FOUND
from sentry.apidocs.examples.replay_examples import ReplayExamples
Expand Down Expand Up @@ -66,7 +67,12 @@ def get(self, request: Request, project: Project, replay_id: str) -> Response:
return Response(status=404)

# query for user ids who viewed the replay
filter_params = self.get_filter_params(request, project, date_filter_optional=False)
try:
filter_params = self.get_filter_params(request, project, date_filter_optional=False)
except NoProjects:
return Response(status=404)
# "start" and "end" keys are expected to exist due to date_filter_optional=False.
# The fx returns defaults if filters aren't in the request

# If no rows were found then the replay does not exist and a 404 is returned.
viewed_by_ids_response: list[dict[str, Any]] = query_replay_viewed_by_ids(
Expand Down
211 changes: 99 additions & 112 deletions src/sentry/replays/post_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@


class DeviceResponseType(TypedDict, total=False):
name: str | None
brand: str | None
model: str | None
family: str | None
model: str | None
name: str | None


class SDKResponseType(TypedDict, total=False):
Expand All @@ -33,47 +33,47 @@ class BrowserResponseType(TypedDict, total=False):


class UserResponseType(TypedDict, total=False):
id: str | None
username: str | None
display_name: str | None
email: str | None
id: str | None
ip: str | None
display_name: str | None
username: str | None


@extend_schema_serializer(exclude_fields=["info_ids", "warning_ids"])
class ReplayDetailsResponse(TypedDict, total=False):
id: str
project_id: str
trace_ids: list[str]
error_ids: list[str]
environment: str | None
tags: dict[str, list[str]] | list
user: UserResponseType
sdk: SDKResponseType
os: OSResponseType
activity: int | None
browser: BrowserResponseType
device: DeviceResponseType
is_archived: bool | None
urls: list[str] | None
clicks: list[dict[str, Any]]
count_dead_clicks: int | None
count_rage_clicks: int | None
count_errors: int | None
count_infos: int | None
count_rage_clicks: int | None
count_segments: int | None
count_urls: int | None
count_warnings: int | None
device: DeviceResponseType
dist: str | None
duration: int | None
environment: str | None
error_ids: list[str]
finished_at: str | None
started_at: str | None
activity: int | None
count_urls: int | None
replay_type: str
count_segments: int | None
has_viewed: bool
id: str
info_ids: list[str] | None
is_archived: bool | None
os: OSResponseType
platform: str | None
project_id: str
releases: list[str]
dist: str | None
replay_type: str
sdk: SDKResponseType
started_at: str | None
tags: dict[str, list[str]] | list
trace_ids: list[str]
urls: list[str] | None
user: UserResponseType
warning_ids: list[str] | None
info_ids: list[str] | None
count_warnings: int | None
count_infos: int | None
has_viewed: bool


def process_raw_response(
Expand All @@ -95,95 +95,98 @@ def generate_restricted_fieldset(
yield from response


def _strip_dashes(field: str) -> str:
def _strip_dashes(field: str | None) -> str:
if field:
return field.replace("-", "")
return field
return ""


def generate_normalized_output(
response: list[dict[str, Any]]
) -> Generator[ReplayDetailsResponse, None, None]:
"""For each payload in the response strip "agg_" prefixes."""
"""Skip archives, strip "agg_" prefixes, coerce correct types, and compute/nest new fields"""

for item in response:

ret_item: ReplayDetailsResponse = {}
if item["isArchived"]:
if item.get("isArchived"):
yield _archived_row(item["replay_id"], item["project_id"]) # type: ignore[misc]
continue

ret_item["id"] = _strip_dashes(item.pop("replay_id", None))
# required fields
ret_item["project_id"] = str(item["project_id"])
ret_item["trace_ids"] = item.pop("traceIds", [])
ret_item["error_ids"] = item.pop("errorIds", [])
ret_item["environment"] = item.pop("agg_environment", None)

# modified + renamed fields
ret_item["environment"] = item.get("agg_environment", None)
# Returns a UInt8 of either 0 or 1. We coerce to a bool.
ret_item["has_viewed"] = bool(item.get("has_viewed", 0))
ret_item["id"] = _strip_dashes(item.get("replay_id", None))
ret_item["releases"] = list(filter(bool, item.get("releases", [])))

# computed fields
ret_item["browser"] = {
"name": item.get("browser_name", None),
"version": item.get("browser_version", None),
}
ret_item["clicks"] = extract_click_fields(item)
ret_item["device"] = {
"name": item.get("device_name", None),
"brand": item.get("device_brand", None),
"model": item.get("device_model", None),
"family": item.get("device_family", None),
}
ret_item["os"] = {
"name": item.get("os_name", None),
"version": item.get("os_version", None),
}
ret_item["sdk"] = {
"name": item.get("sdk_name", None),
"version": item.get("sdk_version", None),
}
ret_item["tags"] = dict_unique_list(
zip(
item.pop("tk", None) or [],
item.pop("tv", None) or [],
item.get("tk", None) or [],
item.get("tv", None) or [],
)
)
ret_item["user"] = {
"id": item.pop("user_id", None),
"username": item.pop("user_username", None),
"email": item.pop("user_email", None),
"ip": item.pop("user_ip", None),
"id": item.get("user_id", None),
"username": item.get("user_username", None),
"email": item.get("user_email", None),
"ip": item.get("user_ip", None),
}
ret_item["user"]["display_name"] = (
ret_item["user"]["username"]
or ret_item["user"]["email"]
or ret_item["user"]["id"]
or ret_item["user"]["ip"]
)
ret_item["sdk"] = {
"name": item.pop("sdk_name", None),
"version": item.pop("sdk_version", None),
}
ret_item["os"] = {
"name": item.pop("os_name", None),
"version": item.pop("os_version", None),
}
ret_item["browser"] = {
"name": item.pop("browser_name", None),
"version": item.pop("browser_version", None),
}
ret_item["device"] = {
"name": item.pop("device_name", None),
"brand": item.pop("device_brand", None),
"model": item.pop("device_model", None),
"family": item.pop("device_family", None),
}

item.pop("agg_urls", None)
ret_item["urls"] = item.pop("urls_sorted", None)

ret_item["is_archived"] = bool(item.pop("isArchived", 0))

item.pop("clickClass", None)
item.pop("click_selector", None)
ret_item["activity"] = item.pop("activity", None)
# don't need clickClass or click_selector
# for the click field, as they are only used for searching.
# optional fields
ret_item["activity"] = item.get("activity", None)
ret_item["count_dead_clicks"] = item.get("count_dead_clicks", None)
ret_item["count_errors"] = item.get("count_errors", None)
ret_item["count_infos"] = item.get("count_infos", None)
ret_item["count_rage_clicks"] = item.get("count_rage_clicks", None)
ret_item["count_segments"] = item.get("count_segments", None)
ret_item["count_urls"] = item.get("count_urls", None)
ret_item["count_warnings"] = item.get("count_warnings", None)
ret_item["dist"] = item.get("dist", None)
ret_item["duration"] = item.get("duration", None)
ret_item["error_ids"] = item.get("errorIds", [])
ret_item["finished_at"] = item.get("finished_at", None)
ret_item["info_ids"] = item.get("info_ids", None)
ret_item["is_archived"] = item.get("isArchived", None)
ret_item["platform"] = item.get("platform", None)
ret_item["replay_type"] = item.get("replay_type", "session")
ret_item["started_at"] = item.get("started_at", None)
ret_item["trace_ids"] = item.get("traceIds", [])
ret_item["urls"] = item.get("urls_sorted", None)
ret_item["warning_ids"] = item.get("warning_ids", None)

# excluded fields: agg_urls, clickClass, click_selector
# Don't need clickClass and click_selector for the click field, as they are only used for searching.
# (click.classes contains the full list of classes for a click)
ret_item["clicks"] = extract_click_fields(item)
ret_item["count_dead_clicks"] = item.pop("count_dead_clicks", None)
ret_item["count_errors"] = item.pop("count_errors", None)
ret_item["count_rage_clicks"] = item.pop("count_rage_clicks", None)
ret_item["count_segments"] = item.pop("count_segments", None)
ret_item["count_urls"] = item.pop("count_urls", None)
ret_item["dist"] = item.pop("dist", None)
ret_item["duration"] = item.pop("duration", None)
ret_item["finished_at"] = item.pop("finished_at", None)
ret_item["platform"] = item.pop("platform", None)
ret_item["releases"] = list(filter(bool, item.pop("releases", [])))
ret_item["replay_type"] = item.pop("replay_type", "session")
ret_item["started_at"] = item.pop("started_at", None)

ret_item["warning_ids"] = item.pop("warning_ids", None)
ret_item["info_ids"] = item.pop("info_ids", None)
ret_item["count_infos"] = item.pop("count_infos", None)
ret_item["count_warnings"] = item.pop("count_warnings", None)
# Returns a UInt8 of either 0 or 1. We coerce to a bool.
ret_item["has_viewed"] = bool(item.get("has_viewed", 0))
yield ret_item


Expand All @@ -209,32 +212,16 @@ def dict_unique_list(items: Iterable[tuple[str, str]]) -> dict[str, list[str]]:

def _archived_row(replay_id: str, project_id: int) -> dict[str, Any]:
archived_replay_response = {
"browser": {"name": None, "version": None},
"device": {"name": None, "brand": None, "model": None, "family": None},
"error_ids": [],
"id": _strip_dashes(replay_id),
"os": {"name": None, "version": None},
"project_id": str(project_id),
"trace_ids": [],
"error_ids": [],
"environment": None,
"sdk": {"name": None, "version": None},
"tags": [],
"trace_ids": [],
"user": {"id": "Archived Replay", "display_name": "Archived Replay"},
"sdk": {"name": None, "version": None},
"os": {"name": None, "version": None},
"browser": {"name": None, "version": None},
"device": {"name": None, "brand": None, "model": None, "family": None},
"urls": None,
"activity": None,
"count_dead_clicks": None,
"count_rage_clicks": None,
"count_errors": None,
"duration": None,
"finished_at": None,
"started_at": None,
"is_archived": True,
"count_segments": None,
"count_urls": None,
"dist": None,
"platform": None,
"releases": None,
"clicks": None,
}
for field in VALID_FIELD_SET:
if field not in archived_replay_response:
Expand Down
11 changes: 3 additions & 8 deletions src/sentry/replays/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,7 @@
ELIGIBLE_SUBQUERY_SORTS = {"started_at", "browser.name", "os.name"}


# Compatibility function for getsentry code.
def query_replays_collection(*args, **kwargs):
return query_replays_collection_raw(*args, **kwargs)[0]


def query_replays_collection_raw(
def query_replays_collection_paginated(
project_ids: list[int],
start: datetime,
end: datetime,
Expand All @@ -56,8 +51,8 @@ def query_replays_collection_raw(
search_filters: Sequence[SearchFilter],
organization: Organization | None = None,
actor: Any | None = None,
):
"""Query aggregated replay collection."""
) -> tuple[list[dict[str, Any]], bool]:
"""Query aggregated replay collection. Returns (response, has_more)"""
paginators = Paginators(limit, offset)

return query_using_optimized_search(
Expand Down
6 changes: 3 additions & 3 deletions src/sentry/replays/scripts/delete_replays.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from sentry.models.organization import Organization
from sentry.replays.lib.kafka import initialize_replays_publisher
from sentry.replays.post_process import generate_normalized_output
from sentry.replays.query import query_replays_collection, replay_url_parser_config
from sentry.replays.query import query_replays_collection_paginated, replay_url_parser_config
from sentry.replays.tasks import archive_replay, delete_replay_recording_async

logger = logging.getLogger()
Expand All @@ -31,7 +31,7 @@ def delete_replays(
while True:
replays = list(
generate_normalized_output(
query_replays_collection(
query_replays_collection_paginated(
project_ids=[project_id],
start=start_utc,
end=end_utc,
Expand All @@ -42,7 +42,7 @@ def delete_replays(
search_filters=search_filters,
sort="started_at",
organization=Organization.objects.filter(project__id=project_id).get(),
)
)[0]
)
)

Expand Down
Loading