Skip to content

Commit db69918

Browse files
committed
Fix #25211: mark num_outputs as out parameter in PJRT_Executable_OutputDimensions_Args
- Annotate num_outputs in pjrt_c_api.h as /* out */ - Add C API test verifying num_outputs, dims, dim_sizes
1 parent bf3f0a4 commit db69918

File tree

2 files changed

+54
-1
lines changed

2 files changed

+54
-1
lines changed

xla/pjrt/c/pjrt_c_api.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1713,7 +1713,7 @@ struct PJRT_Executable_OutputDimensions_Args {
17131713
size_t struct_size;
17141714
PJRT_Extension_Base* extension_start;
17151715
PJRT_Executable* executable;
1716-
size_t num_outputs;
1716+
size_t num_outputs; // out - Number of output shapes
17171717
// Has length: sum of all elements in the list `dim_sizes`.
17181718
const int64_t* dims; // out
17191719
// Has length `num_outputs`.

xla/pjrt/c/pjrt_c_api_test.cc

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,59 @@ TEST_F(PjrtCApiTest, PluginAttributes) {
503503
EXPECT_TRUE(names.find("stablehlo_minimum_version") != names.end());
504504
}
505505

506+
TEST_F(PjrtCApiTest, ExecutableOutputDimensions) {
507+
// First compile an executable
508+
PJRT_Client_Compile_Args compile_args;
509+
compile_args.struct_size = PJRT_Client_Compile_Args_STRUCT_SIZE;
510+
compile_args.extension_start = nullptr;
511+
compile_args.client = client_;
512+
std::string options_str = BuildSingleDeviceCompileOptionStr();
513+
compile_args.compile_options = options_str.c_str();
514+
compile_args.compile_options_size = options_str.size();
515+
516+
std::string format(::pjrt::kMlirFormat);
517+
std::string program_code{module_add_one};
518+
PJRT_Program program;
519+
program.struct_size = PJRT_Program_STRUCT_SIZE;
520+
program.extension_start = nullptr;
521+
program.code = program_code.data();
522+
program.code_size = program_code.length();
523+
program.format = format.c_str();
524+
program.format_size = format.size();
525+
compile_args.program = &program;
526+
527+
PJRT_Error* error = api_->PJRT_Client_Compile(&compile_args);
528+
ASSERT_EQ(nullptr, error);
529+
530+
// Now test output dimensions
531+
PJRT_Executable_OutputDimensions_Args args;
532+
args.struct_size = PJRT_Executable_OutputDimensions_Args_STRUCT_SIZE;
533+
args.extension_start = nullptr;
534+
args.executable = compile_args.executable;
535+
args.num_outputs = 0; // Should be set by the API call
536+
args.dims = nullptr;
537+
args.dim_sizes = nullptr;
538+
args.error = nullptr;
539+
540+
PJRT_Error* dim_error = api_->PJRT_Executable_GetOutputDimensions(&args);
541+
ASSERT_EQ(nullptr, dim_error);
542+
543+
// Verify that num_outputs was set properly
544+
EXPECT_GT(args.num_outputs, 0u);
545+
ASSERT_NE(args.dim_sizes, nullptr);
546+
ASSERT_NE(args.dims, nullptr);
547+
548+
// Verify that we have the correct number of output shapes
549+
size_t total_dims = 0;
550+
for (size_t i = 0; i < args.num_outputs; ++i) {
551+
total_dims += args.dim_sizes[i];
552+
}
553+
EXPECT_GT(total_dims, 0);
554+
555+
// Clean up
556+
destroy_executable(compile_args.executable, api_);
557+
}
558+
506559
// --------------------------------- Devices -----------------------------------
507560

508561
TEST_F(PjrtCApiTest, DeviceId) {

0 commit comments

Comments
 (0)