-
-
Notifications
You must be signed in to change notification settings - Fork 152
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor/replace/remove aesara.tensor.basic.get_scalar_constant_value
#96
Comments
Here's some more info on the situation with The function For example, the following provides some simple cases in which the import theano.tensor as tt
test_array = tt.as_tensor([[1], [4]])
test_val = test_array.shape[-1] # i.e. 1
tt.get_scalar_constant_value(test_val) # works
test_val = test_array.shape[0] # i.e. 2
tt.get_scalar_constant_value(test_val) # doesn't work
test_val = test_array.shape[:-1] # i.e. [2]
tt.get_scalar_constant_value(test_val) # doesn't work Now, we could fix these specific cases, but the entire idea behind Here's a simple demonstration for the last example above: import theano
test_val = test_array.shape[:-1]
fgraph = theano.gof.fg.FunctionGraph(theano.gof.graph.inputs([test_val]), [test_val], clone=True)
res = tt.opt.topo_constant_folding.optimize(fgraph) >>> tt_dprint(fgraph)
TensorConstant{(1,) of 2} [id A] The only downsides to this approach (that I can think of right now) involve the overhead for cloning and possibly some cases of overly cautious First off, graph cloning should be very cheap, and, if it isn't, then we need to fix that ASAP—regardless of |
Took a stab at this and came up with this def get_constant_value(
v, scalar=True, *args, **kwargs
):
"""Return the constant value underlying variable `v`.
If `v` is the output of dimshuffles, fills, allocs, rebroadcasts,
cast, OutputGuard, DeepCopyOp, ScalarFromTensor, ScalarOp, Elemwise
and some pattern with Subtensor, this function digs through them.
If `v` is not some view of constant data, then raise a
NotConstantError. If specified `v` as scalar, this will raise
a NonScalarConstantError
Parameters
----------
v: Variable
Variable to be evaluated
scalar: bool
Specify if the returned value should be a scalar
"""
if v is None:
raise NotConstantError()
v = as_tensor_variable(v)
if not isinstance(v, Constant) and v.owner is not None:
from aesara.graph.opt_utils import optimize_graph
v_fgraph_clone = FunctionGraph([*graph_inputs([v])], [v], clone=True)
optimize_graph(v_fgraph_clone)
v = v_fgraph_clone.outputs[0]
if not isinstance(v, Constant):
raise NotConstantError()
elif (scalar and v.ndim != 0):
raise NotScalarConstantError()
else:
unique_value = get_unique_value(v)
if unique_value is not None:
v = unique_value
else:
v = v.data
return v The issue with using 'only' For instance: from aesara.tensor.basic_opt import topo_constant_folding
from aesara.graph.fg import FunctionGraph
from aesara.tensor.type import iscalar
import aesara.tensor.basic as at
import aesara
from aesara.graph.basic import graph_inputs
def get_constant_value(v):
v = at.as_tensor_variable(v)
v_fgraph_clone = FunctionGraph([*graph_inputs([v])], [v], clone=True)
topo_constant_folding.optimize(v_fgraph_clone)
v = v_fgraph_clone.outputs[0]
return v
b = iscalar()
a = at.stack([b, 2, 3])
aesara.dprint(get_constant_value(a[1]))
# Isn't able to convert it into TensorConstant{2} by itself
# Subtensor{int64} [id A] ''
# |MakeVector{dtype='int32'} [id B] ''
# | |<TensorType(int32, scalar)> [id C]
# | |TensorConstant{2} [id D]
# | |TensorConstant{3} [id E]
# |ScalarConstant{1} [id F] I do agree that calling the entire set of optimizations i.e. |
That's a great example, and it could also be telling us that perhaps we need to extend/improve our notion (and implementation) of constant folding. I say this because there's a weird dependency chain in this situation: the non-constant folding rewrite that performs that Instead, it might make more sense to consider such a rewrite as a type of constant folding and somehow incorporate it into that framework, then there's no drawn out dependency chain. Aside from that consideration, I think you have the correct overall approach, but that we need to start working out a way to make these rewrites-within-rewrites more efficient. This is also a rising issue for our future If we could continue to use in-place updates in these situations, that would be great, but we really can't circumvent the rewrite framework like that without incurring some problems/challenges. |
aesara.tensor.basic.get_scalar_constant_value
Actually, I'm starting to wonder how much we even need We've updated and added a lot of missing canonicalizations recently, and most—if not all—of the steps performed by If you think of it that way, using The only instances where that's not true is when a rewrite uses |
We can just keep the above |
Yeah, that's likely the best course of action. |
@brandonwillard Can you describe the cache based memoization approach you proposed in the meet for this particular functionality. I remember it being something like to be able to call certain constant folding optimizations at proper 'checkpoint's do=uring graph optimizations. |
I was really just thinking of using something like The question is really about how we can effectively use caching in these cases given that |
aesara.tensor.get_scalar_constant_value
is a utility function that contains unreasonably long conditional statements and introduces more unnecessary cross-module—and ultimately circular—dependencies (e.g.aesara.tensor.subtensor.Subtensor.get_constant_idx
callsaesara.tensor.get_scalar_constant_value
and vice-versa).Let's fix this situation by
moving this utility function into its own module and/or reimplementing it using some form of dispatch (e.g. single-dispatch, if possible). With the latter, theusing the existing constant folding capabilities—perhaps with some additional features (e.g. more fine-grained folding conditions).Op
-specific parts ofget_scalar_constant_value
can be defined within eachOp
's own module and utilized only when relevant (i.e. when/if theOp
itself is imported)The text was updated successfully, but these errors were encountered: