Skip to content

Commit c134a92

Browse files
committed
Refactor ExecutableOutputDimensions test: use CreateExecutable() helper
1 parent db69918 commit c134a92

File tree

1 file changed

+9
-26
lines changed

1 file changed

+9
-26
lines changed

xla/pjrt/c/pjrt_c_api_test.cc

Lines changed: 9 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -504,34 +504,19 @@ TEST_F(PjrtCApiTest, PluginAttributes) {
504504
}
505505

506506
TEST_F(PjrtCApiTest, ExecutableOutputDimensions) {
507-
// First compile an executable
508-
PJRT_Client_Compile_Args compile_args;
509-
compile_args.struct_size = PJRT_Client_Compile_Args_STRUCT_SIZE;
510-
compile_args.extension_start = nullptr;
511-
compile_args.client = client_;
512-
std::string options_str = BuildSingleDeviceCompileOptionStr();
513-
compile_args.compile_options = options_str.c_str();
514-
compile_args.compile_options_size = options_str.size();
507+
// Create executable using the helper
508+
auto executable_or = create_executable(api_, client_);
509+
ASSERT_NE(executable_or.get(), nullptr);
515510

516-
std::string format(::pjrt::kMlirFormat);
517-
std::string program_code{module_add_one};
518-
PJRT_Program program;
519-
program.struct_size = PJRT_Program_STRUCT_SIZE;
520-
program.extension_start = nullptr;
521-
program.code = program_code.data();
522-
program.code_size = program_code.length();
523-
program.format = format.c_str();
524-
program.format_size = format.size();
525-
compile_args.program = &program;
526-
527-
PJRT_Error* error = api_->PJRT_Client_Compile(&compile_args);
528-
ASSERT_EQ(nullptr, error);
511+
// Get the underlying PJRT_Executable
512+
auto executable = GetExecutable(executable_or.get(), api_);
513+
ASSERT_NE(executable.get(), nullptr);
529514

530515
// Now test output dimensions
531516
PJRT_Executable_OutputDimensions_Args args;
532517
args.struct_size = PJRT_Executable_OutputDimensions_Args_STRUCT_SIZE;
533518
args.extension_start = nullptr;
534-
args.executable = compile_args.executable;
519+
args.executable = executable.get();
535520
args.num_outputs = 0; // Should be set by the API call
536521
args.dims = nullptr;
537522
args.dim_sizes = nullptr;
@@ -541,7 +526,8 @@ TEST_F(PjrtCApiTest, ExecutableOutputDimensions) {
541526
ASSERT_EQ(nullptr, dim_error);
542527

543528
// Verify that num_outputs was set properly
544-
EXPECT_GT(args.num_outputs, 0u);
529+
const size_t expected_num_outputs = 1; // The add_one computation has 1 output
530+
EXPECT_EQ(args.num_outputs, expected_num_outputs);
545531
ASSERT_NE(args.dim_sizes, nullptr);
546532
ASSERT_NE(args.dims, nullptr);
547533

@@ -551,9 +537,6 @@ TEST_F(PjrtCApiTest, ExecutableOutputDimensions) {
551537
total_dims += args.dim_sizes[i];
552538
}
553539
EXPECT_GT(total_dims, 0);
554-
555-
// Clean up
556-
destroy_executable(compile_args.executable, api_);
557540
}
558541

559542
// --------------------------------- Devices -----------------------------------

0 commit comments

Comments
 (0)