Skip to content

Commit

Permalink
Refactor existing drivers to bring the code closer together
Browse files Browse the repository at this point in the history
  • Loading branch information
CharliePoole committed Jan 4, 2025
1 parent d9af6bb commit d82d271
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 84 deletions.
10 changes: 4 additions & 6 deletions src/NUnitEngine/nunit.engine.core/Drivers/NUnit3DriverFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,11 @@ public bool IsSupportedTestFramework(AssemblyName reference)

#if NETFRAMEWORK
/// <summary>
/// Gets a driver for a given test assembly and a framework
/// which the assembly is already known to reference.
/// Gets a driver for a given test framework.
/// </summary>
/// <param name="domain">The domain in which the assembly will be loaded</param>
/// <param name="reference">An AssemblyName referring to the test framework.</param>
/// <returns></returns>
/// <returns>An IFrameworkDriver</returns>
public IFrameworkDriver GetDriver(AppDomain domain, AssemblyName reference)
{
Guard.ArgumentValid(IsSupportedTestFramework(reference), "Invalid framework", "reference");
Expand All @@ -39,16 +38,15 @@ public IFrameworkDriver GetDriver(AppDomain domain, AssemblyName reference)
}
#else
/// <summary>
/// Gets a driver for a given test assembly and a framework
/// which the assembly is already known to reference.
/// Gets a driver for a given test framework.
/// </summary>
/// <param name="reference">An AssemblyName referring to the test framework.</param>
/// <returns></returns>
public IFrameworkDriver GetDriver(AssemblyName reference)
{
Guard.ArgumentValid(IsSupportedTestFramework(reference), "Invalid framework", "reference");
log.Info("Using NUnitNetCore31Driver");
return new NUnitNetCore31Driver();
return new NUnitNetCore31Driver(reference);
}
#endif
}
Expand Down
55 changes: 31 additions & 24 deletions src/NUnitEngine/nunit.engine.core/Drivers/NUnit3FrameworkDriver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
using NUnit.Common;
using NUnit.Engine.Internal;
using NUnit.Engine.Extensibility;
using System.Diagnostics.CodeAnalysis;

namespace NUnit.Engine.Drivers
{
Expand All @@ -19,51 +18,60 @@ namespace NUnit.Engine.Drivers
/// </summary>
public class NUnit3FrameworkDriver : IFrameworkDriver
{
// Messages
private const string LOAD_MESSAGE = "Method called without calling Load first";

// API Constants
private static readonly string CONTROLLER_TYPE = "NUnit.Framework.Api.FrameworkController";
private static readonly string LOAD_ACTION = CONTROLLER_TYPE + "+LoadTestsAction";
private static readonly string EXPLORE_ACTION = CONTROLLER_TYPE + "+ExploreTestsAction";
private static readonly string COUNT_ACTION = CONTROLLER_TYPE + "+CountTestsAction";
private static readonly string RUN_ACTION = CONTROLLER_TYPE + "+RunTestsAction";
private static readonly string STOP_RUN_ACTION = CONTROLLER_TYPE + "+StopRunAction";

static readonly ILogger log = InternalTrace.GetLogger("NUnitFrameworkDriver");
static readonly ILogger log = InternalTrace.GetLogger(nameof(NUnit3FrameworkDriver));

readonly AppDomain _testDomain;
readonly AssemblyName _reference;
readonly AssemblyName _nunitRef;
string? _testAssemblyPath;

object? _frameworkController;
Type? _frameworkControllerType;

/// <summary>
/// Construct an NUnit3FrameworkDriver
/// </summary>
/// <param name="testDomain">The application domain in which to create the FrameworkController</param>
/// <param name="reference">An AssemblyName referring to the test framework.</param>
public NUnit3FrameworkDriver(AppDomain testDomain, AssemblyName reference)
/// <param name="nunitRef">An AssemblyName referring to the test framework.</param>
public NUnit3FrameworkDriver(AppDomain testDomain, AssemblyName nunitRef)
{
_testDomain = testDomain;
_reference = reference;
_nunitRef = nunitRef;
}

/// <summary>
/// An id prefix that will be passed to the test framework and used as part of the
/// test ids created.
/// </summary>
public string ID { get; set; } = string.Empty;

/// <summary>
/// Loads the tests in an assembly.
/// </summary>
/// <returns>An Xml string representing the loaded test</returns>
/// <param name="testAssemblyPath">The path to the test assembly</param>
/// <param name="settings">The test settings</param>
/// <returns>An XML string representing the loaded test</returns>
public string Load(string testAssemblyPath, IDictionary<string, object> settings)
{
Guard.ArgumentValid(File.Exists(testAssemblyPath), "Framework driver constructor called with a file name that doesn't exist.", "testAssemblyPath");

Guard.ArgumentValid(File.Exists(testAssemblyPath), "Framework driver called with a file name that doesn't exist.", "testAssemblyPath");
log.Debug($"Loading {testAssemblyPath}");
var idPrefix = string.IsNullOrEmpty(ID) ? "" : ID + "-";

// Normally, the runner should check for an invalid requested runtime, but we make sure here
// Normally, the caller should check for an invalid requested runtime, but we make sure here
var requestedRuntime = settings.ContainsKey(EnginePackageSettings.RequestedRuntimeFramework)
? settings[EnginePackageSettings.RequestedRuntimeFramework] : null;

_testAssemblyPath = testAssemblyPath;
_testAssemblyPath = Path.GetFullPath(testAssemblyPath);

try
{
Expand All @@ -78,6 +86,9 @@ public string Load(string testAssemblyPath, IDictionary<string, object> settings
throw new NUnitEngineException("The NUnit 3 driver cannot support this test assembly. Use a platform specific runner.", ex);
}

_frameworkControllerType = _frameworkController.GetType();
log.Debug($"Created FrameworkControler {_frameworkControllerType.Name}");

CallbackHandler handler = new CallbackHandler();

var fileName = Path.GetFileName(_testAssemblyPath);
Expand All @@ -86,19 +97,21 @@ public string Load(string testAssemblyPath, IDictionary<string, object> settings

CreateObject(LOAD_ACTION, _frameworkController, handler);

log.Info("Loaded {0}", fileName);
log.Debug($"Loaded {testAssemblyPath}");

return handler.Result.ShouldNotBeNull();
}

/// <summary>
/// Counts the number of test cases for the loaded test assembly
/// </summary>
/// <param name="filter">The XML test filter</param>
/// <returns>The number of test cases</returns>
public int CountTestCases(string filter)
{
CheckLoadWasCalled();

CallbackHandler handler = new CallbackHandler();

CreateObject(COUNT_ACTION, _frameworkController.ShouldNotBeNull(), filter, handler);

return int.Parse(handler.Result.ShouldNotBeNull());
}

Expand All @@ -111,12 +124,9 @@ public int CountTestCases(string filter)
public string Run(ITestEventListener? listener, string filter)
{
CheckLoadWasCalled();

var handler = new RunTestsCallbackHandler(listener);

log.Info("Running {0} - see separate log file", Path.GetFileName(_testAssemblyPath.ShouldNotBeNull()));
var handler = new RunTestsCallbackHandler(listener);
CreateObject(RUN_ACTION, _frameworkController.ShouldNotBeNull(), filter, handler);

return handler.Result.ShouldNotBeNull();
}

Expand All @@ -137,12 +147,9 @@ public void StopRun(bool force)
public string Explore(string filter)
{
CheckLoadWasCalled();

CallbackHandler handler = new CallbackHandler();

log.Info("Exploring {0} - see separate log file", Path.GetFileName(_testAssemblyPath.ShouldNotBeNull()));
CallbackHandler handler = new CallbackHandler();
CreateObject(EXPLORE_ACTION, _frameworkController.ShouldNotBeNull(), filter, handler);

return handler.Result.ShouldNotBeNull();
}

Expand All @@ -157,7 +164,7 @@ private object CreateObject(string typeName, params object?[]? args)
try
{
return _testDomain.CreateInstanceAndUnwrap(
_reference.FullName, typeName, false, 0, null, args, null, null )!;
_nunitRef.FullName, typeName, false, 0, null, args, null, null )!;
}
catch (TargetInvocationException ex)
{
Expand Down
135 changes: 83 additions & 52 deletions src/NUnitEngine/nunit.engine.core/Drivers/NUnitNetCore31Driver.cs
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
// Copyright (c) Charlie Poole, Rob Prouse and Contributors. MIT License - see LICENSE.txt

#if NETCOREAPP3_1_OR_GREATER
#if NETCOREAPP
using System;
using System.Linq;
using System.Collections.Generic;
using System.IO;
using NUnit.Engine.Internal;
using System.Reflection;
using NUnit.Engine.Extensibility;
using System.Diagnostics;
using NUnit.Common;
using NUnit.Engine.Internal;
using NUnit.Engine.Extensibility;

namespace NUnit.Engine.Drivers
{
Expand All @@ -21,27 +20,40 @@ namespace NUnit.Engine.Drivers
/// </summary>
public class NUnitNetCore31Driver : IFrameworkDriver
{
const string LOAD_MESSAGE = "Method called without calling Load first";
private const string LOAD_MESSAGE = "Method called without calling Load first";
const string INVALID_FRAMEWORK_MESSAGE = "Running tests against this version of the framework using this driver is not supported. Please update NUnit.Framework to the latest version.";
const string FAILED_TO_LOAD_TEST_ASSEMBLY = "Failed to load the test assembly {0}";
const string FAILED_TO_LOAD_ASSEMBLY = "Failed to load assembly ";
const string FAILED_TO_LOAD_NUNIT = "Failed to load the NUnit Framework in the test assembly";

static readonly string CONTROLLER_TYPE = "NUnit.Framework.Api.FrameworkController";
static readonly string LOAD_METHOD = "LoadTests";
static readonly string EXPLORE_METHOD = "ExploreTests";
static readonly string COUNT_METHOD = "CountTests";
static readonly string RUN_METHOD = "RunTests";
static readonly string RUN_ASYNC_METHOD = "RunTests";
static readonly string STOP_RUN_METHOD = "StopRun";
private static readonly string CONTROLLER_TYPE = "NUnit.Framework.Api.FrameworkController";
private static readonly string LOAD_METHOD = "LoadTests";
private static readonly string EXPLORE_METHOD = "ExploreTests";
private static readonly string COUNT_METHOD = "CountTests";
private static readonly string RUN_METHOD = "RunTests";
private static readonly string RUN_ASYNC_METHOD = "RunTests";
private static readonly string STOP_RUN_METHOD = "StopRun";

static ILogger log = InternalTrace.GetLogger(nameof(NUnitNetCore31Driver));
static readonly ILogger log = InternalTrace.GetLogger(nameof(NUnitNetCore31Driver));

readonly AssemblyName _nunitRef;
string? _testAssemblyPath;

Assembly? _testAssembly;
Assembly? _frameworkAssembly;
object? _frameworkController;
Type? _frameworkControllerType;
Assembly? _testAssembly;
Assembly? _frameworkAssembly;
TestAssemblyLoadContext? _assemblyLoadContext;

/// <summary>
/// Construct an NUnitNetCore31Driver
/// </summary>
/// <param name="reference">An AssemblyName referring to the test framework.</param>
public NUnitNetCore31Driver(AssemblyName nunitRef)
{
Guard.ArgumentNotNull(nunitRef, nameof(nunitRef));
_nunitRef = nunitRef;
}

/// <summary>
/// An id prefix that will be passed to the test framework and used as part of the
/// test ids created.
Expand All @@ -51,46 +63,24 @@ public class NUnitNetCore31Driver : IFrameworkDriver
/// <summary>
/// Loads the tests in an assembly.
/// </summary>
/// <param name="assemblyPath">The path to the test assembly</param>
/// <param name="testAssemblyPath">The path to the test assembly</param>
/// <param name="settings">The test settings</param>
/// <returns>An XML string representing the loaded test</returns>
public string Load(string assemblyPath, IDictionary<string, object> settings)
public string Load(string testAssemblyPath, IDictionary<string, object> settings)
{
log.Debug($"Loading {assemblyPath}");
Guard.ArgumentValid(File.Exists(testAssemblyPath), "Framework driver called with a file name that doesn't exist.", "testAssemblyPath");
log.Debug($"Loading {testAssemblyPath}");
var idPrefix = string.IsNullOrEmpty(ID) ? "" : ID + "-";

assemblyPath = Path.GetFullPath(assemblyPath); //AssemblyLoadContext requires an absolute path
_assemblyLoadContext = new TestAssemblyLoadContext(assemblyPath);

try
{
_testAssembly = _assemblyLoadContext.LoadFromAssemblyPath(assemblyPath);
}
catch (Exception e)
{
var msg = string.Format(FAILED_TO_LOAD_TEST_ASSEMBLY, assemblyPath);
log.Error(msg);
throw new NUnitEngineException(msg, e);
}
log.Debug($"Loaded {assemblyPath}");
// Normally, the caller should check for an invalid requested runtime, but we make sure here
var requestedRuntime = settings.ContainsKey(EnginePackageSettings.RequestedRuntimeFramework)
? settings[EnginePackageSettings.RequestedRuntimeFramework] : null;

var nunitRef = _testAssembly.GetReferencedAssemblies().FirstOrDefault(reference => string.Equals(reference.Name, "nunit.framework", StringComparison.OrdinalIgnoreCase));
if (nunitRef == null)
{
log.Error(FAILED_TO_LOAD_NUNIT);
throw new NUnitEngineException(FAILED_TO_LOAD_NUNIT);
}
_testAssemblyPath = Path.GetFullPath(testAssemblyPath);
_assemblyLoadContext = new TestAssemblyLoadContext(_testAssemblyPath);

try
{
_frameworkAssembly = _assemblyLoadContext.LoadFromAssemblyName(nunitRef);
}
catch (Exception e)
{
log.Error($"{FAILED_TO_LOAD_NUNIT}\r\n{e}");
throw new NUnitEngineException(FAILED_TO_LOAD_NUNIT, e);
}
log.Debug("Loaded nunit.framework");
_testAssembly = LoadAssembly(_testAssemblyPath!);
_frameworkAssembly = LoadAssembly(_nunitRef);

_frameworkController = CreateObject(CONTROLLER_TYPE, _testAssembly, idPrefix, settings);
if (_frameworkController == null)
Expand Down Expand Up @@ -161,23 +151,64 @@ public void StopRun(bool force)
public string Explore(string filter)
{
CheckLoadWasCalled();

log.Info("Exploring {0} - see separate log file", _testAssembly.ShouldNotBeNull().FullName!);
return (string)ExecuteMethod(EXPLORE_METHOD, filter);
}

void CheckLoadWasCalled()
private void CheckLoadWasCalled()
{
if (_frameworkController == null)
throw new InvalidOperationException(LOAD_MESSAGE);
}

object CreateObject(string typeName, params object?[]? args)
private object CreateObject(string typeName, params object?[]? args)
{
var type = _frameworkAssembly.ShouldNotBeNull().GetType(typeName, throwOnError: true)!;
return Activator.CreateInstance(type, args)!;
}

private Assembly LoadAssembly(string assemblyPath)
{
Assembly assembly;

try
{
assembly = _assemblyLoadContext?.LoadFromAssemblyPath(assemblyPath)!;
if (assembly == null)
throw new Exception("LoadFromAssemblyPath returned null");
}
catch (Exception e)
{
var msg = string.Format(FAILED_TO_LOAD_ASSEMBLY + assemblyPath);
log.Error(msg);
throw new NUnitEngineException(msg, e);
}

log.Debug($"Loaded {assemblyPath}");
return assembly;
}

private Assembly LoadAssembly(AssemblyName assemblyName)
{
Assembly assembly;

try
{
assembly = _assemblyLoadContext?.LoadFromAssemblyName(assemblyName)!;
if (assembly == null)
throw new Exception("LoadFromAssemblyName returned null");
}
catch (Exception e)
{
var msg = string.Format(FAILED_TO_LOAD_ASSEMBLY + assemblyName.FullName);
log.Error($"{FAILED_TO_LOAD_ASSEMBLY}\r\n{e}");
throw new NUnitEngineException(FAILED_TO_LOAD_NUNIT, e);
}

log.Debug($"Loaded {assemblyName.FullName}");
return assembly;
}

object ExecuteMethod(string methodName, params object?[] args)
{
var method = _frameworkControllerType.ShouldNotBeNull().GetMethod(methodName, BindingFlags.Public | BindingFlags.Instance);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ public class TestFilteringTests
public void LoadAssembly()
{
var mockAssemblyPath = System.IO.Path.Combine(TestContext.CurrentContext.TestDirectory, MOCK_ASSEMBLY);
var assemblyName = typeof(TestAttribute).Assembly.GetName();
#if NETCOREAPP3_1_OR_GREATER
_driver = new NUnitNetCore31Driver();
_driver = new NUnitNetCore31Driver(assemblyName);
#else
var assemblyName = typeof(NUnit.Framework.TestAttribute).Assembly.GetName();
_driver = new NUnit3FrameworkDriver(AppDomain.CurrentDomain, assemblyName);
#endif
_driver.Load(mockAssemblyPath, new Dictionary<string, object>());
Expand Down

0 comments on commit d82d271

Please sign in to comment.