Skip to content

Commit

Permalink
Fix encoded array fill values (#248)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjala authored Aug 22, 2023
1 parent 21e0e25 commit 334bef8
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 26 deletions.
4 changes: 2 additions & 2 deletions hsds/util/arrayUtil.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ def getNumpyValue(value, dt=None, encoding=None):
msg = "Unable to decode base64 string: {value}"
# log.warn(msg)
raise ValueError(msg)
arr = bytesToArray(data, dt, ())
arr = bytesToArray(data, dt, dt.shape)
else:
if isinstance(value, list):
# convert to tuple
Expand All @@ -524,7 +524,7 @@ def getNumpyValue(value, dt=None, encoding=None):
else:
# use as is
pass
arr = np.asarray(value, dtype=dt)
arr = np.asarray(value, dtype=dt.base)
return arr[()]


Expand Down
104 changes: 80 additions & 24 deletions tests/unit/array_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import numpy as np

import sys
import base64

sys.path.append("../..")
from hsds.util.arrayUtil import (
Expand Down Expand Up @@ -681,30 +682,12 @@ def testGetNumpyValue(self):
self.assertTrue(isinstance(val, np.int32))
self.assertEqual(42, val)

# test int with base64 encoding
dt = np.dtype("<i4")
val = getNumpyValue("KgAAAA==", dt=dt, encoding="base64")
self.assertTrue(isinstance(val, np.int32))
self.assertEqual(42, val)

# test float with base64 encoding
dt = np.dtype("f4")
val = getNumpyValue("AADAfw==", dt=dt, encoding="base64")
self.assertTrue(isinstance(val, np.float32))
self.assertTrue(val != val)

# test fixed length string conversion
dt = np.dtype("S5")
val = getNumpyValue("hello", dt=dt)
self.assertTrue(isinstance(val, np.bytes_))
self.assertEqual(val, b"hello")

# base64 encoded string
dt = np.dtype("S5")
val = getNumpyValue("aGVsbG8=", dt=dt, encoding="base64")
self.assertTrue(isinstance(val, np.bytes_))
self.assertEqual(val, b"hello")

# test variable length string conversion
dt = np.dtype("O", metadata={"vlen": bytes})
val = getNumpyValue("hello", dt=dt)
Expand All @@ -718,19 +701,92 @@ def testGetNumpyValue(self):
self.assertEqual(val[0], 42)
self.assertEqual(val[1], b'hdf5')

# test compound type encoded
dt = np.dtype([('int', "<i4"), ('str', "S4")])
val = getNumpyValue("KgAAAGhkZjU=", dt=dt, encoding="base64")
self.assertTrue(isinstance(val, np.void))
self.assertEqual(val[0], 42)
self.assertEqual(val[1], b'hdf5')
# test array of ints
dt = np.dtype("<i4")
arr = np.array([0, 1], dtype=dt)
dt = np.dtype(("<i4", (len(arr),)))
val = getNumpyValue(arr, dt=dt)

self.assertTrue(np.array_equal(val, arr))
self.assertTrue(isinstance(val[0], np.int32))

# test array of floats
dt = np.dtype("f4")
arr = np.array([0.001, 1.001], dtype=dt)
val = getNumpyValue(arr, dt=np.dtype(("f4", (len(arr),))))

self.assertTrue(np.array_equal(val, arr))
self.assertTrue(isinstance(val[0], np.float32))

# test array of fixed-length strings
dt = np.dtype("S5")
arr = np.array([b'hello', b'world'], dtype=dt)
val = getNumpyValue(arr, dt=np.dtype(("S5", (len(arr),))))

self.assertTrue(np.array_equal(val, arr))
self.assertTrue(isinstance(val[0], np.bytes_))

# test nan string
dt = np.dtype("f4")
val = getNumpyValue("nan", dt=dt)
self.assertTrue(isinstance(val, np.float32))
self.assertTrue(val != val)

def testGetNumpyValueBase64Encoded(self):
# Set up value, numpy dtype, and expected type after decoding
value_info = []
value_info.append([42, np.dtype("<i4"), np.int32]) # int
value_info.append([1.001, np.dtype("f4"), np.float32]) # float
value_info.append([b"hello", np.dtype("S5"), np.bytes_]) # fixed-length string
value_info.append([(42, b'hdf5'),
np.dtype([('int', "<i4"), ('str', "S4")]), np.void]) # compound type
np_values = []

for vi in value_info:
np_values.append(np.array(vi[0], dtype=vi[1]))

for i in range(len(np_values)):
numpy_dtype_out = value_info[i][2]

# Turn numpy array to bytes object which can be encoded
encoded_val = np_values[i].tobytes()
# Encode numpy bytes object
encoded_val = base64.b64encode(encoded_val)
# Decode from bytes object to regular string containing a base64 encoded numpy array
# This prevents the utf-8 encoding inside getNumpyValue from prepending b'
encoded_val = encoded_val.decode()
decoded_val = getNumpyValue(encoded_val, dt=np_values[i].dtype, encoding="base64")
self.assertTrue(isinstance(decoded_val, numpy_dtype_out))
self.assertEqual(decoded_val, np_values[i])

# test array types

# Set up value, numpy dtype, and expected type after decoding
value_info = []
value_info.append([np.array([0, 1], dtype=np.dtype("<i4")),
np.dtype(("<i4", (2,))), np.int32]) # int array
value_info.append([np.array([0.001, 1.001], dtype=np.dtype("f4")),
np.dtype(("f4", (2,))), np.float32]) # float array
value_info.append([np.array([b'hello', b'world'], dtype=np.dtype("S5")),
np.dtype(("S5", (2,))), np.bytes_]) # fixed length string array

for i in range(len(value_info)):
this_array = value_info[i][0]
array_dtype = value_info[i][1]
array_dtype_out = value_info[i][2]

# Turn numpy array to bytes object which can be encoded
encoded_val = this_array.tobytes()
# Encode numpy bytes object
encoded_val = base64.b64encode(encoded_val)
# Decode from bytes object to regular string containing a base64 encoded numpy array
# This prevents the utf-8 encoding inside getNumpyValue from prepending b'
encoded_val = encoded_val.decode()
decoded_val = getNumpyValue(encoded_val, dt=array_dtype, encoding="base64")

self.assertTrue(np.array_equal(decoded_val, this_array))
self.assertTrue(isinstance(decoded_val[0], array_dtype_out))

# test invalid base64 length
try:
dt = np.dtype("<i8")
Expand Down

0 comments on commit 334bef8

Please sign in to comment.