Skip to content

Commit

Permalink
fixed bug and added test
Browse files Browse the repository at this point in the history
  • Loading branch information
danielbinschmid committed Jan 19, 2024
1 parent 4ddfb85 commit a7293c7
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 3 deletions.
22 changes: 22 additions & 0 deletions tests/test_trajectories.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
def test_imports():
from trajectories import TrajectoryFactory

def test_square_trajectory():
from trajectories import TrajectoryFactory
from trajectories.trajectories import SquareLinearTrajectory
import numpy as np
trajectory: SquareLinearTrajectory = TrajectoryFactory.get_linear_square_trajectory()
p0 = trajectory.get_waypoint(0)
p1 = trajectory.get_waypoint(0.25)
p2 = trajectory.get_waypoint(0.5)
p3 = trajectory.get_waypoint(0.75)

# assert corner points
assert((p0.coordinate == trajectory.corner_points[0]).all())
assert((p1.coordinate == trajectory.corner_points[1]).all())
assert((p2.coordinate == trajectory.corner_points[2]).all())
assert((p3.coordinate == trajectory.corner_points[3]).all())

# assert interpolation
assert(trajectory.get_waypoint(0.1).coordinate == np.asarray([0.4, 0., 1.], dtype=np.float32)).all()

4 changes: 2 additions & 2 deletions trajectories/trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ def get_waypoint(self, time: float):
# in-between two points, linear interpolation
cur_corner = math.floor(time)
upcomimg_corner = math.ceil(time)
diff = upcomimg_corner - cur_corner
target_pos = cur_corner + (time - int(time)) * (diff)
diff = self.corner_points[upcomimg_corner] - self.corner_points[cur_corner]
target_pos = self.corner_points[cur_corner] + (time - int(time)) * (diff)

target_wp = Waypoint(
coordinate=target_pos,
Expand Down
2 changes: 1 addition & 1 deletion trajectories/trajectory_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ class TrajectoryFactory:
"""

@classmethod
def get_linear_square_trajectory(cls, square_scale: float, time_scale: float) -> Trajectory:
def get_linear_square_trajectory(cls, square_scale: float=1, time_scale: float=1) -> Trajectory:
return SquareLinearTrajectory(
square_scale, time_scale
)
Expand Down

0 comments on commit a7293c7

Please sign in to comment.