Skip to content

Commit 1621857

Browse files
authored
Cache incremental listens and metadata tables in Spark (#3178)
* Test persisting incremental listens df * fix type issue * test incremental users df * test artist country code df * test artist country code df - 2 * test rec/rel/rg cache dfs * fix typo * globalize cache table handling * fix typo * fix typo - 2
1 parent 666de8d commit 1621857

34 files changed

+265
-191
lines changed

listenbrainz_spark/hdfs/upload.py

+12
Original file line numberDiff line numberDiff line change
@@ -183,3 +183,15 @@ def process_full_listens_dump(self):
183183

184184
if path_exists(path.LISTENBRAINZ_BASE_STATS_DIRECTORY):
185185
hdfs_connection.client.delete(path.LISTENBRAINZ_BASE_STATS_DIRECTORY, recursive=True, skip_trash=True)
186+
187+
def process_incremental_listens_dump(self):
188+
query = f"""
189+
SELECT user_id
190+
, max(created) AS created
191+
FROM parquet.`{path.INCREMENTAL_DUMPS_SAVE_PATH}`
192+
GROUP BY user_id
193+
"""
194+
run_query(query) \
195+
.write \
196+
.mode("overwrite") \
197+
.parquet(path.INCREMENTAL_USERS_DF)

listenbrainz_spark/path.py

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
# path to save incremental dumps
2020
INCREMENTAL_DUMPS_SAVE_PATH = os.path.join(LISTENBRAINZ_NEW_DATA_DIRECTORY, "incremental.parquet")
21+
INCREMENTAL_USERS_DF = os.path.join("/", "incremental-users")
2122

2223
# Directory containing RDD checkpoints to break lineage while using iterative algorithms.
2324
CHECKPOINT_DIR = os.path.join('/', 'checkpoint')

listenbrainz_spark/persisted.py

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from typing import Optional
2+
3+
from pandas import DataFrame
4+
5+
from listenbrainz_spark.path import INCREMENTAL_DUMPS_SAVE_PATH, INCREMENTAL_USERS_DF
6+
from listenbrainz_spark.utils import read_files_from_HDFS
7+
8+
_incremental_listens_df: Optional[DataFrame] = None
9+
_incremental_users_df: Optional[DataFrame] = None
10+
11+
12+
def unpersist_incremental_df():
13+
global _incremental_listens_df, _incremental_users_df
14+
if _incremental_listens_df is not None:
15+
_incremental_listens_df.unpersist()
16+
_incremental_listens_df = None
17+
if _incremental_users_df is not None:
18+
_incremental_users_df.unpersist()
19+
_incremental_users_df = None
20+
21+
22+
def get_incremental_listens_df() -> DataFrame:
23+
global _incremental_listens_df
24+
if _incremental_listens_df is None:
25+
_incremental_listens_df = read_files_from_HDFS(INCREMENTAL_DUMPS_SAVE_PATH)
26+
_incremental_listens_df.persist()
27+
return _incremental_listens_df
28+
29+
30+
def get_incremental_users_df() -> DataFrame:
31+
global _incremental_users_df
32+
if _incremental_users_df is None:
33+
_incremental_users_df = read_files_from_HDFS(INCREMENTAL_USERS_DF)
34+
_incremental_users_df.persist()
35+
return _incremental_users_df

listenbrainz_spark/popularity/listens.py

+12-19
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from datetime import datetime
22
from typing import List, Optional
33

4-
from listenbrainz_spark.path import LISTENBRAINZ_POPULARITY_DIRECTORY, RELEASE_METADATA_CACHE_DATAFRAME
4+
from listenbrainz_spark.path import LISTENBRAINZ_POPULARITY_DIRECTORY
55
from listenbrainz_spark.popularity.common import get_popularity_per_artist_query, \
66
get_release_group_popularity_per_artist_query, get_popularity_query
7+
from listenbrainz_spark.postgres.release import get_release_metadata_cache
78
from listenbrainz_spark.stats.incremental.query_provider import QueryProvider
89
from listenbrainz_spark.stats.incremental.range_selector import ListenRangeSelector
910

@@ -25,7 +26,7 @@ def get_base_path(self) -> str:
2526
return LISTENBRAINZ_POPULARITY_DIRECTORY
2627

2728
def get_filter_aggregate_query(self, existing_aggregate: str, incremental_aggregate: str,
28-
existing_created: Optional[datetime], cache_tables: List[str]) -> str:
29+
existing_created: Optional[datetime]) -> str:
2930
inc_where_clause = f"WHERE created >= to_timestamp('{existing_created}')" if existing_created else ""
3031
entity_id = self.get_entity_id()
3132
return f"""
@@ -37,23 +38,19 @@ def get_filter_aggregate_query(self, existing_aggregate: str, incremental_aggreg
3738
WHERE EXISTS(SELECT 1 FROM incremental_users iu WHERE iu.{entity_id} = ea.{entity_id})
3839
"""
3940

40-
def get_cache_tables(self) -> List[str]:
41-
if self.entity == "release_group":
42-
return [RELEASE_METADATA_CACHE_DATAFRAME]
43-
return []
44-
4541
def get_entity_id(self):
4642
return self.entity + "_mbid"
4743

48-
def get_aggregate_query(self, table: str, cache_tables: List[str]) -> str:
44+
def get_aggregate_query(self, table: str) -> str:
4945
if self.entity == "artist":
5046
return get_popularity_per_artist_query("artist", table)
5147
elif self.entity == "release_group":
52-
return get_release_group_popularity_per_artist_query(table, cache_tables[0])
48+
rel_cache_table = get_release_metadata_cache()
49+
return get_release_group_popularity_per_artist_query(table, rel_cache_table)
5350
else:
5451
return get_popularity_query(self.entity, table)
5552

56-
def get_stats_query(self, final_aggregate: str, cache_tables: List[str]) -> str:
53+
def get_stats_query(self, final_aggregate: str) -> str:
5754
return f"SELECT * FROM {final_aggregate}"
5855

5956
def get_combine_aggregates_query(self, existing_aggregate: str, incremental_aggregate: str) -> str:
@@ -95,7 +92,7 @@ def get_base_path(self) -> str:
9592
return LISTENBRAINZ_POPULARITY_DIRECTORY
9693

9794
def get_filter_aggregate_query(self, existing_aggregate: str, incremental_aggregate: str,
98-
existing_created: Optional[datetime], cache_tables: List[str]) -> str:
95+
existing_created: Optional[datetime]) -> str:
9996
inc_where_clause = f"WHERE created >= to_timestamp('{existing_created}')" if existing_created else ""
10097
entity_id = self.get_entity_id()
10198
return f"""
@@ -111,20 +108,16 @@ def get_filter_aggregate_query(self, existing_aggregate: str, incremental_aggreg
111108
)
112109
"""
113110

114-
def get_cache_tables(self) -> List[str]:
115-
if self.entity == "release_group":
116-
return [RELEASE_METADATA_CACHE_DATAFRAME]
117-
return []
118-
119111
def get_entity_id(self):
120112
return self.entity + "_mbid"
121113

122-
def get_aggregate_query(self, table: str, cache_tables: List[str]) -> str:
114+
def get_aggregate_query(self, table: str) -> str:
123115
if self.entity == "release_group":
124-
return get_release_group_popularity_per_artist_query(table, cache_tables[0])
116+
rel_cache_table = get_release_metadata_cache()
117+
return get_release_group_popularity_per_artist_query(table, rel_cache_table)
125118
return get_popularity_per_artist_query(self.entity, table)
126119

127-
def get_stats_query(self, final_aggregate: str, cache_tables: List[str]) -> str:
120+
def get_stats_query(self, final_aggregate: str) -> str:
128121
return f"SELECT * FROM {final_aggregate}"
129122

130123
def get_combine_aggregates_query(self, existing_aggregate: str, incremental_aggregate: str) -> str:

listenbrainz_spark/popularity/mlhd.py

+2-13
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,6 @@ class MlhdStatsEngine:
2121
def __init__(self, provider: QueryProvider, message_creator: MessageCreator):
2222
self.provider = provider
2323
self.message_creator = message_creator
24-
self._cache_tables = []
25-
26-
def _setup_cache_tables(self):
27-
""" Set up metadata cache tables by reading data from HDFS and creating temporary views. """
28-
cache_tables = []
29-
for idx, df_path in enumerate(self.provider.get_cache_tables()):
30-
df_name = f"entity_data_cache_{idx}"
31-
cache_tables.append(df_name)
32-
read_files_from_HDFS(df_path).createOrReplaceTempView(df_name)
33-
self._cache_tables = cache_tables
3424

3525
def create_partial_aggregate(self) -> DataFrame:
3626
metadata_path = self.provider.get_bookkeeping_path()
@@ -41,7 +31,7 @@ def create_partial_aggregate(self) -> DataFrame:
4131

4232
logger.info("Creating partial aggregate from full dump listens")
4333
hdfs_connection.client.makedirs(Path(existing_aggregate_path).parent)
44-
full_query = self.provider.get_aggregate_query(table, self._cache_tables)
34+
full_query = self.provider.get_aggregate_query(table)
4535
full_df = run_query(full_query)
4636
full_df.write.mode("overwrite").parquet(existing_aggregate_path)
4737

@@ -56,15 +46,14 @@ def create_partial_aggregate(self) -> DataFrame:
5646
return full_df
5747

5848
def generate_stats(self) -> DataFrame:
59-
self._setup_cache_tables()
6049
prefix = self.provider.get_table_prefix()
6150
self.create_partial_aggregate()
6251

6352
partial_df = read_files_from_HDFS(self.provider.get_existing_aggregate_path())
6453
partial_table = f"{prefix}_existing_aggregate"
6554
partial_df.createOrReplaceTempView(partial_table)
6655

67-
results_query = self.provider.get_stats_query(partial_table, self._cache_tables)
56+
results_query = self.provider.get_stats_query(partial_table)
6857
results_df = run_query(results_query)
6958
return results_df
7059

listenbrainz_spark/postgres/artist.py

+24-1
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,18 @@
1+
from typing import Optional
2+
13
import pycountry
4+
from pyspark import StorageLevel
5+
from pyspark.sql import DataFrame
26

37
import listenbrainz_spark
48
from listenbrainz_spark import config
59
from listenbrainz_spark.path import ARTIST_COUNTRY_CODE_DATAFRAME
610
from listenbrainz_spark.postgres.utils import load_from_db
711
from listenbrainz_spark.stats import run_query
12+
from listenbrainz_spark.utils import read_files_from_HDFS
13+
14+
_ARTIST_COUNTRY_CACHE = "artist_country_cache"
15+
_artist_country_df: Optional[DataFrame] = None
816

917

1018
def create_iso_country_codes_df():
@@ -18,7 +26,6 @@ def create_iso_country_codes_df():
1826
df.createOrReplaceTempView("iso_codes")
1927

2028

21-
2229
def create_artist_country_cache():
2330
""" Import artist country from postgres to HDFS for use in artist map stats calculation. """
2431
query = """
@@ -64,3 +71,19 @@ def create_artist_country_cache():
6471
.write \
6572
.format("parquet") \
6673
.save(config.HDFS_CLUSTER_URI + ARTIST_COUNTRY_CODE_DATAFRAME, mode="overwrite")
74+
75+
global _artist_country_df
76+
if _artist_country_df is not None:
77+
_artist_country_df.unpersist()
78+
_artist_country_df = None
79+
80+
81+
def get_artist_country_cache():
82+
""" Read the ARTIST_COUNTRY_CACHE parquet files from HDFS and create a spark SQL view
83+
if one already doesn't exist """
84+
global _artist_country_df
85+
if _artist_country_df is None:
86+
_artist_country_df = read_files_from_HDFS(ARTIST_COUNTRY_CODE_DATAFRAME)
87+
_artist_country_df.persist(StorageLevel.DISK_ONLY)
88+
_artist_country_df.createOrReplaceTempView(_ARTIST_COUNTRY_CACHE)
89+
return _ARTIST_COUNTRY_CACHE

listenbrainz_spark/postgres/recording.py

+26
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
1+
from typing import Optional
2+
3+
from pyspark import StorageLevel
4+
from pyspark.sql import DataFrame
5+
16
from listenbrainz_spark.path import RECORDING_LENGTH_DATAFRAME, RECORDING_ARTIST_DATAFRAME
27
from listenbrainz_spark.postgres.utils import save_pg_table_to_hdfs
8+
from listenbrainz_spark.utils import read_files_from_HDFS
9+
10+
_RECORDING_ARTIST_CACHE = "recording_artist_cache"
11+
_recording_artist_df: Optional[DataFrame] = None
312

413

514
def create_recording_length_cache():
@@ -35,3 +44,20 @@ def create_recording_artist_cache():
3544
"""
3645

3746
save_pg_table_to_hdfs(query, RECORDING_ARTIST_DATAFRAME, process_artists_column=True)
47+
48+
global _recording_artist_df
49+
if _recording_artist_df is not None:
50+
_recording_artist_df.unpersist()
51+
_recording_artist_df = None
52+
53+
54+
def get_recording_artist_cache():
55+
""" Read the RECORDING_ARTIST_CACHE parquet files from HDFS and create a spark SQL view
56+
if one already doesn't exist """
57+
global _recording_artist_df
58+
if _recording_artist_df is None:
59+
_recording_artist_df = read_files_from_HDFS(RECORDING_ARTIST_DATAFRAME)
60+
_recording_artist_df.persist(StorageLevel.DISK_ONLY)
61+
_recording_artist_df.createOrReplaceTempView(_RECORDING_ARTIST_CACHE)
62+
return _RECORDING_ARTIST_CACHE
63+

listenbrainz_spark/postgres/release.py

+25
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
1+
from typing import Optional
2+
3+
from pyspark import StorageLevel
4+
from pyspark.sql import DataFrame
5+
16
from listenbrainz_spark.path import RELEASE_METADATA_CACHE_DATAFRAME
27
from listenbrainz_spark.postgres.utils import save_pg_table_to_hdfs
8+
from listenbrainz_spark.utils import read_files_from_HDFS
9+
10+
_RELEASE_METADATA_CACHE = "release_metadata_cache"
11+
_release_metadata_df: Optional[DataFrame] = None
312

413

514
def create_release_metadata_cache():
@@ -104,3 +113,19 @@ def create_release_metadata_cache():
104113
"""
105114

106115
save_pg_table_to_hdfs(query, RELEASE_METADATA_CACHE_DATAFRAME, process_artists_column=True)
116+
117+
global _release_metadata_df
118+
if _release_metadata_df is not None:
119+
_release_metadata_df.unpersist()
120+
_release_metadata_df = None
121+
122+
123+
def get_release_metadata_cache():
124+
""" Read the RELEASE_METADATA_CACHE parquet files from HDFS and create a spark SQL view
125+
if one already doesn't exist """
126+
global _release_metadata_df
127+
if _release_metadata_df is None:
128+
_release_metadata_df = read_files_from_HDFS(RELEASE_METADATA_CACHE_DATAFRAME)
129+
_release_metadata_df.persist(StorageLevel.DISK_ONLY)
130+
_release_metadata_df.createOrReplaceTempView(_RELEASE_METADATA_CACHE)
131+
return _RELEASE_METADATA_CACHE

listenbrainz_spark/postgres/release_group.py

+25
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
1+
from typing import Optional
2+
3+
from pyspark import StorageLevel
4+
from pyspark.sql import DataFrame
5+
16
from listenbrainz_spark.path import RELEASE_GROUP_METADATA_CACHE_DATAFRAME
27
from listenbrainz_spark.postgres.utils import save_pg_table_to_hdfs
8+
from listenbrainz_spark.utils import read_files_from_HDFS
9+
10+
_RELEASE_GROUP_METADATA_CACHE = "release_group_metadata_cache"
11+
_release_group_metadata_df: Optional[DataFrame] = None
312

413

514
def create_release_group_metadata_cache():
@@ -72,3 +81,19 @@ def create_release_group_metadata_cache():
7281
"""
7382

7483
save_pg_table_to_hdfs(query, RELEASE_GROUP_METADATA_CACHE_DATAFRAME, process_artists_column=True)
84+
85+
global _release_group_metadata_df
86+
if _release_group_metadata_df is not None:
87+
_release_group_metadata_df.unpersist()
88+
_release_group_metadata_df = None
89+
90+
91+
def get_release_group_metadata_cache():
92+
""" Read the RELEASE_GROUP_METADATA_CACHE parquet files from HDFS and create a spark SQL view
93+
if one already doesn't exist """
94+
global _release_group_metadata_df
95+
if _release_group_metadata_df is None:
96+
_release_group_metadata_df = read_files_from_HDFS(RELEASE_GROUP_METADATA_CACHE_DATAFRAME)
97+
_release_group_metadata_df.persist(StorageLevel.DISK_ONLY)
98+
_release_group_metadata_df.createOrReplaceTempView(_RELEASE_GROUP_METADATA_CACHE)
99+
return _RELEASE_GROUP_METADATA_CACHE

listenbrainz_spark/request_consumer/jobs/import_dump.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
11
""" Spark job that downloads the latest listenbrainz dumps and imports into HDFS
22
"""
33
import logging
4-
import shutil
54
import tempfile
6-
import time
7-
from datetime import datetime
5+
from datetime import datetime, timezone
86

97
import listenbrainz_spark.request_consumer.jobs.utils as utils
108
from listenbrainz_spark.dump import DumpType
119
from listenbrainz_spark.dump.local import ListenbrainzLocalDumpLoader
1210
from listenbrainz_spark.ftp.download import ListenbrainzDataDownloader
1311
from listenbrainz_spark.hdfs.upload import ListenbrainzDataUploader
12+
from listenbrainz_spark.persisted import unpersist_incremental_df
1413

1514
logger = logging.getLogger(__name__)
1615

@@ -40,7 +39,8 @@ def import_full_dump_to_hdfs(dump_id: int = None, local: bool = False) -> str:
4039
uploader = ListenbrainzDataUploader()
4140
uploader.upload_new_listens_full_dump(src)
4241
uploader.process_full_listens_dump()
43-
utils.insert_dump_data(dump_id, DumpType.FULL, datetime.utcnow())
42+
utils.insert_dump_data(dump_id, DumpType.FULL, datetime.now(tz=timezone.utc))
43+
unpersist_incremental_listens_df()
4444
return dump_name
4545

4646

@@ -68,8 +68,11 @@ def import_incremental_dump_to_hdfs(dump_id: int = None, local: bool = False) ->
6868
# instantiating ListenbrainzDataUploader creates a spark session which
6969
# is a bit non-intuitive.
7070
# FIXME in future to make initializing of spark session more explicit?
71-
ListenbrainzDataUploader().upload_new_listens_incremental_dump(src)
72-
utils.insert_dump_data(dump_id, DumpType.INCREMENTAL, datetime.utcnow())
71+
uploader = ListenbrainzDataUploader()
72+
uploader.upload_new_listens_incremental_dump(src)
73+
uploader.process_incremental_listens_dump()
74+
utils.insert_dump_data(dump_id, DumpType.INCREMENTAL, datetime.now(tz=timezone.utc))
75+
unpersist_incremental_df()
7376
return dump_name
7477

7578

0 commit comments

Comments
 (0)