-
-
Notifications
You must be signed in to change notification settings - Fork 152
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable complete static shape information at the type level #711
Enable complete static shape information at the type level #711
Conversation
Type
level
f569038
to
68b0a3a
Compare
Codecov Report
@@ Coverage Diff @@
## main #711 +/- ##
==========================================
+ Coverage 78.17% 78.35% +0.18%
==========================================
Files 152 152
Lines 47663 47679 +16
Branches 10881 10879 -2
==========================================
+ Hits 37260 37360 +100
+ Misses 7846 7772 -74
+ Partials 2557 2547 -10
|
7da3d5d
to
e2c881a
Compare
This is really interesting progress. I don't have anything to add just yet, but adding this link to an old Theano devs discussion that seems relevant after this PR (when mixing fixed and non-fixed shape variables): https://groups.google.com/g/theano-dev/c/OKiVysM4ySg/m/6Q4iHLseBwAJ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall this PR (is extremely huge and) a really interesting change.
I have a question over here. If we have the full static shape information at the Type level, what are the cases in which we might be needing this ? I might be misunderstanding what we mean by |
Yes, that's exactly what it is; one can create a Aside from providing exact shapes upfront, it can also simplify a lot of the shape inference and remove the need for some graph elements (e.g. More immediately, we simplify our For the same reasons, to fix #608 we would need to change There are also a few other related interfaces (e.g. #93) that can be improved along the way, like |
dbc1c05
to
fe03603
Compare
e11a9be
to
7a9791d
Compare
After fixing `Shape`'s broadcastable information, `local_subtensor_remove_broadcastable_index` started replacing all forms of `shape(x)[0]` with `shape(x).dimshuffle(())` when `shape(x).broadcastable == (True,)`, so a lot of the changes in this commit compensate for that difference.
These changes enforce a strict "narrowing"-only conversion policy; i.e. `Type`s can be converted to equal or more specific types.
`TensorType` now supports full shape information. The field `TensorType.broadcastable` has been replaced by `TensorType.shape`, which corresponds to the available static shape information. `None`s are used to encode unknown/no static shape information for a dimension.
71f1144
to
f266c52
Compare
Nice, this is a big one 🥳 . |
I know you don't like to talk about JAX ;) but does that help at all with omnistaging? |
It actually could, but there are still a few follow-ups needed in order to consistently propagate this new static shape information. I'll try to provide those in the next couple days. |
aesara-devs/aesara#711 adressed the warning's underlying issue. The warning is not raised anymore so we re-ran the notebook to have a cleaner output.
aesara-devs/aesara#711 adressed the warning's underlying issue. The warning is not raised anymore so we re-ran the notebook to have a cleaner output.
aesara-devs/aesara#711 adressed the warning's underlying issue. The warning is not raised anymore so we re-ran the notebook to have a cleaner output.
This change was introduced in 9a45333 and did not include specific tests. It was likely introduced to cope with the old restrictions regarding rewrite substitution of variables with different static broadcastable shape information, which was alleviated in aesara-devs#711
This change was introduced in 9a45333 and did not include specific tests. It was likely introduced to cope with the old restrictions regarding rewrite substitution of variables with different static broadcastable shape information, which was alleviated in aesara-devs#711
This change was introduced in 9a45333 and did not include specific tests. It was likely introduced to cope with the old restrictions regarding rewrite substitution of variables with different static broadcastable shape information, which was alleviated in aesara-devs#711
This change was introduced in 9a45333 and did not include specific tests. It was likely introduced to cope with the old restrictions regarding rewrite substitution of variables with different static broadcastable shape information, which was alleviated in aesara-devs#711
This change was introduced in 9a45333 and did not include specific tests. It was likely introduced to cope with the old restrictions regarding rewrite substitution of variables with different static broadcastable shape information, which was alleviated in aesara-devs#711
This PR addresses #431 and—as a result—#608 by adding full static shape information at the
Type
level.I recently noticed that something like this was attempted a while back, and that some of the relevant issues are described in
typeattr.txt
. As mentioned in that document, the main issue is thata.type == b.type
is not suitable for use whenType
s have varying degrees of type/shape information.This varying degree of type information is currently dealt with by making all
Type
s conform in the case of replacements.Types
can differ during rewrites when, for instance, the replacementVariable
is the output of anOp
that's better at inferringTensorType.broadcastable
information (i.e. static shape information). In these cases, if a rewrite/optimization replaces aVariable
ofType
TensorType("float64", (True, False))
with aVariable
ofType
TensorType("float64", (False, False))
, it will apply aRebroadcast
Op
in order to "broaden" the original(True, False)
broadcast pattern (i.e. shape equal to(1, s)
for somes
) so that the replacementType
is equal to the replaced variable's originalType
.There are at least two things wrong with this approach: 1)
Rebroadcast
is basically just a redundantAssert
Op
, and 2)Type
"broadening" isn't particularly productive or useful. We should only be "narrowing" theType
s. In other words, if we get more information about aType
, we should use it. This puts less pressure on each individualOp
to operate at an equally precise level of shape inference, and it more explicitly draws out contradictions/inconsistencies betweenOp
implementations.To summarize the above, we need to "narrow"
Type
s as we refine a graph. We're talking about a simple form of type inference, essentially.In order to do this, an explicit "information"-based ordering over
Type
s is required. For example,a.type < b.type
is true whena
is a more specific type thanb
, and, as a result,a
can be replaced byb
under the assumption that both types represent the same term. Since this is a high-level assumption we necessarily make when performing replacements, we're basically updating our type information about terms when we encounter these cases.For example,
t1 = TensorType("float64", [False, False])
can be replaced byt2 = TensorType("float64", [True, False])
, becauset2
has a more specific broadcast pattern thant1
. In other words,t1 > t2
according to the above.The implementation in this PR uses a new
Type.is_super
method to implement the<
ordering in the above example, but it can easily be replaced by a more convenient__lt__
implementation. (NB: the name of the method is liable to change.)After making these important
Type
system changes, this PR introduces an entirely newType
to implement fixed-shape types; however, since the existingTensorType
already has partial shape information (i.e.TensorType.broadcastable
), it might be better to extend that by adding aTensorType.shape
attribute and makingTensorType.broadcastable
a property derived therefrom. This aligns closer with the approach described in the oldtypeattr.txt
file above, since theTensorType.shape
would need to encode missing shape information with something likeNone
s (e.g.TensorType.shape = (None, 1, None)
would be equivalent toTensorType.broadcastable = (False, True, False)
).Type
equalityType
sTensorType
s withall(TensorType.broadcastable) is True
fixed-shape tensors, since their shapes are fixed and known already.This would involve overriding
TensorType.__new__
so that it returns the appropriate types.Type.filter_variable
/Type.convert_variable
interface. Currently, these methods overlap a lot and their exact purposes aren't clear and their usages seem somewhat inconsistent with regard to types and the like.FixedShapeTensorType
and extending the existingTensorType
to include ashape
field.