Skip to content
Snippets Groups Projects
Commit b3fae7a4 authored by Martina Ferrari's avatar Martina Ferrari
Browse files

New upstream version 1.3.0

parent 42a0bf16
No related branches found
No related tags found
No related merge requests found
language: go language: go
sudo: false
go: go:
- 1.2 - 1.2.x
- 1.3 - 1.3.x
- 1.4 - 1.4 # has no cover tool for latest releases
- release - 1.5.x
- tip - 1.6.x
- 1.7.x
- 1.8.x
# - tip # sadly fails most of the times
script: script:
- go test -v ./... - go vet
- go test -race ./... - test -z "$(go fmt ./...)" # fail if not formatted properly
- go test -race -coverprofile=coverage.txt -covermode=atomic
after_success:
- bash <(curl -s https://codecov.io/bash)
The three clause BSD license (http://en.wikipedia.org/wiki/BSD_licenses) The three clause BSD license (http://en.wikipedia.org/wiki/BSD_licenses)
Copyright (c) 2013-2015, DataDog.lt team Copyright (c) 2013-2017, DATA-DOG team
All rights reserved. All rights reserved.
Redistribution and use in source and binary forms, with or without Redistribution and use in source and binary forms, with or without
......
[![Build Status](https://travis-ci.org/DATA-DOG/go-sqlmock.png)](https://travis-ci.org/DATA-DOG/go-sqlmock) [![Build Status](https://travis-ci.org/DATA-DOG/go-sqlmock.svg)](https://travis-ci.org/DATA-DOG/go-sqlmock)
[![GoDoc](https://godoc.org/github.com/DATA-DOG/go-sqlmock?status.png)](https://godoc.org/github.com/DATA-DOG/go-sqlmock) [![GoDoc](https://godoc.org/github.com/DATA-DOG/go-sqlmock?status.svg)](https://godoc.org/github.com/DATA-DOG/go-sqlmock)
[![codecov.io](https://codecov.io/github/DATA-DOG/go-sqlmock/branch/master/graph/badge.svg)](https://codecov.io/github/DATA-DOG/go-sqlmock)
# Sql driver mock for Golang # Sql driver mock for Golang
This is a **mock** driver as **database/sql/driver** which is very flexible and pragmatic to **sqlmock** is a mock library implementing [sql/driver](https://godoc.org/database/sql/driver). Which has one and only
manage and mock expected queries. All the expectations should be met and all queries and actions purpose - to simulate any **sql** driver behavior in tests, without needing a real database connection. It helps to
triggered should be mocked in order to pass a test. The package has no 3rd party dependencies. maintain correct **TDD** workflow.
**NOTE:** regarding major issues #20 and #9 the **api** has changed to support concurrency and more than - this library is now complete and stable. (you may not find new changes for this reason)
one database connection. - supports concurrency and multiple connections.
- supports **go1.8** Context related feature mocking and Named sql parameters.
- does not require any modifications to your source code.
- the driver allows to mock any sql driver method behavior.
- has strict by default expectation order matching.
- has no third party dependencies.
If you need an old version, checkout **go-sqlmock** at gopkg.in: **NOTE:** in **v1.2.0** **sqlmock.Rows** has changed to struct from interface, if you were using any type references to that
interface, you will need to switch it to a pointer struct type. Also, **sqlmock.Rows** were used to implement **driver.Rows**
go get gopkg.in/DATA-DOG/go-sqlmock.v0 interface, which was not required or useful for mocking and was removed. Hope it will not cause issues.
Otherwise use the **v1** branch from master which should be stable afterwards, because all the issues which
were known will be fixed in this version.
## Install ## Install
go get gopkg.in/DATA-DOG/go-sqlmock.v1 go get gopkg.in/DATA-DOG/go-sqlmock.v1
Or take an older version:
go get gopkg.in/DATA-DOG/go-sqlmock.v0
## Documentation and Examples ## Documentation and Examples
Visit [godoc](http://godoc.org/github.com/DATA-DOG/go-sqlmock) for general examples and public api reference. Visit [godoc](http://godoc.org/github.com/DATA-DOG/go-sqlmock) for general examples and public api reference.
...@@ -91,7 +90,7 @@ import ( ...@@ -91,7 +90,7 @@ import (
"fmt" "fmt"
"testing" "testing"
"github.com/DATA-DOG/go-sqlmock" "gopkg.in/DATA-DOG/go-sqlmock.v1"
) )
// a successful case // a successful case
...@@ -145,12 +144,65 @@ func TestShouldRollbackStatUpdatesOnFailure(t *testing.T) { ...@@ -145,12 +144,65 @@ func TestShouldRollbackStatUpdatesOnFailure(t *testing.T) {
} }
``` ```
## Matching arguments like time.Time
There may be arguments which are of `struct` type and cannot be compared easily by value like `time.Time`. In this case
**sqlmock** provides an [Argument](https://godoc.org/github.com/DATA-DOG/go-sqlmock#Argument) interface which
can be used in more sophisticated matching. Here is a simple example of time argument matching:
``` go
type AnyTime struct{}
// Match satisfies sqlmock.Argument interface
func (a AnyTime) Match(v driver.Value) bool {
_, ok := v.(time.Time)
return ok
}
func TestAnyTimeArgument(t *testing.T) {
t.Parallel()
db, mock, err := New()
if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
mock.ExpectExec("INSERT INTO users").
WithArgs("john", AnyTime{}).
WillReturnResult(NewResult(1, 1))
_, err = db.Exec("INSERT INTO users(name, created_at) VALUES (?, ?)", "john", time.Now())
if err != nil {
t.Errorf("error '%s' was not expected, while inserting a row", err)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expections: %s", err)
}
}
```
It only asserts that argument is of `time.Time` type.
## Run tests ## Run tests
go test -race go test -race
## Changes ## Change Log
- **2017-09-01** - it is now possible to expect that prepared statement will be closed,
using **ExpectedPrepare.WillBeClosed**.
- **2017-02-09** - implemented support for **go1.8** features. **Rows** interface was changed to struct
but contains all methods as before and should maintain backwards compatibility. **ExpectedQuery.WillReturnRows** may now
accept multiple row sets.
- **2016-11-02** - `db.Prepare()` was not validating expected prepare SQL
query. It should still be validated even if Exec or Query is not
executed on that prepared statement.
- **2016-02-23** - added **sqlmock.AnyArg()** function to provide any kind
of argument matcher.
- **2016-02-23** - convert expected arguments to driver.Value as natural
driver does, the change may affect time.Time comparison and will be
stricter. See [issue](https://github.com/DATA-DOG/go-sqlmock/issues/31).
- **2015-08-27** - **v1** api change, concurrency support, all known issues fixed. - **2015-08-27** - **v1** api change, concurrency support, all known issues fixed.
- **2014-08-16** instead of **panic** during reflect type mismatch when comparing query arguments - now return error - **2014-08-16** instead of **panic** during reflect type mismatch when comparing query arguments - now return error
- **2014-08-14** added **sqlmock.NewErrorResult** which gives an option to return driver.Result with errors for - **2014-08-14** added **sqlmock.NewErrorResult** which gives an option to return driver.Result with errors for
......
package sqlmock
import "database/sql/driver"
// Argument interface allows to match
// any argument in specific way when used with
// ExpectedQuery and ExpectedExec expectations.
type Argument interface {
Match(driver.Value) bool
}
// AnyArg will return an Argument which can
// match any kind of arguments.
//
// Useful for time.Time or similar kinds of arguments.
func AnyArg() Argument {
return anyArgument{}
}
type anyArgument struct{}
func (a anyArgument) Match(_ driver.Value) bool {
return true
}
package sqlmock
import (
"database/sql/driver"
"testing"
"time"
)
type AnyTime struct{}
// Match satisfies sqlmock.Argument interface
func (a AnyTime) Match(v driver.Value) bool {
_, ok := v.(time.Time)
return ok
}
func TestAnyTimeArgument(t *testing.T) {
t.Parallel()
db, mock, err := New()
if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
mock.ExpectExec("INSERT INTO users").
WithArgs("john", AnyTime{}).
WillReturnResult(NewResult(1, 1))
_, err = db.Exec("INSERT INTO users(name, created_at) VALUES (?, ?)", "john", time.Now())
if err != nil {
t.Errorf("error '%s' was not expected, while inserting a row", err)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expections: %s", err)
}
}
func TestByteSliceArgument(t *testing.T) {
t.Parallel()
db, mock, err := New()
if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
username := []byte("user")
mock.ExpectExec("INSERT INTO users").WithArgs(username).WillReturnResult(NewResult(1, 1))
_, err = db.Exec("INSERT INTO users(username) VALUES (?)", username)
if err != nil {
t.Errorf("error '%s' was not expected, while inserting a row", err)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expections: %s", err)
}
}
...@@ -39,7 +39,7 @@ func (d *mockDriver) Open(dsn string) (driver.Conn, error) { ...@@ -39,7 +39,7 @@ func (d *mockDriver) Open(dsn string) (driver.Conn, error) {
// and a mock to manage expectations. // and a mock to manage expectations.
// Pings db so that all expectations could be // Pings db so that all expectations could be
// asserted. // asserted.
func New() (db *sql.DB, mock Sqlmock, err error) { func New() (*sql.DB, Sqlmock, error) {
pool.Lock() pool.Lock()
dsn := fmt.Sprintf("sqlmock_db_%d", pool.counter) dsn := fmt.Sprintf("sqlmock_db_%d", pool.counter)
pool.counter++ pool.counter++
...@@ -48,9 +48,31 @@ func New() (db *sql.DB, mock Sqlmock, err error) { ...@@ -48,9 +48,31 @@ func New() (db *sql.DB, mock Sqlmock, err error) {
pool.conns[dsn] = smock pool.conns[dsn] = smock
pool.Unlock() pool.Unlock()
db, err = sql.Open("sqlmock", dsn) return smock.open()
if err != nil { }
return
// NewWithDSN creates sqlmock database connection
// with a specific DSN and a mock to manage expectations.
// Pings db so that all expectations could be asserted.
//
// This method is introduced because of sql abstraction
// libraries, which do not provide a way to initialize
// with sql.DB instance. For example GORM library.
//
// Note, it will error if attempted to create with an
// already used dsn
//
// It is not recommended to use this method, unless you
// really need it and there is no other way around.
func NewWithDSN(dsn string) (*sql.DB, Sqlmock, error) {
pool.Lock()
if _, ok := pool.conns[dsn]; ok {
pool.Unlock()
return nil, nil, fmt.Errorf("cannot create a new mock database with the same dsn: %s", dsn)
} }
return db, smock, db.Ping() smock := &sqlmock{dsn: dsn, drv: pool, ordered: true}
pool.conns[dsn] = smock
pool.Unlock()
return smock.open()
} }
...@@ -5,6 +5,10 @@ import ( ...@@ -5,6 +5,10 @@ import (
"testing" "testing"
) )
type void struct{}
func (void) Print(...interface{}) {}
func ExampleNew() { func ExampleNew() {
db, mock, err := New() db, mock, err := New()
if err != nil { if err != nil {
...@@ -71,6 +75,9 @@ func TestTwoOpenConnectionsOnTheSameDSN(t *testing.T) { ...@@ -71,6 +75,9 @@ func TestTwoOpenConnectionsOnTheSameDSN(t *testing.T) {
t.Errorf("expected no error, but got: %s", err) t.Errorf("expected no error, but got: %s", err)
} }
db2, mock2, err := New() db2, mock2, err := New()
if err != nil {
t.Errorf("expected no error, but got: %s", err)
}
if len(pool.conns) != 2 { if len(pool.conns) != 2 {
t.Errorf("expected 2 connection in pool, but there is: %d", len(pool.conns)) t.Errorf("expected 2 connection in pool, but there is: %d", len(pool.conns))
} }
...@@ -82,3 +89,24 @@ func TestTwoOpenConnectionsOnTheSameDSN(t *testing.T) { ...@@ -82,3 +89,24 @@ func TestTwoOpenConnectionsOnTheSameDSN(t *testing.T) {
t.Errorf("expected not the same mock instance, but it is the same") t.Errorf("expected not the same mock instance, but it is the same")
} }
} }
func TestWrongDSN(t *testing.T) {
t.Parallel()
db, _, _ := New()
defer db.Close()
if _, err := db.Driver().Open("wrong_dsn"); err == nil {
t.Error("expected error on Open")
}
}
func TestNewDSN(t *testing.T) {
if _, _, err := NewWithDSN("sqlmock_db_99"); err != nil {
t.Errorf("expected no error on NewWithDSN, but got: %s", err)
}
}
func TestDuplicateNewDSN(t *testing.T) {
if _, _, err := NewWithDSN("sqlmock_db_1"); err == nil {
t.Error("expected error on NewWithDSN")
}
}
...@@ -3,19 +3,12 @@ package sqlmock ...@@ -3,19 +3,12 @@ package sqlmock
import ( import (
"database/sql/driver" "database/sql/driver"
"fmt" "fmt"
"reflect"
"regexp" "regexp"
"strings" "strings"
"sync" "sync"
"time"
) )
// Argument interface allows to match
// any argument in specific way when used with
// ExpectedQuery and ExpectedExec expectations.
type Argument interface {
Match(driver.Value) bool
}
// an expectation interface // an expectation interface
type expectation interface { type expectation interface {
fulfilled() bool fulfilled() bool
...@@ -61,6 +54,7 @@ func (e *ExpectedClose) String() string { ...@@ -61,6 +54,7 @@ func (e *ExpectedClose) String() string {
// returned by *Sqlmock.ExpectBegin. // returned by *Sqlmock.ExpectBegin.
type ExpectedBegin struct { type ExpectedBegin struct {
commonExpectation commonExpectation
delay time.Duration
} }
// WillReturnError allows to set an error for *sql.DB.Begin action // WillReturnError allows to set an error for *sql.DB.Begin action
...@@ -78,6 +72,13 @@ func (e *ExpectedBegin) String() string { ...@@ -78,6 +72,13 @@ func (e *ExpectedBegin) String() string {
return msg return msg
} }
// WillDelayFor allows to specify duration for which it will delay
// result. May be used together with Context
func (e *ExpectedBegin) WillDelayFor(duration time.Duration) *ExpectedBegin {
e.delay = duration
return e
}
// ExpectedCommit is used to manage *sql.Tx.Commit expectation // ExpectedCommit is used to manage *sql.Tx.Commit expectation
// returned by *Sqlmock.ExpectCommit. // returned by *Sqlmock.ExpectCommit.
type ExpectedCommit struct { type ExpectedCommit struct {
...@@ -125,7 +126,8 @@ func (e *ExpectedRollback) String() string { ...@@ -125,7 +126,8 @@ func (e *ExpectedRollback) String() string {
// Returned by *Sqlmock.ExpectQuery. // Returned by *Sqlmock.ExpectQuery.
type ExpectedQuery struct { type ExpectedQuery struct {
queryBasedExpectation queryBasedExpectation
rows driver.Rows rows driver.Rows
delay time.Duration
} }
// WithArgs will match given expected args to actual database query arguments. // WithArgs will match given expected args to actual database query arguments.
...@@ -142,16 +144,16 @@ func (e *ExpectedQuery) WillReturnError(err error) *ExpectedQuery { ...@@ -142,16 +144,16 @@ func (e *ExpectedQuery) WillReturnError(err error) *ExpectedQuery {
return e return e
} }
// WillReturnRows specifies the set of resulting rows that will be returned // WillDelayFor allows to specify duration for which it will delay
// by the triggered query // result. May be used together with Context
func (e *ExpectedQuery) WillReturnRows(rows driver.Rows) *ExpectedQuery { func (e *ExpectedQuery) WillDelayFor(duration time.Duration) *ExpectedQuery {
e.rows = rows e.delay = duration
return e return e
} }
// String returns string representation // String returns string representation
func (e *ExpectedQuery) String() string { func (e *ExpectedQuery) String() string {
msg := "ExpectedQuery => expecting Query or QueryRow which:" msg := "ExpectedQuery => expecting Query, QueryContext or QueryRow which:"
msg += "\n - matches sql: '" + e.sqlRegex.String() + "'" msg += "\n - matches sql: '" + e.sqlRegex.String() + "'"
if len(e.args) == 0 { if len(e.args) == 0 {
...@@ -165,12 +167,7 @@ func (e *ExpectedQuery) String() string { ...@@ -165,12 +167,7 @@ func (e *ExpectedQuery) String() string {
} }
if e.rows != nil { if e.rows != nil {
msg += "\n - should return rows:\n" msg += fmt.Sprintf("\n - %s", e.rows)
rs, _ := e.rows.(*rows)
for i, row := range rs.rows {
msg += fmt.Sprintf(" %d - %+v\n", i, row)
}
msg = strings.TrimSpace(msg)
} }
if e.err != nil { if e.err != nil {
...@@ -185,6 +182,7 @@ func (e *ExpectedQuery) String() string { ...@@ -185,6 +182,7 @@ func (e *ExpectedQuery) String() string {
type ExpectedExec struct { type ExpectedExec struct {
queryBasedExpectation queryBasedExpectation
result driver.Result result driver.Result
delay time.Duration
} }
// WithArgs will match given expected args to actual database exec operation arguments. // WithArgs will match given expected args to actual database exec operation arguments.
...@@ -201,9 +199,16 @@ func (e *ExpectedExec) WillReturnError(err error) *ExpectedExec { ...@@ -201,9 +199,16 @@ func (e *ExpectedExec) WillReturnError(err error) *ExpectedExec {
return e return e
} }
// WillDelayFor allows to specify duration for which it will delay
// result. May be used together with Context
func (e *ExpectedExec) WillDelayFor(duration time.Duration) *ExpectedExec {
e.delay = duration
return e
}
// String returns string representation // String returns string representation
func (e *ExpectedExec) String() string { func (e *ExpectedExec) String() string {
msg := "ExpectedExec => expecting Exec which:" msg := "ExpectedExec => expecting Exec or ExecContext which:"
msg += "\n - matches sql: '" + e.sqlRegex.String() + "'" msg += "\n - matches sql: '" + e.sqlRegex.String() + "'"
if len(e.args) == 0 { if len(e.args) == 0 {
...@@ -247,10 +252,13 @@ func (e *ExpectedExec) WillReturnResult(result driver.Result) *ExpectedExec { ...@@ -247,10 +252,13 @@ func (e *ExpectedExec) WillReturnResult(result driver.Result) *ExpectedExec {
// Returned by *Sqlmock.ExpectPrepare. // Returned by *Sqlmock.ExpectPrepare.
type ExpectedPrepare struct { type ExpectedPrepare struct {
commonExpectation commonExpectation
mock *sqlmock mock *sqlmock
sqlRegex *regexp.Regexp sqlRegex *regexp.Regexp
statement driver.Stmt statement driver.Stmt
closeErr error closeErr error
mustBeClosed bool
wasClosed bool
delay time.Duration
} }
// WillReturnError allows to set an error for the expected *sql.DB.Prepare or *sql.Tx.Prepare action. // WillReturnError allows to set an error for the expected *sql.DB.Prepare or *sql.Tx.Prepare action.
...@@ -259,12 +267,26 @@ func (e *ExpectedPrepare) WillReturnError(err error) *ExpectedPrepare { ...@@ -259,12 +267,26 @@ func (e *ExpectedPrepare) WillReturnError(err error) *ExpectedPrepare {
return e return e
} }
// WillReturnCloseError allows to set an error for this prapared statement Close action // WillReturnCloseError allows to set an error for this prepared statement Close action
func (e *ExpectedPrepare) WillReturnCloseError(err error) *ExpectedPrepare { func (e *ExpectedPrepare) WillReturnCloseError(err error) *ExpectedPrepare {
e.closeErr = err e.closeErr = err
return e return e
} }
// WillDelayFor allows to specify duration for which it will delay
// result. May be used together with Context
func (e *ExpectedPrepare) WillDelayFor(duration time.Duration) *ExpectedPrepare {
e.delay = duration
return e
}
// WillBeClosed expects this prepared statement to
// be closed.
func (e *ExpectedPrepare) WillBeClosed() *ExpectedPrepare {
e.mustBeClosed = true
return e
}
// ExpectQuery allows to expect Query() or QueryRow() on this prepared statement. // ExpectQuery allows to expect Query() or QueryRow() on this prepared statement.
// this method is convenient in order to prevent duplicating sql query string matching. // this method is convenient in order to prevent duplicating sql query string matching.
func (e *ExpectedPrepare) ExpectQuery() *ExpectedQuery { func (e *ExpectedPrepare) ExpectQuery() *ExpectedQuery {
...@@ -307,63 +329,25 @@ type queryBasedExpectation struct { ...@@ -307,63 +329,25 @@ type queryBasedExpectation struct {
args []driver.Value args []driver.Value
} }
func (e *queryBasedExpectation) attemptMatch(sql string, args []driver.Value) (ret bool) { func (e *queryBasedExpectation) attemptMatch(sql string, args []namedValue) (err error) {
if !e.queryMatches(sql) { if !e.queryMatches(sql) {
return return fmt.Errorf(`could not match sql: "%s" with expected regexp "%s"`, sql, e.sqlRegex.String())
} }
defer recover() // ignore panic since we attempt a match // catch panic
defer func() {
if e := recover(); e != nil {
_, ok := e.(error)
if !ok {
err = fmt.Errorf(e.(string))
}
}
}()
if e.argsMatches(args) { err = e.argsMatches(args)
return true
}
return return
} }
func (e *queryBasedExpectation) queryMatches(sql string) bool { func (e *queryBasedExpectation) queryMatches(sql string) bool {
return e.sqlRegex.MatchString(sql) return e.sqlRegex.MatchString(sql)
} }
func (e *queryBasedExpectation) argsMatches(args []driver.Value) bool {
if nil == e.args {
return true
}
if len(args) != len(e.args) {
return false
}
for k, v := range args {
matcher, ok := e.args[k].(Argument)
if ok {
if !matcher.Match(v) {
return false
}
continue
}
vi := reflect.ValueOf(v)
ai := reflect.ValueOf(e.args[k])
switch vi.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if vi.Int() != ai.Int() {
return false
}
case reflect.Float32, reflect.Float64:
if vi.Float() != ai.Float() {
return false
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
if vi.Uint() != ai.Uint() {
return false
}
case reflect.String:
if vi.String() != ai.String() {
return false
}
default:
// compare types like time.Time based on type only
if vi.Kind() != ai.Kind() {
return false
}
}
}
return true
}
// +build !go1.8
package sqlmock
import (
"database/sql/driver"
"fmt"
"reflect"
)
// WillReturnRows specifies the set of resulting rows that will be returned
// by the triggered query
func (e *ExpectedQuery) WillReturnRows(rows *Rows) *ExpectedQuery {
e.rows = &rowSets{sets: []*Rows{rows}}
return e
}
func (e *queryBasedExpectation) argsMatches(args []namedValue) error {
if nil == e.args {
return nil
}
if len(args) != len(e.args) {
return fmt.Errorf("expected %d, but got %d arguments", len(e.args), len(args))
}
for k, v := range args {
// custom argument matcher
matcher, ok := e.args[k].(Argument)
if ok {
// @TODO: does it make sense to pass value instead of named value?
if !matcher.Match(v.Value) {
return fmt.Errorf("matcher %T could not match %d argument %T - %+v", matcher, k, args[k], args[k])
}
continue
}
dval := e.args[k]
// convert to driver converter
darg, err := driver.DefaultParameterConverter.ConvertValue(dval)
if err != nil {
return fmt.Errorf("could not convert %d argument %T - %+v to driver value: %s", k, e.args[k], e.args[k], err)
}
if !driver.IsValue(darg) {
return fmt.Errorf("argument %d: non-subset type %T returned from Value", k, darg)
}
if !reflect.DeepEqual(darg, v.Value) {
return fmt.Errorf("argument %d expected [%T - %+v] does not match actual [%T - %+v]", k, darg, darg, v.Value, v.Value)
}
}
return nil
}
// +build go1.8
package sqlmock
import (
"database/sql"
"database/sql/driver"
"fmt"
"reflect"
)
// WillReturnRows specifies the set of resulting rows that will be returned
// by the triggered query
func (e *ExpectedQuery) WillReturnRows(rows ...*Rows) *ExpectedQuery {
sets := make([]*Rows, len(rows))
for i, r := range rows {
sets[i] = r
}
e.rows = &rowSets{sets: sets}
return e
}
func (e *queryBasedExpectation) argsMatches(args []namedValue) error {
if nil == e.args {
return nil
}
if len(args) != len(e.args) {
return fmt.Errorf("expected %d, but got %d arguments", len(e.args), len(args))
}
// @TODO should we assert either all args are named or ordinal?
for k, v := range args {
// custom argument matcher
matcher, ok := e.args[k].(Argument)
if ok {
if !matcher.Match(v.Value) {
return fmt.Errorf("matcher %T could not match %d argument %T - %+v", matcher, k, args[k], args[k])
}
continue
}
dval := e.args[k]
if named, isNamed := dval.(sql.NamedArg); isNamed {
dval = named.Value
if v.Name != named.Name {
return fmt.Errorf("named argument %d: name: \"%s\" does not match expected: \"%s\"", k, v.Name, named.Name)
}
} else if k+1 != v.Ordinal {
return fmt.Errorf("argument %d: ordinal position: %d does not match expected: %d", k, k+1, v.Ordinal)
}
// convert to driver converter
darg, err := driver.DefaultParameterConverter.ConvertValue(dval)
if err != nil {
return fmt.Errorf("could not convert %d argument %T - %+v to driver value: %s", k, e.args[k], e.args[k], err)
}
if !driver.IsValue(darg) {
return fmt.Errorf("argument %d: non-subset type %T returned from Value", k, darg)
}
if !reflect.DeepEqual(darg, v.Value) {
return fmt.Errorf("argument %d expected [%T - %+v] does not match actual [%T - %+v]", k, darg, darg, v.Value, v.Value)
}
}
return nil
}
// +build go1.8
package sqlmock
import (
"database/sql"
"database/sql/driver"
"testing"
)
func TestQueryExpectationNamedArgComparison(t *testing.T) {
e := &queryBasedExpectation{}
against := []namedValue{{Value: int64(5), Name: "id"}}
if err := e.argsMatches(against); err != nil {
t.Errorf("arguments should match, since the no expectation was set, but got err: %s", err)
}
e.args = []driver.Value{
sql.Named("id", 5),
sql.Named("s", "str"),
}
if err := e.argsMatches(against); err == nil {
t.Error("arguments should not match, since the size is not the same")
}
against = []namedValue{
{Value: int64(5), Name: "id"},
{Value: "str", Name: "s"},
}
if err := e.argsMatches(against); err != nil {
t.Errorf("arguments should have matched, but it did not: %v", err)
}
against = []namedValue{
{Value: int64(5), Name: "id"},
{Value: "str", Name: "username"},
}
if err := e.argsMatches(against); err == nil {
t.Error("arguments matched, but it should have not due to Name")
}
e.args = []driver.Value{int64(5), "str"}
against = []namedValue{
{Value: int64(5), Ordinal: 0},
{Value: "str", Ordinal: 1},
}
if err := e.argsMatches(against); err == nil {
t.Error("arguments matched, but it should have not due to wrong Ordinal position")
}
against = []namedValue{
{Value: int64(5), Ordinal: 1},
{Value: "str", Ordinal: 2},
}
if err := e.argsMatches(against); err != nil {
t.Errorf("arguments should have matched, but it did not: %v", err)
}
}
...@@ -8,60 +8,101 @@ import ( ...@@ -8,60 +8,101 @@ import (
"time" "time"
) )
type matcher struct {
}
func (m matcher) Match(driver.Value) bool {
return true
}
func TestQueryExpectationArgComparison(t *testing.T) { func TestQueryExpectationArgComparison(t *testing.T) {
e := &queryBasedExpectation{} e := &queryBasedExpectation{}
against := []driver.Value{5} against := []namedValue{{Value: int64(5), Ordinal: 1}}
if !e.argsMatches(against) { if err := e.argsMatches(against); err != nil {
t.Error("arguments should match, since the no expectation was set") t.Errorf("arguments should match, since the no expectation was set, but got err: %s", err)
} }
e.args = []driver.Value{5, "str"} e.args = []driver.Value{5, "str"}
against = []driver.Value{5} against = []namedValue{{Value: int64(5), Ordinal: 1}}
if e.argsMatches(against) { if err := e.argsMatches(against); err == nil {
t.Error("arguments should not match, since the size is not the same") t.Error("arguments should not match, since the size is not the same")
} }
against = []driver.Value{3, "str"} against = []namedValue{
if e.argsMatches(against) { {Value: int64(3), Ordinal: 1},
{Value: "str", Ordinal: 2},
}
if err := e.argsMatches(against); err == nil {
t.Error("arguments should not match, since the first argument (int value) is different") t.Error("arguments should not match, since the first argument (int value) is different")
} }
against = []driver.Value{5, "st"} against = []namedValue{
if e.argsMatches(against) { {Value: int64(5), Ordinal: 1},
{Value: "st", Ordinal: 2},
}
if err := e.argsMatches(against); err == nil {
t.Error("arguments should not match, since the second argument (string value) is different") t.Error("arguments should not match, since the second argument (string value) is different")
} }
against = []driver.Value{5, "str"} against = []namedValue{
if !e.argsMatches(against) { {Value: int64(5), Ordinal: 1},
t.Error("arguments should match, but it did not") {Value: "str", Ordinal: 2},
}
if err := e.argsMatches(against); err != nil {
t.Errorf("arguments should match, but it did not: %s", err)
} }
e.args = []driver.Value{5, time.Now()}
const longForm = "Jan 2, 2006 at 3:04pm (MST)" const longForm = "Jan 2, 2006 at 3:04pm (MST)"
tm, _ := time.Parse(longForm, "Feb 3, 2013 at 7:54pm (PST)") tm, _ := time.Parse(longForm, "Feb 3, 2013 at 7:54pm (PST)")
e.args = []driver.Value{5, tm}
against = []driver.Value{5, tm} against = []namedValue{
if !e.argsMatches(against) { {Value: int64(5), Ordinal: 1},
t.Error("arguments should match (time will be compared only by type), but it did not") {Value: tm, Ordinal: 2},
} }
if err := e.argsMatches(against); err != nil {
against = []driver.Value{5, matcher{}}
if !e.argsMatches(against) {
t.Error("arguments should match, but it did not") t.Error("arguments should match, but it did not")
} }
e.args = []driver.Value{5, AnyArg()}
if err := e.argsMatches(against); err != nil {
t.Errorf("arguments should match, but it did not: %s", err)
}
}
func TestQueryExpectationArgComparisonBool(t *testing.T) {
var e *queryBasedExpectation
e = &queryBasedExpectation{args: []driver.Value{true}}
against := []namedValue{
{Value: true, Ordinal: 1},
}
if err := e.argsMatches(against); err != nil {
t.Error("arguments should match, since arguments are the same")
}
e = &queryBasedExpectation{args: []driver.Value{false}}
against = []namedValue{
{Value: false, Ordinal: 1},
}
if err := e.argsMatches(against); err != nil {
t.Error("arguments should match, since argument are the same")
}
e = &queryBasedExpectation{args: []driver.Value{true}}
against = []namedValue{
{Value: false, Ordinal: 1},
}
if err := e.argsMatches(against); err == nil {
t.Error("arguments should not match, since argument is different")
}
e = &queryBasedExpectation{args: []driver.Value{false}}
against = []namedValue{
{Value: true, Ordinal: 1},
}
if err := e.argsMatches(against); err == nil {
t.Error("arguments should not match, since argument is different")
}
} }
func TestQueryExpectationSqlMatch(t *testing.T) { func TestQueryExpectationSqlMatch(t *testing.T) {
e := &ExpectedExec{} e := &ExpectedExec{}
e.sqlRegex = regexp.MustCompile("SELECT x FROM") e.sqlRegex = regexp.MustCompile("SELECT x FROM")
if !e.queryMatches("SELECT x FROM someting") { if !e.queryMatches("SELECT x FROM someting") {
t.Errorf("Sql must have matched the query") t.Errorf("Sql must have matched the query")
...@@ -73,7 +114,7 @@ func TestQueryExpectationSqlMatch(t *testing.T) { ...@@ -73,7 +114,7 @@ func TestQueryExpectationSqlMatch(t *testing.T) {
} }
} }
func ExampleExpectExec() { func ExampleExpectedExec() {
db, mock, _ := New() db, mock, _ := New()
result := NewErrorResult(fmt.Errorf("some error")) result := NewErrorResult(fmt.Errorf("some error"))
mock.ExpectExec("^INSERT (.+)").WillReturnResult(result) mock.ExpectExec("^INSERT (.+)").WillReturnResult(result)
...@@ -82,3 +123,32 @@ func ExampleExpectExec() { ...@@ -82,3 +123,32 @@ func ExampleExpectExec() {
fmt.Println(err) fmt.Println(err)
// Output: some error // Output: some error
} }
func TestBuildQuery(t *testing.T) {
db, mock, _ := New()
query := `
SELECT
name,
email,
address,
anotherfield
FROM user
where
name = 'John'
and
address = 'Jakarta'
`
mock.ExpectQuery(query)
mock.ExpectExec(query)
mock.ExpectPrepare(query)
db.QueryRow(query)
db.Exec(query)
db.Prepare(query)
if err := mock.ExpectationsWereMet(); err != nil {
t.Error(err)
}
}
...@@ -23,7 +23,7 @@ func ExampleNewResult() { ...@@ -23,7 +23,7 @@ func ExampleNewResult() {
result := NewResult(lastInsertID, affected) result := NewResult(lastInsertID, affected)
mock.ExpectExec("^INSERT (.+)").WillReturnResult(result) mock.ExpectExec("^INSERT (.+)").WillReturnResult(result)
fmt.Println(mock.ExpectationsWereMet()) fmt.Println(mock.ExpectationsWereMet())
// Output: there is a remaining expectation which was not matched: ExpectedExec => expecting Exec which: // Output: there is a remaining expectation which was not matched: ExpectedExec => expecting Exec or ExecContext which:
// - matches sql: '^INSERT (.+)' // - matches sql: '^INSERT (.+)'
// - is without arguments // - is without arguments
// - should return Result having: // - should return Result having:
......
...@@ -3,6 +3,7 @@ package sqlmock ...@@ -3,6 +3,7 @@ package sqlmock
import ( import (
"database/sql/driver" "database/sql/driver"
"encoding/csv" "encoding/csv"
"fmt"
"io" "io"
"strings" "strings"
) )
...@@ -18,57 +19,22 @@ var CSVColumnParser = func(s string) []byte { ...@@ -18,57 +19,22 @@ var CSVColumnParser = func(s string) []byte {
return []byte(s) return []byte(s)
} }
// Rows interface allows to construct rows type rowSets struct {
// which also satisfies database/sql/driver.Rows interface sets []*Rows
type Rows interface { pos int
// composed interface, supports sql driver.Rows
driver.Rows
// AddRow composed from database driver.Value slice
// return the same instance to perform subsequent actions.
// Note that the number of values must match the number
// of columns
AddRow(columns ...driver.Value) Rows
// FromCSVString build rows from csv string.
// return the same instance to perform subsequent actions.
// Note that the number of values must match the number
// of columns
FromCSVString(s string) Rows
// RowError allows to set an error
// which will be returned when a given
// row number is read
RowError(row int, err error) Rows
// CloseError allows to set an error
// which will be returned by rows.Close
// function.
//
// The close error will be triggered only in cases
// when rows.Next() EOF was not yet reached, that is
// a default sql library behavior
CloseError(err error) Rows
} }
type rows struct { func (rs *rowSets) Columns() []string {
cols []string return rs.sets[rs.pos].cols
rows [][]driver.Value
pos int
nextErr map[int]error
closeErr error
}
func (r *rows) Columns() []string {
return r.cols
} }
func (r *rows) Close() error { func (rs *rowSets) Close() error {
return r.closeErr return rs.sets[rs.pos].closeErr
} }
// advances to next row // advances to next row
func (r *rows) Next(dest []driver.Value) error { func (rs *rowSets) Next(dest []driver.Value) error {
r := rs.sets[rs.pos]
r.pos++ r.pos++
if r.pos > len(r.rows) { if r.pos > len(r.rows) {
return io.EOF // per interface spec return io.EOF // per interface spec
...@@ -81,24 +47,79 @@ func (r *rows) Next(dest []driver.Value) error { ...@@ -81,24 +47,79 @@ func (r *rows) Next(dest []driver.Value) error {
return r.nextErr[r.pos-1] return r.nextErr[r.pos-1]
} }
// transforms to debuggable printable string
func (rs *rowSets) String() string {
if rs.empty() {
return "with empty rows"
}
msg := "should return rows:\n"
if len(rs.sets) == 1 {
for n, row := range rs.sets[0].rows {
msg += fmt.Sprintf(" row %d - %+v\n", n, row)
}
return strings.TrimSpace(msg)
}
for i, set := range rs.sets {
msg += fmt.Sprintf(" result set: %d\n", i)
for n, row := range set.rows {
msg += fmt.Sprintf(" row %d - %+v\n", n, row)
}
}
return strings.TrimSpace(msg)
}
func (rs *rowSets) empty() bool {
for _, set := range rs.sets {
if len(set.rows) > 0 {
return false
}
}
return true
}
// Rows is a mocked collection of rows to
// return for Query result
type Rows struct {
cols []string
rows [][]driver.Value
pos int
nextErr map[int]error
closeErr error
}
// NewRows allows Rows to be created from a // NewRows allows Rows to be created from a
// sql driver.Value slice or from the CSV string and // sql driver.Value slice or from the CSV string and
// to be used as sql driver.Rows // to be used as sql driver.Rows
func NewRows(columns []string) Rows { func NewRows(columns []string) *Rows {
return &rows{cols: columns, nextErr: make(map[int]error)} return &Rows{cols: columns, nextErr: make(map[int]error)}
} }
func (r *rows) CloseError(err error) Rows { // CloseError allows to set an error
// which will be returned by rows.Close
// function.
//
// The close error will be triggered only in cases
// when rows.Next() EOF was not yet reached, that is
// a default sql library behavior
func (r *Rows) CloseError(err error) *Rows {
r.closeErr = err r.closeErr = err
return r return r
} }
func (r *rows) RowError(row int, err error) Rows { // RowError allows to set an error
// which will be returned when a given
// row number is read
func (r *Rows) RowError(row int, err error) *Rows {
r.nextErr[row] = err r.nextErr[row] = err
return r return r
} }
func (r *rows) AddRow(values ...driver.Value) Rows { // AddRow composed from database driver.Value slice
// return the same instance to perform subsequent actions.
// Note that the number of values must match the number
// of columns
func (r *Rows) AddRow(values ...driver.Value) *Rows {
if len(values) != len(r.cols) { if len(values) != len(r.cols) {
panic("Expected number of values to match number of columns") panic("Expected number of values to match number of columns")
} }
...@@ -112,7 +133,11 @@ func (r *rows) AddRow(values ...driver.Value) Rows { ...@@ -112,7 +133,11 @@ func (r *rows) AddRow(values ...driver.Value) Rows {
return r return r
} }
func (r *rows) FromCSVString(s string) Rows { // FromCSVString build rows from csv string.
// return the same instance to perform subsequent actions.
// Note that the number of values must match the number
// of columns
func (r *Rows) FromCSVString(s string) *Rows {
res := strings.NewReader(strings.TrimSpace(s)) res := strings.NewReader(strings.TrimSpace(s))
csvReader := csv.NewReader(res) csvReader := csv.NewReader(res)
......
// +build go1.8
package sqlmock
import "io"
// Implement the "RowsNextResultSet" interface
func (rs *rowSets) HasNextResultSet() bool {
return rs.pos+1 < len(rs.sets)
}
// Implement the "RowsNextResultSet" interface
func (rs *rowSets) NextResultSet() error {
if !rs.HasNextResultSet() {
return io.EOF
}
rs.pos++
return nil
}
// +build go1.8
package sqlmock
import (
"fmt"
"testing"
)
func TestQueryMultiRows(t *testing.T) {
t.Parallel()
db, mock, err := New()
if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
rs1 := NewRows([]string{"id", "title"}).AddRow(5, "hello world")
rs2 := NewRows([]string{"name"}).AddRow("gopher").AddRow("john").AddRow("jane").RowError(2, fmt.Errorf("error"))
mock.ExpectQuery("SELECT (.+) FROM articles WHERE id = \\?;SELECT name FROM users").
WithArgs(5).
WillReturnRows(rs1, rs2)
rows, err := db.Query("SELECT id, title FROM articles WHERE id = ?;SELECT name FROM users", 5)
if err != nil {
t.Errorf("error was not expected, but got: %v", err)
}
defer rows.Close()
if !rows.Next() {
t.Error("expected a row to be available in first result set")
}
var id int
var name string
err = rows.Scan(&id, &name)
if err != nil {
t.Errorf("error was not expected, but got: %v", err)
}
if id != 5 || name != "hello world" {
t.Errorf("unexpected row values id: %v name: %v", id, name)
}
if rows.Next() {
t.Error("was not expecting next row in first result set")
}
if !rows.NextResultSet() {
t.Error("had to have next result set")
}
if !rows.Next() {
t.Error("expected a row to be available in second result set")
}
err = rows.Scan(&name)
if err != nil {
t.Errorf("error was not expected, but got: %v", err)
}
if name != "gopher" {
t.Errorf("unexpected row name: %v", name)
}
if !rows.Next() {
t.Error("expected a row to be available in second result set")
}
err = rows.Scan(&name)
if err != nil {
t.Errorf("error was not expected, but got: %v", err)
}
if name != "john" {
t.Errorf("unexpected row name: %v", name)
}
if rows.Next() {
t.Error("expected next row to produce error")
}
if rows.Err() == nil {
t.Error("expected an error, but there was none")
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expections: %s", err)
}
}
...@@ -246,3 +246,40 @@ func TestCSVRowParser(t *testing.T) { ...@@ -246,3 +246,40 @@ func TestCSVRowParser(t *testing.T) {
t.Fatalf("expected col2 to be nil, but got [%T]:%+v", col2, col2) t.Fatalf("expected col2 to be nil, but got [%T]:%+v", col2, col2)
} }
} }
func TestWrongNumberOfValues(t *testing.T) {
// Open new mock database
db, mock, err := New()
if err != nil {
fmt.Println("error creating mock database")
return
}
defer db.Close()
defer func() {
recover()
}()
mock.ExpectQuery("SELECT ID FROM TABLE").WithArgs(101).WillReturnRows(NewRows([]string{"ID"}).AddRow(101, "Hello"))
db.Query("SELECT ID FROM TABLE", 101)
// shouldn't reach here
t.Error("expected panic from query")
}
func TestEmptyRowSets(t *testing.T) {
rs1 := NewRows([]string{"a"}).AddRow("a")
rs2 := NewRows([]string{"b"})
rs3 := NewRows([]string{"c"})
set1 := &rowSets{sets: []*Rows{rs1, rs2}}
set2 := &rowSets{sets: []*Rows{rs3, rs2}}
set3 := &rowSets{sets: []*Rows{rs2}}
if set1.empty() {
t.Fatalf("expected rowset 1, not to be empty, but it was")
}
if !set2.empty() {
t.Fatalf("expected rowset 2, to be empty, but it was not")
}
if !set3.empty() {
t.Fatalf("expected rowset 3, to be empty, but it was not")
}
}
/* /*
Package sqlmock provides sql driver connection, which allows to test database Package sqlmock is a mock library implementing sql driver. Which has one and only
interactions by expected calls and simulate their results or errors. purpose - to simulate any sql driver behavior in tests, without needing a real
database connection. It helps to maintain correct **TDD** workflow.
It does not require any modifications to your source code in order to test It does not require any modifications to your source code in order to test
and mock database operations. It does not even require a real database in order and mock database operations. Supports concurrency and multiple database mocking.
to test your application.
The driver allows to mock any sql driver method behavior. Concurrent actions The driver allows to mock any sql driver method behavior.
are also supported.
*/ */
package sqlmock package sqlmock
import ( import (
"database/sql"
"database/sql/driver" "database/sql/driver"
"fmt" "fmt"
"reflect"
"regexp" "regexp"
"time"
) )
// Sqlmock interface serves to create expectations // Sqlmock interface serves to create expectations
...@@ -67,6 +67,11 @@ type Sqlmock interface { ...@@ -67,6 +67,11 @@ type Sqlmock interface {
// By default it is set to - true. But if you use goroutines // By default it is set to - true. But if you use goroutines
// to parallelize your query executation, that option may // to parallelize your query executation, that option may
// be handy. // be handy.
//
// This option may be turned on anytime during tests. As soon
// as it is switched to false, expectations will be matched
// in any order. Or otherwise if switched to true, any unmatched
// expectations will be expected in order
MatchExpectationsInOrder(bool) MatchExpectationsInOrder(bool)
} }
...@@ -79,6 +84,14 @@ type sqlmock struct { ...@@ -79,6 +84,14 @@ type sqlmock struct {
expected []expectation expected []expectation
} }
func (c *sqlmock) open() (*sql.DB, Sqlmock, error) {
db, err := sql.Open("sqlmock", c.dsn)
if err != nil {
return db, c, err
}
return db, c, db.Ping()
}
func (c *sqlmock) ExpectClose() *ExpectedClose { func (c *sqlmock) ExpectClose() *ExpectedClose {
e := &ExpectedClose{} e := &ExpectedClose{}
c.expected = append(c.expected, e) c.expected = append(c.expected, e)
...@@ -141,12 +154,29 @@ func (c *sqlmock) ExpectationsWereMet() error { ...@@ -141,12 +154,29 @@ func (c *sqlmock) ExpectationsWereMet() error {
if !e.fulfilled() { if !e.fulfilled() {
return fmt.Errorf("there is a remaining expectation which was not matched: %s", e) return fmt.Errorf("there is a remaining expectation which was not matched: %s", e)
} }
// for expected prepared statement check whether it was closed if expected
if prep, ok := e.(*ExpectedPrepare); ok {
if prep.mustBeClosed && !prep.wasClosed {
return fmt.Errorf("expected prepared statement to be closed, but it was not: %s", prep)
}
}
} }
return nil return nil
} }
// Begin meets http://golang.org/pkg/database/sql/driver/#Conn interface // Begin meets http://golang.org/pkg/database/sql/driver/#Conn interface
func (c *sqlmock) Begin() (driver.Tx, error) { func (c *sqlmock) Begin() (driver.Tx, error) {
ex, err := c.begin()
if err != nil {
return nil, err
}
time.Sleep(ex.delay)
return c, nil
}
func (c *sqlmock) begin() (*ExpectedBegin, error) {
var expected *ExpectedBegin var expected *ExpectedBegin
var ok bool var ok bool
var fulfilled int var fulfilled int
...@@ -177,7 +207,8 @@ func (c *sqlmock) Begin() (driver.Tx, error) { ...@@ -177,7 +207,8 @@ func (c *sqlmock) Begin() (driver.Tx, error) {
expected.triggered = true expected.triggered = true
expected.Unlock() expected.Unlock()
return c, expected.err
return expected, expected.err
} }
func (c *sqlmock) ExpectBegin() *ExpectedBegin { func (c *sqlmock) ExpectBegin() *ExpectedBegin {
...@@ -187,7 +218,25 @@ func (c *sqlmock) ExpectBegin() *ExpectedBegin { ...@@ -187,7 +218,25 @@ func (c *sqlmock) ExpectBegin() *ExpectedBegin {
} }
// Exec meets http://golang.org/pkg/database/sql/driver/#Execer // Exec meets http://golang.org/pkg/database/sql/driver/#Execer
func (c *sqlmock) Exec(query string, args []driver.Value) (res driver.Result, err error) { func (c *sqlmock) Exec(query string, args []driver.Value) (driver.Result, error) {
namedArgs := make([]namedValue, len(args))
for i, v := range args {
namedArgs[i] = namedValue{
Ordinal: i + 1,
Value: v,
}
}
ex, err := c.exec(query, namedArgs)
if err != nil {
return nil, err
}
time.Sleep(ex.delay)
return ex.result, nil
}
func (c *sqlmock) exec(query string, args []namedValue) (*ExpectedExec, error) {
query = stripQuery(query) query = stripQuery(query)
var expected *ExpectedExec var expected *ExpectedExec
var fulfilled int var fulfilled int
...@@ -205,10 +254,10 @@ func (c *sqlmock) Exec(query string, args []driver.Value) (res driver.Result, er ...@@ -205,10 +254,10 @@ func (c *sqlmock) Exec(query string, args []driver.Value) (res driver.Result, er
break break
} }
next.Unlock() next.Unlock()
return nil, fmt.Errorf("call to exec query '%s' with args %+v, was not expected, next expectation is: %s", query, args, next) return nil, fmt.Errorf("call to ExecQuery '%s' with args %+v, was not expected, next expectation is: %s", query, args, next)
} }
if exec, ok := next.(*ExpectedExec); ok { if exec, ok := next.(*ExpectedExec); ok {
if exec.attemptMatch(query, args) { if err := exec.attemptMatch(query, args); err == nil {
expected = exec expected = exec
break break
} }
...@@ -216,47 +265,37 @@ func (c *sqlmock) Exec(query string, args []driver.Value) (res driver.Result, er ...@@ -216,47 +265,37 @@ func (c *sqlmock) Exec(query string, args []driver.Value) (res driver.Result, er
next.Unlock() next.Unlock()
} }
if expected == nil { if expected == nil {
msg := "call to exec '%s' query with args %+v was not expected" msg := "call to ExecQuery '%s' with args %+v was not expected"
if fulfilled == len(c.expected) { if fulfilled == len(c.expected) {
msg = "all expectations were already fulfilled, " + msg msg = "all expectations were already fulfilled, " + msg
} }
return nil, fmt.Errorf(msg, query, args) return nil, fmt.Errorf(msg, query, args)
} }
defer expected.Unlock() defer expected.Unlock()
expected.triggered = true
// converts panic to error in case of reflect value type mismatch
defer func(errp *error, exp *ExpectedExec, q string, a []driver.Value) {
if e := recover(); e != nil {
if se, ok := e.(*reflect.ValueError); ok { // catch reflect error, failed type conversion
msg := "exec query \"%s\", args \"%+v\" failed to match with error \"%s\" expectation: %s"
*errp = fmt.Errorf(msg, q, a, se, exp)
} else {
panic(e) // overwise if unknown error panic
}
}
}(&err, expected, query, args)
if !expected.queryMatches(query) { if !expected.queryMatches(query) {
return nil, fmt.Errorf("exec query '%s', does not match regex '%s'", query, expected.sqlRegex.String()) return nil, fmt.Errorf("ExecQuery '%s', does not match regex '%s'", query, expected.sqlRegex.String())
} }
if !expected.argsMatches(args) { if err := expected.argsMatches(args); err != nil {
return nil, fmt.Errorf("exec query '%s', args %+v does not match expected %+v", query, args, expected.args) return nil, fmt.Errorf("ExecQuery '%s', arguments do not match: %s", query, err)
} }
expected.triggered = true
if expected.err != nil { if expected.err != nil {
return nil, expected.err // mocked to return error return nil, expected.err // mocked to return error
} }
if expected.result == nil { if expected.result == nil {
return nil, fmt.Errorf("exec query '%s' with args %+v, must return a database/sql/driver.result, but it was not set for expectation %T as %+v", query, args, expected, expected) return nil, fmt.Errorf("ExecQuery '%s' with args %+v, must return a database/sql/driver.Result, but it was not set for expectation %T as %+v", query, args, expected, expected)
} }
return expected.result, err
return expected, nil
} }
func (c *sqlmock) ExpectExec(sqlRegexStr string) *ExpectedExec { func (c *sqlmock) ExpectExec(sqlRegexStr string) *ExpectedExec {
e := &ExpectedExec{} e := &ExpectedExec{}
sqlRegexStr = stripQuery(sqlRegexStr)
e.sqlRegex = regexp.MustCompile(sqlRegexStr) e.sqlRegex = regexp.MustCompile(sqlRegexStr)
c.expected = append(c.expected, e) c.expected = append(c.expected, e)
return e return e
...@@ -264,9 +303,22 @@ func (c *sqlmock) ExpectExec(sqlRegexStr string) *ExpectedExec { ...@@ -264,9 +303,22 @@ func (c *sqlmock) ExpectExec(sqlRegexStr string) *ExpectedExec {
// Prepare meets http://golang.org/pkg/database/sql/driver/#Conn interface // Prepare meets http://golang.org/pkg/database/sql/driver/#Conn interface
func (c *sqlmock) Prepare(query string) (driver.Stmt, error) { func (c *sqlmock) Prepare(query string) (driver.Stmt, error) {
ex, err := c.prepare(query)
if err != nil {
return nil, err
}
time.Sleep(ex.delay)
return &statement{c, ex, query}, nil
}
func (c *sqlmock) prepare(query string) (*ExpectedPrepare, error) {
var expected *ExpectedPrepare var expected *ExpectedPrepare
var fulfilled int var fulfilled int
var ok bool var ok bool
query = stripQuery(query)
for _, next := range c.expected { for _, next := range c.expected {
next.Lock() next.Lock()
if next.fulfilled() { if next.fulfilled() {
...@@ -275,17 +327,24 @@ func (c *sqlmock) Prepare(query string) (driver.Stmt, error) { ...@@ -275,17 +327,24 @@ func (c *sqlmock) Prepare(query string) (driver.Stmt, error) {
continue continue
} }
if expected, ok = next.(*ExpectedPrepare); ok { if c.ordered {
break if expected, ok = next.(*ExpectedPrepare); ok {
break
}
next.Unlock()
return nil, fmt.Errorf("call to Prepare statement with query '%s', was not expected, next expectation is: %s", query, next)
} }
next.Unlock() if pr, ok := next.(*ExpectedPrepare); ok {
if c.ordered { if pr.sqlRegex.MatchString(query) {
return nil, fmt.Errorf("call to Prepare stetement with query '%s', was not expected, next expectation is: %s", query, next) expected = pr
break
}
} }
next.Unlock()
} }
query = stripQuery(query)
if expected == nil { if expected == nil {
msg := "call to Prepare '%s' query was not expected" msg := "call to Prepare '%s' query was not expected"
if fulfilled == len(c.expected) { if fulfilled == len(c.expected) {
...@@ -293,20 +352,48 @@ func (c *sqlmock) Prepare(query string) (driver.Stmt, error) { ...@@ -293,20 +352,48 @@ func (c *sqlmock) Prepare(query string) (driver.Stmt, error) {
} }
return nil, fmt.Errorf(msg, query) return nil, fmt.Errorf(msg, query)
} }
defer expected.Unlock()
if !expected.sqlRegex.MatchString(query) {
return nil, fmt.Errorf("Prepare query string '%s', does not match regex [%s]", query, expected.sqlRegex.String())
}
expected.triggered = true expected.triggered = true
expected.Unlock() return expected, expected.err
return &statement{c, query, expected.closeErr}, expected.err
} }
func (c *sqlmock) ExpectPrepare(sqlRegexStr string) *ExpectedPrepare { func (c *sqlmock) ExpectPrepare(sqlRegexStr string) *ExpectedPrepare {
sqlRegexStr = stripQuery(sqlRegexStr)
e := &ExpectedPrepare{sqlRegex: regexp.MustCompile(sqlRegexStr), mock: c} e := &ExpectedPrepare{sqlRegex: regexp.MustCompile(sqlRegexStr), mock: c}
c.expected = append(c.expected, e) c.expected = append(c.expected, e)
return e return e
} }
type namedValue struct {
Name string
Ordinal int
Value driver.Value
}
// Query meets http://golang.org/pkg/database/sql/driver/#Queryer // Query meets http://golang.org/pkg/database/sql/driver/#Queryer
func (c *sqlmock) Query(query string, args []driver.Value) (rw driver.Rows, err error) { func (c *sqlmock) Query(query string, args []driver.Value) (driver.Rows, error) {
namedArgs := make([]namedValue, len(args))
for i, v := range args {
namedArgs[i] = namedValue{
Ordinal: i + 1,
Value: v,
}
}
ex, err := c.query(query, namedArgs)
if err != nil {
return nil, err
}
time.Sleep(ex.delay)
return ex.rows, nil
}
func (c *sqlmock) query(query string, args []namedValue) (*ExpectedQuery, error) {
query = stripQuery(query) query = stripQuery(query)
var expected *ExpectedQuery var expected *ExpectedQuery
var fulfilled int var fulfilled int
...@@ -324,10 +411,10 @@ func (c *sqlmock) Query(query string, args []driver.Value) (rw driver.Rows, err ...@@ -324,10 +411,10 @@ func (c *sqlmock) Query(query string, args []driver.Value) (rw driver.Rows, err
break break
} }
next.Unlock() next.Unlock()
return nil, fmt.Errorf("call to query '%s' with args %+v, was not expected, next expectation is: %s", query, args, next) return nil, fmt.Errorf("call to Query '%s' with args %+v, was not expected, next expectation is: %s", query, args, next)
} }
if qr, ok := next.(*ExpectedQuery); ok { if qr, ok := next.(*ExpectedQuery); ok {
if qr.attemptMatch(query, args) { if err := qr.attemptMatch(query, args); err == nil {
expected = qr expected = qr
break break
} }
...@@ -336,7 +423,7 @@ func (c *sqlmock) Query(query string, args []driver.Value) (rw driver.Rows, err ...@@ -336,7 +423,7 @@ func (c *sqlmock) Query(query string, args []driver.Value) (rw driver.Rows, err
} }
if expected == nil { if expected == nil {
msg := "call to query '%s' with args %+v was not expected" msg := "call to Query '%s' with args %+v was not expected"
if fulfilled == len(c.expected) { if fulfilled == len(c.expected) {
msg = "all expectations were already fulfilled, " + msg msg = "all expectations were already fulfilled, " + msg
} }
...@@ -344,40 +431,29 @@ func (c *sqlmock) Query(query string, args []driver.Value) (rw driver.Rows, err ...@@ -344,40 +431,29 @@ func (c *sqlmock) Query(query string, args []driver.Value) (rw driver.Rows, err
} }
defer expected.Unlock() defer expected.Unlock()
expected.triggered = true
// converts panic to error in case of reflect value type mismatch
defer func(errp *error, exp *ExpectedQuery, q string, a []driver.Value) {
if e := recover(); e != nil {
if se, ok := e.(*reflect.ValueError); ok { // catch reflect error, failed type conversion
msg := "query \"%s\", args \"%+v\" failed to match with error \"%s\" expectation: %s"
*errp = fmt.Errorf(msg, q, a, se, exp)
} else {
panic(e) // overwise if unknown error panic
}
}
}(&err, expected, query, args)
if !expected.queryMatches(query) { if !expected.queryMatches(query) {
return nil, fmt.Errorf("query '%s', does not match regex [%s]", query, expected.sqlRegex.String()) return nil, fmt.Errorf("Query '%s', does not match regex [%s]", query, expected.sqlRegex.String())
} }
if !expected.argsMatches(args) { if err := expected.argsMatches(args); err != nil {
return nil, fmt.Errorf("query '%s', args %+v does not match expected %+v", query, args, expected.args) return nil, fmt.Errorf("Query '%s', arguments do not match: %s", query, err)
} }
expected.triggered = true
if expected.err != nil { if expected.err != nil {
return nil, expected.err // mocked to return error return nil, expected.err // mocked to return error
} }
if expected.rows == nil { if expected.rows == nil {
return nil, fmt.Errorf("query '%s' with args %+v, must return a database/sql/driver.rows, but it was not set for expectation %T as %+v", query, args, expected, expected) return nil, fmt.Errorf("Query '%s' with args %+v, must return a database/sql/driver.Rows, but it was not set for expectation %T as %+v", query, args, expected, expected)
} }
return expected, nil
return expected.rows, err
} }
func (c *sqlmock) ExpectQuery(sqlRegexStr string) *ExpectedQuery { func (c *sqlmock) ExpectQuery(sqlRegexStr string) *ExpectedQuery {
e := &ExpectedQuery{} e := &ExpectedQuery{}
sqlRegexStr = stripQuery(sqlRegexStr)
e.sqlRegex = regexp.MustCompile(sqlRegexStr) e.sqlRegex = regexp.MustCompile(sqlRegexStr)
c.expected = append(c.expected, e) c.expected = append(c.expected, e)
return e return e
...@@ -414,11 +490,11 @@ func (c *sqlmock) Commit() error { ...@@ -414,11 +490,11 @@ func (c *sqlmock) Commit() error {
next.Unlock() next.Unlock()
if c.ordered { if c.ordered {
return fmt.Errorf("call to commit transaction, was not expected, next expectation is: %s", next) return fmt.Errorf("call to Commit transaction, was not expected, next expectation is: %s", next)
} }
} }
if expected == nil { if expected == nil {
msg := "call to commit transaction was not expected" msg := "call to Commit transaction was not expected"
if fulfilled == len(c.expected) { if fulfilled == len(c.expected) {
msg = "all expectations were already fulfilled, " + msg msg = "all expectations were already fulfilled, " + msg
} }
...@@ -449,11 +525,11 @@ func (c *sqlmock) Rollback() error { ...@@ -449,11 +525,11 @@ func (c *sqlmock) Rollback() error {
next.Unlock() next.Unlock()
if c.ordered { if c.ordered {
return fmt.Errorf("call to rollback transaction, was not expected, next expectation is: %s", next) return fmt.Errorf("call to Rollback transaction, was not expected, next expectation is: %s", next)
} }
} }
if expected == nil { if expected == nil {
msg := "call to rollback transaction was not expected" msg := "call to Rollback transaction was not expected"
if fulfilled == len(c.expected) { if fulfilled == len(c.expected) {
msg = "all expectations were already fulfilled, " + msg msg = "all expectations were already fulfilled, " + msg
} }
......
// +build go1.8
package sqlmock
import (
"context"
"database/sql/driver"
"errors"
"time"
)
var ErrCancelled = errors.New("canceling query due to user request")
// Implement the "QueryerContext" interface
func (c *sqlmock) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
namedArgs := make([]namedValue, len(args))
for i, nv := range args {
namedArgs[i] = namedValue(nv)
}
ex, err := c.query(query, namedArgs)
if err != nil {
return nil, err
}
select {
case <-time.After(ex.delay):
return ex.rows, nil
case <-ctx.Done():
return nil, ErrCancelled
}
}
// Implement the "ExecerContext" interface
func (c *sqlmock) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
namedArgs := make([]namedValue, len(args))
for i, nv := range args {
namedArgs[i] = namedValue(nv)
}
ex, err := c.exec(query, namedArgs)
if err != nil {
return nil, err
}
select {
case <-time.After(ex.delay):
return ex.result, nil
case <-ctx.Done():
return nil, ErrCancelled
}
}
// Implement the "ConnBeginTx" interface
func (c *sqlmock) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
ex, err := c.begin()
if err != nil {
return nil, err
}
select {
case <-time.After(ex.delay):
return c, nil
case <-ctx.Done():
return nil, ErrCancelled
}
}
// Implement the "ConnPrepareContext" interface
func (c *sqlmock) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
ex, err := c.prepare(query)
if err != nil {
return nil, err
}
select {
case <-time.After(ex.delay):
return &statement{c, ex, query}, nil
case <-ctx.Done():
return nil, ErrCancelled
}
}
// Implement the "Pinger" interface
// for now we do not have a Ping expectation
// may be something for the future
func (c *sqlmock) Ping(ctx context.Context) error {
return nil
}
// Implement the "StmtExecContext" interface
func (stmt *statement) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
return stmt.conn.ExecContext(ctx, stmt.query, args)
}
// Implement the "StmtQueryContext" interface
func (stmt *statement) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
return stmt.conn.QueryContext(ctx, stmt.query, args)
}
// @TODO maybe add ExpectedBegin.WithOptions(driver.TxOptions)
// +build go1.8
package sqlmock
import (
"context"
"database/sql"
"testing"
"time"
)
func TestContextExecCancel(t *testing.T) {
t.Parallel()
db, mock, err := New()
if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
mock.ExpectExec("DELETE FROM users").
WillDelayFor(time.Second).
WillReturnResult(NewResult(1, 1))
ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(time.Millisecond * 10)
cancel()
}()
_, err = db.ExecContext(ctx, "DELETE FROM users")
if err == nil {
t.Error("error was expected, but there was none")
}
if err != ErrCancelled {
t.Errorf("was expecting cancel error, but got: %v", err)
}
_, err = db.ExecContext(ctx, "DELETE FROM users")
if err != context.Canceled {
t.Error("error was expected since context was already done, but there was none")
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expections: %s", err)
}
}
func TestPreparedStatementContextExecCancel(t *testing.T) {
t.Parallel()
db, mock, err := New()
if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
mock.ExpectPrepare("DELETE FROM users").
ExpectExec().
WillDelayFor(time.Second).
WillReturnResult(NewResult(1, 1))
ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(time.Millisecond * 10)
cancel()
}()
stmt, err := db.Prepare("DELETE FROM users")
if err != nil {
t.Errorf("error was not expected, but got: %v", err)
}
_, err = stmt.ExecContext(ctx)
if err == nil {
t.Error("error was expected, but there was none")
}
if err != ErrCancelled {
t.Errorf("was expecting cancel error, but got: %v", err)
}
_, err = stmt.ExecContext(ctx)
if err != context.Canceled {
t.Error("error was expected since context was already done, but there was none")
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expections: %s", err)
}
}
func TestContextExecWithNamedArg(t *testing.T) {
t.Parallel()
db, mock, err := New()
if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
mock.ExpectExec("DELETE FROM users").
WithArgs(sql.Named("id", 5)).
WillDelayFor(time.Second).
WillReturnResult(NewResult(1, 1))
ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(time.Millisecond * 10)
cancel()
}()
_, err = db.ExecContext(ctx, "DELETE FROM users WHERE id = :id", sql.Named("id", 5))
if err == nil {
t.Error("error was expected, but there was none")
}
if err != ErrCancelled {
t.Errorf("was expecting cancel error, but got: %v", err)
}
_, err = db.ExecContext(ctx, "DELETE FROM users WHERE id = :id", sql.Named("id", 5))
if err != context.Canceled {
t.Error("error was expected since context was already done, but there was none")
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expections: %s", err)
}
}
func TestContextExec(t *testing.T) {
t.Parallel()
db, mock, err := New()
if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
mock.ExpectExec("DELETE FROM users").
WillReturnResult(NewResult(1, 1))
ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(time.Millisecond * 10)
cancel()
}()
res, err := db.ExecContext(ctx, "DELETE FROM users")
if err != nil {
t.Errorf("error was not expected, but got: %v", err)
}
affected, err := res.RowsAffected()
if affected != 1 {
t.Errorf("expected affected rows 1, but got %v", affected)
}
if err != nil {
t.Errorf("error was not expected, but got: %v", err)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expections: %s", err)
}
}
func TestContextQueryCancel(t *testing.T) {
t.Parallel()
db, mock, err := New()
if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
rs := NewRows([]string{"id", "title"}).AddRow(5, "hello world")
mock.ExpectQuery("SELECT (.+) FROM articles WHERE id = ?").
WithArgs(5).
WillDelayFor(time.Second).
WillReturnRows(rs)
ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(time.Millisecond * 10)
cancel()
}()
_, err = db.QueryContext(ctx, "SELECT id, title FROM articles WHERE id = ?", 5)
if err == nil {
t.Error("error was expected, but there was none")
}
if err != ErrCancelled {
t.Errorf("was expecting cancel error, but got: %v", err)
}
_, err = db.QueryContext(ctx, "SELECT id, title FROM articles WHERE id = ?", 5)
if err != context.Canceled {
t.Error("error was expected since context was already done, but there was none")
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expections: %s", err)
}
}
func TestPreparedStatementContextQueryCancel(t *testing.T) {
t.Parallel()
db, mock, err := New()
if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
rs := NewRows([]string{"id", "title"}).AddRow(5, "hello world")
mock.ExpectPrepare("SELECT (.+) FROM articles WHERE id = ?").
ExpectQuery().
WithArgs(5).
WillDelayFor(time.Second).
WillReturnRows(rs)
ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(time.Millisecond * 10)
cancel()
}()
stmt, err := db.Prepare("SELECT id, title FROM articles WHERE id = ?")
if err != nil {
t.Errorf("error was not expected, but got: %v", err)
}
_, err = stmt.QueryContext(ctx, 5)
if err == nil {
t.Error("error was expected, but there was none")
}
if err != ErrCancelled {
t.Errorf("was expecting cancel error, but got: %v", err)
}
_, err = stmt.QueryContext(ctx, 5)
if err != context.Canceled {
t.Error("error was expected since context was already done, but there was none")
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expections: %s", err)
}
}
func TestContextQuery(t *testing.T) {
t.Parallel()
db, mock, err := New()
if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
rs := NewRows([]string{"id", "title"}).AddRow(5, "hello world")
mock.ExpectQuery("SELECT (.+) FROM articles WHERE id =").
WithArgs(sql.Named("id", 5)).
WillDelayFor(time.Millisecond * 3).
WillReturnRows(rs)
ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(time.Millisecond * 10)
cancel()
}()
rows, err := db.QueryContext(ctx, "SELECT id, title FROM articles WHERE id = :id", sql.Named("id", 5))
if err != nil {
t.Errorf("error was not expected, but got: %v", err)
}
if !rows.Next() {
t.Error("expected one row, but there was none")
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expections: %s", err)
}
}
func TestContextBeginCancel(t *testing.T) {
t.Parallel()
db, mock, err := New()
if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
mock.ExpectBegin().WillDelayFor(time.Second)
ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(time.Millisecond * 10)
cancel()
}()
_, err = db.BeginTx(ctx, nil)
if err == nil {
t.Error("error was expected, but there was none")
}
if err != ErrCancelled {
t.Errorf("was expecting cancel error, but got: %v", err)
}
_, err = db.BeginTx(ctx, nil)
if err != context.Canceled {
t.Error("error was expected since context was already done, but there was none")
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expections: %s", err)
}
}
func TestContextBegin(t *testing.T) {
t.Parallel()
db, mock, err := New()
if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
mock.ExpectBegin().WillDelayFor(time.Millisecond * 3)
ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(time.Millisecond * 10)
cancel()
}()
tx, err := db.BeginTx(ctx, nil)
if err != nil {
t.Errorf("error was not expected, but got: %v", err)
}
if tx == nil {
t.Error("expected tx, but there was nil")
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expections: %s", err)
}
}
func TestContextPrepareCancel(t *testing.T) {
t.Parallel()
db, mock, err := New()
if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
mock.ExpectPrepare("SELECT").WillDelayFor(time.Second)
ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(time.Millisecond * 10)
cancel()
}()
_, err = db.PrepareContext(ctx, "SELECT")
if err == nil {
t.Error("error was expected, but there was none")
}
if err != ErrCancelled {
t.Errorf("was expecting cancel error, but got: %v", err)
}
_, err = db.PrepareContext(ctx, "SELECT")
if err != context.Canceled {
t.Error("error was expected since context was already done, but there was none")
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expections: %s", err)
}
}
func TestContextPrepare(t *testing.T) {
t.Parallel()
db, mock, err := New()
if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
mock.ExpectPrepare("SELECT").WillDelayFor(time.Millisecond * 3)
ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(time.Millisecond * 10)
cancel()
}()
stmt, err := db.PrepareContext(ctx, "SELECT")
if err != nil {
t.Errorf("error was not expected, but got: %v", err)
}
if stmt == nil {
t.Error("expected stmt, but there was nil")
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expections: %s", err)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment