diff --git a/internal/federation/handle.go b/internal/federation/handle.go index bc1f7e28..dd1f3711 100644 --- a/internal/federation/handle.go +++ b/internal/federation/handle.go @@ -38,37 +38,48 @@ func MakeJoinRequestsHandler(s *Server, w http.ResponseWriter, req *http.Request return } + makeJoinResp, err := MakeRespMakeJoin(s, room, userID) + if err != nil { + w.WriteHeader(500) + w.Write([]byte(fmt.Sprintf("complement: HandleMakeSendJoinRequests %s", err))) + return + } + + // Send it + w.WriteHeader(200) + b, _ := json.Marshal(makeJoinResp) + w.Write(b) +} + +// MakeRespMakeJoin makes the response for a /make_join request, without verifying any signatures +// or dealing with HTTP responses itself. +func MakeRespMakeJoin(s *Server, room *ServerRoom, userID string) (resp gomatrixserverlib.RespMakeJoin, err error) { // Generate a join event builder := gomatrixserverlib.EventBuilder{ Sender: userID, - RoomID: roomID, + RoomID: room.RoomID, Type: "m.room.member", StateKey: &userID, PrevEvents: []string{room.Timeline[len(room.Timeline)-1].EventID()}, Depth: room.Timeline[len(room.Timeline)-1].Depth() + 1, } - err := builder.SetContent(map[string]interface{}{"membership": gomatrixserverlib.Join}) + err = builder.SetContent(map[string]interface{}{"membership": gomatrixserverlib.Join}) if err != nil { - w.WriteHeader(500) - w.Write([]byte("complement: HandleMakeSendJoinRequests make_join cannot set membership content: " + err.Error())) + err = fmt.Errorf("make_join cannot set membership content: %w", err) return } stateNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(&builder) if err != nil { - w.WriteHeader(500) - w.Write([]byte("complement: HandleMakeSendJoinRequests make_join cannot calculate auth_events: " + err.Error())) + err = fmt.Errorf("make_join cannot calculate auth_events: %w", err) return } builder.AuthEvents = room.AuthEvents(stateNeeded) - // Send it - res := map[string]interface{}{ - "event": builder, - "room_version": room.Version, + resp = gomatrixserverlib.RespMakeJoin{ + RoomVersion: room.Version, + JoinEvent: builder, } - w.WriteHeader(200) - b, _ := json.Marshal(res) - w.Write(b) + return } // SendJoinRequestsHandler is the http.Handler implementation for the send_join part of diff --git a/tests/federation_room_join_partial_state_test.go b/tests/federation_room_join_partial_state_test.go index 05187c38..4cef4a46 100644 --- a/tests/federation_room_join_partial_state_test.go +++ b/tests/federation_room_join_partial_state_test.go @@ -806,6 +806,125 @@ func TestPartialStateJoin(t *testing.T) { }, ) }) + + // when the server is in the middle of a partial state join, it should not accept + // /make_join because it can't give a full answer. + t.Run("Rejects make_join during partial join", func(t *testing.T) { + // In this test, we have 3 homeservers: + // hs1 (the server under test) with @alice:hs1 + // This is the server that will be in the middle of a partial join. + // testServer1 (a Complement test server) with @bob: + // This is the server that created the room originally. + // testServer2 (another Complement test server) with @charlie: + // This is the server that will try to make a join via testServer1. + deployment := Deploy(t, b.BlueprintAlice) + defer deployment.Destroy(t) + alice := deployment.Client(t, "hs1", "@alice:hs1") + + testServer1 := createTestServer(t, deployment) + cancel := testServer1.Listen() + defer cancel() + serverRoom := createTestRoom(t, testServer1, alice.GetDefaultRoomVersion(t)) + roomID := serverRoom.RoomID + psjResult := beginPartialStateJoin(t, testServer1, serverRoom, alice) + defer psjResult.Destroy() + + // The partial join is now in progress. + // Let's have a new test server rock up and ask to join the room by making a + // /make_join request. + + testServer2 := createTestServer(t, deployment) + cancel2 := testServer2.Listen() + defer cancel2() + + fedClient2 := testServer2.FederationClient(deployment) + + // charlie sends a make_join + _, err := fedClient2.MakeJoin(context.Background(), "hs1", roomID, testServer2.UserID("charlie"), federation.SupportedRoomVersions()) + + if err == nil { + t.Errorf("MakeJoin returned 200, want 404") + } else if httpError, ok := err.(gomatrix.HTTPError); ok { + t.Logf("MakeJoin => %d/%s", httpError.Code, string(httpError.Contents)) + if httpError.Code != 404 { + t.Errorf("expected 404, got %d", httpError.Code) + } + errcode := must.GetJSONFieldStr(t, httpError.Contents, "errcode") + if errcode != "M_NOT_FOUND" { + t.Errorf("errcode: got %s, want M_NOT_FOUND", errcode) + } + } else { + t.Errorf("MakeJoin: non-HTTPError: %v", err) + } + }) + + // when the server is in the middle of a partial state join, it should not accept + // /send_join because it can't give a full answer. + t.Run("Rejects send_join during partial join", func(t *testing.T) { + // In this test, we have 3 homeservers: + // hs1 (the server under test) with @alice:hs1 + // This is the server that will be in the middle of a partial join. + // testServer1 (a Complement test server) with @charlie: + // This is the server that will create the room originally. + // testServer2 (another Complement test server) with @daniel: + // This is the server that will try to join the room via hs2, + // but only after using hs1 to /make_join (as otherwise we have no way + // of being able to build a request to /send_join) + // + deployment := Deploy(t, b.BlueprintAlice) + defer deployment.Destroy(t) + alice := deployment.Client(t, "hs1", "@alice:hs1") + + testServer1 := createTestServer(t, deployment) + cancel := testServer1.Listen() + defer cancel() + serverRoom := createTestRoom(t, testServer1, alice.GetDefaultRoomVersion(t)) + psjResult := beginPartialStateJoin(t, testServer1, serverRoom, alice) + defer psjResult.Destroy() + + // hs1's partial join is now in progress. + // Let's have a test server rock up and ask to /send_join in the room via hs1. + // To do that, we need to /make_join first. + // Asking hs1 to /make_join won't work, because it should reject that request. + // To work around that, we /make_join via hs2. + + testServer2 := createTestServer(t, deployment) + cancel2 := testServer2.Listen() + defer cancel2() + + fedClient2 := testServer2.FederationClient(deployment) + + // Manually /make_join via testServer1. + // This is permissible because testServer1 is fully joined to the room. + // We can't actually use /make_join because host.docker.internal doesn't resolve, + // so compute it without making any requests: + makeJoinResp, err := federation.MakeRespMakeJoin(testServer1, serverRoom, testServer2.UserID("daniel")) + if err != nil { + t.Fatalf("MakeRespMakeJoin failed : %s", err) + } + + // charlie then tries to /send_join via the homeserver under test + joinEvent, err := makeJoinResp.JoinEvent.Build(time.Now(), gomatrixserverlib.ServerName(testServer2.ServerName()), testServer2.KeyID, testServer2.Priv, makeJoinResp.RoomVersion) + must.NotError(t, "JoinEvent.Build", err) + + // SendJoin should return a 404 because the homeserver under test has not + // finished its partial join. + _, err = fedClient2.SendJoin(context.Background(), "hs1", joinEvent) + if err == nil { + t.Errorf("SendJoin returned 200, want 404") + } else if httpError, ok := err.(gomatrix.HTTPError); ok { + t.Logf("SendJoin => %d/%s", httpError.Code, string(httpError.Contents)) + if httpError.Code != 404 { + t.Errorf("expected 404, got %d", httpError.Code) + } + errcode := must.GetJSONFieldStr(t, httpError.Contents, "errcode") + if errcode != "M_NOT_FOUND" { + t.Errorf("errcode: got %s, want M_NOT_FOUND", errcode) + } + } else { + t.Errorf("SendJoin: non-HTTPError: %v", err) + } + }) } // test reception of an event over federation during a resync