22
22
namespace spvtools {
23
23
namespace val {
24
24
25
+ spv_result_t RayTracingRayFlagsOperandValue (ValidationState_t& _,
26
+ const Instruction* inst,
27
+ uint32_t ray_flags_value) {
28
+ const auto HasMoreThenOneBitSet =
29
+ [ray_flags_value](const spv::RayFlagsMask ray_flag_mask) {
30
+ const uint32_t mask =
31
+ ray_flags_value & static_cast <uint32_t >(ray_flag_mask);
32
+ return mask != 0 && (mask & (mask - 1 ));
33
+ };
34
+
35
+ if (HasMoreThenOneBitSet (spv::RayFlagsMask::SkipAABBsKHR |
36
+ spv::RayFlagsMask::SkipTrianglesKHR)) {
37
+ return _.diag (SPV_ERROR_INVALID_DATA, inst)
38
+ << " Ray Flags contains both SkipTrianglesKHR and SkipAABBsKHR" ;
39
+ }
40
+
41
+ if (HasMoreThenOneBitSet (spv::RayFlagsMask::SkipTrianglesKHR |
42
+ spv::RayFlagsMask::CullFrontFacingTrianglesKHR |
43
+ spv::RayFlagsMask::CullBackFacingTrianglesKHR)) {
44
+ return _.diag (SPV_ERROR_INVALID_DATA, inst)
45
+ << " Ray Flags contains more than one of SkipTrianglesKHR or "
46
+ " CullFrontFacingTrianglesKHR or CullBackFacingTrianglesKHR" ;
47
+ }
48
+
49
+ if (HasMoreThenOneBitSet (spv::RayFlagsMask::OpaqueKHR |
50
+ spv::RayFlagsMask::NoOpaqueKHR |
51
+ spv::RayFlagsMask::CullOpaqueKHR |
52
+ spv::RayFlagsMask::CullNoOpaqueKHR)) {
53
+ return _.diag (SPV_ERROR_INVALID_DATA, inst)
54
+ << " Ray Flags contains more than one of OpaqueKHR or NoOpaqueKHR or "
55
+ " CullOpaqueKHR or CullNoOpaqueKHR" ;
56
+ }
57
+ return SPV_SUCCESS;
58
+ }
59
+
25
60
spv_result_t RayTracingPass (ValidationState_t& _, const Instruction* inst) {
26
61
const spv::Op opcode = inst->opcode ();
27
62
const uint32_t result_type = inst->type_id ();
@@ -51,10 +86,19 @@ spv_result_t RayTracingPass(ValidationState_t& _, const Instruction* inst) {
51
86
" OpTypeAccelerationStructureKHR" ;
52
87
}
53
88
54
- const uint32_t ray_flags = _.GetOperandTypeId (inst, 1 );
55
- if (!_.IsIntScalarType (ray_flags) || _.GetBitWidth (ray_flags) != 32 ) {
89
+ const uint32_t ray_flags = inst->GetOperandAs <uint32_t >(1 );
90
+ bool is_ray_flags_int32 = false ;
91
+ bool is_ray_flags_const = false ;
92
+ uint32_t ray_flags_value = 0 ;
93
+ std::tie (is_ray_flags_int32, is_ray_flags_const, ray_flags_value) =
94
+ _.EvalInt32IfConst (ray_flags);
95
+ if (!is_ray_flags_int32) {
56
96
return _.diag (SPV_ERROR_INVALID_DATA, inst)
57
97
<< " Ray Flags must be a 32-bit int scalar" ;
98
+ } else if (is_ray_flags_const) {
99
+ if (auto error =
100
+ RayTracingRayFlagsOperandValue (_, inst, ray_flags_value))
101
+ return error;
58
102
}
59
103
60
104
const uint32_t cull_mask = _.GetOperandTypeId (inst, 2 );
@@ -88,10 +132,20 @@ spv_result_t RayTracingPass(ValidationState_t& _, const Instruction* inst) {
88
132
<< " Ray Origin must be a 32-bit float 3-component vector" ;
89
133
}
90
134
91
- const uint32_t ray_tmin = _.GetOperandTypeId (inst, 7 );
92
- if (!_.IsFloatScalarType (ray_tmin) || _.GetBitWidth (ray_tmin) != 32 ) {
135
+ const uint32_t ray_tmin = inst->GetOperandAs <uint32_t >(7 );
136
+ bool is_ray_tmin_float32 = false ;
137
+ bool is_ray_tmin_const = false ;
138
+ float ray_tmin_value = 0 ;
139
+ std::tie (is_ray_tmin_float32, is_ray_tmin_const, ray_tmin_value) =
140
+ _.EvalFloat32IfConst (ray_tmin);
141
+ if (!is_ray_tmin_float32) {
93
142
return _.diag (SPV_ERROR_INVALID_DATA, inst)
94
143
<< " Ray TMin must be a 32-bit float scalar" ;
144
+ } else if (is_ray_tmin_const && ray_tmin_value < 0 .0f ) {
145
+ // Don't need to check TMax for being negative because it can't without
146
+ // being less than TMin
147
+ return _.diag (SPV_ERROR_INVALID_DATA, inst)
148
+ << " Ray Tmin is negative (" << ray_tmin_value << " )" ;
95
149
}
96
150
97
151
const uint32_t ray_direction = _.GetOperandTypeId (inst, 8 );
@@ -102,12 +156,24 @@ spv_result_t RayTracingPass(ValidationState_t& _, const Instruction* inst) {
102
156
<< " Ray Direction must be a 32-bit float 3-component vector" ;
103
157
}
104
158
105
- const uint32_t ray_tmax = _.GetOperandTypeId (inst, 9 );
106
- if (!_.IsFloatScalarType (ray_tmax) || _.GetBitWidth (ray_tmax) != 32 ) {
159
+ const uint32_t ray_tmax = inst->GetOperandAs <uint32_t >(9 );
160
+ bool is_ray_tmax_float32 = false ;
161
+ bool is_ray_tmax_const = false ;
162
+ float ray_tmax_value = 0 ;
163
+ std::tie (is_ray_tmax_float32, is_ray_tmax_const, ray_tmax_value) =
164
+ _.EvalFloat32IfConst (ray_tmax);
165
+ if (!is_ray_tmax_float32) {
107
166
return _.diag (SPV_ERROR_INVALID_DATA, inst)
108
167
<< " Ray TMax must be a 32-bit float scalar" ;
109
168
}
110
169
170
+ if (is_ray_tmin_const && is_ray_tmax_const &&
171
+ ray_tmin_value > ray_tmax_value) {
172
+ return _.diag (SPV_ERROR_INVALID_DATA, inst)
173
+ << " Ray Tmin (" << ray_tmin_value
174
+ << " ) is larger than Ray Tmax (" << ray_tmax_value << " )" ;
175
+ }
176
+
111
177
const Instruction* payload = _.FindDef (inst->GetOperandAs <uint32_t >(10 ));
112
178
if (payload->opcode () != spv::Op::OpVariable) {
113
179
return _.diag (SPV_ERROR_INVALID_DATA, inst)
0 commit comments