@@ -18,6 +18,7 @@ module nf_activation
18
18
public :: softplus
19
19
public :: step
20
20
public :: tanhf
21
+ public :: celu
21
22
22
23
type, abstract :: activation_function
23
24
contains
@@ -140,6 +141,15 @@ end function eval_3d_i
140
141
procedure :: eval_3d_prime = > eval_3d_tanh_prime
141
142
end type tanhf
142
143
144
+ type, extends(activation_function) :: celu
145
+ real :: alpha = 1.0 ! Pytorch default
146
+ contains
147
+ procedure :: eval_1d = > eval_1d_celu
148
+ procedure :: eval_1d_prime = > eval_1d_celu_prime
149
+ procedure :: eval_3d = > eval_3d_celu
150
+ procedure :: eval_3d_prime = > eval_3d_celu_prime
151
+ end type celu
152
+
143
153
contains
144
154
145
155
pure function eval_1d_elu (self , x ) result(res)
@@ -522,6 +532,54 @@ pure function eval_3d_tanh_prime(self, x) result(res)
522
532
res = 1 - tanh (x)** 2
523
533
end function eval_3d_tanh_prime
524
534
535
+ pure function eval_1d_celu (self , x ) result(res)
536
+ ! Celu activation function.
537
+ class(celu), intent (in ) :: self
538
+ real , intent (in ) :: x(:)
539
+ real :: res(size (x))
540
+ where (x >= 0.0 )
541
+ res = x
542
+ else where
543
+ res = self % alpha * (exp (x / self % alpha) - 1.0 )
544
+ end where
545
+ end function
546
+
547
+ pure function eval_1d_celu_prime (self , x ) result(res)
548
+ ! Celu activation function.
549
+ class(celu), intent (in ) :: self
550
+ real , intent (in ) :: x(:)
551
+ real :: res(size (x))
552
+ where (x >= 0.0 )
553
+ res = 1.0
554
+ else where
555
+ res = exp (x / self % alpha)
556
+ end where
557
+ end function
558
+
559
+ pure function eval_3d_celu (self , x ) result(res)
560
+ ! Celu activation function.
561
+ class(celu), intent (in ) :: self
562
+ real , intent (in ) :: x(:,:,:)
563
+ real :: res(size (x,1 ),size (x,2 ),size (x,3 ))
564
+ where (x >= 0.0 )
565
+ res = x
566
+ else where
567
+ res = self % alpha * (exp (x / self % alpha) - 1.0 )
568
+ end where
569
+ end function
570
+
571
+ pure function eval_3d_celu_prime (self , x ) result(res)
572
+ ! Celu activation function.
573
+ class(celu), intent (in ) :: self
574
+ real , intent (in ) :: x(:,:,:)
575
+ real :: res(size (x,1 ),size (x,2 ),size (x,3 ))
576
+ where (x >= 0.0 )
577
+ res = 1.0
578
+ else where
579
+ res = exp (x / self % alpha)
580
+ end where
581
+ end function
582
+
525
583
pure function get_name (self ) result(name)
526
584
! ! Return the name of the activation function.
527
585
! !
@@ -556,6 +614,8 @@ pure function get_name(self) result(name)
556
614
name = ' step'
557
615
class is (tanhf)
558
616
name = ' tanh'
617
+ class is (celu)
618
+ name = ' celu'
559
619
class default
560
620
error stop ' Unknown activation function type.'
561
621
end select
0 commit comments