Skip to content

Commit

Permalink
Add first functions, unit tests, Rational, equality operator
Browse files Browse the repository at this point in the history
  • Loading branch information
rikardn committed Sep 11, 2021
1 parent fa49314 commit 2444b94
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 6 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
build
96 changes: 94 additions & 2 deletions src/symengine.f90
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module symengine

use iso_c_binding, only: c_size_t, c_long, c_char, c_ptr, c_null_ptr, c_null_char, c_f_pointer, c_associated
use iso_c_binding, only: c_size_t, c_int, c_long, c_char, c_ptr, c_null_ptr, c_null_char, c_f_pointer, c_associated
implicit none

interface
Expand Down Expand Up @@ -62,6 +62,26 @@ function c_basic_pow(s, a, b) bind(c, name='basic_pow')
type(c_ptr), value :: s, a, b
integer(c_long) :: c_basic_pow
end function
function c_basic_eq(a, b) bind(c, name='basic_eq')
import :: c_int, c_ptr
type(c_ptr), value :: a, b
integer(c_int) :: c_basic_eq
end function
function c_basic_neq(a, b) bind(c, name='basic_neq')
import :: c_int, c_ptr
type(c_ptr), value :: a, b
integer(c_int) :: c_basic_neq
end function
function c_basic_sin(s, a) bind(c, name='basic_sin')
import :: c_long, c_ptr
type(c_ptr), value :: s, a
integer(c_long) :: c_basic_sin
end function
function c_basic_cos(s, a) bind(c, name='basic_cos')
import :: c_long, c_ptr
type(c_ptr), value :: s, a
integer(c_long) :: c_basic_cos
end function
function c_integer_set_si(s, i) bind(c, name='integer_set_si')
import :: c_long, c_ptr
type(c_ptr), value :: s
Expand All @@ -73,6 +93,12 @@ function c_integer_get_si(s) bind(c, name='integer_get_si')
type(c_ptr), value :: s
integer(c_long) :: c_integer_get_si
end function
function c_rational_set_si(s, a, b) bind(c, name='rational_set_si')
import :: c_long, c_ptr
type(c_ptr), value :: s
integer(c_long), value :: a, b
integer(c_long) :: c_rational_set_si
end function
function c_symbol_set(s, c) bind(c, name='symbol_set')
import c_long, c_ptr, c_char
type(c_ptr), value :: s
Expand All @@ -86,20 +112,30 @@ function c_symbol_set(s, c) bind(c, name='symbol_set')
type(c_ptr) :: ptr = c_null_ptr
logical :: tmp = .false.
contains
procedure :: str, basic_assign, basic_add, basic_sub, basic_mul, basic_div, basic_pow
procedure :: str, basic_assign, basic_add, basic_sub, basic_mul, basic_div, basic_pow, basic_eq, basic_neq
generic :: assignment(=) => basic_assign
generic :: operator(+) => basic_add
generic :: operator(-) => basic_sub
generic :: operator(*) => basic_mul
generic :: operator(/) => basic_div
generic :: operator(**) => basic_pow
generic :: operator(==) => basic_eq
generic :: operator(/=) => basic_neq
final :: basic_free
end type

interface Basic
module procedure basic_new
end interface

interface sin
module procedure basic_sin
end interface

interface cos
module procedure basic_cos
end interface

type, extends(Basic) :: SymInteger
contains
procedure :: get
Expand All @@ -109,13 +145,23 @@ function c_symbol_set(s, c) bind(c, name='symbol_set')
module procedure integer_new
end interface

type, extends(Basic) :: Rational
end type Rational

interface Rational
module procedure rational_new
end interface

type, extends(Basic) :: Symbol
end type Symbol

interface Symbol
module procedure symbol_new
end interface

private
public :: Basic, SymInteger, Rational, Symbol, parse, sin, cos


contains

Expand Down Expand Up @@ -202,6 +248,40 @@ function basic_pow(a, b)
basic_pow%tmp = .true.
end function

function basic_eq(a, b)
class(basic), intent(in) :: a, b
logical :: basic_eq
integer(c_int) :: dummy
dummy = c_basic_eq(a%ptr, b%ptr)
basic_eq = (dummy /= 0)
end function

function basic_neq(a, b)
class(basic), intent(in) :: a, b
logical :: basic_neq
integer(c_int) :: dummy
dummy = c_basic_neq(a%ptr, b%ptr)
basic_neq = (dummy /= 0)
end function

function basic_sin(a)
class(basic), intent(in) :: a
type(basic) :: basic_sin
integer(c_long) :: dummy
basic_sin = Basic()
dummy = c_basic_sin(basic_sin%ptr, a%ptr)
basic_sin%tmp = .true.
end function

function basic_cos(a)
class(basic), intent(in) :: a
type(basic) :: basic_cos
integer(c_long) :: dummy
basic_cos = Basic()
dummy = c_basic_cos(basic_cos%ptr, a%ptr)
basic_cos%tmp = .true.
end function

function integer_new(i)
integer :: i
integer(c_long) :: j
Expand All @@ -219,6 +299,18 @@ function get(this) result(i)
i = int(c_integer_get_si(this%ptr))
end function

function rational_new(a, b)
integer :: a, b
integer(c_long) :: x, y
integer(c_long) :: dummy
type(Rational) :: rational_new
x = int(a)
y = int(b)
rational_new%ptr = c_basic_new_heap()
dummy = c_rational_set_si(rational_new%ptr, x, y)
rational_new%tmp = .true.
end function

function symbol_new(c)
character(len=*) :: c
character(len=len_trim(c) + 1) :: new_c
Expand Down
38 changes: 34 additions & 4 deletions src/tests/test_basic.f90
Original file line number Diff line number Diff line change
@@ -1,13 +1,43 @@
subroutine assert_eq(a, b)
use symengine
type(Basic) :: a, b
if (a /= b) then
stop 1
end if
end subroutine


subroutine dostuff()
use symengine
type(Basic) :: a, b, c

a = SymInteger(12)
b = Symbol('x')
c = a * b
print *, c%str()
c = parse('2*(24+x)')
print *, c%str()
c = parse('x * 12')
call assert_eq(a * b, c)

c = parse('x + 12')
call assert_eq(a + b, c)

c = parse('12 - x')
call assert_eq(a - b, c)

c = parse('12 / x')
call assert_eq(a / b, c)

c = parse('12 ** x')
call assert_eq(a ** b, c)

c = parse('sin(x)')
call assert_eq(sin(b), c)

c = parse('cos(x)')
call assert_eq(cos(b), c)

a = Rational(1, 2)
b = Rational(3, 4)
c = Rational(3, 8)
call assert_eq(a * b, c)
end subroutine


Expand Down

0 comments on commit 2444b94

Please sign in to comment.