Skip to content

Commit a59d820

Browse files
committed
ipn/warp: exportable proton config not setup on refresh
1 parent dfed1d8 commit a59d820

File tree

1 file changed

+37
-25
lines changed

1 file changed

+37
-25
lines changed

intra/ipn/warp/proton.go

+37-25
Original file line numberDiff line numberDiff line change
@@ -549,9 +549,9 @@ type ProtonClient struct {
549549
refreshToken string
550550
}
551551
cert struct {
552-
SerialNumber string
553-
ExpirationTime int64
554-
RefreshTime int64
552+
serialNumber string
553+
expirationTime int64
554+
refreshTime int64
555555
}
556556

557557
config *ProtonWgConfig
@@ -598,7 +598,7 @@ func (a *ProtonClient) Config() (*ProtonWgConfig, error) {
598598
return nil, errNoProtonConfig
599599
}
600600

601-
func (a *ProtonClient) refreshConf() error {
601+
func (a *ProtonClient) refreshWgConfig() error {
602602
pc := a.config
603603
if pc == nil {
604604
return errNoProtonConfig
@@ -619,9 +619,9 @@ func (a *ProtonClient) refreshConf() error {
619619
return errProtonCredsMismatch
620620
}
621621
// cert info
622-
if pc.CertSerialNumber != a.cert.SerialNumber {
622+
if pc.CertSerialNumber != a.cert.serialNumber {
623623
log.W("proton: refresh: serial number mismatch conf(%s) != struct(%s)",
624-
pc.CertSerialNumber, a.cert.SerialNumber)
624+
pc.CertSerialNumber, a.cert.serialNumber)
625625
// expect it to be the same when the key is the same
626626
}
627627
// wg info
@@ -679,6 +679,8 @@ func (a *ProtonClient) newConf() error {
679679
return errNoProtonServerInfo
680680
}
681681

682+
// reverse of restoreConfigFrom()
683+
682684
// key
683685
pc.Ed25519PrivBase64 = a.key.PrivateKeyBase64()
684686
// session info
@@ -690,10 +692,10 @@ func (a *ProtonClient) newConf() error {
690692
pc.CredsAccessToken = a.creds.accessToken
691693
pc.CredsRefreshToken = a.creds.refreshToken
692694
// cert info
693-
pc.CertSerialNumber = a.cert.SerialNumber
694-
pc.CertExpTime = a.cert.ExpirationTime
695-
pc.CertRefreshTime = a.cert.RefreshTime
696-
// wg info
695+
pc.CertSerialNumber = a.cert.serialNumber
696+
pc.CertExpTime = a.cert.expirationTime
697+
pc.CertRefreshTime = a.cert.refreshTime
698+
// wg info; similar: refreshWgConfig
697699
for _, c := range rwgConfs {
698700
c.genWgConf()
699701
}
@@ -790,11 +792,15 @@ retryAfterRefresh:
790792
}
791793
// TODO: certResponse.ClientPublicKey == a.key.PublicKeyPKIXPem()
792794

793-
a.cert.SerialNumber = certResponse.SerialNumber
794-
a.cert.ExpirationTime = certResponse.ExpirationTime
795-
a.cert.RefreshTime = certResponse.RefreshTime
795+
a.cert.serialNumber = certResponse.SerialNumber
796+
a.cert.expirationTime = certResponse.ExpirationTime
797+
a.cert.refreshTime = certResponse.RefreshTime
798+
pc := a.config
799+
pc.CertSerialNumber = a.cert.serialNumber
800+
pc.CertExpTime = a.cert.expirationTime
801+
pc.CertRefreshTime = a.cert.refreshTime
796802

797-
refreshAt := time.Unix(int64(a.cert.RefreshTime), 0)
803+
refreshAt := time.Unix(int64(a.cert.refreshTime), 0)
798804

799805
log.I("proton: regcert: success: serial(%s): next refresh(%s)",
800806
certResponse.SerialNumber, refreshAt.Format(time.RFC1123))
@@ -813,7 +819,7 @@ func (a *ProtonClient) Refresh() error {
813819
return err
814820
}
815821

816-
return a.refreshConf()
822+
return a.refreshWgConfig()
817823
}
818824

819825
func (a *ProtonClient) fetchCreds() error {
@@ -1041,6 +1047,11 @@ func (a *ProtonClient) refreshCreds() error {
10411047

10421048
a.creds.accessToken = refreshCredResponse.AccessToken
10431049
a.creds.refreshToken = refreshCredResponse.RefreshToken
1050+
pc := a.config
1051+
pc.CredsAccessToken = a.creds.accessToken
1052+
pc.CredsRefreshToken = a.creds.refreshToken
1053+
1054+
log.I("proton: refreshcreds: ok; new access+refresh tokens")
10441055

10451056
return nil
10461057
}
@@ -1148,7 +1159,7 @@ func (a *ProtonClient) rereg(force bool) error {
11481159
fresh := a.config != nil && a.config.CertRefreshTime-now > 0
11491160

11501161
log.I("proton: re-reg %s (exp? %t, force? %t)",
1151-
a.cert.SerialNumber, !fresh, force)
1162+
a.cert.serialNumber, !fresh, force)
11521163

11531164
if !force && fresh {
11541165
return nil // ok
@@ -1254,13 +1265,13 @@ func (w *Client) MakeProtonWg(allServersFilePath string) (*ProtonClient, error)
12541265
return a, nil
12551266
}
12561267

1257-
func (w *Client) MakeProtonWgFrom(fromConfigJson []byte, allServersFilePath string) (*ProtonClient, error) {
1258-
if len(fromConfigJson) <= 0 {
1268+
func (w *Client) MakeProtonWgFrom(existingConfigJson []byte, allServersFilePath string) (*ProtonClient, error) {
1269+
if len(existingConfigJson) <= 0 {
12591270
return nil, errNoProtonJsonConfig
12601271
}
12611272

12621273
var existingConf ProtonWgConfig
1263-
err := json.Unmarshal(fromConfigJson, &existingConf)
1274+
err := json.Unmarshal(existingConfigJson, &existingConf)
12641275
if err != nil {
12651276
return nil, err
12661277
}
@@ -1280,7 +1291,7 @@ func (w *Client) MakeProtonWgFrom(fromConfigJson []byte, allServersFilePath stri
12801291
return nil, err
12811292
}
12821293

1283-
err = a.assignConfig(&existingConf)
1294+
err = a.restoreConfigFrom(&existingConf)
12841295
if err != nil {
12851296
return nil, err
12861297
}
@@ -1295,15 +1306,16 @@ func (w *Client) MakeProtonWgFrom(fromConfigJson []byte, allServersFilePath stri
12951306
return nil, err
12961307
}
12971308

1298-
err = a.refreshConf()
1309+
err = a.refreshWgConfig()
12991310
if err != nil {
13001311
return nil, err
13011312
}
13021313

13031314
return a, nil
13041315
}
13051316

1306-
func (a *ProtonClient) assignConfig(conf *ProtonWgConfig) error {
1317+
func (a *ProtonClient) restoreConfigFrom(conf *ProtonWgConfig) error {
1318+
// top-level config
13071319
a.config = conf
13081320

13091321
// session info
@@ -1315,9 +1327,9 @@ func (a *ProtonClient) assignConfig(conf *ProtonWgConfig) error {
13151327
a.creds.accessToken = conf.CredsAccessToken
13161328
a.creds.refreshToken = conf.CredsRefreshToken
13171329
// cert info
1318-
a.cert.SerialNumber = conf.CertSerialNumber
1319-
a.cert.ExpirationTime = conf.CertExpTime
1320-
a.cert.RefreshTime = conf.CertRefreshTime
1330+
a.cert.serialNumber = conf.CertSerialNumber
1331+
a.cert.expirationTime = conf.CertExpTime
1332+
a.cert.refreshTime = conf.CertRefreshTime
13211333

13221334
protonLogicalsUpdateTime = time.Unix(conf.CreateTimestamp, 0)
13231335

0 commit comments

Comments
 (0)