@@ -504,34 +504,19 @@ TEST_F(PjrtCApiTest, PluginAttributes) {
504
504
}
505
505
506
506
TEST_F (PjrtCApiTest, ExecutableOutputDimensions) {
507
- // First compile an executable
508
- PJRT_Client_Compile_Args compile_args;
509
- compile_args.struct_size = PJRT_Client_Compile_Args_STRUCT_SIZE;
510
- compile_args.extension_start = nullptr ;
511
- compile_args.client = client_;
512
- std::string options_str = BuildSingleDeviceCompileOptionStr ();
513
- compile_args.compile_options = options_str.c_str ();
514
- compile_args.compile_options_size = options_str.size ();
507
+ // Create executable using the helper
508
+ auto executable_or = create_executable (api_, client_);
509
+ ASSERT_NE (executable_or.get (), nullptr );
515
510
516
- std::string format (::pjrt::kMlirFormat );
517
- std::string program_code{module_add_one};
518
- PJRT_Program program;
519
- program.struct_size = PJRT_Program_STRUCT_SIZE;
520
- program.extension_start = nullptr ;
521
- program.code = program_code.data ();
522
- program.code_size = program_code.length ();
523
- program.format = format.c_str ();
524
- program.format_size = format.size ();
525
- compile_args.program = &program;
526
-
527
- PJRT_Error* error = api_->PJRT_Client_Compile (&compile_args);
528
- ASSERT_EQ (nullptr , error);
511
+ // Get the underlying PJRT_Executable
512
+ auto executable = GetExecutable (executable_or.get (), api_);
513
+ ASSERT_NE (executable.get (), nullptr );
529
514
530
515
// Now test output dimensions
531
516
PJRT_Executable_OutputDimensions_Args args;
532
517
args.struct_size = PJRT_Executable_OutputDimensions_Args_STRUCT_SIZE;
533
518
args.extension_start = nullptr ;
534
- args.executable = compile_args. executable ;
519
+ args.executable = executable. get () ;
535
520
args.num_outputs = 0 ; // Should be set by the API call
536
521
args.dims = nullptr ;
537
522
args.dim_sizes = nullptr ;
@@ -541,7 +526,8 @@ TEST_F(PjrtCApiTest, ExecutableOutputDimensions) {
541
526
ASSERT_EQ (nullptr , dim_error);
542
527
543
528
// Verify that num_outputs was set properly
544
- EXPECT_GT (args.num_outputs , 0u );
529
+ const size_t expected_num_outputs = 1 ; // The add_one computation has 1 output
530
+ EXPECT_EQ (args.num_outputs , expected_num_outputs);
545
531
ASSERT_NE (args.dim_sizes , nullptr );
546
532
ASSERT_NE (args.dims , nullptr );
547
533
@@ -551,9 +537,6 @@ TEST_F(PjrtCApiTest, ExecutableOutputDimensions) {
551
537
total_dims += args.dim_sizes [i];
552
538
}
553
539
EXPECT_GT (total_dims, 0 );
554
-
555
- // Clean up
556
- destroy_executable (compile_args.executable , api_);
557
540
}
558
541
559
542
// --------------------------------- Devices -----------------------------------
0 commit comments