Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KaHyPar selector and greedy merge #46

Merged
merged 17 commits into from
Jan 18, 2025
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,5 @@ lib/OptimalBranchingMIS/docs/build/
lib/OptimalBranchingCore/docs/build/

docs/src/generated/

report.typ
29 changes: 28 additions & 1 deletion examples/rule_discovery.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,31 @@ branching_region = SimpleGraph(Graphs.SimpleEdge.(edges))
# Generate the tree-like N3 neighborhood of R
graph = tree_like_N3_neighborhood(branching_region)

solve_opt_rule(branching_region, graph, vs)
solve_opt_rule(branching_region, graph, vs)


# ## Generating rules for large scale problems
# For large scale problems, we can use the greedy merge rule to generate rules, which avoids generating all candidate clauses.
function solve_greedy_rule(branching_region, graph, vs)
## Use default solver and measure
m = D3Measure()
table_solver = TensorNetworkSolver(; prune_by_env=true)

## Pruning irrelevant entries
ovs = OptimalBranchingMIS.open_vertices(graph, vs)
subg, vmap = induced_subgraph(graph, vs)
@info "solving the branching table..."
tbl = OptimalBranchingMIS.reduced_alpha_configs(table_solver, subg, Int[findfirst(==(v), vs) for v in ovs])
@info "the length of the truth_table after pruning irrelevant entries: $(length(tbl.table))"

@info "generating the optimal branching rule via greedy merge..."
candidates = OptimalBranchingCore.bit_clauses(tbl)
result = OptimalBranchingMIS.OptimalBranchingCore.greedymerge(candidates, MISProblem(graph), vs, m)
return result
@info "the greedily minimized gamma: $(result.γ)"

@info "the branching rule on R:"
viz_dnf(result.optimal_rule, vs)
end

result = solve_greedy_rule(branching_region, graph, vs)
2 changes: 2 additions & 0 deletions lib/OptimalBranchingCore/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@ version = "0.1.1"

[deps]
BitBasis = "50ba71b6-fa0f-514d-ae9a-0916efc90dcf"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
HiGHS = "87dc4568-4c63-4d18-b0c0-bb2238e4078b"
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"

[compat]
BitBasis = "0.9"
DataStructures = "0.18.20"
HiGHS = "1.12"
JuMP = "1.23"
julia = "1.10"
Expand Down
5 changes: 4 additions & 1 deletion lib/OptimalBranchingCore/src/OptimalBranchingCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module OptimalBranchingCore

using JuMP, HiGHS
using BitBasis
using DataStructures

# logic expressions
export Clause, BranchingTable, DNF, booleans, ∨, ∧, ¬, covered_by, literals, is_true_literal, is_false_literal
Expand All @@ -16,7 +17,7 @@ export AbstractProblem, branch_and_reduce, BranchingStrategy
# variable selector interface
export select_variable, AbstractSelector
# branching table solver interface
export branching_table, AbstractTableSolver
export branching_table, AbstractTableSolver, NaiveBranch, GreedyMerge
# measure interface
export measure, AbstractMeasure
# reducer interface
Expand All @@ -30,5 +31,7 @@ include("interfaces.jl")
include("branching_table.jl")
include("setcovering.jl")
include("branch.jl")
include("greedymerge.jl")
include("mockproblem.jl")

end
54 changes: 37 additions & 17 deletions lib/OptimalBranchingCore/src/branch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,15 @@ A [`OptimalBranchingResult`](@ref) object representing the optimal branching rul
"""
function optimal_branching_rule(table::BranchingTable, variables::Vector, problem::AbstractProblem, m::AbstractMeasure, solver::AbstractSetCoverSolver)
candidates = candidate_clauses(table)
size_reductions = [measure(problem, m) - measure(first(apply_branch(problem, candidate, variables)), m) for candidate in candidates]
return minimize_γ(table, candidates, size_reductions, solver; γ0=2.0)
size_reductions = [size_reduction(problem, m, candidate, variables) for candidate in candidates]
return minimize_γ(table, candidates, size_reductions, solver; γ0 = 2.0)
end

function size_reduction(p::AbstractProblem, m::AbstractMeasure, cl::Clause{INT}, variables::Vector) where {INT}
return measure(p, m) - measure(first(apply_branch(p, cl, variables)), m)
end


"""
BranchingStrategy
BranchingStrategy(; kwargs...)
Expand All @@ -31,23 +36,23 @@ A struct representing the configuration for a solver, including the reducer and
- `selector::AbstractSelector`: The selector to select the next branching variable or decision.
- `m::AbstractMeasure`: The measure to evaluate the performance of the branching strategy.
"""
@kwdef struct BranchingStrategy{TS<:AbstractTableSolver, SCS<:AbstractSetCoverSolver, SL<:AbstractSelector, M<:AbstractMeasure}
@kwdef struct BranchingStrategy{TS <: AbstractTableSolver, SCS <: AbstractSetCoverSolver, SL <: AbstractSelector, M <: AbstractMeasure}
set_cover_solver::SCS = IPSolver()
table_solver::TS
selector::SL
measure::M
end
Base.show(io::IO, config::BranchingStrategy) = print(io,
"""
BranchingStrategy
├── table_solver - $(config.table_solver)
├── set_cover_solver - $(config.set_cover_solver)
├── selector - $(config.selector)
└── measure - $(config.measure)
""")
Base.show(io::IO, config::BranchingStrategy) = print(io,
"""
BranchingStrategy
├── table_solver - $(config.table_solver)
├── set_cover_solver - $(config.set_cover_solver)
├── selector - $(config.selector)
└── measure - $(config.measure)
""")

"""
branch_and_reduce(problem::AbstractProblem, config::BranchingStrategy; reducer::AbstractReducer=NoReducer(), result_type=Int)
branch_and_reduce(problem::AbstractProblem, config::BranchingStrategy; reducer::AbstractReducer=NoReducer(), result_type=Int, show_progress=false)

Branch the given problem using the specified solver configuration.

Expand All @@ -62,19 +67,34 @@ Branch the given problem using the specified solver configuration.
### Returns
The resulting value, which may have different type depending on the `result_type`.
"""
function branch_and_reduce(problem::AbstractProblem, config::BranchingStrategy, reducer::AbstractReducer, result_type)
function branch_and_reduce(problem::AbstractProblem, config::BranchingStrategy, reducer::AbstractReducer, result_type; show_progress=false, tag=Tuple{Int,Int}[])
@debug "Branching and reducing problem" problem
isempty(problem) && return zero(result_type)
has_zero_size(problem) && return zero(result_type)
# reduce the problem
rp, reducedvalue = reduce_problem(result_type, problem, reducer)
rp !== problem && return branch_and_reduce(rp, config, reducer, result_type) * reducedvalue
rp !== problem && return branch_and_reduce(rp, config, reducer, result_type; tag) * reducedvalue

# branch the problem
variables = select_variables(rp, config.measure, config.selector) # select a subset of variables
tbl = branching_table(rp, config.table_solver, variables) # compute the BranchingTable
result = optimal_branching_rule(tbl, variables, rp, config.measure, config.set_cover_solver) # compute the optimal branching rule
return sum(result.optimal_rule.clauses) do branch # branch and recurse
return sum(enumerate(get_clauses(result))) do (i, branch) # branch and recurse
show_progress && (print_sequence(stdout, tag); println(stdout))
subproblem, localvalue = apply_branch(rp, branch, variables)
branch_and_reduce(subproblem, config, reducer, result_type) * result_type(localvalue) * reducedvalue
branch_and_reduce(subproblem, config, reducer, result_type;
tag=(show_progress ? [tag..., (i, length(get_clauses(result)))] : tag),
show_progress) * result_type(localvalue) * reducedvalue
end
end

function print_sequence(io::IO, sequence::Vector{Tuple{Int,Int}})
for (i, n) in sequence
if i == n
print(io, "■")
elseif i == 1
print(io, "□")
else
print(io, "▦")
end
end
end
78 changes: 78 additions & 0 deletions lib/OptimalBranchingCore/src/greedymerge.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
struct GreedyMerge <: AbstractSetCoverSolver end
struct NaiveBranch <: AbstractSetCoverSolver end
function optimal_branching_rule(table::BranchingTable, variables::Vector, problem::AbstractProblem, m::AbstractMeasure, solver::GreedyMerge)
candidates = bit_clauses(table)
return greedymerge(candidates, problem, variables, m)
end

function optimal_branching_rule(table::BranchingTable, variables::Vector, problem::AbstractProblem, m::AbstractMeasure, solver::NaiveBranch)
candidates = bit_clauses(table)
size_reductions = [Float64(size_reduction(problem, m, first(candidate), variables)) for candidate in candidates]
γ = complexity_bv(size_reductions)
return OptimalBranchingResult(DNF(first.(candidates)), size_reductions, γ)
end

function bit_clauses(tbl::BranchingTable{INT}) where {INT}
n, bss = tbl.bit_length, tbl.table
temp_clauses = [[Clause(bmask(INT, 1:n), bs) for bs in bss1] for bss1 in bss]
return temp_clauses
end

function greedymerge(cls::Vector{Vector{Clause{INT}}}, problem::AbstractProblem, variables::Vector, m::AbstractMeasure) where {INT}
function reduction_merge(cli, clj)
clmax, iimax, jjmax, reductionmax = Clause(zero(INT), zero(INT)), -1, -1, 0.0
@inbounds for ii = 1:length(cli), jj = 1:length(clj)
cl12 = gather2(length(variables), cli[ii], clj[jj])
iszero(cl12.mask) && continue
reduction = Float64(size_reduction(problem, m, cl12, variables))
if reduction > reductionmax
clmax, iimax, jjmax, reductionmax = cl12, ii, jj, reduction
end
end
return clmax, iimax, jjmax, reductionmax
end
cls = copy(cls)
size_reductions = [Float64(size_reduction(problem, m, first(candidate), variables)) for candidate in cls]
k = 0
@inbounds while true
nc = length(cls)
mask = trues(nc)
γ = complexity_bv(size_reductions)
weights = map(s -> γ^(-s), size_reductions)
queue = PriorityQueue{NTuple{2, Int}, Float64}() # from small to large
for i ∈ 1:nc, j ∈ i+1:nc
_, _, _, reduction = reduction_merge(cls[i], cls[j])
dE = γ^(-reduction) - weights[i] - weights[j]
dE <= -1e-12 && enqueue!(queue, (i, j), dE - 1e-12 * (k += 1; k))
end
isempty(queue) && return OptimalBranchingResult(DNF(first.(cls)), size_reductions, γ)
while !isempty(queue)
(i, j) = dequeue!(queue)
# remove i, j-th row
for rowid in (i, j)
mask[rowid] = false
for k = 1:nc
if mask[k]
a, b = minmax(rowid, k)
haskey(queue, (a, b)) && delete!(queue, (a, b))
end
end
end
# add i-th row
mask[i] = true
clij, _, _, size_reductions[i] = reduction_merge(cls[i], cls[j])
cls[i] = [clij]
weights[i] = γ^(-size_reductions[i])
for k = 1:nc
if i !== k && mask[k]
a, b = minmax(i, k)
_, _, _, reduction = reduction_merge(cls[a], cls[b])
dE = γ^(-reduction) - weights[a] - weights[b]

dE <= -1e-12 && enqueue!(queue, (a, b), dE - 1e-12 * (k += 1; k))
end
end
end
cls, size_reductions = cls[mask], size_reductions[mask]
end
end
76 changes: 76 additions & 0 deletions lib/OptimalBranchingCore/src/mockproblem.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
struct MockProblem <: AbstractProblem
optimal::BitVector
end

"""
NumOfVariables

A struct representing a measure that counts the number of variables in a problem.
Each variable is counted as 1.
"""
struct NumOfVariables <: AbstractMeasure end
measure(p::MockProblem, ::NumOfVariables) = length(p.optimal)


"""
struct RandomSelector <: AbstractSelector

The `RandomSelector` struct represents a strategy for selecting a subset of variables randomly.

# Fields
- `n::Int`: The number of variables to select.
"""
struct RandomSelector <: AbstractSelector
n::Int
end
function select_variables(p::MockProblem, ::NumOfVariables, selector::RandomSelector)
nv = min(length(p.optimal), selector.n)
return sortperm(rand(length(p.optimal)))[1:nv]
end

"""
struct MockTableSolver <: AbstractTableSolver

The `MockTableSolver` randomly generates a branching table with a given number of rows.
Each row must have at least one variable to be covered by the branching rule.

### Fields
- `n::Int`: The number of rows in the branching table.
- `p::Float64 = 0.0`: The probability of generating more than one variables in a row, following the Poisson distribution.
"""
struct MockTableSolver <: AbstractTableSolver
n::Int
p::Float64
end
MockTableSolver(n::Int) = MockTableSolver(n, 0.0)
function branching_table(p::MockProblem, table_solver::MockTableSolver, variables)
function rand_fib() # random independent set on 1D chain
bs = falses(length(variables))
for i=1:length(variables)
if rand() < min(0.5, i == 1 ? 1.0 : 1 - bs[i-1])
bs[i] = true
end
end
return bs
end
rows = unique!([[rand_fib()] for _ in 1:table_solver.n] ∪ [[p.optimal[variables]]])
for i in 1:length(rows)
for _ = 1:100
if rand() < table_solver.p
push!(rows[i], rand_fib())
else
break
end
end
end
return BranchingTable(length(variables), unique!.(rows))
end

function apply_branch(p::MockProblem, clause::Clause{INT}, variables::Vector{T}) where {INT<:Integer, T<:Integer}
remain_mask = trues(length(p.optimal))
for i in 1:length(variables)
isone(readbit(clause.mask, i)) && (remain_mask[variables[i]] = false)
end
return MockProblem(p.optimal[remain_mask]), count(i -> isone(readbit(clause.mask, i)) && (readbit(clause.val, i) == p.optimal[variables[i]]), 1:length(variables))
end
has_zero_size(p::MockProblem) = measure(p, NumOfVariables()) == 0
14 changes: 7 additions & 7 deletions lib/OptimalBranchingCore/src/setcovering.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,18 +101,18 @@ end
The result type for the optimal branching rule.

### Fields
- `selected_ids::Vector{Int}`: The indices of the selected rows in the branching table.
- `optimal_rule::DNF{INT}`: The optimal branching rule.
- `branching_vector::Vector{T<:Real}`: The branching vector that records the size reduction in each subproblem.
- `γ::Float64`: The optimal γ value (the complexity of the branching rule).
"""
struct OptimalBranchingResult{INT <: Integer, T <: Real}
selected_ids::Vector{Int}
optimal_rule::DNF{INT}
branching_vector::Vector{T}
γ::Float64
end
Base.show(io::IO, results::OptimalBranchingResult{INT, T}) where {INT, T} = print(io, "OptimalBranchingResult{$INT, $T}:\n selected_ids: $(results.selected_ids)\n optimal_rule: $(results.optimal_rule)\n branching_vector: $(results.branching_vector)\n γ: $(results.γ)")
Base.show(io::IO, results::OptimalBranchingResult{INT, T}) where {INT, T} = print(io, "OptimalBranchingResult{$INT, $T}:\n optimal_rule: $(results.optimal_rule)\n branching_vector: $(results.branching_vector)\n γ: $(results.γ)")
get_clauses(results::OptimalBranchingResult) = results.optimal_rule.clauses
get_clauses(res::AbstractArray) = res

"""
minimize_γ(table::BranchingTable, candidates::Vector{Clause}, Δρ::Vector, solver)
Expand Down Expand Up @@ -140,7 +140,7 @@ function minimize_γ(table::BranchingTable, candidates::Vector{Clause{INT}}, Δ

# Note: the following instance is captured for time saving, and also for it may cause IP solver to fail
for (k, subset) in enumerate(subsets)
(length(subset) == num_items) && return OptimalBranchingResult([k], DNF([candidates[k]]), [Δρ[k]], 1.0)
(length(subset) == num_items) && return OptimalBranchingResult(DNF([candidates[k]]), [Δρ[k]], 1.0)
end

cx_old = cx = γ0
Expand All @@ -153,7 +153,7 @@ function minimize_γ(table::BranchingTable, candidates::Vector{Clause{INT}}, Δ
cx ≈ cx_old && break # convergence
cx_old = cx
end
return OptimalBranchingResult(picked_scs, DNF([candidates[i] for i in picked_scs]), Δρ[picked_scs], cx)
return OptimalBranchingResult(DNF([candidates[i] for i in picked_scs]), Δρ[picked_scs], cx)
end

# TODO: we need to extend this function to trim the candidate clauses
Expand Down Expand Up @@ -204,7 +204,7 @@ function gather2(n::Int, c1::Clause{INT}, c2::Clause{INT}) where INT
return Clause(mask, val)
end

function is_solved(xs::Vector{T}, sets_id::Vector{Vector{Int}}, num_items::Int) where{T}
function is_solved_by(xs::Vector{T}, sets_id::Vector{Vector{Int}}, num_items::Int) where{T}
for i in 1:num_items
flag = sum(xs[j] for j in sets_id[i])
((flag < 1) && !(flag ≈ 1)) && return false
Expand Down Expand Up @@ -247,7 +247,7 @@ function weighted_minimum_set_cover(solver::LPSolver, weights::AbstractVector, s

optimize!(model)
xs = value.(x)
@assert is_solved(xs, sets_id, num_items)
@assert is_solved_by(xs, sets_id, num_items)
return pick_sets(xs, subsets, num_items)
end

Expand Down
2 changes: 1 addition & 1 deletion lib/OptimalBranchingCore/test/branching_table.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using OptimalBranchingCore, GenericTensorNetworks
using BitBasis
using OptimalBranchingCore.BitBasis
using Test

@testset "branching table" begin
Expand Down
Loading
Loading