diff --git a/source/val/validate.h b/source/val/validate.h index 52267c8ab6..d20a7fd4fb 100644 --- a/source/val/validate.h +++ b/source/val/validate.h @@ -210,6 +210,11 @@ spv_result_t RayQueryPass(ValidationState_t& _, const Instruction* inst); /// Validates correctness of ray tracing instructions. spv_result_t RayTracingPass(ValidationState_t& _, const Instruction* inst); +/// Validates constant values of the Ray Flags operand +spv_result_t RayTracingRayFlagsOperandValue(ValidationState_t& _, + const Instruction* inst, + uint32_t ray_flags_value); + /// Validates correctness of shader execution reorder instructions. spv_result_t RayReorderNVPass(ValidationState_t& _, const Instruction* inst); diff --git a/source/val/validate_ray_query.cpp b/source/val/validate_ray_query.cpp index 9b67fc922b..01be5e7811 100644 --- a/source/val/validate_ray_query.cpp +++ b/source/val/validate_ray_query.cpp @@ -82,10 +82,19 @@ spv_result_t RayQueryPass(ValidationState_t& _, const Instruction* inst) { "OpTypeAccelerationStructureKHR"; } - const uint32_t ray_flags = _.GetOperandTypeId(inst, 2); - if (!_.IsIntScalarType(ray_flags) || _.GetBitWidth(ray_flags) != 32) { + const uint32_t ray_flags = inst->GetOperandAs(2); + bool is_ray_flags_int32 = false; + bool is_ray_flags_const = false; + uint32_t ray_flags_value = 0; + std::tie(is_ray_flags_int32, is_ray_flags_const, ray_flags_value) = + _.EvalInt32IfConst(ray_flags); + if (!is_ray_flags_int32) { return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Ray Flags must be a 32-bit int scalar"; + } else if (is_ray_flags_const) { + if (auto error = + RayTracingRayFlagsOperandValue(_, inst, ray_flags_value)) + return error; } const uint32_t cull_mask = _.GetOperandTypeId(inst, 3); @@ -101,10 +110,20 @@ spv_result_t RayQueryPass(ValidationState_t& _, const Instruction* inst) { << "Ray Origin must be a 32-bit float 3-component vector"; } - const uint32_t ray_tmin = _.GetOperandTypeId(inst, 5); - if (!_.IsFloatScalarType(ray_tmin) || _.GetBitWidth(ray_tmin) != 32) { + const uint32_t ray_tmin = inst->GetOperandAs(5); + bool is_ray_tmin_float32 = false; + bool is_ray_tmin_const = false; + float ray_tmin_value = 0; + std::tie(is_ray_tmin_float32, is_ray_tmin_const, ray_tmin_value) = + _.EvalFloat32IfConst(ray_tmin); + if (!is_ray_tmin_float32) { return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Ray TMin must be a 32-bit float scalar"; + } else if (is_ray_tmin_const && ray_tmin_value < 0.0f) { + // Don't need to check TMax for being negative because it can't without + // being less than TMin + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Ray Tmin is negative (" << ray_tmin_value << ")"; } const uint32_t ray_direction = _.GetOperandTypeId(inst, 6); @@ -115,11 +134,23 @@ spv_result_t RayQueryPass(ValidationState_t& _, const Instruction* inst) { << "Ray Direction must be a 32-bit float 3-component vector"; } - const uint32_t ray_tmax = _.GetOperandTypeId(inst, 7); - if (!_.IsFloatScalarType(ray_tmax) || _.GetBitWidth(ray_tmax) != 32) { + const uint32_t ray_tmax = inst->GetOperandAs(7); + bool is_ray_tmax_float32 = false; + bool is_ray_tmax_const = false; + float ray_tmax_value = 0; + std::tie(is_ray_tmax_float32, is_ray_tmax_const, ray_tmax_value) = + _.EvalFloat32IfConst(ray_tmax); + if (!is_ray_tmax_float32) { return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Ray TMax must be a 32-bit float scalar"; } + + if (is_ray_tmin_const && is_ray_tmax_const && + ray_tmin_value > ray_tmax_value) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Ray Tmin (" << ray_tmin_value + << ") is larger than Ray Tmax (" << ray_tmax_value << ")"; + } break; } diff --git a/source/val/validate_ray_tracing.cpp b/source/val/validate_ray_tracing.cpp index f74e9d4b9d..14d26367c1 100644 --- a/source/val/validate_ray_tracing.cpp +++ b/source/val/validate_ray_tracing.cpp @@ -22,6 +22,41 @@ namespace spvtools { namespace val { +spv_result_t RayTracingRayFlagsOperandValue(ValidationState_t& _, + const Instruction* inst, + uint32_t ray_flags_value) { + const auto HasMoreThenOneBitSet = + [ray_flags_value](const spv::RayFlagsMask ray_flag_mask) { + const uint32_t mask = + ray_flags_value & static_cast(ray_flag_mask); + return mask != 0 && (mask & (mask - 1)); + }; + + if (HasMoreThenOneBitSet(spv::RayFlagsMask::SkipAABBsKHR | + spv::RayFlagsMask::SkipTrianglesKHR)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Ray Flags contains both SkipTrianglesKHR and SkipAABBsKHR"; + } + + if (HasMoreThenOneBitSet(spv::RayFlagsMask::SkipTrianglesKHR | + spv::RayFlagsMask::CullFrontFacingTrianglesKHR | + spv::RayFlagsMask::CullBackFacingTrianglesKHR)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Ray Flags contains more than one of SkipTrianglesKHR or " + "CullFrontFacingTrianglesKHR or CullBackFacingTrianglesKHR"; + } + + if (HasMoreThenOneBitSet(spv::RayFlagsMask::OpaqueKHR | + spv::RayFlagsMask::NoOpaqueKHR | + spv::RayFlagsMask::CullOpaqueKHR | + spv::RayFlagsMask::CullNoOpaqueKHR)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Ray Flags contains more than one of OpaqueKHR or NoOpaqueKHR or " + "CullOpaqueKHR or CullNoOpaqueKHR"; + } + return SPV_SUCCESS; +} + spv_result_t RayTracingPass(ValidationState_t& _, const Instruction* inst) { const spv::Op opcode = inst->opcode(); const uint32_t result_type = inst->type_id(); @@ -51,10 +86,19 @@ spv_result_t RayTracingPass(ValidationState_t& _, const Instruction* inst) { "OpTypeAccelerationStructureKHR"; } - const uint32_t ray_flags = _.GetOperandTypeId(inst, 1); - if (!_.IsIntScalarType(ray_flags) || _.GetBitWidth(ray_flags) != 32) { + const uint32_t ray_flags = inst->GetOperandAs(1); + bool is_ray_flags_int32 = false; + bool is_ray_flags_const = false; + uint32_t ray_flags_value = 0; + std::tie(is_ray_flags_int32, is_ray_flags_const, ray_flags_value) = + _.EvalInt32IfConst(ray_flags); + if (!is_ray_flags_int32) { return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Ray Flags must be a 32-bit int scalar"; + } else if (is_ray_flags_const) { + if (auto error = + RayTracingRayFlagsOperandValue(_, inst, ray_flags_value)) + return error; } const uint32_t cull_mask = _.GetOperandTypeId(inst, 2); @@ -88,10 +132,20 @@ spv_result_t RayTracingPass(ValidationState_t& _, const Instruction* inst) { << "Ray Origin must be a 32-bit float 3-component vector"; } - const uint32_t ray_tmin = _.GetOperandTypeId(inst, 7); - if (!_.IsFloatScalarType(ray_tmin) || _.GetBitWidth(ray_tmin) != 32) { + const uint32_t ray_tmin = inst->GetOperandAs(7); + bool is_ray_tmin_float32 = false; + bool is_ray_tmin_const = false; + float ray_tmin_value = 0; + std::tie(is_ray_tmin_float32, is_ray_tmin_const, ray_tmin_value) = + _.EvalFloat32IfConst(ray_tmin); + if (!is_ray_tmin_float32) { return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Ray TMin must be a 32-bit float scalar"; + } else if (is_ray_tmin_const && ray_tmin_value < 0.0f) { + // Don't need to check TMax for being negative because it can't without + // being less than TMin + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Ray Tmin is negative (" << ray_tmin_value << ")"; } const uint32_t ray_direction = _.GetOperandTypeId(inst, 8); @@ -102,12 +156,24 @@ spv_result_t RayTracingPass(ValidationState_t& _, const Instruction* inst) { << "Ray Direction must be a 32-bit float 3-component vector"; } - const uint32_t ray_tmax = _.GetOperandTypeId(inst, 9); - if (!_.IsFloatScalarType(ray_tmax) || _.GetBitWidth(ray_tmax) != 32) { + const uint32_t ray_tmax = inst->GetOperandAs(9); + bool is_ray_tmax_float32 = false; + bool is_ray_tmax_const = false; + float ray_tmax_value = 0; + std::tie(is_ray_tmax_float32, is_ray_tmax_const, ray_tmax_value) = + _.EvalFloat32IfConst(ray_tmax); + if (!is_ray_tmax_float32) { return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Ray TMax must be a 32-bit float scalar"; } + if (is_ray_tmin_const && is_ray_tmax_const && + ray_tmin_value > ray_tmax_value) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Ray Tmin (" << ray_tmin_value + << ") is larger than Ray Tmax (" << ray_tmax_value << ")"; + } + const Instruction* payload = _.FindDef(inst->GetOperandAs(10)); if (payload->opcode() != spv::Op::OpVariable) { return _.diag(SPV_ERROR_INVALID_DATA, inst) diff --git a/source/val/validation_state.cpp b/source/val/validation_state.cpp index 971e031558..0ed99737ee 100644 --- a/source/val/validation_state.cpp +++ b/source/val/validation_state.cpp @@ -15,6 +15,7 @@ #include "source/val/validation_state.h" #include +#include #include #include @@ -1374,6 +1375,34 @@ std::tuple ValidationState_t::EvalInt32IfConst( return std::make_tuple(true, true, inst->word(3)); } +std::tuple ValidationState_t::EvalFloat32IfConst( + uint32_t id) const { + const Instruction* const inst = FindDef(id); + assert(inst); + const uint32_t type = inst->type_id(); + + if (type == 0 || !IsFloatScalarType(type) || GetBitWidth(type) != 32) { + return std::make_tuple(false, false, 0.0f); + } + + // Spec constant values cannot be evaluated so don't consider constant for + // the purpose of this method. + if (!spvOpcodeIsConstant(inst->opcode()) || + spvOpcodeIsSpecConstant(inst->opcode())) { + return std::make_tuple(true, false, 0.0f); + } + + if (inst->opcode() == spv::Op::OpConstantNull) { + return std::make_tuple(true, true, 0.0f); + } + + assert(inst->words().size() == 4); + uint32_t word = inst->word(3); + float value = 0; + std::memcpy(&value, &word, sizeof(float)); + return std::make_tuple(true, true, value); +} + void ValidationState_t::ComputeFunctionToEntryPointMapping() { for (const uint32_t entry_point : entry_points()) { std::stack call_stack; diff --git a/source/val/validation_state.h b/source/val/validation_state.h index 0cd6c789bb..a74e8da3ab 100644 --- a/source/val/validation_state.h +++ b/source/val/validation_state.h @@ -730,6 +730,9 @@ class ValidationState_t { // OpSpecConstant* return |is_const_int32| as false since their values cannot // be relied upon during validation. std::tuple EvalInt32IfConst(uint32_t id) const; + // Tries to evaluate a 32-bit scalar float constant. + // Returns tuple . + std::tuple EvalFloat32IfConst(uint32_t id) const; // Returns the disassembly string for the given instruction. std::string Disassemble(const Instruction& inst) const; diff --git a/test/val/val_ray_query_test.cpp b/test/val/val_ray_query_test.cpp index e0eb067589..be16b63a0f 100644 --- a/test/val/val_ray_query_test.cpp +++ b/test/val/val_ray_query_test.cpp @@ -40,12 +40,12 @@ OpCapability Shader OpCapability Int64 OpCapability Float64 OpCapability RayQueryKHR -OpExtension "SPV_KHR_ray_query" )"; ss << capabilities_and_extensions; ss << R"( +OpExtension "SPV_KHR_ray_query" OpMemoryModel Logical GLSL450 OpEntryPoint GLCompute %main "main" OpExecutionMode %main LocalSize 1 1 1 @@ -626,6 +626,237 @@ TEST_F(ValidateRayQuery, RayQueryArraySuccess) { EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); } +TEST_F(ValidateRayQuery, InitializeMinMoreThanMax) { + const std::string declarations = R"( +%f32_1 = OpConstant %f32 1 +%f32_2 = OpConstant %f32 2 +)"; + + const std::string body = R"( +%as = OpLoad %type_as %top_level_as +OpRayQueryInitializeKHR %ray_query %as %u32_0 %u32_0 %f32vec3_0 %f32_2 %f32vec3_0 %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body, "", declarations).c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Ray Tmin (2) is larger than Ray Tmax (1)")); +} + +TEST_F(ValidateRayQuery, InitializeMinMoreThanMaxRuntime) { + // Can't check TMin as it is a dynamic variable, so should not return any + // static errors + const std::string shader = R"( + OpCapability Shader + OpCapability RayQueryKHR + OpExtension "SPV_KHR_ray_query" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpDecorate %tlas DescriptorSet 0 + OpDecorate %tlas Binding 0 + OpMemberDecorate %storage_buffer 0 Offset 0 + OpDecorate %storage_buffer BufferBlock + OpDecorate %foo DescriptorSet 0 + OpDecorate %foo Binding 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %6 = OpTypeRayQueryKHR +%_ptr_Private_6 = OpTypePointer Private %6 + %rayQuery = OpVariable %_ptr_Private_6 Private + %9 = OpTypeAccelerationStructureKHR +%_ptr_UniformConstant_9 = OpTypePointer UniformConstant %9 + %tlas = OpVariable %_ptr_UniformConstant_9 UniformConstant + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 + %uint_255 = OpConstant %uint 255 + %float = OpTypeFloat 32 + %v3float = OpTypeVector %float 3 + %float_0 = OpConstant %float 0 + %19 = OpConstantComposite %v3float %float_0 %float_0 %float_0 +%storage_buffer = OpTypeStruct %float +%_ptr_Uniform_storage_buffer = OpTypePointer Uniform %storage_buffer + %foo = OpVariable %_ptr_Uniform_storage_buffer Uniform + %int = OpTypeInt 32 1 + %int_0 = OpConstant %int 0 +%_ptr_Uniform_float = OpTypePointer Uniform %float + %float_1 = OpConstant %float 1 + %29 = OpConstantComposite %v3float %float_1 %float_0 %float_0 + %float_4 = OpConstant %float 4 + %main = OpFunction %void None %3 + %5 = OpLabel + %12 = OpLoad %9 %tlas + %26 = OpAccessChain %_ptr_Uniform_float %foo %int_0 +%descriptor_float = OpLoad %float %26 + OpRayQueryInitializeKHR %rayQuery %12 %uint_0 %uint_255 %19 %descriptor_float %29 %float_4 + OpReturn + OpFunctionEnd +)"; + CompileSuccessfully(shader); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateRayQuery, InitializeMinNegative) { + const std::string declarations = R"( +%f32_1 = OpConstant %f32 1 +%f32_n1 = OpConstant %f32 -1 +)"; + + const std::string body = R"( +%as = OpLoad %type_as %top_level_as +OpRayQueryInitializeKHR %ray_query %as %u32_0 %u32_0 %f32vec3_0 %f32_n1 %f32vec3_0 %f32_1 +)"; + + CompileSuccessfully(GenerateShaderCode(body, "", declarations).c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("Ray Tmin is negative (-1)")); +} + +TEST_F(ValidateRayQuery, InitializeRayFlagsBothSkipPrimitiveCulling) { + const std::string capabilities = R"( +OpCapability RayTraversalPrimitiveCullingKHR +)"; + + // SkipTrianglesKHR | SkipAABBsKHR + const std::string declarations = R"( +%u32_768 = OpConstant %u32 768 +)"; + + const std::string body = R"( +%load = OpLoad %type_as %top_level_as +OpRayQueryInitializeKHR %ray_query %load %u32_768 %u32_0 %f32vec3_0 %f32_0 %f32vec3_0 %f32_0 +)"; + + CompileSuccessfully( + GenerateShaderCode(body, capabilities, declarations).c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Ray Flags contains both SkipTrianglesKHR and SkipAABBsKHR")); +} + +TEST_F(ValidateRayQuery, InitializeRayFlagsSkipAABBs) { + const std::string capabilities = R"( +OpCapability RayTraversalPrimitiveCullingKHR +)"; + + // only SkipAABBsKHR + const std::string declarations = R"( +%u32_512 = OpConstant %u32 512 +)"; + + const std::string body = R"( +%load = OpLoad %type_as %top_level_as +OpRayQueryInitializeKHR %ray_query %load %u32_512 %u32_0 %f32vec3_0 %f32_0 %f32vec3_0 %f32_0 +)"; + + CompileSuccessfully( + GenerateShaderCode(body, capabilities, declarations).c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateRayQuery, InitializeRayFlagsSkipTriangleCullBack) { + const std::string capabilities = R"( +OpCapability RayTraversalPrimitiveCullingKHR +)"; + + // SkipTrianglesKHR and CullBackFacingTrianglesKHR + const std::string declarations = R"( +%u32_272 = OpConstant %u32 272 +)"; + + const std::string body = R"( +%load = OpLoad %type_as %top_level_as +OpRayQueryInitializeKHR %ray_query %load %u32_272 %u32_0 %f32vec3_0 %f32_0 %f32vec3_0 %f32_0 +)"; + + CompileSuccessfully( + GenerateShaderCode(body, capabilities, declarations).c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Ray Flags contains more than one of SkipTrianglesKHR or " + "CullFrontFacingTrianglesKHR or CullBackFacingTrianglesKHR")); +} + +TEST_F(ValidateRayQuery, InitializeRayFlagsSkipTriangleCullFrontAndBack) { + const std::string capabilities = R"( +OpCapability RayTraversalPrimitiveCullingKHR +)"; + + // SkipTrianglesKHR and CullFrontFacingTrianglesKHR and + // CullBackFacingTrianglesKHR + const std::string declarations = R"( +%u32_304 = OpConstant %u32 304 +)"; + + const std::string body = R"( +%load = OpLoad %type_as %top_level_as +OpRayQueryInitializeKHR %ray_query %load %u32_304 %u32_0 %f32vec3_0 %f32_0 %f32vec3_0 %f32_0 +)"; + + CompileSuccessfully( + GenerateShaderCode(body, capabilities, declarations).c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Ray Flags contains more than one of SkipTrianglesKHR or " + "CullFrontFacingTrianglesKHR or CullBackFacingTrianglesKHR")); +} + +TEST_F(ValidateRayQuery, InitializeRayFlagsSkipAABBCullBackward) { + const std::string capabilities = R"( +OpCapability RayTraversalPrimitiveCullingKHR +)"; + + // SkipAABBsKHR and CullBackFacingTrianglesKHR (legal) + const std::string declarations = R"( +%u32_528 = OpConstant %u32 528 +)"; + + const std::string body = R"( +%load = OpLoad %type_as %top_level_as +OpRayQueryInitializeKHR %ray_query %load %u32_528 %u32_0 %f32vec3_0 %f32_0 %f32vec3_0 %f32_0 +)"; + + CompileSuccessfully( + GenerateShaderCode(body, capabilities, declarations).c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateRayQuery, InitializeRayFlagsOpaqueAndCullNoOpaque) { + // OpaqueKHR and CullNoOpaqueKHR + const std::string declarations = R"( +%u32_129 = OpConstant %u32 129 +)"; + + const std::string body = R"( +%load = OpLoad %type_as %top_level_as +OpRayQueryInitializeKHR %ray_query %load %u32_129 %u32_0 %f32vec3_0 %f32_0 %f32vec3_0 %f32_0 +)"; + + CompileSuccessfully(GenerateShaderCode(body, "", declarations).c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Ray Flags contains more than one of OpaqueKHR or " + "NoOpaqueKHR or CullOpaqueKHR or CullNoOpaqueKHR")); +} + +TEST_F(ValidateRayQuery, InitializeRayFlagsOpaqueAndCullBack) { + // OpaqueKHR and CullBackFacingTrianglesKHR (legal) + const std::string declarations = R"( +%u32_17 = OpConstant %u32 17 +)"; + + const std::string body = R"( +%load = OpLoad %type_as %top_level_as +OpRayQueryInitializeKHR %ray_query %load %u32_17 %u32_0 %f32vec3_0 %f32_0 %f32vec3_0 %f32_0 +)"; + + CompileSuccessfully(GenerateShaderCode(body, "", declarations).c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + } // namespace } // namespace val } // namespace spvtools diff --git a/test/val/val_ray_tracing_test.cpp b/test/val/val_ray_tracing_test.cpp index 60f2f89117..a32c3b906c 100644 --- a/test/val/val_ray_tracing_test.cpp +++ b/test/val/val_ray_tracing_test.cpp @@ -364,10 +364,12 @@ OpFunctionEnd std::string GenerateRayTraceCode( const std::string& body, - const std::string execution_model = "RayGenerationKHR") { + const std::string execution_model = "RayGenerationKHR", + const std::string& declarations = "") { std::ostringstream ss; ss << R"( OpCapability RayTracingKHR +OpCapability RayTraversalPrimitiveCullingKHR OpCapability Float64 OpExtension "SPV_KHR_ray_tracing" OpMemoryModel Logical GLSL450 @@ -402,6 +404,11 @@ OpDecorate %top_level_as Binding 0 %var_float = OpVariable %ptr_float Private %ptr_f32vec3 = OpTypePointer Private %f32vec3 %var_f32vec3 = OpVariable %ptr_f32vec3 Private +)"; + + ss << declarations; + + ss << R"( %main = OpFunction %void None %func %label = OpLabel )"; @@ -667,6 +674,167 @@ OpFunctionEnd "IncomingCallableDataKHR storage class in the interface")); } +TEST_F(ValidateRayTracing, TraceRayMinMoreThanMax) { + const std::string declarations = R"( +%float_1 = OpConstant %float 1 +%float_2 = OpConstant %float 2 +)"; + + const std::string body = R"( +%as = OpLoad %type_as %top_level_as +OpTraceRayKHR %as %uint_1 %uint_1 %uint_1 %uint_1 %uint_1 %v3composite %float_2 %v3composite %float_1 %payload +)"; + + CompileSuccessfully( + GenerateRayTraceCode(body, "RayGenerationKHR", declarations).c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Ray Tmin (2) is larger than Ray Tmax (1)")); +} + +TEST_F(ValidateRayTracing, TraceRayMinNegative) { + const std::string declarations = R"( +%float_n1 = OpConstant %float -1 +)"; + + const std::string body = R"( +%as = OpLoad %type_as %top_level_as +OpTraceRayKHR %as %uint_1 %uint_1 %uint_1 %uint_1 %uint_1 %v3composite %float_n1 %v3composite %float_0 %payload +)"; + + CompileSuccessfully( + GenerateRayTraceCode(body, "RayGenerationKHR", declarations).c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr("Ray Tmin is negative (-1)")); +} + +TEST_F(ValidateRayTracing, TraceRayRayFlagsBothSkipPrimitiveCulling) { + // SkipTrianglesKHR | SkipAABBsKHR + const std::string declarations = R"( +%uint_768 = OpConstant %uint 768 +)"; + + const std::string body = R"( +%as = OpLoad %type_as %top_level_as +OpTraceRayKHR %as %uint_768 %uint_1 %uint_1 %uint_1 %uint_1 %v3composite %float_0 %v3composite %float_0 %payload +)"; + + CompileSuccessfully( + GenerateRayTraceCode(body, "RayGenerationKHR", declarations).c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Ray Flags contains both SkipTrianglesKHR and SkipAABBsKHR")); +} + +TEST_F(ValidateRayTracing, TraceRayRayFlagsSkipAABBs) { + // only SkipAABBsKHR + const std::string declarations = R"( +%uint_512 = OpConstant %uint 512 +)"; + + const std::string body = R"( +%as = OpLoad %type_as %top_level_as +OpTraceRayKHR %as %uint_512 %uint_1 %uint_1 %uint_1 %uint_1 %v3composite %float_0 %v3composite %float_0 %payload +)"; + + CompileSuccessfully( + GenerateRayTraceCode(body, "RayGenerationKHR", declarations).c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateRayTracing, TraceRayRayFlagsSkipTriangleCullBack) { + // SkipTrianglesKHR and CullBackFacingTrianglesKHR + const std::string declarations = R"( +%uint_272 = OpConstant %uint 272 +)"; + + const std::string body = R"( +%as = OpLoad %type_as %top_level_as +OpTraceRayKHR %as %uint_272 %uint_1 %uint_1 %uint_1 %uint_1 %v3composite %float_0 %v3composite %float_0 %payload +)"; + + CompileSuccessfully( + GenerateRayTraceCode(body, "RayGenerationKHR", declarations).c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Ray Flags contains more than one of SkipTrianglesKHR or " + "CullFrontFacingTrianglesKHR or CullBackFacingTrianglesKHR")); +} + +TEST_F(ValidateRayTracing, TraceRayRayFlagsSkipTriangleCullFrontAndBack) { + // SkipTrianglesKHR and CullFrontFacingTrianglesKHR and + // CullBackFacingTrianglesKHR + const std::string declarations = R"( +%uint_304 = OpConstant %uint 304 +)"; + + const std::string body = R"( +%as = OpLoad %type_as %top_level_as +OpTraceRayKHR %as %uint_304 %uint_1 %uint_1 %uint_1 %uint_1 %v3composite %float_0 %v3composite %float_0 %payload +)"; + + CompileSuccessfully( + GenerateRayTraceCode(body, "RayGenerationKHR", declarations).c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Ray Flags contains more than one of SkipTrianglesKHR or " + "CullFrontFacingTrianglesKHR or CullBackFacingTrianglesKHR")); +} + +TEST_F(ValidateRayTracing, TraceRayRayFlagsSkipAABBCullBackward) { + // SkipAABBsKHR and CullBackFacingTrianglesKHR (legal) + const std::string declarations = R"( +%uint_528 = OpConstant %uint 528 +)"; + + const std::string body = R"( +%as = OpLoad %type_as %top_level_as +OpTraceRayKHR %as %uint_528 %uint_1 %uint_1 %uint_1 %uint_1 %v3composite %float_0 %v3composite %float_0 %payload +)"; + + CompileSuccessfully( + GenerateRayTraceCode(body, "RayGenerationKHR", declarations).c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateRayTracing, TraceRayRayFlagsOpaqueAndCullNoOpaque) { + // OpaqueKHR and CullNoOpaqueKHR + const std::string declarations = R"( +%uint_129 = OpConstant %uint 129 +)"; + + const std::string body = R"( +%as = OpLoad %type_as %top_level_as +OpTraceRayKHR %as %uint_129 %uint_1 %uint_1 %uint_1 %uint_1 %v3composite %float_0 %v3composite %float_0 %payload +)"; + + CompileSuccessfully( + GenerateRayTraceCode(body, "RayGenerationKHR", declarations).c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Ray Flags contains more than one of OpaqueKHR or " + "NoOpaqueKHR or CullOpaqueKHR or CullNoOpaqueKHR")); +} + +TEST_F(ValidateRayTracing, TraceRayRayFlagsOpaqueAndCullBack) { + // OpaqueKHR and CullBackFacingTrianglesKHR (legal) + const std::string declarations = R"( +%uint_17 = OpConstant %uint 17 +)"; + + const std::string body = R"( +%as = OpLoad %type_as %top_level_as +OpTraceRayKHR %as %uint_17 %uint_1 %uint_1 %uint_1 %uint_1 %v3composite %float_0 %v3composite %float_0 %payload +)"; + + CompileSuccessfully( + GenerateRayTraceCode(body, "RayGenerationKHR", declarations).c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + } // namespace } // namespace val } // namespace spvtools