diff --git a/functions.go b/functions.go new file mode 100644 index 0000000..56f5210 --- /dev/null +++ b/functions.go @@ -0,0 +1,108 @@ +package sets + +// Union returns a union of the given sets (left ∪ right) +func Union[T comparable](left, right *Set[T]) *Set[T] { + left.mux.RLock() + right.mux.RLock() + defer left.mux.RUnlock() + defer right.mux.RUnlock() + + store := map[T]struct{}{} + + for k := range left.store { + store[k] = struct{}{} + } + for k := range right.store { + if _, ok := store[k]; !ok { + store[k] = struct{}{} + } + } + + return &Set[T]{ + store: store, + } +} + +// Intersection returns an intersection of the given sets (left ∩ right) +func Intersection[T comparable](left, right *Set[T]) *Set[T] { + left.mux.RLock() + right.mux.RLock() + defer left.mux.RUnlock() + defer right.mux.RUnlock() + + store := map[T]struct{}{} + + for lk := range left.store { + if _, ok := right.store[lk]; ok { + store[lk] = struct{}{} + } + } + + return &Set[T]{ + store: store, + } +} + +// Diff returns the relative complement of sets (left ∖ right) +func Diff[T comparable](left, right *Set[T]) *Set[T] { + left.mux.RLock() + right.mux.RLock() + defer left.mux.RUnlock() + defer right.mux.RUnlock() + + store := map[T]struct{}{} + + for lk := range left.store { + if _, ok := right.store[lk]; !ok { + store[lk] = struct{}{} + } + } + + return &Set[T]{ + store: store, + } +} + +// SymmetricDiff returns the symmetric difference between sets (left ⊖ right) +func SymmetricDiff[T comparable](left, right *Set[T]) *Set[T] { + left.mux.RLock() + right.mux.RLock() + defer left.mux.RUnlock() + defer right.mux.RUnlock() + + store := map[T]struct{}{} + + for lk := range left.store { + if _, ok := right.store[lk]; !ok { + store[lk] = struct{}{} + } + } + for rk := range right.store { + if _, ok := left.store[rk]; !ok { + store[rk] = struct{}{} + } + } + + return &Set[T]{ + store: store, + } +} + +// Equal checks whether the sets are equal (left = right) +func Equal[T comparable](left, right *Set[T]) bool { + left.mux.RLock() + right.mux.RLock() + defer left.mux.RUnlock() + defer right.mux.RUnlock() + + if len(left.store) != len(right.store) { + return false + } + + for lk := range left.store { + if _, ok := right.store[lk]; !ok { + return false + } + } + return true +} diff --git a/functions_test.go b/functions_test.go new file mode 100644 index 0000000..1149f7e --- /dev/null +++ b/functions_test.go @@ -0,0 +1,61 @@ +package sets + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func getSetForElements(t *testing.T, elements ...int) *Set[int] { + t.Helper() + + target := map[int]struct{}{} + for _, el := range elements { + target[el] = struct{}{} + } + + return &Set[int]{ + store: target, + } +} + +func TestUnion(t *testing.T) { + base := getSetForElements(t, 0, 1, 2, 3, 4, 5) + other := getSetForElements(t, 3, 4, 5, 6, 7, 8) + expected := []int{0, 1, 2, 3, 4, 5, 6, 7, 8} + + require.ElementsMatch(t, expected, Union(base, other).Slice()) +} + +func TestIntersection(t *testing.T) { + base := getSetForElements(t, 0, 1, 2, 3, 4, 5) + other := getSetForElements(t, 3, 4, 5, 6, 7, 8) + expected := []int{3, 4, 5} + + require.ElementsMatch(t, expected, Intersection(base, other).Slice()) +} + +func TestDiff(t *testing.T) { + base := getSetForElements(t, 0, 1, 2, 3, 4, 5) + other := getSetForElements(t, 3, 4, 5, 6, 7, 8) + expected := []int{0, 1, 2} + + require.ElementsMatch(t, expected, Diff(base, other).Slice()) +} + +func TestSymmetricDiff(t *testing.T) { + base := getSetForElements(t, 0, 1, 2, 3, 4, 5) + other := getSetForElements(t, 3, 4, 5, 6, 7, 8) + expected := []int{0, 1, 2, 6, 7, 8} + + require.ElementsMatch(t, expected, SymmetricDiff(base, other).Slice()) +} + +func TestEqual(t *testing.T) { + base := getSetForElements(t, 0, 1, 2, 3, 4, 5) + otherDifferent := getSetForElements(t, 3, 4, 5, 6, 7, 8) + otherEqual := getSetForElements(t, 0, 1, 2, 3, 4, 5) + + require.False(t, Equal(base, otherDifferent)) + require.True(t, Equal(base, otherEqual)) +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..eeea02f --- /dev/null +++ b/go.mod @@ -0,0 +1,11 @@ +module pkg.icikowski.pl/sets + +go 1.20 + +require github.com/stretchr/testify v1.8.4 + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..fa4b6e6 --- /dev/null +++ b/go.sum @@ -0,0 +1,10 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/package.go b/package.go new file mode 100644 index 0000000..96fabac --- /dev/null +++ b/package.go @@ -0,0 +1,4 @@ +/* +Sets-related types & functions +*/ +package sets diff --git a/set.go b/set.go new file mode 100644 index 0000000..ae88d37 --- /dev/null +++ b/set.go @@ -0,0 +1,77 @@ +package sets + +import ( + "sync" +) + +// Set represents a set of values +type Set[T comparable] struct { + store map[T]struct{} + mux sync.RWMutex +} + +// New creates a new set +func New[T comparable](data ...T) *Set[T] { + set := &Set[T]{ + store: map[T]struct{}{}, + } + + for _, element := range data { + set.Insert(element) + } + return set +} + +// Size returns number of elements in set +func (s *Set[T]) Size() int { + s.mux.RLock() + defer s.mux.RUnlock() + + return len(s.store) +} + +// Contains checks whether the value is contained in the set +func (s *Set[T]) Contains(val T) bool { + s.mux.RLock() + defer s.mux.RUnlock() + + _, ok := s.store[val] + return ok +} + +// Insert inserts a value into the set if the value was not already present +func (s *Set[T]) Insert(val T) bool { + s.mux.Lock() + defer s.mux.Unlock() + + if _, ok := s.store[val]; !ok { + s.store[val] = struct{}{} + return false + } + return true +} + +// Delete removes a value from the set if the value was already present +func (s *Set[T]) Delete(val T) bool { + s.mux.Lock() + defer s.mux.Unlock() + + if _, ok := s.store[val]; ok { + delete(s.store, val) + return true + } + return false +} + +// Slice returns a slice which contains the elements from the set +func (s *Set[T]) Slice() []T { + s.mux.RLock() + defer s.mux.RUnlock() + + elements := []T{} + for k := range s.store { + elements = append(elements, k) + } + + return elements +} diff --git a/set_test.go b/set_test.go new file mode 100644 index 0000000..3792e43 --- /dev/null +++ b/set_test.go @@ -0,0 +1,61 @@ +package sets + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func getTestSet(t *testing.T) *Set[int] { + t.Helper() + + return &Set[int]{ + store: map[int]struct{}{ + 0: {}, + 1: {}, + 2: {}, + 3: {}, + 4: {}, + 5: {}, + 6: {}, + 7: {}, + 8: {}, + 9: {}, + }, + } +} + +func TestSetSize(t *testing.T) { + s := getTestSet(t) + + require.Equal(t, 10, s.Size()) +} + +func TestSetContains(t *testing.T) { + s := getTestSet(t) + + require.True(t, s.Contains(0)) + require.False(t, s.Contains(10)) +} + +func TestSetInsert(t *testing.T) { + s := getTestSet(t) + + require.True(t, s.Insert(0)) + require.False(t, s.Insert(10)) + require.Contains(t, s.store, 10) +} + +func TestSetDelete(t *testing.T) { + s := getTestSet(t) + + require.False(t, s.Delete(10)) + require.True(t, s.Delete(0)) + require.NotContains(t, s.store, 0) +} + +func TestSetSlice(t *testing.T) { + s := getTestSet(t) + + require.ElementsMatch(t, []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, s.Slice()) +}