Skip to content

Commit 7846043

Browse files
spirv-val: Add static RayQuery/RayTracing value check
1 parent f9184c6 commit 7846043

7 files changed

+547
-14
lines changed

source/val/validate.h

+5
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,11 @@ spv_result_t RayQueryPass(ValidationState_t& _, const Instruction* inst);
210210
/// Validates correctness of ray tracing instructions.
211211
spv_result_t RayTracingPass(ValidationState_t& _, const Instruction* inst);
212212

213+
/// Validates constant values of the Ray Flags operand
214+
spv_result_t RayTracingRayFlagsOperandValue(ValidationState_t& _,
215+
const Instruction* inst,
216+
uint32_t ray_flags_value);
217+
213218
/// Validates correctness of shader execution reorder instructions.
214219
spv_result_t RayReorderNVPass(ValidationState_t& _, const Instruction* inst);
215220

source/val/validate_ray_query.cpp

+37-6
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,19 @@ spv_result_t RayQueryPass(ValidationState_t& _, const Instruction* inst) {
8282
"OpTypeAccelerationStructureKHR";
8383
}
8484

85-
const uint32_t ray_flags = _.GetOperandTypeId(inst, 2);
86-
if (!_.IsIntScalarType(ray_flags) || _.GetBitWidth(ray_flags) != 32) {
85+
const uint32_t ray_flags = inst->GetOperandAs<uint32_t>(2);
86+
bool is_ray_flags_int32 = false;
87+
bool is_ray_flags_const = false;
88+
uint32_t ray_flags_value = 0;
89+
std::tie(is_ray_flags_int32, is_ray_flags_const, ray_flags_value) =
90+
_.EvalInt32IfConst(ray_flags);
91+
if (!is_ray_flags_int32) {
8792
return _.diag(SPV_ERROR_INVALID_DATA, inst)
8893
<< "Ray Flags must be a 32-bit int scalar";
94+
} else if (is_ray_flags_const) {
95+
if (auto error =
96+
RayTracingRayFlagsOperandValue(_, inst, ray_flags_value))
97+
return error;
8998
}
9099

91100
const uint32_t cull_mask = _.GetOperandTypeId(inst, 3);
@@ -101,10 +110,20 @@ spv_result_t RayQueryPass(ValidationState_t& _, const Instruction* inst) {
101110
<< "Ray Origin must be a 32-bit float 3-component vector";
102111
}
103112

104-
const uint32_t ray_tmin = _.GetOperandTypeId(inst, 5);
105-
if (!_.IsFloatScalarType(ray_tmin) || _.GetBitWidth(ray_tmin) != 32) {
113+
const uint32_t ray_tmin = inst->GetOperandAs<uint32_t>(5);
114+
bool is_ray_tmin_float32 = false;
115+
bool is_ray_tmin_const = false;
116+
float ray_tmin_value = 0;
117+
std::tie(is_ray_tmin_float32, is_ray_tmin_const, ray_tmin_value) =
118+
_.EvalFloat32IfConst(ray_tmin);
119+
if (!is_ray_tmin_float32) {
106120
return _.diag(SPV_ERROR_INVALID_DATA, inst)
107121
<< "Ray TMin must be a 32-bit float scalar";
122+
} else if (is_ray_tmin_const && ray_tmin_value < 0.0f) {
123+
// Don't need to check TMax for being negative because it can't without
124+
// being less than TMin
125+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
126+
<< "Ray Tmin is negative (" << ray_tmin_value << ")";
108127
}
109128

110129
const uint32_t ray_direction = _.GetOperandTypeId(inst, 6);
@@ -115,11 +134,23 @@ spv_result_t RayQueryPass(ValidationState_t& _, const Instruction* inst) {
115134
<< "Ray Direction must be a 32-bit float 3-component vector";
116135
}
117136

118-
const uint32_t ray_tmax = _.GetOperandTypeId(inst, 7);
119-
if (!_.IsFloatScalarType(ray_tmax) || _.GetBitWidth(ray_tmax) != 32) {
137+
const uint32_t ray_tmax = inst->GetOperandAs<uint32_t>(7);
138+
bool is_ray_tmax_float32 = false;
139+
bool is_ray_tmax_const = false;
140+
float ray_tmax_value = 0;
141+
std::tie(is_ray_tmax_float32, is_ray_tmax_const, ray_tmax_value) =
142+
_.EvalFloat32IfConst(ray_tmax);
143+
if (!is_ray_tmax_float32) {
120144
return _.diag(SPV_ERROR_INVALID_DATA, inst)
121145
<< "Ray TMax must be a 32-bit float scalar";
122146
}
147+
148+
if (is_ray_tmin_const && is_ray_tmax_const &&
149+
ray_tmin_value > ray_tmax_value) {
150+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
151+
<< "Ray Tmin (" << ray_tmin_value
152+
<< ") is larger than Ray Tmax (" << ray_tmax_value << ")";
153+
}
123154
break;
124155
}
125156

source/val/validate_ray_tracing.cpp

+72-6
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,41 @@
2222
namespace spvtools {
2323
namespace val {
2424

25+
spv_result_t RayTracingRayFlagsOperandValue(ValidationState_t& _,
26+
const Instruction* inst,
27+
uint32_t ray_flags_value) {
28+
const auto HasMoreThenOneBitSet =
29+
[ray_flags_value](const spv::RayFlagsMask ray_flag_mask) {
30+
const uint32_t mask =
31+
ray_flags_value & static_cast<uint32_t>(ray_flag_mask);
32+
return mask != 0 && (mask & (mask - 1));
33+
};
34+
35+
if (HasMoreThenOneBitSet(spv::RayFlagsMask::SkipAABBsKHR |
36+
spv::RayFlagsMask::SkipTrianglesKHR)) {
37+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
38+
<< "Ray Flags contains both SkipTrianglesKHR and SkipAABBsKHR";
39+
}
40+
41+
if (HasMoreThenOneBitSet(spv::RayFlagsMask::SkipTrianglesKHR |
42+
spv::RayFlagsMask::CullFrontFacingTrianglesKHR |
43+
spv::RayFlagsMask::CullBackFacingTrianglesKHR)) {
44+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
45+
<< "Ray Flags contains more than one of SkipTrianglesKHR or "
46+
"CullFrontFacingTrianglesKHR or CullBackFacingTrianglesKHR";
47+
}
48+
49+
if (HasMoreThenOneBitSet(spv::RayFlagsMask::OpaqueKHR |
50+
spv::RayFlagsMask::NoOpaqueKHR |
51+
spv::RayFlagsMask::CullOpaqueKHR |
52+
spv::RayFlagsMask::CullNoOpaqueKHR)) {
53+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
54+
<< "Ray Flags contains more than one of OpaqueKHR or NoOpaqueKHR or "
55+
"CullOpaqueKHR or CullNoOpaqueKHR";
56+
}
57+
return SPV_SUCCESS;
58+
}
59+
2560
spv_result_t RayTracingPass(ValidationState_t& _, const Instruction* inst) {
2661
const spv::Op opcode = inst->opcode();
2762
const uint32_t result_type = inst->type_id();
@@ -51,10 +86,19 @@ spv_result_t RayTracingPass(ValidationState_t& _, const Instruction* inst) {
5186
"OpTypeAccelerationStructureKHR";
5287
}
5388

54-
const uint32_t ray_flags = _.GetOperandTypeId(inst, 1);
55-
if (!_.IsIntScalarType(ray_flags) || _.GetBitWidth(ray_flags) != 32) {
89+
const uint32_t ray_flags = inst->GetOperandAs<uint32_t>(1);
90+
bool is_ray_flags_int32 = false;
91+
bool is_ray_flags_const = false;
92+
uint32_t ray_flags_value = 0;
93+
std::tie(is_ray_flags_int32, is_ray_flags_const, ray_flags_value) =
94+
_.EvalInt32IfConst(ray_flags);
95+
if (!is_ray_flags_int32) {
5696
return _.diag(SPV_ERROR_INVALID_DATA, inst)
5797
<< "Ray Flags must be a 32-bit int scalar";
98+
} else if (is_ray_flags_const) {
99+
if (auto error =
100+
RayTracingRayFlagsOperandValue(_, inst, ray_flags_value))
101+
return error;
58102
}
59103

60104
const uint32_t cull_mask = _.GetOperandTypeId(inst, 2);
@@ -88,10 +132,20 @@ spv_result_t RayTracingPass(ValidationState_t& _, const Instruction* inst) {
88132
<< "Ray Origin must be a 32-bit float 3-component vector";
89133
}
90134

91-
const uint32_t ray_tmin = _.GetOperandTypeId(inst, 7);
92-
if (!_.IsFloatScalarType(ray_tmin) || _.GetBitWidth(ray_tmin) != 32) {
135+
const uint32_t ray_tmin = inst->GetOperandAs<uint32_t>(7);
136+
bool is_ray_tmin_float32 = false;
137+
bool is_ray_tmin_const = false;
138+
float ray_tmin_value = 0;
139+
std::tie(is_ray_tmin_float32, is_ray_tmin_const, ray_tmin_value) =
140+
_.EvalFloat32IfConst(ray_tmin);
141+
if (!is_ray_tmin_float32) {
93142
return _.diag(SPV_ERROR_INVALID_DATA, inst)
94143
<< "Ray TMin must be a 32-bit float scalar";
144+
} else if (is_ray_tmin_const && ray_tmin_value < 0.0f) {
145+
// Don't need to check TMax for being negative because it can't without
146+
// being less than TMin
147+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
148+
<< "Ray Tmin is negative (" << ray_tmin_value << ")";
95149
}
96150

97151
const uint32_t ray_direction = _.GetOperandTypeId(inst, 8);
@@ -102,12 +156,24 @@ spv_result_t RayTracingPass(ValidationState_t& _, const Instruction* inst) {
102156
<< "Ray Direction must be a 32-bit float 3-component vector";
103157
}
104158

105-
const uint32_t ray_tmax = _.GetOperandTypeId(inst, 9);
106-
if (!_.IsFloatScalarType(ray_tmax) || _.GetBitWidth(ray_tmax) != 32) {
159+
const uint32_t ray_tmax = inst->GetOperandAs<uint32_t>(9);
160+
bool is_ray_tmax_float32 = false;
161+
bool is_ray_tmax_const = false;
162+
float ray_tmax_value = 0;
163+
std::tie(is_ray_tmax_float32, is_ray_tmax_const, ray_tmax_value) =
164+
_.EvalFloat32IfConst(ray_tmax);
165+
if (!is_ray_tmax_float32) {
107166
return _.diag(SPV_ERROR_INVALID_DATA, inst)
108167
<< "Ray TMax must be a 32-bit float scalar";
109168
}
110169

170+
if (is_ray_tmin_const && is_ray_tmax_const &&
171+
ray_tmin_value > ray_tmax_value) {
172+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
173+
<< "Ray Tmin (" << ray_tmin_value
174+
<< ") is larger than Ray Tmax (" << ray_tmax_value << ")";
175+
}
176+
111177
const Instruction* payload = _.FindDef(inst->GetOperandAs<uint32_t>(10));
112178
if (payload->opcode() != spv::Op::OpVariable) {
113179
return _.diag(SPV_ERROR_INVALID_DATA, inst)

source/val/validation_state.cpp

+29
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "source/val/validation_state.h"
1616

1717
#include <cassert>
18+
#include <cstring>
1819
#include <stack>
1920
#include <utility>
2021

@@ -1374,6 +1375,34 @@ std::tuple<bool, bool, uint32_t> ValidationState_t::EvalInt32IfConst(
13741375
return std::make_tuple(true, true, inst->word(3));
13751376
}
13761377

1378+
std::tuple<bool, bool, float> ValidationState_t::EvalFloat32IfConst(
1379+
uint32_t id) const {
1380+
const Instruction* const inst = FindDef(id);
1381+
assert(inst);
1382+
const uint32_t type = inst->type_id();
1383+
1384+
if (type == 0 || !IsFloatScalarType(type) || GetBitWidth(type) != 32) {
1385+
return std::make_tuple(false, false, 0.0f);
1386+
}
1387+
1388+
// Spec constant values cannot be evaluated so don't consider constant for
1389+
// the purpose of this method.
1390+
if (!spvOpcodeIsConstant(inst->opcode()) ||
1391+
spvOpcodeIsSpecConstant(inst->opcode())) {
1392+
return std::make_tuple(true, false, 0.0f);
1393+
}
1394+
1395+
if (inst->opcode() == spv::Op::OpConstantNull) {
1396+
return std::make_tuple(true, true, 0.0f);
1397+
}
1398+
1399+
assert(inst->words().size() == 4);
1400+
uint32_t word = inst->word(3);
1401+
float value = 0;
1402+
std::memcpy(&value, &word, sizeof(float));
1403+
return std::make_tuple(true, true, value);
1404+
}
1405+
13771406
void ValidationState_t::ComputeFunctionToEntryPointMapping() {
13781407
for (const uint32_t entry_point : entry_points()) {
13791408
std::stack<uint32_t> call_stack;

source/val/validation_state.h

+3
Original file line numberDiff line numberDiff line change
@@ -730,6 +730,9 @@ class ValidationState_t {
730730
// OpSpecConstant* return |is_const_int32| as false since their values cannot
731731
// be relied upon during validation.
732732
std::tuple<bool, bool, uint32_t> EvalInt32IfConst(uint32_t id) const;
733+
// Tries to evaluate a 32-bit scalar float constant.
734+
// Returns tuple <is_float32, is_const_float32, value>.
735+
std::tuple<bool, bool, float> EvalFloat32IfConst(uint32_t id) const;
733736

734737
// Returns the disassembly string for the given instruction.
735738
std::string Disassemble(const Instruction& inst) const;

0 commit comments

Comments
 (0)