Commit 20332767 by Tianqi Yang

Initial commit

parents
#pragma once
#include "dist.h"
#include <memory>
class Bandit
{
public:
Bandit ( std::vector < std::shared_ptr < rnd_dist > > _distributions, std::vector < double > means )
: n ( _distributions.size () ), means ( means ), distributions ( _distributions )
{
max_mean = *std::max_element ( means.begin (), means.end () );
}
virtual void new_iter ()
{
dist_val.resize ( n );
for ( int i = 0; i < n; ++i ) {
dist_val[i] = distributions[i]->operator() ();
}
}
virtual double operator () ( int i ) = 0;
virtual double get_regret ( int i ) = 0;
int get_n () const
{
return n;
}
protected:
int n;
std::vector < double > means;
double max_mean;
std::vector < std::shared_ptr < rnd_dist > > distributions;
std::vector < double > dist_val;
};
class MAB : public Bandit
{
public:
MAB ( std::vector < std::shared_ptr < rnd_dist > > distributions, std::vector < double > means )
: Bandit ( distributions, means )
{
}
virtual double operator () ( int i )
{
return dist_val[i];
}
virtual double get_regret ( int i )
{
return max_mean - means[i];
}
};
class MixedMAB : public Bandit
{
public:
MixedMAB ( std::vector < std::shared_ptr < rnd_dist > > distributions, std::vector < double > means, std::vector < std::shared_ptr < rnd_dist > > explicit_distribution )
: Bandit ( distributions, means ), explicit_distribution ( explicit_distribution )
{
}
virtual void new_iter ()
{
Bandit::new_iter ();
explicit_val.resize ( n );
for ( int i = 0; i < n; ++i ) {
explicit_val[i] = explicit_distribution[i]->operator () ();
}
}
virtual double operator () ( int i )
{
return dist_val[i] + explicit_val[i];
}
virtual double get_regret ( int i )
{
double opt = -1, cur_v;
int opt_j = -1;
for ( int j = 0; j < n; ++j ) {
cur_v = means[j] + explicit_val[j];
if ( j == 0 || cur_v > opt ) {
opt = cur_v;
opt_j = j;
}
}
return opt - ( means[i] + explicit_val[i] );
}
std::vector < double > & get_explicit_val ()
{
return explicit_val;
}
protected:
std::vector < std::shared_ptr < rnd_dist > > explicit_distribution;
std::vector < double > explicit_val;
};
\ No newline at end of file
#include "solver.h"
#include "utils.h"
#include <bits/stdc++.h>
using namespace std;
class arm_param
{
public:
template < typename T, typename... Us >
void add_arm ( double mean, Us... args )
{
++n;
means.push_back ( mean );
arms.push_back ( make_shared < T > ( mean, args..., &rnd ) );
}
std::vector < double > get_means () const
{
return means;
}
std::vector < std::shared_ptr < rnd_dist > > get_arms () const
{
return arms;
}
private:
int n;
std::vector < double > means;
std::vector < std::shared_ptr < rnd_dist > > arms;
};
int main ( int argc, char **argv )
{
auto args = parse_args ( argc, argv );
int num_iter = args[1];
int num_arms = args[5];
arm_param param;
for ( int i = 1; i <= num_arms; ++i ) {
param.add_arm < bernoulli_dist > ( rnd.rndd () );
}
shared_ptr < Bandit > bandit;
if ( string ( argv[2] ) == "MAB" ) {
bandit = make_shared < MAB > ( param.get_arms (), param.get_means () );
} else if ( string ( argv[2] ) == "MixedMAB" ) {
arm_param explicit_arms;
for ( int i = 1; i <= num_arms; ++i ) {
explicit_arms.add_arm < bernoulli_dist > ( rnd.rndd () );
}
bandit = make_shared < MixedMAB > ( param.get_arms (), param.get_means (), explicit_arms.get_arms () );
}
shared_ptr < Solver > solver;
if ( string ( argv[3] ) == "UCB1" ) {
solver = make_shared < UCB1 > ( bandit );
} else if ( string ( argv[3] ) == "MixedUCB" ) {
solver = make_shared < MixedUCB > ( bandit );
}
auto result = solver->solve ( num_iter, num_iter / 10000 );
ofstream OUT ( argv[4] );
for ( auto x : result ) {
OUT << x.first << " " << x.second << endl;
}
}
\ No newline at end of file
#!/bin/bash
if [ $# == 0 ]; then
_s=main
else
_s=$1
fi
g++ $_s.cpp -o $_s -O2 -std=c++17 -lpthread
#pragma once
#include "rnd.h"
class rnd_dist
{
public:
rnd_dist ( double clip_l, double clip_r, rnd_engine * engine )
: clip_l ( clip_l ), clip_r ( clip_r ), engine ( engine )
{
}
virtual double get () = 0;
virtual double operator () ()
{
while ( true ) {
double val = get ();
if ( val >= clip_l && val <= clip_r ) return val;
}
}
protected:
double clip_l, clip_r;
rnd_engine * engine;
};
class normal_dist : public rnd_dist
{
public:
normal_dist ( double mean, double stddev, double clip_l, double clip_r, rnd_engine * engine )
: rnd_dist ( clip_l, clip_r, engine ), dist ( mean, stddev )
{
}
virtual double get ()
{
return dist ( engine->get_engine () );
}
private:
std::normal_distribution < double > dist;
};
class bernoulli_dist : public rnd_dist
{
public:
bernoulli_dist ( double mean, rnd_engine * engine )
: rnd_dist ( -1, 2, engine ), dist ( mean )
{
}
virtual double get ()
{
return dist ( engine->get_engine () ) ? 1.0 : 0.0;
}
private:
std::bernoulli_distribution dist;
};
\ No newline at end of file
#include "solver.h"
#include "utils.h"
#include <bits/stdc++.h>
using namespace std;
class arm_param
{
public:
template < typename T, typename... Us >
void add_arm ( double mean, Us... args )
{
++n;
means.push_back ( mean );
arms.push_back ( make_shared < T > ( mean, args..., &rnd ) );
}
std::vector < double > get_means () const
{
return means;
}
std::vector < std::shared_ptr < rnd_dist > > get_arms () const
{
return arms;
}
private:
int n;
std::vector < double > means;
std::vector < std::shared_ptr < rnd_dist > > arms;
};
int main ( int argc, char **argv )
{
auto args = parse_args ( argc, argv );
int num_iter = args[1];
int num_arms = args[5];
arm_param param;
for ( int i = 1; i <= num_arms; ++i ) {
double mean = rnd.rndd () * 0.6 + 0.2;
param.add_arm < normal_dist > ( mean, 1.0 / 10, mean - 0.2, mean + 0.2 );
}
shared_ptr < Bandit > bandit;
if ( string ( argv[2] ) == "MAB" ) {
bandit = make_shared < MAB > ( param.get_arms (), param.get_means () );
} else if ( string ( argv[2] ) == "MixedMAB" ) {
arm_param explicit_arms;
for ( int i = 1; i <= num_arms; ++i ) {
double mean = - ( rnd.rndd () * 0.6 + 0.2 );
explicit_arms.add_arm < normal_dist > ( mean, 1.0 / 10, mean - 0.2, mean + 0.2 );
}
bandit = make_shared < MixedMAB > ( param.get_arms (), param.get_means (), explicit_arms.get_arms () );
}
shared_ptr < Solver > solver;
if ( string ( argv[3] ) == "UCB1" ) {
solver = make_shared < UCB1 > ( bandit );
} else if ( string ( argv[3] ) == "MixedUCB" ) {
solver = make_shared < MixedUCB > ( bandit );
}
auto result = solver->solve ( num_iter, num_iter / 10000 );
ofstream OUT ( argv[4] );
for ( auto x : result ) {
OUT << x.first << " " << x.second << endl;
}
}
\ No newline at end of file
#pragma once
#include <iostream>
#include <iomanip>
#include <cmath>
#include <sstream>
#include <thread>
#include <chrono>
#include <mutex>
class ProgressBar
{
public:
ProgressBar ( int _total = 0 ) : total ( _total ), count ( 0 ), fill_char ( '#' ), width ( 50 ), desc ( "" ), unit ( "it" ), precision ( 2 ), flush_frequency ( 50 ), last_line_length ( 0 ), finished ( true )
{
}
void clear ()
{
if ( !finished ) {
finish ();
}
count = 0;
}
void start ()
{
clear ();
start_time = std::chrono::steady_clock::now ();
finished = false;
daemon = std::thread ( &ProgressBar::daemon_thread, this );
}
template < typename T >
ProgressBar & operator << ( T x )
{
std::lock_guard < std::mutex > lock ( output_buf_lock );
output_buf << x;
return *this;
}
void finish ()
{
if ( finished ) {
return;
}
finished = true;
daemon.join ();
std::cout << std::endl;
}
int operator = ( const int _count )
{
std::lock_guard < std::mutex > lock ( count_lock );
return count = _count;
}
void operator ++ ()
{
std::lock_guard < std::mutex > lock ( count_lock );
++count;
}
void operator ++ ( int _ )
{
std::lock_guard < std::mutex > lock ( count_lock );
++count;
}
int operator += ( int delta )
{
std::lock_guard < std::mutex > lock ( count_lock );
return ( count += delta );
}
int get_count () const
{
return count;
}
void set_total ( int _total )
{
total = _total;
}
int get_total () const
{
return total;
}
void set_fill_char ( char _fill_char )
{
fill_char = _fill_char;
}
char get_fill_char () const
{
return fill_char;
}
void set_width ( int _width )
{
width = _width;
}
int get_width () const
{
return width;
}
void set_desc ( std::string _desc )
{
desc = _desc;
}
std::string get_desc () const
{
return desc;
}
void set_unit ( std::string _unit )
{
unit = _unit;
}
std::string get_unit () const
{
return unit;
}
void set_precision ( int _precision )
{
precision = _precision;
}
int get_precision () const
{
return precision;
}
void set_flush_frequency ( int _flush_frequency )
{
flush_frequency = _flush_frequency;
}
int get_flush_frequency () const
{
return flush_frequency;
}
void clear_tail_info ()
{
tail_info.clear ();
tail_info.str ( "" );
}
template < typename T >
void add_tail_info ( T s )
{
std::lock_guard < std::mutex > lock ( tail_info_lock );
tail_info << s;
}
private:
void daemon_thread ()
{
while ( 1 ) {
bool current_finished = finished;
std::cout << "\r";
for ( int i = 0; i < last_line_length; ++i )
{
std::cout << " ";
}
std::cout << "\r";
{
std::lock_guard < std::mutex > lock ( output_buf_lock );
if ( output_buf.good () ) {
std::cout << output_buf.str ();
std::cout.flush ();
}
output_buf.clear ();
output_buf.str ( "" );
}
std::stringstream temp_buf;
if ( desc != "" ) {
temp_buf << desc << ": ";
}
double percent = static_cast < double > ( std::min ( count, total ) ) * 100 / total;
if ( precision ) {
percent = static_cast < int > ( floor ( percent * pow ( 10, precision ) ) ) / pow ( 10, precision );
if ( count >= total ) percent = 100;
temp_buf << std::setw ( precision + 4 ) << std::setprecision ( precision ) << std::fixed << percent << "%";
} else {
temp_buf << static_cast < int > ( floor ( percent ) ) << "%";
}
temp_buf << "|";
double char_width = 100.0 / width;
for ( int i = 1; i <= width; ++i ) {
if ( count >= total || percent >= i * char_width ) {
temp_buf << fill_char;
} else {
temp_buf << " ";
}
}
temp_buf << "| ";
temp_buf << count << "/" << total << " ";
temp_buf << "[";
auto time_used = std::chrono::steady_clock::now () - start_time;
long long time_count = std::chrono::duration_cast < std::chrono::seconds > ( time_used ).count ();
long long time_remain = count ? static_cast < long long > ( std::chrono::duration_cast < std::chrono::duration < double > > ( time_used ).count () * ( total - count ) / count ) : -1;
auto add_time = [&] ( long long x )
{
if ( x == -1 ) {
temp_buf << "??:??";
return;
}
if ( x / 60 / 60 > 0 ) {
temp_buf << std::setw ( 2 ) << std::setfill ( '0' ) << x / 60 / 60 << ":";
}
temp_buf << std::setw ( 2 ) << std::setfill ( '0' ) << x / 60 % 60 << ":";
temp_buf << std::setw ( 2 ) << std::setfill ( '0' ) << x % 60;
};
add_time ( time_count );
temp_buf << "<";
add_time ( time_remain );
temp_buf << ", ";
if ( count ) {
double speed = static_cast < double > ( count ) / std::chrono::duration_cast < std::chrono::duration < double > > ( time_used ).count ();
if ( speed >= 1 ) {
temp_buf << std::setprecision ( 2 ) << std::fixed << speed << unit << "/s";
} else {
temp_buf << std::setprecision ( 2 ) << std::fixed << 1 / speed << "s/" << unit;
}
} else {
temp_buf << "?.??" << unit << "/s";
}
temp_buf << "]";
{
std::lock_guard < std::mutex > lock_guard ( tail_info_lock );
if ( !tail_info.str ().empty () ) {
temp_buf << ", " << tail_info.str ();
}
}
std::cout << temp_buf.str ();
last_line_length = temp_buf.str ().size ();
std::cout.flush ();
if ( current_finished ) {
return;
}
std::this_thread::sleep_for ( std::chrono::milliseconds ( flush_frequency ) );
}
}
int total;
int count;
char fill_char;
int width;
std::string desc;
std::string unit;
int precision;
int flush_frequency;
std::chrono::steady_clock::time_point start_time;
std::stringstream output_buf;
mutable std::mutex count_lock;
mutable std::mutex tail_info_lock;
mutable std::mutex output_buf_lock;
int last_line_length;
std::thread daemon;
bool finished;
std::stringstream tail_info;
};
class pb_tail_clear_t
{
public:
pb_tail_clear_t ()
{
}
};
pb_tail_clear_t pb_tail_clear;
class pb_tail_info
{
public:
pb_tail_info ( ProgressBar &_progress_bar )
: progress_bar ( &_progress_bar )
{
}
friend pb_tail_info operator << ( pb_tail_info pb, const pb_tail_clear_t a )
{
pb.progress_bar->clear_tail_info ();
return pb;
}
template < typename T > friend pb_tail_info operator << ( pb_tail_info pb, const T a )
{
pb.progress_bar->add_tail_info ( a );
return pb;
}
private:
ProgressBar * progress_bar;
};
#pragma once
#include <vector>
#include <chrono>
#include <random>
#include <algorithm>
class rnd_engine
{
public:
rnd_engine ()
: mt_engine ( std::chrono::system_clock::now ().time_since_epoch ().count () )
{
}
long long rnd ()
{
return static_cast <long long> ( mt_engine () & ( static_cast <unsigned long long> ( -1 ) >> 1 ) );
}
long double rndd ()
{
return static_cast < long double > ( rnd () ) / ( static_cast <unsigned long long> ( -1 ) >> 1 );
}
long long operator () ()
{
return rnd ();
}
int operator () ( int l, int r )
{
return static_cast < int > ( ( rnd () % ( r - l + 1 ) ) + l );
}
long long operator () ( long long l, long long r )
{
return ( rnd () % ( r - l + 1 ) ) + l;
}
int choice ( const std::vector <int> &vec )
{
return vec[this->operator () ( 0, vec.size () - 1 )];
}
std::vector <int> sample ( const std::vector <int> &vec, int k )
{
std::vector <int> p ( k );
for ( int i = 0; i < k; ++i ) {
p[i] = this->operator() ( 0, vec.size () - k );
}
std::sort ( p.begin (), p.end () );
std::vector <int> res ( k );
for ( int i = 0; i < k; ++i ) {
p[i] += i;
res[i] = vec[p[i]];
}
return res;
}
std::mt19937_64 & get_engine ()
{
return mt_engine;
}
private:
std::mt19937_64 mt_engine;
};
rnd_engine rnd;
#pragma once
#include "bandit.h"
#include "progress_bar.h"
#include <algorithm>
class Solver
{
public:
Solver ( std::shared_ptr < Bandit > _bandit )
: n ( _bandit->get_n () ), bandit ( _bandit )
{
}
virtual void clear ()
{
regret = 0;
}
virtual double apply_arm ( int i )
{
double value = bandit->operator() ( i );
regret += bandit->get_regret ( i );
return value;
}
virtual int solve_iter () = 0;
virtual std::vector < std::pair < int, int > > solve ( int num_iter, int store_regret_gap = 1 )
{
clear ();
std::vector < std::pair < int, int > > res;
ProgressBar progress_bar ( num_iter );
progress_bar.set_desc ( "Simulating" );
progress_bar.start ();
for ( int i = 1; i <= num_iter; ++i ) {
++progress_bar;
bandit->new_iter ();
solve_iter ();
pb_tail_info ( progress_bar ) << pb_tail_clear << "current regret: " << get_regret ();
if ( i % store_regret_gap == 0 ) {
res.push_back ( std::make_pair ( i, get_regret () ) );
}
}
progress_bar.finish ();
return res;
}
int get_regret () const
{
return static_cast < int > ( round ( regret ) );
}
protected:
int n;
double regret;
std::shared_ptr < Bandit > bandit;
};
class UCB1 : public Solver
{
public:
UCB1 ( std::shared_ptr < Bandit > _bandit )
: Solver ( _bandit )
{
}
virtual void clear ()
{
Solver::clear ();
tot = 0;
count.resize ( n );
for ( int i = 0; i < n; ++i ) {
count[i] = 0;
}
totval.resize ( n );
for ( int i = 0; i < n; ++i ) {
totval[i] = 0;
}
}
virtual int choose_arm ()
{
int ansp = -1;
double max_heuristic = 0, cur_heuristic;
for ( int i = 0; i < n; ++i ) {
if ( !count[i] ) {
return i;
}
cur_heuristic = totval[i] / count[i] + sqrt ( 2 * log ( tot ) / count[i] );
if ( ansp == -1 || cur_heuristic > max_heuristic ) {
ansp = i;
max_heuristic = cur_heuristic;
}
}
return ansp;
}
virtual int solve_iter ()
{
int arm = choose_arm ();
++tot;
++count[arm];
totval[arm] += apply_arm ( arm );
return arm;
}
protected:
int tot;
std::vector < int > count;
std::vector < double > totval;
};
class MixedUCB : public Solver
{
public:
MixedUCB ( std::shared_ptr < Bandit > _bandit )
: Solver ( _bandit )
{
mixed_mab = std::dynamic_pointer_cast < MixedMAB > ( bandit );
}
virtual void clear ()
{
Solver::clear ();
tot = 0;
count.resize ( n );
for ( int i = 0; i < n; ++i ) {
count[i] = 0;
}
totval.resize ( n );
for ( int i = 0; i < n; ++i ) {
totval[i] = 0;
}
}
virtual int choose_arm ()
{
int ansp = -1;
double max_heuristic = 0, cur_heuristic;
for ( int i = 0; i < n; ++i ) {
if ( !count[i] ) {
return i;
}
cur_heuristic = totval[i] / count[i] + mixed_mab->get_explicit_val ()[i] + sqrt ( 2 * log ( tot ) / count[i] );
if ( ansp == -1 || cur_heuristic > max_heuristic ) {
ansp = i;
max_heuristic = cur_heuristic;
}
}
return ansp;
}
virtual int solve_iter ()
{
int arm = choose_arm ();
++tot;
++count[arm];
totval[arm] += apply_arm ( arm ) - mixed_mab->get_explicit_val ()[arm];
return arm;
}
protected:
int tot;
std::vector < int > count;
std::vector < double > totval;
std::shared_ptr < MixedMAB > mixed_mab;
};
\ No newline at end of file
#pragma once
#include <utility>
#include <algorithm>
#include <vector>
#include <sstream>
template < typename T >
T string2T ( char *s )
{
std::stringstream buf;
buf << s;
T x;
buf >> x;
return x;
}
std::vector < int > parse_args ( int argc, char **argv )
{
std::vector < int > args;
args.push_back ( 0 );
for ( int i = 1; i < argc; ++i ) {
args.push_back ( string2T <int> ( argv[i] ) );
}
return args;
}
std::vector <int> xrange ( int l, int r )
{
std::vector <int> res ( r - l + 1 );
for ( int i = l; i <= r; ++i ) {
res[i - l] = i;
}
return res;
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment