Skip to content

Commit 62c6ecf

Browse files
committed
Support pointer base type in OpTypeVector when capabilit
MaskedGatherScatterINTEL is enabled Add supporting tests for spv_intel_masked_gather_scatter Formatting changes for unrelated files
1 parent ba1359d commit 62c6ecf

File tree

4 files changed

+69
-35
lines changed

4 files changed

+69
-35
lines changed

source/val/validate_memory_semantics.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,8 @@ spv_result_t ValidateMemorySemantics(ValidationState_t& _,
203203
"storage class";
204204
}
205205

206-
if (opcode == spv::Op::OpControlBarrier && value && !includes_storage_class) {
206+
if (opcode == spv::Op::OpControlBarrier && value &&
207+
!includes_storage_class) {
207208
return _.diag(SPV_ERROR_INVALID_DATA, inst)
208209
<< _.VkErrorID(4650) << spvOpcodeString(opcode)
209210
<< ": expected Memory Semantics to include a Vulkan-supported "

source/val/validate_type.cpp

+18-1
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,24 @@ spv_result_t ValidateTypeVector(ValidationState_t& _, const Instruction* inst) {
140140
const auto component_index = 1;
141141
const auto component_id = inst->GetOperandAs<uint32_t>(component_index);
142142
const auto component_type = _.FindDef(component_id);
143-
if (!component_type || !spvOpcodeIsScalarType(component_type->opcode())) {
143+
if (component_type) {
144+
bool isPointer = component_type->opcode() == spv::Op::OpTypePointer;
145+
bool isScalar = spvOpcodeIsScalarType(component_type->opcode());
146+
147+
if (_.HasCapability(spv::Capability::MaskedGatherScatterINTEL) &&
148+
!isPointer && !isScalar) {
149+
return _.diag(SPV_ERROR_INVALID_ID, inst)
150+
<< "Invalid OpTypeVector Component Type<id> "
151+
<< _.getIdName(component_id)
152+
<< ": Expected a scalar or pointer type when using the "
153+
"SPV_INTEL_masked_gather_scatter extension.";
154+
} else if (!_.HasCapability(spv::Capability::MaskedGatherScatterINTEL) &&
155+
!isScalar) {
156+
return _.diag(SPV_ERROR_INVALID_ID, inst)
157+
<< "OpTypeVector Component Type <id> " << _.getIdName(component_id)
158+
<< " is not a scalar type.";
159+
}
160+
} else {
144161
return _.diag(SPV_ERROR_INVALID_ID, inst)
145162
<< "OpTypeVector Component Type <id> " << _.getIdName(component_id)
146163
<< " is not a scalar type.";

test/string_utils_test.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
#include "source/util/string_utils.h"
16+
1517
#include <string>
1618

1719
#include "gtest/gtest.h"
18-
#include "source/util/string_utils.h"
1920
#include "spirv-tools/libspirv.h"
2021

2122
namespace spvtools {

test/val/val_id_test.cpp

+47-32
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,7 @@ std::string kOpVariablePtrSetUp = R"(
7979
OpExtension "SPV_KHR_variable_pointers"
8080
)";
8181

82-
std::string kGLSL450MemoryModel =
83-
kOpCapabilitySetup + kOpVariablePtrSetUp + R"(
82+
std::string kGLSL450MemoryModel = kOpCapabilitySetup + kOpVariablePtrSetUp + R"(
8483
OpMemoryModel Logical GLSL450
8584
)";
8685

@@ -709,6 +708,24 @@ TEST_P(ValidateIdWithMessage, OpTypeVectorComponentTypeBad) {
709708
"'2[%_ptr_UniformConstant_float]' is not a scalar type.")));
710709
}
711710

711+
TEST_P(ValidateIdWithMessage, OpTypeVectorComponentTypeCanBePointerType) {
712+
std::string spirv = R"(
713+
OpCapability Addresses
714+
OpCapability Linkage
715+
OpCapability Kernel
716+
OpCapability Int64
717+
OpCapability GenericPointer
718+
OpCapability MaskedGatherScatterINTEL
719+
OpExtension "SPV_INTEL_masked_gather_scatter"
720+
OpMemoryModel Physical64 OpenCL
721+
722+
%2 = OpTypeInt 32 0
723+
%3 = OpTypePointer Generic %2
724+
%4 = OpTypeVector %3 4)";
725+
CompileSuccessfully(spirv.c_str());
726+
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
727+
}
728+
712729
TEST_P(ValidateIdWithMessage, OpTypeVectorColumnCountLessThanTwoBad) {
713730
std::string spirv = kGLSL450MemoryModel + R"(
714731
%1 = OpTypeFloat 32
@@ -4024,8 +4041,7 @@ TEST_P(AccessChainInstructionTest, AccessChainResultTypeBad) {
40244041
const std::string instr = GetParam();
40254042
const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : "";
40264043
std::string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"(
4027-
%float_entry = )" +
4028-
instr +
4044+
%float_entry = )" + instr +
40294045
R"( %float %my_matrix )" + elem +
40304046
R"(%int_0 %int_1
40314047
OpReturn
@@ -4045,8 +4061,8 @@ TEST_P(AccessChainInstructionTest, AccessChainBaseTypeVoidBad) {
40454061
const std::string instr = GetParam();
40464062
const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : "";
40474063
std::string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"(
4048-
%float_entry = )" +
4049-
instr + " %_ptr_Private_float %void " + elem +
4064+
%float_entry = )" + instr +
4065+
" %_ptr_Private_float %void " + elem +
40504066
R"(%int_0 %int_1
40514067
OpReturn
40524068
OpFunctionEnd
@@ -4062,8 +4078,7 @@ TEST_P(AccessChainInstructionTest, AccessChainBaseTypeNonPtrVariableBad) {
40624078
const std::string instr = GetParam();
40634079
const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : "";
40644080
std::string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"(
4065-
%entry = )" +
4066-
instr + R"( %_ptr_Private_float %_ptr_Private_float )" +
4081+
%entry = )" + instr + R"( %_ptr_Private_float %_ptr_Private_float )" +
40674082
elem +
40684083
R"(%int_0 %int_1
40694084
OpReturn
@@ -4081,8 +4096,8 @@ TEST_P(AccessChainInstructionTest,
40814096
const std::string instr = GetParam();
40824097
const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : "";
40834098
std::string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"(
4084-
%entry = )" +
4085-
instr + R"( %_ptr_Function_float %my_matrix )" + elem +
4099+
%entry = )" + instr + R"( %_ptr_Function_float %my_matrix )" +
4100+
elem +
40864101
R"(%int_0 %int_1
40874102
OpReturn
40884103
OpFunctionEnd
@@ -4102,8 +4117,8 @@ TEST_P(AccessChainInstructionTest,
41024117
const std::string instr = GetParam();
41034118
const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : "";
41044119
std::string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"(
4105-
%entry = )" +
4106-
instr + R"( %_ptr_Private_float %my_float_var )" + elem +
4120+
%entry = )" + instr + R"( %_ptr_Private_float %my_float_var )" +
4121+
elem +
41074122
R"(%int_0
41084123
OpReturn
41094124
OpFunctionEnd
@@ -4122,8 +4137,8 @@ TEST_P(AccessChainInstructionTest, AccessChainNoIndexesGood) {
41224137
const std::string instr = GetParam();
41234138
const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : "";
41244139
std::string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"(
4125-
%entry = )" +
4126-
instr + R"( %_ptr_Private_float %my_float_var )" + elem +
4140+
%entry = )" + instr + R"( %_ptr_Private_float %my_float_var )" +
4141+
elem +
41274142
R"(
41284143
OpReturn
41294144
OpFunctionEnd
@@ -4138,8 +4153,8 @@ TEST_P(AccessChainInstructionTest, AccessChainNoIndexesBad) {
41384153
const std::string instr = GetParam();
41394154
const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : "";
41404155
std::string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"(
4141-
%entry = )" +
4142-
instr + R"( %_ptr_Private_mat4x3 %my_float_var )" + elem +
4156+
%entry = )" + instr + R"( %_ptr_Private_mat4x3 %my_float_var )" +
4157+
elem +
41434158
R"(
41444159
OpReturn
41454160
OpFunctionEnd
@@ -4295,8 +4310,8 @@ TEST_P(AccessChainInstructionTest, AccessChainUndefinedIndexBad) {
42954310
const std::string instr = GetParam();
42964311
const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : "";
42974312
std::string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"(
4298-
%entry = )" +
4299-
instr + R"( %_ptr_Private_float %my_matrix )" + elem +
4313+
%entry = )" + instr + R"( %_ptr_Private_float %my_matrix )" +
4314+
elem +
43004315
R"(%float_0 %int_1
43014316
OpReturn
43024317
OpFunctionEnd
@@ -4314,8 +4329,8 @@ TEST_P(AccessChainInstructionTest, AccessChainStructIndexNotConstantBad) {
43144329
const std::string instr = GetParam();
43154330
const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : "";
43164331
std::string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"(
4317-
%f = )" +
4318-
instr + R"( %_ptr_Uniform_float %blockName_var )" + elem +
4332+
%f = )" + instr + R"( %_ptr_Uniform_float %blockName_var )" +
4333+
elem +
43194334
R"(%int_0 %spec_int %int_2
43204335
OpReturn
43214336
OpFunctionEnd
@@ -4333,8 +4348,8 @@ TEST_P(AccessChainInstructionTest,
43334348
const std::string instr = GetParam();
43344349
const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : "";
43354350
std::string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"(
4336-
%entry = )" +
4337-
instr + R"( %_ptr_Uniform_float %blockName_var )" + elem +
4351+
%entry = )" + instr + R"( %_ptr_Uniform_float %blockName_var )" +
4352+
elem +
43384353
R"(%int_0 %int_1 %int_2
43394354
OpReturn
43404355
OpFunctionEnd
@@ -4353,8 +4368,8 @@ TEST_P(AccessChainInstructionTest, AccessChainStructTooManyIndexesBad) {
43534368
const std::string instr = GetParam();
43544369
const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : "";
43554370
std::string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"(
4356-
%entry = )" +
4357-
instr + R"( %_ptr_Uniform_float %blockName_var )" + elem +
4371+
%entry = )" + instr + R"( %_ptr_Uniform_float %blockName_var )" +
4372+
elem +
43584373
R"(%int_0 %int_2 %int_2
43594374
OpReturn
43604375
OpFunctionEnd
@@ -4372,8 +4387,8 @@ TEST_P(AccessChainInstructionTest, AccessChainStructIndexOutOfBoundBad) {
43724387
const std::string instr = GetParam();
43734388
const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : "";
43744389
std::string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"(
4375-
%entry = )" +
4376-
instr + R"( %_ptr_Uniform_float %blockName_var )" + elem +
4390+
%entry = )" + instr + R"( %_ptr_Uniform_float %blockName_var )" +
4391+
elem +
43774392
R"(%int_3 %int_2 %int_2
43784393
OpReturn
43794394
OpFunctionEnd
@@ -4428,8 +4443,8 @@ TEST_P(AccessChainInstructionTest, AccessChainIndexIntoRuntimeArrayGood) {
44284443
" OpDecorate %_ptr_Uniform_blockName ArrayStride 8 ";
44294444
std::string spirv = kGLSL450MemoryModel + arrayStride +
44304445
kDeeplyNestedStructureSetup + R"(
4431-
%runtime_arr_entry = )" + instr +
4432-
R"( %_ptr_Uniform_float %blockName_var )" + elem +
4446+
%runtime_arr_entry = )" +
4447+
instr + R"( %_ptr_Uniform_float %blockName_var )" + elem +
44334448
R"(%int_2 %int_0
44344449
OpReturn
44354450
OpFunctionEnd
@@ -4463,8 +4478,8 @@ TEST_P(AccessChainInstructionTest, AccessChainMatrixMoreArgsThanNeededBad) {
44634478
const std::string instr = GetParam();
44644479
const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : "";
44654480
std::string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"(
4466-
%entry = )" +
4467-
instr + R"( %_ptr_Private_float %my_matrix )" + elem +
4481+
%entry = )" + instr + R"( %_ptr_Private_float %my_matrix )" +
4482+
elem +
44684483
R"(%int_0 %int_1 %int_0
44694484
OpReturn
44704485
OpFunctionEnd
@@ -4483,8 +4498,8 @@ TEST_P(AccessChainInstructionTest,
44834498
const std::string instr = GetParam();
44844499
const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : "";
44854500
std::string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"(
4486-
%entry = )" +
4487-
instr + R"( %_ptr_Private_mat4x3 %my_matrix )" + elem +
4501+
%entry = )" + instr + R"( %_ptr_Private_mat4x3 %my_matrix )" +
4502+
elem +
44884503
R"(%int_0 %int_1
44894504
OpReturn
44904505
OpFunctionEnd

0 commit comments

Comments
 (0)