Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable complete static shape information at the type level #711

Merged
merged 36 commits into from
Jan 13, 2022

Conversation

brandonwillard
Copy link
Member

@brandonwillard brandonwillard commented Dec 29, 2021

This PR addresses #431 and—as a result—#608 by adding full static shape information at the Type level.

I recently noticed that something like this was attempted a while back, and that some of the relevant issues are described in typeattr.txt. As mentioned in that document, the main issue is that a.type == b.type is not suitable for use when Types have varying degrees of type/shape information.

This varying degree of type information is currently dealt with by making all Types conform in the case of replacements. Types can differ during rewrites when, for instance, the replacement Variable is the output of an Op that's better at inferring TensorType.broadcastable information (i.e. static shape information). In these cases, if a rewrite/optimization replaces a Variable of Type TensorType("float64", (True, False)) with a Variable of Type TensorType("float64", (False, False)), it will apply a Rebroadcast Op in order to "broaden" the original (True, False) broadcast pattern (i.e. shape equal to (1, s) for some s) so that the replacement Type is equal to the replaced variable's original Type.

There are at least two things wrong with this approach: 1) Rebroadcast is basically just a redundant Assert Op, and 2) Type "broadening" isn't particularly productive or useful. We should only be "narrowing" the Types. In other words, if we get more information about a Type, we should use it. This puts less pressure on each individual Op to operate at an equally precise level of shape inference, and it more explicitly draws out contradictions/inconsistencies between Op implementations.

To summarize the above, we need to "narrow" Types as we refine a graph. We're talking about a simple form of type inference, essentially.

In order to do this, an explicit "information"-based ordering over Types is required. For example, a.type < b.type is true when a is a more specific type than b, and, as a result, a can be replaced by b under the assumption that both types represent the same term. Since this is a high-level assumption we necessarily make when performing replacements, we're basically updating our type information about terms when we encounter these cases.

For example, t1 = TensorType("float64", [False, False]) can be replaced by t2 = TensorType("float64", [True, False]), because t2 has a more specific broadcast pattern than t1. In other words, t1 > t2 according to the above.

The implementation in this PR uses a new Type.is_super method to implement the < ordering in the above example, but it can easily be replaced by a more convenient __lt__ implementation. (NB: the name of the method is liable to change.)

After making these important Type system changes, this PR introduces an entirely new Type to implement fixed-shape types; however, since the existing TensorType already has partial shape information (i.e. TensorType.broadcastable), it might be better to extend that by adding a TensorType.shape attribute and making TensorType.broadcastable a property derived therefrom. This aligns closer with the approach described in the old typeattr.txt file above, since the TensorType.shape would need to encode missing shape information with something like Nones (e.g. TensorType.shape = (None, 1, None) would be equivalent to TensorType.broadcastable = (False, True, False)).

  • Finish refactoring code that uses strict Type equality
  • Add tests for new fixed-shape Types
  • Consider making all TensorTypes with all(TensorType.broadcastable) is True fixed-shape tensors, since their shapes are fixed and known already.
    This would involve overriding TensorType.__new__ so that it returns the appropriate types.
  • Clarify/fix the Type.filter_variable/Type.convert_variable interface. Currently, these methods overlap a lot and their exact purposes aren't clear and their usages seem somewhat inconsistent with regard to types and the like.
  • Consider removing distinct FixedShapeTensorType and extending the existing TensorType to include a shape field.

@brandonwillard brandonwillard changed the title Introduce a full shape information at the type level Introduce full shape information at the Type level Dec 29, 2021
@brandonwillard brandonwillard self-assigned this Dec 29, 2021
@brandonwillard brandonwillard added enhancement New feature or request important refactor This issue involves refactoring labels Dec 29, 2021
@brandonwillard brandonwillard linked an issue Dec 29, 2021 that may be closed by this pull request
@brandonwillard brandonwillard marked this pull request as draft December 29, 2021 22:20
@brandonwillard brandonwillard force-pushed the add-fixed-shape-type branch 3 times, most recently from f569038 to 68b0a3a Compare December 29, 2021 23:34
@codecov
Copy link

codecov bot commented Dec 30, 2021

Codecov Report

Merging #711 (f266c52) into main (51792fe) will increase coverage by 0.18%.
The diff coverage is 87.14%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #711      +/-   ##
==========================================
+ Coverage   78.17%   78.35%   +0.18%     
==========================================
  Files         152      152              
  Lines       47663    47679      +16     
  Branches    10881    10879       -2     
==========================================
+ Hits        37260    37360     +100     
+ Misses       7846     7772      -74     
+ Partials     2557     2547      -10     
Impacted Files Coverage Δ
aesara/d3viz/d3viz.py 23.68% <0.00%> (ø)
aesara/graph/features.py 65.95% <ø> (ø)
aesara/link/basic.py 85.21% <0.00%> (ø)
aesara/link/vm.py 87.19% <ø> (ø)
aesara/tensor/nnet/batchnorm.py 77.04% <ø> (ø)
aesara/tensor/nnet/neighbours.py 91.26% <0.00%> (ø)
aesara/tensor/nnet/opt.py 42.96% <0.00%> (ø)
aesara/tensor/math_opt.py 86.23% <22.22%> (-0.01%) ⬇️
aesara/compile/function/types.py 78.62% <33.33%> (ø)
aesara/scan/opt.py 80.31% <33.33%> (ø)
... and 57 more

@ricardoV94
Copy link
Contributor

ricardoV94 commented Dec 30, 2021

This is really interesting progress.

I don't have anything to add just yet, but adding this link to an old Theano devs discussion that seems relevant after this PR (when mixing fixed and non-fixed shape variables): https://groups.google.com/g/theano-dev/c/OKiVysM4ySg/m/6Q4iHLseBwAJ

Copy link
Contributor

@kc611 kc611 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall this PR (is extremely huge and) a really interesting change.

@kc611
Copy link
Contributor

kc611 commented Dec 30, 2021

This aligns closes with the approach described in the old typeattr.txt file above, since the TensorType.shape would need to encode missing shape information with something like Nones (e.g. TensorType.shape = (None, 1, None) would be equivalent to TensorType.broadcastable = (False, True, False)).

I have a question over here. If we have the full static shape information at the Type level, what are the cases in which we might be needing this ?

I might be misunderstanding what we mean by full static shape information, I assume it is the literal information about the shape. (as a tuple of integers or something which helps us derive that.)

@brandonwillard
Copy link
Member Author

brandonwillard commented Dec 30, 2021

This aligns closes with the approach described in the old typeattr.txt file above, since the TensorType.shape would need to encode missing shape information with something like Nones (e.g. TensorType.shape = (None, 1, None) would be equivalent to TensorType.broadcastable = (False, True, False)).

I have a question over here. If we have the full static shape information at the Type level, what are the cases in which we might be needing this ?

I might be misunderstanding what we mean by full static shape information, I assume it is the literal information about the shape. (as a tuple of integers or something which helps us derive that.)

Yes, that's exactly what it is; one can create a TensorVariable (or a subclass thereof) that has fixed, concrete shape information, such that var.shape would return a tuple of ints.

Aside from providing exact shapes upfront, it can also simplify a lot of the shape inference and remove the need for some graph elements (e.g. Shape, Rebroadcast, and similar Ops).

More immediately, we simplify our Op implementations, which currently have no consistent means of modeling fixed-size/shape terms. For instance, we're forced to use vararg inputs, or Op-level constants that specify the number of inputs/dimensions (e.g. Op.ndim), in order to model shape-like arguments, but those approaches quickly become cumbersome or even not applicable in the face of multiple such inputs. Join, Split, MakeVector, and Alloc are all good examples. With these changes, it would be possible to use a single fixed-length/shape input in these Ops (e.g. MakeVector's *inputs could be input, if input were fixed-length). (Note that this could also be addressed with support for a tuple-like type.)

For the same reasons, to fix #608 we would need to change RandomVariable so that it carries around the concrete length of its size input, just like those other Ops; however, we can address this—and much more—with the changes here.

There are also a few other related interfaces (e.g. #93) that can be improved along the way, like TensorConstant.shape, which can be made consistent with its premise of modeling a known constant and—thus—a known shape.

@brandonwillard brandonwillard force-pushed the add-fixed-shape-type branch 4 times, most recently from dbc1c05 to fe03603 Compare December 31, 2021 19:01
@kc611 kc611 mentioned this pull request Jan 1, 2022
4 tasks
@brandonwillard brandonwillard force-pushed the add-fixed-shape-type branch 7 times, most recently from e11a9be to 7a9791d Compare January 5, 2022 00:01
After fixing `Shape`'s broadcastable information,
`local_subtensor_remove_broadcastable_index` started replacing all forms of
`shape(x)[0]` with `shape(x).dimshuffle(())` when `shape(x).broadcastable
== (True,)`, so a lot of the changes in this commit compensate for that
difference.
These changes enforce a strict "narrowing"-only conversion policy; i.e. `Type`s
can be converted to equal or more specific types.
`TensorType` now supports full shape information.  The field
`TensorType.broadcastable` has been replaced by `TensorType.shape`, which
corresponds to the available static shape information.  `None`s are used to
encode unknown/no static shape information for a dimension.
@brandonwillard brandonwillard merged commit 150add2 into aesara-devs:main Jan 13, 2022
@brandonwillard brandonwillard deleted the add-fixed-shape-type branch January 13, 2022 03:43
@twiecki
Copy link
Contributor

twiecki commented Jan 13, 2022

Nice, this is a big one 🥳 .

@twiecki
Copy link
Contributor

twiecki commented Jan 13, 2022

I know you don't like to talk about JAX ;) but does that help at all with omnistaging?

@brandonwillard
Copy link
Member Author

I know you don't like to talk about JAX ;) but does that help at all with omnistaging?

It actually could, but there are still a few follow-ups needed in order to consistently propagate this new static shape information. I'll try to provide those in the next couple days.

rlouf added a commit to rlouf/aehmc that referenced this pull request Jan 14, 2022
aesara-devs/aesara#711 adressed the warning's underlying issue. The
warning is not raised anymore so we re-ran the notebook to have a
cleaner output.
rlouf added a commit to aesara-devs/aehmc that referenced this pull request Jan 14, 2022
aesara-devs/aesara#711 adressed the warning's underlying issue. The
warning is not raised anymore so we re-ran the notebook to have a
cleaner output.
@brandonwillard brandonwillard linked an issue Jan 14, 2022 that may be closed by this pull request
rlouf added a commit to rlouf/aehmc that referenced this pull request Jan 20, 2022
aesara-devs/aesara#711 adressed the warning's underlying issue. The
warning is not raised anymore so we re-ran the notebook to have a
cleaner output.
ricardoV94 added a commit to ricardoV94/aesara that referenced this pull request May 28, 2022
This change was introduced in 9a45333 and did not include specific tests. It was likely introduced to cope with the old restrictions regarding rewrite substitution of variables with different static broadcastable shape information, which was alleviated in aesara-devs#711
ricardoV94 added a commit to ricardoV94/aesara that referenced this pull request May 28, 2022
This change was introduced in 9a45333 and did not include specific tests. It was likely introduced to cope with the old restrictions regarding rewrite substitution of variables with different static broadcastable shape information, which was alleviated in aesara-devs#711
ricardoV94 added a commit to ricardoV94/aesara that referenced this pull request Jun 7, 2022
This change was introduced in 9a45333 and did not include specific tests. It was likely introduced to cope with the old restrictions regarding rewrite substitution of variables with different static broadcastable shape information, which was alleviated in aesara-devs#711
ricardoV94 added a commit to ricardoV94/aesara that referenced this pull request Jul 7, 2022
This change was introduced in 9a45333 and did not include specific tests. It was likely introduced to cope with the old restrictions regarding rewrite substitution of variables with different static broadcastable shape information, which was alleviated in aesara-devs#711
ricardoV94 added a commit to ricardoV94/aesara that referenced this pull request Jul 7, 2022
This change was introduced in 9a45333 and did not include specific tests. It was likely introduced to cope with the old restrictions regarding rewrite substitution of variables with different static broadcastable shape information, which was alleviated in aesara-devs#711
brandonwillard pushed a commit that referenced this pull request Jul 7, 2022
This change was introduced in 9a45333 and did not include specific tests. It was likely introduced to cope with the old restrictions regarding rewrite substitution of variables with different static broadcastable shape information, which was alleviated in #711
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request important refactor This issue involves refactoring shape inference
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Clarify Type interface and/or its documentation Create a fixed-length/shape TensorType and TensorVariable
4 participants