Skip to content

Commit c3c048a

Browse files
authored
add CELU activation function (#143)
1 parent f9ff658 commit c3c048a

File tree

3 files changed

+68
-3
lines changed

3 files changed

+68
-3
lines changed

src/nf.f90

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,6 @@ module nf
88
use nf_optimizers, only: sgd
99
use nf_activation, only: activation_function, elu, exponential, &
1010
gaussian, linear, relu, leaky_relu, &
11-
sigmoid, softmax, softplus, step, tanhf
11+
sigmoid, softmax, softplus, step, tanhf, &
12+
celu
1213
end module nf

src/nf/nf_activation.f90

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ module nf_activation
1818
public :: softplus
1919
public :: step
2020
public :: tanhf
21+
public :: celu
2122

2223
type, abstract :: activation_function
2324
contains
@@ -140,6 +141,15 @@ end function eval_3d_i
140141
procedure :: eval_3d_prime => eval_3d_tanh_prime
141142
end type tanhf
142143

144+
type, extends(activation_function) :: celu
145+
real:: alpha = 1.0 ! Pytorch default
146+
contains
147+
procedure :: eval_1d => eval_1d_celu
148+
procedure :: eval_1d_prime => eval_1d_celu_prime
149+
procedure :: eval_3d => eval_3d_celu
150+
procedure :: eval_3d_prime => eval_3d_celu_prime
151+
end type celu
152+
143153
contains
144154

145155
pure function eval_1d_elu(self, x) result(res)
@@ -522,6 +532,54 @@ pure function eval_3d_tanh_prime(self, x) result(res)
522532
res = 1 - tanh(x)**2
523533
end function eval_3d_tanh_prime
524534

535+
pure function eval_1d_celu(self, x) result(res)
536+
! Celu activation function.
537+
class(celu), intent(in) :: self
538+
real, intent(in) :: x(:)
539+
real :: res(size(x))
540+
where (x >= 0.0)
541+
res = x
542+
else where
543+
res = self % alpha * (exp(x / self % alpha) - 1.0)
544+
end where
545+
end function
546+
547+
pure function eval_1d_celu_prime(self, x) result(res)
548+
! Celu activation function.
549+
class(celu), intent(in) :: self
550+
real, intent(in) :: x(:)
551+
real :: res(size(x))
552+
where (x >= 0.0)
553+
res = 1.0
554+
else where
555+
res = exp(x / self % alpha)
556+
end where
557+
end function
558+
559+
pure function eval_3d_celu(self, x) result(res)
560+
! Celu activation function.
561+
class(celu), intent(in) :: self
562+
real, intent(in) :: x(:,:,:)
563+
real :: res(size(x,1),size(x,2),size(x,3))
564+
where (x >= 0.0)
565+
res = x
566+
else where
567+
res = self % alpha * (exp(x / self % alpha) - 1.0)
568+
end where
569+
end function
570+
571+
pure function eval_3d_celu_prime(self, x) result(res)
572+
! Celu activation function.
573+
class(celu), intent(in) :: self
574+
real, intent(in) :: x(:,:,:)
575+
real :: res(size(x,1),size(x,2),size(x,3))
576+
where (x >= 0.0)
577+
res = 1.0
578+
else where
579+
res = exp(x / self % alpha)
580+
end where
581+
end function
582+
525583
pure function get_name(self) result(name)
526584
!! Return the name of the activation function.
527585
!!
@@ -556,6 +614,8 @@ pure function get_name(self) result(name)
556614
name = 'step'
557615
class is (tanhf)
558616
name = 'tanh'
617+
class is (celu)
618+
name = 'celu'
559619
class default
560620
error stop 'Unknown activation function type.'
561621
end select

src/nf/nf_network_submodule.f90

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@
2525
softmax, &
2626
softplus, &
2727
step, &
28-
tanhf
28+
tanhf, &
29+
celu
2930

3031
implicit none
3132

@@ -268,10 +269,13 @@ pure function get_activation_by_name(activation_name) result(res)
268269
case('tanh')
269270
allocate ( res, source = tanhf() )
270271

272+
case('celu')
273+
allocate ( res, source = celu() )
274+
271275
case default
272276
error stop 'activation_name must be one of: ' // &
273277
'"elu", "exponential", "gaussian", "linear", "relu", ' // &
274-
'"leaky_relu", "sigmoid", "softmax", "softplus", "step", or "tanh".'
278+
'"leaky_relu", "sigmoid", "softmax", "softplus", "step", "tanh" or "celu".'
275279
end select
276280

277281
end function get_activation_by_name

0 commit comments

Comments
 (0)