Skip to content

Commit

Permalink
feat: add semantic similar search
Browse files Browse the repository at this point in the history
  • Loading branch information
sixwaaaay committed Oct 9, 2024
1 parent 980f852 commit 8e4586a
Show file tree
Hide file tree
Showing 8 changed files with 137 additions and 60 deletions.
21 changes: 14 additions & 7 deletions sharp/content.Tests/DomainTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,13 @@ public async Task FindById_ReturnsVideoDto_WhenVideoExists()
var mockVideoRepo = new Mock<IVideoRepository>();
var mockUserRepo = new Mock<IUserRepository>();
var mockVoteRepo = new Mock<IVoteRepository>();
var mockSearchClient = new Mock<SearchClient>(null);

Check warning on line 32 in sharp/content.Tests/DomainTest.cs

View workflow job for this annotation

GitHub Actions / Analyze (csharp)

Cannot convert null literal to non-nullable reference type.

Check warning on line 32 in sharp/content.Tests/DomainTest.cs

View workflow job for this annotation

GitHub Actions / Analyze (csharp)

Cannot convert null literal to non-nullable reference type.

Check warning on line 32 in sharp/content.Tests/DomainTest.cs

View workflow job for this annotation

GitHub Actions / build

Cannot convert null literal to non-nullable reference type.
var video = new Video { Id = 1, UserId = 1 };
var user = new User { Id = "1" };
mockVideoRepo.Setup(repo => repo.FindById(1)).ReturnsAsync(video);
mockUserRepo.Setup(repo => repo.FindById(1)).ReturnsAsync(user);
mockVoteRepo.Setup(repo => repo.VotedOfVideos(It.IsAny<List<long>>())).ReturnsAsync([1]);
var service = new DomainService(mockVideoRepo.Object, mockUserRepo.Object, mockVoteRepo.Object);
var service = new DomainService(mockVideoRepo.Object, mockUserRepo.Object, mockVoteRepo.Object, mockSearchClient.Object);

// Act
var result = await service.FindById(1);
Expand All @@ -54,12 +55,13 @@ public async Task FindAllByIds_ReturnsVideoDtos_WhenVideosExist()
var mockVideoRepo = new Mock<IVideoRepository>();
var mockUserRepo = new Mock<IUserRepository>();
var mockVoteRepo = new Mock<IVoteRepository>();
var mockSearchClient = new Mock<SearchClient>(null);

Check warning on line 58 in sharp/content.Tests/DomainTest.cs

View workflow job for this annotation

GitHub Actions / Analyze (csharp)

Cannot convert null literal to non-nullable reference type.

Check warning on line 58 in sharp/content.Tests/DomainTest.cs

View workflow job for this annotation

GitHub Actions / Analyze (csharp)

Cannot convert null literal to non-nullable reference type.

Check warning on line 58 in sharp/content.Tests/DomainTest.cs

View workflow job for this annotation

GitHub Actions / build

Cannot convert null literal to non-nullable reference type.
var videos = new List<Video> { new() { Id = 1, UserId = 1 }, new() { Id = 2, UserId = 2 } };
var users = new List<User> { new() { Id = "1" }, new() { Id = "2" } };
mockVideoRepo.Setup(repo => repo.FindAllByIds(It.IsAny<IReadOnlyList<long>>())).ReturnsAsync(videos);
mockUserRepo.Setup(repo => repo.FindAllByIds(It.IsAny<IEnumerable<long>>())).ReturnsAsync(users);
mockVoteRepo.Setup(repo => repo.VotedOfVideos(It.IsAny<List<long>>())).ReturnsAsync([]);
var service = new DomainService(mockVideoRepo.Object, mockUserRepo.Object, mockVoteRepo.Object);
var service = new DomainService(mockVideoRepo.Object, mockUserRepo.Object, mockVoteRepo.Object, mockSearchClient.Object);

// Act
var result = await service.FindAllByIds([1, 2]);
Expand All @@ -76,10 +78,11 @@ public async Task Save_ReturnsSavedVideoDto_WhenVideoIsSaved()
var mockVideoRepo = new Mock<IVideoRepository>();
var mockUserRepo = new Mock<IUserRepository>();
var mockVoteRepo = new Mock<IVoteRepository>();
var mockSearchClient = new Mock<SearchClient>(null);

Check warning on line 81 in sharp/content.Tests/DomainTest.cs

View workflow job for this annotation

GitHub Actions / Analyze (csharp)

Cannot convert null literal to non-nullable reference type.

Check warning on line 81 in sharp/content.Tests/DomainTest.cs

View workflow job for this annotation

GitHub Actions / Analyze (csharp)

Cannot convert null literal to non-nullable reference type.

Check warning on line 81 in sharp/content.Tests/DomainTest.cs

View workflow job for this annotation

GitHub Actions / build

Cannot convert null literal to non-nullable reference type.
var video = new Video { Id = 1, UserId = 1 };
var user = new User { Id = "1" };
mockVideoRepo.Setup(repo => repo.Save(It.IsAny<Video>())).ReturnsAsync(video);
var service = new DomainService(mockVideoRepo.Object, mockUserRepo.Object, mockVoteRepo.Object);
var service = new DomainService(mockVideoRepo.Object, mockUserRepo.Object, mockVoteRepo.Object, mockSearchClient.Object);

// Act
await service.Save(new Video());
Expand All @@ -93,13 +96,14 @@ public async Task FindByUserId_ReturnsExpectedVideos()
var mockVideoRepo = new Mock<IVideoRepository>();
var mockUserRepo = new Mock<IUserRepository>();
var mockVoteRepo = new Mock<IVoteRepository>();
var mockSearchClient = new Mock<SearchClient>(null);

Check warning on line 99 in sharp/content.Tests/DomainTest.cs

View workflow job for this annotation

GitHub Actions / Analyze (csharp)

Cannot convert null literal to non-nullable reference type.

Check warning on line 99 in sharp/content.Tests/DomainTest.cs

View workflow job for this annotation

GitHub Actions / build

Cannot convert null literal to non-nullable reference type.
var videos = new List<Video> { new() { Id = 1, UserId = 1 }, new() { Id = 2, UserId = 1 } };
var user = new User { Id = "1" };
var voteVideoIds = new List<long> { 1, 2 };
mockVideoRepo.Setup(repo => repo.FindByUserId(1, 1, 2)).ReturnsAsync(videos);
mockUserRepo.Setup(repo => repo.FindById(1)).ReturnsAsync(user);
mockVoteRepo.Setup(repo => repo.VotedOfVideos(It.IsAny<List<long>>())).ReturnsAsync(voteVideoIds);
var service = new DomainService(mockVideoRepo.Object, mockUserRepo.Object, mockVoteRepo.Object);
var service = new DomainService(mockVideoRepo.Object, mockUserRepo.Object, mockVoteRepo.Object, mockSearchClient.Object);

// Act
var result = await service.FindByUserId(1, 1, 2);
Expand All @@ -116,14 +120,15 @@ public async Task FindRecent_ReturnsExpectedVideos()
var mockVideoRepo = new Mock<IVideoRepository>();
var mockUserRepo = new Mock<IUserRepository>();
var mockVoteRepo = new Mock<IVoteRepository>();
var mockSearchClient = new Mock<SearchClient>(null);

Check warning on line 123 in sharp/content.Tests/DomainTest.cs

View workflow job for this annotation

GitHub Actions / Analyze (csharp)

Cannot convert null literal to non-nullable reference type.

Check warning on line 123 in sharp/content.Tests/DomainTest.cs

View workflow job for this annotation

GitHub Actions / build

Cannot convert null literal to non-nullable reference type.
var videos = new List<Video> { new() { Id = 1, UserId = 1 }, new() { Id = 2, UserId = 2 } };
var users = new List<User> { new() { Id = "1" }, new() { Id = "2" } };
var voteVideoIds = new List<long> { 1 };

mockVideoRepo.Setup(repo => repo.FindRecent(1, 2)).ReturnsAsync(videos);
mockUserRepo.Setup(repo => repo.FindAllByIds(It.IsAny<IEnumerable<long>>())).ReturnsAsync(users);
mockVoteRepo.Setup(repo => repo.VotedOfVideos(It.IsAny<List<long>>())).ReturnsAsync(voteVideoIds);
var service = new DomainService(mockVideoRepo.Object, mockUserRepo.Object, mockVoteRepo.Object);
var service = new DomainService(mockVideoRepo.Object, mockUserRepo.Object, mockVoteRepo.Object, mockSearchClient.Object);
// Act
var result = await service.FindRecent(1, 2);

Expand All @@ -139,14 +144,15 @@ public async Task DailyPopularVideos_ReturnsExpectedVideos()
var mockVideoRepo = new Mock<IVideoRepository>();
var mockUserRepo = new Mock<IUserRepository>();
var mockVoteRepo = new Mock<IVoteRepository>();
var mockSearchClient = new Mock<SearchClient>(null);

Check warning on line 147 in sharp/content.Tests/DomainTest.cs

View workflow job for this annotation

GitHub Actions / Analyze (csharp)

Cannot convert null literal to non-nullable reference type.

Check warning on line 147 in sharp/content.Tests/DomainTest.cs

View workflow job for this annotation

GitHub Actions / build

Cannot convert null literal to non-nullable reference type.
var videos = new List<Video> { new() { Id = 1, UserId = 1 }, new() { Id = 2, UserId = 2 } };
var users = new List<User> { new() { Id = "1" }, new() { Id = "2" } };
var voteVideoIds = new List<long> { 1 };
mockVideoRepo.Setup(repo => repo.DailyPopularVideos(1, 2)).ReturnsAsync((2, videos));
mockVoteRepo.Setup(repo => repo.VotedOfVideos(It.IsAny<List<long>>())).ReturnsAsync(voteVideoIds);
mockVideoRepo.Setup(repo => repo.FindAllByIds(It.IsAny<long[]>())).ReturnsAsync(videos);
mockUserRepo.Setup(repo => repo.FindAllByIds(It.IsAny<IEnumerable<long>>())).ReturnsAsync(users);
var service = new DomainService(mockVideoRepo.Object, mockUserRepo.Object, mockVoteRepo.Object);
var service = new DomainService(mockVideoRepo.Object, mockUserRepo.Object, mockVoteRepo.Object, mockSearchClient.Object);

// Act
var result = await service.DailyPopularVideos(1, 2);
Expand All @@ -164,6 +170,7 @@ public async Task VotedVideos_ReturnsExpectedVideos()
var mockVideoRepo = new Mock<IVideoRepository>();
var mockUserRepo = new Mock<IUserRepository>();
var mockVoteRepo = new Mock<IVoteRepository>();
var mockSearchClient = new Mock<SearchClient>(null);

Check warning on line 173 in sharp/content.Tests/DomainTest.cs

View workflow job for this annotation

GitHub Actions / Analyze (csharp)

Cannot convert null literal to non-nullable reference type.

Check warning on line 173 in sharp/content.Tests/DomainTest.cs

View workflow job for this annotation

GitHub Actions / build

Cannot convert null literal to non-nullable reference type.
var videoIds = new long[] { 1, 2 };
var videos = new List<Video> { new() { Id = 1, UserId = 1 }, new() { Id = 2, UserId = 2 } };
var users = new List<User> { new() { Id = "1" }, new() { Id = "2" } };
Expand All @@ -172,7 +179,7 @@ public async Task VotedVideos_ReturnsExpectedVideos()
mockVoteRepo.Setup(repo => repo.VotedOfVideos(It.IsAny<List<long>>())).ReturnsAsync(voteVideoIds);
mockVideoRepo.Setup(repo => repo.FindAllByIds(It.IsAny<long[]>())).ReturnsAsync(videos);
mockUserRepo.Setup(repo => repo.FindAllByIds(It.IsAny<IEnumerable<long>>())).ReturnsAsync(users);
var service = new DomainService(mockVideoRepo.Object, mockUserRepo.Object, mockVoteRepo.Object);
var service = new DomainService(mockVideoRepo.Object, mockUserRepo.Object, mockVoteRepo.Object, mockSearchClient.Object);

// Act
var result = await service.VotedVideos(1, 1, 2);
Expand Down
44 changes: 43 additions & 1 deletion sharp/content.Tests/repository/ClientTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,48 @@ public async Task VotedVideos_ThrowsException_WhenResponseIsNotSuccessful()
// Act & Assert
await Assert.ThrowsAsync<HttpRequestException>(() => voteRepository.VotedVideos(page, size));
}

[Fact]
public async Task SimilarSearch_ReturnsListOfVideoIds()
{
// Arrange
var videoId = 1;
var expectedResponse = new Response()
{
Hits = [
new SimilarVideo(2), new SimilarVideo(3)
]
};

var mockHttpMessageHandler = new Mock<HttpMessageHandler>();
var mockFactory = new Mock<IHttpClientFactory>();
var client = new HttpClient(mockHttpMessageHandler.Object)
{
BaseAddress = new Uri("http://localhost:5151")
};
mockFactory.Setup(_ => _.CreateClient("Search")).Returns(client);

var searchClient = new SearchClient(mockFactory.Object);

mockHttpMessageHandler.Protected()
.Setup<Task<HttpResponseMessage>>(
"SendAsync",
ItExpr.IsAny<HttpRequestMessage>(),
ItExpr.IsAny<CancellationToken>()
)
.ReturnsAsync(new HttpResponseMessage
{
StatusCode = HttpStatusCode.OK,
Content = JsonContent.Create(expectedResponse, SearchContext.Default.Response)
});

// Act
var result = await searchClient.SimilarSearch(videoId);

// Assert
Assert.Equal(expectedResponse.Hits.Select(h => h.Id).ToList(), result);

}
}


Expand All @@ -173,4 +215,4 @@ public record ScanResp(long? NextToken, List<long> TargetIds);
[JsonSerializable(typeof(List<long>))]
[JsonSerializable(typeof(ScanResp))]
[JsonSerializable(typeof(InQuery))]
partial class VoteJsonContext: JsonSerializerContext;
partial class VoteJsonContext : JsonSerializerContext;
3 changes: 3 additions & 0 deletions sharp/content/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@

builder.Services.AddVoteRepository(builder.Configuration.GetConnectionString("Vote") ?? throw new InvalidOperationException("Vote connection string is null"));

builder.Services.AddSearchClient(builder.Configuration.GetConnectionString("Search") ?? throw new InvalidOperationException("Search connection string is null"),
builder.Configuration["Token"] ?? throw new InvalidOperationException("Token is null"));

Check warning on line 90 in sharp/content/Program.cs

View check run for this annotation

Codecov / codecov/patch

sharp/content/Program.cs#L90

Added line #L90 was not covered by tests

builder.Services.AddDomainService().AddMessageDomain();

var app = builder.Build();
Expand Down
10 changes: 6 additions & 4 deletions sharp/content/appsettings.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
},
"AllowedHosts": "*",
"ConnectionStrings": {
"Default" : "Server=db; User ID=name; Password=passwd; Database=db;",
"Default": "Server=db; User ID=name; Password=passwd; Database=db;",
"User": "http://user:8080",
"Vote": "http://graph:8088"
"Vote": "http://graph:8088",
"Search": "http://localhost:7700"
},
"Secret": "secret"
}
"Secret": "secret",
"Token": "this is a token for search"
}
7 changes: 5 additions & 2 deletions sharp/content/domainservice/DomainService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,18 @@ public interface IDomainService
Task<Pagination<VideoDto>> FindRecent(long page, int size);
Task<Pagination<VideoDto>> VotedVideos(long userId, long page, int size);
Task<Pagination<VideoDto>> DailyPopularVideos(long page, int size);
Task<IReadOnlyList<VideoDto>> FindSimilarVideos(long videoId);
Task Save(Video video);
}

public class DomainService(IVideoRepository videoRepo, IUserRepository userRepo, IVoteRepository voteRepo)
public class DomainService(IVideoRepository videoRepo, IUserRepository userRepo, IVoteRepository voteRepo, SearchClient searchClient)
: IDomainService
{
public async Task<VideoDto> FindById(long id)
{
var video = await videoRepo.FindById(id);
var videoToVideoDto = video.ToDto();
var (UserTask, VoteTask) = ( userRepo.FindById(video.UserId), voteRepo.VotedOfVideos([video.Id]));
var (UserTask, VoteTask) = (userRepo.FindById(video.UserId), voteRepo.VotedOfVideos([video.Id]));
var user = await UserTask;
var votedVideos = await VoteTask;

Expand Down Expand Up @@ -128,6 +129,8 @@ public async Task<Pagination<VideoDto>> DailyPopularVideos(long page, int size)
NextPage = token.ToString()
};
}

public async Task<IReadOnlyList<VideoDto>> FindSimilarVideos(long videoId) => await FindAllByIds(await searchClient.SimilarSearch(videoId));

Check warning on line 133 in sharp/content/domainservice/DomainService.cs

View check run for this annotation

Codecov / codecov/patch

sharp/content/domainservice/DomainService.cs#L133

Added line #L133 was not covered by tests
}

[Mapper]
Expand Down
4 changes: 4 additions & 0 deletions sharp/content/endpoints/Endpoints.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ public static Task<Pagination<VideoDto>> Likes(IDomainService service, long user
return service.VotedVideos(userId, page ?? long.MaxValue, size ?? 10);
}

public static Task<IReadOnlyList<VideoDto>> SimilarVideos(IDomainService service, long id) =>
service.FindSimilarVideos(id);

Check warning on line 74 in sharp/content/endpoints/Endpoints.cs

View check run for this annotation

Codecov / codecov/patch

sharp/content/endpoints/Endpoints.cs#L74

Added line #L74 was not covered by tests

public static async Task CreateVideo(IDomainService service, IProbe probe, ClaimsPrincipal user,
VideoRequest request, VideoRequestValidator validator)
{
Expand Down Expand Up @@ -123,6 +126,7 @@ public static void MapEndpoints(this IEndpointRouteBuilder endpoints)
endpoints.MapGet("/users/{userId:long}/likes", Likes).WithName("getUserLikes");
endpoints.MapGet("/videos", Videos).WithName("getVideos");
endpoints.MapGet("/videos/{id:long}", FindVideoById).WithName("getVideo");
endpoints.MapGet("/videos/{id:long}/similar", SimilarVideos).WithName("getSimilarVideos");
endpoints.MapPost("/videos/popular", DailyPopularVideos).WithName("getDailyPopularVideos");
endpoints.MapPost("/videos", CreateVideo).RequireAuthorization().WithName("createVideo");

Expand Down
60 changes: 48 additions & 12 deletions sharp/content/repository/Client.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,11 @@ public interface IUserRepository
{
string? Token { get; set; }

/// <summary>
/// Find user information by id.
/// </summary>
/// <summary> Find user information by id. </summary>
/// <param name="id"> User id. </param>
/// <returns> User information. </returns>
Task<User> FindById(long id);
/// <summary>
/// Find user information by id list.
/// </summary>
/// <summary> Find user information by id list. </summary>
/// <param name="ids"> User id list. </param>
/// <returns> User information list. </returns>
Task<IReadOnlyList<User>> FindAllByIds(IEnumerable<long> ids);
Expand Down Expand Up @@ -109,17 +105,13 @@ public interface IVoteRepository
{
string? Token { get; set; }

/// <summary>
/// Get voted status of videos.
/// </summary>
/// <summary> Get voted status of videos. </summary>
/// <param name="videoIds"> Video ids. </param>
/// <returns> Voted status of videos. </returns>
Task<IReadOnlyList<long>> VotedOfVideos(List<long> videoIds);


/// <summary>
/// Scan voted videos, which means paging through all voted videos.
/// </summary>
/// <summary> Scan voted videos, which means paging through all voted videos. </summary>
/// <param name="userId"> User id. </param>
/// <param name="page"> Page token. </param>
/// <param name="size"> Page size. </param>
Expand Down Expand Up @@ -184,6 +176,41 @@ public record InQuery(List<long> ObjectIds);
[JsonSerializable(typeof(InQuery))]
internal partial class VoteJsonContext : JsonSerializerContext;

public class SearchClient(IHttpClientFactory clientFactory)
{

public async Task<IReadOnlyList<long>> SimilarSearch(long videoId)
{
using var client = clientFactory.CreateClient("Search");
var body = new RequestBody(videoId, ["id"]);
var content = JsonContent.Create(body, SearchContext.Default.RequestBody);
var req = new HttpRequestMessage(HttpMethod.Post, "/indexes/videos/similar") { Content = content };
var resp = await client.SendAsync(req);
resp.EnsureSuccessStatusCode();

var result = await resp.Content.ReadFromJsonAsync(SearchContext.Default.Response) ?? new Response();

return result.Hits.Select(h => h.Id).ToList();
}

}
public record RequestBody(
[property: JsonPropertyName("id")] long Id,
[property: JsonPropertyName("attributesToRetrieve")] string[] AttributesToRetrieve

Check warning on line 199 in sharp/content/repository/Client.cs

View check run for this annotation

Codecov / codecov/patch

sharp/content/repository/Client.cs#L198-L199

Added lines #L198 - L199 were not covered by tests
);

public record Response

Check warning on line 202 in sharp/content/repository/Client.cs

View check run for this annotation

Codecov / codecov/patch

sharp/content/repository/Client.cs#L202

Added line #L202 was not covered by tests
{
[JsonPropertyName("hits")]
public IReadOnlyList<SimilarVideo> Hits { get; init; } = [];
}

public record SimilarVideo([property: JsonPropertyName("id")] long Id);

[JsonSerializable(typeof(Response))]
[JsonSerializable(typeof(RequestBody))]
public partial class SearchContext : JsonSerializerContext;

public static class Extension
{
public static IServiceCollection AddVoteRepository(this IServiceCollection services, string baseAddress)
Expand All @@ -194,6 +221,15 @@ public static IServiceCollection AddVoteRepository(this IServiceCollection servi
return services;
}

public static IServiceCollection AddSearchClient(this IServiceCollection services, string baseAddress, string token)
{
services.AddScoped<SearchClient>().AddHttpClient("Search", client => {
client.BaseAddress = new Uri(baseAddress.TrimEnd('/'));
client.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", token);
});
return services;
}

Check warning on line 231 in sharp/content/repository/Client.cs

View check run for this annotation

Codecov / codecov/patch

sharp/content/repository/Client.cs#L225-L231

Added lines #L225 - L231 were not covered by tests

public static IServiceCollection AddUserRepository(this IServiceCollection services) =>
services.AddScoped<IUserRepository, UserRepository>();

Expand Down
Loading

0 comments on commit 8e4586a

Please sign in to comment.