Skip to content

Commit

Permalink
Merge pull request #78 from c-bata/remove-to-internal-repr
Browse files Browse the repository at this point in the history
Remove a ToInternalRepresentation function.
  • Loading branch information
c-bata authored Mar 11, 2020
2 parents 5bad401 + b675684 commit 66ce706
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 135 deletions.
54 changes: 1 addition & 53 deletions distribution.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ import (

// Distribution represents a parameter that can be optimized.
type Distribution interface {
// ToInternalRepr to convert external representation of a parameter value into internal representation.
ToInternalRepr(interface{}) float64
// ToExternalRepr to convert internal representation of a parameter value into external representation.
ToExternalRepr(float64) interface{}
// Single to test whether the range of this distribution contains just a single value.
Expand All @@ -30,11 +28,6 @@ type UniformDistribution struct {
// UniformDistributionName is the identifier name of UniformDistribution
const UniformDistributionName = "UniformDistribution"

// ToInternalRepr to convert external representation of a parameter value into internal representation.
func (d *UniformDistribution) ToInternalRepr(xr interface{}) float64 {
return xr.(float64)
}

// ToExternalRepr to convert internal representation of a parameter value into external representation.
func (d *UniformDistribution) ToExternalRepr(ir float64) interface{} {
return ir
Expand Down Expand Up @@ -66,11 +59,6 @@ type LogUniformDistribution struct {
// LogUniformDistributionName is the identifier name of LogUniformDistribution
const LogUniformDistributionName = "LogUniformDistribution"

// ToInternalRepr to convert external representation of a parameter value into internal representation.
func (d *LogUniformDistribution) ToInternalRepr(xr interface{}) float64 {
return xr.(float64)
}

// ToExternalRepr to convert internal representation of a parameter value into external representation.
func (d *LogUniformDistribution) ToExternalRepr(ir float64) interface{} {
return ir
Expand Down Expand Up @@ -102,12 +90,6 @@ type IntUniformDistribution struct {
// IntUniformDistributionName is the identifier name of IntUniformDistribution
const IntUniformDistributionName = "IntUniformDistribution"

// ToInternalRepr to convert external representation of a parameter value into internal representation.
func (d *IntUniformDistribution) ToInternalRepr(xr interface{}) float64 {
x := xr.(int)
return float64(x)
}

// ToExternalRepr to convert internal representation of a parameter value into external representation.
func (d *IntUniformDistribution) ToExternalRepr(ir float64) interface{} {
return int(ir)
Expand Down Expand Up @@ -142,14 +124,9 @@ type DiscreteUniformDistribution struct {
// DiscreteUniformDistributionName is the identifier name of DiscreteUniformDistribution
const DiscreteUniformDistributionName = "DiscreteUniformDistribution"

// ToInternalRepr to convert external representation of a parameter value into internal representation.
func (d *DiscreteUniformDistribution) ToInternalRepr(xr interface{}) float64 {
return xr.(float64)
}

// ToExternalRepr to convert internal representation of a parameter value into external representation.
func (d *DiscreteUniformDistribution) ToExternalRepr(ir float64) interface{} {
return ir
return math.Floor((ir-d.Low)/d.Q+0.5)*d.Q + d.Low
}

// Single to test whether the range of this distribution contains just a single value.
Expand Down Expand Up @@ -187,17 +164,6 @@ type CategoricalDistribution struct {
// CategoricalDistributionName is the identifier name of CategoricalDistribution
const CategoricalDistributionName = "CategoricalDistribution"

// ToInternalRepr to convert external representation of a parameter value into internal representation.
func (d *CategoricalDistribution) ToInternalRepr(er interface{}) float64 {
value := er.(string)
for i := range d.Choices {
if d.Choices[i] == value {
return float64(i)
}
}
panic("must not reach here")
}

// ToExternalRepr to convert internal representation of a parameter value into external representation.
func (d *CategoricalDistribution) ToExternalRepr(ir float64) interface{} {
return d.Choices[int(ir)]
Expand Down Expand Up @@ -232,24 +198,6 @@ func ToExternalRepresentation(distribution interface{}, ir float64) (interface{}
}
}

// ToInternalRepresentation converts to internal representation
func ToInternalRepresentation(distribution interface{}, xr interface{}) (float64, error) {
switch d := distribution.(type) {
case UniformDistribution:
return d.ToInternalRepr(xr), nil
case LogUniformDistribution:
return d.ToInternalRepr(xr), nil
case IntUniformDistribution:
return d.ToInternalRepr(xr), nil
case DiscreteUniformDistribution:
return d.ToInternalRepr(xr), nil
case CategoricalDistribution:
return d.ToInternalRepr(xr), nil
default:
return -1, ErrUnknownDistribution
}
}

// DistributionIsSingle whether the distribution contains just a single value.
func DistributionIsSingle(distribution interface{}) (bool, error) {
switch d := distribution.(type) {
Expand Down
65 changes: 15 additions & 50 deletions distribution_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,53 +65,6 @@ func TestDistributionConversionBetweenDistributionAndJSON(t *testing.T) {
}
}

func TestDistributionToInternalRepresentation(t *testing.T) {
tests := []struct {
name string
distribution goptuna.Distribution
args interface{}
want float64
}{
{
name: "uniform distribution",
distribution: &goptuna.UniformDistribution{Low: 0.5, High: 5.5},
args: 3.5,
want: 3.5,
},
{
name: "log uniform distribution",
distribution: &goptuna.LogUniformDistribution{Low: 1e-2, High: 1e5},
args: float64(1e3),
want: float64(1e3),
},
{
name: "int uniform distribution",
distribution: &goptuna.IntUniformDistribution{Low: 0, High: 10},
args: 3,
want: 3.0,
},
{
name: "discrete uniform distribution",
distribution: &goptuna.DiscreteUniformDistribution{Low: 0.5, High: 5.5, Q: 0.5},
args: 3.5,
want: 3.5,
},
{
name: "categorical distribution",
distribution: &goptuna.CategoricalDistribution{Choices: []string{"a", "b", "c"}},
args: "b",
want: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.distribution.ToInternalRepr(tt.args); got != tt.want {
t.Errorf("UniformDistribution.ToInternalRepr() = %v, want %v", got, tt.want)
}
})
}
}

func TestDistributionToExternalRepresentation(t *testing.T) {
tests := []struct {
name string
Expand All @@ -138,11 +91,23 @@ func TestDistributionToExternalRepresentation(t *testing.T) {
want: 3,
},
{
name: "discrete uniform distribution",
name: "discrete uniform distribution 1",
distribution: &goptuna.DiscreteUniformDistribution{Low: 0.5, High: 5.5, Q: 0.5},
args: 3.5,
want: 3.5,
},
{
name: "discrete uniform distribution 2",
distribution: &goptuna.DiscreteUniformDistribution{Low: 0.5, High: 5.5, Q: 0.5},
args: 3.3,
want: 3.5,
},
{
name: "discrete uniform distribution 3",
distribution: &goptuna.DiscreteUniformDistribution{Low: 0.5, High: 5.5, Q: 0.05},
args: 3.52,
want: 3.5,
},
{
name: "categorical distribution",
distribution: &goptuna.CategoricalDistribution{Choices: []string{"a", "b", "c"}},
Expand All @@ -153,7 +118,7 @@ func TestDistributionToExternalRepresentation(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.distribution.ToExternalRepr(tt.args); !reflect.DeepEqual(got, tt.want) {
t.Errorf("UniformDistribution.ToInternalRepr() = %v, want %v", got, tt.want)
t.Errorf("UniformDistribution.ToExternalRepr() = %v, want %v", got, tt.want)
}
})
}
Expand Down Expand Up @@ -219,7 +184,7 @@ func TestDistributionSingle(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.distribution.Single(); got != tt.want {
t.Errorf("UniformDistribution.ToInternalRepr() = %v, want %v", got, tt.want)
t.Errorf("UniformDistribution.Single() = %v, want %v", got, tt.want)
}
})
}
Expand Down
8 changes: 7 additions & 1 deletion rdb/converter.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ func toFrozenTrial(trial trialModel) (goptuna.FrozenTrial, error) {
}

distributions := make(map[string]interface{}, len(trial.TrialParams))
paramsInIR := make(map[string]float64, len(trial.TrialParams))
paramsInXR := make(map[string]interface{}, len(trial.TrialParams))
for i := range trial.TrialParams {
// distributions
Expand All @@ -40,7 +41,11 @@ func toFrozenTrial(trial trialModel) (goptuna.FrozenTrial, error) {
return goptuna.FrozenTrial{}, err
}
distributions[trial.TrialParams[i].Name] = d
// external representations

// internal representation
paramsInIR[trial.TrialParams[i].Name] = trial.TrialParams[i].Value

// external representation
paramsInXR[trial.TrialParams[i].Name], err = goptuna.ToExternalRepresentation(d, trial.TrialParams[i].Value)
if err != nil {
return goptuna.FrozenTrial{}, err
Expand Down Expand Up @@ -83,6 +88,7 @@ func toFrozenTrial(trial trialModel) (goptuna.FrozenTrial, error) {
IntermediateValues: intermediateValue,
DatetimeStart: datetimeStart,
DatetimeComplete: datetimeComplete,
InternalParams: paramsInIR,
Params: paramsInXR,
Distributions: distributions,
UserAttrs: userAttrs,
Expand Down
19 changes: 8 additions & 11 deletions storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ type FrozenTrial struct {
IntermediateValues map[int]float64 `json:"intermediate_values"`
DatetimeStart time.Time `json:"datetime_start"`
DatetimeComplete time.Time `json:"datetime_complete"`
InternalParams map[string]float64 `json:"internal_params"`
Params map[string]interface{} `json:"params"`
Distributions map[string]interface{} `json:"distributions"`
UserAttrs map[string]string `json:"user_attrs"`
Expand Down Expand Up @@ -332,6 +333,7 @@ func (s *InMemoryStorage) CreateNewTrial(studyID int) (int, error) {
IntermediateValues: make(map[int]float64, 8),
DatetimeStart: time.Now(),
DatetimeComplete: time.Time{},
InternalParams: make(map[string]float64, 8),
Params: make(map[string]interface{}, 8),
Distributions: make(map[string]interface{}, 8),
UserAttrs: make(map[string]string, 8),
Expand Down Expand Up @@ -407,14 +409,14 @@ func (s *InMemoryStorage) SetTrialParam(
return ErrTrialCannotBeUpdated
}

// Set param distribution
trial.Distributions[paramName] = distribution
var err error
trial.Params[paramName], err = ToExternalRepresentation(distribution, paramValueInternal)
paramValueExternal, err := ToExternalRepresentation(distribution, paramValueInternal)
if err != nil {
return err
}

trial.Distributions[paramName] = distribution
trial.InternalParams[paramName] = paramValueInternal
trial.Params[paramName] = paramValueExternal
s.trials[trialID] = trial
return nil
}
Expand Down Expand Up @@ -487,16 +489,11 @@ func (s *InMemoryStorage) GetTrialParam(trialID int, paramName string) (float64,

for i := range s.trials {
if s.trials[i].ID == trialID {
xr, ok := s.trials[i].Params[paramName]
ir, ok := s.trials[i].InternalParams[paramName]
if !ok {
return -1.0, errors.New("param doesn't exist")
}
d, ok := s.trials[i].Distributions[paramName]
if !ok {
return -1.0, errors.New("distribution doesn't exist")
}
ir, err := ToInternalRepresentation(d, xr)
return ir, err
return ir, nil
}
}
return -1, ErrInvalidTrialID
Expand Down
10 changes: 1 addition & 9 deletions tpe/sampler.go
Original file line number Diff line number Diff line change
Expand Up @@ -523,18 +523,10 @@ func getObservationPairs(study *goptuna.Study, paramName string) ([]float64, [][
values := make([]float64, 0, len(trials))
scores := make([][2]float64, 0, len(trials))
for _, trial := range trials {
xr, ok := trial.Params[paramName]
ir, ok := trial.InternalParams[paramName]
if !ok {
continue
}
distribution, ok := trial.Distributions[paramName]
if !ok {
continue
}
ir, err := goptuna.ToInternalRepresentation(distribution, xr)
if err != nil {
continue
}

var paramValue, score0, score1 float64
paramValue = ir
Expand Down
18 changes: 7 additions & 11 deletions trial.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,6 @@ func (t *Trial) suggest(name string, distribution interface{}) (float64, error)
return 0.0, err
}

if trial.Params == nil {
trial.Params = make(map[string]interface{}, 8)
}
trial.Params[name], err = ToExternalRepresentation(distribution, v)
if err != nil {
return 0.0, err
}

err = t.Study.Storage.SetTrialParam(trial.ID, name, v, distribution)
return v, err
}
Expand Down Expand Up @@ -149,10 +141,14 @@ func (t *Trial) SuggestDiscreteUniform(name string, low, high, q float64) (float
if low > high {
return 0, errors.New("'low' must be smaller than or equal to the 'high'")
}
v, err := t.suggest(name, DiscreteUniformDistribution{
d := DiscreteUniformDistribution{
High: high, Low: low, Q: q,
})
return v, err
}
ir, err := t.suggest(name, d)
if err != nil {
return 0, err
}
return d.ToExternalRepr(ir).(float64), err
}

// SuggestCategorical suggests an categorical parameter.
Expand Down

0 comments on commit 66ce706

Please sign in to comment.