... | @@ -7,7 +7,19 @@ Both examples consider the following combinations of hyperparameters used for Ba |
... | @@ -7,7 +7,19 @@ Both examples consider the following combinations of hyperparameters used for Ba |
|
* K2 metric with GMM and logit nodes;
|
|
* K2 metric with GMM and logit nodes;
|
|
* K2 with initial structure.
|
|
* K2 with initial structure.
|
|
|
|
|
|
All the examples are executed using cross-validation.
|
|
All the examples are executed using cross-validation, the data is preproccesed as follows:
|
|
|
|
|
|
|
|
~~~python
|
|
|
|
data.dropna(inplace=True)
|
|
|
|
data.reset_index(inplace=True, drop=True)
|
|
|
|
|
|
|
|
encoder = preprocessing.LabelEncoder()
|
|
|
|
discretizer = preprocessing.KBinsDiscretizer(n_bins=5, encode='ordinal', strategy='quantile')
|
|
|
|
|
|
|
|
p = pp.Preprocessor([('encoder', encoder), ('discretizer', discretizer)])
|
|
|
|
|
|
|
|
discretized_data, est = p.apply(train)
|
|
|
|
~~~
|
|
|
|
|
|
# Geological data example
|
|
# Geological data example
|
|
|
|
|
... | @@ -19,18 +31,58 @@ The data set contains 9 variables with 442 samples. The target variable for pred |
... | @@ -19,18 +31,58 @@ The data set contains 9 variables with 442 samples. The target variable for pred |
|
|
|
|
|
### K2 metric sampling example
|
|
### K2 metric sampling example
|
|
|
|
|
|
|
|
To sample using K2 metric the following code can be used:
|
|
|
|
|
|
|
|
~~~python
|
|
|
|
train, validation = train_test_split(data, test_size=0.1)
|
|
|
|
bn = Nets.HybridBN(has_logit=False, use_mixture=False)
|
|
|
|
bn.add_nodes(info)
|
|
|
|
bn.add_edges(discretized_data, scoring_function=('K2',K2Score))
|
|
|
|
bn.fit_parameters(train)
|
|
|
|
# prediction
|
|
|
|
val_pred = bn.predict(validation.iloc[:,:8], 5)
|
|
|
|
# sampling
|
|
|
|
sample = bn.sample(5000, parall_count=5)
|
|
|
|
~~~
|
|
|
|
|
|
<img width="353" alt="K2 geo" src="https://user-images.githubusercontent.com/86363785/188191005-45898257-ff57-4a5c-ba6c-1012c28e689e.png">
|
|
<img width="353" alt="K2 geo" src="https://user-images.githubusercontent.com/86363785/188191005-45898257-ff57-4a5c-ba6c-1012c28e689e.png">
|
|
|
|
|
|
![k2](https://user-images.githubusercontent.com/86363785/188129119-dfa62b6d-b1fd-4e63-aa75-fb7aafba95a1.png)
|
|
![k2](https://user-images.githubusercontent.com/86363785/188129119-dfa62b6d-b1fd-4e63-aa75-fb7aafba95a1.png)
|
|
|
|
|
|
### Sampling with K2 + GMM example
|
|
### Sampling with K2 + GMM example
|
|
|
|
|
|
|
|
To sample using K2 with GMM the following code can be used:
|
|
|
|
|
|
|
|
~~~python
|
|
|
|
train, validation = train_test_split(data, test_size=0.1)
|
|
|
|
bn = Nets.HybridBN(has_logit=False, use_mixture=True)
|
|
|
|
bn.add_nodes(info)
|
|
|
|
bn.add_edges(discretized_data, scoring_function=('K2',K2Score))
|
|
|
|
bn.fit_parameters(train)
|
|
|
|
# prediction
|
|
|
|
val_pred = bn.predict(validation.iloc[:,:8], 5)
|
|
|
|
# sampling
|
|
|
|
sample = bn.sample(5000, parall_count=5)
|
|
|
|
~~~
|
|
|
|
|
|
<img width="566" alt="K2 + GMM geo" src="https://user-images.githubusercontent.com/86363785/188191226-4c7c3e7a-d91a-4c43-81b5-39698c87c0b8.png">
|
|
<img width="566" alt="K2 + GMM geo" src="https://user-images.githubusercontent.com/86363785/188191226-4c7c3e7a-d91a-4c43-81b5-39698c87c0b8.png">
|
|
|
|
|
|
![geo_k2_gmm](https://user-images.githubusercontent.com/86363785/188129748-ce239eb4-bbab-43f0-9d80-c92483f27613.png)
|
|
![geo_k2_gmm](https://user-images.githubusercontent.com/86363785/188129748-ce239eb4-bbab-43f0-9d80-c92483f27613.png)
|
|
|
|
|
|
### Sampling with K2 + GMM + logit nodes example
|
|
### Sampling with K2 + GMM + logit nodes example
|
|
|
|
|
|
|
|
~~~python
|
|
|
|
train, validation = train_test_split(data, test_size=0.1)
|
|
|
|
bn = Nets.HybridBN(has_logit=True, use_mixture=False)
|
|
|
|
bn.add_nodes(info)
|
|
|
|
bn.add_edges(discretized_data, scoring_function=('K2',K2Score))
|
|
|
|
bn.fit_parameters(train)
|
|
|
|
# prediction
|
|
|
|
val_pred = bn.predict(validation.iloc[:,:8], 5)
|
|
|
|
# sampling
|
|
|
|
sample = bn.sample(5000, parall_count=5)
|
|
|
|
~~~
|
|
|
|
|
|
<img width="637" alt="K2 + gmm + logit geo" src="https://user-images.githubusercontent.com/86363785/188191300-759aa54d-3d2e-4b50-80d5-7fa7347e557f.png">
|
|
<img width="637" alt="K2 + gmm + logit geo" src="https://user-images.githubusercontent.com/86363785/188191300-759aa54d-3d2e-4b50-80d5-7fa7347e557f.png">
|
|
|
|
|
|
![geo_k2_gmm_logit](https://user-images.githubusercontent.com/86363785/188129774-a3695199-776d-493f-8a9c-bf78125f03fb.png)
|
|
![geo_k2_gmm_logit](https://user-images.githubusercontent.com/86363785/188129774-a3695199-776d-493f-8a9c-bf78125f03fb.png)
|
... | | ... | |