Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pkg/sqlmodel/causality.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ func (r *RowChange) getCausalityString(values []interface{}) []string {
// Only causality keys use this table name; r.sourceTable keeps the original source table.
sourceTable = r.causalityKeySourceTable
}
pkAndUks := r.whereHandle.UniqueIdxs
pkAndUks := r.whereHandle.getUniqueIdxs()
if len(pkAndUks) == 0 {
// the table has no PK/UK, all values of the row consists the causality key
return []string{genKeyString(sourceTable.String(), r.sourceTableInfo.Columns, values)}
Expand Down
2 changes: 1 addition & 1 deletion pkg/sqlmodel/reduce.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ func (r *RowChange) IsPrimaryOrUniqueKeyUpdated() bool {
}
}

for _, idx := range r.whereHandle.UniqueIdxs {
for _, idx := range r.whereHandle.getUniqueIdxs() {
if idx == nil || idx == r.whereHandle.UniqueNotNullIdx {
continue
}
Expand Down
17 changes: 17 additions & 0 deletions pkg/sqlmodel/where_handle.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
package sqlmodel

import (
"sync"

"github.com/pingcap/log"
"github.com/pingcap/tidb/pkg/meta/model"
pmodel "github.com/pingcap/tidb/pkg/parser/ast"
Expand All @@ -23,6 +25,8 @@ import (

// WhereHandle is used to generate a WHERE clause in SQL.
type WhereHandle struct {
mu sync.RWMutex

UniqueNotNullIdx *model.IndexInfo
// If the index and columns have no NOT NULL constraint, but all data is NOT
// NULL, we can still use it.
Expand Down Expand Up @@ -138,6 +142,9 @@ func (h *WhereHandle) getWhereIdxByData(data []interface{}) *model.IndexInfo {
if h.UniqueNotNullIdx != nil {
return h.UniqueNotNullIdx
}

h.mu.Lock()
defer h.mu.Unlock()
for i, idx := range h.UniqueIdxs {
ok := true
for _, idxCol := range idx.Columns {
Expand All @@ -153,3 +160,13 @@ func (h *WhereHandle) getWhereIdxByData(data []interface{}) *model.IndexInfo {
}
return nil
}

func (h *WhereHandle) getUniqueIdxs() []*model.IndexInfo {
if h.UniqueNotNullIdx != nil {
return h.UniqueIdxs
}

h.mu.RLock()
defer h.mu.RUnlock()
return append([]*model.IndexInfo(nil), h.UniqueIdxs...)
}
52 changes: 52 additions & 0 deletions pkg/sqlmodel/where_handle_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
package sqlmodel

import (
"fmt"
"sync"
"testing"

"github.com/pingcap/tidb/pkg/ddl"
Expand Down Expand Up @@ -215,3 +217,53 @@ CREATE TABLE t (
idx = handle.getWhereIdxByData([]interface{}{1, nil, 3, nil})
require.Nil(t, idx)
}

func TestGetWhereIdxByDataNoRace(t *testing.T) {
t.Parallel()

createSQL := `
CREATE TABLE t (
c1 INT,
c2 INT,
UNIQUE INDEX idx1 (c1),
UNIQUE INDEX idx2 (c2)
)`
p := parser.New()
node, err := p.ParseOneStmt(createSQL, "", "")
require.NoError(t, err)
ti, err := ddl.BuildTableInfoFromAST(metabuild.NewContext(), node.(*ast.CreateTableStmt))
require.NoError(t, err)

handle := GetWhereHandle(ti, ti)
checkIndex := func(data []interface{}, expected string) error {
idx := handle.getWhereIdxByData(data)
if idx == nil {
return fmt.Errorf("expected %s, got nil", expected)
}
if idx.Name.L != expected {
return fmt.Errorf("expected %s, got %s", expected, idx.Name.L)
}
return nil
}

const concurrency = 100
var wg sync.WaitGroup
errCh := make(chan error, concurrency*2)
for i := 0; i < concurrency; i++ {
wg.Add(1)
go func() {
defer wg.Done()
if err := checkIndex([]interface{}{1, nil}, "idx1"); err != nil {
errCh <- err
}
if err := checkIndex([]interface{}{nil, 2}, "idx2"); err != nil {
errCh <- err
}
}()
}
wg.Wait()
close(errCh)
for err := range errCh {
require.NoError(t, err)
}
}
Loading