From a59a89cb64611bd81361cca8c1cf6efacb9061d1 Mon Sep 17 00:00:00 2001 From: mib1185 Date: Sun, 5 Jan 2025 20:03:22 +0000 Subject: [PATCH] allow to use raw response content (StreamReader) --- src/synology_dsm/synology_dsm.py | 49 +++++++++++++++++++++++++------- tests/__init__.py | 4 ++- 2 files changed, 41 insertions(+), 12 deletions(-) diff --git a/src/synology_dsm/synology_dsm.py b/src/synology_dsm/synology_dsm.py index 039931b..5fd88b0 100644 --- a/src/synology_dsm/synology_dsm.py +++ b/src/synology_dsm/synology_dsm.py @@ -11,7 +11,14 @@ from typing import Any, Coroutine, TypedDict from urllib.parse import quote, urlencode -from aiohttp import ClientError, ClientSession, ClientTimeout, MultipartWriter, hdrs +from aiohttp import ( + ClientError, + ClientSession, + ClientTimeout, + MultipartWriter, + StreamReader, + hdrs, +) from yarl import URL from .api import SynoBaseApi @@ -239,13 +246,13 @@ def device_token(self) -> str | None: async def get( self, api: str, method: str, params: dict | None = None, **kwargs: Any - ) -> bytes | dict | str: + ) -> bytes | dict | str | StreamReader: """Handles API GET request.""" return await self._request("GET", api, method, params, **kwargs) async def post( self, api: str, method: str, params: dict | None = None, **kwargs: Any - ) -> bytes | dict | str: + ) -> bytes | dict | str | StreamReader: """Handles API POST request.""" return await self._request("POST", api, method, params, **kwargs) @@ -310,8 +317,9 @@ async def _request( method: str, params: dict | None = None, retry_once: bool = True, + raw_response_content: bool = False, **kwargs: Any, - ) -> bytes | dict | str: + ) -> bytes | dict | str | StreamReader: """Handles API request.""" url, params, kwargs = await self._prepare_request(api, method, params, **kwargs) @@ -319,9 +327,12 @@ async def _request( self._debuglog("---------------------------------------------------------") self._debuglog("API: " + api) self._debuglog("Request Method: " + request_method) - response = await self._execute_request(request_method, url, params, **kwargs) + response = await self._execute_request( + request_method, url, params, raw_response_content, **kwargs + ) self._debuglog("Successful returned data") - self._debuglog("RESPONSE: " + str(response)) + if not raw_response_content: + self._debuglog("RESPONSE: " + str(response)) # Handle data errors if isinstance(response, dict) and response.get("error") and api != API_AUTH: @@ -339,8 +350,13 @@ async def _request( return response async def _execute_request( - self, method: str, url: URL, params: dict, **kwargs: Any - ) -> bytes | dict | str: + self, + method: str, + url: URL, + params: dict, + raw_response_content: bool = False, + **kwargs: Any, + ) -> bytes | dict | str | StreamReader: """Function to execute and handle a request.""" # special handling for spaces in parameters # because yarl.URL does encode a space as + instead of %20 @@ -348,10 +364,18 @@ async def _execute_request( query = urlencode(params, safe="?/:@-._~!$'()*,", quote_via=quote) url_encoded = url.join(URL(f"?{query}", encoded=True)) + if params.get("api") in [ + SynoFileStation.UPLOAD_API_KEY, + SynoFileStation.DOWNLOAD_API_KEY, + ]: + timeout = ClientTimeout(connect=10.0, total=43200.0) + else: + timeout = self._aiohttp_timeout + try: if method == "GET": response = await self._session.get( - url_encoded, timeout=self._aiohttp_timeout, **kwargs + url_encoded, timeout=timeout, **kwargs ) elif ( method == "POST" and params.get("api") == SynoFileStation.UPLOAD_API_KEY @@ -377,7 +401,7 @@ async def _execute_request( response = await self._session.post( url_encoded, - timeout=ClientTimeout(connect=10.0, total=43200.0), + timeout=timeout, data=mp, ) elif method == "POST": @@ -390,7 +414,7 @@ async def _execute_request( self._debuglog("POST data: " + str(data)) response = await self._session.post( - url_encoded, timeout=self._aiohttp_timeout, **kwargs + url_encoded, timeout=timeout, **kwargs ) # mask sesitive parameters @@ -408,6 +432,9 @@ async def _execute_request( # We got a DSM response content_type = response.headers.get("Content-Type", "").split(";")[0] + if raw_response_content: + return response.content + if content_type in [ "application/json", "text/json", diff --git a/tests/__init__.py b/tests/__init__.py index 3bbd76c..b97271c 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -179,7 +179,9 @@ def __init__( self.error = False self.with_surveillance = False - async def _execute_request(self, method, url, params, **kwargs): + async def _execute_request( + self, method, url, params, raw_response_content, **kwargs + ): url = str(url) url += urlencode(params or {})