Skip to content

Commit e076916

Browse files
committed
[GR-10853] add basic superset / subset support for dictview
PullRequest: graalpython/118
2 parents c54567c + 42b93a4 commit e076916

File tree

2 files changed

+216
-0
lines changed

2 files changed

+216
-0
lines changed

graalpython/com.oracle.graal.python.test/src/tests/test_dict.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,3 +373,53 @@ def __hash__(self):
373373
d[x] = 42
374374
x.fail = True
375375
assert_raises(Exc, d.setdefault, x, [])
376+
377+
378+
def test_keys_contained():
379+
helper_keys_contained(lambda x: x.keys())
380+
helper_keys_contained(lambda x: x.items())
381+
382+
383+
def helper_keys_contained(fn):
384+
# Test rich comparisons against dict key views, which should behave the
385+
# same as sets.
386+
empty = fn(dict())
387+
empty2 = fn(dict())
388+
smaller = fn({1: 1, 2: 2})
389+
larger = fn({1: 1, 2: 2, 3: 3})
390+
larger2 = fn({1: 1, 2: 2, 3: 3})
391+
larger3 = fn({4: 1, 2: 2, 3: 3})
392+
393+
assert smaller < larger
394+
assert smaller <= larger
395+
assert larger > smaller
396+
assert larger >= smaller
397+
398+
assert not smaller >= larger
399+
assert not smaller > larger
400+
assert not larger <= smaller
401+
assert not larger < smaller
402+
403+
assert not smaller < larger3
404+
assert not smaller <= larger3
405+
assert not larger3 > smaller
406+
assert not larger3 >= smaller
407+
408+
# Inequality strictness
409+
assert larger2 >= larger
410+
assert larger2 <= larger
411+
assert not larger2 > larger
412+
assert not larger2 < larger
413+
414+
assert larger == larger2
415+
assert smaller != larger
416+
417+
# There is an optimization on the zero-element case.
418+
assert empty == empty2
419+
assert not empty != empty2
420+
assert not empty == smaller
421+
assert empty != smaller
422+
423+
# With the same size, an elementwise compare happens
424+
assert larger != larger3
425+
assert not larger == larger3

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/dict/DictViewBuiltins.java

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,13 @@
4141
import static com.oracle.graal.python.nodes.SpecialMethodNames.__AND__;
4242
import static com.oracle.graal.python.nodes.SpecialMethodNames.__CONTAINS__;
4343
import static com.oracle.graal.python.nodes.SpecialMethodNames.__EQ__;
44+
import static com.oracle.graal.python.nodes.SpecialMethodNames.__GE__;
45+
import static com.oracle.graal.python.nodes.SpecialMethodNames.__GT__;
4446
import static com.oracle.graal.python.nodes.SpecialMethodNames.__ITER__;
4547
import static com.oracle.graal.python.nodes.SpecialMethodNames.__LEN__;
48+
import static com.oracle.graal.python.nodes.SpecialMethodNames.__LE__;
49+
import static com.oracle.graal.python.nodes.SpecialMethodNames.__LT__;
50+
import static com.oracle.graal.python.nodes.SpecialMethodNames.__NE__;
4651
import static com.oracle.graal.python.nodes.SpecialMethodNames.__OR__;
4752
import static com.oracle.graal.python.nodes.SpecialMethodNames.__SUB__;
4853
import static com.oracle.graal.python.nodes.SpecialMethodNames.__XOR__;
@@ -65,6 +70,7 @@
6570
import com.oracle.graal.python.nodes.function.PythonBuiltinBaseNode;
6671
import com.oracle.graal.python.nodes.function.builtins.PythonBinaryBuiltinNode;
6772
import com.oracle.graal.python.nodes.function.builtins.PythonUnaryBuiltinNode;
73+
import com.oracle.truffle.api.CompilerDirectives;
6874
import com.oracle.truffle.api.dsl.Cached;
6975
import com.oracle.truffle.api.dsl.Fallback;
7076
import com.oracle.truffle.api.dsl.GenerateNodeFactory;
@@ -177,6 +183,36 @@ Object doGeneric(Object self, Object other) {
177183
}
178184
}
179185

186+
@Builtin(name = __NE__, fixedNumOfArguments = 2)
187+
@GenerateNodeFactory
188+
public abstract static class NeNode extends PythonBinaryBuiltinNode {
189+
@Child EqNode eqNode;
190+
191+
private EqNode getEqNode() {
192+
if (eqNode == null) {
193+
CompilerDirectives.transferToInterpreterAndInvalidate();
194+
eqNode = insert(DictViewBuiltinsFactory.EqNodeFactory.create());
195+
}
196+
return eqNode;
197+
}
198+
199+
@Specialization
200+
public boolean notEqual(PDictView self, PDictView other) {
201+
return !(Boolean) getEqNode().execute(self, other);
202+
}
203+
204+
@Specialization
205+
public boolean notEqual(PDictView self, PBaseSet other) {
206+
return !(Boolean) getEqNode().execute(self, other);
207+
}
208+
209+
@Fallback
210+
@SuppressWarnings("unused")
211+
Object doGeneric(Object self, Object other) {
212+
return PNotImplemented.NOT_IMPLEMENTED;
213+
}
214+
}
215+
180216
@Builtin(name = __SUB__, fixedNumOfArguments = 2)
181217
@GenerateNodeFactory
182218
abstract static class SubNode extends PythonBinaryBuiltinNode {
@@ -316,4 +352,134 @@ PBaseSet doItemsView(PDictItemsView self, PDictItemsView other,
316352
return factory().createSet(xorNode.execute(selfSet.getDictStorage(), otherSet.getDictStorage()));
317353
}
318354
}
355+
356+
@Builtin(name = __LE__, fixedNumOfArguments = 2)
357+
@GenerateNodeFactory
358+
abstract static class LessEqualNode extends PythonBinaryBuiltinNode {
359+
@Specialization
360+
boolean lessEqual(PDictKeysView self, PBaseSet other,
361+
@Cached("create()") HashingStorageNodes.KeysIsSubsetNode isSubsetNode) {
362+
return isSubsetNode.execute(self.getDict().getDictStorage(), other.getDictStorage());
363+
}
364+
365+
@Specialization
366+
boolean lessEqual(PDictKeysView self, PDictKeysView other,
367+
@Cached("create()") HashingStorageNodes.KeysIsSubsetNode isSubsetNode) {
368+
return isSubsetNode.execute(self.getDict().getDictStorage(), other.getDict().getDictStorage());
369+
}
370+
371+
@Specialization
372+
boolean lessEqual(PDictItemsView self, PBaseSet other,
373+
@Cached("create()") HashingStorageNodes.KeysIsSubsetNode isSubsetNode,
374+
@Cached("create()") SetNodes.ConstructSetNode constructSetNode) {
375+
PSet selfSet = constructSetNode.executeWith(self);
376+
return isSubsetNode.execute(selfSet.getDictStorage(), other.getDictStorage());
377+
}
378+
379+
@Specialization
380+
boolean lessEqual(PDictItemsView self, PDictItemsView other,
381+
@Cached("create()") HashingStorageNodes.KeysIsSubsetNode isSubsetNode,
382+
@Cached("create()") SetNodes.ConstructSetNode constructSetNode) {
383+
PSet selfSet = constructSetNode.executeWith(self);
384+
PSet otherSet = constructSetNode.executeWith(other);
385+
return isSubsetNode.execute(selfSet.getDictStorage(), otherSet.getDictStorage());
386+
}
387+
}
388+
389+
@Builtin(name = __GE__, fixedNumOfArguments = 2)
390+
@GenerateNodeFactory
391+
abstract static class GreaterEqualNode extends PythonBinaryBuiltinNode {
392+
@Specialization
393+
boolean greaterEqual(PDictKeysView self, PBaseSet other,
394+
@Cached("create()") HashingStorageNodes.KeysIsSupersetNode isSupersetNode) {
395+
return isSupersetNode.execute(self.getDict().getDictStorage(), other.getDictStorage());
396+
}
397+
398+
@Specialization
399+
boolean greaterEqual(PDictKeysView self, PDictKeysView other,
400+
@Cached("create()") HashingStorageNodes.KeysIsSupersetNode isSupersetNode) {
401+
return isSupersetNode.execute(self.getDict().getDictStorage(), other.getDict().getDictStorage());
402+
}
403+
404+
@Specialization
405+
boolean greaterEqual(PDictItemsView self, PBaseSet other,
406+
@Cached("create()") HashingStorageNodes.KeysIsSupersetNode isSupersetNode,
407+
@Cached("create()") SetNodes.ConstructSetNode constructSetNode) {
408+
PSet selfSet = constructSetNode.executeWith(self);
409+
return isSupersetNode.execute(selfSet.getDictStorage(), other.getDictStorage());
410+
}
411+
412+
@Specialization
413+
boolean greaterEqual(PDictItemsView self, PDictItemsView other,
414+
@Cached("create()") HashingStorageNodes.KeysIsSupersetNode isSupersetNode,
415+
@Cached("create()") SetNodes.ConstructSetNode constructSetNode) {
416+
PSet selfSet = constructSetNode.executeWith(self);
417+
PSet otherSet = constructSetNode.executeWith(other);
418+
return isSupersetNode.execute(selfSet.getDictStorage(), otherSet.getDictStorage());
419+
}
420+
}
421+
422+
@Builtin(name = __LT__, fixedNumOfArguments = 2)
423+
@GenerateNodeFactory
424+
abstract static class LessThanNode extends PythonBinaryBuiltinNode {
425+
@Child LessEqualNode lessEqualNode;
426+
427+
private LessEqualNode getLessEqualNode() {
428+
if (lessEqualNode == null) {
429+
CompilerDirectives.transferToInterpreterAndInvalidate();
430+
lessEqualNode = insert(DictViewBuiltinsFactory.LessEqualNodeFactory.create());
431+
}
432+
return lessEqualNode;
433+
}
434+
435+
@Specialization
436+
boolean isLessThan(PDictView self, PBaseSet other,
437+
@Cached("createBinaryProfile()") ConditionProfile sizeProfile) {
438+
if (sizeProfile.profile(self.size() >= other.size())) {
439+
return false;
440+
}
441+
return (Boolean) getLessEqualNode().execute(self, other);
442+
}
443+
444+
@Specialization
445+
boolean isLessThan(PDictView self, PDictView other,
446+
@Cached("createBinaryProfile()") ConditionProfile sizeProfile) {
447+
if (sizeProfile.profile(self.size() >= other.size())) {
448+
return false;
449+
}
450+
return (Boolean) getLessEqualNode().execute(self, other);
451+
}
452+
}
453+
454+
@Builtin(name = __GT__, fixedNumOfArguments = 2)
455+
@GenerateNodeFactory
456+
abstract static class GreaterThanNode extends PythonBinaryBuiltinNode {
457+
@Child GreaterEqualNode greaterEqualNode;
458+
459+
private GreaterEqualNode getGreaterEqualNode() {
460+
if (greaterEqualNode == null) {
461+
CompilerDirectives.transferToInterpreterAndInvalidate();
462+
greaterEqualNode = insert(DictViewBuiltinsFactory.GreaterEqualNodeFactory.create());
463+
}
464+
return greaterEqualNode;
465+
}
466+
467+
@Specialization
468+
boolean isGreaterThan(PDictView self, PBaseSet other,
469+
@Cached("createBinaryProfile()") ConditionProfile sizeProfile) {
470+
if (sizeProfile.profile(self.size() <= other.size())) {
471+
return false;
472+
}
473+
return (Boolean) getGreaterEqualNode().execute(self, other);
474+
}
475+
476+
@Specialization
477+
boolean isGreaterThan(PDictView self, PDictView other,
478+
@Cached("createBinaryProfile()") ConditionProfile sizeProfile) {
479+
if (sizeProfile.profile(self.size() <= other.size())) {
480+
return false;
481+
}
482+
return (Boolean) getGreaterEqualNode().execute(self, other);
483+
}
484+
}
319485
}

0 commit comments

Comments
 (0)