From 14bc63ec5422e74b1eb8db26a4b82d91c71cb2ab Mon Sep 17 00:00:00 2001 From: Jeremy Tregunna Date: Thu, 26 Dec 2024 00:56:29 -0600 Subject: [PATCH] feat: Initial implementation of a role based authentication system --- go.mod | 5 ++ go.sum | 2 + rbac.go | 148 +++++++++++++++++++++++++++++++++++++++++++++++++++ rbac_test.go | 98 ++++++++++++++++++++++++++++++++++ 4 files changed, 253 insertions(+) create mode 100644 go.mod create mode 100644 go.sum create mode 100644 rbac.go create mode 100644 rbac_test.go diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..2fa8656 --- /dev/null +++ b/go.mod @@ -0,0 +1,5 @@ +module git.canoozie.net/jer/rbac + +go 1.23.2 + +require github.com/mattn/go-sqlite3 v1.14.24 // indirect diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..9dcdc9b --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM= +github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= diff --git a/rbac.go b/rbac.go new file mode 100644 index 0000000..91f3ff5 --- /dev/null +++ b/rbac.go @@ -0,0 +1,148 @@ +package rbac + +import ( + "database/sql" +) + +// Role in the system +type Role struct { + ID int `json:"id"` + Name string `json:"name"` +} + +// A capability is an action that can be performed in the system. +type Capability struct { + ID int `json:"id"` + Name string `json:"name"` + Description string `json:"description"` +} + +// Links a user to a role. +type UserRole struct { + UserID int `json:"user_id"` + RoleID int `json:"role_id"` +} + +// A role can have many capabilities. +type RoleCapability struct { + RoleID int `json:"role_id"` + CapabilityID int `json:"capability_id"` +} + +// RbacStore is an interface for interacting with the RBAC data store. +type RbacStore interface { + GetUserRoles(userID int) ([]Role, error) + GetRoleCapabilities(roleID int) ([]Capability, error) + HasCapability(userID int, capabilityName string) (bool, error) +} + +// SqlRbacStore implements the RbacStore interface using SQLite. +type SqlRbacStore struct { + db *sql.DB +} + +func NewSqlRbacStore(db *sql.DB) (*SqlRbacStore, error) { + s := &SqlRbacStore{db: db} + err := s.createTablesIfMissing() + if err != nil { + return nil, err + } + return s, nil +} + +func (s *SqlRbacStore) createTablesIfMissing() (err error) { + _, err = s.db.Exec(`CREATE TABLE IF NOT EXISTS roles ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL + )`) + if err != nil { + return + } + _, err = s.db.Exec(`CREATE TABLE IF NOT EXISTS capabilities ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL, + description TEXT + )`) + if err != nil { + return + } + _, err = s.db.Exec(`CREATE TABLE IF NOT EXISTS user_roles ( + user_id INTEGER NOT NULL, + role_id INTEGER NOT NULL, + PRIMARY KEY (user_id, role_id) + )`) + if err != nil { + return + } + _, err = s.db.Exec(`CREATE TABLE IF NOT EXISTS role_capabilities ( + role_id INTEGER NOT NULL, + capability_id INTEGER NOT NULL, + PRIMARY KEY (role_id, capability_id) + )`) + return +} + +func (s *SqlRbacStore) GetUserRoles(userID int) ([]Role, error) { + var roles []Role + rows, err := s.db.Query("SELECT r.id, r.name FROM roles r JOIN user_roles ur ON r.id = ur.role_id WHERE ur.user_id = ?", userID) + if err != nil { + return nil, err + } + defer rows.Close() + for rows.Next() { + var role Role + err = rows.Scan(&role.ID, &role.Name) + if err != nil { + return nil, err + } + roles = append(roles, role) + } + return roles, nil +} + +func (s *SqlRbacStore) GetRoleCapabilities(roleID int) ([]Capability, error) { + var capabilities []Capability + rows, err := s.db.Query("SELECT c.id, c.name, c.description FROM capabilities c JOIN role_capabilities rc ON c.id = rc.capability_id WHERE rc.role_id = ?", roleID) + if err != nil { + return nil, err + } + defer rows.Close() + for rows.Next() { + var capability Capability + err = rows.Scan(&capability.ID, &capability.Name, &capability.Description) + if err != nil { + return nil, err + } + capabilities = append(capabilities, capability) + } + return capabilities, nil +} + +func (s *SqlRbacStore) HasCapability(userID int, capabilityName string) (bool, error) { + var has bool + err := s.db.QueryRow("SELECT EXISTS(SELECT 1 FROM role_capabilities rc JOIN user_roles ur ON rc.role_id = ur.role_id WHERE ur.user_id = ? AND rc.capability_id = (SELECT id FROM capabilities WHERE name = ?))", userID, capabilityName).Scan(&has) + if err != nil { + return false, err + } + return has, nil +} + +type RbacService struct { + store RbacStore +} + +func NewRbacService(store RbacStore) *RbacService { + return &RbacService{store: store} +} + +func (s *RbacService) GetUserRoles(userID int) ([]Role, error) { + return s.store.GetUserRoles(userID) +} + +func (s *RbacService) GetRoleCapabilities(roleID int) ([]Capability, error) { + return s.store.GetRoleCapabilities(roleID) +} + +func (s *RbacService) HasCapability(userID int, capabilityName string) (bool, error) { + return s.store.HasCapability(userID, capabilityName) +} diff --git a/rbac_test.go b/rbac_test.go new file mode 100644 index 0000000..0196068 --- /dev/null +++ b/rbac_test.go @@ -0,0 +1,98 @@ +package rbac + +import ( + "database/sql" + _ "github.com/mattn/go-sqlite3" + "reflect" + "testing" +) + +func compareSlices(a, b []Role) bool { + if len(a) != len(b) { + return false + } + for i, v := range a { + if v != b[i] { + return false + } + } + return true +} + +func TestSqlRbacStore(t *testing.T) { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + store, err := NewSqlRbacStore(db) + if err != nil { + t.Fatal(err) + } + + // Create test data + _, err = db.Exec(` + INSERT INTO roles (id, name) VALUES (1, 'admin'); + INSERT INTO user_roles (user_id, role_id) VALUES (1, 1); + INSERT INTO capabilities (id, name, description) VALUES (1, 'Send Email', ''); + INSERT INTO role_capabilities (role_id, capability_id) VALUES (1, 1); + `) + if err != nil { + t.Fatal(err) + } + + tests := []struct { + name string + userID int + capability string + wantHas bool + }{ + { + name: "User has capability", + userID: 1, + capability: "Send Email", + wantHas: true, + }, + { + name: "User does not have capability", + userID: 2, + capability: "Send Email", + wantHas: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + has, err := store.HasCapability(tt.userID, tt.capability) + if err != nil { + t.Fatal(err) + } + if has != tt.wantHas { + t.Errorf("want %v, got %v", tt.wantHas, has) + } + + roles, err := store.GetUserRoles(tt.userID) + if err != nil { + t.Fatal(err) + } + wantRoles := []Role{} + if tt.userID == 1 { + wantRoles = append(wantRoles, Role{ID: 1, Name: "admin"}) + } + if !compareSlices(roles, wantRoles) { + t.Errorf("want %v, got %v", wantRoles, roles) + } + + capabilities, err := store.GetRoleCapabilities(1) + if err != nil { + t.Fatal(err) + } + wantCapabilities := []Capability{{ID: 1, Name: "Send Email"}} + if !reflect.DeepEqual(capabilities, wantCapabilities) { + t.Errorf("want %v, got %v", wantCapabilities, capabilities) + } + }) + } + +}