Skip to content

Commit

Permalink
Remove RandomType's get_shape_info and get_size methods
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Jul 29, 2021
1 parent d9fd640 commit 7d07260
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 56 deletions.
15 changes: 0 additions & 15 deletions aesara/tensor/random/type.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import sys

import numpy as np

import aesara
Expand Down Expand Up @@ -30,10 +28,6 @@ def filter(cls, data, strict=False, allow_downcast=None):
else:
raise TypeError()

@staticmethod
def get_shape_info(obj):
return obj.get_value(borrow=True)

@staticmethod
def may_share_memory(a, b):
return a._bit_generator is b._bit_generator
Expand Down Expand Up @@ -99,10 +93,6 @@ def _eq(sa, sb):

return _eq(sa, sb)

@staticmethod
def get_size(shape_info):
return sys.getsizeof(shape_info.get_state(legacy=False))


# Register `RandomStateType`'s C code for `ViewOp`.
aesara.compile.register_view_op_c_code(
Expand Down Expand Up @@ -184,11 +174,6 @@ def _eq(sa, sb):

return _eq(sa, sb)

@staticmethod
def get_size(shape_info):
state = shape_info.__getstate__()
return sys.getsizeof(state)


# Register `RandomGeneratorType`'s C code for `ViewOp`.
aesara.compile.register_view_op_c_code(
Expand Down
65 changes: 24 additions & 41 deletions tests/tensor/random/test_type.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import pickle
import sys

import numpy as np
import pytest
Expand Down Expand Up @@ -95,21 +94,6 @@ def test_values_eq(self):
assert not rng_type.values_eq(rng_g, rng_a)
assert not rng_type.values_eq(rng_e, rng_g)

def test_get_shape_info(self):
rng = np.random.RandomState(12)
rng_a = shared(rng)

assert isinstance(
random_state_type.get_shape_info(rng_a), np.random.RandomState
)

def test_get_size(self):
rng = np.random.RandomState(12)
rng_a = shared(rng)
shape_info = random_state_type.get_shape_info(rng_a)
size = random_state_type.get_size(shape_info)
assert size == sys.getsizeof(rng.get_state(legacy=False))

def test_may_share_memory(self):
bg1 = np.random.MT19937()
bg2 = np.random.MT19937()
Expand All @@ -119,16 +103,23 @@ def test_may_share_memory(self):

rng_var_a = shared(rng_a, borrow=True)
rng_var_b = shared(rng_b, borrow=True)
shape_info_a = random_state_type.get_shape_info(rng_var_a)
shape_info_b = random_state_type.get_shape_info(rng_var_b)

assert random_state_type.may_share_memory(shape_info_a, shape_info_b) is False
assert (
random_state_type.may_share_memory(
rng_var_a.get_value(borrow=True), rng_var_b.get_value(borrow=True)
)
is False
)

rng_c = np.random.RandomState(bg2)
rng_var_c = shared(rng_c, borrow=True)
shape_info_c = random_state_type.get_shape_info(rng_var_c)

assert random_state_type.may_share_memory(shape_info_b, shape_info_c) is True
assert (
random_state_type.may_share_memory(
rng_var_b.get_value(borrow=True), rng_var_c.get_value(borrow=True)
)
is True
)


class TestRandomGeneratorType:
Expand Down Expand Up @@ -197,21 +188,6 @@ def test_values_eq(self):
assert rng_type.is_valid_value(bitgen_g, strict=True)
assert rng_type.is_valid_value(bitgen_h.__getstate__(), strict=False)

def test_get_shape_info(self):
rng = np.random.default_rng(12)
rng_a = shared(rng)

assert isinstance(
random_generator_type.get_shape_info(rng_a), np.random.Generator
)

def test_get_size(self):
rng = np.random.Generator(np.random.PCG64(12))
rng_a = shared(rng)
shape_info = random_generator_type.get_shape_info(rng_a)
size = random_generator_type.get_size(shape_info)
assert size == sys.getsizeof(rng.__getstate__())

def test_may_share_memory(self):
bg_a = np.random.PCG64()
bg_b = np.random.PCG64()
Expand All @@ -220,13 +196,20 @@ def test_may_share_memory(self):

rng_var_a = shared(rng_a, borrow=True)
rng_var_b = shared(rng_b, borrow=True)
shape_info_a = random_state_type.get_shape_info(rng_var_a)
shape_info_b = random_state_type.get_shape_info(rng_var_b)

assert random_state_type.may_share_memory(shape_info_a, shape_info_b) is False
assert (
random_state_type.may_share_memory(
rng_var_a.get_value(borrow=True), rng_var_b.get_value(borrow=True)
)
is False
)

rng_c = np.random.Generator(bg_b)
rng_var_c = shared(rng_c, borrow=True)
shape_info_c = random_state_type.get_shape_info(rng_var_c)

assert random_state_type.may_share_memory(shape_info_b, shape_info_c) is True
assert (
random_state_type.may_share_memory(
rng_var_b.get_value(borrow=True), rng_var_c.get_value(borrow=True)
)
is True
)

0 comments on commit 7d07260

Please sign in to comment.