Parcourir la Source

added more tests and refactored

bmallred 9 ans auparavant
Parent
Commettre
c32a1bcc0e
7 fichiers modifiés avec 193 ajouts et 141 suppressions
  1. 8 3
      account.go
  2. 28 0
      environment.go
  3. 16 0
      environment_test.go
  4. 0 122
      main.go
  5. 0 16
      main_test.go
  6. 99 0
      security.go
  7. 42 0
      security_test.go

+ 8 - 3
account.go

@ -12,9 +12,14 @@ import (
12 12
)
13 13
14 14
const (
15
	publishLayout = "200601021504"
15
	publishDateLayout = "200601021504"
16 16
)
17 17
18
type Account struct {
19
	Username   string `json:username`
20
	Passphrase string `json:passphrase`
21
}
22
18 23
func getAccounts() ([]Account, error) {
19 24
	file := UserFile
20 25
@ -195,7 +200,7 @@ func (a *Account) Add(content, filename string, file []byte) (string, error) {
195 200
196 201
	// Create new content directory if necessary
197 202
	if content == "" {
198
		content = fmt.Sprintf("%x", hash(userDir+time.Now().Format(publishLayout), salt()))
203
		content = fmt.Sprintf("%x", hash(userDir+time.Now().Format(publishDateLayout), salt()))
199 204
	}
200 205
201 206
	// Create new content directory
@ -241,7 +246,7 @@ func (a *Account) Publish(title, content string, date time.Time) error {
241 246
	userDir := strings.ToLower(a.Username)
242 247
	s := string(filepath.Separator)
243 248
	contentPath := baseDir + s + userDir + s + "content" + s + content
244
	publishPath := baseDir + s + userDir + s + "blog" + s + date.Format(publishLayout) + "-" + prettyTitle
249
	publishPath := baseDir + s + userDir + s + "blog" + s + date.Format(publishDateLayout) + "-" + prettyTitle
245 250
246 251
	// Check if the content exists
247 252
	if _, err := os.Stat(contentPath); os.IsNotExist(err) {

+ 28 - 0
environment.go

@ -0,0 +1,28 @@
1
package main
2
3
import "os"
4
5
var (
6
	EnvironmentVariables = map[string]string{
7
		"LOOP_URL":  "localhost:6006",
8
		"LOOP_SALT": "whatyoutalkingaboutwillus",
9
		"LOOP_DATA": ".",
10
	}
11
)
12
13
func ConfigureEnvironment() {
14
	for key, value := range EnvironmentVariables {
15
		env := os.Getenv(key)
16
		if env == "" && value != "" {
17
			os.Setenv(key, value)
18
		}
19
	}
20
}
21
22
func address() string {
23
	return os.Getenv("LOOP_URL")
24
}
25
26
func salt() string {
27
	return os.Getenv("LOOP_SALT")
28
}

+ 16 - 0
environment_test.go

@ -0,0 +1,16 @@
1
package main
2
3
import (
4
	"os"
5
	"testing"
6
)
7
8
func TestConfigureEnvironment(t *testing.T) {
9
	ConfigureEnvironment()
10
	for k, v := range EnvironmentVariables {
11
		ev := os.Getenv(k)
12
		if ev == "" || ev != v {
13
			t.FailNow()
14
		}
15
	}
16
}

+ 0 - 122
main.go

@ -1,42 +1,14 @@
1 1
package main
2 2
3 3
import (
4
	"crypto/aes"
5
	"crypto/cipher"
6
	"crypto/md5"
7
	"crypto/rand"
8
	"crypto/sha512"
9
	"encoding/base64"
10
	"errors"
11
	"fmt"
12
	"io"
13 4
	"log"
14 5
	"net/http"
15
	"os"
16
	"strings"
17 6
)
18 7
19 8
const (
20 9
	UserFile = "loop_users"
21 10
)
22 11
23
var (
24
	EnvironmentVariables = map[string]string{
25
		"LOOP_URL":  "localhost:6006",
26
		"LOOP_SALT": "whatyoutalkingaboutwillus",
27
		"LOOP_DATA": ".",
28
	}
29
)
30
31
func ConfigureEnvironment() {
32
	for key, value := range EnvironmentVariables {
33
		env := os.Getenv(key)
34
		if env == "" && value != "" {
35
			os.Setenv(key, value)
36
		}
37
	}
38
}
39
40 12
func main() {
41 13
	log.Print("Configuring environment...")
42 14
	ConfigureEnvironment()
@ -45,97 +17,3 @@ func main() {
45 17
46 18
	log.Fatal(http.ListenAndServe(address(), nil))
47 19
}
48
49
// General stuff
50
51
type Account struct {
52
	Username   string `json:username`
53
	Passphrase string `json:passphrase`
54
}
55
56
func address() string {
57
	return os.Getenv("LOOP_URL")
58
}
59
60
// Security stuff
61
62
func encodeBase64(data []byte) string {
63
	return base64.StdEncoding.EncodeToString(data)
64
}
65
func decodeBase64(data []byte) []byte {
66
	b, _ := base64.StdEncoding.DecodeString(string(data))
67
	return b
68
}
69
func salt() string {
70
	return os.Getenv("LOOP_SALT")
71
}
72
func getSizedKey(key string) string {
73
	// Get the correct key length
74
	l := len(key)
75
	if l < 16 {
76
		for i := 0; i < 16-l; i++ {
77
			key += "."
78
		}
79
	} else if l < 24 {
80
		for i := 0; i < 24-l; i++ {
81
			key += "."
82
		}
83
	} else if l < 32 {
84
		for i := 0; i < 32-l; i++ {
85
			key += "."
86
		}
87
	} else {
88
		key = key[:32]
89
	}
90
91
	return key
92
}
93
func encrypt(text, passphrase string) ([]byte, error) {
94
	key := []byte(getSizedKey(passphrase))
95
	block, err := aes.NewCipher(key)
96
	if err != nil {
97
		return nil, err
98
	}
99
	cipherText := make([]byte, aes.BlockSize+len(text))
100
	iv := cipherText[:aes.BlockSize]
101
	if _, err := io.ReadFull(rand.Reader, iv); err != nil {
102
		return nil, err
103
	}
104
	encrypter := cipher.NewCFBEncrypter(block, iv)
105
	encrypter.XORKeyStream(cipherText[aes.BlockSize:], []byte(text))
106
	return cipherText, nil
107
}
108
func decrypt(text []byte, passphrase string) ([]byte, error) {
109
	key := []byte(getSizedKey(passphrase))
110
	block, err := aes.NewCipher(key)
111
	if err != nil {
112
		return nil, err
113
	}
114
	if len(text) < aes.BlockSize {
115
		return nil, errors.New("Cipher text too short")
116
	}
117
	iv := text[:aes.BlockSize]
118
	data := text[aes.BlockSize:]
119
	decrypter := cipher.NewCFBDecrypter(block, iv)
120
	decrypter.XORKeyStream(data, data)
121
	return data, nil
122
}
123
func hash(clearText, salt string) []byte {
124
	h := md5.New()
125
	h.Write([]byte(clearText))
126
	return h.Sum(nil)
127
}
128
func hashCredentials(username, passphrase, salt string) []byte {
129
	clearText := fmt.Sprintf(
130
		"%s%s-%s",
131
		salt,
132
		strings.ToLower(username),
133
		strings.ToLower(passphrase))
134
135
	sha := sha512.New()
136
	sha.Write([]byte(clearText))
137
	return sha.Sum(nil)
138
}
139
func generatePassphrase(username, passphrase string) string {
140
	return encodeBase64(hashCredentials(username, passphrase, salt()))
141
}

+ 0 - 16
main_test.go

@ -1,16 +0,0 @@
1
package main
2
3
import (
4
	"os"
5
	"testing"
6
)
7
8
func TestAddress(t *testing.T) {
9
	expected := "127.0.0.1:8888"
10
	os.Setenv("LOOP_URL", expected)
11
	actual := os.Getenv("LOOP_URL")
12
13
	if actual != expected {
14
		t.FailNow()
15
	}
16
}

+ 99 - 0
security.go

@ -0,0 +1,99 @@
1
package main
2
3
import (
4
	"crypto/aes"
5
	"crypto/cipher"
6
	"crypto/md5"
7
	"crypto/rand"
8
	"crypto/sha512"
9
	"encoding/base64"
10
	"errors"
11
	"fmt"
12
	"io"
13
	"strings"
14
)
15
16
func encodeBase64(data []byte) string {
17
	return base64.StdEncoding.EncodeToString(data)
18
}
19
20
func decodeBase64(data []byte) []byte {
21
	b, _ := base64.StdEncoding.DecodeString(string(data))
22
	return b
23
}
24
25
func getSizedKey(key string) string {
26
	// Get the correct key length
27
	l := len(key)
28
	if l < 16 {
29
		for i := 0; i < 16-l; i++ {
30
			key += "."
31
		}
32
	} else if l < 24 {
33
		for i := 0; i < 24-l; i++ {
34
			key += "."
35
		}
36
	} else if l < 32 {
37
		for i := 0; i < 32-l; i++ {
38
			key += "."
39
		}
40
	} else {
41
		key = key[:32]
42
	}
43
44
	return key
45
}
46
47
func encrypt(text, passphrase string) ([]byte, error) {
48
	key := []byte(getSizedKey(passphrase))
49
	block, err := aes.NewCipher(key)
50
	if err != nil {
51
		return nil, err
52
	}
53
	cipherText := make([]byte, aes.BlockSize+len(text))
54
	iv := cipherText[:aes.BlockSize]
55
	if _, err := io.ReadFull(rand.Reader, iv); err != nil {
56
		return nil, err
57
	}
58
	encrypter := cipher.NewCFBEncrypter(block, iv)
59
	encrypter.XORKeyStream(cipherText[aes.BlockSize:], []byte(text))
60
	return cipherText, nil
61
}
62
63
func decrypt(text []byte, passphrase string) ([]byte, error) {
64
	key := []byte(getSizedKey(passphrase))
65
	block, err := aes.NewCipher(key)
66
	if err != nil {
67
		return nil, err
68
	}
69
	if len(text) < aes.BlockSize {
70
		return nil, errors.New("Cipher text too short")
71
	}
72
	iv := text[:aes.BlockSize]
73
	data := text[aes.BlockSize:]
74
	decrypter := cipher.NewCFBDecrypter(block, iv)
75
	decrypter.XORKeyStream(data, data)
76
	return data, nil
77
}
78
79
func hash(clearText, salt string) []byte {
80
	h := md5.New()
81
	h.Write([]byte(clearText))
82
	return h.Sum(nil)
83
}
84
85
func hashCredentials(username, passphrase, salt string) []byte {
86
	clearText := fmt.Sprintf(
87
		"%s%s-%s",
88
		salt,
89
		strings.ToLower(username),
90
		strings.ToLower(passphrase))
91
92
	sha := sha512.New()
93
	sha.Write([]byte(clearText))
94
	return sha.Sum(nil)
95
}
96
97
func generatePassphrase(username, passphrase string) string {
98
	return encodeBase64(hashCredentials(username, passphrase, salt()))
99
}

+ 42 - 0
security_test.go

@ -0,0 +1,42 @@
1
package main
2
3
import "testing"
4
5
func TestEncoding(t *testing.T) {
6
	clear := "encoding test"
7
	base64 := encodeBase64([]byte(clear))
8
	decode64 := string(decodeBase64([]byte(base64)))
9
	if decode64 != clear {
10
		t.Error("Encoding does not match")
11
	}
12
}
13
14
func TestEncryption(t *testing.T) {
15
	clear := "encryption test"
16
	passphrase := "password"
17
	encrypted, err := encrypt(clear, passphrase)
18
	if err != nil {
19
		t.Error(err)
20
	}
21
22
	decrypted, err := decrypt(encrypted, passphrase)
23
	if err != nil {
24
		t.Error(err)
25
	}
26
27
	if clear != string(decrypted) {
28
		t.Error("Encryption does not match")
29
	}
30
}
31
32
func TestHash(t *testing.T) {
33
	t.Skip("Need a test for hashing")
34
}
35
36
func TestHashCredentials(t *testing.T) {
37
	t.Skip("Need a test for hashing credentials")
38
}
39
40
func TestGeneratePassphrase(t *testing.T) {
41
	t.Skip("Need a test for generating a passphrase")
42
}