Deps update, gopkg.lock fix, flagger initialization fix.

This commit is contained in:
Stanislav Nikitin 2019-02-25 16:29:48 +05:00
parent 23510e22a0
commit afb076d116
406 changed files with 55802 additions and 18666 deletions
Gopkg.lockGopkg.toml
config
context
parsers/default
vendor
github.com/sirupsen/logrus
golang.org/x/crypto

62
Gopkg.lock generated

@ -2,47 +2,71 @@
[[projects]]
digest = "1:6a874e3ddfb9db2b42bd8c85b6875407c702fa868eed20634ff489bc896ccfd3"
name = "github.com/konsorten/go-windows-terminal-sequences"
packages = ["."]
pruneopts = ""
revision = "5c8c8bd35d3832f5d134ae1e1e375b69a4d25242"
version = "v1.0.1"
[[projects]]
branch = "master"
digest = "1:58942e5575baa224511b2cdda65b643bae65cea4e205a779198072205b9d91f7"
name = "github.com/pztrn/mogrus"
packages = ["."]
pruneopts = ""
revision = "2d9bba232129a20da87a9a9ab404b0f33c218622"
[[projects]]
digest = "1:9a3c631555e0351fdc4e696577bb63afd90c399d782a8462dba9d100d7021db3"
name = "github.com/sirupsen/logrus"
packages = ["."]
revision = "d682213848ed68c0a260ca37d6dd5ace8423f5ba"
version = "v1.0.4"
pruneopts = ""
revision = "e1e72e9de974bd926e5c56f83753fba2df402ce5"
version = "v1.3.0"
[[projects]]
branch = "master"
digest = "1:e08724b3e77d1ce400cff5185a4532f01cacc8059dd5a619fdf4a91e79374dfa"
name = "gitlab.com/pztrn/flagger"
packages = ["."]
pruneopts = ""
revision = "d429d7149cc92b7eaa7c98ab3d80b6d5c2e44046"
[[projects]]
branch = "master"
digest = "1:d0f4eb7abce3fbd3f0dcbbc03ffe18464846afd34c815928d2ae11c1e5aded04"
name = "golang.org/x/crypto"
packages = ["ssh/terminal"]
revision = "1875d0a70c90e57f11972aefd42276df65e895b9"
pruneopts = ""
revision = "ffb98f73852f696ea2bb21a617a5c4b3e067a439"
[[projects]]
branch = "master"
digest = "1:d8c7c5e28d31494f02b25867de6827adaa90005f0ff3407f3336751fc0bc7432"
name = "golang.org/x/sys"
packages = [
"unix",
"windows"
"windows",
]
revision = "8f27ce8a604014414f8dfffc25cbcde83a3f2216"
pruneopts = ""
revision = "cc5685c2db1239775905f3911f0067c0fa74762f"
[[projects]]
branch = "v2"
digest = "1:cedccf16b71e86db87a24f8d4c70b0a855872eb967cb906a66b95de56aefbd0d"
name = "gopkg.in/yaml.v2"
packages = ["."]
revision = "d670f9405373e636a5a2765eea47fac0c9bc91a4"
[[projects]]
branch = "master"
name = "source.pztrn.name/golibs/flagger"
packages = ["."]
revision = "c9b4726e1787adb01126cdf9a7932e2bccbb51a7"
[[projects]]
branch = "master"
name = "source.pztrn.name/golibs/mogrus"
packages = ["."]
revision = "2e3a9d38c8e4b0036f714f88ef0fe9fc14b4c6b8"
pruneopts = ""
revision = "51d6538a90f86fe93ac480b35f37b2be17fef232"
[solve-meta]
analyzer-name = "dep"
analyzer-version = 1
inputs-digest = "2b99cf1ed594e6ce9e2bc8bf1be8cbe0018db897ef6dc6c3e59ae5030099ef48"
input-imports = [
"github.com/pztrn/mogrus",
"gitlab.com/pztrn/flagger",
"gopkg.in/yaml.v2",
]
solver-name = "gps-cdcl"
solver-version = 1

@ -26,8 +26,8 @@
[[constraint]]
branch = "master"
name = "source.pztrn.name/golibs/flagger"
name = "gitlab.com/pztrn/flagger"
[[constraint]]
branch = "master"
name = "source.pztrn.name/golibs/mogrus"
name = "github.com/pztrn/mogrus"

@ -29,8 +29,8 @@ import (
"gitlab.com/pztrn/opensaps/config/struct"
// other
"gitlab.com/pztrn/flagger"
"gopkg.in/yaml.v2"
"source.pztrn.name/golibs/flagger"
)
type Configuration struct{}

@ -30,8 +30,8 @@ import (
"gitlab.com/pztrn/opensaps/slack/message"
// other
"source.pztrn.name/golibs/flagger"
"source.pztrn.name/golibs/mogrus"
"github.com/pztrn/mogrus"
"gitlab.com/pztrn/flagger"
)
type Context struct {
@ -52,7 +52,7 @@ func (c *Context) Initialize() {
c.Log = l.CreateLogger("opensaps")
c.Log.CreateOutput("stdout", os.Stdout, true, "debug")
c.Flagger = flagger.New(c.Log)
c.Flagger = flagger.New("opensaps", flagger.LoggerInterface(c.Log))
c.Flagger.Initialize()
}

@ -19,8 +19,8 @@ package defaultparser
import (
// local
"gitlab.com/pztrn/misc/opensaps/context"
"gitlab.com/pztrn/misc/opensaps/parsers/interface"
"gitlab.com/pztrn/opensaps/context"
"gitlab.com/pztrn/opensaps/parsers/interface"
)
var (

@ -1 +1,2 @@
logrus
vendor

@ -1,15 +1,52 @@
language: go
go:
- 1.6.x
- 1.7.x
- 1.8.x
- tip
go_import_path: github.com/sirupsen/logrus
env:
- GOMAXPROCS=4 GORACE=halt_on_error=1
install:
- go get github.com/stretchr/testify/assert
- go get gopkg.in/gemnasium/logrus-airbrake-hook.v2
- go get golang.org/x/sys/unix
- go get golang.org/x/sys/windows
script:
- go test -race -v ./...
matrix:
include:
- go: 1.10.x
install:
- go get github.com/stretchr/testify/assert
- go get golang.org/x/crypto/ssh/terminal
- go get golang.org/x/sys/unix
- go get golang.org/x/sys/windows
script:
- go test -race -v ./...
- go: 1.11.x
env: GO111MODULE=on
install:
- go mod download
script:
- go test -race -v ./...
- go: 1.11.x
env: GO111MODULE=off
install:
- go get github.com/stretchr/testify/assert
- go get golang.org/x/crypto/ssh/terminal
- go get golang.org/x/sys/unix
- go get golang.org/x/sys/windows
script:
- go test -race -v ./...
- go: 1.10.x
install:
- go get github.com/stretchr/testify/assert
- go get golang.org/x/crypto/ssh/terminal
- go get golang.org/x/sys/unix
- go get golang.org/x/sys/windows
script:
- go test -race -v -tags appengine ./...
- go: 1.11.x
env: GO111MODULE=on
install:
- go mod download
script:
- go test -race -v -tags appengine ./...
- go: 1.11.x
env: GO111MODULE=off
install:
- go get github.com/stretchr/testify/assert
- go get golang.org/x/crypto/ssh/terminal
- go get golang.org/x/sys/unix
- go get golang.org/x/sys/windows
script:
- go test -race -v -tags appengine ./...

@ -1,3 +1,50 @@
# 1.2.0
This new release introduces:
* A new method `SetReportCaller` in the `Logger` to enable the file, line and calling function from which the trace has been issued
* A new trace level named `Trace` whose level is below `Debug`
* A configurable exit function to be called upon a Fatal trace
* The `Level` object now implements `encoding.TextUnmarshaler` interface
# 1.1.1
This is a bug fix release.
* fix the build break on Solaris
* don't drop a whole trace in JSONFormatter when a field param is a function pointer which can not be serialized
# 1.1.0
This new release introduces:
* several fixes:
* a fix for a race condition on entry formatting
* proper cleanup of previously used entries before putting them back in the pool
* the extra new line at the end of message in text formatter has been removed
* a new global public API to check if a level is activated: IsLevelEnabled
* the following methods have been added to the Logger object
* IsLevelEnabled
* SetFormatter
* SetOutput
* ReplaceHooks
* introduction of go module
* an indent configuration for the json formatter
* output colour support for windows
* the field sort function is now configurable for text formatter
* the CLICOLOR and CLICOLOR\_FORCE environment variable support in text formater
# 1.0.6
This new release introduces:
* a new api WithTime which allows to easily force the time of the log entry
which is mostly useful for logger wrapper
* a fix reverting the immutability of the entry given as parameter to the hooks
a new configuration field of the json formatter in order to put all the fields
in a nested dictionnary
* a new SetOutput method in the Logger
* a new configuration of the textformatter to configure the name of the default keys
* a new configuration of the text formatter to disable the level truncation
# 1.0.5
* Fix hooks race (#707)
* Fix panic deadlock (#695)
# 1.0.4
* Fix race when adding hooks (#612)

@ -56,8 +56,39 @@ time="2015-03-26T01:27:38-04:00" level=warning msg="The group's number increased
time="2015-03-26T01:27:38-04:00" level=debug msg="Temperature changes" temperature=-4
time="2015-03-26T01:27:38-04:00" level=panic msg="It's over 9000!" animal=orca size=9009
time="2015-03-26T01:27:38-04:00" level=fatal msg="The ice breaks!" err=&{0x2082280c0 map[animal:orca size:9009] 2015-03-26 01:27:38.441574009 -0400 EDT panic It's over 9000!} number=100 omg=true
exit status 1
```
To ensure this behaviour even if a TTY is attached, set your formatter as follows:
```go
log.SetFormatter(&log.TextFormatter{
DisableColors: true,
FullTimestamp: true,
})
```
#### Logging Method Name
If you wish to add the calling method as a field, instruct the logger via:
```go
log.SetReportCaller(true)
```
This adds the caller as 'method' like so:
```json
{"animal":"penguin","level":"fatal","method":"github.com/sirupsen/arcticcreatures.migrate","msg":"a penguin swims by",
"time":"2014-03-10 19:57:38.562543129 -0400 EDT"}
```
```text
time="2015-03-26T01:27:38-04:00" level=fatal method=github.com/sirupsen/arcticcreatures.migrate msg="a penguin swims by" animal=penguin
```
Note that this does add measurable overhead - the cost will depend on the version of Go, but is
between 20 and 40% in recent tests with 1.6 and 1.7. You can validate this in your
environment via benchmarks:
```
go test -bench=.*CallerTracing
```
#### Case-sensitivity
@ -220,7 +251,7 @@ Logrus comes with [built-in hooks](hooks/). Add those, or your custom hook, in
```go
import (
log "github.com/sirupsen/logrus"
"gopkg.in/gemnasium/logrus-airbrake-hook.v2" // the package is named "aibrake"
"gopkg.in/gemnasium/logrus-airbrake-hook.v2" // the package is named "airbrake"
logrus_syslog "github.com/sirupsen/logrus/hooks/syslog"
"log/syslog"
)
@ -241,64 +272,15 @@ func init() {
```
Note: Syslog hook also support connecting to local syslog (Ex. "/dev/log" or "/var/run/syslog" or "/var/run/log"). For the detail, please check the [syslog hook README](hooks/syslog/README.md).
| Hook | Description |
| ----- | ----------- |
| [Airbrake "legacy"](https://github.com/gemnasium/logrus-airbrake-legacy-hook) | Send errors to an exception tracking service compatible with the Airbrake API V2. Uses [`airbrake-go`](https://github.com/tobi/airbrake-go) behind the scenes. |
| [Airbrake](https://github.com/gemnasium/logrus-airbrake-hook) | Send errors to the Airbrake API V3. Uses the official [`gobrake`](https://github.com/airbrake/gobrake) behind the scenes. |
| [Amazon Kinesis](https://github.com/evalphobia/logrus_kinesis) | Hook for logging to [Amazon Kinesis](https://aws.amazon.com/kinesis/) |
| [Amqp-Hook](https://github.com/vladoatanasov/logrus_amqp) | Hook for logging to Amqp broker (Like RabbitMQ) |
| [AzureTableHook](https://github.com/kpfaulkner/azuretablehook/) | Hook for logging to Azure Table Storage|
| [Bugsnag](https://github.com/Shopify/logrus-bugsnag/blob/master/bugsnag.go) | Send errors to the Bugsnag exception tracking service. |
| [DeferPanic](https://github.com/deferpanic/dp-logrus) | Hook for logging to DeferPanic |
| [Discordrus](https://github.com/kz/discordrus) | Hook for logging to [Discord](https://discordapp.com/) |
| [ElasticSearch](https://github.com/sohlich/elogrus) | Hook for logging to ElasticSearch|
| [Firehose](https://github.com/beaubrewer/logrus_firehose) | Hook for logging to [Amazon Firehose](https://aws.amazon.com/kinesis/firehose/)
| [Fluentd](https://github.com/evalphobia/logrus_fluent) | Hook for logging to fluentd |
| [Go-Slack](https://github.com/multiplay/go-slack) | Hook for logging to [Slack](https://slack.com) |
| [Graylog](https://github.com/gemnasium/logrus-graylog-hook) | Hook for logging to [Graylog](http://graylog2.org/) |
| [Hiprus](https://github.com/nubo/hiprus) | Send errors to a channel in hipchat. |
| [Honeybadger](https://github.com/agonzalezro/logrus_honeybadger) | Hook for sending exceptions to Honeybadger |
| [InfluxDB](https://github.com/Abramovic/logrus_influxdb) | Hook for logging to influxdb |
| [Influxus](http://github.com/vlad-doru/influxus) | Hook for concurrently logging to [InfluxDB](http://influxdata.com/) |
| [Journalhook](https://github.com/wercker/journalhook) | Hook for logging to `systemd-journald` |
| [KafkaLogrus](https://github.com/tracer0tong/kafkalogrus) | Hook for logging to Kafka |
| [LFShook](https://github.com/rifflock/lfshook) | Hook for logging to the local filesystem |
| [Logbeat](https://github.com/macandmia/logbeat) | Hook for logging to [Opbeat](https://opbeat.com/) |
| [Logentries](https://github.com/jcftang/logentriesrus) | Hook for logging to [Logentries](https://logentries.com/) |
| [Logentrus](https://github.com/puddingfactory/logentrus) | Hook for logging to [Logentries](https://logentries.com/) |
| [Logmatic.io](https://github.com/logmatic/logmatic-go) | Hook for logging to [Logmatic.io](http://logmatic.io/) |
| [Logrusly](https://github.com/sebest/logrusly) | Send logs to [Loggly](https://www.loggly.com/) |
| [Logstash](https://github.com/bshuster-repo/logrus-logstash-hook) | Hook for logging to [Logstash](https://www.elastic.co/products/logstash) |
| [Mail](https://github.com/zbindenren/logrus_mail) | Hook for sending exceptions via mail |
| [Mattermost](https://github.com/shuLhan/mattermost-integration/tree/master/hooks/logrus) | Hook for logging to [Mattermost](https://mattermost.com/) |
| [Mongodb](https://github.com/weekface/mgorus) | Hook for logging to mongodb |
| [NATS-Hook](https://github.com/rybit/nats_logrus_hook) | Hook for logging to [NATS](https://nats.io) |
| [Octokit](https://github.com/dorajistyle/logrus-octokit-hook) | Hook for logging to github via octokit |
| [Papertrail](https://github.com/polds/logrus-papertrail-hook) | Send errors to the [Papertrail](https://papertrailapp.com) hosted logging service via UDP. |
| [PostgreSQL](https://github.com/gemnasium/logrus-postgresql-hook) | Send logs to [PostgreSQL](http://postgresql.org) |
| [Promrus](https://github.com/weaveworks/promrus) | Expose number of log messages as [Prometheus](https://prometheus.io/) metrics |
| [Pushover](https://github.com/toorop/logrus_pushover) | Send error via [Pushover](https://pushover.net) |
| [Raygun](https://github.com/squirkle/logrus-raygun-hook) | Hook for logging to [Raygun.io](http://raygun.io/) |
| [Redis-Hook](https://github.com/rogierlommers/logrus-redis-hook) | Hook for logging to a ELK stack (through Redis) |
| [Rollrus](https://github.com/heroku/rollrus) | Hook for sending errors to rollbar |
| [Scribe](https://github.com/sagar8192/logrus-scribe-hook) | Hook for logging to [Scribe](https://github.com/facebookarchive/scribe)|
| [Sentry](https://github.com/evalphobia/logrus_sentry) | Send errors to the Sentry error logging and aggregation service. |
| [Slackrus](https://github.com/johntdyer/slackrus) | Hook for Slack chat. |
| [Stackdriver](https://github.com/knq/sdhook) | Hook for logging to [Google Stackdriver](https://cloud.google.com/logging/) |
| [Sumorus](https://github.com/doublefree/sumorus) | Hook for logging to [SumoLogic](https://www.sumologic.com/)|
| [Syslog](https://github.com/sirupsen/logrus/blob/master/hooks/syslog/syslog.go) | Send errors to remote syslog server. Uses standard library `log/syslog` behind the scenes. |
| [Syslog TLS](https://github.com/shinji62/logrus-syslog-ng) | Send errors to remote syslog server with TLS support. |
| [Telegram](https://github.com/rossmcdonald/telegram_hook) | Hook for logging errors to [Telegram](https://telegram.org/) |
| [TraceView](https://github.com/evalphobia/logrus_appneta) | Hook for logging to [AppNeta TraceView](https://www.appneta.com/products/traceview/) |
| [Typetalk](https://github.com/dragon3/logrus-typetalk-hook) | Hook for logging to [Typetalk](https://www.typetalk.in/) |
| [logz.io](https://github.com/ripcurld00d/logrus-logzio-hook) | Hook for logging to [logz.io](https://logz.io), a Log as a Service using Logstash |
| [SQS-Hook](https://github.com/tsarpaul/logrus_sqs) | Hook for logging to [Amazon Simple Queue Service (SQS)](https://aws.amazon.com/sqs/) |
A list of currently known of service hook can be found in this wiki [page](https://github.com/sirupsen/logrus/wiki/Hooks)
#### Level logging
Logrus has six logging levels: Debug, Info, Warning, Error, Fatal and Panic.
Logrus has seven logging levels: Trace, Debug, Info, Warning, Error, Fatal and Panic.
```go
log.Trace("Something very low level.")
log.Debug("Useful debugging information.")
log.Info("Something noteworthy happened!")
log.Warn("You should probably take a look at this.")
@ -370,6 +352,8 @@ The built-in logging formatters are:
field to `true`. To force no colored output even if there is a TTY set the
`DisableColors` field to `true`. For Windows, see
[github.com/mattn/go-colorable](https://github.com/mattn/go-colorable).
* When colors are enabled, levels are truncated to 4 characters by default. To disable
truncation set the `DisableLevelTruncation` field to `true`.
* All options are listed in the [generated docs](https://godoc.org/github.com/sirupsen/logrus#TextFormatter).
* `logrus.JSONFormatter`. Logs fields as JSON.
* All options are listed in the [generated docs](https://godoc.org/github.com/sirupsen/logrus#JSONFormatter).
@ -377,6 +361,7 @@ The built-in logging formatters are:
Third party logging formatters:
* [`FluentdFormatter`](https://github.com/joonix/log). Formats entries that can be parsed by Kubernetes and Google Container Engine.
* [`GELF`](https://github.com/fabienm/go-logrus-formatters). Formats entries so they comply to Graylog's [GELF 1.1 specification](http://docs.graylog.org/en/2.4/pages/gelf.html).
* [`logstash`](https://github.com/bshuster-repo/logrus-logstash-hook). Logs fields as [Logstash](http://logstash.net) Events.
* [`prefixed`](https://github.com/x-cray/logrus-prefixed-formatter). Displays log entry source along with alternative layout.
* [`zalgo`](https://github.com/aybabtme/logzalgo). Invoking the P͉̫o̳̼̊w̖͈̰͎e̬͔̭͂r͚̼̹̲ ̫͓͉̳͈ō̠͕͖̚f̝͍̠ ͕̲̞͖͑Z̖̫̤̫ͪa͉̬͈̗l͖͎g̳̥o̰̥̅!̣͔̲̻͊̄ ̙̘̦̹̦.
@ -493,7 +478,7 @@ logrus.RegisterExitHandler(handler)
#### Thread safety
By default Logger is protected by mutex for concurrent writes, this mutex is invoked when calling hooks and writing logs.
By default, Logger is protected by a mutex for concurrent writes. The mutex is held when calling hooks and writing logs.
If you are sure such locking is not needed, you can call logger.SetNoLock() to disable the locking.
Situation when locking is not needed includes:

@ -6,6 +6,8 @@ import (
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"testing"
"time"
)
@ -19,6 +21,9 @@ func TestRegister(t *testing.T) {
}
func TestHandler(t *testing.T) {
testprog := testprogleader
testprog = append(testprog, getPackage()...)
testprog = append(testprog, testprogtrailer...)
tempDir, err := ioutil.TempDir("", "test_handler")
if err != nil {
log.Fatalf("can't create temp dir. %q", err)
@ -47,13 +52,24 @@ func TestHandler(t *testing.T) {
}
}
var testprog = []byte(`
// getPackage returns the name of the current package, which makes running this
// test in a fork simpler
func getPackage() []byte {
pc, _, _, _ := runtime.Caller(0)
fullFuncName := runtime.FuncForPC(pc).Name()
idx := strings.LastIndex(fullFuncName, ".")
return []byte(fullFuncName[:idx]) // trim off function details
}
var testprogleader = []byte(`
// Test program for atexit, gets output file and data as arguments and writes
// data to output file in atexit handler.
package main
import (
"github.com/sirupsen/logrus"
"`)
var testprogtrailer = []byte(
`"
"flag"
"fmt"
"io/ioutil"

@ -4,11 +4,30 @@ import (
"bytes"
"fmt"
"os"
"reflect"
"runtime"
"strings"
"sync"
"time"
)
var bufferPool *sync.Pool
var (
bufferPool *sync.Pool
// qualified package name, cached at first use
logrusPackage string
// Positions in the call stack when tracing to report the calling method
minimumCallerDepth int
// Used for caller information initialisation
callerInitOnce sync.Once
)
const (
maximumCallerDepth int = 25
knownLogrusFrames int = 4
)
func init() {
bufferPool = &sync.Pool{
@ -16,15 +35,18 @@ func init() {
return new(bytes.Buffer)
},
}
// start at the bottom of the stack before the package-name cache is primed
minimumCallerDepth = 1
}
// Defines the key when adding errors using WithError.
var ErrorKey = "error"
// An entry is the final or intermediate Logrus logging entry. It contains all
// the fields passed with WithField{,s}. It's finally logged when Debug, Info,
// Warn, Error, Fatal or Panic is called on it. These objects can be reused and
// passed around as much as you wish to avoid field duplication.
// the fields passed with WithField{,s}. It's finally logged when Trace, Debug,
// Info, Warn, Error, Fatal or Panic is called on it. These objects can be
// reused and passed around as much as you wish to avoid field duplication.
type Entry struct {
Logger *Logger
@ -34,22 +56,28 @@ type Entry struct {
// Time at which the log entry was created
Time time.Time
// Level the log entry was logged at: Debug, Info, Warn, Error, Fatal or Panic
// Level the log entry was logged at: Trace, Debug, Info, Warn, Error, Fatal or Panic
// This field will be set on entry firing and the value will be equal to the one in Logger struct field.
Level Level
// Message passed to Debug, Info, Warn, Error, Fatal or Panic
// Calling method, with package name
Caller *runtime.Frame
// Message passed to Trace, Debug, Info, Warn, Error, Fatal or Panic
Message string
// When formatter is called in entry.log(), an Buffer may be set to entry
// When formatter is called in entry.log(), a Buffer may be set to entry
Buffer *bytes.Buffer
// err may contain a field formatting error
err string
}
func NewEntry(logger *Logger) *Entry {
return &Entry{
Logger: logger,
// Default is three fields, give a little extra room
Data: make(Fields, 5),
// Default is three fields, plus one optional. Give a little extra room.
Data: make(Fields, 6),
}
}
@ -80,46 +108,117 @@ func (entry *Entry) WithFields(fields Fields) *Entry {
for k, v := range entry.Data {
data[k] = v
}
fieldErr := entry.err
for k, v := range fields {
data[k] = v
isErrField := false
if t := reflect.TypeOf(v); t != nil {
switch t.Kind() {
case reflect.Func:
isErrField = true
case reflect.Ptr:
isErrField = t.Elem().Kind() == reflect.Func
}
}
if isErrField {
tmp := fmt.Sprintf("can not add field %q", k)
if fieldErr != "" {
fieldErr = entry.err + ", " + tmp
} else {
fieldErr = tmp
}
} else {
data[k] = v
}
}
return &Entry{Logger: entry.Logger, Data: data}
return &Entry{Logger: entry.Logger, Data: data, Time: entry.Time, err: fieldErr}
}
// Overrides the time of the Entry.
func (entry *Entry) WithTime(t time.Time) *Entry {
return &Entry{Logger: entry.Logger, Data: entry.Data, Time: t, err: entry.err}
}
// getPackageName reduces a fully qualified function name to the package name
// There really ought to be to be a better way...
func getPackageName(f string) string {
for {
lastPeriod := strings.LastIndex(f, ".")
lastSlash := strings.LastIndex(f, "/")
if lastPeriod > lastSlash {
f = f[:lastPeriod]
} else {
break
}
}
return f
}
// getCaller retrieves the name of the first non-logrus calling function
func getCaller() *runtime.Frame {
// Restrict the lookback frames to avoid runaway lookups
pcs := make([]uintptr, maximumCallerDepth)
depth := runtime.Callers(minimumCallerDepth, pcs)
frames := runtime.CallersFrames(pcs[:depth])
// cache this package's fully-qualified name
callerInitOnce.Do(func() {
logrusPackage = getPackageName(runtime.FuncForPC(pcs[0]).Name())
// now that we have the cache, we can skip a minimum count of known-logrus functions
// XXX this is dubious, the number of frames may vary store an entry in a logger interface
minimumCallerDepth = knownLogrusFrames
})
for f, again := frames.Next(); again; f, again = frames.Next() {
pkg := getPackageName(f.Function)
// If the caller isn't part of this package, we're done
if pkg != logrusPackage {
return &f
}
}
// if we got here, we failed to find the caller's context
return nil
}
func (entry Entry) HasCaller() (has bool) {
return entry.Logger != nil &&
entry.Logger.ReportCaller &&
entry.Caller != nil
}
// This function is not declared with a pointer value because otherwise
// race conditions will occur when using multiple goroutines
func (entry Entry) log(level Level, msg string) {
var buffer *bytes.Buffer
entry.Time = time.Now()
// Default to now, but allow users to override if they want.
//
// We don't have to worry about polluting future calls to Entry#log()
// with this assignment because this function is declared with a
// non-pointer receiver.
if entry.Time.IsZero() {
entry.Time = time.Now()
}
entry.Level = level
entry.Message = msg
entry.Logger.mu.Lock()
err := entry.Logger.Hooks.Fire(level, &entry)
entry.Logger.mu.Unlock()
if err != nil {
entry.Logger.mu.Lock()
fmt.Fprintf(os.Stderr, "Failed to fire hook: %v\n", err)
entry.Logger.mu.Unlock()
if entry.Logger.ReportCaller {
entry.Caller = getCaller()
}
entry.fireHooks()
buffer = bufferPool.Get().(*bytes.Buffer)
buffer.Reset()
defer bufferPool.Put(buffer)
entry.Buffer = buffer
serialized, err := entry.Logger.Formatter.Format(&entry)
entry.write()
entry.Buffer = nil
if err != nil {
entry.Logger.mu.Lock()
fmt.Fprintf(os.Stderr, "Failed to obtain reader, %v\n", err)
entry.Logger.mu.Unlock()
} else {
entry.Logger.mu.Lock()
_, err = entry.Logger.Out.Write(serialized)
if err != nil {
fmt.Fprintf(os.Stderr, "Failed to write to log, %v\n", err)
}
entry.Logger.mu.Unlock()
}
// To avoid Entry#log() returning a value that only would make sense for
// panic() to use in Entry#Panic(), we avoid the allocation by checking
@ -129,26 +228,53 @@ func (entry Entry) log(level Level, msg string) {
}
}
func (entry *Entry) Debug(args ...interface{}) {
if entry.Logger.level() >= DebugLevel {
entry.log(DebugLevel, fmt.Sprint(args...))
func (entry *Entry) fireHooks() {
entry.Logger.mu.Lock()
defer entry.Logger.mu.Unlock()
err := entry.Logger.Hooks.Fire(entry.Level, entry)
if err != nil {
fmt.Fprintf(os.Stderr, "Failed to fire hook: %v\n", err)
}
}
func (entry *Entry) write() {
entry.Logger.mu.Lock()
defer entry.Logger.mu.Unlock()
serialized, err := entry.Logger.Formatter.Format(entry)
if err != nil {
fmt.Fprintf(os.Stderr, "Failed to obtain reader, %v\n", err)
} else {
_, err = entry.Logger.Out.Write(serialized)
if err != nil {
fmt.Fprintf(os.Stderr, "Failed to write to log, %v\n", err)
}
}
}
func (entry *Entry) Log(level Level, args ...interface{}) {
if entry.Logger.IsLevelEnabled(level) {
entry.log(level, fmt.Sprint(args...))
}
}
func (entry *Entry) Trace(args ...interface{}) {
entry.Log(TraceLevel, args...)
}
func (entry *Entry) Debug(args ...interface{}) {
entry.Log(DebugLevel, args...)
}
func (entry *Entry) Print(args ...interface{}) {
entry.Info(args...)
}
func (entry *Entry) Info(args ...interface{}) {
if entry.Logger.level() >= InfoLevel {
entry.log(InfoLevel, fmt.Sprint(args...))
}
entry.Log(InfoLevel, args...)
}
func (entry *Entry) Warn(args ...interface{}) {
if entry.Logger.level() >= WarnLevel {
entry.log(WarnLevel, fmt.Sprint(args...))
}
entry.Log(WarnLevel, args...)
}
func (entry *Entry) Warning(args ...interface{}) {
@ -156,37 +282,35 @@ func (entry *Entry) Warning(args ...interface{}) {
}
func (entry *Entry) Error(args ...interface{}) {
if entry.Logger.level() >= ErrorLevel {
entry.log(ErrorLevel, fmt.Sprint(args...))
}
entry.Log(ErrorLevel, args...)
}
func (entry *Entry) Fatal(args ...interface{}) {
if entry.Logger.level() >= FatalLevel {
entry.log(FatalLevel, fmt.Sprint(args...))
}
Exit(1)
entry.Log(FatalLevel, args...)
entry.Logger.Exit(1)
}
func (entry *Entry) Panic(args ...interface{}) {
if entry.Logger.level() >= PanicLevel {
entry.log(PanicLevel, fmt.Sprint(args...))
}
entry.Log(PanicLevel, args...)
panic(fmt.Sprint(args...))
}
// Entry Printf family functions
func (entry *Entry) Logf(level Level, format string, args ...interface{}) {
entry.Log(level, fmt.Sprintf(format, args...))
}
func (entry *Entry) Tracef(format string, args ...interface{}) {
entry.Logf(TraceLevel, format, args...)
}
func (entry *Entry) Debugf(format string, args ...interface{}) {
if entry.Logger.level() >= DebugLevel {
entry.Debug(fmt.Sprintf(format, args...))
}
entry.Logf(DebugLevel, format, args...)
}
func (entry *Entry) Infof(format string, args ...interface{}) {
if entry.Logger.level() >= InfoLevel {
entry.Info(fmt.Sprintf(format, args...))
}
entry.Logf(InfoLevel, format, args...)
}
func (entry *Entry) Printf(format string, args ...interface{}) {
@ -194,9 +318,7 @@ func (entry *Entry) Printf(format string, args ...interface{}) {
}
func (entry *Entry) Warnf(format string, args ...interface{}) {
if entry.Logger.level() >= WarnLevel {
entry.Warn(fmt.Sprintf(format, args...))
}
entry.Logf(WarnLevel, format, args...)
}
func (entry *Entry) Warningf(format string, args ...interface{}) {
@ -204,36 +326,36 @@ func (entry *Entry) Warningf(format string, args ...interface{}) {
}
func (entry *Entry) Errorf(format string, args ...interface{}) {
if entry.Logger.level() >= ErrorLevel {
entry.Error(fmt.Sprintf(format, args...))
}
entry.Logf(ErrorLevel, format, args...)
}
func (entry *Entry) Fatalf(format string, args ...interface{}) {
if entry.Logger.level() >= FatalLevel {
entry.Fatal(fmt.Sprintf(format, args...))
}
Exit(1)
entry.Logf(FatalLevel, format, args...)
entry.Logger.Exit(1)
}
func (entry *Entry) Panicf(format string, args ...interface{}) {
if entry.Logger.level() >= PanicLevel {
entry.Panic(fmt.Sprintf(format, args...))
}
entry.Logf(PanicLevel, format, args...)
}
// Entry Println family functions
func (entry *Entry) Debugln(args ...interface{}) {
if entry.Logger.level() >= DebugLevel {
entry.Debug(entry.sprintlnn(args...))
func (entry *Entry) Logln(level Level, args ...interface{}) {
if entry.Logger.IsLevelEnabled(level) {
entry.Log(level, entry.sprintlnn(args...))
}
}
func (entry *Entry) Traceln(args ...interface{}) {
entry.Logln(TraceLevel, args...)
}
func (entry *Entry) Debugln(args ...interface{}) {
entry.Logln(DebugLevel, args...)
}
func (entry *Entry) Infoln(args ...interface{}) {
if entry.Logger.level() >= InfoLevel {
entry.Info(entry.sprintlnn(args...))
}
entry.Logln(InfoLevel, args...)
}
func (entry *Entry) Println(args ...interface{}) {
@ -241,9 +363,7 @@ func (entry *Entry) Println(args ...interface{}) {
}
func (entry *Entry) Warnln(args ...interface{}) {
if entry.Logger.level() >= WarnLevel {
entry.Warn(entry.sprintlnn(args...))
}
entry.Logln(WarnLevel, args...)
}
func (entry *Entry) Warningln(args ...interface{}) {
@ -251,22 +371,16 @@ func (entry *Entry) Warningln(args ...interface{}) {
}
func (entry *Entry) Errorln(args ...interface{}) {
if entry.Logger.level() >= ErrorLevel {
entry.Error(entry.sprintlnn(args...))
}
entry.Logln(ErrorLevel, args...)
}
func (entry *Entry) Fatalln(args ...interface{}) {
if entry.Logger.level() >= FatalLevel {
entry.Fatal(entry.sprintlnn(args...))
}
Exit(1)
entry.Logln(FatalLevel, args...)
entry.Logger.Exit(1)
}
func (entry *Entry) Panicln(args ...interface{}) {
if entry.Logger.level() >= PanicLevel {
entry.Panic(entry.sprintlnn(args...))
}
entry.Logln(PanicLevel, args...)
}
// Sprintlnn => Sprint no newline. This is to get the behavior of how

@ -4,6 +4,7 @@ import (
"bytes"
"fmt"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
@ -75,3 +76,66 @@ func TestEntryPanicf(t *testing.T) {
entry := NewEntry(logger)
entry.WithField("err", errBoom).Panicf("kaboom %v", true)
}
const (
badMessage = "this is going to panic"
panicMessage = "this is broken"
)
type panickyHook struct{}
func (p *panickyHook) Levels() []Level {
return []Level{InfoLevel}
}
func (p *panickyHook) Fire(entry *Entry) error {
if entry.Message == badMessage {
panic(panicMessage)
}
return nil
}
func TestEntryHooksPanic(t *testing.T) {
logger := New()
logger.Out = &bytes.Buffer{}
logger.Level = InfoLevel
logger.Hooks.Add(&panickyHook{})
defer func() {
p := recover()
assert.NotNil(t, p)
assert.Equal(t, panicMessage, p)
entry := NewEntry(logger)
entry.Info("another message")
}()
entry := NewEntry(logger)
entry.Info(badMessage)
}
func TestEntryWithIncorrectField(t *testing.T) {
assert := assert.New(t)
fn := func() {}
e := Entry{}
eWithFunc := e.WithFields(Fields{"func": fn})
eWithFuncPtr := e.WithFields(Fields{"funcPtr": &fn})
assert.Equal(eWithFunc.err, `can not add field "func"`)
assert.Equal(eWithFuncPtr.err, `can not add field "funcPtr"`)
eWithFunc = eWithFunc.WithField("not_a_func", "it is a string")
eWithFuncPtr = eWithFuncPtr.WithField("not_a_func", "it is a string")
assert.Equal(eWithFunc.err, `can not add field "func"`)
assert.Equal(eWithFuncPtr.err, `can not add field "funcPtr"`)
eWithFunc = eWithFunc.WithTime(time.Now())
eWithFuncPtr = eWithFuncPtr.WithTime(time.Now())
assert.Equal(eWithFunc.err, `can not add field "func"`)
assert.Equal(eWithFuncPtr.err, `can not add field "funcPtr"`)
}

@ -1,16 +1,18 @@
package logrus_test
import (
"github.com/sirupsen/logrus"
"os"
"github.com/sirupsen/logrus"
)
func Example_basic() {
var log = logrus.New()
log.Formatter = new(logrus.JSONFormatter)
log.Formatter = new(logrus.TextFormatter) //default
log.Formatter.(*logrus.TextFormatter).DisableColors = true // remove colors
log.Formatter.(*logrus.TextFormatter).DisableTimestamp = true // remove timestamp from test output
log.Level = logrus.DebugLevel
log.Level = logrus.TraceLevel
log.Out = os.Stdout
// file, err := os.OpenFile("logrus.log", os.O_CREATE|os.O_WRONLY, 0666)
@ -35,6 +37,11 @@ func Example_basic() {
}
}()
log.WithFields(logrus.Fields{
"animal": "walrus",
"number": 0,
}).Trace("Went to the beach")
log.WithFields(logrus.Fields{
"animal": "walrus",
"number": 8,
@ -60,6 +67,7 @@ func Example_basic() {
}).Panic("It's over 9000!")
// Output:
// level=trace msg="Went to the beach" animal=walrus number=0
// level=debug msg="Started observing beach" animal=walrus number=8
// level=info msg="A group of walrus emerges from the ocean" animal=walrus size=10
// level=warning msg="The group's number increased tremendously!" number=122 omg=true

@ -1,16 +1,24 @@
// +build !windows
package logrus_test
import (
"github.com/sirupsen/logrus"
"gopkg.in/gemnasium/logrus-airbrake-hook.v2"
"log/syslog"
"os"
"github.com/sirupsen/logrus"
slhooks "github.com/sirupsen/logrus/hooks/syslog"
)
// An example on how to use a hook
func Example_hook() {
var log = logrus.New()
log.Formatter = new(logrus.TextFormatter) // default
log.Formatter.(*logrus.TextFormatter).DisableColors = true // remove colors
log.Formatter.(*logrus.TextFormatter).DisableTimestamp = true // remove timestamp from test output
log.Hooks.Add(airbrake.NewHook(123, "xyz", "development"))
if sl, err := slhooks.NewSyslogHook("udp", "localhost:514", syslog.LOG_INFO, ""); err == nil {
log.Hooks.Add(sl)
}
log.Out = os.Stdout
log.WithFields(logrus.Fields{

@ -2,6 +2,7 @@ package logrus
import (
"io"
"time"
)
var (
@ -15,37 +16,38 @@ func StandardLogger() *Logger {
// SetOutput sets the standard logger output.
func SetOutput(out io.Writer) {
std.mu.Lock()
defer std.mu.Unlock()
std.Out = out
std.SetOutput(out)
}
// SetFormatter sets the standard logger formatter.
func SetFormatter(formatter Formatter) {
std.mu.Lock()
defer std.mu.Unlock()
std.Formatter = formatter
std.SetFormatter(formatter)
}
// SetReportCaller sets whether the standard logger will include the calling
// method as a field.
func SetReportCaller(include bool) {
std.SetReportCaller(include)
}
// SetLevel sets the standard logger level.
func SetLevel(level Level) {
std.mu.Lock()
defer std.mu.Unlock()
std.SetLevel(level)
}
// GetLevel returns the standard logger level.
func GetLevel() Level {
std.mu.Lock()
defer std.mu.Unlock()
return std.level()
return std.GetLevel()
}
// IsLevelEnabled checks if the log level of the standard logger is greater than the level param
func IsLevelEnabled(level Level) bool {
return std.IsLevelEnabled(level)
}
// AddHook adds a hook to the standard logger hooks.
func AddHook(hook Hook) {
std.mu.Lock()
defer std.mu.Unlock()
std.Hooks.Add(hook)
std.AddHook(hook)
}
// WithError creates an entry from the standard logger and adds an error to it, using the value defined in ErrorKey as key.
@ -72,6 +74,20 @@ func WithFields(fields Fields) *Entry {
return std.WithFields(fields)
}
// WithTime creats an entry from the standard logger and overrides the time of
// logs generated with it.
//
// Note that it doesn't log until you call Debug, Print, Info, Warn, Fatal
// or Panic on the Entry it returns.
func WithTime(t time.Time) *Entry {
return std.WithTime(t)
}
// Trace logs a message at level Trace on the standard logger.
func Trace(args ...interface{}) {
std.Trace(args...)
}
// Debug logs a message at level Debug on the standard logger.
func Debug(args ...interface{}) {
std.Debug(args...)
@ -107,11 +123,16 @@ func Panic(args ...interface{}) {
std.Panic(args...)
}
// Fatal logs a message at level Fatal on the standard logger.
// Fatal logs a message at level Fatal on the standard logger then the process will exit with status set to 1.
func Fatal(args ...interface{}) {
std.Fatal(args...)
}
// Tracef logs a message at level Trace on the standard logger.
func Tracef(format string, args ...interface{}) {
std.Tracef(format, args...)
}
// Debugf logs a message at level Debug on the standard logger.
func Debugf(format string, args ...interface{}) {
std.Debugf(format, args...)
@ -147,11 +168,16 @@ func Panicf(format string, args ...interface{}) {
std.Panicf(format, args...)
}
// Fatalf logs a message at level Fatal on the standard logger.
// Fatalf logs a message at level Fatal on the standard logger then the process will exit with status set to 1.
func Fatalf(format string, args ...interface{}) {
std.Fatalf(format, args...)
}
// Traceln logs a message at level Trace on the standard logger.
func Traceln(args ...interface{}) {
std.Traceln(args...)
}
// Debugln logs a message at level Debug on the standard logger.
func Debugln(args ...interface{}) {
std.Debugln(args...)
@ -187,7 +213,7 @@ func Panicln(args ...interface{}) {
std.Panicln(args...)
}
// Fatalln logs a message at level Fatal on the standard logger.
// Fatalln logs a message at level Fatal on the standard logger then the process will exit with status set to 1.
func Fatalln(args ...interface{}) {
std.Fatalln(args...)
}

@ -2,7 +2,16 @@ package logrus
import "time"
const defaultTimestampFormat = time.RFC3339
// Default key names for the default fields
const (
defaultTimestampFormat = time.RFC3339
FieldKeyMsg = "msg"
FieldKeyLevel = "level"
FieldKeyTime = "time"
FieldKeyLogrusError = "logrus_error"
FieldKeyFunc = "func"
FieldKeyFile = "file"
)
// The Formatter interface is used to implement a custom Formatter. It takes an
// `Entry`. It exposes all the fields, including the default ones:
@ -18,7 +27,7 @@ type Formatter interface {
Format(*Entry) ([]byte, error)
}
// This is to not silently overwrite `time`, `msg` and `level` fields when
// This is to not silently overwrite `time`, `msg`, `func` and `level` fields when
// dumping it. If this code wasn't there doing:
//
// logrus.WithField("level", 1).Info("hello")
@ -30,16 +39,40 @@ type Formatter interface {
//
// It's not exported because it's still using Data in an opinionated way. It's to
// avoid code duplication between the two default formatters.
func prefixFieldClashes(data Fields) {
if t, ok := data["time"]; ok {
data["fields.time"] = t
func prefixFieldClashes(data Fields, fieldMap FieldMap, reportCaller bool) {
timeKey := fieldMap.resolve(FieldKeyTime)
if t, ok := data[timeKey]; ok {
data["fields."+timeKey] = t
delete(data, timeKey)
}
if m, ok := data["msg"]; ok {
data["fields.msg"] = m
msgKey := fieldMap.resolve(FieldKeyMsg)
if m, ok := data[msgKey]; ok {
data["fields."+msgKey] = m
delete(data, msgKey)
}
if l, ok := data["level"]; ok {
data["fields.level"] = l
levelKey := fieldMap.resolve(FieldKeyLevel)
if l, ok := data[levelKey]; ok {
data["fields."+levelKey] = l
delete(data, levelKey)
}
logrusErrKey := fieldMap.resolve(FieldKeyLogrusError)
if l, ok := data[logrusErrKey]; ok {
data["fields."+logrusErrKey] = l
delete(data, logrusErrKey)
}
// If reportCaller is not set, 'func' will not conflict.
if reportCaller {
funcKey := fieldMap.resolve(FieldKeyFunc)
if l, ok := data[funcKey]; ok {
data["fields."+funcKey] = l
}
fileKey := fieldMap.resolve(FieldKeyFile)
if l, ok := data[fileKey]; ok {
data["fields."+fileKey] = l
}
}
}

@ -1,10 +1,16 @@
package logrus
package logrus_test
import (
"bytes"
"encoding/json"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
. "github.com/sirupsen/logrus"
. "github.com/sirupsen/logrus/internal/testutils"
)
type TestHook struct {
@ -18,6 +24,7 @@ func (hook *TestHook) Fire(entry *Entry) error {
func (hook *TestHook) Levels() []Level {
return []Level{
TraceLevel,
DebugLevel,
InfoLevel,
WarnLevel,
@ -50,6 +57,7 @@ func (hook *ModifyHook) Fire(entry *Entry) error {
func (hook *ModifyHook) Levels() []Level {
return []Level{
TraceLevel,
DebugLevel,
InfoLevel,
WarnLevel,
@ -85,6 +93,46 @@ func TestCanFireMultipleHooks(t *testing.T) {
})
}
type SingleLevelModifyHook struct {
ModifyHook
}
func (h *SingleLevelModifyHook) Levels() []Level {
return []Level{InfoLevel}
}
func TestHookEntryIsPristine(t *testing.T) {
l := New()
b := &bytes.Buffer{}
l.Formatter = &JSONFormatter{}
l.Out = b
l.AddHook(&SingleLevelModifyHook{})
l.Error("error message")
data := map[string]string{}
err := json.Unmarshal(b.Bytes(), &data)
require.NoError(t, err)
_, ok := data["wow"]
require.False(t, ok)
b.Reset()
l.Info("error message")
data = map[string]string{}
err = json.Unmarshal(b.Bytes(), &data)
require.NoError(t, err)
_, ok = data["wow"]
require.True(t, ok)
b.Reset()
l.Error("error message")
data = map[string]string{}
err = json.Unmarshal(b.Bytes(), &data)
require.NoError(t, err)
_, ok = data["wow"]
require.False(t, ok)
b.Reset()
}
type ErrorHook struct {
Fired bool
}

@ -43,7 +43,7 @@ func (hook *SyslogHook) Fire(entry *logrus.Entry) error {
return hook.Writer.Warning(line)
case logrus.InfoLevel:
return hook.Writer.Info(line)
case logrus.DebugLevel:
case logrus.DebugLevel, logrus.TraceLevel:
return hook.Writer.Debug(line)
default:
return nil

@ -1,3 +1,5 @@
// +build !windows,!nacl,!plan9
package syslog
import (

@ -15,7 +15,7 @@ type Hook struct {
// Entries is an array of all entries that have been received by this hook.
// For safe access, use the AllEntries() method, rather than reading this
// value directly.
Entries []*logrus.Entry
Entries []logrus.Entry
mu sync.RWMutex
}
@ -52,7 +52,7 @@ func NewNullLogger() (*logrus.Logger, *Hook) {
func (t *Hook) Fire(e *logrus.Entry) error {
t.mu.Lock()
defer t.mu.Unlock()
t.Entries = append(t.Entries, e)
t.Entries = append(t.Entries, *e)
return nil
}
@ -68,9 +68,7 @@ func (t *Hook) LastEntry() *logrus.Entry {
if i < 0 {
return nil
}
// Make a copy, for safety
e := *t.Entries[i]
return &e
return &t.Entries[i]
}
// AllEntries returns all entries that were logged.
@ -79,10 +77,9 @@ func (t *Hook) AllEntries() []*logrus.Entry {
defer t.mu.RUnlock()
// Make a copy so the returned value won't race with future log requests
entries := make([]*logrus.Entry, len(t.Entries))
for i, entry := range t.Entries {
for i := 0; i < len(t.Entries); i++ {
// Make a copy, for safety
e := *entry
entries[i] = &e
entries[i] = &t.Entries[i]
}
return entries
}
@ -91,5 +88,5 @@ func (t *Hook) AllEntries() []*logrus.Entry {
func (t *Hook) Reset() {
t.mu.Lock()
defer t.mu.Unlock()
t.Entries = make([]*logrus.Entry, 0)
t.Entries = make([]logrus.Entry, 0)
}

@ -1,14 +1,16 @@
package test
import (
"math/rand"
"sync"
"testing"
"time"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
)
func TestAllHooks(t *testing.T) {
assert := assert.New(t)
logger, hook := NewNullLogger()
@ -35,5 +37,49 @@ func TestAllHooks(t *testing.T) {
assert.Equal(logrus.ErrorLevel, hook.LastEntry().Level)
assert.Equal("Hello error", hook.LastEntry().Message)
assert.Equal(1, len(hook.Entries))
}
func TestLoggingWithHooksRace(t *testing.T) {
rand.Seed(time.Now().Unix())
unlocker := rand.Int() % 100
assert := assert.New(t)
logger, hook := NewNullLogger()
var wgOne, wgAll sync.WaitGroup
wgOne.Add(1)
wgAll.Add(100)
for i := 0; i < 100; i++ {
go func(i int) {
logger.Info("info")
wgAll.Done()
if i == unlocker {
wgOne.Done()
}
}(i)
}
wgOne.Wait()
assert.Equal(logrus.InfoLevel, hook.LastEntry().Level)
assert.Equal("info", hook.LastEntry().Message)
wgAll.Wait()
entries := hook.AllEntries()
assert.Equal(100, len(entries))
}
func TestFatalWithAlternateExit(t *testing.T) {
assert := assert.New(t)
logger, hook := NewNullLogger()
logger.ExitFunc = func(code int) {}
logger.Fatal("something went very wrong")
assert.Equal(logrus.FatalLevel, hook.LastEntry().Level)
assert.Equal("something went very wrong", hook.LastEntry().Message)
assert.Equal(1, len(hook.Entries))
}

@ -1,6 +1,7 @@
package logrus
import (
"bytes"
"encoding/json"
"fmt"
)
@ -10,13 +11,6 @@ type fieldKey string
// FieldMap allows customization of the key names for default fields.
type FieldMap map[fieldKey]string
// Default key names for the default fields
const (
FieldKeyMsg = "msg"
FieldKeyLevel = "level"
FieldKeyTime = "time"
)
func (f FieldMap) resolve(key fieldKey) string {
if k, ok := f[key]; ok {
return k
@ -33,21 +27,28 @@ type JSONFormatter struct {
// DisableTimestamp allows disabling automatic timestamps in output
DisableTimestamp bool
// DataKey allows users to put all the log entry parameters into a nested dictionary at a given key.
DataKey string
// FieldMap allows users to customize the names of keys for default fields.
// As an example:
// formatter := &JSONFormatter{
// FieldMap: FieldMap{
// FieldKeyTime: "@timestamp",
// FieldKeyTime: "@timestamp",
// FieldKeyLevel: "@level",
// FieldKeyMsg: "@message",
// FieldKeyMsg: "@message",
// FieldKeyFunc: "@caller",
// },
// }
FieldMap FieldMap
// PrettyPrint will indent all json logs
PrettyPrint bool
}
// Format renders a single log entry
func (f *JSONFormatter) Format(entry *Entry) ([]byte, error) {
data := make(Fields, len(entry.Data)+3)
data := make(Fields, len(entry.Data)+4)
for k, v := range entry.Data {
switch v := v.(type) {
case error:
@ -58,22 +59,47 @@ func (f *JSONFormatter) Format(entry *Entry) ([]byte, error) {
data[k] = v
}
}
prefixFieldClashes(data)
if f.DataKey != "" {
newData := make(Fields, 4)
newData[f.DataKey] = data
data = newData
}
prefixFieldClashes(data, f.FieldMap, entry.HasCaller())
timestampFormat := f.TimestampFormat
if timestampFormat == "" {
timestampFormat = defaultTimestampFormat
}
if entry.err != "" {
data[f.FieldMap.resolve(FieldKeyLogrusError)] = entry.err
}
if !f.DisableTimestamp {
data[f.FieldMap.resolve(FieldKeyTime)] = entry.Time.Format(timestampFormat)
}
data[f.FieldMap.resolve(FieldKeyMsg)] = entry.Message
data[f.FieldMap.resolve(FieldKeyLevel)] = entry.Level.String()
if entry.HasCaller() {
data[f.FieldMap.resolve(FieldKeyFunc)] = entry.Caller.Function
data[f.FieldMap.resolve(FieldKeyFile)] = fmt.Sprintf("%s:%d", entry.Caller.File, entry.Caller.Line)
}
serialized, err := json.Marshal(data)
if err != nil {
var b *bytes.Buffer
if entry.Buffer != nil {
b = entry.Buffer
} else {
b = &bytes.Buffer{}
}
encoder := json.NewEncoder(b)
if f.PrettyPrint {
encoder.SetIndent("", " ")
}
if err := encoder.Encode(data); err != nil {
return nil, fmt.Errorf("Failed to marshal fields to JSON, %v", err)
}
return append(serialized, '\n'), nil
return b.Bytes(), nil
}

@ -3,6 +3,8 @@ package logrus
import (
"encoding/json"
"errors"
"fmt"
"runtime"
"strings"
"testing"
)
@ -106,6 +108,102 @@ func TestFieldClashWithLevel(t *testing.T) {
}
}
func TestFieldClashWithRemappedFields(t *testing.T) {
formatter := &JSONFormatter{
FieldMap: FieldMap{
FieldKeyTime: "@timestamp",
FieldKeyLevel: "@level",
FieldKeyMsg: "@message",
},
}
b, err := formatter.Format(WithFields(Fields{
"@timestamp": "@timestamp",
"@level": "@level",
"@message": "@message",
"timestamp": "timestamp",
"level": "level",
"msg": "msg",
}))
if err != nil {
t.Fatal("Unable to format entry: ", err)
}
entry := make(map[string]interface{})
err = json.Unmarshal(b, &entry)
if err != nil {
t.Fatal("Unable to unmarshal formatted entry: ", err)
}
for _, field := range []string{"timestamp", "level", "msg"} {
if entry[field] != field {
t.Errorf("Expected field %v to be untouched; got %v", field, entry[field])
}
remappedKey := fmt.Sprintf("fields.%s", field)
if remapped, ok := entry[remappedKey]; ok {
t.Errorf("Expected %s to be empty; got %v", remappedKey, remapped)
}
}
for _, field := range []string{"@timestamp", "@level", "@message"} {
if entry[field] == field {
t.Errorf("Expected field %v to be mapped to an Entry value", field)
}
remappedKey := fmt.Sprintf("fields.%s", field)
if remapped, ok := entry[remappedKey]; ok {
if remapped != field {
t.Errorf("Expected field %v to be copied to %s; got %v", field, remappedKey, remapped)
}
} else {
t.Errorf("Expected field %v to be copied to %s; was absent", field, remappedKey)
}
}
}
func TestFieldsInNestedDictionary(t *testing.T) {
formatter := &JSONFormatter{
DataKey: "args",
}
logEntry := WithFields(Fields{
"level": "level",
"test": "test",
})
logEntry.Level = InfoLevel
b, err := formatter.Format(logEntry)
if err != nil {
t.Fatal("Unable to format entry: ", err)
}
entry := make(map[string]interface{})
err = json.Unmarshal(b, &entry)
if err != nil {
t.Fatal("Unable to unmarshal formatted entry: ", err)
}
args := entry["args"].(map[string]interface{})
for _, field := range []string{"test", "level"} {
if value, present := args[field]; !present || value != field {
t.Errorf("Expected field %v to be present under 'args'; untouched", field)
}
}
for _, field := range []string{"test", "fields.level"} {
if _, present := entry[field]; present {
t.Errorf("Expected field %v not to be present at top level", field)
}
}
// with nested object, "level" shouldn't clash
if entry["level"] != "info" {
t.Errorf("Expected 'level' field to contain 'info'")
}
}
func TestJSONEntryEndsWithNewline(t *testing.T) {
formatter := &JSONFormatter{}
@ -170,6 +268,55 @@ func TestJSONTimeKey(t *testing.T) {
}
}
func TestFieldDoesNotClashWithCaller(t *testing.T) {
SetReportCaller(false)
formatter := &JSONFormatter{}
b, err := formatter.Format(WithField("func", "howdy pardner"))
if err != nil {
t.Fatal("Unable to format entry: ", err)
}
entry := make(map[string]interface{})
err = json.Unmarshal(b, &entry)
if err != nil {
t.Fatal("Unable to unmarshal formatted entry: ", err)
}
if entry["func"] != "howdy pardner" {
t.Fatal("func field replaced when ReportCaller=false")
}
}
func TestFieldClashWithCaller(t *testing.T) {
SetReportCaller(true)
formatter := &JSONFormatter{}
e := WithField("func", "howdy pardner")
e.Caller = &runtime.Frame{Function: "somefunc"}
b, err := formatter.Format(e)
if err != nil {
t.Fatal("Unable to format entry: ", err)
}
entry := make(map[string]interface{})
err = json.Unmarshal(b, &entry)
if err != nil {
t.Fatal("Unable to unmarshal formatted entry: ", err)
}
if entry["fields.func"] != "howdy pardner" {
t.Fatalf("fields.func not set to original func field when ReportCaller=true (got '%s')",
entry["fields.func"])
}
if entry["func"] != "somefunc" {
t.Fatalf("func not set as expected when ReportCaller=true (got '%s')",
entry["func"])
}
SetReportCaller(false) // return to default value
}
func TestJSONDisableTimestamp(t *testing.T) {
formatter := &JSONFormatter{
DisableTimestamp: true,

@ -5,12 +5,13 @@ import (
"os"
"sync"
"sync/atomic"
"time"
)
type Logger struct {
// The logs are `io.Copy`'d to this in a mutex. It's common to set this to a
// file, or leave it default which is `os.Stderr`. You can also set this to
// something more adventorous, such as logging to Kafka.
// something more adventurous, such as logging to Kafka.
Out io.Writer
// Hooks for the logger instance. These allow firing events based on logging
// levels and log entries. For example, to send errors to an error tracking
@ -23,6 +24,10 @@ type Logger struct {
// own that implements the `Formatter` interface, see the `README` or included
// formatters for examples.
Formatter Formatter
// Flag for whether to log caller info (off by default)
ReportCaller bool
// The logging level the logger should log at. This is typically (and defaults
// to) `logrus.Info`, which allows Info(), Warn(), Error() and Fatal() to be
// logged.
@ -31,8 +36,12 @@ type Logger struct {
mu MutexWrap
// Reusable empty entry
entryPool sync.Pool
// Function to exit the application, defaults to `os.Exit()`
ExitFunc exitFunc
}
type exitFunc func(int)
type MutexWrap struct {
lock sync.Mutex
disabled bool
@ -68,10 +77,12 @@ func (mw *MutexWrap) Disable() {
// It's recommended to make this a global instance called `log`.
func New() *Logger {
return &Logger{
Out: os.Stderr,
Formatter: new(TextFormatter),
Hooks: make(LevelHooks),
Level: InfoLevel,
Out: os.Stderr,
Formatter: new(TextFormatter),
Hooks: make(LevelHooks),
Level: InfoLevel,
ExitFunc: os.Exit,
ReportCaller: false,
}
}
@ -84,11 +95,12 @@ func (logger *Logger) newEntry() *Entry {
}
func (logger *Logger) releaseEntry(entry *Entry) {
entry.Data = map[string]interface{}{}
logger.entryPool.Put(entry)
}
// Adds a field to the log entry, note that it doesn't log until you call
// Debug, Print, Info, Warn, Fatal or Panic. It only creates a log entry.
// Debug, Print, Info, Warn, Error, Fatal or Panic. It only creates a log entry.
// If you want multiple fields, use `WithFields`.
func (logger *Logger) WithField(key string, value interface{}) *Entry {
entry := logger.newEntry()
@ -112,20 +124,31 @@ func (logger *Logger) WithError(err error) *Entry {
return entry.WithError(err)
}
func (logger *Logger) Debugf(format string, args ...interface{}) {
if logger.level() >= DebugLevel {
// Overrides the time of the log entry.
func (logger *Logger) WithTime(t time.Time) *Entry {
entry := logger.newEntry()
defer logger.releaseEntry(entry)
return entry.WithTime(t)
}
func (logger *Logger) Logf(level Level, format string, args ...interface{}) {
if logger.IsLevelEnabled(level) {
entry := logger.newEntry()
entry.Debugf(format, args...)
entry.Logf(level, format, args...)
logger.releaseEntry(entry)
}
}
func (logger *Logger) Tracef(format string, args ...interface{}) {
logger.Logf(TraceLevel, format, args...)
}
func (logger *Logger) Debugf(format string, args ...interface{}) {
logger.Logf(DebugLevel, format, args...)
}
func (logger *Logger) Infof(format string, args ...interface{}) {
if logger.level() >= InfoLevel {
entry := logger.newEntry()
entry.Infof(format, args...)
logger.releaseEntry(entry)
}
logger.Logf(InfoLevel, format, args...)
}
func (logger *Logger) Printf(format string, args ...interface{}) {
@ -135,60 +158,44 @@ func (logger *Logger) Printf(format string, args ...interface{}) {
}
func (logger *Logger) Warnf(format string, args ...interface{}) {
if logger.level() >= WarnLevel {
entry := logger.newEntry()
entry.Warnf(format, args...)
logger.releaseEntry(entry)
}
logger.Logf(WarnLevel, format, args...)
}
func (logger *Logger) Warningf(format string, args ...interface{}) {
if logger.level() >= WarnLevel {
entry := logger.newEntry()
entry.Warnf(format, args...)
logger.releaseEntry(entry)
}
logger.Warnf(format, args...)
}
func (logger *Logger) Errorf(format string, args ...interface{}) {
if logger.level() >= ErrorLevel {
entry := logger.newEntry()
entry.Errorf(format, args...)
logger.releaseEntry(entry)
}
logger.Logf(ErrorLevel, format, args...)
}
func (logger *Logger) Fatalf(format string, args ...interface{}) {
if logger.level() >= FatalLevel {
entry := logger.newEntry()
entry.Fatalf(format, args...)
logger.releaseEntry(entry)
}
Exit(1)
logger.Logf(FatalLevel, format, args...)
logger.Exit(1)
}
func (logger *Logger) Panicf(format string, args ...interface{}) {
if logger.level() >= PanicLevel {
logger.Logf(PanicLevel, format, args...)
}
func (logger *Logger) Log(level Level, args ...interface{}) {
if logger.IsLevelEnabled(level) {
entry := logger.newEntry()
entry.Panicf(format, args...)
entry.Log(level, args...)
logger.releaseEntry(entry)
}
}
func (logger *Logger) Trace(args ...interface{}) {
logger.Log(TraceLevel, args...)
}
func (logger *Logger) Debug(args ...interface{}) {
if logger.level() >= DebugLevel {
entry := logger.newEntry()
entry.Debug(args...)
logger.releaseEntry(entry)
}
logger.Log(DebugLevel, args...)
}
func (logger *Logger) Info(args ...interface{}) {
if logger.level() >= InfoLevel {
entry := logger.newEntry()
entry.Info(args...)
logger.releaseEntry(entry)
}
logger.Log(InfoLevel, args...)
}
func (logger *Logger) Print(args ...interface{}) {
@ -198,60 +205,44 @@ func (logger *Logger) Print(args ...interface{}) {
}
func (logger *Logger) Warn(args ...interface{}) {
if logger.level() >= WarnLevel {
entry := logger.newEntry()
entry.Warn(args...)
logger.releaseEntry(entry)
}
logger.Log(WarnLevel, args...)
}
func (logger *Logger) Warning(args ...interface{}) {
if logger.level() >= WarnLevel {
entry := logger.newEntry()
entry.Warn(args...)
logger.releaseEntry(entry)
}
logger.Warn(args...)
}
func (logger *Logger) Error(args ...interface{}) {
if logger.level() >= ErrorLevel {
entry := logger.newEntry()
entry.Error(args...)
logger.releaseEntry(entry)
}
logger.Log(ErrorLevel, args...)
}
func (logger *Logger) Fatal(args ...interface{}) {
if logger.level() >= FatalLevel {
entry := logger.newEntry()
entry.Fatal(args...)
logger.releaseEntry(entry)
}
Exit(1)
logger.Log(FatalLevel, args...)
logger.Exit(1)
}
func (logger *Logger) Panic(args ...interface{}) {
if logger.level() >= PanicLevel {
logger.Log(PanicLevel, args...)
}
func (logger *Logger) Logln(level Level, args ...interface{}) {
if logger.IsLevelEnabled(level) {
entry := logger.newEntry()
entry.Panic(args...)
entry.Logln(level, args...)
logger.releaseEntry(entry)
}
}
func (logger *Logger) Traceln(args ...interface{}) {
logger.Logln(TraceLevel, args...)
}
func (logger *Logger) Debugln(args ...interface{}) {
if logger.level() >= DebugLevel {
entry := logger.newEntry()
entry.Debugln(args...)
logger.releaseEntry(entry)
}
logger.Logln(DebugLevel, args...)
}
func (logger *Logger) Infoln(args ...interface{}) {
if logger.level() >= InfoLevel {
entry := logger.newEntry()
entry.Infoln(args...)
logger.releaseEntry(entry)
}
logger.Logln(InfoLevel, args...)
}
func (logger *Logger) Println(args ...interface{}) {
@ -261,44 +252,32 @@ func (logger *Logger) Println(args ...interface{}) {
}
func (logger *Logger) Warnln(args ...interface{}) {
if logger.level() >= WarnLevel {
entry := logger.newEntry()
entry.Warnln(args...)
logger.releaseEntry(entry)
}
logger.Logln(WarnLevel, args...)
}
func (logger *Logger) Warningln(args ...interface{}) {
if logger.level() >= WarnLevel {
entry := logger.newEntry()
entry.Warnln(args...)
logger.releaseEntry(entry)
}
logger.Warn(args...)
}
func (logger *Logger) Errorln(args ...interface{}) {
if logger.level() >= ErrorLevel {
entry := logger.newEntry()
entry.Errorln(args...)
logger.releaseEntry(entry)
}
logger.Logln(ErrorLevel, args...)
}
func (logger *Logger) Fatalln(args ...interface{}) {
if logger.level() >= FatalLevel {
entry := logger.newEntry()
entry.Fatalln(args...)
logger.releaseEntry(entry)
}
Exit(1)
logger.Logln(FatalLevel, args...)
logger.Exit(1)
}
func (logger *Logger) Panicln(args ...interface{}) {
if logger.level() >= PanicLevel {
entry := logger.newEntry()
entry.Panicln(args...)
logger.releaseEntry(entry)
logger.Logln(PanicLevel, args...)
}
func (logger *Logger) Exit(code int) {
runHandlers()
if logger.ExitFunc == nil {
logger.ExitFunc = os.Exit
}
logger.ExitFunc(code)
}
//When file is opened with appending mode, it's safe to
@ -312,12 +291,53 @@ func (logger *Logger) level() Level {
return Level(atomic.LoadUint32((*uint32)(&logger.Level)))
}
// SetLevel sets the logger level.
func (logger *Logger) SetLevel(level Level) {
atomic.StoreUint32((*uint32)(&logger.Level), uint32(level))
}
// GetLevel returns the logger level.
func (logger *Logger) GetLevel() Level {
return logger.level()
}
// AddHook adds a hook to the logger hooks.
func (logger *Logger) AddHook(hook Hook) {
logger.mu.Lock()
defer logger.mu.Unlock()
logger.Hooks.Add(hook)
}
// IsLevelEnabled checks if the log level of the logger is greater than the level param
func (logger *Logger) IsLevelEnabled(level Level) bool {
return logger.level() >= level
}
// SetFormatter sets the logger formatter.
func (logger *Logger) SetFormatter(formatter Formatter) {
logger.mu.Lock()
defer logger.mu.Unlock()
logger.Formatter = formatter
}
// SetOutput sets the logger output.
func (logger *Logger) SetOutput(output io.Writer) {
logger.mu.Lock()
defer logger.mu.Unlock()
logger.Out = output
}
func (logger *Logger) SetReportCaller(reportCaller bool) {
logger.mu.Lock()
defer logger.mu.Unlock()
logger.ReportCaller = reportCaller
}
// ReplaceHooks replaces the logger hooks and returns the old ones
func (logger *Logger) ReplaceHooks(hooks LevelHooks) LevelHooks {
logger.mu.Lock()
oldHooks := logger.Hooks
logger.Hooks = hooks
logger.mu.Unlock()
return oldHooks
}

@ -1,6 +1,7 @@
package logrus
import (
"io/ioutil"
"os"
"testing"
)
@ -59,3 +60,26 @@ func doLoggerBenchmarkNoLock(b *testing.B, out *os.File, formatter Formatter, fi
}
})
}
func BenchmarkLoggerJSONFormatter(b *testing.B) {
doLoggerBenchmarkWithFormatter(b, &JSONFormatter{})
}
func BenchmarkLoggerTextFormatter(b *testing.B) {
doLoggerBenchmarkWithFormatter(b, &TextFormatter{})
}
func doLoggerBenchmarkWithFormatter(b *testing.B, f Formatter) {
b.SetParallelism(100)
log := New()
log.Formatter = f
log.Out = ioutil.Discard
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
log.
WithField("foo1", "bar1").
WithField("foo2", "bar2").
Info("this is a dummy log")
}
})
}

@ -14,22 +14,11 @@ type Level uint32
// Convert the Level to a string. E.g. PanicLevel becomes "panic".
func (level Level) String() string {
switch level {
case DebugLevel:
return "debug"
case InfoLevel:
return "info"
case WarnLevel:
return "warning"
case ErrorLevel:
return "error"
case FatalLevel:
return "fatal"
case PanicLevel:
return "panic"
if b, err := level.MarshalText(); err == nil {
return string(b)
} else {
return "unknown"
}
return "unknown"
}
// ParseLevel takes a string level and returns the Logrus log level constant.
@ -47,12 +36,47 @@ func ParseLevel(lvl string) (Level, error) {
return InfoLevel, nil
case "debug":
return DebugLevel, nil
case "trace":
return TraceLevel, nil
}
var l Level
return l, fmt.Errorf("not a valid logrus Level: %q", lvl)
}
// UnmarshalText implements encoding.TextUnmarshaler.
func (level *Level) UnmarshalText(text []byte) error {
l, err := ParseLevel(string(text))
if err != nil {
return err
}
*level = Level(l)
return nil
}
func (level Level) MarshalText() ([]byte, error) {
switch level {
case TraceLevel:
return []byte("trace"), nil
case DebugLevel:
return []byte("debug"), nil
case InfoLevel:
return []byte("info"), nil
case WarnLevel:
return []byte("warning"), nil
case ErrorLevel:
return []byte("error"), nil
case FatalLevel:
return []byte("fatal"), nil
case PanicLevel:
return []byte("panic"), nil
}
return nil, fmt.Errorf("not a valid lorus level %q", level)
}
// A constant exposing all logging levels
var AllLevels = []Level{
PanicLevel,
@ -61,6 +85,7 @@ var AllLevels = []Level{
WarnLevel,
InfoLevel,
DebugLevel,
TraceLevel,
}
// These are the different logging levels. You can set the logging level to log
@ -69,7 +94,7 @@ const (
// PanicLevel level, highest level of severity. Logs and then calls panic with the
// message passed to Debug, Info, ...
PanicLevel Level = iota
// FatalLevel level. Logs and then calls `os.Exit(1)`. It will exit even if the
// FatalLevel level. Logs and then calls `logger.Exit(1)`. It will exit even if the
// logging level is set to Panic.
FatalLevel
// ErrorLevel level. Logs. Used for errors that should definitely be noted.
@ -82,6 +107,8 @@ const (
InfoLevel
// DebugLevel level. Usually only enabled when debugging. Very verbose logging.
DebugLevel
// TraceLevel level. Designates finer-grained informational events than the Debug.
TraceLevel
)
// Won't compile if StdLogger can't be realized by a log.Logger
@ -140,4 +167,20 @@ type FieldLogger interface {
Errorln(args ...interface{})
Fatalln(args ...interface{})
Panicln(args ...interface{})
// IsDebugEnabled() bool
// IsInfoEnabled() bool
// IsWarnEnabled() bool
// IsErrorEnabled() bool
// IsFatalEnabled() bool
// IsPanicEnabled() bool
}
// Ext1FieldLogger (the first extension to FieldLogger) is superfluous, it is
// here for consistancy. Do not use. Use Logger or Entry instead.
type Ext1FieldLogger interface {
FieldLogger
Tracef(format string, args ...interface{})
Trace(args ...interface{})
Traceln(args ...interface{})
}

@ -1,67 +1,95 @@
package logrus
package logrus_test
import (
"bytes"
"encoding/json"
"strconv"
"strings"
"fmt"
"io/ioutil"
"os"
"path/filepath"
"runtime"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
. "github.com/sirupsen/logrus"
. "github.com/sirupsen/logrus/internal/testutils"
)
func LogAndAssertJSON(t *testing.T, log func(*Logger), assertions func(fields Fields)) {
// TestReportCaller verifies that when ReportCaller is set, the 'func' field
// is added, and when it is unset it is not set or modified
// Verify that functions within the Logrus package aren't considered when
// discovering the caller.
func TestReportCallerWhenConfigured(t *testing.T) {
LogAndAssertJSON(t, func(log *Logger) {
log.ReportCaller = false
log.Print("testNoCaller")
}, func(fields Fields) {
assert.Equal(t, "testNoCaller", fields["msg"])
assert.Equal(t, "info", fields["level"])
assert.Equal(t, nil, fields["func"])
})
LogAndAssertJSON(t, func(log *Logger) {
log.ReportCaller = true
log.Print("testWithCaller")
}, func(fields Fields) {
assert.Equal(t, "testWithCaller", fields["msg"])
assert.Equal(t, "info", fields["level"])
assert.Equal(t,
"github.com/sirupsen/logrus_test.TestReportCallerWhenConfigured.func3", fields["func"])
})
}
func logSomething(t *testing.T, message string) Fields {
var buffer bytes.Buffer
var fields Fields
logger := New()
logger.Out = &buffer
logger.Formatter = new(JSONFormatter)
logger.ReportCaller = true
log(logger)
entry := logger.WithFields(Fields{
"foo": "bar",
})
entry.Info(message)
err := json.Unmarshal(buffer.Bytes(), &fields)
assert.Nil(t, err)
assertions(fields)
return fields
}
func LogAndAssertText(t *testing.T, log func(*Logger), assertions func(fields map[string]string)) {
var buffer bytes.Buffer
// TestReportCallerHelperDirect - verify reference when logging from a regular function
func TestReportCallerHelperDirect(t *testing.T) {
fields := logSomething(t, "direct")
logger := New()
logger.Out = &buffer
logger.Formatter = &TextFormatter{
DisableColors: true,
}
assert.Equal(t, "direct", fields["msg"])
assert.Equal(t, "info", fields["level"])
assert.Regexp(t, "github.com/.*/logrus_test.logSomething", fields["func"])
}
log(logger)
// TestReportCallerHelperDirect - verify reference when logging from a function called via pointer
func TestReportCallerHelperViaPointer(t *testing.T) {
fptr := logSomething
fields := fptr(t, "via pointer")
fields := make(map[string]string)
for _, kv := range strings.Split(buffer.String(), " ") {
if !strings.Contains(kv, "=") {
continue
}
kvArr := strings.Split(kv, "=")
key := strings.TrimSpace(kvArr[0])
val := kvArr[1]
if kvArr[1][0] == '"' {
var err error
val, err = strconv.Unquote(val)
assert.NoError(t, err)
}
fields[key] = val
}
assertions(fields)
assert.Equal(t, "via pointer", fields["msg"])
assert.Equal(t, "info", fields["level"])
assert.Regexp(t, "github.com/.*/logrus_test.logSomething", fields["func"])
}
func TestPrint(t *testing.T) {
LogAndAssertJSON(t, func(log *Logger) {
log.Print("test")
}, func(fields Fields) {
assert.Equal(t, fields["msg"], "test")
assert.Equal(t, fields["level"], "info")
assert.Equal(t, "test", fields["msg"])
assert.Equal(t, "info", fields["level"])
})
}
@ -69,8 +97,8 @@ func TestInfo(t *testing.T) {
LogAndAssertJSON(t, func(log *Logger) {
log.Info("test")
}, func(fields Fields) {
assert.Equal(t, fields["msg"], "test")
assert.Equal(t, fields["level"], "info")
assert.Equal(t, "test", fields["msg"])
assert.Equal(t, "info", fields["level"])
})
}
@ -78,8 +106,17 @@ func TestWarn(t *testing.T) {
LogAndAssertJSON(t, func(log *Logger) {
log.Warn("test")
}, func(fields Fields) {
assert.Equal(t, fields["msg"], "test")
assert.Equal(t, fields["level"], "warning")
assert.Equal(t, "test", fields["msg"])
assert.Equal(t, "warning", fields["level"])
})
}
func TestLog(t *testing.T) {
LogAndAssertJSON(t, func(log *Logger) {
log.Log(WarnLevel, "test")
}, func(fields Fields) {
assert.Equal(t, "test", fields["msg"])
assert.Equal(t, "warning", fields["level"])
})
}
@ -87,7 +124,7 @@ func TestInfolnShouldAddSpacesBetweenStrings(t *testing.T) {
LogAndAssertJSON(t, func(log *Logger) {
log.Infoln("test", "test")
}, func(fields Fields) {
assert.Equal(t, fields["msg"], "test test")
assert.Equal(t, "test test", fields["msg"])
})
}
@ -95,7 +132,7 @@ func TestInfolnShouldAddSpacesBetweenStringAndNonstring(t *testing.T) {
LogAndAssertJSON(t, func(log *Logger) {
log.Infoln("test", 10)
}, func(fields Fields) {
assert.Equal(t, fields["msg"], "test 10")
assert.Equal(t, "test 10", fields["msg"])
})
}
@ -103,7 +140,7 @@ func TestInfolnShouldAddSpacesBetweenTwoNonStrings(t *testing.T) {
LogAndAssertJSON(t, func(log *Logger) {
log.Infoln(10, 10)
}, func(fields Fields) {
assert.Equal(t, fields["msg"], "10 10")
assert.Equal(t, "10 10", fields["msg"])
})
}
@ -111,7 +148,7 @@ func TestInfoShouldAddSpacesBetweenTwoNonStrings(t *testing.T) {
LogAndAssertJSON(t, func(log *Logger) {
log.Infoln(10, 10)
}, func(fields Fields) {
assert.Equal(t, fields["msg"], "10 10")
assert.Equal(t, "10 10", fields["msg"])
})
}
@ -119,7 +156,7 @@ func TestInfoShouldNotAddSpacesBetweenStringAndNonstring(t *testing.T) {
LogAndAssertJSON(t, func(log *Logger) {
log.Info("test", 10)
}, func(fields Fields) {
assert.Equal(t, fields["msg"], "test10")
assert.Equal(t, "test10", fields["msg"])
})
}
@ -127,7 +164,7 @@ func TestInfoShouldNotAddSpacesBetweenStrings(t *testing.T) {
LogAndAssertJSON(t, func(log *Logger) {
log.Info("test", "test")
}, func(fields Fields) {
assert.Equal(t, fields["msg"], "testtest")
assert.Equal(t, "testtest", fields["msg"])
})
}
@ -165,7 +202,7 @@ func TestUserSuppliedFieldDoesNotOverwriteDefaults(t *testing.T) {
LogAndAssertJSON(t, func(log *Logger) {
log.WithField("msg", "hello").Info("test")
}, func(fields Fields) {
assert.Equal(t, fields["msg"], "test")
assert.Equal(t, "test", fields["msg"])
})
}
@ -173,8 +210,8 @@ func TestUserSuppliedMsgFieldHasPrefix(t *testing.T) {
LogAndAssertJSON(t, func(log *Logger) {
log.WithField("msg", "hello").Info("test")
}, func(fields Fields) {
assert.Equal(t, fields["msg"], "test")
assert.Equal(t, fields["fields.msg"], "hello")
assert.Equal(t, "test", fields["msg"])
assert.Equal(t, "hello", fields["fields.msg"])
})
}
@ -182,7 +219,7 @@ func TestUserSuppliedTimeFieldHasPrefix(t *testing.T) {
LogAndAssertJSON(t, func(log *Logger) {
log.WithField("time", "hello").Info("test")
}, func(fields Fields) {
assert.Equal(t, fields["fields.time"], "hello")
assert.Equal(t, "hello", fields["fields.time"])
})
}
@ -190,8 +227,8 @@ func TestUserSuppliedLevelFieldHasPrefix(t *testing.T) {
LogAndAssertJSON(t, func(log *Logger) {
log.WithField("level", 1).Info("test")
}, func(fields Fields) {
assert.Equal(t, fields["level"], "info")
assert.Equal(t, fields["fields.level"], 1.0) // JSON has floats only
assert.Equal(t, "info", fields["level"])
assert.Equal(t, 1.0, fields["fields.level"]) // JSON has floats only
})
}
@ -209,6 +246,65 @@ func TestDefaultFieldsAreNotPrefixed(t *testing.T) {
})
}
func TestWithTimeShouldOverrideTime(t *testing.T) {
now := time.Now().Add(24 * time.Hour)
LogAndAssertJSON(t, func(log *Logger) {
log.WithTime(now).Info("foobar")
}, func(fields Fields) {
assert.Equal(t, fields["time"], now.Format(time.RFC3339))
})
}
func TestWithTimeShouldNotOverrideFields(t *testing.T) {
now := time.Now().Add(24 * time.Hour)
LogAndAssertJSON(t, func(log *Logger) {
log.WithField("herp", "derp").WithTime(now).Info("blah")
}, func(fields Fields) {
assert.Equal(t, fields["time"], now.Format(time.RFC3339))
assert.Equal(t, fields["herp"], "derp")
})
}
func TestWithFieldShouldNotOverrideTime(t *testing.T) {
now := time.Now().Add(24 * time.Hour)
LogAndAssertJSON(t, func(log *Logger) {
log.WithTime(now).WithField("herp", "derp").Info("blah")
}, func(fields Fields) {
assert.Equal(t, fields["time"], now.Format(time.RFC3339))
assert.Equal(t, fields["herp"], "derp")
})
}
func TestTimeOverrideMultipleLogs(t *testing.T) {
var buffer bytes.Buffer
var firstFields, secondFields Fields
logger := New()
logger.Out = &buffer
formatter := new(JSONFormatter)
formatter.TimestampFormat = time.StampMilli
logger.Formatter = formatter
llog := logger.WithField("herp", "derp")
llog.Info("foo")
err := json.Unmarshal(buffer.Bytes(), &firstFields)
assert.NoError(t, err, "should have decoded first message")
buffer.Reset()
time.Sleep(10 * time.Millisecond)
llog.Info("bar")
err = json.Unmarshal(buffer.Bytes(), &secondFields)
assert.NoError(t, err, "should have decoded second message")
assert.NotEqual(t, firstFields["time"], secondFields["time"], "timestamps should not be equal")
}
func TestDoubleLoggingDoesntPrefixPreviousFields(t *testing.T) {
var buffer bytes.Buffer
@ -235,13 +331,119 @@ func TestDoubleLoggingDoesntPrefixPreviousFields(t *testing.T) {
err = json.Unmarshal(buffer.Bytes(), &fields)
assert.NoError(t, err, "should have decoded second message")
assert.Equal(t, len(fields), 4, "should only have msg/time/level/context fields")
assert.Equal(t, fields["msg"], "omg it is!")
assert.Equal(t, fields["context"], "eating raw fish")
assert.Equal(t, "omg it is!", fields["msg"])
assert.Equal(t, "eating raw fish", fields["context"])
assert.Nil(t, fields["fields.msg"], "should not have prefixed previous `msg` entry")
}
func TestNestedLoggingReportsCorrectCaller(t *testing.T) {
var buffer bytes.Buffer
var fields Fields
logger := New()
logger.Out = &buffer
logger.Formatter = new(JSONFormatter)
logger.ReportCaller = true
llog := logger.WithField("context", "eating raw fish")
llog.Info("looks delicious")
_, _, line, _ := runtime.Caller(0)
err := json.Unmarshal(buffer.Bytes(), &fields)
require.NoError(t, err, "should have decoded first message")
assert.Equal(t, 6, len(fields), "should have msg/time/level/func/context fields")
assert.Equal(t, "looks delicious", fields["msg"])
assert.Equal(t, "eating raw fish", fields["context"])
assert.Equal(t,
"github.com/sirupsen/logrus_test.TestNestedLoggingReportsCorrectCaller", fields["func"])
cwd, err := os.Getwd()
require.NoError(t, err)
assert.Equal(t, filepath.ToSlash(fmt.Sprintf("%s/logrus_test.go:%d", cwd, line-1)), filepath.ToSlash(fields["file"].(string)))
buffer.Reset()
logger.WithFields(Fields{
"Clyde": "Stubblefield",
}).WithFields(Fields{
"Jab'o": "Starks",
}).WithFields(Fields{
"uri": "https://www.youtube.com/watch?v=V5DTznu-9v0",
}).WithFields(Fields{
"func": "y drummer",
}).WithFields(Fields{
"James": "Brown",
}).Print("The hardest workin' man in show business")
_, _, line, _ = runtime.Caller(0)
err = json.Unmarshal(buffer.Bytes(), &fields)
assert.NoError(t, err, "should have decoded second message")
assert.Equal(t, 11, len(fields), "should have all builtin fields plus foo,bar,baz,...")
assert.Equal(t, "Stubblefield", fields["Clyde"])
assert.Equal(t, "Starks", fields["Jab'o"])
assert.Equal(t, "https://www.youtube.com/watch?v=V5DTznu-9v0", fields["uri"])
assert.Equal(t, "y drummer", fields["fields.func"])
assert.Equal(t, "Brown", fields["James"])
assert.Equal(t, "The hardest workin' man in show business", fields["msg"])
assert.Nil(t, fields["fields.msg"], "should not have prefixed previous `msg` entry")
assert.Equal(t,
"github.com/sirupsen/logrus_test.TestNestedLoggingReportsCorrectCaller", fields["func"])
require.NoError(t, err)
assert.Equal(t, filepath.ToSlash(fmt.Sprintf("%s/logrus_test.go:%d", cwd, line-1)), filepath.ToSlash(fields["file"].(string)))
logger.ReportCaller = false // return to default value
}
func logLoop(iterations int, reportCaller bool) {
var buffer bytes.Buffer
logger := New()
logger.Out = &buffer
logger.Formatter = new(JSONFormatter)
logger.ReportCaller = reportCaller
for i := 0; i < iterations; i++ {
logger.Infof("round %d of %d", i, iterations)
}
}
// Assertions for upper bounds to reporting overhead
func TestCallerReportingOverhead(t *testing.T) {
iterations := 5000
before := time.Now()
logLoop(iterations, false)
during := time.Now()
logLoop(iterations, true)
after := time.Now()
elapsedNotReporting := during.Sub(before).Nanoseconds()
elapsedReporting := after.Sub(during).Nanoseconds()
maxDelta := 1 * time.Second
assert.WithinDuration(t, during, before, maxDelta,
"%d log calls without caller name lookup takes less than %d second(s) (was %d nanoseconds)",
iterations, maxDelta.Seconds(), elapsedNotReporting)
assert.WithinDuration(t, after, during, maxDelta,
"%d log calls without caller name lookup takes less than %d second(s) (was %d nanoseconds)",
iterations, maxDelta.Seconds(), elapsedReporting)
}
// benchmarks for both with and without caller-function reporting
func BenchmarkWithoutCallerTracing(b *testing.B) {
for i := 0; i < b.N; i++ {
logLoop(1000, false)
}
}
func BenchmarkWithCallerTracing(b *testing.B) {
for i := 0; i < b.N; i++ {
logLoop(1000, true)
}
}
func TestConvertLevelToString(t *testing.T) {
assert.Equal(t, "trace", TraceLevel.String())
assert.Equal(t, "debug", DebugLevel.String())
assert.Equal(t, "info", InfoLevel.String())
assert.Equal(t, "warning", WarnLevel.String())
@ -307,6 +509,14 @@ func TestParseLevel(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, DebugLevel, l)
l, err = ParseLevel("trace")
assert.Nil(t, err)
assert.Equal(t, TraceLevel, l)
l, err = ParseLevel("TRACE")
assert.Nil(t, err)
assert.Equal(t, TraceLevel, l)
l, err = ParseLevel("invalid")
assert.Equal(t, "not a valid logrus Level: \"invalid\"", err.Error())
}
@ -343,10 +553,52 @@ func TestLoggingRace(t *testing.T) {
wg.Wait()
}
func TestLoggingRaceWithHooksOnEntry(t *testing.T) {
logger := New()
hook := new(ModifyHook)
logger.AddHook(hook)
entry := logger.WithField("context", "clue")
var wg sync.WaitGroup
wg.Add(100)
for i := 0; i < 100; i++ {
go func() {
entry.Info("info")
wg.Done()
}()
}
wg.Wait()
}
func TestReplaceHooks(t *testing.T) {
old, cur := &TestHook{}, &TestHook{}
logger := New()
logger.SetOutput(ioutil.Discard)
logger.AddHook(old)
hooks := make(LevelHooks)
hooks.Add(cur)
replaced := logger.ReplaceHooks(hooks)
logger.Info("test")
assert.Equal(t, old.Fired, false)
assert.Equal(t, cur.Fired, true)
logger.ReplaceHooks(replaced)
logger.Info("test")
assert.Equal(t, old.Fired, true)
}
// Compile test
func TestLogrusInterface(t *testing.T) {
func TestLogrusInterfaces(t *testing.T) {
var buffer bytes.Buffer
fn := func(l FieldLogger) {
// This verifies FieldLogger and Ext1FieldLogger work as designed.
// Please don't use them. Use Logger and Entry directly.
fn := func(xl Ext1FieldLogger) {
var l FieldLogger = xl
b := l.WithField("key", "value")
b.Debug("Test")
}
@ -384,3 +636,81 @@ func TestEntryWriter(t *testing.T) {
assert.Equal(t, fields["foo"], "bar")
assert.Equal(t, fields["level"], "warning")
}
func TestLogLevelEnabled(t *testing.T) {
log := New()
log.SetLevel(PanicLevel)
assert.Equal(t, true, log.IsLevelEnabled(PanicLevel))
assert.Equal(t, false, log.IsLevelEnabled(FatalLevel))
assert.Equal(t, false, log.IsLevelEnabled(ErrorLevel))
assert.Equal(t, false, log.IsLevelEnabled(WarnLevel))
assert.Equal(t, false, log.IsLevelEnabled(InfoLevel))
assert.Equal(t, false, log.IsLevelEnabled(DebugLevel))
assert.Equal(t, false, log.IsLevelEnabled(TraceLevel))
log.SetLevel(FatalLevel)
assert.Equal(t, true, log.IsLevelEnabled(PanicLevel))
assert.Equal(t, true, log.IsLevelEnabled(FatalLevel))
assert.Equal(t, false, log.IsLevelEnabled(ErrorLevel))
assert.Equal(t, false, log.IsLevelEnabled(WarnLevel))
assert.Equal(t, false, log.IsLevelEnabled(InfoLevel))
assert.Equal(t, false, log.IsLevelEnabled(DebugLevel))
assert.Equal(t, false, log.IsLevelEnabled(TraceLevel))
log.SetLevel(ErrorLevel)
assert.Equal(t, true, log.IsLevelEnabled(PanicLevel))
assert.Equal(t, true, log.IsLevelEnabled(FatalLevel))
assert.Equal(t, true, log.IsLevelEnabled(ErrorLevel))
assert.Equal(t, false, log.IsLevelEnabled(WarnLevel))
assert.Equal(t, false, log.IsLevelEnabled(InfoLevel))
assert.Equal(t, false, log.IsLevelEnabled(DebugLevel))
assert.Equal(t, false, log.IsLevelEnabled(TraceLevel))
log.SetLevel(WarnLevel)
assert.Equal(t, true, log.IsLevelEnabled(PanicLevel))
assert.Equal(t, true, log.IsLevelEnabled(FatalLevel))
assert.Equal(t, true, log.IsLevelEnabled(ErrorLevel))
assert.Equal(t, true, log.IsLevelEnabled(WarnLevel))
assert.Equal(t, false, log.IsLevelEnabled(InfoLevel))
assert.Equal(t, false, log.IsLevelEnabled(DebugLevel))
assert.Equal(t, false, log.IsLevelEnabled(TraceLevel))
log.SetLevel(InfoLevel)
assert.Equal(t, true, log.IsLevelEnabled(PanicLevel))
assert.Equal(t, true, log.IsLevelEnabled(FatalLevel))
assert.Equal(t, true, log.IsLevelEnabled(ErrorLevel))
assert.Equal(t, true, log.IsLevelEnabled(WarnLevel))
assert.Equal(t, true, log.IsLevelEnabled(InfoLevel))
assert.Equal(t, false, log.IsLevelEnabled(DebugLevel))
assert.Equal(t, false, log.IsLevelEnabled(TraceLevel))
log.SetLevel(DebugLevel)
assert.Equal(t, true, log.IsLevelEnabled(PanicLevel))
assert.Equal(t, true, log.IsLevelEnabled(FatalLevel))
assert.Equal(t, true, log.IsLevelEnabled(ErrorLevel))
assert.Equal(t, true, log.IsLevelEnabled(WarnLevel))
assert.Equal(t, true, log.IsLevelEnabled(InfoLevel))
assert.Equal(t, true, log.IsLevelEnabled(DebugLevel))
assert.Equal(t, false, log.IsLevelEnabled(TraceLevel))
log.SetLevel(TraceLevel)
assert.Equal(t, true, log.IsLevelEnabled(PanicLevel))
assert.Equal(t, true, log.IsLevelEnabled(FatalLevel))
assert.Equal(t, true, log.IsLevelEnabled(ErrorLevel))
assert.Equal(t, true, log.IsLevelEnabled(WarnLevel))
assert.Equal(t, true, log.IsLevelEnabled(InfoLevel))
assert.Equal(t, true, log.IsLevelEnabled(DebugLevel))
assert.Equal(t, true, log.IsLevelEnabled(TraceLevel))
}
func TestReportCallerOnTextFormatter(t *testing.T) {
l := New()
l.Formatter.(*TextFormatter).ForceColors = true
l.Formatter.(*TextFormatter).DisableColors = false
l.WithFields(Fields{"func": "func", "file": "file"}).Info("test")
l.Formatter.(*TextFormatter).ForceColors = false
l.Formatter.(*TextFormatter).DisableColors = true
l.WithFields(Fields{"func": "func", "file": "file"}).Info("test")
}

@ -1,10 +0,0 @@
// +build darwin freebsd openbsd netbsd dragonfly
// +build !appengine
package logrus
import "golang.org/x/sys/unix"
const ioctlReadTermios = unix.TIOCGETA
type Termios unix.Termios

@ -1,4 +1,4 @@
// +build !appengine
// +build !appengine,!js,!windows,!aix
package logrus

@ -1,14 +0,0 @@
// Based on ssh/terminal:
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !appengine
package logrus
import "golang.org/x/sys/unix"
const ioctlReadTermios = unix.TCGETS
type Termios unix.Termios

@ -3,6 +3,8 @@ package logrus
import (
"bytes"
"fmt"
"os"
"runtime"
"sort"
"strings"
"sync"
@ -20,6 +22,7 @@ const (
var (
baseTimestamp time.Time
emptyFieldMap FieldMap
)
func init() {
@ -34,6 +37,9 @@ type TextFormatter struct {
// Force disabling colors.
DisableColors bool
// Override coloring based on CLICOLOR and CLICOLOR_FORCE. - https://bixense.com/clicolors/
EnvironmentOverrideColors bool
// Disable timestamp logging. useful when output is redirected to logging
// system that already adds timestamps.
DisableTimestamp bool
@ -50,60 +56,135 @@ type TextFormatter struct {
// be desired.
DisableSorting bool
// The keys sorting function, when uninitialized it uses sort.Strings.
SortingFunc func([]string)
// Disables the truncation of the level text to 4 characters.
DisableLevelTruncation bool
// QuoteEmptyFields will wrap empty fields in quotes if true
QuoteEmptyFields bool
// Whether the logger's out is to a terminal
isTerminal bool
sync.Once
// FieldMap allows users to customize the names of keys for default fields.
// As an example:
// formatter := &TextFormatter{
// FieldMap: FieldMap{
// FieldKeyTime: "@timestamp",
// FieldKeyLevel: "@level",
// FieldKeyMsg: "@message"}}
FieldMap FieldMap
terminalInitOnce sync.Once
}
func (f *TextFormatter) init(entry *Entry) {
if entry.Logger != nil {
f.isTerminal = checkIfTerminal(entry.Logger.Out)
if f.isTerminal {
initTerminal(entry.Logger.Out)
}
}
}
func (f *TextFormatter) isColored() bool {
isColored := f.ForceColors || (f.isTerminal && (runtime.GOOS != "windows"))
if f.EnvironmentOverrideColors {
if force, ok := os.LookupEnv("CLICOLOR_FORCE"); ok && force != "0" {
isColored = true
} else if ok && force == "0" {
isColored = false
} else if os.Getenv("CLICOLOR") == "0" {
isColored = false
}
}
return isColored && !f.DisableColors
}
// Format renders a single log entry
func (f *TextFormatter) Format(entry *Entry) ([]byte, error) {
var b *bytes.Buffer
keys := make([]string, 0, len(entry.Data))
for k := range entry.Data {
data := make(Fields)
for k, v := range entry.Data {
data[k] = v
}
prefixFieldClashes(data, f.FieldMap, entry.HasCaller())
keys := make([]string, 0, len(data))
for k := range data {
keys = append(keys, k)
}
if !f.DisableSorting {
sort.Strings(keys)
fixedKeys := make([]string, 0, 4+len(data))
if !f.DisableTimestamp {
fixedKeys = append(fixedKeys, f.FieldMap.resolve(FieldKeyTime))
}
fixedKeys = append(fixedKeys, f.FieldMap.resolve(FieldKeyLevel))
if entry.Message != "" {
fixedKeys = append(fixedKeys, f.FieldMap.resolve(FieldKeyMsg))
}
if entry.err != "" {
fixedKeys = append(fixedKeys, f.FieldMap.resolve(FieldKeyLogrusError))
}
if entry.HasCaller() {
fixedKeys = append(fixedKeys,
f.FieldMap.resolve(FieldKeyFunc), f.FieldMap.resolve(FieldKeyFile))
}
if !f.DisableSorting {
if f.SortingFunc == nil {
sort.Strings(keys)
fixedKeys = append(fixedKeys, keys...)
} else {
if !f.isColored() {
fixedKeys = append(fixedKeys, keys...)
f.SortingFunc(fixedKeys)
} else {
f.SortingFunc(keys)
}
}
} else {
fixedKeys = append(fixedKeys, keys...)
}
var b *bytes.Buffer
if entry.Buffer != nil {
b = entry.Buffer
} else {
b = &bytes.Buffer{}
}
prefixFieldClashes(entry.Data)
f.Do(func() { f.init(entry) })
isColored := (f.ForceColors || f.isTerminal) && !f.DisableColors
f.terminalInitOnce.Do(func() { f.init(entry) })
timestampFormat := f.TimestampFormat
if timestampFormat == "" {
timestampFormat = defaultTimestampFormat
}
if isColored {
f.printColored(b, entry, keys, timestampFormat)
if f.isColored() {
f.printColored(b, entry, keys, data, timestampFormat)
} else {
if !f.DisableTimestamp {
f.appendKeyValue(b, "time", entry.Time.Format(timestampFormat))
}
f.appendKeyValue(b, "level", entry.Level.String())
if entry.Message != "" {
f.appendKeyValue(b, "msg", entry.Message)
}
for _, key := range keys {
f.appendKeyValue(b, key, entry.Data[key])
for _, key := range fixedKeys {
var value interface{}
switch {
case key == f.FieldMap.resolve(FieldKeyTime):
value = entry.Time.Format(timestampFormat)
case key == f.FieldMap.resolve(FieldKeyLevel):
value = entry.Level.String()
case key == f.FieldMap.resolve(FieldKeyMsg):
value = entry.Message
case key == f.FieldMap.resolve(FieldKeyLogrusError):
value = entry.err
case key == f.FieldMap.resolve(FieldKeyFunc) && entry.HasCaller():
value = entry.Caller.Function
case key == f.FieldMap.resolve(FieldKeyFile) && entry.HasCaller():
value = fmt.Sprintf("%s:%d", entry.Caller.File, entry.Caller.Line)
default:
value = data[key]
}
f.appendKeyValue(b, key, value)
}
}
@ -111,10 +192,10 @@ func (f *TextFormatter) Format(entry *Entry) ([]byte, error) {
return b.Bytes(), nil
}
func (f *TextFormatter) printColored(b *bytes.Buffer, entry *Entry, keys []string, timestampFormat string) {
func (f *TextFormatter) printColored(b *bytes.Buffer, entry *Entry, keys []string, data Fields, timestampFormat string) {
var levelColor int
switch entry.Level {
case DebugLevel:
case DebugLevel, TraceLevel:
levelColor = gray
case WarnLevel:
levelColor = yellow
@ -124,17 +205,31 @@ func (f *TextFormatter) printColored(b *bytes.Buffer, entry *Entry, keys []strin
levelColor = blue
}
levelText := strings.ToUpper(entry.Level.String())[0:4]
levelText := strings.ToUpper(entry.Level.String())
if !f.DisableLevelTruncation {
levelText = levelText[0:4]
}
// Remove a single newline if it already exists in the message to keep
// the behavior of logrus text_formatter the same as the stdlib log package
entry.Message = strings.TrimSuffix(entry.Message, "\n")
caller := ""
if entry.HasCaller() {
caller = fmt.Sprintf("%s:%d %s()",
entry.Caller.File, entry.Caller.Line, entry.Caller.Function)
}
if f.DisableTimestamp {
fmt.Fprintf(b, "\x1b[%dm%s\x1b[0m %-44s ", levelColor, levelText, entry.Message)
fmt.Fprintf(b, "\x1b[%dm%s\x1b[0m%s %-44s ", levelColor, levelText, caller, entry.Message)
} else if !f.FullTimestamp {
fmt.Fprintf(b, "\x1b[%dm%s\x1b[0m[%04d] %-44s ", levelColor, levelText, int(entry.Time.Sub(baseTimestamp)/time.Second), entry.Message)
fmt.Fprintf(b, "\x1b[%dm%s\x1b[0m[%04d]%s %-44s ", levelColor, levelText, int(entry.Time.Sub(baseTimestamp)/time.Second), caller, entry.Message)
} else {
fmt.Fprintf(b, "\x1b[%dm%s\x1b[0m[%s] %-44s ", levelColor, levelText, entry.Time.Format(timestampFormat), entry.Message)
fmt.Fprintf(b, "\x1b[%dm%s\x1b[0m[%s]%s %-44s ", levelColor, levelText, entry.Time.Format(timestampFormat), caller, entry.Message)
}
for _, k := range keys {
v := entry.Data[k]
v := data[k]
fmt.Fprintf(b, " \x1b[%dm%s\x1b[0m=", levelColor, k)
f.appendValue(b, v)
}

@ -4,9 +4,15 @@ import (
"bytes"
"errors"
"fmt"
"os"
"runtime"
"sort"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestFormatting(t *testing.T) {
@ -128,6 +134,44 @@ func TestTimestampFormat(t *testing.T) {
checkTimeStr("")
}
func TestDisableLevelTruncation(t *testing.T) {
entry := &Entry{
Time: time.Now(),
Message: "testing",
}
keys := []string{}
timestampFormat := "Mon Jan 2 15:04:05 -0700 MST 2006"
checkDisableTruncation := func(disabled bool, level Level) {
tf := &TextFormatter{DisableLevelTruncation: disabled}
var b bytes.Buffer
entry.Level = level
tf.printColored(&b, entry, keys, nil, timestampFormat)
logLine := (&b).String()
if disabled {
expected := strings.ToUpper(level.String())
if !strings.Contains(logLine, expected) {
t.Errorf("level string expected to be %s when truncation disabled", expected)
}
} else {
expected := strings.ToUpper(level.String())
if len(level.String()) > 4 {
if strings.Contains(logLine, expected) {
t.Errorf("level string %s expected to be truncated to %s when truncation is enabled", expected, expected[0:4])
}
} else {
if !strings.Contains(logLine, expected) {
t.Errorf("level string expected to be %s when truncation is enabled and level string is below truncation threshold", expected)
}
}
}
}
checkDisableTruncation(true, DebugLevel)
checkDisableTruncation(true, InfoLevel)
checkDisableTruncation(false, ErrorLevel)
checkDisableTruncation(false, InfoLevel)
}
func TestDisableTimestampWithColoredOutput(t *testing.T) {
tf := &TextFormatter{DisableTimestamp: true, ForceColors: true}
@ -137,5 +181,305 @@ func TestDisableTimestampWithColoredOutput(t *testing.T) {
}
}
// TODO add tests for sorting etc., this requires a parser for the text
// formatter output.
func TestNewlineBehavior(t *testing.T) {
tf := &TextFormatter{ForceColors: true}
// Ensure a single new line is removed as per stdlib log
e := NewEntry(StandardLogger())
e.Message = "test message\n"
b, _ := tf.Format(e)
if bytes.Contains(b, []byte("test message\n")) {
t.Error("first newline at end of Entry.Message resulted in unexpected 2 newlines in output. Expected newline to be removed.")
}
// Ensure a double new line is reduced to a single new line
e = NewEntry(StandardLogger())
e.Message = "test message\n\n"
b, _ = tf.Format(e)
if bytes.Contains(b, []byte("test message\n\n")) {
t.Error("Double newline at end of Entry.Message resulted in unexpected 2 newlines in output. Expected single newline")
}
if !bytes.Contains(b, []byte("test message\n")) {
t.Error("Double newline at end of Entry.Message did not result in a single newline after formatting")
}
}
func TestTextFormatterFieldMap(t *testing.T) {
formatter := &TextFormatter{
DisableColors: true,
FieldMap: FieldMap{
FieldKeyMsg: "message",
FieldKeyLevel: "somelevel",
FieldKeyTime: "timeywimey",
},
}
entry := &Entry{
Message: "oh hi",
Level: WarnLevel,
Time: time.Date(1981, time.February, 24, 4, 28, 3, 100, time.UTC),
Data: Fields{
"field1": "f1",
"message": "messagefield",
"somelevel": "levelfield",
"timeywimey": "timeywimeyfield",
},
}
b, err := formatter.Format(entry)
if err != nil {
t.Fatal("Unable to format entry: ", err)
}
assert.Equal(t,
`timeywimey="1981-02-24T04:28:03Z" `+
`somelevel=warning `+
`message="oh hi" `+
`field1=f1 `+
`fields.message=messagefield `+
`fields.somelevel=levelfield `+
`fields.timeywimey=timeywimeyfield`+"\n",
string(b),
"Formatted output doesn't respect FieldMap")
}
func TestTextFormatterIsColored(t *testing.T) {
params := []struct {
name string
expectedResult bool
isTerminal bool
disableColor bool
forceColor bool
envColor bool
clicolorIsSet bool
clicolorForceIsSet bool
clicolorVal string
clicolorForceVal string
}{
// Default values
{
name: "testcase1",
expectedResult: false,
isTerminal: false,
disableColor: false,
forceColor: false,
envColor: false,
clicolorIsSet: false,
clicolorForceIsSet: false,
},
// Output on terminal
{
name: "testcase2",
expectedResult: true,
isTerminal: true,
disableColor: false,
forceColor: false,
envColor: false,
clicolorIsSet: false,
clicolorForceIsSet: false,
},
// Output on terminal with color disabled
{
name: "testcase3",
expectedResult: false,
isTerminal: true,
disableColor: true,
forceColor: false,
envColor: false,
clicolorIsSet: false,
clicolorForceIsSet: false,
},
// Output not on terminal with color disabled
{
name: "testcase4",
expectedResult: false,
isTerminal: false,
disableColor: true,
forceColor: false,
envColor: false,
clicolorIsSet: false,
clicolorForceIsSet: false,
},
// Output not on terminal with color forced
{
name: "testcase5",
expectedResult: true,
isTerminal: false,
disableColor: false,
forceColor: true,
envColor: false,
clicolorIsSet: false,
clicolorForceIsSet: false,
},
// Output on terminal with clicolor set to "0"
{
name: "testcase6",
expectedResult: false,
isTerminal: true,
disableColor: false,
forceColor: false,
envColor: true,
clicolorIsSet: true,
clicolorForceIsSet: false,
clicolorVal: "0",
},
// Output on terminal with clicolor set to "1"
{
name: "testcase7",
expectedResult: true,
isTerminal: true,
disableColor: false,
forceColor: false,
envColor: true,
clicolorIsSet: true,
clicolorForceIsSet: false,
clicolorVal: "1",
},
// Output not on terminal with clicolor set to "0"
{
name: "testcase8",
expectedResult: false,
isTerminal: false,
disableColor: false,
forceColor: false,
envColor: true,
clicolorIsSet: true,
clicolorForceIsSet: false,
clicolorVal: "0",
},
// Output not on terminal with clicolor set to "1"
{
name: "testcase9",
expectedResult: false,
isTerminal: false,
disableColor: false,
forceColor: false,
envColor: true,
clicolorIsSet: true,
clicolorForceIsSet: false,
clicolorVal: "1",
},
// Output not on terminal with clicolor set to "1" and force color
{
name: "testcase10",
expectedResult: true,
isTerminal: false,
disableColor: false,
forceColor: true,
envColor: true,
clicolorIsSet: true,
clicolorForceIsSet: false,
clicolorVal: "1",
},
// Output not on terminal with clicolor set to "0" and force color
{
name: "testcase11",
expectedResult: false,
isTerminal: false,
disableColor: false,
forceColor: true,
envColor: true,
clicolorIsSet: true,
clicolorForceIsSet: false,
clicolorVal: "0",
},
// Output not on terminal with clicolor_force set to "1"
{
name: "testcase12",
expectedResult: true,
isTerminal: false,
disableColor: false,
forceColor: false,
envColor: true,
clicolorIsSet: false,
clicolorForceIsSet: true,
clicolorForceVal: "1",
},
// Output not on terminal with clicolor_force set to "0"
{
name: "testcase13",
expectedResult: false,
isTerminal: false,
disableColor: false,
forceColor: false,
envColor: true,
clicolorIsSet: false,
clicolorForceIsSet: true,
clicolorForceVal: "0",
},
// Output on terminal with clicolor_force set to "0"
{
name: "testcase14",
expectedResult: false,
isTerminal: true,
disableColor: false,
forceColor: false,
envColor: true,
clicolorIsSet: false,
clicolorForceIsSet: true,
clicolorForceVal: "0",
},
}
cleanenv := func() {
os.Unsetenv("CLICOLOR")
os.Unsetenv("CLICOLOR_FORCE")
}
defer cleanenv()
for _, val := range params {
t.Run("textformatter_"+val.name, func(subT *testing.T) {
tf := TextFormatter{
isTerminal: val.isTerminal,
DisableColors: val.disableColor,
ForceColors: val.forceColor,
EnvironmentOverrideColors: val.envColor,
}
cleanenv()
if val.clicolorIsSet {
os.Setenv("CLICOLOR", val.clicolorVal)
}
if val.clicolorForceIsSet {
os.Setenv("CLICOLOR_FORCE", val.clicolorForceVal)
}
res := tf.isColored()
if runtime.GOOS == "windows" && !tf.ForceColors && !val.clicolorForceIsSet {
assert.Equal(subT, false, res)
} else {
assert.Equal(subT, val.expectedResult, res)
}
})
}
}
func TestCustomSorting(t *testing.T) {
formatter := &TextFormatter{
DisableColors: true,
SortingFunc: func(keys []string) {
sort.Slice(keys, func(i, j int) bool {
if keys[j] == "prefix" {
return false
}
if keys[i] == "prefix" {
return true
}
return strings.Compare(keys[i], keys[j]) == -1
})
},
}
entry := &Entry{
Message: "Testing custom sort function",
Time: time.Now(),
Level: InfoLevel,
Data: Fields{
"test": "testvalue",
"prefix": "the application prefix",
"blablabla": "blablabla",
},
}
b, err := formatter.Format(entry)
require.NoError(t, err)
require.True(t, strings.HasPrefix(string(b), "prefix="), "format output is %q", string(b))
}

@ -24,6 +24,8 @@ func (entry *Entry) WriterLevel(level Level) *io.PipeWriter {
var printFunc func(args ...interface{})
switch level {
case TraceLevel:
printFunc = entry.Trace
case DebugLevel:
printFunc = entry.Debug
case InfoLevel:

@ -4,16 +4,15 @@ Go is an open source project.
It is the work of hundreds of contributors. We appreciate your help!
## Filing issues
When [filing an issue](https://golang.org/issue/new), make sure to answer these five questions:
1. What version of Go are you using (`go version`)?
2. What operating system and processor architecture are you using?
3. What did you do?
4. What did you expect to see?
5. What did you see instead?
1. What version of Go are you using (`go version`)?
2. What operating system and processor architecture are you using?
3. What did you do?
4. What did you expect to see?
5. What did you see instead?
General questions should go to the [golang-nuts mailing list](https://groups.google.com/group/golang-nuts) instead of the issue tracker.
The gophers there will answer or ask you to file an issue if you've tripped over a bug.
@ -23,9 +22,5 @@ The gophers there will answer or ask you to file an issue if you've tripped over
Please read the [Contribution Guidelines](https://golang.org/doc/contribute.html)
before sending patches.
**We do not accept GitHub pull requests**
(we use [Gerrit](https://code.google.com/p/gerrit/) instead for code review).
Unless otherwise noted, the Go source files are distributed under
the BSD-style license found in the LICENSE file.

@ -14,7 +14,6 @@
package acme
import (
"bytes"
"context"
"crypto"
"crypto/ecdsa"
@ -23,6 +22,8 @@ import (
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"encoding/base64"
"encoding/hex"
"encoding/json"
@ -33,14 +34,26 @@ import (
"io/ioutil"
"math/big"
"net/http"
"strconv"
"strings"
"sync"
"time"
)
// LetsEncryptURL is the Directory endpoint of Let's Encrypt CA.
const LetsEncryptURL = "https://acme-v01.api.letsencrypt.org/directory"
const (
// LetsEncryptURL is the Directory endpoint of Let's Encrypt CA.
LetsEncryptURL = "https://acme-v01.api.letsencrypt.org/directory"
// ALPNProto is the ALPN protocol name used by a CA server when validating
// tls-alpn-01 challenges.
//
// Package users must ensure their servers can negotiate the ACME ALPN in
// order for tls-alpn-01 challenge verifications to succeed.
// See the crypto/tls package's Config.NextProtos field.
ALPNProto = "acme-tls/1"
)
// idPeACMEIdentifierV1 is the OID for the ACME extension for the TLS-ALPN challenge.
var idPeACMEIdentifierV1 = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 30, 1}
const (
maxChainLen = 5 // max depth and breadth of a certificate chain
@ -64,6 +77,10 @@ const (
type Client struct {
// Key is the account key used to register with a CA and sign requests.
// Key.Public() must return a *rsa.PublicKey or *ecdsa.PublicKey.
//
// The following algorithms are supported:
// RS256, ES256, ES384 and ES512.
// See RFC7518 for more details about the algorithms.
Key crypto.Signer
// HTTPClient optionally specifies an HTTP client to use
@ -76,6 +93,22 @@ type Client struct {
// will have no effect.
DirectoryURL string
// RetryBackoff computes the duration after which the nth retry of a failed request
// should occur. The value of n for the first call on failure is 1.
// The values of r and resp are the request and response of the last failed attempt.
// If the returned value is negative or zero, no more retries are done and an error
// is returned to the caller of the original method.
//
// Requests which result in a 4xx client error are not retried,
// except for 400 Bad Request due to "bad nonce" errors and 429 Too Many Requests.
//
// If RetryBackoff is nil, a truncated exponential backoff algorithm
// with the ceiling of 10 seconds is used, where each subsequent retry n
// is done after either ("Retry-After" + jitter) or (2^n seconds + jitter),
// preferring the former if "Retry-After" header is found in the resp.
// The jitter is a random value up to 1 second.
RetryBackoff func(n int, r *http.Request, resp *http.Response) time.Duration
dirMu sync.Mutex // guards writes to dir
dir *Directory // cached result of Client's Discover method
@ -95,19 +128,12 @@ func (c *Client) Discover(ctx context.Context) (Directory, error) {
return *c.dir, nil
}
dirURL := c.DirectoryURL
if dirURL == "" {
dirURL = LetsEncryptURL
}
res, err := c.get(ctx, dirURL)
res, err := c.get(ctx, c.directoryURL(), wantStatus(http.StatusOK))
if err != nil {
return Directory{}, err
}
defer res.Body.Close()
c.addNonce(res.Header)
if res.StatusCode != http.StatusOK {
return Directory{}, responseError(res)
}
var v struct {
Reg string `json:"new-reg"`
@ -135,6 +161,13 @@ func (c *Client) Discover(ctx context.Context) (Directory, error) {
return *c.dir, nil
}
func (c *Client) directoryURL() string {
if c.DirectoryURL != "" {
return c.DirectoryURL
}
return LetsEncryptURL
}
// CreateCert requests a new certificate using the Certificate Signing Request csr encoded in DER format.
// The exp argument indicates the desired certificate validity duration. CA may issue a certificate
// with a different duration.
@ -166,14 +199,11 @@ func (c *Client) CreateCert(ctx context.Context, csr []byte, exp time.Duration,
req.NotAfter = now.Add(exp).Format(time.RFC3339)
}
res, err := c.retryPostJWS(ctx, c.Key, c.dir.CertURL, req)
res, err := c.post(ctx, c.Key, c.dir.CertURL, req, wantStatus(http.StatusCreated))
if err != nil {
return nil, "", err
}
defer res.Body.Close()
if res.StatusCode != http.StatusCreated {
return nil, "", responseError(res)
}
curl := res.Header.Get("Location") // cert permanent URL
if res.ContentLength == 0 {
@ -196,26 +226,11 @@ func (c *Client) CreateCert(ctx context.Context, csr []byte, exp time.Duration,
// Callers are encouraged to parse the returned value to ensure the certificate is valid
// and has expected features.
func (c *Client) FetchCert(ctx context.Context, url string, bundle bool) ([][]byte, error) {
for {
res, err := c.get(ctx, url)
if err != nil {
return nil, err
}
defer res.Body.Close()
if res.StatusCode == http.StatusOK {
return c.responseCert(ctx, res, bundle)
}
if res.StatusCode > 299 {
return nil, responseError(res)
}
d := retryAfter(res.Header.Get("Retry-After"), 3*time.Second)
select {
case <-time.After(d):
// retry
case <-ctx.Done():
return nil, ctx.Err()
}
res, err := c.get(ctx, url, wantStatus(http.StatusOK))
if err != nil {
return nil, err
}
return c.responseCert(ctx, res, bundle)
}
// RevokeCert revokes a previously issued certificate cert, provided in DER format.
@ -241,14 +256,11 @@ func (c *Client) RevokeCert(ctx context.Context, key crypto.Signer, cert []byte,
if key == nil {
key = c.Key
}
res, err := c.retryPostJWS(ctx, key, c.dir.RevokeURL, body)
res, err := c.post(ctx, key, c.dir.RevokeURL, body, wantStatus(http.StatusOK))
if err != nil {
return err
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return responseError(res)
}
return nil
}
@ -314,6 +326,20 @@ func (c *Client) UpdateReg(ctx context.Context, a *Account) (*Account, error) {
// a valid authorization (Authorization.Status is StatusValid). If so, the caller
// need not fulfill any challenge and can proceed to requesting a certificate.
func (c *Client) Authorize(ctx context.Context, domain string) (*Authorization, error) {
return c.authorize(ctx, "dns", domain)
}
// AuthorizeIP is the same as Authorize but requests IP address authorization.
// Clients which successfully obtain such authorization may request to issue
// a certificate for IP addresses.
//
// See the ACME spec extension for more details about IP address identifiers:
// https://tools.ietf.org/html/draft-ietf-acme-ip.
func (c *Client) AuthorizeIP(ctx context.Context, ipaddr string) (*Authorization, error) {
return c.authorize(ctx, "ip", ipaddr)
}
func (c *Client) authorize(ctx context.Context, typ, val string) (*Authorization, error) {
if _, err := c.Discover(ctx); err != nil {
return nil, err
}
@ -327,16 +353,13 @@ func (c *Client) Authorize(ctx context.Context, domain string) (*Authorization,
Identifier authzID `json:"identifier"`
}{
Resource: "new-authz",
Identifier: authzID{Type: "dns", Value: domain},
Identifier: authzID{Type: typ, Value: val},
}
res, err := c.retryPostJWS(ctx, c.Key, c.dir.AuthzURL, req)
res, err := c.post(ctx, c.Key, c.dir.AuthzURL, req, wantStatus(http.StatusCreated))
if err != nil {
return nil, err
}
defer res.Body.Close()
if res.StatusCode != http.StatusCreated {
return nil, responseError(res)
}
var v wireAuthz
if err := json.NewDecoder(res.Body).Decode(&v); err != nil {
@ -353,14 +376,11 @@ func (c *Client) Authorize(ctx context.Context, domain string) (*Authorization,
// If a caller needs to poll an authorization until its status is final,
// see the WaitAuthorization method.
func (c *Client) GetAuthorization(ctx context.Context, url string) (*Authorization, error) {
res, err := c.get(ctx, url)
res, err := c.get(ctx, url, wantStatus(http.StatusOK, http.StatusAccepted))
if err != nil {
return nil, err
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK && res.StatusCode != http.StatusAccepted {
return nil, responseError(res)
}
var v wireAuthz
if err := json.NewDecoder(res.Body).Decode(&v); err != nil {
return nil, fmt.Errorf("acme: invalid response: %v", err)
@ -387,56 +407,58 @@ func (c *Client) RevokeAuthorization(ctx context.Context, url string) error {
Status: "deactivated",
Delete: true,
}
res, err := c.retryPostJWS(ctx, c.Key, url, req)
res, err := c.post(ctx, c.Key, url, req, wantStatus(http.StatusOK))
if err != nil {
return err
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return responseError(res)
}
return nil
}
// WaitAuthorization polls an authorization at the given URL
// until it is in one of the final states, StatusValid or StatusInvalid,
// or the context is done.
// the ACME CA responded with a 4xx error code, or the context is done.
//
// It returns a non-nil Authorization only if its Status is StatusValid.
// In all other cases WaitAuthorization returns an error.
// If the Status is StatusInvalid, the returned error is of type *AuthorizationError.
func (c *Client) WaitAuthorization(ctx context.Context, url string) (*Authorization, error) {
sleep := sleeper(ctx)
for {
res, err := c.get(ctx, url)
res, err := c.get(ctx, url, wantStatus(http.StatusOK, http.StatusAccepted))
if err != nil {
return nil, err
}
retry := res.Header.Get("Retry-After")
if res.StatusCode != http.StatusOK && res.StatusCode != http.StatusAccepted {
res.Body.Close()
if err := sleep(retry, 1); err != nil {
return nil, err
}
continue
}
var raw wireAuthz
err = json.NewDecoder(res.Body).Decode(&raw)
res.Body.Close()
if err != nil {
if err := sleep(retry, 0); err != nil {
return nil, err
}
continue
}
if raw.Status == StatusValid {
switch {
case err != nil:
// Skip and retry.
case raw.Status == StatusValid:
return raw.authorization(url), nil
}
if raw.Status == StatusInvalid {
case raw.Status == StatusInvalid:
return nil, raw.error(url)
}
if err := sleep(retry, 0); err != nil {
return nil, err
// Exponential backoff is implemented in c.get above.
// This is just to prevent continuously hitting the CA
// while waiting for a final authorization status.
d := retryAfter(res.Header.Get("Retry-After"))
if d == 0 {
// Given that the fastest challenges TLS-SNI and HTTP-01
// require a CA to make at least 1 network round trip
// and most likely persist a challenge state,
// this default delay seems reasonable.
d = time.Second
}
t := time.NewTimer(d)
select {
case <-ctx.Done():
t.Stop()
return nil, ctx.Err()
case <-t.C:
// Retry.
}
}
}
@ -445,14 +467,11 @@ func (c *Client) WaitAuthorization(ctx context.Context, url string) (*Authorizat
//
// A client typically polls a challenge status using this method.
func (c *Client) GetChallenge(ctx context.Context, url string) (*Challenge, error) {
res, err := c.get(ctx, url)
res, err := c.get(ctx, url, wantStatus(http.StatusOK, http.StatusAccepted))
if err != nil {
return nil, err
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK && res.StatusCode != http.StatusAccepted {
return nil, responseError(res)
}
v := wireChallenge{URI: url}
if err := json.NewDecoder(res.Body).Decode(&v); err != nil {
return nil, fmt.Errorf("acme: invalid response: %v", err)
@ -479,16 +498,14 @@ func (c *Client) Accept(ctx context.Context, chal *Challenge) (*Challenge, error
Type: chal.Type,
Auth: auth,
}
res, err := c.retryPostJWS(ctx, c.Key, chal.URI, req)
res, err := c.post(ctx, c.Key, chal.URI, req, wantStatus(
http.StatusOK, // according to the spec
http.StatusAccepted, // Let's Encrypt: see https://goo.gl/WsJ7VT (acme-divergences.md)
))
if err != nil {
return nil, err
}
defer res.Body.Close()
// Note: the protocol specifies 200 as the expected response code, but
// letsencrypt seems to be returning 202.
if res.StatusCode != http.StatusOK && res.StatusCode != http.StatusAccepted {
return nil, responseError(res)
}
var v wireChallenge
if err := json.NewDecoder(res.Body).Decode(&v); err != nil {
@ -545,7 +562,7 @@ func (c *Client) HTTP01ChallengePath(token string) string {
// If no WithKey option is provided, a new ECDSA key is generated using P-256 curve.
//
// The returned certificate is valid for the next 24 hours and must be presented only when
// the server name of the client hello matches exactly the returned name value.
// the server name of the TLS ClientHello matches exactly the returned name value.
func (c *Client) TLSSNI01ChallengeCert(token string, opt ...CertOption) (cert tls.Certificate, name string, err error) {
ka, err := keyAuth(c.Key.Public(), token)
if err != nil {
@ -572,7 +589,7 @@ func (c *Client) TLSSNI01ChallengeCert(token string, opt ...CertOption) (cert tl
// If no WithKey option is provided, a new ECDSA key is generated using P-256 curve.
//
// The returned certificate is valid for the next 24 hours and must be presented only when
// the server name in the client hello matches exactly the returned name value.
// the server name in the TLS ClientHello matches exactly the returned name value.
func (c *Client) TLSSNI02ChallengeCert(token string, opt ...CertOption) (cert tls.Certificate, name string, err error) {
b := sha256.Sum256([]byte(token))
h := hex.EncodeToString(b[:])
@ -593,6 +610,52 @@ func (c *Client) TLSSNI02ChallengeCert(token string, opt ...CertOption) (cert tl
return cert, sanA, nil
}
// TLSALPN01ChallengeCert creates a certificate for TLS-ALPN-01 challenge response.
// Servers can present the certificate to validate the challenge and prove control
// over a domain name. For more details on TLS-ALPN-01 see
// https://tools.ietf.org/html/draft-shoemaker-acme-tls-alpn-00#section-3
//
// The token argument is a Challenge.Token value.
// If a WithKey option is provided, its private part signs the returned cert,
// and the public part is used to specify the signee.
// If no WithKey option is provided, a new ECDSA key is generated using P-256 curve.
//
// The returned certificate is valid for the next 24 hours and must be presented only when
// the server name in the TLS ClientHello matches the domain, and the special acme-tls/1 ALPN protocol
// has been specified.
func (c *Client) TLSALPN01ChallengeCert(token, domain string, opt ...CertOption) (cert tls.Certificate, err error) {
ka, err := keyAuth(c.Key.Public(), token)
if err != nil {
return tls.Certificate{}, err
}
shasum := sha256.Sum256([]byte(ka))
extValue, err := asn1.Marshal(shasum[:])
if err != nil {
return tls.Certificate{}, err
}
acmeExtension := pkix.Extension{
Id: idPeACMEIdentifierV1,
Critical: true,
Value: extValue,
}
tmpl := defaultTLSChallengeCertTemplate()
var newOpt []CertOption
for _, o := range opt {
switch o := o.(type) {
case *certOptTemplate:
t := *(*x509.Certificate)(o) // shallow copy is ok
tmpl = &t
default:
newOpt = append(newOpt, o)
}
}
tmpl.ExtraExtensions = append(tmpl.ExtraExtensions, acmeExtension)
newOpt = append(newOpt, WithTemplate(tmpl))
return tlsChallengeCert([]string{domain}, newOpt)
}
// doReg sends all types of registration requests.
// The type of request is identified by typ argument, which is a "resource"
// in the ACME spec terms.
@ -612,14 +675,15 @@ func (c *Client) doReg(ctx context.Context, url string, typ string, acct *Accoun
req.Contact = acct.Contact
req.Agreement = acct.AgreedTerms
}
res, err := c.retryPostJWS(ctx, c.Key, url, req)
res, err := c.post(ctx, c.Key, url, req, wantStatus(
http.StatusOK, // updates and deletes
http.StatusCreated, // new account creation
http.StatusAccepted, // Let's Encrypt divergent implementation
))
if err != nil {
return nil, err
}
defer res.Body.Close()
if res.StatusCode < 200 || res.StatusCode > 299 {
return nil, responseError(res)
}
var v struct {
Contact []string
@ -649,66 +713,19 @@ func (c *Client) doReg(ctx context.Context, url string, typ string, acct *Accoun
}, nil
}
// retryPostJWS will retry calls to postJWS if there is a badNonce error,
// clearing the stored nonces after each error.
// If the response was 4XX-5XX, then responseError is called on the body,
// the body is closed, and the error returned.
func (c *Client) retryPostJWS(ctx context.Context, key crypto.Signer, url string, body interface{}) (*http.Response, error) {
sleep := sleeper(ctx)
for {
res, err := c.postJWS(ctx, key, url, body)
if err != nil {
return nil, err
}
// handle errors 4XX-5XX with responseError
if res.StatusCode >= 400 && res.StatusCode <= 599 {
err := responseError(res)
res.Body.Close()
// according to spec badNonce is urn:ietf:params:acme:error:badNonce
// however, acme servers in the wild return their version of the error
// https://tools.ietf.org/html/draft-ietf-acme-acme-02#section-5.4
if ae, ok := err.(*Error); ok && strings.HasSuffix(strings.ToLower(ae.ProblemType), ":badnonce") {
// clear any nonces that we might've stored that might now be
// considered bad
c.clearNonces()
retry := res.Header.Get("Retry-After")
if err := sleep(retry, 1); err != nil {
return nil, err
}
continue
}
return nil, err
}
return res, nil
}
}
// postJWS signs the body with the given key and POSTs it to the provided url.
// The body argument must be JSON-serializable.
func (c *Client) postJWS(ctx context.Context, key crypto.Signer, url string, body interface{}) (*http.Response, error) {
nonce, err := c.popNonce(ctx, url)
if err != nil {
return nil, err
}
b, err := jwsEncodeJSON(body, key, nonce)
if err != nil {
return nil, err
}
res, err := c.post(ctx, url, "application/jose+json", bytes.NewReader(b))
if err != nil {
return nil, err
}
c.addNonce(res.Header)
return res, nil
}
// popNonce returns a nonce value previously stored with c.addNonce
// or fetches a fresh one from the given URL.
// or fetches a fresh one from a URL by issuing a HEAD request.
// It first tries c.directoryURL() and then the provided url if the former fails.
func (c *Client) popNonce(ctx context.Context, url string) (string, error) {
c.noncesMu.Lock()
defer c.noncesMu.Unlock()
if len(c.nonces) == 0 {
return c.fetchNonce(ctx, url)
dirURL := c.directoryURL()
v, err := c.fetchNonce(ctx, dirURL)
if err != nil && url != dirURL {
v, err = c.fetchNonce(ctx, url)
}
return v, err
}
var nonce string
for nonce = range c.nonces {
@ -742,58 +759,12 @@ func (c *Client) addNonce(h http.Header) {
c.nonces[v] = struct{}{}
}
func (c *Client) httpClient() *http.Client {
if c.HTTPClient != nil {
return c.HTTPClient
}
return http.DefaultClient
}
func (c *Client) get(ctx context.Context, urlStr string) (*http.Response, error) {
req, err := http.NewRequest("GET", urlStr, nil)
if err != nil {
return nil, err
}
return c.do(ctx, req)
}
func (c *Client) head(ctx context.Context, urlStr string) (*http.Response, error) {
req, err := http.NewRequest("HEAD", urlStr, nil)
if err != nil {
return nil, err
}
return c.do(ctx, req)
}
func (c *Client) post(ctx context.Context, urlStr, contentType string, body io.Reader) (*http.Response, error) {
req, err := http.NewRequest("POST", urlStr, body)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", contentType)
return c.do(ctx, req)
}
func (c *Client) do(ctx context.Context, req *http.Request) (*http.Response, error) {
res, err := c.httpClient().Do(req.WithContext(ctx))
if err != nil {
select {
case <-ctx.Done():
// Prefer the unadorned context error.
// (The acme package had tests assuming this, previously from ctxhttp's
// behavior, predating net/http supporting contexts natively)
// TODO(bradfitz): reconsider this in the future. But for now this
// requires no test updates.
return nil, ctx.Err()
default:
return nil, err
}
}
return res, nil
}
func (c *Client) fetchNonce(ctx context.Context, url string) (string, error) {
resp, err := c.head(ctx, url)
r, err := http.NewRequest("HEAD", url, nil)
if err != nil {
return "", err
}
resp, err := c.doNoRetry(ctx, r)
if err != nil {
return "", err
}
@ -845,24 +816,6 @@ func (c *Client) responseCert(ctx context.Context, res *http.Response, bundle bo
return cert, nil
}
// responseError creates an error of Error type from resp.
func responseError(resp *http.Response) error {
// don't care if ReadAll returns an error:
// json.Unmarshal will fail in that case anyway
b, _ := ioutil.ReadAll(resp.Body)
e := &wireError{Status: resp.StatusCode}
if err := json.Unmarshal(b, e); err != nil {
// this is not a regular error response:
// populate detail with anything we received,
// e.Status will already contain HTTP response code value
e.Detail = string(b)
if e.Detail == "" {
e.Detail = resp.Status
}
}
return e.error(resp.Header)
}
// chainCert fetches CA certificate chain recursively by following "up" links.
// Each recursive call increments the depth by 1, resulting in an error
// if the recursion level reaches maxChainLen.
@ -873,14 +826,11 @@ func (c *Client) chainCert(ctx context.Context, url string, depth int) ([][]byte
return nil, errors.New("acme: certificate chain is too deep")
}
res, err := c.get(ctx, url)
res, err := c.get(ctx, url, wantStatus(http.StatusOK))
if err != nil {
return nil, err
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return nil, responseError(res)
}
b, err := ioutil.ReadAll(io.LimitReader(res.Body, maxCertSize+1))
if err != nil {
return nil, err
@ -925,65 +875,6 @@ func linkHeader(h http.Header, rel string) []string {
return links
}
// sleeper returns a function that accepts the Retry-After HTTP header value
// and an increment that's used with backoff to increasingly sleep on
// consecutive calls until the context is done. If the Retry-After header
// cannot be parsed, then backoff is used with a maximum sleep time of 10
// seconds.
func sleeper(ctx context.Context) func(ra string, inc int) error {
var count int
return func(ra string, inc int) error {
count += inc
d := backoff(count, 10*time.Second)
d = retryAfter(ra, d)
wakeup := time.NewTimer(d)
defer wakeup.Stop()
select {
case <-ctx.Done():
return ctx.Err()
case <-wakeup.C:
return nil
}
}
}
// retryAfter parses a Retry-After HTTP header value,
// trying to convert v into an int (seconds) or use http.ParseTime otherwise.
// It returns d if v cannot be parsed.
func retryAfter(v string, d time.Duration) time.Duration {
if i, err := strconv.Atoi(v); err == nil {
return time.Duration(i) * time.Second
}
t, err := http.ParseTime(v)
if err != nil {
return d
}
return t.Sub(timeNow())
}
// backoff computes a duration after which an n+1 retry iteration should occur
// using truncated exponential backoff algorithm.
//
// The n argument is always bounded between 0 and 30.
// The max argument defines upper bound for the returned value.
func backoff(n int, max time.Duration) time.Duration {
if n < 0 {
n = 0
}
if n > 30 {
n = 30
}
var d time.Duration
if x, err := rand.Int(rand.Reader, big.NewInt(1000)); err == nil {
d = time.Duration(x.Int64()) * time.Millisecond
}
d += time.Duration(1<<uint(n)) * time.Second
if d > max {
return max
}
return d
}
// keyAuth generates a key authorization string for a given token.
func keyAuth(pub crypto.PublicKey, token string) (string, error) {
th, err := JWKThumbprint(pub)
@ -993,15 +884,25 @@ func keyAuth(pub crypto.PublicKey, token string) (string, error) {
return fmt.Sprintf("%s.%s", token, th), nil
}
// defaultTLSChallengeCertTemplate is a template used to create challenge certs for TLS challenges.
func defaultTLSChallengeCertTemplate() *x509.Certificate {
return &x509.Certificate{
SerialNumber: big.NewInt(1),
NotBefore: time.Now(),
NotAfter: time.Now().Add(24 * time.Hour),
BasicConstraintsValid: true,
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
}
}
// tlsChallengeCert creates a temporary certificate for TLS-SNI challenges
// with the given SANs and auto-generated public/private key pair.
// The Subject Common Name is set to the first SAN to aid debugging.
// To create a cert with a custom key pair, specify WithKey option.
func tlsChallengeCert(san []string, opt []CertOption) (tls.Certificate, error) {
var (
key crypto.Signer
tmpl *x509.Certificate
)
var key crypto.Signer
tmpl := defaultTLSChallengeCertTemplate()
for _, o := range opt {
switch o := o.(type) {
case *certOptKey:
@ -1010,7 +911,7 @@ func tlsChallengeCert(san []string, opt []CertOption) (tls.Certificate, error) {
}
key = o.key
case *certOptTemplate:
var t = *(*x509.Certificate)(o) // shallow copy is ok
t := *(*x509.Certificate)(o) // shallow copy is ok
tmpl = &t
default:
// package's fault, if we let this happen:
@ -1023,16 +924,6 @@ func tlsChallengeCert(san []string, opt []CertOption) (tls.Certificate, error) {
return tls.Certificate{}, err
}
}
if tmpl == nil {
tmpl = &x509.Certificate{
SerialNumber: big.NewInt(1),
NotBefore: time.Now(),
NotAfter: time.Now().Add(24 * time.Hour),
BasicConstraintsValid: true,
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
}
}
tmpl.DNSNames = san
if len(san) > 0 {
tmpl.Subject.CommonName = san[0]

@ -13,9 +13,9 @@ import (
"crypto/x509"
"crypto/x509/pkix"
"encoding/base64"
"encoding/hex"
"encoding/json"
"fmt"
"io/ioutil"
"math/big"
"net/http"
"net/http/httptest"
@ -75,6 +75,7 @@ func TestDiscover(t *testing.T) {
)
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Replay-Nonce", "testnonce")
fmt.Fprintf(w, `{
"new-reg": %q,
"new-authz": %q,
@ -100,6 +101,9 @@ func TestDiscover(t *testing.T) {
if dir.RevokeURL != revoke {
t.Errorf("dir.RevokeURL = %q; want %q", dir.RevokeURL, revoke)
}
if _, exist := c.nonces["testnonce"]; !exist {
t.Errorf("c.nonces = %q; want 'testnonce' in the map", c.nonces)
}
}
func TestRegister(t *testing.T) {
@ -147,7 +151,11 @@ func TestRegister(t *testing.T) {
return false
}
c := Client{Key: testKeyEC, dir: &Directory{RegURL: ts.URL}}
c := Client{
Key: testKeyEC,
DirectoryURL: ts.URL,
dir: &Directory{RegURL: ts.URL},
}
a := &Account{Contact: contacts}
var err error
if a, err = c.Register(context.Background(), a, prompt); err != nil {
@ -288,106 +296,131 @@ func TestGetReg(t *testing.T) {
}
func TestAuthorize(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == "HEAD" {
w.Header().Set("Replay-Nonce", "test-nonce")
return
}
if r.Method != "POST" {
t.Errorf("r.Method = %q; want POST", r.Method)
}
var j struct {
Resource string
Identifier struct {
Type string
Value string
}
}
decodeJWSRequest(t, &j, r)
// Test request
if j.Resource != "new-authz" {
t.Errorf("j.Resource = %q; want new-authz", j.Resource)
}
if j.Identifier.Type != "dns" {
t.Errorf("j.Identifier.Type = %q; want dns", j.Identifier.Type)
}
if j.Identifier.Value != "example.com" {
t.Errorf("j.Identifier.Value = %q; want example.com", j.Identifier.Value)
}
w.Header().Set("Location", "https://ca.tld/acme/auth/1")
w.WriteHeader(http.StatusCreated)
fmt.Fprintf(w, `{
"identifier": {"type":"dns","value":"example.com"},
"status":"pending",
"challenges":[
{
"type":"http-01",
"status":"pending",
"uri":"https://ca.tld/acme/challenge/publickey/id1",
"token":"token1"
},
{
"type":"tls-sni-01",
"status":"pending",
"uri":"https://ca.tld/acme/challenge/publickey/id2",
"token":"token2"
tt := []struct{ typ, value string }{
{"dns", "example.com"},
{"ip", "1.2.3.4"},
}
for _, test := range tt {
t.Run(test.typ, func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == "HEAD" {
w.Header().Set("Replay-Nonce", "test-nonce")
return
}
if r.Method != "POST" {
t.Errorf("r.Method = %q; want POST", r.Method)
}
],
"combinations":[[0],[1]]}`)
}))
defer ts.Close()
cl := Client{Key: testKeyEC, dir: &Directory{AuthzURL: ts.URL}}
auth, err := cl.Authorize(context.Background(), "example.com")
if err != nil {
t.Fatal(err)
}
var j struct {
Resource string
Identifier struct {
Type string
Value string
}
}
decodeJWSRequest(t, &j, r)
if auth.URI != "https://ca.tld/acme/auth/1" {
t.Errorf("URI = %q; want https://ca.tld/acme/auth/1", auth.URI)
}
if auth.Status != "pending" {
t.Errorf("Status = %q; want pending", auth.Status)
}
if auth.Identifier.Type != "dns" {
t.Errorf("Identifier.Type = %q; want dns", auth.Identifier.Type)
}
if auth.Identifier.Value != "example.com" {
t.Errorf("Identifier.Value = %q; want example.com", auth.Identifier.Value)
}
// Test request
if j.Resource != "new-authz" {
t.Errorf("j.Resource = %q; want new-authz", j.Resource)
}
if j.Identifier.Type != test.typ {
t.Errorf("j.Identifier.Type = %q; want %q", j.Identifier.Type, test.typ)
}
if j.Identifier.Value != test.value {
t.Errorf("j.Identifier.Value = %q; want %q", j.Identifier.Value, test.value)
}
if n := len(auth.Challenges); n != 2 {
t.Fatalf("len(auth.Challenges) = %d; want 2", n)
}
w.Header().Set("Location", "https://ca.tld/acme/auth/1")
w.WriteHeader(http.StatusCreated)
fmt.Fprintf(w, `{
"identifier": {"type":%q,"value":%q},
"status":"pending",
"challenges":[
{
"type":"http-01",
"status":"pending",
"uri":"https://ca.tld/acme/challenge/publickey/id1",
"token":"token1"
},
{
"type":"tls-sni-01",
"status":"pending",
"uri":"https://ca.tld/acme/challenge/publickey/id2",
"token":"token2"
}
],
"combinations":[[0],[1]]
}`, test.typ, test.value)
}))
defer ts.Close()
c := auth.Challenges[0]
if c.Type != "http-01" {
t.Errorf("c.Type = %q; want http-01", c.Type)
}
if c.URI != "https://ca.tld/acme/challenge/publickey/id1" {
t.Errorf("c.URI = %q; want https://ca.tld/acme/challenge/publickey/id1", c.URI)
}
if c.Token != "token1" {
t.Errorf("c.Token = %q; want token1", c.Token)
}
var (
auth *Authorization
err error
)
cl := Client{
Key: testKeyEC,
DirectoryURL: ts.URL,
dir: &Directory{AuthzURL: ts.URL},
}
switch test.typ {
case "dns":
auth, err = cl.Authorize(context.Background(), test.value)
case "ip":
auth, err = cl.AuthorizeIP(context.Background(), test.value)
default:
t.Fatalf("unknown identifier type: %q", test.typ)
}
if err != nil {
t.Fatal(err)
}
c = auth.Challenges[1]
if c.Type != "tls-sni-01" {
t.Errorf("c.Type = %q; want tls-sni-01", c.Type)
}
if c.URI != "https://ca.tld/acme/challenge/publickey/id2" {
t.Errorf("c.URI = %q; want https://ca.tld/acme/challenge/publickey/id2", c.URI)
}
if c.Token != "token2" {
t.Errorf("c.Token = %q; want token2", c.Token)
}
if auth.URI != "https://ca.tld/acme/auth/1" {
t.Errorf("URI = %q; want https://ca.tld/acme/auth/1", auth.URI)
}
if auth.Status != "pending" {
t.Errorf("Status = %q; want pending", auth.Status)
}
if auth.Identifier.Type != test.typ {
t.Errorf("Identifier.Type = %q; want %q", auth.Identifier.Type, test.typ)
}
if auth.Identifier.Value != test.value {
t.Errorf("Identifier.Value = %q; want %q", auth.Identifier.Value, test.value)
}
combs := [][]int{{0}, {1}}
if !reflect.DeepEqual(auth.Combinations, combs) {
t.Errorf("auth.Combinations: %+v\nwant: %+v\n", auth.Combinations, combs)
if n := len(auth.Challenges); n != 2 {
t.Fatalf("len(auth.Challenges) = %d; want 2", n)
}
c := auth.Challenges[0]
if c.Type != "http-01" {
t.Errorf("c.Type = %q; want http-01", c.Type)
}
if c.URI != "https://ca.tld/acme/challenge/publickey/id1" {
t.Errorf("c.URI = %q; want https://ca.tld/acme/challenge/publickey/id1", c.URI)
}
if c.Token != "token1" {
t.Errorf("c.Token = %q; want token1", c.Token)
}
c = auth.Challenges[1]
if c.Type != "tls-sni-01" {
t.Errorf("c.Type = %q; want tls-sni-01", c.Type)
}
if c.URI != "https://ca.tld/acme/challenge/publickey/id2" {
t.Errorf("c.URI = %q; want https://ca.tld/acme/challenge/publickey/id2", c.URI)
}
if c.Token != "token2" {
t.Errorf("c.Token = %q; want token2", c.Token)
}
combs := [][]int{{0}, {1}}
if !reflect.DeepEqual(auth.Combinations, combs) {
t.Errorf("auth.Combinations: %+v\nwant: %+v\n", auth.Combinations, combs)
}
})
}
}
@ -401,7 +434,11 @@ func TestAuthorizeValid(t *testing.T) {
w.Write([]byte(`{"status":"valid"}`))
}))
defer ts.Close()
client := Client{Key: testKey, dir: &Directory{AuthzURL: ts.URL}}
client := Client{
Key: testKey,
DirectoryURL: ts.URL,
dir: &Directory{AuthzURL: ts.URL},
}
_, err := client.Authorize(context.Background(), "example.com")
if err != nil {
t.Errorf("err = %v", err)
@ -485,95 +522,98 @@ func TestGetAuthorization(t *testing.T) {
}
func TestWaitAuthorization(t *testing.T) {
var count int
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
count++
w.Header().Set("Retry-After", "0")
if count > 1 {
fmt.Fprintf(w, `{"status":"valid"}`)
return
t.Run("wait loop", func(t *testing.T) {
var count int
authz, err := runWaitAuthorization(context.Background(), t, func(w http.ResponseWriter, r *http.Request) {
count++
w.Header().Set("Retry-After", "0")
if count > 1 {
fmt.Fprintf(w, `{"status":"valid"}`)
return
}
fmt.Fprintf(w, `{"status":"pending"}`)
})
if err != nil {
t.Fatalf("non-nil error: %v", err)
}
fmt.Fprintf(w, `{"status":"pending"}`)
}))
if authz == nil {
t.Fatal("authz is nil")
}
})
t.Run("invalid status", func(t *testing.T) {
_, err := runWaitAuthorization(context.Background(), t, func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, `{"status":"invalid"}`)
})
if _, ok := err.(*AuthorizationError); !ok {
t.Errorf("err is %v (%T); want non-nil *AuthorizationError", err, err)
}
})
t.Run("non-retriable error", func(t *testing.T) {
const code = http.StatusBadRequest
_, err := runWaitAuthorization(context.Background(), t, func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(code)
})
res, ok := err.(*Error)
if !ok {
t.Fatalf("err is %v (%T); want a non-nil *Error", err, err)
}
if res.StatusCode != code {
t.Errorf("res.StatusCode = %d; want %d", res.StatusCode, code)
}
})
for _, code := range []int{http.StatusTooManyRequests, http.StatusInternalServerError} {
t.Run(fmt.Sprintf("retriable %d error", code), func(t *testing.T) {
var count int
authz, err := runWaitAuthorization(context.Background(), t, func(w http.ResponseWriter, r *http.Request) {
count++
w.Header().Set("Retry-After", "0")
if count > 1 {
fmt.Fprintf(w, `{"status":"valid"}`)
return
}
w.WriteHeader(code)
})
if err != nil {
t.Fatalf("non-nil error: %v", err)
}
if authz == nil {
t.Fatal("authz is nil")
}
})
}
t.Run("context cancel", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
_, err := runWaitAuthorization(ctx, t, func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Retry-After", "60")
fmt.Fprintf(w, `{"status":"pending"}`)
})
if err == nil {
t.Error("err is nil")
}
})
}
func runWaitAuthorization(ctx context.Context, t *testing.T, h http.HandlerFunc) (*Authorization, error) {
t.Helper()
ts := httptest.NewServer(h)
defer ts.Close()
type res struct {
authz *Authorization
err error
}
done := make(chan res)
defer close(done)
ch := make(chan res, 1)
go func() {
var client Client
a, err := client.WaitAuthorization(context.Background(), ts.URL)
done <- res{a, err}
a, err := client.WaitAuthorization(ctx, ts.URL)
ch <- res{a, err}
}()
select {
case <-time.After(5 * time.Second):
t.Fatal("WaitAuthz took too long to return")
case res := <-done:
if res.err != nil {
t.Fatalf("res.err = %v", res.err)
}
if res.authz == nil {
t.Fatal("res.authz is nil")
}
}
}
func TestWaitAuthorizationInvalid(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, `{"status":"invalid"}`)
}))
defer ts.Close()
res := make(chan error)
defer close(res)
go func() {
var client Client
_, err := client.WaitAuthorization(context.Background(), ts.URL)
res <- err
}()
select {
case <-time.After(3 * time.Second):
t.Fatal("WaitAuthz took too long to return")
case err := <-res:
if err == nil {
t.Error("err is nil")
}
if _, ok := err.(*AuthorizationError); !ok {
t.Errorf("err is %T; want *AuthorizationError", err)
}
}
}
func TestWaitAuthorizationCancel(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Retry-After", "60")
fmt.Fprintf(w, `{"status":"pending"}`)
}))
defer ts.Close()
res := make(chan error)
defer close(res)
go func() {
var client Client
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
_, err := client.WaitAuthorization(ctx, ts.URL)
res <- err
}()
select {
case <-time.After(time.Second):
t.Fatal("WaitAuthz took too long to return")
case err := <-res:
if err == nil {
t.Error("err is nil")
}
t.Fatal("WaitAuthorization took too long to return")
case v := <-ch:
return v.authz, v.err
}
panic("runWaitAuthorization: out of select")
}
func TestRevokeAuthorization(t *testing.T) {
@ -600,7 +640,7 @@ func TestRevokeAuthorization(t *testing.T) {
t.Errorf("req.Delete is false")
}
case "/2":
w.WriteHeader(http.StatusInternalServerError)
w.WriteHeader(http.StatusBadRequest)
}
}))
defer ts.Close()
@ -821,7 +861,7 @@ func TestFetchCertRetry(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if count < 1 {
w.Header().Set("Retry-After", "0")
w.WriteHeader(http.StatusAccepted)
w.WriteHeader(http.StatusTooManyRequests)
count++
return
}
@ -1013,6 +1053,53 @@ func TestNonce_fetchError(t *testing.T) {
}
}
func TestNonce_popWhenEmpty(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "HEAD" {
t.Errorf("r.Method = %q; want HEAD", r.Method)
}
switch r.URL.Path {
case "/dir-with-nonce":
w.Header().Set("Replay-Nonce", "dirnonce")
case "/new-nonce":
w.Header().Set("Replay-Nonce", "newnonce")
case "/dir-no-nonce", "/empty":
// No nonce in the header.
default:
t.Errorf("Unknown URL: %s", r.URL)
}
}))
defer ts.Close()
ctx := context.Background()
tt := []struct {
dirURL, popURL, nonce string
wantOK bool
}{
{ts.URL + "/dir-with-nonce", ts.URL + "/new-nonce", "dirnonce", true},
{ts.URL + "/dir-no-nonce", ts.URL + "/new-nonce", "newnonce", true},
{ts.URL + "/dir-no-nonce", ts.URL + "/empty", "", false},
}
for _, test := range tt {
t.Run(fmt.Sprintf("nonce:%s wantOK:%v", test.nonce, test.wantOK), func(t *testing.T) {
c := Client{DirectoryURL: test.dirURL}
v, err := c.popNonce(ctx, test.popURL)
if !test.wantOK {
if err == nil {
t.Fatalf("c.popNonce(%q) returned nil error", test.popURL)
}
return
}
if err != nil {
t.Fatalf("c.popNonce(%q): %v", test.popURL, err)
}
if v != test.nonce {
t.Errorf("c.popNonce(%q) = %q; want %q", test.popURL, v, test.nonce)
}
})
}
}
func TestNonce_postJWS(t *testing.T) {
var count int
seen := make(map[string]bool)
@ -1046,7 +1133,11 @@ func TestNonce_postJWS(t *testing.T) {
}))
defer ts.Close()
client := Client{Key: testKey, dir: &Directory{AuthzURL: ts.URL}}
client := Client{
Key: testKey,
DirectoryURL: ts.URL, // nonces are fetched from here first
dir: &Directory{AuthzURL: ts.URL},
}
if _, err := client.Authorize(context.Background(), "example.com"); err != nil {
t.Errorf("client.Authorize 1: %v", err)
}
@ -1068,44 +1159,6 @@ func TestNonce_postJWS(t *testing.T) {
}
}
func TestRetryPostJWS(t *testing.T) {
var count int
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
count++
w.Header().Set("Replay-Nonce", fmt.Sprintf("nonce%d", count))
if r.Method == "HEAD" {
// We expect the client to do 2 head requests to fetch
// nonces, one to start and another after getting badNonce
return
}
head, err := decodeJWSHead(r)
if err != nil {
t.Errorf("decodeJWSHead: %v", err)
} else if head.Nonce == "" {
t.Error("head.Nonce is empty")
} else if head.Nonce == "nonce1" {
// return a badNonce error to force the call to retry
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(`{"type":"urn:ietf:params:acme:error:badNonce"}`))
return
}
// Make client.Authorize happy; we're not testing its result.
w.WriteHeader(http.StatusCreated)
w.Write([]byte(`{"status":"valid"}`))
}))
defer ts.Close()
client := Client{Key: testKey, dir: &Directory{AuthzURL: ts.URL}}
// This call will fail with badNonce, causing a retry
if _, err := client.Authorize(context.Background(), "example.com"); err != nil {
t.Errorf("client.Authorize 1: %v", err)
}
if count != 4 {
t.Errorf("total requests count: %d; want 4", count)
}
}
func TestLinkHeader(t *testing.T) {
h := http.Header{"Link": {
`<https://example.com/acme/new-authz>;rel="next"`,
@ -1129,37 +1182,6 @@ func TestLinkHeader(t *testing.T) {
}
}
func TestErrorResponse(t *testing.T) {
s := `{
"status": 400,
"type": "urn:acme:error:xxx",
"detail": "text"
}`
res := &http.Response{
StatusCode: 400,
Status: "400 Bad Request",
Body: ioutil.NopCloser(strings.NewReader(s)),
Header: http.Header{"X-Foo": {"bar"}},
}
err := responseError(res)
v, ok := err.(*Error)
if !ok {
t.Fatalf("err = %+v (%T); want *Error type", err, err)
}
if v.StatusCode != 400 {
t.Errorf("v.StatusCode = %v; want 400", v.StatusCode)
}
if v.ProblemType != "urn:acme:error:xxx" {
t.Errorf("v.ProblemType = %q; want urn:acme:error:xxx", v.ProblemType)
}
if v.Detail != "text" {
t.Errorf("v.Detail = %q; want text", v.Detail)
}
if !reflect.DeepEqual(v.Header, res.Header) {
t.Errorf("v.Header = %+v; want %+v", v.Header, res.Header)
}
}
func TestTLSSNI01ChallengeCert(t *testing.T) {
const (
token = "evaGxfADs6pSRb2LAv9IZf17Dt3juxGJ-PCt92wr-oA"
@ -1227,6 +1249,58 @@ func TestTLSSNI02ChallengeCert(t *testing.T) {
}
}
func TestTLSALPN01ChallengeCert(t *testing.T) {
const (
token = "evaGxfADs6pSRb2LAv9IZf17Dt3juxGJ-PCt92wr-oA"
keyAuth = "evaGxfADs6pSRb2LAv9IZf17Dt3juxGJ-PCt92wr-oA." + testKeyECThumbprint
// echo -n <token.testKeyECThumbprint> | shasum -a 256
h = "0420dbbd5eefe7b4d06eb9d1d9f5acb4c7cda27d320e4b30332f0b6cb441734ad7b0"
domain = "example.com"
)
extValue, err := hex.DecodeString(h)
if err != nil {
t.Fatal(err)
}
client := &Client{Key: testKeyEC}
tlscert, err := client.TLSALPN01ChallengeCert(token, domain)
if err != nil {
t.Fatal(err)
}
if n := len(tlscert.Certificate); n != 1 {
t.Fatalf("len(tlscert.Certificate) = %d; want 1", n)
}
cert, err := x509.ParseCertificate(tlscert.Certificate[0])
if err != nil {
t.Fatal(err)
}
names := []string{domain}
if !reflect.DeepEqual(cert.DNSNames, names) {
t.Fatalf("cert.DNSNames = %v;\nwant %v", cert.DNSNames, names)
}
if cn := cert.Subject.CommonName; cn != domain {
t.Errorf("CommonName = %q; want %q", cn, domain)
}
acmeExts := []pkix.Extension{}
for _, ext := range cert.Extensions {
if idPeACMEIdentifierV1.Equal(ext.Id) {
acmeExts = append(acmeExts, ext)
}
}
if len(acmeExts) != 1 {
t.Errorf("acmeExts = %v; want exactly one", acmeExts)
}
if !acmeExts[0].Critical {
t.Errorf("acmeExt.Critical = %v; want true", acmeExts[0].Critical)
}
if bytes.Compare(acmeExts[0].Value, extValue) != 0 {
t.Errorf("acmeExt.Value = %v; want %v", acmeExts[0].Value, extValue)
}
}
func TestTLSChallengeCertOpt(t *testing.T) {
key, err := rsa.GenerateKey(rand.Reader, 512)
if err != nil {
@ -1325,28 +1399,3 @@ func TestDNS01ChallengeRecord(t *testing.T) {
t.Errorf("val = %q; want %q", val, value)
}
}
func TestBackoff(t *testing.T) {
tt := []struct{ min, max time.Duration }{
{time.Second, 2 * time.Second},
{2 * time.Second, 3 * time.Second},
{4 * time.Second, 5 * time.Second},
{8 * time.Second, 9 * time.Second},
}
for i, test := range tt {
d := backoff(i, time.Minute)
if d < test.min || test.max < d {
t.Errorf("%d: d = %v; want between %v and %v", i, d, test.min, test.max)
}
}
min, max := time.Second, 2*time.Second
if d := backoff(-1, time.Minute); d < min || max < d {
t.Errorf("d = %v; want between %v and %v", d, min, max)
}
bound := 10 * time.Second
if d := backoff(100, bound); d != bound {
t.Errorf("d = %v; want %v", d, bound)
}
}

@ -27,7 +27,6 @@ import (
"net"
"net/http"
"path"
"strconv"
"strings"
"sync"
"time"
@ -45,7 +44,7 @@ var createCertRetryAfter = time.Minute
var pseudoRand *lockedMathRand
func init() {
src := mathrand.NewSource(timeNow().UnixNano())
src := mathrand.NewSource(time.Now().UnixNano())
pseudoRand = &lockedMathRand{rnd: mathrand.New(src)}
}
@ -70,7 +69,7 @@ func HostWhitelist(hosts ...string) HostPolicy {
}
return func(_ context.Context, host string) error {
if !whitelist[host] {
return errors.New("acme/autocert: host not configured")
return fmt.Errorf("acme/autocert: host %q not configured in HostWhitelist", host)
}
return nil
}
@ -82,9 +81,9 @@ func defaultHostPolicy(context.Context, string) error {
}
// Manager is a stateful certificate manager built on top of acme.Client.
// It obtains and refreshes certificates automatically using "tls-sni-01",
// "tls-sni-02" and "http-01" challenge types, as well as providing them
// to a TLS server via tls.Config.
// It obtains and refreshes certificates automatically using "tls-alpn-01",
// "tls-sni-01", "tls-sni-02" and "http-01" challenge types,
// as well as providing them to a TLS server via tls.Config.
//
// You must specify a cache implementation, such as DirCache,
// to reuse obtained certificates across program restarts.
@ -99,11 +98,11 @@ type Manager struct {
// To always accept the terms, the callers can use AcceptTOS.
Prompt func(tosURL string) bool
// Cache optionally stores and retrieves previously-obtained certificates.
// If nil, certs will only be cached for the lifetime of the Manager.
// Cache optionally stores and retrieves previously-obtained certificates
// and other state. If nil, certs will only be cached for the lifetime of
// the Manager. Multiple Managers can share the same Cache.
//
// Manager passes the Cache certificates data encoded in PEM, with private/public
// parts combined in a single Cache.Put call, private key first.
// Using a persistent Cache, such as DirCache, is strongly recommended.
Cache Cache
// HostPolicy controls which domains the Manager will attempt
@ -128,8 +127,10 @@ type Manager struct {
// Client is used to perform low-level operations, such as account registration
// and requesting new certificates.
//
// If Client is nil, a zero-value acme.Client is used with acme.LetsEncryptURL
// directory endpoint and a newly-generated ECDSA P-256 key.
// as directory endpoint. If the Client.Key is nil, a new ECDSA P-256 key is
// generated and, if Cache is not nil, stored in cache.
//
// Mutating the field after the first call of GetCertificate method will have no effect.
Client *acme.Client
@ -141,22 +142,30 @@ type Manager struct {
// If the Client's account key is already registered, Email is not used.
Email string
// ForceRSA makes the Manager generate certificates with 2048-bit RSA keys.
// ForceRSA used to make the Manager generate RSA certificates. It is now ignored.
//
// If false, a default is used. Currently the default
// is EC-based keys using the P-256 curve.
// Deprecated: the Manager will request the correct type of certificate based
// on what each client supports.
ForceRSA bool
// ExtraExtensions are used when generating a new CSR (Certificate Request),
// thus allowing customization of the resulting certificate.
// For instance, TLS Feature Extension (RFC 7633) can be used
// to prevent an OCSP downgrade attack.
//
// The field value is passed to crypto/x509.CreateCertificateRequest
// in the template's ExtraExtensions field as is.
ExtraExtensions []pkix.Extension
clientMu sync.Mutex
client *acme.Client // initialized by acmeClient method
stateMu sync.Mutex
state map[string]*certState // keyed by domain name
state map[certKey]*certState
// renewal tracks the set of domains currently running renewal timers.
// It is keyed by domain name.
renewalMu sync.Mutex
renewal map[string]*domainRenewal
renewal map[certKey]*domainRenewal
// tokensMu guards the rest of the fields: tryHTTP01, certTokens and httpTokens.
tokensMu sync.RWMutex
@ -168,21 +177,60 @@ type Manager struct {
// to be provisioned.
// The entries are stored for the duration of the authorization flow.
httpTokens map[string][]byte
// certTokens contains temporary certificates for tls-sni challenges
// certTokens contains temporary certificates for tls-sni and tls-alpn challenges
// and is keyed by token domain name, which matches server name of ClientHello.
// Keys always have ".acme.invalid" suffix.
// Keys always have ".acme.invalid" suffix for tls-sni. Otherwise, they are domain names
// for tls-alpn.
// The entries are stored for the duration of the authorization flow.
certTokens map[string]*tls.Certificate
// nowFunc, if not nil, returns the current time. This may be set for
// testing purposes.
nowFunc func() time.Time
}
// certKey is the key by which certificates are tracked in state, renewal and cache.
type certKey struct {
domain string // without trailing dot
isRSA bool // RSA cert for legacy clients (as opposed to default ECDSA)
isToken bool // tls-based challenge token cert; key type is undefined regardless of isRSA
}
func (c certKey) String() string {
if c.isToken {
return c.domain + "+token"
}
if c.isRSA {
return c.domain + "+rsa"
}
return c.domain
}
// TLSConfig creates a new TLS config suitable for net/http.Server servers,
// supporting HTTP/2 and the tls-alpn-01 ACME challenge type.
func (m *Manager) TLSConfig() *tls.Config {
return &tls.Config{
GetCertificate: m.GetCertificate,
NextProtos: []string{
"h2", "http/1.1", // enable HTTP/2
acme.ALPNProto, // enable tls-alpn ACME challenges
},
}
}
// GetCertificate implements the tls.Config.GetCertificate hook.
// It provides a TLS certificate for hello.ServerName host, including answering
// *.acme.invalid (TLS-SNI) challenges. All other fields of hello are ignored.
// tls-alpn-01 and *.acme.invalid (tls-sni-01 and tls-sni-02) challenges.
// All other fields of hello are ignored.
//
// If m.HostPolicy is non-nil, GetCertificate calls the policy before requesting
// a new cert. A non-nil error returned from m.HostPolicy halts TLS negotiation.
// The error is propagated back to the caller of GetCertificate and is user-visible.
// This does not affect cached certs. See HostPolicy field description for more details.
//
// If GetCertificate is used directly, instead of via Manager.TLSConfig, package users will
// also have to add acme.ALPNProto to NextProtos for tls-alpn-01, or use HTTPHandler
// for http-01. (The tls-sni-* challenges have been deprecated by popular ACME providers
// due to security issues in the ecosystem.)
func (m *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
if m.Prompt == nil {
return nil, errors.New("acme/autocert: Manager.Prompt not set")
@ -195,7 +243,7 @@ func (m *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate,
if !strings.Contains(strings.Trim(name, "."), ".") {
return nil, errors.New("acme/autocert: server name component count invalid")
}
if strings.ContainsAny(name, `/\`) {
if strings.ContainsAny(name, `+/\`) {
return nil, errors.New("acme/autocert: server name contains invalid character")
}
@ -204,14 +252,17 @@ func (m *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate,
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
// check whether this is a token cert requested for TLS-SNI challenge
if strings.HasSuffix(name, ".acme.invalid") {
// Check whether this is a token cert requested for TLS-SNI or TLS-ALPN challenge.
if wantsTokenCert(hello) {
m.tokensMu.RLock()
defer m.tokensMu.RUnlock()
// It's ok to use the same token cert key for both tls-sni and tls-alpn
// because there's always at most 1 token cert per on-going domain authorization.
// See m.verify for details.
if cert := m.certTokens[name]; cert != nil {
return cert, nil
}
if cert, err := m.cacheGet(ctx, name); err == nil {
if cert, err := m.cacheGet(ctx, certKey{domain: name, isToken: true}); err == nil {
return cert, nil
}
// TODO: cache error results?
@ -219,8 +270,11 @@ func (m *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate,
}
// regular domain
name = strings.TrimSuffix(name, ".") // golang.org/issue/18114
cert, err := m.cert(ctx, name)
ck := certKey{
domain: strings.TrimSuffix(name, "."), // golang.org/issue/18114
isRSA: !supportsECDSA(hello),
}
cert, err := m.cert(ctx, ck)
if err == nil {
return cert, nil
}
@ -232,14 +286,71 @@ func (m *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate,
if err := m.hostPolicy()(ctx, name); err != nil {
return nil, err
}
cert, err = m.createCert(ctx, name)
cert, err = m.createCert(ctx, ck)
if err != nil {
return nil, err
}
m.cachePut(ctx, name, cert)
m.cachePut(ctx, ck, cert)
return cert, nil
}
// wantsTokenCert reports whether a TLS request with SNI is made by a CA server
// for a challenge verification.
func wantsTokenCert(hello *tls.ClientHelloInfo) bool {
// tls-alpn-01
if len(hello.SupportedProtos) == 1 && hello.SupportedProtos[0] == acme.ALPNProto {
return true
}
// tls-sni-xx
return strings.HasSuffix(hello.ServerName, ".acme.invalid")
}
func supportsECDSA(hello *tls.ClientHelloInfo) bool {
// The "signature_algorithms" extension, if present, limits the key exchange
// algorithms allowed by the cipher suites. See RFC 5246, section 7.4.1.4.1.
if hello.SignatureSchemes != nil {
ecdsaOK := false
schemeLoop:
for _, scheme := range hello.SignatureSchemes {
const tlsECDSAWithSHA1 tls.SignatureScheme = 0x0203 // constant added in Go 1.10
switch scheme {
case tlsECDSAWithSHA1, tls.ECDSAWithP256AndSHA256,
tls.ECDSAWithP384AndSHA384, tls.ECDSAWithP521AndSHA512:
ecdsaOK = true
break schemeLoop
}
}
if !ecdsaOK {
return false
}
}
if hello.SupportedCurves != nil {
ecdsaOK := false
for _, curve := range hello.SupportedCurves {
if curve == tls.CurveP256 {
ecdsaOK = true
break
}
}
if !ecdsaOK {
return false
}
}
for _, suite := range hello.CipherSuites {
switch suite {
case tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA,
tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256,
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305:
return true
}
}
return false
}
// HTTPHandler configures the Manager to provision ACME "http-01" challenge responses.
// It returns an http.Handler that responds to the challenges and must be
// running on port 80. If it receives a request that is not an ACME challenge,
@ -253,8 +364,8 @@ func (m *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate,
// Because the fallback handler is run with unencrypted port 80 requests,
// the fallback should not serve TLS-only requests.
//
// If HTTPHandler is never called, the Manager will only use TLS SNI
// challenges for domain verification.
// If HTTPHandler is never called, the Manager will only use the "tls-alpn-01"
// challenge for domain verification.
func (m *Manager) HTTPHandler(fallback http.Handler) http.Handler {
m.tokensMu.Lock()
defer m.tokensMu.Unlock()
@ -305,16 +416,16 @@ func stripPort(hostport string) string {
// cert returns an existing certificate either from m.state or cache.
// If a certificate is found in cache but not in m.state, the latter will be filled
// with the cached value.
func (m *Manager) cert(ctx context.Context, name string) (*tls.Certificate, error) {
func (m *Manager) cert(ctx context.Context, ck certKey) (*tls.Certificate, error) {
m.stateMu.Lock()
if s, ok := m.state[name]; ok {
if s, ok := m.state[ck]; ok {
m.stateMu.Unlock()
s.RLock()
defer s.RUnlock()
return s.tlscert()
}
defer m.stateMu.Unlock()
cert, err := m.cacheGet(ctx, name)
cert, err := m.cacheGet(ctx, ck)
if err != nil {
return nil, err
}
@ -323,25 +434,25 @@ func (m *Manager) cert(ctx context.Context, name string) (*tls.Certificate, erro
return nil, errors.New("acme/autocert: private key cannot sign")
}
if m.state == nil {
m.state = make(map[string]*certState)
m.state = make(map[certKey]*certState)
}
s := &certState{
key: signer,
cert: cert.Certificate,
leaf: cert.Leaf,
}
m.state[name] = s
go m.renew(name, s.key, s.leaf.NotAfter)
m.state[ck] = s
go m.renew(ck, s.key, s.leaf.NotAfter)
return cert, nil
}
// cacheGet always returns a valid certificate, or an error otherwise.
// If a cached certficate exists but is not valid, ErrCacheMiss is returned.
func (m *Manager) cacheGet(ctx context.Context, domain string) (*tls.Certificate, error) {
// If a cached certificate exists but is not valid, ErrCacheMiss is returned.
func (m *Manager) cacheGet(ctx context.Context, ck certKey) (*tls.Certificate, error) {
if m.Cache == nil {
return nil, ErrCacheMiss
}
data, err := m.Cache.Get(ctx, domain)
data, err := m.Cache.Get(ctx, ck.String())
if err != nil {
return nil, err
}
@ -372,7 +483,7 @@ func (m *Manager) cacheGet(ctx context.Context, domain string) (*tls.Certificate
}
// verify and create TLS cert
leaf, err := validCert(domain, pubDER, privKey)
leaf, err := validCert(ck, pubDER, privKey, m.now())
if err != nil {
return nil, ErrCacheMiss
}
@ -384,7 +495,7 @@ func (m *Manager) cacheGet(ctx context.Context, domain string) (*tls.Certificate
return tlscert, nil
}
func (m *Manager) cachePut(ctx context.Context, domain string, tlscert *tls.Certificate) error {
func (m *Manager) cachePut(ctx context.Context, ck certKey, tlscert *tls.Certificate) error {
if m.Cache == nil {
return nil
}
@ -416,7 +527,7 @@ func (m *Manager) cachePut(ctx context.Context, domain string, tlscert *tls.Cert
}
}
return m.Cache.Put(ctx, domain, buf.Bytes())
return m.Cache.Put(ctx, ck.String(), buf.Bytes())
}
func encodeECDSAKey(w io.Writer, key *ecdsa.PrivateKey) error {
@ -433,9 +544,9 @@ func encodeECDSAKey(w io.Writer, key *ecdsa.PrivateKey) error {
//
// If the domain is already being verified, it waits for the existing verification to complete.
// Either way, createCert blocks for the duration of the whole process.
func (m *Manager) createCert(ctx context.Context, domain string) (*tls.Certificate, error) {
func (m *Manager) createCert(ctx context.Context, ck certKey) (*tls.Certificate, error) {
// TODO: maybe rewrite this whole piece using sync.Once
state, err := m.certState(domain)
state, err := m.certState(ck)
if err != nil {
return nil, err
}
@ -453,44 +564,44 @@ func (m *Manager) createCert(ctx context.Context, domain string) (*tls.Certifica
defer state.Unlock()
state.locked = false
der, leaf, err := m.authorizedCert(ctx, state.key, domain)
der, leaf, err := m.authorizedCert(ctx, state.key, ck)
if err != nil {
// Remove the failed state after some time,
// making the manager call createCert again on the following TLS hello.
time.AfterFunc(createCertRetryAfter, func() {
defer testDidRemoveState(domain)
defer testDidRemoveState(ck)
m.stateMu.Lock()
defer m.stateMu.Unlock()
// Verify the state hasn't changed and it's still invalid
// before deleting.
s, ok := m.state[domain]
s, ok := m.state[ck]
if !ok {
return
}
if _, err := validCert(domain, s.cert, s.key); err == nil {
if _, err := validCert(ck, s.cert, s.key, m.now()); err == nil {
return
}
delete(m.state, domain)
delete(m.state, ck)
})
return nil, err
}
state.cert = der
state.leaf = leaf
go m.renew(domain, state.key, state.leaf.NotAfter)
go m.renew(ck, state.key, state.leaf.NotAfter)
return state.tlscert()
}
// certState returns a new or existing certState.
// If a new certState is returned, state.exist is false and the state is locked.
// The returned error is non-nil only in the case where a new state could not be created.
func (m *Manager) certState(domain string) (*certState, error) {
func (m *Manager) certState(ck certKey) (*certState, error) {
m.stateMu.Lock()
defer m.stateMu.Unlock()
if m.state == nil {
m.state = make(map[string]*certState)
m.state = make(map[certKey]*certState)
}
// existing state
if state, ok := m.state[domain]; ok {
if state, ok := m.state[ck]; ok {
return state, nil
}
@ -499,7 +610,7 @@ func (m *Manager) certState(domain string) (*certState, error) {
err error
key crypto.Signer
)
if m.ForceRSA {
if ck.isRSA {
key, err = rsa.GenerateKey(rand.Reader, 2048)
} else {
key, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
@ -513,22 +624,22 @@ func (m *Manager) certState(domain string) (*certState, error) {
locked: true,
}
state.Lock() // will be unlocked by m.certState caller
m.state[domain] = state
m.state[ck] = state
return state, nil
}
// authorizedCert starts the domain ownership verification process and requests a new cert upon success.
// The key argument is the certificate private key.
func (m *Manager) authorizedCert(ctx context.Context, key crypto.Signer, domain string) (der [][]byte, leaf *x509.Certificate, err error) {
func (m *Manager) authorizedCert(ctx context.Context, key crypto.Signer, ck certKey) (der [][]byte, leaf *x509.Certificate, err error) {
client, err := m.acmeClient(ctx)
if err != nil {
return nil, nil, err
}
if err := m.verify(ctx, client, domain); err != nil {
if err := m.verify(ctx, client, ck.domain); err != nil {
return nil, nil, err
}
csr, err := certRequest(key, domain)
csr, err := certRequest(key, ck.domain, m.ExtraExtensions)
if err != nil {
return nil, nil, err
}
@ -536,25 +647,55 @@ func (m *Manager) authorizedCert(ctx context.Context, key crypto.Signer, domain
if err != nil {
return nil, nil, err
}
leaf, err = validCert(domain, der, key)
leaf, err = validCert(ck, der, key, m.now())
if err != nil {
return nil, nil, err
}
return der, leaf, nil
}
// revokePendingAuthz revokes all authorizations idenfied by the elements of uri slice.
// It ignores revocation errors.
func (m *Manager) revokePendingAuthz(ctx context.Context, uri []string) {
client, err := m.acmeClient(ctx)
if err != nil {
return
}
for _, u := range uri {
client.RevokeAuthorization(ctx, u)
}
}
// verify runs the identifier (domain) authorization flow
// using each applicable ACME challenge type.
func (m *Manager) verify(ctx context.Context, client *acme.Client, domain string) error {
// The list of challenge types we'll try to fulfill
// in this specific order.
challengeTypes := []string{"tls-sni-02", "tls-sni-01"}
challengeTypes := []string{"tls-alpn-01", "tls-sni-02", "tls-sni-01"}
m.tokensMu.RLock()
if m.tryHTTP01 {
challengeTypes = append(challengeTypes, "http-01")
}
m.tokensMu.RUnlock()
// Keep track of pending authzs and revoke the ones that did not validate.
pendingAuthzs := make(map[string]bool)
defer func() {
var uri []string
for k, pending := range pendingAuthzs {
if pending {
uri = append(uri, k)
}
}
if len(uri) > 0 {
// Use "detached" background context.
// The revocations need not happen in the current verification flow.
go m.revokePendingAuthz(context.Background(), uri)
}
}()
// errs accumulates challenge failure errors, printed if all fail
errs := make(map[*acme.Challenge]error)
var nextTyp int // challengeType index of the next challenge type to try
for {
// Start domain authorization and get the challenge.
@ -571,6 +712,8 @@ func (m *Manager) verify(ctx context.Context, client *acme.Client, domain string
return fmt.Errorf("acme/autocert: invalid authorization %q", authz.URI)
}
pendingAuthzs[authz.URI] = true
// Pick the next preferred challenge.
var chal *acme.Challenge
for chal == nil && nextTyp < len(challengeTypes) {
@ -578,28 +721,44 @@ func (m *Manager) verify(ctx context.Context, client *acme.Client, domain string
nextTyp++
}
if chal == nil {
return fmt.Errorf("acme/autocert: unable to authorize %q; tried %q", domain, challengeTypes)
errorMsg := fmt.Sprintf("acme/autocert: unable to authorize %q", domain)
for chal, err := range errs {
errorMsg += fmt.Sprintf("; challenge %q failed with error: %v", chal.Type, err)
}
return errors.New(errorMsg)
}
cleanup, err := m.fulfill(ctx, client, chal)
cleanup, err := m.fulfill(ctx, client, chal, domain)
if err != nil {
errs[chal] = err
continue
}
defer cleanup()
if _, err := client.Accept(ctx, chal); err != nil {
errs[chal] = err
continue
}
// A challenge is fulfilled and accepted: wait for the CA to validate.
if _, err := client.WaitAuthorization(ctx, authz.URI); err == nil {
return nil
if _, err := client.WaitAuthorization(ctx, authz.URI); err != nil {
errs[chal] = err
continue
}
delete(pendingAuthzs, authz.URI)
return nil
}
}
// fulfill provisions a response to the challenge chal.
// The cleanup is non-nil only if provisioning succeeded.
func (m *Manager) fulfill(ctx context.Context, client *acme.Client, chal *acme.Challenge) (cleanup func(), err error) {
func (m *Manager) fulfill(ctx context.Context, client *acme.Client, chal *acme.Challenge, domain string) (cleanup func(), err error) {
switch chal.Type {
case "tls-alpn-01":
cert, err := client.TLSALPN01ChallengeCert(chal.Token, domain)
if err != nil {
return nil, err
}
m.putCertToken(ctx, domain, &cert)
return func() { go m.deleteCertToken(domain) }, nil
case "tls-sni-01":
cert, name, err := client.TLSSNI01ChallengeCert(chal.Token)
if err != nil {
@ -635,8 +794,8 @@ func pickChallenge(typ string, chal []*acme.Challenge) *acme.Challenge {
return nil
}
// putCertToken stores the cert under the named key in both m.certTokens map
// and m.Cache.
// putCertToken stores the token certificate with the specified name
// in both m.certTokens map and m.Cache.
func (m *Manager) putCertToken(ctx context.Context, name string, cert *tls.Certificate) {
m.tokensMu.Lock()
defer m.tokensMu.Unlock()
@ -644,17 +803,18 @@ func (m *Manager) putCertToken(ctx context.Context, name string, cert *tls.Certi
m.certTokens = make(map[string]*tls.Certificate)
}
m.certTokens[name] = cert
m.cachePut(ctx, name, cert)
m.cachePut(ctx, certKey{domain: name, isToken: true}, cert)
}
// deleteCertToken removes the token certificate for the specified domain name
// deleteCertToken removes the token certificate with the specified name
// from both m.certTokens map and m.Cache.
func (m *Manager) deleteCertToken(name string) {
m.tokensMu.Lock()
defer m.tokensMu.Unlock()
delete(m.certTokens, name)
if m.Cache != nil {
m.Cache.Delete(context.Background(), name)
ck := certKey{domain: name, isToken: true}
m.Cache.Delete(context.Background(), ck.String())
}
}
@ -705,7 +865,7 @@ func (m *Manager) deleteHTTPToken(tokenPath string) {
// httpTokenCacheKey returns a key at which an http-01 token value may be stored
// in the Manager's optional Cache.
func httpTokenCacheKey(tokenPath string) string {
return "http-01-" + path.Base(tokenPath)
return path.Base(tokenPath) + "+http-01"
}
// renew starts a cert renewal timer loop, one per domain.
@ -716,18 +876,18 @@ func httpTokenCacheKey(tokenPath string) string {
//
// The key argument is a certificate private key.
// The exp argument is the cert expiration time (NotAfter).
func (m *Manager) renew(domain string, key crypto.Signer, exp time.Time) {
func (m *Manager) renew(ck certKey, key crypto.Signer, exp time.Time) {
m.renewalMu.Lock()
defer m.renewalMu.Unlock()
if m.renewal[domain] != nil {
if m.renewal[ck] != nil {
// another goroutine is already on it
return
}
if m.renewal == nil {
m.renewal = make(map[string]*domainRenewal)
m.renewal = make(map[certKey]*domainRenewal)
}
dr := &domainRenewal{m: m, domain: domain, key: key}
m.renewal[domain] = dr
dr := &domainRenewal{m: m, ck: ck, key: key}
m.renewal[ck] = dr
dr.start(exp)
}
@ -743,7 +903,10 @@ func (m *Manager) stopRenew() {
}
func (m *Manager) accountKey(ctx context.Context) (crypto.Signer, error) {
const keyName = "acme_account.key"
const keyName = "acme_account+key"
// Previous versions of autocert stored the value under a different key.
const legacyKeyName = "acme_account.key"
genKey := func() (*ecdsa.PrivateKey, error) {
return ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
@ -754,6 +917,9 @@ func (m *Manager) accountKey(ctx context.Context) (crypto.Signer, error) {
}
data, err := m.Cache.Get(ctx, keyName)
if err == ErrCacheMiss {
data, err = m.Cache.Get(ctx, legacyKeyName)
}
if err == ErrCacheMiss {
key, err := genKey()
if err != nil {
@ -825,6 +991,13 @@ func (m *Manager) renewBefore() time.Duration {
return 720 * time.Hour // 30 days
}
func (m *Manager) now() time.Time {
if m.nowFunc != nil {
return m.nowFunc()
}
return time.Now()
}
// certState is ready when its mutex is unlocked for reading.
type certState struct {
sync.RWMutex
@ -850,12 +1023,12 @@ func (s *certState) tlscert() (*tls.Certificate, error) {
}, nil
}
// certRequest creates a certificate request for the given common name cn
// and optional SANs.
func certRequest(key crypto.Signer, cn string, san ...string) ([]byte, error) {
// certRequest generates a CSR for the given common name cn and optional SANs.
func certRequest(key crypto.Signer, cn string, ext []pkix.Extension, san ...string) ([]byte, error) {
req := &x509.CertificateRequest{
Subject: pkix.Name{CommonName: cn},
DNSNames: san,
Subject: pkix.Name{CommonName: cn},
DNSNames: san,
ExtraExtensions: ext,
}
return x509.CreateCertificateRequest(rand.Reader, req, key)
}
@ -886,12 +1059,12 @@ func parsePrivateKey(der []byte) (crypto.Signer, error) {
return nil, errors.New("acme/autocert: failed to parse private key")
}
// validCert parses a cert chain provided as der argument and verifies the leaf, der[0],
// corresponds to the private key, as well as the domain match and expiration dates.
// It doesn't do any revocation checking.
// validCert parses a cert chain provided as der argument and verifies the leaf and der[0]
// correspond to the private key, the domain and key type match, and expiration dates
// are valid. It doesn't do any revocation checking.
//
// The returned value is the verified leaf cert.
func validCert(domain string, der [][]byte, key crypto.Signer) (leaf *x509.Certificate, err error) {
func validCert(ck certKey, der [][]byte, key crypto.Signer, now time.Time) (leaf *x509.Certificate, err error) {
// parse public part(s)
var n int
for _, b := range der {
@ -903,22 +1076,21 @@ func validCert(domain string, der [][]byte, key crypto.Signer) (leaf *x509.Certi
n += copy(pub[n:], b)
}
x509Cert, err := x509.ParseCertificates(pub)
if len(x509Cert) == 0 {
if err != nil || len(x509Cert) == 0 {
return nil, errors.New("acme/autocert: no public key found")
}
// verify the leaf is not expired and matches the domain name
leaf = x509Cert[0]
now := timeNow()
if now.Before(leaf.NotBefore) {
return nil, errors.New("acme/autocert: certificate is not valid yet")
}
if now.After(leaf.NotAfter) {
return nil, errors.New("acme/autocert: expired certificate")
}
if err := leaf.VerifyHostname(domain); err != nil {
if err := leaf.VerifyHostname(ck.domain); err != nil {
return nil, err
}
// ensure the leaf corresponds to the private key
// ensure the leaf corresponds to the private key and matches the certKey type
switch pub := leaf.PublicKey.(type) {
case *rsa.PublicKey:
prv, ok := key.(*rsa.PrivateKey)
@ -928,6 +1100,9 @@ func validCert(domain string, der [][]byte, key crypto.Signer) (leaf *x509.Certi
if pub.N.Cmp(prv.N) != 0 {
return nil, errors.New("acme/autocert: private key does not match public key")
}
if !ck.isRSA && !ck.isToken {
return nil, errors.New("acme/autocert: key type does not match expected value")
}
case *ecdsa.PublicKey:
prv, ok := key.(*ecdsa.PrivateKey)
if !ok {
@ -936,22 +1111,15 @@ func validCert(domain string, der [][]byte, key crypto.Signer) (leaf *x509.Certi
if pub.X.Cmp(prv.X) != 0 || pub.Y.Cmp(prv.Y) != 0 {
return nil, errors.New("acme/autocert: private key does not match public key")
}
if ck.isRSA && !ck.isToken {
return nil, errors.New("acme/autocert: key type does not match expected value")
}
default:
return nil, errors.New("acme/autocert: unknown public key algorithm")
}
return leaf, nil
}
func retryAfter(v string) time.Duration {
if i, err := strconv.Atoi(v); err == nil {
return time.Duration(i) * time.Second
}
if t, err := http.ParseTime(v); err == nil {
return t.Sub(timeNow())
}
return time.Second
}
type lockedMathRand struct {
sync.Mutex
rnd *mathrand.Rand
@ -966,8 +1134,6 @@ func (r *lockedMathRand) int63n(max int64) int64 {
// For easier testing.
var (
timeNow = time.Now
// Called when a state is removed.
testDidRemoveState = func(domain string) {}
testDidRemoveState = func(certKey) {}
)

@ -5,6 +5,7 @@
package autocert
import (
"bytes"
"context"
"crypto"
"crypto/ecdsa"
@ -14,11 +15,13 @@ import (
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"encoding/base64"
"encoding/json"
"fmt"
"html/template"
"io"
"io/ioutil"
"math/big"
"net/http"
"net/http/httptest"
@ -29,6 +32,13 @@ import (
"time"
"golang.org/x/crypto/acme"
"golang.org/x/crypto/acme/autocert/internal/acmetest"
)
var (
exampleDomain = "example.org"
exampleCertKey = certKey{domain: exampleDomain}
exampleCertKeyRSA = certKey{domain: exampleDomain, isRSA: true}
)
var discoTmpl = template.Must(template.New("disco").Parse(`{
@ -64,6 +74,7 @@ var authzTmpl = template.Must(template.New("authz").Parse(`{
}`))
type memCache struct {
t *testing.T
mu sync.Mutex
keyData map[string][]byte
}
@ -79,7 +90,26 @@ func (m *memCache) Get(ctx context.Context, key string) ([]byte, error) {
return v, nil
}
// filenameSafe returns whether all characters in s are printable ASCII
// and safe to use in a filename on most filesystems.
func filenameSafe(s string) bool {
for _, c := range s {
if c < 0x20 || c > 0x7E {
return false
}
switch c {
case '\\', '/', ':', '*', '?', '"', '<', '>', '|':
return false
}
}
return true
}
func (m *memCache) Put(ctx context.Context, key string, data []byte) error {
if !filenameSafe(key) {
m.t.Errorf("invalid characters in cache key %q", key)
}
m.mu.Lock()
defer m.mu.Unlock()
@ -95,12 +125,29 @@ func (m *memCache) Delete(ctx context.Context, key string) error {
return nil
}
func newMemCache() *memCache {
func newMemCache(t *testing.T) *memCache {
return &memCache{
t: t,
keyData: make(map[string][]byte),
}
}
func (m *memCache) numCerts() int {
m.mu.Lock()
defer m.mu.Unlock()
res := 0
for key := range m.keyData {
if strings.HasSuffix(key, "+token") ||
strings.HasSuffix(key, "+key") ||
strings.HasSuffix(key, "+http-01") {
continue
}
res++
}
return res
}
func dummyCert(pub interface{}, san ...string) ([]byte, error) {
return dateDummyCert(pub, time.Now(), time.Now().Add(90*24*time.Hour), san...)
}
@ -137,53 +184,58 @@ func decodePayload(v interface{}, r io.Reader) error {
return json.Unmarshal(payload, v)
}
func clientHelloInfo(sni string, ecdsaSupport bool) *tls.ClientHelloInfo {
hello := &tls.ClientHelloInfo{
ServerName: sni,
CipherSuites: []uint16{tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305},
}
if ecdsaSupport {
hello.CipherSuites = append(hello.CipherSuites, tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305)
}
return hello
}
func TestGetCertificate(t *testing.T) {
man := &Manager{Prompt: AcceptTOS}
defer man.stopRenew()
hello := &tls.ClientHelloInfo{ServerName: "example.org"}
hello := clientHelloInfo("example.org", true)
testGetCertificate(t, man, "example.org", hello)
}
func TestGetCertificate_trailingDot(t *testing.T) {
man := &Manager{Prompt: AcceptTOS}
defer man.stopRenew()
hello := &tls.ClientHelloInfo{ServerName: "example.org."}
hello := clientHelloInfo("example.org.", true)
testGetCertificate(t, man, "example.org", hello)
}
func TestGetCertificate_ForceRSA(t *testing.T) {
man := &Manager{
Prompt: AcceptTOS,
Cache: newMemCache(),
Cache: newMemCache(t),
ForceRSA: true,
}
defer man.stopRenew()
hello := &tls.ClientHelloInfo{ServerName: "example.org"}
testGetCertificate(t, man, "example.org", hello)
hello := clientHelloInfo(exampleDomain, true)
testGetCertificate(t, man, exampleDomain, hello)
cert, err := man.cacheGet(context.Background(), "example.org")
// ForceRSA was deprecated and is now ignored.
cert, err := man.cacheGet(context.Background(), exampleCertKey)
if err != nil {
t.Fatalf("man.cacheGet: %v", err)
}
if _, ok := cert.PrivateKey.(*rsa.PrivateKey); !ok {
t.Errorf("cert.PrivateKey is %T; want *rsa.PrivateKey", cert.PrivateKey)
if _, ok := cert.PrivateKey.(*ecdsa.PrivateKey); !ok {
t.Errorf("cert.PrivateKey is %T; want *ecdsa.PrivateKey", cert.PrivateKey)
}
}
func TestGetCertificate_nilPrompt(t *testing.T) {
man := &Manager{}
defer man.stopRenew()
url, finish := startACMEServerStub(t, man, "example.org")
url, finish := startACMEServerStub(t, getCertificateFromManager(man, true), "example.org")
defer finish()
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatal(err)
}
man.Client = &acme.Client{
Key: key,
DirectoryURL: url,
}
hello := &tls.ClientHelloInfo{ServerName: "example.org"}
man.Client = &acme.Client{DirectoryURL: url}
hello := clientHelloInfo("example.org", true)
if _, err := man.GetCertificate(hello); err == nil {
t.Error("got certificate for example.org; wanted error")
}
@ -197,7 +249,7 @@ func TestGetCertificate_expiredCache(t *testing.T) {
}
tmpl := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{CommonName: "example.org"},
Subject: pkix.Name{CommonName: exampleDomain},
NotAfter: time.Now(),
}
pub, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &pk.PublicKey, pk)
@ -209,16 +261,16 @@ func TestGetCertificate_expiredCache(t *testing.T) {
PrivateKey: pk,
}
man := &Manager{Prompt: AcceptTOS, Cache: newMemCache()}
man := &Manager{Prompt: AcceptTOS, Cache: newMemCache(t)}
defer man.stopRenew()
if err := man.cachePut(context.Background(), "example.org", tlscert); err != nil {
if err := man.cachePut(context.Background(), exampleCertKey, tlscert); err != nil {
t.Fatalf("man.cachePut: %v", err)
}
// The expired cached cert should trigger a new cert issuance
// and return without an error.
hello := &tls.ClientHelloInfo{ServerName: "example.org"}
testGetCertificate(t, man, "example.org", hello)
hello := clientHelloInfo(exampleDomain, true)
testGetCertificate(t, man, exampleDomain, hello)
}
func TestGetCertificate_failedAttempt(t *testing.T) {
@ -227,7 +279,6 @@ func TestGetCertificate_failedAttempt(t *testing.T) {
}))
defer ts.Close()
const example = "example.org"
d := createCertRetryAfter
f := testDidRemoveState
defer func() {
@ -236,51 +287,168 @@ func TestGetCertificate_failedAttempt(t *testing.T) {
}()
createCertRetryAfter = 0
done := make(chan struct{})
testDidRemoveState = func(domain string) {
if domain != example {
t.Errorf("testDidRemoveState: domain = %q; want %q", domain, example)
testDidRemoveState = func(ck certKey) {
if ck != exampleCertKey {
t.Errorf("testDidRemoveState: domain = %v; want %v", ck, exampleCertKey)
}
close(done)
}
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatal(err)
}
man := &Manager{
Prompt: AcceptTOS,
Client: &acme.Client{
Key: key,
DirectoryURL: ts.URL,
},
}
defer man.stopRenew()
hello := &tls.ClientHelloInfo{ServerName: example}
hello := clientHelloInfo(exampleDomain, true)
if _, err := man.GetCertificate(hello); err == nil {
t.Error("GetCertificate: err is nil")
}
select {
case <-time.After(5 * time.Second):
t.Errorf("took too long to remove the %q state", example)
t.Errorf("took too long to remove the %q state", exampleCertKey)
case <-done:
man.stateMu.Lock()
defer man.stateMu.Unlock()
if v, exist := man.state[example]; exist {
t.Errorf("state exists for %q: %+v", example, v)
if v, exist := man.state[exampleCertKey]; exist {
t.Errorf("state exists for %v: %+v", exampleCertKey, v)
}
}
}
// testGetCertificate_tokenCache tests the fallback of token certificate fetches
// to cache when Manager.certTokens misses. ecdsaSupport refers to the CA when
// verifying the certificate token.
func testGetCertificate_tokenCache(t *testing.T, ecdsaSupport bool) {
man1 := &Manager{
Cache: newMemCache(t),
Prompt: AcceptTOS,
}
defer man1.stopRenew()
man2 := &Manager{
Cache: man1.Cache,
Prompt: AcceptTOS,
}
defer man2.stopRenew()
// Send the verification request to a different Manager from the one that
// initiated the authorization, when they share caches.
url, finish := startACMEServerStub(t, getCertificateFromManager(man2, ecdsaSupport), "example.org")
defer finish()
man1.Client = &acme.Client{DirectoryURL: url}
hello := clientHelloInfo("example.org", true)
if _, err := man1.GetCertificate(hello); err != nil {
t.Error(err)
}
if _, err := man2.GetCertificate(hello); err != nil {
t.Error(err)
}
}
func TestGetCertificate_tokenCache(t *testing.T) {
t.Run("ecdsaSupport=true", func(t *testing.T) {
testGetCertificate_tokenCache(t, true)
})
t.Run("ecdsaSupport=false", func(t *testing.T) {
testGetCertificate_tokenCache(t, false)
})
}
func TestGetCertificate_ecdsaVsRSA(t *testing.T) {
cache := newMemCache(t)
man := &Manager{Prompt: AcceptTOS, Cache: cache}
defer man.stopRenew()
url, finish := startACMEServerStub(t, getCertificateFromManager(man, true), "example.org")
defer finish()
man.Client = &acme.Client{DirectoryURL: url}
cert, err := man.GetCertificate(clientHelloInfo("example.org", true))
if err != nil {
t.Error(err)
}
if _, ok := cert.Leaf.PublicKey.(*ecdsa.PublicKey); !ok {
t.Error("an ECDSA client was served a non-ECDSA certificate")
}
cert, err = man.GetCertificate(clientHelloInfo("example.org", false))
if err != nil {
t.Error(err)
}
if _, ok := cert.Leaf.PublicKey.(*rsa.PublicKey); !ok {
t.Error("a RSA client was served a non-RSA certificate")
}
if _, err := man.GetCertificate(clientHelloInfo("example.org", true)); err != nil {
t.Error(err)
}
if _, err := man.GetCertificate(clientHelloInfo("example.org", false)); err != nil {
t.Error(err)
}
if numCerts := cache.numCerts(); numCerts != 2 {
t.Errorf("found %d certificates in cache; want %d", numCerts, 2)
}
}
func TestGetCertificate_wrongCacheKeyType(t *testing.T) {
cache := newMemCache(t)
man := &Manager{Prompt: AcceptTOS, Cache: cache}
defer man.stopRenew()
url, finish := startACMEServerStub(t, getCertificateFromManager(man, true), exampleDomain)
defer finish()
man.Client = &acme.Client{DirectoryURL: url}
// Make an RSA cert and cache it without suffix.
pk, err := rsa.GenerateKey(rand.Reader, 512)
if err != nil {
t.Fatal(err)
}
tmpl := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{CommonName: exampleDomain},
NotAfter: time.Now().Add(90 * 24 * time.Hour),
}
pub, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &pk.PublicKey, pk)
if err != nil {
t.Fatal(err)
}
rsaCert := &tls.Certificate{
Certificate: [][]byte{pub},
PrivateKey: pk,
}
if err := man.cachePut(context.Background(), exampleCertKey, rsaCert); err != nil {
t.Fatalf("man.cachePut: %v", err)
}
// The RSA cached cert should be silently ignored and replaced.
cert, err := man.GetCertificate(clientHelloInfo(exampleDomain, true))
if err != nil {
t.Error(err)
}
if _, ok := cert.Leaf.PublicKey.(*ecdsa.PublicKey); !ok {
t.Error("an ECDSA client was served a non-ECDSA certificate")
}
if numCerts := cache.numCerts(); numCerts != 1 {
t.Errorf("found %d certificates in cache; want %d", numCerts, 1)
}
}
func getCertificateFromManager(man *Manager, ecdsaSupport bool) func(string) error {
return func(sni string) error {
_, err := man.GetCertificate(clientHelloInfo(sni, ecdsaSupport))
return err
}
}
// startACMEServerStub runs an ACME server
// The domain argument is the expected domain name of a certificate request.
func startACMEServerStub(t *testing.T, man *Manager, domain string) (url string, finish func()) {
// TODO: Drop this in favour of x/crypto/acme/autocert/internal/acmetest.
func startACMEServerStub(t *testing.T, getCertificate func(string) error, domain string) (url string, finish func()) {
// echo token-02 | shasum -a 256
// then divide result in 2 parts separated by dot
tokenCertName := "4e8eb87631187e9ff2153b56b13a4dec.13a35d002e485d60ff37354b32f665d9.token.acme.invalid"
verifyTokenCert := func() {
hello := &tls.ClientHelloInfo{ServerName: tokenCertName}
_, err := man.GetCertificate(hello)
if err != nil {
if err := getCertificate(tokenCertName); err != nil {
t.Errorf("verifyTokenCert: GetCertificate(%q): %v", tokenCertName, err)
return
}
@ -362,8 +530,7 @@ func startACMEServerStub(t *testing.T, man *Manager, domain string) (url string,
tick := time.NewTicker(100 * time.Millisecond)
defer tick.Stop()
for {
hello := &tls.ClientHelloInfo{ServerName: tokenCertName}
if _, err := man.GetCertificate(hello); err != nil {
if err := getCertificate(tokenCertName); err != nil {
return
}
select {
@ -387,21 +554,13 @@ func startACMEServerStub(t *testing.T, man *Manager, domain string) (url string,
// tests man.GetCertificate flow using the provided hello argument.
// The domain argument is the expected domain name of a certificate request.
func testGetCertificate(t *testing.T, man *Manager, domain string, hello *tls.ClientHelloInfo) {
url, finish := startACMEServerStub(t, man, domain)
url, finish := startACMEServerStub(t, getCertificateFromManager(man, true), domain)
defer finish()
// use EC key to run faster on 386
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatal(err)
}
man.Client = &acme.Client{
Key: key,
DirectoryURL: url,
}
man.Client = &acme.Client{DirectoryURL: url}
// simulate tls.Config.GetCertificate
var tlscert *tls.Certificate
var err error
done := make(chan struct{})
go func() {
tlscert, err = man.GetCertificate(hello)
@ -445,13 +604,13 @@ func TestVerifyHTTP01(t *testing.T) {
if w.Code != http.StatusOK {
t.Errorf("http token: w.Code = %d; want %d", w.Code, http.StatusOK)
}
if v := string(w.Body.Bytes()); !strings.HasPrefix(v, "token-http-01.") {
if v := w.Body.String(); !strings.HasPrefix(v, "token-http-01.") {
t.Errorf("http token value = %q; want 'token-http-01.' prefix", v)
}
}
// ACME CA server stub, only the needed bits.
// TODO: Merge this with startACMEServerStub, making it a configurable CA for testing.
// TODO: Replace this with x/crypto/acme/autocert/internal/acmetest.
var ca *httptest.Server
ca = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Replay-Nonce", "nonce")
@ -505,18 +664,18 @@ func TestVerifyHTTP01(t *testing.T) {
}))
defer ca.Close()
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatal(err)
}
m := &Manager{
Client: &acme.Client{
Key: key,
DirectoryURL: ca.URL,
},
}
http01 = m.HTTPHandler(nil)
if err := m.verify(context.Background(), m.Client, "example.org"); err != nil {
ctx := context.Background()
client, err := m.acmeClient(ctx)
if err != nil {
t.Fatalf("m.acmeClient: %v", err)
}
if err := m.verify(ctx, client, "example.org"); err != nil {
t.Errorf("m.verify: %v", err)
}
// Only tls-sni-01, tls-sni-02 and http-01 must be accepted
@ -529,6 +688,111 @@ func TestVerifyHTTP01(t *testing.T) {
}
}
func TestRevokeFailedAuthz(t *testing.T) {
// Prefill authorization URIs expected to be revoked.
// The challenges are selected in a specific order,
// each tried within a newly created authorization.
// This means each authorization URI corresponds to a different challenge type.
revokedAuthz := map[string]bool{
"/authz/0": false, // tls-sni-02
"/authz/1": false, // tls-sni-01
"/authz/2": false, // no viable challenge, but authz is created
}
var authzCount int // num. of created authorizations
var revokeCount int // num. of revoked authorizations
done := make(chan struct{}) // closed when revokeCount is 3
// ACME CA server stub, only the needed bits.
// TODO: Replace this with x/crypto/acme/autocert/internal/acmetest.
var ca *httptest.Server
ca = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Replay-Nonce", "nonce")
if r.Method == "HEAD" {
// a nonce request
return
}
switch r.URL.Path {
// Discovery.
case "/":
if err := discoTmpl.Execute(w, ca.URL); err != nil {
t.Errorf("discoTmpl: %v", err)
}
// Client key registration.
case "/new-reg":
w.Write([]byte("{}"))
// New domain authorization.
case "/new-authz":
w.Header().Set("Location", fmt.Sprintf("%s/authz/%d", ca.URL, authzCount))
w.WriteHeader(http.StatusCreated)
if err := authzTmpl.Execute(w, ca.URL); err != nil {
t.Errorf("authzTmpl: %v", err)
}
authzCount++
// tls-sni-02 challenge "accept" request.
case "/challenge/2":
// Refuse.
http.Error(w, "won't accept tls-sni-02 challenge", http.StatusBadRequest)
// tls-sni-01 challenge "accept" request.
case "/challenge/1":
// Accept but the authorization will be "expired".
w.Write([]byte("{}"))
// Authorization requests.
case "/authz/0", "/authz/1", "/authz/2":
// Revocation requests.
if r.Method == "POST" {
var req struct{ Status string }
if err := decodePayload(&req, r.Body); err != nil {
t.Errorf("%s: decodePayload: %v", r.URL, err)
}
switch req.Status {
case "deactivated":
revokedAuthz[r.URL.Path] = true
revokeCount++
if revokeCount >= 3 {
// Last authorization is revoked.
defer close(done)
}
default:
t.Errorf("%s: req.Status = %q; want 'deactivated'", r.URL, req.Status)
}
w.Write([]byte(`{"status": "invalid"}`))
return
}
// Authorization status requests.
// Simulate abandoned authorization, deleted by the CA.
w.WriteHeader(http.StatusNotFound)
default:
http.NotFound(w, r)
t.Errorf("unrecognized r.URL.Path: %s", r.URL.Path)
}
}))
defer ca.Close()
m := &Manager{
Client: &acme.Client{DirectoryURL: ca.URL},
}
// Should fail and revoke 3 authorizations.
// The first 2 are tsl-sni-02 and tls-sni-01 challenges.
// The third time an authorization is created but no viable challenge is found.
// See revokedAuthz above for more explanation.
if _, err := m.createCert(context.Background(), exampleCertKey); err == nil {
t.Errorf("m.createCert returned nil error")
}
select {
case <-time.After(3 * time.Second):
t.Error("revocations took too long")
case <-done:
// revokeCount is at least 3.
}
for uri, ok := range revokedAuthz {
if !ok {
t.Errorf("%q authorization was not revoked", uri)
}
}
}
func TestHTTPHandlerDefaultFallback(t *testing.T) {
tt := []struct {
method, url string
@ -571,7 +835,7 @@ func TestHTTPHandlerDefaultFallback(t *testing.T) {
}
func TestAccountKeyCache(t *testing.T) {
m := Manager{Cache: newMemCache()}
m := Manager{Cache: newMemCache(t)}
ctx := context.Background()
k1, err := m.accountKey(ctx)
if err != nil {
@ -587,36 +851,57 @@ func TestAccountKeyCache(t *testing.T) {
}
func TestCache(t *testing.T) {
privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
ecdsaKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatal(err)
}
tmpl := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{CommonName: "example.org"},
NotAfter: time.Now().Add(time.Hour),
}
pub, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &privKey.PublicKey, privKey)
cert, err := dummyCert(ecdsaKey.Public(), exampleDomain)
if err != nil {
t.Fatal(err)
}
tlscert := &tls.Certificate{
Certificate: [][]byte{pub},
PrivateKey: privKey,
ecdsaCert := &tls.Certificate{
Certificate: [][]byte{cert},
PrivateKey: ecdsaKey,
}
man := &Manager{Cache: newMemCache()}
rsaKey, err := rsa.GenerateKey(rand.Reader, 512)
if err != nil {
t.Fatal(err)
}
cert, err = dummyCert(rsaKey.Public(), exampleDomain)
if err != nil {
t.Fatal(err)
}
rsaCert := &tls.Certificate{
Certificate: [][]byte{cert},
PrivateKey: rsaKey,
}
man := &Manager{Cache: newMemCache(t)}
defer man.stopRenew()
ctx := context.Background()
if err := man.cachePut(ctx, "example.org", tlscert); err != nil {
if err := man.cachePut(ctx, exampleCertKey, ecdsaCert); err != nil {
t.Fatalf("man.cachePut: %v", err)
}
res, err := man.cacheGet(ctx, "example.org")
if err := man.cachePut(ctx, exampleCertKeyRSA, rsaCert); err != nil {
t.Fatalf("man.cachePut: %v", err)
}
res, err := man.cacheGet(ctx, exampleCertKey)
if err != nil {
t.Fatalf("man.cacheGet: %v", err)
}
if res == nil {
t.Fatal("res is nil")
if res == nil || !bytes.Equal(res.Certificate[0], ecdsaCert.Certificate[0]) {
t.Errorf("man.cacheGet = %+v; want %+v", res, ecdsaCert)
}
res, err = man.cacheGet(ctx, exampleCertKeyRSA)
if err != nil {
t.Fatalf("man.cacheGet: %v", err)
}
if res == nil || !bytes.Equal(res.Certificate[0], rsaCert.Certificate[0]) {
t.Errorf("man.cacheGet = %+v; want %+v", res, rsaCert)
}
}
@ -680,26 +965,28 @@ func TestValidCert(t *testing.T) {
}
tt := []struct {
domain string
key crypto.Signer
cert [][]byte
ok bool
ck certKey
key crypto.Signer
cert [][]byte
ok bool
}{
{"example.org", key1, [][]byte{cert1}, true},
{"example.org", key3, [][]byte{cert3}, true},
{"example.org", key1, [][]byte{cert1, cert2, cert3}, true},
{"example.org", key1, [][]byte{cert1, {1}}, false},
{"example.org", key1, [][]byte{{1}}, false},
{"example.org", key1, [][]byte{cert2}, false},
{"example.org", key2, [][]byte{cert1}, false},
{"example.org", key1, [][]byte{cert3}, false},
{"example.org", key3, [][]byte{cert1}, false},
{"example.net", key1, [][]byte{cert1}, false},
{"example.org", key1, [][]byte{early}, false},
{"example.org", key1, [][]byte{expired}, false},
{certKey{domain: "example.org"}, key1, [][]byte{cert1}, true},
{certKey{domain: "example.org", isRSA: true}, key3, [][]byte{cert3}, true},
{certKey{domain: "example.org"}, key1, [][]byte{cert1, cert2, cert3}, true},
{certKey{domain: "example.org"}, key1, [][]byte{cert1, {1}}, false},
{certKey{domain: "example.org"}, key1, [][]byte{{1}}, false},
{certKey{domain: "example.org"}, key1, [][]byte{cert2}, false},
{certKey{domain: "example.org"}, key2, [][]byte{cert1}, false},
{certKey{domain: "example.org"}, key1, [][]byte{cert3}, false},
{certKey{domain: "example.org"}, key3, [][]byte{cert1}, false},
{certKey{domain: "example.net"}, key1, [][]byte{cert1}, false},
{certKey{domain: "example.org"}, key1, [][]byte{early}, false},
{certKey{domain: "example.org"}, key1, [][]byte{expired}, false},
{certKey{domain: "example.org", isRSA: true}, key1, [][]byte{cert1}, false},
{certKey{domain: "example.org"}, key3, [][]byte{cert3}, false},
}
for i, test := range tt {
leaf, err := validCert(test.domain, test.cert, test.key)
leaf, err := validCert(test.ck, test.cert, test.key, now)
if err != nil && test.ok {
t.Errorf("%d: err = %v", i, err)
}
@ -748,10 +1035,155 @@ func TestManagerGetCertificateBogusSNI(t *testing.T) {
{"fo.o", "cache.Get of fo.o"},
}
for _, tt := range tests {
_, err := m.GetCertificate(&tls.ClientHelloInfo{ServerName: tt.name})
_, err := m.GetCertificate(clientHelloInfo(tt.name, true))
got := fmt.Sprint(err)
if got != tt.wantErr {
t.Errorf("GetCertificate(SNI = %q) = %q; want %q", tt.name, got, tt.wantErr)
}
}
}
func TestCertRequest(t *testing.T) {
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatal(err)
}
// An extension from RFC7633. Any will do.
ext := pkix.Extension{
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1},
Value: []byte("dummy"),
}
b, err := certRequest(key, "example.org", []pkix.Extension{ext}, "san.example.org")
if err != nil {
t.Fatalf("certRequest: %v", err)
}
r, err := x509.ParseCertificateRequest(b)
if err != nil {
t.Fatalf("ParseCertificateRequest: %v", err)
}
var found bool
for _, v := range r.Extensions {
if v.Id.Equal(ext.Id) {
found = true
break
}
}
if !found {
t.Errorf("want %v in Extensions: %v", ext, r.Extensions)
}
}
func TestSupportsECDSA(t *testing.T) {
tests := []struct {
CipherSuites []uint16
SignatureSchemes []tls.SignatureScheme
SupportedCurves []tls.CurveID
ecdsaOk bool
}{
{[]uint16{
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
}, nil, nil, false},
{[]uint16{
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
}, nil, nil, true},
// SignatureSchemes limits, not extends, CipherSuites
{[]uint16{
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
}, []tls.SignatureScheme{
tls.PKCS1WithSHA256, tls.ECDSAWithP256AndSHA256,
}, nil, false},
{[]uint16{
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
}, []tls.SignatureScheme{
tls.PKCS1WithSHA256,
}, nil, false},
{[]uint16{
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
}, []tls.SignatureScheme{
tls.PKCS1WithSHA256, tls.ECDSAWithP256AndSHA256,
}, nil, true},
{[]uint16{
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
}, []tls.SignatureScheme{
tls.PKCS1WithSHA256, tls.ECDSAWithP256AndSHA256,
}, []tls.CurveID{
tls.CurveP521,
}, false},
{[]uint16{
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
}, []tls.SignatureScheme{
tls.PKCS1WithSHA256, tls.ECDSAWithP256AndSHA256,
}, []tls.CurveID{
tls.CurveP256,
tls.CurveP521,
}, true},
}
for i, tt := range tests {
result := supportsECDSA(&tls.ClientHelloInfo{
CipherSuites: tt.CipherSuites,
SignatureSchemes: tt.SignatureSchemes,
SupportedCurves: tt.SupportedCurves,
})
if result != tt.ecdsaOk {
t.Errorf("%d: supportsECDSA = %v; want %v", i, result, tt.ecdsaOk)
}
}
}
// TODO: add same end-to-end for http-01 challenge type.
func TestEndToEnd(t *testing.T) {
const domain = "example.org"
// ACME CA server
ca := acmetest.NewCAServer([]string{"tls-alpn-01"}, []string{domain})
defer ca.Close()
// User dummy server.
m := &Manager{
Prompt: AcceptTOS,
Client: &acme.Client{DirectoryURL: ca.URL},
}
us := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("OK"))
}))
us.TLS = &tls.Config{
NextProtos: []string{"http/1.1", acme.ALPNProto},
GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
cert, err := m.GetCertificate(hello)
if err != nil {
t.Errorf("m.GetCertificate: %v", err)
}
return cert, err
},
}
us.StartTLS()
defer us.Close()
// In TLS-ALPN challenge verification, CA connects to the domain:443 in question.
// Because the domain won't resolve in tests, we need to tell the CA
// where to dial to instead.
ca.Resolve(domain, strings.TrimPrefix(us.URL, "https://"))
// A client visiting user dummy server.
tr := &http.Transport{
TLSClientConfig: &tls.Config{
RootCAs: ca.Roots,
ServerName: domain,
},
}
client := &http.Client{Transport: tr}
res, err := client.Get(us.URL)
if err != nil {
t.Logf("CA errors: %v", ca.Errors())
t.Fatal(err)
}
defer res.Body.Close()
b, err := ioutil.ReadAll(res.Body)
if err != nil {
t.Fatal(err)
}
if v := string(b); v != "OK" {
t.Errorf("user server response: %q; want 'OK'", v)
}
}

@ -16,10 +16,10 @@ import (
var ErrCacheMiss = errors.New("acme/autocert: certificate cache miss")
// Cache is used by Manager to store and retrieve previously obtained certificates
// as opaque data.
// and other account data as opaque blobs.
//
// The key argument of the methods refers to a domain name but need not be an FQDN.
// Cache implementations should not rely on the key naming pattern.
// Cache implementations should not rely on the key naming pattern. Keys can
// include any printable ASCII characters, except the following: \/:*?"<>|
type Cache interface {
// Get returns a certificate data for the specified key.
// If there's no such key, Get returns ErrCacheMiss.

@ -5,7 +5,6 @@
package autocert_test
import (
"crypto/tls"
"fmt"
"log"
"net/http"
@ -25,12 +24,11 @@ func ExampleManager() {
m := &autocert.Manager{
Cache: autocert.DirCache("secret-dir"),
Prompt: autocert.AcceptTOS,
HostPolicy: autocert.HostWhitelist("example.org"),
HostPolicy: autocert.HostWhitelist("example.org", "www.example.org"),
}
go http.ListenAndServe(":http", m.HTTPHandler(nil))
s := &http.Server{
Addr: ":https",
TLSConfig: &tls.Config{GetCertificate: m.GetCertificate},
TLSConfig: m.TLSConfig(),
}
s.ListenAndServeTLS("", "")
}

@ -72,11 +72,8 @@ func NewListener(domains ...string) net.Listener {
// the Manager m's Prompt, Cache, HostPolicy, and other desired options.
func (m *Manager) Listener() net.Listener {
ln := &listener{
m: m,
conf: &tls.Config{
GetCertificate: m.GetCertificate, // bonus: panic on nil m
NextProtos: []string{"h2", "http/1.1"}, // Enable HTTP/2
},
m: m,
conf: m.TLSConfig(),
}
ln.tcpListener, ln.tcpListenErr = net.Listen("tcp", ":443")
return ln

@ -17,9 +17,9 @@ const renewJitter = time.Hour
// domainRenewal tracks the state used by the periodic timers
// renewing a single domain's cert.
type domainRenewal struct {
m *Manager
domain string
key crypto.Signer
m *Manager
ck certKey
key crypto.Signer
timerMu sync.Mutex
timer *time.Timer
@ -71,25 +71,43 @@ func (dr *domainRenewal) renew() {
testDidRenewLoop(next, err)
}
// updateState locks and replaces the relevant Manager.state item with the given
// state. It additionally updates dr.key with the given state's key.
func (dr *domainRenewal) updateState(state *certState) {
dr.m.stateMu.Lock()
defer dr.m.stateMu.Unlock()
dr.key = state.key
dr.m.state[dr.ck] = state
}
// do is similar to Manager.createCert but it doesn't lock a Manager.state item.
// Instead, it requests a new certificate independently and, upon success,
// replaces dr.m.state item with a new one and updates cache for the given domain.
//
// It may return immediately if the expiration date of the currently cached cert
// is far enough in the future.
// It may lock and update the Manager.state if the expiration date of the currently
// cached cert is far enough in the future.
//
// The returned value is a time interval after which the renewal should occur again.
func (dr *domainRenewal) do(ctx context.Context) (time.Duration, error) {
// a race is likely unavoidable in a distributed environment
// but we try nonetheless
if tlscert, err := dr.m.cacheGet(ctx, dr.domain); err == nil {
if tlscert, err := dr.m.cacheGet(ctx, dr.ck); err == nil {
next := dr.next(tlscert.Leaf.NotAfter)
if next > dr.m.renewBefore()+renewJitter {
return next, nil
signer, ok := tlscert.PrivateKey.(crypto.Signer)
if ok {
state := &certState{
key: signer,
cert: tlscert.Certificate,
leaf: tlscert.Leaf,
}
dr.updateState(state)
return next, nil
}
}
}
der, leaf, err := dr.m.authorizedCert(ctx, dr.key, dr.domain)
der, leaf, err := dr.m.authorizedCert(ctx, dr.key, dr.ck)
if err != nil {
return 0, err
}
@ -102,16 +120,15 @@ func (dr *domainRenewal) do(ctx context.Context) (time.Duration, error) {
if err != nil {
return 0, err
}
dr.m.cachePut(ctx, dr.domain, tlscert)
dr.m.stateMu.Lock()
defer dr.m.stateMu.Unlock()
// m.state is guaranteed to be non-nil at this point
dr.m.state[dr.domain] = state
if err := dr.m.cachePut(ctx, dr.ck, tlscert); err != nil {
return 0, err
}
dr.updateState(state)
return dr.next(leaf.NotAfter), nil
}
func (dr *domainRenewal) next(expiry time.Time) time.Duration {
d := expiry.Sub(timeNow()) - dr.m.renewBefore()
d := expiry.Sub(dr.m.now()) - dr.m.renewBefore()
// add a bit of randomness to renew deadline
n := pseudoRand.int63n(int64(renewJitter))
d -= time.Duration(n)

@ -23,10 +23,10 @@ import (
func TestRenewalNext(t *testing.T) {
now := time.Now()
timeNow = func() time.Time { return now }
defer func() { timeNow = time.Now }()
man := &Manager{RenewBefore: 7 * 24 * time.Hour}
man := &Manager{
RenewBefore: 7 * 24 * time.Hour,
nowFunc: func() time.Time { return now },
}
defer man.stopRenew()
tt := []struct {
expiry time.Time
@ -48,8 +48,6 @@ func TestRenewalNext(t *testing.T) {
}
func TestRenewFromCache(t *testing.T) {
const domain = "example.org"
// ACME CA server stub
var ca *httptest.Server
ca = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@ -84,7 +82,7 @@ func TestRenewFromCache(t *testing.T) {
if err != nil {
t.Fatalf("new-cert: CSR: %v", err)
}
der, err := dummyCert(csr.PublicKey, domain)
der, err := dummyCert(csr.PublicKey, exampleDomain)
if err != nil {
t.Fatalf("new-cert: dummyCert: %v", err)
}
@ -105,30 +103,28 @@ func TestRenewFromCache(t *testing.T) {
}))
defer ca.Close()
// use EC key to run faster on 386
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatal(err)
}
man := &Manager{
Prompt: AcceptTOS,
Cache: newMemCache(),
Cache: newMemCache(t),
RenewBefore: 24 * time.Hour,
Client: &acme.Client{
Key: key,
DirectoryURL: ca.URL,
},
}
defer man.stopRenew()
// cache an almost expired cert
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatal(err)
}
now := time.Now()
cert, err := dateDummyCert(key.Public(), now.Add(-2*time.Hour), now.Add(time.Minute), domain)
cert, err := dateDummyCert(key.Public(), now.Add(-2*time.Hour), now.Add(time.Minute), exampleDomain)
if err != nil {
t.Fatal(err)
}
tlscert := &tls.Certificate{PrivateKey: key, Certificate: [][]byte{cert}}
if err := man.cachePut(context.Background(), domain, tlscert); err != nil {
if err := man.cachePut(context.Background(), exampleCertKey, tlscert); err != nil {
t.Fatal(err)
}
@ -152,7 +148,7 @@ func TestRenewFromCache(t *testing.T) {
// ensure the new cert is cached
after := time.Now().Add(future)
tlscert, err := man.cacheGet(context.Background(), domain)
tlscert, err := man.cacheGet(context.Background(), exampleCertKey)
if err != nil {
t.Fatalf("man.cacheGet: %v", err)
}
@ -163,9 +159,9 @@ func TestRenewFromCache(t *testing.T) {
// verify the old cert is also replaced in memory
man.stateMu.Lock()
defer man.stateMu.Unlock()
s := man.state[domain]
s := man.state[exampleCertKey]
if s == nil {
t.Fatalf("m.state[%q] is nil", domain)
t.Fatalf("m.state[%q] is nil", exampleCertKey)
}
tlscert, err = s.tlscert()
if err != nil {
@ -177,7 +173,7 @@ func TestRenewFromCache(t *testing.T) {
}
// trigger renew
hello := &tls.ClientHelloInfo{ServerName: domain}
hello := clientHelloInfo(exampleDomain, true)
if _, err := man.GetCertificate(hello); err != nil {
t.Fatal(err)
}
@ -189,3 +185,145 @@ func TestRenewFromCache(t *testing.T) {
case <-done:
}
}
func TestRenewFromCacheAlreadyRenewed(t *testing.T) {
man := &Manager{
Prompt: AcceptTOS,
Cache: newMemCache(t),
RenewBefore: 24 * time.Hour,
Client: &acme.Client{
DirectoryURL: "invalid",
},
}
defer man.stopRenew()
// cache a recently renewed cert with a different private key
newKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatal(err)
}
now := time.Now()
newCert, err := dateDummyCert(newKey.Public(), now.Add(-2*time.Hour), now.Add(time.Hour*24*90), exampleDomain)
if err != nil {
t.Fatal(err)
}
newLeaf, err := validCert(exampleCertKey, [][]byte{newCert}, newKey, now)
if err != nil {
t.Fatal(err)
}
newTLSCert := &tls.Certificate{PrivateKey: newKey, Certificate: [][]byte{newCert}, Leaf: newLeaf}
if err := man.cachePut(context.Background(), exampleCertKey, newTLSCert); err != nil {
t.Fatal(err)
}
// set internal state to an almost expired cert
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatal(err)
}
oldCert, err := dateDummyCert(key.Public(), now.Add(-2*time.Hour), now.Add(time.Minute), exampleDomain)
if err != nil {
t.Fatal(err)
}
oldLeaf, err := validCert(exampleCertKey, [][]byte{oldCert}, key, now)
if err != nil {
t.Fatal(err)
}
man.stateMu.Lock()
if man.state == nil {
man.state = make(map[certKey]*certState)
}
s := &certState{
key: key,
cert: [][]byte{oldCert},
leaf: oldLeaf,
}
man.state[exampleCertKey] = s
man.stateMu.Unlock()
// veriy the renewal accepted the newer cached cert
defer func() {
testDidRenewLoop = func(next time.Duration, err error) {}
}()
done := make(chan struct{})
testDidRenewLoop = func(next time.Duration, err error) {
defer close(done)
if err != nil {
t.Errorf("testDidRenewLoop: %v", err)
}
// Next should be about 90 days
// Previous expiration was within 1 min.
future := 88 * 24 * time.Hour
if next < future {
t.Errorf("testDidRenewLoop: next = %v; want >= %v", next, future)
}
// ensure the cached cert was not modified
tlscert, err := man.cacheGet(context.Background(), exampleCertKey)
if err != nil {
t.Fatalf("man.cacheGet: %v", err)
}
if !tlscert.Leaf.NotAfter.Equal(newLeaf.NotAfter) {
t.Errorf("cache leaf.NotAfter = %v; want == %v", tlscert.Leaf.NotAfter, newLeaf.NotAfter)
}
// verify the old cert is also replaced in memory
man.stateMu.Lock()
defer man.stateMu.Unlock()
s := man.state[exampleCertKey]
if s == nil {
t.Fatalf("m.state[%q] is nil", exampleCertKey)
}
stateKey := s.key.Public().(*ecdsa.PublicKey)
if stateKey.X.Cmp(newKey.X) != 0 || stateKey.Y.Cmp(newKey.Y) != 0 {
t.Fatalf("state key was not updated from cache x: %v y: %v; want x: %v y: %v", stateKey.X, stateKey.Y, newKey.X, newKey.Y)
}
tlscert, err = s.tlscert()
if err != nil {
t.Fatalf("s.tlscert: %v", err)
}
if !tlscert.Leaf.NotAfter.Equal(newLeaf.NotAfter) {
t.Errorf("state leaf.NotAfter = %v; want == %v", tlscert.Leaf.NotAfter, newLeaf.NotAfter)
}
// verify the private key is replaced in the renewal state
r := man.renewal[exampleCertKey]
if r == nil {
t.Fatalf("m.renewal[%q] is nil", exampleCertKey)
}
renewalKey := r.key.Public().(*ecdsa.PublicKey)
if renewalKey.X.Cmp(newKey.X) != 0 || renewalKey.Y.Cmp(newKey.Y) != 0 {
t.Fatalf("renewal private key was not updated from cache x: %v y: %v; want x: %v y: %v", renewalKey.X, renewalKey.Y, newKey.X, newKey.Y)
}
}
// assert the expiring cert is returned from state
hello := clientHelloInfo(exampleDomain, true)
tlscert, err := man.GetCertificate(hello)
if err != nil {
t.Fatal(err)
}
if !oldLeaf.NotAfter.Equal(tlscert.Leaf.NotAfter) {
t.Errorf("state leaf.NotAfter = %v; want == %v", tlscert.Leaf.NotAfter, oldLeaf.NotAfter)
}
// trigger renew
go man.renew(exampleCertKey, s.key, s.leaf.NotAfter)
// wait for renew loop
select {
case <-time.After(10 * time.Second):
t.Fatal("renew took too long to occur")
case <-done:
// assert the new cert is returned from state after renew
hello := clientHelloInfo(exampleDomain, true)
tlscert, err := man.GetCertificate(hello)
if err != nil {
t.Fatal(err)
}
if !newTLSCert.Leaf.NotAfter.Equal(tlscert.Leaf.NotAfter) {
t.Errorf("state leaf.NotAfter = %v; want == %v", tlscert.Leaf.NotAfter, newTLSCert.Leaf.NotAfter)
}
}
}

@ -25,7 +25,7 @@ func jwsEncodeJSON(claimset interface{}, key crypto.Signer, nonce string) ([]byt
if err != nil {
return nil, err
}
alg, sha := jwsHasher(key)
alg, sha := jwsHasher(key.Public())
if alg == "" || !sha.Available() {
return nil, ErrUnsupportedKey
}
@ -97,13 +97,16 @@ func jwkEncode(pub crypto.PublicKey) (string, error) {
}
// jwsSign signs the digest using the given key.
// It returns ErrUnsupportedKey if the key type is unknown.
// The hash is used only for RSA keys.
// The hash is unused for ECDSA keys.
//
// Note: non-stdlib crypto.Signer implementations are expected to return
// the signature in the format as specified in RFC7518.
// See https://tools.ietf.org/html/rfc7518 for more details.
func jwsSign(key crypto.Signer, hash crypto.Hash, digest []byte) ([]byte, error) {
switch key := key.(type) {
case *rsa.PrivateKey:
return key.Sign(rand.Reader, digest, hash)
case *ecdsa.PrivateKey:
if key, ok := key.(*ecdsa.PrivateKey); ok {
// The key.Sign method of ecdsa returns ASN1-encoded signature.
// So, we use the package Sign function instead
// to get R and S values directly and format the result accordingly.
r, s, err := ecdsa.Sign(rand.Reader, key, digest)
if err != nil {
return nil, err
@ -118,18 +121,18 @@ func jwsSign(key crypto.Signer, hash crypto.Hash, digest []byte) ([]byte, error)
copy(sig[size*2-len(sb):], sb)
return sig, nil
}
return nil, ErrUnsupportedKey
return key.Sign(rand.Reader, digest, hash)
}
// jwsHasher indicates suitable JWS algorithm name and a hash function
// to use for signing a digest with the provided key.
// It returns ("", 0) if the key is not supported.
func jwsHasher(key crypto.Signer) (string, crypto.Hash) {
switch key := key.(type) {
case *rsa.PrivateKey:
func jwsHasher(pub crypto.PublicKey) (string, crypto.Hash) {
switch pub := pub.(type) {
case *rsa.PublicKey:
return "RS256", crypto.SHA256
case *ecdsa.PrivateKey:
switch key.Params().Name {
case *ecdsa.PublicKey:
switch pub.Params().Name {
case "P-256":
return "ES256", crypto.SHA256
case "P-384":

@ -5,6 +5,7 @@
package acme
import (
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rsa"
@ -13,6 +14,7 @@ import (
"encoding/json"
"encoding/pem"
"fmt"
"io"
"math/big"
"testing"
)
@ -241,6 +243,79 @@ func TestJWSEncodeJSONEC(t *testing.T) {
}
}
type customTestSigner struct {
sig []byte
pub crypto.PublicKey
}
func (s *customTestSigner) Public() crypto.PublicKey { return s.pub }
func (s *customTestSigner) Sign(io.Reader, []byte, crypto.SignerOpts) ([]byte, error) {
return s.sig, nil
}
func TestJWSEncodeJSONCustom(t *testing.T) {
claims := struct{ Msg string }{"hello"}
const (
// printf '{"Msg":"hello"}' | base64 | tr -d '=' | tr '/+' '_-'
payload = "eyJNc2ciOiJoZWxsbyJ9"
// printf 'testsig' | base64 | tr -d '='
testsig = "dGVzdHNpZw"
// printf '{"alg":"ES256","jwk":{"crv":"P-256","kty":"EC","x":<testKeyECPubY>,"y":<testKeyECPubY>,"nonce":"nonce"}' | \
// base64 | tr -d '=' | tr '/+' '_-'
es256phead = "eyJhbGciOiJFUzI1NiIsImp3ayI6eyJjcnYiOiJQLTI1NiIsImt0eSI6IkVDIiwieCI6IjVsaEV1" +
"ZzV4SzR4QkRaMm5BYmF4THRhTGl2ODVieEo3ZVBkMWRrTzIzSFEiLCJ5IjoiNGFpSzcyc0JlVUFH" +
"a3YwVGFMc213b2tZVVl5TnhHc1M1RU1JS3dzTklLayJ9LCJub25jZSI6Im5vbmNlIn0"
// {"alg":"RS256","jwk":{"e":"AQAB","kty":"RSA","n":"..."},"nonce":"nonce"}
rs256phead = "eyJhbGciOiJSUzI1NiIsImp3ayI6eyJlIjoiQVFBQiIsImt0eSI6" +
"IlJTQSIsIm4iOiI0eGdaM2VSUGt3b1J2eTdxZVJVYm1NRGUwVi14" +
"SDllV0xkdTBpaGVlTGxybUQybXFXWGZQOUllU0tBcGJuMzRnOFR1" +
"QVM5ZzV6aHE4RUxRM2ttanItS1Y4NkdBTWdJNlZBY0dscTNRcnpw" +
"VENmXzMwQWI3LXphd3JmUmFGT05hMUh3RXpQWTFLSG5HVmt4SmM4" +
"NWdOa3dZSTlTWTJSSFh0dmxuM3pzNXdJVE5yZG9zcUVYZWFJa1ZZ" +
"QkVoYmhOdTU0cHAza3hvNlR1V0xpOWU2cFhlV2V0RXdtbEJ3dFda" +
"bFBvaWIyajNUeExCa3NLWmZveUZ5ZWszODBtSGdKQXVtUV9JMmZq" +
"ajk4Xzk3bWszaWhPWTRBZ1ZkQ0RqMXpfR0NvWmtHNVJxN25iQ0d5" +
"b3N5S1d5RFgwMFpzLW5OcVZob0xlSXZYQzRubldkSk1aNnJvZ3h5" +
"UVEifSwibm9uY2UiOiJub25jZSJ9"
)
tt := []struct {
alg, phead string
pub crypto.PublicKey
}{
{"RS256", rs256phead, testKey.Public()},
{"ES256", es256phead, testKeyEC.Public()},
}
for _, tc := range tt {
tc := tc
t.Run(tc.alg, func(t *testing.T) {
signer := &customTestSigner{
sig: []byte("testsig"),
pub: tc.pub,
}
b, err := jwsEncodeJSON(claims, signer, "nonce")
if err != nil {
t.Fatal(err)
}
var j struct{ Protected, Payload, Signature string }
if err := json.Unmarshal(b, &j); err != nil {
t.Fatal(err)
}
if j.Protected != tc.phead {
t.Errorf("j.Protected = %q\nwant %q", j.Protected, tc.phead)
}
if j.Payload != payload {
t.Errorf("j.Payload = %q\nwant %q", j.Payload, payload)
}
if j.Signature != testsig {
t.Errorf("j.Signature = %q\nwant %q", j.Signature, testsig)
}
})
}
}
func TestJWKThumbprintRSA(t *testing.T) {
// Key example from RFC 7638
const base64N = "0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAt" +

@ -104,7 +104,7 @@ func RateLimit(err error) (time.Duration, bool) {
if e.Header == nil {
return 0, true
}
return retryAfter(e.Header.Get("Retry-After"), 0), true
return retryAfter(e.Header.Get("Retry-After")), true
}
// Account is a user account. It is associated with a private key.
@ -296,8 +296,8 @@ func (e *wireError) error(h http.Header) *Error {
}
}
// CertOption is an optional argument type for the TLSSNIxChallengeCert methods for
// customizing a temporary certificate for TLS-SNI challenges.
// CertOption is an optional argument type for the TLS ChallengeCert methods for
// customizing a temporary certificate for TLS-based challenges.
type CertOption interface {
privateCertOpt()
}
@ -317,7 +317,7 @@ func (*certOptKey) privateCertOpt() {}
// WithTemplate creates an option for specifying a certificate template.
// See x509.CreateCertificate for template usage details.
//
// In TLSSNIxChallengeCert methods, the template is also used as parent,
// In TLS ChallengeCert methods, the template is also used as parent,
// resulting in a self-signed certificate.
// The DNSNames field of t is always overwritten for tls-sni challenge certs.
func WithTemplate(t *x509.Certificate) CertOption {

@ -5,7 +5,35 @@
// Package argon2 implements the key derivation function Argon2.
// Argon2 was selected as the winner of the Password Hashing Competition and can
// be used to derive cryptographic keys from passwords.
// Argon2 is specfifed at https://github.com/P-H-C/phc-winner-argon2/blob/master/argon2-specs.pdf
//
// For a detailed specification of Argon2 see [1].
//
// If you aren't sure which function you need, use Argon2id (IDKey) and
// the parameter recommendations for your scenario.
//
//
// Argon2i
//
// Argon2i (implemented by Key) is the side-channel resistant version of Argon2.
// It uses data-independent memory access, which is preferred for password
// hashing and password-based key derivation. Argon2i requires more passes over
// memory than Argon2id to protect from trade-off attacks. The recommended
// parameters (taken from [2]) for non-interactive operations are time=3 and to
// use the maximum available memory.
//
//
// Argon2id
//
// Argon2id (implemented by IDKey) is a hybrid version of Argon2 combining
// Argon2i and Argon2d. It uses data-independent memory access for the first
// half of the first iteration over the memory and data-dependent memory access
// for the rest. Argon2id is side-channel resistant and provides better brute-
// force cost savings due to time-memory tradeoffs than Argon2i. The recommended
// parameters for non-interactive operations (taken from [2]) are time=1 and to
// use the maximum available memory.
//
// [1] https://github.com/P-H-C/phc-winner-argon2/blob/master/argon2-specs.pdf
// [2] https://tools.ietf.org/html/draft-irtf-cfrg-argon2-03#section-9.3
package argon2
import (
@ -25,23 +53,52 @@ const (
)
// Key derives a key from the password, salt, and cost parameters using Argon2i
// returning a byte slice of length keyLen that can be used as cryptographic key.
// The CPU cost and parallism degree must be greater than zero.
// returning a byte slice of length keyLen that can be used as cryptographic
// key. The CPU cost and parallelism degree must be greater than zero.
//
// For example, you can get a derived key for e.g. AES-256 (which needs a 32-byte key) by doing:
// `key := argon2.Key([]byte("some password"), salt, 4, 32*1024, 4, 32)`
// For example, you can get a derived key for e.g. AES-256 (which needs a
// 32-byte key) by doing:
//
// The recommended parameters for interactive logins as of 2017 are time=4, memory=32*1024.
// The number of threads can be adjusted to the numbers of available CPUs.
// The time parameter specifies the number of passes over the memory and the memory
// parameter specifies the size of the memory in KiB. For example memory=32*1024 sets the
// memory cost to ~32 MB.
// The cost parameters should be increased as memory latency and CPU parallelism increases.
// Remember to get a good random salt.
// key := argon2.Key([]byte("some password"), salt, 3, 32*1024, 4, 32)
//
// The draft RFC recommends[2] time=3, and memory=32*1024 is a sensible number.
// If using that amount of memory (32 MB) is not possible in some contexts then
// the time parameter can be increased to compensate.
//
// The time parameter specifies the number of passes over the memory and the
// memory parameter specifies the size of the memory in KiB. For example
// memory=32*1024 sets the memory cost to ~32 MB. The number of threads can be
// adjusted to the number of available CPUs. The cost parameters should be
// increased as memory latency and CPU parallelism increases. Remember to get a
// good random salt.
func Key(password, salt []byte, time, memory uint32, threads uint8, keyLen uint32) []byte {
return deriveKey(argon2i, password, salt, nil, nil, time, memory, threads, keyLen)
}
// IDKey derives a key from the password, salt, and cost parameters using
// Argon2id returning a byte slice of length keyLen that can be used as
// cryptographic key. The CPU cost and parallelism degree must be greater than
// zero.
//
// For example, you can get a derived key for e.g. AES-256 (which needs a
// 32-byte key) by doing:
//
// key := argon2.IDKey([]byte("some password"), salt, 1, 64*1024, 4, 32)
//
// The draft RFC recommends[2] time=1, and memory=64*1024 is a sensible number.
// If using that amount of memory (64 MB) is not possible in some contexts then
// the time parameter can be increased to compensate.
//
// The time parameter specifies the number of passes over the memory and the
// memory parameter specifies the size of the memory in KiB. For example
// memory=64*1024 sets the memory cost to ~64 MB. The number of threads can be
// adjusted to the numbers of available CPUs. The cost parameters should be
// increased as memory latency and CPU parallelism increases. Remember to get a
// good random salt.
func IDKey(password, salt []byte, time, memory uint32, threads uint8, keyLen uint32) []byte {
return deriveKey(argon2id, password, salt, nil, nil, time, memory, threads, keyLen)
}
func deriveKey(mode int, password, salt, secret, data []byte, time, memory uint32, threads uint8, keyLen uint32) []byte {
if time < 1 {
panic("argon2: number of rounds too small")

@ -6,12 +6,11 @@
package argon2
func init() {
useSSE4 = supportsSSE4()
}
import "golang.org/x/sys/cpu"
//go:noescape
func supportsSSE4() bool
func init() {
useSSE4 = cpu.X86.HasSSE41
}
//go:noescape
func mixBlocksSSE2(out, a, b, c *block)

@ -241,12 +241,3 @@ loop:
SUBQ $2, BP
JA loop
RET
// func supportsSSE4() bool
TEXT ·supportsSSE4(SB), 4, $0-1
MOVL $1, AX
CPUID
SHRL $19, CX // Bit 19 indicates SSE4 support
ANDL $1, CX // CX != 0 if support SSE4
MOVB CX, ret+0(FP)
RET

@ -209,19 +209,19 @@ func TestMinorNotRequired(t *testing.T) {
func BenchmarkEqual(b *testing.B) {
b.StopTimer()
passwd := []byte("somepasswordyoulike")
hash, _ := GenerateFromPassword(passwd, 10)
hash, _ := GenerateFromPassword(passwd, DefaultCost)
b.StartTimer()
for i := 0; i < b.N; i++ {
CompareHashAndPassword(hash, passwd)
}
}
func BenchmarkGeneration(b *testing.B) {
func BenchmarkDefaultCost(b *testing.B) {
b.StopTimer()
passwd := []byte("mylongpassword1234")
b.StartTimer()
for i := 0; i < b.N; i++ {
GenerateFromPassword(passwd, 10)
GenerateFromPassword(passwd, DefaultCost)
}
}

@ -75,23 +75,25 @@ func Sum256(data []byte) [Size256]byte {
}
// New512 returns a new hash.Hash computing the BLAKE2b-512 checksum. A non-nil
// key turns the hash into a MAC. The key must between zero and 64 bytes long.
// key turns the hash into a MAC. The key must be between zero and 64 bytes long.
func New512(key []byte) (hash.Hash, error) { return newDigest(Size, key) }
// New384 returns a new hash.Hash computing the BLAKE2b-384 checksum. A non-nil
// key turns the hash into a MAC. The key must between zero and 64 bytes long.
// key turns the hash into a MAC. The key must be between zero and 64 bytes long.
func New384(key []byte) (hash.Hash, error) { return newDigest(Size384, key) }
// New256 returns a new hash.Hash computing the BLAKE2b-256 checksum. A non-nil
// key turns the hash into a MAC. The key must between zero and 64 bytes long.
// key turns the hash into a MAC. The key must be between zero and 64 bytes long.
func New256(key []byte) (hash.Hash, error) { return newDigest(Size256, key) }
// New returns a new hash.Hash computing the BLAKE2b checksum with a custom length.
// A non-nil key turns the hash into a MAC. The key must between zero and 64 bytes long.
// A non-nil key turns the hash into a MAC. The key must be between zero and 64 bytes long.
// The hash size can be a value between 1 and 64 but it is highly recommended to use
// values equal or greater than:
// - 32 if BLAKE2b is used as a hash function (The key is zero bytes long).
// - 16 if BLAKE2b is used as a MAC function (The key is at least 16 bytes long).
// When the key is nil, the returned hash.Hash implements BinaryMarshaler
// and BinaryUnmarshaler for state (de)serialization as documented by hash.Hash.
func New(size int, key []byte) (hash.Hash, error) { return newDigest(size, key) }
func newDigest(hashSize int, key []byte) (*digest, error) {
@ -150,6 +152,50 @@ type digest struct {
keyLen int
}
const (
magic = "b2b"
marshaledSize = len(magic) + 8*8 + 2*8 + 1 + BlockSize + 1
)
func (d *digest) MarshalBinary() ([]byte, error) {
if d.keyLen != 0 {
return nil, errors.New("crypto/blake2b: cannot marshal MACs")
}
b := make([]byte, 0, marshaledSize)
b = append(b, magic...)
for i := 0; i < 8; i++ {
b = appendUint64(b, d.h[i])
}
b = appendUint64(b, d.c[0])
b = appendUint64(b, d.c[1])
// Maximum value for size is 64
b = append(b, byte(d.size))
b = append(b, d.block[:]...)
b = append(b, byte(d.offset))
return b, nil
}
func (d *digest) UnmarshalBinary(b []byte) error {
if len(b) < len(magic) || string(b[:len(magic)]) != magic {
return errors.New("crypto/blake2b: invalid hash state identifier")
}
if len(b) != marshaledSize {
return errors.New("crypto/blake2b: invalid hash state size")
}
b = b[len(magic):]
for i := 0; i < 8; i++ {
b, d.h[i] = consumeUint64(b)
}
b, d.c[0] = consumeUint64(b)
b, d.c[1] = consumeUint64(b)
d.size = int(b[0])
b = b[1:]
copy(d.block[:], b[:BlockSize])
b = b[BlockSize:]
d.offset = int(b[0])
return nil
}
func (d *digest) BlockSize() int { return BlockSize }
func (d *digest) Size() int { return d.size }
@ -219,3 +265,25 @@ func (d *digest) finalize(hash *[Size]byte) {
binary.LittleEndian.PutUint64(hash[8*i:], v)
}
}
func appendUint64(b []byte, x uint64) []byte {
var a [8]byte
binary.BigEndian.PutUint64(a[:], x)
return append(b, a[:]...)
}
func appendUint32(b []byte, x uint32) []byte {
var a [4]byte
binary.BigEndian.PutUint32(a[:], x)
return append(b, a[:]...)
}
func consumeUint64(b []byte) ([]byte, uint64) {
x := binary.BigEndian.Uint64(b)
return b[8:], x
}
func consumeUint32(b []byte) ([]byte, uint32) {
x := binary.BigEndian.Uint32(b)
return b[4:], x
}

@ -6,21 +6,14 @@
package blake2b
import "golang.org/x/sys/cpu"
func init() {
useAVX2 = supportsAVX2()
useAVX = supportsAVX()
useSSE4 = supportsSSE4()
useAVX2 = cpu.X86.HasAVX2
useAVX = cpu.X86.HasAVX
useSSE4 = cpu.X86.HasSSE41
}
//go:noescape
func supportsSSE4() bool
//go:noescape
func supportsAVX() bool
//go:noescape
func supportsAVX2() bool
//go:noescape
func hashBlocksAVX2(h *[8]uint64, c *[2]uint64, flag uint64, blocks []byte)
@ -31,13 +24,14 @@ func hashBlocksAVX(h *[8]uint64, c *[2]uint64, flag uint64, blocks []byte)
func hashBlocksSSE4(h *[8]uint64, c *[2]uint64, flag uint64, blocks []byte)
func hashBlocks(h *[8]uint64, c *[2]uint64, flag uint64, blocks []byte) {
if useAVX2 {
switch {
case useAVX2:
hashBlocksAVX2(h, c, flag, blocks)
} else if useAVX {
case useAVX:
hashBlocksAVX(h, c, flag, blocks)
} else if useSSE4 {
case useSSE4:
hashBlocksSSE4(h, c, flag, blocks)
} else {
default:
hashBlocksGeneric(h, c, flag, blocks)
}
}

@ -748,15 +748,3 @@ noinc:
MOVQ BP, SP
RET
// func supportsAVX2() bool
TEXT ·supportsAVX2(SB), 4, $0-1
MOVQ runtime·support_avx2(SB), AX
MOVB AX, ret+0(FP)
RET
// func supportsAVX() bool
TEXT ·supportsAVX(SB), 4, $0-1
MOVQ runtime·support_avx(SB), AX
MOVB AX, ret+0(FP)
RET

@ -6,12 +6,11 @@
package blake2b
func init() {
useSSE4 = supportsSSE4()
}
import "golang.org/x/sys/cpu"
//go:noescape
func supportsSSE4() bool
func init() {
useSSE4 = cpu.X86.HasSSE41
}
//go:noescape
func hashBlocksSSE4(h *[8]uint64, c *[2]uint64, flag uint64, blocks []byte)

@ -279,12 +279,3 @@ noinc:
MOVQ BP, SP
RET
// func supportsSSE4() bool
TEXT ·supportsSSE4(SB), 4, $0-1
MOVL $1, AX
CPUID
SHRL $19, CX // Bit 19 indicates SSE4 support
ANDL $1, CX // CX != 0 if support SSE4
MOVB CX, ret+0(FP)
RET

@ -6,6 +6,7 @@ package blake2b
import (
"bytes"
"encoding"
"encoding/hex"
"fmt"
"hash"
@ -69,6 +70,54 @@ func TestHashes2X(t *testing.T) {
testHashes2X(t)
}
func TestMarshal(t *testing.T) {
input := make([]byte, 255)
for i := range input {
input[i] = byte(i)
}
for _, size := range []int{Size, Size256, Size384, 12, 25, 63} {
for i := 0; i < 256; i++ {
h, err := New(size, nil)
if err != nil {
t.Fatalf("size=%d, len(input)=%d: error from New(%v, nil): %v", size, i, size, err)
}
h2, err := New(size, nil)
if err != nil {
t.Fatalf("size=%d, len(input)=%d: error from New(%v, nil): %v", size, i, size, err)
}
h.Write(input[:i/2])
halfstate, err := h.(encoding.BinaryMarshaler).MarshalBinary()
if err != nil {
t.Fatalf("size=%d, len(input)=%d: could not marshal: %v", size, i, err)
}
err = h2.(encoding.BinaryUnmarshaler).UnmarshalBinary(halfstate)
if err != nil {
t.Fatalf("size=%d, len(input)=%d: could not unmarshal: %v", size, i, err)
}
h.Write(input[i/2 : i])
sum := h.Sum(nil)
h2.Write(input[i/2 : i])
sum2 := h2.Sum(nil)
if !bytes.Equal(sum, sum2) {
t.Fatalf("size=%d, len(input)=%d: results do not match; sum = %v, sum2 = %v", size, i, sum, sum2)
}
h3, err := New(size, nil)
if err != nil {
t.Fatalf("size=%d, len(input)=%d: error from New(%v, nil): %v", size, i, size, err)
}
h3.Write(input[:i])
sum3 := h3.Sum(nil)
if !bytes.Equal(sum, sum3) {
t.Fatalf("size=%d, len(input)=%d: sum = %v, want %v", size, i, sum, sum3)
}
}
}
}
func testHashes(t *testing.T) {
key, _ := hex.DecodeString("000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f")

@ -29,7 +29,7 @@ type XOF interface {
}
// OutputLengthUnknown can be used as the size argument to NewXOF to indicate
// the the length of the output is not known in advance.
// the length of the output is not known in advance.
const OutputLengthUnknown = 0
// magicUnknownOutputLength is a magic value for the output size that indicates

@ -49,6 +49,8 @@ func Sum256(data []byte) [Size]byte {
// New256 returns a new hash.Hash computing the BLAKE2s-256 checksum. A non-nil
// key turns the hash into a MAC. The key must between zero and 32 bytes long.
// When the key is nil, the returned hash.Hash implements BinaryMarshaler
// and BinaryUnmarshaler for state (de)serialization as documented by hash.Hash.
func New256(key []byte) (hash.Hash, error) { return newDigest(Size, key) }
// New128 returns a new hash.Hash computing the BLAKE2s-128 checksum given a
@ -120,6 +122,50 @@ type digest struct {
keyLen int
}
const (
magic = "b2s"
marshaledSize = len(magic) + 8*4 + 2*4 + 1 + BlockSize + 1
)
func (d *digest) MarshalBinary() ([]byte, error) {
if d.keyLen != 0 {
return nil, errors.New("crypto/blake2s: cannot marshal MACs")
}
b := make([]byte, 0, marshaledSize)
b = append(b, magic...)
for i := 0; i < 8; i++ {
b = appendUint32(b, d.h[i])
}
b = appendUint32(b, d.c[0])
b = appendUint32(b, d.c[1])
// Maximum value for size is 32
b = append(b, byte(d.size))
b = append(b, d.block[:]...)
b = append(b, byte(d.offset))
return b, nil
}
func (d *digest) UnmarshalBinary(b []byte) error {
if len(b) < len(magic) || string(b[:len(magic)]) != magic {
return errors.New("crypto/blake2s: invalid hash state identifier")
}
if len(b) != marshaledSize {
return errors.New("crypto/blake2s: invalid hash state size")
}
b = b[len(magic):]
for i := 0; i < 8; i++ {
b, d.h[i] = consumeUint32(b)
}
b, d.c[0] = consumeUint32(b)
b, d.c[1] = consumeUint32(b)
d.size = int(b[0])
b = b[1:]
copy(d.block[:], b[:BlockSize])
b = b[BlockSize:]
d.offset = int(b[0])
return nil
}
func (d *digest) BlockSize() int { return BlockSize }
func (d *digest) Size() int { return d.size }
@ -185,3 +231,14 @@ func (d *digest) finalize(hash *[Size]byte) {
binary.LittleEndian.PutUint32(hash[4*i:], v)
}
}
func appendUint32(b []byte, x uint32) []byte {
var a [4]byte
binary.BigEndian.PutUint32(a[:], x)
return append(b, a[:]...)
}
func consumeUint32(b []byte) ([]byte, uint32) {
x := binary.BigEndian.Uint32(b)
return b[4:], x
}

@ -6,18 +6,14 @@
package blake2s
import "golang.org/x/sys/cpu"
var (
useSSE4 = false
useSSSE3 = supportSSSE3()
useSSE2 = supportSSE2()
useSSSE3 = cpu.X86.HasSSSE3
useSSE2 = cpu.X86.HasSSE2
)
//go:noescape
func supportSSE2() bool
//go:noescape
func supportSSSE3() bool
//go:noescape
func hashBlocksSSE2(h *[8]uint32, c *[2]uint32, flag uint32, blocks []byte)
@ -25,11 +21,12 @@ func hashBlocksSSE2(h *[8]uint32, c *[2]uint32, flag uint32, blocks []byte)
func hashBlocksSSSE3(h *[8]uint32, c *[2]uint32, flag uint32, blocks []byte)
func hashBlocks(h *[8]uint32, c *[2]uint32, flag uint32, blocks []byte) {
if useSSSE3 {
switch {
case useSSSE3:
hashBlocksSSSE3(h, c, flag, blocks)
} else if useSSE2 {
case useSSE2:
hashBlocksSSE2(h, c, flag, blocks)
} else {
default:
hashBlocksGeneric(h, c, flag, blocks)
}
}

@ -433,28 +433,3 @@ loop:
MOVL BP, SP
RET
// func supportSSSE3() bool
TEXT ·supportSSSE3(SB), 4, $0-1
MOVL $1, AX
CPUID
MOVL CX, BX
ANDL $0x1, BX // supports SSE3
JZ FALSE
ANDL $0x200, CX // supports SSSE3
JZ FALSE
MOVB $1, ret+0(FP)
RET
FALSE:
MOVB $0, ret+0(FP)
RET
// func supportSSE2() bool
TEXT ·supportSSE2(SB), 4, $0-1
MOVL $1, AX
CPUID
SHRL $26, DX
ANDL $1, DX // DX != 0 if support SSE2
MOVB DX, ret+0(FP)
RET

@ -6,18 +6,14 @@
package blake2s
import "golang.org/x/sys/cpu"
var (
useSSE4 = supportSSE4()
useSSSE3 = supportSSSE3()
useSSE2 = true // Always available on amd64
useSSE4 = cpu.X86.HasSSE41
useSSSE3 = cpu.X86.HasSSSE3
useSSE2 = cpu.X86.HasSSE2
)
//go:noescape
func supportSSSE3() bool
//go:noescape
func supportSSE4() bool
//go:noescape
func hashBlocksSSE2(h *[8]uint32, c *[2]uint32, flag uint32, blocks []byte)
@ -28,13 +24,14 @@ func hashBlocksSSSE3(h *[8]uint32, c *[2]uint32, flag uint32, blocks []byte)
func hashBlocksSSE4(h *[8]uint32, c *[2]uint32, flag uint32, blocks []byte)
func hashBlocks(h *[8]uint32, c *[2]uint32, flag uint32, blocks []byte) {
if useSSE4 {
switch {
case useSSE4:
hashBlocksSSE4(h, c, flag, blocks)
} else if useSSSE3 {
case useSSSE3:
hashBlocksSSSE3(h, c, flag, blocks)
} else if useSSE2 {
case useSSE2:
hashBlocksSSE2(h, c, flag, blocks)
} else {
default:
hashBlocksGeneric(h, c, flag, blocks)
}
}

@ -436,28 +436,3 @@ TEXT ·hashBlocksSSSE3(SB), 0, $672-48 // frame = 656 + 16 byte alignment
TEXT ·hashBlocksSSE4(SB), 0, $32-48 // frame = 16 + 16 byte alignment
HASH_BLOCKS(h+0(FP), c+8(FP), flag+16(FP), blocks_base+24(FP), blocks_len+32(FP), BLAKE2s_SSE4)
RET
// func supportSSE4() bool
TEXT ·supportSSE4(SB), 4, $0-1
MOVL $1, AX
CPUID
SHRL $19, CX // Bit 19 indicates SSE4.1.
ANDL $1, CX
MOVB CX, ret+0(FP)
RET
// func supportSSSE3() bool
TEXT ·supportSSSE3(SB), 4, $0-1
MOVL $1, AX
CPUID
MOVL CX, BX
ANDL $0x1, BX // Bit zero indicates SSE3 support.
JZ FALSE
ANDL $0x200, CX // Bit nine indicates SSSE3 support.
JZ FALSE
MOVB $1, ret+0(FP)
RET
FALSE:
MOVB $0, ret+0(FP)
RET

@ -5,6 +5,8 @@
package blake2s
import (
"bytes"
"encoding"
"encoding/hex"
"fmt"
"testing"
@ -64,6 +66,52 @@ func TestHashes2X(t *testing.T) {
testHashes2X(t)
}
func TestMarshal(t *testing.T) {
input := make([]byte, 255)
for i := range input {
input[i] = byte(i)
}
for i := 0; i < 256; i++ {
h, err := New256(nil)
if err != nil {
t.Fatalf("len(input)=%d: error from New256(nil): %v", i, err)
}
h2, err := New256(nil)
if err != nil {
t.Fatalf("len(input)=%d: error from New256(nil): %v", i, err)
}
h.Write(input[:i/2])
halfstate, err := h.(encoding.BinaryMarshaler).MarshalBinary()
if err != nil {
t.Fatalf("len(input)=%d: could not marshal: %v", i, err)
}
err = h2.(encoding.BinaryUnmarshaler).UnmarshalBinary(halfstate)
if err != nil {
t.Fatalf("len(input)=%d: could not unmarshal: %v", i, err)
}
h.Write(input[i/2 : i])
sum := h.Sum(nil)
h2.Write(input[i/2 : i])
sum2 := h2.Sum(nil)
if !bytes.Equal(sum, sum2) {
t.Fatalf("len(input)=%d: results do not match; sum = %v, sum2 = %v", i, sum, sum2)
}
h3, err := New256(nil)
if err != nil {
t.Fatalf("len(input)=%d: error from New256(nil): %v", i, err)
}
h3.Write(input[:i])
sum3 := h3.Sum(nil)
if !bytes.Equal(sum, sum3) {
t.Fatalf("len(input)=%d: sum = %v, want %v", i, sum, sum3)
}
}
}
func testHashes(t *testing.T) {
key, _ := hex.DecodeString("000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f")

@ -29,7 +29,7 @@ type XOF interface {
}
// OutputLengthUnknown can be used as the size argument to NewXOF to indicate
// the the length of the output is not known in advance.
// the length of the output is not known in advance.
const OutputLengthUnknown = 0
// magicUnknownOutputLength is a magic value for the output size that indicates

@ -3,6 +3,14 @@
// license that can be found in the LICENSE file.
// Package blowfish implements Bruce Schneier's Blowfish encryption algorithm.
//
// Blowfish is a legacy cipher and its short block size makes it vulnerable to
// birthday bound attacks (see https://sweet32.info). It should only be used
// where compatibility with legacy systems, not security, is the goal.
//
// Deprecated: any new system should use AES (from crypto/aes, if necessary in
// an AEAD mode like crypto/cipher.NewGCM) or XChaCha20-Poly1305 (from
// golang.org/x/crypto/chacha20poly1305).
package blowfish // import "golang.org/x/crypto/blowfish"
// The code is a port of Bruce Schneier's C implementation.

@ -15,9 +15,14 @@
// http://cryptojedi.org/papers/dclxvi-20100714.pdf. Its output is compatible
// with the implementation described in that paper.
//
// (This package previously claimed to operate at a 128-bit security level.
// This package previously claimed to operate at a 128-bit security level.
// However, recent improvements in attacks mean that is no longer true. See
// https://moderncrypto.org/mail-archive/curves/2016/000740.html.)
// https://moderncrypto.org/mail-archive/curves/2016/000740.html.
//
// Deprecated: due to its weakened security, new systems should not rely on this
// elliptic curve. This package is frozen, and not implemented in constant time.
// There is a more complete implementation at github.com/cloudflare/bn256, but
// note that it suffers from the same security issues of the underlying curve.
package bn256 // import "golang.org/x/crypto/bn256"
import (
@ -26,9 +31,6 @@ import (
"math/big"
)
// BUG(agl): this implementation is not constant time.
// TODO(agl): keep GF(p²) elements in Mongomery form.
// G1 is an abstract cyclic group. The zero value is suitable for use as the
// output of an operation, but cannot be used as an input.
type G1 struct {
@ -54,6 +56,9 @@ func RandomG1(r io.Reader) (*big.Int, *G1, error) {
}
func (e *G1) String() string {
if e.p == nil {
return "bn256.G1" + newCurvePoint(nil).String()
}
return "bn256.G1" + e.p.String()
}
@ -77,7 +82,8 @@ func (e *G1) ScalarMult(a *G1, k *big.Int) *G1 {
}
// Add sets e to a+b and then returns e.
// BUG(agl): this function is not complete: a==b fails.
//
// Warning: this function is not complete, it fails for a equal to b.
func (e *G1) Add(a, b *G1) *G1 {
if e.p == nil {
e.p = newCurvePoint(nil)
@ -97,14 +103,18 @@ func (e *G1) Neg(a *G1) *G1 {
// Marshal converts n to a byte slice.
func (e *G1) Marshal() []byte {
// Each value is a 256-bit number.
const numBytes = 256 / 8
if e.p.IsInfinity() {
return make([]byte, numBytes*2)
}
e.p.MakeAffine(nil)
xBytes := new(big.Int).Mod(e.p.x, p).Bytes()
yBytes := new(big.Int).Mod(e.p.y, p).Bytes()
// Each value is a 256-bit number.
const numBytes = 256 / 8
ret := make([]byte, numBytes*2)
copy(ret[1*numBytes-len(xBytes):], xBytes)
copy(ret[2*numBytes-len(yBytes):], yBytes)
@ -171,6 +181,9 @@ func RandomG2(r io.Reader) (*big.Int, *G2, error) {
}
func (e *G2) String() string {
if e.p == nil {
return "bn256.G2" + newTwistPoint(nil).String()
}
return "bn256.G2" + e.p.String()
}
@ -194,7 +207,8 @@ func (e *G2) ScalarMult(a *G2, k *big.Int) *G2 {
}
// Add sets e to a+b and then returns e.
// BUG(agl): this function is not complete: a==b fails.
//
// Warning: this function is not complete, it fails for a equal to b.
func (e *G2) Add(a, b *G2) *G2 {
if e.p == nil {
e.p = newTwistPoint(nil)
@ -205,6 +219,13 @@ func (e *G2) Add(a, b *G2) *G2 {
// Marshal converts n into a byte slice.
func (n *G2) Marshal() []byte {
// Each value is a 256-bit number.
const numBytes = 256 / 8
if n.p.IsInfinity() {
return make([]byte, numBytes*4)
}
n.p.MakeAffine(nil)
xxBytes := new(big.Int).Mod(n.p.x.x, p).Bytes()
@ -212,9 +233,6 @@ func (n *G2) Marshal() []byte {
yxBytes := new(big.Int).Mod(n.p.y.x, p).Bytes()
yyBytes := new(big.Int).Mod(n.p.y.y, p).Bytes()
// Each value is a 256-bit number.
const numBytes = 256 / 8
ret := make([]byte, numBytes*4)
copy(ret[1*numBytes-len(xxBytes):], xxBytes)
copy(ret[2*numBytes-len(xyBytes):], xyBytes)
@ -269,8 +287,11 @@ type GT struct {
p *gfP12
}
func (g *GT) String() string {
return "bn256.GT" + g.p.String()
func (e *GT) String() string {
if e.p == nil {
return "bn256.GT" + newGFp12(nil).String()
}
return "bn256.GT" + e.p.String()
}
// ScalarMult sets e to a*k and then returns e.

@ -245,10 +245,19 @@ func (c *curvePoint) Mul(a *curvePoint, scalar *big.Int, pool *bnPool) *curvePoi
return c
}
// MakeAffine converts c to affine form and returns c. If c is ∞, then it sets
// c to 0 : 1 : 0.
func (c *curvePoint) MakeAffine(pool *bnPool) *curvePoint {
if words := c.z.Bits(); len(words) == 1 && words[0] == 1 {
return c
}
if c.IsInfinity() {
c.x.SetInt64(0)
c.y.SetInt64(1)
c.z.SetInt64(0)
c.t.SetInt64(0)
return c
}
zInv := pool.Get().ModInverse(c.z, p)
t := pool.Get().Mul(c.y, zInv)

@ -125,8 +125,8 @@ func (e *gfP12) Mul(a, b *gfP12, pool *bnPool) *gfP12 {
}
func (e *gfP12) MulScalar(a *gfP12, b *gfP6, pool *bnPool) *gfP12 {
e.x.Mul(e.x, b, pool)
e.y.Mul(e.y, b, pool)
e.x.Mul(a.x, b, pool)
e.y.Mul(a.y, b, pool)
return e
}

@ -219,10 +219,19 @@ func (c *twistPoint) Mul(a *twistPoint, scalar *big.Int, pool *bnPool) *twistPoi
return c
}
// MakeAffine converts c to affine form and returns c. If c is ∞, then it sets
// c to 0 : 1 : 0.
func (c *twistPoint) MakeAffine(pool *bnPool) *twistPoint {
if c.z.IsOne() {
return c
}
if c.IsInfinity() {
c.x.SetZero()
c.y.SetOne()
c.z.SetZero()
c.t.SetZero()
return c
}
zInv := newGFp2(pool).Invert(c.z, pool)
t := newGFp2(pool).Mul(c.y, zInv, pool)

@ -2,8 +2,15 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package cast5 implements CAST5, as defined in RFC 2144. CAST5 is a common
// OpenPGP cipher.
// Package cast5 implements CAST5, as defined in RFC 2144.
//
// CAST5 is a legacy cipher and its short block size makes it vulnerable to
// birthday bound attacks (see https://sweet32.info). It should only be used
// where compatibility with legacy systems, not security, is the goal.
//
// Deprecated: any new system should use AES (from crypto/aes, if necessary in
// an AEAD mode like crypto/cipher.NewGCM) or XChaCha20-Poly1305 (from
// golang.org/x/crypto/chacha20poly1305).
package cast5 // import "golang.org/x/crypto/cast5"
import "errors"

@ -2,32 +2,50 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package chacha20poly1305 implements the ChaCha20-Poly1305 AEAD as specified in RFC 7539.
// Package chacha20poly1305 implements the ChaCha20-Poly1305 AEAD as specified in RFC 7539,
// and its extended nonce variant XChaCha20-Poly1305.
package chacha20poly1305 // import "golang.org/x/crypto/chacha20poly1305"
import (
"crypto/cipher"
"encoding/binary"
"errors"
)
const (
// KeySize is the size of the key used by this AEAD, in bytes.
KeySize = 32
// NonceSize is the size of the nonce used with this AEAD, in bytes.
// NonceSize is the size of the nonce used with the standard variant of this
// AEAD, in bytes.
//
// Note that this is too short to be safely generated at random if the same
// key is reused more than 2³² times.
NonceSize = 12
// NonceSizeX is the size of the nonce used with the XChaCha20-Poly1305
// variant of this AEAD, in bytes.
NonceSizeX = 24
)
type chacha20poly1305 struct {
key [32]byte
key [8]uint32
}
// New returns a ChaCha20-Poly1305 AEAD that uses the given, 256-bit key.
// New returns a ChaCha20-Poly1305 AEAD that uses the given 256-bit key.
func New(key []byte) (cipher.AEAD, error) {
if len(key) != KeySize {
return nil, errors.New("chacha20poly1305: bad key length")
}
ret := new(chacha20poly1305)
copy(ret.key[:], key)
ret.key[0] = binary.LittleEndian.Uint32(key[0:4])
ret.key[1] = binary.LittleEndian.Uint32(key[4:8])
ret.key[2] = binary.LittleEndian.Uint32(key[8:12])
ret.key[3] = binary.LittleEndian.Uint32(key[12:16])
ret.key[4] = binary.LittleEndian.Uint32(key[16:20])
ret.key[5] = binary.LittleEndian.Uint32(key[20:24])
ret.key[6] = binary.LittleEndian.Uint32(key[24:28])
ret.key[7] = binary.LittleEndian.Uint32(key[28:32])
return ret, nil
}

@ -6,7 +6,12 @@
package chacha20poly1305
import "encoding/binary"
import (
"encoding/binary"
"golang.org/x/crypto/internal/subtle"
"golang.org/x/sys/cpu"
)
//go:noescape
func chacha20Poly1305Open(dst []byte, key []uint32, src, ad []byte) bool
@ -14,78 +19,26 @@ func chacha20Poly1305Open(dst []byte, key []uint32, src, ad []byte) bool
//go:noescape
func chacha20Poly1305Seal(dst []byte, key []uint32, src, ad []byte)
// cpuid is implemented in chacha20poly1305_amd64.s.
func cpuid(eaxArg, ecxArg uint32) (eax, ebx, ecx, edx uint32)
// xgetbv with ecx = 0 is implemented in chacha20poly1305_amd64.s.
func xgetbv() (eax, edx uint32)
var (
useASM bool
useAVX2 bool
useAVX2 = cpu.X86.HasAVX2 && cpu.X86.HasBMI2
)
func init() {
detectCPUFeatures()
}
// detectCPUFeatures is used to detect if cpu instructions
// used by the functions implemented in assembler in
// chacha20poly1305_amd64.s are supported.
func detectCPUFeatures() {
maxID, _, _, _ := cpuid(0, 0)
if maxID < 1 {
return
}
_, _, ecx1, _ := cpuid(1, 0)
haveSSSE3 := isSet(9, ecx1)
useASM = haveSSSE3
haveOSXSAVE := isSet(27, ecx1)
osSupportsAVX := false
// For XGETBV, OSXSAVE bit is required and sufficient.
if haveOSXSAVE {
eax, _ := xgetbv()
// Check if XMM and YMM registers have OS support.
osSupportsAVX = isSet(1, eax) && isSet(2, eax)
}
haveAVX := isSet(28, ecx1) && osSupportsAVX
if maxID < 7 {
return
}
_, ebx7, _, _ := cpuid(7, 0)
haveAVX2 := isSet(5, ebx7) && haveAVX
haveBMI2 := isSet(8, ebx7)
useAVX2 = haveAVX2 && haveBMI2
}
// isSet checks if bit at bitpos is set in value.
func isSet(bitpos uint, value uint32) bool {
return value&(1<<bitpos) != 0
}
// setupState writes a ChaCha20 input matrix to state. See
// https://tools.ietf.org/html/rfc7539#section-2.3.
func setupState(state *[16]uint32, key *[32]byte, nonce []byte) {
func setupState(state *[16]uint32, key *[8]uint32, nonce []byte) {
state[0] = 0x61707865
state[1] = 0x3320646e
state[2] = 0x79622d32
state[3] = 0x6b206574
state[4] = binary.LittleEndian.Uint32(key[:4])
state[5] = binary.LittleEndian.Uint32(key[4:8])
state[6] = binary.LittleEndian.Uint32(key[8:12])
state[7] = binary.LittleEndian.Uint32(key[12:16])
state[8] = binary.LittleEndian.Uint32(key[16:20])
state[9] = binary.LittleEndian.Uint32(key[20:24])
state[10] = binary.LittleEndian.Uint32(key[24:28])
state[11] = binary.LittleEndian.Uint32(key[28:32])
state[4] = key[0]
state[5] = key[1]
state[6] = key[2]
state[7] = key[3]
state[8] = key[4]
state[9] = key[5]
state[10] = key[6]
state[11] = key[7]
state[12] = 0
state[13] = binary.LittleEndian.Uint32(nonce[:4])
@ -94,7 +47,7 @@ func setupState(state *[16]uint32, key *[32]byte, nonce []byte) {
}
func (c *chacha20poly1305) seal(dst, nonce, plaintext, additionalData []byte) []byte {
if !useASM {
if !cpu.X86.HasSSSE3 {
return c.sealGeneric(dst, nonce, plaintext, additionalData)
}
@ -102,12 +55,15 @@ func (c *chacha20poly1305) seal(dst, nonce, plaintext, additionalData []byte) []
setupState(&state, &c.key, nonce)
ret, out := sliceForAppend(dst, len(plaintext)+16)
if subtle.InexactOverlap(out, plaintext) {
panic("chacha20poly1305: invalid buffer overlap")
}
chacha20Poly1305Seal(out[:], state[:], plaintext, additionalData)
return ret
}
func (c *chacha20poly1305) open(dst, nonce, ciphertext, additionalData []byte) ([]byte, error) {
if !useASM {
if !cpu.X86.HasSSSE3 {
return c.openGeneric(dst, nonce, ciphertext, additionalData)
}
@ -116,6 +72,9 @@ func (c *chacha20poly1305) open(dst, nonce, ciphertext, additionalData []byte) (
ciphertext = ciphertext[:len(ciphertext)-16]
ret, out := sliceForAppend(dst, len(ciphertext))
if subtle.InexactOverlap(out, ciphertext) {
panic("chacha20poly1305: invalid buffer overlap")
}
if !chacha20Poly1305Open(out, state[:], ciphertext, additionalData) {
for i := range out {
out[i] = 0

@ -2693,22 +2693,3 @@ sealAVX2Tail512LoopB:
VPERM2I128 $0x13, tmpStoreAVX2, DD3, DD0
JMP sealAVX2SealHash
// func cpuid(eaxArg, ecxArg uint32) (eax, ebx, ecx, edx uint32)
TEXT ·cpuid(SB), NOSPLIT, $0-24
MOVL eaxArg+0(FP), AX
MOVL ecxArg+4(FP), CX
CPUID
MOVL AX, eax+8(FP)
MOVL BX, ebx+12(FP)
MOVL CX, ecx+16(FP)
MOVL DX, edx+20(FP)
RET
// func xgetbv() (eax, edx uint32)
TEXT ·xgetbv(SB),NOSPLIT,$0-8
MOVL $0, CX
XGETBV
MOVL AX, eax+0(FP)
MOVL DX, edx+4(FP)
RET

@ -8,6 +8,7 @@ import (
"encoding/binary"
"golang.org/x/crypto/internal/chacha20"
"golang.org/x/crypto/internal/subtle"
"golang.org/x/crypto/poly1305"
)
@ -16,15 +17,20 @@ func roundTo16(n int) int {
}
func (c *chacha20poly1305) sealGeneric(dst, nonce, plaintext, additionalData []byte) []byte {
var counter [16]byte
copy(counter[4:], nonce)
ret, out := sliceForAppend(dst, len(plaintext)+poly1305.TagSize)
if subtle.InexactOverlap(out, plaintext) {
panic("chacha20poly1305: invalid buffer overlap")
}
var polyKey [32]byte
chacha20.XORKeyStream(polyKey[:], polyKey[:], &counter, &c.key)
ret, out := sliceForAppend(dst, len(plaintext)+poly1305.TagSize)
counter[0] = 1
chacha20.XORKeyStream(out, plaintext, &counter, &c.key)
s := chacha20.New(c.key, [3]uint32{
binary.LittleEndian.Uint32(nonce[0:4]),
binary.LittleEndian.Uint32(nonce[4:8]),
binary.LittleEndian.Uint32(nonce[8:12]),
})
s.XORKeyStream(polyKey[:], polyKey[:])
s.Advance() // skip the next 32 bytes
s.XORKeyStream(out, plaintext)
polyInput := make([]byte, roundTo16(len(additionalData))+roundTo16(len(plaintext))+8+8)
copy(polyInput, additionalData)
@ -44,11 +50,14 @@ func (c *chacha20poly1305) openGeneric(dst, nonce, ciphertext, additionalData []
copy(tag[:], ciphertext[len(ciphertext)-16:])
ciphertext = ciphertext[:len(ciphertext)-16]
var counter [16]byte
copy(counter[4:], nonce)
var polyKey [32]byte
chacha20.XORKeyStream(polyKey[:], polyKey[:], &counter, &c.key)
s := chacha20.New(c.key, [3]uint32{
binary.LittleEndian.Uint32(nonce[0:4]),
binary.LittleEndian.Uint32(nonce[4:8]),
binary.LittleEndian.Uint32(nonce[8:12]),
})
s.XORKeyStream(polyKey[:], polyKey[:])
s.Advance() // skip the next 32 bytes
polyInput := make([]byte, roundTo16(len(additionalData))+roundTo16(len(ciphertext))+8+8)
copy(polyInput, additionalData)
@ -57,6 +66,9 @@ func (c *chacha20poly1305) openGeneric(dst, nonce, ciphertext, additionalData []
binary.LittleEndian.PutUint64(polyInput[len(polyInput)-8:], uint64(len(ciphertext)))
ret, out := sliceForAppend(dst, len(ciphertext))
if subtle.InexactOverlap(out, ciphertext) {
panic("chacha20poly1305: invalid buffer overlap")
}
if !poly1305.Verify(&tag, polyInput, &polyKey) {
for i := range out {
out[i] = 0
@ -64,7 +76,6 @@ func (c *chacha20poly1305) openGeneric(dst, nonce, ciphertext, additionalData []
return nil, errOpen
}
counter[0] = 1
chacha20.XORKeyStream(out, ciphertext, &counter, &c.key)
s.XORKeyStream(out, ciphertext)
return ret, nil
}

@ -6,9 +6,13 @@ package chacha20poly1305
import (
"bytes"
cr "crypto/rand"
"crypto/cipher"
cryptorand "crypto/rand"
"encoding/hex"
mr "math/rand"
"fmt"
"log"
mathrand "math/rand"
"strconv"
"testing"
)
@ -19,7 +23,18 @@ func TestVectors(t *testing.T) {
ad, _ := hex.DecodeString(test.aad)
plaintext, _ := hex.DecodeString(test.plaintext)
aead, err := New(key)
var (
aead cipher.AEAD
err error
)
switch len(nonce) {
case NonceSize:
aead, err = New(key)
case NonceSizeX:
aead, err = NewX(key)
default:
t.Fatalf("#%d: wrong nonce length: %d", i, len(nonce))
}
if err != nil {
t.Fatal(err)
}
@ -42,7 +57,7 @@ func TestVectors(t *testing.T) {
}
if len(ad) > 0 {
alterAdIdx := mr.Intn(len(ad))
alterAdIdx := mathrand.Intn(len(ad))
ad[alterAdIdx] ^= 0x80
if _, err := aead.Open(nil, nonce, ct, ad); err == nil {
t.Errorf("#%d: Open was successful after altering additional data", i)
@ -50,14 +65,14 @@ func TestVectors(t *testing.T) {
ad[alterAdIdx] ^= 0x80
}
alterNonceIdx := mr.Intn(aead.NonceSize())
alterNonceIdx := mathrand.Intn(aead.NonceSize())
nonce[alterNonceIdx] ^= 0x80
if _, err := aead.Open(nil, nonce, ct, ad); err == nil {
t.Errorf("#%d: Open was successful after altering nonce", i)
}
nonce[alterNonceIdx] ^= 0x80
alterCtIdx := mr.Intn(len(ct))
alterCtIdx := mathrand.Intn(len(ct))
ct[alterCtIdx] ^= 0x80
if _, err := aead.Open(nil, nonce, ct, ad); err == nil {
t.Errorf("#%d: Open was successful after altering ciphertext", i)
@ -68,87 +83,117 @@ func TestVectors(t *testing.T) {
func TestRandom(t *testing.T) {
// Some random tests to verify Open(Seal) == Plaintext
for i := 0; i < 256; i++ {
var nonce [12]byte
var key [32]byte
f := func(t *testing.T, nonceSize int) {
for i := 0; i < 256; i++ {
var nonce = make([]byte, nonceSize)
var key [32]byte
al := mr.Intn(128)
pl := mr.Intn(16384)
ad := make([]byte, al)
plaintext := make([]byte, pl)
cr.Read(key[:])
cr.Read(nonce[:])
cr.Read(ad)
cr.Read(plaintext)
al := mathrand.Intn(128)
pl := mathrand.Intn(16384)
ad := make([]byte, al)
plaintext := make([]byte, pl)
cryptorand.Read(key[:])
cryptorand.Read(nonce[:])
cryptorand.Read(ad)
cryptorand.Read(plaintext)
aead, err := New(key[:])
if err != nil {
t.Fatal(err)
}
ct := aead.Seal(nil, nonce[:], plaintext, ad)
plaintext2, err := aead.Open(nil, nonce[:], ct, ad)
if err != nil {
t.Errorf("Random #%d: Open failed", i)
continue
}
if !bytes.Equal(plaintext, plaintext2) {
t.Errorf("Random #%d: plaintext's don't match: got %x vs %x", i, plaintext2, plaintext)
continue
}
if len(ad) > 0 {
alterAdIdx := mr.Intn(len(ad))
ad[alterAdIdx] ^= 0x80
if _, err := aead.Open(nil, nonce[:], ct, ad); err == nil {
t.Errorf("Random #%d: Open was successful after altering additional data", i)
var (
aead cipher.AEAD
err error
)
switch len(nonce) {
case NonceSize:
aead, err = New(key[:])
case NonceSizeX:
aead, err = NewX(key[:])
default:
t.Fatalf("#%d: wrong nonce length: %d", i, len(nonce))
}
if err != nil {
t.Fatal(err)
}
ad[alterAdIdx] ^= 0x80
}
alterNonceIdx := mr.Intn(aead.NonceSize())
nonce[alterNonceIdx] ^= 0x80
if _, err := aead.Open(nil, nonce[:], ct, ad); err == nil {
t.Errorf("Random #%d: Open was successful after altering nonce", i)
}
nonce[alterNonceIdx] ^= 0x80
ct := aead.Seal(nil, nonce[:], plaintext, ad)
alterCtIdx := mr.Intn(len(ct))
ct[alterCtIdx] ^= 0x80
if _, err := aead.Open(nil, nonce[:], ct, ad); err == nil {
t.Errorf("Random #%d: Open was successful after altering ciphertext", i)
plaintext2, err := aead.Open(nil, nonce[:], ct, ad)
if err != nil {
t.Errorf("Random #%d: Open failed", i)
continue
}
if !bytes.Equal(plaintext, plaintext2) {
t.Errorf("Random #%d: plaintext's don't match: got %x vs %x", i, plaintext2, plaintext)
continue
}
if len(ad) > 0 {
alterAdIdx := mathrand.Intn(len(ad))
ad[alterAdIdx] ^= 0x80
if _, err := aead.Open(nil, nonce[:], ct, ad); err == nil {
t.Errorf("Random #%d: Open was successful after altering additional data", i)
}
ad[alterAdIdx] ^= 0x80
}
alterNonceIdx := mathrand.Intn(aead.NonceSize())
nonce[alterNonceIdx] ^= 0x80
if _, err := aead.Open(nil, nonce[:], ct, ad); err == nil {
t.Errorf("Random #%d: Open was successful after altering nonce", i)
}
nonce[alterNonceIdx] ^= 0x80
alterCtIdx := mathrand.Intn(len(ct))
ct[alterCtIdx] ^= 0x80
if _, err := aead.Open(nil, nonce[:], ct, ad); err == nil {
t.Errorf("Random #%d: Open was successful after altering ciphertext", i)
}
ct[alterCtIdx] ^= 0x80
}
ct[alterCtIdx] ^= 0x80
}
t.Run("Standard", func(t *testing.T) { f(t, NonceSize) })
t.Run("X", func(t *testing.T) { f(t, NonceSizeX) })
}
func benchamarkChaCha20Poly1305Seal(b *testing.B, buf []byte) {
func benchamarkChaCha20Poly1305Seal(b *testing.B, buf []byte, nonceSize int) {
b.ReportAllocs()
b.SetBytes(int64(len(buf)))
var key [32]byte
var nonce [12]byte
var nonce = make([]byte, nonceSize)
var ad [13]byte
var out []byte
aead, _ := New(key[:])
var aead cipher.AEAD
switch len(nonce) {
case NonceSize:
aead, _ = New(key[:])
case NonceSizeX:
aead, _ = NewX(key[:])
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
out = aead.Seal(out[:0], nonce[:], buf[:], ad[:])
}
}
func benchamarkChaCha20Poly1305Open(b *testing.B, buf []byte) {
func benchamarkChaCha20Poly1305Open(b *testing.B, buf []byte, nonceSize int) {
b.ReportAllocs()
b.SetBytes(int64(len(buf)))
var key [32]byte
var nonce [12]byte
var nonce = make([]byte, nonceSize)
var ad [13]byte
var ct []byte
var out []byte
aead, _ := New(key[:])
var aead cipher.AEAD
switch len(nonce) {
case NonceSize:
aead, _ = New(key[:])
case NonceSizeX:
aead, _ = NewX(key[:])
}
ct = aead.Seal(ct[:0], nonce[:], buf[:], ad[:])
b.ResetTimer()
@ -157,26 +202,54 @@ func benchamarkChaCha20Poly1305Open(b *testing.B, buf []byte) {
}
}
func BenchmarkChacha20Poly1305Open_64(b *testing.B) {
benchamarkChaCha20Poly1305Open(b, make([]byte, 64))
func BenchmarkChacha20Poly1305(b *testing.B) {
for _, length := range []int{64, 1350, 8 * 1024} {
b.Run("Open-"+strconv.Itoa(length), func(b *testing.B) {
benchamarkChaCha20Poly1305Open(b, make([]byte, length), NonceSize)
})
b.Run("Seal-"+strconv.Itoa(length), func(b *testing.B) {
benchamarkChaCha20Poly1305Seal(b, make([]byte, length), NonceSize)
})
b.Run("Open-"+strconv.Itoa(length)+"-X", func(b *testing.B) {
benchamarkChaCha20Poly1305Open(b, make([]byte, length), NonceSizeX)
})
b.Run("Seal-"+strconv.Itoa(length)+"-X", func(b *testing.B) {
benchamarkChaCha20Poly1305Seal(b, make([]byte, length), NonceSizeX)
})
}
}
func BenchmarkChacha20Poly1305Seal_64(b *testing.B) {
benchamarkChaCha20Poly1305Seal(b, make([]byte, 64))
}
var key = make([]byte, KeySize)
func BenchmarkChacha20Poly1305Open_1350(b *testing.B) {
benchamarkChaCha20Poly1305Open(b, make([]byte, 1350))
}
func ExampleNewX() {
aead, err := NewX(key)
if err != nil {
log.Fatalln("Failed to instantiate XChaCha20-Poly1305:", err)
}
func BenchmarkChacha20Poly1305Seal_1350(b *testing.B) {
benchamarkChaCha20Poly1305Seal(b, make([]byte, 1350))
}
for _, msg := range []string{
"Attack at dawn.",
"The eagle has landed.",
"Gophers, gophers, gophers everywhere!",
} {
// Encryption.
nonce := make([]byte, NonceSizeX)
if _, err := cryptorand.Read(nonce); err != nil {
panic(err)
}
ciphertext := aead.Seal(nil, nonce, []byte(msg), nil)
func BenchmarkChacha20Poly1305Open_8K(b *testing.B) {
benchamarkChaCha20Poly1305Open(b, make([]byte, 8*1024))
}
// Decryption.
plaintext, err := aead.Open(nil, nonce, ciphertext, nil)
if err != nil {
log.Fatalln("Failed to decrypt or authenticate message:", err)
}
func BenchmarkChacha20Poly1305Seal_8K(b *testing.B) {
benchamarkChaCha20Poly1305Seal(b, make([]byte, 8*1024))
fmt.Printf("%s\n", plaintext)
}
// Output: Attack at dawn.
// The eagle has landed.
// Gophers, gophers, gophers everywhere!
}

@ -7,6 +7,13 @@ package chacha20poly1305
var chacha20Poly1305Tests = []struct {
plaintext, aad, key, nonce, out string
}{
{
"",
"",
"808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9f",
"070000004041424344454647",
"a0784d7a4716f3feb4f64e7f4b39bf04",
},
{
"4c616469657320616e642047656e746c656d656e206f662074686520636c617373206f66202739393a204966204920636f756c64206f6666657220796f75206f6e6c79206f6e652074697020666f7220746865206675747572652c2073756e73637265656e20776f756c642062652069742e",
"50515253c0c1c2c3c4c5c6c7",
@ -329,4 +336,391 @@ var chacha20Poly1305Tests = []struct {
"129039b5572e8a7a8131f76a",
"2c125232a59879aee36cacc4aca5085a4688c4f776667a8fbd86862b5cfb1d57c976688fdd652eafa2b88b1b8e358aa2110ff6ef13cdc1ceca9c9f087c35c38d89d6fbd8de89538070f17916ecb19ca3ef4a1c834f0bdaa1df62aaabef2e117106787056c909e61ecd208357dd5c363f11c5d6cf24992cc873cf69f59360a820fcf290bd90b2cab24c47286acb4e1033962b6d41e562a206a94796a8ab1c6b8bade804ff9bdf5ba6062d2c1f8fe0f4dfc05720bd9a612b92c26789f9f6a7ce43f5e8e3aee99a9cd7d6c11eaa611983c36935b0dda57d898a60a0ab7c4b54",
},
// XChaCha20-Poly1305 vectors
{
"000000000000000000000000000000",
"",
"0000000000000000000000000000000000000000000000000000000000000000",
"000000000000000000000000000000000000000000000000",
"789e9689e5208d7fd9e1f3c5b5341fb2f7033812ac9ebd3745e2c99c7bbfeb",
},
{
"02dc819b71875e49f5e1e5a768141cfd3f14307ae61a34d81decd9a3367c00c7",
"",
"b7bbfe61b8041658ddc95d5cbdc01bbe7626d24f3a043b70ddee87541234cff7",
"e293239d4c0a07840c5f83cb515be7fd59c333933027e99c",
"7a51f271bd2e547943c7be3316c05519a5d16803712289aa2369950b1504dd8267222e47b13280077ecada7b8795d535",
},
{
"7afc5f3f24155002e17dc176a8f1f3a097ff5a991b02ff4640f70b90db0c15c328b696d6998ea7988edfe3b960e47824e4ae002fbe589be57896a9b7bf5578599c6ba0153c7c",
"d499bb9758debe59a93783c61974b7",
"4ea8fab44a07f7ffc0329b2c2f8f994efdb6d505aec32113ae324def5d929ba1",
"404d5086271c58bf27b0352a205d21ce4367d7b6a7628961",
"26d2b46ad58b6988e2dcf1d09ba8ab6f532dc7e0847cdbc0ed00284225c02bbdb278ee8381ebd127a06926107d1b731cfb1521b267168926492e8f77219ad922257a5be2c5e52e6183ca4dfd0ad3912d7bd1ec968065",
},
{
"",
"",
"48d8bd02c2e9947eae58327114d35e055407b5519c8019535efcb4fc875b5e2b",
"cc0a587a475caba06f8dbc09afec1462af081fe1908c2cba",
"fc3322d0a9d6fac3eb4a9e09b00b361e",
},
{
"e0862731e5",
"",
"6579e7ee96151131a1fcd06fe0d52802c0021f214960ecceec14b2b8591f62cd",
"e2230748649bc22e2b71e46a7814ecabe3a7005e949bd491",
"e991efb85d8b1cfa3f92cb72b8d3c882e88f4529d9",
},
{
"00c7dd8f440af1530b44",
"",
"ffb733657c849d50ab4ab40c4ae18f8ee2f0acf7c907afefdc04dff3537fdff3",
"02c6fd8032a8d89edbedcd1db024c09d29f08b1e74325085",
"13dbcdb8c60c3ed28449a57688edfaea89e309ab4faa6d51e532",
},
{
"7422f311ea476cf819cb8b3c77369f",
"",
"ef0d05d028d6abdd5e99d1761d2028de75ee6eb376ff0dc8036e9a8e10743876",
"f772745200b0f92e38f1d8dae79bf8138e84b301f0be74df",
"d5f992f9834df1be86b580ac59c7eae063a68072829c51bc8a26970dd3d310",
},
{
"ba09ca69450e6c7bece31a7a3f216e3b9ed0e536",
"",
"8d93e31abfe22a63faf45cbea91877050718f13fef6e2664a1892d7f23007ccf",
"260b7b3554a7e6ff8aae7dd6234077ca539689a20c1610a8",
"c99e9a768eb2ec8569bdff8a37295069552faebcafb1a76e98bc7c5b6b778b3d1b6291f0",
},
{
"424ec5f98a0fdc5a7388532d11ab0edb26733505627b7f2d1f",
"",
"b68d5e6c46cdbb0060445522bdc5c562ae803b6aaaf1e103c146e93527a59299",
"80bb5dc1dd44a35ec4f91307f1a95b4ca31183a1a596fb7c",
"29d4eed0fff0050d4bb40de3b055d836206e7cbd62de1a63904f0cf731129ba3f9c2b9d46251a6de89",
},
{
"e7e4515cc0a6ef0491af983eaac4f862d6e726758a3c657f4ec444841e42",
"",
"e31a1d3af650e8e2848bd78432d89ecd1fdece9842dc2792e7bda080f537b17b",
"f3f09905e9a871e757348834f483ed71be9c0f437c8d74b0",
"f5c69528963e17db725a28885d30a45194f12848b8b7644c7bded47a2ee83e6d4ef34006305cfdf82effdced461d",
},
{
"0f5ca45a54875d1d19e952e53caeaa19389342f776dab11723535503338d6f77202a37",
"",
"1031bc920d4fcb4434553b1bf2d25ab375200643bf523ff037bf8914297e8dca",
"4cc77e2ef5445e07b5f44de2dc5bf62d35b8c6f69502d2bf",
"7aa8669e1bfe8b0688899cdddbb8cee31265928c66a69a5090478da7397573b1cc0f64121e7d8bff8db0ddd3c17460d7f29a12",
},
{
"c45578c04c194994e89025c7ffb015e5f138be3cd1a93640af167706aee2ad25ad38696df41ad805",
"",
"ac8648b7c94328419c668ce1c57c71893adf73abbb98892a4fc8da17400e3a5e",
"4ad637facf97af5fc03207ae56219da9972858b7430b3611",
"49e093fcd074fb67a755669119b8bd430d98d9232ca988882deeb3508bde7c00160c35cea89092db864dcb6d440aefa5aacb8aa7b9c04cf0",
},
{
"b877bfa192ea7e4c7569b9ee973f89924d45f9d8ed03c7098ad0cad6e7880906befedcaf6417bb43efabca7a2f",
"",
"125e331d5da423ecabc8adf693cdbc2fc3d3589740d40a3894f914db86c02492",
"913f8b2f08006e6260de41ec3ee01d938a3e68fb12dc44c4",
"1be334253423c90fc8ea885ee5cd3a54268c035ba8a2119e5bd4f7822cd7bf9cb4cec568d5b6d6292606d32979e044df3504e6eb8c0b2fc7e2a0e17d62",
},
{
"d946484a1df5f85ff72c92ff9e192660cde5074bd0ddd5de900c35eb10ed991113b1b19884631bc8ceb386bcd83908061ce9",
"",
"b7e83276373dcf8929b6a6ea80314c9de871f5f241c9144189ee4caf62726332",
"f59f9d6e3e6c00720dc20dc21586e8330431ebf42cf9180e",
"a38a662b18c2d15e1b7b14443cc23267a10bee23556b084b6254226389c414069b694159a4d0b5abbe34de381a0e2c88b947b4cfaaebf50c7a1ad6c656e386280ad7",
},
{
"d266927ca40b2261d5a4722f3b4da0dd5bec74e103fab431702309fd0d0f1a259c767b956aa7348ca923d64c04f0a2e898b0670988b15e",
"",
"a60e09cd0bea16f26e54b62b2908687aa89722c298e69a3a22cf6cf1c46b7f8a",
"92da9d67854c53597fc099b68d955be32df2f0d9efe93614",
"9dd6d05832f6b4d7f555a5a83930d6aed5423461d85f363efb6c474b6c4c8261b680dea393e24c2a3c8d1cc9db6df517423085833aa21f9ab5b42445b914f2313bcd205d179430",
},
{
"f7e11b4d372ed7cb0c0e157f2f9488d8efea0f9bbe089a345f51bdc77e30d1392813c5d22ca7e2c7dfc2e2d0da67efb2a559058d4de7a11bd2a2915e",
"",
"194b1190fa31d483c222ec475d2d6117710dd1ac19a6f1a1e8e894885b7fa631",
"6b07ea26bb1f2d92e04207b447f2fd1dd2086b442a7b6852",
"25ae14585790d71d39a6e88632228a70b1f6a041839dc89a74701c06bfa7c4de3288b7772cb2919818d95777ab58fe5480d6e49958f5d2481431014a8f88dab8f7e08d2a9aebbe691430011d",
},
{
"",
"1e2b11e3",
"70cd96817da85ede0efdf03a358103a84561b25453dee73735e5fb0161b0d493",
"5ddeba49f7266d11827a43931d1c300dd47a3c33f9f8bf9b",
"592fc4c19f3cddec517b2a00f9df9665",
},
{
"81b3cb7eb3",
"efcfd0cf",
"a977412f889281a6d75c24186f1bfaa00dcc5132f0929f20ef15bbf9e63c4c91",
"3f26ca997fb9166d9c615babe3e543ca43ab7cab20634ac5",
"8e4ade3e254cf52e93eace5c46667f150832725594",
},
{
"556f97f2ebdb4e949923",
"f7cee2e0",
"787b3e86546a51028501c801dadf8d5b996fd6f6f2363d5d0f900c44f6a2f4c2",
"7fa6af59a779657d1cada847439ea5b92a1337cfbebbc3b1",
"608ec22dae5f48b89d6f0d2a940d5a7661e0a8e68aaee4ad2d96",
},
{
"c06847a36ad031595b60edd44dc245",
"d4175e1f",
"16de31e534dd5af32801b1acd0ec541d1f8d82bcbc3af25ec815f3575b7aca73",
"29f6656972838f56c1684f6a278f9e4e207b51d68706fc25",
"836082cc51303e500fceade0b1a18f1d97d64ff41cc81754c07d6231b9fd1b",
},
{
"0d03c22ced7b29c6741e72166cd61792028dfc80",
"e505dad0",
"ac2b426e5c5c8e00666180a3410e8a2f6e52247a43aecea9622163e8433c93b2",
"c1123430468228625967bbc0fbd0f963e674372259ff2deb",
"bf09979bf4fed2eec6c97f6e1bcfac35eeffc6d54a55cc1d83d8767ae74db2d7cdfbc371",
},
{
"05bf00e1707cffe7ccbd06a9f846d0fd471a700ed43b4facb8",
"d863bebe",
"66c121f0f84b95ba1e6d29e7d81900bc96a642421b9b6105ae5eb5f2e7b07577",
"8ed6ae211a661e967995b71f7316ba88f44322bb62b4187b",
"b2c5c85d087e0305e9058fba52b661fb3d7f21cb4d4915ae048bc9e5d66a2f921dd4a1c1b030f442c9",
},
{
"5f2b91a9be8bfaa21451ddc6c5cf28d1cc00b046b76270b95cda3c280c83",
"a8750275",
"39592eb276877fca9dd11e2181c0b23127328407e3cc11e315e5d748f43529cc",
"1084bebd756f193d9eea608b3a0193a5028f8ced19684821",
"eaee1f49ac8468154c601a5dd8b84d597602e5a73534b5fad5664f97d0f017dd114752be969679cf610340c6a312",
},
{
"01e8e269b5376943f3b2d245483a76461dc8b7634868b559165f5dbb20839029fae9bb",
"a1e96da0",
"b8386123b87e50d9d046242cf1bf141fce7f65aff0fba76861a2bc72582d6ff0",
"0fbe2a13a89bea031de96d78f9f11358ba7b6a5e724b4392",
"705ec3f910ec85c6005baa99641de6ca43332ff52b5466df6af4ffbe4ef2a376a8f871d1eae503b5896601fee005cdc1f4c1c6",
},
{
"706daba66e2edb1f828f3c0051e3cc214b12210bde0587bba02580f741a4c83e84d4e9fe961120cd",
"87663c5a",
"d519d82ba8a3f0c3af9efe36682b62e285167be101a526c1d73000f169c2a486",
"ad651aac536978e2bc1a54816345ac5e9a9b43b3d9cc0bfc",
"07051b5e72da9c4811beb07ff9f95aece67eae18420eb3f0e8bb8a5e26d4b483fa40eb063a2354842d0c8a41d981cc2b77c530b496db01c8",
},
{
"1f6b24f2f0d9eb460d726bed953d66fcc4ecc29da6ed2fd711358eac3b2609d74ba3e21885156cde3cbe6d9b6f",
"f5efbc4e",
"86068a00544f749ad4ad15bb8e427ae78577ae22f4ca9778efff828ba10f6b20",
"c8420412c9626dcd34ece14593730f6aa2d01ec51cacd59f",
"a99f6c88eac35bb34439e34b292fe9db8192446dcdc81e2192060ec36d98b47de2bee12bf0f67cb24fb0949c07733a6781cd9455cdc61123f506886b04",
},
{
"d69389d83362be8c0ddb738659a6cc4bd65d88cb5b525232f4d59a7d4751a7203c254923ecb6873e803220aab19664789a63",
"bc35fb1c",
"835855b326a98682b3075b4d7f1b89059c3cdfc547d4296c80ce7a77ba6434e3",
"c27cb75fc319ba431cbaeb120341d0c4745d883eb47e92bc",
"db6dc3f9a0f4f1a6df2495a88910550c2c6205478bfc1e81282e34b5b36d984c72c0509c522c987c61d2e640ced69402a6d33aa10d3d0b81e680b3c19bc142e81923",
},
{
"a66a7f089115ed9e2d5bb5d33d7282a7afe401269b00f2a233a59c04b794a42901d862140b61d18d7c7f0ad5da040613e557f8abc74219",
"2c060aaf",
"99758aa7714fd707931f71803eefe04a06955041308a0b2a1104313b270ccf34",
"63f690d8926408c7a34fe8ddd505a8dc58769dc74e8d5da6",
"92b21ee85afcd8996ac28f3aed1047ad814d6e4ffbca3159af16f26eded83e4abda9e4275eb3ff0ad90dffe09f2d443b628f824f680b46527ce0128e8de1920f7c44350ebe7913",
},
{
"f955183b1f762d4536d3f6885ea7f5ac27414caf46c2e24a2fd3bd56b91c53d840fb657224565e0a6f686f8ba320e04a401057399d9a3d995ab17c13",
"c372ddc5",
"a188be3795b2ca2e69b6aa263244f0963c492d694cf6c9b705a1d7045f3f2a26",
"51bb484ea094ee140474681e1c838e4442fd148de2cc345a",
"48759a5ddfdd829d11de8e0c538ce4a9c475faab6912039b568ad92d737d172fc1eb0c00c3793de6dddbfacfdbbc7f44aeba33684e18005aa982b6fc6c556e63bb90ff7a1dde8153a63eabe0",
},
{
"",
"e013cd0bfafd486d",
"af3d3ba094d38299ecb91c17bfe3d085da5bd42e11acf8acb5bc26a4be9a7583",
"7dd63c14173831f109761b1c1abe18f6ba937d825957011b",
"8bc685a7d9d501952295cd25d8c92517",
},
{
"284b64597e",
"31d013e53aa3ea79",
"93c77409d7f805f97fe683b2dd6ee06152a5e918b3eed5b731acccffdcb2cc04",
"3d331e90c4597cf0c30d1b7cfbd07bcb6ab927eda056873c",
"3538a449d6c18d148a8c6cb76f1bc288657ac7036a",
},
{
"9fe67f5c78180ede8274",
"188608d230d75860",
"b7cca89a82640aea6f80b458c9e633d88594fb498959d39787be87030892d48f",
"ef891d50e8c08958f814590fdb7a9f16c61cc2aae1682109",
"bbb40c30f3d1391a5b38df480cbbf964b71e763e8140751f4e28",
},
{
"3a2826b6f7e3d542e4ded8f23c9aa4",
"260033e789c4676a",
"7fe2731214f2b4b42f93217d43f1776498413725e4f6cfe62b756e5a52df10ea",
"888728219ebf761547f5e2218532714403020e5a8b7a49d0",
"fe0328f883fcd88930ae017c0f54ed90f883041efc020e959125af370c1d47",
},
{
"91858bf7b969005d7164acbd5678052b651c53e0",
"f3cc53ecafcbadb3",
"d69c04e9726b22d51f97bc9da0f0fda86736e6b78e8ef9f6f0000f79890d6d43",
"6de3c45161b434e05445cf6bf69eef7bddf595fc6d8836bd",
"a8869dd578c0835e120c843bb7dedc7a1e9eae24ffd742be6bf5b74088a8a2c550976fcb",
},
{
"b3b1a4d6b2a2b9c5a1ca6c1efaec34dcfa1acbe7074d5e10cc",
"d0f72bd16cda3bae",
"2b317857b089c9305c49b83019f6e158bc4ecc3339b39ade02ee10c37c268da0",
"cb5fa6d1e14a0b4bdf350cd10c8a7bd638102911ec74be09",
"e6372f77c14343650074e07a2b7223c37b29242224b722b24d63b5956f27aa64ce7ce4e39cd14a2787",
},
{
"057d3e9f865be7dff774938cab6d080e50cf9a1593f53c0063201e0bb7ae",
"fd3881e505c8b12d",
"36e42b1ef1ee8d068f09b5fad3ee43d98d34aa3e3f994f2055aee139da71de9d",
"24124da36473d01bdca30297c9eef4fe61955525a453da17",
"a8b28139524c98c1f8776f442eac4c22766fe6aac83224641c58bf021fc9cb709ec4706f49c2d0c1828acf2bfe8d",
},
{
"bd8f13e928c34d67a6c70c3c7efdf2982ecc31d8cee68f9cbddc75912cd828ac93d28b",
"193206c8fcc5b19b",
"6e47c40c9d7b757c2efca4d73890e4c73f3c859aab4fdc64b564b8480dd84e72",
"ca31340ae20d30fe488be355cb36652c5db7c9d6265a3e95",
"a121efc5e1843deade4b8adbfef1808de4eda222f176630ad34fb476fca19e0299e4a13668e53cf13882035ba4f04f47c8b4e3",
},
{
"23067a196e977d10039c14ff358061c918d2148d31961bb3e12c27c5122383cb25c4d1d79c775720",
"62338d02fff78a00",
"2c5c79c92d91fb40ef7d0a77e8033f7b265e3bab998b8116d17b2e62bb4f8a09",
"024736adb1d5c01006dffd8158b57936d158d5b42054336d",
"46d0905473a995d38c7cdbb8ef3da96ecc82a22c5b3c6c9d1c4a61ae7a17db53cb88c5f7eccf2da1d0c417c300f989b4273470e36f03542f",
},
{
"252e966c680329eb687bff813b78fea3bfd3505333f106c6f9f45ba69896723c41bb763793d9b266e897d05557",
"1e93e0cfe6523380",
"9ec6fd1baa13ee16aec3fac16718a2baccf18a403cec467c25b7448e9b321110",
"e7120b1018ab363a36e61102eedbcbe9847a6cbacaa9c328",
"2934f034587d4144bb11182679cd2cd1c99c8088d18e233379e9bc9c41107a1f57a2723ecc7b9ba4e6ee198adf0fd766738e828827dc73136fc5b996e9",
},
{
"6744aefcb318f12bc6eeb59d4d62f7eb95f347cea14bd5158415f07f84e4e3baa3de07512d9b76095ac1312cfcb1bb77f499",
"608d2a33ce5d0b04",
"0f665cbdaaa40f4f5a00c53d951b0a98aac2342be259a52670f650a783be7aab",
"378bdb57e957b8c2e1500c9513052a3b02ff5b7edbd4a3a7",
"341c60fcb374b394f1b01a4a80aedef49ab0b67ec963675e6eec43ef106f7003be87dbf4a8976709583dccc55abc7f979c4721837e8664a69804ea31736aa2af615a",
},
{
"bcf1004f988220b7ce063ef2ec4e276ffd074f0a90aa807de1532679d2a1505568eaa4192d9a6ea52cc500322343ce9f8e68cc2c606d83",
"e64bd00126c8792c",
"58e65150d6a15dcefbc14a171998987ad0d709fb06a17d68d6a778759681c308",
"106d2bd120b06e4eb10bc674fe55c77a3742225268319303",
"a28052a6686a1e9435fee8702f7da563a7b3d7b5d3e9e27f11abf73db309cd1f39a34756258c1c5c7f2fb12cf15eb20175c2a08fc93dd19c5e482ef3fbef3d8404a3cfd54a7baf",
},
{
"acd08d4938a224b4cb2d723bf75420f3ea27b698fadd815bb7db9548a05651398644354334e69f8e4e5503bf1a6f92b38e860044a7edca6874038ce1",
"28a137808d0225b8",
"a031203b963a395b08be55844d81af39d19b23b7cc24b21afa31edc1eea6edd6",
"e8b31c52b6690f10f4ae62ba9d50ba39fb5edcfb78400e35",
"35cf39ba31da95ac9b661cdbd5e9c9655d13b8ff065c4ec10c810833a47a87d8057dd1948a7801bfe6904b49fed0aabfb3cd755a1a262d372786908ddcf64cae9f71cb9ed199c3ddacc50116",
},
{
"",
"cda7ee2857e09e9054ef6806",
"d91dffb18132d8dd3d144a2f10ba28bc5df36cb60369f3b19893ec91db3cf904",
"ee56f19c62b0438da6a0d9e01844313902be44f84a6a4ce7",
"ccd48b61a5683c195d4424009eb1d147",
},
{
"350f4c7ac2",
"7c104b539c1d2ae022434cd6",
"cbb61e369117f9250f68fa707240c554359262a4d66c757f80e3aeb6920894fb",
"fbb14c9943444eac5413c6f5c8095451eddece02c9461043",
"b5c6a35865ed8e5216ff6c77339ee1ab570de50e51",
},
{
"4f0d61d3ea03a44a8df0",
"51c20a8ae9e9794da931fe23",
"ba6ced943aa62f9261d7513b822e02054e099acafb5360f0d850064da48b5a4f",
"04c68cb50cdbb0ec03f8381cf59b886e64c40548bf8e3f82",
"ea45a73957e2a853655623f2a3bb58791f7ea36dd2957ed66ffa",
},
{
"4fbdd4d4293a8f34fdbc8f3ad44cf6",
"8212f315e3759c3253c588bb",
"5354791bc2370415811818e913e310dd12e6a0cf5dcab2b6424816eecccf4b65",
"7ee6353c2fbc73c9ebc652270bc86e4008e09583e623e679",
"50a354811a918e1801fb567621a8924baf8dd79da6d36702855d3753f1319c",
},
{
"5a6f68b5a9a9920ca9c6edf5be7c0af150a063c4",
"9a524aa62938fb7a1e50ed06",
"fd91605a6ad85d8ba7a71b08dce1032aa9992bf4f28d407a53ddda04c043cada",
"46791d99d6de33e79025bf9e97c198e7cf409614c6284b4d",
"648033c1eb615467e90b7d3ac24202d8b849549141f9bab03e9e910c29b8eab3d4fb3f2c",
},
{
"d9318c2c0d9ed89e35d242a6b1d496e7e0c5bbdf77eba14c56",
"a16053c35fbe8dc93c14a81f",
"f21406aec83134ebf7bc48c6d0f45acb5f341fbc7d3b5a9bff3ea1333c916af7",
"de6b977be450d5efa7777e006802ddbb10814a22da1c3cd9",
"8d3dad487d5161663da830b71c3e24ec5cdb74d858cbb73b084ed0902198532aad3a18416966bff223",
},
{
"68d0ee08d38cb4bcc9268fee3030666e70e41fcabf6fe06536eeec43eec5",
"11e09447d40b22dc98070eec",
"da5ee1ec02eab13220fcb94f16efec848a8dd57c0f4d67955423f5d17fde5aa3",
"8f13e61d773a250810f75d46bf163a3f9205be5751f6049a",
"92a103b03764c1ad1f88500d22eeae5c0fe1044c872987c0b97affc5e8c3d783f8cc28a11dc91990ea22dd1bad74",
},
{
"a1d960bda08efcf19e136dc1e8b05b6b381c820eda5f9a8047e1a2dd1803a1e4d11a7f",
"aa73d8d4aaa0cfd9d80a9ae8",
"08028833d617c28ba75b48f177cb5da87189189abb68dcb8974eca9230c25945",
"f7b6f34a910fd11588f567de8555932291f7df05f6e2b193",
"99cfc4cca193998bae153b744e6c94a82a2867780aa0f43acddb7c433fcb297311313ec2199f00d7ca7da0646b40113c60e935",
},
{
"3b4ae39a745b6247ce5baf675ec36c5065b1bf76c8379eab4b769961d43a753896d068938017777e",
"128c017a985052f8cdbc6b28",
"4683d5caff613187a9b16af897253848e9c54fc0ec319de62452a86961d3cbb2",
"5612a13c2da003b91188921cbac3fa093eba99d8cbbb51ff",
"91a98b93b2174257175f7c882b45cc252e0db8667612bd270c1c12fe28b6bf209760bf8f370318f92ae3f88a5d4773b05714132cc28dddb8",
},
{
"22ccf680d2995ef6563de281cff76882a036a59ad73f250e710b3040590d69bccde8a8411abe8b0d3cb728ca82",
"13a97d0a167a61aa21e531ec",
"9e140762eed274948b66de25e6e8f36ab65dc730b0cb096ef15aaba900a5588c",
"d0e9594cfd42ab72553bf34062a263f588bb8f1fc86a19f5",
"f194fc866dfba30e42c4508b7d90b3fa3f8983831ede713334563e36aa861f2f885b40be1dbe20ba2d10958a12823588d4bbbefb81a87d87315204f5e3",
},
{
"a65f5d10c482b3381af296e631eb605eba6a11ccec6ceab021460d0bd35feb676ec6dbba5d4ad6c9f4d683ea541035bc80fa",
"f15ae71ffed50a8fcc4996b0",
"f535d60e8b75ac7e526041eed86eb4d65ae7e315eff15dba6c0133acc2a6a4bf",
"01ba61691ebb3c66d2f94c1b1c597ecd7b5ff7d2a30be405",
"d79e7c3893df5a5879c2f0a3f7ca619f08e4540f3ac7db35790b4211b9d47ae735adadf35fd47252a4763e3fd2b2cd8157f6ea7986108a53437962670a97d68ee281",
},
{
"8c014655b97f6da76b0b168b565fd62de874c164fd7e227346a0ec22c908bed1e2a0b429620e6f3a68dd518f13a2c0250608a1cb08a7c3",
"10a7eff999029c5040c1b3bd",
"bf11af23e88c350a443493f6fa0eb34f234f4daa2676e26f0701bce5642d13f4",
"f14c97392afd2e32e2c625910ca029f9b6e81676c79cc42f",
"78d5226f372d5d60681dbfc749d12df74249f196b0cbf14fa65a3a59dc65ae458455ec39baa1df3397afe752bb06f6f13bf03c99abda7a95c1d0b73fd92d5f888a5f6f889a9aea",
},
{
"66234d7a5b71eef134d60eccf7d5096ee879a33983d6f7a575e3a5e3a4022edccffe7865dde20b5b0a37252e31cb9a3650c63e35b057a1bc200a5b5b",
"ccc2406f997bcae737ddd0f5",
"d009eeb5b9b029577b14d200b7687b655eedb7d74add488f092681787999d66d",
"99319712626b400f9458dbb7a9abc9f5810f25b47fc90b39",
"543a2bbf52fd999027ae7c297353f3ce986f810bc2382583d0a81fda5939e4c87b6e8d262790cd614d6f753d8035b32adf43acc7f6d4c2c44289538928564b6587c2fcb99de1d8e34ffff323",
},
}

@ -23,6 +23,12 @@ func (b *Builder) AddASN1Int64(v int64) {
b.addASN1Signed(asn1.INTEGER, v)
}
// AddASN1Int64WithTag appends a DER-encoded ASN.1 INTEGER with the
// given tag.
func (b *Builder) AddASN1Int64WithTag(v int64, tag asn1.Tag) {
b.addASN1Signed(tag, v)
}
// AddASN1Enum appends a DER-encoded ASN.1 ENUMERATION.
func (b *Builder) AddASN1Enum(v int64) {
b.addASN1Signed(asn1.ENUM, v)
@ -224,6 +230,9 @@ func (b *Builder) AddASN1(tag asn1.Tag, f BuilderContinuation) {
// String
// ReadASN1Boolean decodes an ASN.1 INTEGER and converts it to a boolean
// representation into out and advances. It reports whether the read
// was successful.
func (s *String) ReadASN1Boolean(out *bool) bool {
var bytes String
if !s.ReadASN1(&bytes, asn1.INTEGER) || len(bytes) != 1 {
@ -245,8 +254,8 @@ func (s *String) ReadASN1Boolean(out *bool) bool {
var bigIntType = reflect.TypeOf((*big.Int)(nil)).Elem()
// ReadASN1Integer decodes an ASN.1 INTEGER into out and advances. If out does
// not point to an integer or to a big.Int, it panics. It returns true on
// success and false on error.
// not point to an integer or to a big.Int, it panics. It reports whether the
// read was successful.
func (s *String) ReadASN1Integer(out interface{}) bool {
if reflect.TypeOf(out).Kind() != reflect.Ptr {
panic("out is not a pointer")
@ -359,8 +368,16 @@ func asn1Unsigned(out *uint64, n []byte) bool {
return true
}
// ReadASN1Enum decodes an ASN.1 ENUMERATION into out and advances. It returns
// true on success and false on error.
// ReadASN1Int64WithTag decodes an ASN.1 INTEGER with the given tag into out
// and advances. It reports whether the read was successful and resulted in a
// value that can be represented in an int64.
func (s *String) ReadASN1Int64WithTag(out *int64, tag asn1.Tag) bool {
var bytes String
return s.ReadASN1(&bytes, tag) && checkASN1Integer(bytes) && asn1Signed(out, bytes)
}
// ReadASN1Enum decodes an ASN.1 ENUMERATION into out and advances. It reports
// whether the read was successful.
func (s *String) ReadASN1Enum(out *int) bool {
var bytes String
var i int64
@ -392,7 +409,7 @@ func (s *String) readBase128Int(out *int) bool {
}
// ReadASN1ObjectIdentifier decodes an ASN.1 OBJECT IDENTIFIER into out and
// advances. It returns true on success and false on error.
// advances. It reports whether the read was successful.
func (s *String) ReadASN1ObjectIdentifier(out *encoding_asn1.ObjectIdentifier) bool {
var bytes String
if !s.ReadASN1(&bytes, asn1.OBJECT_IDENTIFIER) || len(bytes) == 0 {
@ -431,7 +448,7 @@ func (s *String) ReadASN1ObjectIdentifier(out *encoding_asn1.ObjectIdentifier) b
}
// ReadASN1GeneralizedTime decodes an ASN.1 GENERALIZEDTIME into out and
// advances. It returns true on success and false on error.
// advances. It reports whether the read was successful.
func (s *String) ReadASN1GeneralizedTime(out *time.Time) bool {
var bytes String
if !s.ReadASN1(&bytes, asn1.GeneralizedTime) {
@ -449,8 +466,8 @@ func (s *String) ReadASN1GeneralizedTime(out *time.Time) bool {
return true
}
// ReadASN1BitString decodes an ASN.1 BIT STRING into out and advances. It
// returns true on success and false on error.
// ReadASN1BitString decodes an ASN.1 BIT STRING into out and advances.
// It reports whether the read was successful.
func (s *String) ReadASN1BitString(out *encoding_asn1.BitString) bool {
var bytes String
if !s.ReadASN1(&bytes, asn1.BIT_STRING) || len(bytes) == 0 {
@ -471,8 +488,8 @@ func (s *String) ReadASN1BitString(out *encoding_asn1.BitString) bool {
}
// ReadASN1BitString decodes an ASN.1 BIT STRING into out and advances. It is
// an error if the BIT STRING is not a whole number of bytes. This function
// returns true on success and false on error.
// an error if the BIT STRING is not a whole number of bytes. It reports
// whether the read was successful.
func (s *String) ReadASN1BitStringAsBytes(out *[]byte) bool {
var bytes String
if !s.ReadASN1(&bytes, asn1.BIT_STRING) || len(bytes) == 0 {
@ -489,14 +506,14 @@ func (s *String) ReadASN1BitStringAsBytes(out *[]byte) bool {
// ReadASN1Bytes reads the contents of a DER-encoded ASN.1 element (not including
// tag and length bytes) into out, and advances. The element must match the
// given tag. It returns true on success and false on error.
// given tag. It reports whether the read was successful.
func (s *String) ReadASN1Bytes(out *[]byte, tag asn1.Tag) bool {
return s.ReadASN1((*String)(out), tag)
}
// ReadASN1 reads the contents of a DER-encoded ASN.1 element (not including
// tag and length bytes) into out, and advances. The element must match the
// given tag. It returns true on success and false on error.
// given tag. It reports whether the read was successful.
//
// Tags greater than 30 are not supported (i.e. low-tag-number format only).
func (s *String) ReadASN1(out *String, tag asn1.Tag) bool {
@ -509,7 +526,7 @@ func (s *String) ReadASN1(out *String, tag asn1.Tag) bool {
// ReadASN1Element reads the contents of a DER-encoded ASN.1 element (including
// tag and length bytes) into out, and advances. The element must match the
// given tag. It returns true on success and false on error.
// given tag. It reports whether the read was successful.
//
// Tags greater than 30 are not supported (i.e. low-tag-number format only).
func (s *String) ReadASN1Element(out *String, tag asn1.Tag) bool {
@ -521,8 +538,8 @@ func (s *String) ReadASN1Element(out *String, tag asn1.Tag) bool {
}
// ReadAnyASN1 reads the contents of a DER-encoded ASN.1 element (not including
// tag and length bytes) into out, sets outTag to its tag, and advances. It
// returns true on success and false on error.
// tag and length bytes) into out, sets outTag to its tag, and advances.
// It reports whether the read was successful.
//
// Tags greater than 30 are not supported (i.e. low-tag-number format only).
func (s *String) ReadAnyASN1(out *String, outTag *asn1.Tag) bool {
@ -531,14 +548,14 @@ func (s *String) ReadAnyASN1(out *String, outTag *asn1.Tag) bool {
// ReadAnyASN1Element reads the contents of a DER-encoded ASN.1 element
// (including tag and length bytes) into out, sets outTag to is tag, and
// advances. It returns true on success and false on error.
// advances. It reports whether the read was successful.
//
// Tags greater than 30 are not supported (i.e. low-tag-number format only).
func (s *String) ReadAnyASN1Element(out *String, outTag *asn1.Tag) bool {
return s.readASN1(out, outTag, false /* include header */)
}
// PeekASN1Tag returns true if the next ASN.1 value on the string starts with
// PeekASN1Tag reports whether the next ASN.1 value on the string starts with
// the given tag.
func (s String) PeekASN1Tag(tag asn1.Tag) bool {
if len(s) == 0 {
@ -547,7 +564,8 @@ func (s String) PeekASN1Tag(tag asn1.Tag) bool {
return asn1.Tag(s[0]) == tag
}
// SkipASN1 reads and discards an ASN.1 element with the given tag.
// SkipASN1 reads and discards an ASN.1 element with the given tag. It
// reports whether the operation was successful.
func (s *String) SkipASN1(tag asn1.Tag) bool {
var unused String
return s.ReadASN1(&unused, tag)
@ -556,7 +574,7 @@ func (s *String) SkipASN1(tag asn1.Tag) bool {
// ReadOptionalASN1 attempts to read the contents of a DER-encoded ASN.1
// element (not including tag and length bytes) tagged with the given tag into
// out. It stores whether an element with the tag was found in outPresent,
// unless outPresent is nil. It returns true on success and false on error.
// unless outPresent is nil. It reports whether the read was successful.
func (s *String) ReadOptionalASN1(out *String, outPresent *bool, tag asn1.Tag) bool {
present := s.PeekASN1Tag(tag)
if outPresent != nil {
@ -569,7 +587,7 @@ func (s *String) ReadOptionalASN1(out *String, outPresent *bool, tag asn1.Tag) b
}
// SkipOptionalASN1 advances s over an ASN.1 element with the given tag, or
// else leaves s unchanged.
// else leaves s unchanged. It reports whether the operation was successful.
func (s *String) SkipOptionalASN1(tag asn1.Tag) bool {
if !s.PeekASN1Tag(tag) {
return true
@ -581,8 +599,8 @@ func (s *String) SkipOptionalASN1(tag asn1.Tag) bool {
// ReadOptionalASN1Integer attempts to read an optional ASN.1 INTEGER
// explicitly tagged with tag into out and advances. If no element with a
// matching tag is present, it writes defaultValue into out instead. If out
// does not point to an integer or to a big.Int, it panics. It returns true on
// success and false on error.
// does not point to an integer or to a big.Int, it panics. It reports
// whether the read was successful.
func (s *String) ReadOptionalASN1Integer(out interface{}, tag asn1.Tag, defaultValue interface{}) bool {
if reflect.TypeOf(out).Kind() != reflect.Ptr {
panic("out is not a pointer")
@ -619,8 +637,8 @@ func (s *String) ReadOptionalASN1Integer(out interface{}, tag asn1.Tag, defaultV
// ReadOptionalASN1OctetString attempts to read an optional ASN.1 OCTET STRING
// explicitly tagged with tag into out and advances. If no element with a
// matching tag is present, it writes defaultValue into out instead. It returns
// true on success and false on error.
// matching tag is present, it sets "out" to nil instead. It reports
// whether the read was successful.
func (s *String) ReadOptionalASN1OctetString(out *[]byte, outPresent *bool, tag asn1.Tag) bool {
var present bool
var child String
@ -644,6 +662,7 @@ func (s *String) ReadOptionalASN1OctetString(out *[]byte, outPresent *bool, tag
// ReadOptionalASN1Boolean sets *out to the value of the next ASN.1 BOOLEAN or,
// if the next bytes are not an ASN.1 BOOLEAN, to the value of defaultValue.
// It reports whether the operation was successful.
func (s *String) ReadOptionalASN1Boolean(out *bool, defaultValue bool) bool {
var present bool
var child String

@ -149,6 +149,39 @@ func TestReadASN1IntegerSigned(t *testing.T) {
}
}
})
// Repeat with the implicit-tagging functions
t.Run("WithTag", func(t *testing.T) {
for i, test := range testData64 {
tag := asn1.Tag((i * 3) % 32).ContextSpecific()
testData := make([]byte, len(test.in))
copy(testData, test.in)
// Alter the tag of the test case.
testData[0] = uint8(tag)
in := String(testData)
var out int64
ok := in.ReadASN1Int64WithTag(&out, tag)
if !ok || out != test.out {
t.Errorf("#%d: in.ReadASN1Int64WithTag() = %v, want true; out = %d, want %d", i, ok, out, test.out)
}
var b Builder
b.AddASN1Int64WithTag(test.out, tag)
result, err := b.Bytes()
if err != nil {
t.Errorf("#%d: AddASN1Int64WithTag failed: %s", i, err)
continue
}
if !bytes.Equal(result, testData) {
t.Errorf("#%d: AddASN1Int64WithTag: got %x, want %x", i, result, testData)
}
}
})
}
func TestReadASN1IntegerUnsigned(t *testing.T) {

@ -50,8 +50,14 @@ func NewFixedBuilder(buffer []byte) *Builder {
}
}
// SetError sets the value to be returned as the error from Bytes. Writes
// performed after calling SetError are ignored.
func (b *Builder) SetError(err error) {
b.err = err
}
// Bytes returns the bytes written by the builder or an error if one has
// occurred during during building.
// occurred during building.
func (b *Builder) Bytes() ([]byte, error) {
if b.err != nil {
return nil, b.err
@ -94,7 +100,7 @@ func (b *Builder) AddBytes(v []byte) {
b.add(v...)
}
// BuilderContinuation is continuation-passing interface for building
// BuilderContinuation is a continuation-passing interface for building
// length-prefixed byte sequences. Builder methods for length-prefixed
// sequences (AddUint8LengthPrefixed etc) will invoke the BuilderContinuation
// supplied to them. The child builder passed to the continuation can be used
@ -268,9 +274,11 @@ func (b *Builder) flushChild() {
return
}
if !b.fixedSize {
b.result = child.result // In case child reallocated result.
if b.fixedSize && &b.result[0] != &child.result[0] {
panic("cryptobyte: BuilderContinuation reallocated a fixed-size buffer")
}
b.result = child.result
}
func (b *Builder) add(bytes ...byte) {
@ -278,7 +286,7 @@ func (b *Builder) add(bytes ...byte) {
return
}
if b.child != nil {
panic("attempted write while child is pending")
panic("cryptobyte: attempted write while child is pending")
}
if len(b.result)+len(bytes) < len(bytes) {
b.err = errors.New("cryptobyte: length overflow")
@ -290,6 +298,26 @@ func (b *Builder) add(bytes ...byte) {
b.result = append(b.result, bytes...)
}
// Unwrite rolls back n bytes written directly to the Builder. An attempt by a
// child builder passed to a continuation to unwrite bytes from its parent will
// panic.
func (b *Builder) Unwrite(n int) {
if b.err != nil {
return
}
if b.child != nil {
panic("cryptobyte: attempted unwrite while child is pending")
}
length := len(b.result) - b.pendingLenLen - b.offset
if length < 0 {
panic("cryptobyte: internal error")
}
if n > length {
panic("cryptobyte: attempted to unwrite more than was written")
}
b.result = b.result[:len(b.result)-n]
}
// A MarshalingValue marshals itself into a Builder.
type MarshalingValue interface {
// Marshal is called by Builder.AddValue. It receives a pointer to a builder

@ -327,12 +327,14 @@ func TestWriteWithPendingChild(t *testing.T) {
var b Builder
b.AddUint8LengthPrefixed(func(c *Builder) {
c.AddUint8LengthPrefixed(func(d *Builder) {
defer func() {
if recover() == nil {
t.Errorf("recover() = nil, want error; c.AddUint8() did not panic")
}
func() {
defer func() {
if recover() == nil {
t.Errorf("recover() = nil, want error; c.AddUint8() did not panic")
}
}()
c.AddUint8(2) // panics
}()
c.AddUint8(2) // panics
defer func() {
if recover() == nil {
@ -351,6 +353,92 @@ func TestWriteWithPendingChild(t *testing.T) {
})
}
func TestSetError(t *testing.T) {
const errorStr = "TestSetError"
var b Builder
b.SetError(errors.New(errorStr))
ret, err := b.Bytes()
if ret != nil {
t.Error("expected nil result")
}
if err == nil {
t.Fatal("unexpected nil error")
}
if s := err.Error(); s != errorStr {
t.Errorf("expected error %q, got %v", errorStr, s)
}
}
func TestUnwrite(t *testing.T) {
var b Builder
b.AddBytes([]byte{1, 2, 3, 4, 5})
b.Unwrite(2)
if err := builderBytesEq(&b, 1, 2, 3); err != nil {
t.Error(err)
}
func() {
defer func() {
if recover() == nil {
t.Errorf("recover() = nil, want error; b.Unwrite() did not panic")
}
}()
b.Unwrite(4) // panics
}()
b = Builder{}
b.AddBytes([]byte{1, 2, 3, 4, 5})
b.AddUint8LengthPrefixed(func(b *Builder) {
b.AddBytes([]byte{1, 2, 3, 4, 5})
defer func() {
if recover() == nil {
t.Errorf("recover() = nil, want error; b.Unwrite() did not panic")
}
}()
b.Unwrite(6) // panics
})
b = Builder{}
b.AddBytes([]byte{1, 2, 3, 4, 5})
b.AddUint8LengthPrefixed(func(c *Builder) {
defer func() {
if recover() == nil {
t.Errorf("recover() = nil, want error; b.Unwrite() did not panic")
}
}()
b.Unwrite(2) // panics (attempted unwrite while child is pending)
})
}
func TestFixedBuilderLengthPrefixed(t *testing.T) {
bufCap := 10
inner := bytes.Repeat([]byte{0xff}, bufCap-2)
buf := make([]byte, 0, bufCap)
b := NewFixedBuilder(buf)
b.AddUint16LengthPrefixed(func(b *Builder) {
b.AddBytes(inner)
})
if got := b.BytesOrPanic(); len(got) != bufCap {
t.Errorf("Expected output length to be %d, got %d", bufCap, len(got))
}
}
func TestFixedBuilderPanicReallocate(t *testing.T) {
defer func() {
recover()
}()
b := NewFixedBuilder(make([]byte, 0, 10))
b1 := NewFixedBuilder(make([]byte, 0, 10))
b.AddUint16LengthPrefixed(func(b *Builder) {
*b = *b1
})
t.Error("Builder did not panic")
}
// ASN.1
func TestASN1Int64(t *testing.T) {

@ -37,8 +37,8 @@ func (s *String) Skip(n int) bool {
return s.read(n) != nil
}
// ReadUint8 decodes an 8-bit value into out and advances over it. It
// returns true on success and false on error.
// ReadUint8 decodes an 8-bit value into out and advances over it.
// It reports whether the read was successful.
func (s *String) ReadUint8(out *uint8) bool {
v := s.read(1)
if v == nil {
@ -49,7 +49,7 @@ func (s *String) ReadUint8(out *uint8) bool {
}
// ReadUint16 decodes a big-endian, 16-bit value into out and advances over it.
// It returns true on success and false on error.
// It reports whether the read was successful.
func (s *String) ReadUint16(out *uint16) bool {
v := s.read(2)
if v == nil {
@ -60,7 +60,7 @@ func (s *String) ReadUint16(out *uint16) bool {
}
// ReadUint24 decodes a big-endian, 24-bit value into out and advances over it.
// It returns true on success and false on error.
// It reports whether the read was successful.
func (s *String) ReadUint24(out *uint32) bool {
v := s.read(3)
if v == nil {
@ -71,7 +71,7 @@ func (s *String) ReadUint24(out *uint32) bool {
}
// ReadUint32 decodes a big-endian, 32-bit value into out and advances over it.
// It returns true on success and false on error.
// It reports whether the read was successful.
func (s *String) ReadUint32(out *uint32) bool {
v := s.read(4)
if v == nil {
@ -119,28 +119,27 @@ func (s *String) readLengthPrefixed(lenLen int, outChild *String) bool {
}
// ReadUint8LengthPrefixed reads the content of an 8-bit length-prefixed value
// into out and advances over it. It returns true on success and false on
// error.
// into out and advances over it. It reports whether the read was successful.
func (s *String) ReadUint8LengthPrefixed(out *String) bool {
return s.readLengthPrefixed(1, out)
}
// ReadUint16LengthPrefixed reads the content of a big-endian, 16-bit
// length-prefixed value into out and advances over it. It returns true on
// success and false on error.
// length-prefixed value into out and advances over it. It reports whether the
// read was successful.
func (s *String) ReadUint16LengthPrefixed(out *String) bool {
return s.readLengthPrefixed(2, out)
}
// ReadUint24LengthPrefixed reads the content of a big-endian, 24-bit
// length-prefixed value into out and advances over it. It returns true on
// success and false on error.
// length-prefixed value into out and advances over it. It reports whether
// the read was successful.
func (s *String) ReadUint24LengthPrefixed(out *String) bool {
return s.readLengthPrefixed(3, out)
}
// ReadBytes reads n bytes into out and advances over them. It returns true on
// success and false and error.
// ReadBytes reads n bytes into out and advances over them. It reports
// whether the read was successful.
func (s *String) ReadBytes(out *[]byte, n int) bool {
v := s.read(n)
if v == nil {
@ -150,8 +149,8 @@ func (s *String) ReadBytes(out *[]byte, n int) bool {
return true
}
// CopyBytes copies len(out) bytes into out and advances over them. It returns
// true on success and false on error.
// CopyBytes copies len(out) bytes into out and advances over them. It reports
// whether the copy operation was successful
func (s *String) CopyBytes(out []byte) bool {
n := len(out)
v := s.read(n)

@ -86,7 +86,7 @@ func feFromBytes(dst *fieldElement, src *[32]byte) {
h6 := load3(src[20:]) << 7
h7 := load3(src[23:]) << 5
h8 := load3(src[26:]) << 4
h9 := load3(src[29:]) << 2
h9 := (load3(src[29:]) & 0x7fffff) << 2
var carry [10]int64
carry[9] = (h9 + 1<<24) >> 25

@ -5,6 +5,8 @@
package curve25519
import (
"bytes"
"crypto/rand"
"fmt"
"testing"
)
@ -28,6 +30,30 @@ func TestBaseScalarMult(t *testing.T) {
}
}
// TestHighBitIgnored tests the following requirement in RFC 7748:
//
// When receiving such an array, implementations of X25519 (but not X448) MUST
// mask the most significant bit in the final byte.
//
// Regression test for issue #30095.
func TestHighBitIgnored(t *testing.T) {
var s, u [32]byte
rand.Read(s[:])
rand.Read(u[:])
var hi0, hi1 [32]byte
u[31] &= 0x7f
ScalarMult(&hi0, &s, &u)
u[31] |= 0x80
ScalarMult(&hi1, &s, &u)
if !bytes.Equal(hi0[:], hi1[:]) {
t.Errorf("high bit of group point should not affect result")
}
}
func BenchmarkScalarBaseMult(b *testing.B) {
var in, out [32]byte
in[0] = 1

@ -6,7 +6,10 @@
// https://ed25519.cr.yp.to/.
//
// These functions are also compatible with the “Ed25519” function defined in
// RFC 8032.
// RFC 8032. However, unlike RFC 8032's formulation, this package's private key
// representation includes a public key suffix to make multiple signing
// operations with the same key more efficient. This package refers to the RFC
// 8032 private key as the “seed”.
package ed25519
// This code is a port of the public domain, “ref10” implementation of ed25519
@ -31,6 +34,8 @@ const (
PrivateKeySize = 64
// SignatureSize is the size, in bytes, of signatures generated and verified by this package.
SignatureSize = 64
// SeedSize is the size, in bytes, of private key seeds. These are the private key representations used by RFC 8032.
SeedSize = 32
)
// PublicKey is the type of Ed25519 public keys.
@ -46,6 +51,15 @@ func (priv PrivateKey) Public() crypto.PublicKey {
return PublicKey(publicKey)
}
// Seed returns the private key seed corresponding to priv. It is provided for
// interoperability with RFC 8032. RFC 8032's private keys correspond to seeds
// in this package.
func (priv PrivateKey) Seed() []byte {
seed := make([]byte, SeedSize)
copy(seed, priv[:32])
return seed
}
// Sign signs the given message with priv.
// Ed25519 performs two passes over messages to be signed and therefore cannot
// handle pre-hashed messages. Thus opts.HashFunc() must return zero to
@ -61,19 +75,33 @@ func (priv PrivateKey) Sign(rand io.Reader, message []byte, opts crypto.SignerOp
// GenerateKey generates a public/private key pair using entropy from rand.
// If rand is nil, crypto/rand.Reader will be used.
func GenerateKey(rand io.Reader) (publicKey PublicKey, privateKey PrivateKey, err error) {
func GenerateKey(rand io.Reader) (PublicKey, PrivateKey, error) {
if rand == nil {
rand = cryptorand.Reader
}
privateKey = make([]byte, PrivateKeySize)
publicKey = make([]byte, PublicKeySize)
_, err = io.ReadFull(rand, privateKey[:32])
if err != nil {
seed := make([]byte, SeedSize)
if _, err := io.ReadFull(rand, seed); err != nil {
return nil, nil, err
}
digest := sha512.Sum512(privateKey[:32])
privateKey := NewKeyFromSeed(seed)
publicKey := make([]byte, PublicKeySize)
copy(publicKey, privateKey[32:])
return publicKey, privateKey, nil
}
// NewKeyFromSeed calculates a private key from a seed. It will panic if
// len(seed) is not SeedSize. This function is provided for interoperability
// with RFC 8032. RFC 8032's private keys correspond to seeds in this
// package.
func NewKeyFromSeed(seed []byte) PrivateKey {
if l := len(seed); l != SeedSize {
panic("ed25519: bad seed length: " + strconv.Itoa(l))
}
digest := sha512.Sum512(seed)
digest[0] &= 248
digest[31] &= 127
digest[31] |= 64
@ -85,10 +113,11 @@ func GenerateKey(rand io.Reader) (publicKey PublicKey, privateKey PrivateKey, er
var publicKeyBytes [32]byte
A.ToBytes(&publicKeyBytes)
privateKey := make([]byte, PrivateKeySize)
copy(privateKey, seed)
copy(privateKey[32:], publicKeyBytes[:])
copy(publicKey, publicKeyBytes[:])
return publicKey, privateKey, nil
return privateKey
}
// Sign signs the message with privateKey and returns a signature. It will
@ -171,9 +200,16 @@ func Verify(publicKey PublicKey, message, sig []byte) bool {
edwards25519.ScReduce(&hReduced, &digest)
var R edwards25519.ProjectiveGroupElement
var b [32]byte
copy(b[:], sig[32:])
edwards25519.GeDoubleScalarMultVartime(&R, &hReduced, &A, &b)
var s [32]byte
copy(s[:], sig[32:])
// https://tools.ietf.org/html/rfc8032#section-5.1.7 requires that s be in
// the range [0, order) in order to prevent signature malleability.
if !edwards25519.ScMinimal(&s) {
return false
}
edwards25519.GeDoubleScalarMultVartime(&R, &hReduced, &A, &s)
var checkR [32]byte
R.ToBytes(&checkR)

@ -139,6 +139,19 @@ func TestGolden(t *testing.T) {
if !Verify(pubKey, msg, sig2) {
t.Errorf("signature failed to verify on line %d", lineNo)
}
priv2 := NewKeyFromSeed(priv[:32])
if !bytes.Equal(priv[:], priv2) {
t.Errorf("recreating key pair gave different private key on line %d: %x vs %x", lineNo, priv[:], priv2)
}
if pubKey2 := priv2.Public().(PublicKey); !bytes.Equal(pubKey, pubKey2) {
t.Errorf("recreating key pair gave different public key on line %d: %x vs %x", lineNo, pubKey, pubKey2)
}
if seed := priv2.Seed(); !bytes.Equal(priv[:32], seed) {
t.Errorf("recreating key pair gave different seed on line %d: %x vs %x", lineNo, priv[:32], seed)
}
}
if err := scanner.Err(); err != nil {
@ -146,6 +159,30 @@ func TestGolden(t *testing.T) {
}
}
func TestMalleability(t *testing.T) {
// https://tools.ietf.org/html/rfc8032#section-5.1.7 adds an additional test
// that s be in [0, order). This prevents someone from adding a multiple of
// order to s and obtaining a second valid signature for the same message.
msg := []byte{0x54, 0x65, 0x73, 0x74}
sig := []byte{
0x7c, 0x38, 0xe0, 0x26, 0xf2, 0x9e, 0x14, 0xaa, 0xbd, 0x05, 0x9a,
0x0f, 0x2d, 0xb8, 0xb0, 0xcd, 0x78, 0x30, 0x40, 0x60, 0x9a, 0x8b,
0xe6, 0x84, 0xdb, 0x12, 0xf8, 0x2a, 0x27, 0x77, 0x4a, 0xb0, 0x67,
0x65, 0x4b, 0xce, 0x38, 0x32, 0xc2, 0xd7, 0x6f, 0x8f, 0x6f, 0x5d,
0xaf, 0xc0, 0x8d, 0x93, 0x39, 0xd4, 0xee, 0xf6, 0x76, 0x57, 0x33,
0x36, 0xa5, 0xc5, 0x1e, 0xb6, 0xf9, 0x46, 0xb3, 0x1d,
}
publicKey := []byte{
0x7d, 0x4d, 0x0e, 0x7f, 0x61, 0x53, 0xa6, 0x9b, 0x62, 0x42, 0xb5,
0x22, 0xab, 0xbe, 0xe6, 0x85, 0xfd, 0xa4, 0x42, 0x0f, 0x88, 0x34,
0xb1, 0x08, 0xc3, 0xbd, 0xae, 0x36, 0x9e, 0xf5, 0x49, 0xfa,
}
if Verify(publicKey, msg, sig) {
t.Fatal("non-canonical signature accepted")
}
}
func BenchmarkKeyGeneration(b *testing.B) {
var zero zeroReader
for i := 0; i < b.N; i++ {

@ -4,6 +4,8 @@
package edwards25519
import "encoding/binary"
// This code is a port of the public domain, “ref10” implementation of ed25519
// from SUPERCOP.
@ -1769,3 +1771,23 @@ func ScReduce(out *[32]byte, s *[64]byte) {
out[30] = byte(s11 >> 9)
out[31] = byte(s11 >> 17)
}
// order is the order of Curve25519 in little-endian form.
var order = [4]uint64{0x5812631a5cf5d3ed, 0x14def9dea2f79cd6, 0, 0x1000000000000000}
// ScMinimal returns true if the given scalar is less than the order of the
// curve.
func ScMinimal(scalar *[32]byte) bool {
for i := 3; ; i-- {
v := binary.LittleEndian.Uint64(scalar[i*8:])
if v > order[i] {
return false
} else if v < order[i] {
break
} else if i == 0 {
return false
}
}
return true
}

@ -9,49 +9,44 @@ import (
"crypto/rand"
"crypto/sha256"
"fmt"
"golang.org/x/crypto/hkdf"
"io"
"golang.org/x/crypto/hkdf"
)
// Usage example that expands one master key into three other cryptographically
// secure keys.
// Usage example that expands one master secret into three other
// cryptographically secure keys.
func Example_usage() {
// Underlying hash function to use
// Underlying hash function for HMAC.
hash := sha256.New
// Cryptographically secure master key.
master := []byte{0x00, 0x01, 0x02, 0x03} // i.e. NOT this.
// Cryptographically secure master secret.
secret := []byte{0x00, 0x01, 0x02, 0x03} // i.e. NOT this.
// Non secret salt, optional (can be nil)
// Recommended: hash-length sized random
// Non-secret salt, optional (can be nil).
// Recommended: hash-length random value.
salt := make([]byte, hash().Size())
n, err := io.ReadFull(rand.Reader, salt)
if n != len(salt) || err != nil {
fmt.Println("error:", err)
return
if _, err := rand.Read(salt); err != nil {
panic(err)
}
// Non secret context specific info, optional (can be nil).
// Note, independent from the master key.
info := []byte{0x03, 0x14, 0x15, 0x92, 0x65}
// Non-secret context info, optional (can be nil).
info := []byte("hkdf example")
// Create the key derivation function
hkdf := hkdf.New(hash, master, salt, info)
// Generate three 128-bit derived keys.
hkdf := hkdf.New(hash, secret, salt, info)
// Generate the required keys
keys := make([][]byte, 3)
for i := 0; i < len(keys); i++ {
keys[i] = make([]byte, 24)
n, err := io.ReadFull(hkdf, keys[i])
if n != len(keys[i]) || err != nil {
fmt.Println("error:", err)
return
var keys [][]byte
for i := 0; i < 3; i++ {
key := make([]byte, 16)
if _, err := io.ReadFull(hkdf, key); err != nil {
panic(err)
}
keys = append(keys, key)
}
// Keys should contain 192 bit random keys
for i := 1; i <= len(keys); i++ {
fmt.Printf("Key #%d: %v\n", i, !bytes.Equal(keys[i-1], make([]byte, 24)))
for i := range keys {
fmt.Printf("Key #%d: %v\n", i+1, !bytes.Equal(keys[i], make([]byte, 16)))
}
// Output:

@ -8,8 +8,6 @@
// HKDF is a cryptographic key derivation function (KDF) with the goal of
// expanding limited input keying material into one or more cryptographically
// strong secret keys.
//
// RFC 5869: https://tools.ietf.org/html/rfc5869
package hkdf // import "golang.org/x/crypto/hkdf"
import (
@ -19,6 +17,21 @@ import (
"io"
)
// Extract generates a pseudorandom key for use with Expand from an input secret
// and an optional independent salt.
//
// Only use this function if you need to reuse the extracted key with multiple
// Expand invocations and different context values. Most common scenarios,
// including the generation of multiple keys, should use New instead.
func Extract(hash func() hash.Hash, secret, salt []byte) []byte {
if salt == nil {
salt = make([]byte, hash().Size())
}
extractor := hmac.New(hash, salt)
extractor.Write(secret)
return extractor.Sum(nil)
}
type hkdf struct {
expander hash.Hash
size int
@ -26,22 +39,22 @@ type hkdf struct {
info []byte
counter byte
prev []byte
cache []byte
prev []byte
buf []byte
}
func (f *hkdf) Read(p []byte) (int, error) {
// Check whether enough data can be generated
need := len(p)
remains := len(f.cache) + int(255-f.counter+1)*f.size
remains := len(f.buf) + int(255-f.counter+1)*f.size
if remains < need {
return 0, errors.New("hkdf: entropy limit reached")
}
// Read from the cache, if enough data is present
n := copy(p, f.cache)
// Read any leftover from the buffer
n := copy(p, f.buf)
p = p[n:]
// Fill the buffer
// Fill the rest of the buffer
for len(p) > 0 {
f.expander.Reset()
f.expander.Write(f.prev)
@ -51,25 +64,30 @@ func (f *hkdf) Read(p []byte) (int, error) {
f.counter++
// Copy the new batch into p
f.cache = f.prev
n = copy(p, f.cache)
f.buf = f.prev
n = copy(p, f.buf)
p = p[n:]
}
// Save leftovers for next run
f.cache = f.cache[n:]
f.buf = f.buf[n:]
return need, nil
}
// New returns a new HKDF using the given hash, the secret keying material to expand
// and optional salt and info fields.
func New(hash func() hash.Hash, secret, salt, info []byte) io.Reader {
if salt == nil {
salt = make([]byte, hash().Size())
}
extractor := hmac.New(hash, salt)
extractor.Write(secret)
prk := extractor.Sum(nil)
return &hkdf{hmac.New(hash, prk), extractor.Size(), info, 1, nil, nil}
// Expand returns a Reader, from which keys can be read, using the given
// pseudorandom key and optional context info, skipping the extraction step.
//
// The pseudorandomKey should have been generated by Extract, or be a uniformly
// random or pseudorandom cryptographically strong key. See RFC 5869, Section
// 3.3. Most common scenarios will want to use New instead.
func Expand(hash func() hash.Hash, pseudorandomKey, info []byte) io.Reader {
expander := hmac.New(hash, pseudorandomKey)
return &hkdf{expander, expander.Size(), info, 1, nil, nil}
}
// New returns a Reader, from which keys can be read, using the given hash,
// secret, salt and context info. Salt and info can be nil.
func New(hash func() hash.Hash, secret, salt, info []byte) io.Reader {
prk := Extract(hash, secret, salt)
return Expand(hash, prk, info)
}

@ -18,6 +18,7 @@ type hkdfTest struct {
hash func() hash.Hash
master []byte
salt []byte
prk []byte
info []byte
out []byte
}
@ -35,6 +36,12 @@ var hkdfTests = []hkdfTest{
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
0x08, 0x09, 0x0a, 0x0b, 0x0c,
},
[]byte{
0x07, 0x77, 0x09, 0x36, 0x2c, 0x2e, 0x32, 0xdf,
0x0d, 0xdc, 0x3f, 0x0d, 0xc4, 0x7b, 0xba, 0x63,
0x90, 0xb6, 0xc7, 0x3b, 0xb5, 0x0f, 0x9c, 0x31,
0x22, 0xec, 0x84, 0x4a, 0xd7, 0xc2, 0xb3, 0xe5,
},
[]byte{
0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7,
0xf8, 0xf9,
@ -74,6 +81,12 @@ var hkdfTests = []hkdfTest{
0xa0, 0xa1, 0xa2, 0xa3, 0xa4, 0xa5, 0xa6, 0xa7,
0xa8, 0xa9, 0xaa, 0xab, 0xac, 0xad, 0xae, 0xaf,
},
[]byte{
0x06, 0xa6, 0xb8, 0x8c, 0x58, 0x53, 0x36, 0x1a,
0x06, 0x10, 0x4c, 0x9c, 0xeb, 0x35, 0xb4, 0x5c,
0xef, 0x76, 0x00, 0x14, 0x90, 0x46, 0x71, 0x01,
0x4a, 0x19, 0x3f, 0x40, 0xc1, 0x5f, 0xc2, 0x44,
},
[]byte{
0xb0, 0xb1, 0xb2, 0xb3, 0xb4, 0xb5, 0xb6, 0xb7,
0xb8, 0xb9, 0xba, 0xbb, 0xbc, 0xbd, 0xbe, 0xbf,
@ -108,6 +121,12 @@ var hkdfTests = []hkdfTest{
0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b,
},
[]byte{},
[]byte{
0x19, 0xef, 0x24, 0xa3, 0x2c, 0x71, 0x7b, 0x16,
0x7f, 0x33, 0xa9, 0x1d, 0x6f, 0x64, 0x8b, 0xdf,
0x96, 0x59, 0x67, 0x76, 0xaf, 0xdb, 0x63, 0x77,
0xac, 0x43, 0x4c, 0x1c, 0x29, 0x3c, 0xcb, 0x04,
},
[]byte{},
[]byte{
0x8d, 0xa4, 0xe7, 0x75, 0xa5, 0x63, 0xc1, 0x8f,
@ -118,6 +137,30 @@ var hkdfTests = []hkdfTest{
0x96, 0xc8,
},
},
{
sha256.New,
[]byte{
0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b,
0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b,
0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b,
},
nil,
[]byte{
0x19, 0xef, 0x24, 0xa3, 0x2c, 0x71, 0x7b, 0x16,
0x7f, 0x33, 0xa9, 0x1d, 0x6f, 0x64, 0x8b, 0xdf,
0x96, 0x59, 0x67, 0x76, 0xaf, 0xdb, 0x63, 0x77,
0xac, 0x43, 0x4c, 0x1c, 0x29, 0x3c, 0xcb, 0x04,
},
nil,
[]byte{
0x8d, 0xa4, 0xe7, 0x75, 0xa5, 0x63, 0xc1, 0x8f,
0x71, 0x5f, 0x80, 0x2a, 0x06, 0x3c, 0x5a, 0x31,
0xb8, 0xa1, 0x1f, 0x5c, 0x5e, 0xe1, 0x87, 0x9e,
0xc3, 0x45, 0x4e, 0x5f, 0x3c, 0x73, 0x8d, 0x2d,
0x9d, 0x20, 0x13, 0x95, 0xfa, 0xa4, 0xb6, 0x1a,
0x96, 0xc8,
},
},
{
sha1.New,
[]byte{
@ -128,6 +171,11 @@ var hkdfTests = []hkdfTest{
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
0x08, 0x09, 0x0a, 0x0b, 0x0c,
},
[]byte{
0x9b, 0x6c, 0x18, 0xc4, 0x32, 0xa7, 0xbf, 0x8f,
0x0e, 0x71, 0xc8, 0xeb, 0x88, 0xf4, 0xb3, 0x0b,
0xaa, 0x2b, 0xa2, 0x43,
},
[]byte{
0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7,
0xf8, 0xf9,
@ -167,6 +215,11 @@ var hkdfTests = []hkdfTest{
0xa0, 0xa1, 0xa2, 0xa3, 0xa4, 0xa5, 0xa6, 0xa7,
0xa8, 0xa9, 0xaa, 0xab, 0xac, 0xad, 0xae, 0xaf,
},
[]byte{
0x8a, 0xda, 0xe0, 0x9a, 0x2a, 0x30, 0x70, 0x59,
0x47, 0x8d, 0x30, 0x9b, 0x26, 0xc4, 0x11, 0x5a,
0x22, 0x4c, 0xfa, 0xf6,
},
[]byte{
0xb0, 0xb1, 0xb2, 0xb3, 0xb4, 0xb5, 0xb6, 0xb7,
0xb8, 0xb9, 0xba, 0xbb, 0xbc, 0xbd, 0xbe, 0xbf,
@ -201,6 +254,11 @@ var hkdfTests = []hkdfTest{
0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b,
},
[]byte{},
[]byte{
0xda, 0x8c, 0x8a, 0x73, 0xc7, 0xfa, 0x77, 0x28,
0x8e, 0xc6, 0xf5, 0xe7, 0xc2, 0x97, 0x78, 0x6a,
0xa0, 0xd3, 0x2d, 0x01,
},
[]byte{},
[]byte{
0x0a, 0xc1, 0xaf, 0x70, 0x02, 0xb3, 0xd7, 0x61,
@ -219,7 +277,12 @@ var hkdfTests = []hkdfTest{
0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c,
},
nil,
[]byte{},
[]byte{
0x2a, 0xdc, 0xca, 0xda, 0x18, 0x77, 0x9e, 0x7c,
0x20, 0x77, 0xad, 0x2e, 0xb1, 0x9d, 0x3f, 0x3e,
0x73, 0x13, 0x85, 0xdd,
},
nil,
[]byte{
0x2c, 0x91, 0x11, 0x72, 0x04, 0xd7, 0x45, 0xf3,
0x50, 0x0d, 0x63, 0x6a, 0x62, 0xf6, 0x4f, 0x0a,
@ -233,6 +296,11 @@ var hkdfTests = []hkdfTest{
func TestHKDF(t *testing.T) {
for i, tt := range hkdfTests {
prk := Extract(tt.hash, tt.master, tt.salt)
if !bytes.Equal(prk, tt.prk) {
t.Errorf("test %d: incorrect PRK: have %v, need %v.", i, prk, tt.prk)
}
hkdf := New(tt.hash, tt.master, tt.salt, tt.info)
out := make([]byte, len(tt.out))
@ -244,6 +312,17 @@ func TestHKDF(t *testing.T) {
if !bytes.Equal(out, tt.out) {
t.Errorf("test %d: incorrect output: have %v, need %v.", i, out, tt.out)
}
hkdf = Expand(tt.hash, prk, tt.info)
n, err = io.ReadFull(hkdf, out)
if n != len(tt.out) || err != nil {
t.Errorf("test %d: not enough output bytes from Expand: %d.", i, n)
}
if !bytes.Equal(out, tt.out) {
t.Errorf("test %d: incorrect output from Expand: have %v, need %v.", i, out, tt.out)
}
}
}

@ -2,197 +2,263 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package ChaCha20 implements the core ChaCha20 function as specified in https://tools.ietf.org/html/rfc7539#section-2.3.
// Package ChaCha20 implements the core ChaCha20 function as specified
// in https://tools.ietf.org/html/rfc7539#section-2.3.
package chacha20
import "encoding/binary"
import (
"crypto/cipher"
"encoding/binary"
const rounds = 20
"golang.org/x/crypto/internal/subtle"
)
// core applies the ChaCha20 core function to 16-byte input in, 32-byte key k,
// and 16-byte constant c, and puts the result into 64-byte array out.
func core(out *[64]byte, in *[16]byte, k *[32]byte) {
j0 := uint32(0x61707865)
j1 := uint32(0x3320646e)
j2 := uint32(0x79622d32)
j3 := uint32(0x6b206574)
j4 := binary.LittleEndian.Uint32(k[0:4])
j5 := binary.LittleEndian.Uint32(k[4:8])
j6 := binary.LittleEndian.Uint32(k[8:12])
j7 := binary.LittleEndian.Uint32(k[12:16])
j8 := binary.LittleEndian.Uint32(k[16:20])
j9 := binary.LittleEndian.Uint32(k[20:24])
j10 := binary.LittleEndian.Uint32(k[24:28])
j11 := binary.LittleEndian.Uint32(k[28:32])
j12 := binary.LittleEndian.Uint32(in[0:4])
j13 := binary.LittleEndian.Uint32(in[4:8])
j14 := binary.LittleEndian.Uint32(in[8:12])
j15 := binary.LittleEndian.Uint32(in[12:16])
// assert that *Cipher implements cipher.Stream
var _ cipher.Stream = (*Cipher)(nil)
x0, x1, x2, x3, x4, x5, x6, x7 := j0, j1, j2, j3, j4, j5, j6, j7
x8, x9, x10, x11, x12, x13, x14, x15 := j8, j9, j10, j11, j12, j13, j14, j15
// Cipher is a stateful instance of ChaCha20 using a particular key
// and nonce. A *Cipher implements the cipher.Stream interface.
type Cipher struct {
key [8]uint32
counter uint32 // incremented after each block
nonce [3]uint32
buf [bufSize]byte // buffer for unused keystream bytes
len int // number of unused keystream bytes at end of buf
}
for i := 0; i < rounds; i += 2 {
x0 += x4
x12 ^= x0
x12 = (x12 << 16) | (x12 >> (16))
x8 += x12
x4 ^= x8
x4 = (x4 << 12) | (x4 >> (20))
x0 += x4
x12 ^= x0
x12 = (x12 << 8) | (x12 >> (24))
x8 += x12
x4 ^= x8
x4 = (x4 << 7) | (x4 >> (25))
x1 += x5
x13 ^= x1
x13 = (x13 << 16) | (x13 >> 16)
x9 += x13
x5 ^= x9
x5 = (x5 << 12) | (x5 >> 20)
x1 += x5
x13 ^= x1
x13 = (x13 << 8) | (x13 >> 24)
x9 += x13
x5 ^= x9
x5 = (x5 << 7) | (x5 >> 25)
x2 += x6
x14 ^= x2
x14 = (x14 << 16) | (x14 >> 16)
x10 += x14
x6 ^= x10
x6 = (x6 << 12) | (x6 >> 20)
x2 += x6
x14 ^= x2
x14 = (x14 << 8) | (x14 >> 24)
x10 += x14
x6 ^= x10
x6 = (x6 << 7) | (x6 >> 25)
x3 += x7
x15 ^= x3
x15 = (x15 << 16) | (x15 >> 16)
x11 += x15
x7 ^= x11
x7 = (x7 << 12) | (x7 >> 20)
x3 += x7
x15 ^= x3
x15 = (x15 << 8) | (x15 >> 24)
x11 += x15
x7 ^= x11
x7 = (x7 << 7) | (x7 >> 25)
x0 += x5
x15 ^= x0
x15 = (x15 << 16) | (x15 >> 16)
x10 += x15
x5 ^= x10
x5 = (x5 << 12) | (x5 >> 20)
x0 += x5
x15 ^= x0
x15 = (x15 << 8) | (x15 >> 24)
x10 += x15
x5 ^= x10
x5 = (x5 << 7) | (x5 >> 25)
x1 += x6
x12 ^= x1
x12 = (x12 << 16) | (x12 >> 16)
x11 += x12
x6 ^= x11
x6 = (x6 << 12) | (x6 >> 20)
x1 += x6
x12 ^= x1
x12 = (x12 << 8) | (x12 >> 24)
x11 += x12
x6 ^= x11
x6 = (x6 << 7) | (x6 >> 25)
x2 += x7
x13 ^= x2
x13 = (x13 << 16) | (x13 >> 16)
x8 += x13
x7 ^= x8
x7 = (x7 << 12) | (x7 >> 20)
x2 += x7
x13 ^= x2
x13 = (x13 << 8) | (x13 >> 24)
x8 += x13
x7 ^= x8
x7 = (x7 << 7) | (x7 >> 25)
x3 += x4
x14 ^= x3
x14 = (x14 << 16) | (x14 >> 16)
x9 += x14
x4 ^= x9
x4 = (x4 << 12) | (x4 >> 20)
x3 += x4
x14 ^= x3
x14 = (x14 << 8) | (x14 >> 24)
x9 += x14
x4 ^= x9
x4 = (x4 << 7) | (x4 >> 25)
// New creates a new ChaCha20 stream cipher with the given key and nonce.
// The initial counter value is set to 0.
func New(key [8]uint32, nonce [3]uint32) *Cipher {
return &Cipher{key: key, nonce: nonce}
}
// ChaCha20 constants spelling "expand 32-byte k"
const (
j0 uint32 = 0x61707865
j1 uint32 = 0x3320646e
j2 uint32 = 0x79622d32
j3 uint32 = 0x6b206574
)
func quarterRound(a, b, c, d uint32) (uint32, uint32, uint32, uint32) {
a += b
d ^= a
d = (d << 16) | (d >> 16)
c += d
b ^= c
b = (b << 12) | (b >> 20)
a += b
d ^= a
d = (d << 8) | (d >> 24)
c += d
b ^= c
b = (b << 7) | (b >> 25)
return a, b, c, d
}
// XORKeyStream XORs each byte in the given slice with a byte from the
// cipher's key stream. Dst and src must overlap entirely or not at all.
//
// If len(dst) < len(src), XORKeyStream will panic. It is acceptable
// to pass a dst bigger than src, and in that case, XORKeyStream will
// only update dst[:len(src)] and will not touch the rest of dst.
//
// Multiple calls to XORKeyStream behave as if the concatenation of
// the src buffers was passed in a single run. That is, Cipher
// maintains state and does not reset at each XORKeyStream call.
func (s *Cipher) XORKeyStream(dst, src []byte) {
if len(dst) < len(src) {
panic("chacha20: output smaller than input")
}
if subtle.InexactOverlap(dst[:len(src)], src) {
panic("chacha20: invalid buffer overlap")
}
x0 += j0
x1 += j1
x2 += j2
x3 += j3
x4 += j4
x5 += j5
x6 += j6
x7 += j7
x8 += j8
x9 += j9
x10 += j10
x11 += j11
x12 += j12
x13 += j13
x14 += j14
x15 += j15
// xor src with buffered keystream first
if s.len != 0 {
buf := s.buf[len(s.buf)-s.len:]
if len(src) < len(buf) {
buf = buf[:len(src)]
}
td, ts := dst[:len(buf)], src[:len(buf)] // BCE hint
for i, b := range buf {
td[i] = ts[i] ^ b
}
s.len -= len(buf)
if s.len != 0 {
return
}
s.buf = [len(s.buf)]byte{} // zero the empty buffer
src = src[len(buf):]
dst = dst[len(buf):]
}
binary.LittleEndian.PutUint32(out[0:4], x0)
binary.LittleEndian.PutUint32(out[4:8], x1)
binary.LittleEndian.PutUint32(out[8:12], x2)
binary.LittleEndian.PutUint32(out[12:16], x3)
binary.LittleEndian.PutUint32(out[16:20], x4)
binary.LittleEndian.PutUint32(out[20:24], x5)
binary.LittleEndian.PutUint32(out[24:28], x6)
binary.LittleEndian.PutUint32(out[28:32], x7)
binary.LittleEndian.PutUint32(out[32:36], x8)
binary.LittleEndian.PutUint32(out[36:40], x9)
binary.LittleEndian.PutUint32(out[40:44], x10)
binary.LittleEndian.PutUint32(out[44:48], x11)
binary.LittleEndian.PutUint32(out[48:52], x12)
binary.LittleEndian.PutUint32(out[52:56], x13)
binary.LittleEndian.PutUint32(out[56:60], x14)
binary.LittleEndian.PutUint32(out[60:64], x15)
if len(src) == 0 {
return
}
if haveAsm {
if uint64(len(src))+uint64(s.counter)*64 > (1<<38)-64 {
panic("chacha20: counter overflow")
}
s.xorKeyStreamAsm(dst, src)
return
}
// set up a 64-byte buffer to pad out the final block if needed
// (hoisted out of the main loop to avoid spills)
rem := len(src) % 64 // length of final block
fin := len(src) - rem // index of final block
if rem > 0 {
copy(s.buf[len(s.buf)-64:], src[fin:])
}
// pre-calculate most of the first round
s1, s5, s9, s13 := quarterRound(j1, s.key[1], s.key[5], s.nonce[0])
s2, s6, s10, s14 := quarterRound(j2, s.key[2], s.key[6], s.nonce[1])
s3, s7, s11, s15 := quarterRound(j3, s.key[3], s.key[7], s.nonce[2])
n := len(src)
src, dst = src[:n:n], dst[:n:n] // BCE hint
for i := 0; i < n; i += 64 {
// calculate the remainder of the first round
s0, s4, s8, s12 := quarterRound(j0, s.key[0], s.key[4], s.counter)
// execute the second round
x0, x5, x10, x15 := quarterRound(s0, s5, s10, s15)
x1, x6, x11, x12 := quarterRound(s1, s6, s11, s12)
x2, x7, x8, x13 := quarterRound(s2, s7, s8, s13)
x3, x4, x9, x14 := quarterRound(s3, s4, s9, s14)
// execute the remaining 18 rounds
for i := 0; i < 9; i++ {
x0, x4, x8, x12 = quarterRound(x0, x4, x8, x12)
x1, x5, x9, x13 = quarterRound(x1, x5, x9, x13)
x2, x6, x10, x14 = quarterRound(x2, x6, x10, x14)
x3, x7, x11, x15 = quarterRound(x3, x7, x11, x15)
x0, x5, x10, x15 = quarterRound(x0, x5, x10, x15)
x1, x6, x11, x12 = quarterRound(x1, x6, x11, x12)
x2, x7, x8, x13 = quarterRound(x2, x7, x8, x13)
x3, x4, x9, x14 = quarterRound(x3, x4, x9, x14)
}
x0 += j0
x1 += j1
x2 += j2
x3 += j3
x4 += s.key[0]
x5 += s.key[1]
x6 += s.key[2]
x7 += s.key[3]
x8 += s.key[4]
x9 += s.key[5]
x10 += s.key[6]
x11 += s.key[7]
x12 += s.counter
x13 += s.nonce[0]
x14 += s.nonce[1]
x15 += s.nonce[2]
// increment the counter
s.counter += 1
if s.counter == 0 {
panic("chacha20: counter overflow")
}
// pad to 64 bytes if needed
in, out := src[i:], dst[i:]
if i == fin {
// src[fin:] has already been copied into s.buf before
// the main loop
in, out = s.buf[len(s.buf)-64:], s.buf[len(s.buf)-64:]
}
in, out = in[:64], out[:64] // BCE hint
// XOR the key stream with the source and write out the result
xor(out[0:], in[0:], x0)
xor(out[4:], in[4:], x1)
xor(out[8:], in[8:], x2)
xor(out[12:], in[12:], x3)
xor(out[16:], in[16:], x4)
xor(out[20:], in[20:], x5)
xor(out[24:], in[24:], x6)
xor(out[28:], in[28:], x7)
xor(out[32:], in[32:], x8)
xor(out[36:], in[36:], x9)
xor(out[40:], in[40:], x10)
xor(out[44:], in[44:], x11)
xor(out[48:], in[48:], x12)
xor(out[52:], in[52:], x13)
xor(out[56:], in[56:], x14)
xor(out[60:], in[60:], x15)
}
// copy any trailing bytes out of the buffer and into dst
if rem != 0 {
s.len = 64 - rem
copy(dst[fin:], s.buf[len(s.buf)-64:])
}
}
// Advance discards bytes in the key stream until the next 64 byte block
// boundary is reached and updates the counter accordingly. If the key
// stream is already at a block boundary no bytes will be discarded and
// the counter will be unchanged.
func (s *Cipher) Advance() {
s.len -= s.len % 64
if s.len == 0 {
s.buf = [len(s.buf)]byte{}
}
}
// XORKeyStream crypts bytes from in to out using the given key and counters.
// In and out must overlap entirely or not at all. Counter contains the raw
// ChaCha20 counter bytes (i.e. block counter followed by nonce).
func XORKeyStream(out, in []byte, counter *[16]byte, key *[32]byte) {
var block [64]byte
var counterCopy [16]byte
copy(counterCopy[:], counter[:])
for len(in) >= 64 {
core(&block, &counterCopy, key)
for i, x := range block {
out[i] = in[i] ^ x
}
u := uint32(1)
for i := 0; i < 4; i++ {
u += uint32(counterCopy[i])
counterCopy[i] = byte(u)
u >>= 8
}
in = in[64:]
out = out[64:]
}
if len(in) > 0 {
core(&block, &counterCopy, key)
for i, v := range in {
out[i] = v ^ block[i]
}
s := Cipher{
key: [8]uint32{
binary.LittleEndian.Uint32(key[0:4]),
binary.LittleEndian.Uint32(key[4:8]),
binary.LittleEndian.Uint32(key[8:12]),
binary.LittleEndian.Uint32(key[12:16]),
binary.LittleEndian.Uint32(key[16:20]),
binary.LittleEndian.Uint32(key[20:24]),
binary.LittleEndian.Uint32(key[24:28]),
binary.LittleEndian.Uint32(key[28:32]),
},
nonce: [3]uint32{
binary.LittleEndian.Uint32(counter[4:8]),
binary.LittleEndian.Uint32(counter[8:12]),
binary.LittleEndian.Uint32(counter[12:16]),
},
counter: binary.LittleEndian.Uint32(counter[0:4]),
}
s.XORKeyStream(out, in)
}
// HChaCha20 uses the ChaCha20 core to generate a derived key from a key and a
// nonce. It should only be used as part of the XChaCha20 construction.
func HChaCha20(key *[8]uint32, nonce *[4]uint32) [8]uint32 {
x0, x1, x2, x3 := j0, j1, j2, j3
x4, x5, x6, x7 := key[0], key[1], key[2], key[3]
x8, x9, x10, x11 := key[4], key[5], key[6], key[7]
x12, x13, x14, x15 := nonce[0], nonce[1], nonce[2], nonce[3]
for i := 0; i < 10; i++ {
x0, x4, x8, x12 = quarterRound(x0, x4, x8, x12)
x1, x5, x9, x13 = quarterRound(x1, x5, x9, x13)
x2, x6, x10, x14 = quarterRound(x2, x6, x10, x14)
x3, x7, x11, x15 = quarterRound(x3, x7, x11, x15)
x0, x5, x10, x15 = quarterRound(x0, x5, x10, x15)
x1, x6, x11, x12 = quarterRound(x1, x6, x11, x12)
x2, x7, x8, x13 = quarterRound(x2, x7, x8, x13)
x3, x4, x9, x14 = quarterRound(x3, x4, x9, x14)
}
var out [8]uint32
out[0], out[1], out[2], out[3] = x0, x1, x2, x3
out[4], out[5], out[6], out[7] = x12, x13, x14, x15
return out
}

@ -5,7 +5,10 @@
package chacha20
import (
"encoding/binary"
"encoding/hex"
"fmt"
"math/rand"
"testing"
)
@ -31,3 +34,192 @@ func TestCore(t *testing.T) {
t.Errorf("wanted %x but got %x", expected, result)
}
}
// Run the test cases with the input and output in different buffers.
func TestNoOverlap(t *testing.T) {
for _, c := range testVectors {
s := New(c.key, c.nonce)
input, err := hex.DecodeString(c.input)
if err != nil {
t.Fatalf("cannot decode input %#v: %v", c.input, err)
}
output := make([]byte, c.length)
s.XORKeyStream(output, input)
got := hex.EncodeToString(output)
if got != c.output {
t.Errorf("length=%v: got %#v, want %#v", c.length, got, c.output)
}
}
}
// Run the test cases with the input and output overlapping entirely.
func TestOverlap(t *testing.T) {
for _, c := range testVectors {
s := New(c.key, c.nonce)
data, err := hex.DecodeString(c.input)
if err != nil {
t.Fatalf("cannot decode input %#v: %v", c.input, err)
}
s.XORKeyStream(data, data)
got := hex.EncodeToString(data)
if got != c.output {
t.Errorf("length=%v: got %#v, want %#v", c.length, got, c.output)
}
}
}
// Run the test cases with various source and destination offsets.
func TestUnaligned(t *testing.T) {
const max = 8 // max offset (+1) to test
for _, c := range testVectors {
input := make([]byte, c.length+max)
output := make([]byte, c.length+max)
for i := 0; i < max; i++ { // input offsets
for j := 0; j < max; j++ { // output offsets
s := New(c.key, c.nonce)
input := input[i : i+c.length]
output := output[j : j+c.length]
data, err := hex.DecodeString(c.input)
if err != nil {
t.Fatalf("cannot decode input %#v: %v", c.input, err)
}
copy(input, data)
s.XORKeyStream(output, input)
got := hex.EncodeToString(output)
if got != c.output {
t.Errorf("length=%v: got %#v, want %#v", c.length, got, c.output)
}
}
}
}
}
// Run the test cases by calling XORKeyStream multiple times.
func TestStep(t *testing.T) {
// wide range of step sizes to try and hit edge cases
steps := [...]int{1, 3, 4, 7, 8, 17, 24, 30, 64, 256}
rnd := rand.New(rand.NewSource(123))
for _, c := range testVectors {
s := New(c.key, c.nonce)
input, err := hex.DecodeString(c.input)
if err != nil {
t.Fatalf("cannot decode input %#v: %v", c.input, err)
}
output := make([]byte, c.length)
// step through the buffers
i, step := 0, steps[rnd.Intn(len(steps))]
for i+step < c.length {
s.XORKeyStream(output[i:i+step], input[i:i+step])
if i+step < c.length && output[i+step] != 0 {
t.Errorf("length=%v, i=%v, step=%v: output overwritten", c.length, i, step)
}
i += step
step = steps[rnd.Intn(len(steps))]
}
// finish the encryption
s.XORKeyStream(output[i:], input[i:])
got := hex.EncodeToString(output)
if got != c.output {
t.Errorf("length=%v: got %#v, want %#v", c.length, got, c.output)
}
}
}
// Test that Advance() discards bytes until a block boundary is hit.
func TestAdvance(t *testing.T) {
for _, c := range testVectors {
for i := 0; i < 63; i++ {
s := New(c.key, c.nonce)
z := New(c.key, c.nonce)
input, err := hex.DecodeString(c.input)
if err != nil {
t.Fatalf("cannot decode input %#v: %v", c.input, err)
}
zeros, discard := make([]byte, 64), make([]byte, 64)
so, zo := make([]byte, c.length), make([]byte, c.length)
for j := 0; j < c.length; j += 64 {
lim := j + i
if lim > c.length {
lim = c.length
}
s.XORKeyStream(so[j:lim], input[j:lim])
// calling s.Advance() multiple times should have no effect
for k := 0; k < i%3+1; k++ {
s.Advance()
}
z.XORKeyStream(zo[j:lim], input[j:lim])
if lim < c.length {
end := 64 - i
if c.length-lim < end {
end = c.length - lim
}
z.XORKeyStream(discard[:], zeros[:end])
}
}
got := hex.EncodeToString(so)
want := hex.EncodeToString(zo)
if got != want {
t.Errorf("length=%v: got %#v, want %#v", c.length, got, want)
}
}
}
}
func BenchmarkChaCha20(b *testing.B) {
sizes := []int{32, 63, 64, 256, 1024, 1350, 65536}
for _, size := range sizes {
s := size
b.Run(fmt.Sprint(s), func(b *testing.B) {
k := [32]byte{}
c := [16]byte{}
src := make([]byte, s)
dst := make([]byte, s)
b.SetBytes(int64(s))
b.ResetTimer()
for i := 0; i < b.N; i++ {
XORKeyStream(dst, src, &c, &k)
}
})
}
}
func TestHChaCha20(t *testing.T) {
// See draft-paragon-paseto-rfc-00 §7.2.1.
key := []byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f}
nonce := []byte{0x00, 0x00, 0x00, 0x09, 0x00, 0x00, 0x00, 0x4a,
0x00, 0x00, 0x00, 0x00, 0x31, 0x41, 0x59, 0x27}
expected := []byte{0x82, 0x41, 0x3b, 0x42, 0x27, 0xb2, 0x7b, 0xfe,
0xd3, 0x0e, 0x42, 0x50, 0x8a, 0x87, 0x7d, 0x73,
0xa0, 0xf9, 0xe4, 0xd5, 0x8a, 0x74, 0xa8, 0x53,
0xc1, 0x2e, 0xc4, 0x13, 0x26, 0xd3, 0xec, 0xdc,
}
result := HChaCha20(&[8]uint32{
binary.LittleEndian.Uint32(key[0:4]),
binary.LittleEndian.Uint32(key[4:8]),
binary.LittleEndian.Uint32(key[8:12]),
binary.LittleEndian.Uint32(key[12:16]),
binary.LittleEndian.Uint32(key[16:20]),
binary.LittleEndian.Uint32(key[20:24]),
binary.LittleEndian.Uint32(key[24:28]),
binary.LittleEndian.Uint32(key[28:32]),
}, &[4]uint32{
binary.LittleEndian.Uint32(nonce[0:4]),
binary.LittleEndian.Uint32(nonce[4:8]),
binary.LittleEndian.Uint32(nonce[8:12]),
binary.LittleEndian.Uint32(nonce[12:16]),
})
for i := 0; i < 8; i++ {
want := binary.LittleEndian.Uint32(expected[i*4 : (i+1)*4])
if got := result[i]; got != want {
t.Errorf("word %d incorrect: want 0x%x, got 0x%x", i, want, got)
}
}
}

@ -3,6 +3,10 @@
// license that can be found in the LICENSE file.
// Package md4 implements the MD4 hash algorithm as defined in RFC 1320.
//
// Deprecated: MD4 is cryptographically broken and should should only be used
// where compatibility with legacy systems, not security, is the goal. Instead,
// use a secure hash like SHA-256 (from crypto/sha256).
package md4 // import "golang.org/x/crypto/md4"
import (

@ -35,6 +35,7 @@ This package is interoperable with NaCl: https://nacl.cr.yp.to/secretbox.html.
package secretbox // import "golang.org/x/crypto/nacl/secretbox"
import (
"golang.org/x/crypto/internal/subtle"
"golang.org/x/crypto/poly1305"
"golang.org/x/crypto/salsa20/salsa"
)
@ -87,6 +88,9 @@ func Seal(out, message []byte, nonce *[24]byte, key *[32]byte) []byte {
copy(poly1305Key[:], firstBlock[:])
ret, out := sliceForAppend(out, len(message)+poly1305.TagSize)
if subtle.AnyOverlap(out, message) {
panic("nacl: invalid buffer overlap")
}
// We XOR up to 32 bytes of message with the keystream generated from
// the first block.
@ -118,7 +122,7 @@ func Seal(out, message []byte, nonce *[24]byte, key *[32]byte) []byte {
// Open authenticates and decrypts a box produced by Seal and appends the
// message to out, which must not overlap box. The output will be Overhead
// bytes smaller than box.
func Open(out []byte, box []byte, nonce *[24]byte, key *[32]byte) ([]byte, bool) {
func Open(out, box []byte, nonce *[24]byte, key *[32]byte) ([]byte, bool) {
if len(box) < Overhead {
return nil, false
}
@ -143,6 +147,9 @@ func Open(out []byte, box []byte, nonce *[24]byte, key *[32]byte) ([]byte, bool)
}
ret, out := sliceForAppend(out, len(box)-Overhead)
if subtle.AnyOverlap(out, box) {
panic("nacl: invalid buffer overlap")
}
// We XOR up to 32 bytes of box with the keystream generated from
// the first block.

@ -63,7 +63,7 @@ func (r ResponseStatus) String() string {
}
// ResponseError is an error that may be returned by ParseResponse to indicate
// that the response itself is an error, not just that its indicating that a
// that the response itself is an error, not just that it's indicating that a
// certificate is revoked, unknown, etc.
type ResponseError struct {
Status ResponseStatus
@ -487,9 +487,8 @@ func ParseResponseForCert(bytes []byte, cert, issuer *x509.Certificate) (*Respon
if err != nil {
return nil, err
}
if len(basicResp.Certificates) > 1 {
return nil, ParseError("OCSP response contains bad number of certificates")
if len(rest) > 0 {
return nil, ParseError("trailing data in OCSP response")
}
if n := len(basicResp.TBSResponseData.Responses); n == 0 || cert == nil && n > 1 {
@ -544,6 +543,13 @@ func ParseResponseForCert(bytes []byte, cert, issuer *x509.Certificate) (*Respon
}
if len(basicResp.Certificates) > 0 {
// Responders should only send a single certificate (if they
// send any) that connects the responder's certificate to the
// original issuer. We accept responses with multiple
// certificates due to a number responders sending them[1], but
// ignore all but the first.
//
// [1] https://github.com/golang/go/issues/21527
ret.Certificate, err = x509.ParseCertificate(basicResp.Certificates[0].FullBytes)
if err != nil {
return nil, err

@ -13,6 +13,7 @@ import (
"bufio"
"bytes"
"crypto"
"fmt"
"hash"
"io"
"net/textproto"
@ -177,8 +178,9 @@ func Decode(data []byte) (b *Block, rest []byte) {
// message.
type dashEscaper struct {
buffered *bufio.Writer
h hash.Hash
hashers []hash.Hash // one per key in privateKeys
hashType crypto.Hash
toHash io.Writer // writes to all the hashes in hashers
atBeginningOfLine bool
isFirstLine bool
@ -186,8 +188,8 @@ type dashEscaper struct {
whitespace []byte
byteBuf []byte // a one byte buffer to save allocations
privateKey *packet.PrivateKey
config *packet.Config
privateKeys []*packet.PrivateKey
config *packet.Config
}
func (d *dashEscaper) Write(data []byte) (n int, err error) {
@ -198,7 +200,7 @@ func (d *dashEscaper) Write(data []byte) (n int, err error) {
// The final CRLF isn't included in the hash so we have to wait
// until this point (the start of the next line) before writing it.
if !d.isFirstLine {
d.h.Write(crlf)
d.toHash.Write(crlf)
}
d.isFirstLine = false
}
@ -219,12 +221,12 @@ func (d *dashEscaper) Write(data []byte) (n int, err error) {
if _, err = d.buffered.Write(dashEscape); err != nil {
return
}
d.h.Write(d.byteBuf)
d.toHash.Write(d.byteBuf)
d.atBeginningOfLine = false
} else if b == '\n' {
// Nothing to do because we delay writing CRLF to the hash.
} else {
d.h.Write(d.byteBuf)
d.toHash.Write(d.byteBuf)
d.atBeginningOfLine = false
}
if err = d.buffered.WriteByte(b); err != nil {
@ -245,13 +247,13 @@ func (d *dashEscaper) Write(data []byte) (n int, err error) {
// Any buffered whitespace wasn't at the end of the line so
// we need to write it out.
if len(d.whitespace) > 0 {
d.h.Write(d.whitespace)
d.toHash.Write(d.whitespace)
if _, err = d.buffered.Write(d.whitespace); err != nil {
return
}
d.whitespace = d.whitespace[:0]
}
d.h.Write(d.byteBuf)
d.toHash.Write(d.byteBuf)
if err = d.buffered.WriteByte(b); err != nil {
return
}
@ -269,25 +271,29 @@ func (d *dashEscaper) Close() (err error) {
return
}
}
sig := new(packet.Signature)
sig.SigType = packet.SigTypeText
sig.PubKeyAlgo = d.privateKey.PubKeyAlgo
sig.Hash = d.hashType
sig.CreationTime = d.config.Now()
sig.IssuerKeyId = &d.privateKey.KeyId
if err = sig.Sign(d.h, d.privateKey, d.config); err != nil {
return
}
out, err := armor.Encode(d.buffered, "PGP SIGNATURE", nil)
if err != nil {
return
}
if err = sig.Serialize(out); err != nil {
return
t := d.config.Now()
for i, k := range d.privateKeys {
sig := new(packet.Signature)
sig.SigType = packet.SigTypeText
sig.PubKeyAlgo = k.PubKeyAlgo
sig.Hash = d.hashType
sig.CreationTime = t
sig.IssuerKeyId = &k.KeyId
if err = sig.Sign(d.hashers[i], k, d.config); err != nil {
return
}
if err = sig.Serialize(out); err != nil {
return
}
}
if err = out.Close(); err != nil {
return
}
@ -300,8 +306,17 @@ func (d *dashEscaper) Close() (err error) {
// Encode returns a WriteCloser which will clear-sign a message with privateKey
// and write it to w. If config is nil, sensible defaults are used.
func Encode(w io.Writer, privateKey *packet.PrivateKey, config *packet.Config) (plaintext io.WriteCloser, err error) {
if privateKey.Encrypted {
return nil, errors.InvalidArgumentError("signing key is encrypted")
return EncodeMulti(w, []*packet.PrivateKey{privateKey}, config)
}
// EncodeMulti returns a WriteCloser which will clear-sign a message with all the
// private keys indicated and write it to w. If config is nil, sensible defaults
// are used.
func EncodeMulti(w io.Writer, privateKeys []*packet.PrivateKey, config *packet.Config) (plaintext io.WriteCloser, err error) {
for _, k := range privateKeys {
if k.Encrypted {
return nil, errors.InvalidArgumentError(fmt.Sprintf("signing key %s is encrypted", k.KeyIdString()))
}
}
hashType := config.Hash()
@ -313,7 +328,14 @@ func Encode(w io.Writer, privateKey *packet.PrivateKey, config *packet.Config) (
if !hashType.Available() {
return nil, errors.UnsupportedError("unsupported hash type: " + strconv.Itoa(int(hashType)))
}
h := hashType.New()
var hashers []hash.Hash
var ws []io.Writer
for range privateKeys {
h := hashType.New()
hashers = append(hashers, h)
ws = append(ws, h)
}
toHash := io.MultiWriter(ws...)
buffered := bufio.NewWriter(w)
// start has a \n at the beginning that we don't want here.
@ -338,16 +360,17 @@ func Encode(w io.Writer, privateKey *packet.PrivateKey, config *packet.Config) (
plaintext = &dashEscaper{
buffered: buffered,
h: h,
hashers: hashers,
hashType: hashType,
toHash: toHash,
atBeginningOfLine: true,
isFirstLine: true,
byteBuf: make([]byte, 1),
privateKey: privateKey,
config: config,
privateKeys: privateKeys,
config: config,
}
return

@ -6,8 +6,11 @@ package clearsign
import (
"bytes"
"golang.org/x/crypto/openpgp"
"fmt"
"testing"
"golang.org/x/crypto/openpgp"
"golang.org/x/crypto/openpgp/packet"
)
func testParse(t *testing.T, input []byte, expected, expectedPlaintext string) {
@ -125,6 +128,71 @@ func TestSigning(t *testing.T) {
}
}
// We use this to make test keys, so that they aren't all the same.
type quickRand byte
func (qr *quickRand) Read(p []byte) (int, error) {
for i := range p {
p[i] = byte(*qr)
}
*qr++
return len(p), nil
}
func TestMultiSign(t *testing.T) {
zero := quickRand(0)
config := packet.Config{Rand: &zero}
for nKeys := 0; nKeys < 4; nKeys++ {
nextTest:
for nExtra := 0; nExtra < 4; nExtra++ {
var signKeys []*packet.PrivateKey
var verifyKeys openpgp.EntityList
desc := fmt.Sprintf("%d keys; %d of which will be used to verify", nKeys+nExtra, nKeys)
for i := 0; i < nKeys+nExtra; i++ {
e, err := openpgp.NewEntity("name", "comment", "email", &config)
if err != nil {
t.Errorf("cannot create key: %v", err)
continue nextTest
}
if i < nKeys {
verifyKeys = append(verifyKeys, e)
}
signKeys = append(signKeys, e.PrivateKey)
}
input := []byte("this is random text\r\n4 17")
var output bytes.Buffer
w, err := EncodeMulti(&output, signKeys, nil)
if err != nil {
t.Errorf("EncodeMulti (%s) failed: %v", desc, err)
}
if _, err := w.Write(input); err != nil {
t.Errorf("Write(%q) to signer (%s) failed: %v", string(input), desc, err)
}
if err := w.Close(); err != nil {
t.Errorf("Close() of signer (%s) failed: %v", desc, err)
}
block, _ := Decode(output.Bytes())
if string(block.Bytes) != string(input) {
t.Errorf("Inline data didn't match original; got %q want %q", string(block.Bytes), string(input))
}
_, err = openpgp.CheckDetachedSignature(verifyKeys, bytes.NewReader(block.Bytes), block.ArmoredSignature.Body)
if nKeys == 0 {
if err == nil {
t.Errorf("verifying inline (%s) succeeded; want failure", desc)
}
} else {
if err != nil {
t.Errorf("verifying inline (%s) failed (%v); want success", desc, err)
}
}
}
}
}
var clearsignInput = []byte(`
;lasjlkfdsa

@ -333,7 +333,6 @@ func ReadEntity(packets *packet.Reader) (*Entity, error) {
return nil, errors.StructuralError("primary key cannot be used for signatures")
}
var current *Identity
var revocations []*packet.Signature
EachPacket:
for {
@ -346,32 +345,8 @@ EachPacket:
switch pkt := p.(type) {
case *packet.UserId:
current = new(Identity)
current.Name = pkt.Id
current.UserId = pkt
e.Identities[pkt.Id] = current
for {
p, err = packets.Next()
if err == io.EOF {
return nil, io.ErrUnexpectedEOF
} else if err != nil {
return nil, err
}
sig, ok := p.(*packet.Signature)
if !ok {
return nil, errors.StructuralError("user ID packet not followed by self-signature")
}
if (sig.SigType == packet.SigTypePositiveCert || sig.SigType == packet.SigTypeGenericCert) && sig.IssuerKeyId != nil && *sig.IssuerKeyId == e.PrimaryKey.KeyId {
if err = e.PrimaryKey.VerifyUserIdSignature(pkt.Id, e.PrimaryKey, sig); err != nil {
return nil, errors.StructuralError("user ID self-signature invalid: " + err.Error())
}
current.SelfSignature = sig
break
}
current.Signatures = append(current.Signatures, sig)
if err := addUserID(e, packets, pkt); err != nil {
return nil, err
}
case *packet.Signature:
if pkt.SigType == packet.SigTypeKeyRevocation {
@ -380,11 +355,9 @@ EachPacket:
// TODO: RFC4880 5.2.1 permits signatures
// directly on keys (eg. to bind additional
// revocation keys).
} else if current == nil {
return nil, errors.StructuralError("signature packet found before user id packet")
} else {
current.Signatures = append(current.Signatures, pkt)
}
// Else, ignoring the signature as it does not follow anything
// we would know to attach it to.
case *packet.PrivateKey:
if pkt.IsSubkey == false {
packets.Unread(p)
@ -425,33 +398,105 @@ EachPacket:
return e, nil
}
func addUserID(e *Entity, packets *packet.Reader, pkt *packet.UserId) error {
// Make a new Identity object, that we might wind up throwing away.
// We'll only add it if we get a valid self-signature over this
// userID.
identity := new(Identity)
identity.Name = pkt.Id
identity.UserId = pkt
for {
p, err := packets.Next()
if err == io.EOF {
break
} else if err != nil {
return err
}
sig, ok := p.(*packet.Signature)
if !ok {
packets.Unread(p)
break
}
if (sig.SigType == packet.SigTypePositiveCert || sig.SigType == packet.SigTypeGenericCert) && sig.IssuerKeyId != nil && *sig.IssuerKeyId == e.PrimaryKey.KeyId {
if err = e.PrimaryKey.VerifyUserIdSignature(pkt.Id, e.PrimaryKey, sig); err != nil {
return errors.StructuralError("user ID self-signature invalid: " + err.Error())
}
identity.SelfSignature = sig
e.Identities[pkt.Id] = identity
} else {
identity.Signatures = append(identity.Signatures, sig)
}
}
return nil
}
func addSubkey(e *Entity, packets *packet.Reader, pub *packet.PublicKey, priv *packet.PrivateKey) error {
var subKey Subkey
subKey.PublicKey = pub
subKey.PrivateKey = priv
p, err := packets.Next()
if err == io.EOF {
return io.ErrUnexpectedEOF
for {
p, err := packets.Next()
if err == io.EOF {
break
} else if err != nil {
return errors.StructuralError("subkey signature invalid: " + err.Error())
}
sig, ok := p.(*packet.Signature)
if !ok {
packets.Unread(p)
break
}
if sig.SigType != packet.SigTypeSubkeyBinding && sig.SigType != packet.SigTypeSubkeyRevocation {
return errors.StructuralError("subkey signature with wrong type")
}
if err := e.PrimaryKey.VerifyKeySignature(subKey.PublicKey, sig); err != nil {
return errors.StructuralError("subkey signature invalid: " + err.Error())
}
switch sig.SigType {
case packet.SigTypeSubkeyRevocation:
subKey.Sig = sig
case packet.SigTypeSubkeyBinding:
if shouldReplaceSubkeySig(subKey.Sig, sig) {
subKey.Sig = sig
}
}
}
if err != nil {
return errors.StructuralError("subkey signature invalid: " + err.Error())
}
var ok bool
subKey.Sig, ok = p.(*packet.Signature)
if !ok {
if subKey.Sig == nil {
return errors.StructuralError("subkey packet not followed by signature")
}
if subKey.Sig.SigType != packet.SigTypeSubkeyBinding && subKey.Sig.SigType != packet.SigTypeSubkeyRevocation {
return errors.StructuralError("subkey signature with wrong type")
}
err = e.PrimaryKey.VerifyKeySignature(subKey.PublicKey, subKey.Sig)
if err != nil {
return errors.StructuralError("subkey signature invalid: " + err.Error())
}
e.Subkeys = append(e.Subkeys, subKey)
return nil
}
func shouldReplaceSubkeySig(existingSig, potentialNewSig *packet.Signature) bool {
if potentialNewSig == nil {
return false
}
if existingSig == nil {
return true
}
if existingSig.SigType == packet.SigTypeSubkeyRevocation {
return false // never override a revocation signature
}
return potentialNewSig.CreationTime.After(existingSig.CreationTime)
}
const defaultRSAKeyBits = 2048
// NewEntity returns an Entity that contains a fresh RSA/RSA keypair with a
@ -486,7 +531,7 @@ func NewEntity(name, comment, email string, config *packet.Config) (*Entity, err
}
isPrimaryId := true
e.Identities[uid.Id] = &Identity{
Name: uid.Name,
Name: uid.Id,
UserId: uid,
SelfSignature: &packet.Signature{
CreationTime: currentTime,
@ -500,6 +545,10 @@ func NewEntity(name, comment, email string, config *packet.Config) (*Entity, err
IssuerKeyId: &e.PrimaryKey.KeyId,
},
}
err = e.Identities[uid.Id].SelfSignature.SignUserId(uid.Id, e.PrimaryKey, e.PrivateKey, config)
if err != nil {
return nil, err
}
// If the user passes in a DefaultHash via packet.Config,
// set the PreferredHash for the SelfSignature.
@ -507,6 +556,11 @@ func NewEntity(name, comment, email string, config *packet.Config) (*Entity, err
e.Identities[uid.Id].SelfSignature.PreferredHash = []uint8{hashToHashId(config.DefaultHash)}
}
// Likewise for DefaultCipher.
if config != nil && config.DefaultCipher != 0 {
e.Identities[uid.Id].SelfSignature.PreferredSymmetric = []uint8{uint8(config.DefaultCipher)}
}
e.Subkeys = make([]Subkey, 1)
e.Subkeys[0] = Subkey{
PublicKey: packet.NewRSAPublicKey(currentTime, &encryptingPriv.PublicKey),
@ -524,13 +578,16 @@ func NewEntity(name, comment, email string, config *packet.Config) (*Entity, err
}
e.Subkeys[0].PublicKey.IsSubkey = true
e.Subkeys[0].PrivateKey.IsSubkey = true
err = e.Subkeys[0].Sig.SignKey(e.Subkeys[0].PublicKey, e.PrivateKey, config)
if err != nil {
return nil, err
}
return e, nil
}
// SerializePrivate serializes an Entity, including private key material, to
// the given Writer. For now, it must only be used on an Entity returned from
// NewEntity.
// SerializePrivate serializes an Entity, including private key material, but
// excluding signatures from other entities, to the given Writer.
// Identities and subkeys are re-signed in case they changed since NewEntry.
// If config is nil, sensible defaults will be used.
func (e *Entity) SerializePrivate(w io.Writer, config *packet.Config) (err error) {
err = e.PrivateKey.Serialize(w)
@ -568,8 +625,8 @@ func (e *Entity) SerializePrivate(w io.Writer, config *packet.Config) (err error
return nil
}
// Serialize writes the public part of the given Entity to w. (No private
// key material will be output).
// Serialize writes the public part of the given Entity to w, including
// signatures from other entities. No private key material will be output.
func (e *Entity) Serialize(w io.Writer) error {
err := e.PrimaryKey.Serialize(w)
if err != nil {

File diff suppressed because one or more lines are too long

@ -42,12 +42,18 @@ func (e *EncryptedKey) parse(r io.Reader) (err error) {
switch e.Algo {
case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly:
e.encryptedMPI1.bytes, e.encryptedMPI1.bitLength, err = readMPI(r)
if err != nil {
return
}
case PubKeyAlgoElGamal:
e.encryptedMPI1.bytes, e.encryptedMPI1.bitLength, err = readMPI(r)
if err != nil {
return
}
e.encryptedMPI2.bytes, e.encryptedMPI2.bitLength, err = readMPI(r)
if err != nil {
return
}
}
_, err = consumeAll(r)
return
@ -72,7 +78,8 @@ func (e *EncryptedKey) Decrypt(priv *PrivateKey, config *Config) error {
// padding oracle attacks.
switch priv.PubKeyAlgo {
case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly:
b, err = rsa.DecryptPKCS1v15(config.Random(), priv.PrivateKey.(*rsa.PrivateKey), e.encryptedMPI1.bytes)
k := priv.PrivateKey.(*rsa.PrivateKey)
b, err = rsa.DecryptPKCS1v15(config.Random(), k, padToKeySize(&k.PublicKey, e.encryptedMPI1.bytes))
case PubKeyAlgoElGamal:
c1 := new(big.Int).SetBytes(e.encryptedMPI1.bytes)
c2 := new(big.Int).SetBytes(e.encryptedMPI2.bytes)

@ -39,39 +39,44 @@ var encryptedKeyPriv = &PrivateKey{
}
func TestDecryptingEncryptedKey(t *testing.T) {
const encryptedKeyHex = "c18c032a67d68660df41c70104005789d0de26b6a50c985a02a13131ca829c413a35d0e6fa8d6842599252162808ac7439c72151c8c6183e76923fe3299301414d0c25a2f06a2257db3839e7df0ec964773f6e4c4ac7ff3b48c444237166dd46ba8ff443a5410dc670cb486672fdbe7c9dfafb75b4fea83af3a204fe2a7dfa86bd20122b4f3d2646cbeecb8f7be8"
const expectedKeyHex = "d930363f7e0308c333b9618617ea728963d8df993665ae7be1092d4926fd864b"
for i, encryptedKeyHex := range []string{
"c18c032a67d68660df41c70104005789d0de26b6a50c985a02a13131ca829c413a35d0e6fa8d6842599252162808ac7439c72151c8c6183e76923fe3299301414d0c25a2f06a2257db3839e7df0ec964773f6e4c4ac7ff3b48c444237166dd46ba8ff443a5410dc670cb486672fdbe7c9dfafb75b4fea83af3a204fe2a7dfa86bd20122b4f3d2646cbeecb8f7be8",
// MPI can be shorter than the length of the key.
"c18b032a67d68660df41c70103f8e520c52ae9807183c669ce26e772e482dc5d8cf60e6f59316e145be14d2e5221ee69550db1d5618a8cb002a719f1f0b9345bde21536d410ec90ba86cac37748dec7933eb7f9873873b2d61d3321d1cd44535014f6df58f7bc0c7afb5edc38e1a974428997d2f747f9a173bea9ca53079b409517d332df62d805564cffc9be6",
} {
const expectedKeyHex = "d930363f7e0308c333b9618617ea728963d8df993665ae7be1092d4926fd864b"
p, err := Read(readerFromHex(encryptedKeyHex))
if err != nil {
t.Errorf("error from Read: %s", err)
return
}
ek, ok := p.(*EncryptedKey)
if !ok {
t.Errorf("didn't parse an EncryptedKey, got %#v", p)
return
}
p, err := Read(readerFromHex(encryptedKeyHex))
if err != nil {
t.Errorf("#%d: error from Read: %s", i, err)
return
}
ek, ok := p.(*EncryptedKey)
if !ok {
t.Errorf("#%d: didn't parse an EncryptedKey, got %#v", i, p)
return
}
if ek.KeyId != 0x2a67d68660df41c7 || ek.Algo != PubKeyAlgoRSA {
t.Errorf("unexpected EncryptedKey contents: %#v", ek)
return
}
if ek.KeyId != 0x2a67d68660df41c7 || ek.Algo != PubKeyAlgoRSA {
t.Errorf("#%d: unexpected EncryptedKey contents: %#v", i, ek)
return
}
err = ek.Decrypt(encryptedKeyPriv, nil)
if err != nil {
t.Errorf("error from Decrypt: %s", err)
return
}
err = ek.Decrypt(encryptedKeyPriv, nil)
if err != nil {
t.Errorf("#%d: error from Decrypt: %s", i, err)
return
}
if ek.CipherFunc != CipherAES256 {
t.Errorf("unexpected EncryptedKey contents: %#v", ek)
return
}
if ek.CipherFunc != CipherAES256 {
t.Errorf("#%d: unexpected EncryptedKey contents: %#v", i, ek)
return
}
keyHex := fmt.Sprintf("%x", ek.Key)
if keyHex != expectedKeyHex {
t.Errorf("bad key, got %s want %x", keyHex, expectedKeyHex)
keyHex := fmt.Sprintf("%x", ek.Key)
if keyHex != expectedKeyHex {
t.Errorf("#%d: bad key, got %s want %s", i, keyHex, expectedKeyHex)
}
}
}
@ -121,7 +126,7 @@ func TestEncryptingEncryptedKey(t *testing.T) {
keyHex := fmt.Sprintf("%x", ek.Key)
if keyHex != expectedKeyHex {
t.Errorf("bad key, got %s want %x", keyHex, expectedKeyHex)
t.Errorf("bad key, got %s want %s", keyHex, expectedKeyHex)
}
}

Some files were not shown because too many files have changed in this diff Show More