Skip to content

Commit 9a0edc6

Browse files
committed
Add audio-levels variable, make map work with arrays
1 parent b06145b commit 9a0edc6

File tree

4 files changed

+36
-19
lines changed

4 files changed

+36
-19
lines changed

auto_editor/analyze.py

-4
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,6 @@ def link_nodes(*nodes: FilterContext) -> None:
9090
c.link_to(n)
9191

9292

93-
def to_threshold(arr: np.ndarray, t: int | float) -> NDArray[np.bool_]:
94-
return np.fromiter((x >= t for x in arr), dtype=np.bool_)
95-
96-
9793
def mut_remove_small(
9894
arr: NDArray[np.bool_], lim: int, replace: int, with_: int
9995
) -> None:

auto_editor/lang/palet.py

+18-9
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,7 @@
1818
import numpy as np
1919
from numpy import logical_and, logical_not, logical_or, logical_xor
2020

21-
from auto_editor.analyze import (
22-
LevelError,
23-
mut_remove_large,
24-
mut_remove_small,
25-
to_threshold,
26-
)
21+
from auto_editor.analyze import LevelError, mut_remove_large, mut_remove_small
2722
from auto_editor.lib.contracts import *
2823
from auto_editor.lib.data_structs import *
2924
from auto_editor.lib.err import MyError
@@ -690,6 +685,9 @@ def palet_map(proc: Proc, seq: Any) -> Any:
690685
return Quoted(tuple(map(proc, seq.val)))
691686
if isinstance(seq, list | range):
692687
return list(map(proc, seq))
688+
elif isinstance(seq, np.ndarray):
689+
vectorized_proc = np.vectorize(proc)
690+
return vectorized_proc(seq)
693691
return proc(seq)
694692

695693

@@ -1469,6 +1467,16 @@ def edit_all() -> np.ndarray:
14691467
return env["@levels"].all()
14701468

14711469

1470+
def audio_levels(stream: int) -> np.ndarray:
1471+
if "@levels" not in env:
1472+
raise MyError("Can't use `audio` if there's no input media")
1473+
1474+
try:
1475+
return env["@levels"].audio(stream)
1476+
except LevelError as e:
1477+
raise MyError(e)
1478+
1479+
14721480
def edit_audio(
14731481
threshold: float = 0.04,
14741482
stream: object = Sym("all"),
@@ -1491,7 +1499,7 @@ def edit_audio(
14911499

14921500
try:
14931501
for s in stream_range:
1494-
audio_list = to_threshold(levels.audio(s), threshold)
1502+
audio_list = levels.audio(s) >= threshold
14951503
if stream_data is None:
14961504
stream_data = audio_list
14971505
else:
@@ -1521,7 +1529,7 @@ def edit_motion(
15211529
levels = env["@levels"]
15221530
strict = env["@filesetup"].strict
15231531
try:
1524-
return to_threshold(levels.motion(stream, blur, width), threshold)
1532+
return levels.motion(stream, blur, width) >= threshold
15251533
except LevelError as e:
15261534
return raise_(e) if strict else levels.all()
15271535

@@ -1582,7 +1590,7 @@ def my_eval(env: Env, node: object) -> Any:
15821590
return ref(oper, my_eval(env, node[1]))
15831591

15841592
raise MyError(
1585-
f"Tried to run: {print_str(oper)} with args: {print_str(node[1:])}"
1593+
f"{print_str(oper)} is not a function. Tried to run with args: {print_str(node[1:])}"
15861594
)
15871595

15881596
if type(oper) is Syntax:
@@ -1617,6 +1625,7 @@ def my_eval(env: Env, node: object) -> Any:
16171625
# edit procedures
16181626
"none": Proc("none", edit_none, (0, 0)),
16191627
"all/e": Proc("all/e", edit_all, (0, 0)),
1628+
"audio-levels": Proc("audio-levels", audio_levels, (1, 1), is_nat),
16201629
"audio": Proc("audio", edit_audio, (0, 4),
16211630
is_threshold, orc(is_nat, Sym("all")), is_nat,
16221631
{"threshold": 0, "stream": 1, "minclip": 2, "mincut": 2}

auto_editor/lib/contracts.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from fractions import Fraction
66
from typing import Any
77

8+
from numpy import float64
9+
810
from .data_structs import Sym, print_str
911
from .err import MyError
1012

@@ -41,7 +43,7 @@ def check_contract(c: object, val: object) -> bool:
4143
return val is True
4244
if c is False:
4345
return val is False
44-
if type(c) in (int, float, Fraction, complex, str, Sym):
46+
if type(c) in (int, float, float64, Fraction, complex, str, Sym):
4547
return val == c
4648
raise MyError(f"Invalid contract, got: {print_str(c)}")
4749

@@ -163,17 +165,21 @@ def is_contract(c: object) -> bool:
163165
is_nat = Contract("nat?", lambda v: type(v) is int and v > -1)
164166
is_nat1 = Contract("nat1?", lambda v: type(v) is int and v > 0)
165167
int_not_zero = Contract("(or/c (not/c 0) int?)", lambda v: v != 0 and is_int(v))
166-
is_num = Contract("number?", lambda v: type(v) in (int, float, Fraction, complex))
167-
is_real = Contract("real?", lambda v: type(v) in (int, float, Fraction))
168-
is_float = Contract("float?", lambda v: type(v) is float)
168+
is_num = Contract(
169+
"number?", lambda v: type(v) in (int, float, float64, Fraction, complex)
170+
)
171+
is_real = Contract("real?", lambda v: type(v) in (int, float, float64, Fraction))
172+
is_float = Contract("float?", lambda v: type(v) in (float, float64))
169173
is_frac = Contract("frac?", lambda v: type(v) is Fraction)
170174
is_str = Contract("string?", lambda v: type(v) is str)
171175
any_p = Contract("any", lambda v: True)
172176
is_void = Contract("void?", lambda v: v is None)
173-
is_int_or_float = Contract("(or/c int? float?)", lambda v: type(v) in (int, float))
177+
is_int_or_float = Contract(
178+
"(or/c int? float?)", lambda v: type(v) in (int, float, float64)
179+
)
174180
is_threshold = Contract(
175181
"threshold?",
176-
lambda v: type(v) in (int, float) and v >= 0 and v <= 1, # type: ignore
182+
lambda v: type(v) in (int, float, float64) and v >= 0 and v <= 1, # type: ignore
177183
)
178184
is_proc = Contract("procedure?", lambda v: isinstance(v, Proc | Contract))
179185

resources/scripts/maxcut.pal

+6
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,9 @@
4646
(maxcut my-arr 3)
4747
(bool-array 1 0 1 0 0 1 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1)
4848
))
49+
50+
51+
(define (b x) (>= x 0.5))
52+
(define arr (array 'float64 0.1 0.2 0.3 0.6 0.7))
53+
54+
(assert (bool-array? (map b arr)))

0 commit comments

Comments
 (0)