@@ -503,6 +503,59 @@ TEST_F(PjrtCApiTest, PluginAttributes) {
503
503
EXPECT_TRUE (names.find (" stablehlo_minimum_version" ) != names.end ());
504
504
}
505
505
506
+ TEST_F (PjrtCApiTest, ExecutableOutputDimensions) {
507
+ // First compile an executable
508
+ PJRT_Client_Compile_Args compile_args;
509
+ compile_args.struct_size = PJRT_Client_Compile_Args_STRUCT_SIZE;
510
+ compile_args.extension_start = nullptr ;
511
+ compile_args.client = client_;
512
+ std::string options_str = BuildSingleDeviceCompileOptionStr ();
513
+ compile_args.compile_options = options_str.c_str ();
514
+ compile_args.compile_options_size = options_str.size ();
515
+
516
+ std::string format (::pjrt::kMlirFormat );
517
+ std::string program_code{module_add_one};
518
+ PJRT_Program program;
519
+ program.struct_size = PJRT_Program_STRUCT_SIZE;
520
+ program.extension_start = nullptr ;
521
+ program.code = program_code.data ();
522
+ program.code_size = program_code.length ();
523
+ program.format = format.c_str ();
524
+ program.format_size = format.size ();
525
+ compile_args.program = &program;
526
+
527
+ PJRT_Error* error = api_->PJRT_Client_Compile (&compile_args);
528
+ ASSERT_EQ (nullptr , error);
529
+
530
+ // Now test output dimensions
531
+ PJRT_Executable_OutputDimensions_Args args;
532
+ args.struct_size = PJRT_Executable_OutputDimensions_Args_STRUCT_SIZE;
533
+ args.extension_start = nullptr ;
534
+ args.executable = compile_args.executable ;
535
+ args.num_outputs = 0 ; // Should be set by the API call
536
+ args.dims = nullptr ;
537
+ args.dim_sizes = nullptr ;
538
+ args.error = nullptr ;
539
+
540
+ PJRT_Error* dim_error = api_->PJRT_Executable_GetOutputDimensions (&args);
541
+ ASSERT_EQ (nullptr , dim_error);
542
+
543
+ // Verify that num_outputs was set properly
544
+ EXPECT_GT (args.num_outputs , 0u );
545
+ ASSERT_NE (args.dim_sizes , nullptr );
546
+ ASSERT_NE (args.dims , nullptr );
547
+
548
+ // Verify that we have the correct number of output shapes
549
+ size_t total_dims = 0 ;
550
+ for (size_t i = 0 ; i < args.num_outputs ; ++i) {
551
+ total_dims += args.dim_sizes [i];
552
+ }
553
+ EXPECT_GT (total_dims, 0 );
554
+
555
+ // Clean up
556
+ destroy_executable (compile_args.executable , api_);
557
+ }
558
+
506
559
// --------------------------------- Devices -----------------------------------
507
560
508
561
TEST_F (PjrtCApiTest, DeviceId) {
0 commit comments