diff --git a/internal/ast/symbol.go b/internal/ast/symbol.go index ea5d73a7eb..55c2909572 100644 --- a/internal/ast/symbol.go +++ b/internal/ast/symbol.go @@ -27,6 +27,15 @@ type Symbol struct { type SymbolTable map[string]*Symbol +// GetOrInit returns the symbol table, or initializes it if it's nil. +// This will modify whatever holds the SymbolTable, so is not safe for concurrent use. +func (s *SymbolTable) GetOrInit() SymbolTable { + if *s == nil { + *s = make(SymbolTable) + } + return *s +} + const InternalSymbolNamePrefix = "\xFE" // Invalid UTF8 sequence, will never occur as IdentifierName const ( diff --git a/internal/ast/utilities.go b/internal/ast/utilities.go index 166d3063ed..c83f5e1d4d 100644 --- a/internal/ast/utilities.go +++ b/internal/ast/utilities.go @@ -42,25 +42,6 @@ func GetSymbolId(symbol *Symbol) SymbolId { return SymbolId(id) } -func GetSymbolTable(data *SymbolTable) SymbolTable { - if *data == nil { - *data = make(SymbolTable) - } - return *data -} - -func GetMembers(symbol *Symbol) SymbolTable { - return GetSymbolTable(&symbol.Members) -} - -func GetExports(symbol *Symbol) SymbolTable { - return GetSymbolTable(&symbol.Exports) -} - -func GetLocals(container *Node) SymbolTable { - return GetSymbolTable(&container.LocalsContainerData().Locals) -} - // Determines if a node is missing (either `nil` or empty) func NodeIsMissing(node *Node) bool { return node == nil || node.Loc.Pos() == node.Loc.End() && node.Loc.Pos() >= 0 && node.Kind != KindEndOfFile diff --git a/internal/binder/binder.go b/internal/binder/binder.go index 5f18fa716a..e7faf82ef8 100644 --- a/internal/binder/binder.go +++ b/internal/binder/binder.go @@ -376,6 +376,18 @@ func GetSymbolNameForPrivateIdentifier(containingClassSymbol *ast.Symbol, descri return ast.InternalSymbolNamePrefix + "#" + strconv.Itoa(int(ast.GetSymbolId(containingClassSymbol))) + "@" + description } +func getMembers(symbol *ast.Symbol) ast.SymbolTable { + return symbol.Members.GetOrInit() +} + +func getExports(symbol *ast.Symbol) ast.SymbolTable { + return symbol.Exports.GetOrInit() +} + +func getLocals(container *ast.Node) ast.SymbolTable { + return container.LocalsContainerData().Locals.GetOrInit() +} + func (b *Binder) declareModuleMember(node *ast.Node, symbolFlags ast.SymbolFlags, symbolExcludes ast.SymbolFlags) *ast.Symbol { container := b.container if node.Kind == ast.KindCommonJSExport { @@ -384,9 +396,9 @@ func (b *Binder) declareModuleMember(node *ast.Node, symbolFlags ast.SymbolFlags hasExportModifier := ast.GetCombinedModifierFlags(node)&ast.ModifierFlagsExport != 0 if symbolFlags&ast.SymbolFlagsAlias != 0 { if node.Kind == ast.KindExportSpecifier || (node.Kind == ast.KindImportEqualsDeclaration && hasExportModifier) { - return b.declareSymbol(ast.GetExports(container.Symbol()), container.Symbol(), node, symbolFlags, symbolExcludes) + return b.declareSymbol(getExports(container.Symbol()), container.Symbol(), node, symbolFlags, symbolExcludes) } - return b.declareSymbol(ast.GetLocals(container), nil /*parent*/, node, symbolFlags, symbolExcludes) + return b.declareSymbol(getLocals(container), nil /*parent*/, node, symbolFlags, symbolExcludes) } // Exported module members are given 2 symbols: A local symbol that is classified with an ExportValue flag, // and an associated export symbol with all the correct flags set on it. There are 2 main reasons: @@ -405,33 +417,33 @@ func (b *Binder) declareModuleMember(node *ast.Node, symbolFlags ast.SymbolFlags // and should never be merged directly with other augmentation, and the latter case would be possible if automatic merge is allowed. if !ast.IsAmbientModule(node) && (hasExportModifier || container.Flags&ast.NodeFlagsExportContext != 0) { if !ast.IsLocalsContainer(container) || (ast.HasSyntacticModifier(node, ast.ModifierFlagsDefault) && b.getDeclarationName(node) == ast.InternalSymbolNameMissing) || ast.IsCommonJSExport(node) { - return b.declareSymbol(ast.GetExports(container.Symbol()), container.Symbol(), node, symbolFlags, symbolExcludes) + return b.declareSymbol(getExports(container.Symbol()), container.Symbol(), node, symbolFlags, symbolExcludes) // No local symbol for an unnamed default! } exportKind := ast.SymbolFlagsNone if symbolFlags&ast.SymbolFlagsValue != 0 { exportKind = ast.SymbolFlagsExportValue } - local := b.declareSymbol(ast.GetLocals(container), nil /*parent*/, node, exportKind, symbolExcludes) - local.ExportSymbol = b.declareSymbol(ast.GetExports(container.Symbol()), container.Symbol(), node, symbolFlags, symbolExcludes) + local := b.declareSymbol(getLocals(container), nil /*parent*/, node, exportKind, symbolExcludes) + local.ExportSymbol = b.declareSymbol(getExports(container.Symbol()), container.Symbol(), node, symbolFlags, symbolExcludes) node.ExportableData().LocalSymbol = local return local } - return b.declareSymbol(ast.GetLocals(container), nil /*parent*/, node, symbolFlags, symbolExcludes) + return b.declareSymbol(getLocals(container), nil /*parent*/, node, symbolFlags, symbolExcludes) } func (b *Binder) declareClassMember(node *ast.Node, symbolFlags ast.SymbolFlags, symbolExcludes ast.SymbolFlags) *ast.Symbol { if ast.IsStatic(node) { - return b.declareSymbol(ast.GetExports(b.container.Symbol()), b.container.Symbol(), node, symbolFlags, symbolExcludes) + return b.declareSymbol(getExports(b.container.Symbol()), b.container.Symbol(), node, symbolFlags, symbolExcludes) } - return b.declareSymbol(ast.GetMembers(b.container.Symbol()), b.container.Symbol(), node, symbolFlags, symbolExcludes) + return b.declareSymbol(getMembers(b.container.Symbol()), b.container.Symbol(), node, symbolFlags, symbolExcludes) } func (b *Binder) declareSourceFileMember(node *ast.Node, symbolFlags ast.SymbolFlags, symbolExcludes ast.SymbolFlags) *ast.Symbol { if ast.IsExternalOrCommonJSModule(b.file) { return b.declareModuleMember(node, symbolFlags, symbolExcludes) } - return b.declareSymbol(ast.GetLocals(b.file.AsNode()), nil /*parent*/, node, symbolFlags, symbolExcludes) + return b.declareSymbol(getLocals(b.file.AsNode()), nil /*parent*/, node, symbolFlags, symbolExcludes) } func (b *Binder) declareSymbolAndAddToSymbolTable(node *ast.Node, symbolFlags ast.SymbolFlags, symbolExcludes ast.SymbolFlags) *ast.Symbol { @@ -443,14 +455,14 @@ func (b *Binder) declareSymbolAndAddToSymbolTable(node *ast.Node, symbolFlags as case ast.KindClassExpression, ast.KindClassDeclaration: return b.declareClassMember(node, symbolFlags, symbolExcludes) case ast.KindEnumDeclaration: - return b.declareSymbol(ast.GetExports(b.container.Symbol()), b.container.Symbol(), node, symbolFlags, symbolExcludes) + return b.declareSymbol(getExports(b.container.Symbol()), b.container.Symbol(), node, symbolFlags, symbolExcludes) case ast.KindTypeLiteral, ast.KindJSDocTypeLiteral, ast.KindObjectLiteralExpression, ast.KindInterfaceDeclaration, ast.KindJsxAttributes: - return b.declareSymbol(ast.GetMembers(b.container.Symbol()), b.container.Symbol(), node, symbolFlags, symbolExcludes) + return b.declareSymbol(getMembers(b.container.Symbol()), b.container.Symbol(), node, symbolFlags, symbolExcludes) case ast.KindFunctionType, ast.KindConstructorType, ast.KindCallSignature, ast.KindConstructSignature, ast.KindJSDocSignature, ast.KindIndexSignature, ast.KindMethodDeclaration, ast.KindMethodSignature, ast.KindConstructor, ast.KindGetAccessor, ast.KindSetAccessor, ast.KindFunctionDeclaration, ast.KindFunctionExpression, ast.KindArrowFunction, ast.KindClassStaticBlockDeclaration, ast.KindTypeAliasDeclaration, ast.KindJSTypeAliasDeclaration, ast.KindMappedType: - return b.declareSymbol(ast.GetLocals(b.container), nil /*parent*/, node, symbolFlags, symbolExcludes) + return b.declareSymbol(getLocals(b.container), nil /*parent*/, node, symbolFlags, symbolExcludes) } panic("Unhandled case in declareSymbolAndAddToSymbolTable") } @@ -779,7 +791,7 @@ func (b *Binder) bindSourceFileIfExternalModule() { b.bindSourceFileAsExternalModule() // Create symbol equivalent for the module.exports = {} originalSymbol := b.file.Symbol - b.declareSymbol(ast.GetSymbolTable(&b.file.Symbol.Exports), b.file.Symbol, b.file.AsNode(), ast.SymbolFlagsProperty, ast.SymbolFlagsAll) + b.declareSymbol(b.file.Symbol.Exports.GetOrInit(), b.file.Symbol, b.file.AsNode(), ast.SymbolFlagsProperty, ast.SymbolFlagsAll) b.file.Symbol = originalSymbol } } @@ -841,7 +853,7 @@ func (b *Binder) bindNamespaceExportDeclaration(node *ast.Node) { case !node.Parent.AsSourceFile().IsDeclarationFile: b.errorOnNode(node, diagnostics.Global_module_exports_may_only_appear_in_declaration_files) default: - b.declareSymbol(ast.GetSymbolTable(&b.file.Symbol.GlobalExports), b.file.Symbol, node, ast.SymbolFlagsAlias, ast.SymbolFlagsAliasExcludes) + b.declareSymbol(b.file.Symbol.GlobalExports.GetOrInit(), b.file.Symbol, node, ast.SymbolFlagsAlias, ast.SymbolFlagsAliasExcludes) } } @@ -858,12 +870,12 @@ func (b *Binder) bindExportDeclaration(node *ast.Node) { b.bindAnonymousDeclaration(node, ast.SymbolFlagsExportStar, b.getDeclarationName(node)) } else if decl.ExportClause == nil { // All export * declarations are collected in an __export symbol - b.declareSymbol(ast.GetExports(b.container.Symbol()), b.container.Symbol(), node, ast.SymbolFlagsExportStar, ast.SymbolFlagsNone) + b.declareSymbol(getExports(b.container.Symbol()), b.container.Symbol(), node, ast.SymbolFlagsExportStar, ast.SymbolFlagsNone) } else if ast.IsNamespaceExport(decl.ExportClause) { // declareSymbol walks up parents to find name text, parent _must_ be set // but won't be set by the normal binder walk until `bindChildren` later on. setParent(decl.ExportClause, node) - b.declareSymbol(ast.GetExports(b.container.Symbol()), b.container.Symbol(), decl.ExportClause, ast.SymbolFlagsAlias, ast.SymbolFlagsAliasExcludes) + b.declareSymbol(getExports(b.container.Symbol()), b.container.Symbol(), decl.ExportClause, ast.SymbolFlagsAlias, ast.SymbolFlagsAliasExcludes) } } @@ -882,7 +894,7 @@ func (b *Binder) bindExportAssignment(node *ast.Node) { } // If there is an `export default x;` alias declaration, can't `export default` anything else. // (In contrast, you can still have `export default function f() {}` and `export default interface I {}`.) - symbol := b.declareSymbol(ast.GetExports(container.Symbol()), container.Symbol(), node, flags, ast.SymbolFlagsAll) + symbol := b.declareSymbol(getExports(container.Symbol()), container.Symbol(), node, flags, ast.SymbolFlagsAll) if ast.IsJSExportAssignment(node) || node.AsExportAssignment().IsExportEquals { // Will be an error later, since the module already has other exports. Just make sure this has a valueDeclaration set. SetValueDeclaration(symbol, node) @@ -971,12 +983,12 @@ func (b *Binder) bindClassLikeDeclaration(node *ast.Node) { // module might have an exported variable called 'prototype'. We can't allow that as // that would clash with the built-in 'prototype' for the class. prototypeSymbol := b.newSymbol(ast.SymbolFlagsProperty|ast.SymbolFlagsPrototype, "prototype") - symbolExport := ast.GetExports(symbol)[prototypeSymbol.Name] + symbolExport := getExports(symbol)[prototypeSymbol.Name] if symbolExport != nil { setParent(name, node) b.errorOnNode(symbolExport.Declarations[0], diagnostics.Duplicate_identifier_0, ast.SymbolName(prototypeSymbol)) } - ast.GetExports(symbol)[prototypeSymbol.Name] = prototypeSymbol + getExports(symbol)[prototypeSymbol.Name] = prototypeSymbol prototypeSymbol.Parent = symbol } @@ -1039,7 +1051,7 @@ func (b *Binder) bindFunctionPropertyAssignment(node *ast.Node) { b.bindAnonymousDeclaration(node, ast.SymbolFlagsProperty|ast.SymbolFlagsAssignment, ast.InternalSymbolNameComputed) addLateBoundAssignmentDeclarationToSymbol(node, funcSymbol) } else { - b.declareSymbol(ast.GetExports(funcSymbol), funcSymbol, node, ast.SymbolFlagsProperty|ast.SymbolFlagsAssignment, ast.SymbolFlagsPropertyExcludes) + b.declareSymbol(getExports(funcSymbol), funcSymbol, node, ast.SymbolFlagsProperty|ast.SymbolFlagsAssignment, ast.SymbolFlagsPropertyExcludes) } } } @@ -1063,9 +1075,9 @@ func (b *Binder) bindThisPropertyAssignment(node *ast.Node) { classSymbol := containingClass.Symbol() var symbolTable ast.SymbolTable if ast.IsStatic(thisContainer) { - symbolTable = ast.GetExports(containingClass.Symbol()) + symbolTable = getExports(containingClass.Symbol()) } else { - symbolTable = ast.GetMembers(containingClass.Symbol()) + symbolTable = getMembers(containingClass.Symbol()) } if ast.HasDynamicName(node) { b.declareSymbolEx(symbolTable, containingClass.Symbol(), node, ast.SymbolFlagsProperty, ast.SymbolFlagsNone, true /*isReplaceableByMethod*/, true /*isComputedName*/) @@ -1137,7 +1149,7 @@ func (b *Binder) bindParameter(node *ast.Node) { if ast.IsParameterPropertyDeclaration(node, node.Parent) { classDeclaration := node.Parent.Parent flags := ast.SymbolFlagsProperty | core.IfElse(decl.QuestionToken != nil, ast.SymbolFlagsOptional, ast.SymbolFlagsNone) - b.declareSymbol(ast.GetMembers(classDeclaration.Symbol()), classDeclaration.Symbol(), node, flags, ast.SymbolFlagsPropertyExcludes) + b.declareSymbol(getMembers(classDeclaration.Symbol()), classDeclaration.Symbol(), node, flags, ast.SymbolFlagsPropertyExcludes) } } @@ -1181,7 +1193,7 @@ func (b *Binder) bindBlockScopedDeclaration(node *ast.Node, symbolFlags ast.Symb } fallthrough default: - b.declareSymbol(ast.GetLocals(b.blockScopeContainer), nil /*parent*/, node, symbolFlags, symbolExcludes) + b.declareSymbol(getLocals(b.blockScopeContainer), nil /*parent*/, node, symbolFlags, symbolExcludes) } } @@ -1200,7 +1212,7 @@ func (b *Binder) bindTypeParameter(node *ast.Node) { if node.Parent.Kind == ast.KindInferType { container := b.getInferTypeContainer(node.Parent) if container != nil { - b.declareSymbol(ast.GetLocals(container), nil /*parent*/, node, ast.SymbolFlagsTypeParameter, ast.SymbolFlagsTypeParameterExcludes) + b.declareSymbol(getLocals(container), nil /*parent*/, node, ast.SymbolFlagsTypeParameter, ast.SymbolFlagsTypeParameterExcludes) } else { b.bindAnonymousDeclaration(node, ast.SymbolFlagsTypeParameter, b.getDeclarationName(node)) } diff --git a/internal/checker/checker.go b/internal/checker/checker.go index f7db09f9fc..09e32fcaee 100644 --- a/internal/checker/checker.go +++ b/internal/checker/checker.go @@ -1334,7 +1334,7 @@ func (c *Checker) mergeModuleAugmentation(moduleName *ast.Node) { }) { merged := c.mergeSymbol(moduleAugmentation.Symbol, mainModule, true /*unidirectional*/) // moduleName will be a StringLiteral since this is not `declare global`. - ast.GetSymbolTable(&c.patternAmbientModuleAugmentations)[moduleName.Text()] = merged + c.patternAmbientModuleAugmentations.GetOrInit()[moduleName.Text()] = merged } else { if mainModule.Exports[ast.InternalSymbolNameExportStar] != nil && len(moduleAugmentation.Symbol.Exports) != 0 { // We may need to merge the module augmentation's exports into the target symbols of the resolved exports @@ -13370,10 +13370,10 @@ func (c *Checker) mergeSymbol(target *ast.Symbol, source *ast.Symbol, unidirecti } target.Declarations = append(target.Declarations, source.Declarations...) if source.Members != nil { - c.mergeSymbolTable(ast.GetSymbolTable(&target.Members), source.Members, unidirectional, nil) + c.mergeSymbolTable(target.Members.GetOrInit(), source.Members, unidirectional, nil) } if source.Exports != nil { - c.mergeSymbolTable(ast.GetSymbolTable(&target.Exports), source.Exports, unidirectional, target) + c.mergeSymbolTable(target.Exports.GetOrInit(), source.Exports, unidirectional, target) } if !unidirectional { c.recordMergedSymbol(target, source) @@ -19951,9 +19951,9 @@ func (c *Checker) getPropertyOfUnionOrIntersectionType(t *Type, name string, ski func (c *Checker) getUnionOrIntersectionProperty(t *Type, name string, skipObjectFunctionPropertyAugment bool) *ast.Symbol { var cache ast.SymbolTable if skipObjectFunctionPropertyAugment { - cache = ast.GetSymbolTable(&t.AsUnionOrIntersectionType().propertyCacheWithoutFunctionPropertyAugment) + cache = t.AsUnionOrIntersectionType().propertyCacheWithoutFunctionPropertyAugment.GetOrInit() } else { - cache = ast.GetSymbolTable(&t.AsUnionOrIntersectionType().propertyCache) + cache = t.AsUnionOrIntersectionType().propertyCache.GetOrInit() } if prop := cache[name]; prop != nil { return prop @@ -19963,7 +19963,7 @@ func (c *Checker) getUnionOrIntersectionProperty(t *Type, name string, skipObjec cache[name] = prop // Propagate an entry from the non-augmented cache to the augmented cache unless the property is partial. if skipObjectFunctionPropertyAugment && prop.CheckFlags&ast.CheckFlagsPartial == 0 { - augmentedCache := ast.GetSymbolTable(&t.AsUnionOrIntersectionType().propertyCache) + augmentedCache := t.AsUnionOrIntersectionType().propertyCache.GetOrInit() if augmentedCache[name] == nil { augmentedCache[name] = prop }