Skip to content

Commit

Permalink
Merge pull request #2 from edowson/master
Browse files Browse the repository at this point in the history
Fixes for Python3 and Keras-2.3.1
  • Loading branch information
JesperDramsch authored May 10, 2020
2 parents 0460420 + 861f8b7 commit 123bfc0
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 15 deletions.
4 changes: 2 additions & 2 deletions complexnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#
# What this module includes by default:
from . import bn, conv, dense, init, norm, pool
# from . import fft
from . import fft

from .bn import ComplexBatchNormalization as ComplexBN
from .conv import (
Expand All @@ -17,7 +17,7 @@
WeightNorm_Conv,
)
from .dense import ComplexDense
# from .fft import (fft, ifft, fft2, ifft2, FFT, IFFT, FFT2, IFFT2)
from .fft import (fft, ifft, fft2, ifft2, FFT, IFFT, FFT2, IFFT2)
from .init import (
ComplexIndependentFilters,
IndependentFilters,
Expand Down
13 changes: 4 additions & 9 deletions scripts/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,6 @@
import os, pdb, sys
import time

__version__ = "0.0.0"



#
# Message Formatter
#
Expand All @@ -26,7 +22,7 @@ class MsgFormatter(L.Formatter):

def formatTime(self, record, datefmt):
t = record.created
timeFrac = abs(t-long(t))
timeFrac = abs(t-int(t))
timeStruct = time.localtime(record.created)
timeString = ""
timeString += time.strftime("%F %T", timeStruct)
Expand Down Expand Up @@ -84,7 +80,7 @@ def addArgs(cls, argp):
argp.add_argument("-l", "--loglevel", default="info", type=str,
choices=cls.LOGLEVELS.keys(),
help="Logging severity level.")
argp.add_argument("-s", "--seed", default=0xe4223644e98b8e64, type=long,
argp.add_argument("-s", "--seed", default=0xe4223644e98b8e64, type=int,
help="Seed for PRNGs.")
argp.add_argument("--summary", action="store_true",
help="""Print a summary of the network.""")
Expand Down Expand Up @@ -196,8 +192,7 @@ def getArgParser(prog):
argp = Ap.ArgumentParser(prog = prog,
usage = None,
description = None,
epilog = None,
version = __version__)
epilog = None)
subp = argp.add_subparsers()
argp.set_defaults(argp=argp)
argp.set_defaults(subp=subp)
Expand All @@ -207,7 +202,7 @@ def getArgParser(prog):


# Add subcommands
for v in globals().itervalues():
for v in globals().values():
if(isinstance(v, type) and
issubclass(v, Subcommand) and
v != Subcommand):
Expand Down
8 changes: 4 additions & 4 deletions scripts/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def getResnetModel(d):
activation = d.act
advanced_act = d.aact
drop_prob = d.dropout
inputShape = (3, 32, 32) if K.image_dim_ordering() == "th" else (32, 32, 3)
inputShape = (3, 32, 32) if K.image_data_format() == "channels_first" else (32, 32, 3)
channelAxis = 1 if K.image_data_format() == 'channels_first' else -1
filsize = (3, 3)
convArgs = {
Expand Down Expand Up @@ -196,7 +196,7 @@ def getResnetModel(d):
# Stage 2
#

for i in xrange(n):
for i in range(n):
O = getResidualBlock(O, filsize, [sf, sf], 2, str(i), 'regular', convArgs, bnArgs, d)
if i == n//2 and d.spectral_pool_scheme == "stagemiddle":
O = applySpectralPooling(O, d)
Expand All @@ -209,7 +209,7 @@ def getResnetModel(d):
if d.spectral_pool_scheme == "nodownsample":
O = applySpectralPooling(O, d)

for i in xrange(n-1):
for i in range(n-1):
O = getResidualBlock(O, filsize, [sf*2, sf*2], 3, str(i+1), 'regular', convArgs, bnArgs, d)
if i == n//2 and d.spectral_pool_scheme == "stagemiddle":
O = applySpectralPooling(O, d)
Expand All @@ -222,7 +222,7 @@ def getResnetModel(d):
if d.spectral_pool_scheme == "nodownsample":
O = applySpectralPooling(O, d)

for i in xrange(n-1):
for i in range(n-1):
O = getResidualBlock(O, filsize, [sf*4, sf*4], 4, str(i+1), 'regular', convArgs, bnArgs, d)
if i == n//2 and d.spectral_pool_scheme == "stagemiddle":
O = applySpectralPooling(O, d)
Expand Down

0 comments on commit 123bfc0

Please sign in to comment.