diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000000000000000000000000000000000000..26a991724e96c1189a854054e5011e84dfcc41bb --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,12 @@ +version: 2 +updates: +- package-ecosystem: gomod + directory: "/" + schedule: + interval: daily + time: "13:00" + open-pull-requests-limit: 10 +- package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: daily diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml new file mode 100644 index 0000000000000000000000000000000000000000..d28b2969964fa030b59368d3173f08e752205803 --- /dev/null +++ b/.github/workflows/codeql-analysis.yml @@ -0,0 +1,53 @@ +name: "Code scanning - action" + +on: + push: + branches-ignore: + - 'dependabot/**' + pull_request: + schedule: + - cron: '0 13 * * 4' + +jobs: + CodeQL-Build: + + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v3 + with: + # We must fetch at least the immediate parents so that if this is + # a pull request then we can checkout the head. + fetch-depth: 2 + + # If this run was triggered by a pull request event, then checkout + # the head of the pull request instead of the merge commit. + - run: git checkout HEAD^2 + if: ${{ github.event_name == 'pull_request' }} + + # Initializes the CodeQL tools for scanning. + - name: Initialize CodeQL + uses: github/codeql-action/init@v2 + # Override language selection by uncommenting this and choosing your languages + # with: + # languages: go, javascript, csharp, python, cpp, java + + # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). + # If this step fails, then you should remove it and run the build manually (see below) + - name: Autobuild + uses: github/codeql-action/autobuild@v2 + + # ℹ️ Command-line programs to run using the OS shell. + # 📚 https://git.io/JvXDl + + # ✏️ If the Autobuild fails above, remove it and uncomment the following three lines + # and modify them (or add more) to build your code if your project + # uses a compiled language + + #- run: | + # make bootstrap + # make release + + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@v2 diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml new file mode 100644 index 0000000000000000000000000000000000000000..9fa1aa6a57aee5c11c6e6e1894ec76e3348fa89e --- /dev/null +++ b/.github/workflows/go.yml @@ -0,0 +1,33 @@ +name: Go + +on: [push, pull_request] + +jobs: + + build: + name: Build + strategy: + matrix: + go-version: [1.19.x, 1.20.x] + platform: [ubuntu-latest, macos-latest, windows-latest] + runs-on: ${{ matrix.platform }} + steps: + - name: Set up Go 1.x + uses: actions/setup-go@v4 + with: + go-version: ${{ matrix.go-version }} + id: go + + - name: Check out code into the Go module directory + uses: actions/checkout@v3 + with: + submodules: true + + - name: Get dependencies + run: go get -v -t -d ./... + + - name: Build + run: go build -v . + + - name: Test + run: go test -race -v ./... diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml new file mode 100644 index 0000000000000000000000000000000000000000..dd4cd67090cccb35156c7d0b0419a971a3c6f2ef --- /dev/null +++ b/.github/workflows/golangci-lint.yml @@ -0,0 +1,14 @@ +name: golangci-lint + +on: [push, pull_request] + +jobs: + golangci: + name: lint + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: golangci-lint + uses: golangci/golangci-lint-action@v3 + with: + version: latest diff --git a/.gitmodules b/.gitmodules index 51779cbff134295e7033c8f2500b9ac58182e37d..400b2ab62c0faaf8637ba3fdedec812410d57f36 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,3 @@ [submodule "test-data"] path = test-data - url = git://github.com/maxmind/MaxMind-DB.git + url = https://github.com/maxmind/MaxMind-DB.git diff --git a/.golangci.toml b/.golangci.toml new file mode 100644 index 0000000000000000000000000000000000000000..0950c1c5437f606fed177a4e8c9ffc770c673d76 --- /dev/null +++ b/.golangci.toml @@ -0,0 +1,452 @@ +[run] + deadline = "10m" + tests = true + +[linters] + disable-all = true + enable = [ + "asasalint", + "asciicheck", + "bidichk", + "bodyclose", + "containedctx", + "contextcheck", + "depguard", + "dupword", + "durationcheck", + "errcheck", + "errchkjson", + "errname", + "errorlint", + # "exhaustive", + "exportloopref", + "forbidigo", + "goconst", + "gocyclo", + "gocritic", + "godot", + "gofumpt", + "gomodguard", + "gosec", + "gosimple", + "govet", + "grouper", + "ineffassign", + "lll", + "makezero", + "maintidx", + "misspell", + "nakedret", + "nilerr", + "noctx", + "nolintlint", + "nosprintfhostport", + "predeclared", + "revive", + "rowserrcheck", + "sqlclosecheck", + "staticcheck", + "stylecheck", + "tenv", + "tparallel", + "typecheck", + "unconvert", + "unparam", + "unused", + "usestdlibvars", + "vetshadow", + "wastedassign", + ] + +[[linters-settings.depguard.rules.main.deny]] +pkg = "io/ioutil" +desc = "Deprecated. Functions have been moved elsewhere." + +[linters-settings.errcheck] + check-blank = true + # Ignoring Close so that we don't have to have a bunch of + # `defer func() { _ = r.Close() }()` constructs when we + # don't actually care about the error. + ignore = "Close,fmt:.*" + +[linters-settings.errorlint] + errorf = true + asserts = true + comparison = true + +[linters-settings.exhaustive] + default-signifies-exhaustive = true + +[linters-settings.forbidigo] + # Forbid the following identifiers + forbid = [ + "Geoip", # use "GeoIP" + "^geoIP", # use "geoip" + "Maxmind", # use "MaxMind" + "^maxMind", # use "maxmind" + ] + +[linters-settings.gocritic] + enabled-checks = [ + "appendAssign", + "appendCombine", + "argOrder", + "assignOp", + "badCall", + "badCond", + "badLock", + "badRegexp", + "badSorting", + "boolExprSimplify", + "builtinShadow", + "builtinShadowDecl", + "captLocal", + "caseOrder", + "codegenComment", + "commentedOutCode", + "commentedOutImport", + "commentFormatting", + "defaultCaseOrder", + "deferInLoop", + "deferUnlambda", + "deprecatedComment", + "docStub", + "dupArg", + "dupBranchBody", + "dupCase", + "dupImport", + "dupSubExpr", + "dynamicFmtString", + "elseif", + "emptyDecl", + "emptyFallthrough", + "emptyStringTest", + "equalFold", + "evalOrder", + "exitAfterDefer", + "exposedSyncMutex", + "externalErrorReassign", + "filepathJoin", + "flagDeref", + "flagName", + "hexLiteral", + "httpNoBody", + "hugeParam", + "ifElseChain", + "importShadow", + "indexAlloc", + "initClause", + "mapKey", + "methodExprCall", + "nestingReduce", + "newDeref", + "nilValReturn", + "octalLiteral", + "offBy1", + "paramTypeCombine", + "preferDecodeRune", + "preferFilepathJoin", + "preferFprint", + "preferStringWriter", + "preferWriteByte", + "ptrToRefParam", + "rangeExprCopy", + "rangeValCopy", + "redundantSprint", + "regexpMust", + "regexpPattern", + "regexpSimplify", + "returnAfterHttpError", + "ruleguard", + "singleCaseSwitch", + "sliceClear", + "sloppyLen", + "sloppyReassign", + "sloppyTestFuncName", + "sloppyTypeAssert", + "sortSlice", + "sprintfQuotedString", + "sqlQuery", + "stringsCompare", + "stringConcatSimplify", + "stringXbytes", + "switchTrue", + "syncMapLoadAndDelete", + "timeExprSimplify", + "todoCommentWithoutDetail", + "tooManyResultsChecker", + "truncateCmp", + "typeAssertChain", + "typeDefFirst", + "typeSwitchVar", + "typeUnparen", + "underef", + "unlabelStmt", + "unlambda", + # "unnamedResult", + "unnecessaryBlock", + "unnecessaryDefer", + "unslice", + "valSwap", + "weakCond", + # Covered by nolintlint + # "whyNoLint" + "wrapperFunc", + "yodaStyleExpr", + ] + +[linters-settings.gofumpt] + extra-rules = true + lang-version = "1.19" + +[linters-settings.gosec] + excludes = [ + # G104 - "Audit errors not checked." We use errcheck for this. + "G104", + + # G304 - "Potential file inclusion via variable" + "G304", + + # G306 - "Expect WriteFile permissions to be 0600 or less". + "G306", + + # Prohibits defer (*os.File).Close, which we allow when reading from file. + "G307", + ] + +[linters-settings.govet] + "enable-all" = true + disable = ["shadow"] + +[linters-settings.lll] + line-length = 120 + tab-width = 4 + +[linters-settings.nolintlint] + allow-leading-space = false + allow-unused = false + allow-no-explanation = ["lll", "misspell"] + require-explanation = true + require-specific = true + +[linters-settings.revive] + ignore-generated-header = true + severity = "warning" + + # [[linters-settings.revive.rules]] + # name = "add-constant" + + # [[linters-settings.revive.rules]] + # name = "argument-limit" + + [[linters-settings.revive.rules]] + name = "atomic" + + [[linters-settings.revive.rules]] + name = "bare-return" + + [[linters-settings.revive.rules]] + name = "blank-imports" + + [[linters-settings.revive.rules]] + name = "bool-literal-in-expr" + + [[linters-settings.revive.rules]] + name = "call-to-gc" + + # [[linters-settings.revive.rules]] + # name = "cognitive-complexity" + + [[linters-settings.revive.rules]] + name = "comment-spacings" + arguments = ["easyjson", "nolint"] + + # [[linters-settings.revive.rules]] + # name = "confusing-naming" + + # [[linters-settings.revive.rules]] + # name = "confusing-results" + + [[linters-settings.revive.rules]] + name = "constant-logical-expr" + + [[linters-settings.revive.rules]] + name = "context-as-argument" + + [[linters-settings.revive.rules]] + name = "context-keys-type" + + # [[linters-settings.revive.rules]] + # name = "cyclomatic" + + [[linters-settings.revive.rules]] + name = "datarace" + + # [[linters-settings.revive.rules]] + # name = "deep-exit" + + [[linters-settings.revive.rules]] + name = "defer" + + [[linters-settings.revive.rules]] + name = "dot-imports" + + [[linters-settings.revive.rules]] + name = "duplicated-imports" + + [[linters-settings.revive.rules]] + name = "early-return" + + [[linters-settings.revive.rules]] + name = "empty-block" + + [[linters-settings.revive.rules]] + name = "empty-lines" + + [[linters-settings.revive.rules]] + name = "errorf" + + [[linters-settings.revive.rules]] + name = "error-naming" + + [[linters-settings.revive.rules]] + name = "error-return" + + [[linters-settings.revive.rules]] + name = "error-strings" + + [[linters-settings.revive.rules]] + name = "exported" + + # [[linters-settings.revive.rules]] + # name = "file-header" + + # [[linters-settings.revive.rules]] + # name = "flag-parameter" + + # [[linters-settings.revive.rules]] + # name = "function-result-limit" + + [[linters-settings.revive.rules]] + name = "get-return" + + [[linters-settings.revive.rules]] + name = "identical-branches" + + [[linters-settings.revive.rules]] + name = "if-return" + + [[linters-settings.revive.rules]] + name = "imports-blacklist" + + [[linters-settings.revive.rules]] + name = "import-shadowing" + + [[linters-settings.revive.rules]] + name = "increment-decrement" + + [[linters-settings.revive.rules]] + name = "indent-error-flow" + + # [[linters-settings.revive.rules]] + # name = "line-length-limit" + + # [[linters-settings.revive.rules]] + # name = "max-public-structs" + + [[linters-settings.revive.rules]] + name = "modifies-parameter" + + [[linters-settings.revive.rules]] + name = "modifies-value-receiver" + + # [[linters-settings.revive.rules]] + # name = "nested-structs" + + [[linters-settings.revive.rules]] + name = "optimize-operands-order" + + [[linters-settings.revive.rules]] + name = "package-comments" + + [[linters-settings.revive.rules]] + name = "range" + + [[linters-settings.revive.rules]] + name = "range-val-address" + + [[linters-settings.revive.rules]] + name = "range-val-in-closure" + + [[linters-settings.revive.rules]] + name = "receiver-naming" + + [[linters-settings.revive.rules]] + name = "redefines-builtin-id" + + [[linters-settings.revive.rules]] + name = "string-of-int" + + [[linters-settings.revive.rules]] + name = "struct-tag" + + [[linters-settings.revive.rules]] + name = "superfluous-else" + + [[linters-settings.revive.rules]] + name = "time-equal" + + [[linters-settings.revive.rules]] + name = "time-naming" + + [[linters-settings.revive.rules]] + name = "unconditional-recursion" + + [[linters-settings.revive.rules]] + name = "unexported-naming" + + [[linters-settings.revive.rules]] + name = "unexported-return" + + # [[linters-settings.revive.rules]] + # name = "unhandled-error" + + [[linters-settings.revive.rules]] + name = "unnecessary-stmt" + + [[linters-settings.revive.rules]] + name = "unreachable-code" + + [[linters-settings.revive.rules]] + name = "unused-parameter" + + [[linters-settings.revive.rules]] + name = "unused-receiver" + + [[linters-settings.revive.rules]] + name = "use-any" + + [[linters-settings.revive.rules]] + name = "useless-break" + + [[linters-settings.revive.rules]] + name = "var-declaration" + + [[linters-settings.revive.rules]] + name = "var-naming" + + [[linters-settings.revive.rules]] + name = "waitgroup-by-value" + +[linters-settings.unparam] + check-exported = true + +[issues] +exclude-use-default = false + +[[issues.exclude-rules]] + linters = [ + "govet" + ] + path = "_test.go" + text = "^fieldalignment" diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 49c4478b98ea9f4f84e259d4f24acf2805807a22..0000000000000000000000000000000000000000 --- a/.travis.yml +++ /dev/null @@ -1,23 +0,0 @@ -language: go - -go: - - 1.4 - - 1.5 - - 1.6 - - 1.7 - - 1.8 - - tip - -before_install: - - "if [[ $TRAVIS_GO_VERSION == 1.7 ]]; then go get -v github.com/golang/lint/golint; fi" - -install: - - go get -v -t ./... - -script: - - go test -race -cpu 1,4 -v - - go test -race -v -tags appengine - - "if [[ $TRAVIS_GO_VERSION == 1.7 ]]; then go vet ./...; fi" - - "if [[ $TRAVIS_GO_VERSION == 1.7 ]]; then golint .; fi" - -sudo: false diff --git a/README.md b/README.md index cdd6bd1a8594a03b5af913fa7bf4dcc6ce6f8b64..9662888bdf9eb7186d30ce294b03195eb44a5387 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,6 @@ # MaxMind DB Reader for Go # -[](https://travis-ci.org/oschwald/maxminddb-golang) -[](https://ci.appveyor.com/project/oschwald/maxminddb-golang/branch/master) -[](https://godoc.org/github.com/oschwald/maxminddb-golang) +[](https://godoc.org/github.com/oschwald/maxminddb-golang) This is a Go reader for the MaxMind DB format. Although this can be used to read [GeoLite2](http://dev.maxmind.com/geoip/geoip2/geolite2/) and diff --git a/appveyor.yml b/appveyor.yml deleted file mode 100644 index e2bb9dd23750ce7a5f4fb4faa8df29b0dcbb9368..0000000000000000000000000000000000000000 --- a/appveyor.yml +++ /dev/null @@ -1,19 +0,0 @@ -version: "{build}" - -os: Windows Server 2012 R2 - -clone_folder: c:\gopath\src\github.com\oschwald\maxminddb-golang - -environment: - GOPATH: c:\gopath - -install: - - echo %PATH% - - echo %GOPATH% - - git submodule update --init --recursive - - go version - - go env - - go get -v -t ./... - -build_script: - - go test -v ./... diff --git a/decoder.go b/decoder.go index 396da75445dfe8a5dd7521ec2de1b34c2826feb9..dd0f9ba3066167a3a60d25a3e38bf711209f4dcb 100644 --- a/decoder.go +++ b/decoder.go @@ -27,20 +27,24 @@ const ( _Uint64 _Uint128 _Slice - _Container - _Marker + // We don't use the next two. They are placeholders. See the spec + // for more details. + _Container //nolint: deadcode, varcheck // above + _Marker //nolint: deadcode, varcheck // above _Bool _Float32 ) const ( - // This is the value used in libmaxminddb + // This is the value used in libmaxminddb. maximumDataStructureDepth = 512 ) func (d *decoder) decode(offset uint, result reflect.Value, depth int) (uint, error) { if depth > maximumDataStructureDepth { - return 0, newInvalidDatabaseError("exceeded maximum data structure depth; database is likely corrupt") + return 0, newInvalidDatabaseError( + "exceeded maximum data structure depth; database is likely corrupt", + ) } typeNum, size, newOffset, err := d.decodeCtrlData(offset) if err != nil { @@ -54,6 +58,36 @@ func (d *decoder) decode(offset uint, result reflect.Value, depth int) (uint, er return d.decodeFromType(typeNum, size, newOffset, result, depth+1) } +func (d *decoder) decodeToDeserializer( + offset uint, + dser deserializer, + depth int, + getNext bool, +) (uint, error) { + if depth > maximumDataStructureDepth { + return 0, newInvalidDatabaseError( + "exceeded maximum data structure depth; database is likely corrupt", + ) + } + skip, err := dser.ShouldSkip(uintptr(offset)) + if err != nil { + return 0, err + } + if skip { + if getNext { + return d.nextValueOffset(offset, 1) + } + return 0, nil + } + + typeNum, size, newOffset, err := d.decodeCtrlData(offset) + if err != nil { + return 0, err + } + + return d.decodeFromTypeToDeserializer(typeNum, size, newOffset, dser, depth+1) +} + func (d *decoder) decodeCtrlData(offset uint) (dataType, uint, uint, error) { newOffset := offset + 1 if offset >= uint(len(d.buffer)) { @@ -75,7 +109,11 @@ func (d *decoder) decodeCtrlData(offset uint) (dataType, uint, uint, error) { return typeNum, size, newOffset, err } -func (d *decoder) sizeFromCtrlByte(ctrlByte byte, offset uint, typeNum dataType) (uint, uint, error) { +func (d *decoder) sizeFromCtrlByte( + ctrlByte byte, + offset uint, + typeNum dataType, +) (uint, uint, error) { size := uint(ctrlByte & 0x1f) if typeNum == _Extended { return size, offset, nil @@ -113,12 +151,12 @@ func (d *decoder) decodeFromType( result reflect.Value, depth int, ) (uint, error) { - result = d.indirect(result) + result = indirect(result) // For these types, size has a special meaning switch dtype { case _Bool: - return d.unmarshalBool(size, offset, result) + return unmarshalBool(size, offset, result) case _Map: return d.unmarshalMap(size, offset, result, depth) case _Pointer: @@ -155,14 +193,77 @@ func (d *decoder) decodeFromType( } } -func (d *decoder) unmarshalBool(size uint, offset uint, result reflect.Value) (uint, error) { - if size > 1 { - return 0, newInvalidDatabaseError("the MaxMind DB file's data section contains bad data (bool size of %v)", size) +func (d *decoder) decodeFromTypeToDeserializer( + dtype dataType, + size uint, + offset uint, + dser deserializer, + depth int, +) (uint, error) { + // For these types, size has a special meaning + switch dtype { + case _Bool: + v, offset := decodeBool(size, offset) + return offset, dser.Bool(v) + case _Map: + return d.decodeMapToDeserializer(size, offset, dser, depth) + case _Pointer: + pointer, newOffset, err := d.decodePointer(size, offset) + if err != nil { + return 0, err + } + _, err = d.decodeToDeserializer(pointer, dser, depth, false) + return newOffset, err + case _Slice: + return d.decodeSliceToDeserializer(size, offset, dser, depth) } - value, newOffset, err := d.decodeBool(size, offset) - if err != nil { - return 0, err + + // For the remaining types, size is the byte size + if offset+size > uint(len(d.buffer)) { + return 0, newOffsetError() } + switch dtype { + case _Bytes: + v, offset := d.decodeBytes(size, offset) + return offset, dser.Bytes(v) + case _Float32: + v, offset := d.decodeFloat32(size, offset) + return offset, dser.Float32(v) + case _Float64: + v, offset := d.decodeFloat64(size, offset) + return offset, dser.Float64(v) + case _Int32: + v, offset := d.decodeInt(size, offset) + return offset, dser.Int32(int32(v)) + case _String: + v, offset := d.decodeString(size, offset) + return offset, dser.String(v) + case _Uint16: + v, offset := d.decodeUint(size, offset) + return offset, dser.Uint16(uint16(v)) + case _Uint32: + v, offset := d.decodeUint(size, offset) + return offset, dser.Uint32(uint32(v)) + case _Uint64: + v, offset := d.decodeUint(size, offset) + return offset, dser.Uint64(v) + case _Uint128: + v, offset := d.decodeUint128(size, offset) + return offset, dser.Uint128(v) + default: + return 0, newInvalidDatabaseError("unknown type: %d", dtype) + } +} + +func unmarshalBool(size, offset uint, result reflect.Value) (uint, error) { + if size > 1 { + return 0, newInvalidDatabaseError( + "the MaxMind DB file's data section contains bad data (bool size of %v)", + size, + ) + } + value, newOffset := decodeBool(size, offset) + switch result.Kind() { case reflect.Bool: result.SetBool(value) @@ -180,7 +281,7 @@ func (d *decoder) unmarshalBool(size uint, offset uint, result reflect.Value) (u // heavily based on encoding/json as my original version had a subtle // bug. This method should be considered to be licensed under // https://golang.org/LICENSE -func (d *decoder) indirect(result reflect.Value) reflect.Value { +func indirect(result reflect.Value) reflect.Value { for { // Load value from interface, but only if the result will be // usefully addressable. @@ -199,6 +300,7 @@ func (d *decoder) indirect(result reflect.Value) reflect.Value { if result.IsNil() { result.Set(reflect.New(result.Type().Elem())) } + result = result.Elem() } return result @@ -206,11 +308,9 @@ func (d *decoder) indirect(result reflect.Value) reflect.Value { var sliceType = reflect.TypeOf([]byte{}) -func (d *decoder) unmarshalBytes(size uint, offset uint, result reflect.Value) (uint, error) { - value, newOffset, err := d.decodeBytes(size, offset) - if err != nil { - return 0, err - } +func (d *decoder) unmarshalBytes(size, offset uint, result reflect.Value) (uint, error) { + value, newOffset := d.decodeBytes(size, offset) + switch result.Kind() { case reflect.Slice: if result.Type() == sliceType { @@ -226,14 +326,14 @@ func (d *decoder) unmarshalBytes(size uint, offset uint, result reflect.Value) ( return newOffset, newUnmarshalTypeError(value, result.Type()) } -func (d *decoder) unmarshalFloat32(size uint, offset uint, result reflect.Value) (uint, error) { +func (d *decoder) unmarshalFloat32(size, offset uint, result reflect.Value) (uint, error) { if size != 4 { - return 0, newInvalidDatabaseError("the MaxMind DB file's data section contains bad data (float32 size of %v)", size) - } - value, newOffset, err := d.decodeFloat32(size, offset) - if err != nil { - return 0, err + return 0, newInvalidDatabaseError( + "the MaxMind DB file's data section contains bad data (float32 size of %v)", + size, + ) } + value, newOffset := d.decodeFloat32(size, offset) switch result.Kind() { case reflect.Float32, reflect.Float64: @@ -248,15 +348,15 @@ func (d *decoder) unmarshalFloat32(size uint, offset uint, result reflect.Value) return newOffset, newUnmarshalTypeError(value, result.Type()) } -func (d *decoder) unmarshalFloat64(size uint, offset uint, result reflect.Value) (uint, error) { - +func (d *decoder) unmarshalFloat64(size, offset uint, result reflect.Value) (uint, error) { if size != 8 { - return 0, newInvalidDatabaseError("the MaxMind DB file's data section contains bad data (float 64 size of %v)", size) - } - value, newOffset, err := d.decodeFloat64(size, offset) - if err != nil { - return 0, err + return 0, newInvalidDatabaseError( + "the MaxMind DB file's data section contains bad data (float 64 size of %v)", + size, + ) } + value, newOffset := d.decodeFloat64(size, offset) + switch result.Kind() { case reflect.Float32, reflect.Float64: if result.OverflowFloat(value) { @@ -273,14 +373,14 @@ func (d *decoder) unmarshalFloat64(size uint, offset uint, result reflect.Value) return newOffset, newUnmarshalTypeError(value, result.Type()) } -func (d *decoder) unmarshalInt32(size uint, offset uint, result reflect.Value) (uint, error) { +func (d *decoder) unmarshalInt32(size, offset uint, result reflect.Value) (uint, error) { if size > 4 { - return 0, newInvalidDatabaseError("the MaxMind DB file's data section contains bad data (int32 size of %v)", size) - } - value, newOffset, err := d.decodeInt(size, offset) - if err != nil { - return 0, err + return 0, newInvalidDatabaseError( + "the MaxMind DB file's data section contains bad data (int32 size of %v)", + size, + ) } + value, newOffset := d.decodeInt(size, offset) switch result.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: @@ -289,7 +389,12 @@ func (d *decoder) unmarshalInt32(size uint, offset uint, result reflect.Value) ( result.SetInt(n) return newOffset, nil } - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + case reflect.Uint, + reflect.Uint8, + reflect.Uint16, + reflect.Uint32, + reflect.Uint64, + reflect.Uintptr: n := uint64(value) if !result.OverflowUint(n) { result.SetUint(n) @@ -310,7 +415,7 @@ func (d *decoder) unmarshalMap( result reflect.Value, depth int, ) (uint, error) { - result = d.indirect(result) + result = indirect(result) switch result.Kind() { default: return 0, newUnmarshalTypeError("map", result.Type()) @@ -320,7 +425,7 @@ func (d *decoder) unmarshalMap( return d.decodeMap(size, offset, result, depth) case reflect.Interface: if result.NumMethod() == 0 { - rv := reflect.ValueOf(make(map[string]interface{}, size)) + rv := reflect.ValueOf(make(map[string]any, size)) newOffset, err := d.decodeMap(size, offset, rv, depth) result.Set(rv) return newOffset, err @@ -329,7 +434,11 @@ func (d *decoder) unmarshalMap( } } -func (d *decoder) unmarshalPointer(size uint, offset uint, result reflect.Value, depth int) (uint, error) { +func (d *decoder) unmarshalPointer( + size, offset uint, + result reflect.Value, + depth int, +) (uint, error) { pointer, newOffset, err := d.decodePointer(size, offset) if err != nil { return 0, err @@ -349,7 +458,7 @@ func (d *decoder) unmarshalSlice( return d.decodeSlice(size, offset, result, depth) case reflect.Interface: if result.NumMethod() == 0 { - a := []interface{}{} + a := []any{} rv := reflect.ValueOf(&a).Elem() newOffset, err := d.decodeSlice(size, offset, rv, depth) result.Set(rv) @@ -359,12 +468,9 @@ func (d *decoder) unmarshalSlice( return 0, newUnmarshalTypeError("array", result.Type()) } -func (d *decoder) unmarshalString(size uint, offset uint, result reflect.Value) (uint, error) { - value, newOffset, err := d.decodeString(size, offset) +func (d *decoder) unmarshalString(size, offset uint, result reflect.Value) (uint, error) { + value, newOffset := d.decodeString(size, offset) - if err != nil { - return 0, err - } switch result.Kind() { case reflect.String: result.SetString(value) @@ -376,18 +482,22 @@ func (d *decoder) unmarshalString(size uint, offset uint, result reflect.Value) } } return newOffset, newUnmarshalTypeError(value, result.Type()) - } -func (d *decoder) unmarshalUint(size uint, offset uint, result reflect.Value, uintType uint) (uint, error) { +func (d *decoder) unmarshalUint( + size, offset uint, + result reflect.Value, + uintType uint, +) (uint, error) { if size > uintType/8 { - return 0, newInvalidDatabaseError("the MaxMind DB file's data section contains bad data (uint%v size of %v)", uintType, size) + return 0, newInvalidDatabaseError( + "the MaxMind DB file's data section contains bad data (uint%v size of %v)", + uintType, + size, + ) } - value, newOffset, err := d.decodeUint(size, offset) - if err != nil { - return 0, err - } + value, newOffset := d.decodeUint(size, offset) switch result.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: @@ -396,7 +506,12 @@ func (d *decoder) unmarshalUint(size uint, offset uint, result reflect.Value, ui result.SetInt(n) return newOffset, nil } - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + case reflect.Uint, + reflect.Uint8, + reflect.Uint16, + reflect.Uint32, + reflect.Uint64, + reflect.Uintptr: if !result.OverflowUint(value) { result.SetUint(value) return newOffset, nil @@ -412,14 +527,14 @@ func (d *decoder) unmarshalUint(size uint, offset uint, result reflect.Value, ui var bigIntType = reflect.TypeOf(big.Int{}) -func (d *decoder) unmarshalUint128(size uint, offset uint, result reflect.Value) (uint, error) { +func (d *decoder) unmarshalUint128(size, offset uint, result reflect.Value) (uint, error) { if size > 16 { - return 0, newInvalidDatabaseError("the MaxMind DB file's data section contains bad data (uint128 size of %v)", size) - } - value, newOffset, err := d.decodeUint128(size, offset) - if err != nil { - return 0, err + return 0, newInvalidDatabaseError( + "the MaxMind DB file's data section contains bad data (uint128 size of %v)", + size, + ) } + value, newOffset := d.decodeUint128(size, offset) switch result.Kind() { case reflect.Struct: @@ -436,36 +551,36 @@ func (d *decoder) unmarshalUint128(size uint, offset uint, result reflect.Value) return newOffset, newUnmarshalTypeError(value, result.Type()) } -func (d *decoder) decodeBool(size uint, offset uint) (bool, uint, error) { - return size != 0, offset, nil +func decodeBool(size, offset uint) (bool, uint) { + return size != 0, offset } -func (d *decoder) decodeBytes(size uint, offset uint) ([]byte, uint, error) { +func (d *decoder) decodeBytes(size, offset uint) ([]byte, uint) { newOffset := offset + size bytes := make([]byte, size) copy(bytes, d.buffer[offset:newOffset]) - return bytes, newOffset, nil + return bytes, newOffset } -func (d *decoder) decodeFloat64(size uint, offset uint) (float64, uint, error) { +func (d *decoder) decodeFloat64(size, offset uint) (float64, uint) { newOffset := offset + size bits := binary.BigEndian.Uint64(d.buffer[offset:newOffset]) - return math.Float64frombits(bits), newOffset, nil + return math.Float64frombits(bits), newOffset } -func (d *decoder) decodeFloat32(size uint, offset uint) (float32, uint, error) { +func (d *decoder) decodeFloat32(size, offset uint) (float32, uint) { newOffset := offset + size bits := binary.BigEndian.Uint32(d.buffer[offset:newOffset]) - return math.Float32frombits(bits), newOffset, nil + return math.Float32frombits(bits), newOffset } -func (d *decoder) decodeInt(size uint, offset uint) (int, uint, error) { +func (d *decoder) decodeInt(size, offset uint) (int, uint) { newOffset := offset + size var val int32 for _, b := range d.buffer[offset:newOffset] { val = (val << 8) | int32(b) } - return int(val), newOffset, nil + return int(val), newOffset } func (d *decoder) decodeMap( @@ -475,24 +590,65 @@ func (d *decoder) decodeMap( depth int, ) (uint, error) { if result.IsNil() { - result.Set(reflect.MakeMap(result.Type())) + result.Set(reflect.MakeMapWithSize(result.Type(), int(size))) } + mapType := result.Type() + keyValue := reflect.New(mapType.Key()).Elem() + elemType := mapType.Elem() + var elemValue reflect.Value for i := uint(0); i < size; i++ { var key []byte var err error key, offset, err = d.decodeKey(offset) + if err != nil { + return 0, err + } + + if elemValue.IsValid() { + // After 1.20 is the minimum supported version, this can just be + // elemValue.SetZero() + reflectSetZero(elemValue) + } else { + elemValue = reflect.New(elemType).Elem() + } + offset, err = d.decode(offset, elemValue, depth) if err != nil { return 0, err } - value := reflect.New(result.Type().Elem()) - offset, err = d.decode(offset, value, depth) + keyValue.SetString(string(key)) + result.SetMapIndex(keyValue, elemValue) + } + return offset, nil +} + +func (d *decoder) decodeMapToDeserializer( + size uint, + offset uint, + dser deserializer, + depth int, +) (uint, error) { + err := dser.StartMap(size) + if err != nil { + return 0, err + } + for i := uint(0); i < size; i++ { + // TODO - implement key/value skipping? + offset, err = d.decodeToDeserializer(offset, dser, depth, true) + if err != nil { + return 0, err + } + + offset, err = d.decodeToDeserializer(offset, dser, depth, true) if err != nil { return 0, err } - result.SetMapIndex(reflect.ValueOf(string(key)), value.Elem()) + } + err = dser.End() + if err != nil { + return 0, err } return offset, nil } @@ -511,7 +667,7 @@ func (d *decoder) decodePointer( if pointerSize == 4 { prefix = 0 } else { - prefix = uint(size & 0x7) + prefix = size & 0x7 } unpacked := uintFromBytes(prefix, pointerBytes) @@ -549,60 +705,44 @@ func (d *decoder) decodeSlice( return offset, nil } -func (d *decoder) decodeString(size uint, offset uint) (string, uint, error) { - newOffset := offset + size - return string(d.buffer[offset:newOffset]), newOffset, nil +func (d *decoder) decodeSliceToDeserializer( + size uint, + offset uint, + dser deserializer, + depth int, +) (uint, error) { + err := dser.StartSlice(size) + if err != nil { + return 0, err + } + for i := uint(0); i < size; i++ { + offset, err = d.decodeToDeserializer(offset, dser, depth, true) + if err != nil { + return 0, err + } + } + err = dser.End() + if err != nil { + return 0, err + } + return offset, nil } -type fieldsType struct { - namedFields map[string]int - anonymousFields []int +func (d *decoder) decodeString(size, offset uint) (string, uint) { + newOffset := offset + size + return string(d.buffer[offset:newOffset]), newOffset } -var ( - fieldMap = map[reflect.Type]*fieldsType{} - fieldMapMu sync.RWMutex -) - func (d *decoder) decodeStruct( size uint, offset uint, result reflect.Value, depth int, ) (uint, error) { - resultType := result.Type() - - fieldMapMu.RLock() - fields, ok := fieldMap[resultType] - fieldMapMu.RUnlock() - if !ok { - numFields := resultType.NumField() - namedFields := make(map[string]int, numFields) - var anonymous []int - for i := 0; i < numFields; i++ { - field := resultType.Field(i) - - fieldName := field.Name - if tag := field.Tag.Get("maxminddb"); tag != "" { - if tag == "-" { - continue - } - fieldName = tag - } - if field.Anonymous { - anonymous = append(anonymous, i) - continue - } - namedFields[fieldName] = i - } - fieldMapMu.Lock() - fields = &fieldsType{namedFields, anonymous} - fieldMap[resultType] = fields - fieldMapMu.Unlock() - } + fields := cachedFields(result) // This fills in embedded structs - for i := range fields.anonymousFields { + for _, i := range fields.anonymousFields { _, err := d.unmarshalMap(size, offset, result.Field(i), depth) if err != nil { return 0, err @@ -638,7 +778,45 @@ func (d *decoder) decodeStruct( return offset, nil } -func (d *decoder) decodeUint(size uint, offset uint) (uint64, uint, error) { +type fieldsType struct { + namedFields map[string]int + anonymousFields []int +} + +var fieldsMap sync.Map + +func cachedFields(result reflect.Value) *fieldsType { + resultType := result.Type() + + if fields, ok := fieldsMap.Load(resultType); ok { + return fields.(*fieldsType) + } + numFields := resultType.NumField() + namedFields := make(map[string]int, numFields) + var anonymous []int + for i := 0; i < numFields; i++ { + field := resultType.Field(i) + + fieldName := field.Name + if tag := field.Tag.Get("maxminddb"); tag != "" { + if tag == "-" { + continue + } + fieldName = tag + } + if field.Anonymous { + anonymous = append(anonymous, i) + continue + } + namedFields[fieldName] = i + } + fields := &fieldsType{namedFields, anonymous} + fieldsMap.Store(resultType, fields) + + return fields +} + +func (d *decoder) decodeUint(size, offset uint) (uint64, uint) { newOffset := offset + size bytes := d.buffer[offset:newOffset] @@ -646,15 +824,15 @@ func (d *decoder) decodeUint(size uint, offset uint) (uint64, uint, error) { for _, b := range bytes { val = (val << 8) | uint64(b) } - return val, newOffset, nil + return val, newOffset } -func (d *decoder) decodeUint128(size uint, offset uint) (*big.Int, uint, error) { +func (d *decoder) decodeUint128(size, offset uint) (*big.Int, uint) { newOffset := offset + size val := new(big.Int) val.SetBytes(d.buffer[offset:newOffset]) - return val, newOffset, nil + return val, newOffset } func uintFromBytes(prefix uint, uintBytes []byte) uint { @@ -694,8 +872,8 @@ func (d *decoder) decodeKey(offset uint) ([]byte, uint, error) { // This function is used to skip ahead to the next value without decoding // the one at the offset passed in. The size bits have different meanings for -// different data types -func (d *decoder) nextValueOffset(offset uint, numberToSkip uint) (uint, error) { +// different data types. +func (d *decoder) nextValueOffset(offset, numberToSkip uint) (uint, error) { if numberToSkip == 0 { return offset, nil } diff --git a/decoder_test.go b/decoder_test.go index 921cbdf7dafef9656f0a568b9e488392e11c9e1c..27e5ece2615ee23fa2793390a946118d6755297c 100644 --- a/decoder_test.go +++ b/decoder_test.go @@ -2,17 +2,18 @@ package maxminddb import ( "encoding/hex" - "io/ioutil" "math/big" + "os" "reflect" "strings" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestBool(t *testing.T) { - bools := map[string]interface{}{ + bools := map[string]any{ "0007": false, "0107": true, } @@ -21,7 +22,7 @@ func TestBool(t *testing.T) { } func TestDouble(t *testing.T) { - doubles := map[string]interface{}{ + doubles := map[string]any{ "680000000000000000": 0.0, "683FE0000000000000": 0.5, "68400921FB54442EEA": 3.14159265359, @@ -35,7 +36,7 @@ func TestDouble(t *testing.T) { } func TestFloat(t *testing.T) { - floats := map[string]interface{}{ + floats := map[string]any{ "040800000000": float32(0.0), "04083F800000": float32(1.0), "04083F8CCCCD": float32(1.1), @@ -50,7 +51,7 @@ func TestFloat(t *testing.T) { } func TestInt32(t *testing.T) { - int32 := map[string]interface{}{ + int32s := map[string]any{ "0001": 0, "0401ffffffff": -1, "0101ff": 255, @@ -64,33 +65,37 @@ func TestInt32(t *testing.T) { "04017fffffff": 2147483647, "040180000001": -2147483647, } - validateDecoding(t, int32) + validateDecoding(t, int32s) } func TestMap(t *testing.T) { - maps := map[string]interface{}{ - "e0": map[string]interface{}{}, - "e142656e43466f6f": map[string]interface{}{"en": "Foo"}, - "e242656e43466f6f427a6843e4baba": map[string]interface{}{"en": "Foo", "zh": "人"}, - "e1446e616d65e242656e43466f6f427a6843e4baba": map[string]interface{}{"name": map[string]interface{}{"en": "Foo", "zh": "人"}}, - "e1496c616e677561676573020442656e427a68": map[string]interface{}{"languages": []interface{}{"en", "zh"}}, + maps := map[string]any{ + "e0": map[string]any{}, + "e142656e43466f6f": map[string]any{"en": "Foo"}, + "e242656e43466f6f427a6843e4baba": map[string]any{"en": "Foo", "zh": "人"}, + "e1446e616d65e242656e43466f6f427a6843e4baba": map[string]any{ + "name": map[string]any{"en": "Foo", "zh": "人"}, + }, + "e1496c616e677561676573020442656e427a68": map[string]any{ + "languages": []any{"en", "zh"}, + }, } validateDecoding(t, maps) } func TestSlice(t *testing.T) { - slice := map[string]interface{}{ - "0004": []interface{}{}, - "010443466f6f": []interface{}{"Foo"}, - "020443466f6f43e4baba": []interface{}{"Foo", "人"}, + slice := map[string]any{ + "0004": []any{}, + "010443466f6f": []any{"Foo"}, + "020443466f6f43e4baba": []any{"Foo", "人"}, } validateDecoding(t, slice) } var testStrings = makeTestStrings() -func makeTestStrings() map[string]interface{} { - str := map[string]interface{}{ +func makeTestStrings() map[string]any { + str := map[string]any{ "40": "", "4131": "1", "43E4BABA": "人", @@ -113,9 +118,10 @@ func TestString(t *testing.T) { } func TestByte(t *testing.T) { - b := make(map[string]interface{}) + b := make(map[string]any) for key, val := range testStrings { - oldCtrl, _ := hex.DecodeString(key[0:2]) + oldCtrl, err := hex.DecodeString(key[0:2]) + require.NoError(t, err) newCtrl := []byte{oldCtrl[0] ^ 0xc0} key = strings.Replace(key, hex.EncodeToString(oldCtrl), hex.EncodeToString(newCtrl), 1) b[key] = []byte(val.(string)) @@ -125,18 +131,18 @@ func TestByte(t *testing.T) { } func TestUint16(t *testing.T) { - uint16 := map[string]interface{}{ + uint16s := map[string]any{ "a0": uint64(0), "a1ff": uint64(255), "a201f4": uint64(500), "a22a78": uint64(10872), "a2ffff": uint64(65535), } - validateDecoding(t, uint16) + validateDecoding(t, uint16s) } func TestUint32(t *testing.T) { - uint32 := map[string]interface{}{ + uint32s := map[string]any{ "c0": uint64(0), "c1ff": uint64(255), "c201f4": uint64(500), @@ -145,14 +151,14 @@ func TestUint32(t *testing.T) { "c3ffffff": uint64(16777215), "c4ffffffff": uint64(4294967295), } - validateDecoding(t, uint32) + validateDecoding(t, uint32s) } func TestUint64(t *testing.T) { ctrlByte := "02" bits := uint64(64) - uints := map[string]interface{}{ + uints := map[string]any{ "00" + ctrlByte: uint64(0), "02" + ctrlByte + "01f4": uint64(500), "02" + ctrlByte + "2a78": uint64(10872), @@ -167,12 +173,12 @@ func TestUint64(t *testing.T) { validateDecoding(t, uints) } -// Dedup with above somehow +// Dedup with above somehow. func TestUint128(t *testing.T) { ctrlByte := "03" bits := uint(128) - uints := map[string]interface{}{ + uints := map[string]any{ "00" + ctrlByte: big.NewInt(0), "02" + ctrlByte + "01f4": big.NewInt(500), "02" + ctrlByte + "2a78": big.NewInt(10872), @@ -189,7 +195,7 @@ func TestUint128(t *testing.T) { } // No pow or bit shifting for big int, apparently :-( -// This is _not_ meant to be a comprehensive power function +// This is _not_ meant to be a comprehensive power function. func powBigInt(bi *big.Int, pow uint) *big.Int { newInt := big.NewInt(1) for i := uint(0); i < pow; i++ { @@ -198,14 +204,15 @@ func powBigInt(bi *big.Int, pow uint) *big.Int { return newInt } -func validateDecoding(t *testing.T, tests map[string]interface{}) { +func validateDecoding(t *testing.T, tests map[string]any) { for inputStr, expected := range tests { - inputBytes, _ := hex.DecodeString(inputStr) + inputBytes, err := hex.DecodeString(inputStr) + require.NoError(t, err) d := decoder{inputBytes} - var result interface{} - _, err := d.decode(0, reflect.ValueOf(&result), 0) - assert.Nil(t, err) + var result any + _, err = d.decode(0, reflect.ValueOf(&result), 0) + assert.NoError(t, err) if !reflect.DeepEqual(result, expected) { // A big case statement would produce nicer errors @@ -215,8 +222,8 @@ func validateDecoding(t *testing.T, tests map[string]interface{}) { } func TestPointers(t *testing.T) { - bytes, err := ioutil.ReadFile("test-data/test-data/maps-with-pointers.raw") - assert.Nil(t, err) + bytes, err := os.ReadFile(testFile("maps-with-pointers.raw")) + require.NoError(t, err) d := decoder{bytes} expected := map[uint]map[string]string{ @@ -231,10 +238,9 @@ func TestPointers(t *testing.T) { for offset, expectedValue := range expected { var actual map[string]string _, err := d.decode(offset, reflect.ValueOf(&actual), 0) - assert.Nil(t, err) + assert.NoError(t, err) if !reflect.DeepEqual(actual, expectedValue) { t.Errorf("Decode for pointer at %d failed", offset) } } - } diff --git a/deserializer.go b/deserializer.go new file mode 100644 index 0000000000000000000000000000000000000000..c6dd68d14ceb116012762b999ea8dd0efc46e009 --- /dev/null +++ b/deserializer.go @@ -0,0 +1,31 @@ +package maxminddb + +import "math/big" + +// deserializer is an interface for a type that deserializes an MaxMind DB +// data record to some other type. This exists as an alternative to the +// standard reflection API. +// +// This is fundamentally different than the Unmarshaler interface that +// several packages provide. A Deserializer will generally create the +// final struct or value rather than unmarshaling to itself. +// +// This interface and the associated unmarshaling code is EXPERIMENTAL! +// It is not currently covered by any Semantic Versioning guarantees. +// Use at your own risk. +type deserializer interface { + ShouldSkip(offset uintptr) (bool, error) + StartSlice(size uint) error + StartMap(size uint) error + End() error + String(string) error + Float64(float64) error + Bytes([]byte) error + Uint16(uint16) error + Uint32(uint32) error + Int32(int32) error + Uint64(uint64) error + Uint128(*big.Int) error + Bool(bool) error + Float32(float32) error +} diff --git a/deserializer_test.go b/deserializer_test.go new file mode 100644 index 0000000000000000000000000000000000000000..c68a9d5a9a69feb990ee83375ec623c1f528d72f --- /dev/null +++ b/deserializer_test.go @@ -0,0 +1,120 @@ +package maxminddb + +import ( + "math/big" + "net" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestDecodingToDeserializer(t *testing.T) { + reader, err := Open(testFile("MaxMind-DB-test-decoder.mmdb")) + require.NoError(t, err, "unexpected error while opening database: %v", err) + + dser := testDeserializer{} + err = reader.Lookup(net.ParseIP("::1.1.1.0"), &dser) + require.NoError(t, err, "unexpected error while doing lookup: %v", err) + + checkDecodingToInterface(t, dser.rv) +} + +type stackValue struct { + value any + curNum int +} + +type testDeserializer struct { + stack []*stackValue + rv any + key *string +} + +func (*testDeserializer) ShouldSkip(_ uintptr) (bool, error) { + return false, nil +} + +func (d *testDeserializer) StartSlice(size uint) error { + return d.add(make([]any, size)) +} + +func (d *testDeserializer) StartMap(_ uint) error { + return d.add(map[string]any{}) +} + +//nolint:unparam // This is to meet the requirements of the interface. +func (d *testDeserializer) End() error { + d.stack = d.stack[:len(d.stack)-1] + return nil +} + +func (d *testDeserializer) String(v string) error { + return d.add(v) +} + +func (d *testDeserializer) Float64(v float64) error { + return d.add(v) +} + +func (d *testDeserializer) Bytes(v []byte) error { + return d.add(v) +} + +func (d *testDeserializer) Uint16(v uint16) error { + return d.add(uint64(v)) +} + +func (d *testDeserializer) Uint32(v uint32) error { + return d.add(uint64(v)) +} + +func (d *testDeserializer) Int32(v int32) error { + return d.add(int(v)) +} + +func (d *testDeserializer) Uint64(v uint64) error { + return d.add(v) +} + +func (d *testDeserializer) Uint128(v *big.Int) error { + return d.add(v) +} + +func (d *testDeserializer) Bool(v bool) error { + return d.add(v) +} + +func (d *testDeserializer) Float32(v float32) error { + return d.add(v) +} + +func (d *testDeserializer) add(v any) error { + if len(d.stack) == 0 { + d.rv = v + } else { + top := d.stack[len(d.stack)-1] + switch parent := top.value.(type) { + case map[string]any: + if d.key == nil { + key := v.(string) + d.key = &key + } else { + parent[*d.key] = v + d.key = nil + } + + case []any: + parent[top.curNum] = v + top.curNum++ + default: + } + } + + switch v := v.(type) { + case map[string]any, []any: + d.stack = append(d.stack, &stackValue{value: v}) + default: + } + + return nil +} diff --git a/errors.go b/errors.go index 132780019bb3a39e06b0ce85728a735bc75d4ed5..aeba906b52aacf7dc1f55987fa72c7999f7603dd 100644 --- a/errors.go +++ b/errors.go @@ -15,7 +15,7 @@ func newOffsetError() InvalidDatabaseError { return InvalidDatabaseError{"unexpected end of database"} } -func newInvalidDatabaseError(format string, args ...interface{}) InvalidDatabaseError { +func newInvalidDatabaseError(format string, args ...any) InvalidDatabaseError { return InvalidDatabaseError{fmt.Sprintf(format, args...)} } @@ -26,11 +26,11 @@ func (e InvalidDatabaseError) Error() string { // UnmarshalTypeError is returned when the value in the database cannot be // assigned to the specified data type. type UnmarshalTypeError struct { - Value string // stringified copy of the database value that caused the error - Type reflect.Type // type of the value that could not be assign to + Type reflect.Type + Value string } -func newUnmarshalTypeError(value interface{}, rType reflect.Type) UnmarshalTypeError { +func newUnmarshalTypeError(value any, rType reflect.Type) UnmarshalTypeError { return UnmarshalTypeError{ Value: fmt.Sprintf("%v", value), Type: rType, diff --git a/example_test.go b/example_test.go index f6768d1cee97d7bbac4d4ecc759b40fba7fce767..9e2bbc373db0bc0ef82015ed80bfd6c8255557c2 100644 --- a/example_test.go +++ b/example_test.go @@ -8,7 +8,7 @@ import ( "github.com/oschwald/maxminddb-golang" ) -// This example shows how to decode to a struct +// This example shows how to decode to a struct. func ExampleReader_Lookup_struct() { db, err := maxminddb.Open("test-data/test-data/GeoIP2-City-Test.mmdb") if err != nil { @@ -26,14 +26,14 @@ func ExampleReader_Lookup_struct() { err = db.Lookup(ip, &record) if err != nil { - log.Fatal(err) + log.Panic(err) } fmt.Print(record.Country.ISOCode) // Output: // GB } -// This example demonstrates how to decode to an interface{} +// This example demonstrates how to decode to an any. func ExampleReader_Lookup_interface() { db, err := maxminddb.Open("test-data/test-data/GeoIP2-City-Test.mmdb") if err != nil { @@ -43,16 +43,16 @@ func ExampleReader_Lookup_interface() { ip := net.ParseIP("81.2.69.142") - var record interface{} + var record any err = db.Lookup(ip, &record) if err != nil { - log.Fatal(err) + log.Panic(err) } fmt.Printf("%v", record) } // This example demonstrates how to iterate over all networks in the -// database +// database. func ExampleReader_Networks() { db, err := maxminddb.Open("test-data/test-data/GeoIP2-Connection-Type-Test.mmdb") if err != nil { @@ -64,94 +64,82 @@ func ExampleReader_Networks() { Domain string `maxminddb:"connection_type"` }{} - networks := db.Networks() + networks := db.Networks(maxminddb.SkipAliasedNetworks) for networks.Next() { subnet, err := networks.Network(&record) if err != nil { - log.Fatal(err) + log.Panic(err) } fmt.Printf("%s: %s\n", subnet.String(), record.Domain) } if networks.Err() != nil { - log.Fatal(networks.Err()) + log.Panic(networks.Err()) } // Output: - // ::100:0/120: Dialup - // ::100:100/120: Cable/DSL - // ::100:200/119: Dialup - // ::100:400/118: Dialup - // ::100:800/117: Dialup - // ::100:1000/116: Dialup - // ::100:2000/115: Dialup - // ::100:4000/114: Dialup - // ::100:8000/113: Dialup - // ::50d6:0/116: Cellular - // ::6001:0/112: Cable/DSL - // ::600a:0/111: Cable/DSL - // ::6045:0/112: Cable/DSL - // ::605e:0/111: Cable/DSL - // ::6c60:0/107: Cellular - // ::af10:c700/120: Dialup - // ::bb9c:8a00/120: Cable/DSL - // ::c9f3:c800/120: Corporate - // ::cfb3:3000/116: Cellular - // 1.0.0.0/24: Dialup - // 1.0.1.0/24: Cable/DSL - // 1.0.2.0/23: Dialup - // 1.0.4.0/22: Dialup - // 1.0.8.0/21: Dialup - // 1.0.16.0/20: Dialup - // 1.0.32.0/19: Dialup - // 1.0.64.0/18: Dialup - // 1.0.128.0/17: Dialup + // 1.0.0.0/24: Cable/DSL + // 1.0.1.0/24: Cellular + // 1.0.2.0/23: Cable/DSL + // 1.0.4.0/22: Cable/DSL + // 1.0.8.0/21: Cable/DSL + // 1.0.16.0/20: Cable/DSL + // 1.0.32.0/19: Cable/DSL + // 1.0.64.0/18: Cable/DSL + // 1.0.128.0/17: Cable/DSL + // 2.125.160.216/29: Cable/DSL + // 67.43.156.0/24: Cellular // 80.214.0.0/20: Cellular // 96.1.0.0/16: Cable/DSL // 96.10.0.0/15: Cable/DSL // 96.69.0.0/16: Cable/DSL // 96.94.0.0/15: Cable/DSL // 108.96.0.0/11: Cellular - // 175.16.199.0/24: Dialup + // 149.101.100.0/28: Cellular + // 175.16.199.0/24: Cable/DSL // 187.156.138.0/24: Cable/DSL // 201.243.200.0/24: Corporate // 207.179.48.0/20: Cellular - // 2001:0:100::/56: Dialup - // 2001:0:100:100::/56: Cable/DSL - // 2001:0:100:200::/55: Dialup - // 2001:0:100:400::/54: Dialup - // 2001:0:100:800::/53: Dialup - // 2001:0:100:1000::/52: Dialup - // 2001:0:100:2000::/51: Dialup - // 2001:0:100:4000::/50: Dialup - // 2001:0:100:8000::/49: Dialup - // 2001:0:50d6::/52: Cellular - // 2001:0:6001::/48: Cable/DSL - // 2001:0:600a::/47: Cable/DSL - // 2001:0:6045::/48: Cable/DSL - // 2001:0:605e::/47: Cable/DSL - // 2001:0:6c60::/43: Cellular - // 2001:0:af10:c700::/56: Dialup - // 2001:0:bb9c:8a00::/56: Cable/DSL - // 2001:0:c9f3:c800::/56: Corporate - // 2001:0:cfb3:3000::/52: Cellular - // 2002:100::/40: Dialup - // 2002:100:100::/40: Cable/DSL - // 2002:100:200::/39: Dialup - // 2002:100:400::/38: Dialup - // 2002:100:800::/37: Dialup - // 2002:100:1000::/36: Dialup - // 2002:100:2000::/35: Dialup - // 2002:100:4000::/34: Dialup - // 2002:100:8000::/33: Dialup - // 2002:50d6::/36: Cellular - // 2002:6001::/32: Cable/DSL - // 2002:600a::/31: Cable/DSL - // 2002:6045::/32: Cable/DSL - // 2002:605e::/31: Cable/DSL - // 2002:6c60::/27: Cellular - // 2002:af10:c700::/40: Dialup - // 2002:bb9c:8a00::/40: Cable/DSL - // 2002:c9f3:c800::/40: Corporate - // 2002:cfb3:3000::/36: Cellular + // 216.160.83.56/29: Corporate // 2003::/24: Cable/DSL +} + +// This example demonstrates how to iterate over all networks in the +// database which are contained within an arbitrary network. +func ExampleReader_NetworksWithin() { + db, err := maxminddb.Open("test-data/test-data/GeoIP2-Connection-Type-Test.mmdb") + if err != nil { + log.Fatal(err) + } + defer db.Close() + + record := struct { + Domain string `maxminddb:"connection_type"` + }{} + _, network, err := net.ParseCIDR("1.0.0.0/8") + if err != nil { + log.Panic(err) + } + + networks := db.NetworksWithin(network, maxminddb.SkipAliasedNetworks) + for networks.Next() { + subnet, err := networks.Network(&record) + if err != nil { + log.Panic(err) + } + fmt.Printf("%s: %s\n", subnet.String(), record.Domain) + } + if networks.Err() != nil { + log.Panic(networks.Err()) + } + + // Output: + // 1.0.0.0/24: Cable/DSL + // 1.0.1.0/24: Cellular + // 1.0.2.0/23: Cable/DSL + // 1.0.4.0/22: Cable/DSL + // 1.0.8.0/21: Cable/DSL + // 1.0.16.0/20: Cable/DSL + // 1.0.32.0/19: Cable/DSL + // 1.0.64.0/18: Cable/DSL + // 1.0.128.0/17: Cable/DSL } diff --git a/go.mod b/go.mod new file mode 100644 index 0000000000000000000000000000000000000000..fadd604e2d4d9687fc5dc94983c503dd13770ebf --- /dev/null +++ b/go.mod @@ -0,0 +1,14 @@ +module github.com/oschwald/maxminddb-golang + +go 1.19 + +require ( + github.com/stretchr/testify v1.8.4 + golang.org/x/sys v0.10.0 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000000000000000000000000000000000000..f7877e0a2fb436befa37d4589972cc745ed22782 --- /dev/null +++ b/go.sum @@ -0,0 +1,12 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA= +golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/mmap_unix.go b/mmap_unix.go index 99f98cab6caca0d1668c6f7af2edd3090e43ac48..48b2e403ce2bc79e6fd3577ef5fe5bfac6128d14 100644 --- a/mmap_unix.go +++ b/mmap_unix.go @@ -1,15 +1,14 @@ -// +build !windows,!appengine +//go:build !windows && !appengine && !plan9 && !js && !wasip1 && !wasi +// +build !windows,!appengine,!plan9,!js,!wasip1,!wasi package maxminddb import ( - "syscall" - "golang.org/x/sys/unix" ) -func mmap(fd int, length int) (data []byte, err error) { - return unix.Mmap(fd, 0, length, syscall.PROT_READ, syscall.MAP_SHARED) +func mmap(fd, length int) (data []byte, err error) { + return unix.Mmap(fd, 0, length, unix.PROT_READ, unix.MAP_SHARED) } func munmap(b []byte) (err error) { diff --git a/mmap_windows.go b/mmap_windows.go index 661250eca00bc59f5ce00fb8251809d6cbfe45b3..79133a7fb58b7108f6500b36af8f1f9646cbe936 100644 --- a/mmap_windows.go +++ b/mmap_windows.go @@ -1,3 +1,4 @@ +//go:build windows && !appengine // +build windows,!appengine package maxminddb diff --git a/node.go b/node.go new file mode 100644 index 0000000000000000000000000000000000000000..16e8b5f6a0f9d9978e688b71a3920682420d9848 --- /dev/null +++ b/node.go @@ -0,0 +1,58 @@ +package maxminddb + +type nodeReader interface { + readLeft(uint) uint + readRight(uint) uint +} + +type nodeReader24 struct { + buffer []byte +} + +func (n nodeReader24) readLeft(nodeNumber uint) uint { + return (uint(n.buffer[nodeNumber]) << 16) | + (uint(n.buffer[nodeNumber+1]) << 8) | + uint(n.buffer[nodeNumber+2]) +} + +func (n nodeReader24) readRight(nodeNumber uint) uint { + return (uint(n.buffer[nodeNumber+3]) << 16) | + (uint(n.buffer[nodeNumber+4]) << 8) | + uint(n.buffer[nodeNumber+5]) +} + +type nodeReader28 struct { + buffer []byte +} + +func (n nodeReader28) readLeft(nodeNumber uint) uint { + return ((uint(n.buffer[nodeNumber+3]) & 0xF0) << 20) | + (uint(n.buffer[nodeNumber]) << 16) | + (uint(n.buffer[nodeNumber+1]) << 8) | + uint(n.buffer[nodeNumber+2]) +} + +func (n nodeReader28) readRight(nodeNumber uint) uint { + return ((uint(n.buffer[nodeNumber+3]) & 0x0F) << 24) | + (uint(n.buffer[nodeNumber+4]) << 16) | + (uint(n.buffer[nodeNumber+5]) << 8) | + uint(n.buffer[nodeNumber+6]) +} + +type nodeReader32 struct { + buffer []byte +} + +func (n nodeReader32) readLeft(nodeNumber uint) uint { + return (uint(n.buffer[nodeNumber]) << 24) | + (uint(n.buffer[nodeNumber+1]) << 16) | + (uint(n.buffer[nodeNumber+2]) << 8) | + uint(n.buffer[nodeNumber+3]) +} + +func (n nodeReader32) readRight(nodeNumber uint) uint { + return (uint(n.buffer[nodeNumber+4]) << 24) | + (uint(n.buffer[nodeNumber+5]) << 16) | + (uint(n.buffer[nodeNumber+6]) << 8) | + uint(n.buffer[nodeNumber+7]) +} diff --git a/reader.go b/reader.go index bc933e99e51a940eecd2fc01b23026e9c126b9d3..470845b2b26bf0491082c501235b89d7d7131d73 100644 --- a/reader.go +++ b/reader.go @@ -1,3 +1,4 @@ +// Package maxminddb provides a reader for the MaxMind DB file format. package maxminddb import ( @@ -20,26 +21,32 @@ var metadataStartMarker = []byte("\xAB\xCD\xEFMaxMind.com") // Reader holds the data corresponding to the MaxMind DB file. Its only public // field is Metadata, which contains the metadata from the MaxMind DB file. +// +// All of the methods on Reader are thread-safe. The struct may be safely +// shared across goroutines. type Reader struct { - hasMappedFile bool - buffer []byte - decoder decoder - Metadata Metadata - ipv4Start uint + nodeReader nodeReader + buffer []byte + decoder decoder + Metadata Metadata + ipv4Start uint + ipv4StartBitDepth int + nodeOffsetMult uint + hasMappedFile bool } // Metadata holds the metadata decoded from the MaxMind DB file. In particular -// in has the format version, the build time as Unix epoch time, the database +// it has the format version, the build time as Unix epoch time, the database // type and description, the IP version supported, and a slice of the natural // languages included. type Metadata struct { + Description map[string]string `maxminddb:"description"` + DatabaseType string `maxminddb:"database_type"` + Languages []string `maxminddb:"languages"` BinaryFormatMajorVersion uint `maxminddb:"binary_format_major_version"` BinaryFormatMinorVersion uint `maxminddb:"binary_format_minor_version"` BuildEpoch uint `maxminddb:"build_epoch"` - DatabaseType string `maxminddb:"database_type"` - Description map[string]string `maxminddb:"description"` IPVersion uint `maxminddb:"ip_version"` - Languages []string `maxminddb:"languages"` NodeCount uint `maxminddb:"node_count"` RecordSize uint `maxminddb:"record_size"` } @@ -74,63 +81,130 @@ func FromBytes(buffer []byte) (*Reader, error) { buffer[searchTreeSize+dataSectionSeparatorSize : metadataStart-len(metadataStartMarker)], } + nodeBuffer := buffer[:searchTreeSize] + var nodeReader nodeReader + switch metadata.RecordSize { + case 24: + nodeReader = nodeReader24{buffer: nodeBuffer} + case 28: + nodeReader = nodeReader28{buffer: nodeBuffer} + case 32: + nodeReader = nodeReader32{buffer: nodeBuffer} + default: + return nil, newInvalidDatabaseError("unknown record size: %d", metadata.RecordSize) + } + reader := &Reader{ - buffer: buffer, - decoder: d, - Metadata: metadata, - ipv4Start: 0, + buffer: buffer, + nodeReader: nodeReader, + decoder: d, + Metadata: metadata, + ipv4Start: 0, + nodeOffsetMult: metadata.RecordSize / 4, } - reader.ipv4Start, err = reader.startNode() + reader.setIPv4Start() return reader, err } -func (r *Reader) startNode() (uint, error) { +func (r *Reader) setIPv4Start() { if r.Metadata.IPVersion != 6 { - return 0, nil + return } nodeCount := r.Metadata.NodeCount node := uint(0) - var err error - for i := 0; i < 96 && node < nodeCount; i++ { - node, err = r.readNode(node, 0) - if err != nil { - return 0, err - } + i := 0 + for ; i < 96 && node < nodeCount; i++ { + node = r.nodeReader.readLeft(node * r.nodeOffsetMult) } - return node, err + r.ipv4Start = node + r.ipv4StartBitDepth = i } -// Lookup takes an IP address as a net.IP structure and a pointer to the -// result value to Decode into. -func (r *Reader) Lookup(ipAddress net.IP, result interface{}) error { - pointer, err := r.lookupPointer(ipAddress) +// Lookup retrieves the database record for ip and stores it in the value +// pointed to by result. If result is nil or not a pointer, an error is +// returned. If the data in the database record cannot be stored in result +// because of type differences, an UnmarshalTypeError is returned. If the +// database is invalid or otherwise cannot be read, an InvalidDatabaseError +// is returned. +func (r *Reader) Lookup(ip net.IP, result any) error { + if r.buffer == nil { + return errors.New("cannot call Lookup on a closed database") + } + pointer, _, _, err := r.lookupPointer(ip) if pointer == 0 || err != nil { return err } return r.retrieveData(pointer, result) } +// LookupNetwork retrieves the database record for ip and stores it in the +// value pointed to by result. The network returned is the network associated +// with the data record in the database. The ok return value indicates whether +// the database contained a record for the ip. +// +// If result is nil or not a pointer, an error is returned. If the data in the +// database record cannot be stored in result because of type differences, an +// UnmarshalTypeError is returned. If the database is invalid or otherwise +// cannot be read, an InvalidDatabaseError is returned. +func (r *Reader) LookupNetwork( + ip net.IP, + result any, +) (network *net.IPNet, ok bool, err error) { + if r.buffer == nil { + return nil, false, errors.New("cannot call Lookup on a closed database") + } + pointer, prefixLength, ip, err := r.lookupPointer(ip) + + network = r.cidr(ip, prefixLength) + if pointer == 0 || err != nil { + return network, false, err + } + + return network, true, r.retrieveData(pointer, result) +} + // LookupOffset maps an argument net.IP to a corresponding record offset in the // database. NotFound is returned if no such record is found, and a record may // otherwise be extracted by passing the returned offset to Decode. LookupOffset // is an advanced API, which exists to provide clients with a means to cache // previously-decoded records. -func (r *Reader) LookupOffset(ipAddress net.IP) (uintptr, error) { - pointer, err := r.lookupPointer(ipAddress) +func (r *Reader) LookupOffset(ip net.IP) (uintptr, error) { + if r.buffer == nil { + return 0, errors.New("cannot call LookupOffset on a closed database") + } + pointer, _, _, err := r.lookupPointer(ip) if pointer == 0 || err != nil { return NotFound, err } return r.resolveDataPointer(pointer) } +func (r *Reader) cidr(ip net.IP, prefixLength int) *net.IPNet { + // This is necessary as the node that the IPv4 start is at may + // be at a bit depth that is less that 96, i.e., ipv4Start points + // to a leaf node. For instance, if a record was inserted at ::/8, + // the ipv4Start would point directly at the leaf node for the + // record and would have a bit depth of 8. This would not happen + // with databases currently distributed by MaxMind as all of them + // have an IPv4 subtree that is greater than a single node. + if r.Metadata.IPVersion == 6 && + len(ip) == net.IPv4len && + r.ipv4StartBitDepth != 96 { + return &net.IPNet{IP: net.ParseIP("::"), Mask: net.CIDRMask(r.ipv4StartBitDepth, 128)} + } + + mask := net.CIDRMask(prefixLength, len(ip)*8) + return &net.IPNet{IP: ip.Mask(mask), Mask: mask} +} + // Decode the record at |offset| into |result|. The result value pointed to // must be a data value that corresponds to a record in the database. This may // include a struct representation of the data, a map capable of holding the -// data or an empty interface{} value. +// data or an empty any value. // // If result is a pointer to a struct, the struct need not include a field // for every value that may be in the database. If a field is not present in @@ -143,103 +217,93 @@ func (r *Reader) LookupOffset(ipAddress net.IP) (uintptr, error) { // the City database, all records of the same country will reference a // single representative record for that country. This uintptr behavior allows // clients to leverage this normalization in their own sub-record caching. -func (r *Reader) Decode(offset uintptr, result interface{}) error { +func (r *Reader) Decode(offset uintptr, result any) error { + if r.buffer == nil { + return errors.New("cannot call Decode on a closed database") + } + return r.decode(offset, result) +} + +func (r *Reader) decode(offset uintptr, result any) error { rv := reflect.ValueOf(result) if rv.Kind() != reflect.Ptr || rv.IsNil() { return errors.New("result param must be a pointer") } - _, err := r.decoder.decode(uint(offset), reflect.ValueOf(result), 0) + if dser, ok := result.(deserializer); ok { + _, err := r.decoder.decodeToDeserializer(uint(offset), dser, 0, false) + return err + } + + _, err := r.decoder.decode(uint(offset), rv, 0) return err } -func (r *Reader) lookupPointer(ipAddress net.IP) (uint, error) { - if ipAddress == nil { - return 0, errors.New("ipAddress passed to Lookup cannot be nil") +func (r *Reader) lookupPointer(ip net.IP) (uint, int, net.IP, error) { + if ip == nil { + return 0, 0, nil, errors.New("IP passed to Lookup cannot be nil") } - ipV4Address := ipAddress.To4() + ipV4Address := ip.To4() if ipV4Address != nil { - ipAddress = ipV4Address + ip = ipV4Address } - if len(ipAddress) == 16 && r.Metadata.IPVersion == 4 { - return 0, fmt.Errorf("error looking up '%s': you attempted to look up an IPv6 address in an IPv4-only database", ipAddress.String()) + if len(ip) == 16 && r.Metadata.IPVersion == 4 { + return 0, 0, ip, fmt.Errorf( + "error looking up '%s': you attempted to look up an IPv6 address in an IPv4-only database", + ip.String(), + ) } - return r.findAddressInTree(ipAddress) -} - -func (r *Reader) findAddressInTree(ipAddress net.IP) (uint, error) { - - bitCount := uint(len(ipAddress) * 8) + bitCount := uint(len(ip) * 8) var node uint if bitCount == 32 { node = r.ipv4Start } + node, prefixLength := r.traverseTree(ip, node, bitCount) nodeCount := r.Metadata.NodeCount - - for i := uint(0); i < bitCount && node < nodeCount; i++ { - bit := uint(1) & (uint(ipAddress[i>>3]) >> (7 - (i % 8))) - - var err error - node, err = r.readNode(node, bit) - if err != nil { - return 0, err - } - } if node == nodeCount { // Record is empty - return 0, nil + return 0, prefixLength, ip, nil } else if node > nodeCount { - return node, nil + return node, prefixLength, ip, nil } - return 0, newInvalidDatabaseError("invalid node in search tree") + return 0, prefixLength, ip, newInvalidDatabaseError("invalid node in search tree") } -func (r *Reader) readNode(nodeNumber uint, index uint) (uint, error) { - RecordSize := r.Metadata.RecordSize +func (r *Reader) traverseTree(ip net.IP, node, bitCount uint) (uint, int) { + nodeCount := r.Metadata.NodeCount - baseOffset := nodeNumber * RecordSize / 4 + i := uint(0) + for ; i < bitCount && node < nodeCount; i++ { + bit := uint(1) & (uint(ip[i>>3]) >> (7 - (i % 8))) - var nodeBytes []byte - var prefix uint - switch RecordSize { - case 24: - offset := baseOffset + index*3 - nodeBytes = r.buffer[offset : offset+3] - case 28: - prefix = uint(r.buffer[baseOffset+3]) - if index != 0 { - prefix &= 0x0F + offset := node * r.nodeOffsetMult + if bit == 0 { + node = r.nodeReader.readLeft(offset) } else { - prefix = (0xF0 & prefix) >> 4 + node = r.nodeReader.readRight(offset) } - offset := baseOffset + index*4 - nodeBytes = r.buffer[offset : offset+3] - case 32: - offset := baseOffset + index*4 - nodeBytes = r.buffer[offset : offset+4] - default: - return 0, newInvalidDatabaseError("unknown record size: %d", RecordSize) } - return uintFromBytes(prefix, nodeBytes), nil + + return node, int(i) } -func (r *Reader) retrieveData(pointer uint, result interface{}) error { +func (r *Reader) retrieveData(pointer uint, result any) error { offset, err := r.resolveDataPointer(pointer) if err != nil { return err } - return r.Decode(offset, result) + return r.decode(offset, result) } func (r *Reader) resolveDataPointer(pointer uint) (uintptr, error) { - var resolved = uintptr(pointer - r.Metadata.NodeCount - dataSectionSeparatorSize) + resolved := uintptr(pointer - r.Metadata.NodeCount - dataSectionSeparatorSize) - if resolved > uintptr(len(r.buffer)) { + if resolved >= uintptr(len(r.buffer)) { return 0, newInvalidDatabaseError("the MaxMind DB file's search tree is corrupt") } return resolved, nil diff --git a/reader_appengine.go b/reader_appengine.go deleted file mode 100644 index 631e19532ee18eb317d0b902d52ca2a45aaa17db..0000000000000000000000000000000000000000 --- a/reader_appengine.go +++ /dev/null @@ -1,26 +0,0 @@ -// +build appengine - -package maxminddb - -import "io/ioutil" - -// Open takes a string path to a MaxMind DB file and returns a Reader -// structure or an error. The database file is opened using a memory map, -// except on Google App Engine where mmap is not supported; there the database -// is loaded into memory. Use the Close method on the Reader object to return -// the resources to the system. -func Open(file string) (*Reader, error) { - bytes, err := ioutil.ReadFile(file) - if err != nil { - return nil, err - } - - return FromBytes(bytes) -} - -// Close unmaps the database file from virtual memory and returns the -// resources to the system. If called on a Reader opened using FromBytes -// or Open on Google App Engine, this method does nothing. -func (r *Reader) Close() error { - return nil -} diff --git a/reader_memory.go b/reader_memory.go new file mode 100644 index 0000000000000000000000000000000000000000..4ebb3473d257ffdb5fdc433e18cffaf513d749de --- /dev/null +++ b/reader_memory.go @@ -0,0 +1,26 @@ +//go:build appengine || plan9 || js || wasip1 || wasi +// +build appengine plan9 js wasip1 wasi + +package maxminddb + +import "io/ioutil" + +// Open takes a string path to a MaxMind DB file and returns a Reader +// structure or an error. The database file is opened using a memory map +// on supported platforms. On platforms without memory map support, such +// as WebAssembly or Google App Engine, the database is loaded into memory. +// Use the Close method on the Reader object to return the resources to the system. +func Open(file string) (*Reader, error) { + bytes, err := ioutil.ReadFile(file) + if err != nil { + return nil, err + } + + return FromBytes(bytes) +} + +// Close returns the resources used by the database to the system. +func (r *Reader) Close() error { + r.buffer = nil + return nil +} diff --git a/reader_mmap.go b/reader_mmap.go new file mode 100644 index 0000000000000000000000000000000000000000..1d083019ee6f1b4f23e9bf8be098ed3ae8776671 --- /dev/null +++ b/reader_mmap.go @@ -0,0 +1,64 @@ +//go:build !appengine && !plan9 && !js && !wasip1 && !wasi +// +build !appengine,!plan9,!js,!wasip1,!wasi + +package maxminddb + +import ( + "os" + "runtime" +) + +// Open takes a string path to a MaxMind DB file and returns a Reader +// structure or an error. The database file is opened using a memory map +// on supported platforms. On platforms without memory map support, such +// as WebAssembly or Google App Engine, the database is loaded into memory. +// Use the Close method on the Reader object to return the resources to the system. +func Open(file string) (*Reader, error) { + mapFile, err := os.Open(file) + if err != nil { + _ = mapFile.Close() + return nil, err + } + + stats, err := mapFile.Stat() + if err != nil { + _ = mapFile.Close() + return nil, err + } + + fileSize := int(stats.Size()) + mmap, err := mmap(int(mapFile.Fd()), fileSize) + if err != nil { + _ = mapFile.Close() + return nil, err + } + + if err := mapFile.Close(); err != nil { + //nolint:errcheck // we prefer to return the original error + munmap(mmap) + return nil, err + } + + reader, err := FromBytes(mmap) + if err != nil { + //nolint:errcheck // we prefer to return the original error + munmap(mmap) + return nil, err + } + + reader.hasMappedFile = true + runtime.SetFinalizer(reader, (*Reader).Close) + return reader, nil +} + +// Close returns the resources used by the database to the system. +func (r *Reader) Close() error { + var err error + if r.hasMappedFile { + runtime.SetFinalizer(r, nil) + r.hasMappedFile = false + err = munmap(r.buffer) + } + r.buffer = nil + return err +} diff --git a/reader_other.go b/reader_other.go deleted file mode 100644 index b611a9561262b7bc9404faf69eee9e232b6093a1..0000000000000000000000000000000000000000 --- a/reader_other.go +++ /dev/null @@ -1,61 +0,0 @@ -// +build !appengine - -package maxminddb - -import ( - "os" - "runtime" -) - -// Open takes a string path to a MaxMind DB file and returns a Reader -// structure or an error. The database file is opened using a memory map, -// except on Google App Engine where mmap is not supported; there the database -// is loaded into memory. Use the Close method on the Reader object to return -// the resources to the system. -func Open(file string) (*Reader, error) { - mapFile, err := os.Open(file) - if err != nil { - return nil, err - } - defer func() { - if rerr := mapFile.Close(); rerr != nil { - err = rerr - } - }() - - stats, err := mapFile.Stat() - if err != nil { - return nil, err - } - - fileSize := int(stats.Size()) - mmap, err := mmap(int(mapFile.Fd()), fileSize) - if err != nil { - return nil, err - } - - reader, err := FromBytes(mmap) - if err != nil { - if err2 := munmap(mmap); err2 != nil { - // failing to unmap the file is probably the more severe error - return nil, err2 - } - return nil, err - } - - reader.hasMappedFile = true - runtime.SetFinalizer(reader, (*Reader).Close) - return reader, err -} - -// Close unmaps the database file from virtual memory and returns the -// resources to the system. If called on a Reader opened using FromBytes -// or Open on Google App Engine, this method does nothing. -func (r *Reader) Close() error { - if !r.hasMappedFile { - return nil - } - runtime.SetFinalizer(r, nil) - r.hasMappedFile = false - return munmap(r.buffer) -} diff --git a/reader_test.go b/reader_test.go index a0909c0f914ac82c95d90b1f358d18ac9800862d..f40541eeee5a488bf379f00fba957a62727ac86e 100644 --- a/reader_test.go +++ b/reader_test.go @@ -3,10 +3,11 @@ package maxminddb import ( "errors" "fmt" - "io/ioutil" "math/big" "math/rand" "net" + "os" + "path/filepath" "testing" "time" @@ -17,9 +18,13 @@ import ( func TestReader(t *testing.T) { for _, recordSize := range []uint{24, 28, 32} { for _, ipVersion := range []uint{4, 6} { - fileName := fmt.Sprintf("test-data/test-data/MaxMind-DB-test-ipv%d-%d.mmdb", ipVersion, recordSize) + fileName := fmt.Sprintf( + testFile("MaxMind-DB-test-ipv%d-%d.mmdb"), + ipVersion, + recordSize, + ) reader, err := Open(fileName) - require.Nil(t, err, "unexpected error while opening database: %v", err) + require.NoError(t, err, "unexpected error while opening database: %v", err) checkMetadata(t, reader, ipVersion, recordSize) if ipVersion == 4 { @@ -34,10 +39,15 @@ func TestReader(t *testing.T) { func TestReaderBytes(t *testing.T) { for _, recordSize := range []uint{24, 28, 32} { for _, ipVersion := range []uint{4, 6} { - fileName := fmt.Sprintf("test-data/test-data/MaxMind-DB-test-ipv%d-%d.mmdb", ipVersion, recordSize) - bytes, _ := ioutil.ReadFile(fileName) + fileName := fmt.Sprintf( + testFile("MaxMind-DB-test-ipv%d-%d.mmdb"), + ipVersion, + recordSize, + ) + bytes, err := os.ReadFile(fileName) + require.NoError(t, err) reader, err := FromBytes(bytes) - require.Nil(t, err, "unexpected error while opening bytes: %v", err) + require.NoError(t, err, "unexpected error while opening bytes: %v", err) checkMetadata(t, reader, ipVersion, recordSize) @@ -50,98 +60,257 @@ func TestReaderBytes(t *testing.T) { } } +func TestLookupNetwork(t *testing.T) { + bigInt := new(big.Int) + bigInt.SetString("1329227995784915872903807060280344576", 10) + decoderRecord := map[string]any{ + "array": []any{ + uint64(1), + uint64(2), + uint64(3), + }, + "boolean": true, + "bytes": []uint8{ + 0x0, + 0x0, + 0x0, + 0x2a, + }, + "double": 42.123456, + "float": float32(1.1), + "int32": -268435456, + "map": map[string]any{ + "mapX": map[string]any{ + "arrayX": []any{ + uint64(0x7), + uint64(0x8), + uint64(0x9), + }, + "utf8_stringX": "hello", + }, + }, + "uint128": bigInt, + "uint16": uint64(0x64), + "uint32": uint64(0x10000000), + "uint64": uint64(0x1000000000000000), + "utf8_string": "unicode! ☯ - ♫", + } + + tests := []struct { + IP net.IP + DBFile string + ExpectedCIDR string + ExpectedRecord any + ExpectedOK bool + }{ + { + IP: net.ParseIP("1.1.1.1"), + DBFile: "MaxMind-DB-test-ipv6-32.mmdb", + ExpectedCIDR: "1.0.0.0/8", + ExpectedRecord: nil, + ExpectedOK: false, + }, + { + IP: net.ParseIP("::1:ffff:ffff"), + DBFile: "MaxMind-DB-test-ipv6-24.mmdb", + ExpectedCIDR: "::1:ffff:ffff/128", + ExpectedRecord: map[string]any{"ip": "::1:ffff:ffff"}, + ExpectedOK: true, + }, + { + IP: net.ParseIP("::2:0:1"), + DBFile: "MaxMind-DB-test-ipv6-24.mmdb", + ExpectedCIDR: "::2:0:0/122", + ExpectedRecord: map[string]any{"ip": "::2:0:0"}, + ExpectedOK: true, + }, + { + IP: net.ParseIP("1.1.1.1"), + DBFile: "MaxMind-DB-test-ipv4-24.mmdb", + ExpectedCIDR: "1.1.1.1/32", + ExpectedRecord: map[string]any{"ip": "1.1.1.1"}, + ExpectedOK: true, + }, + { + IP: net.ParseIP("1.1.1.3"), + DBFile: "MaxMind-DB-test-ipv4-24.mmdb", + ExpectedCIDR: "1.1.1.2/31", + ExpectedRecord: map[string]any{"ip": "1.1.1.2"}, + ExpectedOK: true, + }, + { + IP: net.ParseIP("1.1.1.3"), + DBFile: "MaxMind-DB-test-decoder.mmdb", + ExpectedCIDR: "1.1.1.0/24", + ExpectedRecord: decoderRecord, + ExpectedOK: true, + }, + { + IP: net.ParseIP("::ffff:1.1.1.128"), + DBFile: "MaxMind-DB-test-decoder.mmdb", + ExpectedCIDR: "1.1.1.0/24", + ExpectedRecord: decoderRecord, + ExpectedOK: true, + }, + { + IP: net.ParseIP("::1.1.1.128"), + DBFile: "MaxMind-DB-test-decoder.mmdb", + ExpectedCIDR: "::101:100/120", + ExpectedRecord: decoderRecord, + ExpectedOK: true, + }, + { + IP: net.ParseIP("200.0.2.1"), + DBFile: "MaxMind-DB-no-ipv4-search-tree.mmdb", + ExpectedCIDR: "::/64", + ExpectedRecord: "::0/64", + ExpectedOK: true, + }, + { + IP: net.ParseIP("::200.0.2.1"), + DBFile: "MaxMind-DB-no-ipv4-search-tree.mmdb", + ExpectedCIDR: "::/64", + ExpectedRecord: "::0/64", + ExpectedOK: true, + }, + { + IP: net.ParseIP("0:0:0:0:ffff:ffff:ffff:ffff"), + DBFile: "MaxMind-DB-no-ipv4-search-tree.mmdb", + ExpectedCIDR: "::/64", + ExpectedRecord: "::0/64", + ExpectedOK: true, + }, + { + IP: net.ParseIP("ef00::"), + DBFile: "MaxMind-DB-no-ipv4-search-tree.mmdb", + ExpectedCIDR: "8000::/1", + ExpectedRecord: nil, + ExpectedOK: false, + }, + } + + for _, test := range tests { + t.Run(fmt.Sprintf("%s - %s", test.DBFile, test.IP), func(t *testing.T) { + var record any + reader, err := Open(testFile(test.DBFile)) + require.NoError(t, err) + + network, ok, err := reader.LookupNetwork(test.IP, &record) + require.NoError(t, err) + assert.Equal(t, test.ExpectedOK, ok) + assert.Equal(t, test.ExpectedCIDR, network.String()) + assert.Equal(t, test.ExpectedRecord, record) + }) + } +} + func TestDecodingToInterface(t *testing.T) { - reader, err := Open("test-data/test-data/MaxMind-DB-test-decoder.mmdb") - assert.Nil(t, err, "unexpected error while opening database: %v", err) + reader, err := Open(testFile("MaxMind-DB-test-decoder.mmdb")) + require.NoError(t, err, "unexpected error while opening database: %v", err) - var recordInterface interface{} + var recordInterface any err = reader.Lookup(net.ParseIP("::1.1.1.0"), &recordInterface) - require.Nil(t, err, "unexpected error while doing lookup: %v", err) - - record := recordInterface.(map[string]interface{}) - assert.Equal(t, record["array"], []interface{}{uint64(1), uint64(2), uint64(3)}) - assert.Equal(t, record["boolean"], true) - assert.Equal(t, record["bytes"], []byte{0x00, 0x00, 0x00, 0x2a}) - assert.Equal(t, record["double"], 42.123456) - assert.Equal(t, record["float"], float32(1.1)) - assert.Equal(t, record["int32"], -268435456) - assert.Equal(t, record["map"], - map[string]interface{}{ - "mapX": map[string]interface{}{ - "arrayX": []interface{}{uint64(7), uint64(8), uint64(9)}, - "utf8_stringX": "hello", - }}) + require.NoError(t, err, "unexpected error while doing lookup: %v", err) + + checkDecodingToInterface(t, recordInterface) +} + +func TestMetadataPointer(t *testing.T) { + _, err := Open(testFile("MaxMind-DB-test-metadata-pointers.mmdb")) + require.NoError(t, err, "unexpected error while opening database: %v", err) +} - assert.Equal(t, record["uint16"], uint64(100)) - assert.Equal(t, record["uint32"], uint64(268435456)) - assert.Equal(t, record["uint64"], uint64(1152921504606846976)) - assert.Equal(t, record["utf8_string"], "unicode! ☯ - ♫") +func checkDecodingToInterface(t *testing.T, recordInterface any) { + record := recordInterface.(map[string]any) + assert.Equal(t, []any{uint64(1), uint64(2), uint64(3)}, record["array"]) + assert.Equal(t, true, record["boolean"]) + assert.Equal(t, []byte{0x00, 0x00, 0x00, 0x2a}, record["bytes"]) + assert.Equal(t, 42.123456, record["double"]) + assert.Equal(t, float32(1.1), record["float"]) + assert.Equal(t, -268435456, record["int32"]) + assert.Equal(t, + map[string]any{ + "mapX": map[string]any{ + "arrayX": []any{uint64(7), uint64(8), uint64(9)}, + "utf8_stringX": "hello", + }, + }, + record["map"], + ) + + assert.Equal(t, uint64(100), record["uint16"]) + assert.Equal(t, uint64(268435456), record["uint32"]) + assert.Equal(t, uint64(1152921504606846976), record["uint64"]) + assert.Equal(t, "unicode! ☯ - ♫", record["utf8_string"]) bigInt := new(big.Int) bigInt.SetString("1329227995784915872903807060280344576", 10) - assert.Equal(t, record["uint128"], bigInt) + assert.Equal(t, bigInt, record["uint128"]) } type TestType struct { - Array []uint `maxminddb:"array"` - Boolean bool `maxminddb:"boolean"` - Bytes []byte `maxminddb:"bytes"` - Double float64 `maxminddb:"double"` - Float float32 `maxminddb:"float"` - Int32 int32 `maxminddb:"int32"` - Map map[string]interface{} `maxminddb:"map"` - Uint16 uint16 `maxminddb:"uint16"` - Uint32 uint32 `maxminddb:"uint32"` - Uint64 uint64 `maxminddb:"uint64"` - Uint128 big.Int `maxminddb:"uint128"` - Utf8String string `maxminddb:"utf8_string"` + Array []uint `maxminddb:"array"` + Boolean bool `maxminddb:"boolean"` + Bytes []byte `maxminddb:"bytes"` + Double float64 `maxminddb:"double"` + Float float32 `maxminddb:"float"` + Int32 int32 `maxminddb:"int32"` + Map map[string]any `maxminddb:"map"` + Uint16 uint16 `maxminddb:"uint16"` + Uint32 uint32 `maxminddb:"uint32"` + Uint64 uint64 `maxminddb:"uint64"` + Uint128 big.Int `maxminddb:"uint128"` + Utf8String string `maxminddb:"utf8_string"` } func TestDecoder(t *testing.T) { - reader, err := Open("test-data/test-data/MaxMind-DB-test-decoder.mmdb") - require.Nil(t, err) + reader, err := Open(testFile("MaxMind-DB-test-decoder.mmdb")) + require.NoError(t, err) verify := func(result TestType) { - assert.Equal(t, result.Array, []uint{uint(1), uint(2), uint(3)}) - assert.Equal(t, result.Boolean, true) - assert.Equal(t, result.Bytes, []byte{0x00, 0x00, 0x00, 0x2a}) - assert.Equal(t, result.Double, 42.123456) - assert.Equal(t, result.Float, float32(1.1)) - assert.Equal(t, result.Int32, int32(-268435456)) - - assert.Equal(t, result.Map, - map[string]interface{}{ - "mapX": map[string]interface{}{ - "arrayX": []interface{}{uint64(7), uint64(8), uint64(9)}, + assert.Equal(t, []uint{uint(1), uint(2), uint(3)}, result.Array) + assert.Equal(t, true, result.Boolean) + assert.Equal(t, []byte{0x00, 0x00, 0x00, 0x2a}, result.Bytes) + assert.Equal(t, 42.123456, result.Double) + assert.Equal(t, float32(1.1), result.Float) + assert.Equal(t, int32(-268435456), result.Int32) + + assert.Equal(t, + map[string]any{ + "mapX": map[string]any{ + "arrayX": []any{uint64(7), uint64(8), uint64(9)}, "utf8_stringX": "hello", - }}) - - assert.Equal(t, result.Uint16, uint16(100)) - assert.Equal(t, result.Uint32, uint32(268435456)) - assert.Equal(t, result.Uint64, uint64(1152921504606846976)) - assert.Equal(t, result.Utf8String, "unicode! ☯ - ♫") + }, + }, + result.Map, + ) + + assert.Equal(t, uint16(100), result.Uint16) + assert.Equal(t, uint32(268435456), result.Uint32) + assert.Equal(t, uint64(1152921504606846976), result.Uint64) + assert.Equal(t, "unicode! ☯ - ♫", result.Utf8String) bigInt := new(big.Int) bigInt.SetString("1329227995784915872903807060280344576", 10) - assert.Equal(t, &result.Uint128, bigInt) + assert.Equal(t, bigInt, &result.Uint128) } { // Directly lookup and decode. var result TestType - assert.Nil(t, reader.Lookup(net.ParseIP("::1.1.1.0"), &result)) + require.NoError(t, reader.Lookup(net.ParseIP("::1.1.1.0"), &result)) verify(result) } { // Lookup record offset, then Decode. var result TestType offset, err := reader.LookupOffset(net.ParseIP("::1.1.1.0")) - assert.Nil(t, err) - assert.NotEqual(t, offset, NotFound) + require.NoError(t, err) + assert.NotEqual(t, NotFound, offset) - assert.Nil(t, reader.Decode(offset, &result)) + assert.NoError(t, reader.Decode(offset, &result)) verify(result) } - assert.Nil(t, reader.Close()) + assert.NoError(t, reader.Close()) } type TestInterface interface { @@ -155,22 +324,26 @@ func (t *TestType) method() bool { func TestStructInterface(t *testing.T) { var result TestInterface = &TestType{} - reader, err := Open("test-data/test-data/MaxMind-DB-test-decoder.mmdb") - require.Nil(t, err) + reader, err := Open(testFile("MaxMind-DB-test-decoder.mmdb")) + require.NoError(t, err) - require.Nil(t, reader.Lookup(net.ParseIP("::1.1.1.0"), &result)) + require.NoError(t, reader.Lookup(net.ParseIP("::1.1.1.0"), &result)) - assert.Equal(t, result.method(), true) + assert.Equal(t, true, result.method()) } func TestNonEmptyNilInterface(t *testing.T) { var result TestInterface - reader, err := Open("test-data/test-data/MaxMind-DB-test-decoder.mmdb") - require.Nil(t, err) + reader, err := Open(testFile("MaxMind-DB-test-decoder.mmdb")) + require.NoError(t, err) err = reader.Lookup(net.ParseIP("::1.1.1.0"), &result) - assert.Equal(t, err.Error(), "maxminddb: cannot unmarshal map into type maxminddb.TestInterface") + assert.Equal( + t, + "maxminddb: cannot unmarshal map into type maxminddb.TestInterface", + err.Error(), + ) } type CityTraits struct { @@ -183,12 +356,12 @@ type City struct { func TestEmbeddedStructAsInterface(t *testing.T) { var city City - var result interface{} = city.Traits + var result any = city.Traits - db, err := Open("test-data/test-data/GeoIP2-ISP-Test.mmdb") - require.Nil(t, err) + db, err := Open(testFile("GeoIP2-ISP-Test.mmdb")) + require.NoError(t, err) - assert.Nil(t, db.Lookup(net.ParseIP("1.128.0.0"), &result)) + assert.NoError(t, db.Lookup(net.ParseIP("1.128.0.0"), &result)) } type BoolInterface interface { @@ -205,15 +378,16 @@ type ValueTypeTestType struct { Boolean BoolInterface `maxminddb:"boolean"` } -func TesValueTypeInterface(t *testing.T) { +func TestValueTypeInterface(t *testing.T) { var result ValueTypeTestType result.Boolean = Bool(false) - reader, err := Open("test-data/test-data/MaxMind-DB-test-decoder.mmdb") - require.Nil(t, err) - require.Nil(t, reader.Lookup(net.ParseIP("::1.1.1.0"), &result)) + reader, err := Open(testFile("MaxMind-DB-test-decoder.mmdb")) + require.NoError(t, err) - assert.Equal(t, result.Boolean.true(), true) + // although it would be nice to support cases like this, I am not sure it + // is possible to do so in a general way. + assert.Error(t, reader.Lookup(net.ParseIP("::1.1.1.0"), &result)) } type NestedMapX struct { @@ -226,6 +400,7 @@ type NestedPointerMapX struct { type PointerMap struct { MapX struct { + Ignored string NestedMapX *NestedPointerMapX } `maxminddb:"mapX"` @@ -249,43 +424,104 @@ type TestPointerType struct { } func TestComplexStructWithNestingAndPointer(t *testing.T) { - reader, err := Open("test-data/test-data/MaxMind-DB-test-decoder.mmdb") - assert.Nil(t, err) + reader, err := Open(testFile("MaxMind-DB-test-decoder.mmdb")) + assert.NoError(t, err) var result TestPointerType err = reader.Lookup(net.ParseIP("::1.1.1.0"), &result) - require.Nil(t, err) + require.NoError(t, err) - assert.Equal(t, *result.Array, []uint{uint(1), uint(2), uint(3)}) - assert.Equal(t, *result.Boolean, true) - assert.Equal(t, *result.Bytes, []byte{0x00, 0x00, 0x00, 0x2a}) - assert.Equal(t, *result.Double, 42.123456) - assert.Equal(t, *result.Float, float32(1.1)) - assert.Equal(t, *result.Int32, int32(-268435456)) + assert.Equal(t, []uint{uint(1), uint(2), uint(3)}, *result.Array) + assert.Equal(t, true, *result.Boolean) + assert.Equal(t, []byte{0x00, 0x00, 0x00, 0x2a}, *result.Bytes) + assert.Equal(t, 42.123456, *result.Double) + assert.Equal(t, float32(1.1), *result.Float) + assert.Equal(t, int32(-268435456), *result.Int32) - assert.Equal(t, result.Map.MapX.ArrayX, []int{7, 8, 9}) + assert.Equal(t, []int{7, 8, 9}, result.Map.MapX.ArrayX) - assert.Equal(t, result.Map.MapX.UTF8StringX, "hello") + assert.Equal(t, "hello", result.Map.MapX.UTF8StringX) - assert.Equal(t, *result.Uint16, uint16(100)) - assert.Equal(t, *result.Uint32, uint32(268435456)) - assert.Equal(t, **result.Uint64, uint64(1152921504606846976)) - assert.Equal(t, *result.Utf8String, "unicode! ☯ - ♫") + assert.Equal(t, uint16(100), *result.Uint16) + assert.Equal(t, uint32(268435456), *result.Uint32) + assert.Equal(t, uint64(1152921504606846976), **result.Uint64) + assert.Equal(t, "unicode! ☯ - ♫", *result.Utf8String) bigInt := new(big.Int) bigInt.SetString("1329227995784915872903807060280344576", 10) - assert.Equal(t, result.Uint128, bigInt) - - assert.Nil(t, reader.Close()) + assert.Equal(t, bigInt, result.Uint128) + + assert.NoError(t, reader.Close()) +} + +// See GitHub #115. +func TestNestedMapDecode(t *testing.T) { + db, err := Open(testFile("GeoIP2-Country-Test.mmdb")) + require.NoError(t, err) + + var r map[string]map[string]any + + require.NoError(t, db.Lookup(net.ParseIP("89.160.20.128"), &r)) + + assert.Equal( + t, + map[string]map[string]any{ + "continent": { + "code": "EU", + "geoname_id": uint64(6255148), + "names": map[string]any{ + "de": "Europa", + "en": "Europe", + "es": "Europa", + "fr": "Europe", + "ja": "ヨーロッパ", + "pt-BR": "Europa", + "ru": "Европа", + "zh-CN": "欧洲", + }, + }, + "country": { + "geoname_id": uint64(2661886), + "is_in_european_union": true, + "iso_code": "SE", + "names": map[string]any{ + "de": "Schweden", + "en": "Sweden", + "es": "Suecia", + "fr": "Suède", + "ja": "スウェーデン王国", + "pt-BR": "Suécia", + "ru": "Швеция", + "zh-CN": "瑞典", + }, + }, + "registered_country": { + "geoname_id": uint64(2921044), + "is_in_european_union": true, + "iso_code": "DE", + "names": map[string]any{ + "de": "Deutschland", + "en": "Germany", + "es": "Alemania", + "fr": "Allemagne", + "ja": "ドイツ連邦共和国", + "pt-BR": "Alemanha", + "ru": "Германия", + "zh-CN": "德国", + }, + }, + }, + r, + ) } func TestNestedOffsetDecode(t *testing.T) { - db, err := Open("test-data/test-data/GeoIP2-City-Test.mmdb") - require.Nil(t, err) + db, err := Open(testFile("GeoIP2-City-Test.mmdb")) + require.NoError(t, err) off, err := db.LookupOffset(net.ParseIP("81.2.69.142")) assert.NotEqual(t, off, NotFound) - require.Nil(t, err) + require.NoError(t, err) var root struct { CountryOffset uintptr `maxminddb:"country"` @@ -298,71 +534,75 @@ func TestNestedOffsetDecode(t *testing.T) { TimeZoneOffset uintptr `maxminddb:"time_zone"` } `maxminddb:"location"` } - assert.Nil(t, db.Decode(off, &root)) - assert.Equal(t, root.Location.Latitude, 51.5142) + assert.NoError(t, db.Decode(off, &root)) + assert.Equal(t, 51.5142, root.Location.Latitude) var longitude float64 - assert.Nil(t, db.Decode(root.Location.LongitudeOffset, &longitude)) - assert.Equal(t, longitude, -0.0931) + assert.NoError(t, db.Decode(root.Location.LongitudeOffset, &longitude)) + assert.Equal(t, -0.0931, longitude) var timeZone string - assert.Nil(t, db.Decode(root.Location.TimeZoneOffset, &timeZone)) - assert.Equal(t, timeZone, "Europe/London") + assert.NoError(t, db.Decode(root.Location.TimeZoneOffset, &timeZone)) + assert.Equal(t, "Europe/London", timeZone) var country struct { IsoCode string `maxminddb:"iso_code"` } - assert.Nil(t, db.Decode(root.CountryOffset, &country)) - assert.Equal(t, country.IsoCode, "GB") + assert.NoError(t, db.Decode(root.CountryOffset, &country)) + assert.Equal(t, "GB", country.IsoCode) - assert.Nil(t, db.Close()) + assert.NoError(t, db.Close()) } func TestDecodingUint16IntoInt(t *testing.T) { - reader, err := Open("test-data/test-data/MaxMind-DB-test-decoder.mmdb") - require.Nil(t, err, "unexpected error while opening database: %v", err) + reader, err := Open(testFile("MaxMind-DB-test-decoder.mmdb")) + require.NoError(t, err, "unexpected error while opening database: %v", err) var result struct { Uint16 int `maxminddb:"uint16"` } err = reader.Lookup(net.ParseIP("::1.1.1.0"), &result) - require.Nil(t, err) + require.NoError(t, err) - assert.Equal(t, result.Uint16, 100) + assert.Equal(t, 100, result.Uint16) } func TestIpv6inIpv4(t *testing.T) { - reader, err := Open("test-data/test-data/MaxMind-DB-test-ipv4-24.mmdb") - require.Nil(t, err, "unexpected error while opening database: %v", err) + reader, err := Open(testFile("MaxMind-DB-test-ipv4-24.mmdb")) + require.NoError(t, err, "unexpected error while opening database: %v", err) var result TestType err = reader.Lookup(net.ParseIP("2001::"), &result) var emptyResult TestType - assert.Equal(t, result, emptyResult) + assert.Equal(t, emptyResult, result) - expected := errors.New("error looking up '2001::': you attempted to look up an IPv6 address in an IPv4-only database") - assert.Equal(t, err, expected) - assert.Nil(t, reader.Close(), "error on close") + expected := errors.New( + "error looking up '2001::': you attempted to look up an IPv6 address in an IPv4-only database", + ) + assert.Equal(t, expected, err) + assert.NoError(t, reader.Close(), "error on close") } func TestBrokenDoubleDatabase(t *testing.T) { - reader, err := Open("test-data/test-data/GeoIP2-City-Test-Broken-Double-Format.mmdb") - require.Nil(t, err, "unexpected error while opening database: %v", err) + reader, err := Open(testFile("GeoIP2-City-Test-Broken-Double-Format.mmdb")) + require.NoError(t, err, "unexpected error while opening database: %v", err) - var result interface{} + var result any err = reader.Lookup(net.ParseIP("2001:220::"), &result) - expected := newInvalidDatabaseError("the MaxMind DB file's data section contains bad data (float 64 size of 2)") - assert.Equal(t, err, expected) - assert.Nil(t, reader.Close(), "error on close") + expected := newInvalidDatabaseError( + "the MaxMind DB file's data section contains bad data (float 64 size of 2)", + ) + assert.Equal(t, expected, err) + assert.NoError(t, reader.Close(), "error on close") } func TestInvalidNodeCountDatabase(t *testing.T) { - _, err := Open("test-data/test-data/GeoIP2-City-Test-Invalid-Node-Count.mmdb") + _, err := Open(testFile("GeoIP2-City-Test-Invalid-Node-Count.mmdb")) expected := newInvalidDatabaseError("the MaxMind DB contains invalid metadata") - assert.Equal(t, err, expected) + assert.Equal(t, expected, err) } func TestMissingDatabase(t *testing.T) { @@ -374,63 +614,81 @@ func TestMissingDatabase(t *testing.T) { func TestNonDatabase(t *testing.T) { reader, err := Open("README.md") assert.Nil(t, reader, "received reader when doing lookups on DB that doesn't exist") - assert.Equal(t, err.Error(), "error opening database: invalid MaxMind DB file") + assert.Equal(t, "error opening database: invalid MaxMind DB file", err.Error()) } func TestDecodingToNonPointer(t *testing.T) { - reader, _ := Open("test-data/test-data/MaxMind-DB-test-decoder.mmdb") + reader, err := Open(testFile("MaxMind-DB-test-decoder.mmdb")) + require.NoError(t, err) - var recordInterface interface{} - err := reader.Lookup(net.ParseIP("::1.1.1.0"), recordInterface) - assert.Equal(t, err.Error(), "result param must be a pointer") - assert.Nil(t, reader.Close(), "error on close") + var recordInterface any + err = reader.Lookup(net.ParseIP("::1.1.1.0"), recordInterface) + assert.Equal(t, "result param must be a pointer", err.Error()) + assert.NoError(t, reader.Close(), "error on close") } func TestNilLookup(t *testing.T) { - reader, _ := Open("test-data/test-data/MaxMind-DB-test-decoder.mmdb") + reader, err := Open(testFile("MaxMind-DB-test-decoder.mmdb")) + require.NoError(t, err) + + var recordInterface any + err = reader.Lookup(nil, recordInterface) + assert.Equal(t, "IP passed to Lookup cannot be nil", err.Error()) + assert.NoError(t, reader.Close(), "error on close") +} + +func TestUsingClosedDatabase(t *testing.T) { + reader, err := Open(testFile("MaxMind-DB-test-decoder.mmdb")) + require.NoError(t, err) + require.NoError(t, reader.Close()) + + var recordInterface any + + err = reader.Lookup(nil, recordInterface) + assert.Equal(t, "cannot call Lookup on a closed database", err.Error()) - var recordInterface interface{} - err := reader.Lookup(nil, recordInterface) - assert.Equal(t, err.Error(), "ipAddress passed to Lookup cannot be nil") - assert.Nil(t, reader.Close(), "error on close") + _, err = reader.LookupOffset(nil) + assert.Equal(t, "cannot call LookupOffset on a closed database", err.Error()) + + err = reader.Decode(0, recordInterface) + assert.Equal(t, "cannot call Decode on a closed database", err.Error()) } -func checkMetadata(t *testing.T, reader *Reader, ipVersion uint, recordSize uint) { +func checkMetadata(t *testing.T, reader *Reader, ipVersion, recordSize uint) { metadata := reader.Metadata - assert.Equal(t, metadata.BinaryFormatMajorVersion, uint(2)) + assert.Equal(t, uint(2), metadata.BinaryFormatMajorVersion) - assert.Equal(t, metadata.BinaryFormatMinorVersion, uint(0)) + assert.Equal(t, uint(0), metadata.BinaryFormatMinorVersion) assert.IsType(t, uint(0), metadata.BuildEpoch) - assert.Equal(t, metadata.DatabaseType, "Test") + assert.Equal(t, "Test", metadata.DatabaseType) assert.Equal(t, metadata.Description, map[string]string{ "en": "Test Database", "zh": "Test Database Chinese", }) - assert.Equal(t, metadata.IPVersion, ipVersion) - assert.Equal(t, metadata.Languages, []string{"en", "zh"}) + assert.Equal(t, ipVersion, metadata.IPVersion) + assert.Equal(t, []string{"en", "zh"}, metadata.Languages) if ipVersion == 4 { - assert.Equal(t, metadata.NodeCount, uint(164)) + assert.Equal(t, uint(164), metadata.NodeCount) } else { - assert.Equal(t, metadata.NodeCount, uint(416)) + assert.Equal(t, uint(416), metadata.NodeCount) } - assert.Equal(t, metadata.RecordSize, recordSize) + assert.Equal(t, recordSize, metadata.RecordSize) } func checkIpv4(t *testing.T, reader *Reader) { - for i := uint(0); i < 6; i++ { address := fmt.Sprintf("1.1.1.%d", uint(1)<<i) ip := net.ParseIP(address) var result map[string]string err := reader.Lookup(ip, &result) - assert.Nil(t, err, "unexpected error while doing lookup: %v", err) - assert.Equal(t, result, map[string]string{"ip": address}) + assert.NoError(t, err, "unexpected error while doing lookup: %v", err) + assert.Equal(t, map[string]string{"ip": address}, result) } pairs := map[string]string{ "1.1.1.3": "1.1.1.2", @@ -449,8 +707,8 @@ func checkIpv4(t *testing.T, reader *Reader) { var result map[string]string err := reader.Lookup(ip, &result) - assert.Nil(t, err, "unexpected error while doing lookup: %v", err) - assert.Equal(t, result, data) + assert.NoError(t, err, "unexpected error while doing lookup: %v", err) + assert.Equal(t, data, result) } for _, address := range []string{"1.1.1.33", "255.254.253.123"} { @@ -458,21 +716,22 @@ func checkIpv4(t *testing.T, reader *Reader) { var result map[string]string err := reader.Lookup(ip, &result) - assert.Nil(t, err, "unexpected error while doing lookup: %v", err) + assert.NoError(t, err, "unexpected error while doing lookup: %v", err) assert.Nil(t, result) } } func checkIpv6(t *testing.T, reader *Reader) { - - subnets := []string{"::1:ffff:ffff", "::2:0:0", - "::2:0:40", "::2:0:50", "::2:0:58"} + subnets := []string{ + "::1:ffff:ffff", "::2:0:0", + "::2:0:40", "::2:0:50", "::2:0:58", + } for _, address := range subnets { var result map[string]string err := reader.Lookup(net.ParseIP(address), &result) - assert.Nil(t, err, "unexpected error while doing lookup: %v", err) - assert.Equal(t, result, map[string]string{"ip": address}) + assert.NoError(t, err, "unexpected error while doing lookup: %v", err) + assert.Equal(t, map[string]string{"ip": address}, result) } pairs := map[string]string{ @@ -490,37 +749,160 @@ func checkIpv6(t *testing.T, reader *Reader) { data := map[string]string{"ip": valueAddress} var result map[string]string err := reader.Lookup(net.ParseIP(keyAddress), &result) - assert.Nil(t, err, "unexpected error while doing lookup: %v", err) - assert.Equal(t, result, data) + assert.NoError(t, err, "unexpected error while doing lookup: %v", err) + assert.Equal(t, data, result) } for _, address := range []string{"1.1.1.33", "255.254.253.123", "89fa::"} { var result map[string]string err := reader.Lookup(net.ParseIP(address), &result) - assert.Nil(t, err, "unexpected error while doing lookup: %v", err) + assert.NoError(t, err, "unexpected error while doing lookup: %v", err) assert.Nil(t, result) } } -func BenchmarkMaxMindDB(b *testing.B) { +func BenchmarkOpen(b *testing.B) { + var db *Reader + var err error + for i := 0; i < b.N; i++ { + db, err = Open("GeoLite2-City.mmdb") + if err != nil { + b.Error(err) + } + } + assert.NotNil(b, db) + assert.NoError(b, db.Close(), "error on close") +} + +func BenchmarkInterfaceLookup(b *testing.B) { db, err := Open("GeoLite2-City.mmdb") - assert.Nil(b, err) + require.NoError(b, err) + //nolint:gosec // this is a test r := rand.New(rand.NewSource(time.Now().UnixNano())) - var result interface{} + var result any - ip := make(net.IP, 4, 4) + ip := make(net.IP, 4) for i := 0; i < b.N; i++ { - randomIPv4Address(b, r, ip) + randomIPv4Address(r, ip) err = db.Lookup(ip, &result) - assert.Nil(b, err) + if err != nil { + b.Error(err) + } } - assert.Nil(b, db.Close(), "error on close") + assert.NoError(b, db.Close(), "error on close") +} + +func BenchmarkInterfaceLookupNetwork(b *testing.B) { + db, err := Open("GeoLite2-City.mmdb") + require.NoError(b, err) + + //nolint:gosec // this is a test + r := rand.New(rand.NewSource(time.Now().UnixNano())) + var result any + + ip := make(net.IP, 4) + for i := 0; i < b.N; i++ { + randomIPv4Address(r, ip) + _, _, err = db.LookupNetwork(ip, &result) + if err != nil { + b.Error(err) + } + } + assert.NoError(b, db.Close(), "error on close") +} + +type fullCity struct { + City struct { + GeoNameID uint `maxminddb:"geoname_id"` + Names map[string]string `maxminddb:"names"` + } `maxminddb:"city"` + Continent struct { + Code string `maxminddb:"code"` + GeoNameID uint `maxminddb:"geoname_id"` + Names map[string]string `maxminddb:"names"` + } `maxminddb:"continent"` + Country struct { + GeoNameID uint `maxminddb:"geoname_id"` + IsInEuropeanUnion bool `maxminddb:"is_in_european_union"` + IsoCode string `maxminddb:"iso_code"` + Names map[string]string `maxminddb:"names"` + } `maxminddb:"country"` + Location struct { + AccuracyRadius uint16 `maxminddb:"accuracy_radius"` + Latitude float64 `maxminddb:"latitude"` + Longitude float64 `maxminddb:"longitude"` + MetroCode uint `maxminddb:"metro_code"` + TimeZone string `maxminddb:"time_zone"` + } `maxminddb:"location"` + Postal struct { + Code string `maxminddb:"code"` + } `maxminddb:"postal"` + RegisteredCountry struct { + GeoNameID uint `maxminddb:"geoname_id"` + IsInEuropeanUnion bool `maxminddb:"is_in_european_union"` + IsoCode string `maxminddb:"iso_code"` + Names map[string]string `maxminddb:"names"` + } `maxminddb:"registered_country"` + RepresentedCountry struct { + GeoNameID uint `maxminddb:"geoname_id"` + IsInEuropeanUnion bool `maxminddb:"is_in_european_union"` + IsoCode string `maxminddb:"iso_code"` + Names map[string]string `maxminddb:"names"` + Type string `maxminddb:"type"` + } `maxminddb:"represented_country"` + Subdivisions []struct { + GeoNameID uint `maxminddb:"geoname_id"` + IsoCode string `maxminddb:"iso_code"` + Names map[string]string `maxminddb:"names"` + } `maxminddb:"subdivisions"` + Traits struct { + IsAnonymousProxy bool `maxminddb:"is_anonymous_proxy"` + IsSatelliteProvider bool `maxminddb:"is_satellite_provider"` + } `maxminddb:"traits"` +} + +func BenchmarkCityLookup(b *testing.B) { + db, err := Open("GeoLite2-City.mmdb") + require.NoError(b, err) + + //nolint:gosec // this is a test + r := rand.New(rand.NewSource(time.Now().UnixNano())) + var result fullCity + + ip := make(net.IP, 4) + for i := 0; i < b.N; i++ { + randomIPv4Address(r, ip) + err = db.Lookup(ip, &result) + if err != nil { + b.Error(err) + } + } + assert.NoError(b, db.Close(), "error on close") +} + +func BenchmarkCityLookupNetwork(b *testing.B) { + db, err := Open("GeoLite2-City.mmdb") + require.NoError(b, err) + + //nolint:gosec // this is a test + r := rand.New(rand.NewSource(time.Now().UnixNano())) + var result fullCity + + ip := make(net.IP, 4) + for i := 0; i < b.N; i++ { + randomIPv4Address(r, ip) + _, _, err = db.LookupNetwork(ip, &result) + if err != nil { + b.Error(err) + } + } + assert.NoError(b, db.Close(), "error on close") } func BenchmarkCountryCode(b *testing.B) { db, err := Open("GeoLite2-City.mmdb") - assert.Nil(b, err) + require.NoError(b, err) type MinCountry struct { Country struct { @@ -528,22 +910,29 @@ func BenchmarkCountryCode(b *testing.B) { } `maxminddb:"country"` } + //nolint:gosec // this is a test r := rand.New(rand.NewSource(0)) var result MinCountry - ip := make(net.IP, 4, 4) + ip := make(net.IP, 4) for i := 0; i < b.N; i++ { - randomIPv4Address(b, r, ip) + randomIPv4Address(r, ip) err = db.Lookup(ip, &result) - assert.Nil(b, err) + if err != nil { + b.Error(err) + } } - assert.Nil(b, db.Close(), "error on close") + assert.NoError(b, db.Close(), "error on close") } -func randomIPv4Address(b *testing.B, r *rand.Rand, ip []byte) { +func randomIPv4Address(r *rand.Rand, ip []byte) { num := r.Uint32() ip[0] = byte(num >> 24) ip[1] = byte(num >> 16) ip[2] = byte(num >> 8) ip[3] = byte(num) } + +func testFile(file string) string { + return filepath.Join("test-data", "test-data", file) +} diff --git a/set_zero_120.go b/set_zero_120.go new file mode 100644 index 0000000000000000000000000000000000000000..33b9dff9d959e668dc450e0aca696eb7dea38255 --- /dev/null +++ b/set_zero_120.go @@ -0,0 +1,10 @@ +//go:build go1.20 +// +build go1.20 + +package maxminddb + +import "reflect" + +func reflectSetZero(v reflect.Value) { + v.SetZero() +} diff --git a/set_zero_pre120.go b/set_zero_pre120.go new file mode 100644 index 0000000000000000000000000000000000000000..6639de73e612cdc3c9f10a3cee988910868f9316 --- /dev/null +++ b/set_zero_pre120.go @@ -0,0 +1,10 @@ +//go:build !go1.20 +// +build !go1.20 + +package maxminddb + +import "reflect" + +func reflectSetZero(v reflect.Value) { + v.Set(reflect.Zero(v.Type())) +} diff --git a/traverse.go b/traverse.go index f9b443c0dffc9a3d03a52dc5db68d17cf099e923..657e2c40c67e90b55e109178d8daebe177589566 100644 --- a/traverse.go +++ b/traverse.go @@ -1,6 +1,9 @@ package maxminddb -import "net" +import ( + "fmt" + "net" +) // Internal structure used to keep track of nodes we still need to visit. type netNode struct { @@ -11,77 +14,141 @@ type netNode struct { // Networks represents a set of subnets that we are iterating over. type Networks struct { - reader *Reader - nodes []netNode // Nodes we still have to visit. - lastNode netNode - err error + err error + reader *Reader + nodes []netNode + lastNode netNode + skipAliasedNetworks bool +} + +var ( + allIPv4 = &net.IPNet{IP: make(net.IP, 4), Mask: net.CIDRMask(0, 32)} + allIPv6 = &net.IPNet{IP: make(net.IP, 16), Mask: net.CIDRMask(0, 128)} +) + +// NetworksOption are options for Networks and NetworksWithin. +type NetworksOption func(*Networks) + +// SkipAliasedNetworks is an option for Networks and NetworksWithin that +// makes them not iterate over aliases of the IPv4 subtree in an IPv6 +// database, e.g., ::ffff:0:0/96, 2001::/32, and 2002::/16. +// +// You most likely want to set this. The only reason it isn't the default +// behavior is to provide backwards compatibility to existing users. +func SkipAliasedNetworks(networks *Networks) { + networks.skipAliasedNetworks = true } // Networks returns an iterator that can be used to traverse all networks in // the database. // // Please note that a MaxMind DB may map IPv4 networks into several locations -// in in an IPv6 database. This iterator will iterate over all of these -// locations separately. -func (r *Reader) Networks() *Networks { - s := 4 +// in an IPv6 database. This iterator will iterate over all of these locations +// separately. To only iterate over the IPv4 networks once, use the +// SkipAliasedNetworks option. +func (r *Reader) Networks(options ...NetworksOption) *Networks { + var networks *Networks if r.Metadata.IPVersion == 6 { - s = 16 + networks = r.NetworksWithin(allIPv6, options...) + } else { + networks = r.NetworksWithin(allIPv4, options...) + } + + return networks +} + +// NetworksWithin returns an iterator that can be used to traverse all networks +// in the database which are contained in a given network. +// +// Please note that a MaxMind DB may map IPv4 networks into several locations +// in an IPv6 database. This iterator will iterate over all of these locations +// separately. To only iterate over the IPv4 networks once, use the +// SkipAliasedNetworks option. +// +// If the provided network is contained within a network in the database, the +// iterator will iterate over exactly one network, the containing network. +func (r *Reader) NetworksWithin(network *net.IPNet, options ...NetworksOption) *Networks { + if r.Metadata.IPVersion == 4 && network.IP.To4() == nil { + return &Networks{ + err: fmt.Errorf( + "error getting networks with '%s': you attempted to use an IPv6 network in an IPv4-only database", + network.String(), + ), + } + } + + networks := &Networks{reader: r} + for _, option := range options { + option(networks) } - return &Networks{ - reader: r, - nodes: []netNode{ - { - ip: make(net.IP, s), - }, + + ip := network.IP + prefixLength, _ := network.Mask.Size() + + if r.Metadata.IPVersion == 6 && len(ip) == net.IPv4len { + if networks.skipAliasedNetworks { + ip = net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ip[0], ip[1], ip[2], ip[3]} + } else { + ip = ip.To16() + } + prefixLength += 96 + } + + pointer, bit := r.traverseTree(ip, 0, uint(prefixLength)) + networks.nodes = []netNode{ + { + ip: ip, + bit: uint(bit), + pointer: pointer, }, } + + return networks } // Next prepares the next network for reading with the Network method. It // returns true if there is another network to be processed and false if there // are no more networks or if there is an error. func (n *Networks) Next() bool { + if n.err != nil { + return false + } for len(n.nodes) > 0 { node := n.nodes[len(n.nodes)-1] n.nodes = n.nodes[:len(n.nodes)-1] - for { - if node.pointer < n.reader.Metadata.NodeCount { - ipRight := make(net.IP, len(node.ip)) - copy(ipRight, node.ip) - if len(ipRight) <= int(node.bit>>3) { - n.err = newInvalidDatabaseError( - "invalid search tree at %v/%v", ipRight, node.bit) - return false - } - ipRight[node.bit>>3] |= 1 << (7 - (node.bit % 8)) - - rightPointer, err := n.reader.readNode(node.pointer, 1) - if err != nil { - n.err = err - return false - } - - node.bit++ - n.nodes = append(n.nodes, netNode{ - pointer: rightPointer, - ip: ipRight, - bit: node.bit, - }) - - node.pointer, err = n.reader.readNode(node.pointer, 0) - if err != nil { - n.err = err - return false - } - - } else if node.pointer > n.reader.Metadata.NodeCount { + for node.pointer != n.reader.Metadata.NodeCount { + // This skips IPv4 aliases without hardcoding the networks that the writer + // currently aliases. + if n.skipAliasedNetworks && n.reader.ipv4Start != 0 && + node.pointer == n.reader.ipv4Start && !isInIPv4Subtree(node.ip) { + break + } + + if node.pointer > n.reader.Metadata.NodeCount { n.lastNode = node return true - } else { - break } + ipRight := make(net.IP, len(node.ip)) + copy(ipRight, node.ip) + if len(ipRight) <= int(node.bit>>3) { + n.err = newInvalidDatabaseError( + "invalid search tree at %v/%v", ipRight, node.bit) + return false + } + ipRight[node.bit>>3] |= 1 << (7 - (node.bit % 8)) + + offset := node.pointer * n.reader.nodeOffsetMult + rightPointer := n.reader.nodeReader.readRight(offset) + + node.bit++ + n.nodes = append(n.nodes, netNode{ + pointer: rightPointer, + ip: ipRight, + bit: node.bit, + }) + + node.pointer = n.reader.nodeReader.readLeft(offset) } } @@ -91,14 +158,29 @@ func (n *Networks) Next() bool { // Network returns the current network or an error if there is a problem // decoding the data for the network. It takes a pointer to a result value to // decode the network's data into. -func (n *Networks) Network(result interface{}) (*net.IPNet, error) { +func (n *Networks) Network(result any) (*net.IPNet, error) { + if n.err != nil { + return nil, n.err + } if err := n.reader.retrieveData(n.lastNode.pointer, result); err != nil { return nil, err } + ip := n.lastNode.ip + prefixLength := int(n.lastNode.bit) + + // We do this because uses of SkipAliasedNetworks expect the IPv4 networks + // to be returned as IPv4 networks. If we are not skipping aliased + // networks, then the user will get IPv4 networks from the ::FFFF:0:0/96 + // network as Go automatically converts those. + if n.skipAliasedNetworks && isInIPv4Subtree(ip) { + ip = ip[12:] + prefixLength -= 96 + } + return &net.IPNet{ - IP: n.lastNode.ip, - Mask: net.CIDRMask(int(n.lastNode.bit), len(n.lastNode.ip)*8), + IP: ip, + Mask: net.CIDRMask(prefixLength, len(ip)*8), }, nil } @@ -106,3 +188,17 @@ func (n *Networks) Network(result interface{}) (*net.IPNet, error) { func (n *Networks) Err() error { return n.err } + +// isInIPv4Subtree returns true if the IP is an IPv6 address in the database's +// IPv4 subtree. +func isInIPv4Subtree(ip net.IP) bool { + if len(ip) != 16 { + return false + } + for i := 0; i < 12; i++ { + if ip[i] != 0 { + return false + } + } + return true +} diff --git a/traverse_test.go b/traverse_test.go index 717eb70eda70ac2fd17f39944b8f6370159316dd..6a4162cba1ba105846b34e659918d654522e7585 100644 --- a/traverse_test.go +++ b/traverse_test.go @@ -2,6 +2,7 @@ package maxminddb import ( "fmt" + "net" "testing" "github.com/stretchr/testify/assert" @@ -11,10 +12,11 @@ import ( func TestNetworks(t *testing.T) { for _, recordSize := range []uint{24, 28, 32} { for _, ipVersion := range []uint{4, 6} { - fileName := fmt.Sprintf("test-data/test-data/MaxMind-DB-test-ipv%d-%d.mmdb", ipVersion, recordSize) + fileName := testFile( + fmt.Sprintf("MaxMind-DB-test-ipv%d-%d.mmdb", ipVersion, recordSize), + ) reader, err := Open(fileName) require.Nil(t, err, "unexpected error while opening database: %v", err) - defer reader.Close() n := reader.Networks() for n.Next() { @@ -28,21 +30,272 @@ func TestNetworks(t *testing.T) { ) } assert.Nil(t, n.Err()) + assert.NoError(t, reader.Close()) } } } func TestNetworksWithInvalidSearchTree(t *testing.T) { - reader, err := Open("test-data/test-data/MaxMind-DB-test-broken-search-tree-24.mmdb") + reader, err := Open(testFile("MaxMind-DB-test-broken-search-tree-24.mmdb")) require.Nil(t, err, "unexpected error while opening database: %v", err) - defer reader.Close() n := reader.Networks() for n.Next() { - var record interface{} + var record any _, err := n.Network(&record) assert.Nil(t, err) } assert.NotNil(t, n.Err(), "no error received when traversing an broken search tree") - assert.Equal(t, n.Err().Error(), "invalid search tree at 128.128.128.128/32") + assert.Equal(t, "invalid search tree at 128.128.128.128/32", n.Err().Error()) + + assert.NoError(t, reader.Close()) +} + +type networkTest struct { + Network string + Database string + Expected []string + Options []NetworksOption +} + +var tests = []networkTest{ + { + Network: "0.0.0.0/0", + Database: "ipv4", + Expected: []string{ + "1.1.1.1/32", + "1.1.1.2/31", + "1.1.1.4/30", + "1.1.1.8/29", + "1.1.1.16/28", + "1.1.1.32/32", + }, + }, + { + Network: "1.1.1.1/30", + Database: "ipv4", + Expected: []string{ + "1.1.1.1/32", + "1.1.1.2/31", + }, + }, + { + Network: "1.1.1.1/32", + Database: "ipv4", + Expected: []string{ + "1.1.1.1/32", + }, + }, + { + Network: "255.255.255.0/24", + Database: "ipv4", + Expected: []string(nil), + }, + { + Network: "1.1.1.1/32", + Database: "mixed", + Expected: []string{ + "1.1.1.1/32", + }, + }, + { + Network: "255.255.255.0/24", + Database: "mixed", + Expected: []string(nil), + }, + { + Network: "::1:ffff:ffff/128", + Database: "ipv6", + Expected: []string{ + "::1:ffff:ffff/128", + }, + Options: []NetworksOption{SkipAliasedNetworks}, + }, + { + Network: "::/0", + Database: "ipv6", + Expected: []string{ + "::1:ffff:ffff/128", + "::2:0:0/122", + "::2:0:40/124", + "::2:0:50/125", + "::2:0:58/127", + }, + Options: []NetworksOption{SkipAliasedNetworks}, + }, + { + Network: "::2:0:40/123", + Database: "ipv6", + Expected: []string{ + "::2:0:40/124", + "::2:0:50/125", + "::2:0:58/127", + }, + Options: []NetworksOption{SkipAliasedNetworks}, + }, + { + Network: "0:0:0:0:0:ffff:ffff:ff00/120", + Database: "ipv6", + Expected: []string(nil), + }, + { + Network: "0.0.0.0/0", + Database: "mixed", + Expected: []string{ + "1.1.1.1/32", + "1.1.1.2/31", + "1.1.1.4/30", + "1.1.1.8/29", + "1.1.1.16/28", + "1.1.1.32/32", + }, + }, + { + Network: "0.0.0.0/0", + Database: "mixed", + Expected: []string{ + "1.1.1.1/32", + "1.1.1.2/31", + "1.1.1.4/30", + "1.1.1.8/29", + "1.1.1.16/28", + "1.1.1.32/32", + }, + Options: []NetworksOption{SkipAliasedNetworks}, + }, + { + Network: "::/0", + Database: "mixed", + Expected: []string{ + "::101:101/128", + "::101:102/127", + "::101:104/126", + "::101:108/125", + "::101:110/124", + "::101:120/128", + "::1:ffff:ffff/128", + "::2:0:0/122", + "::2:0:40/124", + "::2:0:50/125", + "::2:0:58/127", + "1.1.1.1/32", + "1.1.1.2/31", + "1.1.1.4/30", + "1.1.1.8/29", + "1.1.1.16/28", + "1.1.1.32/32", + "2001:0:101:101::/64", + "2001:0:101:102::/63", + "2001:0:101:104::/62", + "2001:0:101:108::/61", + "2001:0:101:110::/60", + "2001:0:101:120::/64", + "2002:101:101::/48", + "2002:101:102::/47", + "2002:101:104::/46", + "2002:101:108::/45", + "2002:101:110::/44", + "2002:101:120::/48", + }, + }, + { + Network: "::/0", + Database: "mixed", + Expected: []string{ + "1.1.1.1/32", + "1.1.1.2/31", + "1.1.1.4/30", + "1.1.1.8/29", + "1.1.1.16/28", + "1.1.1.32/32", + "::1:ffff:ffff/128", + "::2:0:0/122", + "::2:0:40/124", + "::2:0:50/125", + "::2:0:58/127", + }, + Options: []NetworksOption{SkipAliasedNetworks}, + }, + { + Network: "1.1.1.16/28", + Database: "mixed", + Expected: []string{ + "1.1.1.16/28", + }, + }, + { + Network: "1.1.1.4/30", + Database: "ipv4", + Expected: []string{ + "1.1.1.4/30", + }, + }, +} + +func TestNetworksWithin(t *testing.T) { + for _, v := range tests { + for _, recordSize := range []uint{24, 28, 32} { + fileName := testFile(fmt.Sprintf("MaxMind-DB-test-%s-%d.mmdb", v.Database, recordSize)) + reader, err := Open(fileName) + require.Nil(t, err, "unexpected error while opening database: %v", err) + + _, network, err := net.ParseCIDR(v.Network) + assert.Nil(t, err) + n := reader.NetworksWithin(network, v.Options...) + var innerIPs []string + + for n.Next() { + record := struct { + IP string `maxminddb:"ip"` + }{} + network, err := n.Network(&record) + assert.Nil(t, err) + innerIPs = append(innerIPs, network.String()) + } + + assert.Equal(t, v.Expected, innerIPs) + assert.Nil(t, n.Err()) + + assert.NoError(t, reader.Close()) + } + } +} + +var geoipTests = []networkTest{ + { + Network: "81.2.69.128/26", + Database: "GeoIP2-Country-Test.mmdb", + Expected: []string{ + "81.2.69.142/31", + "81.2.69.144/28", + "81.2.69.160/27", + }, + }, +} + +func TestGeoIPNetworksWithin(t *testing.T) { + for _, v := range geoipTests { + fileName := testFile(v.Database) + reader, err := Open(fileName) + require.Nil(t, err, "unexpected error while opening database: %v", err) + + _, network, err := net.ParseCIDR(v.Network) + assert.Nil(t, err) + n := reader.NetworksWithin(network) + var innerIPs []string + + for n.Next() { + record := struct { + IP string `maxminddb:"ip"` + }{} + network, err := n.Network(&record) + assert.Nil(t, err) + innerIPs = append(innerIPs, network.String()) + } + + assert.Equal(t, v.Expected, innerIPs) + assert.Nil(t, n.Err()) + + assert.NoError(t, reader.Close()) + } } diff --git a/verifier.go b/verifier.go index ace9d35c400ab8c9aa2f323e577f493329b3bf5b..b14b3e48798f19c022fc5eead34113f672d5aa6d 100644 --- a/verifier.go +++ b/verifier.go @@ -1,6 +1,9 @@ package maxminddb -import "reflect" +import ( + "reflect" + "runtime" +) type verifier struct { reader *Reader @@ -15,7 +18,9 @@ func (r *Reader) Verify() error { return err } - return v.verifyDatabase() + err := v.verifyDatabase() + runtime.KeepAlive(v.reader) + return err } func (v *verifier) verifyMetadata() error { @@ -132,23 +137,34 @@ func (v *verifier) verifyDataSection(offsets map[uint]bool) error { var offset uint bufferLen := uint(len(decoder.buffer)) for offset < bufferLen { - var data interface{} + var data any rv := reflect.ValueOf(&data) newOffset, err := decoder.decode(offset, rv, 0) if err != nil { - return newInvalidDatabaseError("received decoding error (%v) at offset of %v", err, offset) + return newInvalidDatabaseError( + "received decoding error (%v) at offset of %v", + err, + offset, + ) } if newOffset <= offset { - return newInvalidDatabaseError("data section offset unexpectedly went from %v to %v", offset, newOffset) + return newInvalidDatabaseError( + "data section offset unexpectedly went from %v to %v", + offset, + newOffset, + ) } pointer := offset - if _, ok := offsets[pointer]; ok { - delete(offsets, pointer) - } else { - return newInvalidDatabaseError("found data (%v) at %v that the search tree does not point to", data, pointer) + if _, ok := offsets[pointer]; !ok { + return newInvalidDatabaseError( + "found data (%v) at %v that the search tree does not point to", + data, + pointer, + ) } + delete(offsets, pointer) offset = newOffset } @@ -173,8 +189,8 @@ func (v *verifier) verifyDataSection(offsets map[uint]bool) error { func testError( field string, - expected interface{}, - actual interface{}, + expected any, + actual any, ) error { return newInvalidDatabaseError( "%v - Expected: %v Actual: %v", diff --git a/verifier_test.go b/verifier_test.go index 8fc6bd446945f68eac3dbdb591fafccee020bfa6..dfdbd6361d3b606e1a4f7433d20b6365a92c9bbb 100644 --- a/verifier_test.go +++ b/verifier_test.go @@ -1,50 +1,61 @@ package maxminddb import ( - "github.com/stretchr/testify/assert" "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestVerifyOnGoodDatabases(t *testing.T) { databases := []string{ - "test-data/test-data/GeoIP2-Anonymous-IP-Test.mmdb", - "test-data/test-data/GeoIP2-City-Test.mmdb", - "test-data/test-data/GeoIP2-Connection-Type-Test.mmdb", - "test-data/test-data/GeoIP2-Country-Test.mmdb", - "test-data/test-data/GeoIP2-Domain-Test.mmdb", - "test-data/test-data/GeoIP2-ISP-Test.mmdb", - "test-data/test-data/GeoIP2-Precision-City-Test.mmdb", - "test-data/test-data/MaxMind-DB-no-ipv4-search-tree.mmdb", - "test-data/test-data/MaxMind-DB-string-value-entries.mmdb", - "test-data/test-data/MaxMind-DB-test-decoder.mmdb", - "test-data/test-data/MaxMind-DB-test-ipv4-24.mmdb", - "test-data/test-data/MaxMind-DB-test-ipv4-28.mmdb", - "test-data/test-data/MaxMind-DB-test-ipv4-32.mmdb", - "test-data/test-data/MaxMind-DB-test-ipv6-24.mmdb", - "test-data/test-data/MaxMind-DB-test-ipv6-28.mmdb", - "test-data/test-data/MaxMind-DB-test-ipv6-32.mmdb", - "test-data/test-data/MaxMind-DB-test-mixed-24.mmdb", - "test-data/test-data/MaxMind-DB-test-mixed-28.mmdb", - "test-data/test-data/MaxMind-DB-test-mixed-32.mmdb", - "test-data/test-data/MaxMind-DB-test-nested.mmdb", + "GeoIP2-Anonymous-IP-Test.mmdb", + "GeoIP2-City-Test.mmdb", + "GeoIP2-Connection-Type-Test.mmdb", + "GeoIP2-Country-Test.mmdb", + "GeoIP2-Domain-Test.mmdb", + "GeoIP2-ISP-Test.mmdb", + "GeoIP2-Precision-Enterprise-Test.mmdb", + "MaxMind-DB-no-ipv4-search-tree.mmdb", + "MaxMind-DB-string-value-entries.mmdb", + "MaxMind-DB-test-decoder.mmdb", + "MaxMind-DB-test-ipv4-24.mmdb", + "MaxMind-DB-test-ipv4-28.mmdb", + "MaxMind-DB-test-ipv4-32.mmdb", + "MaxMind-DB-test-ipv6-24.mmdb", + "MaxMind-DB-test-ipv6-28.mmdb", + "MaxMind-DB-test-ipv6-32.mmdb", + "MaxMind-DB-test-mixed-24.mmdb", + "MaxMind-DB-test-mixed-28.mmdb", + "MaxMind-DB-test-mixed-32.mmdb", + "MaxMind-DB-test-nested.mmdb", } for _, database := range databases { - reader, err := Open(database) - assert.Nil(t, err) - assert.Nil(t, reader.Verify(), "Received error (%v) when verifying %v", err, database) + t.Run(database, func(t *testing.T) { + reader, err := Open(testFile(database)) + require.NoError(t, err) + + assert.NoError( + t, + reader.Verify(), + "Received error (%v) when verifying %v", + err, + database, + ) + }) } } func TestVerifyOnBrokenDatabases(t *testing.T) { databases := []string{ - "test-data/test-data/GeoIP2-City-Test-Broken-Double-Format.mmdb", - "test-data/test-data/MaxMind-DB-test-broken-pointers-24.mmdb", - "test-data/test-data/MaxMind-DB-test-broken-search-tree-24.mmdb", + "GeoIP2-City-Test-Broken-Double-Format.mmdb", + "MaxMind-DB-test-broken-pointers-24.mmdb", + "MaxMind-DB-test-broken-search-tree-24.mmdb", } for _, database := range databases { - reader, err := Open(database) + reader, err := Open(testFile(database)) assert.Nil(t, err) assert.NotNil(t, reader.Verify(), "Did not receive expected error when verifying %v", database,