MOAI Project
from typing import List, Optional, Union, Tuple, Dict, Any
from pydantic import BaseModel, Field, ValidationError, validator, confloat, conint, constr, Json
from pydantic.tools import parse_obj_as
import pandas as pd
import numpy as np
import fastai.tabular.core as ft

文脈から問題例を生成する関数 instance_generation

引数:

  • 文脈(Context;たとえば日付)を表すデータフレーム
  • 他のTabularPandasの引数

返値:

  • 問題例を生成

これは, 表データから表データを生成する深層学習モデルなので, fastaiの TabularPandas で代用できる. 複数ラベル回帰を用いる. 顧客数だけを予測してから,点の座標を予測する. その後,予測した点の数だけサンプリングする.

https://docs.fast.ai/tabular.core.html#tabularpandas

他の方法としては,問題例をクラスタリングして分類ラベルを付与し, 与えられた文脈からラベルを当てる分類を適当な機械学習で行うことも考えられる.

df = pd.DataFrame({'date': pd.date_range(
    start='2021/1/1',
    freq='d',
    periods=30
)})
#ft.make_date(df, 'date')
ft.add_datepart(df, "date")
df.head()
Year Month Week Day Dayofweek Dayofyear Is_month_end Is_month_start Is_quarter_end Is_quarter_start Is_year_end Is_year_start Elapsed
0 2021 1 53 1 4 1 False True False True False True 1.609459e+09
1 2021 1 53 2 5 2 False False False False False False 1.609546e+09
2 2021 1 53 3 6 3 False False False False False False 1.609632e+09
3 2021 1 1 4 0 4 False False False False False False 1.609718e+09
4 2021 1 1 5 1 5 False False False False False False 1.609805e+09

問題例を入力とし,過去の学習データから解の情報を返す関数

問題例間の距離を定義する必要がある.

Instanceクラス

  • たとえばTSPなら点数と座標の列や地点間の移動費用,時間の行列;  点の座標は(たとえば北東方向に)大きい順に並べておく.
  • 他の問題例との距離計算(輸送問題を解く) 割当問題の解法 https://scmopt.github.io/opt100/33ap.html#%E5%89%B2%E5%BD%93%E5%95%8F%E9%A1%8C
  • 問題例をクラスタリングしたときのラベル
  • 対応する文脈データフレーム

Solutionクラス

  • 対応するInstance
class Instance():
        
    def __init__(self, context=None, data=None, feature=None):
        self.context = context
        self.data = data
        self.feature = self.extract_feature(data)
    
    def extract_feature(self, data=None):
        pass
    
    def distance(self, another):
        pass
        
    class Config:
        arbitrary_types_allowed = True
        
class Solution():
    pass
an_instance = Instance()
n_max, n_min = 100, 50 #平日と休日の平均顧客数
sd = 10 #standard deviation 
data = []
for row in df.itertuples():
    if row.Dayofweek<=4:
        loc = np.random.normal( loc= 0, scale = sd, size= (n_max,2) )
        instance = Instance(numpyArray=loc)
    else:
        loc = np.random.normal( loc= 0, scale = sd, size= (n_min,2) )
        instance = Instance(numpyArray=loc)
    
    new_loc = np.concatenate( ((loc[:,0] + loc[:,1]).reshape( (-1,1) ), loc ), axis=1 )
    new_loc.sort(axis=0)
    sorted_loc = new_loc[:,1:]
    
    data.append( sorted_loc.flatten() )
new_loc = np.concatenate( ((loc[:,0] + loc[:,1]).reshape( (-1,1) ), loc ), axis=1 )
new_loc.sort(axis=0)
sorted_loc = new_loc[:,1:]
sorted_loc
array([[-20.19666462, -24.74534989],
       [-13.69222509, -17.30384304],
       [-12.33393609, -15.45093693],
       [-10.53916806, -15.32671206],
       [ -9.35126821, -15.22409121],
       [ -8.822563  , -14.52154506],
       [ -8.73932943, -12.28191137],
       [ -8.36073931, -12.13593063],
       [ -8.02864509, -11.33781158],
       [ -7.74814991, -11.22964621],
       [ -7.7033224 , -10.10622249],
       [ -7.42682121,  -8.17208738],
       [ -6.55308268,  -7.914439  ],
       [ -5.66587855,  -6.38066481],
       [ -5.282336  ,  -6.34512238],
       [ -4.90897862,  -4.69273309],
       [ -3.94285144,  -4.53736864],
       [ -3.23423259,  -4.38628677],
       [ -3.16378394,  -4.35164651],
       [ -2.82444794,  -3.8416044 ],
       [ -2.72227564,  -3.29531517],
       [ -2.62622257,  -2.25543761],
       [ -2.40086914,  -2.03813249],
       [ -1.5312655 ,  -1.94330015],
       [ -1.5306093 ,  -1.57389452],
       [ -1.24658262,  -0.88679573],
       [ -1.03659721,  -0.83513033],
       [ -0.64599481,  -0.72145542],
       [  1.16255704,   1.08847602],
       [  2.35537998,   1.26442765],
       [  2.54497356,   2.33036034],
       [  4.16398703,   3.35935826],
       [  4.19791133,   3.81257709],
       [  4.6522782 ,   3.85829137],
       [  4.83735901,   3.99427655],
       [  4.96292333,   4.71795374],
       [  5.65189015,   5.10307142],
       [  5.75416739,   6.04482509],
       [  6.40985136,   6.39549295],
       [  7.24370719,   6.71061287],
       [  7.25471248,   7.65835066],
       [  8.06849262,   7.83926117],
       [  8.6756994 ,   8.892523  ],
       [ 11.8957031 ,   9.88954257],
       [ 12.03415855,  10.10811007],
       [ 12.91862634,  10.68724123],
       [ 14.15213083,  12.69164302],
       [ 15.66748155,  12.73932924],
       [ 23.25826818,  13.61147835],
       [ 29.68281631,  22.75131276]])
sorted_loc.flatten()
array([-20.19666462, -24.74534989, -13.69222509, -17.30384304,
       -12.33393609, -15.45093693, -10.53916806, -15.32671206,
        -9.35126821, -15.22409121,  -8.822563  , -14.52154506,
        -8.73932943, -12.28191137,  -8.36073931, -12.13593063,
        -8.02864509, -11.33781158,  -7.74814991, -11.22964621,
        -7.7033224 , -10.10622249,  -7.42682121,  -8.17208738,
        -6.55308268,  -7.914439  ,  -5.66587855,  -6.38066481,
        -5.282336  ,  -6.34512238,  -4.90897862,  -4.69273309,
        -3.94285144,  -4.53736864,  -3.23423259,  -4.38628677,
        -3.16378394,  -4.35164651,  -2.82444794,  -3.8416044 ,
        -2.72227564,  -3.29531517,  -2.62622257,  -2.25543761,
        -2.40086914,  -2.03813249,  -1.5312655 ,  -1.94330015,
        -1.5306093 ,  -1.57389452,  -1.24658262,  -0.88679573,
        -1.03659721,  -0.83513033,  -0.64599481,  -0.72145542,
         1.16255704,   1.08847602,   2.35537998,   1.26442765,
         2.54497356,   2.33036034,   4.16398703,   3.35935826,
         4.19791133,   3.81257709,   4.6522782 ,   3.85829137,
         4.83735901,   3.99427655,   4.96292333,   4.71795374,
         5.65189015,   5.10307142,   5.75416739,   6.04482509,
         6.40985136,   6.39549295,   7.24370719,   6.71061287,
         7.25471248,   7.65835066,   8.06849262,   7.83926117,
         8.6756994 ,   8.892523  ,  11.8957031 ,   9.88954257,
        12.03415855,  10.10811007,  12.91862634,  10.68724123,
        14.15213083,  12.69164302,  15.66748155,  12.73932924,
        23.25826818,  13.61147835,  29.68281631,  22.75131276])
loc_df = pd.DataFrame(data)
pd.concat( [df, loc_df], axis=1 )
Year Month Week Day Dayofweek Dayofyear Is_month_end Is_month_start Is_quarter_end Is_quarter_start ... 190 191 192 193 194 195 196 197 198 199
0 2021 1 53 1 4 1 False True False True ... 14.470132 20.842812 14.809092 22.660701 15.209006 22.725063 18.868986 23.626687 29.370445 28.334336
1 2021 1 53 2 5 2 False False False False ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
2 2021 1 53 3 6 3 False False False False ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
3 2021 1 1 4 0 4 False False False False ... 14.920890 16.903697 15.339421 18.232201 15.420866 18.901594 16.108396 24.803074 17.458791 26.837565
4 2021 1 1 5 1 5 False False False False ... 16.620776 13.192277 17.113658 13.442796 17.145945 13.689042 18.394004 20.106800 23.705008 25.717087
5 2021 1 1 6 2 6 False False False False ... 19.174283 15.499169 19.307015 16.227055 21.029683 17.517485 22.046130 17.754123 26.434196 20.584023
6 2021 1 1 7 3 7 False False False False ... 19.204572 17.585625 21.110851 17.667264 21.812600 19.550941 24.309018 22.560459 25.161768 23.408903
7 2021 1 1 8 4 8 False False False False ... 15.978173 15.930315 16.300432 17.716276 18.003666 20.133823 20.179454 22.199705 20.717632 24.163873
8 2021 1 1 9 5 9 False False False False ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
9 2021 1 1 10 6 10 False False False False ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
10 2021 1 2 11 0 11 False False False False ... 18.734432 15.311818 19.974946 16.032423 20.910366 20.280738 23.157091 20.382886 26.204243 21.547111
11 2021 1 2 12 1 12 False False False False ... 14.473795 17.296738 18.006466 19.308644 18.114859 19.789006 19.152658 20.259421 23.773261 24.067777
12 2021 1 2 13 2 13 False False False False ... 16.047620 11.983234 16.662413 14.059299 18.522118 15.099365 19.265778 15.182389 19.778607 18.048300
13 2021 1 2 14 3 14 False False False False ... 20.008937 16.221666 20.111617 16.348063 21.458096 16.614298 29.560323 18.886797 33.268425 24.639888
14 2021 1 2 15 4 15 False False False False ... 15.167943 17.820723 16.650802 19.963430 17.963204 22.391149 18.381884 23.846735 22.882777 24.703841
15 2021 1 2 16 5 16 False False False False ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
16 2021 1 2 17 6 17 False False False False ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
17 2021 1 3 18 0 18 False False False False ... 17.136287 17.469488 17.142804 17.516448 18.589927 17.575417 21.233247 19.700874 26.086480 22.692306
18 2021 1 3 19 1 19 False False False False ... 15.395525 14.435958 17.677362 15.293260 20.340947 16.074834 25.272895 18.181873 27.679798 24.392501
19 2021 1 3 20 2 20 False False False False ... 18.327361 20.267441 18.634094 20.791375 20.257293 20.978126 21.129962 22.456371 26.234376 23.942431
20 2021 1 3 21 3 21 False False False False ... 16.626384 16.446960 18.830958 17.035065 20.053474 18.842837 20.223198 24.516268 20.841600 25.171439
21 2021 1 3 22 4 22 False False False False ... 22.943273 16.232633 23.352425 16.451594 23.798254 16.814654 27.866677 18.287373 28.669873 18.924273
22 2021 1 3 23 5 23 False False False False ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
23 2021 1 3 24 6 24 False False False False ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
24 2021 1 4 25 0 25 False False False False ... 13.051126 16.176910 14.691632 21.956990 14.696065 22.078566 15.721047 22.890334 24.247104 27.744182
25 2021 1 4 26 1 26 False False False False ... 17.066672 18.071581 19.149347 18.227819 19.395548 18.514316 22.566636 19.413032 23.693322 25.134841
26 2021 1 4 27 2 27 False False False False ... 14.615501 18.174073 14.895239 19.835223 16.661473 20.114736 22.003240 23.647944 23.992727 29.237498
27 2021 1 4 28 3 28 False False False False ... 16.130062 16.489784 16.364160 17.313268 17.327189 17.666102 18.210826 20.392974 20.443050 22.736248
28 2021 1 4 29 4 29 False False False False ... 19.313625 17.858271 20.281820 17.864886 21.246656 20.640586 30.321315 23.640204 30.574485 31.760287
29 2021 1 4 30 5 30 False False False False ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN

30 rows × 213 columns