Skip to content

Commit 7316d58

Browse files
committed
Add import, math module
1 parent a9fa6ad commit 7316d58

File tree

4 files changed

+74
-16
lines changed

4 files changed

+74
-16
lines changed

auto_editor/lang/libmath.py

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from __future__ import annotations
2+
3+
import math
4+
5+
from auto_editor.lib.contracts import Proc, andc, gt_c, is_real, between_c
6+
7+
8+
def all() -> dict[str, object]:
9+
return {
10+
"exp": Proc("exp", math.exp, (1, 1), is_real),
11+
"ceil": Proc("ceil", math.ceil, (1, 1), is_real),
12+
"floor": Proc("floor", math.floor, (1, 1), is_real),
13+
"sin": Proc("sin", math.sin, (1, 1), is_real),
14+
"cos": Proc("cos", math.cos, (1, 1), is_real),
15+
"tan": Proc("tan", math.tan, (1, 1), is_real),
16+
"asin": Proc("asin", math.asin, (1, 1), between_c(-1, 1)),
17+
"acos": Proc("acos", math.acos, (1, 1), between_c(-1, 1)),
18+
"atan": Proc("atan", math.atan, (1, 1), is_real),
19+
"log": Proc("log", math.log, (1, 2), andc(is_real, gt_c(0))),
20+
"pi": math.pi,
21+
"e": math.e,
22+
"tau": math.tau,
23+
}

auto_editor/lang/palet.py

+22-10
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66

77
from __future__ import annotations
88

9-
import cmath
10-
import math
9+
from cmath import sqrt as complex_sqrt
1110
from dataclasses import dataclass
1211
from difflib import get_close_matches
1312
from fractions import Fraction
@@ -549,7 +548,7 @@ def int_div(n: int, *m: int) -> int:
549548

550549

551550
def _sqrt(v: Number) -> Number:
552-
r = cmath.sqrt(v)
551+
r = complex_sqrt(v)
553552
if r.imag == 0:
554553
if int(r.real) == r.real:
555554
return int(r.real)
@@ -1396,6 +1395,25 @@ def syn_let_star(env: Env, node: Node) -> Any:
13961395
return my_eval(inner_env, node[-1])
13971396

13981397

1398+
def syn_import(env: Env, node: Node) -> None:
1399+
guard_term(node, 2, 2)
1400+
1401+
if type(node[1]) is not Sym:
1402+
raise MyError("class name must be an identifier")
1403+
1404+
module = node[1].val
1405+
error = MyError(f"No module named `{module}`")
1406+
1407+
if module != "math":
1408+
raise error
1409+
try:
1410+
obj = __import__("auto_editor.lang.libmath", fromlist=["lang"])
1411+
except ImportError:
1412+
raise error
1413+
1414+
env.update(obj.all())
1415+
1416+
13991417
def syn_class(env: Env, node: Node) -> None:
14001418
if len(node) < 2:
14011419
raise MyError(f"{node[0]}: Expects at least 1 term")
@@ -1544,6 +1562,7 @@ def my_eval(env: Env, node: object) -> Any:
15441562
"case": Syntax(syn_case),
15451563
"let": Syntax(syn_let),
15461564
"let*": Syntax(syn_let_star),
1565+
"import": Syntax(syn_import),
15471566
"class": Syntax(syn_class),
15481567
"@r": Syntax(attr),
15491568
# loops
@@ -1615,17 +1634,10 @@ def my_eval(env: Env, node: object) -> Any:
16151634
"imag-part": Proc("imag-part", lambda v: v.imag, (1, 1), is_num),
16161635
# reals
16171636
"pow": Proc("pow", pow, (2, 2), is_real),
1618-
"exp": Proc("exp", math.exp, (1, 1), is_real),
16191637
"abs": Proc("abs", abs, (1, 1), is_real),
1620-
"ceil": Proc("ceil", math.ceil, (1, 1), is_real),
1621-
"floor": Proc("floor", math.floor, (1, 1), is_real),
16221638
"round": Proc("round", round, (1, 1), is_real),
16231639
"max": Proc("max", lambda *v: max(v), (1, None), is_real),
16241640
"min": Proc("min", lambda *v: min(v), (1, None), is_real),
1625-
"sin": Proc("sin", math.sin, (1, 1), is_real),
1626-
"cos": Proc("cos", math.cos, (1, 1), is_real),
1627-
"log": Proc("log", math.log, (1, 2), andc(is_real, gt_c(0))),
1628-
"tan": Proc("tan", math.tan, (1, 1), is_real),
16291641
"mod": Proc("mod", mod, (2, 2), is_int),
16301642
"modulo": Proc("modulo", mod, (2, 2), is_int),
16311643
# symbols

auto_editor/subcommands/test.py

+1-6
Original file line numberDiff line numberDiff line change
@@ -577,12 +577,6 @@ def cases(*cases: tuple[str, Any]) -> None:
577577
("(pow 4 0.5)", 2.0),
578578
("(abs 1.0)", 1.0),
579579
("(abs -1)", 1),
580-
("(round 3.5)", 4),
581-
("(round 2.5)", 2),
582-
("(ceil 2.1)", 3),
583-
("(ceil 2.9)", 3),
584-
("(floor 2.1)", 2),
585-
("(floor 2.9)", 2),
586580
("(bool? #t)", True),
587581
("(bool? #f)", True),
588582
("(bool? 0)", False),
@@ -693,6 +687,7 @@ def palet_scripts():
693687
run.raw(["palet", "resources/scripts/maxcut.pal"])
694688
run.raw(["palet", "resources/scripts/scope.pal"])
695689
run.raw(["palet", "resources/scripts/case.pal"])
690+
run.raw(["palet", "resources/scripts/testmath.pal"])
696691

697692
tests = []
698693

resources/scripts/testmath.pal

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#!/usr/bin/env auto-editor palet
2+
#lang palet
3+
4+
(import math)
5+
6+
(assert (equal? (round 3.5) 4))
7+
(assert (equal? (round 2.5) 2))
8+
(assert (equal? (ceil 2.1) 3))
9+
(assert (equal? (ceil 2.9) 3))
10+
(assert (equal? (floor 2.1) 2))
11+
(assert (equal? (floor 2.9) 2))
12+
13+
(assert (equal? (sin 0) 0.0))
14+
(assert (equal? (sin 0/1) 0.0))
15+
(assert (equal? (sin (/ pi 2)) 1.0))
16+
17+
(assert (equal? (cos 0) 1.0))
18+
(assert (equal? (cos (* pi 2)) 1.0))
19+
(assert (equal? (cos pi) -1.0))
20+
(assert (equal? (cos tau) 1.0))
21+
22+
(assert (equal? (asin 0) 0.0))
23+
(assert (equal? (asin 0/1) 0.0))
24+
(assert (equal? (acos 1) 0.0))
25+
(assert (equal? (acos -1) pi))
26+
27+
(assert (equal? (log 1) 0.0))
28+
(assert (equal? (log e) 1.0))

0 commit comments

Comments
 (0)