Skip to content

Commit e0aad72

Browse files
committed
Ak/backport rework float round
PullRequest: truffleruby/4467
2 parents 21efe56 + 261f4d2 commit e0aad72

File tree

5 files changed

+151
-267
lines changed

5 files changed

+151
-267
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ Compatibility:
6060
* Fix `Integer#ceil` when self is 0 (@andrykonchin).
6161
* Fix `Module#remove_const` and emit warning when constant is deprecated (@andrykonchin).
6262
* Add `Module#set_temporary_name` (#3681, @andrykonchin).
63+
* Modify `Float#round` to match MRI behavior (#3676, @andrykonchin).
6364

6465
Performance:
6566

spec/ruby/core/float/round_spec.rb

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,11 @@
3030
12.345678.round(3.999).should == 12.346
3131
end
3232

33+
it "correctly rounds exact floats with a numerous digits in a fraction part" do
34+
0.8241000000000004.round(10).should == 0.8241
35+
0.8241000000000002.round(10).should == 0.8241
36+
end
37+
3338
it "returns zero when passed a negative argument with magnitude greater than magnitude of the whole number portion of the Float" do
3439
0.8346268.round(-1).should eql(0)
3540
end
@@ -68,6 +73,10 @@
6873
0.42.round(2.0**30).should == 0.42
6974
end
7075

76+
it "returns rounded values for not so big argument" do
77+
0.42.round(2.0**23).should == 0.42
78+
end
79+
7180
it "returns big values rounded to nearest" do
7281
+2.5e20.round(-20).should eql( +3 * 10 ** 20 )
7382
-2.5e20.round(-20).should eql( -3 * 10 ** 20 )

src/main/java/org/truffleruby/core/numeric/FloatNodes.java

Lines changed: 0 additions & 223 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
import com.oracle.truffle.api.profiles.InlinedConditionProfile;
3131
import com.oracle.truffle.api.strings.TruffleString;
3232
import org.truffleruby.annotations.Split;
33-
import org.truffleruby.annotations.SuppressFBWarnings;
3433
import org.truffleruby.annotations.CoreMethod;
3534
import org.truffleruby.builtins.CoreMethodArrayArgumentsNode;
3635
import org.truffleruby.annotations.CoreModule;
@@ -623,228 +622,6 @@ public abstract static class FloatFloorNDigitsPrimitiveNode extends PrimitiveArr
623622

624623
}
625624

626-
@ImportStatic(FloatRoundGuards.class)
627-
@Primitive(name = "float_round_up")
628-
public abstract static class FloatRoundUpPrimitiveNode extends PrimitiveArrayArgumentsNode {
629-
630-
@Specialization(guards = "fitsInInteger(n)")
631-
int roundFittingInt(double n) {
632-
int l = (int) n;
633-
int signum = (int) Math.signum(n);
634-
double d = Math.abs(n - l);
635-
if (d >= 0.5) {
636-
l += signum;
637-
}
638-
return l;
639-
}
640-
641-
@Specialization(guards = "fitsInLong(n)", replaces = "roundFittingInt")
642-
long roundFittingLong(double n) {
643-
long l = (long) n;
644-
long signum = (long) Math.signum(n);
645-
double d = Math.abs(n - l);
646-
if (d >= 0.5) {
647-
l += signum;
648-
}
649-
return l;
650-
}
651-
652-
@Specialization(replaces = "roundFittingLong")
653-
Object round(double n,
654-
@Cached FloatToIntegerNode floatToIntegerNode) {
655-
double signum = Math.signum(n);
656-
double f = Math.floor(Math.abs(n));
657-
double d = Math.abs(n) - f;
658-
if (d >= 0.5) {
659-
f += 1;
660-
}
661-
return floatToIntegerNode.execute(this, f * signum);
662-
}
663-
}
664-
665-
@ImportStatic(FloatRoundGuards.class)
666-
@Primitive(name = "float_round_up_decimal", lowerFixnum = 1)
667-
public abstract static class FloatRoundUpDecimalPrimitiveNode extends PrimitiveArrayArgumentsNode {
668-
669-
@Specialization
670-
double roundNDecimal(double n, int ndigits,
671-
@Cached InlinedConditionProfile boundaryCase) {
672-
long intPart = (long) n;
673-
double s = Math.pow(10.0, ndigits) * Math.signum(n);
674-
double f = (n % 1) * s;
675-
long fInt = (long) f;
676-
double d = f % 1;
677-
int limit = Math.getExponent(n) + Math.getExponent(s) - 51;
678-
if (boundaryCase.profile(this, (Math.getExponent(d) <= limit) ||
679-
(Math.getExponent(1.0 - d) <= limit))) {
680-
return findClosest(n, s, d);
681-
} else if (d > 0.5 || Math.abs(n) - Math.abs((intPart + (fInt + 0.5) / s)) >= 0) {
682-
fInt += 1;
683-
}
684-
return intPart + fInt / s;
685-
}
686-
}
687-
688-
/* If the rounding result is very near to an integer boundary then we need to find the number that is closest to the
689-
* correct result. If we don't do this then it's possible to get errors in the least significant bit of the result.
690-
* We'll test the adjacent double in the direction closest to the boundary and compare the fractional portions. If
691-
* we're already at the minimum error we'll return the original number as it is already rounded as well as it can
692-
* be. In the case of a tie we return the lower number, otherwise we check the go round again. */
693-
private static double findClosest(double n, double s, double d) {
694-
double n2;
695-
while (true) {
696-
if (d > 0.5) {
697-
n2 = Math.nextAfter(n, n + s);
698-
} else {
699-
n2 = Math.nextAfter(n, n - s);
700-
}
701-
double f = (n2 % 1) * s;
702-
double d2 = f % 1;
703-
if (((d > 0.5) ? 1 - d : d) < ((d2 > 0.5) ? 1 - d2 : d2)) {
704-
return n;
705-
} else if (((d > 0.5) ? 1 - d : d) == ((d2 > 0.5) ? 1 - d2 : d2)) {
706-
return Math.abs(n) < Math.abs(n2) ? n : n2;
707-
} else {
708-
d = d2;
709-
n = n2;
710-
}
711-
}
712-
}
713-
714-
@SuppressFBWarnings("FE_FLOATING_POINT_EQUALITY")
715-
@ImportStatic(FloatRoundGuards.class)
716-
@Primitive(name = "float_round_even")
717-
public abstract static class FloatRoundEvenPrimitiveNode extends PrimitiveArrayArgumentsNode {
718-
719-
@Specialization(guards = { "fitsInInteger(n)" })
720-
int roundFittingInt(double n) {
721-
int l = (int) n;
722-
int signum = (int) Math.signum(n);
723-
double d = Math.abs(n - l);
724-
if (d > 0.5) {
725-
l += signum;
726-
} else if (d == 0.5) {
727-
l += l % 2;
728-
}
729-
return l;
730-
}
731-
732-
@Specialization(guards = "fitsInLong(n)", replaces = "roundFittingInt")
733-
long roundFittingLong(double n) {
734-
long l = (long) n;
735-
long signum = (long) Math.signum(n);
736-
double d = Math.abs(n - l);
737-
if (d > 0.5) {
738-
l += signum;
739-
} else if (d == 0.5) {
740-
l += l % 2;
741-
}
742-
return l;
743-
}
744-
745-
@Specialization(replaces = "roundFittingLong")
746-
Object round(double n,
747-
@Cached FloatToIntegerNode floatToIntegerNode) {
748-
double signum = Math.signum(n);
749-
double f = Math.floor(Math.abs(n));
750-
double d = Math.abs(n) - f;
751-
if (d > 0.5) {
752-
f += signum;
753-
} else if (d == 0.5) {
754-
f += f % 2;
755-
}
756-
return floatToIntegerNode.execute(this, f * signum);
757-
}
758-
}
759-
760-
@SuppressFBWarnings("FE_FLOATING_POINT_EQUALITY")
761-
@ImportStatic(FloatRoundGuards.class)
762-
@Primitive(name = "float_round_even_decimal", lowerFixnum = 1)
763-
public abstract static class FloatRoundEvenDecimalPrimitiveNode extends PrimitiveArrayArgumentsNode {
764-
765-
@Specialization
766-
double roundNDecimal(double n, int ndigits,
767-
@Cached InlinedConditionProfile boundaryCase) {
768-
long intPart = (long) n;
769-
double s = Math.pow(10.0, ndigits) * Math.signum(n);
770-
double f = (n % 1) * s;
771-
long fInt = (long) f;
772-
double d = f % 1;
773-
int limit = Math.getExponent(n) + Math.getExponent(s) - 51;
774-
if (boundaryCase.profile(this, (Math.getExponent(d) <= limit) ||
775-
(Math.getExponent(1.0 - d) <= limit))) {
776-
return findClosest(n, s, d);
777-
} else if (d > 0.5) {
778-
fInt += 1;
779-
} else if (d == 0.5 || Math.abs(n) - Math.abs((intPart + (fInt + 0.5) / s)) >= 0) {
780-
fInt += fInt % 2;
781-
}
782-
return intPart + fInt / s;
783-
}
784-
}
785-
786-
@ImportStatic(FloatRoundGuards.class)
787-
@Primitive(name = "float_round_down")
788-
public abstract static class FloatRoundDownPrimitiveNode extends PrimitiveArrayArgumentsNode {
789-
790-
@Specialization(guards = "fitsInInteger(n)")
791-
int roundFittingInt(double n) {
792-
int l = (int) n;
793-
int signum = (int) Math.signum(n);
794-
double d = Math.abs(n - l);
795-
if (d > 0.5) {
796-
l += signum;
797-
}
798-
return l;
799-
}
800-
801-
@Specialization(guards = "fitsInLong(n)", replaces = "roundFittingInt")
802-
long roundFittingLong(double n) {
803-
long l = (long) n;
804-
long signum = (long) Math.signum(n);
805-
double d = Math.abs(n - l);
806-
if (d > 0.5) {
807-
l += signum;
808-
}
809-
return l;
810-
}
811-
812-
@Specialization(replaces = "roundFittingLong")
813-
Object round(double n,
814-
@Cached FloatToIntegerNode floatToIntegerNode) {
815-
double signum = Math.signum(n);
816-
double f = Math.floor(Math.abs(n));
817-
double d = Math.abs(n) - f;
818-
if (d > 0.5) {
819-
f += 1;
820-
}
821-
return floatToIntegerNode.execute(this, f * signum);
822-
}
823-
}
824-
825-
@ImportStatic(FloatRoundGuards.class)
826-
@Primitive(name = "float_round_down_decimal", lowerFixnum = 1)
827-
public abstract static class FloatRoundDownDecimalPrimitiveNode extends PrimitiveArrayArgumentsNode {
828-
829-
@Specialization
830-
double roundNDecimal(double n, int ndigits,
831-
@Cached InlinedConditionProfile boundaryCase) {
832-
long intPart = (long) n;
833-
double s = Math.pow(10.0, ndigits) * Math.signum(n);
834-
double f = (n % 1) * s;
835-
long fInt = (long) f;
836-
double d = f % 1;
837-
int limit = Math.getExponent(n) + Math.getExponent(s) - 51;
838-
if (boundaryCase.profile(this, (Math.getExponent(d) <= limit) ||
839-
(Math.getExponent(1.0 - d) <= limit))) {
840-
return findClosest(n, s, d);
841-
} else if (d > 0.5 && Math.abs(n) - Math.abs((intPart + (fInt + 0.5) / s)) > 0) {
842-
fInt += 1;
843-
}
844-
return intPart + fInt / s;
845-
}
846-
}
847-
848625
@Primitive(name = "float_exp")
849626
public abstract static class FloatExpNode extends PrimitiveArrayArgumentsNode {
850627

src/main/ruby/truffleruby/core/float.rb

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -193,42 +193,44 @@ def floor(ndigits = undefined)
193193
end
194194
end
195195

196-
def round(ndigits = undefined, half: nil)
197-
ndigits = if Primitive.undefined?(ndigits)
198-
nil
199-
else
200-
Truffle::Type.coerce_to(ndigits, Integer, :to_int)
201-
end
196+
def round(ndigits = 0, half: :up)
197+
ndigits = Truffle::Type.coerce_to(ndigits, Integer, :to_int)
198+
202199
if self == 0.0
203-
return ndigits && ndigits > 0 ? self : 0
200+
return ndigits > 0 ? self : 0
201+
end
202+
203+
half = :up if Primitive.nil?(half)
204+
if half != :up && half != :down && half != :even
205+
raise ArgumentError, "invalid rounding mode: #{half}"
204206
end
205-
if Primitive.nil?(ndigits)
206-
if infinite?
207+
208+
if ndigits <= 0
209+
if self.infinite?
207210
raise FloatDomainError, 'Infinite'
208-
elsif nan?
211+
elsif self.nan?
209212
raise FloatDomainError, 'NaN'
210-
else
211-
case half
212-
when nil, :up
213-
Primitive.float_round_up(self)
214-
when :even
215-
Primitive.float_round_even(self)
216-
when :down
217-
Primitive.float_round_down(self)
218-
else
219-
raise ArgumentError, "invalid rounding mode: #{half}"
220-
end
221213
end
222-
else
223-
if ndigits == 0
224-
round(half: half)
225-
elsif ndigits < 0
226-
to_i.round(ndigits, :half => half)
227-
elsif infinite? or nan?
214+
end
215+
216+
if ndigits < 0
217+
to_i.round(ndigits, half: half)
218+
elsif ndigits == 0
219+
Truffle::FloatOperations.round_to_n_place(self, ndigits, half)
220+
elsif !infinite? && !nan?
221+
exponent = Primitive.float_exp(self)
222+
223+
if Truffle::FloatOperations.round_overflow?(ndigits, exponent)
228224
self
225+
elsif Truffle::FloatOperations.round_overflow?(ndigits, exponent)
226+
0.0
227+
elsif ndigits > 14
228+
to_r.round(ndigits, half: half).to_f
229229
else
230230
Truffle::FloatOperations.round_to_n_place(self, ndigits, half)
231231
end
232+
else
233+
self # Infinity or NaN
232234
end
233235
end
234236

0 commit comments

Comments
 (0)